├── .github ├── assets │ └── badges │ │ └── .gitkeep ├── dependabot.yml ├── pr-labeler.yml ├── release-drafter.yml └── workflows │ ├── draft.yml │ ├── pr_labeler.yml │ ├── pre_commit_auto_update.yml │ └── pypi-deploy.yml ├── .gitignore ├── .pre-commit-config.yaml ├── CODE_OF_CONDUCT.md ├── CONTRIBUTING.md ├── LICENSE ├── README.md ├── pyproject.toml ├── requirements-dev.txt ├── skllm ├── __init__.py ├── classification.py ├── config.py ├── datasets │ ├── __init__.py │ ├── multi_class.py │ ├── multi_label.py │ ├── summarization.py │ └── translation.py ├── llm │ ├── anthropic │ │ ├── completion.py │ │ ├── credentials.py │ │ └── mixin.py │ ├── base.py │ ├── gpt │ │ ├── clients │ │ │ ├── llama_cpp │ │ │ │ ├── completion.py │ │ │ │ └── handler.py │ │ │ └── openai │ │ │ │ ├── completion.py │ │ │ │ ├── credentials.py │ │ │ │ ├── embedding.py │ │ │ │ └── tuning.py │ │ ├── completion.py │ │ ├── embedding.py │ │ ├── mixin.py │ │ └── utils.py │ └── vertex │ │ ├── completion.py │ │ ├── mixin.py │ │ └── tuning.py ├── memory │ ├── __init__.py │ ├── _annoy.py │ ├── _sklearn_nn.py │ └── base.py ├── models │ ├── _base │ │ ├── classifier.py │ │ ├── tagger.py │ │ ├── text2text.py │ │ └── vectorizer.py │ ├── anthropic │ │ ├── classification │ │ │ ├── few_shot.py │ │ │ └── zero_shot.py │ │ ├── tagging │ │ │ └── ner.py │ │ └── text2text │ │ │ ├── __init__.py │ │ │ ├── summarization.py │ │ │ └── translation.py │ ├── gpt │ │ ├── classification │ │ │ ├── few_shot.py │ │ │ ├── tunable.py │ │ │ └── zero_shot.py │ │ ├── tagging │ │ │ └── ner.py │ │ ├── text2text │ │ │ ├── __init__.py │ │ │ ├── summarization.py │ │ │ ├── translation.py │ │ │ └── tunable.py │ │ └── vectorization.py │ └── vertex │ │ ├── classification │ │ ├── tunable.py │ │ └── zero_shot.py │ │ └── text2text │ │ ├── __init__.py │ │ └── tunable.py ├── prompts │ ├── builders.py │ └── templates.py ├── text2text.py ├── utils │ ├── __init__.py │ ├── rendering.py │ └── xml.py └── vectorization.py └── tests ├── llm ├── __init__.py ├── anthropic │ ├── __init__.py │ └── test_anthropic_mixins.py ├── gpt │ ├── __init__.py │ └── test_gpt_mixins.py └── vertex │ ├── __init__.py │ └── test_vertex_mixins.py └── test_utils.py /.github/assets/badges/.gitkeep: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BeastByteAI/scikit-llm/5491ec8d1ba5528b560cd115f7a0c93369fb0628/.github/assets/badges/.gitkeep -------------------------------------------------------------------------------- /.github/dependabot.yml: -------------------------------------------------------------------------------- 1 | version: 2 2 | updates: 3 | - package-ecosystem: "pip" 4 | directory: "/" 5 | schedule: 6 | interval: "monthly" 7 | - package-ecosystem: "github-actions" 8 | directory: "/" 9 | schedule: 10 | interval: "monthly" 11 | -------------------------------------------------------------------------------- /.github/pr-labeler.yml: -------------------------------------------------------------------------------- 1 | feature: ['features/*', 'feature/*', 'feat/*', 'features-*', 'feature-*', 'feat-*'] 2 | fix: ['fixes/*', 'fix/*'] 3 | chore: ['chore/*'] 4 | dependencies: ['update/*'] 5 | -------------------------------------------------------------------------------- /.github/release-drafter.yml: -------------------------------------------------------------------------------- 1 | name-template: "v$RESOLVED_VERSION" 2 | tag-template: "v$RESOLVED_VERSION" 3 | categories: 4 | - title: "🚀 Features" 5 | labels: 6 | - "feature" 7 | - "enhancement" 8 | - title: "🐛 Bug Fixes" 9 | labels: 10 | - "fix" 11 | - "bugfix" 12 | - "bug" 13 | - title: "🧹 Maintenance" 14 | labels: 15 | - "maintenance" 16 | - "dependencies" 17 | - "refactoring" 18 | - "cosmetic" 19 | - "chore" 20 | - title: "📝️ Documentation" 21 | labels: 22 | - "documentation" 23 | - "docs" 24 | change-template: "- $TITLE (#$NUMBER)" 25 | change-title-escapes: '\<*_&' # You can add # and @ to disable mentions 26 | version-resolver: 27 | major: 28 | labels: 29 | - "major" 30 | minor: 31 | labels: 32 | - "minor" 33 | patch: 34 | labels: 35 | - "patch" 36 | default: patch 37 | template: | 38 | ## Changes 39 | 40 | $CHANGES 41 | -------------------------------------------------------------------------------- /.github/workflows/draft.yml: -------------------------------------------------------------------------------- 1 | # Drafts the next Release notes as Pull Requests are merged (or commits are pushed) into "main" or "master" 2 | name: Draft next release 3 | 4 | on: 5 | push: 6 | branches: [main, "master"] 7 | 8 | permissions: 9 | contents: read 10 | 11 | jobs: 12 | update-release-draft: 13 | permissions: 14 | contents: write 15 | pull-requests: write 16 | runs-on: ubuntu-latest 17 | steps: 18 | - uses: release-drafter/release-drafter@v6 19 | env: 20 | GITHUB_TOKEN: ${{ github.token }} 21 | -------------------------------------------------------------------------------- /.github/workflows/pr_labeler.yml: -------------------------------------------------------------------------------- 1 | # This workflow will apply the corresponding label on a pull request 2 | name: PR Labeler 3 | 4 | on: 5 | pull_request_target: 6 | 7 | permissions: 8 | contents: read 9 | pull-requests: write 10 | 11 | jobs: 12 | pr-labeler: 13 | runs-on: ubuntu-latest 14 | steps: 15 | - uses: TimonVS/pr-labeler-action@v5 16 | with: 17 | repo-token: ${{ github.token }} 18 | -------------------------------------------------------------------------------- /.github/workflows/pre_commit_auto_update.yml: -------------------------------------------------------------------------------- 1 | # Run a pre-commit autoupdate every week and open a pull request if needed 2 | name: Pre-commit auto-update 3 | 4 | on: 5 | # At 00:00 on the 1st of every month. 6 | schedule: 7 | - cron: "0 0 1 * *" 8 | workflow_dispatch: 9 | 10 | permissions: 11 | contents: write 12 | pull-requests: write 13 | 14 | jobs: 15 | pre-commit-auto-update: 16 | runs-on: ubuntu-latest 17 | steps: 18 | - name: Checkout 19 | uses: actions/checkout@v4 20 | - name: Set up Python 21 | uses: actions/setup-python@v5 22 | - name: Install pre-commit 23 | run: pip install pre-commit 24 | - name: Run pre-commit 25 | run: pre-commit autoupdate 26 | - name: Set git config 27 | run: | 28 | git config --local user.email "action@github.com" 29 | git config --local user.name "GitHub Action" 30 | - uses: peter-evans/create-pull-request@v6 31 | with: 32 | token: ${{ github.token }} 33 | branch: update/pre-commit-hooks 34 | title: Update pre-commit hooks 35 | commit-message: "Update pre-commit hooks" 36 | body: Update versions of pre-commit hooks to latest version. 37 | labels: "dependencies,github_actions" 38 | -------------------------------------------------------------------------------- /.github/workflows/pypi-deploy.yml: -------------------------------------------------------------------------------- 1 | name: PyPi Deploy 2 | 3 | on: 4 | release: 5 | types: [published] 6 | workflow_dispatch: 7 | 8 | jobs: 9 | deploy: 10 | runs-on: ubuntu-latest 11 | steps: 12 | - name: Checkout code 13 | uses: actions/checkout@v4 14 | 15 | - name: Setup Python 16 | uses: actions/setup-python@v5 17 | with: 18 | python-version: '3.10' 19 | 20 | - name: Install dependencies 21 | run: | 22 | python -m pip install --upgrade pip 23 | pip install build twine 24 | 25 | - name: Build and publish 26 | env: 27 | TWINE_USERNAME: __token__ 28 | TWINE_PASSWORD: ${{ secrets.PYPI_TOKEN }} 29 | run: | 30 | python -m build 31 | twine upload dist/* 32 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | share/python-wheels/ 24 | *.egg-info/ 25 | .installed.cfg 26 | *.egg 27 | MANIFEST 28 | 29 | # PyInstaller 30 | # Usually these files are written by a python script from a template 31 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 32 | *.manifest 33 | *.spec 34 | 35 | # Installer logs 36 | pip-log.txt 37 | pip-delete-this-directory.txt 38 | 39 | # Unit test / coverage reports 40 | htmlcov/ 41 | .tox/ 42 | .nox/ 43 | .coverage 44 | .coverage.* 45 | .cache 46 | nosetests.xml 47 | coverage.xml 48 | *.cover 49 | *.py,cover 50 | .hypothesis/ 51 | .pytest_cache/ 52 | cover/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | .pybuilder/ 76 | target/ 77 | 78 | # Jupyter Notebook 79 | .ipynb_checkpoints 80 | 81 | # IPython 82 | profile_default/ 83 | ipython_config.py 84 | 85 | # pyenv 86 | # For a library or package, you might want to ignore these files since the code is 87 | # intended to run in multiple environments; otherwise, check them in: 88 | # .python-version 89 | 90 | # pipenv 91 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 92 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 93 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 94 | # install all needed dependencies. 95 | #Pipfile.lock 96 | 97 | # poetry 98 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. 99 | # This is especially recommended for binary packages to ensure reproducibility, and is more 100 | # commonly ignored for libraries. 101 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control 102 | #poetry.lock 103 | 104 | # pdm 105 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. 106 | #pdm.lock 107 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it 108 | # in version control. 109 | # https://pdm.fming.dev/#use-with-ide 110 | .pdm.toml 111 | 112 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm 113 | __pypackages__/ 114 | 115 | # Celery stuff 116 | celerybeat-schedule 117 | celerybeat.pid 118 | 119 | # SageMath parsed files 120 | *.sage.py 121 | 122 | # Environments 123 | .env 124 | .venv 125 | env/ 126 | venv/ 127 | ENV/ 128 | env.bak/ 129 | venv.bak/ 130 | 131 | # Spyder project settings 132 | .spyderproject 133 | .spyproject 134 | 135 | # Rope project settings 136 | .ropeproject 137 | 138 | # mkdocs documentation 139 | /site 140 | 141 | # mypy 142 | .mypy_cache/ 143 | .dmypy.json 144 | dmypy.json 145 | 146 | # Pyre type checker 147 | .pyre/ 148 | 149 | # pytype static type analyzer 150 | .pytype/ 151 | 152 | # Cython debug symbols 153 | cython_debug/ 154 | 155 | # PyCharm 156 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can 157 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore 158 | # and can be added to the global gitignore or merged into this file. For a more nuclear 159 | # option (not recommended) you can uncomment the following to ignore the entire idea folder. 160 | #.idea/ 161 | test.py 162 | tmp.ipynb 163 | tmp.py 164 | *.pickle 165 | *.ipynb 166 | 167 | # vscode 168 | .vscode/ 169 | tmp2.py 170 | tmp.* -------------------------------------------------------------------------------- /.pre-commit-config.yaml: -------------------------------------------------------------------------------- 1 | repos: 2 | - repo: https://github.com/pre-commit/pre-commit-hooks 3 | rev: v4.4.0 4 | hooks: 5 | - id: trailing-whitespace 6 | - id: end-of-file-fixer 7 | - id: check-docstring-first 8 | - id: check-xml 9 | - id: check-json 10 | - id: check-yaml 11 | - id: check-toml 12 | - id: debug-statements 13 | - id: check-executables-have-shebangs 14 | - id: check-case-conflict 15 | - id: check-added-large-files 16 | - id: detect-private-key 17 | # Formatter for Json and Yaml files 18 | - repo: https://github.com/pre-commit/mirrors-prettier 19 | rev: v3.0.0-alpha.9-for-vscode 20 | hooks: 21 | - id: prettier 22 | types: [json, yaml, toml] 23 | # Formatter for markdown files 24 | - repo: https://github.com/executablebooks/mdformat 25 | rev: 0.7.16 26 | hooks: 27 | - id: mdformat 28 | args: ["--number"] 29 | additional_dependencies: 30 | - mdformat-gfm 31 | - mdformat-tables 32 | - mdformat-frontmatter 33 | - mdformat-black 34 | - mdformat-shfmt 35 | # An extremely fast Python linter, written in Rust 36 | - repo: https://github.com/charliermarsh/ruff-pre-commit 37 | rev: "v0.0.263" 38 | hooks: 39 | - id: ruff 40 | args: [--fix, --exit-non-zero-on-fix] 41 | # Python code formatter 42 | - repo: https://github.com/psf/black 43 | rev: 23.3.0 44 | hooks: 45 | - id: black 46 | args: ["--config", "pyproject.toml"] 47 | # Python's import formatter 48 | - repo: https://github.com/PyCQA/isort 49 | rev: 5.12.0 50 | hooks: 51 | - id: isort 52 | # Formats docstrings to follow PEP 257 53 | - repo: https://github.com/PyCQA/docformatter 54 | rev: v1.6.4 55 | hooks: 56 | - id: docformatter 57 | additional_dependencies: [tomli] 58 | args: ["--in-place", "--config", "pyproject.toml"] 59 | -------------------------------------------------------------------------------- /CODE_OF_CONDUCT.md: -------------------------------------------------------------------------------- 1 | # Contributor Covenant Code of Conduct 2 | 3 | ## Our Pledge 4 | 5 | We as members, contributors, and leaders pledge to make participation in our 6 | community a harassment-free experience for everyone, regardless of age, body 7 | size, visible or invisible disability, ethnicity, sex characteristics, gender 8 | identity and expression, level of experience, education, socio-economic status, 9 | nationality, personal appearance, race, religion, or sexual identity 10 | and orientation. 11 | 12 | We pledge to act and interact in ways that contribute to an open, welcoming, 13 | diverse, inclusive, and healthy community. 14 | 15 | ## Our Standards 16 | 17 | Examples of behavior that contributes to a positive environment for our 18 | community include: 19 | 20 | * Demonstrating empathy and kindness toward other people 21 | * Being respectful of differing opinions, viewpoints, and experiences 22 | * Giving and gracefully accepting constructive feedback 23 | * Accepting responsibility and apologizing to those affected by our mistakes, 24 | and learning from the experience 25 | * Focusing on what is best not just for us as individuals, but for the 26 | overall community 27 | 28 | Examples of unacceptable behavior include: 29 | 30 | * The use of sexualized language or imagery, and sexual attention or 31 | advances of any kind 32 | * Trolling, insulting or derogatory comments, and personal or political attacks 33 | * Public or private harassment 34 | * Publishing others' private information, such as a physical or email 35 | address, without their explicit permission 36 | * Other conduct which could reasonably be considered inappropriate in a 37 | professional setting 38 | 39 | ## Enforcement Responsibilities 40 | 41 | Community leaders are responsible for clarifying and enforcing our standards of 42 | acceptable behavior and will take appropriate and fair corrective action in 43 | response to any behavior that they deem inappropriate, threatening, offensive, 44 | or harmful. 45 | 46 | Community leaders have the right and responsibility to remove, edit, or reject 47 | comments, commits, code, wiki edits, issues, and other contributions that are 48 | not aligned to this Code of Conduct, and will communicate reasons for moderation 49 | decisions when appropriate. 50 | 51 | ## Scope 52 | 53 | This Code of Conduct applies within all community spaces, and also applies when 54 | an individual is officially representing the community in public spaces. 55 | Examples of representing our community include using an official e-mail address, 56 | posting via an official social media account, or acting as an appointed 57 | representative at an online or offline event. 58 | 59 | ## Enforcement 60 | 61 | Instances of abusive, harassing, or otherwise unacceptable behavior may be 62 | reported to the community leaders responsible for enforcement at 63 | https://www.linkedin.com/in/iryna-kondrashchenko-673800155/. 64 | All complaints will be reviewed and investigated promptly and fairly. 65 | 66 | All community leaders are obligated to respect the privacy and security of the 67 | reporter of any incident. 68 | 69 | ## Enforcement Guidelines 70 | 71 | Community leaders will follow these Community Impact Guidelines in determining 72 | the consequences for any action they deem in violation of this Code of Conduct: 73 | 74 | ### 1. Correction 75 | 76 | **Community Impact**: Use of inappropriate language or other behavior deemed 77 | unprofessional or unwelcome in the community. 78 | 79 | **Consequence**: A private, written warning from community leaders, providing 80 | clarity around the nature of the violation and an explanation of why the 81 | behavior was inappropriate. A public apology may be requested. 82 | 83 | ### 2. Warning 84 | 85 | **Community Impact**: A violation through a single incident or series 86 | of actions. 87 | 88 | **Consequence**: A warning with consequences for continued behavior. No 89 | interaction with the people involved, including unsolicited interaction with 90 | those enforcing the Code of Conduct, for a specified period of time. This 91 | includes avoiding interactions in community spaces as well as external channels 92 | like social media. Violating these terms may lead to a temporary or 93 | permanent ban. 94 | 95 | ### 3. Temporary Ban 96 | 97 | **Community Impact**: A serious violation of community standards, including 98 | sustained inappropriate behavior. 99 | 100 | **Consequence**: A temporary ban from any sort of interaction or public 101 | communication with the community for a specified period of time. No public or 102 | private interaction with the people involved, including unsolicited interaction 103 | with those enforcing the Code of Conduct, is allowed during this period. 104 | Violating these terms may lead to a permanent ban. 105 | 106 | ### 4. Permanent Ban 107 | 108 | **Community Impact**: Demonstrating a pattern of violation of community 109 | standards, including sustained inappropriate behavior, harassment of an 110 | individual, or aggression toward or disparagement of classes of individuals. 111 | 112 | **Consequence**: A permanent ban from any sort of public interaction within 113 | the community. 114 | 115 | ## Attribution 116 | 117 | This Code of Conduct is adapted from the [Contributor Covenant][homepage], 118 | version 2.0, available at 119 | https://www.contributor-covenant.org/version/2/0/code_of_conduct.html. 120 | 121 | Community Impact Guidelines were inspired by [Mozilla's code of conduct 122 | enforcement ladder](https://github.com/mozilla/diversity). 123 | 124 | [homepage]: https://www.contributor-covenant.org 125 | 126 | For answers to common questions about this code of conduct, see the FAQ at 127 | https://www.contributor-covenant.org/faq. Translations are available at 128 | https://www.contributor-covenant.org/translations. 129 | -------------------------------------------------------------------------------- /CONTRIBUTING.md: -------------------------------------------------------------------------------- 1 | # Contributing to Scikit-LLM 2 | 3 | Welcome! We appreciate your interest in contributing to Scikit-LLM. Whether you're a developer, designer, writer, or simply passionate about open source, there are several ways you can help improve this project. This document will guide you through the process of contributing to our Python repository. 4 | 5 | ## How Can I Contribute? 6 | 7 | **IMPORTANT:** We are currently preparing the transition to v.1.0 which will include major code restructuring. Until then, no new pull requests to the main branch will be approved unless discussed in advance via issues or in Discord ! 8 | 9 | There are several ways you can contribute to this project: 10 | 11 | - Bug Fixes: Help us identify and fix issues in the codebase. 12 | - Feature Implementation: Implement new features and enhancements. 13 | - Documentation: Improve the project's documentation, including code comments and README files. 14 | - Testing: Write and improve test cases to ensure the project's quality and reliability. 15 | - Translations: Provide translations for the project's documentation or user interface. 16 | - Bug Reports and Feature Requests: Submit bug reports or suggest new features and improvements. 17 | 18 | **Important:** before contributing, we recommend that you open an issue to discuss your planned changes. This allows us to align our goals, provide guidance, and potentially find other contributors interested in collaborating on the same feature or bug fix. 19 | 20 | > ### Legal Notice 21 | > 22 | > When contributing to this project, you must agree that you have authored 100% of the content, that you have the necessary rights to the content and that the content you contribute may be provided under the project license. 23 | 24 | ## Development dependencies 25 | 26 | In order to install all development dependencies, run the following command: 27 | 28 | ```shell 29 | pip install -r requirements-dev.txt 30 | ``` 31 | 32 | To ensure that you follow the development workflow, please setup the pre-commit hooks: 33 | 34 | ```shell 35 | pre-commit install 36 | ``` 37 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2023 Iryna Kondrashchenko 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 |
2 | logo 3 |
4 | 5 | # Scikit-LLM: Scikit-Learn Meets Large Language Models 6 | 7 | Seamlessly integrate powerful language models like ChatGPT into scikit-learn for enhanced text analysis tasks. 8 | 9 | ## Installation 💾 10 | 11 | ```bash 12 | pip install scikit-llm 13 | ``` 14 | 15 | ## Support us 🤝 16 | 17 | You can support the project in the following ways: 18 | 19 | - ⭐ Star Scikit-LLM on GitHub (click the star button in the top right corner) 20 | - 💡 Provide your feedback or propose ideas in the [issues](https://github.com/iryna-kondr/scikit-llm/issues) section or [Discord](https://discord.gg/YDAbwuWK7V) 21 | - 📰 Post about Scikit-LLM on LinkedIn or other platforms 22 | - 🔗 Check out our other projects: Dingo, Falcon 23 | 24 |
25 | 26 | 27 | 28 | 29 | Logo 30 | 31 |

