├── .editorconfig ├── .gitattributes ├── .github ├── _typos.toml ├── copilot-instructions.md └── workflows │ ├── codeql-analysis.yml │ └── python.yml ├── .gitignore ├── .vscode ├── launch.json └── settings.json ├── LICENSE ├── README.md ├── assets ├── barplot_test4.png ├── distributions_test4.svg ├── scatterplot_basic1.png └── ts_lineplot5.svg ├── notebooks ├── base │ └── base.ipynb ├── llm │ ├── 01_openai_chat_completion.ipynb │ ├── 02_ollama_intro.ipynb │ ├── 03_llm_streaming.ipynb │ ├── 10_gpt-4-v.ipynb │ └── 20_embeddings.ipynb ├── statistics │ └── statistics.ipynb └── viz │ └── viz.ipynb ├── noxfile.py ├── pyproject.toml ├── src └── not_again_ai │ ├── __init__.py │ ├── base │ ├── __init__.py │ ├── file_system.py │ └── parallel.py │ ├── data │ ├── __init__.py │ ├── brave_search_api.py │ └── web.py │ ├── llm │ ├── __init__.py │ ├── chat_completion │ │ ├── __init__.py │ │ ├── interface.py │ │ ├── providers │ │ │ ├── __init__.py │ │ │ ├── anthropic_api.py │ │ │ ├── gemini_api.py │ │ │ ├── ollama_api.py │ │ │ └── openai_api.py │ │ └── types.py │ ├── embedding │ │ ├── __init__.py │ │ ├── interface.py │ │ ├── providers │ │ │ ├── __init__.py │ │ │ ├── ollama_api.py │ │ │ └── openai_api.py │ │ └── types.py │ ├── image_gen │ │ ├── __init__.py │ │ ├── interface.py │ │ ├── providers │ │ │ ├── __init__.py │ │ │ └── openai_api.py │ │ └── types.py │ └── prompting │ │ ├── __init__.py │ │ ├── compile_prompt.py │ │ ├── interface.py │ │ ├── providers │ │ ├── __init__.py │ │ └── openai_tiktoken.py │ │ └── types.py │ ├── py.typed │ ├── statistics │ ├── __init__.py │ └── dependence.py │ └── viz │ ├── __init__.py │ ├── barplots.py │ ├── distributions.py │ ├── scatterplot.py │ ├── time_series.py │ └── utils.py ├── tests ├── __init__.py ├── base │ ├── __init__.py │ ├── test_file_system.py │ └── test_parallel.py ├── data │ ├── __init__.py │ ├── test_brave_search_api.py │ └── test_web.py ├── llm │ ├── __init__.py │ ├── chat_completion │ │ ├── __init__.py │ │ ├── test_chat_completion.py │ │ └── test_chat_completion_stream.py │ ├── embedding │ │ ├── __init__.py │ │ └── test_embedding.py │ ├── image_gen │ │ └── test_image_gen.py │ ├── prompting │ │ ├── __init__.py │ │ ├── test_compile_messages.py │ │ └── test_tokenizer.py │ └── sample_images │ │ ├── SKDiagram.png │ │ ├── SKInfographic.png │ │ ├── body_lotion.png │ │ ├── cat.jpg │ │ ├── dog.jpg │ │ ├── numbers.png │ │ ├── soap.png │ │ ├── sunlit_lounge.png │ │ └── sunlit_lounge_mask.png ├── statistics │ ├── __init__.py │ └── test_dependence.py └── viz │ ├── __init__.py │ ├── test_barplot.py │ ├── test_distributions.py │ ├── test_scatterplot.py │ └── test_time_series.py └── uv.lock /.editorconfig: -------------------------------------------------------------------------------- 1 | # For information about this file, see: https://editorconfig.org/ 2 | root = true 3 | 4 | # For ease of fitting multiple editor panes side by side consistently, set all files types to use 5 | # the same relaxed max line length permitted in PEP 8. 6 | [*] 7 | max_line_length = 120 8 | -------------------------------------------------------------------------------- /.gitattributes: -------------------------------------------------------------------------------- 1 | # Auto detect text files and perform LF normalization 2 | * text=auto -------------------------------------------------------------------------------- /.github/_typos.toml: -------------------------------------------------------------------------------- 1 | [default.extend-identifiers] 2 | arange = "arange" # np.arange 3 | 4 | [files] 5 | extend-exclude = ["*.ipynb"] -------------------------------------------------------------------------------- /.github/copilot-instructions.md: -------------------------------------------------------------------------------- 1 | # Python Rules 2 | - The user is using Python version >= 3.11 with uv as the Python package and project manager. 3 | - Follow the Google Python Style Guide. 4 | - Instead of importing `Optional` from typing, using the `| `syntax. 5 | - Always add appropriate type hintssuch that the code would pass a mypy type check. 6 | - For type hints, use `list`, not `List`. For example, if the variable is `[{"name": "Jane", "age": 32}, {"name": "Amy", "age": 28}]` the type hint should be `list[dict]` 7 | - If the user is using Pydantic, it is version >=2.10 8 | - Always prefer pathlib for dealing with files. Use `Path.open` instead of `open`. 9 | - Prefer to use pendulum instead of datetime 10 | - Prefer to use loguru instead of logging 11 | - Prefer httpx for HTTP requests instead of requests 12 | -------------------------------------------------------------------------------- /.github/workflows/codeql-analysis.yml: -------------------------------------------------------------------------------- 1 | # For most projects, this workflow file will not need changing; you simply need 2 | # to commit it to your repository. 3 | # 4 | # You may wish to alter this file to override the set of languages analyzed, 5 | # or to provide custom queries or build logic. 6 | # 7 | # ******** NOTE ******** 8 | # We have attempted to detect the languages in your repository. Please check 9 | # the `language` matrix defined below to confirm you have the correct set of 10 | # supported CodeQL languages. 11 | # 12 | name: "CodeQL" 13 | 14 | on: 15 | push: 16 | branches: [ "main" ] 17 | pull_request: 18 | # The branches below must be a subset of the branches above 19 | branches: [ "main" ] 20 | schedule: 21 | - cron: '30 5 * * *' 22 | 23 | jobs: 24 | analyze: 25 | name: Analyze 26 | runs-on: ubuntu-latest 27 | permissions: 28 | actions: read 29 | contents: read 30 | security-events: write 31 | 32 | strategy: 33 | fail-fast: false 34 | matrix: 35 | language: [ 'python' ] 36 | # CodeQL supports [ 'cpp', 'csharp', 'go', 'java', 'javascript', 'python', 'ruby' ] 37 | # Learn more about CodeQL language support at https://aka.ms/codeql-docs/language-support 38 | 39 | steps: 40 | - name: Checkout repository 41 | uses: actions/checkout@v3 42 | 43 | # Initializes the CodeQL tools for scanning. 44 | - name: Initialize CodeQL 45 | uses: github/codeql-action/init@v2 46 | with: 47 | languages: ${{ matrix.language }} 48 | # If you wish to specify custom queries, you can do so here or in a config file. 49 | # By default, queries listed here will override any specified in a config file. 50 | # Prefix the list here with "+" to use these queries and those in the config file. 51 | 52 | # Details on CodeQL's query packs refer to : https://docs.github.com/en/code-security/code-scanning/automatically-scanning-your-code-for-vulnerabilities-and-errors/configuring-code-scanning#using-queries-in-ql-packs 53 | # queries: security-extended,security-and-quality 54 | 55 | 56 | # Autobuild attempts to build any compiled languages (C/C++, C#, or Java). 57 | # If this step fails, then you should remove it and run the build manually (see below) 58 | - name: Autobuild 59 | uses: github/codeql-action/autobuild@v2 60 | 61 | # ℹ️ Command-line programs to run using the OS shell. 62 | # 📚 See https://docs.github.com/en/actions/using-workflows/workflow-syntax-for-github-actions#jobsjob_idstepsrun 63 | 64 | # If the Autobuild fails above, remove it and uncomment the following three lines. 65 | # modify them (or add more) to build your code if your project, please refer to the EXAMPLE below for guidance. 66 | 67 | # - run: | 68 | # echo "Run, Build Application using script" 69 | # ./location_of_script_within_repo/buildscript.sh 70 | 71 | - name: Perform CodeQL Analysis 72 | uses: github/codeql-action/analyze@v2 73 | -------------------------------------------------------------------------------- /.github/workflows/python.yml: -------------------------------------------------------------------------------- 1 | name: python 2 | 3 | on: 4 | workflow_dispatch: 5 | pull_request: 6 | branches: ["main"] 7 | push: 8 | branches: ["main"] 9 | 10 | env: 11 | UV_VERSION: "0.6.9" 12 | PYTHON_VERSION: "3.12" 13 | 14 | jobs: 15 | test: 16 | runs-on: ubuntu-24.04 17 | strategy: 18 | matrix: 19 | python-version: [ "3.11", "3.12" ] 20 | steps: 21 | - uses: actions/checkout@v4 22 | - name: Install uv 23 | uses: astral-sh/setup-uv@v5 24 | # Caching is enabled by default for GitHub-hosted runners: 25 | # https://github.com/astral-sh/setup-uv?tab=readme-ov-file#enable-caching 26 | with: 27 | version: ${{ env.UV_VERSION }} 28 | - name: Set up Python ${{ matrix.python-version }} 29 | uses: actions/setup-python@v5 30 | with: 31 | python-version: ${{ matrix.python-version }} 32 | - name: Install Nox 33 | run: uv sync --locked --all-extras 34 | - name: Test with Nox 35 | env: 36 | OPENAI_API_KEY: ${{ secrets.OPENAI_API_KEY }} 37 | OPENAI_ORG_ID: ${{ secrets.OPENAI_ORG_ID }} 38 | SKIP_TESTS_NAAI: "tests/llm/chat_completion tests/llm/embedding tests/llm/image_gen tests/data" 39 | run: uv run nox -s test-${{ matrix.python-version }} 40 | quality: 41 | runs-on: ubuntu-24.04 42 | strategy: 43 | matrix: 44 | nox-session: ["lint", "type_check", "typos"] 45 | steps: 46 | - uses: actions/checkout@v4 47 | - name: Install uv 48 | uses: astral-sh/setup-uv@v5 49 | with: 50 | version: ${{ env.UV_VERSION }} 51 | - name: Set up Python 52 | uses: actions/setup-python@v5 53 | with: 54 | python-version: ${{ env.PYTHON_VERSION }} 55 | - name: Install dependencies 56 | run: uv sync --locked --all-extras 57 | - name: Test with Nox 58 | run: uv run nox -s ${{ matrix.nox-session }} 59 | lock-check: 60 | runs-on: ubuntu-24.04 61 | steps: 62 | - uses: actions/checkout@v4 63 | - name: Install uv 64 | uses: astral-sh/setup-uv@v5 65 | with: 66 | version: ${{ env.UV_VERSION }} 67 | - name: Set up Python 68 | uses: actions/setup-python@v5 69 | with: 70 | python-version: ${{ env.PYTHON_VERSION }} 71 | - name: Validate Lockfile Up-to-date 72 | run: uv lock --check -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Project Specific 2 | notebooks/**/*.png 3 | 4 | # Byte-compiled / optimized / DLL files 5 | __pycache__/ 6 | *.py[cod] 7 | *$py.class 8 | 9 | # C extensions 10 | *.so 11 | 12 | # Distribution / packaging 13 | .Python 14 | build/ 15 | develop-eggs/ 16 | dist/ 17 | downloads/ 18 | eggs/ 19 | .eggs/ 20 | lib/ 21 | lib64/ 22 | parts/ 23 | sdist/ 24 | var/ 25 | wheels/ 26 | pip-wheel-metadata/ 27 | share/python-wheels/ 28 | *.egg-info/ 29 | .installed.cfg 30 | *.egg 31 | MANIFEST 32 | 33 | # PyInstaller 34 | # Usually these files are written by a python script from a template 35 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 36 | *.manifest 37 | *.spec 38 | 39 | # Installer logs 40 | pip-log.txt 41 | pip-delete-this-directory.txt 42 | 43 | # Unit test / coverage reports 44 | htmlcov/ 45 | .tox/ 46 | .nox/ 47 | .coverage 48 | .coverage.* 49 | .cache 50 | nosetests.xml 51 | coverage.xml 52 | *.cover 53 | *.py,cover 54 | .hypothesis/ 55 | .pytest_cache/ 56 | 57 | # Translations 58 | *.mo 59 | *.pot 60 | 61 | # Django stuff: 62 | *.log 63 | local_settings.py 64 | db.sqlite3 65 | db.sqlite3-journal 66 | 67 | # Flask stuff: 68 | instance/ 69 | .webassets-cache 70 | 71 | # Scrapy stuff: 72 | .scrapy 73 | 74 | # Sphinx documentation 75 | docs/_build/ 76 | 77 | # PyBuilder 78 | target/ 79 | 80 | # Jupyter Notebook 81 | .ipynb_checkpoints 82 | 83 | # IPython 84 | profile_default/ 85 | ipython_config.py 86 | 87 | # pyenv 88 | .python-version 89 | 90 | # pipenv 91 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 92 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 93 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 94 | # install all needed dependencies. 95 | #Pipfile.lock 96 | 97 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 98 | __pypackages__/ 99 | 100 | # Celery stuff 101 | celerybeat-schedule 102 | celerybeat.pid 103 | 104 | # SageMath parsed files 105 | *.sage.py 106 | 107 | # Environments 108 | .env 109 | .venv 110 | env/ 111 | venv/ 112 | ENV/ 113 | env.bak/ 114 | venv.bak/ 115 | 116 | # Spyder project settings 117 | .spyderproject 118 | .spyproject 119 | 120 | # Rope project settings 121 | .ropeproject 122 | 123 | # mkdocs documentation 124 | /site 125 | 126 | # mypy 127 | .mypy_cache/ 128 | .dmypy.json 129 | dmypy.json 130 | 131 | # Pyre type checker 132 | .pyre/ 133 | 134 | # Ruff 135 | .ruff_cache/ -------------------------------------------------------------------------------- /.vscode/launch.json: -------------------------------------------------------------------------------- 1 | { 2 | // Use IntelliSense to learn about possible attributes. 3 | // Hover to view descriptions of existing attributes. 4 | // For more information, visit: https://go.microsoft.com/fwlink/?linkid=830387 5 | "version": "0.2.0", 6 | "configurations": [ 7 | { 8 | "name": "Python Debugger: Current File", 9 | "type": "debugpy", 10 | "request": "launch", 11 | "program": "${file}", 12 | "console": "integratedTerminal", 13 | "justMyCode": true 14 | } 15 | ] 16 | } -------------------------------------------------------------------------------- /.vscode/settings.json: -------------------------------------------------------------------------------- 1 | { 2 | "ruff.configuration": "./pyproject.toml", 3 | "notebook.formatOnSave.enabled": true, 4 | "notebook.codeActionsOnSave": { 5 | "notebook.source.fixAll": "explicit", 6 | "notebook.source.organizeImports": "explicit" 7 | }, 8 | "[python]": { 9 | "editor.formatOnSave": true, 10 | "editor.codeActionsOnSave": { 11 | "source.fixAll": "explicit", 12 | "source.organizeImports": "explicit" 13 | }, 14 | "editor.defaultFormatter": "charliermarsh.ruff" 15 | }, 16 | "python.testing.unittestEnabled": false, 17 | "python.testing.pytestEnabled": true, 18 | "python.testing.pytestArgs": ["-s"], 19 | "markdown.extension.orderedList.marker": "one", 20 | } -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2022-2024 DaveCoDev 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # not-again-ai 2 | 3 | [![GitHub Actions][github-actions-badge]](https://github.com/johnthagen/python-blueprint/actions) 4 | [![uv][uv-badge]](https://github.com/astral-sh/uv) 5 | [![Nox][nox-badge]](https://github.com/wntrblm/nox) 6 | [![Ruff][ruff-badge]](https://github.com/astral-sh/ruff) 7 | [![Type checked with mypy][mypy-badge]](https://mypy-lang.org/) 8 | 9 | [github-actions-badge]: https://github.com/johnthagen/python-blueprint/workflows/python/badge.svg 10 | [uv-badge]: https://img.shields.io/endpoint?url=https://raw.githubusercontent.com/astral-sh/uv/main/assets/badge/v0.json 11 | [nox-badge]: https://img.shields.io/badge/%F0%9F%A6%8A-Nox-D85E00.svg 12 | [black-badge]: https://img.shields.io/badge/code%20style-black-000000.svg 13 | [ruff-badge]: https://img.shields.io/endpoint?url=https://raw.githubusercontent.com/astral-sh/ruff/main/assets/badge/v2.json 14 | [mypy-badge]: https://www.mypy-lang.org/static/mypy_badge.svg 15 | 16 | **not-again-ai** is a collection of various building blocks that come up over and over again when developing AI products. 17 | The key goals of this package are to have simple, yet flexible interfaces and to minimize dependencies. 18 | It is encouraged to also **a)** use this as a template for your own Python package. 19 | **b)** instead of installing the package, copy and paste functions into your own projects. 20 | We make this easier by limiting the number of dependencies and use an MIT license. 21 | 22 | **Documentation** available within individual **[notebooks](notebooks)** or docstrings within the source code. 23 | 24 | # Installation 25 | 26 | Requires: Python 3.11, or 3.12 which can be installed with [uv](https://docs.astral.sh/uv/getting-started/installation/) by running the command `uv python install 3.12` 27 | 28 | Install the entire package from [PyPI](https://pypi.org/project/not-again-ai/) with: 29 | 30 | ```bash 31 | $ pip install not_again_ai[data,llm,statistics,viz] 32 | ``` 33 | 34 | The package is split into subpackages, so you can install only the parts you need. 35 | 36 | ### Base 37 | 1. `pip install not_again_ai` 38 | 39 | 40 | ### Data 41 | 1. `pip install not_again_ai[data]` 42 | 1. `crawl4ai-setup` to run crawl4ai post-installation setup. 43 | 1. Set the `BRAVE_SEARCH_API_KEY` environment variable to use the Brave Search API for web data extraction. 44 | 1. Get the API key from https://api-dashboard.search.brave.com/app/keys. You must have at least the Free "Data for Search" subscription. 45 | 46 | 47 | ### LLM 48 | 1. `pip install not_again_ai[llm]` 49 | 1. Setup OpenAI API 50 | 1. Go to https://platform.openai.com/settings/profile?tab=api-keys to get your API key. 51 | 1. (Optional) Set the `OPENAI_API_KEY` and the `OPENAI_ORG_ID` environment variables. 52 | 1. Setup Azure OpenAI (AOAI) 53 | 1. Using AOAI requires using Entra ID authentication. See https://learn.microsoft.com/en-us/azure/ai-services/openai/how-to/managed-identity for how to set this up for your AOAI deployment. 54 | * Requires the correct role assigned to your user account and being signed into the Azure CLI. 55 | 1. (Optional) Set the `AZURE_OPENAI_ENDPOINT` environment variable. 56 | 1. If you wish to use Ollama: 57 | 1. Follow the instructions at https://github.com/ollama/ollama to install Ollama for your system. 58 | 1. (Optional) [Add Ollama as a startup service (recommended)](https://github.com/ollama/ollama/blob/main/docs/linux.md#adding-ollama-as-a-startup-service-recommended) 59 | 1. (Optional) To make the Ollama service accessible on your local network from a Linux server, add the following to the `/etc/systemd/system/ollama.service` file which will make Ollama available at `http://:11434`: 60 | ```bash 61 | [Service] 62 | ... 63 | Environment="OLLAMA_HOST=0.0.0.0" 64 | ``` 65 | 1. It is recommended to always have the latest version of Ollama. To update Ollama check the [docs](https://github.com/ollama/ollama/blob/main/docs/). The command for Linux is: `curl -fsSL https://ollama.com/install.sh | sh` 66 | 67 | 68 | ### Statistics 69 | 1. `pip install not_again_ai[statistics]` 70 | 71 | 72 | ### Visualization 73 | 1. `pip install not_again_ai[viz]` 74 | 75 | 76 | # Development Information 77 | 78 | This package uses [uv](https://docs.astral.sh/uv/) to manage dependencies and 79 | isolated [Python virtual environments](https://docs.python.org/3/library/venv.html). 80 | 81 | To proceed, 82 | [install uv globally](https://docs.astral.sh/uv/getting-started/installation/) 83 | onto your system. 84 | 85 | To install a specific version of Python: 86 | 87 | ```shell 88 | uv python install 3.12 89 | ``` 90 | 91 | ## Dependencies 92 | 93 | Dependencies are defined in [`pyproject.toml`](./pyproject.toml) and specific versions are locked 94 | into [`uv.lock`](./uv.lock). This allows for exact reproducible environments across 95 | all machines that use the project, both during development and in production. 96 | 97 | To install all dependencies into an isolated virtual environment: 98 | 99 | ```shell 100 | uv sync --all-extras --all-groups 101 | ``` 102 | 103 | To upgrade all dependencies to their latest versions: 104 | 105 | ```shell 106 | uv lock --upgrade 107 | ``` 108 | 109 | ## Packaging 110 | 111 | This project is designed as a Python package, meaning that it can be bundled up and redistributed 112 | as a single compressed file. 113 | 114 | Packaging is configured by the [`pyproject.toml`](./pyproject.toml). 115 | 116 | To package the project as both a 117 | [source distribution](https://packaging.python.org/en/latest/flow/#the-source-distribution-sdist) and 118 | a [wheel](https://packaging.python.org/en/latest/specifications/binary-distribution-format/): 119 | 120 | ```bash 121 | $ uv build 122 | ``` 123 | 124 | This will generate `dist/not-again-ai-.tar.gz` and `dist/not_again_ai--py3-none-any.whl`. 125 | 126 | 127 | ## Publish Distributions to PyPI 128 | 129 | Source and wheel redistributable packages can 130 | be [published to PyPI](https://docs.astral.sh/uv/guides/package/) or installed 131 | directly from the filesystem using `pip`. 132 | 133 | ```shell 134 | uv publish 135 | ``` 136 | 137 | # Enforcing Code Quality 138 | 139 | Automated code quality checks are performed using [Nox](https://nox.thea.codes/en/stable/). Nox 140 | will automatically create virtual environments and run commands based on 141 | [`noxfile.py`](./noxfile.py) for unit testing, PEP 8 style guide checking, type checking and 142 | documentation generation. 143 | 144 | To run all default sessions: 145 | 146 | ```shell 147 | uv run nox 148 | ``` 149 | 150 | ## Unit Testing 151 | 152 | Unit testing is performed with [pytest](https://pytest.org/). pytest has become the de facto Python 153 | unit testing framework. Some key advantages over the built-in 154 | [unittest](https://docs.python.org/3/library/unittest.html) module are: 155 | 156 | 1. Significantly less boilerplate needed for tests. 157 | 2. PEP 8 compliant names (e.g. `pytest.raises()` instead of `self.assertRaises()`). 158 | 3. Vibrant ecosystem of plugins. 159 | 160 | pytest will automatically discover and run tests by recursively searching for folders and `.py` 161 | files prefixed with `test` for any functions prefixed by `test`. 162 | 163 | The `tests` folder is created as a Python package (i.e. there is an `__init__.py` file within it) 164 | because this helps `pytest` uniquely namespace the test files. Without this, two test files cannot 165 | be named the same, even if they are in different subdirectories. 166 | 167 | Code coverage is provided by the [pytest-cov](https://pytest-cov.readthedocs.io/en/latest/) plugin. 168 | 169 | When running a unit test Nox session (e.g. `nox -s test`), an HTML report is generated in 170 | the `htmlcov` folder showing each source file and which lines were executed during unit testing. 171 | Open `htmlcov/index.html` in a web browser to view the report. Code coverage reports help identify 172 | areas of the project that are currently not tested. 173 | 174 | pytest and code coverage are configured in [`pyproject.toml`](./pyproject.toml). 175 | 176 | To run selected tests: 177 | 178 | ```bash 179 | (.venv) $ uv run nox -s test -- -k "test_web" 180 | ``` 181 | 182 | ## Code Style Checking 183 | 184 | [PEP 8](https://peps.python.org/pep-0008/) is the universally accepted style guide for Python 185 | code. PEP 8 code compliance is verified using [Ruff][Ruff]. Ruff is configured in the 186 | `[tool.ruff]` section of [`pyproject.toml`](./pyproject.toml). 187 | 188 | [Ruff]: https://github.com/astral-sh/ruff 189 | 190 | To lint code, run: 191 | 192 | ```bash 193 | (.venv) $ uv run nox -s lint 194 | ``` 195 | 196 | To automatically fix fixable lint errors, run: 197 | 198 | ```bash 199 | (.venv) $ uv run nox -s lint_fix 200 | ``` 201 | 202 | ## Automated Code Formatting 203 | 204 | [Ruff][Ruff] is used to automatically format code and group and sort imports. 205 | 206 | To automatically format code, run: 207 | 208 | ```bash 209 | (.venv) $ uv run nox -s fmt 210 | ``` 211 | 212 | To verify code has been formatted, such as in a CI job: 213 | 214 | ```bash 215 | (.venv) $ uv run nox -s fmt_check 216 | ``` 217 | 218 | ## Type Checking 219 | 220 | [Type annotations](https://docs.python.org/3/library/typing.html) allows developers to include 221 | optional static typing information to Python source code. This allows static analyzers such 222 | as [mypy](http://mypy-lang.org/), [PyCharm](https://www.jetbrains.com/pycharm/), 223 | or [Pyright](https://github.com/microsoft/pyright) to check that functions are used with the 224 | correct types before runtime. 225 | 226 | 227 | ```python 228 | def factorial(n: int) -> int: 229 | ... 230 | ``` 231 | 232 | mypy is configured in [`pyproject.toml`](./pyproject.toml). To type check code, run: 233 | 234 | ```bash 235 | (.venv) $ uv run nox -s type_check 236 | ``` 237 | 238 | ### Distributing Type Annotations 239 | 240 | [PEP 561](https://www.python.org/dev/peps/pep-0561/) defines how a Python package should 241 | communicate the presence of inline type annotations to static type 242 | checkers. [mypy's documentation](https://mypy.readthedocs.io/en/stable/installed_packages.html) 243 | provides further examples on how to do this. 244 | 245 | Mypy looks for the existence of a file named [`py.typed`](./src/not-again-ai/py.typed) in the root of the 246 | installed package to indicate that inline type annotations should be checked. 247 | 248 | ## Typos 249 | 250 | Check for typos using [typos](https://github.com/crate-ci/typos) 251 | 252 | ```bash 253 | (.venv) $ uv run nox -s typos 254 | ``` 255 | 256 | ## Continuous Integration 257 | 258 | Continuous integration is provided by [GitHub Actions](https://github.com/features/actions). This 259 | runs all tests, lints, and type checking for every commit and pull request to the repository. 260 | 261 | GitHub Actions is configured in [`.github/workflows/python.yml`](./.github/workflows/python.yml). 262 | 263 | ## [Visual Studio Code](https://code.visualstudio.com/docs/languages/python) 264 | 265 | Install the [Python extension](https://marketplace.visualstudio.com/items?itemName=ms-python.python) for VSCode. 266 | 267 | Install the [Ruff extension](https://marketplace.visualstudio.com/items?itemName=charliermarsh.ruff) for VSCode. 268 | 269 | Default settings are configured in [`.vscode/settings.json`](./.vscode/settings.json) which will enable Ruff with consistent settings. 270 | 271 | # Attributions 272 | [python-blueprint](https://github.com/johnthagen/python-blueprint) for the Python package skeleton. 273 | 274 | This project uses Crawl4AI (https://github.com/unclecode/crawl4ai) for web data extraction. 275 | -------------------------------------------------------------------------------- /assets/barplot_test4.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DaveCoDev/not-again-ai/93302ff4b8b9eda5398e2c318854677f517445c8/assets/barplot_test4.png -------------------------------------------------------------------------------- /assets/scatterplot_basic1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DaveCoDev/not-again-ai/93302ff4b8b9eda5398e2c318854677f517445c8/assets/scatterplot_basic1.png -------------------------------------------------------------------------------- /notebooks/base/base.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | "# Base not-again-ai contains various helper functions\n", 8 | "\n", 9 | "Ranging from file system operations to parallelization.\n" 10 | ] 11 | }, 12 | { 13 | "cell_type": "markdown", 14 | "metadata": {}, 15 | "source": [ 16 | "## File System\n" 17 | ] 18 | }, 19 | { 20 | "cell_type": "code", 21 | "execution_count": 1, 22 | "metadata": {}, 23 | "outputs": [], 24 | "source": [ 25 | "import shutil\n", 26 | "\n", 27 | "from not_again_ai.base.file_system import create_file_dir\n", 28 | "\n", 29 | "# Create a directory and its parent directories for a specified Path.\n", 30 | "create_file_dir(\"test/test.txt\")\n", 31 | "\n", 32 | "# Cleanup the directory.\n", 33 | "shutil.rmtree(\"test\")" 34 | ] 35 | }, 36 | { 37 | "cell_type": "code", 38 | "execution_count": 2, 39 | "metadata": {}, 40 | "outputs": [ 41 | { 42 | "data": { 43 | "text/plain": [ 44 | "'1.07 KB'" 45 | ] 46 | }, 47 | "execution_count": 2, 48 | "metadata": {}, 49 | "output_type": "execute_result" 50 | } 51 | ], 52 | "source": [ 53 | "from not_again_ai.base.file_system import readable_size\n", 54 | "\n", 55 | "# Convert a size in bytes to a human-readable format.\n", 56 | "readable_size(1099) # 1.07 KB" 57 | ] 58 | }, 59 | { 60 | "cell_type": "markdown", 61 | "metadata": {}, 62 | "source": [ 63 | "## Parallelization\n" 64 | ] 65 | }, 66 | { 67 | "cell_type": "code", 68 | "execution_count": 3, 69 | "metadata": {}, 70 | "outputs": [ 71 | { 72 | "data": { 73 | "text/plain": [ 74 | "[8, 2]" 75 | ] 76 | }, 77 | "execution_count": 3, 78 | "metadata": {}, 79 | "output_type": "execute_result" 80 | } 81 | ], 82 | "source": [ 83 | "from not_again_ai.base.parallel import embarrassingly_parallel_simple\n", 84 | "\n", 85 | "\n", 86 | "# embarrassingly_parallel_simple allows you to execute a list of functions (that take no arguments) in parallel and returns the results in the order the functions were provided.\n", 87 | "def do_something() -> int:\n", 88 | " return 8\n", 89 | "\n", 90 | "\n", 91 | "def do_something2() -> int:\n", 92 | " return 2\n", 93 | "\n", 94 | "\n", 95 | "result = embarrassingly_parallel_simple([do_something, do_something2], num_processes=2)\n", 96 | "result # [8, 2]" 97 | ] 98 | }, 99 | { 100 | "cell_type": "code", 101 | "execution_count": 4, 102 | "metadata": {}, 103 | "outputs": [ 104 | { 105 | "data": { 106 | "text/plain": [ 107 | "[4, 9, 16]" 108 | ] 109 | }, 110 | "execution_count": 4, 111 | "metadata": {}, 112 | "output_type": "execute_result" 113 | } 114 | ], 115 | "source": [ 116 | "import multiprocessing\n", 117 | "import random\n", 118 | "import time\n", 119 | "\n", 120 | "from not_again_ai.base.parallel import embarrassingly_parallel\n", 121 | "\n", 122 | "\n", 123 | "# Simulate a function that takes some time\n", 124 | "def multby2(x: float, y: float) -> float:\n", 125 | " time.sleep(random.uniform(0, 1))\n", 126 | " return x * y\n", 127 | "\n", 128 | "\n", 129 | "args = ((2, 2), (3, 3), (4, 4))\n", 130 | "result = embarrassingly_parallel(multby2, args, None, num_processes=multiprocessing.cpu_count())\n", 131 | "result" 132 | ] 133 | } 134 | ], 135 | "metadata": { 136 | "kernelspec": { 137 | "display_name": ".venv", 138 | "language": "python", 139 | "name": "python3" 140 | }, 141 | "language_info": { 142 | "codemirror_mode": { 143 | "name": "ipython", 144 | "version": 3 145 | }, 146 | "file_extension": ".py", 147 | "mimetype": "text/x-python", 148 | "name": "python", 149 | "nbconvert_exporter": "python", 150 | "pygments_lexer": "ipython3", 151 | "version": "3.12.3" 152 | } 153 | }, 154 | "nbformat": 4, 155 | "nbformat_minor": 2 156 | } 157 | -------------------------------------------------------------------------------- /notebooks/llm/01_openai_chat_completion.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | "# Using OpenAI Chat Completions\n", 8 | "\n", 9 | "This notebook covers how to use the Chat Completions API and other features such as creating prompts and function calling.\n" 10 | ] 11 | }, 12 | { 13 | "cell_type": "markdown", 14 | "metadata": {}, 15 | "source": [ 16 | "## Instantiating the OpenAI Client\n", 17 | "\n", 18 | "The OpenAI client object is used to get responses from the API. This will automatically read your API and org key from your environment variables.\n", 19 | "\n", 20 | "You can optionally pass in your API key and org key as arguments: `api_key` and `organization`.\n" 21 | ] 22 | }, 23 | { 24 | "cell_type": "code", 25 | "execution_count": 1, 26 | "metadata": {}, 27 | "outputs": [], 28 | "source": [ 29 | "from not_again_ai.llm.chat_completion.providers.openai_api import openai_client\n", 30 | "\n", 31 | "client = openai_client()" 32 | ] 33 | }, 34 | { 35 | "cell_type": "markdown", 36 | "metadata": {}, 37 | "source": [ 38 | "## Basic Chat Completion\n", 39 | "\n", 40 | "The `chat_completion` function is an easy way to get responses from OpenAI models.\n", 41 | "It requires the prompt to the model to be formatted in the chat completion format,\n", 42 | "see the [API reference](https://platform.openai.com/docs/api-reference/chat/create) for more details.\n" 43 | ] 44 | }, 45 | { 46 | "cell_type": "code", 47 | "execution_count": 2, 48 | "metadata": {}, 49 | "outputs": [ 50 | { 51 | "data": { 52 | "text/plain": [ 53 | "'Hello! How can I assist you today?'" 54 | ] 55 | }, 56 | "execution_count": 2, 57 | "metadata": {}, 58 | "output_type": "execute_result" 59 | } 60 | ], 61 | "source": [ 62 | "from not_again_ai.llm.chat_completion import chat_completion\n", 63 | "from not_again_ai.llm.chat_completion.types import ChatCompletionRequest, SystemMessage, UserMessage\n", 64 | "\n", 65 | "messages = [\n", 66 | " SystemMessage(content=\"You are a helpful assistant.\"),\n", 67 | " UserMessage(content=\"Hello!\"),\n", 68 | "]\n", 69 | "request = ChatCompletionRequest(\n", 70 | " messages=messages,\n", 71 | " model=\"gpt-4o-mini-2024-07-18\",\n", 72 | " max_completion_tokens=100,\n", 73 | ")\n", 74 | "response = chat_completion(request, \"openai\", client)\n", 75 | "\n", 76 | "response.choices[0].message.content" 77 | ] 78 | }, 79 | { 80 | "cell_type": "markdown", 81 | "metadata": {}, 82 | "source": [ 83 | "## Creating Prompts\n", 84 | "\n", 85 | "Injecting variables into prompts is a common task and we provide the `chat_prompt` which uses [Liquid templating](https://jg-rp.github.io/liquid/).\n", 86 | "\n", 87 | "In the `messages` argument, the \"content\" field can be a [Python Liquid](https://jg-rp.github.io/liquid/introduction/getting-started) template string to allow for more dynamic prompts which not only supports variable injection, but also conditional logic, loops, and comments.\n" 88 | ] 89 | }, 90 | { 91 | "cell_type": "code", 92 | "execution_count": 3, 93 | "metadata": {}, 94 | "outputs": [ 95 | { 96 | "data": { 97 | "text/plain": [ 98 | "[SystemMessage(content='- You are a helpful assistant trying to extract places that occur in a given text.\\n- You must identify all the places in the text and return them in a list like this: [\"place1\", \"place2\", \"place3\"].', role=, name=None),\n", 99 | " UserMessage(content='Here is the text I want you to extract places from:\\nI went to Paris and Berlin.', role=, name=None)]" 100 | ] 101 | }, 102 | "execution_count": 3, 103 | "metadata": {}, 104 | "output_type": "execute_result" 105 | } 106 | ], 107 | "source": [ 108 | "from not_again_ai.llm.prompting.compile_prompt import compile_messages\n", 109 | "\n", 110 | "place_extraction_prompt = [\n", 111 | " SystemMessage(\n", 112 | " content=\"\"\"- You are a helpful assistant trying to extract places that occur in a given text.\n", 113 | "- You must identify all the places in the text and return them in a list like this: [\"place1\", \"place2\", \"place3\"].\"\"\"\n", 114 | " ),\n", 115 | " UserMessage(\n", 116 | " content=\"\"\"Here is the text I want you to extract places from:\n", 117 | "{%- # The user's input text goes below %}\n", 118 | "{{text}}\"\"\",\n", 119 | " ),\n", 120 | "]\n", 121 | "\n", 122 | "variables = {\n", 123 | " \"text\": \"I went to Paris and Berlin.\",\n", 124 | "}\n", 125 | "\n", 126 | "messages = compile_messages(messages=place_extraction_prompt, variables=variables)\n", 127 | "messages" 128 | ] 129 | }, 130 | { 131 | "cell_type": "markdown", 132 | "metadata": {}, 133 | "source": [ 134 | "## Token Management\n", 135 | "\n", 136 | "While the OpenAI chat completion will return the tokens used, the `num_tokens_from_messages` helper can be used to compute the number of tokens used in a list of messages before calling the API.\n", 137 | "\n", 138 | "We explicitly require a tokenizer since loading it has some overhead, so we want to avoid doing so many times for certain use cases.\n", 139 | "\n", 140 | "NOTE: This function not support counting tokens used by function calling.\n" 141 | ] 142 | }, 143 | { 144 | "cell_type": "code", 145 | "execution_count": 4, 146 | "metadata": {}, 147 | "outputs": [ 148 | { 149 | "name": "stdout", 150 | "output_type": "stream", 151 | "text": [ 152 | "78\n" 153 | ] 154 | } 155 | ], 156 | "source": [ 157 | "from not_again_ai.llm.prompting.providers.openai_tiktoken import TokenizerOpenAI\n", 158 | "\n", 159 | "tokenizer = TokenizerOpenAI(model=\"gpt-4o-mini-2024-07-18\")\n", 160 | "num_tokens = tokenizer.num_tokens_in_messages(messages=messages)\n", 161 | "print(num_tokens)" 162 | ] 163 | }, 164 | { 165 | "cell_type": "markdown", 166 | "metadata": {}, 167 | "source": [ 168 | "## Chat Completion with Function Calling and other Parameters\n", 169 | "\n", 170 | "The `chat_completion` function can also be used to call functions in the prompt and a myriad of other commonly used parameters like temperature, max_tokens, and logprobs. See the docstring for more details.\n", 171 | "\n", 172 | "See the [gpt-4-v.ipynb](gpt-4-v.ipynb) for full details on how to use the vision features of `chat_completion` and `chat_prompt`.\n" 173 | ] 174 | }, 175 | { 176 | "cell_type": "code", 177 | "execution_count": 5, 178 | "metadata": {}, 179 | "outputs": [ 180 | { 181 | "data": { 182 | "text/plain": [ 183 | "ChatCompletionResponse(choices=[ChatCompletionChoice(message=AssistantMessage(content='', role=, name=None, refusal=None, tool_calls=[ToolCall(id='call_dwyBECUXUbPyuJH6oGxE3DFz', function=Function(name='get_current_weather', arguments={'location': 'Boston, MA', 'format': 'fahrenheit'}), type='function')]), finish_reason='tool_calls', json_message=None, logprobs=None, extras={}), ChatCompletionChoice(message=AssistantMessage(content='', role=, name=None, refusal=None, tool_calls=[ToolCall(id='call_yB03eV0flHXXKI6STtHMvPpm', function=Function(name='get_current_weather', arguments={'location': 'Boston, MA', 'format': 'fahrenheit'}), type='function')]), finish_reason='tool_calls', json_message=None, logprobs=None, extras={})], errors='', completion_tokens=46, prompt_tokens=99, completion_detailed_tokens=None, prompt_detailed_tokens=None, response_duration=0.8277, system_fingerprint='fp_e4fa3702df', extras={'prompt_filter_results': None})" 184 | ] 185 | }, 186 | "execution_count": 5, 187 | "metadata": {}, 188 | "output_type": "execute_result" 189 | } 190 | ], 191 | "source": [ 192 | "# Define a tool to get the current weather\n", 193 | "tools = [\n", 194 | " {\n", 195 | " \"type\": \"function\",\n", 196 | " \"function\": {\n", 197 | " \"name\": \"get_current_weather\",\n", 198 | " \"description\": \"Get the current weather\",\n", 199 | " \"parameters\": {\n", 200 | " \"type\": \"object\",\n", 201 | " \"properties\": {\n", 202 | " \"location\": {\n", 203 | " \"type\": \"string\",\n", 204 | " \"description\": \"The city and state, e.g. San Francisco, CA\",\n", 205 | " },\n", 206 | " \"format\": {\n", 207 | " \"type\": \"string\",\n", 208 | " \"enum\": [\"celsius\", \"fahrenheit\"],\n", 209 | " \"description\": \"The temperature unit to use. Infer this from the users location.\",\n", 210 | " },\n", 211 | " },\n", 212 | " \"required\": [\"location\", \"format\"],\n", 213 | " },\n", 214 | " },\n", 215 | " },\n", 216 | "]\n", 217 | "# Ask the model to call the function\n", 218 | "messages = [\n", 219 | " UserMessage(\n", 220 | " content=\"What's the current weather like in {{ city_state }} today? Call the get_current_weather function.\",\n", 221 | " )\n", 222 | "]\n", 223 | "\n", 224 | "messages = compile_messages(messages=messages, variables={\"city_state\": \"Boston, MA\"})\n", 225 | "\n", 226 | "client = openai_client()\n", 227 | "\n", 228 | "request = ChatCompletionRequest(\n", 229 | " messages=messages,\n", 230 | " model=\"gpt-4o-mini-2024-07-18\",\n", 231 | " client=client,\n", 232 | " tools=tools,\n", 233 | " tool_choice=\"required\", # Force the model to use the tool\n", 234 | " max_completion_tokens=300,\n", 235 | " temperature=0,\n", 236 | " log_probs=True,\n", 237 | " top_log_probs=2, # returns the log probabilities of the top 2 tokens\n", 238 | " seed=42, # Set the seed for reproducibility. The API will also return a `system_fingerprint` field to monitor changes in the backend.\n", 239 | " n=2, # Generate 2 completions at once\n", 240 | ")\n", 241 | "response = chat_completion(request, \"openai\", client)\n", 242 | "response" 243 | ] 244 | } 245 | ], 246 | "metadata": { 247 | "kernelspec": { 248 | "display_name": ".venv", 249 | "language": "python", 250 | "name": "python3" 251 | }, 252 | "language_info": { 253 | "codemirror_mode": { 254 | "name": "ipython", 255 | "version": 3 256 | }, 257 | "file_extension": ".py", 258 | "mimetype": "text/x-python", 259 | "name": "python", 260 | "nbconvert_exporter": "python", 261 | "pygments_lexer": "ipython3", 262 | "version": "3.12.3" 263 | } 264 | }, 265 | "nbformat": 4, 266 | "nbformat_minor": 2 267 | } 268 | -------------------------------------------------------------------------------- /notebooks/llm/02_ollama_intro.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | "# Using Ollama\n", 8 | "\n", 9 | "[Ollama](https://github.com/ollama/ollama) is a simple way to get started with running language models locally.\n", 10 | "\n", 11 | "We provide helpers to interface with Ollama by wrapping the [ollama-python](https://github.com/ollama/ollama-python) package.\n", 12 | "\n", 13 | "## Installation\n", 14 | "\n", 15 | "See the main README for installation instructions.\n" 16 | ] 17 | }, 18 | { 19 | "cell_type": "markdown", 20 | "metadata": {}, 21 | "source": [ 22 | "## Instantiating the Ollama client\n", 23 | "\n", 24 | "We use the `Client` class from Ollama to allow customizability of the host. By default, the `ollama_client` function will try to read in the `OLLAMA_HOST` environment variable. If it is not set, you must provide a host. Generally, the default is `http://localhost:11434`.\n" 25 | ] 26 | }, 27 | { 28 | "cell_type": "code", 29 | "execution_count": 1, 30 | "metadata": {}, 31 | "outputs": [], 32 | "source": [ 33 | "from not_again_ai.llm.chat_completion.providers.ollama_api import ollama_client\n", 34 | "\n", 35 | "client = ollama_client()" 36 | ] 37 | }, 38 | { 39 | "cell_type": "markdown", 40 | "metadata": {}, 41 | "source": [ 42 | "## Basic Chat Completion\n", 43 | "\n", 44 | "The same `chat_completion` used for OpenAI, etc can be used to call models hosted on Ollama.\n", 45 | "\n", 46 | "We assume that the model `phi4` has already been pulled into Ollama. If not, you can do so with the command `ollama pull phi4` in your terminal.\n" 47 | ] 48 | }, 49 | { 50 | "cell_type": "code", 51 | "execution_count": 2, 52 | "metadata": {}, 53 | "outputs": [ 54 | { 55 | "data": { 56 | "text/plain": [ 57 | "ChatCompletionResponse(choices=[ChatCompletionChoice(message=AssistantMessage(content=\"Hi there! How can I assist you today? Whether it's answering questions, providing information, or helping with a specific task, feel free to let me know what you need! 😊\", role=, name=None, refusal=None, tool_calls=None), finish_reason='stop', json_message=None, logprobs=None, extras=None)], errors='', completion_tokens=39, prompt_tokens=24, completion_detailed_tokens=None, prompt_detailed_tokens=None, response_duration=4.5843, system_fingerprint=None, extras=None)" 58 | ] 59 | }, 60 | "execution_count": 2, 61 | "metadata": {}, 62 | "output_type": "execute_result" 63 | } 64 | ], 65 | "source": [ 66 | "from not_again_ai.llm.chat_completion import chat_completion\n", 67 | "from not_again_ai.llm.chat_completion.types import ChatCompletionRequest, SystemMessage, UserMessage\n", 68 | "\n", 69 | "messages = [\n", 70 | " SystemMessage(content=\"You are a helpful assistant.\"),\n", 71 | " UserMessage(content=\"Hello!\"),\n", 72 | "]\n", 73 | "\n", 74 | "request = ChatCompletionRequest(\n", 75 | " messages=messages,\n", 76 | " model=\"phi4\",\n", 77 | " context_window=4000, # Set context_window because Ollama's default is small.\n", 78 | ")\n", 79 | "\n", 80 | "response = chat_completion(request, provider=\"ollama\", client=client)\n", 81 | "response" 82 | ] 83 | }, 84 | { 85 | "cell_type": "markdown", 86 | "metadata": {}, 87 | "source": [ 88 | "## Chat Completion with Other Features\n", 89 | "\n", 90 | "The Ollama API also supports several other features, such as JSON mode, temperature, and max_tokens. The `ChatCompletionRequest` class has fields for all of these including ones specific to Ollama such as `top_k`.\n" 91 | ] 92 | }, 93 | { 94 | "cell_type": "code", 95 | "execution_count": 3, 96 | "metadata": {}, 97 | "outputs": [ 98 | { 99 | "data": { 100 | "text/plain": [ 101 | "ChatCompletionResponse(choices=[ChatCompletionChoice(message=AssistantMessage(content='{\\n \"random_number\": 47\\n} \\n\\n', role=, name=None, refusal=None, tool_calls=None), finish_reason='stop', json_message={'random_number': 47}, logprobs=None, extras=None)], errors='', completion_tokens=12, prompt_tokens=40, completion_detailed_tokens=None, prompt_detailed_tokens=None, response_duration=4.258, system_fingerprint=None, extras=None)" 102 | ] 103 | }, 104 | "execution_count": 3, 105 | "metadata": {}, 106 | "output_type": "execute_result" 107 | } 108 | ], 109 | "source": [ 110 | "messages = [\n", 111 | " SystemMessage(content=\"You are a helpful assistant.\"),\n", 112 | " UserMessage(content=\"Generate a random number between 0 and 100 and structure the response in using JSON.\"),\n", 113 | "]\n", 114 | "\n", 115 | "request = ChatCompletionRequest(\n", 116 | " messages=messages,\n", 117 | " model=\"phi4\",\n", 118 | " max_completion_tokens=300,\n", 119 | " context_window=1000,\n", 120 | " temperature=1.51,\n", 121 | " json_mode=True,\n", 122 | " top_k=5,\n", 123 | " seed=6,\n", 124 | ")\n", 125 | "\n", 126 | "response = chat_completion(request, provider=\"ollama\", client=client)\n", 127 | "response" 128 | ] 129 | } 130 | ], 131 | "metadata": { 132 | "kernelspec": { 133 | "display_name": ".venv", 134 | "language": "python", 135 | "name": "python3" 136 | }, 137 | "language_info": { 138 | "codemirror_mode": { 139 | "name": "ipython", 140 | "version": 3 141 | }, 142 | "file_extension": ".py", 143 | "mimetype": "text/x-python", 144 | "name": "python", 145 | "nbconvert_exporter": "python", 146 | "pygments_lexer": "ipython3", 147 | "version": "3.12.3" 148 | } 149 | }, 150 | "nbformat": 4, 151 | "nbformat_minor": 2 152 | } 153 | -------------------------------------------------------------------------------- /notebooks/llm/10_gpt-4-v.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | "# Using GPT-4V\n" 8 | ] 9 | }, 10 | { 11 | "cell_type": "markdown", 12 | "metadata": {}, 13 | "source": [ 14 | "This notebook demonstrates how to use GPT-4V's image capabilities directly through the OpenAI API.\n", 15 | "We provide helper functions to simplify the creation of prompts and understanding which parameters are available while maintaining the complete flexibility that the API offers.\n" 16 | ] 17 | }, 18 | { 19 | "cell_type": "markdown", 20 | "metadata": {}, 21 | "source": [ 22 | "## Creating Prompts\n", 23 | "\n", 24 | "Prompts for vision enabled models follow the familiar [chat completion](https://platform.openai.com/docs/guides/text-generation/chat-completions-api) format as the non-vision enabled models or requests.\n", 25 | "\n", 26 | "However, including images in the prompt requires a slightly different format. Images are available to the models in two ways: by passing a URL to an image or by passing the base64 encoded image directly in the request.\n", 27 | "Note that images can be passed in the `user`, `system` and `assistant` messages, however currently they cannot be in the _first_ message [[source]](https://platform.openai.com/docs/guides/vision).\n", 28 | "\n", 29 | "We can have messages containing text as before, but when we want to include images with a message, `content` becomes a list. That list can contain both text and image messages, in any order. We used the `encode_image` function to convert the image to base64 encoding. The optional `detail` parameter in the `image_url` message specifies the quality of the image. It can be either `low` or `high`. For more details on how images are processed and associated costs, refer to the [OpenAI API documentation](https://platform.openai.com/docs/guides/vision/low-or-high-fidelity-image-understanding). Other providers may not have this functionality.\n" 30 | ] 31 | }, 32 | { 33 | "cell_type": "code", 34 | "execution_count": 1, 35 | "metadata": {}, 36 | "outputs": [ 37 | { 38 | "data": { 39 | "text/plain": [ 40 | "[SystemMessage(content='You are a helpful assistant.', role=, name=None),\n", 41 | " UserMessage(content=[TextContent(type=, text='Based on these infographics, can you summarize how Semantic Kernel works in exactly one sentence?'), ImageContent(type=, image_url=ImageUrl(url='...', detail=)), ImageContent(type=, image_url=ImageUrl(url='...', detail=))], role=, name=None)]" 42 | ] 43 | }, 44 | "execution_count": 1, 45 | "metadata": {}, 46 | "output_type": "execute_result" 47 | } 48 | ], 49 | "source": [ 50 | "from pathlib import Path\n", 51 | "\n", 52 | "from not_again_ai.llm.chat_completion.types import (\n", 53 | " ImageContent,\n", 54 | " ImageDetail,\n", 55 | " ImageUrl,\n", 56 | " SystemMessage,\n", 57 | " TextContent,\n", 58 | " UserMessage,\n", 59 | ")\n", 60 | "from not_again_ai.llm.prompting.compile_prompt import compile_messages, encode_image\n", 61 | "\n", 62 | "sk_infographic = Path.cwd().parent.parent / \"tests\" / \"llm\" / \"sample_images\" / \"SKInfographic.png\"\n", 63 | "sk_diagram = Path.cwd().parent.parent / \"tests\" / \"llm\" / \"sample_images\" / \"SKDiagram.png\"\n", 64 | "\n", 65 | "messages = [\n", 66 | " SystemMessage(content=\"You are a helpful {{ persona }}.\"),\n", 67 | " UserMessage(\n", 68 | " content=[\n", 69 | " TextContent(\n", 70 | " text=\"Based on these infographics, can you summarize how {{ library }} works in exactly one sentence?\"\n", 71 | " ),\n", 72 | " ImageContent(\n", 73 | " image_url=ImageUrl(url=f\"data:image/png;base64,{encode_image(sk_infographic)}\", detail=ImageDetail.HIGH)\n", 74 | " ),\n", 75 | " ImageContent(\n", 76 | " image_url=ImageUrl(url=f\"data:image/png;base64,{encode_image(sk_diagram)}\", detail=ImageDetail.LOW)\n", 77 | " ),\n", 78 | " ],\n", 79 | " ),\n", 80 | "]\n", 81 | "\n", 82 | "prompt = compile_messages(messages, variables={\"persona\": \"assistant\", \"library\": \"Semantic Kernel\"})\n", 83 | "\n", 84 | "# Truncate the url fields to avoid cluttering the output\n", 85 | "prompt[1].content[1].image_url.url = prompt[1].content[1].image_url.url[0:50] + \"...\"\n", 86 | "prompt[1].content[2].image_url.url = prompt[1].content[2].image_url.url[0:50] + \"...\"\n", 87 | "prompt" 88 | ] 89 | }, 90 | { 91 | "cell_type": "markdown", 92 | "metadata": {}, 93 | "source": [ 94 | "Here are the two images that were encoded:\n", 95 | "\n", 96 | "![SKInfographic](https://github.com/DaveCoDev/not-again-ai/blob/main/tests/llm/sample_images/SKInfographic.png?raw=true)\n", 97 | "\n", 98 | "![SKDiagram](https://github.com/DaveCoDev/not-again-ai/blob/main/tests/llm/sample_images/SKDiagram.png?raw=true)\n" 99 | ] 100 | }, 101 | { 102 | "cell_type": "markdown", 103 | "metadata": {}, 104 | "source": [ 105 | "## Making an API Request\n", 106 | "\n", 107 | "With prompt formatted, making the request is easy.\n", 108 | "\n", 109 | "### Simplifying the response format\n", 110 | "\n", 111 | "The response from the API is quite verbose. We can simplify it by extracting only what is needed, depending on the parameters we provided in our request.\n", 112 | "\n", 113 | "Using our helper functions, let's send a request which tries to use all the available parameters. Notice that we use `n=2` to get two completions in one request. However, due to the seed they should always be equivalent. NOTE: We have noticed that the `seed` parameter is hit or miss and does not generate the same completions in all scenarios.\n" 114 | ] 115 | }, 116 | { 117 | "cell_type": "code", 118 | "execution_count": 2, 119 | "metadata": {}, 120 | "outputs": [ 121 | { 122 | "data": { 123 | "text/plain": [ 124 | "'Semantic Kernel is a framework that integrates various AI services and plugins to manage and execute tasks by processing prompts, utilizing memory, planning, and invoking functions to deliver results efficiently.'" 125 | ] 126 | }, 127 | "execution_count": 2, 128 | "metadata": {}, 129 | "output_type": "execute_result" 130 | } 131 | ], 132 | "source": [ 133 | "from not_again_ai.llm.chat_completion import chat_completion\n", 134 | "from not_again_ai.llm.chat_completion.providers.openai_api import openai_client\n", 135 | "from not_again_ai.llm.chat_completion.types import ChatCompletionRequest\n", 136 | "\n", 137 | "client = openai_client()\n", 138 | "\n", 139 | "prompt = compile_messages(messages, variables={\"persona\": \"assistant\", \"library\": \"Semantic Kernel\"})\n", 140 | "\n", 141 | "request = ChatCompletionRequest(\n", 142 | " messages=prompt,\n", 143 | " model=\"gpt-4o-mini-2024-07-18\",\n", 144 | " max_completion_tokens=200,\n", 145 | " temperature=0.5,\n", 146 | " seed=42,\n", 147 | " n=2,\n", 148 | ")\n", 149 | "response = chat_completion(request, \"openai\", client)\n", 150 | "response.choices[0].message.content" 151 | ] 152 | } 153 | ], 154 | "metadata": { 155 | "kernelspec": { 156 | "display_name": ".venv", 157 | "language": "python", 158 | "name": "python3" 159 | }, 160 | "language_info": { 161 | "codemirror_mode": { 162 | "name": "ipython", 163 | "version": 3 164 | }, 165 | "file_extension": ".py", 166 | "mimetype": "text/x-python", 167 | "name": "python", 168 | "nbconvert_exporter": "python", 169 | "pygments_lexer": "ipython3", 170 | "version": "3.12.3" 171 | } 172 | }, 173 | "nbformat": 4, 174 | "nbformat_minor": 2 175 | } 176 | -------------------------------------------------------------------------------- /notebooks/llm/20_embeddings.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | "# Using Embeddings\n", 8 | "\n", 9 | "This notebook covers how you can use the `create_embeddings` function to create embeddings for text.\n", 10 | "\n", 11 | "## Embeddings with Ollama\n", 12 | "\n", 13 | "First we instantiate the Ollama client, which is identical to the client we use for chat completions, detailed in [02_ollama_intro.ipynb](./02_ollama_intro.ipynb).\n", 14 | "\n", 15 | "Then we use the general `create_embeddings` function to create embeddings passing the `EmbeddingRequest`, \"ollama\" as the provider, and client we initialized.\n" 16 | ] 17 | }, 18 | { 19 | "cell_type": "code", 20 | "execution_count": 10, 21 | "metadata": {}, 22 | "outputs": [ 23 | { 24 | "name": "stdout", 25 | "output_type": "stream", 26 | "text": [ 27 | "[-0.011252048, 0.053802438, -0.011280932, -0.0730897, -0.038289726]\n" 28 | ] 29 | } 30 | ], 31 | "source": [ 32 | "from not_again_ai.llm.embedding import EmbeddingRequest, create_embeddings\n", 33 | "from not_again_ai.llm.embedding.providers.ollama_api import ollama_client as ollama_embedding_client\n", 34 | "\n", 35 | "ollama_client = ollama_embedding_client()\n", 36 | "\n", 37 | "request = EmbeddingRequest(input=\"This is some text that I want to embed!\", model=\"snowflake-arctic-embed2\")\n", 38 | "response = create_embeddings(request, \"ollama\", ollama_client)\n", 39 | "print(response.embeddings[0].embedding[:5])" 40 | ] 41 | }, 42 | { 43 | "cell_type": "markdown", 44 | "metadata": {}, 45 | "source": [ 46 | "## Embeddings with OpenAI\n", 47 | "\n", 48 | "The OpenAI client is identical to the client we use for chat completions, detailed in [01_openai_chat_completion.ipynb](./01_openai_chat_completion.ipynb).\n", 49 | "\n", 50 | "We then use the general `create_embeddings` function to create embeddings passing the `EmbeddingRequest`, \"openai\" as the provider, and client we initialized.\n", 51 | "\n", 52 | "Note that OpenAI supports additional parameters, such as `dimensions` which we show here.\n" 53 | ] 54 | }, 55 | { 56 | "cell_type": "code", 57 | "execution_count": 11, 58 | "metadata": {}, 59 | "outputs": [ 60 | { 61 | "name": "stdout", 62 | "output_type": "stream", 63 | "text": [ 64 | "[-0.16265206038951874, 0.11295679211616516, 0.980196475982666]\n" 65 | ] 66 | } 67 | ], 68 | "source": [ 69 | "from not_again_ai.llm.embedding.providers.openai_api import openai_client as openai_embedding_client\n", 70 | "\n", 71 | "openai_client = openai_embedding_client()\n", 72 | "request = EmbeddingRequest(\n", 73 | " input=\"This is some text that I want to embed with OpenAI!\", model=\"text-embedding-3-small\", dimensions=3\n", 74 | ")\n", 75 | "response = create_embeddings(request, \"openai\", openai_client)\n", 76 | "print(response.embeddings[0].embedding[:5])" 77 | ] 78 | }, 79 | { 80 | "cell_type": "markdown", 81 | "metadata": {}, 82 | "source": [ 83 | "Finally, we can batch requests with either provider.\n" 84 | ] 85 | }, 86 | { 87 | "cell_type": "code", 88 | "execution_count": 12, 89 | "metadata": {}, 90 | "outputs": [ 91 | { 92 | "name": "stdout", 93 | "output_type": "stream", 94 | "text": [ 95 | "First embedding: [-0.024765074, 0.051080894, 0.021982849, -0.076628484, -0.07709133]\n", 96 | "Second embedding: [0.02124997, 0.04338046, -0.011488909, -0.03943117, -0.037866518]\n" 97 | ] 98 | } 99 | ], 100 | "source": [ 101 | "request = EmbeddingRequest(\n", 102 | " input=[\"This is some text that I want to embed with OpenAI!\", \"And embed this too!\"],\n", 103 | " model=\"snowflake-arctic-embed2\",\n", 104 | ")\n", 105 | "\n", 106 | "responses = create_embeddings(request, \"ollama\", ollama_client)\n", 107 | "print(f\"First embedding: {responses.embeddings[0].embedding[:5]}\")\n", 108 | "print(f\"Second embedding: {responses.embeddings[1].embedding[:5]}\")" 109 | ] 110 | } 111 | ], 112 | "metadata": { 113 | "kernelspec": { 114 | "display_name": ".venv", 115 | "language": "python", 116 | "name": "python3" 117 | }, 118 | "language_info": { 119 | "codemirror_mode": { 120 | "name": "ipython", 121 | "version": 3 122 | }, 123 | "file_extension": ".py", 124 | "mimetype": "text/x-python", 125 | "name": "python", 126 | "nbconvert_exporter": "python", 127 | "pygments_lexer": "ipython3", 128 | "version": "3.12.3" 129 | } 130 | }, 131 | "nbformat": 4, 132 | "nbformat_minor": 2 133 | } 134 | -------------------------------------------------------------------------------- /notebooks/statistics/statistics.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | "# Statistics" 8 | ] 9 | }, 10 | { 11 | "cell_type": "markdown", 12 | "metadata": {}, 13 | "source": [ 14 | "## Measuring dependence between variables\n", 15 | "### Pearson Correlation and Predictive Power Score\n", 16 | "\n", 17 | "Let's setup two variables, a categorical variable `x`, a binary variable `y`, where `x` *mostly predicts* `y` and see the scores for both methods.\n" 18 | ] 19 | }, 20 | { 21 | "cell_type": "code", 22 | "execution_count": 1, 23 | "metadata": {}, 24 | "outputs": [ 25 | { 26 | "data": { 27 | "text/plain": [ 28 | "array(['a', 'a', 'a', 'a', 'a', 'a', 'a', 'a', 'a', 'a', 'a', 'a', 'a',\n", 29 | " 'a', 'a', 'a', 'a', 'a', 'a', 'a', 'a', 'a', 'a', 'a', 'a', 'a',\n", 30 | " 'a', 'a', 'a', 'a', 'a', 'a', 'a', 'a', 'a', 'a', 'a', 'a', 'a',\n", 31 | " 'a', 'a', 'a', 'a', 'a', 'a', 'a', 'a', 'a', 'a', 'a', 'b', 'b',\n", 32 | " 'b', 'b', 'b', 'b', 'b', 'b', 'b', 'b', 'b', 'b', 'b', 'b', 'b',\n", 33 | " 'b', 'b', 'b', 'b', 'b', 'b', 'b', 'b', 'b', 'b', 'b', 'b', 'b',\n", 34 | " 'b', 'b', 'b', 'b', 'b', 'b', 'b', 'b', 'b', 'b', 'b', 'b', 'b',\n", 35 | " 'b', 'b', 'b', 'b', 'b', 'b', 'b', 'b', 'b', 'b', 'b', 'b', 'b',\n", 36 | " 'b', 'b', 'b', 'b', 'b', 'b', 'b', 'b', 'b', 'b', 'b', 'b', 'b',\n", 37 | " 'b', 'b', 'b', 'b', 'b', 'b', 'b', 'b', 'b', 'b', 'b', 'b', 'b',\n", 38 | " 'b', 'b', 'b', 'b', 'b', 'b', 'b', 'b', 'b', 'b', 'b', 'b', 'b',\n", 39 | " 'b', 'b', 'b', 'b', 'b', 'b', 'b'], dtype=' None: 18 | s.install(".[data,llm,statistics,viz]", "pytest", "pytest-asyncio", "pytest-cov", "pytest-randomly") 19 | s.run_install( 20 | "uv", 21 | "sync", 22 | "--locked", 23 | "--all-extras", 24 | env={"UV_PROJECT_ENVIRONMENT": s.virtualenv.location}, 25 | ) 26 | 27 | # Skip tests in directories specified by the SKIP_TESTS_NAII environment variable. 28 | skip_tests = os.getenv("SKIP_TESTS_NAAI", "") 29 | skip_tests += " tests/llm/chat_completion/ tests/llm/embedding/ tests/llm/image_gen/" 30 | skip_args = [f"--ignore={dir}" for dir in skip_tests.split()] if skip_tests else [] 31 | 32 | s.run( 33 | "python", 34 | "-m", 35 | "pytest", 36 | "--cov=not_again_ai", 37 | "--cov-report=html", 38 | "--cov-report=term", 39 | "tests", 40 | *skip_args, 41 | "-W ignore::DeprecationWarning", 42 | *s.posargs, 43 | ) 44 | 45 | 46 | # For some sessions, set venv_backend="none" to simply execute scripts within the existing 47 | # uv-generated virtual environment, rather than have nox create a new one for each session. 48 | @session(venv_backend="none") 49 | @parametrize( 50 | "command", 51 | [ 52 | param( 53 | [ 54 | "ruff", 55 | "check", 56 | ".", 57 | "--select", 58 | "I", 59 | # Also remove unused imports. 60 | "--select", 61 | "F401", 62 | "--extend-fixable", 63 | "F401", 64 | "--fix", 65 | ], 66 | id="sort_imports", 67 | ), 68 | param(["ruff", "format", "."], id="format"), 69 | ], 70 | ) 71 | def fmt(s: Session, command: list[str]) -> None: 72 | s.run(*command) 73 | 74 | 75 | @session(venv_backend="none") 76 | @parametrize( 77 | "command", 78 | [ 79 | param(["ruff", "check", "."], id="lint_check"), 80 | param(["ruff", "format", "--check", "."], id="format_check"), 81 | ], 82 | ) 83 | def lint(s: Session, command: list[str]) -> None: 84 | s.run(*command) 85 | 86 | 87 | @session(venv_backend="none") 88 | def lint_fix(s: Session) -> None: 89 | s.run("ruff", "check", ".", "--extend-fixable", "F401", "--fix") 90 | 91 | 92 | @session(venv_backend="none") 93 | def type_check(s: Session) -> None: 94 | s.run("mypy", "src", "tests", "noxfile.py") 95 | 96 | 97 | @session(venv_backend="none") 98 | def typos(s: Session) -> None: 99 | s.run("typos", "-c", ".github/_typos.toml") 100 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [project] 2 | name = "not-again-ai" 3 | version = "0.20.0" 4 | description = "Designed to once and for all collect all the little things that come up over and over again in AI projects and put them in one place." 5 | authors = [ 6 | { name = "DaveCoDev", email = "dave.co.dev@gmail.com" } 7 | ] 8 | license = "MIT" 9 | readme = "README.md" 10 | repository = "https://github.com/DaveCoDev/not-again-ai" 11 | documentation = "https://github.com/DaveCoDev/not-again-ai" 12 | classifiers = [ 13 | "Development Status :: 3 - Alpha", 14 | "Intended Audience :: Developers", 15 | "Intended Audience :: Science/Research", 16 | "License :: OSI Approved :: MIT License", 17 | "Operating System :: OS Independent", 18 | "Programming Language :: Python", 19 | "Programming Language :: Python :: 3", 20 | "Programming Language :: Python :: 3.11", 21 | "Programming Language :: Python :: 3.12", 22 | "Typing :: Typed", 23 | ] 24 | requires-python = ">=3.11" 25 | dependencies = [ 26 | "loguru>=0.7,<1.0", 27 | "pydantic>=2.11,<3.0", 28 | ] 29 | 30 | [project.urls] 31 | Homepage = "https://github.com/DaveCoDev/not-again-ai" 32 | Documentation = "https://davecodev.github.io/not-again-ai/" 33 | Repository = "https://github.com/DaveCoDev/not-again-ai" 34 | 35 | [project.optional-dependencies] 36 | data = [ 37 | "Crawl4AI>=0.6,<1.0", 38 | "httpx>=0.28,<1.0", 39 | "markitdown[pdf]==0.1.2" 40 | ] 41 | llm = [ 42 | "anthropic>=0.50,<1.0", 43 | "azure-identity>=1.21,<2.0", 44 | "google-genai>1.12,<2.0", 45 | "ollama>=0.4,<1.0", 46 | "openai>=1.76,<2.0", 47 | "python-liquid>=2.0,<3.0", 48 | "tiktoken>=0.9,<1.0" 49 | ] 50 | statistics = [ 51 | "numpy>=2.2,<3.0", 52 | "scikit-learn>=1.6,<2.0", 53 | "scipy>=1.15" 54 | ] 55 | viz = [ 56 | "numpy>=2.2,<3.0", 57 | "pandas>=2.2,<3.0", 58 | "seaborn>=0.13,<1.0", 59 | ] 60 | 61 | [dependency-groups] 62 | dev = [ 63 | "ipykernel", 64 | "ipywidgets", 65 | ] 66 | nox = [ 67 | "nox", 68 | ] 69 | test = [ 70 | "pytest", 71 | "pytest-asyncio", 72 | "pytest-cov", 73 | "pytest-randomly", 74 | ] 75 | type_check = [ 76 | "mypy", 77 | # Add "types-" stub packages as needed: https://github.com/python/typeshed/tree/main/stubs 78 | ] 79 | lint = [ 80 | "ruff", 81 | ] 82 | typos = [ 83 | "typos", 84 | ] 85 | 86 | [build-system] 87 | requires = ["hatchling"] 88 | build-backend = "hatchling.build" 89 | 90 | [tool.uv] 91 | default-groups = "all" 92 | 93 | [tool.mypy] 94 | ignore_missing_imports = true 95 | strict = true 96 | # If certain strict config options are too pedantic for a project, 97 | # disable them selectively here by setting to false. 98 | disallow_untyped_calls = false 99 | 100 | [tool.ruff] 101 | line-length = 120 102 | target-version = "py312" 103 | src = ["src", "tests"] 104 | 105 | [tool.ruff.lint] 106 | select = [ 107 | "F", # pyflakes 108 | "E", # pycodestyle 109 | "I", # isort 110 | "N", # pep8-naming 111 | "UP", # pyupgrade 112 | "RUF", # ruff 113 | "B", # flake8-bugbear 114 | "C4", # flake8-comprehensions 115 | "ISC", # flake8-implicit-str-concat 116 | "PIE", # flake8-pie 117 | "PT", # flake-pytest-style 118 | "PTH", # flake8-use-pathlib 119 | "SIM", # flake8-simplify 120 | "TID", # flake8-tidy-imports 121 | ] 122 | extend-ignore = ["E501"] 123 | unfixable = ["F401"] 124 | 125 | [tool.ruff.lint.isort] 126 | force-sort-within-sections = true 127 | split-on-trailing-comma = false 128 | 129 | [tool.ruff.lint.flake8-tidy-imports] 130 | ban-relative-imports = "all" 131 | 132 | [tool.pytest.ini_options] 133 | addopts = [ 134 | "--strict-config", 135 | "--strict-markers", 136 | ] 137 | xfail_strict = true 138 | filterwarnings = [ 139 | # When running tests, treat warnings as errors (e.g. -Werror). 140 | # See: https://docs.pytest.org/en/latest/reference/reference.html#confval-filterwarnings 141 | "error", 142 | # Add additional warning suppressions as needed here. For example, if a third-party library 143 | # is throwing a deprecation warning that needs to be fixed upstream: 144 | "ignore::DeprecationWarning", 145 | "ignore::pytest.PytestUnraisableExceptionWarning" 146 | ] 147 | asyncio_mode = "auto" 148 | asyncio_default_fixture_loop_scope = "function" 149 | 150 | [tool.coverage.run] 151 | branch = true 152 | -------------------------------------------------------------------------------- /src/not_again_ai/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DaveCoDev/not-again-ai/93302ff4b8b9eda5398e2c318854677f517445c8/src/not_again_ai/__init__.py -------------------------------------------------------------------------------- /src/not_again_ai/base/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DaveCoDev/not-again-ai/93302ff4b8b9eda5398e2c318854677f517445c8/src/not_again_ai/base/__init__.py -------------------------------------------------------------------------------- /src/not_again_ai/base/file_system.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | 3 | 4 | def create_file_dir(filepath: str | Path) -> None: 5 | """Creates the parent directories for the specified filepath. 6 | Does not throw any errors if the directories already exist. 7 | 8 | Args: 9 | filepath (str | Path): path to a file 10 | """ 11 | root_path = Path(filepath).parent 12 | root_path.mkdir(parents=True, exist_ok=True) 13 | 14 | 15 | def readable_size(size: float) -> str: 16 | """Convert a file size given in bytes to a human-readable format. 17 | 18 | Args: 19 | size (int): file size in bytes 20 | 21 | Returns: 22 | str: human-readable file size 23 | """ 24 | # Define the suffixes for each size unit 25 | suffixes = ["B", "KB", "MB", "GB", "TB", "PB", "EB", "ZB", "YB"] 26 | 27 | # Start with bytes 28 | count = 0 29 | while size >= 1024 and count < len(suffixes) - 1: 30 | count += 1 31 | size /= 1024 32 | 33 | # Format the size to two decimal places and append the appropriate suffix 34 | return f"{size:.2f} {suffixes[count]}" 35 | -------------------------------------------------------------------------------- /src/not_again_ai/base/parallel.py: -------------------------------------------------------------------------------- 1 | from collections.abc import Callable 2 | from multiprocessing.pool import ThreadPool 3 | from typing import Any 4 | 5 | 6 | def embarrassingly_parallel( 7 | func: Callable[..., Any], 8 | args_list: tuple[tuple[Any, ...], ...] | None, 9 | kwargs_list: list[dict[str, Any]] | None = None, 10 | num_processes: int = 1, 11 | ) -> list[Any]: 12 | """Call multiple functions in parallel providing either positional arguments, keyword arguments, 13 | or both. Return the function returns in a list ordered by order of the input arguments. 14 | 15 | If both are provided, positional and keyword arguments must be aligned in the same order 16 | and each list must be the same length. 17 | 18 | Args: 19 | func (Callable[..., Any]): Any function 20 | args_list (Optional[tuple[tuple[Any, ...], ...]]): A tuple of tuples each of positional arguments. 21 | kwargs_list (Optional[list[dict[str, Any]]], optional): A list of dictionaries containing keyword arguments. Defaults to None. 22 | num_processes (int, optional): Number of parallel processors to use. Defaults to 1. 23 | 24 | Raises: 25 | ValueError: If positional and keyword arguments are not aligned in the 26 | same order or if the lists are not the same length. 27 | ValueError: If neither positional nor keyword arguments are provided. 28 | 29 | Returns: 30 | list[Any]: list of the returns of each function call in order of the args_list or kwargs_list. 31 | """ 32 | 33 | pool = ThreadPool(processes=num_processes) 34 | results = {} 35 | if (args_list is not None) and (kwargs_list is None): 36 | for idx, args in enumerate(args_list): 37 | results[idx] = pool.apply_async(func, args) 38 | elif (args_list is None) and (kwargs_list is not None): 39 | for idx, kwargs in enumerate(kwargs_list): 40 | results[idx] = pool.apply_async(func, kwds=kwargs) 41 | elif (args_list is not None) and (kwargs_list is not None): 42 | # in this case args_list and kwargs_list must be of the same length 43 | if len(args_list) == len(kwargs_list): 44 | for idx, (args, kwargs) in enumerate(zip(args_list, kwargs_list, strict=True)): 45 | results[idx] = pool.apply_async(func, args, kwargs) 46 | else: 47 | pool.close() 48 | pool.terminate() 49 | raise ValueError("args_list and kwargs_list must be of the same length") 50 | else: 51 | pool.close() 52 | pool.terminate() 53 | raise ValueError("either args_list or kwargs_list must be provided") 54 | 55 | return_results = [] 56 | for _, res in results.items(): 57 | return_results.append(res.get()) 58 | 59 | pool.close() 60 | pool.terminate() 61 | return return_results 62 | 63 | 64 | def embarrassingly_parallel_simple(funcs: list[Callable[..., Any]], num_processes: int = 1) -> list[Any]: 65 | """Executes the given functions in parallel and returns the results in the same order as the funcs were provided. 66 | 67 | Args: 68 | funcs (list[Callable[..., Any]]): A list of any functions that take no arguments. 69 | num_processes (int, optional): Number of parallel processors to use. Defaults to 1. 70 | 71 | Returns: 72 | list[Any]: list of the returns of each function call in order of the provided funcs. 73 | """ 74 | 75 | pool = ThreadPool(processes=num_processes) 76 | results = {} 77 | for idx, func in enumerate(funcs): 78 | results[idx] = pool.apply_async(func) 79 | 80 | return_results = [] 81 | for _, res in results.items(): 82 | return_results.append(res.get()) 83 | 84 | pool.close() 85 | pool.terminate() 86 | return return_results 87 | -------------------------------------------------------------------------------- /src/not_again_ai/data/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DaveCoDev/not-again-ai/93302ff4b8b9eda5398e2c318854677f517445c8/src/not_again_ai/data/__init__.py -------------------------------------------------------------------------------- /src/not_again_ai/data/brave_search_api.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import httpx 4 | from loguru import logger 5 | from pydantic import BaseModel 6 | 7 | 8 | class SearchWebResult(BaseModel): 9 | title: str 10 | url: str 11 | description: str 12 | netloc: str | None = None 13 | 14 | 15 | class SearchWebResults(BaseModel): 16 | results: list[SearchWebResult] 17 | 18 | 19 | async def search( 20 | query: str, 21 | count: int = 20, 22 | offset: int = 0, 23 | country: str = "US", 24 | search_lang: str = "en", 25 | ui_lang: str = "en-US", 26 | freshness: str | None = None, 27 | timezone: str = "America/New_York", 28 | state: str = "MA", 29 | user_agent: str = "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/134.0.0.0 Safari/537.36 Edg/134.0.0.", 30 | ) -> SearchWebResults: 31 | """ 32 | Search using Brave Search API. 33 | 34 | Args: 35 | query: The search query string 36 | count: Number of search results to return (1-20, default 10) 37 | offset: Number of search results to skip (default 0) 38 | country: Country code for search results (default "US") 39 | search_lang: Language for search (default "en") 40 | ui_lang: User interface language (default "en-US") 41 | freshness: Freshness of results ("pd", "pw", "pm", "py" or YYYY-MM-DDtoYYYY-MM-DD or None) 42 | timezone: Timezone for search results (default "America/New_York") 43 | state: State for search results (default "MA") 44 | user_agent: User agent string for the request (default is a common browser UA) 45 | 46 | Returns: 47 | SearchWebResults: A model containing the search results 48 | 49 | Raises: 50 | httpx.HTTPError: If the request fails 51 | ValueError: If BRAVE_SEARCH_API_KEY is not set 52 | """ 53 | api_key = os.getenv("BRAVE_SEARCH_API_KEY") 54 | if not api_key: 55 | raise ValueError("BRAVE_SEARCH_API_KEY environment variable is not set") 56 | 57 | url = "https://api.search.brave.com/res/v1/web/search" 58 | 59 | headers = { 60 | "Accept": "application/json", 61 | "Accept-Encoding": "gzip", 62 | "X-Subscription-Token": api_key, 63 | "X-Loc-Country": country, 64 | "X-Loc-Timezone": timezone, 65 | "X-Loc-State": state, 66 | "User-Agent": user_agent, 67 | } 68 | 69 | params: dict[str, str | int | bool] = { 70 | "q": query, 71 | "count": count, 72 | "offset": offset, 73 | "country": country, 74 | "search_lang": search_lang, 75 | "ui_lang": ui_lang, 76 | "text_decorations": False, 77 | "spellcheck": False, 78 | "units": "imperial", 79 | "extra_snippets": False, 80 | "safesearch": "off", 81 | } 82 | 83 | # Add optional parameters if provided 84 | if freshness: 85 | params["freshness"] = freshness 86 | 87 | try: 88 | async with httpx.AsyncClient() as client: 89 | response = await client.get(url, headers=headers, params=params) 90 | response.raise_for_status() 91 | data = response.json() 92 | results_list: list[SearchWebResult] = [] 93 | for item in data.get("web", {}).get("results", []): 94 | result = SearchWebResult( 95 | title=item.get("title", ""), 96 | url=item.get("url", ""), 97 | description=item.get("snippet", ""), 98 | netloc=item.get("meta_url", {}).get("netloc", None), 99 | ) 100 | results_list.append(result) 101 | return SearchWebResults(results=results_list) 102 | 103 | except httpx.HTTPError as e: 104 | logger.error(f"HTTP error during Brave search: {e}") 105 | raise 106 | except Exception as e: 107 | logger.error(f"Unexpected error during Brave search: {e}") 108 | raise 109 | 110 | 111 | class SearchNewsResult(BaseModel): 112 | title: str 113 | url: str 114 | description: str 115 | age: str 116 | netloc: str | None = None 117 | 118 | 119 | class SearchNewsResults(BaseModel): 120 | results: list[SearchNewsResult] 121 | 122 | 123 | async def search_news( 124 | query: str, 125 | count: int = 20, 126 | offset: int = 0, 127 | country: str = "US", 128 | search_lang: str = "en", 129 | ui_lang: str = "en-US", 130 | freshness: str | None = None, 131 | user_agent: str = "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/134.0.0.0 Safari/537.36 Edg/134.0.0.", 132 | ) -> SearchNewsResults: 133 | """ 134 | Search news using Brave News Search API. 135 | 136 | Args: 137 | query: The search query string 138 | count: Number of news results to return (1-20, default 20) 139 | offset: Number of search results to skip (default 0) 140 | country: Country code for search results (default "US") 141 | search_lang: Language for search (default "en") 142 | ui_lang: User interface language (default "en-US") 143 | freshness: Freshness of results ("pd", "pw", "pm", "py" or YYYY-MM-DDtoYYYY-MM-DD or None) 144 | user_agent: User agent string for the request (default is a common browser UA) 145 | 146 | Returns: 147 | SearchNewsResults: A model containing the news search results 148 | 149 | Raises: 150 | httpx.HTTPError: If the request fails 151 | ValueError: If BRAVE_SEARCH_API_KEY is not set 152 | """ 153 | api_key = os.getenv("BRAVE_SEARCH_API_KEY") 154 | if not api_key: 155 | raise ValueError("BRAVE_SEARCH_API_KEY environment variable is not set") 156 | 157 | url = "https://api.search.brave.com/res/v1/news/search" 158 | 159 | headers = { 160 | "Accept": "application/json", 161 | "Accept-Encoding": "gzip", 162 | "X-Subscription-Token": api_key, 163 | "User-Agent": user_agent, 164 | } 165 | 166 | params: dict[str, str | int | bool] = { 167 | "q": query, 168 | "count": count, 169 | "offset": offset, 170 | "country": country, 171 | "search_lang": search_lang, 172 | "ui_lang": ui_lang, 173 | "spellcheck": False, 174 | "safesearch": "off", 175 | } 176 | 177 | # Add optional parameters if provided 178 | if freshness: 179 | params["freshness"] = freshness 180 | 181 | try: 182 | async with httpx.AsyncClient() as client: 183 | response = await client.get(url, headers=headers, params=params) 184 | response.raise_for_status() 185 | data = response.json() 186 | results_list: list[SearchNewsResult] = [] 187 | for item in data.get("results", []): 188 | result = SearchNewsResult( 189 | title=item.get("title", ""), 190 | url=item.get("url", ""), 191 | description=item.get("description", ""), 192 | age=item.get("age"), 193 | netloc=item.get("meta_url", {}).get("netloc", None), 194 | ) 195 | results_list.append(result) 196 | return SearchNewsResults(results=results_list) 197 | 198 | except httpx.HTTPError as e: 199 | logger.error(f"HTTP error during Brave news search: {e}") 200 | raise 201 | except Exception as e: 202 | logger.error(f"Unexpected error during Brave news search: {e}") 203 | raise 204 | -------------------------------------------------------------------------------- /src/not_again_ai/data/web.py: -------------------------------------------------------------------------------- 1 | import asyncio 2 | import io 3 | import mimetypes 4 | from pathlib import Path 5 | import re 6 | from urllib.parse import urlparse 7 | 8 | from crawl4ai import AsyncWebCrawler, CacheMode 9 | from crawl4ai.async_configs import BrowserConfig, CrawlerRunConfig 10 | from crawl4ai.markdown_generation_strategy import DefaultMarkdownGenerator 11 | import httpx 12 | from markitdown import MarkItDown, StreamInfo 13 | from pydantic import BaseModel 14 | 15 | 16 | class Link(BaseModel): 17 | url: str 18 | text: str 19 | 20 | 21 | class URLResult(BaseModel): 22 | url: str 23 | markdown: str 24 | links: list[Link] = [] 25 | 26 | 27 | async def _markitdown_bytes_to_str(file_bytes: bytes, filename_extension: str) -> str: 28 | """ 29 | Convert a file using MarkItDown defaults. 30 | """ 31 | with io.BytesIO(file_bytes) as temp: 32 | result = await asyncio.to_thread( 33 | MarkItDown(enable_plugins=False).convert, 34 | source=temp, 35 | stream_info=StreamInfo(extension=filename_extension), 36 | ) 37 | text = result.text_content 38 | return text 39 | 40 | 41 | def _detect_pdf_extension(url: str) -> bool: 42 | """ 43 | Detect if the URL is a PDF based on its extension. 44 | """ 45 | parsed_url = urlparse(url) 46 | filename = Path(parsed_url.path).name 47 | return mimetypes.guess_type(filename)[0] == "application/pdf" 48 | 49 | 50 | def _detect_google_sheets(url: str) -> bool: 51 | """ 52 | Detect if the URL is a Google Sheets document. 53 | """ 54 | is_google_sheets = url.startswith("https://docs.google.com/spreadsheets/") 55 | return is_google_sheets 56 | 57 | 58 | async def _handle_pdf_content(url: str) -> URLResult: 59 | md = MarkItDown(enable_plugins=False) 60 | result = md.convert(url) 61 | url_result = URLResult( 62 | url=url, 63 | markdown=result.markdown or "", 64 | links=[], 65 | ) 66 | return url_result 67 | 68 | 69 | async def _handle_google_sheets_content(url: str) -> URLResult: 70 | """ 71 | Handle Google Sheets by using the export URL to get the raw content. 72 | """ 73 | edit_pattern = r"https://docs\.google\.com/spreadsheets/d/([a-zA-Z0-9-_]+)/edit" 74 | export_pattern = r"https://docs\.google\.com/spreadsheets/d/([a-zA-Z0-9-_]+)/export\?format=csv" 75 | 76 | # Check if it's already an export URL 77 | export_match = re.search(export_pattern, url) 78 | if export_match: 79 | export_url = url 80 | else: 81 | # Check if it's an edit URL and extract document ID 82 | edit_match = re.search(edit_pattern, url) 83 | if edit_match: 84 | doc_id = edit_match.group(1) 85 | export_url = f"https://docs.google.com/spreadsheets/d/{doc_id}/export?format=csv&gid=0" 86 | else: 87 | return await _handle_web_content(url) 88 | 89 | async with httpx.AsyncClient(follow_redirects=True) as client: 90 | response = await client.get(export_url) 91 | response.raise_for_status() 92 | csv_bytes = response.content 93 | 94 | # Convert CSV to markdown using MarkItDown 95 | markdown_content = await _markitdown_bytes_to_str(csv_bytes, ".csv") 96 | 97 | url_result = URLResult( 98 | url=url, 99 | markdown=markdown_content, 100 | links=[], 101 | ) 102 | return url_result 103 | 104 | 105 | async def _handle_web_content(url: str) -> URLResult: 106 | browser_config = BrowserConfig( 107 | browser_type="chromium", 108 | headless=True, 109 | verbose=False, 110 | user_agent_mode="random", 111 | java_script_enabled=True, 112 | ) 113 | run_config = CrawlerRunConfig( 114 | scan_full_page=True, 115 | user_agent_mode="random", 116 | cache_mode=CacheMode.DISABLED, 117 | markdown_generator=DefaultMarkdownGenerator(), 118 | ) 119 | 120 | async with AsyncWebCrawler(config=browser_config) as crawler: 121 | result = await crawler.arun( 122 | url=url, 123 | config=run_config, 124 | ) 125 | 126 | if result.response_headers.get("content-type") == "application/pdf": 127 | return await _handle_pdf_content(url) 128 | 129 | links: list[Link] = [] 130 | seen_urls: set[str] = set() 131 | combined_link_data = result.links.get("internal", []) + result.links.get("external", []) 132 | for link_data in combined_link_data: 133 | href = link_data.get("href", "") 134 | if href and href not in seen_urls: 135 | seen_urls.add(href) 136 | link = Link( 137 | url=href, 138 | text=link_data.get("title", "") or link_data.get("text", ""), 139 | ) 140 | links.append(link) 141 | 142 | url_result = URLResult( 143 | url=url, 144 | markdown=result.markdown or "", 145 | links=links, 146 | ) 147 | return url_result 148 | 149 | 150 | async def process_url(url: str) -> URLResult: 151 | """ 152 | Process a URL to extract content and convert it to Markdown and links 153 | """ 154 | if _detect_pdf_extension(url): 155 | url_result = await _handle_pdf_content(url) 156 | elif _detect_google_sheets(url): 157 | url_result = await _handle_google_sheets_content(url) 158 | else: 159 | url_result = await _handle_web_content(url) 160 | return url_result 161 | -------------------------------------------------------------------------------- /src/not_again_ai/llm/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DaveCoDev/not-again-ai/93302ff4b8b9eda5398e2c318854677f517445c8/src/not_again_ai/llm/__init__.py -------------------------------------------------------------------------------- /src/not_again_ai/llm/chat_completion/__init__.py: -------------------------------------------------------------------------------- 1 | from not_again_ai.llm.chat_completion.interface import chat_completion, chat_completion_stream 2 | from not_again_ai.llm.chat_completion.types import ChatCompletionRequest 3 | 4 | __all__ = ["ChatCompletionRequest", "chat_completion", "chat_completion_stream"] 5 | -------------------------------------------------------------------------------- /src/not_again_ai/llm/chat_completion/interface.py: -------------------------------------------------------------------------------- 1 | from collections.abc import AsyncGenerator, Callable 2 | from typing import Any 3 | 4 | from not_again_ai.llm.chat_completion.providers.anthropic_api import anthropic_chat_completion 5 | from not_again_ai.llm.chat_completion.providers.gemini_api import gemini_chat_completion 6 | from not_again_ai.llm.chat_completion.providers.ollama_api import ollama_chat_completion, ollama_chat_completion_stream 7 | from not_again_ai.llm.chat_completion.providers.openai_api import openai_chat_completion, openai_chat_completion_stream 8 | from not_again_ai.llm.chat_completion.types import ChatCompletionChunk, ChatCompletionRequest, ChatCompletionResponse 9 | 10 | 11 | def chat_completion( 12 | request: ChatCompletionRequest, 13 | provider: str, 14 | client: Callable[..., Any], 15 | ) -> ChatCompletionResponse: 16 | """Get a chat completion response from the given provider. Currently supported providers: 17 | - `openai` - OpenAI 18 | - `azure_openai` - Azure OpenAI 19 | - `ollama` - Ollama 20 | - `anthropic` - Anthropic 21 | - `gemini` - Gemini 22 | 23 | Args: 24 | request: Request parameter object 25 | provider: The supported provider name 26 | client: Client information, see the provider's implementation for what can be provided 27 | 28 | Returns: 29 | ChatCompletionResponse: The chat completion response. 30 | """ 31 | if provider == "openai" or provider == "azure_openai": 32 | return openai_chat_completion(request, client) 33 | elif provider == "ollama": 34 | return ollama_chat_completion(request, client) 35 | elif provider == "anthropic": 36 | return anthropic_chat_completion(request, client) 37 | elif provider == "gemini": 38 | return gemini_chat_completion(request, client) 39 | else: 40 | raise ValueError(f"Provider {provider} not supported") 41 | 42 | 43 | async def chat_completion_stream( 44 | request: ChatCompletionRequest, 45 | provider: str, 46 | client: Callable[..., Any], 47 | ) -> AsyncGenerator[ChatCompletionChunk, None]: 48 | """Stream a chat completion response from the given provider. Currently supported providers: 49 | - `openai` - OpenAI 50 | - `azure_openai` - Azure OpenAI 51 | 52 | Args: 53 | request: Request parameter object 54 | provider: The supported provider name 55 | client: Client information, see the provider's implementation for what can be provided 56 | 57 | Returns: 58 | AsyncGenerator[ChatCompletionChunk, None] 59 | """ 60 | request.stream = True 61 | if provider == "openai" or provider == "azure_openai": 62 | async for chunk in openai_chat_completion_stream(request, client): 63 | yield chunk 64 | elif provider == "ollama": 65 | async for chunk in ollama_chat_completion_stream(request, client): 66 | yield chunk 67 | else: 68 | raise ValueError(f"Provider {provider} not supported") 69 | -------------------------------------------------------------------------------- /src/not_again_ai/llm/chat_completion/providers/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DaveCoDev/not-again-ai/93302ff4b8b9eda5398e2c318854677f517445c8/src/not_again_ai/llm/chat_completion/providers/__init__.py -------------------------------------------------------------------------------- /src/not_again_ai/llm/chat_completion/providers/anthropic_api.py: -------------------------------------------------------------------------------- 1 | from collections.abc import Callable 2 | import os 3 | import time 4 | from typing import Any 5 | 6 | from anthropic import Anthropic 7 | from anthropic.types import Message 8 | 9 | from not_again_ai.llm.chat_completion.types import ( 10 | AssistantMessage, 11 | ChatCompletionChoice, 12 | ChatCompletionRequest, 13 | ChatCompletionResponse, 14 | Function, 15 | ToolCall, 16 | ) 17 | 18 | ANTHROPIC_PARAMETER_MAP = { 19 | "max_completion_tokens": "max_tokens", 20 | } 21 | 22 | 23 | def anthropic_chat_completion(request: ChatCompletionRequest, client: Callable[..., Any]) -> ChatCompletionResponse: 24 | """Anthropic chat completion function. 25 | 26 | TODO 27 | - Image messages 28 | - Thinking 29 | - Citations 30 | - Stop sequences 31 | - Documents 32 | """ 33 | kwargs = request.model_dump(mode="json", exclude_none=True) 34 | 35 | # For each key in ANTHROPIC_PARAMETER_MAP 36 | # If it is not None, set the key in kwargs to the value of the corresponding value in ANTHROPIC_PARAMETER_MAP 37 | # If it is None, remove that key from kwargs 38 | for key, value in ANTHROPIC_PARAMETER_MAP.items(): 39 | if value is not None and key in kwargs: 40 | kwargs[value] = kwargs.pop(key) 41 | elif value is None and key in kwargs: 42 | del kwargs[key] 43 | 44 | # Handle messages 45 | # Any system messages need to be removed from messages and concatenated into a single string (in order) 46 | # Any tool messages need to be converted to a special user message 47 | # Any assistant messages with tool calls need to be converted. 48 | system = "" 49 | new_messages = [] 50 | for message in kwargs["messages"]: 51 | if message["role"] == "system": 52 | system += message["content"] + "\n" 53 | elif message["role"] == "tool": 54 | new_messages.append( 55 | { 56 | "role": "user", 57 | "content": [ 58 | { 59 | "type": "tool_result", 60 | "tool_use_id": message["name"], 61 | "content": message["content"], 62 | } 63 | ], 64 | } 65 | ) 66 | elif message["role"] == "assistant": 67 | content = [] 68 | if message.get("content", None): 69 | content.append( 70 | { 71 | "type": "text", 72 | "content": message["content"], 73 | } 74 | ) 75 | for tool_call in message.get("tool_calls", []): 76 | content.append( 77 | { 78 | "type": "tool_use", 79 | "id": tool_call["id"], 80 | "name": tool_call["function"]["name"], 81 | "input": tool_call["function"]["arguments"], 82 | } 83 | ) 84 | new_messages.append( 85 | { 86 | "role": "assistant", 87 | "content": content, 88 | } 89 | ) 90 | else: 91 | new_messages.append(message) 92 | kwargs["messages"] = new_messages 93 | system = system.strip() 94 | if system: 95 | kwargs["system"] = system 96 | 97 | # Handle tool choice and parallel tool calls 98 | if kwargs.get("tool_choice") is not None: 99 | tool_choice_value = kwargs.pop("tool_choice") 100 | tool_choice = {} 101 | if tool_choice_value == "none": 102 | tool_choice["type"] = "none" 103 | elif tool_choice_value in ["auto", "any"]: 104 | tool_choice["type"] = "auto" 105 | if kwargs.get("parallel_tool_calls") is not None: 106 | tool_choice["disable_parallel_tool_use"] = not kwargs["parallel_tool_calls"] # type: ignore 107 | else: 108 | tool_choice["name"] = tool_choice_value 109 | tool_choice["type"] = "tool" 110 | if kwargs.get("parallel_tool_calls") is not None: 111 | tool_choice["disable_parallel_tool_use"] = not kwargs["parallel_tool_calls"] # type: ignore 112 | kwargs["tool_choice"] = tool_choice 113 | kwargs.pop("parallel_tool_calls", None) 114 | 115 | start_time = time.time() 116 | response: Message = client(**kwargs) 117 | end_time = time.time() 118 | response_duration = round(end_time - start_time, 4) 119 | 120 | tool_calls: list[ToolCall] = [] 121 | assistant_message = "" 122 | for block in response.content: 123 | if block.type == "text": 124 | assistant_message += block.text 125 | elif block.type == "tool_use": 126 | tool_calls.append( 127 | ToolCall( 128 | id=block.id, 129 | function=Function( 130 | name=block.name, 131 | arguments=block.input, # type: ignore 132 | ), 133 | ) 134 | ) 135 | 136 | choice = ChatCompletionChoice( 137 | message=AssistantMessage( 138 | content=assistant_message, 139 | tool_calls=tool_calls, 140 | ), 141 | finish_reason=response.stop_reason or "stop", 142 | ) 143 | 144 | chat_completion_response = ChatCompletionResponse( 145 | choices=[choice], 146 | errors="", 147 | completion_tokens=response.usage.output_tokens, 148 | prompt_tokens=response.usage.input_tokens, 149 | cache_read_input_tokens=response.usage.cache_read_input_tokens, 150 | cache_creation_input_tokens=response.usage.cache_creation_input_tokens, 151 | response_duration=response_duration, 152 | ) 153 | return chat_completion_response 154 | 155 | 156 | def create_client_callable(client_class: type[Anthropic], **client_args: Any) -> Callable[..., Any]: 157 | """Creates a callable that instantiates and uses an Anthropic client. 158 | 159 | Args: 160 | client_class: The Anthropic client class to instantiate 161 | **client_args: Arguments to pass to the client constructor 162 | 163 | Returns: 164 | A callable that creates a client and returns completion results 165 | """ 166 | filtered_args = {k: v for k, v in client_args.items() if v is not None} 167 | 168 | def client_callable(**kwargs: Any) -> Any: 169 | client = client_class(**filtered_args) 170 | completion = client.beta.messages.create(**kwargs) 171 | return completion 172 | 173 | return client_callable 174 | 175 | 176 | def anthropic_client(api_key: str | None = None) -> Callable[..., Any]: 177 | if not api_key: 178 | api_key = os.environ.get("ANTHROPIC_API_KEY") 179 | client_callable = create_client_callable(Anthropic, api_key=api_key) 180 | return client_callable 181 | -------------------------------------------------------------------------------- /src/not_again_ai/llm/chat_completion/providers/gemini_api.py: -------------------------------------------------------------------------------- 1 | import base64 2 | from collections.abc import Callable 3 | import os 4 | import time 5 | from typing import Any 6 | 7 | from google import genai 8 | from google.genai import types 9 | from google.genai.types import FunctionCall, FunctionCallingConfigMode, GenerateContentResponse 10 | 11 | from not_again_ai.llm.chat_completion.types import ( 12 | AssistantMessage, 13 | ChatCompletionChoice, 14 | ChatCompletionRequest, 15 | ChatCompletionResponse, 16 | Function, 17 | ImageContent, 18 | Role, 19 | TextContent, 20 | ToolCall, 21 | ) 22 | 23 | # This should be all of the options we want to support in types.GenerateContentConfig, that are not handled otherwise 24 | GEMINI_PARAMETER_MAP = { 25 | "max_completion_tokens": "max_output_tokens", 26 | "temperature": "temperature", 27 | "top_p": "top_p", 28 | "top_k": "top_k", 29 | } 30 | 31 | GEMINI_FINISH_REASON_MAP = { 32 | "STOP": "stop", 33 | "MAX_TOKENS": "max_tokens", 34 | "SAFETY": "safety", 35 | "RECITATION": "recitation", 36 | "LANGUAGE": "language", 37 | "OTHER": "other", 38 | "BLOCKLIST": "blocklist", 39 | "PROHIBITED_CONTENT": "prohibited_content", 40 | "SPII": "spii", 41 | "MALFORMED_FUNCTION_CALL": "malformed_function_call", 42 | "IMAGE_SAFETY": "image_safety", 43 | } 44 | 45 | 46 | def gemini_chat_completion(request: ChatCompletionRequest, client: Callable[..., Any]) -> ChatCompletionResponse: 47 | """Experimental Gemini chat completion function.""" 48 | # Handle messages 49 | # Any system messages need to be removed from messages and concatenated into a single string (in order) 50 | system = "" 51 | contents = [] 52 | for message in request.messages: 53 | if message.role == "system": 54 | # Handle both string content and structured content 55 | if isinstance(message.content, str): 56 | system += message.content + "\n" 57 | else: 58 | # If it's a list of content parts, extract text content 59 | for part in message.content: 60 | if hasattr(part, "text"): 61 | system += part.text + "\n" 62 | elif message.role == "tool": 63 | tool_name = message.name if message.name is not None else "" 64 | function_response_part = types.Part.from_function_response( 65 | name=tool_name, 66 | response={"result": message.content}, 67 | ) 68 | contents.append( 69 | types.Content( 70 | role="user", 71 | parts=[function_response_part], 72 | ) 73 | ) 74 | elif message.role == "assistant": 75 | if message.content and isinstance(message.content, str): 76 | contents.append(types.Content(role="model", parts=[types.Part(text=message.content)])) 77 | function_parts = [] 78 | if isinstance(message, AssistantMessage) and message.tool_calls: 79 | for tool_call in message.tool_calls: 80 | function_call_part = types.Part( 81 | function_call=FunctionCall( 82 | id=tool_call.id, 83 | name=tool_call.function.name, 84 | args=tool_call.function.arguments, 85 | ) 86 | ) 87 | function_parts.append(function_call_part) 88 | if function_parts: 89 | contents.append(types.Content(role="model", parts=function_parts)) 90 | elif message.role == "user": 91 | if isinstance(message.content, str): 92 | contents.append(types.Content(role="user", parts=[types.Part(text=message.content)])) 93 | elif isinstance(message.content, list): 94 | parts = [] 95 | for part in message.content: 96 | if isinstance(part, TextContent): 97 | parts.append(types.Part(text=part.text)) 98 | elif isinstance(part, ImageContent): 99 | # Extract MIME type and data from data URI 100 | uri_parts = part.image_url.url.split(",", 1) 101 | if len(uri_parts) == 2: 102 | mime_type = uri_parts[0].split(":")[1].split(";")[0] 103 | base64_data = uri_parts[1] 104 | image_data = base64.b64decode(base64_data) 105 | parts.append(types.Part.from_bytes(mime_type=mime_type, data=image_data)) 106 | contents.append(types.Content(role="user", parts=parts)) 107 | 108 | kwargs: dict[str, Any] = {} 109 | kwargs["contents"] = contents 110 | kwargs["model"] = request.model 111 | config: dict[str, Any] = {} 112 | config["system_instruction"] = system.rstrip() 113 | config["automatic_function_calling"] = {"disable": True} 114 | 115 | # Handle the possible tool choice options 116 | if request.tool_choice: 117 | tool_choice = request.tool_choice 118 | if tool_choice == "auto": 119 | config["tool_config"] = types.FunctionCallingConfig(mode=FunctionCallingConfigMode.AUTO) 120 | elif tool_choice == "any": 121 | config["tool_config"] = types.FunctionCallingConfig(mode=FunctionCallingConfigMode.ANY) 122 | elif tool_choice == "none": 123 | config["tool_config"] = types.FunctionCallingConfig(mode=FunctionCallingConfigMode.NONE) 124 | elif isinstance(tool_choice, list): 125 | config["tool_config"] = types.FunctionCallingConfig( 126 | mode=FunctionCallingConfigMode.ANY, allowed_function_names=tool_choice 127 | ) 128 | elif tool_choice not in (None, "auto", "any", "none"): 129 | config["tool_config"] = types.FunctionCallingConfig( 130 | mode=FunctionCallingConfigMode.ANY, allowed_function_names=[tool_choice] 131 | ) 132 | 133 | # Handle tools 134 | tools = [] 135 | for tool in request.tools or []: 136 | tools.append(types.Tool(function_declarations=[tool])) # type: ignore 137 | if tools: 138 | config["tools"] = tools 139 | 140 | # Everything else defined in GEMINI_PARAMETER_MAP goes into kwargs["config"] 141 | request_kwargs = request.model_dump(mode="json", exclude_none=True) 142 | for key, value in GEMINI_PARAMETER_MAP.items(): 143 | if value is not None and key in request_kwargs: 144 | config[value] = request_kwargs.pop(key) 145 | 146 | kwargs["config"] = types.GenerateContentConfig(**config) 147 | 148 | start_time = time.time() 149 | response: GenerateContentResponse = client(**kwargs) 150 | end_time = time.time() 151 | response_duration = round(end_time - start_time, 4) 152 | 153 | finish_reason = "other" 154 | if response.candidates and response.candidates[0].finish_reason: 155 | finish_reason_str = str(response.candidates[0].finish_reason) 156 | finish_reason = GEMINI_FINISH_REASON_MAP.get(finish_reason_str, "other") 157 | 158 | tool_calls: list[ToolCall] = [] 159 | tool_call_objs = response.function_calls 160 | if tool_call_objs: 161 | for tool_call_obj in tool_call_objs: 162 | tool_call_id = tool_call_obj.id if tool_call_obj.id else "" 163 | tool_calls.append( 164 | ToolCall( 165 | id=tool_call_id, 166 | function=Function( 167 | name=tool_call_obj.name if tool_call_obj.name is not None else "", 168 | arguments=tool_call_obj.args if tool_call_obj.args is not None else {}, 169 | ), 170 | ) 171 | ) 172 | 173 | assistant_message = "" 174 | if ( 175 | response.candidates 176 | and response.candidates[0].content 177 | and response.candidates[0].content.parts 178 | and response.candidates[0].content.parts[0].text 179 | ): 180 | assistant_message = response.candidates[0].content.parts[0].text 181 | 182 | choice = ChatCompletionChoice( 183 | message=AssistantMessage( 184 | role=Role.ASSISTANT, 185 | content=assistant_message, 186 | tool_calls=tool_calls, 187 | ), 188 | finish_reason=finish_reason, 189 | ) 190 | 191 | completion_tokens = 0 192 | # Add null check for usage_metadata 193 | if response.usage_metadata is not None: 194 | if response.usage_metadata.thoughts_token_count: 195 | completion_tokens = response.usage_metadata.thoughts_token_count 196 | if response.usage_metadata.candidates_token_count: 197 | completion_tokens += response.usage_metadata.candidates_token_count 198 | 199 | # Set safe default for prompt_tokens 200 | prompt_tokens = 0 201 | if response.usage_metadata is not None and response.usage_metadata.prompt_token_count: 202 | prompt_tokens = response.usage_metadata.prompt_token_count 203 | 204 | chat_completion_response = ChatCompletionResponse( 205 | choices=[choice], 206 | completion_tokens=completion_tokens, 207 | prompt_tokens=prompt_tokens, 208 | response_duration=response_duration, 209 | ) 210 | return chat_completion_response 211 | 212 | 213 | def create_client_callable(client_class: type[genai.Client], **client_args: Any) -> Callable[..., Any]: 214 | """Creates a callable that instantiates and uses a Google genai client. 215 | 216 | Args: 217 | client_class: The Google genai client class to instantiate 218 | **client_args: Arguments to pass to the client constructor 219 | 220 | Returns: 221 | A callable that creates a client and returns completion results 222 | """ 223 | filtered_args = {k: v for k, v in client_args.items() if v is not None} 224 | 225 | def client_callable(**kwargs: Any) -> Any: 226 | client = client_class(**filtered_args) 227 | completion = client.models.generate_content(**kwargs) 228 | return completion 229 | 230 | return client_callable 231 | 232 | 233 | def gemini_client(api_key: str | None = None) -> Callable[..., Any]: 234 | if not api_key: 235 | api_key = os.environ.get("GEMINI_API_KEY") 236 | client_callable = create_client_callable(genai.Client, api_key=api_key) 237 | return client_callable 238 | -------------------------------------------------------------------------------- /src/not_again_ai/llm/chat_completion/providers/ollama_api.py: -------------------------------------------------------------------------------- 1 | from collections.abc import AsyncGenerator, Callable 2 | import json 3 | import os 4 | import re 5 | import time 6 | from typing import Any, Literal, cast 7 | 8 | from loguru import logger 9 | from ollama import AsyncClient, ChatResponse, Client, ResponseError 10 | 11 | from not_again_ai.llm.chat_completion.types import ( 12 | AssistantMessage, 13 | ChatCompletionChoice, 14 | ChatCompletionChoiceStream, 15 | ChatCompletionChunk, 16 | ChatCompletionDelta, 17 | ChatCompletionRequest, 18 | ChatCompletionResponse, 19 | Function, 20 | PartialFunction, 21 | PartialToolCall, 22 | Role, 23 | ToolCall, 24 | ) 25 | 26 | OLLAMA_PARAMETER_MAP = { 27 | "frequency_penalty": "repeat_penalty", 28 | "max_completion_tokens": "num_predict", 29 | "context_window": "num_ctx", 30 | "n": None, 31 | "tool_choice": None, 32 | "reasoning_effort": None, 33 | "parallel_tool_calls": None, 34 | "logit_bias": None, 35 | "top_logprobs": None, 36 | "presence_penalty": None, 37 | "max_tokens": "num_predict", 38 | } 39 | 40 | 41 | def validate(request: ChatCompletionRequest) -> None: 42 | if request.json_mode and request.structured_outputs is not None: 43 | raise ValueError("json_schema and json_mode cannot be used together.") 44 | 45 | # Check if any of the parameters set to OLLAMA_PARAMETER_MAP are not None 46 | for key, value in OLLAMA_PARAMETER_MAP.items(): 47 | if value is None and getattr(request, key) is not None: 48 | logger.warning(f"Parameter {key} is not supported by Ollama and will be ignored.") 49 | 50 | # If "stop" is not None, check if it is just a string 51 | if isinstance(request.stop, list): 52 | logger.warning("Parameter 'stop' needs to be a string and not a list. It will be ignored.") 53 | request.stop = None 54 | 55 | # Raise an error if both "max_tokens" and "max_completion_tokens" are provided 56 | if request.max_tokens is not None and request.max_completion_tokens is not None: 57 | raise ValueError("`max_tokens` and `max_completion_tokens` cannot both be provided.") 58 | 59 | 60 | def format_kwargs(request: ChatCompletionRequest) -> dict[str, Any]: 61 | kwargs = request.model_dump(mode="json", exclude_none=True) 62 | # For each key in OLLAMA_PARAMETER_MAP 63 | # If it is not None, set the key in kwargs to the value of the corresponding value in OLLAMA_PARAMETER_MAP 64 | # If it is None, remove that key from kwargs 65 | for key, value in OLLAMA_PARAMETER_MAP.items(): 66 | if value is not None and key in kwargs: 67 | kwargs[value] = kwargs.pop(key) 68 | elif value is None and key in kwargs: 69 | del kwargs[key] 70 | 71 | # If json_mode is True, set the format to json 72 | json_mode = kwargs.get("json_mode", None) 73 | if json_mode: 74 | kwargs["format"] = "json" 75 | kwargs.pop("json_mode") 76 | elif json_mode is not None and not json_mode: 77 | kwargs.pop("json_mode") 78 | 79 | # If structured_outputs is not None, set the format to structured_outputs 80 | if kwargs.get("structured_outputs", None): 81 | # Check if the schema is in the OpenAI and pull out the schema 82 | if "schema" in kwargs["structured_outputs"]: 83 | kwargs["format"] = kwargs["structured_outputs"]["schema"] 84 | kwargs.pop("structured_outputs") 85 | else: 86 | kwargs["format"] = kwargs.pop("structured_outputs") 87 | 88 | option_fields = [ 89 | "mirostat", 90 | "mirostat_eta", 91 | "mirostat_tau", 92 | "num_ctx", 93 | "repeat_last_n", 94 | "repeat_penalty", 95 | "temperature", 96 | "seed", 97 | "stop", 98 | "tfs_z", 99 | "num_predict", 100 | "top_k", 101 | "top_p", 102 | "min_p", 103 | ] 104 | # For each field in option_fields, if it is in kwargs, make it under an options dictionary 105 | options = {} 106 | for field in option_fields: 107 | if field in kwargs: 108 | options[field] = kwargs.pop(field) 109 | kwargs["options"] = options 110 | 111 | for message in kwargs["messages"]: 112 | role = message.get("role", None) 113 | # For each ToolMessage, remove the name field 114 | if role is not None and role == "tool": 115 | message.pop("name") 116 | 117 | # For each AssistantMessage with tool calls, remove the id field 118 | if role is not None and role == "assistant" and message.get("tool_calls", None): 119 | for tool_call in message["tool_calls"]: 120 | tool_call.pop("id") 121 | 122 | # Content and images need to be separated 123 | images = [] 124 | content = "" 125 | if isinstance(message["content"], list): 126 | for item in message["content"]: 127 | if item["type"] == "image_url": 128 | image_url = item["image_url"]["url"] 129 | # Remove the data URL prefix if present 130 | if image_url.startswith("data:"): 131 | image_url = image_url.split("base64,", 1)[1] 132 | images.append(image_url) 133 | else: 134 | content += item["text"] 135 | else: 136 | content = message["content"] 137 | 138 | message["content"] = content 139 | if len(images) > 1: 140 | images = images[:1] 141 | logger.warning("Ollama model only supports a single image per message. Using only the first images.") 142 | message["images"] = images 143 | 144 | return kwargs 145 | 146 | 147 | def ollama_chat_completion( 148 | request: ChatCompletionRequest, 149 | client: Callable[..., Any], 150 | ) -> ChatCompletionResponse: 151 | validate(request) 152 | kwargs = format_kwargs(request) 153 | 154 | try: 155 | start_time = time.time() 156 | response: ChatResponse = client(**kwargs) 157 | end_time = time.time() 158 | response_duration = round(end_time - start_time, 4) 159 | except ResponseError as e: 160 | # If the error says "model 'model' not found" use regex then raise a more specific error 161 | expected_pattern = f"model '{request.model}' not found" 162 | if re.search(expected_pattern, e.error): 163 | raise ResponseError(f"Model '{request.model}' not found.") from e 164 | else: 165 | raise ResponseError(e.error) from e 166 | 167 | errors = "" 168 | 169 | # Handle tool calls 170 | tool_calls: list[ToolCall] | None = None 171 | if response.message.tool_calls: 172 | parsed_tool_calls: list[ToolCall] = [] 173 | for tool_call in response.message.tool_calls: 174 | tool_name = tool_call.function.name 175 | if request.tools and tool_name not in [tool["function"]["name"] for tool in request.tools]: 176 | errors += f"Tool call {tool_call} has an invalid tool name: {tool_name}\n" 177 | tool_args = dict(tool_call.function.arguments) 178 | parsed_tool_calls.append( 179 | ToolCall( 180 | id="", 181 | function=Function( 182 | name=tool_name, 183 | arguments=tool_args, 184 | ), 185 | ) 186 | ) 187 | tool_calls = parsed_tool_calls 188 | 189 | json_message = None 190 | if (request.json_mode or (request.structured_outputs is not None)) and response.message.content: 191 | try: 192 | json_message = json.loads(response.message.content) 193 | except json.JSONDecodeError: 194 | errors += "Message failed to parse into JSON\n" 195 | 196 | finish_reason = cast( 197 | Literal["stop", "length", "tool_calls", "content_filter"], 198 | "stop" if response.done_reason is None else response.done_reason or "stop", 199 | ) 200 | 201 | choice = ChatCompletionChoice( 202 | message=AssistantMessage( 203 | content=response.message.content or "", 204 | tool_calls=tool_calls, 205 | ), 206 | finish_reason=finish_reason, 207 | json_message=json_message, 208 | ) 209 | 210 | return ChatCompletionResponse( 211 | choices=[choice], 212 | errors=errors.strip(), 213 | completion_tokens=response.get("eval_count", -1), 214 | prompt_tokens=response.get("prompt_eval_count", -1), 215 | response_duration=response_duration, 216 | ) 217 | 218 | 219 | async def ollama_chat_completion_stream( 220 | request: ChatCompletionRequest, 221 | client: Callable[..., Any], 222 | ) -> AsyncGenerator[ChatCompletionChunk, None]: 223 | validate(request) 224 | kwargs = format_kwargs(request) 225 | 226 | start_time = time.time() 227 | stream = await client(**kwargs) 228 | 229 | async for chunk in stream: 230 | errors = "" 231 | # Handle tool calls 232 | tool_calls: list[PartialToolCall] | None = None 233 | if chunk.message.tool_calls: 234 | parsed_tool_calls: list[PartialToolCall] = [] 235 | for tool_call in chunk.message.tool_calls: 236 | tool_name = tool_call.function.name 237 | if request.tools and tool_name not in [tool["function"]["name"] for tool in request.tools]: 238 | errors += f"Tool call {tool_call} has an invalid tool name: {tool_name}\n" 239 | tool_args = tool_call.function.arguments 240 | 241 | parsed_tool_calls.append( 242 | PartialToolCall( 243 | id="", 244 | function=PartialFunction( 245 | name=tool_name, 246 | arguments=tool_args, 247 | ), 248 | ) 249 | ) 250 | tool_calls = parsed_tool_calls 251 | 252 | current_time = time.time() 253 | response_duration = round(current_time - start_time, 4) 254 | 255 | delta = ChatCompletionDelta( 256 | content=chunk.message.content or "", 257 | role=Role.ASSISTANT, 258 | tool_calls=tool_calls, 259 | ) 260 | choice_obj = ChatCompletionChoiceStream( 261 | delta=delta, 262 | finish_reason=chunk.done_reason, 263 | index=0, 264 | ) 265 | chunk_obj = ChatCompletionChunk( 266 | choices=[choice_obj], 267 | errors=errors.strip(), 268 | completion_tokens=chunk.get("eval_count", None), 269 | prompt_tokens=chunk.get("prompt_eval_count", None), 270 | response_duration=response_duration, 271 | ) 272 | yield chunk_obj 273 | 274 | 275 | def ollama_client( 276 | host: str | None = None, timeout: float | None = None, async_client: bool = False 277 | ) -> Callable[..., Any]: 278 | """Create an Ollama client instance based on the specified host or will read from the OLLAMA_HOST environment variable. 279 | 280 | Args: 281 | host (str, optional): The host URL of the Ollama server. 282 | timeout (float, optional): The timeout for requests 283 | 284 | Returns: 285 | Client: An instance of the Ollama client. 286 | 287 | Examples: 288 | >>> client = client(host="http://localhost:11434") 289 | """ 290 | if host is None: 291 | host = os.getenv("OLLAMA_HOST") 292 | if host is None: 293 | logger.warning("OLLAMA_HOST environment variable not set, using default host: http://localhost:11434") 294 | host = "http://localhost:11434" 295 | 296 | def client_callable(**kwargs: Any) -> Any: 297 | client = AsyncClient(host=host, timeout=timeout) if async_client else Client(host=host, timeout=timeout) 298 | return client.chat(**kwargs) 299 | 300 | return client_callable 301 | -------------------------------------------------------------------------------- /src/not_again_ai/llm/chat_completion/types.py: -------------------------------------------------------------------------------- 1 | from enum import Enum 2 | from typing import Any, Generic, Literal, TypeVar 3 | 4 | from pydantic import BaseModel, Field 5 | 6 | 7 | class Role(str, Enum): 8 | ASSISTANT = "assistant" 9 | DEVELOPER = "developer" 10 | SYSTEM = "system" 11 | TOOL = "tool" 12 | USER = "user" 13 | 14 | 15 | class ContentPartType(str, Enum): 16 | TEXT = "text" 17 | IMAGE = "image_url" 18 | 19 | 20 | class TextContent(BaseModel): 21 | type: Literal[ContentPartType.TEXT] = ContentPartType.TEXT 22 | text: str 23 | 24 | 25 | class ImageDetail(str, Enum): 26 | AUTO = "auto" 27 | LOW = "low" 28 | HIGH = "high" 29 | 30 | 31 | class ImageUrl(BaseModel): 32 | url: str 33 | detail: ImageDetail = ImageDetail.AUTO 34 | 35 | 36 | class ImageContent(BaseModel): 37 | type: Literal[ContentPartType.IMAGE] = ContentPartType.IMAGE 38 | image_url: ImageUrl 39 | 40 | 41 | ContentT = TypeVar("ContentT", bound=str | list[TextContent | ImageContent]) 42 | 43 | 44 | class BaseMessage(BaseModel, Generic[ContentT]): 45 | content: ContentT 46 | role: Role 47 | name: str | None = None 48 | 49 | 50 | class Function(BaseModel): 51 | name: str 52 | arguments: dict[str, Any] 53 | 54 | 55 | class PartialFunction(BaseModel): 56 | name: str 57 | arguments: str | dict[str, Any] 58 | 59 | 60 | class ToolCall(BaseModel): 61 | id: str 62 | function: Function 63 | type: Literal["function"] = "function" 64 | 65 | 66 | class PartialToolCall(BaseModel): 67 | id: str | None 68 | function: PartialFunction 69 | type: Literal["function"] = "function" 70 | 71 | 72 | class DeveloperMessage(BaseMessage[str]): 73 | role: Literal[Role.DEVELOPER] = Role.DEVELOPER 74 | 75 | 76 | class SystemMessage(BaseMessage[str]): 77 | role: Literal[Role.SYSTEM] = Role.SYSTEM 78 | 79 | 80 | class UserMessage(BaseMessage[str | list[TextContent | ImageContent]]): 81 | role: Literal[Role.USER] = Role.USER 82 | 83 | 84 | class AssistantMessage(BaseMessage[str]): 85 | role: Literal[Role.ASSISTANT] = Role.ASSISTANT 86 | refusal: str | None = None 87 | tool_calls: list[ToolCall] | None = None 88 | 89 | 90 | class ToolMessage(BaseMessage[str]): 91 | # A tool message's name field will be interpreted as "tool_call_id" 92 | role: Literal[Role.TOOL] = Role.TOOL 93 | 94 | 95 | MessageT = AssistantMessage | DeveloperMessage | SystemMessage | ToolMessage | UserMessage 96 | 97 | 98 | class ChatCompletionRequest(BaseModel): 99 | messages: list[MessageT] 100 | model: str 101 | stream: bool = Field(default=False) 102 | 103 | max_completion_tokens: int | None = Field(default=None) 104 | context_window: int | None = Field(default=None) 105 | logprobs: bool | None = Field(default=None) 106 | n: int | None = Field(default=None) 107 | 108 | tools: list[dict[str, Any]] | None = Field(default=None) 109 | tool_choice: str | None = Field(default=None) 110 | parallel_tool_calls: bool | None = Field(default=None) 111 | json_mode: bool | None = Field(default=None) 112 | structured_outputs: dict[str, Any] | None = Field(default=None) 113 | 114 | temperature: float | None = Field(default=None) 115 | reasoning_effort: Literal["low", "medium", "high"] | None = Field(default=None) 116 | top_p: float | None = Field(default=None) 117 | logit_bias: dict[str, float] | None = Field(default=None) 118 | top_logprobs: int | None = Field(default=None) 119 | frequency_penalty: float | None = Field(default=None) 120 | presence_penalty: float | None = Field(default=None) 121 | stop: str | list[str] | None = Field(default=None) 122 | 123 | seed: int | None = Field(default=None) 124 | 125 | mirostat: int | None = Field(default=None) 126 | mirostat_eta: float | None = Field(default=None) 127 | mirostat_tau: float | None = Field(default=None) 128 | repeat_last_n: int | None = Field(default=None) 129 | tfs_z: float | None = Field(default=None) 130 | top_k: int | None = Field(default=None) 131 | min_p: float | None = Field(default=None) 132 | 133 | max_tokens: int | None = Field( 134 | default=None, 135 | description="Sometimes `max_completion_tokens` is not correctly supported so we provide this as a fallback.", 136 | ) 137 | 138 | 139 | class ChatCompletionChoice(BaseModel): 140 | message: AssistantMessage 141 | finish_reason: str 142 | json_message: dict[str, Any] | None = Field(default=None) 143 | logprobs: list[dict[str, Any] | list[dict[str, Any]]] | None = Field(default=None) 144 | 145 | extras: Any | None = Field(default=None) 146 | 147 | 148 | class ChatCompletionResponse(BaseModel): 149 | choices: list[ChatCompletionChoice] 150 | 151 | errors: str = Field(default="") 152 | 153 | completion_tokens: int 154 | prompt_tokens: int 155 | completion_detailed_tokens: dict[str, int] | None = Field(default=None) 156 | prompt_detailed_tokens: dict[str, int] | None = Field(default=None) 157 | cache_read_input_tokens: int | None = Field(default=None) 158 | cache_creation_input_tokens: int | None = Field(default=None) 159 | response_duration: float 160 | 161 | system_fingerprint: str | None = Field(default=None) 162 | 163 | extras: Any | None = Field(default=None) 164 | 165 | 166 | class ChatCompletionDelta(BaseModel): 167 | content: str 168 | role: Role = Field(default=Role.ASSISTANT) 169 | 170 | tool_calls: list[PartialToolCall] | None = Field(default=None) 171 | 172 | refusal: str | None = Field(default=None) 173 | 174 | 175 | class ChatCompletionChoiceStream(BaseModel): 176 | delta: ChatCompletionDelta 177 | index: int 178 | finish_reason: Literal["stop", "length", "tool_calls", "content_filter"] | None 179 | 180 | logprobs: list[dict[str, Any] | list[dict[str, Any]]] | None = Field(default=None) 181 | 182 | extras: Any | None = Field(default=None) 183 | 184 | 185 | class ChatCompletionChunk(BaseModel): 186 | choices: list[ChatCompletionChoiceStream] 187 | 188 | errors: str = Field(default="") 189 | 190 | completion_tokens: int | None = Field(default=None) 191 | prompt_tokens: int | None = Field(default=None) 192 | response_duration: float | None = Field(default=None) 193 | 194 | system_fingerprint: str | None = Field(default=None) 195 | extras: Any | None = Field(default=None) 196 | -------------------------------------------------------------------------------- /src/not_again_ai/llm/embedding/__init__.py: -------------------------------------------------------------------------------- 1 | from not_again_ai.llm.embedding.interface import create_embeddings 2 | from not_again_ai.llm.embedding.types import EmbeddingRequest 3 | 4 | __all__ = ["EmbeddingRequest", "create_embeddings"] 5 | -------------------------------------------------------------------------------- /src/not_again_ai/llm/embedding/interface.py: -------------------------------------------------------------------------------- 1 | from collections.abc import Callable 2 | from typing import Any 3 | 4 | from not_again_ai.llm.embedding.providers.ollama_api import ollama_create_embeddings 5 | from not_again_ai.llm.embedding.providers.openai_api import openai_create_embeddings 6 | from not_again_ai.llm.embedding.types import EmbeddingRequest, EmbeddingResponse 7 | 8 | 9 | def create_embeddings(request: EmbeddingRequest, provider: str, client: Callable[..., Any]) -> EmbeddingResponse: 10 | """Get a embedding response from the given provider. Currently supported providers: 11 | - `openai` - OpenAI 12 | - `azure_openai` - Azure OpenAI 13 | - `ollama` - Ollama 14 | 15 | Args: 16 | request: Request parameter object 17 | provider: The supported provider name 18 | client: Client information, see the provider's implementation for what can be provided 19 | 20 | Returns: 21 | EmbeddingResponse: The embedding response. 22 | """ 23 | if provider == "openai" or provider == "azure_openai": 24 | return openai_create_embeddings(request, client) 25 | elif provider == "ollama": 26 | return ollama_create_embeddings(request, client) 27 | else: 28 | raise ValueError(f"Provider {provider} not supported") 29 | -------------------------------------------------------------------------------- /src/not_again_ai/llm/embedding/providers/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DaveCoDev/not-again-ai/93302ff4b8b9eda5398e2c318854677f517445c8/src/not_again_ai/llm/embedding/providers/__init__.py -------------------------------------------------------------------------------- /src/not_again_ai/llm/embedding/providers/ollama_api.py: -------------------------------------------------------------------------------- 1 | from collections.abc import Callable 2 | import os 3 | import re 4 | import time 5 | from typing import Any 6 | 7 | from loguru import logger 8 | from ollama import Client, EmbedResponse, ResponseError 9 | 10 | from not_again_ai.llm.embedding.types import EmbeddingObject, EmbeddingRequest, EmbeddingResponse 11 | 12 | OLLAMA_PARAMETER_MAP = { 13 | "dimensions": None, 14 | } 15 | 16 | 17 | def validate(request: EmbeddingRequest) -> None: 18 | # Check if any of the parameters set to OLLAMA_PARAMETER_MAP are not None 19 | for key, value in OLLAMA_PARAMETER_MAP.items(): 20 | if value is None and getattr(request, key) is not None: 21 | logger.warning(f"Parameter {key} is not supported by Ollama and will be ignored.") 22 | 23 | 24 | def ollama_create_embeddings(request: EmbeddingRequest, client: Callable[..., Any]) -> EmbeddingResponse: 25 | validate(request) 26 | kwargs = request.model_dump(mode="json", exclude_none=True) 27 | 28 | # For each key in OLLAMA_PARAMETER_MAP 29 | # If it is not None, set the key in kwargs to the value of the corresponding value in OLLAMA_PARAMETER_MAP 30 | # If it is None, remove that key from kwargs 31 | for key, value in OLLAMA_PARAMETER_MAP.items(): 32 | if value is not None and key in kwargs: 33 | kwargs[value] = kwargs.pop(key) 34 | elif value is None and key in kwargs: 35 | del kwargs[key] 36 | 37 | # Explicitly set truncate to True (it is the default) 38 | kwargs["truncate"] = True 39 | 40 | try: 41 | start_time = time.time() 42 | response: EmbedResponse = client(**kwargs) 43 | end_time = time.time() 44 | response_duration = round(end_time - start_time, 4) 45 | except ResponseError as e: 46 | # If the error says "model 'model' not found" use regex then raise a more specific error 47 | expected_pattern = f"model '{request.model}' not found" 48 | if re.search(expected_pattern, e.error): 49 | raise ResponseError(f"Model '{request.model}' not found.") from e 50 | else: 51 | raise ResponseError(e.error) from e 52 | 53 | embeddings: list[EmbeddingObject] = [] 54 | for index, embedding in enumerate(response.embeddings): 55 | embeddings.append(EmbeddingObject(embedding=list(embedding), index=index)) 56 | 57 | return EmbeddingResponse( 58 | embeddings=embeddings, 59 | response_duration=response_duration, 60 | total_tokens=response.prompt_eval_count, 61 | ) 62 | 63 | 64 | def ollama_client(host: str | None = None, timeout: float | None = None) -> Callable[..., Any]: 65 | """Create an Ollama client instance based on the specified host or will read from the OLLAMA_HOST environment variable. 66 | 67 | Args: 68 | host (str, optional): The host URL of the Ollama server. 69 | timeout (float, optional): The timeout for requests 70 | 71 | Returns: 72 | Client: An instance of the Ollama client. 73 | 74 | Examples: 75 | >>> client = client(host="http://localhost:11434") 76 | """ 77 | if host is None: 78 | host = os.getenv("OLLAMA_HOST") 79 | if host is None: 80 | logger.warning("OLLAMA_HOST environment variable not set, using default host: http://localhost:11434") 81 | host = "http://localhost:11434" 82 | 83 | def client_callable(**kwargs: Any) -> Any: 84 | client = Client(host=host, timeout=timeout) 85 | return client.embed(**kwargs) 86 | 87 | return client_callable 88 | -------------------------------------------------------------------------------- /src/not_again_ai/llm/embedding/providers/openai_api.py: -------------------------------------------------------------------------------- 1 | from collections.abc import Callable 2 | import time 3 | from typing import Any, Literal 4 | 5 | from azure.identity import DefaultAzureCredential, get_bearer_token_provider 6 | from openai import AzureOpenAI, OpenAI 7 | 8 | from not_again_ai.llm.embedding.types import EmbeddingObject, EmbeddingRequest, EmbeddingResponse 9 | 10 | 11 | def openai_create_embeddings(request: EmbeddingRequest, client: Callable[..., Any]) -> EmbeddingResponse: 12 | kwargs = request.model_dump(mode="json", exclude_none=True) 13 | 14 | start_time = time.time() 15 | response = client(**kwargs) 16 | end_time = time.time() 17 | response_duration = round(end_time - start_time, 4) 18 | 19 | embeddings: list[EmbeddingObject] = [] 20 | for data in response["data"]: 21 | embeddings.append(EmbeddingObject(embedding=data["embedding"], index=data["index"])) 22 | 23 | return EmbeddingResponse( 24 | embeddings=embeddings, 25 | response_duration=response_duration, 26 | total_tokens=response["usage"]["total_tokens"], 27 | ) 28 | 29 | 30 | def create_client_callable(client_class: type[OpenAI | AzureOpenAI], **client_args: Any) -> Callable[..., Any]: 31 | """Creates a callable that instantiates and uses an OpenAI client. 32 | 33 | Args: 34 | client_class: The OpenAI client class to instantiate (OpenAI or AzureOpenAI) 35 | **client_args: Arguments to pass to the client constructor 36 | 37 | Returns: 38 | A callable that creates a client and returns completion results 39 | """ 40 | filtered_args = {k: v for k, v in client_args.items() if v is not None} 41 | 42 | def client_callable(**kwargs: Any) -> Any: 43 | client = client_class(**filtered_args) 44 | completion = client.embeddings.create(**kwargs) 45 | return completion.to_dict() 46 | 47 | return client_callable 48 | 49 | 50 | class InvalidOAIAPITypeError(Exception): 51 | """Raised when an invalid OAIAPIType string is provided.""" 52 | 53 | 54 | def openai_client( 55 | api_type: Literal["openai", "azure_openai"] = "openai", 56 | api_key: str | None = None, 57 | organization: str | None = None, 58 | aoai_api_version: str = "2024-06-01", 59 | azure_endpoint: str | None = None, 60 | timeout: float | None = None, 61 | max_retries: int | None = None, 62 | ) -> Callable[..., Any]: 63 | """Create an OpenAI or Azure OpenAI client instance based on the specified API type and other provided parameters. 64 | 65 | It is preferred to use RBAC authentication for Azure OpenAI. You must be signed in with the Azure CLI and have correct role assigned. 66 | See https://techcommunity.microsoft.com/t5/microsoft-developer-community/using-keyless-authentication-with-azure-openai/ba-p/4111521 67 | 68 | Args: 69 | api_type (str, optional): Type of the API to be used. Accepted values are 'openai' or 'azure_openai'. 70 | Defaults to 'openai'. 71 | api_key (str, optional): The API key to authenticate the client. If not provided, 72 | OpenAI automatically uses `OPENAI_API_KEY` from the environment. 73 | If provided for Azure OpenAI, it will be used for authentication instead of the Azure AD token provider. 74 | organization (str, optional): The ID of the organization. If not provided, 75 | OpenAI automotically uses `OPENAI_ORG_ID` from the environment. 76 | aoai_api_version (str, optional): Only applicable if using Azure OpenAI https://learn.microsoft.com/azure/ai-services/openai/reference#rest-api-versioning 77 | azure_endpoint (str, optional): The endpoint to use for Azure OpenAI. 78 | timeout (float, optional): By default requests time out after 10 minutes. 79 | max_retries (int, optional): Certain errors are automatically retried 2 times by default, 80 | with a short exponential backoff. Connection errors (for example, due to a network connectivity problem), 81 | 408 Request Timeout, 409 Conflict, 429 Rate Limit, and >=500 Internal errors are all retried by default. 82 | 83 | Returns: 84 | Callable[..., Any]: A callable that creates a client and returns completion results 85 | 86 | 87 | Raises: 88 | InvalidOAIAPITypeError: If an invalid API type string is provided. 89 | NotImplementedError: If the specified API type is recognized but not yet supported (e.g., 'azure_openai'). 90 | """ 91 | if api_type not in ["openai", "azure_openai"]: 92 | raise InvalidOAIAPITypeError(f"Invalid OAIAPIType: {api_type}. Must be 'openai' or 'azure_openai'.") 93 | 94 | if api_type == "openai": 95 | return create_client_callable( 96 | OpenAI, 97 | api_key=api_key, 98 | organization=organization, 99 | timeout=timeout, 100 | max_retries=max_retries, 101 | ) 102 | elif api_type == "azure_openai": 103 | if api_key: 104 | return create_client_callable( 105 | AzureOpenAI, 106 | api_version=aoai_api_version, 107 | azure_endpoint=azure_endpoint, 108 | api_key=api_key, 109 | timeout=timeout, 110 | max_retries=max_retries, 111 | ) 112 | else: 113 | azure_credential = DefaultAzureCredential() 114 | ad_token_provider = get_bearer_token_provider( 115 | azure_credential, "https://cognitiveservices.azure.com/.default" 116 | ) 117 | return create_client_callable( 118 | AzureOpenAI, 119 | api_version=aoai_api_version, 120 | azure_endpoint=azure_endpoint, 121 | azure_ad_token_provider=ad_token_provider, 122 | timeout=timeout, 123 | max_retries=max_retries, 124 | ) 125 | else: 126 | raise NotImplementedError(f"API type '{api_type}' is invalid.") 127 | -------------------------------------------------------------------------------- /src/not_again_ai/llm/embedding/types.py: -------------------------------------------------------------------------------- 1 | from typing import Any 2 | 3 | from pydantic import BaseModel, Field 4 | 5 | 6 | class EmbeddingRequest(BaseModel): 7 | input: str | list[str] 8 | model: str 9 | dimensions: int | None = Field(default=None) 10 | 11 | 12 | class EmbeddingObject(BaseModel): 13 | embedding: list[float] 14 | index: int 15 | 16 | 17 | class EmbeddingResponse(BaseModel): 18 | embeddings: list[EmbeddingObject] 19 | total_tokens: int | None = Field(default=None) 20 | response_duration: float 21 | 22 | errors: str = Field(default="") 23 | extras: Any | None = Field(default=None) 24 | -------------------------------------------------------------------------------- /src/not_again_ai/llm/image_gen/__init__.py: -------------------------------------------------------------------------------- 1 | from not_again_ai.llm.image_gen.interface import create_image 2 | from not_again_ai.llm.image_gen.types import ImageGenRequest 3 | 4 | __all__ = ["ImageGenRequest", "create_image"] 5 | -------------------------------------------------------------------------------- /src/not_again_ai/llm/image_gen/interface.py: -------------------------------------------------------------------------------- 1 | from collections.abc import Callable 2 | from typing import Any 3 | 4 | from not_again_ai.llm.image_gen.providers.openai_api import openai_create_image 5 | from not_again_ai.llm.image_gen.types import ImageGenRequest, ImageGenResponse 6 | 7 | 8 | def create_image(request: ImageGenRequest, provider: str, client: Callable[..., Any]) -> ImageGenResponse: 9 | """Get a image response from the given provider. Currently supported providers: 10 | - `openai` - OpenAI 11 | - `azure_openai` - Azure OpenAI 12 | 13 | Args: 14 | request: Request parameter object 15 | provider: The supported provider name 16 | client: Client information, see the provider's implementation for what can be provided 17 | 18 | Returns: 19 | ImageGenResponse: The image generation response. 20 | """ 21 | if provider == "openai" or provider == "azure_openai": 22 | return openai_create_image(request, client) 23 | else: 24 | raise ValueError(f"Provider {provider} not supported") 25 | -------------------------------------------------------------------------------- /src/not_again_ai/llm/image_gen/providers/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DaveCoDev/not-again-ai/93302ff4b8b9eda5398e2c318854677f517445c8/src/not_again_ai/llm/image_gen/providers/__init__.py -------------------------------------------------------------------------------- /src/not_again_ai/llm/image_gen/providers/openai_api.py: -------------------------------------------------------------------------------- 1 | import base64 2 | from collections.abc import Callable 3 | import time 4 | from typing import Any, Literal 5 | 6 | from azure.identity import DefaultAzureCredential, get_bearer_token_provider 7 | from openai import AzureOpenAI, OpenAI 8 | from openai.types.images_response import ImagesResponse 9 | 10 | from not_again_ai.llm.image_gen.types import ImageGenRequest, ImageGenResponse 11 | 12 | 13 | def openai_create_image(request: ImageGenRequest, client: Callable[..., Any]) -> ImageGenResponse: 14 | """Create an image using OpenAI API. 15 | 16 | Args: 17 | request (ImageGenRequest): The request object containing parameters for image generation. 18 | client (Callable[..., Any]): The OpenAI client callable. 19 | 20 | Returns: 21 | ImageGenResponse: The response object containing the generated image and metadata. 22 | """ 23 | kwargs = request.model_dump(exclude_none=True) 24 | if kwargs.get("images"): 25 | kwargs["image"] = kwargs.pop("images", None) 26 | 27 | start_time = time.time() 28 | response: ImagesResponse = client(**kwargs) 29 | end_time = time.time() 30 | response_duration = round(end_time - start_time, 4) 31 | 32 | images: list[bytes] = [] 33 | if response.data: 34 | for data in response.data: 35 | images.append(base64.b64decode(data.b64_json or "")) 36 | 37 | input_tokens = response.usage.input_tokens if response.usage else -1 38 | output_tokens = response.usage.output_tokens if response.usage else -1 39 | input_tokens_details = response.usage.input_tokens_details.to_dict() if response.usage else {} 40 | image_gen_response = ImageGenResponse( 41 | images=images, 42 | input_tokens=input_tokens, 43 | output_tokens=output_tokens, 44 | input_tokens_details=input_tokens_details, 45 | response_duration=response_duration, 46 | ) 47 | return image_gen_response 48 | 49 | 50 | def create_client_callable(client_class: type[OpenAI | AzureOpenAI], **client_args: Any) -> Callable[..., Any]: 51 | """ 52 | Creates the correct callable depending on the parameters provided. 53 | """ 54 | filtered_args = {k: v for k, v in client_args.items() if v is not None} 55 | 56 | def client_callable(**kwargs: Any) -> Any: 57 | client = client_class(**filtered_args) 58 | # If mask or image is not none, use client.images.edit instead of client.images.generate 59 | if kwargs.get("mask") or kwargs.get("image"): 60 | completion = client.images.edit(**kwargs) 61 | else: 62 | completion = client.images.generate(**kwargs) 63 | return completion 64 | 65 | return client_callable 66 | 67 | 68 | class InvalidOAIAPITypeError(Exception): 69 | """Raised when an invalid OAIAPIType string is provided.""" 70 | 71 | 72 | def openai_client( 73 | api_type: Literal["openai", "azure_openai"] = "openai", 74 | api_key: str | None = None, 75 | organization: str | None = None, 76 | aoai_api_version: str = "2024-06-01", 77 | azure_endpoint: str | None = None, 78 | timeout: float | None = None, 79 | max_retries: int | None = None, 80 | ) -> Callable[..., Any]: 81 | """Create an OpenAI or Azure OpenAI client instance based on the specified API type and other provided parameters. 82 | 83 | It is preferred to use RBAC authentication for Azure OpenAI. You must be signed in with the Azure CLI and have correct role assigned. 84 | See https://techcommunity.microsoft.com/t5/microsoft-developer-community/using-keyless-authentication-with-azure-openai/ba-p/4111521 85 | 86 | Args: 87 | api_type (str, optional): Type of the API to be used. Accepted values are 'openai' or 'azure_openai'. 88 | Defaults to 'openai'. 89 | api_key (str, optional): The API key to authenticate the client. If not provided, 90 | OpenAI automatically uses `OPENAI_API_KEY` from the environment. 91 | If provided for Azure OpenAI, it will be used for authentication instead of the Azure AD token provider. 92 | organization (str, optional): The ID of the organization. If not provided, 93 | OpenAI automotically uses `OPENAI_ORG_ID` from the environment. 94 | aoai_api_version (str, optional): Only applicable if using Azure OpenAI https://learn.microsoft.com/azure/ai-services/openai/reference#rest-api-versioning 95 | azure_endpoint (str, optional): The endpoint to use for Azure OpenAI. 96 | timeout (float, optional): By default requests time out after 10 minutes. 97 | max_retries (int, optional): Certain errors are automatically retried 2 times by default, 98 | with a short exponential backoff. Connection errors (for example, due to a network connectivity problem), 99 | 408 Request Timeout, 409 Conflict, 429 Rate Limit, and >=500 Internal errors are all retried by default. 100 | 101 | Returns: 102 | Callable[..., Any]: A callable that creates a client and returns completion results 103 | 104 | 105 | Raises: 106 | InvalidOAIAPITypeError: If an invalid API type string is provided. 107 | NotImplementedError: If the specified API type is recognized but not yet supported (e.g., 'azure_openai'). 108 | """ 109 | if api_type not in ["openai", "azure_openai"]: 110 | raise InvalidOAIAPITypeError(f"Invalid OAIAPIType: {api_type}. Must be 'openai' or 'azure_openai'.") 111 | 112 | if api_type == "openai": 113 | return create_client_callable( 114 | OpenAI, 115 | api_key=api_key, 116 | organization=organization, 117 | timeout=timeout, 118 | max_retries=max_retries, 119 | ) 120 | elif api_type == "azure_openai": 121 | if api_key: 122 | return create_client_callable( 123 | AzureOpenAI, 124 | api_version=aoai_api_version, 125 | azure_endpoint=azure_endpoint, 126 | api_key=api_key, 127 | timeout=timeout, 128 | max_retries=max_retries, 129 | ) 130 | else: 131 | azure_credential = DefaultAzureCredential() 132 | ad_token_provider = get_bearer_token_provider( 133 | azure_credential, "https://cognitiveservices.azure.com/.default" 134 | ) 135 | return create_client_callable( 136 | AzureOpenAI, 137 | api_version=aoai_api_version, 138 | azure_endpoint=azure_endpoint, 139 | azure_ad_token_provider=ad_token_provider, 140 | timeout=timeout, 141 | max_retries=max_retries, 142 | ) 143 | else: 144 | raise NotImplementedError(f"API type '{api_type}' is invalid.") 145 | -------------------------------------------------------------------------------- /src/not_again_ai/llm/image_gen/types.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | from typing import Any 3 | 4 | from pydantic import BaseModel, Field 5 | 6 | 7 | class ImageGenRequest(BaseModel): 8 | prompt: str 9 | model: str 10 | images: list[Path] | None = Field(default=None) 11 | mask: Path | None = Field(default=None) 12 | n: int = Field(default=1) 13 | quality: str | None = Field(default=None) 14 | size: str | None = Field(default=None) 15 | background: str | None = Field(default=None) 16 | moderation: str | None = Field(default=None) 17 | 18 | 19 | class ImageGenResponse(BaseModel): 20 | images: list[bytes] 21 | input_tokens: int 22 | output_tokens: int 23 | response_duration: float 24 | input_tokens_details: dict[str, Any] | None = Field(default=None) 25 | -------------------------------------------------------------------------------- /src/not_again_ai/llm/prompting/__init__.py: -------------------------------------------------------------------------------- 1 | from not_again_ai.llm.prompting.interface import Tokenizer 2 | 3 | __all__ = ["Tokenizer"] 4 | -------------------------------------------------------------------------------- /src/not_again_ai/llm/prompting/compile_prompt.py: -------------------------------------------------------------------------------- 1 | import base64 2 | from collections.abc import Sequence 3 | from copy import deepcopy 4 | import mimetypes 5 | from pathlib import Path 6 | from typing import Any 7 | 8 | from liquid import render 9 | from openai.lib._pydantic import to_strict_json_schema 10 | from pydantic import BaseModel 11 | 12 | from not_again_ai.llm.chat_completion.types import MessageT 13 | 14 | 15 | def _apply_templates(value: Any, variables: dict[str, str]) -> Any: 16 | """Recursively applies Liquid templating to all string fields within the given value.""" 17 | if isinstance(value, str): 18 | return render(value, **variables) 19 | elif isinstance(value, list): 20 | return [_apply_templates(item, variables) for item in value] 21 | elif isinstance(value, dict): 22 | return {key: _apply_templates(val, variables) for key, val in value.items()} 23 | elif isinstance(value, BaseModel): 24 | # Process each field in the BaseModel by converting it to a dict, 25 | # applying templating to its values, and then re-instantiating the model. 26 | processed_data = {key: _apply_templates(val, variables) for key, val in value.model_dump().items()} 27 | return value.__class__(**processed_data) 28 | else: 29 | return value 30 | 31 | 32 | def compile_messages(messages: Sequence[MessageT], variables: dict[str, str]) -> Sequence[MessageT]: 33 | """Compiles messages using Liquid templating and the provided variables. 34 | Calls render(content_part, **variables) on each text content part. 35 | 36 | Args: 37 | messages: List of MessageT where content can contain Liquid templates. 38 | variables: The variables to inject into the templates. 39 | 40 | Returns: 41 | The same list of messages with the content parts injected with the variables. 42 | """ 43 | messages_formatted = deepcopy(messages) 44 | messages_formatted = [_apply_templates(message, variables) for message in messages_formatted] 45 | return messages_formatted 46 | 47 | 48 | def compile_tools(tools: Sequence[dict[str, Any]], variables: dict[str, str]) -> Sequence[dict[str, Any]]: 49 | """Compiles a list of tool argument dictionaries using Liquid templating and provided variables. 50 | 51 | Each dictionary in the list is deep copied and processed recursively to substitute any Liquid 52 | templates present in its data structure. 53 | 54 | Args: 55 | tools: A list of dictionaries representing tool arguments, where values can include Liquid templates. 56 | variables: A dictionary of variables to substitute into the Liquid templates. 57 | 58 | Returns: 59 | A new list of dictionaries with the Liquid templates replaced by their corresponding variable values. 60 | """ 61 | tools_formatted = deepcopy(tools) 62 | tools_formatted = [_apply_templates(tool, variables) for tool in tools_formatted] 63 | return tools_formatted 64 | 65 | 66 | def encode_image(image_path: Path) -> str: 67 | """Encodes an image file at the given Path to base64. 68 | 69 | Args: 70 | image_path: The path to the image file to encode. 71 | 72 | Returns: 73 | The base64 encoded image as a string. 74 | """ 75 | with Path.open(image_path, "rb") as image_file: 76 | return base64.b64encode(image_file.read()).decode("utf-8") 77 | 78 | 79 | def create_image_url(image_path: Path) -> str: 80 | """Creates a data URL for an image file at the given Path. 81 | 82 | Args: 83 | image_path: The path to the image file to encode. 84 | 85 | Returns: 86 | The data URL for the image. 87 | """ 88 | image_data = encode_image(image_path) 89 | 90 | valid_mime_types = ["image/jpeg", "image/png", "image/webp", "image/gif"] 91 | 92 | # Get the MIME type from the image file extension 93 | mime_type = mimetypes.guess_type(image_path)[0] 94 | 95 | # Check if the MIME type is valid 96 | # List of valid types is here: https://platform.openai.com/docs/guides/vision/what-type-of-files-can-i-upload 97 | if mime_type not in valid_mime_types: 98 | raise ValueError(f"Invalid MIME type for image: {mime_type}") 99 | 100 | return f"data:{mime_type};base64,{image_data}" 101 | 102 | 103 | def pydantic_to_json_schema( 104 | pydantic_model: type[BaseModel], schema_name: str, description: str | None = None 105 | ) -> dict[str, Any]: 106 | """Converts a Pydantic model to a JSON schema expected by Structured Outputs. 107 | Must adhere to the supported schemas: https://platform.openai.com/docs/guides/structured-outputs/supported-schemas 108 | 109 | Args: 110 | pydantic_model: The Pydantic model to convert. 111 | schema_name: The name of the schema. 112 | description: An optional description of the schema. 113 | 114 | Returns: 115 | A JSON schema dictionary representing the Pydantic model. 116 | """ 117 | converted_pydantic = to_strict_json_schema(pydantic_model) 118 | schema = { 119 | "name": schema_name, 120 | "strict": True, 121 | "schema": converted_pydantic, 122 | } 123 | if description: 124 | schema["description"] = description 125 | return schema 126 | -------------------------------------------------------------------------------- /src/not_again_ai/llm/prompting/interface.py: -------------------------------------------------------------------------------- 1 | from collections.abc import Collection, Set 2 | from typing import Literal 3 | 4 | from loguru import logger 5 | 6 | from not_again_ai.llm.chat_completion.types import MessageT 7 | from not_again_ai.llm.prompting.providers.openai_tiktoken import TokenizerOpenAI 8 | from not_again_ai.llm.prompting.types import BaseTokenizer 9 | 10 | 11 | class Tokenizer(BaseTokenizer): 12 | def __init__( 13 | self, 14 | model: str, 15 | provider: str, 16 | allowed_special: Literal["all"] | Set[str] | None = None, 17 | disallowed_special: Literal["all"] | Collection[str] | None = None, 18 | ): 19 | self.model = model 20 | self.provider = provider 21 | self.allowed_special = allowed_special 22 | self.disallowed_special = disallowed_special 23 | 24 | self.init_tokenizer(model, provider, allowed_special, disallowed_special) 25 | 26 | def init_tokenizer( 27 | self, 28 | model: str, 29 | provider: str, 30 | allowed_special: Literal["all"] | Set[str] | None = None, 31 | disallowed_special: Literal["all"] | Collection[str] | None = None, 32 | ) -> None: 33 | if provider == "openai" or provider == "azure_openai": 34 | self.tokenizer = TokenizerOpenAI(model, provider, allowed_special, disallowed_special) 35 | else: 36 | logger.warning(f"Provider {provider} not supported. Initializing using tiktoken and gpt-4o.") 37 | self.tokenizer = TokenizerOpenAI("gpt-4o", "openai", allowed_special, disallowed_special) 38 | 39 | def truncate_str(self, text: str, max_len: int) -> str: 40 | return self.tokenizer.truncate_str(text, max_len) 41 | 42 | def num_tokens_in_str(self, text: str) -> int: 43 | return self.tokenizer.num_tokens_in_str(text) 44 | 45 | def num_tokens_in_messages(self, messages: list[MessageT]) -> int: 46 | return self.tokenizer.num_tokens_in_messages(messages) 47 | -------------------------------------------------------------------------------- /src/not_again_ai/llm/prompting/providers/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DaveCoDev/not-again-ai/93302ff4b8b9eda5398e2c318854677f517445c8/src/not_again_ai/llm/prompting/providers/__init__.py -------------------------------------------------------------------------------- /src/not_again_ai/llm/prompting/providers/openai_tiktoken.py: -------------------------------------------------------------------------------- 1 | from collections.abc import Collection, Set 2 | from typing import Literal 3 | 4 | from loguru import logger 5 | import tiktoken 6 | 7 | from not_again_ai.llm.chat_completion.types import MessageT 8 | from not_again_ai.llm.prompting.types import BaseTokenizer 9 | 10 | 11 | class TokenizerOpenAI(BaseTokenizer): 12 | def __init__( 13 | self, 14 | model: str, 15 | provider: str = "openai", 16 | allowed_special: Literal["all"] | Set[str] | None = None, 17 | disallowed_special: Literal["all"] | Collection[str] | None = None, 18 | ): 19 | self.model = model 20 | self.provider = provider 21 | self.allowed_special = allowed_special 22 | self.disallowed_special = disallowed_special 23 | 24 | self.init_tokenizer(model, provider, allowed_special, disallowed_special) 25 | 26 | def init_tokenizer( 27 | self, 28 | model: str, 29 | provider: str = "openai", 30 | allowed_special: Literal["all"] | Set[str] | None = None, 31 | disallowed_special: Literal["all"] | Collection[str] | None = None, 32 | ) -> None: 33 | try: 34 | self.encoding = tiktoken.encoding_for_model(model) 35 | except KeyError: 36 | default_encoding = "o200k_base" 37 | logger.warning(f"Model {model} not found. Using {default_encoding} encoding.") 38 | self.encoding = tiktoken.get_encoding(default_encoding) 39 | 40 | # Set defaults if not provided 41 | if not allowed_special: 42 | self.allowed_special = set() 43 | if not disallowed_special: 44 | self.disallowed_special = () 45 | 46 | def truncate_str(self, text: str, max_len: int) -> str: 47 | tokens = self.encoding.encode( 48 | text, 49 | allowed_special=self.allowed_special if self.allowed_special is not None else set(), 50 | disallowed_special=self.disallowed_special if self.disallowed_special is not None else (), 51 | ) 52 | if len(tokens) > max_len: 53 | tokens = tokens[:max_len] 54 | truncated_text = self.encoding.decode(tokens) 55 | return truncated_text 56 | else: 57 | return text 58 | 59 | def num_tokens_in_str(self, text: str) -> int: 60 | return len( 61 | self.encoding.encode( 62 | text, 63 | allowed_special=self.allowed_special if self.allowed_special is not None else set(), 64 | disallowed_special=self.disallowed_special if self.disallowed_special is not None else (), 65 | ) 66 | ) 67 | 68 | def num_tokens_in_messages(self, messages: list[MessageT]) -> int: 69 | if self.model in { 70 | "gpt-3.5-turbo-0613", 71 | "gpt-3.5-turbo-16k-0613", 72 | "gpt-3.5-turbo-1106", 73 | "gpt-3.5-turbo-0125", 74 | "gpt-4-0314", 75 | "gpt-4-32k-0314", 76 | "gpt-4-0613", 77 | "gpt-4-32k-0613", 78 | "gpt-4-1106-preview", 79 | "gpt-4-turbo-preview", 80 | "gpt-4-0125-preview", 81 | "gpt-4-turbo", 82 | "gpt-4-turbo-2024-04-09", 83 | "gpt-4o", 84 | "gpt-4o-2024-05-13", 85 | "gpt-4o-2024-08-06", 86 | "gpt-4o-2024-11-20", 87 | "gpt-4o-mini", 88 | "gpt-4o-mini-2024-07-18", 89 | "o1", 90 | "o1-2024-12-17", 91 | "o1-mini", 92 | "o1-mini-2024-09-12", 93 | "o1-preview", 94 | "o1-preview-2024-09-12", 95 | }: 96 | tokens_per_message = 3 # every message follows <|start|>{role/name}\n{content}<|end|>\n 97 | tokens_per_name = 1 # if there's a name, the role is omitted 98 | elif self.model == "gpt-3.5-turbo-0301": 99 | tokens_per_message = 4 100 | tokens_per_name = -1 101 | else: 102 | logger.warning(f"Model {self.model} not supported. Assuming gpt-4o encoding.") 103 | tokens_per_message = 3 104 | tokens_per_name = 1 105 | 106 | num_tokens = 0 107 | for message in messages: 108 | num_tokens += tokens_per_message 109 | message_dict = message.model_dump(exclude_none=True) 110 | for key, value in message_dict.items(): 111 | if isinstance(value, str): 112 | num_tokens += len( 113 | self.encoding.encode( 114 | value, 115 | allowed_special=self.allowed_special if self.allowed_special is not None else set(), 116 | disallowed_special=self.disallowed_special if self.disallowed_special is not None else (), 117 | ) 118 | ) 119 | if key == "name": 120 | num_tokens += tokens_per_name 121 | num_tokens += 3 122 | return num_tokens 123 | -------------------------------------------------------------------------------- /src/not_again_ai/llm/prompting/types.py: -------------------------------------------------------------------------------- 1 | from abc import ABC, abstractmethod 2 | from collections.abc import Collection, Set 3 | from typing import Literal 4 | 5 | from not_again_ai.llm.chat_completion.types import MessageT 6 | 7 | 8 | class BaseTokenizer(ABC): 9 | def __init__( 10 | self, 11 | model: str, 12 | provider: str, 13 | allowed_special: Literal["all"] | Set[str] | None = None, 14 | disallowed_special: Literal["all"] | Collection[str] | None = None, 15 | ): 16 | self.model = model 17 | self.provider = provider 18 | self.allowed_special = allowed_special 19 | self.disallowed_special = disallowed_special 20 | 21 | self.init_tokenizer(model, provider, allowed_special, disallowed_special) 22 | 23 | @abstractmethod 24 | def init_tokenizer( 25 | self, 26 | model: str, 27 | provider: str, 28 | allowed_special: Literal["all"] | Set[str] | None = None, 29 | disallowed_special: Literal["all"] | Collection[str] | None = None, 30 | ) -> None: 31 | pass 32 | 33 | @abstractmethod 34 | def truncate_str(self, text: str, max_len: int) -> str: 35 | pass 36 | 37 | @abstractmethod 38 | def num_tokens_in_str(self, text: str) -> int: 39 | pass 40 | 41 | @abstractmethod 42 | def num_tokens_in_messages(self, messages: list[MessageT]) -> int: 43 | pass 44 | -------------------------------------------------------------------------------- /src/not_again_ai/py.typed: -------------------------------------------------------------------------------- 1 | # Instruct type checkers to look for inline type annotations in this package. 2 | # See PEP 561. 3 | -------------------------------------------------------------------------------- /src/not_again_ai/statistics/__init__.py: -------------------------------------------------------------------------------- 1 | import importlib.util 2 | 3 | if ( 4 | importlib.util.find_spec("numpy") is None 5 | or importlib.util.find_spec("scipy") is None 6 | or importlib.util.find_spec("sklearn") is None 7 | ): 8 | raise ImportError( 9 | "not_again_ai.statistics requires the 'statistics' extra to be installed. " 10 | "You can install it using 'pip install not_again_ai[statistics]'." 11 | ) 12 | else: 13 | import numpy # noqa: F401 14 | import scipy # noqa: F401 15 | import sklearn # noqa: F401 16 | -------------------------------------------------------------------------------- /src/not_again_ai/statistics/dependence.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import numpy.typing as npt 3 | import scipy 4 | import sklearn.metrics as skmetrics 5 | import sklearn.model_selection as skmodel_selection 6 | import sklearn.preprocessing as skpreprocessing 7 | import sklearn.tree as sktree 8 | 9 | 10 | def _process_variable( 11 | x: npt.NDArray[np.int_] | (npt.NDArray[np.float64] | npt.NDArray[np.str_]), 12 | ) -> npt.NDArray[np.int_] | (npt.NDArray[np.float64] | npt.NDArray[np.str_]): 13 | """Process variable by encoding it as a numeric array.""" 14 | le = skpreprocessing.LabelEncoder() 15 | x = le.fit_transform(x) 16 | return x 17 | 18 | 19 | def pearson_correlation( 20 | x: list[int] 21 | | (list[float] | (list[str] | (npt.NDArray[np.int_] | (npt.NDArray[np.float64] | npt.NDArray[np.str_])))), 22 | y: list[int] 23 | | (list[float] | (list[str] | (npt.NDArray[np.int_] | (npt.NDArray[np.float64] | npt.NDArray[np.str_])))), 24 | is_x_categorical: bool = False, 25 | is_y_categorical: bool = False, 26 | print_diagnostics: bool = False, 27 | ) -> float: 28 | """Absolute value of the Pearson correlation coefficient. 29 | Returns 1 in the case y contains all of the same values. 30 | 31 | Implemented using scipy https://docs.scipy.org/doc/scipy/reference/generated/scipy.stats.pearsonr.html 32 | 33 | Args: 34 | x (listlike): first variable 35 | y (listlike): second variable 36 | is_x_categorical (bool): whether x is categorical 37 | is_y_categorical (bool): whether y is categorical 38 | print_diagnostics (bool): whether to print diagnostics to stdout 39 | """ 40 | x_array = np.array(x) 41 | y_array = np.array(y) 42 | 43 | if is_x_categorical: 44 | x_array = _process_variable(x_array) 45 | 46 | if is_y_categorical: 47 | y_array = _process_variable(y_array) 48 | 49 | # check if y contains all of the same values, and if so return 1 50 | if len(np.unique(y_array)) == 1: 51 | if print_diagnostics: 52 | print("y contains all of the same values, returning 1") 53 | return 1.0 54 | 55 | pearsonr = scipy.stats.pearsonr(x_array, y_array) 56 | metric: float = pearsonr.statistic 57 | metric = np.abs(metric) 58 | return metric 59 | 60 | 61 | def pred_power_score_classification( 62 | x: list[int] 63 | | (list[float] | (list[str] | (npt.NDArray[np.int_] | (npt.NDArray[np.float64] | npt.NDArray[np.str_])))), 64 | y: list[int] | (list[str] | npt.NDArray[np.int_]), 65 | cv_splits: int = 5, 66 | print_diagnostics: bool = False, 67 | ) -> float: 68 | """Compute Predictive Power Score, an asymmetric score that can detect 69 | linear or non-linear relationships between two variables. 70 | For this implementation, the score is computed for a classification task and y must be categorical. 71 | 72 | Returns 1 in the case y contains all of the same values. 73 | 74 | Args: 75 | x (listlike of int, float, or string): first variable 76 | y (listlike of int or string): second variable 77 | cv_splits (int): number of cross-validation splits 78 | print_diagnostics (bool): whether to print diagnostics to stdout 79 | """ 80 | x_array = np.array(x) 81 | y_array = np.array(y) 82 | 83 | le = skpreprocessing.LabelEncoder() 84 | # check if x contains any strings 85 | if any(isinstance(elem, str) for elem in x_array): 86 | x_array = le.fit_transform(x_array) 87 | 88 | x_array = x_array.reshape(-1, 1) 89 | y_array = le.fit_transform(y_array) 90 | 91 | # check if y contains all of the same values, and if so return 1 92 | if len(np.unique(y_array)) == 1: 93 | if print_diagnostics: 94 | print("y contains all of the same values, returning 1") 95 | return 1.0 96 | 97 | # Use KFold cross-validation to compute weighted (macro) F1 score 98 | model = sktree.DecisionTreeClassifier(criterion="gini", splitter="best", max_depth=None, random_state=0) 99 | cv_method = skmodel_selection.KFold(n_splits=cv_splits, shuffle=True, random_state=0) 100 | f1_scores = skmodel_selection.cross_val_score( 101 | model, x_array, y_array, cv=cv_method, scoring="f1_weighted", error_score="raise" 102 | ) 103 | f1 = np.mean(f1_scores) 104 | 105 | # find majority class in y 106 | majority_class = np.argmax(np.bincount(y_array)) 107 | preds = np.ones_like(y_array) * majority_class 108 | f1_null: float = skmetrics.f1_score(y_array, preds, average="weighted") 109 | 110 | # generate random predictions 111 | preds = np.random.choice(np.unique(y_array), size=len(y_array)) 112 | f1_random: float = skmetrics.f1_score(y_array, preds, average="weighted") 113 | 114 | f1_naive = np.max([f1_null, f1_random]) 115 | pps: float = (f1 - f1_naive) / (1 - f1_naive) 116 | 117 | # ensure pps is not negative 118 | pps = np.max([0, pps]) 119 | return pps 120 | -------------------------------------------------------------------------------- /src/not_again_ai/viz/__init__.py: -------------------------------------------------------------------------------- 1 | import importlib.util 2 | 3 | if ( 4 | importlib.util.find_spec("numpy") is None 5 | or importlib.util.find_spec("pandas") is None 6 | or importlib.util.find_spec("seaborn") is None 7 | ): 8 | raise ImportError( 9 | "not_again_ai.viz requires the 'viz' extra to be installed. " 10 | "You can install it using 'pip install not_again_ai[viz]'." 11 | ) 12 | else: 13 | import numpy # noqa: F401 14 | import pandas # noqa: F401 15 | import seaborn # noqa: F401 16 | -------------------------------------------------------------------------------- /src/not_again_ai/viz/barplots.py: -------------------------------------------------------------------------------- 1 | import matplotlib.pyplot as plt 2 | import numpy as np 3 | import numpy.typing as npt 4 | import seaborn as sns 5 | 6 | from not_again_ai.base.file_system import create_file_dir 7 | from not_again_ai.viz.utils import reset_plot_libs 8 | 9 | 10 | def simple_barplot( 11 | x: list[str] | (list[float] | (npt.NDArray[np.int_] | npt.NDArray[np.float64])), 12 | y: list[str] | (list[float] | (npt.NDArray[np.int_] | npt.NDArray[np.float64])), 13 | save_pathname: str, 14 | order: str | None = None, 15 | orient_bars_vertically: bool = True, 16 | title: str | None = None, 17 | x_label: str | None = None, 18 | y_label: str | None = None, 19 | font_size: float = 48, 20 | height: float = 11, 21 | aspect: float = 2, 22 | ) -> None: 23 | """Saves a simple barplot to the specified pathname. 24 | 25 | Args: 26 | x (listlike): Input listlike data for x-axis. 27 | If orient_bars_vertically is True, this is the category names. 28 | If orient_bars_vertically is False, this is the bar heights (cannot be numeric). 29 | y (listlike): Input listlike data for y-axis 30 | If orient_bars_vertically is True, this is the bar heights (cannot be numeric). 31 | If orient_bars_vertically is False, this is the category names. 32 | save_pathname (str): Filepath to save plot to. Parent directories will be automatically created. 33 | order (str, optional): Order of the bars, either "asc" or "desc". Defaults to None. 34 | orient_bars_vertically (bool, optional): Whether to orient the bars vertically. Defaults to True. 35 | title (str, optional): Title of the plot. Defaults to None. 36 | x_label (str, optional): Label for the x-axis. Defaults to None. 37 | y_label (str, optional): Label for the y-axis. Defaults to None. 38 | font_size (float, optional): Font size. Defaults to 48. 39 | height (float, optional): Height (in inches) of the plot. Defaults to 11. 40 | aspect (float, optional): Aspect ratio of the plot. Defaults to 2. 41 | """ 42 | 43 | sns.set_theme( 44 | style="white", 45 | rc={ 46 | "font.size": font_size, 47 | "axes.titlesize": font_size, 48 | "axes.labelsize": font_size * 0.8, 49 | "xtick.labelsize": font_size * 0.65, 50 | "ytick.labelsize": font_size * 0.65, 51 | "legend.fontsize": font_size * 0.5, 52 | "legend.title_fontsize": font_size * 0.55, 53 | }, 54 | ) 55 | sns.set_color_codes("muted") 56 | 57 | if order: 58 | if order == "asc": 59 | # sort x and y ascending by y 60 | if orient_bars_vertically: 61 | y, x = (list(t) for t in zip(*sorted(zip(y, x, strict=True)), strict=True)) 62 | else: 63 | x, y = (list(t) for t in zip(*sorted(zip(x, y, strict=True)), strict=True)) 64 | elif order == "desc": 65 | # sort x and y descending by y 66 | if orient_bars_vertically: 67 | y, x = (list(t) for t in zip(*sorted(zip(y, x, strict=True), reverse=True), strict=True)) 68 | else: 69 | x, y = (list(t) for t in zip(*sorted(zip(x, y, strict=True), reverse=True), strict=True)) 70 | 71 | ax = sns.barplot(x=x, y=y, color="b", orient="v" if orient_bars_vertically else "h") 72 | ax.figure.set_size_inches(height * aspect, height) 73 | 74 | ax.set_title(title) 75 | ax.set_xlabel(x_label) 76 | ax.set_ylabel(y_label) 77 | 78 | sns.despine() 79 | 80 | create_file_dir(save_pathname) 81 | plt.savefig(save_pathname, bbox_inches="tight") 82 | reset_plot_libs() 83 | -------------------------------------------------------------------------------- /src/not_again_ai/viz/distributions.py: -------------------------------------------------------------------------------- 1 | import matplotlib.patches as mpatches 2 | import matplotlib.pyplot as plt 3 | import numpy as np 4 | import numpy.typing as npt 5 | import seaborn as sns 6 | 7 | from not_again_ai.base.file_system import create_file_dir 8 | from not_again_ai.viz.utils import reset_plot_libs 9 | 10 | 11 | def univariate_distplot( 12 | data: list[float] | npt.NDArray[np.float64], 13 | save_pathname: str, 14 | print_summary: bool = True, 15 | title: str | None = None, 16 | xlabel: str | None = "Value", 17 | ylabel: str | None = "Count", 18 | xlim: tuple[float, float] | None = None, 19 | ylim: tuple[float, float] | None = None, 20 | xticks: npt.ArrayLike | None = None, 21 | yticks: npt.ArrayLike | None = None, 22 | bins: int = 50, 23 | font_size: float = 48, 24 | height: float = 13, 25 | aspect: float = 2.2, 26 | ) -> None: 27 | """Saves a univariate distribution plot to the specified pathname. 28 | 29 | Args: 30 | data (listlike): Input listlike data to plot distribution of 31 | save_pathname (str): Filepath to save plot to. Parent directories will be automatically created. 32 | print_summary (bool, optional): If true will print summary statistics. Defaults to True. 33 | title (str, optional): Title of the plot. Defaults to None. 34 | xlabel (str, optional): Set the label for the x-axis. Defaults to 'Value'. 35 | ylabel (str, optional): Set the label for the y-axis. Defaults to 'Count'. 36 | xlim (tuple[float, float], optional): Set the x-axis limits (lower, upper). Defaults to None. 37 | ylim (tuple[float, float], optional): Set the y-axis limits (lower, upper). Defaults to None. 38 | xticks (npt.ArrayLike, optional): Set the x-axis tick locations. Defaults to None. 39 | yticks (npt.ArrayLike, optional): Set the y-axis tick locations. Defaults to None. 40 | bins (int, optional): See matplotlib [histplot documentation](https://matplotlib.org/stable/api/_as_gen/matplotlib.pyplot.hist.html#matplotlib-pyplot-hist) for all options. Defaults to 50. 41 | font_size (float, optional): Font size. Defaults to 48. 42 | height (float, optional): Height (in inches) of the plot. Defaults to 13. 43 | aspect (float, optional): Aspect ratio of the plot, so that `aspect` * `height` gives the width of each facet in inches. Defaults to 2.2. 44 | """ 45 | 46 | sns.set_theme( 47 | style="white", 48 | rc={ 49 | "font.size": font_size, 50 | "axes.titlesize": font_size, 51 | "axes.labelsize": font_size * 0.8, 52 | "xtick.labelsize": font_size * 0.7, 53 | "ytick.labelsize": font_size * 0.7, 54 | "legend.fontsize": font_size * 0.55, 55 | }, 56 | ) 57 | 58 | # precompute summary statistics 59 | mean = np.mean(data) 60 | median = np.median(data) 61 | stdev = np.std(data) 62 | percentile_5 = np.percentile(data, 5) 63 | percentile_95 = np.percentile(data, 95) 64 | 65 | facet_grid = sns.displot(data, bins=bins, height=height, aspect=aspect) 66 | 67 | facet_grid.set( 68 | xlim=xlim, 69 | ylim=ylim, 70 | title=title, 71 | xlabel=xlabel, 72 | ylabel=ylabel, 73 | ) 74 | 75 | if xticks is not None: 76 | facet_grid.set(xticks=xticks) 77 | if yticks is not None: 78 | facet_grid.set(yticks=yticks) 79 | 80 | pastel_colors = sns.color_palette("pastel") 81 | 82 | ax = facet_grid.axes.flatten()[0] 83 | 84 | # plot summary statistic lines 85 | ax.axvline(x=mean, color=pastel_colors[1], ls="--", lw=2.5, label=f"Mean: {mean:.3f}") 86 | ax.axvline(x=median, color=pastel_colors[2], ls="--", lw=2.5, label=f"Median: {mean:.3f}") 87 | ax.axvline(x=percentile_5, color=pastel_colors[9], ls="--", lw=2.5, label=f"5 Percentile: {percentile_5:.3f}") 88 | ax.axvline(x=percentile_95, color=pastel_colors[4], ls="--", lw=2.5, label=f"95 Percentile: {percentile_95:.3f}") 89 | # add legend for these lines 90 | handles, _ = ax.get_legend_handles_labels() 91 | # and an empty patch for the stdev statistic 92 | handles.append(mpatches.Patch(color="none", label=f"St Dev: {stdev:.3f}")) 93 | 94 | plt.legend(handles=handles, loc=0) 95 | 96 | create_file_dir(save_pathname) 97 | plt.savefig(save_pathname, bbox_inches="tight") 98 | reset_plot_libs() 99 | 100 | if print_summary: 101 | to_print = ( 102 | "Summary Statistics:", 103 | f"Mean:\t\t{mean:.3f}", 104 | f"Median:\t\t{median:.3f}", 105 | f"5 Percentile:\t{percentile_5:.3f}", 106 | f"95 Percentile:\t{percentile_95:.3f}", 107 | f"St Dev:\t\t{stdev:.3f}", 108 | ) 109 | print("\n".join(to_print)) 110 | -------------------------------------------------------------------------------- /src/not_again_ai/viz/scatterplot.py: -------------------------------------------------------------------------------- 1 | import matplotlib.pyplot as plt 2 | import numpy as np 3 | import numpy.typing as npt 4 | import pandas as pd 5 | import seaborn as sns 6 | 7 | from not_again_ai.base.file_system import create_file_dir 8 | from not_again_ai.viz.utils import reset_plot_libs 9 | 10 | 11 | def scatterplot_basic( 12 | x: list[float] | (npt.NDArray[np.int_] | npt.NDArray[np.float64]), 13 | y: list[float] | (npt.NDArray[np.int_] | npt.NDArray[np.float64]), 14 | save_pathname: str, 15 | title: str | None = None, 16 | xlim: tuple[float, float] | None = None, 17 | ylim: tuple[float, float] | None = None, 18 | font_size: float = 48, 19 | height: float = 13, 20 | aspect: float = 1.2, 21 | ) -> None: 22 | """Saves a basic scatterplot to the specified pathname. 23 | 24 | Args: 25 | x (listlike): Input listlike data for x-axis 26 | y (listlike): Input listlike data for y-axis 27 | save_pathname (str): Filepath to save plot to. Parent directories will be automatically created. 28 | title (str, optional): Title of the plot. Defaults to None. 29 | xlim (tuple[float, float], optional): Set the x-axis limits (lower, upper). Defaults to None. 30 | ylim (tuple[float, float], optional): Set the y-axis limits (lower, upper). Defaults to None. 31 | font_size (float, optional): Font size. Defaults to 48. 32 | height (float, optional): Height (in inches) of the plot. Defaults to 13. 33 | aspect (float, optional): Aspect ratio of the plot, so that `aspect` * `height` gives the width of each facet in inches. Defaults to 1.2. 34 | """ 35 | 36 | sns.set_theme( 37 | style="white", 38 | rc={ 39 | "font.size": font_size, 40 | "axes.titlesize": font_size, 41 | "axes.labelsize": font_size * 0.8, 42 | "xtick.labelsize": font_size * 0.65, 43 | "ytick.labelsize": font_size * 0.65, 44 | "legend.fontsize": font_size * 0.5, 45 | "legend.title_fontsize": font_size * 0.55, 46 | }, 47 | ) 48 | data = pd.DataFrame({"x": x, "y": y}) 49 | ax = sns.scatterplot(data=data, x="x", y="y") 50 | 51 | ax.figure.set_size_inches(height * aspect, height) 52 | ax.set_title(title) 53 | ax.set_xlabel("x") 54 | ax.set_ylabel("y") 55 | 56 | ax.set_xlim(xlim) 57 | ax.set_ylim(ylim) 58 | 59 | sns.despine(top=True, right=True) 60 | 61 | create_file_dir(save_pathname) 62 | plt.savefig(save_pathname, bbox_inches="tight") 63 | reset_plot_libs() 64 | -------------------------------------------------------------------------------- /src/not_again_ai/viz/time_series.py: -------------------------------------------------------------------------------- 1 | import matplotlib 2 | import matplotlib.dates 3 | import matplotlib.pyplot as plt 4 | import numpy as np 5 | import numpy.typing as npt 6 | import pandas as pd 7 | import seaborn as sns 8 | 9 | from not_again_ai.base.file_system import create_file_dir 10 | from not_again_ai.viz.utils import reset_plot_libs 11 | 12 | 13 | def ts_lineplot( 14 | ts_data: list[float] | (npt.NDArray[np.float64] | npt.NDArray[np.int64]), 15 | save_pathname: str, 16 | ts_x: ( 17 | list[float] 18 | | (npt.NDArray[np.float64] | (npt.NDArray[np.datetime64] | (npt.NDArray[np.int64] | pd.Series))) 19 | | None 20 | ) = None, 21 | ts_names: list[str] | None = None, 22 | title: str | None = None, 23 | xlabel: str | None = "Time", 24 | ylabel: str | None = "Value", 25 | legend_title: str | None = None, 26 | xaxis_date_format: str | None = None, 27 | xaxis_major_locator: matplotlib.ticker.Locator | None = None, 28 | ylim: tuple[float, float] | None = None, 29 | yticks: npt.ArrayLike | None = None, 30 | font_size: float = 48, 31 | height: float = 13, 32 | aspect: float = 2.2, 33 | linewidth: float = 2, 34 | legend_loc: str | (tuple[float, float] | int) | None = None, 35 | palette: str | (list[str] | (list[float] | (dict[str, str] | matplotlib.colors.Colormap))) = "tab10", 36 | ) -> None: 37 | """Saves a time series plot where each row in `ts_data` is a time series. 38 | Optionally, a specific x axis (like dates) can be provided with `ts_x`. 39 | Names to appear in the legend for each time series can be provided with `ts_names`. 40 | 41 | Args: 42 | ts_data (list of lists or 2D numpy array): Each nested list or row is a time series to be plotted. 43 | save_pathname (str): Filepath to save plot to. Parent directories will be automatically created. 44 | ts_x (listlike, optional): The values that will be used for the x-axis. Defaults to None. 45 | ts_names (list[str], optional): The names of the time series shown on the legend. Defaults to None. 46 | title (str, optional): Title of the plot. Defaults to None. 47 | xlabel (str, optional): Set the label for the x-axis. Defaults to 'Time'. 48 | ylabel (str, optional): Set the label for the y-axis. Defaults to 'Value'. 49 | legend_title (str, optional): Sets the title of the legend. Defaults to None. 50 | xaxis_date_format (str, optional): A dateformat string. See [strftime cheatsheet](https://strftime.org/). Defaults to None. 51 | xaxis_major_locator (matplotlib.ticker.Locator, optional): Matplotlib tick locator, 52 | See [Tick locating](https://matplotlib.org/stable/api/ticker_api.html) or [Date tickers](https://matplotlib.org/stable/api/ticker_api.html) for the available options. Defaults to None. 53 | ylim (tuple[float, float], optional): Set the y-axis limits (lower, upper). Defaults to None. 54 | yticks (npt.ArrayLike, optional): Set the y-axis tick locations. Defaults to None. 55 | font_size (float, optional): Font size. Defaults to 48. 56 | height (float, optional): Height (in inches) of the plot. Defaults to 13. 57 | aspect (float, optional): Aspect ratio of the plot, so that `aspect` * `height` gives the width of each facet in inches. Defaults to 2.2. 58 | linewidth (float, optional): Size of each time series line. Defaults to 2. 59 | legend_loc (Union[str, tuple[float, float], int], optional): Matplotlib legend location. 60 | See [matplotlib documentation](https://matplotlib.org/stable/api/_as_gen/matplotlib.pyplot.legend.html). Defaults to None. 61 | palette (str, list, dict, matplotlib.colors.Colormap], optional): Takes the same arguments as [seaborn's lineplot](https://seaborn.pydata.org/generated/seaborn.lineplot.html#seaborn-lineplot) palette argument. 62 | """ 63 | 64 | sns.set_theme( 65 | style="white", 66 | rc={ 67 | "font.size": font_size, 68 | "axes.titlesize": font_size, 69 | "axes.labelsize": font_size * 0.8, 70 | "xtick.labelsize": font_size * 0.65, 71 | "ytick.labelsize": font_size * 0.65, 72 | "legend.fontsize": font_size * 0.5, 73 | "legend.title_fontsize": font_size * 0.55, 74 | }, 75 | ) 76 | # Transpose the list of lists or numpy array 77 | ts_data = np.array(ts_data).T 78 | sns_data = pd.DataFrame(ts_data, columns=ts_names) 79 | if ts_x is None: 80 | ts_x = np.arange(len(ts_data)) 81 | 82 | sns_data["Time"] = ts_x 83 | sns_data = sns_data.melt(id_vars="Time", var_name="Time Series", value_name="Value") 84 | ax = sns.lineplot(data=sns_data, x="Time", y="Value", hue="Time Series", palette=palette, linewidth=linewidth) 85 | 86 | ax.figure.set_size_inches(height * aspect, height) 87 | ax.set_title(title) 88 | ax.set_xlabel(xlabel) 89 | ax.set_ylabel(ylabel) 90 | ax.legend(title=legend_title) 91 | 92 | ax.set_ylim(ylim) 93 | 94 | if legend_loc is not None: 95 | ax.legend(loc=legend_loc) 96 | 97 | if (xaxis_date_format is not None) and (ts_x is not None): 98 | ax.xaxis.set_major_formatter(matplotlib.dates.DateFormatter(xaxis_date_format)) 99 | 100 | if (xaxis_major_locator is not None) and (ts_x is not None): 101 | ax.xaxis.set_major_locator(xaxis_major_locator) 102 | 103 | if yticks is not None: 104 | ax.set(yticks=yticks) 105 | 106 | sns.despine(top=True, right=True) 107 | 108 | create_file_dir(save_pathname) 109 | plt.savefig(save_pathname, bbox_inches="tight") 110 | reset_plot_libs() 111 | -------------------------------------------------------------------------------- /src/not_again_ai/viz/utils.py: -------------------------------------------------------------------------------- 1 | import matplotlib.pyplot as plt 2 | import seaborn as sns 3 | 4 | 5 | def reset_plot_libs() -> None: 6 | """Resets the plot libraries so that subsequent method calls are not impacted.""" 7 | plt.clf() 8 | plt.cla() 9 | plt.close() 10 | sns.reset_orig() 11 | -------------------------------------------------------------------------------- /tests/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DaveCoDev/not-again-ai/93302ff4b8b9eda5398e2c318854677f517445c8/tests/__init__.py -------------------------------------------------------------------------------- /tests/base/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DaveCoDev/not-again-ai/93302ff4b8b9eda5398e2c318854677f517445c8/tests/base/__init__.py -------------------------------------------------------------------------------- /tests/base/test_file_system.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | 3 | from not_again_ai.base.file_system import readable_size 4 | 5 | 6 | @pytest.mark.parametrize( 7 | ("size", "expected"), 8 | [ 9 | (0, "0.00 B"), 10 | (523, "523.00 B"), 11 | (2048, "2.00 KB"), 12 | (5242880, "5.00 MB"), 13 | (10737418240, "10.00 GB"), 14 | (1099511627776, "1.00 TB"), 15 | ], 16 | ) 17 | def test_human_readable_size(size: float, expected: str) -> None: 18 | assert readable_size(size) == expected, f"Failed for {size} bytes" 19 | -------------------------------------------------------------------------------- /tests/base/test_parallel.py: -------------------------------------------------------------------------------- 1 | import multiprocessing 2 | import random 3 | import time 4 | 5 | import pytest 6 | 7 | from not_again_ai.base.parallel import embarrassingly_parallel, embarrassingly_parallel_simple 8 | 9 | 10 | def multby2(x: float, y: float, double: bool = False) -> float: 11 | time.sleep(random.uniform(0, 1)) 12 | if double: 13 | return x * y * 2 14 | else: 15 | return x * y 16 | 17 | 18 | def do_something() -> int: 19 | return 8 20 | 21 | 22 | def do_something2() -> int: 23 | return 2 24 | 25 | 26 | def echo(x: int) -> int: 27 | time.sleep(random.uniform(0, 1) / 10) 28 | return x 29 | 30 | 31 | def test_embarrassingly_parallel() -> None: 32 | args = ((2, 2), (3, 3), (4, 4)) 33 | 34 | result = embarrassingly_parallel(multby2, args, num_processes=multiprocessing.cpu_count()) 35 | 36 | total = 0 37 | for x in result: 38 | total += x 39 | assert total == 4 + 9 + 16 40 | 41 | 42 | def test_embarrassingly_parallel_both() -> None: 43 | args = ((2, 2), (3, 3), (4, 4)) 44 | kwargs = [{"double": True}, {"double": False}, {"double": True}] 45 | 46 | result = embarrassingly_parallel(multby2, args, kwargs, num_processes=multiprocessing.cpu_count()) 47 | 48 | total = 0 49 | for x in result: 50 | total += x 51 | assert total == 8 + 9 + 32 52 | 53 | 54 | def test_embarrassingly_parallel_kwargs() -> None: 55 | kwargs = [{"x": 2, "y": 2, "double": True}, {"x": 3, "y": 3, "double": False}, {"x": 4, "y": 4, "double": True}] 56 | 57 | result = embarrassingly_parallel(multby2, None, kwargs, num_processes=multiprocessing.cpu_count()) 58 | 59 | total = 0 60 | for x in result: 61 | total += x 62 | assert total == 8 + 9 + 32 63 | 64 | 65 | def test_embarrassingly_parallel_exceptions() -> None: 66 | with pytest.raises(ValueError, match="either args_list or kwargs_list must be provided"): 67 | embarrassingly_parallel(multby2, None, None, num_processes=multiprocessing.cpu_count()) 68 | 69 | args = ((2, 2), (3, 3), (4, 4)) 70 | kwargs = [{"double": True}, {"double": False}] 71 | with pytest.raises(ValueError, match="args_list and kwargs_list must be of the same length"): 72 | embarrassingly_parallel(multby2, args, kwargs, num_processes=multiprocessing.cpu_count()) 73 | 74 | 75 | def test_embarrassingly_parallel_ordering() -> None: 76 | args = ((1,), (2,), (3,), (4,), (5,), (6,), (7,), (8,), (9,), (10,)) 77 | result = embarrassingly_parallel(echo, args, num_processes=3) 78 | assert result == [1, 2, 3, 4, 5, 6, 7, 8, 9, 10] 79 | 80 | 81 | def test_embarrassingly_parallel_simple() -> None: 82 | result = embarrassingly_parallel_simple([do_something, do_something2], num_processes=2) 83 | assert result == [8, 2] 84 | -------------------------------------------------------------------------------- /tests/data/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DaveCoDev/not-again-ai/93302ff4b8b9eda5398e2c318854677f517445c8/tests/data/__init__.py -------------------------------------------------------------------------------- /tests/data/test_brave_search_api.py: -------------------------------------------------------------------------------- 1 | from typing import Any 2 | 3 | import pytest 4 | 5 | from not_again_ai.data.brave_search_api import search, search_news 6 | 7 | 8 | @pytest.mark.parametrize( 9 | ("query", "search_params"), 10 | [ 11 | ("brave search", {}), 12 | ("python programming", {"count": 2, "country": "US"}), 13 | ("machine learning", {"count": 4, "search_lang": "en", "freshness": "pw"}), 14 | ("AI news", {"count": 1, "offset": 5, "country": "GB", "ui_lang": "en-GB"}), 15 | ], 16 | ) 17 | async def test_brave_search_api(query: str, search_params: dict[str, Any]) -> None: 18 | """Test the Brave Search API with a sample query and optional parameters.""" 19 | content = await search(query=query, **search_params) 20 | assert content.results, f"No results returned for query: {query}" 21 | 22 | 23 | @pytest.mark.skip("API Cost") 24 | @pytest.mark.parametrize( 25 | ("query", "search_params"), 26 | [ 27 | ("latest tech news", {}), 28 | ("AI breakthrough", {"count": 3, "country": "US"}), 29 | ], 30 | ) 31 | async def test_brave_search_news_api(query: str, search_params: dict[str, Any]) -> None: 32 | """Test the Brave News Search API with a sample query and optional parameters.""" 33 | content = await search_news(query=query, **search_params) 34 | assert content.results, f"No news results returned for query: {query}" 35 | -------------------------------------------------------------------------------- /tests/data/test_web.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | 3 | from not_again_ai.data.web import process_url 4 | 5 | 6 | @pytest.mark.parametrize( 7 | "url", 8 | [ 9 | "https://example.com", 10 | "https://github.com/unclecode/crawl4ai", 11 | "https://arxiv.org/pdf/1710.02298", 12 | "https://www.youtube.com/watch?v=dQw4w9WgXcQ", 13 | "https://www.nascar.com/news/nascar-craftsman-truck-series/", 14 | "https://docs.google.com/spreadsheets/d/1BxiMVs0XRA5nFMdKvBdBZjgmUUqptlbs74OgvE2upms/edit?gid=0#gid=0", 15 | "https://docs.google.com/spreadsheets/d/1BxiMVs0XRA5nFMdKvBdBZjgmUUqptlbs74OgvE2upms/export?format=csv&gid=0", 16 | ], 17 | ) 18 | async def test_process_url(url: str) -> None: 19 | content = await process_url(url) 20 | assert content, f"Content should not be empty for URL: {url}" 21 | -------------------------------------------------------------------------------- /tests/llm/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DaveCoDev/not-again-ai/93302ff4b8b9eda5398e2c318854677f517445c8/tests/llm/__init__.py -------------------------------------------------------------------------------- /tests/llm/chat_completion/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DaveCoDev/not-again-ai/93302ff4b8b9eda5398e2c318854677f517445c8/tests/llm/chat_completion/__init__.py -------------------------------------------------------------------------------- /tests/llm/embedding/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DaveCoDev/not-again-ai/93302ff4b8b9eda5398e2c318854677f517445c8/tests/llm/embedding/__init__.py -------------------------------------------------------------------------------- /tests/llm/embedding/test_embedding.py: -------------------------------------------------------------------------------- 1 | from collections.abc import Callable 2 | from typing import Any 3 | 4 | import pytest 5 | 6 | from not_again_ai.llm.embedding import EmbeddingRequest, create_embeddings 7 | from not_again_ai.llm.embedding.providers.ollama_api import ollama_client 8 | from not_again_ai.llm.embedding.providers.openai_api import openai_client 9 | from not_again_ai.llm.embedding.types import EmbeddingResponse 10 | 11 | 12 | def print_embedding_response(embedding_response: EmbeddingResponse, max_elements: int = 5) -> None: 13 | """Print an EmbeddingResponse, truncating each embedding vector after `max_elements` values. 14 | 15 | Args: 16 | embedding_response (EmbeddingResponse): The response to print. 17 | max_elements (int): Maximum number of elements from each embedding vector to display. 18 | Defaults to 5. 19 | """ 20 | print("EmbeddingResponse:") 21 | print(f" Response Duration: {embedding_response.response_duration}") 22 | print(f" Total Tokens: {embedding_response.total_tokens}") 23 | print(f" Extras: {embedding_response.extras}") 24 | print(" Embeddings:") 25 | for obj in embedding_response.embeddings: 26 | vector = obj.embedding 27 | if len(vector) > max_elements: 28 | # Format the first max_elements, then append an ellipsis. 29 | vector_str = ", ".join(f"{value:.4f}" for value in vector[:max_elements]) 30 | vector_str += ", ..." 31 | else: 32 | vector_str = ", ".join(f"{value:.4f}" for value in vector) 33 | print(f" Index {obj.index}: [{vector_str}]") 34 | 35 | 36 | # region OpenAI and Azure OpenAI Embedding 37 | @pytest.fixture( 38 | params=[ 39 | {}, 40 | {"api_type": "azure_openai", "aoai_api_version": "2025-01-01-preview"}, 41 | ] 42 | ) 43 | def openai_aoai_client_fixture(request: pytest.FixtureRequest) -> Callable[..., Any]: 44 | return openai_client(**request.param) 45 | 46 | 47 | def test_create_embeddings(openai_aoai_client_fixture: Callable[..., Any]) -> None: 48 | request = EmbeddingRequest(input="Hello, world!", model="text-embedding-3-small") 49 | response = create_embeddings(request, "openai", openai_aoai_client_fixture) 50 | print_embedding_response(response) 51 | 52 | 53 | def test_create_embeddings_multiple(openai_aoai_client_fixture: Callable[..., Any]) -> None: 54 | request = EmbeddingRequest(input=["Hello, world!", "Hello, world 2!"], model="text-embedding-3-small") 55 | response = create_embeddings(request, "openai", openai_aoai_client_fixture) 56 | print_embedding_response(response) 57 | 58 | 59 | def test_create_embeddings_dimensions(openai_aoai_client_fixture: Callable[..., Any]) -> None: 60 | request = EmbeddingRequest( 61 | input="This is a test of the dimensions parameter", 62 | model="text-embedding-3-large", 63 | dimensions=3, 64 | ) 65 | response = create_embeddings(request, "openai", openai_aoai_client_fixture) 66 | print_embedding_response(response) 67 | 68 | 69 | # endregion 70 | # region Ollama 71 | @pytest.fixture( 72 | params=[ 73 | {}, 74 | ] 75 | ) 76 | def ollama_client_fixture(request: pytest.FixtureRequest) -> Callable[..., Any]: 77 | return ollama_client(**request.param) 78 | 79 | 80 | def test_create_embeddings_ollama(ollama_client_fixture: Callable[..., Any]) -> None: 81 | request = EmbeddingRequest(input="Hello, world!", model="snowflake-arctic-embed2") 82 | response = create_embeddings(request, "ollama", ollama_client_fixture) 83 | print_embedding_response(response) 84 | 85 | 86 | def test_create_embeddings_ollama_missing_param(ollama_client_fixture: Callable[..., Any]) -> None: 87 | request = EmbeddingRequest(input="Hello, world!", model="snowflake-arctic-embed2", dimensions=3) 88 | response = create_embeddings(request, "ollama", ollama_client_fixture) 89 | print_embedding_response(response) 90 | 91 | 92 | def test_create_embeddings_ollama_multiple(ollama_client_fixture: Callable[..., Any]) -> None: 93 | request = EmbeddingRequest(input=["Hello, world!", "Hello, world 2!"], model="snowflake-arctic-embed2") 94 | response = create_embeddings(request, "ollama", ollama_client_fixture) 95 | print_embedding_response(response) 96 | 97 | 98 | # endregion 99 | -------------------------------------------------------------------------------- /tests/llm/image_gen/test_image_gen.py: -------------------------------------------------------------------------------- 1 | from collections.abc import Callable 2 | from pathlib import Path 3 | from typing import Any 4 | 5 | import pytest 6 | 7 | from not_again_ai.llm.image_gen import ImageGenRequest, create_image 8 | from not_again_ai.llm.image_gen.providers.openai_api import openai_client 9 | 10 | image_dir = Path(__file__).parents[1] / "sample_images" 11 | body_lotion_image_path = image_dir / "body_lotion.png" 12 | soap_image_path = image_dir / "soap.png" 13 | sunlit_lounge_image_path = image_dir / "sunlit_lounge.png" 14 | sunlit_lounge_mask_image_path = image_dir / "sunlit_lounge_mask.png" 15 | 16 | save_dir = Path(__file__).parents[3] / ".nox" / "temp" 17 | save_dir.mkdir(parents=True, exist_ok=True) 18 | temp_image_path = save_dir / "temp_image.png" 19 | 20 | 21 | @pytest.fixture( 22 | params=[ 23 | {}, 24 | ] 25 | ) 26 | def openai_aoai_client_fixture(request: pytest.FixtureRequest) -> Callable[..., Any]: 27 | return openai_client(**request.param) 28 | 29 | 30 | def test_create_image(openai_aoai_client_fixture: Callable[..., Any]) -> None: 31 | prompt = """A children's book drawing of a veterinarian using a stethoscope to 32 | listen to the heartbeat of a baby otter.""" 33 | request = ImageGenRequest( 34 | prompt=prompt, 35 | model="gpt-image-1", 36 | quality="low", 37 | ) 38 | response = create_image(request, "openai", openai_aoai_client_fixture) 39 | 40 | with temp_image_path.open("wb") as f: 41 | f.write(response.images[0]) 42 | 43 | 44 | def test_edit_image(openai_aoai_client_fixture: Callable[..., Any]) -> None: 45 | prompt = """Generate a photorealistic image of a gift basket on a white background 46 | labeled 'Relax & Unwind' with a ribbon and handwriting-like font, 47 | containing all the items in the reference pictures.""" 48 | request = ImageGenRequest( 49 | prompt=prompt, 50 | model="gpt-image-1", 51 | images=[body_lotion_image_path], 52 | quality="low", 53 | ) 54 | response = create_image(request, "openai", openai_aoai_client_fixture) 55 | 56 | with temp_image_path.open("wb") as f: 57 | f.write(response.images[0]) 58 | 59 | 60 | def test_edit_images(openai_aoai_client_fixture: Callable[..., Any]) -> None: 61 | prompt = """Generate a photorealistic image of a gift basket on a white background 62 | labeled 'Relax & Unwind' with a ribbon and handwriting-like font, 63 | containing all the items in the reference pictures.""" 64 | request = ImageGenRequest( 65 | prompt=prompt, 66 | model="gpt-image-1", 67 | images=[body_lotion_image_path, soap_image_path], 68 | quality="low", 69 | size="1024x1024", 70 | ) 71 | response = create_image(request, "openai", openai_aoai_client_fixture) 72 | 73 | with temp_image_path.open("wb") as f: 74 | f.write(response.images[0]) 75 | 76 | 77 | def test_edit_image_with_mask(openai_aoai_client_fixture: Callable[..., Any]) -> None: 78 | request = ImageGenRequest( 79 | prompt="A sunlit indoor lounge area with a pool containing a flamingo", 80 | model="gpt-image-1", 81 | images=[sunlit_lounge_image_path], 82 | mask=sunlit_lounge_mask_image_path, 83 | quality="low", 84 | size="1024x1024", 85 | ) 86 | response = create_image(request, "openai", openai_aoai_client_fixture) 87 | with temp_image_path.open("wb") as f: 88 | f.write(response.images[0]) 89 | 90 | 91 | def test_create_image_multiple(openai_aoai_client_fixture: Callable[..., Any]) -> None: 92 | prompt = """A high-quality 3D-rendered illustration of a color wheel logo. \ 93 | The design features eight symmetrical, petal-shaped leaves arranged in a perfect circular flower pattern. \ 94 | Each leaf is semi-transparent like colored glass, rendered in soft pastel tones including pink, orange, yellow, green, blue, and purple. \ 95 | The petals overlap slightly, creating gentle blended hues where they intersect. \ 96 | The background is a flat, light-toned surface with even, diffused lighting, giving the image a modern, polished, and professional appearance. No text.""" 97 | 98 | request = ImageGenRequest( 99 | prompt=prompt, 100 | model="gpt-image-1", 101 | n=2, 102 | quality="medium", 103 | size="1536x1024", 104 | background="transparent", 105 | moderation="low", 106 | ) 107 | 108 | response = create_image(request, "openai", openai_aoai_client_fixture) 109 | 110 | with temp_image_path.open("wb") as f: 111 | f.write(response.images[0]) 112 | with (temp_image_path.parent / "temp_image_1.png").open("wb") as f: 113 | f.write(response.images[1]) 114 | -------------------------------------------------------------------------------- /tests/llm/prompting/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DaveCoDev/not-again-ai/93302ff4b8b9eda5398e2c318854677f517445c8/tests/llm/prompting/__init__.py -------------------------------------------------------------------------------- /tests/llm/prompting/test_tokenizer.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | 3 | from not_again_ai.llm.chat_completion.types import MessageT, SystemMessage, UserMessage 4 | from not_again_ai.llm.prompting import Tokenizer 5 | 6 | 7 | @pytest.fixture( 8 | params=[ 9 | {"model": "gpt-4o-mini-2024-07-18", "provider": "openai"}, 10 | {"model": "o1-mini", "provider": "azure_openai"}, 11 | ] 12 | ) 13 | def tokenizer(request: pytest.FixtureRequest) -> Tokenizer: 14 | return Tokenizer(**request.param) 15 | 16 | 17 | @pytest.fixture( 18 | params=[ 19 | {"model": "gpt-4o-mini-2024-07-18", "provider": "openai"}, 20 | {"model": "o1-mini", "provider": "azure_openai"}, 21 | {"model": "gpt-4o-mini-2024-07-18", "provider": "mistral"}, 22 | {"model": "unknown-model", "provider": "openai"}, 23 | ] 24 | ) 25 | def tokenizer_with_unsupported(request: pytest.FixtureRequest) -> Tokenizer: 26 | return Tokenizer(**request.param) 27 | 28 | 29 | # This test only runs with OpenAI and Azure 30 | def test_truncate_str(tokenizer_with_unsupported: Tokenizer) -> None: 31 | text = "This is a test sentence for the function." 32 | max_len = 3 33 | result = tokenizer_with_unsupported.truncate_str(text, max_len) 34 | print(result) 35 | 36 | 37 | # This test runs with all three providers 38 | def test_truncate_str_no_truncation(tokenizer: Tokenizer) -> None: 39 | text = "Short text" 40 | max_len = 20 41 | result = tokenizer.truncate_str(text, max_len) 42 | print(result) 43 | 44 | 45 | def test_num_tokens_in_str(tokenizer_with_unsupported: Tokenizer) -> None: 46 | text = "This is a test sentence for the function." 47 | result = tokenizer_with_unsupported.num_tokens_in_str(text) 48 | print(result) 49 | 50 | 51 | def test_num_tokens_allowed_special() -> None: 52 | tokenizer = Tokenizer( 53 | model="gpt-4o-mini-2024-07-18", provider="openai", allowed_special=set(), disallowed_special=() 54 | ) 55 | text = "<|endoftext|>" 56 | result = tokenizer.num_tokens_in_str(text) 57 | print(result) 58 | 59 | 60 | def test_num_tokens_multiple_messages(tokenizer_with_unsupported: Tokenizer) -> None: 61 | messages: list[MessageT] = [SystemMessage(content="System message."), UserMessage(content="User message.")] 62 | result = tokenizer_with_unsupported.num_tokens_in_messages(messages) 63 | print(result) 64 | -------------------------------------------------------------------------------- /tests/llm/sample_images/SKDiagram.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DaveCoDev/not-again-ai/93302ff4b8b9eda5398e2c318854677f517445c8/tests/llm/sample_images/SKDiagram.png -------------------------------------------------------------------------------- /tests/llm/sample_images/SKInfographic.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DaveCoDev/not-again-ai/93302ff4b8b9eda5398e2c318854677f517445c8/tests/llm/sample_images/SKInfographic.png -------------------------------------------------------------------------------- /tests/llm/sample_images/body_lotion.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DaveCoDev/not-again-ai/93302ff4b8b9eda5398e2c318854677f517445c8/tests/llm/sample_images/body_lotion.png -------------------------------------------------------------------------------- /tests/llm/sample_images/cat.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DaveCoDev/not-again-ai/93302ff4b8b9eda5398e2c318854677f517445c8/tests/llm/sample_images/cat.jpg -------------------------------------------------------------------------------- /tests/llm/sample_images/dog.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DaveCoDev/not-again-ai/93302ff4b8b9eda5398e2c318854677f517445c8/tests/llm/sample_images/dog.jpg -------------------------------------------------------------------------------- /tests/llm/sample_images/numbers.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DaveCoDev/not-again-ai/93302ff4b8b9eda5398e2c318854677f517445c8/tests/llm/sample_images/numbers.png -------------------------------------------------------------------------------- /tests/llm/sample_images/soap.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DaveCoDev/not-again-ai/93302ff4b8b9eda5398e2c318854677f517445c8/tests/llm/sample_images/soap.png -------------------------------------------------------------------------------- /tests/llm/sample_images/sunlit_lounge.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DaveCoDev/not-again-ai/93302ff4b8b9eda5398e2c318854677f517445c8/tests/llm/sample_images/sunlit_lounge.png -------------------------------------------------------------------------------- /tests/llm/sample_images/sunlit_lounge_mask.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DaveCoDev/not-again-ai/93302ff4b8b9eda5398e2c318854677f517445c8/tests/llm/sample_images/sunlit_lounge_mask.png -------------------------------------------------------------------------------- /tests/statistics/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DaveCoDev/not-again-ai/93302ff4b8b9eda5398e2c318854677f517445c8/tests/statistics/__init__.py -------------------------------------------------------------------------------- /tests/statistics/test_dependence.py: -------------------------------------------------------------------------------- 1 | from typing import Any 2 | 3 | import numpy as np 4 | import numpy.typing as npt 5 | 6 | from not_again_ai.statistics.dependence import pearson_correlation, pred_power_score_classification 7 | 8 | 9 | def _example_1(rs: np.random.RandomState) -> tuple[npt.NDArray[Any], npt.NDArray[Any]]: 10 | """Example 1 - x is mostly predictive of y. 11 | x is categorical (strings), y is binary, both are numpy arrays 12 | """ 13 | x0 = rs.choice(["a"], 200) 14 | y0 = rs.choice([0, 1], 200, p=[0.9, 0.1]) 15 | 16 | x1 = rs.choice(["b"], 300) 17 | y1 = rs.choice([0, 1], 300, p=[0.1, 0.9]) 18 | 19 | x = np.concatenate([x0, x1]) 20 | y = np.concatenate([y0, y1]) 21 | return (x, y) 22 | 23 | 24 | def _example_2() -> tuple[npt.NDArray[Any], npt.NDArray[Any]]: 25 | """Example 2 - x completely predicts y""" 26 | x = np.array([1, 1, 0, 0, 0, 0, 0, 0, 0, 0]) 27 | y = np.array(["a", "a", "b", "b", "b", "b", "b", "b", "b", "b"]) 28 | return (x, y) 29 | 30 | 31 | def _example_3() -> tuple[npt.NDArray[Any], npt.NDArray[Any]]: 32 | """Example 3 - y contains all of the same values""" 33 | x = np.array([0, 0, 0, 0, 0, 0]) 34 | y = np.array([3, 3, 3, 3, 3, 3]) 35 | return (x, y) 36 | 37 | 38 | def _example_4(rs: np.random.RandomState) -> tuple[npt.NDArray[Any], npt.NDArray[Any]]: 39 | """Example 4 - y is multi-class, x is numeric""" 40 | x0 = rs.normal(3, 0.1, 100) 41 | yo = rs.choice([0, 1, 2], 100, p=[0.9, 0.05, 0.05]) 42 | 43 | x1 = rs.normal(2, 0.1, 100) 44 | y1 = rs.choice([0, 1, 2], 100, p=[0.05, 0.9, 0.05]) 45 | 46 | x2 = rs.normal(1, 0.1, 100) 47 | y2 = rs.choice([0, 1, 2], 100, p=[0.05, 0.05, 0.9]) 48 | 49 | x = np.concatenate([x0, x1, x2]) 50 | y = np.concatenate([yo, y1, y2]) 51 | return (x, y) 52 | 53 | 54 | def _example_5(rs: np.random.RandomState) -> tuple[npt.NDArray[Any], npt.NDArray[Any]]: 55 | """Example 5 - x is not predictive of y (random noise)""" 56 | x = rs.choice(["a", "b"], 500) 57 | y = rs.choice([0, 1], 500) 58 | return (x, y) 59 | 60 | 61 | def _example_6(rs: np.random.RandomState) -> tuple[npt.NDArray[Any], npt.NDArray[Any]]: 62 | """Example 6 - Both variables fully random: Correlation should be 0""" 63 | x = rs.randn(500) 64 | y = rs.randn(500) 65 | return (x, y) 66 | 67 | 68 | def _example_7(rs: np.random.RandomState) -> tuple[npt.NDArray[Any], npt.NDArray[Any]]: 69 | """Example 7 - y = x^2 + noise""" 70 | x = (rs.rand(500) * 4) - 2 71 | y = x**2 + (rs.randn(500) * 0.2) 72 | return (x, y) 73 | 74 | 75 | def _example_8(rs: np.random.RandomState) -> tuple[npt.NDArray[Any], npt.NDArray[Any]]: 76 | """Example 8 - y = x + noise, correlation should be high""" 77 | x = rs.randn(500) 78 | e = rs.randn(500) * 0.2 79 | y = x + e 80 | return (x, y) 81 | 82 | 83 | def test_pearson_correlation() -> None: 84 | rs = np.random.RandomState(365) 85 | 86 | x, y = _example_1(rs) 87 | res = pearson_correlation(x, y, is_x_categorical=True, is_y_categorical=True) 88 | assert res > 0.333 89 | 90 | x, y = _example_2() 91 | res = pearson_correlation(x, y, is_x_categorical=True, is_y_categorical=True) 92 | assert np.isclose(res, 1, atol=1e-6) 93 | 94 | x, y = _example_3() 95 | res = pearson_correlation(x, y, is_x_categorical=True, is_y_categorical=True, print_diagnostics=True) 96 | assert res == 1 97 | 98 | x, y = _example_4(rs) 99 | res = pearson_correlation(x, y, is_x_categorical=False, is_y_categorical=True) 100 | assert res > 0.333 101 | 102 | x, y = _example_5(rs) 103 | res = pearson_correlation(x, y, is_x_categorical=True, is_y_categorical=True) 104 | assert res < 0.333 105 | 106 | x, y = _example_6(rs) 107 | res = pearson_correlation(x, y) 108 | assert res < 0.333 109 | 110 | x, y = _example_7(rs) 111 | res = pearson_correlation(x, y) 112 | assert res >= 0 113 | 114 | x, y = _example_8(rs) 115 | res = pearson_correlation(x, y) 116 | assert res > 0.5 117 | 118 | 119 | def test_pred_power_score_classification() -> None: 120 | rs = np.random.RandomState(365) 121 | 122 | x, y = _example_1(rs) 123 | res = pred_power_score_classification(x, y) 124 | assert res > 0.333 125 | 126 | x, y = _example_2() 127 | res = pred_power_score_classification(x, y) 128 | assert res == 1 129 | 130 | x, y = _example_3() 131 | res = pred_power_score_classification(x, y, print_diagnostics=True) 132 | assert res == 1 133 | 134 | x, y = _example_4(rs) 135 | res = pred_power_score_classification(x, y) 136 | assert res > 0.333 137 | 138 | x, y = _example_5(rs) 139 | res = pred_power_score_classification(x, y, cv_splits=10) 140 | assert res < 0.333 141 | -------------------------------------------------------------------------------- /tests/viz/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DaveCoDev/not-again-ai/93302ff4b8b9eda5398e2c318854677f517445c8/tests/viz/__init__.py -------------------------------------------------------------------------------- /tests/viz/test_barplot.py: -------------------------------------------------------------------------------- 1 | from not_again_ai.viz.barplots import simple_barplot 2 | 3 | 4 | def test_simple_barplot() -> None: 5 | save_pathname = ".nox/temp/barplot_test1.png" 6 | x = ["fence", "wall", "gate", "door", "window", "counter", "stair", "curtain", "ceiling", "floor"] 7 | y = [0.5, 0.01, 0.25, 0.2, 0.1, 0.05, 0.04, 0.03, 0.02, 0.3] 8 | simple_barplot( 9 | x, 10 | y, 11 | save_pathname, 12 | order="asc", 13 | title="Token logits", 14 | x_label="token", 15 | y_label="logit", 16 | ) 17 | 18 | save_pathname = ".nox/temp/barplot_test2.png" 19 | simple_barplot( 20 | y, 21 | x, 22 | save_pathname, 23 | order="desc", 24 | orient_bars_vertically=False, 25 | title="Token logits", 26 | x_label="logit", 27 | y_label="token", 28 | ) 29 | 30 | save_pathname = ".nox/temp/barplot_test3.png" 31 | simple_barplot( 32 | y, 33 | x, 34 | save_pathname, 35 | order="asc", 36 | orient_bars_vertically=False, 37 | title="Token logits", 38 | x_label="logit", 39 | y_label="token", 40 | ) 41 | 42 | save_pathname = ".nox/temp/barplot_test4.png" 43 | simple_barplot( 44 | x, 45 | y, 46 | save_pathname, 47 | order="desc", 48 | orient_bars_vertically=True, 49 | title="Token logits", 50 | x_label="logit", 51 | y_label="token", 52 | ) 53 | -------------------------------------------------------------------------------- /tests/viz/test_distributions.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | from not_again_ai.viz.distributions import univariate_distplot 4 | 5 | 6 | def test_univariate_distplot() -> None: 7 | univariate_distplot(data=np.random.normal(size=1000), save_pathname=".nox/temp/distributions_test1.png") 8 | univariate_distplot( 9 | data=np.random.normal(size=1000), 10 | save_pathname=".nox/temp/distributions_test2.png", 11 | print_summary=False, 12 | title="Test title", 13 | xlabel="Test xlabel", 14 | ylabel="Test ylabel", 15 | xlim=(-5, 5), 16 | ylim=(0, 40), 17 | xticks=None, 18 | yticks=None, 19 | bins=200, 20 | font_size=54, 21 | height=14, 22 | aspect=2.2, 23 | ) 24 | univariate_distplot( 25 | data=[1, 1, 2, 2, 3, 3, 4, 4, 5, 5, 6, 6, 7, 7, 8, 8, 9, 9, 10, 10], 26 | save_pathname=".nox/temp/distributions_test3.png", 27 | print_summary=False, 28 | xlim=None, 29 | ylim=None, 30 | xticks=np.arange(0, 11, 1), 31 | yticks=np.arange(0, 5, 1), 32 | bins=100, 33 | ) 34 | univariate_distplot( 35 | data=np.random.beta(a=0.5, b=0.5, size=10000), 36 | save_pathname=".nox/temp/distributions_test4.svg", 37 | print_summary=False, 38 | bins=100, 39 | title=r"Beta Distribution $\alpha=0.5, \beta=0.5$", 40 | font_size=18, 41 | height=3.91, 42 | aspect=1.8, 43 | ) 44 | -------------------------------------------------------------------------------- /tests/viz/test_scatterplot.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | from not_again_ai.viz.scatterplot import scatterplot_basic 4 | 5 | 6 | def test_scatterplot_basic() -> None: 7 | rs = np.random.RandomState(365) 8 | x = rs.randn(100) 9 | e = rs.randn(100) * 0.2 10 | y = x + e 11 | scatterplot_basic(x, y, save_pathname=".nox/temp/scatterplot_basic1.png", title="Correlation Chart") 12 | scatterplot_basic( 13 | x, 14 | y, 15 | save_pathname=".nox/temp/scatterplot_basic2.png", 16 | title=None, 17 | xlim=(-10, 10), 18 | ylim=(-5, 5), 19 | font_size=36, 20 | height=15, 21 | aspect=2.2, 22 | ) 23 | -------------------------------------------------------------------------------- /tests/viz/test_time_series.py: -------------------------------------------------------------------------------- 1 | import matplotlib 2 | import numpy as np 3 | import pandas as pd 4 | 5 | from not_again_ai.viz.time_series import ts_lineplot 6 | 7 | 8 | def test_ts_lineplot() -> None: 9 | rs = np.random.RandomState(365) 10 | values = rs.randn(365, 4).cumsum(axis=0).T 11 | dates = pd.date_range("1 1 2021", periods=365, freq="D") 12 | ts_lineplot( 13 | ts_data=values, 14 | save_pathname=".nox/temp/ts_lineplot1.png", 15 | ts_x=None, 16 | ts_names=None, 17 | legend_title="Test Title", 18 | ) 19 | ts_lineplot( 20 | ts_data=values, 21 | save_pathname=".nox/temp/ts_lineplot2.png", 22 | ts_x=dates, 23 | ts_names=None, 24 | xlabel="Months", 25 | ylabel="Number", 26 | ylim=(-15, 15), 27 | ) 28 | ts_lineplot( 29 | ts_data=values, 30 | save_pathname=".nox/temp/ts_lineplot3.png", 31 | ts_x=None, 32 | ts_names=["A", "B", "C", "D"], 33 | height=14, 34 | aspect=1.5, 35 | ) 36 | ts_lineplot( 37 | ts_data=values, 38 | save_pathname=".nox/temp/ts_lineplot4.png", 39 | ts_x=dates, 40 | ts_names=["A", "B", "C", "D"], 41 | title="Example Time Series", 42 | xlabel=None, 43 | ylabel=None, 44 | yticks=np.arange(-30, 40, 10), 45 | xaxis_date_format="%b '%y", 46 | xaxis_major_locator=matplotlib.dates.MonthLocator((1, 3, 5, 7, 9, 11)), 47 | font_size=46, 48 | linewidth=1.8, 49 | legend_loc=2, 50 | palette="colorblind", 51 | ) 52 | ts_lineplot( 53 | ts_data=values, 54 | save_pathname=".nox/temp/ts_lineplot5.svg", 55 | ts_x=dates, 56 | ts_names=["A", "B", "C", "D"], 57 | title="Example Time Series", 58 | xlabel=None, 59 | ylabel=None, 60 | yticks=np.arange(-30, 40, 10), 61 | xaxis_date_format="%b '%y", 62 | xaxis_major_locator=matplotlib.dates.MonthLocator((1, 3, 5, 7, 9, 11)), 63 | linewidth=1.5, 64 | legend_loc=2, 65 | palette="colorblind", 66 | font_size=20, 67 | height=4.4, 68 | aspect=1.8, 69 | ) 70 | --------------------------------------------------------------------------------