├── .flakeheaven.toml ├── .github ├── ISSUE_TEMPLATE │ ├── bug_report.md │ └── feature_request.md └── workflows │ ├── linting.yaml │ └── tests.yaml ├── .gitignore ├── LICENSE.md ├── README.md ├── packages.txt ├── poetry.toml ├── pyproject.toml ├── pyrobbot ├── __init__.py ├── __main__.py ├── app │ ├── .streamlit │ │ └── config.toml │ ├── __init__.py │ ├── app.py │ ├── app_page_templates.py │ ├── app_utils.py │ ├── data │ │ ├── assistant_avatar.png │ │ ├── powered-by-openai-badge-outlined-on-dark.svg │ │ ├── success.wav │ │ ├── user_avatar.png │ │ └── warning.wav │ └── multipage.py ├── argparse_wrapper.py ├── chat.py ├── chat_configs.py ├── chat_context.py ├── command_definitions.py ├── embeddings_database.py ├── general_utils.py ├── internet_utils.py ├── openai_utils.py ├── sst_and_tts.py ├── tokens.py └── voice_chat.py └── tests ├── conftest.py ├── smoke ├── test_app.py └── test_commands.py └── unit ├── test_chat.py ├── test_internet_utils.py ├── test_text_to_speech.py └── test_voice_chat.py /.flakeheaven.toml: -------------------------------------------------------------------------------- 1 | [tool.flakeheaven] 2 | exclude = [".*/", "tmp/", "*/tmp/", "*.ipynb"] 3 | format = "colored" 4 | # Show line of source code in output, with syntax highlighting 5 | show_source = true 6 | style = "google" 7 | 8 | # list of plugins and rules for them 9 | [tool.flakeheaven.plugins] 10 | # Deactivate all rules for all plugins by default 11 | "*" = ["-*"] 12 | # Activate only those plugins not covered by ruff 13 | pydoclint = [ 14 | "+*", 15 | "-DOC105", 16 | "-DOC106", 17 | "-DOC107", 18 | "-DOC109", 19 | "-DOC110", 20 | "-DOC203", 21 | "-DOC301", 22 | "-DOC403", 23 | "-DOC404", 24 | ] 25 | -------------------------------------------------------------------------------- /.github/ISSUE_TEMPLATE/bug_report.md: -------------------------------------------------------------------------------- 1 | --- 2 | name: Bug report 3 | about: Create a report to help us improve 4 | title: '' 5 | labels: '' 6 | assignees: '' 7 | 8 | --- 9 | 10 | **Describe the bug** 11 | A clear and concise description of what the bug is. 12 | 13 | **To Reproduce** 14 | Steps to reproduce the behavior: 15 | 1. Go to '...' 16 | 2. Click on '....' 17 | 3. Scroll down to '....' 18 | 4. See error 19 | 20 | **Expected behavior** 21 | A clear and concise description of what you expected to happen. 22 | 23 | **Screenshots** 24 | If applicable, add screenshots to help explain your problem. 25 | 26 | **Desktop (please complete the following information):** 27 | - OS: [e.g. iOS] 28 | - Browser [e.g. chrome, safari] 29 | - Version [e.g. 22] 30 | 31 | **Smartphone (please complete the following information):** 32 | - Device: [e.g. iPhone6] 33 | - OS: [e.g. iOS8.1] 34 | - Browser [e.g. stock browser, safari] 35 | - Version [e.g. 22] 36 | 37 | **Additional context** 38 | Add any other context about the problem here. 39 | -------------------------------------------------------------------------------- /.github/ISSUE_TEMPLATE/feature_request.md: -------------------------------------------------------------------------------- 1 | --- 2 | name: Feature request 3 | about: Suggest an idea for this project 4 | title: '' 5 | labels: '' 6 | assignees: '' 7 | 8 | --- 9 | 10 | **Is your feature request related to a problem? Please describe.** 11 | A clear and concise description of what the problem is. Ex. I'm always frustrated when [...] 12 | 13 | **Describe the solution you'd like** 14 | A clear and concise description of what you want to happen. 15 | 16 | **Describe alternatives you've considered** 17 | A clear and concise description of any alternative solutions or features you've considered. 18 | 19 | **Additional context** 20 | Add any other context or screenshots about the feature request here. 21 | -------------------------------------------------------------------------------- /.github/workflows/linting.yaml: -------------------------------------------------------------------------------- 1 | #.github/workflows/linting.yaml 2 | name: Linting Checks 3 | 4 | on: 5 | pull_request: 6 | branches: 7 | - main 8 | - develop 9 | paths: 10 | - '**.py' 11 | - '.github/workflows/linting.yaml' 12 | push: 13 | branches: 14 | - '**' # Every branch 15 | paths: 16 | - '**.py' 17 | - '.github/workflows/linting.yaml' 18 | 19 | jobs: 20 | linting: 21 | if: github.repository_owner == 'paulovcmedeiros' 22 | name: Run Linters 23 | runs-on: ubuntu-latest 24 | steps: 25 | #---------------------------------------------- 26 | # check-out repo and set-up python 27 | #---------------------------------------------- 28 | - name: Check out repository 29 | uses: actions/checkout@v3 30 | - name: Set up python 31 | id: setup-python 32 | uses: actions/setup-python@v4 33 | with: 34 | python-version: '3.9' 35 | 36 | #---------------------------------------------- 37 | # --- configure poetry & install project ---- 38 | #---------------------------------------------- 39 | - name: Install Poetry 40 | uses: snok/install-poetry@v1 41 | with: 42 | virtualenvs-create: true 43 | virtualenvs-in-project: true 44 | 45 | - name: Install poethepoet 46 | run: poetry self add 'poethepoet[poetry_plugin]' 47 | 48 | - name: Load cached venv (if cache exists) 49 | id: cached-poetry-dependencies 50 | uses: actions/cache@v3 51 | with: 52 | path: .venv 53 | key: ${{ github.job }}-venv-${{ runner.os }}-${{ steps.setup-python.outputs.python-version }}-${{ hashFiles('**/pyproject.toml', '**/poetry.toml') }} 54 | 55 | - name: Install dependencies (if venv cache is not found) 56 | if: steps.cached-poetry-dependencies.outputs.cache-hit != 'true' 57 | run: poetry install --no-interaction --no-root --only main,linting 58 | 59 | - name: Install the project itself 60 | run: poetry install --no-interaction --only-root 61 | 62 | #---------------------------------------------- 63 | # Run the linting checks 64 | #---------------------------------------------- 65 | - name: Run linters 66 | run: | 67 | poetry devtools lint 68 | 69 | -------------------------------------------------------------------------------- /.github/workflows/tests.yaml: -------------------------------------------------------------------------------- 1 | #.github/workflows/tests.yaml 2 | name: Unit Tests 3 | 4 | on: 5 | pull_request: 6 | branches: 7 | - main 8 | - develop 9 | push: 10 | branches: 11 | - '**' # Every branch 12 | 13 | jobs: 14 | tests: 15 | if: github.repository_owner == 'paulovcmedeiros' 16 | strategy: 17 | fail-fast: true 18 | matrix: 19 | os: [ "ubuntu-latest" ] 20 | env: [ "pytest" ] 21 | python-version: [ "3.9" ] 22 | 23 | name: "${{ matrix.os }}, python=${{ matrix.python-version }}" 24 | runs-on: ${{ matrix.os }} 25 | 26 | container: 27 | image: python:${{ matrix.python-version }}-bullseye 28 | env: 29 | COVERAGE_FILE: ".coverage.${{ matrix.env }}.${{ matrix.python-version }}" 30 | 31 | steps: 32 | #---------------------------------------------- 33 | # check-out repo 34 | #---------------------------------------------- 35 | - name: Check out repository 36 | uses: actions/checkout@v3 37 | 38 | #---------------------------------------------- 39 | # Install Audio Gear 40 | #---------------------------------------------- 41 | - name: Install PortAudio and PulseAudio 42 | run: | 43 | apt-get update 44 | apt-get --assume-yes install portaudio19-dev python-all-dev pulseaudio ffmpeg 45 | 46 | #---------------------------------------------- 47 | # --- configure poetry & install project ---- 48 | #---------------------------------------------- 49 | - name: Install Poetry 50 | uses: snok/install-poetry@v1 51 | with: 52 | virtualenvs-create: true 53 | virtualenvs-in-project: true 54 | 55 | - name: Load cached venv (if cache exists) 56 | id: cached-poetry-dependencies 57 | uses: actions/cache@v3 58 | with: 59 | path: .venv 60 | key: ${{ github.job }}-venv-${{ runner.os }}-${{ matrix.python-version }}-${{ hashFiles('**/pyproject.toml', '**/poetry.toml') }} 61 | 62 | - name: Install dependencies (if venv cache is not found) 63 | if: steps.cached-poetry-dependencies.outputs.cache-hit != 'true' 64 | run: poetry install --no-interaction --no-root --only main,test 65 | 66 | - name: Install the project itself 67 | run: poetry install --no-interaction --only-root 68 | 69 | #---------------------------------------------- 70 | # run test suite and report coverage 71 | #---------------------------------------------- 72 | - name: Run tests 73 | env: 74 | SDL_VIDEODRIVER: "dummy" 75 | SDL_AUDIODRIVER: "disk" 76 | run: | 77 | poetry run pytest 78 | 79 | - name: Upload test coverage report to Codecov 80 | uses: codecov/codecov-action@v3 81 | with: 82 | token: ${{ secrets.CODECOV_TOKEN }} 83 | files: ./.coverage.xml 84 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | share/python-wheels/ 24 | *.egg-info/ 25 | .installed.cfg 26 | *.egg 27 | MANIFEST 28 | 29 | # PyInstaller 30 | # Usually these files are written by a python script from a template 31 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 32 | *.manifest 33 | *.spec 34 | 35 | # Installer logs 36 | pip-log.txt 37 | pip-delete-this-directory.txt 38 | 39 | # Unit test / coverage reports 40 | htmlcov/ 41 | .tox/ 42 | .nox/ 43 | .coverage 44 | .coverage.* 45 | .cache 46 | nosetests.xml 47 | coverage.xml 48 | *.cover 49 | *.py,cover 50 | .hypothesis/ 51 | .pytest_cache/ 52 | cover/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | .pybuilder/ 76 | target/ 77 | 78 | # Jupyter Notebook 79 | .ipynb_checkpoints 80 | 81 | # IPython 82 | profile_default/ 83 | ipython_config.py 84 | 85 | # pyenv 86 | # For a library or package, you might want to ignore these files since the code is 87 | # intended to run in multiple environments; otherwise, check them in: 88 | # .python-version 89 | 90 | # pipenv 91 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 92 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 93 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 94 | # install all needed dependencies. 95 | #Pipfile.lock 96 | 97 | # poetry 98 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. 99 | # This is especially recommended for binary packages to ensure reproducibility, and is more 100 | # commonly ignored for libraries. 101 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control 102 | poetry.lock 103 | 104 | # pdm 105 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. 106 | #pdm.lock 107 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it 108 | # in version control. 109 | # https://pdm.fming.dev/#use-with-ide 110 | .pdm.toml 111 | 112 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm 113 | __pypackages__/ 114 | 115 | # Celery stuff 116 | celerybeat-schedule 117 | celerybeat.pid 118 | 119 | # SageMath parsed files 120 | *.sage.py 121 | 122 | # Environments 123 | .env 124 | .venv 125 | env/ 126 | venv/ 127 | ENV/ 128 | env.bak/ 129 | venv.bak/ 130 | 131 | # Spyder project settings 132 | .spyderproject 133 | .spyproject 134 | 135 | # Rope project settings 136 | .ropeproject 137 | 138 | # mkdocs documentation 139 | /site 140 | 141 | # mypy 142 | .mypy_cache/ 143 | .dmypy.json 144 | dmypy.json 145 | 146 | # Pyre type checker 147 | .pyre/ 148 | 149 | # pytype static type analyzer 150 | .pytype/ 151 | 152 | # Cython debug symbols 153 | cython_debug/ 154 | 155 | # PyCharm 156 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can 157 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore 158 | # and can be added to the global gitignore or merged into this file. For a more nuclear 159 | # option (not recommended) you can uncomment the following to ignore the entire idea folder. 160 | .idea/ 161 | 162 | # Vim 163 | *.swp 164 | 165 | # Temporary files and directories 166 | tmp/ 167 | 168 | # Linting artifacts 169 | .flakeheaven_cache/ 170 | -------------------------------------------------------------------------------- /LICENSE.md: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2023 Paulo V. C. Medeiros 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 |
2 | 3 | [![pyrobbot-logo](https://github.com/paulovcmedeiros/pyRobBot/blob/main/pyrobbot/app/data/assistant_avatar.png?raw=true)]((https://github.com/paulovcmedeiros/pyRobBot)) 4 | # [pyRobBot](https://github.com/paulovcmedeiros/pyRobBot)
Chat with GPT LLMs over voice, text or both.
All with access to the internet. 5 | 6 | [![Pepy Total Downlods](https://img.shields.io/pepy/dt/pyrobbot?style=flat&label=Downloads)](https://www.pepy.tech/projects/pyrobbot) 7 | [![PyPI - Version](https://img.shields.io/pypi/v/pyrobbot)](https://pypi.org/project/pyrobbot/) 8 | [![Streamlit App](https://static.streamlit.io/badges/streamlit_badge_black_white.svg)](https://pyrobbot.streamlit.app) 9 | [](https://openai.com/blog/openai-api) 10 | 11 | 12 | [![Poetry](https://img.shields.io/endpoint?url=https://python-poetry.org/badge/v0.json)](https://python-poetry.org/) 13 | [![Contributors Welcome](https://img.shields.io/badge/Contributors-welcome-.svg)](https://github.com/paulovcmedeiros/pyRobBot/pulls) 14 | [![Linting](https://github.com/paulovcmedeiros/pyRobBot/actions/workflows/linting.yaml/badge.svg)](https://github.com/paulovcmedeiros/pyRobBot/actions/workflows/linting.yaml) 15 | [![Tests](https://github.com/paulovcmedeiros/pyRobBot/actions/workflows/tests.yaml/badge.svg)](https://github.com/paulovcmedeiros/pyRobBot/actions/workflows/tests.yaml) 16 | [![codecov](https://codecov.io/gh/paulovcmedeiros/pyRobBot/graph/badge.svg?token=XI8G1WH9O6)](https://codecov.io/gh/paulovcmedeiros/pyRobBot) 17 | 18 |
19 | 20 | PyRobBot is a python package that uses OpenAI's [GPT large language models (LLMs)](https://platform.openai.com/docs/models) to implement a fully configurable **personal assistant** that, on top of the traditional chatbot interface, can also speak and listen to you using AI-generated **human-like** voices. 21 | 22 | 23 | ## Features 24 | 25 | Features include, but are not limited to: 26 | 27 | - [x] Voice Chat 28 | - Continuous voice input and output 29 | - No need to press a button: the assistant will keep listening until you stop talking 30 | 31 | - [x] Internet access: The assistent will **search the web** to find the answers it doesn't have in its training data 32 | - E.g. latest news, current events, weather forecasts, etc. 33 | - Powered by [DuckDuckGo Search](https://github.com/deedy5/duckduckgo_search) 34 | 35 | - [x] Web browser user interface 36 | - See our [demo app on Streamlit Community Cloud](https://pyrobbot.streamlit.app) 37 | - Voice chat with: 38 | - **Continuous voice input and output** (using [streamlit-webrtc](https://github.com/whitphx/streamlit-webrtc)) 39 | - If you prefer, manual on/off toggling of the microphone (using [streamlit_mic_recorder](https://github.com/B4PT0R/streamlit-mic-recorder)) 40 | - A familiar text interface integrated with the voice chat, for those who prefer a traditional chatbot experience 41 | - Your voice prompts and the assistant's voice replies are shown as text in the chat window 42 | - You may also send promts as text even when voice detection is enabled 43 | - Add/remove conversations dynamically 44 | - Automatic/editable conversation summary title 45 | - Autosave & retrieve chat history 46 | - Resume even the text & voice conversations started outside the web interface 47 | 48 | 49 | - [x] Chat via terminal 50 | - For a more "Wake up, Neo" experience 51 | 52 | - [x] Fully configurable 53 | - Large number of supported languages (*e.g.*, `rob --lang pt-br`) 54 | - Support for multiple LLMs through the OpenAI API 55 | - Choose your preferred Text-to-Speech (TTS) and Speech-To-Text (STT) engines (google/openai) 56 | - Control over the parameters passed to the OpenAI API, with (hopefully) sensible defaults 57 | - Ability to pass base directives to the LLM 58 | - E.g., to make it adopt a persona, but you decide which directived to pass 59 | - Dynamically modifiable AI parameters in each chat separately 60 | - No need to restart the chat 61 | 62 | - [x] Chat context handling using [embeddings](https://platform.openai.com/docs/guides/embeddings) 63 | - [x] Estimated API token usage and associated costs 64 | - [x] OpenAI API key is **never** stored on disk 65 | 66 | 67 | 68 | ## System Requirements 69 | - Python >= 3.9 70 | - A valid [OpenAI API key](https://platform.openai.com/account/api-keys) 71 | - Set it in the Web UI or through the environment variable `OPENAI_API_KEY` 72 | - To enable voice chat, you also need: 73 | - [PortAudio](https://www.portaudio.com/docs/v19-doxydocs/index.html) 74 | - Install on Ubuntu with `sudo apt-get --assume-yes install portaudio19-dev python-all-dev` 75 | - Install on CentOS/RHEL with `sudo yum install portaudio portaudio-devel` 76 | - [ffmpeg](https://ffmpeg.org/download.html) 77 | - Install on Ubuntu with `sudo apt-get --assume-yes install ffmpeg` 78 | - Install on CentOS/RHEL with `sudo yum install ffmpeg` 79 | 80 | ## Installation 81 | This, naturally, assumes your system fulfills all [requirements](#system-requirements). 82 | 83 | ### Regular Installation 84 | The recommended way for most users. 85 | 86 | #### Using pip 87 | ```shell 88 | pip install pyrobbot 89 | ``` 90 | #### From the GitHub repository 91 | ```shell 92 | pip install git+https://github.com/paulovcmedeiros/pyRobBot.git 93 | ``` 94 | 95 | ### Developer-Mode Installation 96 | The recommended way for those who want to contribute to the project. We use [poetry](https://python-poetry.org) with the [poethepoet](https://poethepoet.natn.io/index.html) plugin. To get everything set up, run: 97 | ```shell 98 | # Clean eventual previous install 99 | curl -sSL https://install.python-poetry.org | python3 - --uninstall 100 | rm -rf ${HOME}/.cache/pypoetry/ ${HOME}/.local/bin/poetry ${HOME}/.local/share/pypoetry 101 | # Download and install poetry 102 | curl -sSL https://install.python-poetry.org | python3 - 103 | # Install needed poetry plugin(s) 104 | poetry self add 'poethepoet[poetry_plugin]' 105 | ``` 106 | 107 | 108 | ## Basic Usage 109 | Upon succesfull installation, you should be able to run 110 | ```shell 111 | rob [opts] SUBCOMMAND [subcommand_opts] 112 | ``` 113 | where `[opts]` and `[subcommand_opts]` denote optional command line arguments 114 | that apply, respectively, to `rob` in general and to `SUBCOMMAND` 115 | specifically. 116 | 117 | **Please run `rob -h` for information** about the supported subcommands 118 | and general `rob` options. For info about specific subcommands and the 119 | options that apply to them only, **please run `rob SUBCOMMAND -h`** (note 120 | that the `-h` goes after the subcommand in this case). 121 | 122 | ### Using the Web UI (defult, supports voice & text chat) 123 | ```shell 124 | rob 125 | ``` 126 | See also our [demo Streamlit app](https://pyrobbot.streamlit.app)! 127 | 128 | ### Chatting Only by Voice 129 | ```shell 130 | rob voice 131 | ``` 132 | 133 | ### Running on the Terminal 134 | ```shell 135 | rob . 136 | ``` 137 | 138 | ## Disclaimers 139 | This project's main purpose has been to serve as a learning exercise for me, as well as tool for experimenting with OpenAI API, GPT LLMs and text-to-speech/speech-to-text. 140 | 141 | While it does not claim to be the best or more robust OpenAI-powered chatbot out there, it *does* aim to provide a friendly user interface that is easy to install, use and configure. 142 | 143 | Feel free to open an [issue](https://github.com/paulovcmedeiros/pyRobBot/issues) or, even better, [submit a pull request](https://github.com/paulovcmedeiros/pyRobBot/pulls) if you find a bug or have a suggestion. 144 | 145 | Last but not least: This project is **independently developed** and **not** affiliated, endorsed, or sponsored by OpenAI in any way. It is separate and distinct from OpenAI’s own products and services. 146 | -------------------------------------------------------------------------------- /packages.txt: -------------------------------------------------------------------------------- 1 | ffmpeg 2 | portaudio19-dev 3 | python-all-dev 4 | -------------------------------------------------------------------------------- /poetry.toml: -------------------------------------------------------------------------------- 1 | [virtualenvs] 2 | create = true 3 | in-project = true 4 | prefer-active-python = true 5 | 6 | [virtualenvs.options] 7 | system-site-packages = false 8 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [tool.poetry] 2 | authors = ["Paulo V C Medeiros "] 3 | description = """\ 4 | Chat with GPT LLMs over voice, text or both. With access to the internet.\ 5 | Powered by OpenAI.\ 6 | """ 7 | license = "MIT" 8 | name = "pyrobbot" 9 | readme = "README.md" 10 | version = "0.7.7" 11 | 12 | [build-system] 13 | build-backend = "poetry.core.masonry.api" 14 | requires = ["poetry-core"] 15 | 16 | [tool.poetry.scripts] 17 | rob = "pyrobbot.__main__:main" 18 | 19 | [tool.poetry.dependencies] 20 | # Python version 21 | python = ">=3.9,<3.9.7 || >3.9.7,<3.13" 22 | # Deps that should have been openapi deps 23 | matplotlib = "^3.8.0" 24 | plotly = "^5.18.0" 25 | scikit-learn = "^1.3.2" 26 | scipy = "^1.11.3" 27 | # Other dependencies 28 | loguru = "^0.7.2" 29 | numpy = "^1.26.1" 30 | openai = "^1.13.3" 31 | pandas = "^2.2.0" 32 | pillow = "^10.2.0" 33 | pydantic = "^2.6.1" 34 | streamlit = "^1.31.1" 35 | tiktoken = "^0.6.0" 36 | # Text to speech 37 | audio-recorder-streamlit = "^0.0.8" 38 | beautifulsoup4 = "^4.12.3" 39 | chime = "^0.7.0" 40 | duckduckgo-search = "^5.0" 41 | gtts = "^2.5.1" 42 | httpx = "^0.26.0" 43 | ipinfo = "^5.0.1" 44 | pydub = "^0.25.1" 45 | pygame = "^2.5.2" 46 | setuptools = "^68.2.2" # Needed by webrtcvad-wheels 47 | sounddevice = "^0.4.6" 48 | soundfile = "^0.12.1" 49 | speechrecognition = "^3.10.0" 50 | streamlit-mic-recorder = "^0.0.4" 51 | streamlit-webrtc = "^0.47.1" 52 | twilio = "^9.0.0" 53 | tzlocal = "^5.2" 54 | unidecode = "^1.3.7" 55 | webrtcvad-wheels = "^2.0.11.post1" 56 | 57 | [tool.poetry.group.dev.dependencies] 58 | ipython = "^8.16.1" 59 | 60 | [tool.poetry.group.linting.dependencies] 61 | black = "^24.2.0" 62 | flakeheaven = "^3.3.0" 63 | isort = "^5.13.2" 64 | pydoclint = "^0.4.0" 65 | ruff = "^0.3.0" 66 | 67 | [tool.poetry.group.test.dependencies] 68 | pytest = "^8.0.0" 69 | pytest-cov = "^4.1.0" 70 | pytest-mock = "^3.12.0" 71 | pytest-order = "^1.2.0" 72 | pytest-xdist = "^3.5.0" 73 | python-lorem = "^1.3.0.post1" 74 | 75 | ################## 76 | # Linter configs # 77 | ################## 78 | 79 | [tool.black] 80 | line-length = 90 81 | 82 | [tool.flakeheaven] 83 | base = ".flakeheaven.toml" 84 | 85 | [tool.isort] 86 | line_length = 90 87 | profile = "black" 88 | 89 | [tool.ruff] 90 | line-length = 90 91 | 92 | [tool.ruff.lint] 93 | # C901: Function is too complex. Ignoring this for now but will be removed later. 94 | ignore = ["C901", "D105", "EXE001", "RET504", "RUF012"] 95 | select = [ 96 | "A", 97 | "ARG", 98 | "B", 99 | "BLE", 100 | "C4", 101 | "C90", 102 | "D", 103 | "E", 104 | "ERA", 105 | "EXE", 106 | "F", 107 | "G", 108 | "I", 109 | "N", 110 | "PD", 111 | "PERF", 112 | "PIE", 113 | "PL", 114 | "PT", 115 | "Q", 116 | "RET", 117 | "RSE", 118 | "RUF", 119 | "S", 120 | "SIM", 121 | "SLF", 122 | "T20", 123 | "W", 124 | ] 125 | 126 | [tool.ruff.lint.per-file-ignores] 127 | # S101: Use of `assert` detected 128 | "tests/**/*.py" = [ 129 | "D100", 130 | "D101", 131 | "D102", 132 | "D103", 133 | "D104", 134 | "D105", 135 | "D106", 136 | "D107", 137 | "E501", 138 | "S101", 139 | "SLF001", 140 | ] 141 | 142 | [tool.ruff.lint.pydocstyle] 143 | convention = "google" 144 | 145 | ################## 146 | # pytest configs # 147 | ################## 148 | 149 | [tool.pytest.ini_options] 150 | addopts = """\ 151 | -n auto -v --cache-clear \ 152 | --failed-first \ 153 | --cov-reset --cov-report=term-missing:skip-covered \ 154 | --cov-report=xml:.coverage.xml --cov=./ \ 155 | """ 156 | log_cli_level = "INFO" 157 | testpaths = ["tests/smoke", "tests/unit"] 158 | 159 | [tool.coverage.report] 160 | omit = ["**/app/*"] 161 | 162 | #################################### 163 | # Leave configs for `poe` separate # 164 | #################################### 165 | 166 | [tool.poe] 167 | poetry_command = "devtools" 168 | 169 | [tool.poe.tasks] 170 | _black = "black ." 171 | _isort = "isort ." 172 | _ruff = "ruff check ." 173 | # Test-related tasks 174 | pytest = "pytest" 175 | test = ["pytest"] 176 | # Tasks to be run as pre-push checks 177 | pre-push-checks = ["lint", "pytest"] 178 | 179 | [tool.poe.tasks._flake8] 180 | cmd = "flakeheaven lint ." 181 | env = {FLAKEHEAVEN_CACHE_TIMEOUT = "0"} 182 | 183 | [tool.poe.tasks.lint] 184 | args = [{name = "fix", type = "boolean", default = false}] 185 | control = {expr = "fix"} 186 | 187 | [[tool.poe.tasks.lint.switch]] 188 | case = "True" 189 | sequence = ["_isort", "_black", "_ruff --fix", "_flake8"] 190 | 191 | [[tool.poe.tasks.lint.switch]] 192 | case = "False" 193 | sequence = ["_isort --check-only", "_black --check --diff", "_ruff", "_flake8"] 194 | -------------------------------------------------------------------------------- /pyrobbot/__init__.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | """Unnoficial OpenAI API UI and CLI tool.""" 3 | import os 4 | import sys 5 | import tempfile 6 | import uuid 7 | from collections import defaultdict 8 | from dataclasses import dataclass 9 | from importlib.metadata import metadata, version 10 | from pathlib import Path 11 | 12 | import ipinfo 13 | import requests 14 | from loguru import logger 15 | 16 | logger.remove() 17 | logger.add( 18 | sys.stderr, 19 | level=os.environ.get("LOGLEVEL", os.environ.get("LOGURU_LEVEL", "INFO")), 20 | ) 21 | 22 | os.environ["PYGAME_HIDE_SUPPORT_PROMPT"] = "hide" 23 | 24 | 25 | @dataclass 26 | class GeneralDefinitions: 27 | """General definitions for the package.""" 28 | 29 | # Main package info 30 | RUN_ID = uuid.uuid4().hex 31 | PACKAGE_NAME = __name__ 32 | VERSION = version(__name__) 33 | PACKAGE_DESCRIPTION = metadata(__name__)["Summary"] 34 | 35 | # Main package directories 36 | PACKAGE_DIRECTORY = Path(__file__).parent 37 | PACKAGE_CACHE_DIRECTORY = Path.home() / ".cache" / PACKAGE_NAME 38 | _PACKAGE_TMPDIR = tempfile.TemporaryDirectory() 39 | PACKAGE_TMPDIR = Path(_PACKAGE_TMPDIR.name) 40 | 41 | # Constants related to the app 42 | APP_NAME = "pyRobBot" 43 | APP_DIR = PACKAGE_DIRECTORY / "app" 44 | APP_PATH = APP_DIR / "app.py" 45 | PARSED_ARGS_FILE = PACKAGE_TMPDIR / f"parsed_args_{RUN_ID}.pkl" 46 | 47 | # Location info 48 | IPINFO = defaultdict(lambda: "unknown") 49 | try: 50 | IPINFO = ipinfo.getHandler().getDetails().all 51 | except ( 52 | requests.exceptions.ReadTimeout, 53 | requests.exceptions.ConnectionError, 54 | ipinfo.exceptions.RequestQuotaExceededError, 55 | ) as error: 56 | logger.warning("Cannot get current location info. {}", error) 57 | -------------------------------------------------------------------------------- /pyrobbot/__main__.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | """Program's entry point.""" 3 | from .argparse_wrapper import get_parsed_args 4 | 5 | 6 | def main(argv=None): 7 | """Program's main routine.""" 8 | args = get_parsed_args(argv=argv) 9 | args.run_command(args=args) 10 | -------------------------------------------------------------------------------- /pyrobbot/app/.streamlit/config.toml: -------------------------------------------------------------------------------- 1 | # Stremlit configs. 2 | # See . 3 | [browser] 4 | gatherUsageStats = false 5 | 6 | [runner] 7 | fastReruns = true 8 | 9 | [server] 10 | runOnSave = true 11 | 12 | [theme] 13 | base = "light" 14 | # Colors 15 | primaryColor = "#2BB5E8" 16 | -------------------------------------------------------------------------------- /pyrobbot/app/__init__.py: -------------------------------------------------------------------------------- 1 | """UI for the package.""" 2 | -------------------------------------------------------------------------------- /pyrobbot/app/app.py: -------------------------------------------------------------------------------- 1 | """Entrypoint for the package's UI.""" 2 | 3 | from pyrobbot import GeneralDefinitions 4 | from pyrobbot.app.multipage import MultipageChatbotApp 5 | 6 | 7 | def run_app(): 8 | """Create and run an instance of the pacage's app.""" 9 | MultipageChatbotApp( 10 | page_title=GeneralDefinitions.APP_NAME, 11 | page_icon=":speech_balloon:", 12 | layout="wide", 13 | ).render() 14 | 15 | 16 | if __name__ == "__main__": 17 | run_app() 18 | -------------------------------------------------------------------------------- /pyrobbot/app/app_page_templates.py: -------------------------------------------------------------------------------- 1 | """Utilities for creating pages in a streamlit app.""" 2 | 3 | import base64 4 | import contextlib 5 | import datetime 6 | import queue 7 | import time 8 | import uuid 9 | from abc import ABC, abstractmethod 10 | from pathlib import Path 11 | from typing import TYPE_CHECKING, Union 12 | 13 | import streamlit as st 14 | from audio_recorder_streamlit import audio_recorder 15 | from loguru import logger 16 | from pydub import AudioSegment 17 | from pydub.exceptions import CouldntDecodeError 18 | from streamlit_mic_recorder import mic_recorder 19 | 20 | from pyrobbot.chat_configs import VoiceChatConfigs 21 | 22 | from .app_utils import ( 23 | AsyncReplier, 24 | WebAppChat, 25 | filter_page_info_from_queue, 26 | get_avatar_images, 27 | load_chime, 28 | ) 29 | 30 | if TYPE_CHECKING: 31 | from .multipage import MultipageChatbotApp 32 | 33 | # Sentinel object for when a chat is recovered from cache 34 | _RecoveredChat = object() 35 | 36 | 37 | class AppPage(ABC): 38 | """Abstract base class for a page within a streamlit application.""" 39 | 40 | def __init__( 41 | self, parent: "MultipageChatbotApp", sidebar_title: str = "", page_title: str = "" 42 | ): 43 | """Initializes a new instance of the AppPage class. 44 | 45 | Args: 46 | parent (MultipageChatbotApp): The parent app of the page. 47 | sidebar_title (str, optional): The title to be displayed in the sidebar. 48 | Defaults to an empty string. 49 | page_title (str, optional): The title to be displayed on the page. 50 | Defaults to an empty string. 51 | """ 52 | self.page_id = str(uuid.uuid4()) 53 | self.parent = parent 54 | self.page_number = self.parent.state.get("n_created_pages", 0) + 1 55 | 56 | chat_number_for_title = f"Chat #{self.page_number}" 57 | if page_title is _RecoveredChat: 58 | self.fallback_page_title = f"{chat_number_for_title.strip('#')} (Recovered)" 59 | page_title = None 60 | else: 61 | self.fallback_page_title = chat_number_for_title 62 | if page_title: 63 | self.title = page_title 64 | 65 | self._fallback_sidebar_title = page_title if page_title else chat_number_for_title 66 | if sidebar_title: 67 | self.sidebar_title = sidebar_title 68 | 69 | @property 70 | def state(self): 71 | """Return the state of the page, for persistence of data.""" 72 | if self.page_id not in self.parent.state: 73 | self.parent.state[self.page_id] = {} 74 | return self.parent.state[self.page_id] 75 | 76 | @property 77 | def sidebar_title(self): 78 | """Get the title of the page in the sidebar.""" 79 | return self.state.get("sidebar_title", self._fallback_sidebar_title) 80 | 81 | @sidebar_title.setter 82 | def sidebar_title(self, value: str): 83 | """Set the sidebar title for the page.""" 84 | self.state["sidebar_title"] = value 85 | 86 | @property 87 | def title(self): 88 | """Get the title of the page.""" 89 | return self.state.get("page_title", self.fallback_page_title) 90 | 91 | @title.setter 92 | def title(self, value: str): 93 | """Set the title of the page.""" 94 | self.state["page_title"] = value 95 | 96 | @abstractmethod 97 | def render(self): 98 | """Create the page.""" 99 | 100 | def continuous_mic_recorder(self): 101 | """Record audio from the microphone in a continuous loop.""" 102 | audio_bytes = audio_recorder( 103 | text="", icon_size="2x", energy_threshold=-1, key=f"AR_{self.page_id}" 104 | ) 105 | 106 | if audio_bytes is None: 107 | return AudioSegment.silent(duration=0) 108 | 109 | return AudioSegment(data=audio_bytes) 110 | 111 | def manual_switch_mic_recorder(self): 112 | """Record audio from the microphone.""" 113 | red_square = "\U0001F7E5" 114 | microphone = "\U0001F3A4" 115 | play_button = "\U000025B6" 116 | 117 | recording = mic_recorder( 118 | key=f"audiorecorder_widget_{self.page_id}", 119 | start_prompt=play_button + microphone, 120 | stop_prompt=red_square, 121 | just_once=True, 122 | use_container_width=True, 123 | ) 124 | 125 | if recording is None: 126 | return AudioSegment.silent(duration=0) 127 | 128 | return AudioSegment( 129 | data=recording["bytes"], 130 | sample_width=recording["sample_width"], 131 | frame_rate=recording["sample_rate"], 132 | channels=1, 133 | ) 134 | 135 | def render_custom_audio_player( 136 | self, 137 | audio: Union[AudioSegment, str, Path, None], 138 | parent_element=None, 139 | autoplay: bool = True, 140 | hidden=False, 141 | ): 142 | """Autoplay an audio segment in the streamlit app.""" 143 | # Adaped from: 146 | 147 | if audio is None: 148 | logger.debug("No audio to play. Not rendering audio player.") 149 | return 150 | 151 | if isinstance(audio, (str, Path)): 152 | audio = AudioSegment.from_file(audio, format="mp3") 153 | elif not isinstance(audio, AudioSegment): 154 | raise TypeError(f"Invalid type for audio: {type(audio)}") 155 | 156 | autoplay = "autoplay" if autoplay else "" 157 | hidden = "hidden" if hidden else "" 158 | 159 | data = audio.export(format="mp3").read() 160 | b64 = base64.b64encode(data).decode() 161 | md = f""" 162 | 165 | """ 166 | parent_element = parent_element or st 167 | parent_element.markdown(md, unsafe_allow_html=True) 168 | if autoplay: 169 | time.sleep(audio.duration_seconds) 170 | 171 | 172 | class ChatBotPage(AppPage): 173 | """Implement a chatbot page in a streamlit application, inheriting from AppPage.""" 174 | 175 | def __init__( 176 | self, 177 | parent: "MultipageChatbotApp", 178 | chat_obj: WebAppChat = None, 179 | sidebar_title: str = "", 180 | page_title: str = "", 181 | ): 182 | """Initialize new instance of the ChatBotPage class with an opt WebAppChat object. 183 | 184 | Args: 185 | parent (MultipageChatbotApp): The parent app of the page. 186 | chat_obj (WebAppChat): The chat object. Defaults to None. 187 | sidebar_title (str): The sidebar title for the chatbot page. 188 | Defaults to an empty string. 189 | page_title (str): The title for the chatbot page. 190 | Defaults to an empty string. 191 | """ 192 | super().__init__( 193 | parent=parent, sidebar_title=sidebar_title, page_title=page_title 194 | ) 195 | 196 | if chat_obj: 197 | logger.debug("Setting page chat to chat with ID=<{}>", chat_obj.id) 198 | self.chat_obj = chat_obj 199 | else: 200 | logger.debug("ChatBotPage created wihout specific chat. Creating default.") 201 | _ = self.chat_obj 202 | logger.debug("Default chat id=<{}>", self.chat_obj.id) 203 | 204 | self.avatars = get_avatar_images() 205 | 206 | @property 207 | def chat_configs(self) -> VoiceChatConfigs: 208 | """Return the configs used for the page's chat object.""" 209 | if "chat_configs" not in self.state: 210 | self.state["chat_configs"] = self.parent.state["chat_configs"] 211 | return self.state["chat_configs"] 212 | 213 | @chat_configs.setter 214 | def chat_configs(self, value: VoiceChatConfigs): 215 | self.state["chat_configs"] = VoiceChatConfigs.model_validate(value) 216 | if "chat_obj" in self.state: 217 | del self.state["chat_obj"] 218 | 219 | @property 220 | def chat_obj(self) -> WebAppChat: 221 | """Return the chat object responsible for the queries on this page.""" 222 | if "chat_obj" not in self.state: 223 | self.chat_obj = WebAppChat( 224 | configs=self.chat_configs, openai_client=self.parent.openai_client 225 | ) 226 | return self.state["chat_obj"] 227 | 228 | @chat_obj.setter 229 | def chat_obj(self, new_chat_obj: WebAppChat): 230 | current_chat = self.state.get("chat_obj") 231 | if current_chat: 232 | logger.debug( 233 | "Copy new_chat=<{}> into current_chat=<{}>. Current chat ID kept.", 234 | new_chat_obj.id, 235 | current_chat.id, 236 | ) 237 | current_chat.save_cache() 238 | new_chat_obj.id = current_chat.id 239 | new_chat_obj.openai_client = self.parent.openai_client 240 | self.state["chat_obj"] = new_chat_obj 241 | self.state["chat_configs"] = new_chat_obj.configs 242 | new_chat_obj.save_cache() 243 | 244 | @property 245 | def chat_history(self) -> list[dict[str, str]]: 246 | """Return the chat history of the page.""" 247 | if "messages" not in self.state: 248 | self.state["messages"] = [] 249 | return self.state["messages"] 250 | 251 | def render_chat_history(self): 252 | """Render the chat history of the page. Do not include system messages.""" 253 | with st.chat_message("assistant", avatar=self.avatars["assistant"]): 254 | st.markdown(self.chat_obj.initial_greeting) 255 | 256 | for message in self.chat_history: 257 | role = message["role"] 258 | if role == "system": 259 | continue 260 | with st.chat_message(role, avatar=self.avatars.get(role)): 261 | with contextlib.suppress(KeyError): 262 | if role == "assistant": 263 | st.caption(message["chat_model"]) 264 | else: 265 | st.caption(message["timestamp"]) 266 | st.markdown(message["content"]) 267 | with contextlib.suppress(KeyError): 268 | if audio := message.get("reply_audio_file_path"): 269 | with contextlib.suppress(CouldntDecodeError): 270 | self.render_custom_audio_player(audio, autoplay=False) 271 | 272 | def render_cost_estimate_page(self): 273 | """Render the estimated costs information in the chat.""" 274 | general_df = self.chat_obj.general_token_usage_db.get_usage_balance_dataframe() 275 | chat_df = self.chat_obj.token_usage_db.get_usage_balance_dataframe() 276 | dfs = {"All Recorded Chats": general_df, "Current Chat": chat_df} 277 | 278 | st.header(dfs["Current Chat"].attrs["description"], divider="rainbow") 279 | with st.container(): 280 | for category, df in dfs.items(): 281 | st.subheader(f"**{category}**") 282 | st.dataframe(df) 283 | st.write() 284 | st.caption(df.attrs["disclaimer"]) 285 | 286 | @property 287 | def voice_output(self) -> bool: 288 | """Return the state of the voice output toggle.""" 289 | return st.session_state.get("toggle_voice_output", False) 290 | 291 | def play_chime(self, chime_type: str = "success", parent_element=None): 292 | """Sound a chime to send notificatons to the user.""" 293 | chime = load_chime(chime_type) 294 | self.render_custom_audio_player( 295 | chime, hidden=True, autoplay=True, parent_element=parent_element 296 | ) 297 | 298 | def render_title(self): 299 | """Render the title of the chatbot page.""" 300 | with st.container(height=145, border=False): 301 | self.title_container = st.empty() 302 | self.title_container.subheader(self.title, divider="rainbow") 303 | left, _ = st.columns([0.7, 0.3]) 304 | with left: 305 | self.status_msg_container = st.empty() 306 | 307 | @property 308 | def direct_text_prompt(self): 309 | """Render chat inut widgets and return the user's input.""" 310 | placeholder = ( 311 | f"Send a message to {self.chat_obj.assistant_name} ({self.chat_obj.model})" 312 | ) 313 | text_from_manual_audio_recorder = "" 314 | with st.container(): 315 | left, right = st.columns([0.9, 0.1]) 316 | with left: 317 | text_from_chat_input_widget = st.chat_input(placeholder=placeholder) 318 | with right: 319 | if not st.session_state.get("toggle_continuous_voice_input"): 320 | audio = self.manual_switch_mic_recorder() 321 | text_from_manual_audio_recorder = self.chat_obj.stt(audio).text 322 | 323 | return text_from_chat_input_widget or text_from_manual_audio_recorder 324 | 325 | @property 326 | def continuous_text_prompt(self): 327 | """Wait until a promp from the continuous stream is ready and return it.""" 328 | if not st.session_state.get("toggle_continuous_voice_input"): 329 | return None 330 | 331 | if not self.parent.continuous_audio_input_engine_is_running: 332 | logger.warning("Continuous audio input engine is not running!!!") 333 | self.status_msg_container.error( 334 | "The continuous audio input engine is not running!!!" 335 | ) 336 | return None 337 | 338 | logger.debug("Running on continuous audio prompt. Waiting user input...") 339 | with self.status_msg_container: 340 | self.play_chime(chime_type="warning") 341 | with st.spinner(f"{self.chat_obj.assistant_name} is listening..."): 342 | while True: 343 | with self.parent.text_prompt_queue.mutex: 344 | this_page_prompt_queue = filter_page_info_from_queue( 345 | app_page=self, the_queue=self.parent.text_prompt_queue 346 | ) 347 | with contextlib.suppress(queue.Empty): 348 | if prompt := this_page_prompt_queue.get_nowait()["text"]: 349 | this_page_prompt_queue.task_done() 350 | break 351 | logger.trace("Still waiting for user text prompt...") 352 | time.sleep(0.1) 353 | 354 | logger.debug("Done getting user input: {}", prompt) 355 | return prompt 356 | 357 | def _render_chatbot_page(self): # noqa: PLR0915 358 | """Render a chatbot page. 359 | 360 | Adapted from: 361 | 362 | 363 | """ 364 | self.chat_obj.reply_only_as_text = not self.voice_output 365 | 366 | self.render_title() 367 | chat_msgs_container = st.container(height=550, border=False) 368 | with chat_msgs_container: 369 | self.render_chat_history() 370 | 371 | # The inputs should be rendered after the chat history. There is a performance 372 | # penalty otherwise, as rendering the history causes streamlit to rerun the 373 | # entire page 374 | direct_text_prompt = self.direct_text_prompt 375 | continuous_stt_prompt = "" if direct_text_prompt else self.continuous_text_prompt 376 | prompt = direct_text_prompt or continuous_stt_prompt 377 | 378 | if prompt: 379 | logger.opt(colors=True).debug("Recived prompt: {}", prompt) 380 | self.parent.reply_ongoing.set() 381 | 382 | if continuous_stt_prompt: 383 | self.play_chime("success") 384 | self.status_msg_container.success("Got your message!") 385 | time.sleep(0.5) 386 | elif continuous_stt_prompt: 387 | self.status_msg_container.warning( 388 | "Could not understand your message. Please try again." 389 | ) 390 | logger.opt(colors=True).debug("Received empty prompt") 391 | self.parent.reply_ongoing.clear() 392 | 393 | if prompt: 394 | with chat_msgs_container: 395 | # Process user input 396 | if prompt: 397 | time_now = datetime.datetime.now().replace(microsecond=0) 398 | self.state.update({"chat_started": True}) 399 | # Display user message in chat message container 400 | with st.chat_message("user", avatar=self.avatars["user"]): 401 | st.caption(time_now) 402 | st.markdown(prompt) 403 | self.chat_history.append( 404 | { 405 | "role": "user", 406 | "name": self.chat_obj.username, 407 | "content": prompt, 408 | "timestamp": time_now, 409 | } 410 | ) 411 | 412 | # Display (stream) assistant response in chat message container 413 | with st.chat_message("assistant", avatar=self.avatars["assistant"]): 414 | # Process text and audio replies asynchronously 415 | replier = AsyncReplier(self, prompt) 416 | reply = replier.stream_text_and_audio_reply() 417 | self.chat_history.append( 418 | { 419 | "role": "assistant", 420 | "name": self.chat_obj.assistant_name, 421 | "content": reply["text"], 422 | "reply_audio_file_path": reply["audio"], 423 | "chat_model": self.chat_obj.model, 424 | } 425 | ) 426 | 427 | # Reset title according to conversation initial contents 428 | min_history_len_for_summary = 3 429 | if ( 430 | "page_title" not in self.state 431 | and len(self.chat_history) > min_history_len_for_summary 432 | ): 433 | logger.debug("Working out conversation topic...") 434 | prompt = "Summarize the previous messages in max 4 words" 435 | title = "".join(self.chat_obj.respond_system_prompt(prompt)) 436 | self.chat_obj.metadata["page_title"] = title 437 | self.chat_obj.metadata["sidebar_title"] = title 438 | self.chat_obj.save_cache() 439 | 440 | self.title = title 441 | self.sidebar_title = title 442 | self.title_container.header(title, divider="rainbow") 443 | 444 | # Clear the prompt queue for this page, to remove old prompts 445 | with self.parent.continuous_user_prompt_queue.mutex: 446 | filter_page_info_from_queue( 447 | app_page=self, 448 | the_queue=self.parent.continuous_user_prompt_queue, 449 | ) 450 | with self.parent.text_prompt_queue.mutex: 451 | filter_page_info_from_queue( 452 | app_page=self, the_queue=self.parent.text_prompt_queue 453 | ) 454 | 455 | replier.join() 456 | self.parent.reply_ongoing.clear() 457 | 458 | if continuous_stt_prompt and not self.parent.reply_ongoing.is_set(): 459 | logger.opt(colors=True).debug( 460 | "Rerunning the app to wait for new input..." 461 | ) 462 | st.rerun() 463 | 464 | def render(self): 465 | """Render the app's chatbot or costs page, depending on user choice.""" 466 | 467 | def _trim_page_padding(): 468 | md = """ 469 | 477 | """ 478 | st.markdown(md, unsafe_allow_html=True) 479 | 480 | _trim_page_padding() 481 | if st.session_state.get("toggle_show_costs"): 482 | self.render_cost_estimate_page() 483 | else: 484 | self._render_chatbot_page() 485 | logger.debug("Reached the end of the chatbot page.") 486 | -------------------------------------------------------------------------------- /pyrobbot/app/app_utils.py: -------------------------------------------------------------------------------- 1 | """Utility functions and classes for the app.""" 2 | 3 | import contextlib 4 | import datetime 5 | import os 6 | import queue 7 | import threading 8 | from typing import TYPE_CHECKING 9 | 10 | import streamlit as st 11 | from loguru import logger 12 | from PIL import Image 13 | from pydub import AudioSegment 14 | from streamlit.runtime.scriptrunner import add_script_run_ctx 15 | from twilio.rest import Client as TwilioClient 16 | 17 | from pyrobbot import GeneralDefinitions 18 | from pyrobbot.chat import AssistantResponseChunk 19 | from pyrobbot.voice_chat import VoiceChat 20 | 21 | if TYPE_CHECKING: 22 | from .app_page_templates import AppPage 23 | 24 | 25 | class WebAppChat(VoiceChat): 26 | """A chat object for web apps.""" 27 | 28 | def __init__(self, **kwargs): 29 | """Initialize a new instance of the WebAppChat class.""" 30 | super().__init__(**kwargs) 31 | self.tts_conversion_watcher_thread.start() 32 | self.handle_update_audio_history_thread.start() 33 | 34 | 35 | class AsyncReplier: 36 | """Asynchronously reply to a prompt and stream the text & audio reply.""" 37 | 38 | def __init__(self, app_page: "AppPage", prompt: str): 39 | """Initialize a new instance of the AsyncReplier class.""" 40 | self.app_page = app_page 41 | self.prompt = prompt 42 | 43 | self.chat_obj = app_page.chat_obj 44 | self.question_answer_chunks_queue = queue.Queue() 45 | 46 | self.threads = [ 47 | threading.Thread(name="queue_text_chunks", target=self.queue_text_chunks), 48 | threading.Thread(name="play_queued_audios", target=self.play_queued_audios), 49 | ] 50 | 51 | self.start() 52 | 53 | def start(self): 54 | """Start the threads.""" 55 | for thread in self.threads: 56 | add_script_run_ctx(thread) 57 | thread.start() 58 | 59 | def join(self): 60 | """Wait for all threads to finish.""" 61 | logger.debug("Waiting for {} to finish...", type(self).__name__) 62 | for thread in self.threads: 63 | thread.join() 64 | logger.debug("All {} threads finished", type(self).__name__) 65 | 66 | def queue_text_chunks(self): 67 | """Get chunks of the text reply to the prompt and queue them for display.""" 68 | exchange_id = None 69 | for chunk in self.chat_obj.answer_question(self.prompt): 70 | self.question_answer_chunks_queue.put(chunk) 71 | exchange_id = chunk.exchange_id 72 | self.question_answer_chunks_queue.put( 73 | AssistantResponseChunk(exchange_id=exchange_id, content=None) 74 | ) 75 | 76 | def play_queued_audios(self): 77 | """Play queued audio segments.""" 78 | while True: 79 | try: 80 | logger.debug( 81 | "Waiting for item from the audio reply chunk queue ({}) items so far", 82 | self.chat_obj.play_speech_queue.qsize(), 83 | ) 84 | speech_queue_item = self.chat_obj.play_speech_queue.get() 85 | audio = speech_queue_item["speech"] 86 | if audio is None: 87 | logger.debug("Got `None`. No more audio reply chunks to play") 88 | self.chat_obj.play_speech_queue.task_done() 89 | break 90 | 91 | logger.debug("Playing audio reply chunk ({}s)", audio.duration_seconds) 92 | self.app_page.render_custom_audio_player( 93 | audio, 94 | parent_element=self.app_page.status_msg_container, 95 | autoplay=True, 96 | hidden=True, 97 | ) 98 | logger.debug( 99 | "Done playing audio reply chunk ({}s)", audio.duration_seconds 100 | ) 101 | self.chat_obj.play_speech_queue.task_done() 102 | except Exception as error: # noqa: BLE001 103 | logger.opt(exception=True).debug( 104 | "Error playing audio reply chunk ({}s)", audio.duration_seconds 105 | ) 106 | logger.error(error) 107 | break 108 | finally: 109 | self.app_page.status_msg_container.empty() 110 | 111 | def stream_text_and_audio_reply(self): 112 | """Stream the text and audio reply to the display.""" 113 | text_reply_container = st.empty() 114 | audio_reply_container = st.empty() 115 | 116 | chunk = AssistantResponseChunk(exchange_id=None, content="") 117 | full_response = "" 118 | text_reply_container.markdown("▌") 119 | self.app_page.status_msg_container.empty() 120 | while chunk.content is not None: 121 | logger.trace("Waiting for text or audio chunks...") 122 | # Render text 123 | with contextlib.suppress(queue.Empty): 124 | chunk = self.question_answer_chunks_queue.get_nowait() 125 | if chunk.content is not None: 126 | full_response += chunk.content 127 | text_reply_container.markdown(full_response + "▌") 128 | self.question_answer_chunks_queue.task_done() 129 | 130 | text_reply_container.caption(datetime.datetime.now().replace(microsecond=0)) 131 | text_reply_container.markdown(full_response) 132 | 133 | logger.debug("Waiting for the audio reply to finish...") 134 | self.chat_obj.play_speech_queue.join() 135 | 136 | logger.debug("Getting path to full audio file for the reply...") 137 | history_entry_for_this_reply = ( 138 | self.chat_obj.context_handler.database.retrieve_history( 139 | exchange_id=chunk.exchange_id 140 | ) 141 | ) 142 | full_audio_fpath = history_entry_for_this_reply["reply_audio_file_path"].iloc[0] 143 | if full_audio_fpath is None: 144 | logger.warning("Path to full audio file not available") 145 | else: 146 | logger.debug("Got path to full audio file: {}", full_audio_fpath) 147 | self.app_page.render_custom_audio_player( 148 | full_audio_fpath, parent_element=audio_reply_container, autoplay=False 149 | ) 150 | 151 | return {"text": full_response, "audio": full_audio_fpath} 152 | 153 | 154 | @st.cache_data 155 | def get_ice_servers(): 156 | """Use Twilio's TURN server as recommended by the streamlit-webrtc developers.""" 157 | try: 158 | account_sid = os.environ["TWILIO_ACCOUNT_SID"] 159 | auth_token = os.environ["TWILIO_AUTH_TOKEN"] 160 | except KeyError: 161 | logger.warning( 162 | "Twilio credentials are not set. Cannot use their TURN servers. " 163 | "Falling back to a free STUN server from Google." 164 | ) 165 | return [{"urls": ["stun:stun.l.google.com:19302"]}] 166 | 167 | client = TwilioClient(account_sid, auth_token) 168 | token = client.tokens.create() 169 | return token.ice_servers 170 | 171 | 172 | def filter_page_info_from_queue(app_page: "AppPage", the_queue: queue.Queue): 173 | """Filter `app_page`'s data from `queue` inplace. Return queue of items in `app_page`. 174 | 175 | **Use with original_queue.mutex!!** 176 | 177 | Args: 178 | app_page: The page whose entries should be removed. 179 | the_queue: The queue to be filtered. 180 | 181 | Returns: 182 | queue.Queue: The queue with only the entries from `app_page`. 183 | 184 | Example: 185 | ``` 186 | with the_queue.mutex: 187 | this_page_data = remove_page_info_from_queue(app_page, the_queue) 188 | ``` 189 | """ 190 | queue_with_only_entries_from_other_pages = queue.Queue() 191 | items_from_page_queue = queue.Queue() 192 | while the_queue.queue: 193 | original_queue_entry = the_queue.queue.popleft() 194 | if original_queue_entry["page"].page_id == app_page.page_id: 195 | items_from_page_queue.put(original_queue_entry) 196 | else: 197 | queue_with_only_entries_from_other_pages.put(original_queue_entry) 198 | 199 | the_queue.queue = queue_with_only_entries_from_other_pages.queue 200 | return items_from_page_queue 201 | 202 | 203 | @st.cache_data 204 | def get_avatar_images(): 205 | """Return the avatar images for the assistant and the user.""" 206 | avatar_files_dir = GeneralDefinitions.APP_DIR / "data" 207 | assistant_avatar_file_path = avatar_files_dir / "assistant_avatar.png" 208 | user_avatar_file_path = avatar_files_dir / "user_avatar.png" 209 | assistant_avatar_image = Image.open(assistant_avatar_file_path) 210 | user_avatar_image = Image.open(user_avatar_file_path) 211 | 212 | return {"assistant": assistant_avatar_image, "user": user_avatar_image} 213 | 214 | 215 | @st.cache_data 216 | def load_chime(chime_type: str) -> AudioSegment: 217 | """Load a chime sound from the data directory.""" 218 | return AudioSegment.from_file( 219 | GeneralDefinitions.APP_DIR / "data" / f"{chime_type}.wav", format="wav" 220 | ) 221 | -------------------------------------------------------------------------------- /pyrobbot/app/data/assistant_avatar.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/paulovcmedeiros/pyRobBot/7e77d3b1aee052cfa350a806371de75b9b713ad6/pyrobbot/app/data/assistant_avatar.png -------------------------------------------------------------------------------- /pyrobbot/app/data/powered-by-openai-badge-outlined-on-dark.svg: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /pyrobbot/app/data/success.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/paulovcmedeiros/pyRobBot/7e77d3b1aee052cfa350a806371de75b9b713ad6/pyrobbot/app/data/success.wav -------------------------------------------------------------------------------- /pyrobbot/app/data/user_avatar.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/paulovcmedeiros/pyRobBot/7e77d3b1aee052cfa350a806371de75b9b713ad6/pyrobbot/app/data/user_avatar.png -------------------------------------------------------------------------------- /pyrobbot/app/data/warning.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/paulovcmedeiros/pyRobBot/7e77d3b1aee052cfa350a806371de75b9b713ad6/pyrobbot/app/data/warning.wav -------------------------------------------------------------------------------- /pyrobbot/argparse_wrapper.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | """Wrappers for argparse functionality.""" 3 | import argparse 4 | import contextlib 5 | import sys 6 | 7 | from pydantic import BaseModel 8 | 9 | from . import GeneralDefinitions 10 | from .chat_configs import ChatOptions, VoiceChatConfigs 11 | from .command_definitions import ( 12 | accounting_report, 13 | browser_chat, 14 | terminal_chat, 15 | voice_chat, 16 | ) 17 | 18 | 19 | def _populate_parser_from_pydantic_model(parser, model: BaseModel): 20 | _argarse2pydantic = { 21 | "type": model.get_type, 22 | "default": model.get_default, 23 | "choices": model.get_allowed_values, 24 | "help": model.get_description, 25 | } 26 | 27 | for field_name, field in model.model_fields.items(): 28 | with contextlib.suppress(AttributeError): 29 | if not field.json_schema_extra.get("changeable", True): 30 | continue 31 | 32 | args_opts = { 33 | key: _argarse2pydantic[key](field_name) 34 | for key in _argarse2pydantic 35 | if _argarse2pydantic[key](field_name) is not None 36 | } 37 | 38 | if args_opts.get("type") == bool: 39 | if args_opts.get("default") is True: 40 | args_opts["action"] = "store_false" 41 | else: 42 | args_opts["action"] = "store_true" 43 | args_opts.pop("default", None) 44 | args_opts.pop("type", None) 45 | 46 | args_opts["required"] = field.is_required() 47 | if "help" in args_opts: 48 | args_opts["help"] = f"{args_opts['help']} (default: %(default)s)" 49 | if "default" in args_opts and isinstance(args_opts["default"], (list, tuple)): 50 | args_opts.pop("type", None) 51 | args_opts["nargs"] = "*" 52 | 53 | parser.add_argument(f"--{field_name.replace('_', '-')}", **args_opts) 54 | 55 | return parser 56 | 57 | 58 | def get_parsed_args(argv=None, default_command="ui"): 59 | """Get parsed command line arguments. 60 | 61 | Args: 62 | argv (list): A list of passed command line args. 63 | default_command (str, optional): The default command to run. 64 | 65 | Returns: 66 | argparse.Namespace: Parsed command line arguments. 67 | 68 | """ 69 | if argv is None: 70 | argv = sys.argv[1:] 71 | first_argv = next(iter(argv), "'") 72 | info_flags = ["--version", "-v", "-h", "--help"] 73 | if not argv or (first_argv.startswith("-") and first_argv not in info_flags): 74 | argv = [default_command, *argv] 75 | 76 | # Main parser that will handle the script's commands 77 | main_parser = argparse.ArgumentParser( 78 | formatter_class=argparse.ArgumentDefaultsHelpFormatter 79 | ) 80 | main_parser.add_argument( 81 | "--version", 82 | "-v", 83 | action="version", 84 | version=f"{GeneralDefinitions.PACKAGE_NAME} v" + GeneralDefinitions.VERSION, 85 | ) 86 | subparsers = main_parser.add_subparsers( 87 | title="commands", 88 | dest="command", 89 | required=True, 90 | description=( 91 | "Valid commands (note that commands also accept their " 92 | + "own arguments, in particular [-h]):" 93 | ), 94 | help="command description", 95 | ) 96 | 97 | # Common options to most commands 98 | chat_options_parser = _populate_parser_from_pydantic_model( 99 | parser=argparse.ArgumentParser( 100 | formatter_class=argparse.ArgumentDefaultsHelpFormatter, add_help=False 101 | ), 102 | model=ChatOptions, 103 | ) 104 | chat_options_parser.add_argument( 105 | "--report-accounting-when-done", 106 | action="store_true", 107 | help="Report estimated costs when done with the chat.", 108 | ) 109 | 110 | # Web app chat 111 | parser_ui = subparsers.add_parser( 112 | "ui", 113 | aliases=["app", "webapp", "browser"], 114 | parents=[chat_options_parser], 115 | help="Run the chat UI on the browser.", 116 | ) 117 | parser_ui.set_defaults(run_command=browser_chat) 118 | 119 | # Voice chat 120 | voice_options_parser = _populate_parser_from_pydantic_model( 121 | parser=argparse.ArgumentParser( 122 | formatter_class=argparse.ArgumentDefaultsHelpFormatter, add_help=False 123 | ), 124 | model=VoiceChatConfigs, 125 | ) 126 | parser_voice_chat = subparsers.add_parser( 127 | "voice", 128 | aliases=["v", "speech", "talk"], 129 | parents=[voice_options_parser], 130 | help="Run the chat over voice only.", 131 | ) 132 | parser_voice_chat.set_defaults(run_command=voice_chat) 133 | 134 | # Terminal chat 135 | parser_terminal = subparsers.add_parser( 136 | "terminal", 137 | aliases=["."], 138 | parents=[chat_options_parser], 139 | help="Run the chat on the terminal.", 140 | ) 141 | parser_terminal.set_defaults(run_command=terminal_chat) 142 | 143 | # Accounting report 144 | parser_accounting = subparsers.add_parser( 145 | "accounting", 146 | aliases=["acc"], 147 | help="Show the estimated number of used tokens and associated costs, and exit.", 148 | ) 149 | parser_accounting.set_defaults(run_command=accounting_report) 150 | 151 | return main_parser.parse_args(argv) 152 | -------------------------------------------------------------------------------- /pyrobbot/chat.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | """Implementation of the Chat class.""" 3 | import contextlib 4 | import json 5 | import shutil 6 | import uuid 7 | from collections import defaultdict 8 | from datetime import datetime 9 | from pathlib import Path 10 | from typing import Optional 11 | 12 | import openai 13 | from attr import dataclass 14 | from loguru import logger 15 | from pydub import AudioSegment 16 | from tzlocal import get_localzone 17 | 18 | from . import GeneralDefinitions 19 | from .chat_configs import ChatOptions 20 | from .chat_context import EmbeddingBasedChatContext, FullHistoryChatContext 21 | from .general_utils import ( 22 | AlternativeConstructors, 23 | ReachedMaxNumberOfAttemptsError, 24 | get_call_traceback, 25 | ) 26 | from .internet_utils import websearch 27 | from .openai_utils import OpenAiClientWrapper, make_api_chat_completion_call 28 | from .sst_and_tts import SpeechToText, TextToSpeech 29 | from .tokens import PRICE_PER_K_TOKENS_EMBEDDINGS, TokenUsageDatabase 30 | 31 | 32 | @dataclass 33 | class AssistantResponseChunk: 34 | """A chunk of the assistant's response.""" 35 | 36 | exchange_id: str 37 | content: str 38 | chunk_type: str = "text" 39 | 40 | 41 | class Chat(AlternativeConstructors): 42 | """Manages conversations with an AI chat model. 43 | 44 | This class encapsulates the chat behavior, including handling the chat context, 45 | managing cache directories, and interfacing with the OpenAI API for generating chat 46 | responses. 47 | """ 48 | 49 | _translation_cache = defaultdict(dict) 50 | default_configs = ChatOptions() 51 | 52 | def __init__( 53 | self, 54 | openai_client: OpenAiClientWrapper = None, 55 | configs: ChatOptions = default_configs, 56 | ): 57 | """Initializes a chat instance. 58 | 59 | Args: 60 | configs (ChatOptions, optional): The configurations for this chat session. 61 | openai_client (openai.OpenAI, optional): An OpenAiClientWrapper instance. 62 | 63 | Raises: 64 | NotImplementedError: If the context model specified in configs is unknown. 65 | """ 66 | self.id = str(uuid.uuid4()) 67 | logger.trace( 68 | "Init chat {}, as requested by from <{}>", self.id, get_call_traceback() 69 | ) 70 | logger.debug("Init chat {}", self.id) 71 | 72 | self._code_marker = "\uE001" # TEST 73 | 74 | self._passed_configs = configs 75 | for field in self._passed_configs.model_fields: 76 | setattr(self, field, self._passed_configs[field]) 77 | 78 | try: 79 | self.openai_client = ( 80 | openai_client 81 | if openai_client is not None 82 | else OpenAiClientWrapper( 83 | timeout=self.timeout, private_mode=self.private_mode 84 | ) 85 | ) 86 | except openai.OpenAIError as error: 87 | logger.opt(exception=True).debug(error) 88 | logger.error( 89 | "Cannot connect to OpenAI API. Please verify your API key. {}.", error 90 | ) 91 | self.openai_client = None 92 | 93 | @property 94 | def base_directive(self): 95 | """Return the base directive for the LLM.""" 96 | code_marker = self._code_marker 97 | local_datetime = datetime.now(get_localzone()).isoformat(timespec="seconds") 98 | msg_content = ( 99 | f"Your name is {self.assistant_name}. Your model is {self.model}\n" 100 | f"You are a helpful assistant to {self.username}\n" 101 | f"You have internet access\n" 102 | f"You MUST ALWAYS write {code_marker} before AND after code blocks. Example: " 103 | f"```foo ... ``` MUST become {code_marker}```foo ... ```{code_marker}\n" 104 | f"The current city is {GeneralDefinitions.IPINFO['city']} in " 105 | f"{GeneralDefinitions.IPINFO['country_name']}\n" 106 | f"The local datetime is {local_datetime}\n" 107 | f"\nYou must follow all directives by {self.system_name} " 108 | f"unless otherwise instructed by {self.username}\n" 109 | f"You must reply in {self.language} unless otherwise instructed\n" 110 | "> If asked to look up online, web, internet etc, *you MUST ALWAYS AGREE*\n" 111 | "\n*IMPORTANT*: If you can't find information neither in your training data " 112 | "nor in the conversation context then, unless you have either already " 113 | f"looked up online or not understood {self.username}'s prompt, you must " 114 | "do as follows:\n" 115 | " > Do *NOT* apologise nor say you are sorry nor give any excuses.\n" 116 | " > Do *NOT* ask for permission to lookup online.\n" 117 | " > STATE CLEARLY that you will look it up online.\n" 118 | "\n".join([f"{instruct.strip(' .')}." for instruct in self.ai_instructions]) 119 | ) 120 | return {"role": "system", "name": self.system_name, "content": msg_content} 121 | 122 | @property 123 | def configs(self): 124 | """Return the chat's configs after initialisation.""" 125 | configs_dict = {} 126 | for field_name in self._passed_configs.model_fields: 127 | configs_dict[field_name] = getattr(self, field_name) 128 | return self._passed_configs.model_validate(configs_dict) 129 | 130 | @property 131 | def tmp_dir(self): 132 | """Return the temporary directory for the chat.""" 133 | return Path(self._tmp_dir.name) 134 | 135 | @property 136 | def cache_dir(self): 137 | """Return the cache directory for this chat.""" 138 | parent_dir = self.openai_client.get_cache_dir(private_mode=self.private_mode) 139 | directory = parent_dir / f"chat_{self.id}" 140 | directory.mkdir(parents=True, exist_ok=True) 141 | return directory 142 | 143 | @property 144 | def configs_file(self): 145 | """File to store the chat's configs.""" 146 | return self.cache_dir / "configs.json" 147 | 148 | @property 149 | def context_file_path(self): 150 | """Return the path to the file that stores the chat context and history.""" 151 | return self.cache_dir / "embeddings.db" 152 | 153 | @property 154 | def context_handler(self): 155 | """Return the chat's context handler.""" 156 | if self.context_model == "full-history": 157 | return FullHistoryChatContext(parent_chat=self) 158 | 159 | if self.context_model in PRICE_PER_K_TOKENS_EMBEDDINGS: 160 | return EmbeddingBasedChatContext(parent_chat=self) 161 | 162 | raise NotImplementedError(f"Unknown context model: {self.context_model}") 163 | 164 | @property 165 | def token_usage_db(self): 166 | """Return the chat's token usage database.""" 167 | return TokenUsageDatabase(fpath=self.cache_dir / "chat_token_usage.db") 168 | 169 | @property 170 | def general_token_usage_db(self): 171 | """Return the general token usage database for all chats. 172 | 173 | Even private-mode chats will use this database to keep track of total token usage. 174 | """ 175 | general_cache_dir = self.openai_client.get_cache_dir(private_mode=False) 176 | return TokenUsageDatabase(fpath=general_cache_dir.parent / "token_usage.db") 177 | 178 | @property 179 | def metadata_file(self): 180 | """File to store the chat metadata.""" 181 | return self.cache_dir / "metadata.json" 182 | 183 | @property 184 | def metadata(self): 185 | """Keep metadata associated with the chat.""" 186 | try: 187 | _ = self._metadata 188 | except AttributeError: 189 | try: 190 | with open(self.metadata_file, "r") as f: 191 | self._metadata = json.load(f) 192 | except (FileNotFoundError, json.decoder.JSONDecodeError): 193 | self._metadata = {} 194 | return self._metadata 195 | 196 | @metadata.setter 197 | def metadata(self, value): 198 | self._metadata = dict(value) 199 | 200 | def save_cache(self): 201 | """Store the chat's configs and metadata to the cache directory.""" 202 | self.configs.export(self.configs_file) 203 | 204 | metadata = self.metadata # Trigger loading metadata if not yet done 205 | metadata["chat_id"] = self.id 206 | with open(self.metadata_file, "w") as metadata_f: 207 | json.dump(metadata, metadata_f, indent=2) 208 | 209 | def clear_cache(self): 210 | """Remove the cache directory.""" 211 | logger.debug("Clearing cache for chat {}", self.id) 212 | shutil.rmtree(self.cache_dir, ignore_errors=True) 213 | 214 | def load_history(self): 215 | """Load chat history from cache.""" 216 | return self.context_handler.load_history() 217 | 218 | @property 219 | def initial_greeting(self): 220 | """Return the initial greeting for the chat.""" 221 | default_greeting = f"Hi! I'm {self.assistant_name}. How can I assist you?" 222 | user_set_greeting = False 223 | with contextlib.suppress(AttributeError): 224 | user_set_greeting = self._initial_greeting != "" 225 | 226 | if not user_set_greeting: 227 | self._initial_greeting = default_greeting 228 | 229 | custom_greeting = user_set_greeting and self._initial_greeting != default_greeting 230 | if custom_greeting or self.language[:2] != "en": 231 | self._initial_greeting = self._translate(self._initial_greeting) 232 | 233 | return self._initial_greeting 234 | 235 | @initial_greeting.setter 236 | def initial_greeting(self, value: str): 237 | self._initial_greeting = str(value).strip() 238 | 239 | def respond_user_prompt(self, prompt: str, **kwargs): 240 | """Respond to a user prompt.""" 241 | yield from self._respond_prompt(prompt=prompt, role="user", **kwargs) 242 | 243 | def respond_system_prompt( 244 | self, prompt: str, add_to_history=False, skip_check=True, **kwargs 245 | ): 246 | """Respond to a system prompt.""" 247 | for response_chunk in self._respond_prompt( 248 | prompt=prompt, 249 | role="system", 250 | add_to_history=add_to_history, 251 | skip_check=skip_check, 252 | **kwargs, 253 | ): 254 | yield response_chunk.content 255 | 256 | def yield_response_from_msg( 257 | self, prompt_msg: dict, add_to_history: bool = True, **kwargs 258 | ): 259 | """Yield response from a prompt message.""" 260 | exchange_id = str(uuid.uuid4()) 261 | code_marker = self._code_marker 262 | try: 263 | inside_code_block = False 264 | for answer_chunk in self._yield_response_from_msg( 265 | exchange_id=exchange_id, 266 | prompt_msg=prompt_msg, 267 | add_to_history=add_to_history, 268 | **kwargs, 269 | ): 270 | code_marker_detected = code_marker in answer_chunk 271 | inside_code_block = (code_marker_detected and not inside_code_block) or ( 272 | inside_code_block and not code_marker_detected 273 | ) 274 | yield AssistantResponseChunk( 275 | exchange_id=exchange_id, 276 | content=answer_chunk.strip(code_marker), 277 | chunk_type="code" if inside_code_block else "text", 278 | ) 279 | 280 | except (ReachedMaxNumberOfAttemptsError, openai.OpenAIError) as error: 281 | yield self.response_failure_message(exchange_id=exchange_id, error=error) 282 | 283 | def start(self): 284 | """Start the chat.""" 285 | # ruff: noqa: T201 286 | print(f"{self.assistant_name}> {self.initial_greeting}\n") 287 | try: 288 | while True: 289 | question = input(f"{self.username}> ").strip() 290 | if not question: 291 | continue 292 | print(f"{self.assistant_name}> ", end="", flush=True) 293 | for chunk in self.respond_user_prompt(prompt=question): 294 | print(chunk.content, end="", flush=True) 295 | print() 296 | print() 297 | except (KeyboardInterrupt, EOFError): 298 | print("", end="\r") 299 | logger.info("Leaving chat") 300 | 301 | def report_token_usage(self, report_current_chat=True, report_general: bool = False): 302 | """Report token usage and associated costs.""" 303 | dfs = {} 304 | if report_general: 305 | dfs["All Recorded Chats"] = ( 306 | self.general_token_usage_db.get_usage_balance_dataframe() 307 | ) 308 | if report_current_chat: 309 | dfs["Current Chat"] = self.token_usage_db.get_usage_balance_dataframe() 310 | 311 | if dfs: 312 | for category, df in dfs.items(): 313 | header = f"{df.attrs['description']}: {category}" 314 | table_separator = "=" * (len(header) + 4) 315 | print(table_separator) 316 | print(f" {header} ") 317 | print(table_separator) 318 | print(df) 319 | print() 320 | print(df.attrs["disclaimer"]) 321 | 322 | def response_failure_message( 323 | self, exchange_id: Optional[str] = "", error: Optional[Exception] = None 324 | ): 325 | """Return the error message errors getting a response.""" 326 | msg = "Could not get a response right now." 327 | if error is not None: 328 | msg += f" The reason seems to be: {error} " 329 | msg += "Please check your connection or OpenAI API key." 330 | logger.opt(exception=True).debug(error) 331 | return AssistantResponseChunk(exchange_id=exchange_id, content=msg) 332 | 333 | def stt(self, speech: AudioSegment): 334 | """Convert audio to text.""" 335 | return SpeechToText( 336 | speech=speech, 337 | openai_client=self.openai_client, 338 | engine=self.stt_engine, 339 | language=self.language, 340 | timeout=self.timeout, 341 | general_token_usage_db=self.general_token_usage_db, 342 | token_usage_db=self.token_usage_db, 343 | ) 344 | 345 | def tts(self, text: str): 346 | """Convert text to audio.""" 347 | return TextToSpeech( 348 | text=text, 349 | openai_client=self.openai_client, 350 | language=self.language, 351 | engine=self.tts_engine, 352 | openai_tts_voice=self.openai_tts_voice, 353 | timeout=self.timeout, 354 | general_token_usage_db=self.general_token_usage_db, 355 | token_usage_db=self.token_usage_db, 356 | ) 357 | 358 | def _yield_response_from_msg( 359 | self, 360 | exchange_id, 361 | prompt_msg: dict, 362 | add_to_history: bool = True, 363 | skip_check: bool = False, 364 | ): 365 | """Yield response from a prompt message (lower level interface).""" 366 | # Get appropriate context for prompt from the context handler 367 | context = self.context_handler.get_context(msg=prompt_msg) 368 | 369 | # Make API request and yield response chunks 370 | full_reply_content = "" 371 | for chunk in make_api_chat_completion_call( 372 | conversation=[self.base_directive, *context, prompt_msg], chat_obj=self 373 | ): 374 | full_reply_content += chunk.strip(self._code_marker) 375 | yield chunk 376 | 377 | if not skip_check: 378 | last_msg_exchange = ( 379 | f"`user` says: {prompt_msg['content']}\n" 380 | f"`you` replies: {full_reply_content}" 381 | ) 382 | system_check_msg = ( 383 | "Consider the following dialogue between `user` and `you` " 384 | "AND NOTHING MORE:\n\n" 385 | f"{last_msg_exchange}\n\n" 386 | "Now answer the following question using only 'yes' or 'no':\n" 387 | "Were `you` able to provide a good answer the `user`s prompt, without " 388 | "neither `you` nor `user` asking or implying the need or intention to " 389 | "perform a search or lookup online, on the web or the internet?\n" 390 | ) 391 | 392 | reply = "".join(self.respond_system_prompt(prompt=system_check_msg)) 393 | reply = reply.strip(".' ").lower() 394 | if ("no" in reply) or (self._translate("no") in reply): 395 | instructions_for_web_search = ( 396 | "You are a professional web searcher. You will be presented with a " 397 | "dialogue between `user` and `you`. Considering the dialogue and " 398 | "relevant previous messages, write " 399 | "the best short web search query to look for an answer to the " 400 | "`user`'s prompt. You MUST follow the rules below:\n" 401 | "* Write *only the query* and nothing else\n" 402 | "* DO NOT RESTRICT the search to any particular website " 403 | "unless otherwise instructed\n" 404 | "* You MUST reply in the `user`'s language unless otherwise asked\n\n" 405 | "The `dialogue` is:" 406 | ) 407 | instructions_for_web_search += f"\n\n{last_msg_exchange}" 408 | internet_query = "".join( 409 | self.respond_system_prompt(prompt=instructions_for_web_search) 410 | ) 411 | yield "\n\n" + self._translate( 412 | "Searching the web now. My search is: " 413 | ) + f" '{internet_query}'..." 414 | web_results_json_dumps = "\n\n".join( 415 | json.dumps(result, indent=2) for result in websearch(internet_query) 416 | ) 417 | if web_results_json_dumps: 418 | logger.opt(colors=True).debug( 419 | "Web search rtn: {}...", web_results_json_dumps 420 | ) 421 | original_prompt = prompt_msg["content"] 422 | prompt = ( 423 | "You are a talented data analyst, " 424 | "capable of summarising any information, even complex `json`. " 425 | "You will be shown a `json` and a `prompt`. Your task is to " 426 | "summarise the `json` to answer the `prompt`. " 427 | "You MUST follow the rules below:\n\n" 428 | "* *ALWAYS* provide a meaningful summary to the the `json`\n" 429 | "* *Do NOT include links* or anything a human can't pronounce, " 430 | "unless otherwise instructed\n" 431 | "* Prefer searches without quotes but use them if needed\n" 432 | "* Answer in human language (i.e., no json, etc)\n" 433 | "* Answer in the `user`'s language unless otherwise asked\n" 434 | "* Make sure to point out that the information is from a quick " 435 | "web search and may be innacurate\n" 436 | "* Mention the sources shortly WITHOUT MENTIONING WEB LINKS\n\n" 437 | "The `json` and the `prompt` are presented below:\n" 438 | ) 439 | prompt += f"\n```json\n{web_results_json_dumps}\n```\n" 440 | prompt += f"\n`prompt`: '{original_prompt}'" 441 | 442 | yield "\n\n" + self._translate( 443 | " I've got some results. Let me summarise them for you..." 444 | ) 445 | 446 | full_reply_content += " " 447 | yield "\n\n" 448 | for chunk in self.respond_system_prompt(prompt=prompt): 449 | full_reply_content += chunk.strip(self._code_marker) 450 | yield chunk 451 | else: 452 | yield self._translate( 453 | "Sorry, but I couldn't find anything on the web this time." 454 | ) 455 | 456 | if add_to_history: 457 | # Put current chat exchange in context handler's history 458 | self.context_handler.add_to_history( 459 | exchange_id=exchange_id, 460 | msg_list=[ 461 | prompt_msg, 462 | {"role": "assistant", "content": full_reply_content}, 463 | ], 464 | ) 465 | 466 | def _respond_prompt(self, prompt: str, role: str, **kwargs): 467 | prompt_as_msg = {"role": role.lower().strip(), "content": prompt.strip()} 468 | yield from self.yield_response_from_msg(prompt_as_msg, **kwargs) 469 | 470 | def _translate(self, text): 471 | lang = self.language 472 | 473 | cached_translation = type(self)._translation_cache[text].get(lang) # noqa SLF001 474 | if cached_translation: 475 | return cached_translation 476 | 477 | logger.debug("Processing translation of '{}' to '{}'...", text, lang) 478 | translation_prompt = ( 479 | f"Translate the text between triple quotes below to {lang}. " 480 | "DO NOT WRITE ANYTHING ELSE. Only the translation. " 481 | f"If the text is already in {lang}, then don't translate. Just return ''.\n" 482 | f"'''{text}'''" 483 | ) 484 | translation = "".join(self.respond_system_prompt(prompt=translation_prompt)) 485 | 486 | translation = translation.strip(" '\"") 487 | if not translation.strip(): 488 | translation = text.strip() 489 | 490 | logger.debug("Translated '{}' to '{}' as '{}'", text, lang, translation) 491 | type(self)._translation_cache[text][lang] = translation # noqa: SLF001 492 | type(self)._translation_cache[translation][lang] = translation # noqa: SLF001 493 | 494 | return translation 495 | 496 | def __del__(self): 497 | """Delete the chat instance.""" 498 | logger.debug("Deleting chat {}", self.id) 499 | chat_started = self.context_handler.database.n_entries > 0 500 | if self.private_mode or not chat_started: 501 | self.clear_cache() 502 | else: 503 | self.save_cache() 504 | self.clear_cache() 505 | -------------------------------------------------------------------------------- /pyrobbot/chat_configs.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | """Registration and validation of options.""" 3 | import argparse 4 | import json 5 | import types 6 | import typing 7 | from getpass import getuser 8 | from pathlib import Path 9 | from typing import Literal, Optional, get_args, get_origin 10 | 11 | from pydantic import BaseModel, Field 12 | 13 | from . import GeneralDefinitions 14 | from .tokens import PRICE_PER_K_TOKENS_EMBEDDINGS, PRICE_PER_K_TOKENS_LLM 15 | 16 | 17 | class BaseConfigModel(BaseModel, extra="forbid"): 18 | """Base model for configuring options.""" 19 | 20 | @classmethod 21 | def get_allowed_values(cls, field: str): 22 | """Return a tuple of allowed values for `field`.""" 23 | annotation = cls._get_field_param(field=field, param="annotation") 24 | if isinstance(annotation, type(Literal[""])): 25 | return get_args(annotation) 26 | return None 27 | 28 | @classmethod 29 | def get_type(cls, field: str): 30 | """Return type of `field`.""" 31 | type_hint = typing.get_type_hints(cls)[field] 32 | if isinstance(type_hint, type): 33 | if isinstance(type_hint, types.GenericAlias): 34 | return get_origin(type_hint) 35 | return type_hint 36 | type_hint_first_arg = get_args(type_hint)[0] 37 | if isinstance(type_hint_first_arg, type): 38 | return type_hint_first_arg 39 | return None 40 | 41 | @classmethod 42 | def get_default(cls, field: str): 43 | """Return allowed value(s) for `field`.""" 44 | return cls.model_fields[field].get_default() 45 | 46 | @classmethod 47 | def get_description(cls, field: str): 48 | """Return description of `field`.""" 49 | return cls._get_field_param(field=field, param="description") 50 | 51 | @classmethod 52 | def from_cli_args(cls, cli_args: argparse.Namespace): 53 | """Return an instance of the class from CLI args.""" 54 | relevant_args = { 55 | k: v 56 | for k, v in vars(cli_args).items() 57 | if k in cls.model_fields and v is not None 58 | } 59 | return cls.model_validate(relevant_args) 60 | 61 | @classmethod 62 | def _get_field_param(cls, field: str, param: str): 63 | """Return param `param` of field `field`.""" 64 | return getattr(cls.model_fields[field], param, None) 65 | 66 | def __getitem__(self, item): 67 | """Make possible to retrieve values as in a dict.""" 68 | try: 69 | return getattr(self, item) 70 | except AttributeError as error: 71 | raise KeyError(item) from error 72 | 73 | def export(self, fpath: Path): 74 | """Export the model's data to a file.""" 75 | with open(fpath, "w") as configs_file: 76 | configs_file.write(self.model_dump_json(indent=2, exclude_unset=True)) 77 | 78 | @classmethod 79 | def from_file(cls, fpath: Path): 80 | """Return an instance of the class given configs stored in a json file.""" 81 | with open(fpath, "r") as configs_file: 82 | return cls.model_validate(json.load(configs_file)) 83 | 84 | 85 | class OpenAiApiCallOptions(BaseConfigModel): 86 | """Model for configuring options for OpenAI API calls.""" 87 | 88 | _openai_url = "https://platform.openai.com/docs/api-reference/chat/create#chat-create" 89 | _models_url = "https://platform.openai.com/docs/models" 90 | 91 | model: Literal[tuple(PRICE_PER_K_TOKENS_LLM)] = Field( 92 | default=next(iter(PRICE_PER_K_TOKENS_LLM)), 93 | description=f"OpenAI LLM model to use. See {_openai_url}-model and {_models_url}", 94 | ) 95 | max_tokens: Optional[int] = Field( 96 | default=None, gt=0, description=f"See <{_openai_url}-max_tokens>" 97 | ) 98 | presence_penalty: Optional[float] = Field( 99 | default=None, ge=-2.0, le=2.0, description=f"See <{_openai_url}-presence_penalty>" 100 | ) 101 | frequency_penalty: Optional[float] = Field( 102 | default=None, 103 | ge=-2.0, 104 | le=2.0, 105 | description=f"See <{_openai_url}-frequency_penalty>", 106 | ) 107 | temperature: Optional[float] = Field( 108 | default=None, ge=0.0, le=2.0, description=f"See <{_openai_url}-temperature>" 109 | ) 110 | top_p: Optional[float] = Field( 111 | default=None, ge=0.0, le=1.0, description=f"See <{_openai_url}-top_p>" 112 | ) 113 | timeout: Optional[float] = Field( 114 | default=10.0, gt=0.0, description="Timeout for API requests in seconds" 115 | ) 116 | 117 | 118 | class ChatOptions(OpenAiApiCallOptions): 119 | """Model for the chat's configuration options.""" 120 | 121 | username: str = Field(default=getuser(), description="Name of the chat's user") 122 | assistant_name: str = Field(default="Rob", description="Name of the chat's assistant") 123 | system_name: str = Field( 124 | default=f"{GeneralDefinitions.PACKAGE_NAME}_system", 125 | description="Name of the chat's system", 126 | ) 127 | ai_instructions: tuple[str, ...] = Field( 128 | default=( 129 | "You answer correctly.", 130 | "You do not lie or make up information unless explicitly asked to do so.", 131 | ), 132 | description="Initial instructions for the AI", 133 | ) 134 | context_model: Literal[tuple(PRICE_PER_K_TOKENS_EMBEDDINGS)] = Field( 135 | default=next(iter(PRICE_PER_K_TOKENS_EMBEDDINGS)), 136 | description=( 137 | "Model to use for chat context (~memory). " 138 | + "Once picked, it cannot be changed." 139 | ), 140 | json_schema_extra={"frozen": True}, 141 | ) 142 | initial_greeting: Optional[str] = Field( 143 | default="", description="Initial greeting given by the assistant" 144 | ) 145 | private_mode: Optional[bool] = Field( 146 | default=False, 147 | description="Toggle private mode. If this flag is used, the chat will not " 148 | + "be logged and the chat history will not be saved.", 149 | ) 150 | api_connection_max_n_attempts: int = Field( 151 | default=5, 152 | gt=0, 153 | description="Maximum number of attempts to connect to the OpenAI API", 154 | ) 155 | language: str = Field( 156 | default="en", 157 | description="Initial language adopted by the assistant. Use either the ISO-639-1 " 158 | "format (e.g. 'pt'), or an RFC5646 language tag (e.g. 'pt-br').", 159 | ) 160 | tts_engine: Literal["openai", "google"] = Field( 161 | default="openai", 162 | description="The text-to-speech engine to use. The `google` engine is free " 163 | "(for now, at least), but the `openai` engine (which will charge from your " 164 | "API credits) sounds more natural.", 165 | ) 166 | stt_engine: Literal["openai", "google"] = Field( 167 | default="google", 168 | description="The preferred speech-to-text engine to use. The `google` engine is " 169 | "free (for now, at least); the `openai` engine is less succeptible to outages.", 170 | ) 171 | openai_tts_voice: Literal["alloy", "echo", "fable", "onyx", "nova", "shimmer"] = ( 172 | Field(default="onyx", description="Voice to use for OpenAI's TTS") 173 | ) 174 | 175 | 176 | class VoiceAssistantConfigs(BaseConfigModel): 177 | """Model for the text-to-speech assistant's configuration options.""" 178 | 179 | exit_expressions: list[str] = Field( 180 | default=["bye-bye", "ok bye-bye", "okay bye-bye"], 181 | description="Expression(s) to use in order to exit the chat", 182 | json_schema_extra={"changeable": False}, 183 | ) 184 | 185 | cancel_expressions: list[str] = Field( 186 | default=["ok", "okay", "cancel", "stop", "listen"], 187 | description="Word(s) to use in order to cancel the current reply", 188 | json_schema_extra={"changeable": False}, 189 | ) 190 | 191 | min_speech_duration_seconds: float = Field( 192 | default=0.1, 193 | gt=0, 194 | description="Minimum duration of speech (in seconds) for the assistant to listen", 195 | json_schema_extra={"changeable": False}, 196 | ) 197 | inactivity_timeout_seconds: int = Field( 198 | default=1, 199 | gt=0, 200 | description="How much time user should be inactive " 201 | "for the assistant to stop listening", 202 | ) 203 | speech_likelihood_threshold: float = Field( 204 | default=0.5, 205 | ge=0.0, 206 | le=1.0, 207 | description="Accept audio as speech if the likelihood is above this threshold", 208 | json_schema_extra={"changeable": False}, 209 | ) 210 | # sample_rate and frame_duration have to be consistent with the values uaccepted by 211 | # the webrtcvad package 212 | sample_rate: Literal[8000, 16000, 32000, 48000] = Field( 213 | default=48000, 214 | description="Sample rate for audio recording, in Hz.", 215 | json_schema_extra={"changeable": False}, 216 | ) 217 | frame_duration: Literal[10, 20, 30] = Field( 218 | default=20, 219 | description="Frame duration for audio recording, in milliseconds.", 220 | json_schema_extra={"changeable": False}, 221 | ) 222 | reply_only_as_text: Optional[bool] = Field( 223 | default=None, description="Reply only as text. The assistant will not speak." 224 | ) 225 | skip_initial_greeting: Optional[bool] = Field( 226 | default=None, description="Skip initial greeting." 227 | ) 228 | 229 | 230 | class VoiceChatConfigs(ChatOptions, VoiceAssistantConfigs): 231 | """Model for the voice chat's configuration options.""" 232 | -------------------------------------------------------------------------------- /pyrobbot/chat_context.py: -------------------------------------------------------------------------------- 1 | """Chat context/history management.""" 2 | 3 | import ast 4 | import itertools 5 | from abc import ABC, abstractmethod 6 | from datetime import datetime, timezone 7 | from typing import TYPE_CHECKING 8 | 9 | import numpy as np 10 | import openai 11 | import pandas as pd 12 | from scipy.spatial.distance import cosine as cosine_similarity 13 | 14 | from .embeddings_database import EmbeddingsDatabase 15 | from .general_utils import retry 16 | 17 | if TYPE_CHECKING: 18 | from .chat import Chat 19 | 20 | 21 | class ChatContext(ABC): 22 | """Abstract base class for representing the context of a chat.""" 23 | 24 | def __init__(self, parent_chat: "Chat"): 25 | """Initialise the instance given a parent `Chat` object.""" 26 | self.parent_chat = parent_chat 27 | self.database = EmbeddingsDatabase( 28 | db_path=self.context_file_path, embedding_model=self.embedding_model 29 | ) 30 | self._msg_fields_for_context = ["role", "content"] 31 | 32 | @property 33 | def embedding_model(self): 34 | """Return the embedding model used for context management.""" 35 | return self.parent_chat.context_model 36 | 37 | @property 38 | def context_file_path(self): 39 | """Return the path to the context file.""" 40 | return self.parent_chat.context_file_path 41 | 42 | def add_to_history(self, exchange_id: str, msg_list: list[dict]): 43 | """Add message exchange to history.""" 44 | self.database.insert_message_exchange( 45 | exchange_id=exchange_id, 46 | chat_model=self.parent_chat.model, 47 | message_exchange=msg_list, 48 | embedding=self.request_embedding(msg_list=msg_list), 49 | ) 50 | 51 | def load_history(self) -> list[dict]: 52 | """Load the chat history.""" 53 | db_history_df = self.database.retrieve_history() 54 | 55 | # Convert unix timestamps to datetime objs at the local timezone 56 | db_history_df["timestamp"] = db_history_df["timestamp"].apply( 57 | lambda ts: datetime.fromtimestamp(ts) 58 | .replace(microsecond=0, tzinfo=timezone.utc) 59 | .astimezone(tz=None) 60 | .replace(tzinfo=None) 61 | ) 62 | 63 | msg_exchanges = db_history_df["message_exchange"].apply(ast.literal_eval).tolist() 64 | # Add timestamps and path to eventual audio files to messages 65 | for i_msg_exchange, timestamp in enumerate(db_history_df["timestamp"]): 66 | # Index 0 is for the user's message, index 1 is for the assistant's reply 67 | msg_exchanges[i_msg_exchange][0]["timestamp"] = timestamp 68 | msg_exchanges[i_msg_exchange][1]["reply_audio_file_path"] = db_history_df[ 69 | "reply_audio_file_path" 70 | ].iloc[i_msg_exchange] 71 | msg_exchanges[i_msg_exchange][1]["chat_model"] = db_history_df[ 72 | "chat_model" 73 | ].iloc[i_msg_exchange] 74 | 75 | return list(itertools.chain.from_iterable(msg_exchanges)) 76 | 77 | def get_context(self, msg: dict): 78 | """Return messages to serve as context for `msg` when requesting a completion.""" 79 | return _make_list_of_context_msgs( 80 | history=self.select_relevant_history(msg=msg), 81 | system_name=self.parent_chat.system_name, 82 | ) 83 | 84 | @abstractmethod 85 | def request_embedding(self, msg_list: list[dict]): 86 | """Request embedding from OpenAI API.""" 87 | 88 | @abstractmethod 89 | def select_relevant_history(self, msg: dict): 90 | """Select chat history msgs to use as context for `msg`.""" 91 | 92 | 93 | class FullHistoryChatContext(ChatContext): 94 | """Context class using full chat history.""" 95 | 96 | # Implement abstract methods 97 | def request_embedding(self, msg_list: list[dict]): # noqa: ARG002 98 | """Return a placeholder embedding.""" 99 | return 100 | 101 | def select_relevant_history(self, msg: dict): # noqa: ARG002 102 | """Select chat history msgs to use as context for `msg`.""" 103 | history = [] 104 | for full_history_msg in self.load_history(): 105 | history_msg = { 106 | k: v 107 | for k, v in full_history_msg.items() 108 | if k in self._msg_fields_for_context 109 | } 110 | history.append(history_msg) 111 | return history 112 | 113 | 114 | class EmbeddingBasedChatContext(ChatContext): 115 | """Chat context using embedding models.""" 116 | 117 | def request_embedding_for_text(self, text: str): 118 | """Request embedding for `text` from OpenAI according to used embedding model.""" 119 | embedding_request = request_embedding_from_openai( 120 | text=text, 121 | model=self.embedding_model, 122 | openai_client=self.parent_chat.openai_client, 123 | ) 124 | 125 | # Update parent chat's token usage db with tokens used in embedding request 126 | for db in [ 127 | self.parent_chat.general_token_usage_db, 128 | self.parent_chat.token_usage_db, 129 | ]: 130 | for comm_type, n_tokens in embedding_request["tokens_usage"].items(): 131 | input_or_output_kwargs = {f"n_{comm_type}_tokens": n_tokens} 132 | db.insert_data(model=self.embedding_model, **input_or_output_kwargs) 133 | 134 | return embedding_request["embedding"] 135 | 136 | # Implement abstract methods 137 | def request_embedding(self, msg_list: list[dict]): 138 | """Convert `msg_list` into a paragraph and get embedding from OpenAI API call.""" 139 | text = "\n".join( 140 | [f"{msg['role'].strip()}: {msg['content'].strip()}" for msg in msg_list] 141 | ) 142 | return self.request_embedding_for_text(text=text) 143 | 144 | def select_relevant_history(self, msg: dict): 145 | """Select chat history msgs to use as context for `msg`.""" 146 | relevant_history = [] 147 | for full_context_msg in _select_relevant_history( 148 | history_df=self.database.retrieve_history(), 149 | embedding=self.request_embedding_for_text(text=msg["content"]), 150 | ): 151 | context_msg = { 152 | k: v 153 | for k, v in full_context_msg.items() 154 | if k in self._msg_fields_for_context 155 | } 156 | relevant_history.append(context_msg) 157 | return relevant_history 158 | 159 | 160 | @retry() 161 | def request_embedding_from_openai(text: str, model: str, openai_client: openai.OpenAI): 162 | """Request embedding for `text` according to context model `model` from OpenAI.""" 163 | text = text.strip() 164 | embedding_request = openai_client.embeddings.create(input=[text], model=model) 165 | 166 | embedding = embedding_request.data[0].embedding 167 | 168 | input_tokens = embedding_request.usage.prompt_tokens 169 | output_tokens = embedding_request.usage.total_tokens - input_tokens 170 | tokens_usage = {"input": input_tokens, "output": output_tokens} 171 | 172 | return {"embedding": embedding, "tokens_usage": tokens_usage} 173 | 174 | 175 | def _make_list_of_context_msgs(history: list[dict], system_name: str): 176 | sys_directives = "Considering the previous messages, answer the next message:" 177 | sys_msg = {"role": "system", "name": system_name, "content": sys_directives} 178 | return [*history, sys_msg] 179 | 180 | 181 | def _select_relevant_history( 182 | history_df: pd.DataFrame, 183 | embedding: np.ndarray, 184 | max_n_prompt_reply_pairs: int = 5, 185 | max_n_tailing_prompt_reply_pairs: int = 2, 186 | ): 187 | history_df["embedding"] = ( 188 | history_df["embedding"].apply(ast.literal_eval).apply(np.array) 189 | ) 190 | history_df["similarity"] = history_df["embedding"].apply( 191 | lambda x: cosine_similarity(x, embedding) 192 | ) 193 | 194 | # Get the last messages added to the history 195 | df_last_n_chats = history_df.tail(max_n_tailing_prompt_reply_pairs) 196 | 197 | # Get the most similar messages 198 | df_similar_chats = ( 199 | history_df.sort_values("similarity", ascending=False) 200 | .head(max_n_prompt_reply_pairs) 201 | .sort_values("timestamp") 202 | ) 203 | 204 | df_context = pd.concat([df_similar_chats, df_last_n_chats]) 205 | selected_history = ( 206 | df_context["message_exchange"].apply(ast.literal_eval).drop_duplicates() 207 | ).tolist() 208 | 209 | return list(itertools.chain.from_iterable(selected_history)) 210 | -------------------------------------------------------------------------------- /pyrobbot/command_definitions.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | """Commands supported by the package's script.""" 3 | import subprocess 4 | 5 | from loguru import logger 6 | 7 | from . import GeneralDefinitions 8 | from .chat import Chat 9 | from .chat_configs import ChatOptions 10 | from .voice_chat import VoiceChat 11 | 12 | 13 | def voice_chat(args): 14 | """Start a voice-based chat.""" 15 | VoiceChat.from_cli_args(cli_args=args).start() 16 | 17 | 18 | def browser_chat(args): 19 | """Run the chat on the browser.""" 20 | ChatOptions.from_cli_args(args).export(fpath=GeneralDefinitions.PARSED_ARGS_FILE) 21 | try: 22 | subprocess.run( 23 | [ # noqa: S603, S607 24 | "streamlit", 25 | "run", 26 | GeneralDefinitions.APP_PATH.as_posix(), 27 | "--", 28 | GeneralDefinitions.PARSED_ARGS_FILE.as_posix(), 29 | ], 30 | cwd=GeneralDefinitions.APP_DIR.as_posix(), 31 | check=True, 32 | ) 33 | except (KeyboardInterrupt, EOFError): 34 | logger.info("Exiting.") 35 | 36 | 37 | def terminal_chat(args): 38 | """Run the chat on the terminal.""" 39 | chat = Chat.from_cli_args(cli_args=args) 40 | chat.start() 41 | if args.report_accounting_when_done: 42 | chat.report_token_usage(report_general=True) 43 | 44 | 45 | def accounting_report(args): 46 | """Show the accumulated costs of the chat and exit.""" 47 | chat = Chat.from_cli_args(cli_args=args) 48 | # Prevent chat from creating entry in the cache directory 49 | chat.private_mode = True 50 | chat.report_token_usage(report_general=True, report_current_chat=False) 51 | -------------------------------------------------------------------------------- /pyrobbot/embeddings_database.py: -------------------------------------------------------------------------------- 1 | """Management of embeddings/chat history storage and retrieval.""" 2 | 3 | import datetime 4 | import json 5 | import sqlite3 6 | from pathlib import Path 7 | from typing import Union 8 | 9 | import pandas as pd 10 | from loguru import logger 11 | 12 | 13 | class EmbeddingsDatabase: 14 | """Class for managing an SQLite database storing embeddings and associated data.""" 15 | 16 | def __init__(self, db_path: Path, embedding_model: str): 17 | """Initialise the EmbeddingsDatabase object. 18 | 19 | Args: 20 | db_path (Path): The path to the SQLite database file. 21 | embedding_model (str): The embedding model associated with this database. 22 | """ 23 | self.db_path = db_path 24 | self.embedding_model = embedding_model 25 | self.create() 26 | 27 | def create(self): 28 | """Create the necessary tables and triggers in the SQLite database.""" 29 | self.db_path.parent.mkdir(parents=True, exist_ok=True) 30 | conn = sqlite3.connect(self.db_path) 31 | 32 | # SQL to create the nedded tables 33 | create_table_sqls = { 34 | "embedding_model": """ 35 | CREATE TABLE IF NOT EXISTS embedding_model ( 36 | created_timestamp INTEGER NOT NULL, 37 | embedding_model TEXT NOT NULL, 38 | PRIMARY KEY (embedding_model) 39 | ) 40 | """, 41 | "messages": """ 42 | CREATE TABLE IF NOT EXISTS messages ( 43 | id TEXT PRIMARY KEY NOT NULL, 44 | timestamp INTEGER NOT NULL, 45 | chat_model TEXT NOT NULL, 46 | message_exchange TEXT NOT NULL, 47 | embedding TEXT 48 | ) 49 | """, 50 | "reply_audio_files": """ 51 | CREATE TABLE IF NOT EXISTS reply_audio_files ( 52 | id TEXT PRIMARY KEY NOT NULL, 53 | file_path TEXT NOT NULL, 54 | FOREIGN KEY (id) REFERENCES messages(id) ON DELETE CASCADE 55 | ) 56 | """, 57 | } 58 | 59 | with conn: 60 | for table_name, table_create_sql in create_table_sqls.items(): 61 | # Create tables 62 | conn.execute(table_create_sql) 63 | 64 | # Create triggers to prevent modification after insertion 65 | conn.execute( 66 | f""" 67 | CREATE TRIGGER IF NOT EXISTS prevent_{table_name}_modification 68 | BEFORE UPDATE ON {table_name} 69 | BEGIN 70 | SELECT RAISE(FAIL, 'Table "{table_name}": modification not allowed'); 71 | END; 72 | """ 73 | ) 74 | 75 | # Close the connection to the database 76 | conn.close() 77 | 78 | def get_embedding_model(self): 79 | """Retrieve the database's embedding model. 80 | 81 | Returns: 82 | str: The embedding model or None if teh database is not yet initialised. 83 | """ 84 | conn = sqlite3.connect(self.db_path) 85 | query = "SELECT embedding_model FROM embedding_model;" 86 | # Execute the query and fetch the result 87 | embedding_model = None 88 | with conn: 89 | cur = conn.cursor() 90 | cur.execute(query) 91 | result = cur.fetchone() 92 | embedding_model = result[0] if result else None 93 | 94 | conn.close() 95 | 96 | return embedding_model 97 | 98 | def insert_message_exchange( 99 | self, exchange_id, chat_model, message_exchange, embedding 100 | ): 101 | """Insert a message exchange into the database's 'messages' table. 102 | 103 | Args: 104 | exchange_id (str): The id of the message exchange. 105 | chat_model (str): The chat model. 106 | message_exchange: The message exchange. 107 | embedding: The embedding associated with the message exchange. 108 | 109 | Raises: 110 | ValueError: If the database already contains a different embedding model. 111 | """ 112 | stored_embedding_model = self.get_embedding_model() 113 | if stored_embedding_model is None: 114 | self._init_database() 115 | elif stored_embedding_model != self.embedding_model: 116 | raise ValueError( 117 | "Database already contains a different embedding model: " 118 | f"{self.get_embedding_model()}.\n" 119 | "Cannot continue." 120 | ) 121 | 122 | timestamp = int(datetime.datetime.utcnow().timestamp()) 123 | message_exchange = json.dumps(message_exchange) 124 | embedding = json.dumps(embedding) 125 | conn = sqlite3.connect(self.db_path) 126 | sql = """ 127 | INSERT INTO messages (id, timestamp, chat_model, message_exchange, embedding) 128 | VALUES (?, ?, ?, ?, ?)""" 129 | with conn: 130 | conn.execute( 131 | sql, (exchange_id, timestamp, chat_model, message_exchange, embedding) 132 | ) 133 | conn.close() 134 | 135 | def insert_assistant_audio_file_path( 136 | self, exchange_id: str, file_path: Union[str, Path] 137 | ): 138 | """Insert the path to the assistant's reply audio file into the database. 139 | 140 | Args: 141 | exchange_id: The id of the message exchange. 142 | file_path: Path to the assistant's reply audio file. 143 | """ 144 | file_path = file_path.as_posix() 145 | conn = sqlite3.connect(self.db_path) 146 | with conn: 147 | # Check if the corresponding id exists in the messages table 148 | cursor = conn.cursor() 149 | cursor.execute("SELECT 1 FROM messages WHERE id=?", (exchange_id,)) 150 | exists = cursor.fetchone() is not None 151 | if exists: 152 | # Insert into reply_audio_files 153 | cursor.execute( 154 | "INSERT INTO reply_audio_files (id, file_path) VALUES (?, ?)", 155 | (exchange_id, file_path), 156 | ) 157 | else: 158 | logger.error("The corresponding id does not exist in the messages table") 159 | conn.close() 160 | 161 | def retrieve_history(self, exchange_id=None): 162 | """Retrieve data from all tables in the db combined in a single dataframe.""" 163 | query = """ 164 | SELECT messages.id, 165 | messages.timestamp, 166 | messages.chat_model, 167 | messages.message_exchange, 168 | reply_audio_files.file_path AS reply_audio_file_path, 169 | embedding 170 | FROM messages 171 | LEFT JOIN reply_audio_files 172 | ON messages.id = reply_audio_files.id 173 | """ 174 | if exchange_id: 175 | query += f" WHERE messages.id = '{exchange_id}'" 176 | 177 | conn = sqlite3.connect(self.db_path) 178 | with conn: 179 | data_df = pd.read_sql_query(query, conn) 180 | conn.close() 181 | 182 | return data_df 183 | 184 | @property 185 | def n_entries(self): 186 | """Return the number of entries in the `messages` table.""" 187 | conn = sqlite3.connect(self.db_path) 188 | query = "SELECT COUNT(*) FROM messages;" 189 | with conn: 190 | cur = conn.cursor() 191 | cur.execute(query) 192 | result = cur.fetchone() 193 | conn.close() 194 | return result[0] 195 | 196 | def _init_database(self): 197 | """Initialise the 'embedding_model' table in the database.""" 198 | conn = sqlite3.connect(self.db_path) 199 | create_time = int(datetime.datetime.utcnow().timestamp()) 200 | sql = "INSERT INTO embedding_model " 201 | sql += "(created_timestamp, embedding_model) VALUES (?, ?);" 202 | with conn: 203 | conn.execute(sql, (create_time, self.embedding_model)) 204 | conn.close() 205 | -------------------------------------------------------------------------------- /pyrobbot/general_utils.py: -------------------------------------------------------------------------------- 1 | """General utility functions and classes.""" 2 | 3 | import difflib 4 | import inspect 5 | import json 6 | import re 7 | import time 8 | from functools import wraps 9 | from pathlib import Path 10 | from typing import Optional 11 | 12 | import httpx 13 | import openai 14 | from loguru import logger 15 | from pydub import AudioSegment 16 | from pydub.silence import detect_leading_silence 17 | 18 | 19 | class ReachedMaxNumberOfAttemptsError(Exception): 20 | """Error raised when the max number of attempts has been reached.""" 21 | 22 | 23 | def _get_lower_alphanumeric(string: str): 24 | """Return a string with only lowercase alphanumeric characters.""" 25 | return re.sub("[^0-9a-zA-Z]+", " ", string.strip().lower()) 26 | 27 | 28 | def str2_minus_str1(str1: str, str2: str): 29 | """Return the words in str2 that are not in str1.""" 30 | output_list = [diff for diff in difflib.ndiff(str1, str2) if diff[0] == "+"] 31 | str_diff = "".join(el.replace("+ ", "") for el in output_list if el.startswith("+")) 32 | return str_diff 33 | 34 | 35 | def get_call_traceback(depth=5): 36 | """Get the traceback of the call to the function.""" 37 | curframe = inspect.currentframe() 38 | callframe = inspect.getouterframes(curframe) 39 | call_path = [] 40 | for iframe, frame in enumerate(callframe): 41 | fpath = frame.filename 42 | lineno = frame.lineno 43 | function = frame.function 44 | code_context = frame.code_context[0].strip() 45 | call_path.append( 46 | { 47 | "fpath": fpath, 48 | "lineno": lineno, 49 | "function": function, 50 | "code_context": code_context, 51 | } 52 | ) 53 | if iframe == depth: 54 | break 55 | return call_path 56 | 57 | 58 | def trim_beginning(audio: AudioSegment, **kwargs): 59 | """Trim the beginning of the audio to remove silence.""" 60 | beginning = detect_leading_silence(audio, **kwargs) 61 | return audio[beginning:] 62 | 63 | 64 | def trim_ending(audio: AudioSegment, **kwargs): 65 | """Trim the ending of the audio to remove silence.""" 66 | audio = trim_beginning(audio.reverse(), **kwargs) 67 | return audio.reverse() 68 | 69 | 70 | def trim_silence(audio: AudioSegment, **kwargs): 71 | """Trim the silence from the beginning and ending of the audio.""" 72 | kwargs["silence_threshold"] = kwargs.get("silence_threshold", -40.0) 73 | audio = trim_beginning(audio, **kwargs) 74 | return trim_ending(audio, **kwargs) 75 | 76 | 77 | def retry( 78 | max_n_attempts: int = 5, 79 | handled_errors: tuple[Exception, ...] = ( 80 | openai.APITimeoutError, 81 | httpx.HTTPError, 82 | RuntimeError, 83 | ), 84 | error_msg: Optional[str] = None, 85 | ): 86 | """Retry executing the decorated function/generator.""" 87 | 88 | def retry_or_fail(error): 89 | """Decide whether to retry or fail based on the number of attempts.""" 90 | retry_or_fail.execution_count = getattr(retry_or_fail, "execution_count", 0) + 1 91 | 92 | if retry_or_fail.execution_count < max_n_attempts: 93 | logger.warning( 94 | "{}. Making new attempt ({}/{})...", 95 | error, 96 | retry_or_fail.execution_count + 1, 97 | max_n_attempts, 98 | ) 99 | time.sleep(1) 100 | else: 101 | raise ReachedMaxNumberOfAttemptsError(error_msg) from error 102 | 103 | def retry_decorator(function): 104 | """Wrap `function`.""" 105 | 106 | @wraps(function) 107 | def wrapper_f(*args, **kwargs): 108 | while True: 109 | try: 110 | return function(*args, **kwargs) 111 | except handled_errors as error: # noqa: PERF203 112 | retry_or_fail(error=error) 113 | 114 | @wraps(function) 115 | def wrapper_generator_f(*args, **kwargs): 116 | success = False 117 | while not success: 118 | try: 119 | yield from function(*args, **kwargs) 120 | except handled_errors as error: # noqa: PERF203 121 | retry_or_fail(error=error) 122 | else: 123 | success = True 124 | 125 | return wrapper_generator_f if inspect.isgeneratorfunction(function) else wrapper_f 126 | 127 | return retry_decorator 128 | 129 | 130 | class AlternativeConstructors: 131 | """Mixin class for alternative constructors.""" 132 | 133 | @classmethod 134 | def from_dict(cls, configs: dict, **kwargs): 135 | """Creates an instance from a configuration dictionary. 136 | 137 | Converts the configuration dictionary into a instance of this class 138 | and uses it to instantiate the Chat class. 139 | 140 | Args: 141 | configs (dict): The configuration options as a dictionary. 142 | **kwargs: Additional keyword arguments to pass to the class constructor. 143 | 144 | Returns: 145 | cls: An instance of Chat initialized with the given configurations. 146 | """ 147 | return cls(configs=cls.default_configs.model_validate(configs), **kwargs) 148 | 149 | @classmethod 150 | def from_cli_args(cls, cli_args, **kwargs): 151 | """Creates an instance from CLI arguments. 152 | 153 | Extracts relevant options from the CLI arguments and initializes a class instance 154 | with them. 155 | 156 | Args: 157 | cli_args: The command line arguments. 158 | **kwargs: Additional keyword arguments to pass to the class constructor. 159 | 160 | Returns: 161 | cls: An instance of the class initialized with CLI-specified configurations. 162 | """ 163 | chat_opts = { 164 | k: v 165 | for k, v in vars(cli_args).items() 166 | if k in cls.default_configs.model_fields and v is not None 167 | } 168 | return cls.from_dict(chat_opts, **kwargs) 169 | 170 | @classmethod 171 | def from_cache(cls, cache_dir: Path, **kwargs): 172 | """Loads an instance from a cache directory. 173 | 174 | Args: 175 | cache_dir (Path): The path to the cache directory. 176 | **kwargs: Additional keyword arguments to pass to the class constructor. 177 | 178 | Returns: 179 | cls: An instance of the class loaded with cached configurations and metadata. 180 | """ 181 | try: 182 | with open(cache_dir / "configs.json", "r") as configs_f: 183 | new_configs = json.load(configs_f) 184 | except FileNotFoundError: 185 | logger.warning( 186 | "Could not find config file in cache directory <{}>. " 187 | + "Creating {} with default configs.", 188 | cache_dir, 189 | cls.__name__, 190 | ) 191 | new_configs = cls.default_configs.model_dump() 192 | 193 | try: 194 | with open(cache_dir / "metadata.json", "r") as metadata_f: 195 | new_metadata = json.load(metadata_f) 196 | except FileNotFoundError: 197 | logger.warning( 198 | "Could not find metadata file in cache directory <{}>. " 199 | + "Creating {} with default metadata.", 200 | cache_dir, 201 | cls.__name__, 202 | ) 203 | new_metadata = None 204 | 205 | new = cls.from_dict(new_configs, **kwargs) 206 | if new_metadata is not None: 207 | new.metadata = new_metadata 208 | logger.debug( 209 | "Reseting chat_id from cache: {} --> {}.", 210 | new.id, 211 | new.metadata["chat_id"], 212 | ) 213 | new.id = new.metadata["chat_id"] 214 | 215 | return new 216 | -------------------------------------------------------------------------------- /pyrobbot/internet_utils.py: -------------------------------------------------------------------------------- 1 | """Internet search module for the package.""" 2 | 3 | import asyncio 4 | import re 5 | 6 | import numpy as np 7 | import requests 8 | from bs4 import BeautifulSoup 9 | from bs4.element import Comment 10 | from duckduckgo_search import AsyncDDGS 11 | from loguru import logger 12 | from sklearn.feature_extraction.text import TfidfVectorizer 13 | from sklearn.metrics.pairwise import cosine_similarity 14 | from unidecode import unidecode 15 | 16 | from . import GeneralDefinitions 17 | from .general_utils import retry 18 | 19 | 20 | def cosine_similarity_sentences(sentence1, sentence2): 21 | """Compute the cosine similarity between two sentences.""" 22 | vectorizer = TfidfVectorizer() 23 | vectors = vectorizer.fit_transform([sentence1, sentence2]) 24 | similarity = cosine_similarity(vectors[0], vectors[1]) 25 | return similarity[0][0] 26 | 27 | 28 | def element_is_visible(element): 29 | """Return True if the element is visible.""" 30 | tags_to_exclude = [ 31 | "[document]", 32 | "head", 33 | "header", 34 | "html", 35 | "input", 36 | "meta", 37 | "noscript", 38 | "script", 39 | "style", 40 | "style", 41 | "title", 42 | ] 43 | if element.parent.name in tags_to_exclude or isinstance(element, Comment): 44 | return False 45 | return True 46 | 47 | 48 | def extract_text_from_html(body): 49 | """Extract the text from an HTML document.""" 50 | soup = BeautifulSoup(body, "html.parser") 51 | 52 | page_has_captcha = soup.find("div", id="recaptcha") is not None 53 | if page_has_captcha: 54 | return "" 55 | 56 | texts = soup.find_all(string=True) 57 | visible_texts = filter(element_is_visible, texts) 58 | return " ".join(t.strip() for t in visible_texts if t.strip()) 59 | 60 | 61 | def find_whole_word_index(my_string, my_substring): 62 | """Find the index of a substring in a string, but only if it is a whole word match.""" 63 | pattern = re.compile(r"\b{}\b".format(re.escape(my_substring))) 64 | match = pattern.search(my_string) 65 | 66 | if match: 67 | return match.start() 68 | return -1 # Substring not found 69 | 70 | 71 | async def async_raw_websearch( 72 | query: str, 73 | max_results: int = 5, 74 | region: str = GeneralDefinitions.IPINFO["country_name"], 75 | ): 76 | """Search the web using DuckDuckGo Search API.""" 77 | async with AsyncDDGS(proxies=None) as addgs: 78 | results = await addgs.text( 79 | keywords=query, 80 | region=region, 81 | max_results=max_results, 82 | backend="html", 83 | ) 84 | return results 85 | 86 | 87 | def raw_websearch( 88 | query: str, 89 | max_results: int = 5, 90 | region: str = GeneralDefinitions.IPINFO["country_name"], 91 | ): 92 | """Search the web using DuckDuckGo Search API.""" 93 | raw_results = asyncio.run( 94 | async_raw_websearch(query=query, max_results=max_results, region=region) 95 | ) 96 | raw_results = raw_results or [] 97 | 98 | results = [] 99 | for result in raw_results: 100 | if not isinstance(result, dict): 101 | logger.error("Expected a `dict`, got type {}: {}", type(result), result) 102 | results.append({}) 103 | continue 104 | 105 | if result.get("body") is None: 106 | continue 107 | 108 | try: 109 | response = requests.get(result["href"], allow_redirects=False, timeout=10) 110 | except (requests.exceptions.ConnectionError, requests.exceptions.ReadTimeout): 111 | continue 112 | else: 113 | content_type = response.headers.get("content-type") 114 | if (not content_type) or ("text/html" not in content_type): 115 | continue 116 | html = unidecode(extract_text_from_html(response.text)) 117 | 118 | summary = unidecode(result["body"]) 119 | relevance = cosine_similarity_sentences(query.lower(), summary.lower()) 120 | 121 | relevance_threshold = 1e-2 122 | if relevance < relevance_threshold: 123 | continue 124 | 125 | new_results = { 126 | "href": result["href"], 127 | "summary": summary, 128 | "detailed": html, 129 | "relevance": relevance, 130 | } 131 | results.append(new_results) 132 | return results 133 | 134 | 135 | @retry(error_msg="Error performing web search") 136 | def websearch(query, **kwargs): 137 | """Search the web using DuckDuckGo Search API.""" 138 | raw_results = raw_websearch(query, **kwargs) 139 | raw_results = iter( 140 | sorted(raw_results, key=lambda x: x.get("relevance", 0.0), reverse=True) 141 | ) 142 | min_relevant_keyword_length = 4 143 | min_n_words = 40 144 | 145 | for result in raw_results: 146 | html = result.get("detailed", "") 147 | 148 | index_first_query_word_to_appear = np.inf 149 | for word in unidecode(query).split(): 150 | if len(word) < min_relevant_keyword_length: 151 | continue 152 | index = find_whole_word_index(html.lower(), word.lower()) 153 | if -1 < index < index_first_query_word_to_appear: 154 | index_first_query_word_to_appear = index 155 | if -1 < index_first_query_word_to_appear < np.inf: 156 | html = html[index_first_query_word_to_appear:] 157 | 158 | selected_words = html.split()[:500] 159 | if len(selected_words) < min_n_words: 160 | # Don't return results with less than approx one paragraph 161 | continue 162 | 163 | html = " ".join(selected_words) 164 | 165 | yield { 166 | "href": result.get("href", ""), 167 | "summary": result.get("summary", ""), 168 | "detailed": html, 169 | "relevance": result.get("relevance", ""), 170 | } 171 | break 172 | 173 | for result in raw_results: 174 | yield { 175 | "href": result.get("href", ""), 176 | "summary": result.get("summary", ""), 177 | "relevance": result.get("relevance", ""), 178 | } 179 | -------------------------------------------------------------------------------- /pyrobbot/openai_utils.py: -------------------------------------------------------------------------------- 1 | """Utils for using the OpenAI API.""" 2 | 3 | import hashlib 4 | import shutil 5 | from typing import TYPE_CHECKING, Optional 6 | 7 | import openai 8 | from loguru import logger 9 | 10 | from . import GeneralDefinitions 11 | from .chat_configs import OpenAiApiCallOptions 12 | from .general_utils import retry 13 | from .tokens import get_n_tokens_from_msgs 14 | 15 | if TYPE_CHECKING: 16 | from .chat import Chat 17 | 18 | 19 | class OpenAiClientWrapper(openai.OpenAI): 20 | """Wrapper for OpenAI API client.""" 21 | 22 | def __init__(self, *args, private_mode: bool = False, **kwargs): 23 | """Initialize the OpenAI API client wrapper.""" 24 | super().__init__(*args, **kwargs) 25 | self.private_mode = private_mode 26 | 27 | self.required_cache_files = [ 28 | "chat_token_usage.db", 29 | "configs.json", 30 | "embeddings.db", 31 | "metadata.json", 32 | ] 33 | self.clear_invalid_cache_dirs() 34 | 35 | @property 36 | def cache_dir(self): 37 | """Return client's cache dir according to the privacy configs.""" 38 | return self.get_cache_dir(private_mode=self.private_mode) 39 | 40 | @property 41 | def saved_chat_cache_paths(self): 42 | """Get the filepaths of saved chat contexts, sorted by last modified.""" 43 | yield from sorted( 44 | (direc for direc in self.cache_dir.glob("chat_*/")), 45 | key=lambda fpath: fpath.stat().st_ctime, 46 | ) 47 | 48 | def clear_invalid_cache_dirs(self): 49 | """Remove cache directories that are missing required files.""" 50 | for directory in self.cache_dir.glob("chat_*/"): 51 | if not all( 52 | (directory / fname).exists() for fname in self.required_cache_files 53 | ): 54 | logger.debug(f"Removing invalid cache directory: {directory}") 55 | shutil.rmtree(directory, ignore_errors=True) 56 | 57 | def get_cache_dir(self, private_mode: Optional[bool] = None): 58 | """Return the directory where the chats using the client will be stored.""" 59 | if private_mode is None: 60 | private_mode = self.private_mode 61 | 62 | if private_mode: 63 | client_id = "demo" 64 | parent_dir = GeneralDefinitions.PACKAGE_TMPDIR 65 | else: 66 | client_id = hashlib.sha256(self.api_key.encode("utf-8")).hexdigest() 67 | parent_dir = GeneralDefinitions.PACKAGE_CACHE_DIRECTORY 68 | 69 | directory = parent_dir / f"user_{client_id}" 70 | directory.mkdir(parents=True, exist_ok=True) 71 | 72 | return directory 73 | 74 | 75 | def make_api_chat_completion_call(conversation: list, chat_obj: "Chat"): 76 | """Stream a chat completion from OpenAI API given a conversation and a chat object. 77 | 78 | Args: 79 | conversation (list): A list of messages passed as input for the completion. 80 | chat_obj (Chat): Chat object containing the configurations for the chat. 81 | 82 | Yields: 83 | str: Chunks of text generated by the API in response to the conversation. 84 | """ 85 | api_call_args = {} 86 | for field in OpenAiApiCallOptions.model_fields: 87 | if getattr(chat_obj, field) is not None: 88 | api_call_args[field] = getattr(chat_obj, field) 89 | 90 | logger.trace( 91 | "Making OpenAI API call with chat=<{}>, args {} and messages {}", 92 | chat_obj.id, 93 | api_call_args, 94 | conversation, 95 | ) 96 | 97 | @retry(error_msg="Problems connecting to OpenAI API") 98 | def stream_reply(conversation, **api_call_args): 99 | # Update the chat's token usage database with tokens used in chat input 100 | # Do this here because every attempt consumes tokens, even if it fails 101 | n_tokens = get_n_tokens_from_msgs(messages=conversation, model=chat_obj.model) 102 | for db in [chat_obj.general_token_usage_db, chat_obj.token_usage_db]: 103 | db.insert_data(model=chat_obj.model, n_input_tokens=n_tokens) 104 | 105 | full_reply_content = "" 106 | for completion_chunk in chat_obj.openai_client.chat.completions.create( 107 | messages=conversation, stream=True, **api_call_args 108 | ): 109 | reply_chunk = getattr(completion_chunk.choices[0].delta, "content", "") 110 | if reply_chunk is None: 111 | break 112 | full_reply_content += reply_chunk 113 | yield reply_chunk 114 | 115 | # Update the chat's token usage database with tokens used in chat output 116 | reply_as_msg = {"role": "assistant", "content": full_reply_content} 117 | n_tokens = get_n_tokens_from_msgs(messages=[reply_as_msg], model=chat_obj.model) 118 | for db in [chat_obj.general_token_usage_db, chat_obj.token_usage_db]: 119 | db.insert_data(model=chat_obj.model, n_output_tokens=n_tokens) 120 | 121 | logger.trace("Done with OpenAI API call") 122 | yield from stream_reply(conversation, **api_call_args) 123 | -------------------------------------------------------------------------------- /pyrobbot/sst_and_tts.py: -------------------------------------------------------------------------------- 1 | """Code related to speech-to-text and text-to-speech conversions.""" 2 | 3 | import io 4 | import socket 5 | import uuid 6 | from dataclasses import dataclass, field 7 | from typing import Literal 8 | 9 | import numpy as np 10 | import speech_recognition as sr 11 | from gtts import gTTS 12 | from loguru import logger 13 | from openai import OpenAI 14 | from pydub import AudioSegment 15 | 16 | from .general_utils import retry 17 | from .tokens import TokenUsageDatabase 18 | 19 | 20 | @dataclass 21 | class SpeechAndTextConfigs: 22 | """Configs for speech-to-text and text-to-speech.""" 23 | 24 | openai_client: OpenAI 25 | general_token_usage_db: TokenUsageDatabase 26 | token_usage_db: TokenUsageDatabase 27 | engine: Literal["openai", "google"] = "google" 28 | language: str = "en" 29 | timeout: int = 10 30 | 31 | 32 | @dataclass 33 | class SpeechToText(SpeechAndTextConfigs): 34 | """Class for converting speech to text.""" 35 | 36 | speech: AudioSegment = None 37 | _text: str = field(init=False, default="") 38 | 39 | def __post_init__(self): 40 | if not self.speech: 41 | self.speech = AudioSegment.silent(duration=0) 42 | self.recogniser = sr.Recognizer() 43 | self.recogniser.operation_timeout = self.timeout 44 | 45 | wav_buffer = io.BytesIO() 46 | self.speech.export(wav_buffer, format="wav") 47 | wav_buffer.seek(0) 48 | with sr.AudioFile(wav_buffer) as source: 49 | self.audio_data = self.recogniser.listen(source) 50 | 51 | @property 52 | def text(self) -> str: 53 | """Return the text from the speech.""" 54 | if not self._text: 55 | self._text = self._stt() 56 | return self._text 57 | 58 | def _stt(self) -> str: 59 | """Perform speech-to-text.""" 60 | if not self.speech: 61 | logger.debug("No speech detected") 62 | return "" 63 | 64 | if self.engine == "openai": 65 | stt_function = self._stt_openai 66 | fallback_stt_function = self._stt_google 67 | fallback_name = "google" 68 | else: 69 | stt_function = self._stt_google 70 | fallback_stt_function = self._stt_openai 71 | fallback_name = "openai" 72 | 73 | conversion_id = uuid.uuid4() 74 | logger.debug( 75 | "Converting audio to text ({} STT). Process {}.", self.engine, conversion_id 76 | ) 77 | try: 78 | rtn = stt_function() 79 | except ( 80 | ConnectionResetError, 81 | socket.timeout, 82 | sr.exceptions.RequestError, 83 | ) as error: 84 | logger.error(error) 85 | logger.error( 86 | "{}: Can't communicate with `{}` speech-to-text API right now", 87 | conversion_id, 88 | self.engine, 89 | ) 90 | logger.warning( 91 | "{}: Trying to use `{}` STT instead", conversion_id, fallback_name 92 | ) 93 | rtn = fallback_stt_function() 94 | except sr.exceptions.UnknownValueError: 95 | logger.opt(colors=True).debug( 96 | "{}: Can't understand audio", conversion_id 97 | ) 98 | rtn = "" 99 | 100 | self._text = rtn.strip() 101 | logger.opt(colors=True).debug( 102 | "{}: Done with STT: {}", conversion_id, self._text 103 | ) 104 | 105 | return self._text 106 | 107 | @retry() 108 | def _stt_openai(self): 109 | """Perform speech-to-text using OpenAI's API.""" 110 | wav_buffer = io.BytesIO(self.audio_data.get_wav_data()) 111 | wav_buffer.name = "audio.wav" 112 | with wav_buffer as audio_file_buffer: 113 | transcript = self.openai_client.audio.transcriptions.create( 114 | model="whisper-1", 115 | file=audio_file_buffer, 116 | language=self.language.split("-")[0], # put in ISO-639-1 format 117 | prompt=f"The language is {self.language}. " 118 | "Do not transcribe if you think the audio is noise.", 119 | ) 120 | 121 | for db in [ 122 | self.general_token_usage_db, 123 | self.token_usage_db, 124 | ]: 125 | db.insert_data( 126 | model="whisper-1", 127 | n_input_tokens=int(np.ceil(self.speech.duration_seconds)), 128 | ) 129 | 130 | return transcript.text 131 | 132 | def _stt_google(self): 133 | """Perform speech-to-text using Google's API.""" 134 | return self.recogniser.recognize_google( 135 | audio_data=self.audio_data, language=self.language 136 | ) 137 | 138 | 139 | @dataclass 140 | class TextToSpeech(SpeechAndTextConfigs): 141 | """Class for converting text to speech.""" 142 | 143 | text: str = "" 144 | openai_tts_voice: str = "" 145 | _speech: AudioSegment = field(init=False, default=None) 146 | 147 | def __post_init__(self): 148 | self.text = self.text.strip() 149 | 150 | @property 151 | def speech(self) -> AudioSegment: 152 | """Return the speech from the text.""" 153 | if not self._speech: 154 | self._speech = self._tts() 155 | return self._speech 156 | 157 | def set_sample_rate(self, sample_rate: int): 158 | """Set the sample rate of the speech.""" 159 | self._speech = self.speech.set_frame_rate(sample_rate) 160 | 161 | def _tts(self): 162 | logger.debug("Running {} TTS on text '{}'", self.engine, self.text) 163 | rtn = self._tts_openai() if self.engine == "openai" else self._tts_google() 164 | logger.debug("Done with TTS for '{}'", self.text) 165 | 166 | return rtn 167 | 168 | def _tts_openai(self) -> AudioSegment: 169 | """Convert text to speech using OpenAI's TTS. Return an AudioSegment object.""" 170 | openai_tts_model = "tts-1" 171 | 172 | @retry() 173 | def _create_speech(*args, **kwargs): 174 | for db in [ 175 | self.general_token_usage_db, 176 | self.token_usage_db, 177 | ]: 178 | db.insert_data(model=openai_tts_model, n_input_tokens=len(self.text)) 179 | return self.openai_client.audio.speech.create(*args, **kwargs) 180 | 181 | response = _create_speech( 182 | input=self.text, 183 | model=openai_tts_model, 184 | voice=self.openai_tts_voice, 185 | response_format="mp3", 186 | timeout=self.timeout, 187 | ) 188 | 189 | mp3_buffer = io.BytesIO() 190 | for mp3_stream_chunk in response.iter_bytes(chunk_size=4096): 191 | mp3_buffer.write(mp3_stream_chunk) 192 | mp3_buffer.seek(0) 193 | 194 | audio = AudioSegment.from_mp3(mp3_buffer) 195 | audio += 8 # Increase volume a bit 196 | return audio 197 | 198 | def _tts_google(self) -> AudioSegment: 199 | """Convert text to speech using Google's TTS. Return a WAV BytesIO object.""" 200 | tts = gTTS(self.text, lang=self.language) 201 | mp3_buffer = io.BytesIO() 202 | tts.write_to_fp(mp3_buffer) 203 | mp3_buffer.seek(0) 204 | 205 | return AudioSegment.from_mp3(mp3_buffer) 206 | -------------------------------------------------------------------------------- /pyrobbot/tokens.py: -------------------------------------------------------------------------------- 1 | """Management of token usage and costs for OpenAI API.""" 2 | 3 | import contextlib 4 | import datetime 5 | import sqlite3 6 | from pathlib import Path 7 | from typing import Optional 8 | 9 | import pandas as pd 10 | import tiktoken 11 | 12 | # See for the latest prices. 13 | PRICE_PER_K_TOKENS_LLM = { 14 | # Continuous model upgrades (models that point to the latest versions) 15 | "gpt-3.5-turbo": {"input": 0.0005, "output": 0.0015}, 16 | "gpt-4-turbo-preview": {"input": 0.01, "output": 0.03}, 17 | "gpt-4": {"input": 0.03, "output": 0.06}, 18 | "gpt-3.5-turbo-16k": {"input": 0.001, "output": 0.002}, # -> gpt-3.5-turbo-16k-0613 19 | "gpt-4-32k": {"input": 0.06, "output": 0.12}, 20 | # Static model versions 21 | # GPT 3 22 | "gpt-3.5-turbo-0125": {"input": 0.0015, "output": 0.002}, 23 | "gpt-3.5-turbo-1106": {"input": 0.001, "output": 0.002}, 24 | "gpt-3.5-turbo-0613": {"input": 0.0015, "output": 0.002}, # Deprecated, 2024-06-13 25 | "gpt-3.5-turbo-16k-0613": {"input": 0.001, "output": 0.002}, # Deprecated, 2024-06-13 26 | # GPT 4 27 | "gpt-4-0125-preview": {"input": 0.01, "output": 0.03}, 28 | "gpt-4-1106-preview": {"input": 0.01, "output": 0.03}, 29 | "gpt-4-0613": {"input": 0.03, "output": 0.06}, 30 | "gpt-4-32k-0613": {"input": 0.06, "output": 0.12}, 31 | } 32 | PRICE_PER_K_TOKENS_EMBEDDINGS = { 33 | "text-embedding-3-small": {"input": 0.00002, "output": 0.0}, 34 | "text-embedding-3-large": {"input": 0.00013, "output": 0.0}, 35 | "text-embedding-ada-002": {"input": 0.0001, "output": 0.0}, 36 | "text-embedding-ada-002-v2": {"input": 0.0001, "output": 0.0}, 37 | "text-davinci:002": {"input": 0.0020, "output": 0.020}, 38 | "full-history": {"input": 0.0, "output": 0.0}, 39 | } 40 | PRICE_PER_K_TOKENS_TTS_AND_STT = { 41 | "tts-1": {"input": 0.015, "output": 0.0}, 42 | "tts-1-hd": {"input": 0.03, "output": 0.0}, 43 | "whisper-1": {"input": 0.006, "output": 0.0}, 44 | } 45 | PRICE_PER_K_TOKENS = ( 46 | PRICE_PER_K_TOKENS_LLM 47 | | PRICE_PER_K_TOKENS_EMBEDDINGS 48 | | PRICE_PER_K_TOKENS_TTS_AND_STT 49 | ) 50 | 51 | 52 | class TokenUsageDatabase: 53 | """Manages a database to store estimated token usage and costs for OpenAI API.""" 54 | 55 | def __init__(self, fpath: Path): 56 | """Initialize a TokenUsageDatabase instance.""" 57 | self.fpath = fpath 58 | self.token_price = {} 59 | for model, price_per_k_tokens in PRICE_PER_K_TOKENS.items(): 60 | self.token_price[model] = { 61 | k: v / 1000.0 for k, v in price_per_k_tokens.items() 62 | } 63 | 64 | self.create() 65 | 66 | def create(self): 67 | """Create the database if it doesn't exist.""" 68 | self.fpath.parent.mkdir(parents=True, exist_ok=True) 69 | conn = sqlite3.connect(self.fpath) 70 | cursor = conn.cursor() 71 | 72 | # Create a table to store the data with 'timestamp' as the primary key 73 | cursor.execute( 74 | """ 75 | CREATE TABLE IF NOT EXISTS token_costs ( 76 | timestamp INTEGER NOT NULL, 77 | model TEXT NOT NULL, 78 | n_input_tokens INTEGER NOT NULL, 79 | n_output_tokens INTEGER NOT NULL, 80 | cost_input_tokens REAL NOT NULL, 81 | cost_output_tokens REAL NOT NULL 82 | ) 83 | """ 84 | ) 85 | 86 | conn.commit() 87 | conn.close() 88 | 89 | def insert_data( 90 | self, 91 | model: str, 92 | n_input_tokens: int = 0, 93 | n_output_tokens: int = 0, 94 | timestamp: Optional[int] = None, 95 | ): 96 | """Insert the data into the token_costs table.""" 97 | if model is None: 98 | return 99 | 100 | conn = sqlite3.connect(self.fpath) 101 | cursor = conn.cursor() 102 | 103 | # Insert the data into the table 104 | cursor.execute( 105 | """ 106 | INSERT INTO token_costs ( 107 | timestamp, 108 | model, 109 | n_input_tokens, 110 | n_output_tokens, 111 | cost_input_tokens, 112 | cost_output_tokens 113 | ) 114 | VALUES (?, ?, ?, ?, ?, ?) 115 | """, 116 | ( 117 | timestamp or int(datetime.datetime.utcnow().timestamp()), 118 | model, 119 | n_input_tokens, 120 | n_output_tokens, 121 | n_input_tokens * self.token_price[model]["input"], 122 | n_output_tokens * self.token_price[model]["output"], 123 | ), 124 | ) 125 | 126 | conn.commit() 127 | conn.close() 128 | 129 | def get_usage_balance_dataframe(self): 130 | """Get a dataframe with the accumulated token usage and costs.""" 131 | conn = sqlite3.connect(self.fpath) 132 | query = """ 133 | SELECT 134 | model as Model, 135 | MIN(timestamp) AS "First Used", 136 | SUM(n_input_tokens) AS "Tokens: In", 137 | SUM(n_output_tokens) AS "Tokens: Out", 138 | SUM(n_input_tokens + n_output_tokens) AS "Tokens: Tot.", 139 | SUM(cost_input_tokens) AS "Cost ($): In", 140 | SUM(cost_output_tokens) AS "Cost ($): Out", 141 | SUM(cost_input_tokens + cost_output_tokens) AS "Cost ($): Tot." 142 | FROM token_costs 143 | GROUP BY model 144 | ORDER BY "Cost ($): Tot." DESC 145 | """ 146 | 147 | usage_df = pd.read_sql_query(query, con=conn) 148 | conn.close() 149 | 150 | usage_df["First Used"] = pd.to_datetime(usage_df["First Used"], unit="s") 151 | 152 | usage_df = _group_columns_by_prefix(_add_totals_row(usage_df)) 153 | 154 | # Add metadata to returned dataframe 155 | usage_df.attrs["description"] = "Estimated token usage and associated costs" 156 | link = "https://platform.openai.com/account/usage" 157 | disclaimers = [ 158 | "Note: These are only estimates. Actual costs may vary.", 159 | f"Please visit <{link}> to follow your actual usage and costs.", 160 | ] 161 | usage_df.attrs["disclaimer"] = "\n".join(disclaimers) 162 | 163 | return usage_df 164 | 165 | 166 | def get_n_tokens_from_msgs(messages: list[dict], model: str): 167 | """Returns the number of tokens used by a list of messages.""" 168 | # Adapted from 169 | # 170 | encoding = tiktoken.get_encoding("cl100k_base") 171 | with contextlib.suppress(KeyError): 172 | encoding = tiktoken.encoding_for_model(model) 173 | 174 | # OpenAI's original function was implemented for gpt-3.5-turbo-0613, but we'll use 175 | # it for all models for now. We are only interested in estimates, after all. 176 | num_tokens = 0 177 | for message in messages: 178 | # every message follows {role/name}\n{content}\n 179 | num_tokens += 4 180 | for key, value in message.items(): 181 | if not isinstance(value, str): 182 | raise TypeError( 183 | f"Value for key '{key}' has type {type(value)}. Expected str: {value}" 184 | ) 185 | num_tokens += len(encoding.encode(value)) 186 | if key == "name": # if there's a name, the role is omitted 187 | num_tokens += -1 # role is always required and always 1 token 188 | num_tokens += 2 # every reply is primed with assistant 189 | return num_tokens 190 | 191 | 192 | def _group_columns_by_prefix(dataframe: pd.DataFrame): 193 | dataframe = dataframe.copy() 194 | col_tuples_for_multiindex = dataframe.columns.str.split(": ", expand=True).to_numpy() 195 | dataframe.columns = pd.MultiIndex.from_tuples( 196 | [("", x[0]) if pd.isna(x[1]) else x for x in col_tuples_for_multiindex] 197 | ) 198 | return dataframe 199 | 200 | 201 | def _add_totals_row(accounting_df: pd.DataFrame): 202 | dtypes = accounting_df.dtypes 203 | sums_df = accounting_df.sum(numeric_only=True).rename("Total").to_frame().T 204 | return pd.concat([accounting_df, sums_df]).astype(dtypes).fillna(" ") 205 | -------------------------------------------------------------------------------- /pyrobbot/voice_chat.py: -------------------------------------------------------------------------------- 1 | """Code related to the voice chat feature.""" 2 | 3 | import contextlib 4 | import io 5 | import queue 6 | import threading 7 | import time 8 | from collections import defaultdict, deque 9 | from datetime import datetime 10 | 11 | import chime 12 | import numpy as np 13 | import pydub 14 | import pygame 15 | import soundfile as sf 16 | import webrtcvad 17 | from loguru import logger 18 | from pydub import AudioSegment 19 | 20 | from .chat import Chat 21 | from .chat_configs import VoiceChatConfigs 22 | from .general_utils import _get_lower_alphanumeric, str2_minus_str1 23 | from .sst_and_tts import TextToSpeech 24 | 25 | try: 26 | import sounddevice as sd 27 | except OSError as error: 28 | logger.exception(error) 29 | logger.error( 30 | "Can't use module `sounddevice`. Please check your system's PortAudio install." 31 | ) 32 | _sounddevice_imported = False 33 | else: 34 | _sounddevice_imported = True 35 | 36 | try: 37 | # Test if pydub's AudioSegment can be used 38 | with contextlib.suppress(pydub.exceptions.CouldntDecodeError): 39 | AudioSegment.from_mp3(io.BytesIO()) 40 | except (ImportError, OSError, FileNotFoundError) as error: 41 | logger.exception(error) 42 | logger.error("Can't use module `pydub`. Please check your system's ffmpeg install.") 43 | _pydub_usable = False 44 | else: 45 | _pydub_usable = True 46 | 47 | 48 | class VoiceChat(Chat): 49 | """Class for converting text to speech and speech to text.""" 50 | 51 | default_configs = VoiceChatConfigs() 52 | 53 | def __init__(self, configs: VoiceChatConfigs = default_configs, **kwargs): 54 | """Initializes a chat instance.""" 55 | super().__init__(configs=configs, **kwargs) 56 | _check_needed_imports() 57 | 58 | self.block_size = int((self.sample_rate * self.frame_duration) / 1000) 59 | 60 | self.vad = webrtcvad.Vad(2) 61 | 62 | self.default_chime_theme = "big-sur" 63 | chime.theme(self.default_chime_theme) 64 | 65 | # Create queues and threads for handling the chat 66 | # 1. Watching for questions from the user 67 | self.questions_queue = queue.Queue() 68 | self.questions_listening_watcher_thread = threading.Thread( 69 | target=self.handle_question_listening, 70 | args=(self.questions_queue,), 71 | daemon=True, 72 | ) 73 | # 2. Converting assistant's text reply to speech and playing it 74 | self.tts_conversion_queue = queue.Queue() 75 | self.play_speech_queue = queue.Queue() 76 | self.tts_conversion_watcher_thread = threading.Thread( 77 | target=self.handle_tts_conversion_queue, 78 | args=(self.tts_conversion_queue,), 79 | daemon=True, 80 | ) 81 | self.play_speech_thread = threading.Thread( 82 | target=self.handle_play_speech_queue, 83 | args=(self.play_speech_queue,), 84 | daemon=True, 85 | ) # TODO: Do not start this in webchat 86 | # 3. Watching for expressions that cancel the reply or exit the chat 87 | self.check_for_interrupt_expressions_queue = queue.Queue() 88 | self.check_for_interrupt_expressions_thread = threading.Thread( 89 | target=self.check_for_interrupt_expressions_handler, 90 | args=(self.check_for_interrupt_expressions_queue,), 91 | daemon=True, 92 | ) 93 | self.interrupt_reply = threading.Event() 94 | self.exit_chat = threading.Event() 95 | 96 | # Keep track of played audios to update the history db 97 | self.current_answer_audios_queue = queue.Queue() 98 | self.handle_update_audio_history_thread = threading.Thread( 99 | target=self.handle_update_audio_history, 100 | args=(self.current_answer_audios_queue,), 101 | daemon=True, 102 | ) 103 | 104 | @property 105 | def mixer(self): 106 | """Return the mixer object.""" 107 | mixer = getattr(self, "_mixer", None) 108 | if mixer is not None: 109 | return mixer 110 | 111 | self._mixer = pygame.mixer 112 | try: 113 | self.mixer.init( 114 | frequency=self.sample_rate, channels=1, buffer=self.block_size 115 | ) 116 | except pygame.error as error: 117 | logger.exception(error) 118 | logger.error( 119 | "Can't initialize the mixer. Please check your system's audio settings." 120 | ) 121 | logger.warning("Voice chat may not be available or may not work as expected.") 122 | return self._mixer 123 | 124 | def start(self): 125 | """Start the chat.""" 126 | # ruff: noqa: T201 127 | self.tts_conversion_watcher_thread.start() 128 | self.play_speech_thread.start() 129 | if not self.skip_initial_greeting: 130 | tts_entry = {"exchange_id": self.id, "text": self.initial_greeting} 131 | self.tts_conversion_queue.put(tts_entry) 132 | while self._assistant_still_replying(): 133 | pygame.time.wait(50) 134 | self.questions_listening_watcher_thread.start() 135 | self.check_for_interrupt_expressions_thread.start() 136 | self.handle_update_audio_history_thread.start() 137 | 138 | with contextlib.suppress(KeyboardInterrupt, EOFError): 139 | while not self.exit_chat.is_set(): 140 | self.tts_conversion_queue.join() 141 | self.play_speech_queue.join() 142 | self.current_answer_audios_queue.join() 143 | 144 | if self.interrupt_reply.is_set(): 145 | logger.opt(colors=True).debug( 146 | "Interrupting the reply" 147 | ) 148 | with self.check_for_interrupt_expressions_queue.mutex: 149 | self.check_for_interrupt_expressions_queue.queue.clear() 150 | with contextlib.suppress(pygame.error): 151 | self.mixer.stop() 152 | with self.questions_queue.mutex: 153 | self.questions_queue.queue.clear() 154 | chime.theme("material") 155 | chime.error() 156 | chime.theme(self.default_chime_theme) 157 | time.sleep(0.25) 158 | 159 | chime.warning() 160 | self.interrupt_reply.clear() 161 | logger.debug(f"{self.assistant_name}> Waiting for user input...") 162 | question = self.questions_queue.get() 163 | self.questions_queue.task_done() 164 | 165 | if question is None: 166 | self.exit_chat.set() 167 | else: 168 | chime.success() 169 | for chunk in self.answer_question(question): 170 | if chunk.chunk_type == "code": 171 | print(chunk.content, end="", flush=True) 172 | 173 | self.exit_chat.set() 174 | chime.info() 175 | logger.debug("Leaving chat") 176 | 177 | def answer_question(self, question: str): 178 | """Answer a question.""" 179 | logger.debug("{}> Getting response to '{}'...", self.assistant_name, question) 180 | sentence_for_tts = "" 181 | any_code_chunk_yet = False 182 | for answer_chunk in self.respond_user_prompt(prompt=question): 183 | if self.interrupt_reply.is_set() or self.exit_chat.is_set(): 184 | logger.debug("Reply interrupted.") 185 | raise StopIteration 186 | yield answer_chunk 187 | 188 | if not self.reply_only_as_text: 189 | if answer_chunk.chunk_type not in ("text", "code"): 190 | raise NotImplementedError( 191 | "Unexpected chunk type: {}".format(answer_chunk.chunk_type) 192 | ) 193 | 194 | if answer_chunk.chunk_type == "text": 195 | # The answer chunk is to be spoken 196 | sentence_for_tts += answer_chunk.content 197 | stripd_chunk = answer_chunk.content.strip() 198 | if stripd_chunk.endswith(("?", "!", ".")): 199 | # Check if second last character is a number, to avoid splitting 200 | if stripd_chunk.endswith("."): 201 | with contextlib.suppress(IndexError): 202 | previous_char = sentence_for_tts.strip()[-2] 203 | if previous_char.isdigit(): 204 | continue 205 | # Send sentence for TTS even if the request hasn't finished 206 | tts_entry = { 207 | "exchange_id": answer_chunk.exchange_id, 208 | "text": sentence_for_tts, 209 | } 210 | self.tts_conversion_queue.put(tts_entry) 211 | sentence_for_tts = "" 212 | elif answer_chunk.chunk_type == "code" and not any_code_chunk_yet: 213 | msg = self._translate("Code will be displayed in the text output.") 214 | tts_entry = {"exchange_id": answer_chunk.exchange_id, "text": msg} 215 | self.tts_conversion_queue.put(tts_entry) 216 | any_code_chunk_yet = True 217 | 218 | if sentence_for_tts and not self.reply_only_as_text: 219 | tts_entry = { 220 | "exchange_id": answer_chunk.exchange_id, 221 | "text": sentence_for_tts, 222 | } 223 | self.tts_conversion_queue.put(tts_entry) 224 | 225 | # Signal that the current answer is finished 226 | tts_entry = {"exchange_id": answer_chunk.exchange_id, "text": None} 227 | self.tts_conversion_queue.put(tts_entry) 228 | 229 | def handle_update_audio_history(self, current_answer_audios_queue: queue.Queue): 230 | """Handle updating the chat history with the replies' audio file paths.""" 231 | # Merge all AudioSegments in self.current_answer_audios_queue into a single one 232 | merged_audios = defaultdict(AudioSegment.empty) 233 | while not self.exit_chat.is_set(): 234 | try: 235 | logger.debug("Waiting for reply audio chunks to concatenate and save...") 236 | audio_chunk_queue_item = current_answer_audios_queue.get() 237 | reply_audio_chunk = audio_chunk_queue_item["speech"] 238 | exchange_id = audio_chunk_queue_item["exchange_id"] 239 | logger.debug("Received audio chunk for response ID {}", exchange_id) 240 | 241 | if reply_audio_chunk is not None: 242 | # Reply not yet finished 243 | merged_audios[exchange_id] += reply_audio_chunk 244 | logger.debug( 245 | "Response ID {} audio: {}s so far", 246 | exchange_id, 247 | merged_audios[exchange_id].duration_seconds, 248 | ) 249 | current_answer_audios_queue.task_done() 250 | continue 251 | 252 | # Now the reply has finished 253 | logger.debug( 254 | "Creating a single audio file for response ID {}...", exchange_id 255 | ) 256 | merged_audio = merged_audios[exchange_id] 257 | # Update the chat history with the audio file path 258 | fpath = self.audio_cache_dir() / f"{datetime.now().isoformat()}.mp3" 259 | logger.debug("Updating chat history with audio file path {}", fpath) 260 | self.context_handler.database.insert_assistant_audio_file_path( 261 | exchange_id=exchange_id, file_path=fpath 262 | ) 263 | # Save the combined audio as an mp3 file in the cache directory 264 | merged_audio.export(fpath, format="mp3") 265 | logger.debug("File {} stored", fpath) 266 | del merged_audios[exchange_id] 267 | current_answer_audios_queue.task_done() 268 | except Exception as error: # noqa: BLE001 269 | logger.error(error) 270 | logger.opt(exception=True).debug(error) 271 | 272 | def speak(self, tts: TextToSpeech): 273 | """Reproduce audio from a pygame Sound object.""" 274 | tts.set_sample_rate(self.sample_rate) 275 | self.mixer.Sound(tts.speech.raw_data).play() 276 | audio_recorded_while_assistant_replies = self.listen( 277 | duration_seconds=tts.speech.duration_seconds 278 | ) 279 | 280 | msgs_to_compare = { 281 | "assistant_txt": tts.text, 282 | "user_audio": audio_recorded_while_assistant_replies, 283 | } 284 | self.check_for_interrupt_expressions_queue.put(msgs_to_compare) 285 | 286 | while self.mixer.get_busy(): 287 | pygame.time.wait(100) 288 | 289 | def check_for_interrupt_expressions_handler( 290 | self, check_for_interrupt_expressions_queue: queue.Queue 291 | ): 292 | """Check for expressions that interrupt the assistant's reply.""" 293 | while not self.exit_chat.is_set(): 294 | try: 295 | msgs_to_compare = check_for_interrupt_expressions_queue.get() 296 | recorded_prompt = self.stt(speech=msgs_to_compare["user_audio"]).text 297 | 298 | recorded_prompt = _get_lower_alphanumeric(recorded_prompt).strip() 299 | assistant_msg = _get_lower_alphanumeric( 300 | msgs_to_compare.get("assistant_txt", "") 301 | ).strip() 302 | 303 | user_words = str2_minus_str1( 304 | str1=assistant_msg, str2=recorded_prompt 305 | ).strip() 306 | if user_words: 307 | logger.debug( 308 | "Detected user words while assistant was replying: {}", 309 | user_words, 310 | ) 311 | if any( 312 | cancel_cmd in user_words for cancel_cmd in self.cancel_expressions 313 | ): 314 | logger.debug( 315 | "Heard '{}'. Signalling for reply to be cancelled...", 316 | user_words, 317 | ) 318 | self.interrupt_reply.set() 319 | except Exception as error: # noqa: PERF203, BLE001 320 | logger.opt(exception=True).debug(error) 321 | finally: 322 | check_for_interrupt_expressions_queue.task_done() 323 | 324 | def listen(self, duration_seconds: float = np.inf) -> AudioSegment: 325 | """Record audio from the microphone until user stops.""" 326 | # Adapted from 327 | # 329 | debug_msg = "The assistant is listening" 330 | if duration_seconds < np.inf: 331 | debug_msg += f" for {duration_seconds} s" 332 | debug_msg += "..." 333 | 334 | inactivity_timeout_seconds = self.inactivity_timeout_seconds 335 | if duration_seconds < np.inf: 336 | inactivity_timeout_seconds = duration_seconds 337 | 338 | q = queue.Queue() 339 | 340 | def callback(indata, frames, time, status): # noqa: ARG001 341 | """This is called (from a separate thread) for each audio block.""" 342 | q.put(indata.copy()) 343 | 344 | raw_buffer = io.BytesIO() 345 | start_time = datetime.now() 346 | with self.get_sound_file(raw_buffer, mode="x") as sound_file, sd.InputStream( 347 | samplerate=self.sample_rate, 348 | blocksize=self.block_size, 349 | channels=1, 350 | callback=callback, 351 | dtype="int16", # int16, i.e., 2 bytes per sample 352 | ): 353 | logger.debug("{}", debug_msg) 354 | # Recording will stop after inactivity_timeout_seconds of silence 355 | voice_activity_detected = deque( 356 | maxlen=int((1000.0 * inactivity_timeout_seconds) / self.frame_duration) 357 | ) 358 | last_inactivity_checked = datetime.now() 359 | continue_recording = True 360 | speech_detected = False 361 | elapsed_time = 0.0 362 | with contextlib.suppress(KeyboardInterrupt): 363 | while continue_recording and elapsed_time < duration_seconds: 364 | new_data = q.get() 365 | sound_file.write(new_data) 366 | 367 | # Gather voice activity samples for the inactivity check 368 | wav_buffer = _np_array_to_wav_in_memory( 369 | sound_data=new_data, 370 | sample_rate=self.sample_rate, 371 | subtype="PCM_16", 372 | ) 373 | 374 | vad_thinks_this_chunk_is_speech = self.vad.is_speech( 375 | wav_buffer, self.sample_rate 376 | ) 377 | voice_activity_detected.append(vad_thinks_this_chunk_is_speech) 378 | 379 | # Decide if user has been inactive for too long 380 | now = datetime.now() 381 | if duration_seconds < np.inf: 382 | continue_recording = True 383 | elif ( 384 | now - last_inactivity_checked 385 | ).seconds >= inactivity_timeout_seconds: 386 | speech_likelihood = 0.0 387 | if len(voice_activity_detected) > 0: 388 | speech_likelihood = sum(voice_activity_detected) / len( 389 | voice_activity_detected 390 | ) 391 | continue_recording = ( 392 | speech_likelihood >= self.speech_likelihood_threshold 393 | ) 394 | if continue_recording: 395 | speech_detected = True 396 | last_inactivity_checked = now 397 | 398 | elapsed_time = (now - start_time).seconds 399 | 400 | if speech_detected or duration_seconds < np.inf: 401 | return AudioSegment.from_wav(raw_buffer) 402 | return AudioSegment.empty() 403 | 404 | def handle_question_listening(self, questions_queue: queue.Queue): 405 | """Handle the queue of questions to be answered.""" 406 | minimum_prompt_duration_seconds = 0.05 407 | while not self.exit_chat.is_set(): 408 | if self._assistant_still_replying(): 409 | pygame.time.wait(100) 410 | continue 411 | try: 412 | audio = self.listen() 413 | if audio is None: 414 | questions_queue.put(None) 415 | continue 416 | 417 | if audio.duration_seconds < minimum_prompt_duration_seconds: 418 | continue 419 | 420 | question = self.stt(speech=audio).text 421 | 422 | # Check for the exit expressions 423 | if any( 424 | _get_lower_alphanumeric(question).startswith( 425 | _get_lower_alphanumeric(expr) 426 | ) 427 | for expr in self.exit_expressions 428 | ): 429 | questions_queue.put(None) 430 | elif question: 431 | questions_queue.put(question) 432 | except sd.PortAudioError as error: 433 | logger.opt(exception=True).debug(error) 434 | except Exception as error: # noqa: BLE001 435 | logger.opt(exception=True).debug(error) 436 | logger.error(error) 437 | 438 | def handle_play_speech_queue(self, play_speech_queue: queue.Queue[TextToSpeech]): 439 | """Handle the queue of audio segments to be played.""" 440 | while not self.exit_chat.is_set(): 441 | try: 442 | play_speech_queue_item = play_speech_queue.get() 443 | if play_speech_queue_item["speech"] and not self.interrupt_reply.is_set(): 444 | self.speak(play_speech_queue_item["tts_obj"]) 445 | except Exception as error: # noqa: BLE001, PERF203 446 | logger.exception(error) 447 | finally: 448 | play_speech_queue.task_done() 449 | 450 | def handle_tts_conversion_queue(self, tts_conversion_queue: queue.Queue): 451 | """Handle the text-to-speech queue.""" 452 | logger.debug("Chat {}: TTS conversion handler started.", self.id) 453 | while not self.exit_chat.is_set(): 454 | try: 455 | tts_entry = tts_conversion_queue.get() 456 | if tts_entry["text"] is None: 457 | # Signal that the current anwer is finished 458 | play_speech_queue_item = { 459 | "exchange_id": tts_entry["exchange_id"], 460 | "speech": None, 461 | } 462 | self.play_speech_queue.put(play_speech_queue_item) 463 | self.current_answer_audios_queue.put(play_speech_queue_item) 464 | 465 | logger.debug( 466 | "Reply ID {} notified that is has finished", 467 | tts_entry["exchange_id"], 468 | ) 469 | tts_conversion_queue.task_done() 470 | continue 471 | 472 | text = tts_entry["text"].strip() 473 | if text and not self.interrupt_reply.is_set(): 474 | logger.debug( 475 | "Reply ID {}: received text '{}' for TTS", 476 | tts_entry["exchange_id"], 477 | text, 478 | ) 479 | 480 | tts_obj = self.tts(text) 481 | # Trigger the TTS conversion 482 | _ = tts_obj.speech 483 | 484 | logger.debug( 485 | "Reply ID {}: Sending speech for '{}' to the playing queue", 486 | tts_entry["exchange_id"], 487 | text, 488 | ) 489 | play_speech_queue_item = { 490 | "exchange_id": tts_entry["exchange_id"], 491 | "tts_obj": tts_obj, 492 | "speech": tts_obj.speech, 493 | } 494 | self.play_speech_queue.put(play_speech_queue_item) 495 | self.current_answer_audios_queue.put(play_speech_queue_item) 496 | 497 | # Pay attention to the indentation level 498 | tts_conversion_queue.task_done() 499 | 500 | except Exception as error: # noqa: BLE001 501 | logger.opt(exception=True).debug(error) 502 | logger.error(error) 503 | logger.error("TTS conversion queue handler ended.") 504 | 505 | def get_sound_file(self, wav_buffer: io.BytesIO, mode: str = "r"): 506 | """Return a sound file object.""" 507 | return sf.SoundFile( 508 | wav_buffer, 509 | mode=mode, 510 | samplerate=self.sample_rate, 511 | channels=1, 512 | format="wav", 513 | subtype="PCM_16", 514 | ) 515 | 516 | def audio_cache_dir(self): 517 | """Return the audio cache directory.""" 518 | directory = self.cache_dir / "audio_files" 519 | directory.mkdir(parents=True, exist_ok=True) 520 | return directory 521 | 522 | def _assistant_still_replying(self): 523 | """Check if the assistant is still talking.""" 524 | return ( 525 | self.mixer.get_busy() 526 | or self.questions_queue.unfinished_tasks > 0 527 | or self.tts_conversion_queue.unfinished_tasks > 0 528 | or self.play_speech_queue.unfinished_tasks > 0 529 | ) 530 | 531 | 532 | def _check_needed_imports(): 533 | """Check if the needed modules are available.""" 534 | if not _sounddevice_imported: 535 | logger.warning( 536 | "Module `sounddevice`, needed for local audio recording, is not available." 537 | ) 538 | 539 | if not _pydub_usable: 540 | logger.error( 541 | "Module `pydub`, needed for audio conversion, doesn't seem to be working. " 542 | "Voice chat may not be available or may not work as expected." 543 | ) 544 | 545 | 546 | def _np_array_to_wav_in_memory( 547 | sound_data: np.ndarray, sample_rate: int, subtype="PCM_16" 548 | ): 549 | """Convert the recorded array to an in-memory wav file.""" 550 | wav_buffer = io.BytesIO() 551 | wav_buffer.name = "audio.wav" 552 | sf.write(wav_buffer, sound_data, sample_rate, subtype=subtype) 553 | wav_buffer.seek(44) # Skip the WAV header 554 | return wav_buffer.read() 555 | -------------------------------------------------------------------------------- /tests/conftest.py: -------------------------------------------------------------------------------- 1 | import io 2 | 3 | import lorem 4 | import numpy as np 5 | import pydub 6 | import pytest 7 | from _pytest.logging import LogCaptureFixture 8 | from loguru import logger 9 | 10 | import pyrobbot 11 | from pyrobbot.chat import Chat 12 | from pyrobbot.chat_configs import ChatOptions, VoiceChatConfigs 13 | from pyrobbot.voice_chat import VoiceChat 14 | 15 | 16 | @pytest.fixture() 17 | def caplog(caplog: LogCaptureFixture): 18 | """Override the default `caplog` fixture to propagate Loguru to the caplog handler.""" 19 | # Source: 21 | handler_id = logger.add( 22 | caplog.handler, 23 | format="{message}", 24 | level=0, 25 | filter=lambda record: record["level"].no >= caplog.handler.level, 26 | enqueue=False, # Set to 'True' if your test is spawning child processes. 27 | ) 28 | yield caplog 29 | logger.remove(handler_id) 30 | 31 | 32 | # Register markers and constants 33 | def pytest_configure(config): 34 | config.addinivalue_line( 35 | "markers", 36 | "no_chat_completion_create_mocking: do not mock openai.ChatCompletion.create", 37 | ) 38 | config.addinivalue_line( 39 | "markers", 40 | "no_embedding_create_mocking: mark test to not mock openai.Embedding.create", 41 | ) 42 | 43 | pytest.original_package_cache_directory = ( 44 | pyrobbot.GeneralDefinitions.PACKAGE_CACHE_DIRECTORY 45 | ) 46 | 47 | 48 | @pytest.fixture(autouse=True) 49 | def _set_env(monkeypatch): 50 | # Make sure we don't consume our tokens in tests 51 | monkeypatch.setenv("OPENAI_API_KEY", "INVALID_API_KEY") 52 | monkeypatch.setenv("STREAMLIT_SERVER_HEADLESS", "true") 53 | 54 | 55 | @pytest.fixture(autouse=True) 56 | def _mocked_general_constants(tmp_path, mocker): 57 | mocker.patch( 58 | "pyrobbot.GeneralDefinitions.PACKAGE_CACHE_DIRECTORY", tmp_path / "cache" 59 | ) 60 | 61 | 62 | @pytest.fixture() 63 | def mock_wav_bytes_string(): 64 | """Mock a WAV file as a bytes string.""" 65 | return ( 66 | b"RIFF$\x00\x00\x00WAVEfmt \x10\x00\x00\x00\x01\x00\x01\x00\x00\x04\x00" 67 | b"\x00\x00\x04\x00\x00\x01\x00\x08\x00data\x00\x00\x00\x00" 68 | ) 69 | 70 | 71 | @pytest.fixture(autouse=True) 72 | def _openai_api_request_mockers(request, mocker): 73 | """Mockers for OpenAI API requests. We don't want to consume our tokens in tests.""" 74 | 75 | def _mock_openai_chat_completion_create(*args, **kwargs): # noqa: ARG001 76 | """Mock `openai.ChatCompletion.create`. Yield from lorem ipsum instead.""" 77 | completion_chunk = type("CompletionChunk", (), {}) 78 | completion_chunk_choice = type("CompletionChunkChoice", (), {}) 79 | completion_chunk_choice_delta = type("CompletionChunkChoiceDelta", (), {}) 80 | for word in lorem.get_paragraph().split(): 81 | completion_chunk_choice_delta.content = word + " " 82 | completion_chunk_choice.delta = completion_chunk_choice_delta 83 | completion_chunk.choices = [completion_chunk_choice] 84 | yield completion_chunk 85 | 86 | # Yield some code as well, to test the code filtering 87 | code_path = pyrobbot.GeneralDefinitions.PACKAGE_DIRECTORY / "__init__.py" 88 | for word in [ 89 | "```python\n", 90 | *code_path.read_text().splitlines(keepends=True)[:5], 91 | "```\n", 92 | ]: 93 | completion_chunk_choice_delta.content = word + " " 94 | completion_chunk_choice.delta = completion_chunk_choice_delta 95 | completion_chunk.choices = [completion_chunk_choice] 96 | yield completion_chunk 97 | 98 | def _mock_openai_embedding_create(*args, **kwargs): # noqa: ARG001 99 | """Mock `openai.Embedding.create`. Yield from lorem ipsum instead.""" 100 | embedding_request_mock_type = type("EmbeddingRequest", (), {}) 101 | embedding_mock_type = type("Embedding", (), {}) 102 | usage_mock_type = type("Usage", (), {}) 103 | 104 | embedding = embedding_mock_type() 105 | embedding.embedding = np.random.rand(512).tolist() 106 | embedding_request = embedding_request_mock_type() 107 | embedding_request.data = [embedding] 108 | 109 | usage = usage_mock_type() 110 | usage.prompt_tokens = 0 111 | usage.total_tokens = 0 112 | embedding_request.usage = usage 113 | 114 | return embedding_request 115 | 116 | if "no_chat_completion_create_mocking" not in request.keywords: 117 | mocker.patch( 118 | "openai.resources.chat.completions.Completions.create", 119 | new=_mock_openai_chat_completion_create, 120 | ) 121 | if "no_embedding_create_mocking" not in request.keywords: 122 | mocker.patch( 123 | "openai.resources.embeddings.Embeddings.create", 124 | new=_mock_openai_embedding_create, 125 | ) 126 | 127 | 128 | @pytest.fixture(autouse=True) 129 | def _internet_search_mockers(mocker): 130 | """Mockers for the internet search module.""" 131 | mocker.patch("duckduckgo_search.DDGS.text", return_value=lorem.get_paragraph()) 132 | 133 | 134 | @pytest.fixture() 135 | def _input_builtin_mocker(mocker, user_input): 136 | """Mock the `input` builtin. Raise `KeyboardInterrupt` after the second call.""" 137 | 138 | # We allow two calls in order to allow for the chat context handler to kick in 139 | def _mock_input(*args, **kwargs): # noqa: ARG001 140 | try: 141 | _mock_input.execution_counter += 1 142 | except AttributeError: 143 | _mock_input.execution_counter = 0 144 | if _mock_input.execution_counter > 1: 145 | raise KeyboardInterrupt 146 | return user_input 147 | 148 | mocker.patch( # noqa: PT008 149 | "builtins.input", new=lambda _: _mock_input(user_input=user_input) 150 | ) 151 | 152 | 153 | @pytest.fixture(params=ChatOptions.get_allowed_values("model")[:2]) 154 | def llm_model(request): 155 | return request.param 156 | 157 | 158 | context_model_values = ChatOptions.get_allowed_values("context_model") 159 | 160 | 161 | @pytest.fixture(params=[context_model_values[0], context_model_values[2]]) 162 | def context_model(request): 163 | return request.param 164 | 165 | 166 | @pytest.fixture() 167 | def default_chat_configs(llm_model, context_model): 168 | return ChatOptions(model=llm_model, context_model=context_model) 169 | 170 | 171 | @pytest.fixture() 172 | def default_voice_chat_configs(llm_model, context_model): 173 | return VoiceChatConfigs(model=llm_model, context_model=context_model) 174 | 175 | 176 | @pytest.fixture() 177 | def cli_args_overrides(default_chat_configs): 178 | args = [] 179 | for field, value in default_chat_configs.model_dump().items(): 180 | if value not in [None, True, False]: 181 | args = [*args, *[f"--{field.replace('_', '-')}", str(value)]] 182 | return args 183 | 184 | 185 | @pytest.fixture() 186 | def default_chat(default_chat_configs): 187 | return Chat(configs=default_chat_configs) 188 | 189 | 190 | @pytest.fixture() 191 | def default_voice_chat(default_voice_chat_configs): 192 | chat = VoiceChat(configs=default_voice_chat_configs) 193 | chat.inactivity_timeout_seconds = 1e-5 194 | chat.tts_engine = "google" 195 | return chat 196 | 197 | 198 | @pytest.fixture(autouse=True) 199 | def _voice_chat_mockers(mocker, mock_wav_bytes_string): 200 | """Mockers for the text-to-speech module.""" 201 | mocker.patch( 202 | "pyrobbot.voice_chat.VoiceChat._assistant_still_replying", return_value=False 203 | ) 204 | 205 | mock_google_tts_obj = type("mock_gTTS", (), {}) 206 | mock_openai_tts_response = type("mock_openai_tts_response", (), {}) 207 | 208 | def _mock_iter_bytes(*args, **kwargs): # noqa: ARG001 209 | return [mock_wav_bytes_string] 210 | 211 | mock_openai_tts_response.iter_bytes = _mock_iter_bytes 212 | 213 | mocker.patch( 214 | "pydub.AudioSegment.from_mp3", 215 | return_value=pydub.AudioSegment.from_wav(io.BytesIO(mock_wav_bytes_string)), 216 | ) 217 | mocker.patch("gtts.gTTS", return_value=mock_google_tts_obj) 218 | mocker.patch( 219 | "openai.resources.audio.speech.Speech.create", 220 | return_value=mock_openai_tts_response, 221 | ) 222 | mock_transcription = type("MockTranscription", (), {}) 223 | mock_transcription.text = "patched" 224 | mocker.patch( 225 | "openai.resources.audio.transcriptions.Transcriptions.create", 226 | return_value=mock_transcription, 227 | ) 228 | mocker.patch( 229 | "speech_recognition.Recognizer.recognize_google", 230 | return_value=mock_transcription.text, 231 | ) 232 | 233 | mocker.patch("webrtcvad.Vad.is_speech", return_value=False) 234 | mocker.patch("pygame.mixer.init") 235 | mocker.patch("chime.play_wav") 236 | mocker.patch("chime.play_wav") 237 | -------------------------------------------------------------------------------- /tests/smoke/test_app.py: -------------------------------------------------------------------------------- 1 | import contextlib 2 | 3 | import streamlit 4 | import streamlit_webrtc.component 5 | 6 | from pyrobbot.app import app 7 | 8 | 9 | def test_app(mocker, default_voice_chat_configs): 10 | class MockAttrDict(streamlit.runtime.state.session_state_proxy.SessionStateProxy): 11 | def __getattr__(self, attr): 12 | return self.get(attr, mocker.MagicMock()) 13 | 14 | def __getitem__(self, key): 15 | with contextlib.suppress(KeyError): 16 | return super().__getitem__(key) 17 | return mocker.MagicMock() 18 | 19 | mocker.patch.object(streamlit, "session_state", new=MockAttrDict()) 20 | mocker.patch.object( 21 | streamlit.runtime.state.session_state_proxy, 22 | "SessionStateProxy", 23 | new=MockAttrDict, 24 | ) 25 | mocker.patch("streamlit.chat_input", return_value="foobar") 26 | mocker.patch( 27 | "pyrobbot.chat_configs.VoiceChatConfigs.from_file", 28 | return_value=default_voice_chat_configs, 29 | ) 30 | mocker.patch.object( 31 | streamlit_webrtc.component, 32 | "webrtc_streamer", 33 | mocker.MagicMock(return_value=mocker.MagicMock()), 34 | ) 35 | 36 | mocker.patch("streamlit.number_input", return_value=0) 37 | 38 | mocker.patch( 39 | "pyrobbot.chat_configs.VoiceChatConfigs.model_validate", 40 | return_value=default_voice_chat_configs, 41 | ) 42 | 43 | app.run_app() 44 | -------------------------------------------------------------------------------- /tests/smoke/test_commands.py: -------------------------------------------------------------------------------- 1 | import io 2 | import subprocess 3 | 4 | import pytest 5 | from pydub import AudioSegment 6 | 7 | from pyrobbot.__main__ import main 8 | from pyrobbot.argparse_wrapper import get_parsed_args 9 | 10 | 11 | def test_default_command(): 12 | args = get_parsed_args(argv=[]) 13 | assert args.command == "ui" 14 | 15 | 16 | @pytest.mark.usefixtures("_input_builtin_mocker") 17 | @pytest.mark.parametrize("user_input", ["Hi!", ""], ids=["regular-input", "empty-input"]) 18 | def test_terminal_command(cli_args_overrides): 19 | args = ["terminal", "--report-accounting-when-done", *cli_args_overrides] 20 | args = list(dict.fromkeys(args)) 21 | main(args) 22 | 23 | 24 | def test_accounting_command(): 25 | main(["accounting"]) 26 | 27 | 28 | def test_ui_command(mocker, caplog): 29 | original_run = subprocess.run 30 | 31 | def new_run(*args, **kwargs): 32 | kwargs.pop("timeout", None) 33 | try: 34 | original_run(*args, **kwargs, timeout=0.5) 35 | except subprocess.TimeoutExpired as error: 36 | raise KeyboardInterrupt from error 37 | 38 | mocker.patch("subprocess.run", new=new_run) 39 | main(["ui"]) 40 | assert "Exiting." in caplog.text 41 | 42 | 43 | @pytest.mark.parametrize("stt", ["google", "openai"]) 44 | @pytest.mark.parametrize("tts", ["google", "openai"]) 45 | def test_voice_chat(mocker, mock_wav_bytes_string, tts, stt): 46 | # We allow even number of calls in order to let the function be tested first and 47 | # then terminate the chat 48 | def _mock_listen(*args, **kwargs): # noqa: ARG001 49 | try: 50 | _mock_listen.execution_counter += 1 51 | except AttributeError: 52 | _mock_listen.execution_counter = 0 53 | if _mock_listen.execution_counter % 2: 54 | return None 55 | return AudioSegment.from_wav(io.BytesIO(mock_wav_bytes_string)) 56 | 57 | mocker.patch("pyrobbot.voice_chat.VoiceChat.listen", _mock_listen) 58 | main(["voice", "--tts", tts, "--stt", stt]) 59 | -------------------------------------------------------------------------------- /tests/unit/test_chat.py: -------------------------------------------------------------------------------- 1 | import openai 2 | import pytest 3 | 4 | from pyrobbot import GeneralDefinitions 5 | from pyrobbot.chat import Chat 6 | from pyrobbot.chat_configs import ChatOptions 7 | 8 | 9 | @pytest.mark.order(1) 10 | @pytest.mark.usefixtures("_input_builtin_mocker") 11 | @pytest.mark.no_chat_completion_create_mocking() 12 | @pytest.mark.parametrize("user_input", ["regular-input"]) 13 | def testbed_doesnt_actually_connect_to_openai(caplog): 14 | llm = ChatOptions.get_allowed_values("model")[0] 15 | context_model = ChatOptions.get_allowed_values("context_model")[0] 16 | chat_configs = ChatOptions(model=llm, context_model=context_model) 17 | chat = Chat(configs=chat_configs) 18 | 19 | chat.start() 20 | success = chat.response_failure_message().content in caplog.text 21 | 22 | err_msg = "Refuse to continue: Testbed is trying to connect to OpenAI API!" 23 | err_msg += f"\nThis is what the logger says:\n{caplog.text}" 24 | if not success: 25 | pytest.exit(err_msg) 26 | 27 | 28 | @pytest.mark.order(2) 29 | def test_we_are_using_tmp_cachedir(): 30 | try: 31 | assert ( 32 | pytest.original_package_cache_directory 33 | != GeneralDefinitions.PACKAGE_CACHE_DIRECTORY 34 | ) 35 | 36 | except AssertionError: 37 | pytest.exit( 38 | "Refuse to continue: Tests attempted to use the package's real cache dir " 39 | + f"({GeneralDefinitions.PACKAGE_CACHE_DIRECTORY})!" 40 | ) 41 | 42 | 43 | @pytest.mark.usefixtures("_input_builtin_mocker") 44 | @pytest.mark.parametrize("user_input", ["Hi!", ""], ids=["regular-input", "empty-input"]) 45 | def test_terminal_chat(default_chat): 46 | default_chat.start() 47 | default_chat.__del__() # Just to trigger testing the custom del method 48 | 49 | 50 | def test_chat_configs(default_chat, default_chat_configs): 51 | assert default_chat._passed_configs == default_chat_configs 52 | 53 | 54 | @pytest.mark.no_chat_completion_create_mocking() 55 | @pytest.mark.usefixtures("_input_builtin_mocker") 56 | @pytest.mark.parametrize("user_input", ["regular-input"]) 57 | def test_request_timeout_retry(mocker, default_chat, caplog): 58 | def _mock_openai_chat_completion_create(*args, **kwargs): # noqa: ARG001 59 | raise openai.APITimeoutError("Mocked timeout error was not caught!") 60 | 61 | mocker.patch( 62 | "openai.resources.chat.completions.Completions.create", 63 | new=_mock_openai_chat_completion_create, 64 | ) 65 | mocker.patch("time.sleep") # Don't waste time sleeping in tests 66 | default_chat.start() 67 | assert "APITimeoutError" in caplog.text 68 | 69 | 70 | def test_can_read_chat_from_cache(default_chat): 71 | default_chat.save_cache() 72 | new_chat = Chat.from_cache(default_chat.cache_dir) 73 | assert new_chat.configs == default_chat.configs 74 | 75 | 76 | def test_create_from_cache_returns_default_chat_if_invalid_cachedir(default_chat, caplog): 77 | _ = Chat.from_cache(default_chat.cache_dir / "foobar") 78 | assert "Creating Chat with default configs" in caplog.text 79 | 80 | 81 | @pytest.mark.usefixtures("_input_builtin_mocker") 82 | @pytest.mark.parametrize("user_input", ["regular-input"]) 83 | def test_internet_search_can_be_triggered(default_chat, mocker): 84 | mocker.patch( 85 | "pyrobbot.openai_utils.make_api_chat_completion_call", return_value=iter(["yes"]) 86 | ) 87 | mocker.patch("pyrobbot.chat.Chat.respond_system_prompt", return_value=iter(["yes"])) 88 | mocker.patch( 89 | "pyrobbot.internet_utils.raw_websearch", 90 | return_value=iter( 91 | [ 92 | { 93 | "href": "foo/bar", 94 | "summary": 50 * "foo ", 95 | "detailed": 50 * "foo ", 96 | "relevance": 1.0, 97 | } 98 | ] 99 | ), 100 | ) 101 | default_chat.start() 102 | -------------------------------------------------------------------------------- /tests/unit/test_internet_utils.py: -------------------------------------------------------------------------------- 1 | import contextlib 2 | 3 | import duckduckgo_search 4 | 5 | from pyrobbot.internet_utils import websearch 6 | 7 | # if called inside tests or fixtures. Leave it like this for now. 8 | search_results = [] 9 | with contextlib.suppress(duckduckgo_search.exceptions.DuckDuckGoSearchException): 10 | search_results = list(websearch("foobar")) 11 | 12 | 13 | def test_websearch(): 14 | for i_result, result in enumerate(search_results): 15 | assert isinstance(result, dict) 16 | assert ("detailed" in result) == (i_result == 0) 17 | for key in ["summary", "relevance", "href"]: 18 | assert key in result 19 | -------------------------------------------------------------------------------- /tests/unit/test_text_to_speech.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | from pydub import AudioSegment 3 | 4 | from pyrobbot.sst_and_tts import SpeechToText 5 | 6 | 7 | @pytest.mark.parametrize("stt_engine", ["google", "openai"]) 8 | def test_stt(default_voice_chat, stt_engine): 9 | """Test the speech-to-text method.""" 10 | default_voice_chat.stt_engine = stt_engine 11 | stt = SpeechToText( 12 | openai_client=default_voice_chat.openai_client, 13 | speech=AudioSegment.silent(duration=100), 14 | engine=stt_engine, 15 | general_token_usage_db=default_voice_chat.general_token_usage_db, 16 | token_usage_db=default_voice_chat.token_usage_db, 17 | ) 18 | assert stt.text == "patched" 19 | -------------------------------------------------------------------------------- /tests/unit/test_voice_chat.py: -------------------------------------------------------------------------------- 1 | import contextlib 2 | 3 | import pytest 4 | from pydantic import ValidationError 5 | from pydub import AudioSegment 6 | from sounddevice import PortAudioError 7 | 8 | from pyrobbot.chat_configs import VoiceChatConfigs 9 | from pyrobbot.sst_and_tts import TextToSpeech 10 | from pyrobbot.voice_chat import VoiceChat 11 | 12 | 13 | def test_soundcard_import_check(mocker, caplog): 14 | """Test that the voice chat cannot be instantiated if soundcard is not imported.""" 15 | mocker.patch("pyrobbot.voice_chat._sounddevice_imported", False) 16 | _ = VoiceChat(configs=VoiceChatConfigs()) 17 | msg = "Module `sounddevice`, needed for local audio recording, is not available." 18 | assert msg in caplog.text 19 | 20 | 21 | @pytest.mark.parametrize("param_name", ["sample_rate", "frame_duration"]) 22 | def test_cannot_instanciate_assistant_with_invalid_webrtcvad_params(param_name): 23 | """Test that the voice chat cannot be instantiated with invalid webrtcvad params.""" 24 | with pytest.raises(ValidationError, match="Input should be"): 25 | VoiceChat(configs=VoiceChatConfigs(**{param_name: 1})) 26 | 27 | 28 | def test_listen(default_voice_chat): 29 | """Test the listen method.""" 30 | with contextlib.suppress(PortAudioError, pytest.PytestUnraisableExceptionWarning): 31 | default_voice_chat.listen() 32 | 33 | 34 | def test_speak(default_voice_chat, mocker): 35 | tts = TextToSpeech( 36 | openai_client=default_voice_chat.openai_client, 37 | text="foo", 38 | general_token_usage_db=default_voice_chat.general_token_usage_db, 39 | token_usage_db=default_voice_chat.token_usage_db, 40 | ) 41 | mocker.patch("pygame.mixer.Sound") 42 | mocker.patch("pyrobbot.voice_chat._get_lower_alphanumeric", return_value="ok cancel") 43 | mocker.patch( 44 | "pyrobbot.voice_chat.VoiceChat.listen", 45 | return_value=AudioSegment.silent(duration=150), 46 | ) 47 | default_voice_chat.speak(tts) 48 | 49 | 50 | def test_answer_question(default_voice_chat): 51 | default_voice_chat.answer_question("foo") 52 | 53 | 54 | def test_interrupt_reply(default_voice_chat): 55 | default_voice_chat.interrupt_reply.set() 56 | default_voice_chat.questions_queue.get = lambda: None 57 | default_voice_chat.questions_queue.task_done = lambda: None 58 | default_voice_chat.start() 59 | 60 | 61 | def test_handle_interrupt_expressions(default_voice_chat, mocker): 62 | mocker.patch("pyrobbot.general_utils.str2_minus_str1", return_value="cancel") 63 | default_voice_chat.questions_queue.get = lambda: None 64 | default_voice_chat.questions_queue.task_done = lambda: None 65 | default_voice_chat.questions_queue.answer_question = lambda _question: None 66 | msgs_to_compare = { 67 | "assistant_txt": "foo", 68 | "user_audio": AudioSegment.silent(duration=150), 69 | } 70 | default_voice_chat.check_for_interrupt_expressions_queue.put(msgs_to_compare) 71 | default_voice_chat.start() 72 | --------------------------------------------------------------------------------