32 | 33 | 34 | 35 | 36 | Logo 37 | 38 | 39 | 40 | ## Quick Start & Documentation 📚 41 | 42 | Quick start example of zero-shot text classification using GPT: 43 | 44 | ```python 45 | # Import the necessary modules 46 | from skllm.datasets import get_classification_dataset 47 | from skllm.config import SKLLMConfig 48 | from skllm.models.gpt.classification.zero_shot import ZeroShotGPTClassifier 49 | 50 | # Configure the credentials 51 | SKLLMConfig.set_openai_key("") 52 | SKLLMConfig.set_openai_org("") 53 | 54 | # Load a demo dataset 55 | X, y = get_classification_dataset() # labels: positive, negative, neutral 56 | 57 | # Initialize the model and make the predictions 58 | clf = ZeroShotGPTClassifier(model="gpt-4") 59 | clf.fit(X,y) 60 | clf.predict(X) 61 | ``` 62 | 63 | For more information please refer to the **[documentation](https://skllm.beastbyte.ai)**. 64 | 65 | ## Citation 66 | 67 | You can cite Scikit-LLM using the following BibTeX: 68 | 69 | ``` 70 | @software{ScikitLLM, 71 | author = {Iryna Kondrashchenko and Oleh Kostromin}, 72 | year = {2023}, 73 | publisher = {beastbyte.ai}, 74 | address = {Linz, Austria}, 75 | title = {Scikit-LLM: Scikit-Learn Meets Large Language Models}, 76 | url = {https://github.com/iryna-kondr/scikit-llm } 77 | } 78 | ``` 79 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [build-system] 2 | requires = ["setuptools>=61.0"] 3 | build-backend = "setuptools.build_meta" 4 | 5 | [project] 6 | dependencies = [ 7 | "scikit-learn>=1.1.0,<2.0.0", 8 | "pandas>=1.5.0,<3.0.0", 9 | "openai>=1.2.0,<2.0.0", 10 | "tqdm>=4.60.0,<5.0.0", 11 | "google-cloud-aiplatform[pipelines]>=1.27.0,<2.0.0" 12 | ] 13 | name = "scikit-llm" 14 | version = "1.4.1" 15 | authors = [ 16 | { name="Oleh Kostromin", email="kostromin97@gmail.com" }, 17 | { name="Iryna Kondrashchenko", email="iryna230520@gmail.com" }, 18 | ] 19 | description = "Scikit-LLM: Seamlessly integrate powerful language models like ChatGPT into scikit-learn for enhanced text analysis tasks." 20 | readme = "README.md" 21 | license = {text = "MIT"} 22 | requires-python = ">=3.9" 23 | classifiers = [ 24 | "Programming Language :: Python :: 3", 25 | "License :: OSI Approved :: MIT License", 26 | "Operating System :: OS Independent", 27 | ] 28 | 29 | [project.optional-dependencies] 30 | gguf = ["llama-cpp-python>=0.2.82,<0.2.83"] 31 | annoy = ["annoy>=1.17.2,<2.0.0"] 32 | 33 | [tool.ruff] 34 | select = [ 35 | # pycodestyle 36 | "E", 37 | # pyflakes 38 | "F", 39 | # pydocstyle 40 | "D", 41 | # flake8-bandit 42 | "S", 43 | # pyupgrade 44 | "UP", 45 | # pep8-naming 46 | "N", 47 | ] 48 | # Error E501 (Line too long) is ignored because of docstrings. 49 | ignore = [ 50 | "S101", 51 | "S301", 52 | "S311", 53 | "D100", 54 | "D200", 55 | "D203", 56 | "D205", 57 | "D401", 58 | "E501", 59 | "N803", 60 | "N806", 61 | "D104", 62 | ] 63 | extend-exclude = ["tests/*.py", "setup.py"] 64 | target-version = "py39" 65 | force-exclude = true 66 | 67 | [tool.ruff.per-file-ignores] 68 | "__init__.py" = ["E402", "F401", "F403", "F811"] 69 | 70 | [tool.ruff.pydocstyle] 71 | convention = "numpy" 72 | 73 | [tool.mypy] 74 | ignore_missing_imports = true 75 | 76 | [tool.black] 77 | preview = true 78 | target-version = ['py39', 'py310', 'py311'] 79 | 80 | [tool.isort] 81 | profile = "black" 82 | filter_files = true 83 | known_first_party = ["skllm", "skllm.*"] 84 | skip = ["__init__.py"] 85 | 86 | [tool.docformatter] 87 | close-quotes-on-newline = true # D209 88 | 89 | [tool.pytest.ini_options] 90 | pythonpath = [ 91 | "." 92 | ] 93 | -------------------------------------------------------------------------------- /requirements-dev.txt: -------------------------------------------------------------------------------- 1 | pre-commit 2 | black 3 | isort 4 | ruff 5 | docformatter 6 | interrogate 7 | numpy 8 | pandas 9 | -------------------------------------------------------------------------------- /skllm/__init__.py: -------------------------------------------------------------------------------- 1 | __version__ = '1.4.1' 2 | __author__ = 'Iryna Kondrashchenko, Oleh Kostromin' 3 | -------------------------------------------------------------------------------- /skllm/classification.py: -------------------------------------------------------------------------------- 1 | ## GPT 2 | 3 | from skllm.models.gpt.classification.zero_shot import ( 4 | ZeroShotGPTClassifier, 5 | MultiLabelZeroShotGPTClassifier, 6 | CoTGPTClassifier, 7 | ) 8 | from skllm.models.gpt.classification.few_shot import ( 9 | FewShotGPTClassifier, 10 | DynamicFewShotGPTClassifier, 11 | MultiLabelFewShotGPTClassifier, 12 | ) 13 | from skllm.models.gpt.classification.tunable import ( 14 | GPTClassifier as TunableGPTClassifier, 15 | ) 16 | 17 | ## Vertex 18 | from skllm.models.vertex.classification.zero_shot import ( 19 | ZeroShotVertexClassifier, 20 | MultiLabelZeroShotVertexClassifier, 21 | ) 22 | from skllm.models.vertex.classification.tunable import ( 23 | VertexClassifier as TunableVertexClassifier, 24 | ) 25 | -------------------------------------------------------------------------------- /skllm/config.py: -------------------------------------------------------------------------------- 1 | import os 2 | from typing import Optional 3 | 4 | _OPENAI_KEY_VAR = "SKLLM_CONFIG_OPENAI_KEY" 5 | _OPENAI_ORG_VAR = "SKLLM_CONFIG_OPENAI_ORG" 6 | _AZURE_API_BASE_VAR = "SKLLM_CONFIG_AZURE_API_BASE" 7 | _AZURE_API_VERSION_VAR = "SKLLM_CONFIG_AZURE_API_VERSION" 8 | _GOOGLE_PROJECT = "GOOGLE_CLOUD_PROJECT" 9 | _GPT_URL_VAR = "SKLLM_CONFIG_GPT_URL" 10 | _ANTHROPIC_KEY_VAR = "SKLLM_CONFIG_ANTHROPIC_KEY" 11 | _GGUF_DOWNLOAD_PATH = "SKLLM_CONFIG_GGUF_DOWNLOAD_PATH" 12 | _GGUF_MAX_GPU_LAYERS = "SKLLM_CONFIG_GGUF_MAX_GPU_LAYERS" 13 | _GGUF_VERBOSE = "SKLLM_CONFIG_GGUF_VERBOSE" 14 | 15 | 16 | class SKLLMConfig: 17 | @staticmethod 18 | def set_gpt_key(key: str) -> None: 19 | """Sets the GPT key. 20 | 21 | Parameters 22 | ---------- 23 | key : str 24 | GPT key. 25 | """ 26 | os.environ[_OPENAI_KEY_VAR] = key 27 | 28 | def set_gpt_org(key: str) -> None: 29 | """Sets the GPT organization ID. 30 | 31 | Parameters 32 | ---------- 33 | key : str 34 | GPT organization ID. 35 | """ 36 | os.environ[_OPENAI_ORG_VAR] = key 37 | 38 | @staticmethod 39 | def set_openai_key(key: str) -> None: 40 | """Sets the OpenAI key. 41 | 42 | Parameters 43 | ---------- 44 | key : str 45 | OpenAI key. 46 | """ 47 | os.environ[_OPENAI_KEY_VAR] = key 48 | 49 | @staticmethod 50 | def get_openai_key() -> Optional[str]: 51 | """Gets the OpenAI key. 52 | 53 | Returns 54 | ------- 55 | Optional[str] 56 | OpenAI key. 57 | """ 58 | return os.environ.get(_OPENAI_KEY_VAR, None) 59 | 60 | @staticmethod 61 | def set_openai_org(key: str) -> None: 62 | """Sets OpenAI organization ID. 63 | 64 | Parameters 65 | ---------- 66 | key : str 67 | OpenAI organization ID. 68 | """ 69 | os.environ[_OPENAI_ORG_VAR] = key 70 | 71 | @staticmethod 72 | def get_openai_org() -> str: 73 | """Gets the OpenAI organization ID. 74 | 75 | Returns 76 | ------- 77 | str 78 | OpenAI organization ID. 79 | """ 80 | return os.environ.get(_OPENAI_ORG_VAR, "") 81 | 82 | @staticmethod 83 | def get_azure_api_base() -> str: 84 | """Gets the API base for Azure. 85 | 86 | Returns 87 | ------- 88 | str 89 | URL to be used as the base for the Azure API. 90 | """ 91 | base = os.environ.get(_AZURE_API_BASE_VAR, None) 92 | if base is None: 93 | raise RuntimeError("Azure API base is not set") 94 | return base 95 | 96 | @staticmethod 97 | def set_azure_api_base(base: str) -> None: 98 | """Set the API base for Azure. 99 | 100 | Parameters 101 | ---------- 102 | base : str 103 | URL to be used as the base for the Azure API. 104 | """ 105 | os.environ[_AZURE_API_BASE_VAR] = base 106 | 107 | @staticmethod 108 | def set_azure_api_version(ver: str) -> None: 109 | """Set the API version for Azure. 110 | 111 | Parameters 112 | ---------- 113 | ver : str 114 | Azure API version. 115 | """ 116 | os.environ[_AZURE_API_VERSION_VAR] = ver 117 | 118 | @staticmethod 119 | def get_azure_api_version() -> str: 120 | """Gets the API version for Azure. 121 | 122 | Returns 123 | ------- 124 | str 125 | Azure API version. 126 | """ 127 | return os.environ.get(_AZURE_API_VERSION_VAR, "2023-05-15") 128 | 129 | @staticmethod 130 | def get_google_project() -> Optional[str]: 131 | """Gets the Google Cloud project ID. 132 | 133 | Returns 134 | ------- 135 | Optional[str] 136 | Google Cloud project ID. 137 | """ 138 | return os.environ.get(_GOOGLE_PROJECT, None) 139 | 140 | @staticmethod 141 | def set_google_project(project: str) -> None: 142 | """Sets the Google Cloud project ID. 143 | 144 | Parameters 145 | ---------- 146 | project : str 147 | Google Cloud project ID. 148 | """ 149 | os.environ[_GOOGLE_PROJECT] = project 150 | 151 | @staticmethod 152 | def set_gpt_url(url: str): 153 | """Sets the GPT URL. 154 | 155 | Parameters 156 | ---------- 157 | url : str 158 | GPT URL. 159 | """ 160 | os.environ[_GPT_URL_VAR] = url 161 | 162 | @staticmethod 163 | def get_gpt_url() -> Optional[str]: 164 | """Gets the GPT URL. 165 | 166 | Returns 167 | ------- 168 | Optional[str] 169 | GPT URL. 170 | """ 171 | return os.environ.get(_GPT_URL_VAR, None) 172 | 173 | @staticmethod 174 | def set_anthropic_key(key: str) -> None: 175 | """Sets the Anthropic key. 176 | 177 | Parameters 178 | ---------- 179 | key : str 180 | Anthropic key. 181 | """ 182 | os.environ[_ANTHROPIC_KEY_VAR] = key 183 | 184 | @staticmethod 185 | def get_anthropic_key() -> Optional[str]: 186 | """Gets the Anthropic key. 187 | 188 | Returns 189 | ------- 190 | Optional[str] 191 | Anthropic key. 192 | """ 193 | return os.environ.get(_ANTHROPIC_KEY_VAR, None) 194 | 195 | @staticmethod 196 | def reset_gpt_url(): 197 | """Resets the GPT URL.""" 198 | os.environ.pop(_GPT_URL_VAR, None) 199 | 200 | @staticmethod 201 | def get_gguf_download_path() -> str: 202 | """Gets the path to store the downloaded GGUF files.""" 203 | default_path = os.path.join(os.path.expanduser("~"), ".skllm", "gguf") 204 | return os.environ.get(_GGUF_DOWNLOAD_PATH, default_path) 205 | 206 | @staticmethod 207 | def get_gguf_max_gpu_layers() -> int: 208 | """Gets the maximum number of layers to use for the GGUF model.""" 209 | return int(os.environ.get(_GGUF_MAX_GPU_LAYERS, 0)) 210 | 211 | @staticmethod 212 | def set_gguf_max_gpu_layers(n_layers: int): 213 | """Sets the maximum number of layers to use for the GGUF model.""" 214 | if not isinstance(n_layers, int): 215 | raise ValueError("n_layers must be an integer") 216 | if n_layers < -1: 217 | n_layers = -1 218 | os.environ[_GGUF_MAX_GPU_LAYERS] = str(n_layers) 219 | 220 | @staticmethod 221 | def set_gguf_verbose(verbose: bool): 222 | """Sets the verbosity of the GGUF model.""" 223 | if not isinstance(verbose, bool): 224 | raise ValueError("verbose must be a boolean") 225 | os.environ[_GGUF_VERBOSE] = str(verbose) 226 | 227 | @staticmethod 228 | def get_gguf_verbose() -> bool: 229 | """Gets the verbosity of the GGUF model.""" 230 | return os.environ.get(_GGUF_VERBOSE, "False").lower() == "true" 231 | -------------------------------------------------------------------------------- /skllm/datasets/__init__.py: -------------------------------------------------------------------------------- 1 | from skllm.datasets.multi_class import get_classification_dataset 2 | from skllm.datasets.multi_label import get_multilabel_classification_dataset 3 | from skllm.datasets.summarization import get_summarization_dataset 4 | from skllm.datasets.translation import get_translation_dataset 5 | -------------------------------------------------------------------------------- /skllm/datasets/multi_class.py: -------------------------------------------------------------------------------- 1 | def get_classification_dataset(): 2 | X = [ 3 | r"I was absolutely blown away by the performances in 'Summer's End'. The acting was top-notch, and the plot had me gripped from start to finish. A truly captivating cinematic experience that I would highly recommend.", 4 | r"The special effects in 'Star Battles: Nebula Conflict' were out of this world. I felt like I was actually in space. The storyline was incredibly engaging and left me wanting more. Excellent film.", 5 | r"'The Lost Symphony' was a masterclass in character development and storytelling. The score was hauntingly beautiful and complimented the intense, emotional scenes perfectly. Kudos to the director and cast for creating such a masterpiece.", 6 | r"I was pleasantly surprised by 'Love in the Time of Cholera'. The romantic storyline was heartwarming and the characters were incredibly realistic. The cinematography was also top-notch. A must-watch for all romance lovers.", 7 | r"I went into 'Marble Street' with low expectations, but I was pleasantly surprised. The suspense was well-maintained throughout, and the twist at the end was something I did not see coming. Bravo!", 8 | r"'The Great Plains' is a touching portrayal of life in rural America. The performances were heartfelt and the scenery was breathtaking. I was moved to tears by the end. It's a story that will stay with me for a long time.", 9 | r"The screenwriting in 'Under the Willow Tree' was superb. The dialogue felt real and the characters were well-rounded. The performances were also fantastic. I haven't enjoyed a movie this much in a while.", 10 | r"'Nightshade' is a brilliant take on the superhero genre. The protagonist was relatable and the villain was genuinely scary. The action sequences were thrilling and the storyline was engaging. I can't wait for the sequel.", 11 | r"The cinematography in 'Awakening' was nothing short of spectacular. The visuals alone are worth the ticket price. The storyline was unique and the performances were solid. An overall fantastic film.", 12 | r"'Eternal Embers' was a cinematic delight. The storytelling was original and the performances were exceptional. The director's vision was truly brought to life on the big screen. A must-see for all movie lovers.", 13 | r"I was thoroughly disappointed with 'Silver Shadows'. The plot was confusing and the performances were lackluster. I wouldn't recommend wasting your time on this one.", 14 | r"'The Darkened Path' was a disaster. The storyline was unoriginal, the acting was wooden and the special effects were laughably bad. Save your money and skip this one.", 15 | r"I had high hopes for 'The Final Frontier', but it failed to deliver. The plot was full of holes and the characters were poorly developed. It was a disappointing experience.", 16 | r"'The Fall of the Phoenix' was a letdown. The storyline was confusing and the characters were one-dimensional. I found myself checking my watch multiple times throughout the movie.", 17 | r"I regret wasting my time on 'Emerald City'. The plot was nonsensical and the performances were uninspired. It was a major disappointment.", 18 | r"I found 'Hollow Echoes' to be a complete mess. The plot was non-existent, the performances were overdone, and the pacing was all over the place. Definitely not worth the hype.", 19 | r"'Underneath the Stars' was a huge disappointment. The storyline was predictable and the acting was mediocre at best. I was expecting so much more.", 20 | r"I was left unimpressed by 'River's Edge'. The plot was convoluted, the characters were uninteresting, and the ending was unsatisfying. It's a pass for me.", 21 | r"The acting in 'Desert Mirage' was subpar, and the plot was boring. I found myself yawning multiple times throughout the movie. Save your time and skip this one.", 22 | r"'Crimson Dawn' was a major letdown. The plot was cliched and the characters were flat. The special effects were also poorly executed. I wouldn't recommend it.", 23 | r"'Remember the Days' was utterly forgettable. The storyline was dull, the performances were bland, and the dialogue was cringeworthy. A big disappointment.", 24 | r"'The Last Frontier' was simply okay. The plot was decent and the performances were acceptable. However, it lacked a certain spark to make it truly memorable.", 25 | r"'Through the Storm' was not bad, but it wasn't great either. The storyline was somewhat predictable, and the characters were somewhat stereotypical. It was an average movie at best.", 26 | r"I found 'After the Rain' to be pretty average. The plot was okay and the performances were decent, but it didn't leave a lasting impression on me.", 27 | r"'Beyond the Horizon' was neither good nor bad. The plot was interesting enough, but the characters were not very well developed. It was an okay watch.", 28 | r"'The Silent Echo' was a mediocre movie. The storyline was passable and the performances were fair, but it didn't stand out in any way.", 29 | r"I thought 'The Scent of Roses' was pretty average. The plot was somewhat engaging, and the performances were okay, but it didn't live up to my expectations.", 30 | r"'Under the Same Sky' was an okay movie. The plot was decent, and the performances were fine, but it lacked depth and originality. It's not a movie I would watch again.", 31 | r"'Chasing Shadows' was fairly average. The plot was not bad, and the performances were passable, but it lacked a certain spark. It was just okay.", 32 | r"'Beneath the Surface' was pretty run-of-the-mill. The plot was decent, the performances were okay, but it wasn't particularly memorable. It was an okay movie.", 33 | ] 34 | 35 | 36 | y = ( 37 | ["positive" for _ in range(10)] 38 | + ["negative" for _ in range(10)] 39 | + ["neutral" for _ in range(10)] 40 | ) 41 | 42 | return X, y -------------------------------------------------------------------------------- /skllm/datasets/multi_label.py: -------------------------------------------------------------------------------- 1 | def get_multilabel_classification_dataset(): 2 | X = [ 3 | "The product was of excellent quality, and the packaging was also very good. Highly recommend!", 4 | "The delivery was super fast, but the product did not match the information provided on the website.", 5 | "Great variety of products, but the customer support was quite unresponsive.", 6 | "Affordable prices and an easy-to-use website. A great shopping experience overall.", 7 | "The delivery was delayed, and the packaging was damaged. Not a good experience.", 8 | "Excellent customer support, but the return policy is quite complicated.", 9 | "The product was not as described. However, the return process was easy and quick.", 10 | "Great service and fast delivery. The product was also of high quality.", 11 | "The prices are a bit high. However, the product quality and user experience are worth it.", 12 | "The website provides detailed information about products. The delivery was also very fast." 13 | ] 14 | 15 | y = [ 16 | ["Quality", "Packaging"], 17 | ["Delivery", "Product Information"], 18 | ["Product Variety", "Customer Support"], 19 | ["Price", "User Experience"], 20 | ["Delivery", "Packaging"], 21 | ["Customer Support", "Return Policy"], 22 | ["Product Information", "Return Policy"], 23 | ["Service", "Delivery", "Quality"], 24 | ["Price", "Quality", "User Experience"], 25 | ["Product Information", "Delivery"], 26 | ] 27 | 28 | return X, y -------------------------------------------------------------------------------- /skllm/datasets/summarization.py: -------------------------------------------------------------------------------- 1 | def get_summarization_dataset(): 2 | X = [ 3 | r"The AI research company, OpenAI, has launched a new language model called GPT-4. This model is the latest in a series of transformer-based AI systems designed to perform complex tasks, such as generating human-like text, translating languages, and answering questions. According to OpenAI, GPT-4 is even more powerful and versatile than its predecessors.", 4 | r"John went to the grocery store in the morning to prepare for a small get-together at his house. He bought fresh apples, juicy oranges, and a bottle of milk. Once back home, he used the apples and oranges to make a delicious fruit salad, which he served to his guests in the evening.", 5 | r"The first Mars rover, named Sojourner, was launched by NASA in 1996. The mission was a part of the Mars Pathfinder project and was a major success. The data Sojourner provided about Martian terrain and atmosphere greatly contributed to our understanding of the Red Planet.", 6 | r"A new study suggests that regular exercise can improve memory and cognitive function in older adults. The study, which monitored the health and habits of 500 older adults, recommends 30 minutes of moderate exercise daily for the best results.", 7 | r"The Eiffel Tower, a globally recognized symbol of Paris and France, was completed in 1889 for the World's Fair. Despite its initial criticism and controversy over its unconventional design, the Eiffel Tower has become a beloved landmark and a symbol of French architectural innovation.", 8 | r"Microsoft has announced a new version of its flagship operating system, Windows. The update, which will be rolled out later this year, features improved security protocols and a redesigned user interface, promising a more streamlined and safer user experience.", 9 | r"The WHO declared a new public health emergency due to an outbreak of a previously unknown virus. As the number of cases grows globally, the organization urges nations to ramp up their disease surveillance and response systems.", 10 | r"The 2024 Olympics have been confirmed to take place in Paris, France. This marks the third time the city will host the games, with previous occasions in 1900 and 1924. Preparations are already underway to make the event a grand spectacle.", 11 | r"Apple has introduced its latest iPhone model. The new device boasts a range of features, including an improved camera system, a faster processor, and a longer battery life. It is set to hit the market later this year.", 12 | r"Scientists working in the Amazon rainforest have discovered a new species of bird. The bird, characterized by its unique bright plumage, has created excitement among the global ornithological community." 13 | ] 14 | 15 | return X -------------------------------------------------------------------------------- /skllm/datasets/translation.py: -------------------------------------------------------------------------------- 1 | def get_translation_dataset(): 2 | X = [ 3 | r"Me encanta bailar salsa y bachata. Es una forma divertida de expresarme.", 4 | r"J'ai passé mes dernières vacances en Grèce. Les plages étaient magnifiques.", 5 | ( 6 | r"Ich habe gestern ein tolles Buch gelesen. Die Geschichte war fesselnd bis" 7 | r" zum Ende." 8 | ), 9 | ( 10 | r"Gosto de cozinhar pratos tradicionais italianos. O espaguete à carbonara" 11 | r" é um dos meus favoritos." 12 | ), 13 | ( 14 | r"Mám v plánu letos v létě vyrazit na výlet do Itálie. Doufám, že navštívím" 15 | r" Řím a Benátky." 16 | ), 17 | ( 18 | r"Mijn favoriete hobby is fotograferen. Ik hou ervan om mooie momenten vast" 19 | r" te leggen." 20 | ), 21 | ] 22 | 23 | return X 24 | -------------------------------------------------------------------------------- /skllm/llm/anthropic/completion.py: -------------------------------------------------------------------------------- 1 | from typing import Dict, List, Optional 2 | from skllm.llm.anthropic.credentials import set_credentials 3 | from skllm.utils import retry 4 | 5 | @retry(max_retries=3) 6 | def get_chat_completion( 7 | messages: List[Dict], 8 | key: str, 9 | model: str = "claude-3-haiku-20240307", 10 | max_tokens: int = 1000, 11 | temperature: float = 0.0, 12 | system: Optional[str] = None, 13 | json_response: bool = False, 14 | ) -> dict: 15 | """ 16 | Gets a chat completion from the Anthropic Claude API using the Messages API. 17 | 18 | Parameters 19 | ---------- 20 | messages : dict 21 | Input messages to use. 22 | key : str 23 | The Anthropic API key to use. 24 | model : str, optional 25 | The Claude model to use. 26 | max_tokens : int, optional 27 | Maximum tokens to generate. 28 | temperature : float, optional 29 | Sampling temperature. 30 | system : str, optional 31 | System message to set the assistant's behavior. 32 | json_response : bool, optional 33 | Whether to request a JSON-formatted response. Defaults to False. 34 | 35 | Returns 36 | ------- 37 | response : dict 38 | The completion response from the API. 39 | """ 40 | if not messages: 41 | raise ValueError("Messages list cannot be empty") 42 | if not isinstance(messages, list): 43 | raise TypeError("Messages must be a list") 44 | 45 | client = set_credentials(key) 46 | 47 | if json_response and system: 48 | system = f"{system.rstrip('.')}. Respond in JSON format." 49 | elif json_response: 50 | system = "Respond in JSON format." 51 | 52 | formatted_messages = [ 53 | { 54 | "role": "user", # Explicitly set role to "user" 55 | "content": [ 56 | { 57 | "type": "text", 58 | "text": message.get("content", "") 59 | } 60 | ] 61 | } 62 | for message in messages 63 | ] 64 | 65 | response = client.messages.create( 66 | model=model, 67 | max_tokens=max_tokens, 68 | temperature=temperature, 69 | system=system, 70 | messages=formatted_messages, 71 | ) 72 | return response -------------------------------------------------------------------------------- /skllm/llm/anthropic/credentials.py: -------------------------------------------------------------------------------- 1 | from anthropic import Anthropic 2 | 3 | 4 | def set_credentials(key: str) -> None: 5 | """Set the Anthropic key. 6 | 7 | Parameters 8 | ---------- 9 | key : str 10 | The Anthropic key to use. 11 | """ 12 | client = Anthropic(api_key=key) 13 | return client 14 | -------------------------------------------------------------------------------- /skllm/llm/anthropic/mixin.py: -------------------------------------------------------------------------------- 1 | from typing import Optional, Union, Any, List, Dict, Mapping 2 | from skllm.config import SKLLMConfig as _Config 3 | from skllm.llm.anthropic.completion import get_chat_completion 4 | from skllm.utils import extract_json_key 5 | from skllm.llm.base import BaseTextCompletionMixin, BaseClassifierMixin 6 | import json 7 | 8 | 9 | class ClaudeMixin: 10 | """A mixin class that provides Claude API key to other classes.""" 11 | 12 | _prefer_json_output = False 13 | 14 | def _set_keys(self, key: Optional[str] = None) -> None: 15 | """Set the Claude API key.""" 16 | self.key = key 17 | 18 | def _get_claude_key(self) -> str: 19 | """Get the Claude key from the class or config file.""" 20 | key = self.key 21 | if key is None: 22 | key = _Config.get_anthropic_key() 23 | if key is None: 24 | raise RuntimeError("Claude API key was not found") 25 | return key 26 | 27 | class ClaudeTextCompletionMixin(ClaudeMixin, BaseTextCompletionMixin): 28 | """A mixin class that provides text completion capabilities using the Claude API.""" 29 | 30 | def _get_chat_completion( 31 | self, 32 | model: str, 33 | messages: Union[str, List[Dict[str, str]]], 34 | system_message: Optional[str] = None, 35 | **kwargs: Any, 36 | ): 37 | """Gets a chat completion from the Anthropic API. 38 | 39 | Parameters 40 | ---------- 41 | model : str 42 | The model to use. 43 | messages : Union[str, List[Dict[str, str]]] 44 | input messages to use. 45 | system_message : Optional[str] 46 | A system message to use. 47 | **kwargs : Any 48 | placeholder. 49 | 50 | Returns 51 | ------- 52 | completion : dict 53 | """ 54 | if isinstance(messages, str): 55 | messages = [{"content": messages}] 56 | elif isinstance(messages, list): 57 | messages = [{"content": msg["content"]} for msg in messages] 58 | 59 | completion = get_chat_completion( 60 | messages=messages, 61 | key=self._get_claude_key(), 62 | model=model, 63 | system=system_message, 64 | json_response=self._prefer_json_output, 65 | **kwargs, 66 | ) 67 | return completion 68 | 69 | def _convert_completion_to_str(self, completion: Mapping[str, Any]): 70 | """Converts Claude API completion to string.""" 71 | try: 72 | if hasattr(completion, 'content'): 73 | return completion.content[0].text 74 | return completion.get('content', [{}])[0].get('text', '') 75 | except Exception as e: 76 | print(f"Error converting completion to string: {str(e)}") 77 | return "" 78 | 79 | class ClaudeClassifierMixin(ClaudeTextCompletionMixin, BaseClassifierMixin): 80 | """A mixin class that provides classification capabilities using Claude API.""" 81 | 82 | _prefer_json_output = True 83 | 84 | def _extract_out_label(self, completion: Mapping[str, Any], **kwargs) -> str: 85 | """Extracts the label from a Claude API completion.""" 86 | try: 87 | content = self._convert_completion_to_str(completion) 88 | if not self._prefer_json_output: 89 | return content.strip() 90 | 91 | # Attempt to parse content as JSON and extract label 92 | try: 93 | data = json.loads(content) 94 | if "label" in data: 95 | return data["label"] 96 | except json.JSONDecodeError: 97 | pass 98 | return "" 99 | 100 | except Exception as e: 101 | print(f"Error extracting label: {str(e)}") 102 | return "" 103 | -------------------------------------------------------------------------------- /skllm/llm/base.py: -------------------------------------------------------------------------------- 1 | from abc import ABC, abstractmethod 2 | from typing import Any 3 | 4 | 5 | class BaseTextCompletionMixin(ABC): 6 | @abstractmethod 7 | def _get_chat_completion(self, **kwargs): 8 | """Gets a chat completion from the LLM""" 9 | pass 10 | 11 | @abstractmethod 12 | def _convert_completion_to_str(self, completion: Any): 13 | """Converts a completion object to a string""" 14 | pass 15 | 16 | 17 | class BaseClassifierMixin(BaseTextCompletionMixin): 18 | @abstractmethod 19 | def _extract_out_label(self, completion: Any) -> str: 20 | """Extracts the label from a completion""" 21 | pass 22 | 23 | 24 | class BaseEmbeddingMixin(ABC): 25 | @abstractmethod 26 | def _get_embeddings(self, **kwargs): 27 | """Gets embeddings from the LLM""" 28 | pass 29 | 30 | 31 | class BaseTunableMixin(ABC): 32 | @abstractmethod 33 | def _tune(self, X: Any, y: Any): 34 | pass 35 | 36 | @abstractmethod 37 | def _set_hyperparameters(self, **kwargs): 38 | pass 39 | -------------------------------------------------------------------------------- /skllm/llm/gpt/clients/llama_cpp/completion.py: -------------------------------------------------------------------------------- 1 | from skllm.llm.gpt.clients.llama_cpp.handler import ModelCache, LlamaHandler 2 | 3 | 4 | def get_chat_completion(messages: dict, model: str, **kwargs): 5 | 6 | with ModelCache.lock: 7 | handler = ModelCache.get(model) 8 | if handler is None: 9 | handler = LlamaHandler(model) 10 | ModelCache.store(model, handler) 11 | 12 | return handler.get_chat_completion(messages, **kwargs) 13 | -------------------------------------------------------------------------------- /skllm/llm/gpt/clients/llama_cpp/handler.py: -------------------------------------------------------------------------------- 1 | import threading 2 | import os 3 | import hashlib 4 | import requests 5 | from tqdm import tqdm 6 | import hashlib 7 | from typing import Optional 8 | import tempfile 9 | from skllm.config import SKLLMConfig 10 | from warnings import warn 11 | 12 | 13 | try: 14 | from llama_cpp import Llama as _Llama 15 | 16 | _llama_imported = True 17 | except (ImportError, ModuleNotFoundError): 18 | _llama_imported = False 19 | 20 | 21 | supported_models = { 22 | "llama3-8b-q4": { 23 | "download_url": "https://huggingface.co/QuantFactory/Meta-Llama-3-8B-Instruct-GGUF/resolve/main/Meta-Llama-3-8B-Instruct.Q4_K_M.gguf", 24 | "sha256": "c57380038ea85d8bec586ec2af9c91abc2f2b332d41d6cf180581d7bdffb93c1", 25 | "n_ctx": 8192, 26 | "supports_system_message": True, 27 | }, 28 | "gemma2-9b-q4": { 29 | "download_url": "https://huggingface.co/bartowski/gemma-2-9b-it-GGUF/resolve/main/gemma-2-9b-it-Q4_K_M.gguf", 30 | "sha256": "13b2a7b4115bbd0900162edcebe476da1ba1fc24e718e8b40d32f6e300f56dfe", 31 | "n_ctx": 8192, 32 | "supports_system_message": False, 33 | }, 34 | "phi3-mini-q4": { 35 | "download_url": "https://huggingface.co/microsoft/Phi-3-mini-4k-instruct-gguf/resolve/main/Phi-3-mini-4k-instruct-q4.gguf", 36 | "sha256": "8a83c7fb9049a9b2e92266fa7ad04933bb53aa1e85136b7b30f1b8000ff2edef", 37 | "n_ctx": 4096, 38 | "supports_system_message": False, 39 | }, 40 | "mistral0.3-7b-q4": { 41 | "download_url": "https://huggingface.co/lmstudio-community/Mistral-7B-Instruct-v0.3-GGUF/resolve/main/Mistral-7B-Instruct-v0.3-Q4_K_M.gguf", 42 | "sha256": "1270d22c0fbb3d092fb725d4d96c457b7b687a5f5a715abe1e818da303e562b6", 43 | "n_ctx": 32768, 44 | "supports_system_message": False, 45 | }, 46 | "gemma2-2b-q6": { 47 | "download_url": "https://huggingface.co/bartowski/gemma-2-2b-it-GGUF/resolve/main/gemma-2-2b-it-Q6_K_L.gguf", 48 | "sha256": "b2ef9f67b38c6e246e593cdb9739e34043d84549755a1057d402563a78ff2254", 49 | "n_ctx": 8192, 50 | "supports_system_message": False, 51 | }, 52 | } 53 | 54 | 55 | class LlamaHandler: 56 | 57 | def maybe_download_model(self, model_name, download_url, sha256) -> str: 58 | download_folder = SKLLMConfig.get_gguf_download_path() 59 | os.makedirs(download_folder, exist_ok=True) 60 | model_name = model_name + ".gguf" 61 | model_path = os.path.join(download_folder, model_name) 62 | if not os.path.exists(model_path): 63 | print("The model `{0}` is not found locally.".format(model_name)) 64 | self._download_model(model_name, download_folder, download_url, sha256) 65 | return model_path 66 | 67 | def _download_model( 68 | self, model_filename: str, model_path: str, url: str, expected_sha256: str 69 | ) -> str: 70 | full_path = os.path.join(model_path, model_filename) 71 | temp_file = tempfile.NamedTemporaryFile(delete=False, dir=model_path) 72 | temp_path = temp_file.name 73 | temp_file.close() 74 | 75 | response = requests.get(url, stream=True) 76 | 77 | if response.status_code != 200: 78 | os.remove(temp_path) 79 | raise ValueError( 80 | f"Request failed: HTTP {response.status_code} {response.reason}" 81 | ) 82 | 83 | total_size_in_bytes = int(response.headers.get("content-length", 0)) 84 | block_size = 1024 * 1024 * 4 85 | 86 | sha256 = hashlib.sha256() 87 | 88 | with ( 89 | open(temp_path, "wb") as file, 90 | tqdm( 91 | desc="Downloading {0}: ".format(model_filename), 92 | total=total_size_in_bytes, 93 | unit="iB", 94 | unit_scale=True, 95 | ) as progress_bar, 96 | ): 97 | for data in response.iter_content(block_size): 98 | file.write(data) 99 | sha256.update(data) 100 | progress_bar.update(len(data)) 101 | 102 | downloaded_sha256 = sha256.hexdigest() 103 | if downloaded_sha256 != expected_sha256: 104 | raise ValueError( 105 | f"Expected SHA-256 hash {expected_sha256}, but got {downloaded_sha256}" 106 | ) 107 | 108 | os.rename(temp_path, full_path) 109 | 110 | def __init__(self, model: str): 111 | if not _llama_imported: 112 | raise ImportError( 113 | "llama_cpp is not installed, try `pip install scikit-llm[llama_cpp]`" 114 | ) 115 | self.lock = threading.Lock() 116 | if model not in supported_models: 117 | raise ValueError(f"Model {model} is not supported.") 118 | download_url = supported_models[model]["download_url"] 119 | sha256 = supported_models[model]["sha256"] 120 | n_ctx = supported_models[model]["n_ctx"] 121 | self.supports_system_message = supported_models[model][ 122 | "supports_system_message" 123 | ] 124 | if not self.supports_system_message: 125 | warn( 126 | f"The model {model} does not support system messages. This may cause issues with some estimators." 127 | ) 128 | extended_model_name = model + "-" + sha256[:8] 129 | model_path = self.maybe_download_model( 130 | extended_model_name, download_url, sha256 131 | ) 132 | max_gpu_layers = SKLLMConfig.get_gguf_max_gpu_layers() 133 | verbose = SKLLMConfig.get_gguf_verbose() 134 | self.model = _Llama( 135 | model_path=model_path, 136 | n_ctx=n_ctx, 137 | verbose=verbose, 138 | n_gpu_layers=max_gpu_layers, 139 | ) 140 | 141 | def get_chat_completion(self, messages: dict, **kwargs): 142 | if not self.supports_system_message: 143 | messages = [m for m in messages if m["role"] != "system"] 144 | with self.lock: 145 | return self.model.create_chat_completion( 146 | messages, temperature=0.0, **kwargs 147 | ) 148 | 149 | 150 | class ModelCache: 151 | lock = threading.Lock() 152 | cache: dict[str, LlamaHandler] = {} 153 | 154 | @classmethod 155 | def get(cls, key) -> Optional[LlamaHandler]: 156 | return cls.cache.get(key, None) 157 | 158 | @classmethod 159 | def store(cls, key, value): 160 | cls.cache[key] = value 161 | 162 | @classmethod 163 | def clear(cls): 164 | cls.cache = {} 165 | -------------------------------------------------------------------------------- /skllm/llm/gpt/clients/openai/completion.py: -------------------------------------------------------------------------------- 1 | import openai 2 | from openai import OpenAI 3 | from skllm.llm.gpt.clients.openai.credentials import ( 4 | set_azure_credentials, 5 | set_credentials, 6 | ) 7 | from skllm.utils import retry 8 | 9 | 10 | @retry(max_retries=3) 11 | def get_chat_completion( 12 | messages: dict, 13 | key: str, 14 | org: str, 15 | model: str = "gpt-3.5-turbo", 16 | api="openai", 17 | json_response=False, 18 | ): 19 | """Gets a chat completion from the OpenAI API. 20 | 21 | Parameters 22 | ---------- 23 | messages : dict 24 | input messages to use. 25 | key : str 26 | The OPEN AI key to use. 27 | org : str 28 | The OPEN AI organization ID to use. 29 | model : str, optional 30 | The OPEN AI model to use. Defaults to "gpt-3.5-turbo". 31 | max_retries : int, optional 32 | The maximum number of retries to use. Defaults to 3. 33 | api : str 34 | The API to use. Must be one of "openai" or "azure". Defaults to "openai". 35 | 36 | Returns 37 | ------- 38 | completion : dict 39 | """ 40 | if api in ("openai", "custom_url"): 41 | client = set_credentials(key, org) 42 | elif api == "azure": 43 | client = set_azure_credentials(key, org) 44 | else: 45 | raise ValueError("Invalid API") 46 | model_dict = {"model": model} 47 | if json_response and api == "openai": 48 | model_dict["response_format"] = {"type": "json_object"} 49 | completion = client.chat.completions.create( 50 | temperature=0.0, messages=messages, **model_dict 51 | ) 52 | return completion 53 | -------------------------------------------------------------------------------- /skllm/llm/gpt/clients/openai/credentials.py: -------------------------------------------------------------------------------- 1 | import openai 2 | from skllm.config import SKLLMConfig as _Config 3 | from time import sleep 4 | from openai import OpenAI, AzureOpenAI 5 | from skllm.config import SKLLMConfig as _Config 6 | 7 | 8 | def set_credentials(key: str, org: str) -> None: 9 | """Set the OpenAI key and organization. 10 | 11 | Parameters 12 | ---------- 13 | key : str 14 | The OpenAI key to use. 15 | org : str 16 | The OPEN AI organization ID to use. 17 | """ 18 | url = _Config.get_gpt_url() 19 | client = OpenAI(api_key=key, organization=org, base_url=url) 20 | return client 21 | 22 | 23 | def set_azure_credentials(key: str, org: str) -> None: 24 | """Sets OpenAI credentials for Azure. 25 | 26 | Parameters 27 | ---------- 28 | key : str 29 | The OpenAI (Azure) key to use. 30 | org : str 31 | The OpenAI (Azure) organization ID to use. 32 | """ 33 | client = AzureOpenAI( 34 | api_key=key, 35 | organization=org, 36 | api_version=_Config.get_azure_api_version(), 37 | azure_endpoint=_Config.get_azure_api_base(), 38 | ) 39 | return client 40 | -------------------------------------------------------------------------------- /skllm/llm/gpt/clients/openai/embedding.py: -------------------------------------------------------------------------------- 1 | from skllm.llm.gpt.clients.openai.credentials import set_credentials, set_azure_credentials 2 | from skllm.utils import retry 3 | import openai 4 | from openai import OpenAI 5 | 6 | 7 | @retry(max_retries=3) 8 | def get_embedding( 9 | text: str, 10 | key: str, 11 | org: str, 12 | model: str = "text-embedding-ada-002", 13 | api: str = "openai" 14 | ): 15 | """ 16 | Encodes a string and return the embedding for a string. 17 | 18 | Parameters 19 | ---------- 20 | text : str 21 | The string to encode. 22 | key : str 23 | The OPEN AI key to use. 24 | org : str 25 | The OPEN AI organization ID to use. 26 | model : str, optional 27 | The model to use. Defaults to "text-embedding-ada-002". 28 | max_retries : int, optional 29 | The maximum number of retries to use. Defaults to 3. 30 | api: str, optional 31 | The API to use. Must be one of "openai" or "azure". Defaults to "openai". 32 | 33 | Returns 34 | ------- 35 | emb : list 36 | The GPT embedding for the string. 37 | """ 38 | if api in ("openai", "custom_url"): 39 | client = set_credentials(key, org) 40 | elif api == "azure": 41 | client = set_azure_credentials(key, org) 42 | text = [str(t).replace("\n", " ") for t in text] 43 | embeddings = [] 44 | emb = client.embeddings.create(input=text, model=model) 45 | for i in range(len(emb.data)): 46 | e = emb.data[i].embedding 47 | if not isinstance(e, list): 48 | raise ValueError( 49 | f"Encountered unknown embedding format. Expected list, got {type(emb)}" 50 | ) 51 | embeddings.append(e) 52 | return embeddings 53 | -------------------------------------------------------------------------------- /skllm/llm/gpt/clients/openai/tuning.py: -------------------------------------------------------------------------------- 1 | from typing import Optional, Callable 2 | from time import sleep 3 | from datetime import datetime 4 | import os 5 | 6 | 7 | def create_tuning_job( 8 | client: Callable, 9 | model: str, 10 | training_file: str, 11 | n_epochs: Optional[str] = None, 12 | suffix: Optional[str] = None, 13 | ): 14 | out = client.files.create(file=open(training_file, "rb"), purpose="fine-tune") 15 | out_id = out.id 16 | print(f"Created new file. FILE_ID = {out_id}") 17 | print(f"Waiting for file to be processed ...") 18 | while not wait_file_ready(client, out_id): 19 | sleep(5) 20 | # delete the training_file after it is uploaded 21 | os.remove(training_file) 22 | params = { 23 | "model": model, 24 | "training_file": out_id, 25 | } 26 | if n_epochs is not None: 27 | params["hyperparameters"] = {"n_epochs": n_epochs} 28 | if suffix is not None: 29 | params["suffix"] = suffix 30 | return client.fine_tuning.jobs.create(**params) 31 | 32 | 33 | def await_results(client: Callable, job_id: str, check_interval: int = 120): 34 | while True: 35 | job = client.fine_tuning.jobs.retrieve(job_id) 36 | status = job.status 37 | if status == "succeeded": 38 | return job 39 | elif status == "failed" or status == "cancelled": 40 | print(job) 41 | raise RuntimeError(f"Tuning job failed with status {status}") 42 | else: 43 | now = datetime.now() 44 | print( 45 | f"[{now}] Waiting for tuning job to complete. Current status: {status}" 46 | ) 47 | sleep(check_interval) 48 | 49 | 50 | def delete_file(client: Callable, file_id: str): 51 | client.files.delete(file_id) 52 | 53 | 54 | def wait_file_ready(client: Callable, file_id): 55 | files = client.files.list().data 56 | found = False 57 | for file in files: 58 | if file.id == file_id: 59 | found = True 60 | if file.status == "processed": 61 | return True 62 | elif file.status in ["error", "deleting", "deleted"]: 63 | print(file) 64 | raise RuntimeError( 65 | f"File upload {file_id} failed with status {file.status}" 66 | ) 67 | else: 68 | return False 69 | if not found: 70 | raise RuntimeError(f"File {file_id} not found") 71 | -------------------------------------------------------------------------------- /skllm/llm/gpt/completion.py: -------------------------------------------------------------------------------- 1 | import warnings 2 | from skllm.llm.gpt.clients.openai.completion import ( 3 | get_chat_completion as _oai_get_chat_completion, 4 | ) 5 | from skllm.llm.gpt.clients.llama_cpp.completion import ( 6 | get_chat_completion as _llamacpp_get_chat_completion, 7 | ) 8 | from skllm.llm.gpt.utils import split_to_api_and_model 9 | from skllm.config import SKLLMConfig as _Config 10 | 11 | 12 | def get_chat_completion( 13 | messages: dict, 14 | openai_key: str = None, 15 | openai_org: str = None, 16 | model: str = "gpt-3.5-turbo", 17 | json_response: bool = False, 18 | ): 19 | """Gets a chat completion from the OpenAI compatible API.""" 20 | api, model = split_to_api_and_model(model) 21 | if api == "gguf": 22 | return _llamacpp_get_chat_completion(messages, model) 23 | else: 24 | url = _Config.get_gpt_url() 25 | if api == "openai" and url is not None: 26 | warnings.warn( 27 | f"You are using the OpenAI backend with a custom URL: {url}; did you mean to use the `custom_url` backend?\nTo use the OpenAI backend, please remove the custom URL using `SKLLMConfig.reset_gpt_url()`." 28 | ) 29 | elif api == "custom_url" and url is None: 30 | raise ValueError( 31 | "You are using the `custom_url` backend but no custom URL was provided. Please set it using `SKLLMConfig.set_gpt_url()`." 32 | ) 33 | return _oai_get_chat_completion( 34 | messages, 35 | openai_key, 36 | openai_org, 37 | model, 38 | api=api, 39 | json_response=json_response, 40 | ) 41 | -------------------------------------------------------------------------------- /skllm/llm/gpt/embedding.py: -------------------------------------------------------------------------------- 1 | from skllm.llm.gpt.clients.openai.embedding import get_embedding as _oai_get_embedding 2 | from skllm.llm.gpt.utils import split_to_api_and_model 3 | 4 | def get_embedding( 5 | text: str, 6 | key: str, 7 | org: str, 8 | model: str = "text-embedding-ada-002", 9 | ): 10 | """ 11 | Encodes a string and return the embedding for a string. 12 | 13 | Parameters 14 | ---------- 15 | text : str 16 | The string to encode. 17 | key : str 18 | The OPEN AI key to use. 19 | org : str 20 | The OPEN AI organization ID to use. 21 | model : str, optional 22 | The model to use. Defaults to "text-embedding-ada-002". 23 | 24 | Returns 25 | ------- 26 | emb : list 27 | The GPT embedding for the string. 28 | """ 29 | api, model = split_to_api_and_model(model) 30 | if api == ("gpt4all"): 31 | raise ValueError("GPT4All is not supported for embeddings") 32 | return _oai_get_embedding(text, key, org, model, api=api) 33 | -------------------------------------------------------------------------------- /skllm/llm/gpt/mixin.py: -------------------------------------------------------------------------------- 1 | from typing import Optional, Union, List, Any, Dict, Mapping 2 | from skllm.config import SKLLMConfig as _Config 3 | from skllm.llm.gpt.completion import get_chat_completion 4 | from skllm.llm.gpt.embedding import get_embedding 5 | from skllm.llm.base import ( 6 | BaseClassifierMixin, 7 | BaseEmbeddingMixin, 8 | BaseTextCompletionMixin, 9 | BaseTunableMixin, 10 | ) 11 | from skllm.llm.gpt.clients.openai.tuning import ( 12 | create_tuning_job, 13 | await_results, 14 | delete_file, 15 | ) 16 | import uuid 17 | from skllm.llm.gpt.clients.openai.credentials import ( 18 | set_credentials as _set_credentials_openai, 19 | ) 20 | from skllm.utils import extract_json_key 21 | import numpy as np 22 | from tqdm import tqdm 23 | import json 24 | 25 | 26 | def construct_message(role: str, content: str) -> dict: 27 | """Constructs a message for the OpenAI API. 28 | 29 | Parameters 30 | ---------- 31 | role : str 32 | The role of the message. Must be one of "system", "user", or "assistant". 33 | content : str 34 | The content of the message. 35 | 36 | Returns 37 | ------- 38 | message : dict 39 | """ 40 | if role not in ("system", "user", "assistant"): 41 | raise ValueError("Invalid role") 42 | return {"role": role, "content": content} 43 | 44 | 45 | def _build_clf_example( 46 | x: str, y: str, system_msg="You are a text classification model." 47 | ): 48 | sample = { 49 | "messages": [ 50 | {"role": "system", "content": system_msg}, 51 | {"role": "user", "content": x}, 52 | {"role": "assistant", "content": y}, 53 | ] 54 | } 55 | return json.dumps(sample) 56 | 57 | 58 | class GPTMixin: 59 | """ 60 | A mixin class that provides OpenAI key and organization to other classes. 61 | """ 62 | 63 | _prefer_json_output = False 64 | 65 | def _set_keys(self, key: Optional[str] = None, org: Optional[str] = None) -> None: 66 | """ 67 | Set the OpenAI key and organization. 68 | """ 69 | 70 | self.key = key 71 | self.org = org 72 | 73 | def _get_openai_key(self) -> str: 74 | """ 75 | Get the OpenAI key from the class or the config file. 76 | 77 | Returns 78 | ------- 79 | key: str 80 | """ 81 | key = self.key 82 | if key is None: 83 | key = _Config.get_openai_key() 84 | if ( 85 | hasattr(self, "model") 86 | and isinstance(self.model, str) 87 | and self.model.startswith("gguf::") 88 | ): 89 | key = "gguf" 90 | if key is None: 91 | raise RuntimeError("OpenAI key was not found") 92 | return key 93 | 94 | def _get_openai_org(self) -> str: 95 | """ 96 | Get the OpenAI organization ID from the class or the config file. 97 | 98 | Returns 99 | ------- 100 | org: str 101 | """ 102 | org = self.org 103 | if org is None: 104 | org = _Config.get_openai_org() 105 | if org is None: 106 | raise RuntimeError("OpenAI organization was not found") 107 | return org 108 | 109 | 110 | class GPTTextCompletionMixin(GPTMixin, BaseTextCompletionMixin): 111 | def _get_chat_completion( 112 | self, 113 | model: str, 114 | messages: Union[str, List[Dict[str, str]]], 115 | system_message: Optional[str] = None, 116 | **kwargs: Any, 117 | ): 118 | """Gets a chat completion from the OpenAI API. 119 | 120 | Parameters 121 | ---------- 122 | model : str 123 | The model to use. 124 | messages : Union[str, List[Dict[str, str]]] 125 | input messages to use. 126 | system_message : Optional[str] 127 | A system message to use. 128 | **kwargs : Any 129 | placeholder. 130 | 131 | Returns 132 | ------- 133 | completion : dict 134 | """ 135 | msgs = [] 136 | if system_message is not None: 137 | msgs.append(construct_message("system", system_message)) 138 | if isinstance(messages, str): 139 | msgs.append(construct_message("user", messages)) 140 | else: 141 | for message in messages: 142 | msgs.append(construct_message(message["role"], message["content"])) 143 | completion = get_chat_completion( 144 | msgs, 145 | self._get_openai_key(), 146 | self._get_openai_org(), 147 | model, 148 | json_response=self._prefer_json_output, 149 | ) 150 | return completion 151 | 152 | def _convert_completion_to_str(self, completion: Mapping[str, Any]): 153 | if hasattr(completion, "__getitem__"): 154 | return str(completion["choices"][0]["message"]["content"]) 155 | return str(completion.choices[0].message.content) 156 | 157 | 158 | class GPTClassifierMixin(GPTTextCompletionMixin, BaseClassifierMixin): 159 | _prefer_json_output = True 160 | 161 | def _extract_out_label(self, completion: Mapping[str, Any], **kwargs) -> Any: 162 | """Extracts the label from a completion. 163 | 164 | Parameters 165 | ---------- 166 | label : Mapping[str, Any] 167 | The label to extract. 168 | 169 | Returns 170 | ------- 171 | label : str 172 | """ 173 | try: 174 | if hasattr(completion, "__getitem__"): 175 | label = extract_json_key( 176 | completion["choices"][0]["message"]["content"], "label" 177 | ) 178 | else: 179 | label = extract_json_key(completion.choices[0].message.content, "label") 180 | except Exception as e: 181 | print(completion) 182 | print(f"Could not extract the label from the completion: {str(e)}") 183 | label = "" 184 | return label 185 | 186 | 187 | class GPTEmbeddingMixin(GPTMixin, BaseEmbeddingMixin): 188 | def _get_embeddings(self, text: np.ndarray) -> List[List[float]]: 189 | """Gets embeddings from the OpenAI compatible API. 190 | 191 | Parameters 192 | ---------- 193 | text : str 194 | The text to embed. 195 | model : str 196 | The model to use. 197 | batch_size : int, optional 198 | The batch size to use. Defaults to 1. 199 | 200 | Returns 201 | ------- 202 | embedding : List[List[float]] 203 | """ 204 | embeddings = [] 205 | print("Batch size:", self.batch_size) 206 | for i in tqdm(range(0, len(text), self.batch_size)): 207 | batch = text[i : i + self.batch_size].tolist() 208 | embeddings.extend( 209 | get_embedding( 210 | batch, 211 | self._get_openai_key(), 212 | self._get_openai_org(), 213 | self.model, 214 | ) 215 | ) 216 | 217 | return embeddings 218 | 219 | 220 | # for now this works only with OpenAI 221 | class GPTTunableMixin(BaseTunableMixin): 222 | _supported_tunable_models = [ 223 | "gpt-3.5-turbo-0125", 224 | "gpt-3.5-turbo", 225 | "gpt-4o-mini-2024-07-18", 226 | "gpt-4o-mini", 227 | ] 228 | 229 | def _build_label(self, label: str): 230 | return json.dumps({"label": label}) 231 | 232 | def _set_hyperparameters(self, base_model: str, n_epochs: int, custom_suffix: str): 233 | self.base_model = base_model 234 | self.n_epochs = n_epochs 235 | self.custom_suffix = custom_suffix 236 | if base_model not in self._supported_tunable_models: 237 | raise ValueError( 238 | f"Model {base_model} is not supported. Supported models are" 239 | f" {self._supported_tunable_models}" 240 | ) 241 | 242 | def _tune(self, X, y): 243 | if self.base_model.startswith(("azure::", "gpt4all")): 244 | raise ValueError( 245 | "Tuning is not supported for this model. Please use a different model." 246 | ) 247 | file_uuid = str(uuid.uuid4()) 248 | filename = f"skllm_{file_uuid}.jsonl" 249 | with open(filename, "w+") as f: 250 | for xi, yi in zip(X, y): 251 | prompt = self._get_prompt(xi) 252 | if not isinstance(prompt["messages"], str): 253 | raise ValueError( 254 | "Incompatible prompt. Use a prompt with a single message." 255 | ) 256 | f.write( 257 | _build_clf_example( 258 | prompt["messages"], 259 | self._build_label(yi), 260 | prompt["system_message"], 261 | ) 262 | ) 263 | f.write("\n") 264 | client = _set_credentials_openai(self._get_openai_key(), self._get_openai_org()) 265 | job = create_tuning_job( 266 | client, 267 | self.base_model, 268 | filename, 269 | self.n_epochs, 270 | self.custom_suffix, 271 | ) 272 | print(f"Created new tuning job. JOB_ID = {job.id}") 273 | job = await_results(client, job.id) 274 | self.openai_model = job.fine_tuned_model 275 | self.model = self.openai_model # openai_model is probably not required anymore 276 | delete_file(client, job.training_file) 277 | print(f"Finished training.") 278 | -------------------------------------------------------------------------------- /skllm/llm/gpt/utils.py: -------------------------------------------------------------------------------- 1 | from typing import Tuple 2 | 3 | SUPPORTED_APIS = ["openai", "azure", "gguf", "custom_url"] 4 | 5 | 6 | def split_to_api_and_model(model: str) -> Tuple[str, str]: 7 | if "::" not in model: 8 | return "openai", model 9 | for api in SUPPORTED_APIS: 10 | if model.startswith(f"{api}::"): 11 | return api, model[len(api) + 2 :] 12 | raise ValueError(f"Unsupported API: {model.split('::')[0]}") 13 | -------------------------------------------------------------------------------- /skllm/llm/vertex/completion.py: -------------------------------------------------------------------------------- 1 | from skllm.utils import retry 2 | from vertexai.language_models import ChatModel, TextGenerationModel 3 | from vertexai.generative_models import GenerativeModel, GenerationConfig 4 | 5 | 6 | @retry(max_retries=3) 7 | def get_completion(model: str, text: str): 8 | if model.startswith("text-"): 9 | model_instance = TextGenerationModel.from_pretrained(model) 10 | else: 11 | model_instance = TextGenerationModel.get_tuned_model(model) 12 | response = model_instance.predict(text, temperature=0.0) 13 | return response.text 14 | 15 | 16 | @retry(max_retries=3) 17 | def get_completion_chat_mode(model: str, context: str, text: str): 18 | model_instance = ChatModel.from_pretrained(model) 19 | chat = model_instance.start_chat(context=context) 20 | response = chat.send_message(text, temperature=0.0) 21 | return response.text 22 | 23 | 24 | @retry(max_retries=3) 25 | def get_completion_chat_gemini(model: str, context: str, text: str): 26 | model_instance = GenerativeModel(model, system_instruction=context) 27 | response = model_instance.generate_content( 28 | text, generation_config=GenerationConfig(temperature=0.0) 29 | ) 30 | return response.text 31 | -------------------------------------------------------------------------------- /skllm/llm/vertex/mixin.py: -------------------------------------------------------------------------------- 1 | from typing import Optional, Union, List, Any, Dict 2 | from skllm.llm.base import ( 3 | BaseClassifierMixin, 4 | BaseEmbeddingMixin, 5 | BaseTextCompletionMixin, 6 | BaseTunableMixin, 7 | ) 8 | from skllm.llm.vertex.tuning import tune 9 | from skllm.llm.vertex.completion import get_completion_chat_mode, get_completion, get_completion_chat_gemini 10 | from skllm.utils import extract_json_key 11 | import numpy as np 12 | import pandas as pd 13 | 14 | 15 | class VertexMixin: 16 | pass 17 | 18 | 19 | class VertexTextCompletionMixin(BaseTextCompletionMixin): 20 | def _get_chat_completion( 21 | self, 22 | model: str, 23 | messages: Union[str, List[Dict[str, str]]], 24 | system_message: Optional[str], 25 | examples: Optional[List] = None, 26 | ) -> str: 27 | if examples is not None: 28 | raise NotImplementedError( 29 | "Examples API is not yet supported for Vertex AI." 30 | ) 31 | if not isinstance(messages, str): 32 | raise ValueError("Only messages as strings are supported.") 33 | if model.startswith("chat-"): 34 | completion = get_completion_chat_mode(model, system_message, messages) 35 | elif model.startswith("gemini-"): 36 | completion = get_completion_chat_gemini(model, system_message, messages) 37 | else: 38 | completion = get_completion(model, messages) 39 | return str(completion) 40 | 41 | def _convert_completion_to_str(self, completion: str) -> str: 42 | return completion 43 | 44 | 45 | class VertexClassifierMixin(BaseClassifierMixin, VertexTextCompletionMixin): 46 | def _extract_out_label(self, completion: str, **kwargs) -> Any: 47 | """Extracts the label from a completion. 48 | 49 | Parameters 50 | ---------- 51 | label : Mapping[str, Any] 52 | The label to extract. 53 | 54 | Returns 55 | ------- 56 | label : str 57 | """ 58 | try: 59 | label = extract_json_key(str(completion), "label") 60 | except Exception as e: 61 | print(completion) 62 | print(f"Could not extract the label from the completion: {str(e)}") 63 | label = "" 64 | return label 65 | 66 | 67 | class VertexEmbeddingMixin(BaseEmbeddingMixin): 68 | def _get_embeddings(self, text: np.ndarray) -> List[List[float]]: 69 | raise NotImplementedError("Embeddings are not yet supported for Vertex AI.") 70 | 71 | 72 | class VertexTunableMixin(BaseTunableMixin): 73 | _supported_tunable_models = ["text-bison@002"] 74 | 75 | def _set_hyperparameters(self, base_model: str, n_update_steps: int, **kwargs): 76 | self.verify_model_is_supported(base_model) 77 | self.base_model = base_model 78 | self.n_update_steps = n_update_steps 79 | 80 | def verify_model_is_supported(self, model: str): 81 | if model not in self._supported_tunable_models: 82 | raise ValueError( 83 | f"Model {model} is not supported. Supported models are" 84 | f" {self._supported_tunable_models}" 85 | ) 86 | 87 | def _tune(self, X: Any, y: Any): 88 | df = pd.DataFrame({"input_text": X, "output_text": y}) 89 | job = tune(self.base_model, df, self.n_update_steps)._job 90 | tuned_model = job.result() 91 | self.tuned_model_ = tuned_model._model_resource_name 92 | self.model = tuned_model._model_resource_name 93 | return self 94 | -------------------------------------------------------------------------------- /skllm/llm/vertex/tuning.py: -------------------------------------------------------------------------------- 1 | from pandas import DataFrame 2 | from vertexai.language_models import TextGenerationModel 3 | 4 | 5 | def tune(model: str, data: DataFrame, train_steps: int = 100): 6 | model = TextGenerationModel.from_pretrained(model) 7 | model.tune_model( 8 | training_data=data, 9 | train_steps=train_steps, 10 | tuning_job_location="europe-west4", # the only supported training location atm 11 | tuned_model_location="us-central1", # the only supported deployment location atm 12 | ) 13 | return model 14 | -------------------------------------------------------------------------------- /skllm/memory/__init__.py: -------------------------------------------------------------------------------- 1 | from skllm.memory._sklearn_nn import SklearnMemoryIndex 2 | -------------------------------------------------------------------------------- /skllm/memory/_annoy.py: -------------------------------------------------------------------------------- 1 | import os 2 | import tempfile 3 | from typing import Any, List 4 | 5 | try: 6 | from annoy import AnnoyIndex 7 | except (ImportError, ModuleNotFoundError): 8 | AnnoyIndex = None 9 | 10 | from numpy import ndarray 11 | 12 | from skllm.memory.base import _BaseMemoryIndex 13 | 14 | 15 | class AnnoyMemoryIndex(_BaseMemoryIndex): 16 | """Memory index using Annoy. 17 | 18 | Parameters 19 | ---------- 20 | dim : int 21 | dimensionality of the vectors 22 | metric : str, optional 23 | metric to use, by default "euclidean" 24 | """ 25 | 26 | def __init__(self, dim: int = -1, metric: str = "euclidean", **kwargs: Any) -> None: 27 | if AnnoyIndex is None: 28 | raise ImportError( 29 | "Annoy is not installed. Please install annoy by running `pip install" 30 | " scikit-llm[annoy]`." 31 | ) 32 | self.metric = metric 33 | self.dim = dim 34 | self.built = False 35 | self._index = None 36 | self._counter = 0 37 | 38 | def add(self, vector: ndarray) -> None: 39 | """Adds a vector to the index. 40 | 41 | Parameters 42 | ---------- 43 | vector : ndarray 44 | vector to add to the index 45 | """ 46 | if self.built: 47 | raise RuntimeError("Cannot add vectors after index is built.") 48 | if self.dim < 0: 49 | raise ValueError("Dimensionality must be positive.") 50 | if not self._index: 51 | self._index = AnnoyIndex(self.dim, self.metric) 52 | id = self._counter 53 | self._index.add_item(id, vector) 54 | self._counter += 1 55 | 56 | def build(self) -> None: 57 | """Builds the index. 58 | 59 | No new vectors can be added after building. 60 | """ 61 | if self.dim < 0: 62 | raise ValueError("Dimensionality must be positive.") 63 | self._index.build(-1) 64 | self.built = True 65 | 66 | def retrieve(self, vectors: ndarray, k: int) -> List[List[int]]: 67 | """Retrieves the k nearest neighbors for each vector. 68 | 69 | Parameters 70 | ---------- 71 | vectors : ndarray 72 | vectors to retrieve nearest neighbors for 73 | k : int 74 | number of nearest neighbors to retrieve 75 | 76 | Returns 77 | ------- 78 | List 79 | ids of retrieved nearest neighbors 80 | """ 81 | if not self.built: 82 | raise RuntimeError("Cannot retrieve vectors before the index is built.") 83 | return [ 84 | self._index.get_nns_by_vector(v, k, search_k=-1, include_distances=False) 85 | for v in vectors 86 | ] 87 | 88 | def __getstate__(self) -> dict: 89 | """Returns the state of the object. To store the actual annoy index, it 90 | has to be written to a temporary file. 91 | 92 | Returns 93 | ------- 94 | dict 95 | state of the object 96 | """ 97 | state = self.__dict__.copy() 98 | 99 | # save index to temporary file 100 | with tempfile.NamedTemporaryFile(delete=False) as tmp: 101 | temp_filename = tmp.name 102 | self._index.save(temp_filename) 103 | 104 | # read bytes from the file 105 | with open(temp_filename, "rb") as tmp: 106 | index_bytes = tmp.read() 107 | 108 | # store bytes representation in state 109 | state["_index"] = index_bytes 110 | 111 | # remove temporary file 112 | os.remove(temp_filename) 113 | 114 | return state 115 | 116 | def __setstate__(self, state: dict) -> None: 117 | """Sets the state of the object. It restores the annoy index from the 118 | bytes representation. 119 | 120 | Parameters 121 | ---------- 122 | state : dict 123 | state of the object 124 | """ 125 | self.__dict__.update(state) 126 | # restore index from bytes 127 | with tempfile.NamedTemporaryFile(delete=False) as tmp: 128 | temp_filename = tmp.name 129 | tmp.write(self._index) 130 | 131 | self._index = AnnoyIndex(self.dim, self.metric) 132 | self._index.load(temp_filename) 133 | 134 | # remove temporary file 135 | os.remove(temp_filename) 136 | -------------------------------------------------------------------------------- /skllm/memory/_sklearn_nn.py: -------------------------------------------------------------------------------- 1 | from typing import Any, List 2 | 3 | import numpy as np 4 | from sklearn.neighbors import NearestNeighbors 5 | 6 | from skllm.memory.base import _BaseMemoryIndex 7 | 8 | 9 | class SklearnMemoryIndex(_BaseMemoryIndex): 10 | """Memory index using Sklearn's NearestNeighbors. 11 | 12 | Parameters 13 | ---------- 14 | dim : int 15 | dimensionality of the vectors 16 | metric : str, optional 17 | metric to use, by default "euclidean" 18 | """ 19 | 20 | def __init__(self, dim: int = -1, metric: str = "euclidean", **kwargs: Any) -> None: 21 | self._index = NearestNeighbors(metric=metric, **kwargs) 22 | self.metric = metric 23 | self.dim = dim 24 | self.built = False 25 | self.data = [] 26 | 27 | def add(self, vector: np.ndarray) -> None: 28 | """Adds a vector to the index. 29 | 30 | Parameters 31 | ---------- 32 | vector : np.ndarray 33 | vector to add to the index 34 | """ 35 | if self.built: 36 | raise RuntimeError("Cannot add vectors after index is built.") 37 | self.data.append(vector) 38 | 39 | def build(self) -> None: 40 | """Builds the index. 41 | 42 | No new vectors can be added after building. 43 | """ 44 | data_matrix = np.array(self.data) 45 | self._index.fit(data_matrix) 46 | self.built = True 47 | 48 | def retrieve(self, vectors: np.ndarray, k: int) -> List[List[int]]: 49 | """Retrieves the k nearest neighbors for each vector. 50 | 51 | Parameters 52 | ---------- 53 | vectors : np.ndarray 54 | vectors to retrieve nearest neighbors for 55 | k : int 56 | number of nearest neighbors to retrieve 57 | 58 | Returns 59 | ------- 60 | List 61 | ids of retrieved nearest neighbors 62 | """ 63 | if not self.built: 64 | raise RuntimeError("Cannot retrieve vectors before the index is built.") 65 | _, indices = self._index.kneighbors(vectors, n_neighbors=k) 66 | return indices.tolist() 67 | -------------------------------------------------------------------------------- /skllm/memory/base.py: -------------------------------------------------------------------------------- 1 | from abc import ABC, abstractmethod 2 | from typing import Any, List, Type 3 | 4 | from numpy import ndarray 5 | 6 | 7 | class _BaseMemoryIndex(ABC): 8 | @abstractmethod 9 | def add(self, id: Any, vector: ndarray): 10 | """Adds a vector to the index. 11 | 12 | Parameters 13 | ---------- 14 | id : Any 15 | identifier for the vector 16 | vector : ndarray 17 | vector to add to the index 18 | """ 19 | pass 20 | 21 | @abstractmethod 22 | def retrieve(self, vectors: ndarray, k: int) -> List: 23 | """Retrieves the k nearest neighbors for each vector. 24 | 25 | Parameters 26 | ---------- 27 | vectors : ndarray 28 | vectors to retrieve nearest neighbors for 29 | k : int 30 | number of nearest neighbors to retrieve 31 | 32 | Returns 33 | ------- 34 | List 35 | ids of retrieved nearest neighbors 36 | """ 37 | pass 38 | 39 | @abstractmethod 40 | def build(self) -> None: 41 | """Builds the index. 42 | 43 | All build parameters should be passed to the constructor. 44 | """ 45 | pass 46 | 47 | 48 | class IndexConstructor: 49 | def __init__(self, index: Type[_BaseMemoryIndex], **kwargs: Any) -> None: 50 | self.index = index 51 | self.kwargs = kwargs 52 | 53 | def __call__(self) -> _BaseMemoryIndex: 54 | return self.index(**self.kwargs) 55 | -------------------------------------------------------------------------------- /skllm/models/_base/classifier.py: -------------------------------------------------------------------------------- 1 | from typing import Any, List, Optional, Union 2 | from abc import ABC, abstractmethod 3 | from sklearn.base import ( 4 | BaseEstimator as _SklBaseEstimator, 5 | ClassifierMixin as _SklClassifierMixin, 6 | ) 7 | import warnings 8 | import numpy as np 9 | import pandas as pd 10 | from tqdm import tqdm 11 | from concurrent.futures import ThreadPoolExecutor 12 | import random 13 | from collections import Counter 14 | from skllm.llm.base import ( 15 | BaseClassifierMixin as _BaseClassifierMixin, 16 | BaseTunableMixin as _BaseTunableMixin, 17 | ) 18 | from skllm.utils import to_numpy as _to_numpy 19 | from skllm.prompts.templates import ( 20 | ZERO_SHOT_CLF_PROMPT_TEMPLATE, 21 | ZERO_SHOT_MLCLF_PROMPT_TEMPLATE, 22 | FEW_SHOT_CLF_PROMPT_TEMPLATE, 23 | FEW_SHOT_MLCLF_PROMPT_TEMPLATE, 24 | ZERO_SHOT_CLF_SHORT_PROMPT_TEMPLATE, 25 | ZERO_SHOT_MLCLF_SHORT_PROMPT_TEMPLATE, 26 | COT_CLF_PROMPT_TEMPLATE, 27 | COT_MLCLF_PROMPT_TEMPLATE, 28 | ) 29 | from skllm.prompts.builders import ( 30 | build_zero_shot_prompt_slc, 31 | build_zero_shot_prompt_mlc, 32 | build_few_shot_prompt_slc, 33 | build_few_shot_prompt_mlc, 34 | ) 35 | from skllm.memory.base import IndexConstructor 36 | from skllm.memory._sklearn_nn import SklearnMemoryIndex 37 | from skllm.models._base.vectorizer import BaseVectorizer as _BaseVectorizer 38 | from skllm.utils import re_naive_json_extractor 39 | import json 40 | 41 | _TRAINING_SAMPLE_PROMPT_TEMPLATE = """ 42 | Sample input: 43 | ```{x}``` 44 | s 45 | Sample target: {label} 46 | """ 47 | 48 | 49 | class SingleLabelMixin: 50 | """Mixin class for single label classification.""" 51 | 52 | def validate_prediction(self, label: Any) -> str: 53 | """ 54 | Validates a prediction. 55 | 56 | Parameters 57 | ---------- 58 | label : str 59 | The label to validate. 60 | 61 | Returns 62 | ------- 63 | str 64 | The validated label. 65 | """ 66 | if label not in self.classes_: 67 | label = str(label).replace("'", "").replace('"', "") 68 | if label not in self.classes_: # try again 69 | label = self._get_default_label() 70 | return label 71 | 72 | def _extract_labels(self, y: Any) -> List[str]: 73 | """ 74 | Return the class labels as a list. 75 | 76 | Parameters 77 | ---------- 78 | y : Any 79 | 80 | Returns 81 | ------- 82 | List[str] 83 | """ 84 | if isinstance(y, (pd.Series, np.ndarray)): 85 | labels = y.tolist() 86 | else: 87 | labels = y 88 | return labels 89 | 90 | 91 | class MultiLabelMixin: 92 | """Mixin class for multi label classification.""" 93 | 94 | def validate_prediction(self, label: Any) -> List[str]: 95 | """ 96 | Validates a prediction. 97 | 98 | Parameters 99 | ---------- 100 | label : Any 101 | The label to validate. 102 | 103 | Returns 104 | ------- 105 | List[str] 106 | The validated label. 107 | """ 108 | if not isinstance(label, list): 109 | label = [] 110 | filtered_labels = [] 111 | for l in label: 112 | if l in self.classes_ and l: 113 | if l not in filtered_labels: 114 | filtered_labels.append(l) 115 | elif la := l.replace("'", "").replace('"', "") in self.classes_: 116 | if la not in filtered_labels: 117 | filtered_labels.append(la) 118 | else: 119 | default_label = self._get_default_label() 120 | if not ( 121 | self.default_label == "Random" and default_label in filtered_labels 122 | ): 123 | filtered_labels.append(default_label) 124 | filtered_labels.extend([""] * self.max_labels) 125 | return filtered_labels[: self.max_labels] 126 | 127 | def _extract_labels(self, y) -> List[str]: 128 | """Extracts the labels into a list. 129 | 130 | Parameters 131 | ---------- 132 | y : Any 133 | 134 | Returns 135 | ------- 136 | List[str] 137 | """ 138 | labels = [] 139 | for l in y: 140 | if isinstance(l, list): 141 | for j in l: 142 | labels.append(j) 143 | else: 144 | labels.append(l) 145 | return labels 146 | 147 | 148 | class BaseClassifier(ABC, _SklBaseEstimator, _SklClassifierMixin): 149 | system_msg = "You are a text classifier." 150 | 151 | def __init__( 152 | self, 153 | model: Optional[str], # model can initially be None for tunable estimators 154 | default_label: str = "Random", 155 | max_labels: Optional[int] = 5, 156 | prompt_template: Optional[str] = None, 157 | **kwargs, 158 | ): 159 | if not isinstance(self, _BaseClassifierMixin): 160 | raise TypeError( 161 | "Classifier must be mixed with a skllm.llm.base.BaseClassifierMixin" 162 | " class" 163 | ) 164 | if not isinstance(self, (SingleLabelMixin, MultiLabelMixin)): 165 | raise TypeError( 166 | "Classifier must be mixed with a SingleLabelMixin or MultiLabelMixin" 167 | " class" 168 | ) 169 | 170 | self.model = model 171 | if not isinstance(default_label, str): 172 | raise TypeError("default_label must be a string") 173 | self.default_label = default_label 174 | if isinstance(self, MultiLabelMixin): 175 | if not isinstance(max_labels, int): 176 | raise TypeError("max_labels must be an integer") 177 | if max_labels < 2: 178 | raise ValueError("max_labels must be greater than 1") 179 | self.max_labels = max_labels 180 | if prompt_template is not None and not isinstance(prompt_template, str): 181 | raise TypeError("prompt_template must be a string or None") 182 | self.prompt_template = prompt_template 183 | 184 | def _predict_single(self, x: Any) -> Any: 185 | prompt_dict = self._get_prompt(x) 186 | # this will be inherited from the LLM 187 | prediction = self._get_chat_completion(model=self.model, **prompt_dict) 188 | prediction = self._extract_out_label(prediction) 189 | # this will be inherited from the sl/ml mixin 190 | prediction = self.validate_prediction(prediction) 191 | return prediction 192 | 193 | @abstractmethod 194 | def _get_prompt(self, x: str) -> dict: 195 | """Returns the prompt to use for a single input.""" 196 | pass 197 | 198 | def fit( 199 | self, 200 | X: Optional[Union[np.ndarray, pd.Series, List[str]]], 201 | y: Union[np.ndarray, pd.Series, List[str], List[List[str]]], 202 | ): 203 | """ 204 | Fits the model to the given data. 205 | 206 | Parameters 207 | ---------- 208 | X : Optional[Union[np.ndarray, pd.Series, List[str]]] 209 | Training data 210 | y : Union[np.ndarray, pd.Series, List[str], List[List[str]]] 211 | Training labels 212 | 213 | Returns 214 | ------- 215 | BaseClassifier 216 | self 217 | """ 218 | X = _to_numpy(X) 219 | self.classes_, self.probabilities_ = self._get_unique_targets(y) 220 | return self 221 | 222 | def predict(self, X: Union[np.ndarray, pd.Series, List[str]], num_workers: int = 1): 223 | """ 224 | Predicts the class of each input. 225 | 226 | Parameters 227 | ---------- 228 | X : Union[np.ndarray, pd.Series, List[str]] 229 | The input data to predict the class of. 230 | 231 | num_workers : int 232 | number of workers to use for multithreaded prediction, default 1 233 | 234 | Returns 235 | ------- 236 | np.ndarray 237 | The predicted classes as a numpy array. 238 | """ 239 | X = _to_numpy(X) 240 | 241 | if num_workers > 1: 242 | warnings.warn( 243 | "Passing num_workers to predict is temporary and will be removed in the future." 244 | ) 245 | with ThreadPoolExecutor(max_workers=num_workers) as executor: 246 | predictions = list( 247 | tqdm(executor.map(self._predict_single, X), total=len(X)) 248 | ) 249 | else: 250 | predictions = [] 251 | for x in tqdm(X): 252 | predictions.append(self._predict_single(x)) 253 | 254 | return np.array(predictions) 255 | 256 | def _get_unique_targets(self, y: Any): 257 | labels = self._extract_labels(y) 258 | 259 | counts = Counter(labels) 260 | 261 | total = sum(counts.values()) 262 | 263 | classes, probs = [], [] 264 | for l, c in counts.items(): 265 | classes.append(l) 266 | probs.append(c / total) 267 | 268 | return classes, probs 269 | 270 | def _get_default_label(self): 271 | """Returns the default label based on the default_label argument.""" 272 | if self.default_label == "Random": 273 | return random.choices(self.classes_, self.probabilities_)[0] 274 | else: 275 | return self.default_label 276 | 277 | 278 | class BaseZeroShotClassifier(BaseClassifier): 279 | def _get_prompt_template(self) -> str: 280 | """Returns the prompt template to use for a single input.""" 281 | if self.prompt_template is not None: 282 | return self.prompt_template 283 | elif isinstance(self, SingleLabelMixin): 284 | return ZERO_SHOT_CLF_PROMPT_TEMPLATE 285 | return ZERO_SHOT_MLCLF_PROMPT_TEMPLATE 286 | 287 | def _get_prompt(self, x: str) -> dict: 288 | """Returns the prompt to use for a single input.""" 289 | if isinstance(self, SingleLabelMixin): 290 | prompt = build_zero_shot_prompt_slc( 291 | x, repr(self.classes_), template=self._get_prompt_template() 292 | ) 293 | else: 294 | prompt = build_zero_shot_prompt_mlc( 295 | x, 296 | repr(self.classes_), 297 | self.max_labels, 298 | template=self._get_prompt_template(), 299 | ) 300 | return {"messages": prompt, "system_message": self.system_msg} 301 | 302 | 303 | class BaseCoTClassifier(BaseClassifier): 304 | def _get_prompt_template(self) -> str: 305 | """Returns the prompt template to use for a single input.""" 306 | if self.prompt_template is not None: 307 | return self.prompt_template 308 | elif isinstance(self, SingleLabelMixin): 309 | return COT_CLF_PROMPT_TEMPLATE 310 | return COT_MLCLF_PROMPT_TEMPLATE 311 | 312 | def _get_prompt(self, x: str) -> dict: 313 | """Returns the prompt to use for a single input.""" 314 | if isinstance(self, SingleLabelMixin): 315 | prompt = build_zero_shot_prompt_slc( 316 | x, repr(self.classes_), template=self._get_prompt_template() 317 | ) 318 | else: 319 | prompt = build_zero_shot_prompt_mlc( 320 | x, 321 | repr(self.classes_), 322 | self.max_labels, 323 | template=self._get_prompt_template(), 324 | ) 325 | return {"messages": prompt, "system_message": self.system_msg} 326 | 327 | def _predict_single(self, x: Any) -> Any: 328 | prompt_dict = self._get_prompt(x) 329 | # this will be inherited from the LLM 330 | completion = self._get_chat_completion(model=self.model, **prompt_dict) 331 | completion = self._convert_completion_to_str(completion) 332 | try: 333 | as_dict = json.loads(re_naive_json_extractor(completion)) 334 | label = as_dict["label"] 335 | explanation = str(as_dict["explanation"]) 336 | except Exception as e: 337 | label = "None" 338 | explanation = "Explanation is not available." 339 | # this will be inherited from the sl/ml mixin 340 | prediction = self.validate_prediction(label) 341 | return [prediction, explanation] 342 | 343 | 344 | class BaseFewShotClassifier(BaseClassifier): 345 | def _get_prompt_template(self) -> str: 346 | """Returns the prompt template to use for a single input.""" 347 | if self.prompt_template is not None: 348 | return self.prompt_template 349 | elif isinstance(self, SingleLabelMixin): 350 | return FEW_SHOT_CLF_PROMPT_TEMPLATE 351 | return FEW_SHOT_MLCLF_PROMPT_TEMPLATE 352 | 353 | def _get_prompt(self, x: str) -> dict: 354 | """Returns the prompt to use for a single input.""" 355 | training_data = [] 356 | for xt, yt in zip(*self.training_data_): 357 | training_data.append( 358 | _TRAINING_SAMPLE_PROMPT_TEMPLATE.format(x=xt, label=yt) 359 | ) 360 | 361 | training_data_str = "\n".join(training_data) 362 | 363 | if isinstance(self, SingleLabelMixin): 364 | prompt = build_few_shot_prompt_slc( 365 | x, 366 | repr(self.classes_), 367 | training_data=training_data_str, 368 | template=self._get_prompt_template(), 369 | ) 370 | else: 371 | prompt = build_few_shot_prompt_mlc( 372 | x, 373 | repr(self.classes_), 374 | training_data=training_data_str, 375 | max_cats=self.max_labels, 376 | template=self._get_prompt_template(), 377 | ) 378 | return {"messages": prompt, "system_message": "You are a text classifier."} 379 | 380 | def fit( 381 | self, 382 | X: Union[np.ndarray, pd.Series, List[str]], 383 | y: Union[np.ndarray, pd.Series, List[str], List[List[str]]], 384 | ): 385 | """ 386 | Fits the model to the given data. 387 | 388 | Parameters 389 | ---------- 390 | X : Union[np.ndarray, pd.Series, List[str]] 391 | Training data 392 | y : Union[np.ndarray, pd.Series, List[str]] 393 | Training labels 394 | 395 | Returns 396 | ------- 397 | BaseFewShotClassifier 398 | self 399 | """ 400 | if not len(X) == len(y): 401 | raise ValueError("X and y must have the same length.") 402 | X = _to_numpy(X) 403 | y = _to_numpy(y) 404 | self.training_data_ = (X, y) 405 | self.classes_, self.probabilities_ = self._get_unique_targets(y) 406 | return self 407 | 408 | 409 | class BaseDynamicFewShotClassifier(BaseClassifier): 410 | def __init__( 411 | self, 412 | model: str, 413 | default_label: str = "Random", 414 | n_examples=3, 415 | memory_index: Optional[IndexConstructor] = None, 416 | vectorizer: _BaseVectorizer = None, 417 | prompt_template: Optional[str] = None, 418 | metric="euclidean", 419 | ): 420 | super().__init__( 421 | model=model, 422 | default_label=default_label, 423 | prompt_template=prompt_template, 424 | ) 425 | self.vectorizer = vectorizer 426 | self.memory_index = memory_index 427 | self.n_examples = n_examples 428 | self.metric = metric 429 | if isinstance(self, MultiLabelMixin): 430 | raise TypeError("Multi-label classification is not supported") 431 | 432 | def fit( 433 | self, 434 | X: Union[np.ndarray, pd.Series, List[str]], 435 | y: Union[np.ndarray, pd.Series, List[str]], 436 | ): 437 | """ 438 | Fits the model to the given data. 439 | 440 | Parameters 441 | ---------- 442 | X : Union[np.ndarray, pd.Series, List[str]] 443 | Training data 444 | y : Union[np.ndarray, pd.Series, List[str]] 445 | Training labels 446 | 447 | Returns 448 | ------- 449 | BaseDynamicFewShotClassifier 450 | self 451 | """ 452 | 453 | if not self.vectorizer: 454 | raise ValueError("Vectorizer must be set") 455 | X = _to_numpy(X) 456 | y = _to_numpy(y) 457 | self.embedding_model_ = self.vectorizer.fit(X) 458 | self.classes_, self.probabilities_ = self._get_unique_targets(y) 459 | 460 | self.data_ = {} 461 | for cls in self.classes_: 462 | print(f"Building index for class `{cls}` ...") 463 | self.data_[cls] = {} 464 | partition = X[y == cls] 465 | self.data_[cls]["partition"] = partition 466 | embeddings = self.embedding_model_.transform(partition) 467 | if self.memory_index is not None: 468 | index = self.memory_index() 469 | index.dim = embeddings.shape[1] 470 | else: 471 | index = SklearnMemoryIndex(embeddings.shape[1], metric=self.metric) 472 | for embedding in embeddings: 473 | index.add(embedding) 474 | index.build() 475 | self.data_[cls]["index"] = index 476 | 477 | return self 478 | 479 | def _get_prompt_template(self) -> str: 480 | """Returns the prompt template to use for a single input.""" 481 | if self.prompt_template is not None: 482 | return self.prompt_template 483 | return FEW_SHOT_CLF_PROMPT_TEMPLATE 484 | 485 | def _reorder_examples(self, examples): 486 | n_classes = len(self.classes_) 487 | n_examples = self.n_examples 488 | 489 | shuffled_list = [] 490 | 491 | for i in range(n_examples): 492 | for cls in range(n_classes): 493 | shuffled_list.append(cls * n_examples + i) 494 | 495 | return [examples[i] for i in shuffled_list] 496 | 497 | def _get_prompt(self, x: str) -> dict: 498 | """ 499 | Generates the prompt for the given input. 500 | 501 | Parameters 502 | ---------- 503 | x : str 504 | sample to classify 505 | 506 | Returns 507 | ------- 508 | dict 509 | final prompt 510 | """ 511 | embedding = self.embedding_model_.transform([x]) 512 | training_data = [] 513 | for cls in self.classes_: 514 | index = self.data_[cls]["index"] 515 | partition = self.data_[cls]["partition"] 516 | neighbors = index.retrieve(embedding, min(self.n_examples, len(partition))) 517 | neighbors = [partition[i] for i in neighbors[0]] 518 | training_data.extend( 519 | [ 520 | _TRAINING_SAMPLE_PROMPT_TEMPLATE.format(x=neighbor, label=cls) 521 | for neighbor in neighbors 522 | ] 523 | ) 524 | 525 | training_data_str = "\n".join(self._reorder_examples(training_data)) 526 | 527 | msg = build_few_shot_prompt_slc( 528 | x=x, 529 | training_data=training_data_str, 530 | labels=repr(self.classes_), 531 | template=self._get_prompt_template(), 532 | ) 533 | 534 | return {"messages": msg, "system_message": "You are a text classifier."} 535 | 536 | 537 | class BaseTunableClassifier(BaseClassifier): 538 | def fit( 539 | self, 540 | X: Optional[Union[np.ndarray, pd.Series, List[str]]], 541 | y: Union[np.ndarray, pd.Series, List[str], List[List[str]]], 542 | ): 543 | """ 544 | Fits the model to the given data. 545 | 546 | Parameters 547 | ---------- 548 | X : Optional[Union[np.ndarray, pd.Series, List[str]]] 549 | Training data 550 | y : Union[np.ndarray, pd.Series, List[str], List[List[str]]] 551 | Training labels 552 | 553 | Returns 554 | ------- 555 | BaseTunableClassifier 556 | self 557 | """ 558 | if not isinstance(self, _BaseTunableMixin): 559 | raise TypeError( 560 | "Classifier must be mixed with a skllm.llm.base.BaseTunableMixin class" 561 | ) 562 | super().fit(X, y) 563 | self._tune(X, y) 564 | return self 565 | 566 | def _get_prompt_template(self) -> str: 567 | """Returns the prompt template to use for a single input.""" 568 | if self.prompt_template is not None: 569 | return self.prompt_template 570 | elif isinstance(self, SingleLabelMixin): 571 | return ZERO_SHOT_CLF_SHORT_PROMPT_TEMPLATE 572 | return ZERO_SHOT_MLCLF_SHORT_PROMPT_TEMPLATE 573 | 574 | def _get_prompt(self, x: str) -> dict: 575 | """Returns the prompt to use for a single input.""" 576 | if isinstance(self, SingleLabelMixin): 577 | prompt = build_zero_shot_prompt_slc( 578 | x, repr(self.classes_), template=self._get_prompt_template() 579 | ) 580 | else: 581 | prompt = build_zero_shot_prompt_mlc( 582 | x, 583 | repr(self.classes_), 584 | self.max_labels, 585 | template=self._get_prompt_template(), 586 | ) 587 | return {"messages": prompt, "system_message": "You are a text classifier."} 588 | -------------------------------------------------------------------------------- /skllm/models/_base/tagger.py: -------------------------------------------------------------------------------- 1 | from typing import Any, Union, List, Optional, Dict 2 | from abc import abstractmethod, ABC 3 | from numpy import ndarray 4 | from tqdm import tqdm 5 | import numpy as np 6 | import pandas as pd 7 | from skllm.utils import to_numpy as _to_numpy 8 | from sklearn.base import ( 9 | BaseEstimator as _SklBaseEstimator, 10 | TransformerMixin as _SklTransformerMixin, 11 | ) 12 | 13 | from skllm.utils.rendering import display_ner 14 | from skllm.utils.xml import filter_xml_tags, filter_unwanted_entities, json_to_xml 15 | from skllm.prompts.builders import build_ner_prompt 16 | from skllm.prompts.templates import ( 17 | NER_SYSTEM_MESSAGE_TEMPLATE, 18 | NER_SYSTEM_MESSAGE_SPARSE, 19 | EXPLAINABLE_NER_DENSE_PROMPT_TEMPLATE, 20 | EXPLAINABLE_NER_SPARSE_PROMPT_TEMPLATE, 21 | ) 22 | from skllm.utils import re_naive_json_extractor 23 | import json 24 | from concurrent.futures import ThreadPoolExecutor 25 | 26 | class BaseTagger(ABC, _SklBaseEstimator, _SklTransformerMixin): 27 | 28 | num_workers = 1 29 | 30 | def fit(self, X: Any, y: Any = None): 31 | """ 32 | Fits the model to the data. Usually a no-op. 33 | 34 | Parameters 35 | ---------- 36 | X : Any 37 | training data 38 | y : Any 39 | training outputs 40 | 41 | Returns 42 | ------- 43 | self 44 | BaseTagger 45 | """ 46 | return self 47 | 48 | def predict(self, X: Union[np.ndarray, pd.Series, List[str]]): 49 | return self.transform(X) 50 | 51 | def fit_transform( 52 | self, 53 | X: Union[np.ndarray, pd.Series, List[str]], 54 | y: Optional[Union[np.ndarray, pd.Series, List[str]]] = None, 55 | ) -> ndarray: 56 | return self.fit(X, y).transform(X) 57 | 58 | def transform(self, X: Union[np.ndarray, pd.Series, List[str]]): 59 | """ 60 | Transforms the input data. 61 | 62 | Parameters 63 | ---------- 64 | X : Union[np.ndarray, pd.Series, List[str]] 65 | The input data to predict the class of. 66 | 67 | Returns 68 | ------- 69 | List[str] 70 | """ 71 | X = _to_numpy(X) 72 | predictions = [] 73 | with ThreadPoolExecutor(max_workers=self.num_workers) as executor: 74 | predictions = list(tqdm(executor.map(self._predict_single, X), total=len(X))) 75 | return np.asarray(predictions) 76 | 77 | def _predict_single(self, x: Any) -> Any: 78 | prompt_dict = self._get_prompt(x) 79 | # this will be inherited from the LLM 80 | prediction = self._get_chat_completion(model=self.model, **prompt_dict) 81 | prediction = self._convert_completion_to_str(prediction) 82 | return prediction 83 | 84 | @abstractmethod 85 | def _get_prompt(self, x: str) -> dict: 86 | """Returns the prompt to use for a single input.""" 87 | pass 88 | 89 | 90 | class ExplainableNER(BaseTagger): 91 | entities: Optional[Dict[str, str]] = None 92 | sparse_output = True 93 | _allowed_tags = ["entity", "not_entity"] 94 | 95 | display_predictions = False 96 | 97 | def fit(self, X: Any, y: Any = None): 98 | entities = [] 99 | for k, v in self.entities.items(): 100 | entities.append({"entity": k, "definition": v}) 101 | self.expanded_entities_ = entities 102 | return self 103 | 104 | def transform(self, X: ndarray | pd.Series | List[str]): 105 | predictions = super().transform(X) 106 | if self.sparse_output: 107 | json_predictions = [ 108 | re_naive_json_extractor(p, expected_output="array") for p in predictions 109 | ] 110 | predictions = [] 111 | attributes = ["reasoning", "tag", "value"] 112 | for x, p in zip(X, json_predictions): 113 | p_json = json.loads(p) 114 | predictions.append( 115 | json_to_xml( 116 | x, 117 | p_json, 118 | "entity", 119 | "not_entity", 120 | value_key="value", 121 | attributes=attributes, 122 | ) 123 | ) 124 | 125 | predictions = [ 126 | filter_unwanted_entities( 127 | filter_xml_tags(p, self._allowed_tags), self.entities 128 | ) 129 | for p in predictions 130 | ] 131 | if self.display_predictions: 132 | print("Displaying predictions...") 133 | display_ner(predictions, self.entities) 134 | return predictions 135 | 136 | def _get_prompt(self, x: str) -> dict: 137 | if not hasattr(self, "expanded_entities_"): 138 | raise ValueError("Model not fitted.") 139 | system_message = ( 140 | NER_SYSTEM_MESSAGE_TEMPLATE.format(entities=self.entities.keys()) 141 | if self.sparse_output 142 | else NER_SYSTEM_MESSAGE_SPARSE 143 | ) 144 | template = ( 145 | EXPLAINABLE_NER_SPARSE_PROMPT_TEMPLATE 146 | if self.sparse_output 147 | else EXPLAINABLE_NER_DENSE_PROMPT_TEMPLATE 148 | ) 149 | return { 150 | "messages": build_ner_prompt(self.expanded_entities_, x, template=template), 151 | "system_message": system_message.format(entities=self.entities.keys()), 152 | } 153 | -------------------------------------------------------------------------------- /skllm/models/_base/text2text.py: -------------------------------------------------------------------------------- 1 | from typing import Any, Union, List, Optional 2 | from abc import abstractmethod, ABC 3 | from numpy import ndarray 4 | from tqdm import tqdm 5 | import numpy as np 6 | import pandas as pd 7 | from skllm.utils import to_numpy as _to_numpy 8 | from sklearn.base import ( 9 | BaseEstimator as _SklBaseEstimator, 10 | TransformerMixin as _SklTransformerMixin, 11 | ) 12 | from skllm.llm.base import BaseTunableMixin as _BaseTunableMixin 13 | from skllm.prompts.builders import ( 14 | build_focused_summary_prompt, 15 | build_summary_prompt, 16 | build_translation_prompt, 17 | ) 18 | 19 | 20 | class BaseText2TextModel(ABC, _SklBaseEstimator, _SklTransformerMixin): 21 | def fit(self, X: Any, y: Any = None): 22 | """ 23 | Fits the model to the data. Usually a no-op. 24 | 25 | Parameters 26 | ---------- 27 | X : Any 28 | training data 29 | y : Any 30 | training outputs 31 | 32 | Returns 33 | ------- 34 | self 35 | BaseText2TextModel 36 | """ 37 | return self 38 | 39 | def predict(self, X: Union[np.ndarray, pd.Series, List[str]]): 40 | return self.transform(X) 41 | 42 | def fit_transform( 43 | self, 44 | X: Union[np.ndarray, pd.Series, List[str]], 45 | y: Optional[Union[np.ndarray, pd.Series, List[str]]] = None, 46 | ) -> ndarray: 47 | return self.fit(X, y).transform(X) 48 | 49 | def transform(self, X: Union[np.ndarray, pd.Series, List[str]]): 50 | """ 51 | Transforms the input data. 52 | 53 | Parameters 54 | ---------- 55 | X : Union[np.ndarray, pd.Series, List[str]] 56 | The input data to predict the class of. 57 | 58 | Returns 59 | ------- 60 | List[str] 61 | """ 62 | X = _to_numpy(X) 63 | predictions = [] 64 | for i in tqdm(range(len(X))): 65 | predictions.append(self._predict_single(X[i])) 66 | return predictions 67 | 68 | def _predict_single(self, x: Any) -> Any: 69 | prompt_dict = self._get_prompt(x) 70 | # this will be inherited from the LLM 71 | prediction = self._get_chat_completion(model=self.model, **prompt_dict) 72 | prediction = self._convert_completion_to_str(prediction) 73 | return prediction 74 | 75 | @abstractmethod 76 | def _get_prompt(self, x: str) -> dict: 77 | """Returns the prompt to use for a single input.""" 78 | pass 79 | 80 | 81 | class BaseTunableText2TextModel(BaseText2TextModel): 82 | def fit( 83 | self, 84 | X: Union[np.ndarray, pd.Series, List[str]], 85 | y: Union[np.ndarray, pd.Series, List[str]], 86 | ): 87 | """ 88 | Fits the model to the data. 89 | 90 | Parameters 91 | ---------- 92 | X : Union[np.ndarray, pd.Series, List[str]] 93 | training data 94 | y : Union[np.ndarray, pd.Series, List[str]] 95 | training labels 96 | 97 | Returns 98 | ------- 99 | BaseTunableText2TextModel 100 | self 101 | """ 102 | if not isinstance(self, _BaseTunableMixin): 103 | raise TypeError( 104 | "Classifier must be mixed with a skllm.llm.base.BaseTunableMixin class" 105 | ) 106 | self._tune(X, y) 107 | return self 108 | 109 | def _get_prompt(self, x: str) -> dict: 110 | """Returns the prompt to use for a single input.""" 111 | return {"messages": str(x), "system_message": ""} 112 | 113 | def _predict_single(self, x: str) -> str: 114 | if self.model is None: 115 | raise RuntimeError("Model has not been tuned yet") 116 | return super()._predict_single(x) 117 | 118 | 119 | class BaseSummarizer(BaseText2TextModel): 120 | max_words: int = 15 121 | focus: Optional[str] = None 122 | system_message: str = "You are a text summarizer." 123 | 124 | def _get_prompt(self, X: str) -> str: 125 | if self.focus: 126 | prompt = build_focused_summary_prompt(X, self.max_words, self.focus) 127 | else: 128 | prompt = build_summary_prompt(X, self.max_words) 129 | return {"messages": prompt, "system_message": self.system_message} 130 | 131 | def transform( 132 | self, X: Union[ndarray, pd.Series, List[str]], **kwargs: Any 133 | ) -> ndarray: 134 | """ 135 | Transforms the input data. 136 | 137 | Parameters 138 | ---------- 139 | X : Union[np.ndarray, pd.Series, List[str]] 140 | The input data to predict the class of. 141 | 142 | Returns 143 | ------- 144 | List[str] 145 | """ 146 | y = super().transform(X, **kwargs) 147 | if self.focus: 148 | y = np.asarray( 149 | [ 150 | i.replace("Mentioned concept is not present in the text.", "") 151 | .replace("The general summary is:", "") 152 | .strip() 153 | for i in y 154 | ], 155 | dtype=object, 156 | ) 157 | return y 158 | 159 | 160 | class BaseTranslator(BaseText2TextModel): 161 | output_language: str = "English" 162 | system_message = "You are a text translator." 163 | 164 | def _get_prompt(self, X: str) -> str: 165 | prompt = build_translation_prompt(X, self.output_language) 166 | return {"messages": prompt, "system_message": self.system_message} 167 | 168 | def transform( 169 | self, X: Union[ndarray, pd.Series, List[str]], **kwargs: Any 170 | ) -> ndarray: 171 | """ 172 | Transforms the input data. 173 | 174 | Parameters 175 | ---------- 176 | X : Union[np.ndarray, pd.Series, List[str]] 177 | The input data to predict the class of. 178 | 179 | Returns 180 | ------- 181 | List[str] 182 | """ 183 | y = super().transform(X, **kwargs) 184 | y = np.asarray( 185 | [i.replace("[Translated text:]", "").replace("```", "").strip() for i in y], 186 | dtype=object, 187 | ) 188 | return y 189 | -------------------------------------------------------------------------------- /skllm/models/_base/vectorizer.py: -------------------------------------------------------------------------------- 1 | from typing import Any, List, Optional, Union 2 | import numpy as np 3 | import pandas as pd 4 | from skllm.utils import to_numpy as _to_numpy 5 | from skllm.llm.base import BaseEmbeddingMixin 6 | from sklearn.base import ( 7 | BaseEstimator as _SklBaseEstimator, 8 | TransformerMixin as _SklTransformerMixin, 9 | ) 10 | 11 | 12 | class BaseVectorizer(_SklBaseEstimator, _SklTransformerMixin): 13 | """ 14 | A base vectorization/embedding class. 15 | 16 | Parameters 17 | ---------- 18 | model : str 19 | The embedding model to use. 20 | """ 21 | 22 | def __init__(self, model: str, batch_size: int = 1): 23 | if not isinstance(self, BaseEmbeddingMixin): 24 | raise TypeError( 25 | "Vectorizer must be mixed with skllm.llm.base.BaseEmbeddingMixin." 26 | ) 27 | self.model = model 28 | if not isinstance(batch_size, int): 29 | raise TypeError("batch_size must be an integer") 30 | if batch_size < 1: 31 | raise ValueError("batch_size must be greater than 0") 32 | self.batch_size = batch_size 33 | 34 | def fit(self, X: Any = None, y: Any = None, **kwargs): 35 | """ 36 | Does nothing. Needed only for sklearn compatibility. 37 | 38 | Parameters 39 | ---------- 40 | X : Any, optional 41 | y : Any, optional 42 | kwargs : dict, optional 43 | 44 | Returns 45 | ------- 46 | self : BaseVectorizer 47 | """ 48 | return self 49 | 50 | def transform( 51 | self, X: Optional[Union[np.ndarray, pd.Series, List[str]]] 52 | ) -> np.ndarray: 53 | """ 54 | Transforms a list of strings into a list of GPT embeddings. 55 | This is modelled to function as the sklearn transform method 56 | 57 | Parameters 58 | ---------- 59 | X : Optional[Union[np.ndarray, pd.Series, List[str]]] 60 | The input array of strings to transform into GPT embeddings. 61 | 62 | Returns 63 | ------- 64 | embeddings : np.ndarray 65 | """ 66 | X = _to_numpy(X) 67 | embeddings = self._get_embeddings(X) 68 | embeddings = np.asarray(embeddings) 69 | return embeddings 70 | 71 | def fit_transform( 72 | self, 73 | X: Optional[Union[np.ndarray, pd.Series, List[str]]], 74 | y: Any = None, 75 | **fit_params, 76 | ) -> np.ndarray: 77 | """ 78 | Fits and transforms a list of strings into a list of embeddings. 79 | This is modelled to function as the sklearn fit_transform method 80 | 81 | Parameters 82 | ---------- 83 | X : Optional[Union[np.ndarray, pd.Series, List[str]]] 84 | The input array of strings to transform into embeddings. 85 | y : Any, optional 86 | 87 | Returns 88 | ------- 89 | embeddings : np.ndarray 90 | """ 91 | return self.fit(X, y).transform(X) 92 | -------------------------------------------------------------------------------- /skllm/models/anthropic/classification/few_shot.py: -------------------------------------------------------------------------------- 1 | from skllm.models._base.classifier import ( 2 | BaseFewShotClassifier, 3 | BaseDynamicFewShotClassifier, 4 | SingleLabelMixin, 5 | MultiLabelMixin, 6 | ) 7 | from skllm.llm.anthropic.mixin import ClaudeClassifierMixin 8 | from skllm.models.gpt.vectorization import GPTVectorizer 9 | from skllm.models._base.vectorizer import BaseVectorizer 10 | from skllm.memory.base import IndexConstructor 11 | from typing import Optional 12 | 13 | 14 | class FewShotClaudeClassifier(BaseFewShotClassifier, ClaudeClassifierMixin, SingleLabelMixin): 15 | """Few-shot text classifier using Anthropic's Claude API for single-label classification tasks.""" 16 | 17 | def __init__( 18 | self, 19 | model: str = "claude-3-haiku-20240307", 20 | default_label: str = "Random", 21 | prompt_template: Optional[str] = None, 22 | key: Optional[str] = None, 23 | **kwargs, 24 | ): 25 | """ 26 | Few-shot text classifier using Anthropic's Claude API. 27 | 28 | Parameters 29 | ---------- 30 | model : str, optional 31 | model to use, by default "claude-3-haiku-20240307" 32 | default_label : str, optional 33 | default label for failed prediction; if "Random" -> selects randomly based on class frequencies 34 | prompt_template : Optional[str], optional 35 | custom prompt template to use, by default None 36 | key : Optional[str], optional 37 | estimator-specific API key; if None, retrieved from the global config 38 | """ 39 | super().__init__( 40 | model=model, 41 | default_label=default_label, 42 | prompt_template=prompt_template, 43 | **kwargs, 44 | ) 45 | self._set_keys(key) 46 | 47 | 48 | class MultiLabelFewShotClaudeClassifier( 49 | BaseFewShotClassifier, ClaudeClassifierMixin, MultiLabelMixin 50 | ): 51 | """Few-shot text classifier using Anthropic's Claude API for multi-label classification tasks.""" 52 | 53 | def __init__( 54 | self, 55 | model: str = "claude-3-haiku-20240307", 56 | default_label: str = "Random", 57 | max_labels: Optional[int] = 5, 58 | prompt_template: Optional[str] = None, 59 | key: Optional[str] = None, 60 | **kwargs, 61 | ): 62 | """ 63 | Multi-label few-shot text classifier using Anthropic's Claude API. 64 | 65 | Parameters 66 | ---------- 67 | model : str, optional 68 | model to use, by default "claude-3-haiku-20240307" 69 | default_label : str, optional 70 | default label for failed prediction; if "Random" -> selects randomly based on class frequencies 71 | max_labels : Optional[int], optional 72 | maximum labels per sample, by default 5 73 | prompt_template : Optional[str], optional 74 | custom prompt template to use, by default None 75 | key : Optional[str], optional 76 | estimator-specific API key; if None, retrieved from the global config 77 | """ 78 | super().__init__( 79 | model=model, 80 | default_label=default_label, 81 | max_labels=max_labels, 82 | prompt_template=prompt_template, 83 | **kwargs, 84 | ) 85 | self._set_keys(key) 86 | 87 | 88 | class DynamicFewShotClaudeClassifier( 89 | BaseDynamicFewShotClassifier, ClaudeClassifierMixin, SingleLabelMixin 90 | ): 91 | """ 92 | Dynamic few-shot text classifier using Anthropic's Claude API for 93 | single-label classification tasks with dynamic example selection using GPT embeddings. 94 | """ 95 | 96 | def __init__( 97 | self, 98 | model: str = "claude-3-haiku-20240307", 99 | default_label: str = "Random", 100 | prompt_template: Optional[str] = None, 101 | key: Optional[str] = None, 102 | n_examples: int = 3, 103 | memory_index: Optional[IndexConstructor] = None, 104 | vectorizer: Optional[BaseVectorizer] = None, 105 | metric: Optional[str] = "euclidean", 106 | **kwargs, 107 | ): 108 | """ 109 | Dynamic few-shot text classifier using Anthropic's Claude API. 110 | For each sample, N closest examples are retrieved from the memory. 111 | 112 | Parameters 113 | ---------- 114 | model : str, optional 115 | model to use, by default "claude-3-haiku-20240307" 116 | default_label : str, optional 117 | default label for failed prediction; if "Random" -> selects randomly based on class frequencies 118 | prompt_template : Optional[str], optional 119 | custom prompt template to use, by default None 120 | key : Optional[str], optional 121 | estimator-specific API key; if None, retrieved from the global config 122 | n_examples : int, optional 123 | number of closest examples per class to be retrieved, by default 3 124 | memory_index : Optional[IndexConstructor], optional 125 | custom memory index, for details check `skllm.memory` submodule 126 | vectorizer : Optional[BaseVectorizer], optional 127 | scikit-llm vectorizer; if None, `GPTVectorizer` is used 128 | metric : Optional[str], optional 129 | metric used for similarity search, by default "euclidean" 130 | """ 131 | if vectorizer is None: 132 | vectorizer = GPTVectorizer(model="text-embedding-ada-002", key=key) 133 | super().__init__( 134 | model=model, 135 | default_label=default_label, 136 | prompt_template=prompt_template, 137 | n_examples=n_examples, 138 | memory_index=memory_index, 139 | vectorizer=vectorizer, 140 | metric=metric, 141 | ) 142 | self._set_keys(key) 143 | -------------------------------------------------------------------------------- /skllm/models/anthropic/classification/zero_shot.py: -------------------------------------------------------------------------------- 1 | from skllm.models._base.classifier import ( 2 | SingleLabelMixin as _SingleLabelMixin, 3 | MultiLabelMixin as _MultiLabelMixin, 4 | BaseZeroShotClassifier as _BaseZeroShotClassifier, 5 | BaseCoTClassifier as _BaseCoTClassifier, 6 | ) 7 | from skllm.llm.anthropic.mixin import ClaudeClassifierMixin as _ClaudeClassifierMixin 8 | from typing import Optional 9 | 10 | 11 | class ZeroShotClaudeClassifier( 12 | _BaseZeroShotClassifier, _ClaudeClassifierMixin, _SingleLabelMixin 13 | ): 14 | """Zero-shot text classifier using Anthropic Claude models for single-label classification.""" 15 | 16 | def __init__( 17 | self, 18 | model: str = "claude-3-haiku-20240307", 19 | default_label: str = "Random", 20 | prompt_template: Optional[str] = None, 21 | key: Optional[str] = None, 22 | **kwargs, 23 | ): 24 | """ 25 | Zero-shot text classifier using Anthropic Claude models. 26 | 27 | Parameters 28 | ---------- 29 | model : str, optional 30 | Model to use, by default "claude-3-haiku-20240307". 31 | default_label : str, optional 32 | Default label for failed predictions; if "Random", selects randomly based on class frequencies, defaults to "Random". 33 | prompt_template : Optional[str], optional 34 | Custom prompt template to use, by default None. 35 | key : Optional[str], optional 36 | Estimator-specific API key; if None, retrieved from the global config, by default None. 37 | """ 38 | super().__init__( 39 | model=model, 40 | default_label=default_label, 41 | prompt_template=prompt_template, 42 | **kwargs, 43 | ) 44 | self._set_keys(key) 45 | 46 | 47 | class CoTClaudeClassifier( 48 | _BaseCoTClassifier, _ClaudeClassifierMixin, _SingleLabelMixin 49 | ): 50 | """Chain-of-thought text classifier using Anthropic Claude models for single-label classification.""" 51 | 52 | def __init__( 53 | self, 54 | model: str = "claude-3-haiku-20240307", 55 | default_label: str = "Random", 56 | prompt_template: Optional[str] = None, 57 | key: Optional[str] = None, 58 | **kwargs, 59 | ): 60 | """ 61 | Chain-of-thought text classifier using Anthropic Claude models. 62 | 63 | Parameters 64 | ---------- 65 | model : str, optional 66 | Model to use, by default "claude-3-haiku-20240307". 67 | default_label : str, optional 68 | Default label for failed predictions; if "Random", selects randomly based on class frequencies, defaults to "Random". 69 | prompt_template : Optional[str], optional 70 | Custom prompt template to use, by default None. 71 | key : Optional[str], optional 72 | Estimator-specific API key; if None, retrieved from the global config, by default None. 73 | """ 74 | super().__init__( 75 | model=model, 76 | default_label=default_label, 77 | prompt_template=prompt_template, 78 | **kwargs, 79 | ) 80 | self._set_keys(key) 81 | 82 | 83 | class MultiLabelZeroShotClaudeClassifier( 84 | _BaseZeroShotClassifier, _ClaudeClassifierMixin, _MultiLabelMixin 85 | ): 86 | """Zero-shot text classifier using Anthropic Claude models for multi-label classification.""" 87 | 88 | def __init__( 89 | self, 90 | model: str = "claude-3-haiku-20240307", 91 | default_label: str = "Random", 92 | max_labels: Optional[int] = 5, 93 | prompt_template: Optional[str] = None, 94 | key: Optional[str] = None, 95 | **kwargs, 96 | ): 97 | """ 98 | Multi-label zero-shot text classifier using Anthropic Claude models. 99 | 100 | Parameters 101 | ---------- 102 | model : str, optional 103 | Model to use, by default "claude-3-haiku-20240307". 104 | default_label : str, optional 105 | Default label for failed predictions; if "Random", selects randomly based on class frequencies, defaults to "Random". 106 | max_labels : Optional[int], optional 107 | Maximum number of labels per sample, by default 5. 108 | prompt_template : Optional[str], optional 109 | Custom prompt template to use, by default None. 110 | key : Optional[str], optional 111 | Estimator-specific API key; if None, retrieved from the global config, by default None. 112 | """ 113 | super().__init__( 114 | model=model, 115 | default_label=default_label, 116 | max_labels=max_labels, 117 | prompt_template=prompt_template, 118 | **kwargs, 119 | ) 120 | self._set_keys(key) -------------------------------------------------------------------------------- /skllm/models/anthropic/tagging/ner.py: -------------------------------------------------------------------------------- 1 | from skllm.models._base.tagger import ExplainableNER as _ExplainableNER 2 | from skllm.llm.anthropic.mixin import ClaudeTextCompletionMixin as _ClaudeTextCompletionMixin 3 | from typing import Optional, Dict 4 | 5 | 6 | class AnthropicExplainableNER(_ExplainableNER, _ClaudeTextCompletionMixin): 7 | """Named Entity Recognition model using Anthropic's Claude API for explainable entity extraction.""" 8 | 9 | def __init__( 10 | self, 11 | entities: Dict[str, str], 12 | display_predictions: bool = False, 13 | sparse_output: bool = True, 14 | model: str = "claude-3-haiku-20240307", 15 | key: Optional[str] = None, 16 | num_workers: int = 1, 17 | ) -> None: 18 | """ 19 | Named entity recognition using Anthropic Claude API. 20 | 21 | Parameters 22 | ---------- 23 | entities : dict 24 | dictionary of entities to recognize, with keys as entity names and values as descriptions 25 | display_predictions : bool, optional 26 | whether to display predictions, by default False 27 | sparse_output : bool, optional 28 | whether to generate a sparse representation of the predictions, by default True 29 | model : str, optional 30 | model to use, by default "claude-3-haiku-20240307" 31 | key : Optional[str], optional 32 | estimator-specific API key; if None, retrieved from the global config 33 | num_workers : int, optional 34 | number of workers (threads) to use, by default 1 35 | """ 36 | self._set_keys(key) 37 | self.model = model 38 | self.entities = entities 39 | self.display_predictions = display_predictions 40 | self.sparse_output = sparse_output 41 | self.num_workers = num_workers -------------------------------------------------------------------------------- /skllm/models/anthropic/text2text/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BeastByteAI/scikit-llm/5491ec8d1ba5528b560cd115f7a0c93369fb0628/skllm/models/anthropic/text2text/__init__.py -------------------------------------------------------------------------------- /skllm/models/anthropic/text2text/summarization.py: -------------------------------------------------------------------------------- 1 | from skllm.models._base.text2text import BaseSummarizer as _BaseSummarizer 2 | from skllm.llm.anthropic.mixin import ClaudeTextCompletionMixin as _ClaudeTextCompletionMixin 3 | from typing import Optional 4 | 5 | 6 | class ClaudeSummarizer(_BaseSummarizer, _ClaudeTextCompletionMixin): 7 | """Text summarizer using Anthropic Claude API.""" 8 | 9 | def __init__( 10 | self, 11 | model: str = "claude-3-haiku-20240307", 12 | key: Optional[str] = None, 13 | max_words: int = 15, 14 | focus: Optional[str] = None, 15 | ) -> None: 16 | """ 17 | Initialize the Claude summarizer. 18 | 19 | Parameters 20 | ---------- 21 | model : str, optional 22 | Model to use, by default "claude-3-haiku-20240307" 23 | key : Optional[str], optional 24 | Estimator-specific API key; if None, retrieved from global config 25 | max_words : int, optional 26 | Soft limit of the summary length, by default 15 27 | focus : Optional[str], optional 28 | Concept in the text to focus on, by default None 29 | """ 30 | self._set_keys(key) 31 | self.model = model 32 | self.max_words = max_words 33 | self.focus = focus 34 | self.system_message = "You are a text summarizer. Provide concise and accurate summaries." -------------------------------------------------------------------------------- /skllm/models/anthropic/text2text/translation.py: -------------------------------------------------------------------------------- 1 | from skllm.models._base.text2text import BaseTranslator as _BaseTranslator 2 | from skllm.llm.anthropic.mixin import ClaudeTextCompletionMixin as _ClaudeTextCompletionMixin 3 | from typing import Optional 4 | 5 | 6 | class ClaudeTranslator(_BaseTranslator, _ClaudeTextCompletionMixin): 7 | """Text translator using Anthropic Claude API.""" 8 | 9 | default_output = "Translation is unavailable." 10 | 11 | def __init__( 12 | self, 13 | model: str = "claude-3-haiku-20240307", 14 | key: Optional[str] = None, 15 | output_language: str = "English", 16 | ) -> None: 17 | """ 18 | Initialize the Claude translator. 19 | 20 | Parameters 21 | ---------- 22 | model : str, optional 23 | Model to use, by default "claude-3-haiku-20240307" 24 | key : Optional[str], optional 25 | Estimator-specific API key; if None, retrieved from global config 26 | output_language : str, optional 27 | Target language, by default "English" 28 | """ 29 | self._set_keys(key) 30 | self.model = model 31 | self.output_language = output_language 32 | self.system_message = ( 33 | "You are a professional translator. Provide accurate translations " 34 | "while maintaining the original meaning and tone of the text." 35 | ) -------------------------------------------------------------------------------- /skllm/models/gpt/classification/few_shot.py: -------------------------------------------------------------------------------- 1 | from skllm.models._base.classifier import ( 2 | BaseFewShotClassifier, 3 | BaseDynamicFewShotClassifier, 4 | SingleLabelMixin, 5 | MultiLabelMixin, 6 | ) 7 | from skllm.llm.gpt.mixin import GPTClassifierMixin 8 | from skllm.models.gpt.vectorization import GPTVectorizer 9 | from skllm.models._base.vectorizer import BaseVectorizer 10 | from skllm.memory.base import IndexConstructor 11 | from typing import Optional 12 | 13 | 14 | class FewShotGPTClassifier(BaseFewShotClassifier, GPTClassifierMixin, SingleLabelMixin): 15 | def __init__( 16 | self, 17 | model: str = "gpt-3.5-turbo", 18 | default_label: str = "Random", 19 | prompt_template: Optional[str] = None, 20 | key: Optional[str] = None, 21 | org: Optional[str] = None, 22 | **kwargs, 23 | ): 24 | """ 25 | Few-shot text classifier using OpenAI/GPT API-compatible models. 26 | 27 | Parameters 28 | ---------- 29 | model : str, optional 30 | model to use, by default "gpt-3.5-turbo" 31 | default_label : str, optional 32 | default label for failed prediction; if "Random" -> selects randomly based on class frequencies, by default "Random" 33 | prompt_template : Optional[str], optional 34 | custom prompt template to use, by default None 35 | key : Optional[str], optional 36 | estimator-specific API key; if None, retrieved from the global config, by default None 37 | org : Optional[str], optional 38 | estimator-specific ORG key; if None, retrieved from the global config, by default None 39 | """ 40 | super().__init__( 41 | model=model, 42 | default_label=default_label, 43 | prompt_template=prompt_template, 44 | **kwargs, 45 | ) 46 | self._set_keys(key, org) 47 | 48 | 49 | class MultiLabelFewShotGPTClassifier( 50 | BaseFewShotClassifier, GPTClassifierMixin, MultiLabelMixin 51 | ): 52 | def __init__( 53 | self, 54 | model: str = "gpt-3.5-turbo", 55 | default_label: str = "Random", 56 | max_labels: Optional[int] = 5, 57 | prompt_template: Optional[str] = None, 58 | key: Optional[str] = None, 59 | org: Optional[str] = None, 60 | **kwargs, 61 | ): 62 | """ 63 | Multi-label few-shot text classifier using OpenAI/GPT API-compatible models. 64 | 65 | Parameters 66 | ---------- 67 | model : str, optional 68 | model to use, by default "gpt-3.5-turbo" 69 | default_label : str, optional 70 | default label for failed prediction; if "Random" -> selects randomly based on class frequencies, by default "Random" 71 | max_labels : Optional[int], optional 72 | maximum labels per sample, by default 5 73 | prompt_template : Optional[str], optional 74 | custom prompt template to use, by default None 75 | key : Optional[str], optional 76 | estimator-specific API key; if None, retrieved from the global config, by default None 77 | org : Optional[str], optional 78 | estimator-specific ORG key; if None, retrieved from the global config, by default None 79 | """ 80 | super().__init__( 81 | model=model, 82 | default_label=default_label, 83 | max_labels=max_labels, 84 | prompt_template=prompt_template, 85 | **kwargs, 86 | ) 87 | self._set_keys(key, org) 88 | 89 | 90 | class DynamicFewShotGPTClassifier( 91 | BaseDynamicFewShotClassifier, GPTClassifierMixin, SingleLabelMixin 92 | ): 93 | def __init__( 94 | self, 95 | model: str = "gpt-3.5-turbo", 96 | default_label: str = "Random", 97 | prompt_template: Optional[str] = None, 98 | key: Optional[str] = None, 99 | org: Optional[str] = None, 100 | n_examples: int = 3, 101 | memory_index: Optional[IndexConstructor] = None, 102 | vectorizer: Optional[BaseVectorizer] = None, 103 | metric: Optional[str] = "euclidean", 104 | **kwargs, 105 | ): 106 | """ 107 | Dynamic few-shot text classifier using OpenAI/GPT API-compatible models. 108 | For each sample, N closest examples are retrieved from the memory. 109 | 110 | Parameters 111 | ---------- 112 | model : str, optional 113 | model to use, by default "gpt-3.5-turbo" 114 | default_label : str, optional 115 | default label for failed prediction; if "Random" -> selects randomly based on class frequencies, by default "Random" 116 | prompt_template : Optional[str], optional 117 | custom prompt template to use, by default None 118 | key : Optional[str], optional 119 | estimator-specific API key; if None, retrieved from the global config, by default None 120 | org : Optional[str], optional 121 | estimator-specific ORG key; if None, retrieved from the global config, by default None 122 | n_examples : int, optional 123 | number of closest examples per class to be retrieved, by default 3 124 | memory_index : Optional[IndexConstructor], optional 125 | custom memory index, for details check `skllm.memory` submodule, by default None 126 | vectorizer : Optional[BaseVectorizer], optional 127 | scikit-llm vectorizer; if None, `GPTVectorizer` is used, by default None 128 | metric : Optional[str], optional 129 | metric used for similarity search, by default "euclidean" 130 | """ 131 | if vectorizer is None: 132 | vectorizer = GPTVectorizer(model="text-embedding-ada-002", key=key, org=org) 133 | super().__init__( 134 | model=model, 135 | default_label=default_label, 136 | prompt_template=prompt_template, 137 | n_examples=n_examples, 138 | memory_index=memory_index, 139 | vectorizer=vectorizer, 140 | metric=metric, 141 | ) 142 | self._set_keys(key, org) 143 | -------------------------------------------------------------------------------- /skllm/models/gpt/classification/tunable.py: -------------------------------------------------------------------------------- 1 | from skllm.llm.gpt.mixin import ( 2 | GPTClassifierMixin as _GPTClassifierMixin, 3 | GPTTunableMixin as _GPTTunableMixin, 4 | ) 5 | from skllm.models._base.classifier import ( 6 | BaseTunableClassifier as _BaseTunableClassifier, 7 | SingleLabelMixin as _SingleLabelMixin, 8 | MultiLabelMixin as _MultiLabelMixin, 9 | ) 10 | from typing import Optional 11 | 12 | 13 | class _TunableClassifier(_BaseTunableClassifier, _GPTClassifierMixin, _GPTTunableMixin): 14 | pass 15 | 16 | 17 | class GPTClassifier(_TunableClassifier, _SingleLabelMixin): 18 | def __init__( 19 | self, 20 | base_model: str = "gpt-3.5-turbo-0613", 21 | default_label: str = "Random", 22 | key: Optional[str] = None, 23 | org: Optional[str] = None, 24 | n_epochs: Optional[int] = None, 25 | custom_suffix: Optional[str] = "skllm", 26 | prompt_template: Optional[str] = None, 27 | ): 28 | """ 29 | Tunable GPT-based text classifier. 30 | 31 | Parameters 32 | ---------- 33 | base_model : str, optional 34 | base model to use, by default "gpt-3.5-turbo-0613" 35 | default_label : str, optional 36 | default label for failed prediction; if "Random" -> selects randomly based on class frequencies, by default "Random" 37 | key : Optional[str], optional 38 | estimator-specific API key; if None, retrieved from the global config, by default None 39 | org : Optional[str], optional 40 | estimator-specific ORG key; if None, retrieved from the global config, by default None 41 | n_epochs : Optional[int], optional 42 | number of epochs; if None, determined automatically; by default None 43 | custom_suffix : Optional[str], optional 44 | custom suffix of the tuned model, used for naming purposes only, by default "skllm" 45 | prompt_template : Optional[str], optional 46 | custom prompt template to use, by default None 47 | """ 48 | super().__init__( 49 | model=None, default_label=default_label, prompt_template=prompt_template 50 | ) 51 | self._set_keys(key, org) 52 | self._set_hyperparameters(base_model, n_epochs, custom_suffix) 53 | 54 | 55 | class MultiLabelGPTClassifier(_TunableClassifier, _MultiLabelMixin): 56 | def __init__( 57 | self, 58 | base_model: str = "gpt-3.5-turbo-0613", 59 | default_label: str = "Random", 60 | key: Optional[str] = None, 61 | org: Optional[str] = None, 62 | n_epochs: Optional[int] = None, 63 | custom_suffix: Optional[str] = "skllm", 64 | prompt_template: Optional[str] = None, 65 | max_labels: Optional[int] = 5, 66 | ): 67 | """ 68 | Tunable multi-label GPT-based text classifier. 69 | 70 | Parameters 71 | ---------- 72 | base_model : str, optional 73 | base model to use, by default "gpt-3.5-turbo-0613" 74 | default_label : str, optional 75 | default label for failed prediction; if "Random" -> selects randomly based on class frequencies, by default "Random" 76 | key : Optional[str], optional 77 | estimator-specific API key; if None, retrieved from the global config, by default None 78 | org : Optional[str], optional 79 | estimator-specific ORG key; if None, retrieved from the global config, by default None 80 | n_epochs : Optional[int], optional 81 | number of epochs; if None, determined automatically; by default None 82 | custom_suffix : Optional[str], optional 83 | custom suffix of the tuned model, used for naming purposes only, by default "skllm" 84 | prompt_template : Optional[str], optional 85 | custom prompt template to use, by default None 86 | max_labels : Optional[int], optional 87 | maximum labels per sample, by default 5 88 | """ 89 | super().__init__( 90 | model=None, 91 | default_label=default_label, 92 | prompt_template=prompt_template, 93 | max_labels=max_labels, 94 | ) 95 | self._set_keys(key, org) 96 | self._set_hyperparameters(base_model, n_epochs, custom_suffix) 97 | -------------------------------------------------------------------------------- /skllm/models/gpt/classification/zero_shot.py: -------------------------------------------------------------------------------- 1 | from skllm.models._base.classifier import ( 2 | SingleLabelMixin as _SingleLabelMixin, 3 | MultiLabelMixin as _MultiLabelMixin, 4 | BaseZeroShotClassifier as _BaseZeroShotClassifier, 5 | BaseCoTClassifier as _BaseCoTClassifier, 6 | ) 7 | from skllm.llm.gpt.mixin import GPTClassifierMixin as _GPTClassifierMixin 8 | from typing import Optional 9 | 10 | 11 | class ZeroShotGPTClassifier( 12 | _BaseZeroShotClassifier, _GPTClassifierMixin, _SingleLabelMixin 13 | ): 14 | def __init__( 15 | self, 16 | model: str = "gpt-3.5-turbo", 17 | default_label: str = "Random", 18 | prompt_template: Optional[str] = None, 19 | key: Optional[str] = None, 20 | org: Optional[str] = None, 21 | **kwargs, 22 | ): 23 | """ 24 | Zero-shot text classifier using OpenAI/GPT API-compatible models. 25 | 26 | Parameters 27 | ---------- 28 | model : str, optional 29 | model to use, by default "gpt-3.5-turbo" 30 | default_label : str, optional 31 | default label for failed prediction; if "Random" -> selects randomly based on class frequencies, by default "Random" 32 | prompt_template : Optional[str], optional 33 | custom prompt template to use, by default None 34 | key : Optional[str], optional 35 | estimator-specific API key; if None, retrieved from the global config, by default None 36 | org : Optional[str], optional 37 | estimator-specific ORG key; if None, retrieved from the global config, by default None 38 | """ 39 | super().__init__( 40 | model=model, 41 | default_label=default_label, 42 | prompt_template=prompt_template, 43 | **kwargs, 44 | ) 45 | self._set_keys(key, org) 46 | 47 | 48 | class CoTGPTClassifier(_BaseCoTClassifier, _GPTClassifierMixin, _SingleLabelMixin): 49 | def __init__( 50 | self, 51 | model: str = "gpt-3.5-turbo", 52 | default_label: str = "Random", 53 | prompt_template: Optional[str] = None, 54 | key: Optional[str] = None, 55 | org: Optional[str] = None, 56 | **kwargs, 57 | ): 58 | """ 59 | Chain-of-thought text classifier using OpenAI/GPT API-compatible models. 60 | 61 | Parameters 62 | ---------- 63 | model : str, optional 64 | model to use, by default "gpt-3.5-turbo" 65 | default_label : str, optional 66 | default label for failed prediction; if "Random" -> selects randomly based on class frequencies, by default "Random" 67 | prompt_template : Optional[str], optional 68 | custom prompt template to use, by default None 69 | key : Optional[str], optional 70 | estimator-specific API key; if None, retrieved from the global config, by default None 71 | org : Optional[str], optional 72 | estimator-specific ORG key; if None, retrieved from the global config, by default None 73 | """ 74 | super().__init__( 75 | model=model, 76 | default_label=default_label, 77 | prompt_template=prompt_template, 78 | **kwargs, 79 | ) 80 | self._set_keys(key, org) 81 | 82 | 83 | class MultiLabelZeroShotGPTClassifier( 84 | _BaseZeroShotClassifier, _GPTClassifierMixin, _MultiLabelMixin 85 | ): 86 | def __init__( 87 | self, 88 | model: str = "gpt-3.5-turbo", 89 | default_label: str = "Random", 90 | max_labels: Optional[int] = 5, 91 | prompt_template: Optional[str] = None, 92 | key: Optional[str] = None, 93 | org: Optional[str] = None, 94 | **kwargs, 95 | ): 96 | """ 97 | Multi-label zero-shot text classifier using OpenAI/GPT API-compatible models. 98 | 99 | Parameters 100 | ---------- 101 | model : str, optional 102 | model to use, by default "gpt-3.5-turbo" 103 | default_label : str, optional 104 | default label for failed prediction; if "Random" -> selects randomly based on class frequencies, by default "Random" 105 | max_labels : Optional[int], optional 106 | maximum labels per sample, by default 5 107 | prompt_template : Optional[str], optional 108 | custom prompt template to use, by default None 109 | key : Optional[str], optional 110 | estimator-specific API key; if None, retrieved from the global config, by default None 111 | org : Optional[str], optional 112 | estimator-specific ORG key; if None, retrieved from the global config, by default None 113 | """ 114 | super().__init__( 115 | model=model, 116 | default_label=default_label, 117 | max_labels=max_labels, 118 | prompt_template=prompt_template, 119 | **kwargs, 120 | ) 121 | self._set_keys(key, org) 122 | -------------------------------------------------------------------------------- /skllm/models/gpt/tagging/ner.py: -------------------------------------------------------------------------------- 1 | from skllm.models._base.tagger import ExplainableNER as _ExplainableNER 2 | from skllm.llm.gpt.mixin import GPTTextCompletionMixin as _GPTTextCompletionMixin 3 | from typing import Optional, Dict 4 | 5 | 6 | class GPTExplainableNER(_ExplainableNER, _GPTTextCompletionMixin): 7 | def __init__( 8 | self, 9 | entities: Dict[str, str], 10 | display_predictions: bool = False, 11 | sparse_output: bool = True, 12 | model: str = "gpt-4o", 13 | key: Optional[str] = None, 14 | org: Optional[str] = None, 15 | num_workers: int = 1, 16 | ) -> None: 17 | """ 18 | Named entity recognition using OpenAI/GPT API-compatible models. 19 | 20 | Parameters 21 | ---------- 22 | entities : dict 23 | dictionary of entities to recognize, with keys as entity names and values as descriptions 24 | display_predictions : bool, optional 25 | whether to display predictions, by default False 26 | sparse_output : bool, optional 27 | whether to generate a sparse representation of the predictions, by default True 28 | model : str, optional 29 | model to use, by default "gpt-4o" 30 | key : Optional[str], optional 31 | estimator-specific API key; if None, retrieved from the global config, by default None 32 | org : Optional[str], optional 33 | estimator-specific ORG key; if None, retrieved from the global config, by default None 34 | num_workers : int, optional 35 | number of workers (threads) to use, by default 1 36 | """ 37 | self._set_keys(key, org) 38 | self.model = model 39 | self.entities = entities 40 | self.display_predictions = display_predictions 41 | self.sparse_output = sparse_output 42 | self.num_workers = num_workers -------------------------------------------------------------------------------- /skllm/models/gpt/text2text/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BeastByteAI/scikit-llm/5491ec8d1ba5528b560cd115f7a0c93369fb0628/skllm/models/gpt/text2text/__init__.py -------------------------------------------------------------------------------- /skllm/models/gpt/text2text/summarization.py: -------------------------------------------------------------------------------- 1 | from skllm.models._base.text2text import BaseSummarizer as _BaseSummarizer 2 | from skllm.llm.gpt.mixin import GPTTextCompletionMixin as _GPTTextCompletionMixin 3 | from typing import Optional 4 | 5 | 6 | class GPTSummarizer(_BaseSummarizer, _GPTTextCompletionMixin): 7 | def __init__( 8 | self, 9 | model: str = "gpt-3.5-turbo", 10 | key: Optional[str] = None, 11 | org: Optional[str] = None, 12 | max_words: int = 15, 13 | focus: Optional[str] = None, 14 | ) -> None: 15 | """ 16 | Text summarizer using OpenAI/GPT API-compatible models. 17 | 18 | Parameters 19 | ---------- 20 | model : str, optional 21 | model to use, by default "gpt-3.5-turbo" 22 | key : Optional[str], optional 23 | estimator-specific API key; if None, retrieved from the global config, by default None 24 | org : Optional[str], optional 25 | estimator-specific ORG key; if None, retrieved from the global config, by default None 26 | max_words : int, optional 27 | soft limit of the summary length, by default 15 28 | focus : Optional[str], optional 29 | concept in the text to focus on, by default None 30 | """ 31 | self._set_keys(key, org) 32 | self.model = model 33 | self.max_words = max_words 34 | self.focus = focus 35 | -------------------------------------------------------------------------------- /skllm/models/gpt/text2text/translation.py: -------------------------------------------------------------------------------- 1 | from skllm.models._base.text2text import BaseTranslator as _BaseTranslator 2 | from skllm.llm.gpt.mixin import GPTTextCompletionMixin as _GPTTextCompletionMixin 3 | from typing import Optional 4 | 5 | 6 | class GPTTranslator(_BaseTranslator, _GPTTextCompletionMixin): 7 | default_output = "Translation is unavailable." 8 | 9 | def __init__( 10 | self, 11 | model: str = "gpt-3.5-turbo", 12 | key: Optional[str] = None, 13 | org: Optional[str] = None, 14 | output_language: str = "English", 15 | ) -> None: 16 | """ 17 | Text translator using OpenAI/GPT API-compatible models. 18 | 19 | Parameters 20 | ---------- 21 | model : str, optional 22 | model to use, by default "gpt-3.5-turbo" 23 | key : Optional[str], optional 24 | estimator-specific API key; if None, retrieved from the global config, by default None 25 | org : Optional[str], optional 26 | estimator-specific ORG key; if None, retrieved from the global config, by default None 27 | output_language : str, optional 28 | target language, by default "English" 29 | """ 30 | self._set_keys(key, org) 31 | self.model = model 32 | self.output_language = output_language 33 | -------------------------------------------------------------------------------- /skllm/models/gpt/text2text/tunable.py: -------------------------------------------------------------------------------- 1 | from skllm.llm.gpt.mixin import ( 2 | GPTTunableMixin as _GPTTunableMixin, 3 | GPTTextCompletionMixin as _GPTTextCompletionMixin, 4 | ) 5 | from skllm.models._base.text2text import ( 6 | BaseTunableText2TextModel as _BaseTunableText2TextModel, 7 | ) 8 | from typing import Optional 9 | 10 | 11 | class TunableGPTText2Text( 12 | _BaseTunableText2TextModel, _GPTTextCompletionMixin, _GPTTunableMixin 13 | ): 14 | def __init__( 15 | self, 16 | base_model: str = "gpt-3.5-turbo-0613", 17 | key: Optional[str] = None, 18 | org: Optional[str] = None, 19 | n_epochs: Optional[int] = None, 20 | custom_suffix: Optional[str] = "skllm", 21 | ): 22 | """ 23 | Tunable GPT-based text-to-text model. 24 | 25 | Parameters 26 | ---------- 27 | base_model : str, optional 28 | base model to use, by default "gpt-3.5-turbo-0613" 29 | key : Optional[str], optional 30 | estimator-specific API key; if None, retrieved from the global config, by default None 31 | org : Optional[str], optional 32 | estimator-specific ORG key; if None, retrieved from the global config, by default None 33 | n_epochs : Optional[int], optional 34 | number of epochs; if None, determined automatically; by default None 35 | custom_suffix : Optional[str], optional 36 | custom suffix of the tuned model, used for naming purposes only, by default "skllm" 37 | """ 38 | self.model = None 39 | self._set_keys(key, org) 40 | self._set_hyperparameters(base_model, n_epochs, custom_suffix) 41 | -------------------------------------------------------------------------------- /skllm/models/gpt/vectorization.py: -------------------------------------------------------------------------------- 1 | from skllm.models._base.vectorizer import BaseVectorizer as _BaseVectorizer 2 | from skllm.llm.gpt.mixin import GPTEmbeddingMixin as _GPTEmbeddingMixin 3 | from typing import Optional 4 | 5 | 6 | class GPTVectorizer(_BaseVectorizer, _GPTEmbeddingMixin): 7 | def __init__( 8 | self, 9 | model: str = "text-embedding-3-small", 10 | batch_size: int = 1, 11 | key: Optional[str] = None, 12 | org: Optional[str] = None, 13 | ): 14 | """ 15 | Text vectorizer using OpenAI/GPT API-compatible models. 16 | 17 | Parameters 18 | ---------- 19 | model : str, optional 20 | model to use, by default "text-embedding-ada-002" 21 | batch_size : int, optional 22 | number of samples per request, by default 1 23 | key : Optional[str], optional 24 | estimator-specific API key; if None, retrieved from the global config, by default None 25 | org : Optional[str], optional 26 | estimator-specific ORG key; if None, retrieved from the global config, by default None 27 | """ 28 | super().__init__(model=model, batch_size=batch_size) 29 | self._set_keys(key, org) 30 | -------------------------------------------------------------------------------- /skllm/models/vertex/classification/tunable.py: -------------------------------------------------------------------------------- 1 | from skllm.models._base.classifier import ( 2 | BaseTunableClassifier as _BaseTunableClassifier, 3 | SingleLabelMixin as _SingleLabelMixin, 4 | MultiLabelMixin as _MultiLabelMixin, 5 | ) 6 | from skllm.llm.vertex.mixin import ( 7 | VertexClassifierMixin as _VertexClassifierMixin, 8 | VertexTunableMixin as _VertexTunableMixin, 9 | ) 10 | from typing import Optional 11 | 12 | 13 | class _TunableClassifier( 14 | _BaseTunableClassifier, _VertexClassifierMixin, _VertexTunableMixin 15 | ): 16 | pass 17 | 18 | 19 | class VertexClassifier(_TunableClassifier, _SingleLabelMixin): 20 | def __init__( 21 | self, 22 | base_model: str = "text-bison@002", 23 | n_update_steps: int = 1, 24 | default_label: str = "Random", 25 | ): 26 | """ 27 | Tunable Vertex-based text classifier. 28 | 29 | Parameters 30 | ---------- 31 | base_model : str, optional 32 | base model to use, by default "text-bison@002" 33 | n_update_steps : int, optional 34 | number of epochs, by default 1 35 | default_label : str, optional 36 | default label for failed prediction; if "Random" -> selects randomly based on class frequencies, by default "Random" 37 | """ 38 | self._set_hyperparameters(base_model=base_model, n_update_steps=n_update_steps) 39 | super().__init__( 40 | model=None, 41 | default_label=default_label, 42 | ) 43 | -------------------------------------------------------------------------------- /skllm/models/vertex/classification/zero_shot.py: -------------------------------------------------------------------------------- 1 | from skllm.llm.vertex.mixin import VertexClassifierMixin as _VertexClassifierMixin 2 | from skllm.models._base.classifier import ( 3 | BaseZeroShotClassifier as _BaseZeroShotClassifier, 4 | SingleLabelMixin as _SingleLabelMixin, 5 | MultiLabelMixin as _MultiLabelMixin, 6 | ) 7 | from typing import Optional 8 | 9 | 10 | class ZeroShotVertexClassifier( 11 | _BaseZeroShotClassifier, _SingleLabelMixin, _VertexClassifierMixin 12 | ): 13 | def __init__( 14 | self, 15 | model: str = "text-bison@002", 16 | default_label: str = "Random", 17 | prompt_template: Optional[str] = None, 18 | **kwargs, 19 | ): 20 | """ 21 | Zero-shot text classifier using Vertex AI models. 22 | 23 | Parameters 24 | ---------- 25 | model : str, optional 26 | model to use, by default "text-bison@002" 27 | default_label : str, optional 28 | default label for failed prediction; if "Random" -> selects randomly based on class frequencies, by default "Random" 29 | prompt_template : Optional[str], optional 30 | custom prompt template to use, by default None 31 | """ 32 | super().__init__( 33 | model=model, 34 | default_label=default_label, 35 | prompt_template=prompt_template, 36 | **kwargs, 37 | ) 38 | 39 | 40 | class MultiLabelZeroShotVertexClassifier( 41 | _BaseZeroShotClassifier, _MultiLabelMixin, _VertexClassifierMixin 42 | ): 43 | def __init__( 44 | self, 45 | model: str = "text-bison@002", 46 | default_label: str = "Random", 47 | prompt_template: Optional[str] = None, 48 | max_labels: Optional[int] = 5, 49 | **kwargs, 50 | ): 51 | """ 52 | Multi-label zero-shot text classifier using Vertex AI models. 53 | 54 | Parameters 55 | ---------- 56 | model : str, optional 57 | model to use, by default "text-bison@002" 58 | default_label : str, optional 59 | default label for failed prediction; if "Random" -> selects randomly based on class frequencies, by default "Random" 60 | prompt_template : Optional[str], optional 61 | custom prompt template to use, by default None 62 | max_labels : Optional[int], optional 63 | maximum labels per sample, by default 5 64 | """ 65 | super().__init__( 66 | model=model, 67 | default_label=default_label, 68 | prompt_template=prompt_template, 69 | max_labels=max_labels, 70 | **kwargs, 71 | ) 72 | -------------------------------------------------------------------------------- /skllm/models/vertex/text2text/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BeastByteAI/scikit-llm/5491ec8d1ba5528b560cd115f7a0c93369fb0628/skllm/models/vertex/text2text/__init__.py -------------------------------------------------------------------------------- /skllm/models/vertex/text2text/tunable.py: -------------------------------------------------------------------------------- 1 | from skllm.llm.vertex.mixin import ( 2 | VertexTunableMixin as _VertexTunableMixin, 3 | VertexTextCompletionMixin as _VertexTextCompletionMixin, 4 | ) 5 | from skllm.models._base.text2text import ( 6 | BaseTunableText2TextModel as _BaseTunableText2TextModel, 7 | ) 8 | 9 | 10 | class TunableVertexText2Text( 11 | _BaseTunableText2TextModel, _VertexTextCompletionMixin, _VertexTunableMixin 12 | ): 13 | def __init__( 14 | self, 15 | base_model: str = "text-bison@002", 16 | n_update_steps: int = 1, 17 | ): 18 | """ 19 | Tunable Vertex-based text-to-text model. 20 | 21 | Parameters 22 | ---------- 23 | base_model : str, optional 24 | base model to use, by default "text-bison@002" 25 | n_update_steps : int, optional 26 | number of epochs, by default 1 27 | """ 28 | self.model = None 29 | self._set_hyperparameters(base_model=base_model, n_update_steps=n_update_steps) 30 | -------------------------------------------------------------------------------- /skllm/prompts/builders.py: -------------------------------------------------------------------------------- 1 | from typing import Union, Optional 2 | 3 | from skllm.prompts.templates import ( 4 | FEW_SHOT_CLF_PROMPT_TEMPLATE, 5 | FEW_SHOT_MLCLF_PROMPT_TEMPLATE, 6 | FOCUSED_SUMMARY_PROMPT_TEMPLATE, 7 | SUMMARY_PROMPT_TEMPLATE, 8 | TRANSLATION_PROMPT_TEMPLATE, 9 | ZERO_SHOT_CLF_PROMPT_TEMPLATE, 10 | ZERO_SHOT_MLCLF_PROMPT_TEMPLATE, 11 | EXPLAINABLE_NER_DENSE_PROMPT_TEMPLATE, 12 | ) 13 | 14 | # TODO add validators 15 | 16 | 17 | def build_zero_shot_prompt_slc( 18 | x: str, labels: str, template: str = ZERO_SHOT_CLF_PROMPT_TEMPLATE 19 | ) -> str: 20 | """Builds a prompt for zero-shot single-label classification. 21 | 22 | Parameters 23 | ---------- 24 | x : str 25 | sample to classify 26 | labels : str 27 | candidate labels in a list-like representation 28 | template : str 29 | prompt template to use, must contain placeholders for all variables, by default ZERO_SHOT_CLF_PROMPT_TEMPLATE 30 | 31 | Returns 32 | ------- 33 | str 34 | prepared prompt 35 | """ 36 | return template.format(x=x, labels=labels) 37 | 38 | 39 | def build_few_shot_prompt_slc( 40 | x: str, 41 | labels: str, 42 | training_data: str, 43 | template: str = FEW_SHOT_CLF_PROMPT_TEMPLATE, 44 | ) -> str: 45 | """Builds a prompt for zero-shot single-label classification. 46 | 47 | Parameters 48 | ---------- 49 | x : str 50 | sample to classify 51 | labels : str 52 | candidate labels in a list-like representation 53 | training_data : str 54 | training data to be used for few-shot learning 55 | template : str 56 | prompt template to use, must contain placeholders for all variables, by default ZERO_SHOT_CLF_PROMPT_TEMPLATE 57 | 58 | Returns 59 | ------- 60 | str 61 | prepared prompt 62 | """ 63 | return template.format(x=x, labels=labels, training_data=training_data) 64 | 65 | 66 | def build_few_shot_prompt_mlc( 67 | x: str, 68 | labels: str, 69 | training_data: str, 70 | max_cats: Union[int, str], 71 | template: str = FEW_SHOT_MLCLF_PROMPT_TEMPLATE, 72 | ) -> str: 73 | """Builds a prompt for few-shot single-label classification. 74 | 75 | Parameters 76 | ---------- 77 | x : str 78 | sample to classify 79 | labels : str 80 | candidate labels in a list-like representation 81 | max_cats : Union[int,str] 82 | maximum number of categories to assign 83 | training_data : str 84 | training data to be used for few-shot learning 85 | template : str 86 | prompt template to use, must contain placeholders for all variables, by default ZERO_SHOT_CLF_PROMPT_TEMPLATE 87 | 88 | Returns 89 | ------- 90 | str 91 | prepared prompt 92 | """ 93 | return template.format( 94 | x=x, labels=labels, training_data=training_data, max_cats=max_cats 95 | ) 96 | 97 | 98 | def build_zero_shot_prompt_mlc( 99 | x: str, 100 | labels: str, 101 | max_cats: Union[int, str], 102 | template: str = ZERO_SHOT_MLCLF_PROMPT_TEMPLATE, 103 | ) -> str: 104 | """Builds a prompt for zero-shot multi-label classification. 105 | 106 | Parameters 107 | ---------- 108 | x : str 109 | sample to classify 110 | labels : str 111 | candidate labels in a list-like representation 112 | max_cats : Union[int,str] 113 | maximum number of categories to assign 114 | template : str 115 | prompt template to use, must contain placeholders for all variables, by default ZERO_SHOT_MLCLF_PROMPT_TEMPLATE 116 | 117 | Returns 118 | ------- 119 | str 120 | prepared prompt 121 | """ 122 | return template.format(x=x, labels=labels, max_cats=max_cats) 123 | 124 | 125 | def build_summary_prompt( 126 | x: str, max_words: Union[int, str], template: str = SUMMARY_PROMPT_TEMPLATE 127 | ) -> str: 128 | """Builds a prompt for text summarization. 129 | 130 | Parameters 131 | ---------- 132 | x : str 133 | sample to summarize 134 | max_words : Union[int,str] 135 | maximum number of words to use in the summary 136 | template : str 137 | prompt template to use, must contain placeholders for all variables, by default SUMMARY_PROMPT_TEMPLATE 138 | 139 | Returns 140 | ------- 141 | str 142 | prepared prompt 143 | """ 144 | return template.format(x=x, max_words=max_words) 145 | 146 | 147 | def build_focused_summary_prompt( 148 | x: str, 149 | max_words: Union[int, str], 150 | focus: Union[int, str], 151 | template: str = FOCUSED_SUMMARY_PROMPT_TEMPLATE, 152 | ) -> str: 153 | """Builds a prompt for focused text summarization. 154 | 155 | Parameters 156 | ---------- 157 | x : str 158 | sample to summarize 159 | max_words : Union[int,str] 160 | maximum number of words to use in the summary 161 | focus : Union[int,str] 162 | the topic(s) to focus on 163 | template : str 164 | prompt template to use, must contain placeholders for all variables, by default FOCUSED_SUMMARY_PROMPT_TEMPLATE 165 | 166 | Returns 167 | ------- 168 | str 169 | prepared prompt 170 | """ 171 | return template.format(x=x, max_words=max_words, focus=focus) 172 | 173 | 174 | def build_translation_prompt( 175 | x: str, output_language: str, template: str = TRANSLATION_PROMPT_TEMPLATE 176 | ) -> str: 177 | """Builds a prompt for text translation. 178 | 179 | Parameters 180 | ---------- 181 | x : str 182 | sample to translate 183 | output_language : str 184 | language to translate to 185 | template : str 186 | prompt template to use, must contain placeholders for all variables, by default TRANSLATION_PROMPT_TEMPLATE 187 | 188 | Returns 189 | ------- 190 | str 191 | prepared prompt 192 | """ 193 | return template.format(x=x, output_language=output_language) 194 | 195 | 196 | def build_ner_prompt( 197 | entities: list, 198 | x: str, 199 | template: str = EXPLAINABLE_NER_DENSE_PROMPT_TEMPLATE, 200 | ) -> str: 201 | """Builds a prompt for named entity recognition. 202 | 203 | Parameters 204 | ---------- 205 | entities : list 206 | list of entities to recognize 207 | x : str 208 | sample to recognize entities in 209 | template : str, optional 210 | prompt template to use, must contain placeholders for all variables, by default EXPLAINABLE_NER_DENSE_PROMPT_TEMPLATE 211 | 212 | Returns 213 | ------- 214 | str 215 | prepared prompt 216 | """ 217 | return template.format(entities=entities, x=x) 218 | -------------------------------------------------------------------------------- /skllm/prompts/templates.py: -------------------------------------------------------------------------------- 1 | ZERO_SHOT_CLF_PROMPT_TEMPLATE = """ 2 | You will be provided with the following information: 3 | 1. An arbitrary text sample. The sample is delimited with triple backticks. 4 | 2. List of categories the text sample can be assigned to. The list is delimited with square brackets. The categories in the list are enclosed in the single quotes and comma separated. 5 | 6 | Perform the following tasks: 7 | 1. Identify to which category the provided text belongs to with the highest probability. 8 | 2. Assign the provided text to that category. 9 | 3. Provide your response in a JSON format containing a single key `label` and a value corresponding to the assigned category. Do not provide any additional information except the JSON. 10 | 11 | List of categories: {labels} 12 | 13 | Text sample: ```{x}``` 14 | 15 | Your JSON response: 16 | """ 17 | 18 | COT_CLF_PROMPT_TEMPLATE = """ 19 | You are tasked with classifying a given text sample based on a list of potential categories. Please adhere to the following guidelines: 20 | 21 | 1. The text intended for classification is presented between triple backticks. 22 | 2. The possible categories are enumerated in square brackets, with each category enclosed in single quotes and separated by commas. 23 | 24 | Tasks: 25 | 1. Examine the text and provide detailed justifications for the possibility of the text belonging or not belonging to each category listed. 26 | 2. Determine and select the most appropriate category for the text based on your comprehensive justifications. 27 | 3. Format your decision into a JSON object containing two keys: `explanation` and `label`. The `explanation` should concisely capture the rationale for each category before concluding with the chosen category. 28 | 29 | Category List: {labels} 30 | 31 | Text Sample: ```{x}``` 32 | 33 | Provide your JSON response below, ensuring that justifications for all categories are clearly detailed: 34 | """ 35 | 36 | ZERO_SHOT_CLF_SHORT_PROMPT_TEMPLATE = """ 37 | Classify the following text into one of the following classes: {labels}. Provide your response in a JSON format containing a single key `label`. 38 | Text: ```{x}``` 39 | """ 40 | 41 | ZERO_SHOT_MLCLF_SHORT_PROMPT_TEMPLATE = """ 42 | Classify the following text into at least 1 but up to {max_cats} of the following classes: {labels}. Provide your response in a JSON format containing a single key `label`. 43 | Text: ```{x}``` 44 | """ 45 | 46 | FEW_SHOT_CLF_PROMPT_TEMPLATE = """ 47 | You will be provided with the following information: 48 | 1. An arbitrary text sample. The sample is delimited with triple backticks. 49 | 2. List of categories the text sample can be assigned to. The list is delimited with square brackets. The categories in the list are enclosed in the single quotes and comma separated. 50 | 3. Examples of text samples and their assigned categories. The examples are delimited with triple backticks. The assigned categories are enclosed in a list-like structure. These examples are to be used as training data. 51 | 52 | Perform the following tasks: 53 | 1. Identify to which category the provided text belongs to with the highest probability. 54 | 2. Assign the provided text to that category. 55 | 3. Provide your response in a JSON format containing a single key `label` and a value corresponding to the assigned category. Do not provide any additional information except the JSON. 56 | 57 | List of categories: {labels} 58 | 59 | Training data: 60 | {training_data} 61 | 62 | Text sample: ```{x}``` 63 | 64 | Your JSON response: 65 | """ 66 | 67 | FEW_SHOT_MLCLF_PROMPT_TEMPLATE = """ 68 | You will be provided with the following information: 69 | 1. An arbitrary text sample. The sample is delimited with triple backticks. 70 | 2. List of categories the text sample can be assigned to. The list is delimited with square brackets. The categories in the list are enclosed in the single quotes and comma separated. 71 | 3. Examples of text samples and their assigned categories. The examples are delimited with triple backticks. The assigned categories are enclosed in a list-like structure. These examples are to be used as training data. 72 | 73 | Perform the following tasks: 74 | 1. Identify to which category the provided text belongs to with the highest probability. 75 | 2. Assign the text sample to at least 1 but up to {max_cats} categories based on the probabilities. 76 | 3. Provide your response in a JSON format containing a single key `label` and a value corresponding to the array of assigned categories. Do not provide any additional information except the JSON. 77 | 78 | List of categories: {labels} 79 | 80 | Training data: 81 | {training_data} 82 | 83 | Text sample: ```{x}``` 84 | 85 | Your JSON response: 86 | """ 87 | 88 | ZERO_SHOT_MLCLF_PROMPT_TEMPLATE = """ 89 | You will be provided with the following information: 90 | 1. An arbitrary text sample. The sample is delimited with triple backticks. 91 | 2. List of categories the text sample can be assigned to. The list is delimited with square brackets. The categories in the list are enclosed in the single quotes and comma separated. The text sample belongs to at least one category but cannot exceed {max_cats}. 92 | 93 | Perform the following tasks: 94 | 1. Identify to which categories the provided text belongs to with the highest probability. 95 | 2. Assign the text sample to at least 1 but up to {max_cats} categories based on the probabilities. 96 | 3. Provide your response in a JSON format containing a single key `label` and a value corresponding to the array of assigned categories. Do not provide any additional information except the JSON. 97 | 98 | List of categories: {labels} 99 | 100 | Text sample: ```{x}``` 101 | 102 | Your JSON response: 103 | """ 104 | 105 | COT_MLCLF_PROMPT_TEMPLATE = """ 106 | You are tasked with classifying a given text sample based on a list of potential categories. Please adhere to the following guidelines: 107 | 108 | 1. The text intended for classification is presented between triple backticks. 109 | 2. The possible categories are enumerated in square brackets, with each category enclosed in quotes and separated by commas. 110 | 111 | Tasks: 112 | 1. Examine the text and provide detailed justifications for the possibility of the text belonging or not belonging to each category listed. 113 | 2. Determine and select at most {max_cats} most appropriate categories for the text based on your comprehensive justifications. 114 | 3. Format your decision into a JSON object containing two keys: `explanation` and `label`. The `explanation` should concisely capture the rationale for each category before concluding with the chosen category. The `label` should contain an array of the chosen categories. 115 | 116 | Category List: {labels} 117 | 118 | Text Sample: ```{x}``` 119 | 120 | Provide your JSON response below, ensuring that justifications for all categories are clearly detailed: 121 | """ 122 | 123 | SUMMARY_PROMPT_TEMPLATE = """ 124 | Your task is to generate a summary of the text sample. 125 | Summarize the text sample provided below, delimited by triple backticks, in at most {max_words} words. 126 | 127 | Text sample: ```{x}``` 128 | Summarized text: 129 | """ 130 | 131 | FOCUSED_SUMMARY_PROMPT_TEMPLATE = """ 132 | As an input you will receive: 133 | 1. A focus parameter delimited with square brackets. 134 | 2. A single text sample delimited with triple backticks. 135 | 136 | Perform the following actions: 137 | 1. Determine whether there is something in the text that matches focus. Do not output anything. 138 | 2. Summarise the text in at most {max_words} words. 139 | 3. If possible, make the summarisation focused on the concept provided in the focus parameter. Otherwise, provide a general summarisation. Do not state that general summary is provided. 140 | 4. Do not output anything except of the summary. Do not output any text that was not present in the original text. 141 | 5. If no focused summary possible, or the mentioned concept is not present in the text, output "Mentioned concept is not present in the text." and the general summary. Do not state that general summary is provided. 142 | 143 | Focus: [{focus}] 144 | 145 | Text sample: ```{x}``` 146 | 147 | Summarized text: 148 | """ 149 | 150 | TRANSLATION_PROMPT_TEMPLATE = """ 151 | If the original text, delimited by triple backticks, is already in {output_language} language, output the original text. 152 | Otherwise, translate the original text, delimited by triple backticks, to {output_language} language, and output the translated text only. Do not output any additional information except the translated text. 153 | 154 | Original text: ```{x}``` 155 | Output: 156 | """ 157 | 158 | NER_SYSTEM_MESSAGE_TEMPLATE = """You are an expert in Natural Language Processing. Your task is to identify common Named Entities (NER) in a text provided by the user. 159 | Mark the entities with tags according to the following guidelines: 160 | - Use XML format to tag entities; 161 | - All entities must be enclosed in ... tags; All other text must be enclosed in ... tags; No content should be outside of these tags; 162 | - The tagging operation must be invertible, i.e. the original text must be recoverable from the tagged textl; This is crucial and easy to overlook, double-check this requirement; 163 | - Adjacent entities should be separated into different tags; 164 | - The list of entities is strictly restricted to the following: {entities}. 165 | """ 166 | 167 | NER_SYSTEM_MESSAGE_SPARSE = """You are an expert in Natural Language Processing.""" 168 | 169 | EXPLAINABLE_NER_DENSE_PROMPT_TEMPLATE = """You are provided with a text. Your task is to identify and tag all named entities within the text using the following entity types only: 170 | {entities} 171 | 172 | For each entity, provide a brief explanation for your choice within an XML comment. Use the following XML tag format for each entity: 173 | 174 | Your reasoning hereENTITY_NAME_UPPERCASEEntity text 175 | 176 | The remaining text must be enclosed in a TEXT tag. 177 | 178 | Focus on the context and meaning of each entity rather than just the exact words. The tags should encompass the entire entity based on its definition and usage in the sentence. It is crucial to base your decision on the description of the entity, not just its name. 179 | 180 | Format example: 181 | 182 | Input: 183 | ```This text contains some entity and another entity.``` 184 | 185 | Output: 186 | ```xml 187 | This text contains some justificationENTITY1some entity and another another justificationENTITY2entity. 188 | ``` 189 | 190 | Input: 191 | ``` 192 | {x} 193 | ``` 194 | 195 | Output (origina text with tags): 196 | """ 197 | 198 | 199 | EXPLAINABLE_NER_SPARSE_PROMPT_TEMPLATE = """You are provided with a text. Your task is to identify and tag all named entities within the text using the following entity types only: 200 | {entities} 201 | 202 | You must provide the following information for each entity: 203 | - The reasoning of why you tagged the entity as such; Based on the reasoning, a non-expert should be able to evaluate your decision; 204 | - The tag of the entity (uppercase); 205 | - The value of the entity (as it appears in the text). 206 | 207 | Your response should be json formatted using the following schema: 208 | 209 | {{ 210 | "$schema": "http://json-schema.org/draft-04/schema#", 211 | "type": "array", 212 | "items": [ 213 | {{ 214 | "type": "object", 215 | "properties": {{ 216 | "reasoning": {{ 217 | "type": "string" 218 | }}, 219 | "tag": {{ 220 | "type": "string" 221 | }}, 222 | "value": {{ 223 | "type": "string" 224 | }} 225 | }}, 226 | "required": [ 227 | "reasoning", 228 | "tag", 229 | "value" 230 | ] 231 | }} 232 | ] 233 | }} 234 | 235 | 236 | Input: 237 | ``` 238 | {x} 239 | ``` 240 | 241 | Output json: 242 | """ 243 | -------------------------------------------------------------------------------- /skllm/text2text.py: -------------------------------------------------------------------------------- 1 | ## GPT 2 | from skllm.models.gpt.text2text.summarization import GPTSummarizer 3 | from skllm.models.gpt.text2text.translation import GPTTranslator 4 | from skllm.models.gpt.text2text.tunable import TunableGPTText2Text 5 | 6 | ## Vertex 7 | 8 | from skllm.models.vertex.text2text.tunable import TunableVertexText2Text 9 | -------------------------------------------------------------------------------- /skllm/utils/__init__.py: -------------------------------------------------------------------------------- 1 | import json 2 | from typing import Any 3 | import numpy as np 4 | import pandas as pd 5 | from functools import wraps 6 | from time import sleep 7 | import re 8 | 9 | def to_numpy(X: Any) -> np.ndarray: 10 | """Converts a pandas Series or list to a numpy array. 11 | 12 | Parameters 13 | ---------- 14 | X : Any 15 | The data to convert to a numpy array. 16 | 17 | Returns 18 | ------- 19 | X : np.ndarray 20 | """ 21 | if isinstance(X, pd.Series): 22 | X = X.to_numpy().astype(object) 23 | elif isinstance(X, list): 24 | X = np.asarray(X, dtype=object) 25 | if isinstance(X, np.ndarray) and len(X.shape) > 1: 26 | # do not squeeze the first dim 27 | X = np.squeeze(X, axis=tuple([i for i in range(1, len(X.shape))])) 28 | return X 29 | 30 | # TODO: replace with re version below 31 | def find_json_in_string(string: str) -> str: 32 | """Finds the JSON object in a string. 33 | 34 | Parameters 35 | ---------- 36 | string : str 37 | The string to search for a JSON object. 38 | 39 | Returns 40 | ------- 41 | json_string : str 42 | """ 43 | start = string.find("{") 44 | end = string.rfind("}") 45 | if start != -1 and end != -1: 46 | json_string = string[start : end + 1] 47 | else: 48 | json_string = "{}" 49 | return json_string 50 | 51 | 52 | 53 | def re_naive_json_extractor(json_string: str, expected_output: str = "object") -> str: 54 | """Finds the first JSON-like object or array in a string using regex. 55 | 56 | Parameters 57 | ---------- 58 | string : str 59 | The string to search for a JSON object or array. 60 | 61 | Returns 62 | ------- 63 | json_string : str 64 | A JSON string if found, otherwise an empty JSON object. 65 | """ 66 | json_pattern = json_pattern = r'(\{.*\}|\[.*\])' 67 | match = re.search(json_pattern, json_string, re.DOTALL) 68 | if match: 69 | return match.group(0) 70 | else: 71 | return r"{}" if expected_output == "object" else "[]" 72 | 73 | 74 | 75 | 76 | def extract_json_key(json_: str, key: str): 77 | """Extracts JSON key from a string. 78 | 79 | json_ : str 80 | The JSON string to extract the key from. 81 | key : str 82 | The key to extract. 83 | """ 84 | original_json = json_ 85 | for i in range(2): 86 | try: 87 | json_ = original_json.replace("\n", "") 88 | if i == 1: 89 | json_ = json_.replace("'", '"') 90 | json_ = find_json_in_string(json_) 91 | as_json = json.loads(json_) 92 | if key not in as_json.keys(): 93 | raise KeyError("The required key was not found") 94 | return as_json[key] 95 | except Exception: 96 | if i == 0: 97 | continue 98 | return None 99 | 100 | 101 | def retry(max_retries=3): 102 | def decorator(func): 103 | @wraps(func) 104 | def wrapper(*args, **kwargs): 105 | for attempt in range(max_retries): 106 | try: 107 | return func(*args, **kwargs) 108 | except Exception as e: 109 | error_msg = str(e) 110 | error_type = type(e).__name__ 111 | sleep(2**attempt) 112 | err_msg = ( 113 | f"Could not complete the operation after {max_retries} retries:" 114 | f" `{error_type} :: {error_msg}`" 115 | ) 116 | print(err_msg) 117 | raise RuntimeError(err_msg) 118 | 119 | return wrapper 120 | 121 | return decorator 122 | -------------------------------------------------------------------------------- /skllm/utils/rendering.py: -------------------------------------------------------------------------------- 1 | import random 2 | import re 3 | import html 4 | from typing import Dict, List 5 | 6 | color_palettes = { 7 | "light": [ 8 | "lightblue", 9 | "lightgreen", 10 | "lightcoral", 11 | "lightsalmon", 12 | "lightyellow", 13 | "lightpink", 14 | "lightgray", 15 | "lightcyan", 16 | ], 17 | "dark": [ 18 | "darkblue", 19 | "darkgreen", 20 | "darkred", 21 | "darkorange", 22 | "darkgoldenrod", 23 | "darkmagenta", 24 | "darkgray", 25 | "darkcyan", 26 | ], 27 | } 28 | 29 | 30 | def get_random_color(): 31 | return f"#{random.randint(0, 0xFFFFFF):06x}" 32 | 33 | 34 | # def validate_text(input_text, output_text): 35 | # # Verify the original text was not changed (other than addition of tags) 36 | # stripped_output_text = re.sub(r'<<.*?>>', '', output_text) 37 | # stripped_output_text = re.sub(r'<>', '', stripped_output_text) 38 | # if not all(word in stripped_output_text.split() for word in input_text.split()): 39 | # raise ValueError("Original text was altered.") 40 | # return True 41 | 42 | 43 | # TODO: In the future this should probably be replaced with a proper HTML template 44 | def render_ner(output_texts, allowed_entities): 45 | entity_colors = {} 46 | all_entities = [k.upper() for k in allowed_entities.keys()] 47 | 48 | for i, entity in enumerate(all_entities): 49 | if i < len(color_palettes["light"]): 50 | entity_colors[entity] = { 51 | "light": color_palettes["light"][i], 52 | "dark": color_palettes["dark"][i], 53 | } 54 | else: 55 | random_color = get_random_color() 56 | entity_colors[entity] = {"light": random_color, "dark": random_color} 57 | 58 | def replace_match(match): 59 | reasoning, entity, text = match.groups() 60 | entity = entity.upper() 61 | return ( 62 | f'{text}' 64 | ) 65 | 66 | legend_html = "
" 67 | legend_html += "" 76 | legend_html += "Entities: " 77 | for entity in entity_colors.keys(): 78 | description = allowed_entities.get(entity, "No description") 79 | legend_html += ( 80 | f'{entity} ' 82 | ) 83 | legend_html += "

