├── .flakeheaven.toml
├── .github
├── ISSUE_TEMPLATE
│ ├── bug_report.md
│ └── feature_request.md
└── workflows
│ ├── linting.yaml
│ └── tests.yaml
├── .gitignore
├── LICENSE.md
├── README.md
├── packages.txt
├── poetry.toml
├── pyproject.toml
├── pyrobbot
├── __init__.py
├── __main__.py
├── app
│ ├── .streamlit
│ │ └── config.toml
│ ├── __init__.py
│ ├── app.py
│ ├── app_page_templates.py
│ ├── app_utils.py
│ ├── data
│ │ ├── assistant_avatar.png
│ │ ├── powered-by-openai-badge-outlined-on-dark.svg
│ │ ├── success.wav
│ │ ├── user_avatar.png
│ │ └── warning.wav
│ └── multipage.py
├── argparse_wrapper.py
├── chat.py
├── chat_configs.py
├── chat_context.py
├── command_definitions.py
├── embeddings_database.py
├── general_utils.py
├── internet_utils.py
├── openai_utils.py
├── sst_and_tts.py
├── tokens.py
└── voice_chat.py
└── tests
├── conftest.py
├── smoke
├── test_app.py
└── test_commands.py
└── unit
├── test_chat.py
├── test_internet_utils.py
├── test_text_to_speech.py
└── test_voice_chat.py
/.flakeheaven.toml:
--------------------------------------------------------------------------------
1 | [tool.flakeheaven]
2 | exclude = [".*/", "tmp/", "*/tmp/", "*.ipynb"]
3 | format = "colored"
4 | # Show line of source code in output, with syntax highlighting
5 | show_source = true
6 | style = "google"
7 |
8 | # list of plugins and rules for them
9 | [tool.flakeheaven.plugins]
10 | # Deactivate all rules for all plugins by default
11 | "*" = ["-*"]
12 | # Activate only those plugins not covered by ruff
13 | pydoclint = [
14 | "+*",
15 | "-DOC105",
16 | "-DOC106",
17 | "-DOC107",
18 | "-DOC109",
19 | "-DOC110",
20 | "-DOC203",
21 | "-DOC301",
22 | "-DOC403",
23 | "-DOC404",
24 | ]
25 |
--------------------------------------------------------------------------------
/.github/ISSUE_TEMPLATE/bug_report.md:
--------------------------------------------------------------------------------
1 | ---
2 | name: Bug report
3 | about: Create a report to help us improve
4 | title: ''
5 | labels: ''
6 | assignees: ''
7 |
8 | ---
9 |
10 | **Describe the bug**
11 | A clear and concise description of what the bug is.
12 |
13 | **To Reproduce**
14 | Steps to reproduce the behavior:
15 | 1. Go to '...'
16 | 2. Click on '....'
17 | 3. Scroll down to '....'
18 | 4. See error
19 |
20 | **Expected behavior**
21 | A clear and concise description of what you expected to happen.
22 |
23 | **Screenshots**
24 | If applicable, add screenshots to help explain your problem.
25 |
26 | **Desktop (please complete the following information):**
27 | - OS: [e.g. iOS]
28 | - Browser [e.g. chrome, safari]
29 | - Version [e.g. 22]
30 |
31 | **Smartphone (please complete the following information):**
32 | - Device: [e.g. iPhone6]
33 | - OS: [e.g. iOS8.1]
34 | - Browser [e.g. stock browser, safari]
35 | - Version [e.g. 22]
36 |
37 | **Additional context**
38 | Add any other context about the problem here.
39 |
--------------------------------------------------------------------------------
/.github/ISSUE_TEMPLATE/feature_request.md:
--------------------------------------------------------------------------------
1 | ---
2 | name: Feature request
3 | about: Suggest an idea for this project
4 | title: ''
5 | labels: ''
6 | assignees: ''
7 |
8 | ---
9 |
10 | **Is your feature request related to a problem? Please describe.**
11 | A clear and concise description of what the problem is. Ex. I'm always frustrated when [...]
12 |
13 | **Describe the solution you'd like**
14 | A clear and concise description of what you want to happen.
15 |
16 | **Describe alternatives you've considered**
17 | A clear and concise description of any alternative solutions or features you've considered.
18 |
19 | **Additional context**
20 | Add any other context or screenshots about the feature request here.
21 |
--------------------------------------------------------------------------------
/.github/workflows/linting.yaml:
--------------------------------------------------------------------------------
1 | #.github/workflows/linting.yaml
2 | name: Linting Checks
3 |
4 | on:
5 | pull_request:
6 | branches:
7 | - main
8 | - develop
9 | paths:
10 | - '**.py'
11 | - '.github/workflows/linting.yaml'
12 | push:
13 | branches:
14 | - '**' # Every branch
15 | paths:
16 | - '**.py'
17 | - '.github/workflows/linting.yaml'
18 |
19 | jobs:
20 | linting:
21 | if: github.repository_owner == 'paulovcmedeiros'
22 | name: Run Linters
23 | runs-on: ubuntu-latest
24 | steps:
25 | #----------------------------------------------
26 | # check-out repo and set-up python
27 | #----------------------------------------------
28 | - name: Check out repository
29 | uses: actions/checkout@v3
30 | - name: Set up python
31 | id: setup-python
32 | uses: actions/setup-python@v4
33 | with:
34 | python-version: '3.9'
35 |
36 | #----------------------------------------------
37 | # --- configure poetry & install project ----
38 | #----------------------------------------------
39 | - name: Install Poetry
40 | uses: snok/install-poetry@v1
41 | with:
42 | virtualenvs-create: true
43 | virtualenvs-in-project: true
44 |
45 | - name: Install poethepoet
46 | run: poetry self add 'poethepoet[poetry_plugin]'
47 |
48 | - name: Load cached venv (if cache exists)
49 | id: cached-poetry-dependencies
50 | uses: actions/cache@v3
51 | with:
52 | path: .venv
53 | key: ${{ github.job }}-venv-${{ runner.os }}-${{ steps.setup-python.outputs.python-version }}-${{ hashFiles('**/pyproject.toml', '**/poetry.toml') }}
54 |
55 | - name: Install dependencies (if venv cache is not found)
56 | if: steps.cached-poetry-dependencies.outputs.cache-hit != 'true'
57 | run: poetry install --no-interaction --no-root --only main,linting
58 |
59 | - name: Install the project itself
60 | run: poetry install --no-interaction --only-root
61 |
62 | #----------------------------------------------
63 | # Run the linting checks
64 | #----------------------------------------------
65 | - name: Run linters
66 | run: |
67 | poetry devtools lint
68 |
69 |
--------------------------------------------------------------------------------
/.github/workflows/tests.yaml:
--------------------------------------------------------------------------------
1 | #.github/workflows/tests.yaml
2 | name: Unit Tests
3 |
4 | on:
5 | pull_request:
6 | branches:
7 | - main
8 | - develop
9 | push:
10 | branches:
11 | - '**' # Every branch
12 |
13 | jobs:
14 | tests:
15 | if: github.repository_owner == 'paulovcmedeiros'
16 | strategy:
17 | fail-fast: true
18 | matrix:
19 | os: [ "ubuntu-latest" ]
20 | env: [ "pytest" ]
21 | python-version: [ "3.9" ]
22 |
23 | name: "${{ matrix.os }}, python=${{ matrix.python-version }}"
24 | runs-on: ${{ matrix.os }}
25 |
26 | container:
27 | image: python:${{ matrix.python-version }}-bullseye
28 | env:
29 | COVERAGE_FILE: ".coverage.${{ matrix.env }}.${{ matrix.python-version }}"
30 |
31 | steps:
32 | #----------------------------------------------
33 | # check-out repo
34 | #----------------------------------------------
35 | - name: Check out repository
36 | uses: actions/checkout@v3
37 |
38 | #----------------------------------------------
39 | # Install Audio Gear
40 | #----------------------------------------------
41 | - name: Install PortAudio and PulseAudio
42 | run: |
43 | apt-get update
44 | apt-get --assume-yes install portaudio19-dev python-all-dev pulseaudio ffmpeg
45 |
46 | #----------------------------------------------
47 | # --- configure poetry & install project ----
48 | #----------------------------------------------
49 | - name: Install Poetry
50 | uses: snok/install-poetry@v1
51 | with:
52 | virtualenvs-create: true
53 | virtualenvs-in-project: true
54 |
55 | - name: Load cached venv (if cache exists)
56 | id: cached-poetry-dependencies
57 | uses: actions/cache@v3
58 | with:
59 | path: .venv
60 | key: ${{ github.job }}-venv-${{ runner.os }}-${{ matrix.python-version }}-${{ hashFiles('**/pyproject.toml', '**/poetry.toml') }}
61 |
62 | - name: Install dependencies (if venv cache is not found)
63 | if: steps.cached-poetry-dependencies.outputs.cache-hit != 'true'
64 | run: poetry install --no-interaction --no-root --only main,test
65 |
66 | - name: Install the project itself
67 | run: poetry install --no-interaction --only-root
68 |
69 | #----------------------------------------------
70 | # run test suite and report coverage
71 | #----------------------------------------------
72 | - name: Run tests
73 | env:
74 | SDL_VIDEODRIVER: "dummy"
75 | SDL_AUDIODRIVER: "disk"
76 | run: |
77 | poetry run pytest
78 |
79 | - name: Upload test coverage report to Codecov
80 | uses: codecov/codecov-action@v3
81 | with:
82 | token: ${{ secrets.CODECOV_TOKEN }}
83 | files: ./.coverage.xml
84 |
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | # Byte-compiled / optimized / DLL files
2 | __pycache__/
3 | *.py[cod]
4 | *$py.class
5 |
6 | # C extensions
7 | *.so
8 |
9 | # Distribution / packaging
10 | .Python
11 | build/
12 | develop-eggs/
13 | dist/
14 | downloads/
15 | eggs/
16 | .eggs/
17 | lib/
18 | lib64/
19 | parts/
20 | sdist/
21 | var/
22 | wheels/
23 | share/python-wheels/
24 | *.egg-info/
25 | .installed.cfg
26 | *.egg
27 | MANIFEST
28 |
29 | # PyInstaller
30 | # Usually these files are written by a python script from a template
31 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
32 | *.manifest
33 | *.spec
34 |
35 | # Installer logs
36 | pip-log.txt
37 | pip-delete-this-directory.txt
38 |
39 | # Unit test / coverage reports
40 | htmlcov/
41 | .tox/
42 | .nox/
43 | .coverage
44 | .coverage.*
45 | .cache
46 | nosetests.xml
47 | coverage.xml
48 | *.cover
49 | *.py,cover
50 | .hypothesis/
51 | .pytest_cache/
52 | cover/
53 |
54 | # Translations
55 | *.mo
56 | *.pot
57 |
58 | # Django stuff:
59 | *.log
60 | local_settings.py
61 | db.sqlite3
62 | db.sqlite3-journal
63 |
64 | # Flask stuff:
65 | instance/
66 | .webassets-cache
67 |
68 | # Scrapy stuff:
69 | .scrapy
70 |
71 | # Sphinx documentation
72 | docs/_build/
73 |
74 | # PyBuilder
75 | .pybuilder/
76 | target/
77 |
78 | # Jupyter Notebook
79 | .ipynb_checkpoints
80 |
81 | # IPython
82 | profile_default/
83 | ipython_config.py
84 |
85 | # pyenv
86 | # For a library or package, you might want to ignore these files since the code is
87 | # intended to run in multiple environments; otherwise, check them in:
88 | # .python-version
89 |
90 | # pipenv
91 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
92 | # However, in case of collaboration, if having platform-specific dependencies or dependencies
93 | # having no cross-platform support, pipenv may install dependencies that don't work, or not
94 | # install all needed dependencies.
95 | #Pipfile.lock
96 |
97 | # poetry
98 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
99 | # This is especially recommended for binary packages to ensure reproducibility, and is more
100 | # commonly ignored for libraries.
101 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
102 | poetry.lock
103 |
104 | # pdm
105 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
106 | #pdm.lock
107 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
108 | # in version control.
109 | # https://pdm.fming.dev/#use-with-ide
110 | .pdm.toml
111 |
112 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
113 | __pypackages__/
114 |
115 | # Celery stuff
116 | celerybeat-schedule
117 | celerybeat.pid
118 |
119 | # SageMath parsed files
120 | *.sage.py
121 |
122 | # Environments
123 | .env
124 | .venv
125 | env/
126 | venv/
127 | ENV/
128 | env.bak/
129 | venv.bak/
130 |
131 | # Spyder project settings
132 | .spyderproject
133 | .spyproject
134 |
135 | # Rope project settings
136 | .ropeproject
137 |
138 | # mkdocs documentation
139 | /site
140 |
141 | # mypy
142 | .mypy_cache/
143 | .dmypy.json
144 | dmypy.json
145 |
146 | # Pyre type checker
147 | .pyre/
148 |
149 | # pytype static type analyzer
150 | .pytype/
151 |
152 | # Cython debug symbols
153 | cython_debug/
154 |
155 | # PyCharm
156 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can
157 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
158 | # and can be added to the global gitignore or merged into this file. For a more nuclear
159 | # option (not recommended) you can uncomment the following to ignore the entire idea folder.
160 | .idea/
161 |
162 | # Vim
163 | *.swp
164 |
165 | # Temporary files and directories
166 | tmp/
167 |
168 | # Linting artifacts
169 | .flakeheaven_cache/
170 |
--------------------------------------------------------------------------------
/LICENSE.md:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2023 Paulo V. C. Medeiros
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 |
2 |
3 | []((https://github.com/paulovcmedeiros/pyRobBot))
4 | #
[pyRobBot](https://github.com/paulovcmedeiros/pyRobBot)
Chat with GPT LLMs over voice, text or both.
All with access to the internet.
5 |
6 | [](https://www.pepy.tech/projects/pyrobbot)
7 | [](https://pypi.org/project/pyrobbot/)
8 | [](https://pyrobbot.streamlit.app)
9 | [

](https://openai.com/blog/openai-api)
10 |
11 |
12 | [](https://python-poetry.org/)
13 | [](https://github.com/paulovcmedeiros/pyRobBot/pulls)
14 | [](https://github.com/paulovcmedeiros/pyRobBot/actions/workflows/linting.yaml)
15 | [](https://github.com/paulovcmedeiros/pyRobBot/actions/workflows/tests.yaml)
16 | [](https://codecov.io/gh/paulovcmedeiros/pyRobBot)
17 |
18 |
19 |
20 | PyRobBot is a python package that uses OpenAI's [GPT large language models (LLMs)](https://platform.openai.com/docs/models) to implement a fully configurable **personal assistant** that, on top of the traditional chatbot interface, can also speak and listen to you using AI-generated **human-like** voices.
21 |
22 |
23 | ## Features
24 |
25 | Features include, but are not limited to:
26 |
27 | - [x] Voice Chat
28 | - Continuous voice input and output
29 | - No need to press a button: the assistant will keep listening until you stop talking
30 |
31 | - [x] Internet access: The assistent will **search the web** to find the answers it doesn't have in its training data
32 | - E.g. latest news, current events, weather forecasts, etc.
33 | - Powered by [DuckDuckGo Search](https://github.com/deedy5/duckduckgo_search)
34 |
35 | - [x] Web browser user interface
36 | - See our [demo app on Streamlit Community Cloud](https://pyrobbot.streamlit.app)
37 | - Voice chat with:
38 | - **Continuous voice input and output** (using [streamlit-webrtc](https://github.com/whitphx/streamlit-webrtc))
39 | - If you prefer, manual on/off toggling of the microphone (using [streamlit_mic_recorder](https://github.com/B4PT0R/streamlit-mic-recorder))
40 | - A familiar text interface integrated with the voice chat, for those who prefer a traditional chatbot experience
41 | - Your voice prompts and the assistant's voice replies are shown as text in the chat window
42 | - You may also send promts as text even when voice detection is enabled
43 | - Add/remove conversations dynamically
44 | - Automatic/editable conversation summary title
45 | - Autosave & retrieve chat history
46 | - Resume even the text & voice conversations started outside the web interface
47 |
48 |
49 | - [x] Chat via terminal
50 | - For a more "Wake up, Neo" experience
51 |
52 | - [x] Fully configurable
53 | - Large number of supported languages (*e.g.*, `rob --lang pt-br`)
54 | - Support for multiple LLMs through the OpenAI API
55 | - Choose your preferred Text-to-Speech (TTS) and Speech-To-Text (STT) engines (google/openai)
56 | - Control over the parameters passed to the OpenAI API, with (hopefully) sensible defaults
57 | - Ability to pass base directives to the LLM
58 | - E.g., to make it adopt a persona, but you decide which directived to pass
59 | - Dynamically modifiable AI parameters in each chat separately
60 | - No need to restart the chat
61 |
62 | - [x] Chat context handling using [embeddings](https://platform.openai.com/docs/guides/embeddings)
63 | - [x] Estimated API token usage and associated costs
64 | - [x] OpenAI API key is **never** stored on disk
65 |
66 |
67 |
68 | ## System Requirements
69 | - Python >= 3.9
70 | - A valid [OpenAI API key](https://platform.openai.com/account/api-keys)
71 | - Set it in the Web UI or through the environment variable `OPENAI_API_KEY`
72 | - To enable voice chat, you also need:
73 | - [PortAudio](https://www.portaudio.com/docs/v19-doxydocs/index.html)
74 | - Install on Ubuntu with `sudo apt-get --assume-yes install portaudio19-dev python-all-dev`
75 | - Install on CentOS/RHEL with `sudo yum install portaudio portaudio-devel`
76 | - [ffmpeg](https://ffmpeg.org/download.html)
77 | - Install on Ubuntu with `sudo apt-get --assume-yes install ffmpeg`
78 | - Install on CentOS/RHEL with `sudo yum install ffmpeg`
79 |
80 | ## Installation
81 | This, naturally, assumes your system fulfills all [requirements](#system-requirements).
82 |
83 | ### Regular Installation
84 | The recommended way for most users.
85 |
86 | #### Using pip
87 | ```shell
88 | pip install pyrobbot
89 | ```
90 | #### From the GitHub repository
91 | ```shell
92 | pip install git+https://github.com/paulovcmedeiros/pyRobBot.git
93 | ```
94 |
95 | ### Developer-Mode Installation
96 | The recommended way for those who want to contribute to the project. We use [poetry](https://python-poetry.org) with the [poethepoet](https://poethepoet.natn.io/index.html) plugin. To get everything set up, run:
97 | ```shell
98 | # Clean eventual previous install
99 | curl -sSL https://install.python-poetry.org | python3 - --uninstall
100 | rm -rf ${HOME}/.cache/pypoetry/ ${HOME}/.local/bin/poetry ${HOME}/.local/share/pypoetry
101 | # Download and install poetry
102 | curl -sSL https://install.python-poetry.org | python3 -
103 | # Install needed poetry plugin(s)
104 | poetry self add 'poethepoet[poetry_plugin]'
105 | ```
106 |
107 |
108 | ## Basic Usage
109 | Upon succesfull installation, you should be able to run
110 | ```shell
111 | rob [opts] SUBCOMMAND [subcommand_opts]
112 | ```
113 | where `[opts]` and `[subcommand_opts]` denote optional command line arguments
114 | that apply, respectively, to `rob` in general and to `SUBCOMMAND`
115 | specifically.
116 |
117 | **Please run `rob -h` for information** about the supported subcommands
118 | and general `rob` options. For info about specific subcommands and the
119 | options that apply to them only, **please run `rob SUBCOMMAND -h`** (note
120 | that the `-h` goes after the subcommand in this case).
121 |
122 | ### Using the Web UI (defult, supports voice & text chat)
123 | ```shell
124 | rob
125 | ```
126 | See also our [demo Streamlit app](https://pyrobbot.streamlit.app)!
127 |
128 | ### Chatting Only by Voice
129 | ```shell
130 | rob voice
131 | ```
132 |
133 | ### Running on the Terminal
134 | ```shell
135 | rob .
136 | ```
137 |
138 | ## Disclaimers
139 | This project's main purpose has been to serve as a learning exercise for me, as well as tool for experimenting with OpenAI API, GPT LLMs and text-to-speech/speech-to-text.
140 |
141 | While it does not claim to be the best or more robust OpenAI-powered chatbot out there, it *does* aim to provide a friendly user interface that is easy to install, use and configure.
142 |
143 | Feel free to open an [issue](https://github.com/paulovcmedeiros/pyRobBot/issues) or, even better, [submit a pull request](https://github.com/paulovcmedeiros/pyRobBot/pulls) if you find a bug or have a suggestion.
144 |
145 | Last but not least: This project is **independently developed** and **not** affiliated, endorsed, or sponsored by OpenAI in any way. It is separate and distinct from OpenAI’s own products and services.
146 |
--------------------------------------------------------------------------------
/packages.txt:
--------------------------------------------------------------------------------
1 | ffmpeg
2 | portaudio19-dev
3 | python-all-dev
4 |
--------------------------------------------------------------------------------
/poetry.toml:
--------------------------------------------------------------------------------
1 | [virtualenvs]
2 | create = true
3 | in-project = true
4 | prefer-active-python = true
5 |
6 | [virtualenvs.options]
7 | system-site-packages = false
8 |
--------------------------------------------------------------------------------
/pyproject.toml:
--------------------------------------------------------------------------------
1 | [tool.poetry]
2 | authors = ["Paulo V C Medeiros "]
3 | description = """\
4 | Chat with GPT LLMs over voice, text or both. With access to the internet.\
5 | Powered by OpenAI.\
6 | """
7 | license = "MIT"
8 | name = "pyrobbot"
9 | readme = "README.md"
10 | version = "0.7.7"
11 |
12 | [build-system]
13 | build-backend = "poetry.core.masonry.api"
14 | requires = ["poetry-core"]
15 |
16 | [tool.poetry.scripts]
17 | rob = "pyrobbot.__main__:main"
18 |
19 | [tool.poetry.dependencies]
20 | # Python version
21 | python = ">=3.9,<3.9.7 || >3.9.7,<3.13"
22 | # Deps that should have been openapi deps
23 | matplotlib = "^3.8.0"
24 | plotly = "^5.18.0"
25 | scikit-learn = "^1.3.2"
26 | scipy = "^1.11.3"
27 | # Other dependencies
28 | loguru = "^0.7.2"
29 | numpy = "^1.26.1"
30 | openai = "^1.13.3"
31 | pandas = "^2.2.0"
32 | pillow = "^10.2.0"
33 | pydantic = "^2.6.1"
34 | streamlit = "^1.31.1"
35 | tiktoken = "^0.6.0"
36 | # Text to speech
37 | audio-recorder-streamlit = "^0.0.8"
38 | beautifulsoup4 = "^4.12.3"
39 | chime = "^0.7.0"
40 | duckduckgo-search = "^5.0"
41 | gtts = "^2.5.1"
42 | httpx = "^0.26.0"
43 | ipinfo = "^5.0.1"
44 | pydub = "^0.25.1"
45 | pygame = "^2.5.2"
46 | setuptools = "^68.2.2" # Needed by webrtcvad-wheels
47 | sounddevice = "^0.4.6"
48 | soundfile = "^0.12.1"
49 | speechrecognition = "^3.10.0"
50 | streamlit-mic-recorder = "^0.0.4"
51 | streamlit-webrtc = "^0.47.1"
52 | twilio = "^9.0.0"
53 | tzlocal = "^5.2"
54 | unidecode = "^1.3.7"
55 | webrtcvad-wheels = "^2.0.11.post1"
56 |
57 | [tool.poetry.group.dev.dependencies]
58 | ipython = "^8.16.1"
59 |
60 | [tool.poetry.group.linting.dependencies]
61 | black = "^24.2.0"
62 | flakeheaven = "^3.3.0"
63 | isort = "^5.13.2"
64 | pydoclint = "^0.4.0"
65 | ruff = "^0.3.0"
66 |
67 | [tool.poetry.group.test.dependencies]
68 | pytest = "^8.0.0"
69 | pytest-cov = "^4.1.0"
70 | pytest-mock = "^3.12.0"
71 | pytest-order = "^1.2.0"
72 | pytest-xdist = "^3.5.0"
73 | python-lorem = "^1.3.0.post1"
74 |
75 | ##################
76 | # Linter configs #
77 | ##################
78 |
79 | [tool.black]
80 | line-length = 90
81 |
82 | [tool.flakeheaven]
83 | base = ".flakeheaven.toml"
84 |
85 | [tool.isort]
86 | line_length = 90
87 | profile = "black"
88 |
89 | [tool.ruff]
90 | line-length = 90
91 |
92 | [tool.ruff.lint]
93 | # C901: Function is too complex. Ignoring this for now but will be removed later.
94 | ignore = ["C901", "D105", "EXE001", "RET504", "RUF012"]
95 | select = [
96 | "A",
97 | "ARG",
98 | "B",
99 | "BLE",
100 | "C4",
101 | "C90",
102 | "D",
103 | "E",
104 | "ERA",
105 | "EXE",
106 | "F",
107 | "G",
108 | "I",
109 | "N",
110 | "PD",
111 | "PERF",
112 | "PIE",
113 | "PL",
114 | "PT",
115 | "Q",
116 | "RET",
117 | "RSE",
118 | "RUF",
119 | "S",
120 | "SIM",
121 | "SLF",
122 | "T20",
123 | "W",
124 | ]
125 |
126 | [tool.ruff.lint.per-file-ignores]
127 | # S101: Use of `assert` detected
128 | "tests/**/*.py" = [
129 | "D100",
130 | "D101",
131 | "D102",
132 | "D103",
133 | "D104",
134 | "D105",
135 | "D106",
136 | "D107",
137 | "E501",
138 | "S101",
139 | "SLF001",
140 | ]
141 |
142 | [tool.ruff.lint.pydocstyle]
143 | convention = "google"
144 |
145 | ##################
146 | # pytest configs #
147 | ##################
148 |
149 | [tool.pytest.ini_options]
150 | addopts = """\
151 | -n auto -v --cache-clear \
152 | --failed-first \
153 | --cov-reset --cov-report=term-missing:skip-covered \
154 | --cov-report=xml:.coverage.xml --cov=./ \
155 | """
156 | log_cli_level = "INFO"
157 | testpaths = ["tests/smoke", "tests/unit"]
158 |
159 | [tool.coverage.report]
160 | omit = ["**/app/*"]
161 |
162 | ####################################
163 | # Leave configs for `poe` separate #
164 | ####################################
165 |
166 | [tool.poe]
167 | poetry_command = "devtools"
168 |
169 | [tool.poe.tasks]
170 | _black = "black ."
171 | _isort = "isort ."
172 | _ruff = "ruff check ."
173 | # Test-related tasks
174 | pytest = "pytest"
175 | test = ["pytest"]
176 | # Tasks to be run as pre-push checks
177 | pre-push-checks = ["lint", "pytest"]
178 |
179 | [tool.poe.tasks._flake8]
180 | cmd = "flakeheaven lint ."
181 | env = {FLAKEHEAVEN_CACHE_TIMEOUT = "0"}
182 |
183 | [tool.poe.tasks.lint]
184 | args = [{name = "fix", type = "boolean", default = false}]
185 | control = {expr = "fix"}
186 |
187 | [[tool.poe.tasks.lint.switch]]
188 | case = "True"
189 | sequence = ["_isort", "_black", "_ruff --fix", "_flake8"]
190 |
191 | [[tool.poe.tasks.lint.switch]]
192 | case = "False"
193 | sequence = ["_isort --check-only", "_black --check --diff", "_ruff", "_flake8"]
194 |
--------------------------------------------------------------------------------
/pyrobbot/__init__.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | """Unnoficial OpenAI API UI and CLI tool."""
3 | import os
4 | import sys
5 | import tempfile
6 | import uuid
7 | from collections import defaultdict
8 | from dataclasses import dataclass
9 | from importlib.metadata import metadata, version
10 | from pathlib import Path
11 |
12 | import ipinfo
13 | import requests
14 | from loguru import logger
15 |
16 | logger.remove()
17 | logger.add(
18 | sys.stderr,
19 | level=os.environ.get("LOGLEVEL", os.environ.get("LOGURU_LEVEL", "INFO")),
20 | )
21 |
22 | os.environ["PYGAME_HIDE_SUPPORT_PROMPT"] = "hide"
23 |
24 |
25 | @dataclass
26 | class GeneralDefinitions:
27 | """General definitions for the package."""
28 |
29 | # Main package info
30 | RUN_ID = uuid.uuid4().hex
31 | PACKAGE_NAME = __name__
32 | VERSION = version(__name__)
33 | PACKAGE_DESCRIPTION = metadata(__name__)["Summary"]
34 |
35 | # Main package directories
36 | PACKAGE_DIRECTORY = Path(__file__).parent
37 | PACKAGE_CACHE_DIRECTORY = Path.home() / ".cache" / PACKAGE_NAME
38 | _PACKAGE_TMPDIR = tempfile.TemporaryDirectory()
39 | PACKAGE_TMPDIR = Path(_PACKAGE_TMPDIR.name)
40 |
41 | # Constants related to the app
42 | APP_NAME = "pyRobBot"
43 | APP_DIR = PACKAGE_DIRECTORY / "app"
44 | APP_PATH = APP_DIR / "app.py"
45 | PARSED_ARGS_FILE = PACKAGE_TMPDIR / f"parsed_args_{RUN_ID}.pkl"
46 |
47 | # Location info
48 | IPINFO = defaultdict(lambda: "unknown")
49 | try:
50 | IPINFO = ipinfo.getHandler().getDetails().all
51 | except (
52 | requests.exceptions.ReadTimeout,
53 | requests.exceptions.ConnectionError,
54 | ipinfo.exceptions.RequestQuotaExceededError,
55 | ) as error:
56 | logger.warning("Cannot get current location info. {}", error)
57 |
--------------------------------------------------------------------------------
/pyrobbot/__main__.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | """Program's entry point."""
3 | from .argparse_wrapper import get_parsed_args
4 |
5 |
6 | def main(argv=None):
7 | """Program's main routine."""
8 | args = get_parsed_args(argv=argv)
9 | args.run_command(args=args)
10 |
--------------------------------------------------------------------------------
/pyrobbot/app/.streamlit/config.toml:
--------------------------------------------------------------------------------
1 | # Stremlit configs.
2 | # See .
3 | [browser]
4 | gatherUsageStats = false
5 |
6 | [runner]
7 | fastReruns = true
8 |
9 | [server]
10 | runOnSave = true
11 |
12 | [theme]
13 | base = "light"
14 | # Colors
15 | primaryColor = "#2BB5E8"
16 |
--------------------------------------------------------------------------------
/pyrobbot/app/__init__.py:
--------------------------------------------------------------------------------
1 | """UI for the package."""
2 |
--------------------------------------------------------------------------------
/pyrobbot/app/app.py:
--------------------------------------------------------------------------------
1 | """Entrypoint for the package's UI."""
2 |
3 | from pyrobbot import GeneralDefinitions
4 | from pyrobbot.app.multipage import MultipageChatbotApp
5 |
6 |
7 | def run_app():
8 | """Create and run an instance of the pacage's app."""
9 | MultipageChatbotApp(
10 | page_title=GeneralDefinitions.APP_NAME,
11 | page_icon=":speech_balloon:",
12 | layout="wide",
13 | ).render()
14 |
15 |
16 | if __name__ == "__main__":
17 | run_app()
18 |
--------------------------------------------------------------------------------
/pyrobbot/app/app_page_templates.py:
--------------------------------------------------------------------------------
1 | """Utilities for creating pages in a streamlit app."""
2 |
3 | import base64
4 | import contextlib
5 | import datetime
6 | import queue
7 | import time
8 | import uuid
9 | from abc import ABC, abstractmethod
10 | from pathlib import Path
11 | from typing import TYPE_CHECKING, Union
12 |
13 | import streamlit as st
14 | from audio_recorder_streamlit import audio_recorder
15 | from loguru import logger
16 | from pydub import AudioSegment
17 | from pydub.exceptions import CouldntDecodeError
18 | from streamlit_mic_recorder import mic_recorder
19 |
20 | from pyrobbot.chat_configs import VoiceChatConfigs
21 |
22 | from .app_utils import (
23 | AsyncReplier,
24 | WebAppChat,
25 | filter_page_info_from_queue,
26 | get_avatar_images,
27 | load_chime,
28 | )
29 |
30 | if TYPE_CHECKING:
31 | from .multipage import MultipageChatbotApp
32 |
33 | # Sentinel object for when a chat is recovered from cache
34 | _RecoveredChat = object()
35 |
36 |
37 | class AppPage(ABC):
38 | """Abstract base class for a page within a streamlit application."""
39 |
40 | def __init__(
41 | self, parent: "MultipageChatbotApp", sidebar_title: str = "", page_title: str = ""
42 | ):
43 | """Initializes a new instance of the AppPage class.
44 |
45 | Args:
46 | parent (MultipageChatbotApp): The parent app of the page.
47 | sidebar_title (str, optional): The title to be displayed in the sidebar.
48 | Defaults to an empty string.
49 | page_title (str, optional): The title to be displayed on the page.
50 | Defaults to an empty string.
51 | """
52 | self.page_id = str(uuid.uuid4())
53 | self.parent = parent
54 | self.page_number = self.parent.state.get("n_created_pages", 0) + 1
55 |
56 | chat_number_for_title = f"Chat #{self.page_number}"
57 | if page_title is _RecoveredChat:
58 | self.fallback_page_title = f"{chat_number_for_title.strip('#')} (Recovered)"
59 | page_title = None
60 | else:
61 | self.fallback_page_title = chat_number_for_title
62 | if page_title:
63 | self.title = page_title
64 |
65 | self._fallback_sidebar_title = page_title if page_title else chat_number_for_title
66 | if sidebar_title:
67 | self.sidebar_title = sidebar_title
68 |
69 | @property
70 | def state(self):
71 | """Return the state of the page, for persistence of data."""
72 | if self.page_id not in self.parent.state:
73 | self.parent.state[self.page_id] = {}
74 | return self.parent.state[self.page_id]
75 |
76 | @property
77 | def sidebar_title(self):
78 | """Get the title of the page in the sidebar."""
79 | return self.state.get("sidebar_title", self._fallback_sidebar_title)
80 |
81 | @sidebar_title.setter
82 | def sidebar_title(self, value: str):
83 | """Set the sidebar title for the page."""
84 | self.state["sidebar_title"] = value
85 |
86 | @property
87 | def title(self):
88 | """Get the title of the page."""
89 | return self.state.get("page_title", self.fallback_page_title)
90 |
91 | @title.setter
92 | def title(self, value: str):
93 | """Set the title of the page."""
94 | self.state["page_title"] = value
95 |
96 | @abstractmethod
97 | def render(self):
98 | """Create the page."""
99 |
100 | def continuous_mic_recorder(self):
101 | """Record audio from the microphone in a continuous loop."""
102 | audio_bytes = audio_recorder(
103 | text="", icon_size="2x", energy_threshold=-1, key=f"AR_{self.page_id}"
104 | )
105 |
106 | if audio_bytes is None:
107 | return AudioSegment.silent(duration=0)
108 |
109 | return AudioSegment(data=audio_bytes)
110 |
111 | def manual_switch_mic_recorder(self):
112 | """Record audio from the microphone."""
113 | red_square = "\U0001F7E5"
114 | microphone = "\U0001F3A4"
115 | play_button = "\U000025B6"
116 |
117 | recording = mic_recorder(
118 | key=f"audiorecorder_widget_{self.page_id}",
119 | start_prompt=play_button + microphone,
120 | stop_prompt=red_square,
121 | just_once=True,
122 | use_container_width=True,
123 | )
124 |
125 | if recording is None:
126 | return AudioSegment.silent(duration=0)
127 |
128 | return AudioSegment(
129 | data=recording["bytes"],
130 | sample_width=recording["sample_width"],
131 | frame_rate=recording["sample_rate"],
132 | channels=1,
133 | )
134 |
135 | def render_custom_audio_player(
136 | self,
137 | audio: Union[AudioSegment, str, Path, None],
138 | parent_element=None,
139 | autoplay: bool = True,
140 | hidden=False,
141 | ):
142 | """Autoplay an audio segment in the streamlit app."""
143 | # Adaped from:
146 |
147 | if audio is None:
148 | logger.debug("No audio to play. Not rendering audio player.")
149 | return
150 |
151 | if isinstance(audio, (str, Path)):
152 | audio = AudioSegment.from_file(audio, format="mp3")
153 | elif not isinstance(audio, AudioSegment):
154 | raise TypeError(f"Invalid type for audio: {type(audio)}")
155 |
156 | autoplay = "autoplay" if autoplay else ""
157 | hidden = "hidden" if hidden else ""
158 |
159 | data = audio.export(format="mp3").read()
160 | b64 = base64.b64encode(data).decode()
161 | md = f"""
162 |
165 | """
166 | parent_element = parent_element or st
167 | parent_element.markdown(md, unsafe_allow_html=True)
168 | if autoplay:
169 | time.sleep(audio.duration_seconds)
170 |
171 |
172 | class ChatBotPage(AppPage):
173 | """Implement a chatbot page in a streamlit application, inheriting from AppPage."""
174 |
175 | def __init__(
176 | self,
177 | parent: "MultipageChatbotApp",
178 | chat_obj: WebAppChat = None,
179 | sidebar_title: str = "",
180 | page_title: str = "",
181 | ):
182 | """Initialize new instance of the ChatBotPage class with an opt WebAppChat object.
183 |
184 | Args:
185 | parent (MultipageChatbotApp): The parent app of the page.
186 | chat_obj (WebAppChat): The chat object. Defaults to None.
187 | sidebar_title (str): The sidebar title for the chatbot page.
188 | Defaults to an empty string.
189 | page_title (str): The title for the chatbot page.
190 | Defaults to an empty string.
191 | """
192 | super().__init__(
193 | parent=parent, sidebar_title=sidebar_title, page_title=page_title
194 | )
195 |
196 | if chat_obj:
197 | logger.debug("Setting page chat to chat with ID=<{}>", chat_obj.id)
198 | self.chat_obj = chat_obj
199 | else:
200 | logger.debug("ChatBotPage created wihout specific chat. Creating default.")
201 | _ = self.chat_obj
202 | logger.debug("Default chat id=<{}>", self.chat_obj.id)
203 |
204 | self.avatars = get_avatar_images()
205 |
206 | @property
207 | def chat_configs(self) -> VoiceChatConfigs:
208 | """Return the configs used for the page's chat object."""
209 | if "chat_configs" not in self.state:
210 | self.state["chat_configs"] = self.parent.state["chat_configs"]
211 | return self.state["chat_configs"]
212 |
213 | @chat_configs.setter
214 | def chat_configs(self, value: VoiceChatConfigs):
215 | self.state["chat_configs"] = VoiceChatConfigs.model_validate(value)
216 | if "chat_obj" in self.state:
217 | del self.state["chat_obj"]
218 |
219 | @property
220 | def chat_obj(self) -> WebAppChat:
221 | """Return the chat object responsible for the queries on this page."""
222 | if "chat_obj" not in self.state:
223 | self.chat_obj = WebAppChat(
224 | configs=self.chat_configs, openai_client=self.parent.openai_client
225 | )
226 | return self.state["chat_obj"]
227 |
228 | @chat_obj.setter
229 | def chat_obj(self, new_chat_obj: WebAppChat):
230 | current_chat = self.state.get("chat_obj")
231 | if current_chat:
232 | logger.debug(
233 | "Copy new_chat=<{}> into current_chat=<{}>. Current chat ID kept.",
234 | new_chat_obj.id,
235 | current_chat.id,
236 | )
237 | current_chat.save_cache()
238 | new_chat_obj.id = current_chat.id
239 | new_chat_obj.openai_client = self.parent.openai_client
240 | self.state["chat_obj"] = new_chat_obj
241 | self.state["chat_configs"] = new_chat_obj.configs
242 | new_chat_obj.save_cache()
243 |
244 | @property
245 | def chat_history(self) -> list[dict[str, str]]:
246 | """Return the chat history of the page."""
247 | if "messages" not in self.state:
248 | self.state["messages"] = []
249 | return self.state["messages"]
250 |
251 | def render_chat_history(self):
252 | """Render the chat history of the page. Do not include system messages."""
253 | with st.chat_message("assistant", avatar=self.avatars["assistant"]):
254 | st.markdown(self.chat_obj.initial_greeting)
255 |
256 | for message in self.chat_history:
257 | role = message["role"]
258 | if role == "system":
259 | continue
260 | with st.chat_message(role, avatar=self.avatars.get(role)):
261 | with contextlib.suppress(KeyError):
262 | if role == "assistant":
263 | st.caption(message["chat_model"])
264 | else:
265 | st.caption(message["timestamp"])
266 | st.markdown(message["content"])
267 | with contextlib.suppress(KeyError):
268 | if audio := message.get("reply_audio_file_path"):
269 | with contextlib.suppress(CouldntDecodeError):
270 | self.render_custom_audio_player(audio, autoplay=False)
271 |
272 | def render_cost_estimate_page(self):
273 | """Render the estimated costs information in the chat."""
274 | general_df = self.chat_obj.general_token_usage_db.get_usage_balance_dataframe()
275 | chat_df = self.chat_obj.token_usage_db.get_usage_balance_dataframe()
276 | dfs = {"All Recorded Chats": general_df, "Current Chat": chat_df}
277 |
278 | st.header(dfs["Current Chat"].attrs["description"], divider="rainbow")
279 | with st.container():
280 | for category, df in dfs.items():
281 | st.subheader(f"**{category}**")
282 | st.dataframe(df)
283 | st.write()
284 | st.caption(df.attrs["disclaimer"])
285 |
286 | @property
287 | def voice_output(self) -> bool:
288 | """Return the state of the voice output toggle."""
289 | return st.session_state.get("toggle_voice_output", False)
290 |
291 | def play_chime(self, chime_type: str = "success", parent_element=None):
292 | """Sound a chime to send notificatons to the user."""
293 | chime = load_chime(chime_type)
294 | self.render_custom_audio_player(
295 | chime, hidden=True, autoplay=True, parent_element=parent_element
296 | )
297 |
298 | def render_title(self):
299 | """Render the title of the chatbot page."""
300 | with st.container(height=145, border=False):
301 | self.title_container = st.empty()
302 | self.title_container.subheader(self.title, divider="rainbow")
303 | left, _ = st.columns([0.7, 0.3])
304 | with left:
305 | self.status_msg_container = st.empty()
306 |
307 | @property
308 | def direct_text_prompt(self):
309 | """Render chat inut widgets and return the user's input."""
310 | placeholder = (
311 | f"Send a message to {self.chat_obj.assistant_name} ({self.chat_obj.model})"
312 | )
313 | text_from_manual_audio_recorder = ""
314 | with st.container():
315 | left, right = st.columns([0.9, 0.1])
316 | with left:
317 | text_from_chat_input_widget = st.chat_input(placeholder=placeholder)
318 | with right:
319 | if not st.session_state.get("toggle_continuous_voice_input"):
320 | audio = self.manual_switch_mic_recorder()
321 | text_from_manual_audio_recorder = self.chat_obj.stt(audio).text
322 |
323 | return text_from_chat_input_widget or text_from_manual_audio_recorder
324 |
325 | @property
326 | def continuous_text_prompt(self):
327 | """Wait until a promp from the continuous stream is ready and return it."""
328 | if not st.session_state.get("toggle_continuous_voice_input"):
329 | return None
330 |
331 | if not self.parent.continuous_audio_input_engine_is_running:
332 | logger.warning("Continuous audio input engine is not running!!!")
333 | self.status_msg_container.error(
334 | "The continuous audio input engine is not running!!!"
335 | )
336 | return None
337 |
338 | logger.debug("Running on continuous audio prompt. Waiting user input...")
339 | with self.status_msg_container:
340 | self.play_chime(chime_type="warning")
341 | with st.spinner(f"{self.chat_obj.assistant_name} is listening..."):
342 | while True:
343 | with self.parent.text_prompt_queue.mutex:
344 | this_page_prompt_queue = filter_page_info_from_queue(
345 | app_page=self, the_queue=self.parent.text_prompt_queue
346 | )
347 | with contextlib.suppress(queue.Empty):
348 | if prompt := this_page_prompt_queue.get_nowait()["text"]:
349 | this_page_prompt_queue.task_done()
350 | break
351 | logger.trace("Still waiting for user text prompt...")
352 | time.sleep(0.1)
353 |
354 | logger.debug("Done getting user input: {}", prompt)
355 | return prompt
356 |
357 | def _render_chatbot_page(self): # noqa: PLR0915
358 | """Render a chatbot page.
359 |
360 | Adapted from:
361 |
362 |
363 | """
364 | self.chat_obj.reply_only_as_text = not self.voice_output
365 |
366 | self.render_title()
367 | chat_msgs_container = st.container(height=550, border=False)
368 | with chat_msgs_container:
369 | self.render_chat_history()
370 |
371 | # The inputs should be rendered after the chat history. There is a performance
372 | # penalty otherwise, as rendering the history causes streamlit to rerun the
373 | # entire page
374 | direct_text_prompt = self.direct_text_prompt
375 | continuous_stt_prompt = "" if direct_text_prompt else self.continuous_text_prompt
376 | prompt = direct_text_prompt or continuous_stt_prompt
377 |
378 | if prompt:
379 | logger.opt(colors=True).debug("Recived prompt: {}", prompt)
380 | self.parent.reply_ongoing.set()
381 |
382 | if continuous_stt_prompt:
383 | self.play_chime("success")
384 | self.status_msg_container.success("Got your message!")
385 | time.sleep(0.5)
386 | elif continuous_stt_prompt:
387 | self.status_msg_container.warning(
388 | "Could not understand your message. Please try again."
389 | )
390 | logger.opt(colors=True).debug("Received empty prompt")
391 | self.parent.reply_ongoing.clear()
392 |
393 | if prompt:
394 | with chat_msgs_container:
395 | # Process user input
396 | if prompt:
397 | time_now = datetime.datetime.now().replace(microsecond=0)
398 | self.state.update({"chat_started": True})
399 | # Display user message in chat message container
400 | with st.chat_message("user", avatar=self.avatars["user"]):
401 | st.caption(time_now)
402 | st.markdown(prompt)
403 | self.chat_history.append(
404 | {
405 | "role": "user",
406 | "name": self.chat_obj.username,
407 | "content": prompt,
408 | "timestamp": time_now,
409 | }
410 | )
411 |
412 | # Display (stream) assistant response in chat message container
413 | with st.chat_message("assistant", avatar=self.avatars["assistant"]):
414 | # Process text and audio replies asynchronously
415 | replier = AsyncReplier(self, prompt)
416 | reply = replier.stream_text_and_audio_reply()
417 | self.chat_history.append(
418 | {
419 | "role": "assistant",
420 | "name": self.chat_obj.assistant_name,
421 | "content": reply["text"],
422 | "reply_audio_file_path": reply["audio"],
423 | "chat_model": self.chat_obj.model,
424 | }
425 | )
426 |
427 | # Reset title according to conversation initial contents
428 | min_history_len_for_summary = 3
429 | if (
430 | "page_title" not in self.state
431 | and len(self.chat_history) > min_history_len_for_summary
432 | ):
433 | logger.debug("Working out conversation topic...")
434 | prompt = "Summarize the previous messages in max 4 words"
435 | title = "".join(self.chat_obj.respond_system_prompt(prompt))
436 | self.chat_obj.metadata["page_title"] = title
437 | self.chat_obj.metadata["sidebar_title"] = title
438 | self.chat_obj.save_cache()
439 |
440 | self.title = title
441 | self.sidebar_title = title
442 | self.title_container.header(title, divider="rainbow")
443 |
444 | # Clear the prompt queue for this page, to remove old prompts
445 | with self.parent.continuous_user_prompt_queue.mutex:
446 | filter_page_info_from_queue(
447 | app_page=self,
448 | the_queue=self.parent.continuous_user_prompt_queue,
449 | )
450 | with self.parent.text_prompt_queue.mutex:
451 | filter_page_info_from_queue(
452 | app_page=self, the_queue=self.parent.text_prompt_queue
453 | )
454 |
455 | replier.join()
456 | self.parent.reply_ongoing.clear()
457 |
458 | if continuous_stt_prompt and not self.parent.reply_ongoing.is_set():
459 | logger.opt(colors=True).debug(
460 | "Rerunning the app to wait for new input..."
461 | )
462 | st.rerun()
463 |
464 | def render(self):
465 | """Render the app's chatbot or costs page, depending on user choice."""
466 |
467 | def _trim_page_padding():
468 | md = """
469 |
477 | """
478 | st.markdown(md, unsafe_allow_html=True)
479 |
480 | _trim_page_padding()
481 | if st.session_state.get("toggle_show_costs"):
482 | self.render_cost_estimate_page()
483 | else:
484 | self._render_chatbot_page()
485 | logger.debug("Reached the end of the chatbot page.")
486 |
--------------------------------------------------------------------------------
/pyrobbot/app/app_utils.py:
--------------------------------------------------------------------------------
1 | """Utility functions and classes for the app."""
2 |
3 | import contextlib
4 | import datetime
5 | import os
6 | import queue
7 | import threading
8 | from typing import TYPE_CHECKING
9 |
10 | import streamlit as st
11 | from loguru import logger
12 | from PIL import Image
13 | from pydub import AudioSegment
14 | from streamlit.runtime.scriptrunner import add_script_run_ctx
15 | from twilio.rest import Client as TwilioClient
16 |
17 | from pyrobbot import GeneralDefinitions
18 | from pyrobbot.chat import AssistantResponseChunk
19 | from pyrobbot.voice_chat import VoiceChat
20 |
21 | if TYPE_CHECKING:
22 | from .app_page_templates import AppPage
23 |
24 |
25 | class WebAppChat(VoiceChat):
26 | """A chat object for web apps."""
27 |
28 | def __init__(self, **kwargs):
29 | """Initialize a new instance of the WebAppChat class."""
30 | super().__init__(**kwargs)
31 | self.tts_conversion_watcher_thread.start()
32 | self.handle_update_audio_history_thread.start()
33 |
34 |
35 | class AsyncReplier:
36 | """Asynchronously reply to a prompt and stream the text & audio reply."""
37 |
38 | def __init__(self, app_page: "AppPage", prompt: str):
39 | """Initialize a new instance of the AsyncReplier class."""
40 | self.app_page = app_page
41 | self.prompt = prompt
42 |
43 | self.chat_obj = app_page.chat_obj
44 | self.question_answer_chunks_queue = queue.Queue()
45 |
46 | self.threads = [
47 | threading.Thread(name="queue_text_chunks", target=self.queue_text_chunks),
48 | threading.Thread(name="play_queued_audios", target=self.play_queued_audios),
49 | ]
50 |
51 | self.start()
52 |
53 | def start(self):
54 | """Start the threads."""
55 | for thread in self.threads:
56 | add_script_run_ctx(thread)
57 | thread.start()
58 |
59 | def join(self):
60 | """Wait for all threads to finish."""
61 | logger.debug("Waiting for {} to finish...", type(self).__name__)
62 | for thread in self.threads:
63 | thread.join()
64 | logger.debug("All {} threads finished", type(self).__name__)
65 |
66 | def queue_text_chunks(self):
67 | """Get chunks of the text reply to the prompt and queue them for display."""
68 | exchange_id = None
69 | for chunk in self.chat_obj.answer_question(self.prompt):
70 | self.question_answer_chunks_queue.put(chunk)
71 | exchange_id = chunk.exchange_id
72 | self.question_answer_chunks_queue.put(
73 | AssistantResponseChunk(exchange_id=exchange_id, content=None)
74 | )
75 |
76 | def play_queued_audios(self):
77 | """Play queued audio segments."""
78 | while True:
79 | try:
80 | logger.debug(
81 | "Waiting for item from the audio reply chunk queue ({}) items so far",
82 | self.chat_obj.play_speech_queue.qsize(),
83 | )
84 | speech_queue_item = self.chat_obj.play_speech_queue.get()
85 | audio = speech_queue_item["speech"]
86 | if audio is None:
87 | logger.debug("Got `None`. No more audio reply chunks to play")
88 | self.chat_obj.play_speech_queue.task_done()
89 | break
90 |
91 | logger.debug("Playing audio reply chunk ({}s)", audio.duration_seconds)
92 | self.app_page.render_custom_audio_player(
93 | audio,
94 | parent_element=self.app_page.status_msg_container,
95 | autoplay=True,
96 | hidden=True,
97 | )
98 | logger.debug(
99 | "Done playing audio reply chunk ({}s)", audio.duration_seconds
100 | )
101 | self.chat_obj.play_speech_queue.task_done()
102 | except Exception as error: # noqa: BLE001
103 | logger.opt(exception=True).debug(
104 | "Error playing audio reply chunk ({}s)", audio.duration_seconds
105 | )
106 | logger.error(error)
107 | break
108 | finally:
109 | self.app_page.status_msg_container.empty()
110 |
111 | def stream_text_and_audio_reply(self):
112 | """Stream the text and audio reply to the display."""
113 | text_reply_container = st.empty()
114 | audio_reply_container = st.empty()
115 |
116 | chunk = AssistantResponseChunk(exchange_id=None, content="")
117 | full_response = ""
118 | text_reply_container.markdown("▌")
119 | self.app_page.status_msg_container.empty()
120 | while chunk.content is not None:
121 | logger.trace("Waiting for text or audio chunks...")
122 | # Render text
123 | with contextlib.suppress(queue.Empty):
124 | chunk = self.question_answer_chunks_queue.get_nowait()
125 | if chunk.content is not None:
126 | full_response += chunk.content
127 | text_reply_container.markdown(full_response + "▌")
128 | self.question_answer_chunks_queue.task_done()
129 |
130 | text_reply_container.caption(datetime.datetime.now().replace(microsecond=0))
131 | text_reply_container.markdown(full_response)
132 |
133 | logger.debug("Waiting for the audio reply to finish...")
134 | self.chat_obj.play_speech_queue.join()
135 |
136 | logger.debug("Getting path to full audio file for the reply...")
137 | history_entry_for_this_reply = (
138 | self.chat_obj.context_handler.database.retrieve_history(
139 | exchange_id=chunk.exchange_id
140 | )
141 | )
142 | full_audio_fpath = history_entry_for_this_reply["reply_audio_file_path"].iloc[0]
143 | if full_audio_fpath is None:
144 | logger.warning("Path to full audio file not available")
145 | else:
146 | logger.debug("Got path to full audio file: {}", full_audio_fpath)
147 | self.app_page.render_custom_audio_player(
148 | full_audio_fpath, parent_element=audio_reply_container, autoplay=False
149 | )
150 |
151 | return {"text": full_response, "audio": full_audio_fpath}
152 |
153 |
154 | @st.cache_data
155 | def get_ice_servers():
156 | """Use Twilio's TURN server as recommended by the streamlit-webrtc developers."""
157 | try:
158 | account_sid = os.environ["TWILIO_ACCOUNT_SID"]
159 | auth_token = os.environ["TWILIO_AUTH_TOKEN"]
160 | except KeyError:
161 | logger.warning(
162 | "Twilio credentials are not set. Cannot use their TURN servers. "
163 | "Falling back to a free STUN server from Google."
164 | )
165 | return [{"urls": ["stun:stun.l.google.com:19302"]}]
166 |
167 | client = TwilioClient(account_sid, auth_token)
168 | token = client.tokens.create()
169 | return token.ice_servers
170 |
171 |
172 | def filter_page_info_from_queue(app_page: "AppPage", the_queue: queue.Queue):
173 | """Filter `app_page`'s data from `queue` inplace. Return queue of items in `app_page`.
174 |
175 | **Use with original_queue.mutex!!**
176 |
177 | Args:
178 | app_page: The page whose entries should be removed.
179 | the_queue: The queue to be filtered.
180 |
181 | Returns:
182 | queue.Queue: The queue with only the entries from `app_page`.
183 |
184 | Example:
185 | ```
186 | with the_queue.mutex:
187 | this_page_data = remove_page_info_from_queue(app_page, the_queue)
188 | ```
189 | """
190 | queue_with_only_entries_from_other_pages = queue.Queue()
191 | items_from_page_queue = queue.Queue()
192 | while the_queue.queue:
193 | original_queue_entry = the_queue.queue.popleft()
194 | if original_queue_entry["page"].page_id == app_page.page_id:
195 | items_from_page_queue.put(original_queue_entry)
196 | else:
197 | queue_with_only_entries_from_other_pages.put(original_queue_entry)
198 |
199 | the_queue.queue = queue_with_only_entries_from_other_pages.queue
200 | return items_from_page_queue
201 |
202 |
203 | @st.cache_data
204 | def get_avatar_images():
205 | """Return the avatar images for the assistant and the user."""
206 | avatar_files_dir = GeneralDefinitions.APP_DIR / "data"
207 | assistant_avatar_file_path = avatar_files_dir / "assistant_avatar.png"
208 | user_avatar_file_path = avatar_files_dir / "user_avatar.png"
209 | assistant_avatar_image = Image.open(assistant_avatar_file_path)
210 | user_avatar_image = Image.open(user_avatar_file_path)
211 |
212 | return {"assistant": assistant_avatar_image, "user": user_avatar_image}
213 |
214 |
215 | @st.cache_data
216 | def load_chime(chime_type: str) -> AudioSegment:
217 | """Load a chime sound from the data directory."""
218 | return AudioSegment.from_file(
219 | GeneralDefinitions.APP_DIR / "data" / f"{chime_type}.wav", format="wav"
220 | )
221 |
--------------------------------------------------------------------------------
/pyrobbot/app/data/assistant_avatar.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/paulovcmedeiros/pyRobBot/7e77d3b1aee052cfa350a806371de75b9b713ad6/pyrobbot/app/data/assistant_avatar.png
--------------------------------------------------------------------------------
/pyrobbot/app/data/powered-by-openai-badge-outlined-on-dark.svg:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/pyrobbot/app/data/success.wav:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/paulovcmedeiros/pyRobBot/7e77d3b1aee052cfa350a806371de75b9b713ad6/pyrobbot/app/data/success.wav
--------------------------------------------------------------------------------
/pyrobbot/app/data/user_avatar.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/paulovcmedeiros/pyRobBot/7e77d3b1aee052cfa350a806371de75b9b713ad6/pyrobbot/app/data/user_avatar.png
--------------------------------------------------------------------------------
/pyrobbot/app/data/warning.wav:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/paulovcmedeiros/pyRobBot/7e77d3b1aee052cfa350a806371de75b9b713ad6/pyrobbot/app/data/warning.wav
--------------------------------------------------------------------------------
/pyrobbot/argparse_wrapper.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | """Wrappers for argparse functionality."""
3 | import argparse
4 | import contextlib
5 | import sys
6 |
7 | from pydantic import BaseModel
8 |
9 | from . import GeneralDefinitions
10 | from .chat_configs import ChatOptions, VoiceChatConfigs
11 | from .command_definitions import (
12 | accounting_report,
13 | browser_chat,
14 | terminal_chat,
15 | voice_chat,
16 | )
17 |
18 |
19 | def _populate_parser_from_pydantic_model(parser, model: BaseModel):
20 | _argarse2pydantic = {
21 | "type": model.get_type,
22 | "default": model.get_default,
23 | "choices": model.get_allowed_values,
24 | "help": model.get_description,
25 | }
26 |
27 | for field_name, field in model.model_fields.items():
28 | with contextlib.suppress(AttributeError):
29 | if not field.json_schema_extra.get("changeable", True):
30 | continue
31 |
32 | args_opts = {
33 | key: _argarse2pydantic[key](field_name)
34 | for key in _argarse2pydantic
35 | if _argarse2pydantic[key](field_name) is not None
36 | }
37 |
38 | if args_opts.get("type") == bool:
39 | if args_opts.get("default") is True:
40 | args_opts["action"] = "store_false"
41 | else:
42 | args_opts["action"] = "store_true"
43 | args_opts.pop("default", None)
44 | args_opts.pop("type", None)
45 |
46 | args_opts["required"] = field.is_required()
47 | if "help" in args_opts:
48 | args_opts["help"] = f"{args_opts['help']} (default: %(default)s)"
49 | if "default" in args_opts and isinstance(args_opts["default"], (list, tuple)):
50 | args_opts.pop("type", None)
51 | args_opts["nargs"] = "*"
52 |
53 | parser.add_argument(f"--{field_name.replace('_', '-')}", **args_opts)
54 |
55 | return parser
56 |
57 |
58 | def get_parsed_args(argv=None, default_command="ui"):
59 | """Get parsed command line arguments.
60 |
61 | Args:
62 | argv (list): A list of passed command line args.
63 | default_command (str, optional): The default command to run.
64 |
65 | Returns:
66 | argparse.Namespace: Parsed command line arguments.
67 |
68 | """
69 | if argv is None:
70 | argv = sys.argv[1:]
71 | first_argv = next(iter(argv), "'")
72 | info_flags = ["--version", "-v", "-h", "--help"]
73 | if not argv or (first_argv.startswith("-") and first_argv not in info_flags):
74 | argv = [default_command, *argv]
75 |
76 | # Main parser that will handle the script's commands
77 | main_parser = argparse.ArgumentParser(
78 | formatter_class=argparse.ArgumentDefaultsHelpFormatter
79 | )
80 | main_parser.add_argument(
81 | "--version",
82 | "-v",
83 | action="version",
84 | version=f"{GeneralDefinitions.PACKAGE_NAME} v" + GeneralDefinitions.VERSION,
85 | )
86 | subparsers = main_parser.add_subparsers(
87 | title="commands",
88 | dest="command",
89 | required=True,
90 | description=(
91 | "Valid commands (note that commands also accept their "
92 | + "own arguments, in particular [-h]):"
93 | ),
94 | help="command description",
95 | )
96 |
97 | # Common options to most commands
98 | chat_options_parser = _populate_parser_from_pydantic_model(
99 | parser=argparse.ArgumentParser(
100 | formatter_class=argparse.ArgumentDefaultsHelpFormatter, add_help=False
101 | ),
102 | model=ChatOptions,
103 | )
104 | chat_options_parser.add_argument(
105 | "--report-accounting-when-done",
106 | action="store_true",
107 | help="Report estimated costs when done with the chat.",
108 | )
109 |
110 | # Web app chat
111 | parser_ui = subparsers.add_parser(
112 | "ui",
113 | aliases=["app", "webapp", "browser"],
114 | parents=[chat_options_parser],
115 | help="Run the chat UI on the browser.",
116 | )
117 | parser_ui.set_defaults(run_command=browser_chat)
118 |
119 | # Voice chat
120 | voice_options_parser = _populate_parser_from_pydantic_model(
121 | parser=argparse.ArgumentParser(
122 | formatter_class=argparse.ArgumentDefaultsHelpFormatter, add_help=False
123 | ),
124 | model=VoiceChatConfigs,
125 | )
126 | parser_voice_chat = subparsers.add_parser(
127 | "voice",
128 | aliases=["v", "speech", "talk"],
129 | parents=[voice_options_parser],
130 | help="Run the chat over voice only.",
131 | )
132 | parser_voice_chat.set_defaults(run_command=voice_chat)
133 |
134 | # Terminal chat
135 | parser_terminal = subparsers.add_parser(
136 | "terminal",
137 | aliases=["."],
138 | parents=[chat_options_parser],
139 | help="Run the chat on the terminal.",
140 | )
141 | parser_terminal.set_defaults(run_command=terminal_chat)
142 |
143 | # Accounting report
144 | parser_accounting = subparsers.add_parser(
145 | "accounting",
146 | aliases=["acc"],
147 | help="Show the estimated number of used tokens and associated costs, and exit.",
148 | )
149 | parser_accounting.set_defaults(run_command=accounting_report)
150 |
151 | return main_parser.parse_args(argv)
152 |
--------------------------------------------------------------------------------
/pyrobbot/chat.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | """Implementation of the Chat class."""
3 | import contextlib
4 | import json
5 | import shutil
6 | import uuid
7 | from collections import defaultdict
8 | from datetime import datetime
9 | from pathlib import Path
10 | from typing import Optional
11 |
12 | import openai
13 | from attr import dataclass
14 | from loguru import logger
15 | from pydub import AudioSegment
16 | from tzlocal import get_localzone
17 |
18 | from . import GeneralDefinitions
19 | from .chat_configs import ChatOptions
20 | from .chat_context import EmbeddingBasedChatContext, FullHistoryChatContext
21 | from .general_utils import (
22 | AlternativeConstructors,
23 | ReachedMaxNumberOfAttemptsError,
24 | get_call_traceback,
25 | )
26 | from .internet_utils import websearch
27 | from .openai_utils import OpenAiClientWrapper, make_api_chat_completion_call
28 | from .sst_and_tts import SpeechToText, TextToSpeech
29 | from .tokens import PRICE_PER_K_TOKENS_EMBEDDINGS, TokenUsageDatabase
30 |
31 |
32 | @dataclass
33 | class AssistantResponseChunk:
34 | """A chunk of the assistant's response."""
35 |
36 | exchange_id: str
37 | content: str
38 | chunk_type: str = "text"
39 |
40 |
41 | class Chat(AlternativeConstructors):
42 | """Manages conversations with an AI chat model.
43 |
44 | This class encapsulates the chat behavior, including handling the chat context,
45 | managing cache directories, and interfacing with the OpenAI API for generating chat
46 | responses.
47 | """
48 |
49 | _translation_cache = defaultdict(dict)
50 | default_configs = ChatOptions()
51 |
52 | def __init__(
53 | self,
54 | openai_client: OpenAiClientWrapper = None,
55 | configs: ChatOptions = default_configs,
56 | ):
57 | """Initializes a chat instance.
58 |
59 | Args:
60 | configs (ChatOptions, optional): The configurations for this chat session.
61 | openai_client (openai.OpenAI, optional): An OpenAiClientWrapper instance.
62 |
63 | Raises:
64 | NotImplementedError: If the context model specified in configs is unknown.
65 | """
66 | self.id = str(uuid.uuid4())
67 | logger.trace(
68 | "Init chat {}, as requested by from <{}>", self.id, get_call_traceback()
69 | )
70 | logger.debug("Init chat {}", self.id)
71 |
72 | self._code_marker = "\uE001" # TEST
73 |
74 | self._passed_configs = configs
75 | for field in self._passed_configs.model_fields:
76 | setattr(self, field, self._passed_configs[field])
77 |
78 | try:
79 | self.openai_client = (
80 | openai_client
81 | if openai_client is not None
82 | else OpenAiClientWrapper(
83 | timeout=self.timeout, private_mode=self.private_mode
84 | )
85 | )
86 | except openai.OpenAIError as error:
87 | logger.opt(exception=True).debug(error)
88 | logger.error(
89 | "Cannot connect to OpenAI API. Please verify your API key. {}.", error
90 | )
91 | self.openai_client = None
92 |
93 | @property
94 | def base_directive(self):
95 | """Return the base directive for the LLM."""
96 | code_marker = self._code_marker
97 | local_datetime = datetime.now(get_localzone()).isoformat(timespec="seconds")
98 | msg_content = (
99 | f"Your name is {self.assistant_name}. Your model is {self.model}\n"
100 | f"You are a helpful assistant to {self.username}\n"
101 | f"You have internet access\n"
102 | f"You MUST ALWAYS write {code_marker} before AND after code blocks. Example: "
103 | f"```foo ... ``` MUST become {code_marker}```foo ... ```{code_marker}\n"
104 | f"The current city is {GeneralDefinitions.IPINFO['city']} in "
105 | f"{GeneralDefinitions.IPINFO['country_name']}\n"
106 | f"The local datetime is {local_datetime}\n"
107 | f"\nYou must follow all directives by {self.system_name} "
108 | f"unless otherwise instructed by {self.username}\n"
109 | f"You must reply in {self.language} unless otherwise instructed\n"
110 | "> If asked to look up online, web, internet etc, *you MUST ALWAYS AGREE*\n"
111 | "\n*IMPORTANT*: If you can't find information neither in your training data "
112 | "nor in the conversation context then, unless you have either already "
113 | f"looked up online or not understood {self.username}'s prompt, you must "
114 | "do as follows:\n"
115 | " > Do *NOT* apologise nor say you are sorry nor give any excuses.\n"
116 | " > Do *NOT* ask for permission to lookup online.\n"
117 | " > STATE CLEARLY that you will look it up online.\n"
118 | "\n".join([f"{instruct.strip(' .')}." for instruct in self.ai_instructions])
119 | )
120 | return {"role": "system", "name": self.system_name, "content": msg_content}
121 |
122 | @property
123 | def configs(self):
124 | """Return the chat's configs after initialisation."""
125 | configs_dict = {}
126 | for field_name in self._passed_configs.model_fields:
127 | configs_dict[field_name] = getattr(self, field_name)
128 | return self._passed_configs.model_validate(configs_dict)
129 |
130 | @property
131 | def tmp_dir(self):
132 | """Return the temporary directory for the chat."""
133 | return Path(self._tmp_dir.name)
134 |
135 | @property
136 | def cache_dir(self):
137 | """Return the cache directory for this chat."""
138 | parent_dir = self.openai_client.get_cache_dir(private_mode=self.private_mode)
139 | directory = parent_dir / f"chat_{self.id}"
140 | directory.mkdir(parents=True, exist_ok=True)
141 | return directory
142 |
143 | @property
144 | def configs_file(self):
145 | """File to store the chat's configs."""
146 | return self.cache_dir / "configs.json"
147 |
148 | @property
149 | def context_file_path(self):
150 | """Return the path to the file that stores the chat context and history."""
151 | return self.cache_dir / "embeddings.db"
152 |
153 | @property
154 | def context_handler(self):
155 | """Return the chat's context handler."""
156 | if self.context_model == "full-history":
157 | return FullHistoryChatContext(parent_chat=self)
158 |
159 | if self.context_model in PRICE_PER_K_TOKENS_EMBEDDINGS:
160 | return EmbeddingBasedChatContext(parent_chat=self)
161 |
162 | raise NotImplementedError(f"Unknown context model: {self.context_model}")
163 |
164 | @property
165 | def token_usage_db(self):
166 | """Return the chat's token usage database."""
167 | return TokenUsageDatabase(fpath=self.cache_dir / "chat_token_usage.db")
168 |
169 | @property
170 | def general_token_usage_db(self):
171 | """Return the general token usage database for all chats.
172 |
173 | Even private-mode chats will use this database to keep track of total token usage.
174 | """
175 | general_cache_dir = self.openai_client.get_cache_dir(private_mode=False)
176 | return TokenUsageDatabase(fpath=general_cache_dir.parent / "token_usage.db")
177 |
178 | @property
179 | def metadata_file(self):
180 | """File to store the chat metadata."""
181 | return self.cache_dir / "metadata.json"
182 |
183 | @property
184 | def metadata(self):
185 | """Keep metadata associated with the chat."""
186 | try:
187 | _ = self._metadata
188 | except AttributeError:
189 | try:
190 | with open(self.metadata_file, "r") as f:
191 | self._metadata = json.load(f)
192 | except (FileNotFoundError, json.decoder.JSONDecodeError):
193 | self._metadata = {}
194 | return self._metadata
195 |
196 | @metadata.setter
197 | def metadata(self, value):
198 | self._metadata = dict(value)
199 |
200 | def save_cache(self):
201 | """Store the chat's configs and metadata to the cache directory."""
202 | self.configs.export(self.configs_file)
203 |
204 | metadata = self.metadata # Trigger loading metadata if not yet done
205 | metadata["chat_id"] = self.id
206 | with open(self.metadata_file, "w") as metadata_f:
207 | json.dump(metadata, metadata_f, indent=2)
208 |
209 | def clear_cache(self):
210 | """Remove the cache directory."""
211 | logger.debug("Clearing cache for chat {}", self.id)
212 | shutil.rmtree(self.cache_dir, ignore_errors=True)
213 |
214 | def load_history(self):
215 | """Load chat history from cache."""
216 | return self.context_handler.load_history()
217 |
218 | @property
219 | def initial_greeting(self):
220 | """Return the initial greeting for the chat."""
221 | default_greeting = f"Hi! I'm {self.assistant_name}. How can I assist you?"
222 | user_set_greeting = False
223 | with contextlib.suppress(AttributeError):
224 | user_set_greeting = self._initial_greeting != ""
225 |
226 | if not user_set_greeting:
227 | self._initial_greeting = default_greeting
228 |
229 | custom_greeting = user_set_greeting and self._initial_greeting != default_greeting
230 | if custom_greeting or self.language[:2] != "en":
231 | self._initial_greeting = self._translate(self._initial_greeting)
232 |
233 | return self._initial_greeting
234 |
235 | @initial_greeting.setter
236 | def initial_greeting(self, value: str):
237 | self._initial_greeting = str(value).strip()
238 |
239 | def respond_user_prompt(self, prompt: str, **kwargs):
240 | """Respond to a user prompt."""
241 | yield from self._respond_prompt(prompt=prompt, role="user", **kwargs)
242 |
243 | def respond_system_prompt(
244 | self, prompt: str, add_to_history=False, skip_check=True, **kwargs
245 | ):
246 | """Respond to a system prompt."""
247 | for response_chunk in self._respond_prompt(
248 | prompt=prompt,
249 | role="system",
250 | add_to_history=add_to_history,
251 | skip_check=skip_check,
252 | **kwargs,
253 | ):
254 | yield response_chunk.content
255 |
256 | def yield_response_from_msg(
257 | self, prompt_msg: dict, add_to_history: bool = True, **kwargs
258 | ):
259 | """Yield response from a prompt message."""
260 | exchange_id = str(uuid.uuid4())
261 | code_marker = self._code_marker
262 | try:
263 | inside_code_block = False
264 | for answer_chunk in self._yield_response_from_msg(
265 | exchange_id=exchange_id,
266 | prompt_msg=prompt_msg,
267 | add_to_history=add_to_history,
268 | **kwargs,
269 | ):
270 | code_marker_detected = code_marker in answer_chunk
271 | inside_code_block = (code_marker_detected and not inside_code_block) or (
272 | inside_code_block and not code_marker_detected
273 | )
274 | yield AssistantResponseChunk(
275 | exchange_id=exchange_id,
276 | content=answer_chunk.strip(code_marker),
277 | chunk_type="code" if inside_code_block else "text",
278 | )
279 |
280 | except (ReachedMaxNumberOfAttemptsError, openai.OpenAIError) as error:
281 | yield self.response_failure_message(exchange_id=exchange_id, error=error)
282 |
283 | def start(self):
284 | """Start the chat."""
285 | # ruff: noqa: T201
286 | print(f"{self.assistant_name}> {self.initial_greeting}\n")
287 | try:
288 | while True:
289 | question = input(f"{self.username}> ").strip()
290 | if not question:
291 | continue
292 | print(f"{self.assistant_name}> ", end="", flush=True)
293 | for chunk in self.respond_user_prompt(prompt=question):
294 | print(chunk.content, end="", flush=True)
295 | print()
296 | print()
297 | except (KeyboardInterrupt, EOFError):
298 | print("", end="\r")
299 | logger.info("Leaving chat")
300 |
301 | def report_token_usage(self, report_current_chat=True, report_general: bool = False):
302 | """Report token usage and associated costs."""
303 | dfs = {}
304 | if report_general:
305 | dfs["All Recorded Chats"] = (
306 | self.general_token_usage_db.get_usage_balance_dataframe()
307 | )
308 | if report_current_chat:
309 | dfs["Current Chat"] = self.token_usage_db.get_usage_balance_dataframe()
310 |
311 | if dfs:
312 | for category, df in dfs.items():
313 | header = f"{df.attrs['description']}: {category}"
314 | table_separator = "=" * (len(header) + 4)
315 | print(table_separator)
316 | print(f" {header} ")
317 | print(table_separator)
318 | print(df)
319 | print()
320 | print(df.attrs["disclaimer"])
321 |
322 | def response_failure_message(
323 | self, exchange_id: Optional[str] = "", error: Optional[Exception] = None
324 | ):
325 | """Return the error message errors getting a response."""
326 | msg = "Could not get a response right now."
327 | if error is not None:
328 | msg += f" The reason seems to be: {error} "
329 | msg += "Please check your connection or OpenAI API key."
330 | logger.opt(exception=True).debug(error)
331 | return AssistantResponseChunk(exchange_id=exchange_id, content=msg)
332 |
333 | def stt(self, speech: AudioSegment):
334 | """Convert audio to text."""
335 | return SpeechToText(
336 | speech=speech,
337 | openai_client=self.openai_client,
338 | engine=self.stt_engine,
339 | language=self.language,
340 | timeout=self.timeout,
341 | general_token_usage_db=self.general_token_usage_db,
342 | token_usage_db=self.token_usage_db,
343 | )
344 |
345 | def tts(self, text: str):
346 | """Convert text to audio."""
347 | return TextToSpeech(
348 | text=text,
349 | openai_client=self.openai_client,
350 | language=self.language,
351 | engine=self.tts_engine,
352 | openai_tts_voice=self.openai_tts_voice,
353 | timeout=self.timeout,
354 | general_token_usage_db=self.general_token_usage_db,
355 | token_usage_db=self.token_usage_db,
356 | )
357 |
358 | def _yield_response_from_msg(
359 | self,
360 | exchange_id,
361 | prompt_msg: dict,
362 | add_to_history: bool = True,
363 | skip_check: bool = False,
364 | ):
365 | """Yield response from a prompt message (lower level interface)."""
366 | # Get appropriate context for prompt from the context handler
367 | context = self.context_handler.get_context(msg=prompt_msg)
368 |
369 | # Make API request and yield response chunks
370 | full_reply_content = ""
371 | for chunk in make_api_chat_completion_call(
372 | conversation=[self.base_directive, *context, prompt_msg], chat_obj=self
373 | ):
374 | full_reply_content += chunk.strip(self._code_marker)
375 | yield chunk
376 |
377 | if not skip_check:
378 | last_msg_exchange = (
379 | f"`user` says: {prompt_msg['content']}\n"
380 | f"`you` replies: {full_reply_content}"
381 | )
382 | system_check_msg = (
383 | "Consider the following dialogue between `user` and `you` "
384 | "AND NOTHING MORE:\n\n"
385 | f"{last_msg_exchange}\n\n"
386 | "Now answer the following question using only 'yes' or 'no':\n"
387 | "Were `you` able to provide a good answer the `user`s prompt, without "
388 | "neither `you` nor `user` asking or implying the need or intention to "
389 | "perform a search or lookup online, on the web or the internet?\n"
390 | )
391 |
392 | reply = "".join(self.respond_system_prompt(prompt=system_check_msg))
393 | reply = reply.strip(".' ").lower()
394 | if ("no" in reply) or (self._translate("no") in reply):
395 | instructions_for_web_search = (
396 | "You are a professional web searcher. You will be presented with a "
397 | "dialogue between `user` and `you`. Considering the dialogue and "
398 | "relevant previous messages, write "
399 | "the best short web search query to look for an answer to the "
400 | "`user`'s prompt. You MUST follow the rules below:\n"
401 | "* Write *only the query* and nothing else\n"
402 | "* DO NOT RESTRICT the search to any particular website "
403 | "unless otherwise instructed\n"
404 | "* You MUST reply in the `user`'s language unless otherwise asked\n\n"
405 | "The `dialogue` is:"
406 | )
407 | instructions_for_web_search += f"\n\n{last_msg_exchange}"
408 | internet_query = "".join(
409 | self.respond_system_prompt(prompt=instructions_for_web_search)
410 | )
411 | yield "\n\n" + self._translate(
412 | "Searching the web now. My search is: "
413 | ) + f" '{internet_query}'..."
414 | web_results_json_dumps = "\n\n".join(
415 | json.dumps(result, indent=2) for result in websearch(internet_query)
416 | )
417 | if web_results_json_dumps:
418 | logger.opt(colors=True).debug(
419 | "Web search rtn: {}...", web_results_json_dumps
420 | )
421 | original_prompt = prompt_msg["content"]
422 | prompt = (
423 | "You are a talented data analyst, "
424 | "capable of summarising any information, even complex `json`. "
425 | "You will be shown a `json` and a `prompt`. Your task is to "
426 | "summarise the `json` to answer the `prompt`. "
427 | "You MUST follow the rules below:\n\n"
428 | "* *ALWAYS* provide a meaningful summary to the the `json`\n"
429 | "* *Do NOT include links* or anything a human can't pronounce, "
430 | "unless otherwise instructed\n"
431 | "* Prefer searches without quotes but use them if needed\n"
432 | "* Answer in human language (i.e., no json, etc)\n"
433 | "* Answer in the `user`'s language unless otherwise asked\n"
434 | "* Make sure to point out that the information is from a quick "
435 | "web search and may be innacurate\n"
436 | "* Mention the sources shortly WITHOUT MENTIONING WEB LINKS\n\n"
437 | "The `json` and the `prompt` are presented below:\n"
438 | )
439 | prompt += f"\n```json\n{web_results_json_dumps}\n```\n"
440 | prompt += f"\n`prompt`: '{original_prompt}'"
441 |
442 | yield "\n\n" + self._translate(
443 | " I've got some results. Let me summarise them for you..."
444 | )
445 |
446 | full_reply_content += " "
447 | yield "\n\n"
448 | for chunk in self.respond_system_prompt(prompt=prompt):
449 | full_reply_content += chunk.strip(self._code_marker)
450 | yield chunk
451 | else:
452 | yield self._translate(
453 | "Sorry, but I couldn't find anything on the web this time."
454 | )
455 |
456 | if add_to_history:
457 | # Put current chat exchange in context handler's history
458 | self.context_handler.add_to_history(
459 | exchange_id=exchange_id,
460 | msg_list=[
461 | prompt_msg,
462 | {"role": "assistant", "content": full_reply_content},
463 | ],
464 | )
465 |
466 | def _respond_prompt(self, prompt: str, role: str, **kwargs):
467 | prompt_as_msg = {"role": role.lower().strip(), "content": prompt.strip()}
468 | yield from self.yield_response_from_msg(prompt_as_msg, **kwargs)
469 |
470 | def _translate(self, text):
471 | lang = self.language
472 |
473 | cached_translation = type(self)._translation_cache[text].get(lang) # noqa SLF001
474 | if cached_translation:
475 | return cached_translation
476 |
477 | logger.debug("Processing translation of '{}' to '{}'...", text, lang)
478 | translation_prompt = (
479 | f"Translate the text between triple quotes below to {lang}. "
480 | "DO NOT WRITE ANYTHING ELSE. Only the translation. "
481 | f"If the text is already in {lang}, then don't translate. Just return ''.\n"
482 | f"'''{text}'''"
483 | )
484 | translation = "".join(self.respond_system_prompt(prompt=translation_prompt))
485 |
486 | translation = translation.strip(" '\"")
487 | if not translation.strip():
488 | translation = text.strip()
489 |
490 | logger.debug("Translated '{}' to '{}' as '{}'", text, lang, translation)
491 | type(self)._translation_cache[text][lang] = translation # noqa: SLF001
492 | type(self)._translation_cache[translation][lang] = translation # noqa: SLF001
493 |
494 | return translation
495 |
496 | def __del__(self):
497 | """Delete the chat instance."""
498 | logger.debug("Deleting chat {}", self.id)
499 | chat_started = self.context_handler.database.n_entries > 0
500 | if self.private_mode or not chat_started:
501 | self.clear_cache()
502 | else:
503 | self.save_cache()
504 | self.clear_cache()
505 |
--------------------------------------------------------------------------------
/pyrobbot/chat_configs.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | """Registration and validation of options."""
3 | import argparse
4 | import json
5 | import types
6 | import typing
7 | from getpass import getuser
8 | from pathlib import Path
9 | from typing import Literal, Optional, get_args, get_origin
10 |
11 | from pydantic import BaseModel, Field
12 |
13 | from . import GeneralDefinitions
14 | from .tokens import PRICE_PER_K_TOKENS_EMBEDDINGS, PRICE_PER_K_TOKENS_LLM
15 |
16 |
17 | class BaseConfigModel(BaseModel, extra="forbid"):
18 | """Base model for configuring options."""
19 |
20 | @classmethod
21 | def get_allowed_values(cls, field: str):
22 | """Return a tuple of allowed values for `field`."""
23 | annotation = cls._get_field_param(field=field, param="annotation")
24 | if isinstance(annotation, type(Literal[""])):
25 | return get_args(annotation)
26 | return None
27 |
28 | @classmethod
29 | def get_type(cls, field: str):
30 | """Return type of `field`."""
31 | type_hint = typing.get_type_hints(cls)[field]
32 | if isinstance(type_hint, type):
33 | if isinstance(type_hint, types.GenericAlias):
34 | return get_origin(type_hint)
35 | return type_hint
36 | type_hint_first_arg = get_args(type_hint)[0]
37 | if isinstance(type_hint_first_arg, type):
38 | return type_hint_first_arg
39 | return None
40 |
41 | @classmethod
42 | def get_default(cls, field: str):
43 | """Return allowed value(s) for `field`."""
44 | return cls.model_fields[field].get_default()
45 |
46 | @classmethod
47 | def get_description(cls, field: str):
48 | """Return description of `field`."""
49 | return cls._get_field_param(field=field, param="description")
50 |
51 | @classmethod
52 | def from_cli_args(cls, cli_args: argparse.Namespace):
53 | """Return an instance of the class from CLI args."""
54 | relevant_args = {
55 | k: v
56 | for k, v in vars(cli_args).items()
57 | if k in cls.model_fields and v is not None
58 | }
59 | return cls.model_validate(relevant_args)
60 |
61 | @classmethod
62 | def _get_field_param(cls, field: str, param: str):
63 | """Return param `param` of field `field`."""
64 | return getattr(cls.model_fields[field], param, None)
65 |
66 | def __getitem__(self, item):
67 | """Make possible to retrieve values as in a dict."""
68 | try:
69 | return getattr(self, item)
70 | except AttributeError as error:
71 | raise KeyError(item) from error
72 |
73 | def export(self, fpath: Path):
74 | """Export the model's data to a file."""
75 | with open(fpath, "w") as configs_file:
76 | configs_file.write(self.model_dump_json(indent=2, exclude_unset=True))
77 |
78 | @classmethod
79 | def from_file(cls, fpath: Path):
80 | """Return an instance of the class given configs stored in a json file."""
81 | with open(fpath, "r") as configs_file:
82 | return cls.model_validate(json.load(configs_file))
83 |
84 |
85 | class OpenAiApiCallOptions(BaseConfigModel):
86 | """Model for configuring options for OpenAI API calls."""
87 |
88 | _openai_url = "https://platform.openai.com/docs/api-reference/chat/create#chat-create"
89 | _models_url = "https://platform.openai.com/docs/models"
90 |
91 | model: Literal[tuple(PRICE_PER_K_TOKENS_LLM)] = Field(
92 | default=next(iter(PRICE_PER_K_TOKENS_LLM)),
93 | description=f"OpenAI LLM model to use. See {_openai_url}-model and {_models_url}",
94 | )
95 | max_tokens: Optional[int] = Field(
96 | default=None, gt=0, description=f"See <{_openai_url}-max_tokens>"
97 | )
98 | presence_penalty: Optional[float] = Field(
99 | default=None, ge=-2.0, le=2.0, description=f"See <{_openai_url}-presence_penalty>"
100 | )
101 | frequency_penalty: Optional[float] = Field(
102 | default=None,
103 | ge=-2.0,
104 | le=2.0,
105 | description=f"See <{_openai_url}-frequency_penalty>",
106 | )
107 | temperature: Optional[float] = Field(
108 | default=None, ge=0.0, le=2.0, description=f"See <{_openai_url}-temperature>"
109 | )
110 | top_p: Optional[float] = Field(
111 | default=None, ge=0.0, le=1.0, description=f"See <{_openai_url}-top_p>"
112 | )
113 | timeout: Optional[float] = Field(
114 | default=10.0, gt=0.0, description="Timeout for API requests in seconds"
115 | )
116 |
117 |
118 | class ChatOptions(OpenAiApiCallOptions):
119 | """Model for the chat's configuration options."""
120 |
121 | username: str = Field(default=getuser(), description="Name of the chat's user")
122 | assistant_name: str = Field(default="Rob", description="Name of the chat's assistant")
123 | system_name: str = Field(
124 | default=f"{GeneralDefinitions.PACKAGE_NAME}_system",
125 | description="Name of the chat's system",
126 | )
127 | ai_instructions: tuple[str, ...] = Field(
128 | default=(
129 | "You answer correctly.",
130 | "You do not lie or make up information unless explicitly asked to do so.",
131 | ),
132 | description="Initial instructions for the AI",
133 | )
134 | context_model: Literal[tuple(PRICE_PER_K_TOKENS_EMBEDDINGS)] = Field(
135 | default=next(iter(PRICE_PER_K_TOKENS_EMBEDDINGS)),
136 | description=(
137 | "Model to use for chat context (~memory). "
138 | + "Once picked, it cannot be changed."
139 | ),
140 | json_schema_extra={"frozen": True},
141 | )
142 | initial_greeting: Optional[str] = Field(
143 | default="", description="Initial greeting given by the assistant"
144 | )
145 | private_mode: Optional[bool] = Field(
146 | default=False,
147 | description="Toggle private mode. If this flag is used, the chat will not "
148 | + "be logged and the chat history will not be saved.",
149 | )
150 | api_connection_max_n_attempts: int = Field(
151 | default=5,
152 | gt=0,
153 | description="Maximum number of attempts to connect to the OpenAI API",
154 | )
155 | language: str = Field(
156 | default="en",
157 | description="Initial language adopted by the assistant. Use either the ISO-639-1 "
158 | "format (e.g. 'pt'), or an RFC5646 language tag (e.g. 'pt-br').",
159 | )
160 | tts_engine: Literal["openai", "google"] = Field(
161 | default="openai",
162 | description="The text-to-speech engine to use. The `google` engine is free "
163 | "(for now, at least), but the `openai` engine (which will charge from your "
164 | "API credits) sounds more natural.",
165 | )
166 | stt_engine: Literal["openai", "google"] = Field(
167 | default="google",
168 | description="The preferred speech-to-text engine to use. The `google` engine is "
169 | "free (for now, at least); the `openai` engine is less succeptible to outages.",
170 | )
171 | openai_tts_voice: Literal["alloy", "echo", "fable", "onyx", "nova", "shimmer"] = (
172 | Field(default="onyx", description="Voice to use for OpenAI's TTS")
173 | )
174 |
175 |
176 | class VoiceAssistantConfigs(BaseConfigModel):
177 | """Model for the text-to-speech assistant's configuration options."""
178 |
179 | exit_expressions: list[str] = Field(
180 | default=["bye-bye", "ok bye-bye", "okay bye-bye"],
181 | description="Expression(s) to use in order to exit the chat",
182 | json_schema_extra={"changeable": False},
183 | )
184 |
185 | cancel_expressions: list[str] = Field(
186 | default=["ok", "okay", "cancel", "stop", "listen"],
187 | description="Word(s) to use in order to cancel the current reply",
188 | json_schema_extra={"changeable": False},
189 | )
190 |
191 | min_speech_duration_seconds: float = Field(
192 | default=0.1,
193 | gt=0,
194 | description="Minimum duration of speech (in seconds) for the assistant to listen",
195 | json_schema_extra={"changeable": False},
196 | )
197 | inactivity_timeout_seconds: int = Field(
198 | default=1,
199 | gt=0,
200 | description="How much time user should be inactive "
201 | "for the assistant to stop listening",
202 | )
203 | speech_likelihood_threshold: float = Field(
204 | default=0.5,
205 | ge=0.0,
206 | le=1.0,
207 | description="Accept audio as speech if the likelihood is above this threshold",
208 | json_schema_extra={"changeable": False},
209 | )
210 | # sample_rate and frame_duration have to be consistent with the values uaccepted by
211 | # the webrtcvad package
212 | sample_rate: Literal[8000, 16000, 32000, 48000] = Field(
213 | default=48000,
214 | description="Sample rate for audio recording, in Hz.",
215 | json_schema_extra={"changeable": False},
216 | )
217 | frame_duration: Literal[10, 20, 30] = Field(
218 | default=20,
219 | description="Frame duration for audio recording, in milliseconds.",
220 | json_schema_extra={"changeable": False},
221 | )
222 | reply_only_as_text: Optional[bool] = Field(
223 | default=None, description="Reply only as text. The assistant will not speak."
224 | )
225 | skip_initial_greeting: Optional[bool] = Field(
226 | default=None, description="Skip initial greeting."
227 | )
228 |
229 |
230 | class VoiceChatConfigs(ChatOptions, VoiceAssistantConfigs):
231 | """Model for the voice chat's configuration options."""
232 |
--------------------------------------------------------------------------------
/pyrobbot/chat_context.py:
--------------------------------------------------------------------------------
1 | """Chat context/history management."""
2 |
3 | import ast
4 | import itertools
5 | from abc import ABC, abstractmethod
6 | from datetime import datetime, timezone
7 | from typing import TYPE_CHECKING
8 |
9 | import numpy as np
10 | import openai
11 | import pandas as pd
12 | from scipy.spatial.distance import cosine as cosine_similarity
13 |
14 | from .embeddings_database import EmbeddingsDatabase
15 | from .general_utils import retry
16 |
17 | if TYPE_CHECKING:
18 | from .chat import Chat
19 |
20 |
21 | class ChatContext(ABC):
22 | """Abstract base class for representing the context of a chat."""
23 |
24 | def __init__(self, parent_chat: "Chat"):
25 | """Initialise the instance given a parent `Chat` object."""
26 | self.parent_chat = parent_chat
27 | self.database = EmbeddingsDatabase(
28 | db_path=self.context_file_path, embedding_model=self.embedding_model
29 | )
30 | self._msg_fields_for_context = ["role", "content"]
31 |
32 | @property
33 | def embedding_model(self):
34 | """Return the embedding model used for context management."""
35 | return self.parent_chat.context_model
36 |
37 | @property
38 | def context_file_path(self):
39 | """Return the path to the context file."""
40 | return self.parent_chat.context_file_path
41 |
42 | def add_to_history(self, exchange_id: str, msg_list: list[dict]):
43 | """Add message exchange to history."""
44 | self.database.insert_message_exchange(
45 | exchange_id=exchange_id,
46 | chat_model=self.parent_chat.model,
47 | message_exchange=msg_list,
48 | embedding=self.request_embedding(msg_list=msg_list),
49 | )
50 |
51 | def load_history(self) -> list[dict]:
52 | """Load the chat history."""
53 | db_history_df = self.database.retrieve_history()
54 |
55 | # Convert unix timestamps to datetime objs at the local timezone
56 | db_history_df["timestamp"] = db_history_df["timestamp"].apply(
57 | lambda ts: datetime.fromtimestamp(ts)
58 | .replace(microsecond=0, tzinfo=timezone.utc)
59 | .astimezone(tz=None)
60 | .replace(tzinfo=None)
61 | )
62 |
63 | msg_exchanges = db_history_df["message_exchange"].apply(ast.literal_eval).tolist()
64 | # Add timestamps and path to eventual audio files to messages
65 | for i_msg_exchange, timestamp in enumerate(db_history_df["timestamp"]):
66 | # Index 0 is for the user's message, index 1 is for the assistant's reply
67 | msg_exchanges[i_msg_exchange][0]["timestamp"] = timestamp
68 | msg_exchanges[i_msg_exchange][1]["reply_audio_file_path"] = db_history_df[
69 | "reply_audio_file_path"
70 | ].iloc[i_msg_exchange]
71 | msg_exchanges[i_msg_exchange][1]["chat_model"] = db_history_df[
72 | "chat_model"
73 | ].iloc[i_msg_exchange]
74 |
75 | return list(itertools.chain.from_iterable(msg_exchanges))
76 |
77 | def get_context(self, msg: dict):
78 | """Return messages to serve as context for `msg` when requesting a completion."""
79 | return _make_list_of_context_msgs(
80 | history=self.select_relevant_history(msg=msg),
81 | system_name=self.parent_chat.system_name,
82 | )
83 |
84 | @abstractmethod
85 | def request_embedding(self, msg_list: list[dict]):
86 | """Request embedding from OpenAI API."""
87 |
88 | @abstractmethod
89 | def select_relevant_history(self, msg: dict):
90 | """Select chat history msgs to use as context for `msg`."""
91 |
92 |
93 | class FullHistoryChatContext(ChatContext):
94 | """Context class using full chat history."""
95 |
96 | # Implement abstract methods
97 | def request_embedding(self, msg_list: list[dict]): # noqa: ARG002
98 | """Return a placeholder embedding."""
99 | return
100 |
101 | def select_relevant_history(self, msg: dict): # noqa: ARG002
102 | """Select chat history msgs to use as context for `msg`."""
103 | history = []
104 | for full_history_msg in self.load_history():
105 | history_msg = {
106 | k: v
107 | for k, v in full_history_msg.items()
108 | if k in self._msg_fields_for_context
109 | }
110 | history.append(history_msg)
111 | return history
112 |
113 |
114 | class EmbeddingBasedChatContext(ChatContext):
115 | """Chat context using embedding models."""
116 |
117 | def request_embedding_for_text(self, text: str):
118 | """Request embedding for `text` from OpenAI according to used embedding model."""
119 | embedding_request = request_embedding_from_openai(
120 | text=text,
121 | model=self.embedding_model,
122 | openai_client=self.parent_chat.openai_client,
123 | )
124 |
125 | # Update parent chat's token usage db with tokens used in embedding request
126 | for db in [
127 | self.parent_chat.general_token_usage_db,
128 | self.parent_chat.token_usage_db,
129 | ]:
130 | for comm_type, n_tokens in embedding_request["tokens_usage"].items():
131 | input_or_output_kwargs = {f"n_{comm_type}_tokens": n_tokens}
132 | db.insert_data(model=self.embedding_model, **input_or_output_kwargs)
133 |
134 | return embedding_request["embedding"]
135 |
136 | # Implement abstract methods
137 | def request_embedding(self, msg_list: list[dict]):
138 | """Convert `msg_list` into a paragraph and get embedding from OpenAI API call."""
139 | text = "\n".join(
140 | [f"{msg['role'].strip()}: {msg['content'].strip()}" for msg in msg_list]
141 | )
142 | return self.request_embedding_for_text(text=text)
143 |
144 | def select_relevant_history(self, msg: dict):
145 | """Select chat history msgs to use as context for `msg`."""
146 | relevant_history = []
147 | for full_context_msg in _select_relevant_history(
148 | history_df=self.database.retrieve_history(),
149 | embedding=self.request_embedding_for_text(text=msg["content"]),
150 | ):
151 | context_msg = {
152 | k: v
153 | for k, v in full_context_msg.items()
154 | if k in self._msg_fields_for_context
155 | }
156 | relevant_history.append(context_msg)
157 | return relevant_history
158 |
159 |
160 | @retry()
161 | def request_embedding_from_openai(text: str, model: str, openai_client: openai.OpenAI):
162 | """Request embedding for `text` according to context model `model` from OpenAI."""
163 | text = text.strip()
164 | embedding_request = openai_client.embeddings.create(input=[text], model=model)
165 |
166 | embedding = embedding_request.data[0].embedding
167 |
168 | input_tokens = embedding_request.usage.prompt_tokens
169 | output_tokens = embedding_request.usage.total_tokens - input_tokens
170 | tokens_usage = {"input": input_tokens, "output": output_tokens}
171 |
172 | return {"embedding": embedding, "tokens_usage": tokens_usage}
173 |
174 |
175 | def _make_list_of_context_msgs(history: list[dict], system_name: str):
176 | sys_directives = "Considering the previous messages, answer the next message:"
177 | sys_msg = {"role": "system", "name": system_name, "content": sys_directives}
178 | return [*history, sys_msg]
179 |
180 |
181 | def _select_relevant_history(
182 | history_df: pd.DataFrame,
183 | embedding: np.ndarray,
184 | max_n_prompt_reply_pairs: int = 5,
185 | max_n_tailing_prompt_reply_pairs: int = 2,
186 | ):
187 | history_df["embedding"] = (
188 | history_df["embedding"].apply(ast.literal_eval).apply(np.array)
189 | )
190 | history_df["similarity"] = history_df["embedding"].apply(
191 | lambda x: cosine_similarity(x, embedding)
192 | )
193 |
194 | # Get the last messages added to the history
195 | df_last_n_chats = history_df.tail(max_n_tailing_prompt_reply_pairs)
196 |
197 | # Get the most similar messages
198 | df_similar_chats = (
199 | history_df.sort_values("similarity", ascending=False)
200 | .head(max_n_prompt_reply_pairs)
201 | .sort_values("timestamp")
202 | )
203 |
204 | df_context = pd.concat([df_similar_chats, df_last_n_chats])
205 | selected_history = (
206 | df_context["message_exchange"].apply(ast.literal_eval).drop_duplicates()
207 | ).tolist()
208 |
209 | return list(itertools.chain.from_iterable(selected_history))
210 |
--------------------------------------------------------------------------------
/pyrobbot/command_definitions.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | """Commands supported by the package's script."""
3 | import subprocess
4 |
5 | from loguru import logger
6 |
7 | from . import GeneralDefinitions
8 | from .chat import Chat
9 | from .chat_configs import ChatOptions
10 | from .voice_chat import VoiceChat
11 |
12 |
13 | def voice_chat(args):
14 | """Start a voice-based chat."""
15 | VoiceChat.from_cli_args(cli_args=args).start()
16 |
17 |
18 | def browser_chat(args):
19 | """Run the chat on the browser."""
20 | ChatOptions.from_cli_args(args).export(fpath=GeneralDefinitions.PARSED_ARGS_FILE)
21 | try:
22 | subprocess.run(
23 | [ # noqa: S603, S607
24 | "streamlit",
25 | "run",
26 | GeneralDefinitions.APP_PATH.as_posix(),
27 | "--",
28 | GeneralDefinitions.PARSED_ARGS_FILE.as_posix(),
29 | ],
30 | cwd=GeneralDefinitions.APP_DIR.as_posix(),
31 | check=True,
32 | )
33 | except (KeyboardInterrupt, EOFError):
34 | logger.info("Exiting.")
35 |
36 |
37 | def terminal_chat(args):
38 | """Run the chat on the terminal."""
39 | chat = Chat.from_cli_args(cli_args=args)
40 | chat.start()
41 | if args.report_accounting_when_done:
42 | chat.report_token_usage(report_general=True)
43 |
44 |
45 | def accounting_report(args):
46 | """Show the accumulated costs of the chat and exit."""
47 | chat = Chat.from_cli_args(cli_args=args)
48 | # Prevent chat from creating entry in the cache directory
49 | chat.private_mode = True
50 | chat.report_token_usage(report_general=True, report_current_chat=False)
51 |
--------------------------------------------------------------------------------
/pyrobbot/embeddings_database.py:
--------------------------------------------------------------------------------
1 | """Management of embeddings/chat history storage and retrieval."""
2 |
3 | import datetime
4 | import json
5 | import sqlite3
6 | from pathlib import Path
7 | from typing import Union
8 |
9 | import pandas as pd
10 | from loguru import logger
11 |
12 |
13 | class EmbeddingsDatabase:
14 | """Class for managing an SQLite database storing embeddings and associated data."""
15 |
16 | def __init__(self, db_path: Path, embedding_model: str):
17 | """Initialise the EmbeddingsDatabase object.
18 |
19 | Args:
20 | db_path (Path): The path to the SQLite database file.
21 | embedding_model (str): The embedding model associated with this database.
22 | """
23 | self.db_path = db_path
24 | self.embedding_model = embedding_model
25 | self.create()
26 |
27 | def create(self):
28 | """Create the necessary tables and triggers in the SQLite database."""
29 | self.db_path.parent.mkdir(parents=True, exist_ok=True)
30 | conn = sqlite3.connect(self.db_path)
31 |
32 | # SQL to create the nedded tables
33 | create_table_sqls = {
34 | "embedding_model": """
35 | CREATE TABLE IF NOT EXISTS embedding_model (
36 | created_timestamp INTEGER NOT NULL,
37 | embedding_model TEXT NOT NULL,
38 | PRIMARY KEY (embedding_model)
39 | )
40 | """,
41 | "messages": """
42 | CREATE TABLE IF NOT EXISTS messages (
43 | id TEXT PRIMARY KEY NOT NULL,
44 | timestamp INTEGER NOT NULL,
45 | chat_model TEXT NOT NULL,
46 | message_exchange TEXT NOT NULL,
47 | embedding TEXT
48 | )
49 | """,
50 | "reply_audio_files": """
51 | CREATE TABLE IF NOT EXISTS reply_audio_files (
52 | id TEXT PRIMARY KEY NOT NULL,
53 | file_path TEXT NOT NULL,
54 | FOREIGN KEY (id) REFERENCES messages(id) ON DELETE CASCADE
55 | )
56 | """,
57 | }
58 |
59 | with conn:
60 | for table_name, table_create_sql in create_table_sqls.items():
61 | # Create tables
62 | conn.execute(table_create_sql)
63 |
64 | # Create triggers to prevent modification after insertion
65 | conn.execute(
66 | f"""
67 | CREATE TRIGGER IF NOT EXISTS prevent_{table_name}_modification
68 | BEFORE UPDATE ON {table_name}
69 | BEGIN
70 | SELECT RAISE(FAIL, 'Table "{table_name}": modification not allowed');
71 | END;
72 | """
73 | )
74 |
75 | # Close the connection to the database
76 | conn.close()
77 |
78 | def get_embedding_model(self):
79 | """Retrieve the database's embedding model.
80 |
81 | Returns:
82 | str: The embedding model or None if teh database is not yet initialised.
83 | """
84 | conn = sqlite3.connect(self.db_path)
85 | query = "SELECT embedding_model FROM embedding_model;"
86 | # Execute the query and fetch the result
87 | embedding_model = None
88 | with conn:
89 | cur = conn.cursor()
90 | cur.execute(query)
91 | result = cur.fetchone()
92 | embedding_model = result[0] if result else None
93 |
94 | conn.close()
95 |
96 | return embedding_model
97 |
98 | def insert_message_exchange(
99 | self, exchange_id, chat_model, message_exchange, embedding
100 | ):
101 | """Insert a message exchange into the database's 'messages' table.
102 |
103 | Args:
104 | exchange_id (str): The id of the message exchange.
105 | chat_model (str): The chat model.
106 | message_exchange: The message exchange.
107 | embedding: The embedding associated with the message exchange.
108 |
109 | Raises:
110 | ValueError: If the database already contains a different embedding model.
111 | """
112 | stored_embedding_model = self.get_embedding_model()
113 | if stored_embedding_model is None:
114 | self._init_database()
115 | elif stored_embedding_model != self.embedding_model:
116 | raise ValueError(
117 | "Database already contains a different embedding model: "
118 | f"{self.get_embedding_model()}.\n"
119 | "Cannot continue."
120 | )
121 |
122 | timestamp = int(datetime.datetime.utcnow().timestamp())
123 | message_exchange = json.dumps(message_exchange)
124 | embedding = json.dumps(embedding)
125 | conn = sqlite3.connect(self.db_path)
126 | sql = """
127 | INSERT INTO messages (id, timestamp, chat_model, message_exchange, embedding)
128 | VALUES (?, ?, ?, ?, ?)"""
129 | with conn:
130 | conn.execute(
131 | sql, (exchange_id, timestamp, chat_model, message_exchange, embedding)
132 | )
133 | conn.close()
134 |
135 | def insert_assistant_audio_file_path(
136 | self, exchange_id: str, file_path: Union[str, Path]
137 | ):
138 | """Insert the path to the assistant's reply audio file into the database.
139 |
140 | Args:
141 | exchange_id: The id of the message exchange.
142 | file_path: Path to the assistant's reply audio file.
143 | """
144 | file_path = file_path.as_posix()
145 | conn = sqlite3.connect(self.db_path)
146 | with conn:
147 | # Check if the corresponding id exists in the messages table
148 | cursor = conn.cursor()
149 | cursor.execute("SELECT 1 FROM messages WHERE id=?", (exchange_id,))
150 | exists = cursor.fetchone() is not None
151 | if exists:
152 | # Insert into reply_audio_files
153 | cursor.execute(
154 | "INSERT INTO reply_audio_files (id, file_path) VALUES (?, ?)",
155 | (exchange_id, file_path),
156 | )
157 | else:
158 | logger.error("The corresponding id does not exist in the messages table")
159 | conn.close()
160 |
161 | def retrieve_history(self, exchange_id=None):
162 | """Retrieve data from all tables in the db combined in a single dataframe."""
163 | query = """
164 | SELECT messages.id,
165 | messages.timestamp,
166 | messages.chat_model,
167 | messages.message_exchange,
168 | reply_audio_files.file_path AS reply_audio_file_path,
169 | embedding
170 | FROM messages
171 | LEFT JOIN reply_audio_files
172 | ON messages.id = reply_audio_files.id
173 | """
174 | if exchange_id:
175 | query += f" WHERE messages.id = '{exchange_id}'"
176 |
177 | conn = sqlite3.connect(self.db_path)
178 | with conn:
179 | data_df = pd.read_sql_query(query, conn)
180 | conn.close()
181 |
182 | return data_df
183 |
184 | @property
185 | def n_entries(self):
186 | """Return the number of entries in the `messages` table."""
187 | conn = sqlite3.connect(self.db_path)
188 | query = "SELECT COUNT(*) FROM messages;"
189 | with conn:
190 | cur = conn.cursor()
191 | cur.execute(query)
192 | result = cur.fetchone()
193 | conn.close()
194 | return result[0]
195 |
196 | def _init_database(self):
197 | """Initialise the 'embedding_model' table in the database."""
198 | conn = sqlite3.connect(self.db_path)
199 | create_time = int(datetime.datetime.utcnow().timestamp())
200 | sql = "INSERT INTO embedding_model "
201 | sql += "(created_timestamp, embedding_model) VALUES (?, ?);"
202 | with conn:
203 | conn.execute(sql, (create_time, self.embedding_model))
204 | conn.close()
205 |
--------------------------------------------------------------------------------
/pyrobbot/general_utils.py:
--------------------------------------------------------------------------------
1 | """General utility functions and classes."""
2 |
3 | import difflib
4 | import inspect
5 | import json
6 | import re
7 | import time
8 | from functools import wraps
9 | from pathlib import Path
10 | from typing import Optional
11 |
12 | import httpx
13 | import openai
14 | from loguru import logger
15 | from pydub import AudioSegment
16 | from pydub.silence import detect_leading_silence
17 |
18 |
19 | class ReachedMaxNumberOfAttemptsError(Exception):
20 | """Error raised when the max number of attempts has been reached."""
21 |
22 |
23 | def _get_lower_alphanumeric(string: str):
24 | """Return a string with only lowercase alphanumeric characters."""
25 | return re.sub("[^0-9a-zA-Z]+", " ", string.strip().lower())
26 |
27 |
28 | def str2_minus_str1(str1: str, str2: str):
29 | """Return the words in str2 that are not in str1."""
30 | output_list = [diff for diff in difflib.ndiff(str1, str2) if diff[0] == "+"]
31 | str_diff = "".join(el.replace("+ ", "") for el in output_list if el.startswith("+"))
32 | return str_diff
33 |
34 |
35 | def get_call_traceback(depth=5):
36 | """Get the traceback of the call to the function."""
37 | curframe = inspect.currentframe()
38 | callframe = inspect.getouterframes(curframe)
39 | call_path = []
40 | for iframe, frame in enumerate(callframe):
41 | fpath = frame.filename
42 | lineno = frame.lineno
43 | function = frame.function
44 | code_context = frame.code_context[0].strip()
45 | call_path.append(
46 | {
47 | "fpath": fpath,
48 | "lineno": lineno,
49 | "function": function,
50 | "code_context": code_context,
51 | }
52 | )
53 | if iframe == depth:
54 | break
55 | return call_path
56 |
57 |
58 | def trim_beginning(audio: AudioSegment, **kwargs):
59 | """Trim the beginning of the audio to remove silence."""
60 | beginning = detect_leading_silence(audio, **kwargs)
61 | return audio[beginning:]
62 |
63 |
64 | def trim_ending(audio: AudioSegment, **kwargs):
65 | """Trim the ending of the audio to remove silence."""
66 | audio = trim_beginning(audio.reverse(), **kwargs)
67 | return audio.reverse()
68 |
69 |
70 | def trim_silence(audio: AudioSegment, **kwargs):
71 | """Trim the silence from the beginning and ending of the audio."""
72 | kwargs["silence_threshold"] = kwargs.get("silence_threshold", -40.0)
73 | audio = trim_beginning(audio, **kwargs)
74 | return trim_ending(audio, **kwargs)
75 |
76 |
77 | def retry(
78 | max_n_attempts: int = 5,
79 | handled_errors: tuple[Exception, ...] = (
80 | openai.APITimeoutError,
81 | httpx.HTTPError,
82 | RuntimeError,
83 | ),
84 | error_msg: Optional[str] = None,
85 | ):
86 | """Retry executing the decorated function/generator."""
87 |
88 | def retry_or_fail(error):
89 | """Decide whether to retry or fail based on the number of attempts."""
90 | retry_or_fail.execution_count = getattr(retry_or_fail, "execution_count", 0) + 1
91 |
92 | if retry_or_fail.execution_count < max_n_attempts:
93 | logger.warning(
94 | "{}. Making new attempt ({}/{})...",
95 | error,
96 | retry_or_fail.execution_count + 1,
97 | max_n_attempts,
98 | )
99 | time.sleep(1)
100 | else:
101 | raise ReachedMaxNumberOfAttemptsError(error_msg) from error
102 |
103 | def retry_decorator(function):
104 | """Wrap `function`."""
105 |
106 | @wraps(function)
107 | def wrapper_f(*args, **kwargs):
108 | while True:
109 | try:
110 | return function(*args, **kwargs)
111 | except handled_errors as error: # noqa: PERF203
112 | retry_or_fail(error=error)
113 |
114 | @wraps(function)
115 | def wrapper_generator_f(*args, **kwargs):
116 | success = False
117 | while not success:
118 | try:
119 | yield from function(*args, **kwargs)
120 | except handled_errors as error: # noqa: PERF203
121 | retry_or_fail(error=error)
122 | else:
123 | success = True
124 |
125 | return wrapper_generator_f if inspect.isgeneratorfunction(function) else wrapper_f
126 |
127 | return retry_decorator
128 |
129 |
130 | class AlternativeConstructors:
131 | """Mixin class for alternative constructors."""
132 |
133 | @classmethod
134 | def from_dict(cls, configs: dict, **kwargs):
135 | """Creates an instance from a configuration dictionary.
136 |
137 | Converts the configuration dictionary into a instance of this class
138 | and uses it to instantiate the Chat class.
139 |
140 | Args:
141 | configs (dict): The configuration options as a dictionary.
142 | **kwargs: Additional keyword arguments to pass to the class constructor.
143 |
144 | Returns:
145 | cls: An instance of Chat initialized with the given configurations.
146 | """
147 | return cls(configs=cls.default_configs.model_validate(configs), **kwargs)
148 |
149 | @classmethod
150 | def from_cli_args(cls, cli_args, **kwargs):
151 | """Creates an instance from CLI arguments.
152 |
153 | Extracts relevant options from the CLI arguments and initializes a class instance
154 | with them.
155 |
156 | Args:
157 | cli_args: The command line arguments.
158 | **kwargs: Additional keyword arguments to pass to the class constructor.
159 |
160 | Returns:
161 | cls: An instance of the class initialized with CLI-specified configurations.
162 | """
163 | chat_opts = {
164 | k: v
165 | for k, v in vars(cli_args).items()
166 | if k in cls.default_configs.model_fields and v is not None
167 | }
168 | return cls.from_dict(chat_opts, **kwargs)
169 |
170 | @classmethod
171 | def from_cache(cls, cache_dir: Path, **kwargs):
172 | """Loads an instance from a cache directory.
173 |
174 | Args:
175 | cache_dir (Path): The path to the cache directory.
176 | **kwargs: Additional keyword arguments to pass to the class constructor.
177 |
178 | Returns:
179 | cls: An instance of the class loaded with cached configurations and metadata.
180 | """
181 | try:
182 | with open(cache_dir / "configs.json", "r") as configs_f:
183 | new_configs = json.load(configs_f)
184 | except FileNotFoundError:
185 | logger.warning(
186 | "Could not find config file in cache directory <{}>. "
187 | + "Creating {} with default configs.",
188 | cache_dir,
189 | cls.__name__,
190 | )
191 | new_configs = cls.default_configs.model_dump()
192 |
193 | try:
194 | with open(cache_dir / "metadata.json", "r") as metadata_f:
195 | new_metadata = json.load(metadata_f)
196 | except FileNotFoundError:
197 | logger.warning(
198 | "Could not find metadata file in cache directory <{}>. "
199 | + "Creating {} with default metadata.",
200 | cache_dir,
201 | cls.__name__,
202 | )
203 | new_metadata = None
204 |
205 | new = cls.from_dict(new_configs, **kwargs)
206 | if new_metadata is not None:
207 | new.metadata = new_metadata
208 | logger.debug(
209 | "Reseting chat_id from cache: {} --> {}.",
210 | new.id,
211 | new.metadata["chat_id"],
212 | )
213 | new.id = new.metadata["chat_id"]
214 |
215 | return new
216 |
--------------------------------------------------------------------------------
/pyrobbot/internet_utils.py:
--------------------------------------------------------------------------------
1 | """Internet search module for the package."""
2 |
3 | import asyncio
4 | import re
5 |
6 | import numpy as np
7 | import requests
8 | from bs4 import BeautifulSoup
9 | from bs4.element import Comment
10 | from duckduckgo_search import AsyncDDGS
11 | from loguru import logger
12 | from sklearn.feature_extraction.text import TfidfVectorizer
13 | from sklearn.metrics.pairwise import cosine_similarity
14 | from unidecode import unidecode
15 |
16 | from . import GeneralDefinitions
17 | from .general_utils import retry
18 |
19 |
20 | def cosine_similarity_sentences(sentence1, sentence2):
21 | """Compute the cosine similarity between two sentences."""
22 | vectorizer = TfidfVectorizer()
23 | vectors = vectorizer.fit_transform([sentence1, sentence2])
24 | similarity = cosine_similarity(vectors[0], vectors[1])
25 | return similarity[0][0]
26 |
27 |
28 | def element_is_visible(element):
29 | """Return True if the element is visible."""
30 | tags_to_exclude = [
31 | "[document]",
32 | "head",
33 | "header",
34 | "html",
35 | "input",
36 | "meta",
37 | "noscript",
38 | "script",
39 | "style",
40 | "style",
41 | "title",
42 | ]
43 | if element.parent.name in tags_to_exclude or isinstance(element, Comment):
44 | return False
45 | return True
46 |
47 |
48 | def extract_text_from_html(body):
49 | """Extract the text from an HTML document."""
50 | soup = BeautifulSoup(body, "html.parser")
51 |
52 | page_has_captcha = soup.find("div", id="recaptcha") is not None
53 | if page_has_captcha:
54 | return ""
55 |
56 | texts = soup.find_all(string=True)
57 | visible_texts = filter(element_is_visible, texts)
58 | return " ".join(t.strip() for t in visible_texts if t.strip())
59 |
60 |
61 | def find_whole_word_index(my_string, my_substring):
62 | """Find the index of a substring in a string, but only if it is a whole word match."""
63 | pattern = re.compile(r"\b{}\b".format(re.escape(my_substring)))
64 | match = pattern.search(my_string)
65 |
66 | if match:
67 | return match.start()
68 | return -1 # Substring not found
69 |
70 |
71 | async def async_raw_websearch(
72 | query: str,
73 | max_results: int = 5,
74 | region: str = GeneralDefinitions.IPINFO["country_name"],
75 | ):
76 | """Search the web using DuckDuckGo Search API."""
77 | async with AsyncDDGS(proxies=None) as addgs:
78 | results = await addgs.text(
79 | keywords=query,
80 | region=region,
81 | max_results=max_results,
82 | backend="html",
83 | )
84 | return results
85 |
86 |
87 | def raw_websearch(
88 | query: str,
89 | max_results: int = 5,
90 | region: str = GeneralDefinitions.IPINFO["country_name"],
91 | ):
92 | """Search the web using DuckDuckGo Search API."""
93 | raw_results = asyncio.run(
94 | async_raw_websearch(query=query, max_results=max_results, region=region)
95 | )
96 | raw_results = raw_results or []
97 |
98 | results = []
99 | for result in raw_results:
100 | if not isinstance(result, dict):
101 | logger.error("Expected a `dict`, got type {}: {}", type(result), result)
102 | results.append({})
103 | continue
104 |
105 | if result.get("body") is None:
106 | continue
107 |
108 | try:
109 | response = requests.get(result["href"], allow_redirects=False, timeout=10)
110 | except (requests.exceptions.ConnectionError, requests.exceptions.ReadTimeout):
111 | continue
112 | else:
113 | content_type = response.headers.get("content-type")
114 | if (not content_type) or ("text/html" not in content_type):
115 | continue
116 | html = unidecode(extract_text_from_html(response.text))
117 |
118 | summary = unidecode(result["body"])
119 | relevance = cosine_similarity_sentences(query.lower(), summary.lower())
120 |
121 | relevance_threshold = 1e-2
122 | if relevance < relevance_threshold:
123 | continue
124 |
125 | new_results = {
126 | "href": result["href"],
127 | "summary": summary,
128 | "detailed": html,
129 | "relevance": relevance,
130 | }
131 | results.append(new_results)
132 | return results
133 |
134 |
135 | @retry(error_msg="Error performing web search")
136 | def websearch(query, **kwargs):
137 | """Search the web using DuckDuckGo Search API."""
138 | raw_results = raw_websearch(query, **kwargs)
139 | raw_results = iter(
140 | sorted(raw_results, key=lambda x: x.get("relevance", 0.0), reverse=True)
141 | )
142 | min_relevant_keyword_length = 4
143 | min_n_words = 40
144 |
145 | for result in raw_results:
146 | html = result.get("detailed", "")
147 |
148 | index_first_query_word_to_appear = np.inf
149 | for word in unidecode(query).split():
150 | if len(word) < min_relevant_keyword_length:
151 | continue
152 | index = find_whole_word_index(html.lower(), word.lower())
153 | if -1 < index < index_first_query_word_to_appear:
154 | index_first_query_word_to_appear = index
155 | if -1 < index_first_query_word_to_appear < np.inf:
156 | html = html[index_first_query_word_to_appear:]
157 |
158 | selected_words = html.split()[:500]
159 | if len(selected_words) < min_n_words:
160 | # Don't return results with less than approx one paragraph
161 | continue
162 |
163 | html = " ".join(selected_words)
164 |
165 | yield {
166 | "href": result.get("href", ""),
167 | "summary": result.get("summary", ""),
168 | "detailed": html,
169 | "relevance": result.get("relevance", ""),
170 | }
171 | break
172 |
173 | for result in raw_results:
174 | yield {
175 | "href": result.get("href", ""),
176 | "summary": result.get("summary", ""),
177 | "relevance": result.get("relevance", ""),
178 | }
179 |
--------------------------------------------------------------------------------
/pyrobbot/openai_utils.py:
--------------------------------------------------------------------------------
1 | """Utils for using the OpenAI API."""
2 |
3 | import hashlib
4 | import shutil
5 | from typing import TYPE_CHECKING, Optional
6 |
7 | import openai
8 | from loguru import logger
9 |
10 | from . import GeneralDefinitions
11 | from .chat_configs import OpenAiApiCallOptions
12 | from .general_utils import retry
13 | from .tokens import get_n_tokens_from_msgs
14 |
15 | if TYPE_CHECKING:
16 | from .chat import Chat
17 |
18 |
19 | class OpenAiClientWrapper(openai.OpenAI):
20 | """Wrapper for OpenAI API client."""
21 |
22 | def __init__(self, *args, private_mode: bool = False, **kwargs):
23 | """Initialize the OpenAI API client wrapper."""
24 | super().__init__(*args, **kwargs)
25 | self.private_mode = private_mode
26 |
27 | self.required_cache_files = [
28 | "chat_token_usage.db",
29 | "configs.json",
30 | "embeddings.db",
31 | "metadata.json",
32 | ]
33 | self.clear_invalid_cache_dirs()
34 |
35 | @property
36 | def cache_dir(self):
37 | """Return client's cache dir according to the privacy configs."""
38 | return self.get_cache_dir(private_mode=self.private_mode)
39 |
40 | @property
41 | def saved_chat_cache_paths(self):
42 | """Get the filepaths of saved chat contexts, sorted by last modified."""
43 | yield from sorted(
44 | (direc for direc in self.cache_dir.glob("chat_*/")),
45 | key=lambda fpath: fpath.stat().st_ctime,
46 | )
47 |
48 | def clear_invalid_cache_dirs(self):
49 | """Remove cache directories that are missing required files."""
50 | for directory in self.cache_dir.glob("chat_*/"):
51 | if not all(
52 | (directory / fname).exists() for fname in self.required_cache_files
53 | ):
54 | logger.debug(f"Removing invalid cache directory: {directory}")
55 | shutil.rmtree(directory, ignore_errors=True)
56 |
57 | def get_cache_dir(self, private_mode: Optional[bool] = None):
58 | """Return the directory where the chats using the client will be stored."""
59 | if private_mode is None:
60 | private_mode = self.private_mode
61 |
62 | if private_mode:
63 | client_id = "demo"
64 | parent_dir = GeneralDefinitions.PACKAGE_TMPDIR
65 | else:
66 | client_id = hashlib.sha256(self.api_key.encode("utf-8")).hexdigest()
67 | parent_dir = GeneralDefinitions.PACKAGE_CACHE_DIRECTORY
68 |
69 | directory = parent_dir / f"user_{client_id}"
70 | directory.mkdir(parents=True, exist_ok=True)
71 |
72 | return directory
73 |
74 |
75 | def make_api_chat_completion_call(conversation: list, chat_obj: "Chat"):
76 | """Stream a chat completion from OpenAI API given a conversation and a chat object.
77 |
78 | Args:
79 | conversation (list): A list of messages passed as input for the completion.
80 | chat_obj (Chat): Chat object containing the configurations for the chat.
81 |
82 | Yields:
83 | str: Chunks of text generated by the API in response to the conversation.
84 | """
85 | api_call_args = {}
86 | for field in OpenAiApiCallOptions.model_fields:
87 | if getattr(chat_obj, field) is not None:
88 | api_call_args[field] = getattr(chat_obj, field)
89 |
90 | logger.trace(
91 | "Making OpenAI API call with chat=<{}>, args {} and messages {}",
92 | chat_obj.id,
93 | api_call_args,
94 | conversation,
95 | )
96 |
97 | @retry(error_msg="Problems connecting to OpenAI API")
98 | def stream_reply(conversation, **api_call_args):
99 | # Update the chat's token usage database with tokens used in chat input
100 | # Do this here because every attempt consumes tokens, even if it fails
101 | n_tokens = get_n_tokens_from_msgs(messages=conversation, model=chat_obj.model)
102 | for db in [chat_obj.general_token_usage_db, chat_obj.token_usage_db]:
103 | db.insert_data(model=chat_obj.model, n_input_tokens=n_tokens)
104 |
105 | full_reply_content = ""
106 | for completion_chunk in chat_obj.openai_client.chat.completions.create(
107 | messages=conversation, stream=True, **api_call_args
108 | ):
109 | reply_chunk = getattr(completion_chunk.choices[0].delta, "content", "")
110 | if reply_chunk is None:
111 | break
112 | full_reply_content += reply_chunk
113 | yield reply_chunk
114 |
115 | # Update the chat's token usage database with tokens used in chat output
116 | reply_as_msg = {"role": "assistant", "content": full_reply_content}
117 | n_tokens = get_n_tokens_from_msgs(messages=[reply_as_msg], model=chat_obj.model)
118 | for db in [chat_obj.general_token_usage_db, chat_obj.token_usage_db]:
119 | db.insert_data(model=chat_obj.model, n_output_tokens=n_tokens)
120 |
121 | logger.trace("Done with OpenAI API call")
122 | yield from stream_reply(conversation, **api_call_args)
123 |
--------------------------------------------------------------------------------
/pyrobbot/sst_and_tts.py:
--------------------------------------------------------------------------------
1 | """Code related to speech-to-text and text-to-speech conversions."""
2 |
3 | import io
4 | import socket
5 | import uuid
6 | from dataclasses import dataclass, field
7 | from typing import Literal
8 |
9 | import numpy as np
10 | import speech_recognition as sr
11 | from gtts import gTTS
12 | from loguru import logger
13 | from openai import OpenAI
14 | from pydub import AudioSegment
15 |
16 | from .general_utils import retry
17 | from .tokens import TokenUsageDatabase
18 |
19 |
20 | @dataclass
21 | class SpeechAndTextConfigs:
22 | """Configs for speech-to-text and text-to-speech."""
23 |
24 | openai_client: OpenAI
25 | general_token_usage_db: TokenUsageDatabase
26 | token_usage_db: TokenUsageDatabase
27 | engine: Literal["openai", "google"] = "google"
28 | language: str = "en"
29 | timeout: int = 10
30 |
31 |
32 | @dataclass
33 | class SpeechToText(SpeechAndTextConfigs):
34 | """Class for converting speech to text."""
35 |
36 | speech: AudioSegment = None
37 | _text: str = field(init=False, default="")
38 |
39 | def __post_init__(self):
40 | if not self.speech:
41 | self.speech = AudioSegment.silent(duration=0)
42 | self.recogniser = sr.Recognizer()
43 | self.recogniser.operation_timeout = self.timeout
44 |
45 | wav_buffer = io.BytesIO()
46 | self.speech.export(wav_buffer, format="wav")
47 | wav_buffer.seek(0)
48 | with sr.AudioFile(wav_buffer) as source:
49 | self.audio_data = self.recogniser.listen(source)
50 |
51 | @property
52 | def text(self) -> str:
53 | """Return the text from the speech."""
54 | if not self._text:
55 | self._text = self._stt()
56 | return self._text
57 |
58 | def _stt(self) -> str:
59 | """Perform speech-to-text."""
60 | if not self.speech:
61 | logger.debug("No speech detected")
62 | return ""
63 |
64 | if self.engine == "openai":
65 | stt_function = self._stt_openai
66 | fallback_stt_function = self._stt_google
67 | fallback_name = "google"
68 | else:
69 | stt_function = self._stt_google
70 | fallback_stt_function = self._stt_openai
71 | fallback_name = "openai"
72 |
73 | conversion_id = uuid.uuid4()
74 | logger.debug(
75 | "Converting audio to text ({} STT). Process {}.", self.engine, conversion_id
76 | )
77 | try:
78 | rtn = stt_function()
79 | except (
80 | ConnectionResetError,
81 | socket.timeout,
82 | sr.exceptions.RequestError,
83 | ) as error:
84 | logger.error(error)
85 | logger.error(
86 | "{}: Can't communicate with `{}` speech-to-text API right now",
87 | conversion_id,
88 | self.engine,
89 | )
90 | logger.warning(
91 | "{}: Trying to use `{}` STT instead", conversion_id, fallback_name
92 | )
93 | rtn = fallback_stt_function()
94 | except sr.exceptions.UnknownValueError:
95 | logger.opt(colors=True).debug(
96 | "{}: Can't understand audio", conversion_id
97 | )
98 | rtn = ""
99 |
100 | self._text = rtn.strip()
101 | logger.opt(colors=True).debug(
102 | "{}: Done with STT: {}", conversion_id, self._text
103 | )
104 |
105 | return self._text
106 |
107 | @retry()
108 | def _stt_openai(self):
109 | """Perform speech-to-text using OpenAI's API."""
110 | wav_buffer = io.BytesIO(self.audio_data.get_wav_data())
111 | wav_buffer.name = "audio.wav"
112 | with wav_buffer as audio_file_buffer:
113 | transcript = self.openai_client.audio.transcriptions.create(
114 | model="whisper-1",
115 | file=audio_file_buffer,
116 | language=self.language.split("-")[0], # put in ISO-639-1 format
117 | prompt=f"The language is {self.language}. "
118 | "Do not transcribe if you think the audio is noise.",
119 | )
120 |
121 | for db in [
122 | self.general_token_usage_db,
123 | self.token_usage_db,
124 | ]:
125 | db.insert_data(
126 | model="whisper-1",
127 | n_input_tokens=int(np.ceil(self.speech.duration_seconds)),
128 | )
129 |
130 | return transcript.text
131 |
132 | def _stt_google(self):
133 | """Perform speech-to-text using Google's API."""
134 | return self.recogniser.recognize_google(
135 | audio_data=self.audio_data, language=self.language
136 | )
137 |
138 |
139 | @dataclass
140 | class TextToSpeech(SpeechAndTextConfigs):
141 | """Class for converting text to speech."""
142 |
143 | text: str = ""
144 | openai_tts_voice: str = ""
145 | _speech: AudioSegment = field(init=False, default=None)
146 |
147 | def __post_init__(self):
148 | self.text = self.text.strip()
149 |
150 | @property
151 | def speech(self) -> AudioSegment:
152 | """Return the speech from the text."""
153 | if not self._speech:
154 | self._speech = self._tts()
155 | return self._speech
156 |
157 | def set_sample_rate(self, sample_rate: int):
158 | """Set the sample rate of the speech."""
159 | self._speech = self.speech.set_frame_rate(sample_rate)
160 |
161 | def _tts(self):
162 | logger.debug("Running {} TTS on text '{}'", self.engine, self.text)
163 | rtn = self._tts_openai() if self.engine == "openai" else self._tts_google()
164 | logger.debug("Done with TTS for '{}'", self.text)
165 |
166 | return rtn
167 |
168 | def _tts_openai(self) -> AudioSegment:
169 | """Convert text to speech using OpenAI's TTS. Return an AudioSegment object."""
170 | openai_tts_model = "tts-1"
171 |
172 | @retry()
173 | def _create_speech(*args, **kwargs):
174 | for db in [
175 | self.general_token_usage_db,
176 | self.token_usage_db,
177 | ]:
178 | db.insert_data(model=openai_tts_model, n_input_tokens=len(self.text))
179 | return self.openai_client.audio.speech.create(*args, **kwargs)
180 |
181 | response = _create_speech(
182 | input=self.text,
183 | model=openai_tts_model,
184 | voice=self.openai_tts_voice,
185 | response_format="mp3",
186 | timeout=self.timeout,
187 | )
188 |
189 | mp3_buffer = io.BytesIO()
190 | for mp3_stream_chunk in response.iter_bytes(chunk_size=4096):
191 | mp3_buffer.write(mp3_stream_chunk)
192 | mp3_buffer.seek(0)
193 |
194 | audio = AudioSegment.from_mp3(mp3_buffer)
195 | audio += 8 # Increase volume a bit
196 | return audio
197 |
198 | def _tts_google(self) -> AudioSegment:
199 | """Convert text to speech using Google's TTS. Return a WAV BytesIO object."""
200 | tts = gTTS(self.text, lang=self.language)
201 | mp3_buffer = io.BytesIO()
202 | tts.write_to_fp(mp3_buffer)
203 | mp3_buffer.seek(0)
204 |
205 | return AudioSegment.from_mp3(mp3_buffer)
206 |
--------------------------------------------------------------------------------
/pyrobbot/tokens.py:
--------------------------------------------------------------------------------
1 | """Management of token usage and costs for OpenAI API."""
2 |
3 | import contextlib
4 | import datetime
5 | import sqlite3
6 | from pathlib import Path
7 | from typing import Optional
8 |
9 | import pandas as pd
10 | import tiktoken
11 |
12 | # See for the latest prices.
13 | PRICE_PER_K_TOKENS_LLM = {
14 | # Continuous model upgrades (models that point to the latest versions)
15 | "gpt-3.5-turbo": {"input": 0.0005, "output": 0.0015},
16 | "gpt-4-turbo-preview": {"input": 0.01, "output": 0.03},
17 | "gpt-4": {"input": 0.03, "output": 0.06},
18 | "gpt-3.5-turbo-16k": {"input": 0.001, "output": 0.002}, # -> gpt-3.5-turbo-16k-0613
19 | "gpt-4-32k": {"input": 0.06, "output": 0.12},
20 | # Static model versions
21 | # GPT 3
22 | "gpt-3.5-turbo-0125": {"input": 0.0015, "output": 0.002},
23 | "gpt-3.5-turbo-1106": {"input": 0.001, "output": 0.002},
24 | "gpt-3.5-turbo-0613": {"input": 0.0015, "output": 0.002}, # Deprecated, 2024-06-13
25 | "gpt-3.5-turbo-16k-0613": {"input": 0.001, "output": 0.002}, # Deprecated, 2024-06-13
26 | # GPT 4
27 | "gpt-4-0125-preview": {"input": 0.01, "output": 0.03},
28 | "gpt-4-1106-preview": {"input": 0.01, "output": 0.03},
29 | "gpt-4-0613": {"input": 0.03, "output": 0.06},
30 | "gpt-4-32k-0613": {"input": 0.06, "output": 0.12},
31 | }
32 | PRICE_PER_K_TOKENS_EMBEDDINGS = {
33 | "text-embedding-3-small": {"input": 0.00002, "output": 0.0},
34 | "text-embedding-3-large": {"input": 0.00013, "output": 0.0},
35 | "text-embedding-ada-002": {"input": 0.0001, "output": 0.0},
36 | "text-embedding-ada-002-v2": {"input": 0.0001, "output": 0.0},
37 | "text-davinci:002": {"input": 0.0020, "output": 0.020},
38 | "full-history": {"input": 0.0, "output": 0.0},
39 | }
40 | PRICE_PER_K_TOKENS_TTS_AND_STT = {
41 | "tts-1": {"input": 0.015, "output": 0.0},
42 | "tts-1-hd": {"input": 0.03, "output": 0.0},
43 | "whisper-1": {"input": 0.006, "output": 0.0},
44 | }
45 | PRICE_PER_K_TOKENS = (
46 | PRICE_PER_K_TOKENS_LLM
47 | | PRICE_PER_K_TOKENS_EMBEDDINGS
48 | | PRICE_PER_K_TOKENS_TTS_AND_STT
49 | )
50 |
51 |
52 | class TokenUsageDatabase:
53 | """Manages a database to store estimated token usage and costs for OpenAI API."""
54 |
55 | def __init__(self, fpath: Path):
56 | """Initialize a TokenUsageDatabase instance."""
57 | self.fpath = fpath
58 | self.token_price = {}
59 | for model, price_per_k_tokens in PRICE_PER_K_TOKENS.items():
60 | self.token_price[model] = {
61 | k: v / 1000.0 for k, v in price_per_k_tokens.items()
62 | }
63 |
64 | self.create()
65 |
66 | def create(self):
67 | """Create the database if it doesn't exist."""
68 | self.fpath.parent.mkdir(parents=True, exist_ok=True)
69 | conn = sqlite3.connect(self.fpath)
70 | cursor = conn.cursor()
71 |
72 | # Create a table to store the data with 'timestamp' as the primary key
73 | cursor.execute(
74 | """
75 | CREATE TABLE IF NOT EXISTS token_costs (
76 | timestamp INTEGER NOT NULL,
77 | model TEXT NOT NULL,
78 | n_input_tokens INTEGER NOT NULL,
79 | n_output_tokens INTEGER NOT NULL,
80 | cost_input_tokens REAL NOT NULL,
81 | cost_output_tokens REAL NOT NULL
82 | )
83 | """
84 | )
85 |
86 | conn.commit()
87 | conn.close()
88 |
89 | def insert_data(
90 | self,
91 | model: str,
92 | n_input_tokens: int = 0,
93 | n_output_tokens: int = 0,
94 | timestamp: Optional[int] = None,
95 | ):
96 | """Insert the data into the token_costs table."""
97 | if model is None:
98 | return
99 |
100 | conn = sqlite3.connect(self.fpath)
101 | cursor = conn.cursor()
102 |
103 | # Insert the data into the table
104 | cursor.execute(
105 | """
106 | INSERT INTO token_costs (
107 | timestamp,
108 | model,
109 | n_input_tokens,
110 | n_output_tokens,
111 | cost_input_tokens,
112 | cost_output_tokens
113 | )
114 | VALUES (?, ?, ?, ?, ?, ?)
115 | """,
116 | (
117 | timestamp or int(datetime.datetime.utcnow().timestamp()),
118 | model,
119 | n_input_tokens,
120 | n_output_tokens,
121 | n_input_tokens * self.token_price[model]["input"],
122 | n_output_tokens * self.token_price[model]["output"],
123 | ),
124 | )
125 |
126 | conn.commit()
127 | conn.close()
128 |
129 | def get_usage_balance_dataframe(self):
130 | """Get a dataframe with the accumulated token usage and costs."""
131 | conn = sqlite3.connect(self.fpath)
132 | query = """
133 | SELECT
134 | model as Model,
135 | MIN(timestamp) AS "First Used",
136 | SUM(n_input_tokens) AS "Tokens: In",
137 | SUM(n_output_tokens) AS "Tokens: Out",
138 | SUM(n_input_tokens + n_output_tokens) AS "Tokens: Tot.",
139 | SUM(cost_input_tokens) AS "Cost ($): In",
140 | SUM(cost_output_tokens) AS "Cost ($): Out",
141 | SUM(cost_input_tokens + cost_output_tokens) AS "Cost ($): Tot."
142 | FROM token_costs
143 | GROUP BY model
144 | ORDER BY "Cost ($): Tot." DESC
145 | """
146 |
147 | usage_df = pd.read_sql_query(query, con=conn)
148 | conn.close()
149 |
150 | usage_df["First Used"] = pd.to_datetime(usage_df["First Used"], unit="s")
151 |
152 | usage_df = _group_columns_by_prefix(_add_totals_row(usage_df))
153 |
154 | # Add metadata to returned dataframe
155 | usage_df.attrs["description"] = "Estimated token usage and associated costs"
156 | link = "https://platform.openai.com/account/usage"
157 | disclaimers = [
158 | "Note: These are only estimates. Actual costs may vary.",
159 | f"Please visit <{link}> to follow your actual usage and costs.",
160 | ]
161 | usage_df.attrs["disclaimer"] = "\n".join(disclaimers)
162 |
163 | return usage_df
164 |
165 |
166 | def get_n_tokens_from_msgs(messages: list[dict], model: str):
167 | """Returns the number of tokens used by a list of messages."""
168 | # Adapted from
169 | #
170 | encoding = tiktoken.get_encoding("cl100k_base")
171 | with contextlib.suppress(KeyError):
172 | encoding = tiktoken.encoding_for_model(model)
173 |
174 | # OpenAI's original function was implemented for gpt-3.5-turbo-0613, but we'll use
175 | # it for all models for now. We are only interested in estimates, after all.
176 | num_tokens = 0
177 | for message in messages:
178 | # every message follows {role/name}\n{content}\n
179 | num_tokens += 4
180 | for key, value in message.items():
181 | if not isinstance(value, str):
182 | raise TypeError(
183 | f"Value for key '{key}' has type {type(value)}. Expected str: {value}"
184 | )
185 | num_tokens += len(encoding.encode(value))
186 | if key == "name": # if there's a name, the role is omitted
187 | num_tokens += -1 # role is always required and always 1 token
188 | num_tokens += 2 # every reply is primed with assistant
189 | return num_tokens
190 |
191 |
192 | def _group_columns_by_prefix(dataframe: pd.DataFrame):
193 | dataframe = dataframe.copy()
194 | col_tuples_for_multiindex = dataframe.columns.str.split(": ", expand=True).to_numpy()
195 | dataframe.columns = pd.MultiIndex.from_tuples(
196 | [("", x[0]) if pd.isna(x[1]) else x for x in col_tuples_for_multiindex]
197 | )
198 | return dataframe
199 |
200 |
201 | def _add_totals_row(accounting_df: pd.DataFrame):
202 | dtypes = accounting_df.dtypes
203 | sums_df = accounting_df.sum(numeric_only=True).rename("Total").to_frame().T
204 | return pd.concat([accounting_df, sums_df]).astype(dtypes).fillna(" ")
205 |
--------------------------------------------------------------------------------
/pyrobbot/voice_chat.py:
--------------------------------------------------------------------------------
1 | """Code related to the voice chat feature."""
2 |
3 | import contextlib
4 | import io
5 | import queue
6 | import threading
7 | import time
8 | from collections import defaultdict, deque
9 | from datetime import datetime
10 |
11 | import chime
12 | import numpy as np
13 | import pydub
14 | import pygame
15 | import soundfile as sf
16 | import webrtcvad
17 | from loguru import logger
18 | from pydub import AudioSegment
19 |
20 | from .chat import Chat
21 | from .chat_configs import VoiceChatConfigs
22 | from .general_utils import _get_lower_alphanumeric, str2_minus_str1
23 | from .sst_and_tts import TextToSpeech
24 |
25 | try:
26 | import sounddevice as sd
27 | except OSError as error:
28 | logger.exception(error)
29 | logger.error(
30 | "Can't use module `sounddevice`. Please check your system's PortAudio install."
31 | )
32 | _sounddevice_imported = False
33 | else:
34 | _sounddevice_imported = True
35 |
36 | try:
37 | # Test if pydub's AudioSegment can be used
38 | with contextlib.suppress(pydub.exceptions.CouldntDecodeError):
39 | AudioSegment.from_mp3(io.BytesIO())
40 | except (ImportError, OSError, FileNotFoundError) as error:
41 | logger.exception(error)
42 | logger.error("Can't use module `pydub`. Please check your system's ffmpeg install.")
43 | _pydub_usable = False
44 | else:
45 | _pydub_usable = True
46 |
47 |
48 | class VoiceChat(Chat):
49 | """Class for converting text to speech and speech to text."""
50 |
51 | default_configs = VoiceChatConfigs()
52 |
53 | def __init__(self, configs: VoiceChatConfigs = default_configs, **kwargs):
54 | """Initializes a chat instance."""
55 | super().__init__(configs=configs, **kwargs)
56 | _check_needed_imports()
57 |
58 | self.block_size = int((self.sample_rate * self.frame_duration) / 1000)
59 |
60 | self.vad = webrtcvad.Vad(2)
61 |
62 | self.default_chime_theme = "big-sur"
63 | chime.theme(self.default_chime_theme)
64 |
65 | # Create queues and threads for handling the chat
66 | # 1. Watching for questions from the user
67 | self.questions_queue = queue.Queue()
68 | self.questions_listening_watcher_thread = threading.Thread(
69 | target=self.handle_question_listening,
70 | args=(self.questions_queue,),
71 | daemon=True,
72 | )
73 | # 2. Converting assistant's text reply to speech and playing it
74 | self.tts_conversion_queue = queue.Queue()
75 | self.play_speech_queue = queue.Queue()
76 | self.tts_conversion_watcher_thread = threading.Thread(
77 | target=self.handle_tts_conversion_queue,
78 | args=(self.tts_conversion_queue,),
79 | daemon=True,
80 | )
81 | self.play_speech_thread = threading.Thread(
82 | target=self.handle_play_speech_queue,
83 | args=(self.play_speech_queue,),
84 | daemon=True,
85 | ) # TODO: Do not start this in webchat
86 | # 3. Watching for expressions that cancel the reply or exit the chat
87 | self.check_for_interrupt_expressions_queue = queue.Queue()
88 | self.check_for_interrupt_expressions_thread = threading.Thread(
89 | target=self.check_for_interrupt_expressions_handler,
90 | args=(self.check_for_interrupt_expressions_queue,),
91 | daemon=True,
92 | )
93 | self.interrupt_reply = threading.Event()
94 | self.exit_chat = threading.Event()
95 |
96 | # Keep track of played audios to update the history db
97 | self.current_answer_audios_queue = queue.Queue()
98 | self.handle_update_audio_history_thread = threading.Thread(
99 | target=self.handle_update_audio_history,
100 | args=(self.current_answer_audios_queue,),
101 | daemon=True,
102 | )
103 |
104 | @property
105 | def mixer(self):
106 | """Return the mixer object."""
107 | mixer = getattr(self, "_mixer", None)
108 | if mixer is not None:
109 | return mixer
110 |
111 | self._mixer = pygame.mixer
112 | try:
113 | self.mixer.init(
114 | frequency=self.sample_rate, channels=1, buffer=self.block_size
115 | )
116 | except pygame.error as error:
117 | logger.exception(error)
118 | logger.error(
119 | "Can't initialize the mixer. Please check your system's audio settings."
120 | )
121 | logger.warning("Voice chat may not be available or may not work as expected.")
122 | return self._mixer
123 |
124 | def start(self):
125 | """Start the chat."""
126 | # ruff: noqa: T201
127 | self.tts_conversion_watcher_thread.start()
128 | self.play_speech_thread.start()
129 | if not self.skip_initial_greeting:
130 | tts_entry = {"exchange_id": self.id, "text": self.initial_greeting}
131 | self.tts_conversion_queue.put(tts_entry)
132 | while self._assistant_still_replying():
133 | pygame.time.wait(50)
134 | self.questions_listening_watcher_thread.start()
135 | self.check_for_interrupt_expressions_thread.start()
136 | self.handle_update_audio_history_thread.start()
137 |
138 | with contextlib.suppress(KeyboardInterrupt, EOFError):
139 | while not self.exit_chat.is_set():
140 | self.tts_conversion_queue.join()
141 | self.play_speech_queue.join()
142 | self.current_answer_audios_queue.join()
143 |
144 | if self.interrupt_reply.is_set():
145 | logger.opt(colors=True).debug(
146 | "Interrupting the reply"
147 | )
148 | with self.check_for_interrupt_expressions_queue.mutex:
149 | self.check_for_interrupt_expressions_queue.queue.clear()
150 | with contextlib.suppress(pygame.error):
151 | self.mixer.stop()
152 | with self.questions_queue.mutex:
153 | self.questions_queue.queue.clear()
154 | chime.theme("material")
155 | chime.error()
156 | chime.theme(self.default_chime_theme)
157 | time.sleep(0.25)
158 |
159 | chime.warning()
160 | self.interrupt_reply.clear()
161 | logger.debug(f"{self.assistant_name}> Waiting for user input...")
162 | question = self.questions_queue.get()
163 | self.questions_queue.task_done()
164 |
165 | if question is None:
166 | self.exit_chat.set()
167 | else:
168 | chime.success()
169 | for chunk in self.answer_question(question):
170 | if chunk.chunk_type == "code":
171 | print(chunk.content, end="", flush=True)
172 |
173 | self.exit_chat.set()
174 | chime.info()
175 | logger.debug("Leaving chat")
176 |
177 | def answer_question(self, question: str):
178 | """Answer a question."""
179 | logger.debug("{}> Getting response to '{}'...", self.assistant_name, question)
180 | sentence_for_tts = ""
181 | any_code_chunk_yet = False
182 | for answer_chunk in self.respond_user_prompt(prompt=question):
183 | if self.interrupt_reply.is_set() or self.exit_chat.is_set():
184 | logger.debug("Reply interrupted.")
185 | raise StopIteration
186 | yield answer_chunk
187 |
188 | if not self.reply_only_as_text:
189 | if answer_chunk.chunk_type not in ("text", "code"):
190 | raise NotImplementedError(
191 | "Unexpected chunk type: {}".format(answer_chunk.chunk_type)
192 | )
193 |
194 | if answer_chunk.chunk_type == "text":
195 | # The answer chunk is to be spoken
196 | sentence_for_tts += answer_chunk.content
197 | stripd_chunk = answer_chunk.content.strip()
198 | if stripd_chunk.endswith(("?", "!", ".")):
199 | # Check if second last character is a number, to avoid splitting
200 | if stripd_chunk.endswith("."):
201 | with contextlib.suppress(IndexError):
202 | previous_char = sentence_for_tts.strip()[-2]
203 | if previous_char.isdigit():
204 | continue
205 | # Send sentence for TTS even if the request hasn't finished
206 | tts_entry = {
207 | "exchange_id": answer_chunk.exchange_id,
208 | "text": sentence_for_tts,
209 | }
210 | self.tts_conversion_queue.put(tts_entry)
211 | sentence_for_tts = ""
212 | elif answer_chunk.chunk_type == "code" and not any_code_chunk_yet:
213 | msg = self._translate("Code will be displayed in the text output.")
214 | tts_entry = {"exchange_id": answer_chunk.exchange_id, "text": msg}
215 | self.tts_conversion_queue.put(tts_entry)
216 | any_code_chunk_yet = True
217 |
218 | if sentence_for_tts and not self.reply_only_as_text:
219 | tts_entry = {
220 | "exchange_id": answer_chunk.exchange_id,
221 | "text": sentence_for_tts,
222 | }
223 | self.tts_conversion_queue.put(tts_entry)
224 |
225 | # Signal that the current answer is finished
226 | tts_entry = {"exchange_id": answer_chunk.exchange_id, "text": None}
227 | self.tts_conversion_queue.put(tts_entry)
228 |
229 | def handle_update_audio_history(self, current_answer_audios_queue: queue.Queue):
230 | """Handle updating the chat history with the replies' audio file paths."""
231 | # Merge all AudioSegments in self.current_answer_audios_queue into a single one
232 | merged_audios = defaultdict(AudioSegment.empty)
233 | while not self.exit_chat.is_set():
234 | try:
235 | logger.debug("Waiting for reply audio chunks to concatenate and save...")
236 | audio_chunk_queue_item = current_answer_audios_queue.get()
237 | reply_audio_chunk = audio_chunk_queue_item["speech"]
238 | exchange_id = audio_chunk_queue_item["exchange_id"]
239 | logger.debug("Received audio chunk for response ID {}", exchange_id)
240 |
241 | if reply_audio_chunk is not None:
242 | # Reply not yet finished
243 | merged_audios[exchange_id] += reply_audio_chunk
244 | logger.debug(
245 | "Response ID {} audio: {}s so far",
246 | exchange_id,
247 | merged_audios[exchange_id].duration_seconds,
248 | )
249 | current_answer_audios_queue.task_done()
250 | continue
251 |
252 | # Now the reply has finished
253 | logger.debug(
254 | "Creating a single audio file for response ID {}...", exchange_id
255 | )
256 | merged_audio = merged_audios[exchange_id]
257 | # Update the chat history with the audio file path
258 | fpath = self.audio_cache_dir() / f"{datetime.now().isoformat()}.mp3"
259 | logger.debug("Updating chat history with audio file path {}", fpath)
260 | self.context_handler.database.insert_assistant_audio_file_path(
261 | exchange_id=exchange_id, file_path=fpath
262 | )
263 | # Save the combined audio as an mp3 file in the cache directory
264 | merged_audio.export(fpath, format="mp3")
265 | logger.debug("File {} stored", fpath)
266 | del merged_audios[exchange_id]
267 | current_answer_audios_queue.task_done()
268 | except Exception as error: # noqa: BLE001
269 | logger.error(error)
270 | logger.opt(exception=True).debug(error)
271 |
272 | def speak(self, tts: TextToSpeech):
273 | """Reproduce audio from a pygame Sound object."""
274 | tts.set_sample_rate(self.sample_rate)
275 | self.mixer.Sound(tts.speech.raw_data).play()
276 | audio_recorded_while_assistant_replies = self.listen(
277 | duration_seconds=tts.speech.duration_seconds
278 | )
279 |
280 | msgs_to_compare = {
281 | "assistant_txt": tts.text,
282 | "user_audio": audio_recorded_while_assistant_replies,
283 | }
284 | self.check_for_interrupt_expressions_queue.put(msgs_to_compare)
285 |
286 | while self.mixer.get_busy():
287 | pygame.time.wait(100)
288 |
289 | def check_for_interrupt_expressions_handler(
290 | self, check_for_interrupt_expressions_queue: queue.Queue
291 | ):
292 | """Check for expressions that interrupt the assistant's reply."""
293 | while not self.exit_chat.is_set():
294 | try:
295 | msgs_to_compare = check_for_interrupt_expressions_queue.get()
296 | recorded_prompt = self.stt(speech=msgs_to_compare["user_audio"]).text
297 |
298 | recorded_prompt = _get_lower_alphanumeric(recorded_prompt).strip()
299 | assistant_msg = _get_lower_alphanumeric(
300 | msgs_to_compare.get("assistant_txt", "")
301 | ).strip()
302 |
303 | user_words = str2_minus_str1(
304 | str1=assistant_msg, str2=recorded_prompt
305 | ).strip()
306 | if user_words:
307 | logger.debug(
308 | "Detected user words while assistant was replying: {}",
309 | user_words,
310 | )
311 | if any(
312 | cancel_cmd in user_words for cancel_cmd in self.cancel_expressions
313 | ):
314 | logger.debug(
315 | "Heard '{}'. Signalling for reply to be cancelled...",
316 | user_words,
317 | )
318 | self.interrupt_reply.set()
319 | except Exception as error: # noqa: PERF203, BLE001
320 | logger.opt(exception=True).debug(error)
321 | finally:
322 | check_for_interrupt_expressions_queue.task_done()
323 |
324 | def listen(self, duration_seconds: float = np.inf) -> AudioSegment:
325 | """Record audio from the microphone until user stops."""
326 | # Adapted from
327 | #
329 | debug_msg = "The assistant is listening"
330 | if duration_seconds < np.inf:
331 | debug_msg += f" for {duration_seconds} s"
332 | debug_msg += "..."
333 |
334 | inactivity_timeout_seconds = self.inactivity_timeout_seconds
335 | if duration_seconds < np.inf:
336 | inactivity_timeout_seconds = duration_seconds
337 |
338 | q = queue.Queue()
339 |
340 | def callback(indata, frames, time, status): # noqa: ARG001
341 | """This is called (from a separate thread) for each audio block."""
342 | q.put(indata.copy())
343 |
344 | raw_buffer = io.BytesIO()
345 | start_time = datetime.now()
346 | with self.get_sound_file(raw_buffer, mode="x") as sound_file, sd.InputStream(
347 | samplerate=self.sample_rate,
348 | blocksize=self.block_size,
349 | channels=1,
350 | callback=callback,
351 | dtype="int16", # int16, i.e., 2 bytes per sample
352 | ):
353 | logger.debug("{}", debug_msg)
354 | # Recording will stop after inactivity_timeout_seconds of silence
355 | voice_activity_detected = deque(
356 | maxlen=int((1000.0 * inactivity_timeout_seconds) / self.frame_duration)
357 | )
358 | last_inactivity_checked = datetime.now()
359 | continue_recording = True
360 | speech_detected = False
361 | elapsed_time = 0.0
362 | with contextlib.suppress(KeyboardInterrupt):
363 | while continue_recording and elapsed_time < duration_seconds:
364 | new_data = q.get()
365 | sound_file.write(new_data)
366 |
367 | # Gather voice activity samples for the inactivity check
368 | wav_buffer = _np_array_to_wav_in_memory(
369 | sound_data=new_data,
370 | sample_rate=self.sample_rate,
371 | subtype="PCM_16",
372 | )
373 |
374 | vad_thinks_this_chunk_is_speech = self.vad.is_speech(
375 | wav_buffer, self.sample_rate
376 | )
377 | voice_activity_detected.append(vad_thinks_this_chunk_is_speech)
378 |
379 | # Decide if user has been inactive for too long
380 | now = datetime.now()
381 | if duration_seconds < np.inf:
382 | continue_recording = True
383 | elif (
384 | now - last_inactivity_checked
385 | ).seconds >= inactivity_timeout_seconds:
386 | speech_likelihood = 0.0
387 | if len(voice_activity_detected) > 0:
388 | speech_likelihood = sum(voice_activity_detected) / len(
389 | voice_activity_detected
390 | )
391 | continue_recording = (
392 | speech_likelihood >= self.speech_likelihood_threshold
393 | )
394 | if continue_recording:
395 | speech_detected = True
396 | last_inactivity_checked = now
397 |
398 | elapsed_time = (now - start_time).seconds
399 |
400 | if speech_detected or duration_seconds < np.inf:
401 | return AudioSegment.from_wav(raw_buffer)
402 | return AudioSegment.empty()
403 |
404 | def handle_question_listening(self, questions_queue: queue.Queue):
405 | """Handle the queue of questions to be answered."""
406 | minimum_prompt_duration_seconds = 0.05
407 | while not self.exit_chat.is_set():
408 | if self._assistant_still_replying():
409 | pygame.time.wait(100)
410 | continue
411 | try:
412 | audio = self.listen()
413 | if audio is None:
414 | questions_queue.put(None)
415 | continue
416 |
417 | if audio.duration_seconds < minimum_prompt_duration_seconds:
418 | continue
419 |
420 | question = self.stt(speech=audio).text
421 |
422 | # Check for the exit expressions
423 | if any(
424 | _get_lower_alphanumeric(question).startswith(
425 | _get_lower_alphanumeric(expr)
426 | )
427 | for expr in self.exit_expressions
428 | ):
429 | questions_queue.put(None)
430 | elif question:
431 | questions_queue.put(question)
432 | except sd.PortAudioError as error:
433 | logger.opt(exception=True).debug(error)
434 | except Exception as error: # noqa: BLE001
435 | logger.opt(exception=True).debug(error)
436 | logger.error(error)
437 |
438 | def handle_play_speech_queue(self, play_speech_queue: queue.Queue[TextToSpeech]):
439 | """Handle the queue of audio segments to be played."""
440 | while not self.exit_chat.is_set():
441 | try:
442 | play_speech_queue_item = play_speech_queue.get()
443 | if play_speech_queue_item["speech"] and not self.interrupt_reply.is_set():
444 | self.speak(play_speech_queue_item["tts_obj"])
445 | except Exception as error: # noqa: BLE001, PERF203
446 | logger.exception(error)
447 | finally:
448 | play_speech_queue.task_done()
449 |
450 | def handle_tts_conversion_queue(self, tts_conversion_queue: queue.Queue):
451 | """Handle the text-to-speech queue."""
452 | logger.debug("Chat {}: TTS conversion handler started.", self.id)
453 | while not self.exit_chat.is_set():
454 | try:
455 | tts_entry = tts_conversion_queue.get()
456 | if tts_entry["text"] is None:
457 | # Signal that the current anwer is finished
458 | play_speech_queue_item = {
459 | "exchange_id": tts_entry["exchange_id"],
460 | "speech": None,
461 | }
462 | self.play_speech_queue.put(play_speech_queue_item)
463 | self.current_answer_audios_queue.put(play_speech_queue_item)
464 |
465 | logger.debug(
466 | "Reply ID {} notified that is has finished",
467 | tts_entry["exchange_id"],
468 | )
469 | tts_conversion_queue.task_done()
470 | continue
471 |
472 | text = tts_entry["text"].strip()
473 | if text and not self.interrupt_reply.is_set():
474 | logger.debug(
475 | "Reply ID {}: received text '{}' for TTS",
476 | tts_entry["exchange_id"],
477 | text,
478 | )
479 |
480 | tts_obj = self.tts(text)
481 | # Trigger the TTS conversion
482 | _ = tts_obj.speech
483 |
484 | logger.debug(
485 | "Reply ID {}: Sending speech for '{}' to the playing queue",
486 | tts_entry["exchange_id"],
487 | text,
488 | )
489 | play_speech_queue_item = {
490 | "exchange_id": tts_entry["exchange_id"],
491 | "tts_obj": tts_obj,
492 | "speech": tts_obj.speech,
493 | }
494 | self.play_speech_queue.put(play_speech_queue_item)
495 | self.current_answer_audios_queue.put(play_speech_queue_item)
496 |
497 | # Pay attention to the indentation level
498 | tts_conversion_queue.task_done()
499 |
500 | except Exception as error: # noqa: BLE001
501 | logger.opt(exception=True).debug(error)
502 | logger.error(error)
503 | logger.error("TTS conversion queue handler ended.")
504 |
505 | def get_sound_file(self, wav_buffer: io.BytesIO, mode: str = "r"):
506 | """Return a sound file object."""
507 | return sf.SoundFile(
508 | wav_buffer,
509 | mode=mode,
510 | samplerate=self.sample_rate,
511 | channels=1,
512 | format="wav",
513 | subtype="PCM_16",
514 | )
515 |
516 | def audio_cache_dir(self):
517 | """Return the audio cache directory."""
518 | directory = self.cache_dir / "audio_files"
519 | directory.mkdir(parents=True, exist_ok=True)
520 | return directory
521 |
522 | def _assistant_still_replying(self):
523 | """Check if the assistant is still talking."""
524 | return (
525 | self.mixer.get_busy()
526 | or self.questions_queue.unfinished_tasks > 0
527 | or self.tts_conversion_queue.unfinished_tasks > 0
528 | or self.play_speech_queue.unfinished_tasks > 0
529 | )
530 |
531 |
532 | def _check_needed_imports():
533 | """Check if the needed modules are available."""
534 | if not _sounddevice_imported:
535 | logger.warning(
536 | "Module `sounddevice`, needed for local audio recording, is not available."
537 | )
538 |
539 | if not _pydub_usable:
540 | logger.error(
541 | "Module `pydub`, needed for audio conversion, doesn't seem to be working. "
542 | "Voice chat may not be available or may not work as expected."
543 | )
544 |
545 |
546 | def _np_array_to_wav_in_memory(
547 | sound_data: np.ndarray, sample_rate: int, subtype="PCM_16"
548 | ):
549 | """Convert the recorded array to an in-memory wav file."""
550 | wav_buffer = io.BytesIO()
551 | wav_buffer.name = "audio.wav"
552 | sf.write(wav_buffer, sound_data, sample_rate, subtype=subtype)
553 | wav_buffer.seek(44) # Skip the WAV header
554 | return wav_buffer.read()
555 |
--------------------------------------------------------------------------------
/tests/conftest.py:
--------------------------------------------------------------------------------
1 | import io
2 |
3 | import lorem
4 | import numpy as np
5 | import pydub
6 | import pytest
7 | from _pytest.logging import LogCaptureFixture
8 | from loguru import logger
9 |
10 | import pyrobbot
11 | from pyrobbot.chat import Chat
12 | from pyrobbot.chat_configs import ChatOptions, VoiceChatConfigs
13 | from pyrobbot.voice_chat import VoiceChat
14 |
15 |
16 | @pytest.fixture()
17 | def caplog(caplog: LogCaptureFixture):
18 | """Override the default `caplog` fixture to propagate Loguru to the caplog handler."""
19 | # Source:
21 | handler_id = logger.add(
22 | caplog.handler,
23 | format="{message}",
24 | level=0,
25 | filter=lambda record: record["level"].no >= caplog.handler.level,
26 | enqueue=False, # Set to 'True' if your test is spawning child processes.
27 | )
28 | yield caplog
29 | logger.remove(handler_id)
30 |
31 |
32 | # Register markers and constants
33 | def pytest_configure(config):
34 | config.addinivalue_line(
35 | "markers",
36 | "no_chat_completion_create_mocking: do not mock openai.ChatCompletion.create",
37 | )
38 | config.addinivalue_line(
39 | "markers",
40 | "no_embedding_create_mocking: mark test to not mock openai.Embedding.create",
41 | )
42 |
43 | pytest.original_package_cache_directory = (
44 | pyrobbot.GeneralDefinitions.PACKAGE_CACHE_DIRECTORY
45 | )
46 |
47 |
48 | @pytest.fixture(autouse=True)
49 | def _set_env(monkeypatch):
50 | # Make sure we don't consume our tokens in tests
51 | monkeypatch.setenv("OPENAI_API_KEY", "INVALID_API_KEY")
52 | monkeypatch.setenv("STREAMLIT_SERVER_HEADLESS", "true")
53 |
54 |
55 | @pytest.fixture(autouse=True)
56 | def _mocked_general_constants(tmp_path, mocker):
57 | mocker.patch(
58 | "pyrobbot.GeneralDefinitions.PACKAGE_CACHE_DIRECTORY", tmp_path / "cache"
59 | )
60 |
61 |
62 | @pytest.fixture()
63 | def mock_wav_bytes_string():
64 | """Mock a WAV file as a bytes string."""
65 | return (
66 | b"RIFF$\x00\x00\x00WAVEfmt \x10\x00\x00\x00\x01\x00\x01\x00\x00\x04\x00"
67 | b"\x00\x00\x04\x00\x00\x01\x00\x08\x00data\x00\x00\x00\x00"
68 | )
69 |
70 |
71 | @pytest.fixture(autouse=True)
72 | def _openai_api_request_mockers(request, mocker):
73 | """Mockers for OpenAI API requests. We don't want to consume our tokens in tests."""
74 |
75 | def _mock_openai_chat_completion_create(*args, **kwargs): # noqa: ARG001
76 | """Mock `openai.ChatCompletion.create`. Yield from lorem ipsum instead."""
77 | completion_chunk = type("CompletionChunk", (), {})
78 | completion_chunk_choice = type("CompletionChunkChoice", (), {})
79 | completion_chunk_choice_delta = type("CompletionChunkChoiceDelta", (), {})
80 | for word in lorem.get_paragraph().split():
81 | completion_chunk_choice_delta.content = word + " "
82 | completion_chunk_choice.delta = completion_chunk_choice_delta
83 | completion_chunk.choices = [completion_chunk_choice]
84 | yield completion_chunk
85 |
86 | # Yield some code as well, to test the code filtering
87 | code_path = pyrobbot.GeneralDefinitions.PACKAGE_DIRECTORY / "__init__.py"
88 | for word in [
89 | "```python\n",
90 | *code_path.read_text().splitlines(keepends=True)[:5],
91 | "```\n",
92 | ]:
93 | completion_chunk_choice_delta.content = word + " "
94 | completion_chunk_choice.delta = completion_chunk_choice_delta
95 | completion_chunk.choices = [completion_chunk_choice]
96 | yield completion_chunk
97 |
98 | def _mock_openai_embedding_create(*args, **kwargs): # noqa: ARG001
99 | """Mock `openai.Embedding.create`. Yield from lorem ipsum instead."""
100 | embedding_request_mock_type = type("EmbeddingRequest", (), {})
101 | embedding_mock_type = type("Embedding", (), {})
102 | usage_mock_type = type("Usage", (), {})
103 |
104 | embedding = embedding_mock_type()
105 | embedding.embedding = np.random.rand(512).tolist()
106 | embedding_request = embedding_request_mock_type()
107 | embedding_request.data = [embedding]
108 |
109 | usage = usage_mock_type()
110 | usage.prompt_tokens = 0
111 | usage.total_tokens = 0
112 | embedding_request.usage = usage
113 |
114 | return embedding_request
115 |
116 | if "no_chat_completion_create_mocking" not in request.keywords:
117 | mocker.patch(
118 | "openai.resources.chat.completions.Completions.create",
119 | new=_mock_openai_chat_completion_create,
120 | )
121 | if "no_embedding_create_mocking" not in request.keywords:
122 | mocker.patch(
123 | "openai.resources.embeddings.Embeddings.create",
124 | new=_mock_openai_embedding_create,
125 | )
126 |
127 |
128 | @pytest.fixture(autouse=True)
129 | def _internet_search_mockers(mocker):
130 | """Mockers for the internet search module."""
131 | mocker.patch("duckduckgo_search.DDGS.text", return_value=lorem.get_paragraph())
132 |
133 |
134 | @pytest.fixture()
135 | def _input_builtin_mocker(mocker, user_input):
136 | """Mock the `input` builtin. Raise `KeyboardInterrupt` after the second call."""
137 |
138 | # We allow two calls in order to allow for the chat context handler to kick in
139 | def _mock_input(*args, **kwargs): # noqa: ARG001
140 | try:
141 | _mock_input.execution_counter += 1
142 | except AttributeError:
143 | _mock_input.execution_counter = 0
144 | if _mock_input.execution_counter > 1:
145 | raise KeyboardInterrupt
146 | return user_input
147 |
148 | mocker.patch( # noqa: PT008
149 | "builtins.input", new=lambda _: _mock_input(user_input=user_input)
150 | )
151 |
152 |
153 | @pytest.fixture(params=ChatOptions.get_allowed_values("model")[:2])
154 | def llm_model(request):
155 | return request.param
156 |
157 |
158 | context_model_values = ChatOptions.get_allowed_values("context_model")
159 |
160 |
161 | @pytest.fixture(params=[context_model_values[0], context_model_values[2]])
162 | def context_model(request):
163 | return request.param
164 |
165 |
166 | @pytest.fixture()
167 | def default_chat_configs(llm_model, context_model):
168 | return ChatOptions(model=llm_model, context_model=context_model)
169 |
170 |
171 | @pytest.fixture()
172 | def default_voice_chat_configs(llm_model, context_model):
173 | return VoiceChatConfigs(model=llm_model, context_model=context_model)
174 |
175 |
176 | @pytest.fixture()
177 | def cli_args_overrides(default_chat_configs):
178 | args = []
179 | for field, value in default_chat_configs.model_dump().items():
180 | if value not in [None, True, False]:
181 | args = [*args, *[f"--{field.replace('_', '-')}", str(value)]]
182 | return args
183 |
184 |
185 | @pytest.fixture()
186 | def default_chat(default_chat_configs):
187 | return Chat(configs=default_chat_configs)
188 |
189 |
190 | @pytest.fixture()
191 | def default_voice_chat(default_voice_chat_configs):
192 | chat = VoiceChat(configs=default_voice_chat_configs)
193 | chat.inactivity_timeout_seconds = 1e-5
194 | chat.tts_engine = "google"
195 | return chat
196 |
197 |
198 | @pytest.fixture(autouse=True)
199 | def _voice_chat_mockers(mocker, mock_wav_bytes_string):
200 | """Mockers for the text-to-speech module."""
201 | mocker.patch(
202 | "pyrobbot.voice_chat.VoiceChat._assistant_still_replying", return_value=False
203 | )
204 |
205 | mock_google_tts_obj = type("mock_gTTS", (), {})
206 | mock_openai_tts_response = type("mock_openai_tts_response", (), {})
207 |
208 | def _mock_iter_bytes(*args, **kwargs): # noqa: ARG001
209 | return [mock_wav_bytes_string]
210 |
211 | mock_openai_tts_response.iter_bytes = _mock_iter_bytes
212 |
213 | mocker.patch(
214 | "pydub.AudioSegment.from_mp3",
215 | return_value=pydub.AudioSegment.from_wav(io.BytesIO(mock_wav_bytes_string)),
216 | )
217 | mocker.patch("gtts.gTTS", return_value=mock_google_tts_obj)
218 | mocker.patch(
219 | "openai.resources.audio.speech.Speech.create",
220 | return_value=mock_openai_tts_response,
221 | )
222 | mock_transcription = type("MockTranscription", (), {})
223 | mock_transcription.text = "patched"
224 | mocker.patch(
225 | "openai.resources.audio.transcriptions.Transcriptions.create",
226 | return_value=mock_transcription,
227 | )
228 | mocker.patch(
229 | "speech_recognition.Recognizer.recognize_google",
230 | return_value=mock_transcription.text,
231 | )
232 |
233 | mocker.patch("webrtcvad.Vad.is_speech", return_value=False)
234 | mocker.patch("pygame.mixer.init")
235 | mocker.patch("chime.play_wav")
236 | mocker.patch("chime.play_wav")
237 |
--------------------------------------------------------------------------------
/tests/smoke/test_app.py:
--------------------------------------------------------------------------------
1 | import contextlib
2 |
3 | import streamlit
4 | import streamlit_webrtc.component
5 |
6 | from pyrobbot.app import app
7 |
8 |
9 | def test_app(mocker, default_voice_chat_configs):
10 | class MockAttrDict(streamlit.runtime.state.session_state_proxy.SessionStateProxy):
11 | def __getattr__(self, attr):
12 | return self.get(attr, mocker.MagicMock())
13 |
14 | def __getitem__(self, key):
15 | with contextlib.suppress(KeyError):
16 | return super().__getitem__(key)
17 | return mocker.MagicMock()
18 |
19 | mocker.patch.object(streamlit, "session_state", new=MockAttrDict())
20 | mocker.patch.object(
21 | streamlit.runtime.state.session_state_proxy,
22 | "SessionStateProxy",
23 | new=MockAttrDict,
24 | )
25 | mocker.patch("streamlit.chat_input", return_value="foobar")
26 | mocker.patch(
27 | "pyrobbot.chat_configs.VoiceChatConfigs.from_file",
28 | return_value=default_voice_chat_configs,
29 | )
30 | mocker.patch.object(
31 | streamlit_webrtc.component,
32 | "webrtc_streamer",
33 | mocker.MagicMock(return_value=mocker.MagicMock()),
34 | )
35 |
36 | mocker.patch("streamlit.number_input", return_value=0)
37 |
38 | mocker.patch(
39 | "pyrobbot.chat_configs.VoiceChatConfigs.model_validate",
40 | return_value=default_voice_chat_configs,
41 | )
42 |
43 | app.run_app()
44 |
--------------------------------------------------------------------------------
/tests/smoke/test_commands.py:
--------------------------------------------------------------------------------
1 | import io
2 | import subprocess
3 |
4 | import pytest
5 | from pydub import AudioSegment
6 |
7 | from pyrobbot.__main__ import main
8 | from pyrobbot.argparse_wrapper import get_parsed_args
9 |
10 |
11 | def test_default_command():
12 | args = get_parsed_args(argv=[])
13 | assert args.command == "ui"
14 |
15 |
16 | @pytest.mark.usefixtures("_input_builtin_mocker")
17 | @pytest.mark.parametrize("user_input", ["Hi!", ""], ids=["regular-input", "empty-input"])
18 | def test_terminal_command(cli_args_overrides):
19 | args = ["terminal", "--report-accounting-when-done", *cli_args_overrides]
20 | args = list(dict.fromkeys(args))
21 | main(args)
22 |
23 |
24 | def test_accounting_command():
25 | main(["accounting"])
26 |
27 |
28 | def test_ui_command(mocker, caplog):
29 | original_run = subprocess.run
30 |
31 | def new_run(*args, **kwargs):
32 | kwargs.pop("timeout", None)
33 | try:
34 | original_run(*args, **kwargs, timeout=0.5)
35 | except subprocess.TimeoutExpired as error:
36 | raise KeyboardInterrupt from error
37 |
38 | mocker.patch("subprocess.run", new=new_run)
39 | main(["ui"])
40 | assert "Exiting." in caplog.text
41 |
42 |
43 | @pytest.mark.parametrize("stt", ["google", "openai"])
44 | @pytest.mark.parametrize("tts", ["google", "openai"])
45 | def test_voice_chat(mocker, mock_wav_bytes_string, tts, stt):
46 | # We allow even number of calls in order to let the function be tested first and
47 | # then terminate the chat
48 | def _mock_listen(*args, **kwargs): # noqa: ARG001
49 | try:
50 | _mock_listen.execution_counter += 1
51 | except AttributeError:
52 | _mock_listen.execution_counter = 0
53 | if _mock_listen.execution_counter % 2:
54 | return None
55 | return AudioSegment.from_wav(io.BytesIO(mock_wav_bytes_string))
56 |
57 | mocker.patch("pyrobbot.voice_chat.VoiceChat.listen", _mock_listen)
58 | main(["voice", "--tts", tts, "--stt", stt])
59 |
--------------------------------------------------------------------------------
/tests/unit/test_chat.py:
--------------------------------------------------------------------------------
1 | import openai
2 | import pytest
3 |
4 | from pyrobbot import GeneralDefinitions
5 | from pyrobbot.chat import Chat
6 | from pyrobbot.chat_configs import ChatOptions
7 |
8 |
9 | @pytest.mark.order(1)
10 | @pytest.mark.usefixtures("_input_builtin_mocker")
11 | @pytest.mark.no_chat_completion_create_mocking()
12 | @pytest.mark.parametrize("user_input", ["regular-input"])
13 | def testbed_doesnt_actually_connect_to_openai(caplog):
14 | llm = ChatOptions.get_allowed_values("model")[0]
15 | context_model = ChatOptions.get_allowed_values("context_model")[0]
16 | chat_configs = ChatOptions(model=llm, context_model=context_model)
17 | chat = Chat(configs=chat_configs)
18 |
19 | chat.start()
20 | success = chat.response_failure_message().content in caplog.text
21 |
22 | err_msg = "Refuse to continue: Testbed is trying to connect to OpenAI API!"
23 | err_msg += f"\nThis is what the logger says:\n{caplog.text}"
24 | if not success:
25 | pytest.exit(err_msg)
26 |
27 |
28 | @pytest.mark.order(2)
29 | def test_we_are_using_tmp_cachedir():
30 | try:
31 | assert (
32 | pytest.original_package_cache_directory
33 | != GeneralDefinitions.PACKAGE_CACHE_DIRECTORY
34 | )
35 |
36 | except AssertionError:
37 | pytest.exit(
38 | "Refuse to continue: Tests attempted to use the package's real cache dir "
39 | + f"({GeneralDefinitions.PACKAGE_CACHE_DIRECTORY})!"
40 | )
41 |
42 |
43 | @pytest.mark.usefixtures("_input_builtin_mocker")
44 | @pytest.mark.parametrize("user_input", ["Hi!", ""], ids=["regular-input", "empty-input"])
45 | def test_terminal_chat(default_chat):
46 | default_chat.start()
47 | default_chat.__del__() # Just to trigger testing the custom del method
48 |
49 |
50 | def test_chat_configs(default_chat, default_chat_configs):
51 | assert default_chat._passed_configs == default_chat_configs
52 |
53 |
54 | @pytest.mark.no_chat_completion_create_mocking()
55 | @pytest.mark.usefixtures("_input_builtin_mocker")
56 | @pytest.mark.parametrize("user_input", ["regular-input"])
57 | def test_request_timeout_retry(mocker, default_chat, caplog):
58 | def _mock_openai_chat_completion_create(*args, **kwargs): # noqa: ARG001
59 | raise openai.APITimeoutError("Mocked timeout error was not caught!")
60 |
61 | mocker.patch(
62 | "openai.resources.chat.completions.Completions.create",
63 | new=_mock_openai_chat_completion_create,
64 | )
65 | mocker.patch("time.sleep") # Don't waste time sleeping in tests
66 | default_chat.start()
67 | assert "APITimeoutError" in caplog.text
68 |
69 |
70 | def test_can_read_chat_from_cache(default_chat):
71 | default_chat.save_cache()
72 | new_chat = Chat.from_cache(default_chat.cache_dir)
73 | assert new_chat.configs == default_chat.configs
74 |
75 |
76 | def test_create_from_cache_returns_default_chat_if_invalid_cachedir(default_chat, caplog):
77 | _ = Chat.from_cache(default_chat.cache_dir / "foobar")
78 | assert "Creating Chat with default configs" in caplog.text
79 |
80 |
81 | @pytest.mark.usefixtures("_input_builtin_mocker")
82 | @pytest.mark.parametrize("user_input", ["regular-input"])
83 | def test_internet_search_can_be_triggered(default_chat, mocker):
84 | mocker.patch(
85 | "pyrobbot.openai_utils.make_api_chat_completion_call", return_value=iter(["yes"])
86 | )
87 | mocker.patch("pyrobbot.chat.Chat.respond_system_prompt", return_value=iter(["yes"]))
88 | mocker.patch(
89 | "pyrobbot.internet_utils.raw_websearch",
90 | return_value=iter(
91 | [
92 | {
93 | "href": "foo/bar",
94 | "summary": 50 * "foo ",
95 | "detailed": 50 * "foo ",
96 | "relevance": 1.0,
97 | }
98 | ]
99 | ),
100 | )
101 | default_chat.start()
102 |
--------------------------------------------------------------------------------
/tests/unit/test_internet_utils.py:
--------------------------------------------------------------------------------
1 | import contextlib
2 |
3 | import duckduckgo_search
4 |
5 | from pyrobbot.internet_utils import websearch
6 |
7 | # if called inside tests or fixtures. Leave it like this for now.
8 | search_results = []
9 | with contextlib.suppress(duckduckgo_search.exceptions.DuckDuckGoSearchException):
10 | search_results = list(websearch("foobar"))
11 |
12 |
13 | def test_websearch():
14 | for i_result, result in enumerate(search_results):
15 | assert isinstance(result, dict)
16 | assert ("detailed" in result) == (i_result == 0)
17 | for key in ["summary", "relevance", "href"]:
18 | assert key in result
19 |
--------------------------------------------------------------------------------
/tests/unit/test_text_to_speech.py:
--------------------------------------------------------------------------------
1 | import pytest
2 | from pydub import AudioSegment
3 |
4 | from pyrobbot.sst_and_tts import SpeechToText
5 |
6 |
7 | @pytest.mark.parametrize("stt_engine", ["google", "openai"])
8 | def test_stt(default_voice_chat, stt_engine):
9 | """Test the speech-to-text method."""
10 | default_voice_chat.stt_engine = stt_engine
11 | stt = SpeechToText(
12 | openai_client=default_voice_chat.openai_client,
13 | speech=AudioSegment.silent(duration=100),
14 | engine=stt_engine,
15 | general_token_usage_db=default_voice_chat.general_token_usage_db,
16 | token_usage_db=default_voice_chat.token_usage_db,
17 | )
18 | assert stt.text == "patched"
19 |
--------------------------------------------------------------------------------
/tests/unit/test_voice_chat.py:
--------------------------------------------------------------------------------
1 | import contextlib
2 |
3 | import pytest
4 | from pydantic import ValidationError
5 | from pydub import AudioSegment
6 | from sounddevice import PortAudioError
7 |
8 | from pyrobbot.chat_configs import VoiceChatConfigs
9 | from pyrobbot.sst_and_tts import TextToSpeech
10 | from pyrobbot.voice_chat import VoiceChat
11 |
12 |
13 | def test_soundcard_import_check(mocker, caplog):
14 | """Test that the voice chat cannot be instantiated if soundcard is not imported."""
15 | mocker.patch("pyrobbot.voice_chat._sounddevice_imported", False)
16 | _ = VoiceChat(configs=VoiceChatConfigs())
17 | msg = "Module `sounddevice`, needed for local audio recording, is not available."
18 | assert msg in caplog.text
19 |
20 |
21 | @pytest.mark.parametrize("param_name", ["sample_rate", "frame_duration"])
22 | def test_cannot_instanciate_assistant_with_invalid_webrtcvad_params(param_name):
23 | """Test that the voice chat cannot be instantiated with invalid webrtcvad params."""
24 | with pytest.raises(ValidationError, match="Input should be"):
25 | VoiceChat(configs=VoiceChatConfigs(**{param_name: 1}))
26 |
27 |
28 | def test_listen(default_voice_chat):
29 | """Test the listen method."""
30 | with contextlib.suppress(PortAudioError, pytest.PytestUnraisableExceptionWarning):
31 | default_voice_chat.listen()
32 |
33 |
34 | def test_speak(default_voice_chat, mocker):
35 | tts = TextToSpeech(
36 | openai_client=default_voice_chat.openai_client,
37 | text="foo",
38 | general_token_usage_db=default_voice_chat.general_token_usage_db,
39 | token_usage_db=default_voice_chat.token_usage_db,
40 | )
41 | mocker.patch("pygame.mixer.Sound")
42 | mocker.patch("pyrobbot.voice_chat._get_lower_alphanumeric", return_value="ok cancel")
43 | mocker.patch(
44 | "pyrobbot.voice_chat.VoiceChat.listen",
45 | return_value=AudioSegment.silent(duration=150),
46 | )
47 | default_voice_chat.speak(tts)
48 |
49 |
50 | def test_answer_question(default_voice_chat):
51 | default_voice_chat.answer_question("foo")
52 |
53 |
54 | def test_interrupt_reply(default_voice_chat):
55 | default_voice_chat.interrupt_reply.set()
56 | default_voice_chat.questions_queue.get = lambda: None
57 | default_voice_chat.questions_queue.task_done = lambda: None
58 | default_voice_chat.start()
59 |
60 |
61 | def test_handle_interrupt_expressions(default_voice_chat, mocker):
62 | mocker.patch("pyrobbot.general_utils.str2_minus_str1", return_value="cancel")
63 | default_voice_chat.questions_queue.get = lambda: None
64 | default_voice_chat.questions_queue.task_done = lambda: None
65 | default_voice_chat.questions_queue.answer_question = lambda _question: None
66 | msgs_to_compare = {
67 | "assistant_txt": "foo",
68 | "user_audio": AudioSegment.silent(duration=150),
69 | }
70 | default_voice_chat.check_for_interrupt_expressions_queue.put(msgs_to_compare)
71 | default_voice_chat.start()
72 |
--------------------------------------------------------------------------------