" 84 | 85 | css = "" 100 | 101 | rendered_html = "" 102 | for output_text in output_texts: 103 | none_pattern = re.compile(r"(.*?)") 104 | output_text = none_pattern.sub(r'\1', output_text) 105 | pattern = re.compile(r"(.*?)(.*?)(.*?)") 106 | highlighted_html = pattern.sub(replace_match, output_text) 107 | rendered_html += highlighted_html + "
" 108 | 109 | return css + legend_html + rendered_html 110 | 111 | 112 | def display_ner(output_texts: List[str], allowed_entities: Dict[str, str]): 113 | rendered_html = render_ner(output_texts, allowed_entities) 114 | if is_running_in_jupyter(): 115 | from IPython.display import display, HTML 116 | 117 | display(HTML(rendered_html)) 118 | else: 119 | with open("skllm_ner_output.html", "w") as f: 120 | f.write(rendered_html) 121 | try: 122 | import webbrowser 123 | 124 | webbrowser.open("skllm_ner_output.html") 125 | except Exception: 126 | print( 127 | "Output saved to 'skllm_ner_output.html', please open it in a browser." 128 | ) 129 | 130 | 131 | def is_running_in_jupyter(): 132 | try: 133 | from IPython import get_ipython 134 | 135 | if "IPKernelApp" in get_ipython().config: 136 | return True 137 | except Exception: 138 | return False 139 | return False 140 | -------------------------------------------------------------------------------- /skllm/utils/xml.py: -------------------------------------------------------------------------------- 1 | import re 2 | 3 | 4 | def filter_xml_tags(xml_string, tags): 5 | pattern = "|".join(f"<{tag}>.*?" for tag in tags) 6 | regex = re.compile(pattern, re.DOTALL) 7 | matches = regex.findall(xml_string) 8 | return "".join(matches) 9 | 10 | 11 | def filter_unwanted_entities(xml_string, allowed_entities): 12 | allowed_values_pattern = "|".join(allowed_entities.keys()) 13 | replacement = r"\3" 14 | pattern = rf"(.*?)(?!{allowed_values_pattern})(.*?)(.*?)" 15 | return re.sub(pattern, replacement, xml_string) 16 | 17 | 18 | def replace_all_at_once(text, replacements): 19 | sorted_keys = sorted(replacements, key=len, reverse=True) 20 | regex = re.compile(r"(" + "|".join(map(re.escape, sorted_keys)) + r")") 21 | return regex.sub(lambda match: replacements[match.group(0)], text) 22 | 23 | 24 | def json_to_xml( 25 | original_text: str, 26 | tags: list, 27 | tag_root: str, 28 | non_tag_root: str, 29 | value_key: str = "value", 30 | attributes: list = None, 31 | ): 32 | 33 | if len(tags) == 0: 34 | return f"<{non_tag_root}>{original_text}" 35 | 36 | if attributes is None: 37 | attributes = tags[0].keys() 38 | 39 | replacements = {} 40 | for item in tags: 41 | value = item.get(value_key, "") 42 | if not value: 43 | continue 44 | 45 | attribute_parts = [] 46 | for attr in attributes: 47 | if attr in item: 48 | attribute_parts.append(f"<{attr}>{item[attr]}") 49 | attribute_str = "".join(attribute_parts) 50 | replacements[value] = f"<{tag_root}>{attribute_str}" 51 | original_text = replace_all_at_once(original_text, replacements) 52 | 53 | parts = re.split(f"(<{tag_root}>.*?)", original_text) 54 | final_text = "" 55 | for part in parts: 56 | if not part.startswith(f"<{tag_root}>"): 57 | final_text += f"<{non_tag_root}>{part}" 58 | else: 59 | final_text += part 60 | return final_text 61 | -------------------------------------------------------------------------------- /skllm/vectorization.py: -------------------------------------------------------------------------------- 1 | from skllm.models.gpt.vectorization import GPTVectorizer 2 | -------------------------------------------------------------------------------- /tests/llm/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BeastByteAI/scikit-llm/5491ec8d1ba5528b560cd115f7a0c93369fb0628/tests/llm/__init__.py -------------------------------------------------------------------------------- /tests/llm/anthropic/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BeastByteAI/scikit-llm/5491ec8d1ba5528b560cd115f7a0c93369fb0628/tests/llm/anthropic/__init__.py -------------------------------------------------------------------------------- /tests/llm/anthropic/test_anthropic_mixins.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | from unittest.mock import patch 3 | import json 4 | from skllm.llm.anthropic.mixin import ( 5 | ClaudeMixin, 6 | ClaudeTextCompletionMixin, 7 | ClaudeClassifierMixin, 8 | ) 9 | 10 | 11 | class TestClaudeMixin(unittest.TestCase): 12 | def test_ClaudeMixin(self): 13 | mixin = ClaudeMixin() 14 | mixin._set_keys("test_key") 15 | self.assertEqual(mixin._get_claude_key(), "test_key") 16 | 17 | 18 | class TestClaudeTextCompletionMixin(unittest.TestCase): 19 | @patch("skllm.llm.anthropic.mixin.get_chat_completion") 20 | def test_chat_completion_with_valid_params(self, mock_get_chat_completion): 21 | mixin = ClaudeTextCompletionMixin() 22 | mixin._set_keys("test_key") 23 | 24 | mock_get_chat_completion.return_value = { 25 | "content": [ 26 | {"type": "text", "text": "test response"} 27 | ] 28 | } 29 | 30 | completion = mixin._get_chat_completion( 31 | model="claude-3-haiku-20240307", 32 | messages="Hello", 33 | system_message="Test system" 34 | ) 35 | 36 | self.assertEqual( 37 | mixin._convert_completion_to_str(completion), 38 | "test response" 39 | ) 40 | mock_get_chat_completion.assert_called_once() 41 | 42 | 43 | class TestClaudeClassifierMixin(unittest.TestCase): 44 | @patch("skllm.llm.anthropic.mixin.get_chat_completion") 45 | def test_extract_out_label_with_valid_completion(self, mock_get_chat_completion): 46 | mixin = ClaudeClassifierMixin() 47 | mixin._set_keys("test_key") 48 | 49 | mock_get_chat_completion.return_value = { 50 | "content": [ 51 | {"type": "text", "text": '{"label":"hello world"}'} 52 | ] 53 | } 54 | 55 | completion = mixin._get_chat_completion( 56 | model="claude-3-haiku-20240307", 57 | messages="Hello", 58 | system_message="World" 59 | ) 60 | self.assertEqual(mixin._extract_out_label(completion), "hello world") 61 | mock_get_chat_completion.assert_called_once() -------------------------------------------------------------------------------- /tests/llm/gpt/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BeastByteAI/scikit-llm/5491ec8d1ba5528b560cd115f7a0c93369fb0628/tests/llm/gpt/__init__.py -------------------------------------------------------------------------------- /tests/llm/gpt/test_gpt_mixins.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | from unittest.mock import patch 3 | import json 4 | from skllm.llm.gpt.mixin import ( 5 | construct_message, 6 | _build_clf_example, 7 | GPTMixin, 8 | GPTTextCompletionMixin, 9 | GPTClassifierMixin, 10 | ) 11 | 12 | 13 | class TestGPTMixin(unittest.TestCase): 14 | def test_construct_message(self): 15 | self.assertEqual( 16 | construct_message("user", "Hello"), {"role": "user", "content": "Hello"} 17 | ) 18 | with self.assertRaises(ValueError): 19 | construct_message("invalid_role", "Hello") 20 | 21 | def test_build_clf_example(self): 22 | x = "Hello" 23 | y = "Hi" 24 | system_msg = "You are a text classification model." 25 | expected_output = json.dumps( 26 | { 27 | "messages": [ 28 | {"role": "system", "content": system_msg}, 29 | {"role": "user", "content": x}, 30 | {"role": "assistant", "content": y}, 31 | ] 32 | } 33 | ) 34 | self.assertEqual(_build_clf_example(x, y, system_msg), expected_output) 35 | 36 | def test_GPTMixin(self): 37 | mixin = GPTMixin() 38 | mixin._set_keys("test_key", "test_org") 39 | self.assertEqual(mixin._get_openai_key(), "test_key") 40 | self.assertEqual(mixin._get_openai_org(), "test_org") 41 | 42 | 43 | class TestGPTTextCompletionMixin(unittest.TestCase): 44 | @patch("skllm.llm.gpt.mixin.get_chat_completion") 45 | def test_chat_completion_with_valid_params(self, mock_get_chat_completion): 46 | # Setup 47 | mixin = GPTTextCompletionMixin() 48 | mixin._set_keys("test_key", "test_org") 49 | mock_get_chat_completion.return_value = { 50 | "choices": [{"message": {"content": "test response"}}] 51 | } 52 | 53 | _ = mixin._get_chat_completion("test-model", "Hello", "World") 54 | 55 | mock_get_chat_completion.assert_called_once() 56 | 57 | 58 | class TestGPTClassifierMixin(unittest.TestCase): 59 | @patch("skllm.llm.gpt.mixin.get_chat_completion") 60 | def test_extract_out_label_with_valid_completion(self, mock_get_chat_completion): 61 | mixin = GPTClassifierMixin() 62 | mixin._set_keys("test_key", "test_org") 63 | mock_get_chat_completion.return_value = { 64 | "choices": [{"message": {"content": '{"label":"hello world"}'}}] 65 | } 66 | res = mixin._get_chat_completion("test-model", "Hello", "World") 67 | self.assertEqual(mixin._extract_out_label(res), "hello world") 68 | mock_get_chat_completion.assert_called_once() 69 | 70 | 71 | if __name__ == "__main__": 72 | unittest.main() 73 | -------------------------------------------------------------------------------- /tests/llm/vertex/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BeastByteAI/scikit-llm/5491ec8d1ba5528b560cd115f7a0c93369fb0628/tests/llm/vertex/__init__.py -------------------------------------------------------------------------------- /tests/llm/vertex/test_vertex_mixins.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | from unittest.mock import patch 3 | import json 4 | from skllm.llm.vertex.mixin import * 5 | 6 | 7 | class TestVertexTextCompletionMixin(unittest.TestCase): 8 | @patch("skllm.llm.vertex.mixin.get_completion") 9 | def test_chat_completion_with_valid_params(self, mock_get_chat_completion): 10 | # Setup 11 | mixin = VertexTextCompletionMixin() 12 | mock_get_chat_completion.return_value = "res" 13 | _ = mixin._get_chat_completion("test-model", "Hello", "World") 14 | mock_get_chat_completion.assert_called_once() 15 | 16 | 17 | class TestVertexClassifierMixin(unittest.TestCase): 18 | @patch("skllm.llm.vertex.mixin.get_completion") 19 | def test_extract_out_label_with_valid_completion(self, mock_get_chat_completion): 20 | mixin = VertexClassifierMixin() 21 | mock_get_chat_completion.return_value = '{"label":"hello world"}' 22 | res = mixin._get_chat_completion("test-model", "Hello", "World") 23 | self.assertEqual(mixin._extract_out_label(res), "hello world") 24 | mock_get_chat_completion.assert_called_once() 25 | 26 | 27 | if __name__ == "__main__": 28 | unittest.main() 29 | -------------------------------------------------------------------------------- /tests/test_utils.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | import pandas as pd 3 | import numpy as np 4 | from skllm import utils 5 | 6 | 7 | class TestUtils(unittest.TestCase): 8 | def test_to_numpy(self): 9 | # Test with pandas Series 10 | series = pd.Series([1, 2, 3]) 11 | result = utils.to_numpy(series) 12 | self.assertIsInstance(result, np.ndarray) 13 | self.assertEqual(result.tolist(), [1, 2, 3]) 14 | 15 | # Test with list 16 | list_data = [4, 5, 6] 17 | result = utils.to_numpy(list_data) 18 | self.assertIsInstance(result, np.ndarray) 19 | self.assertEqual(result.tolist(), [4, 5, 6]) 20 | 21 | def test_find_json_in_string(self): 22 | # Test with string containing JSON 23 | string = 'Hello {"name": "John", "age": 30} World' 24 | result = utils.find_json_in_string(string) 25 | self.assertEqual(result, '{"name": "John", "age": 30}') 26 | 27 | # Test with string without JSON 28 | string = "Hello World" 29 | result = utils.find_json_in_string(string) 30 | self.assertEqual(result, "{}") 31 | 32 | 33 | if __name__ == "__main__": 34 | unittest.main() 35 | --------------------------------------------------------------------------------