├── .github
├── assets
│ └── badges
│ │ └── .gitkeep
├── dependabot.yml
├── pr-labeler.yml
├── release-drafter.yml
└── workflows
│ ├── draft.yml
│ ├── pr_labeler.yml
│ ├── pre_commit_auto_update.yml
│ └── pypi-deploy.yml
├── .gitignore
├── .pre-commit-config.yaml
├── CODE_OF_CONDUCT.md
├── CONTRIBUTING.md
├── LICENSE
├── README.md
├── pyproject.toml
├── requirements-dev.txt
├── skllm
├── __init__.py
├── classification.py
├── config.py
├── datasets
│ ├── __init__.py
│ ├── multi_class.py
│ ├── multi_label.py
│ ├── summarization.py
│ └── translation.py
├── llm
│ ├── anthropic
│ │ ├── completion.py
│ │ ├── credentials.py
│ │ └── mixin.py
│ ├── base.py
│ ├── gpt
│ │ ├── clients
│ │ │ ├── llama_cpp
│ │ │ │ ├── completion.py
│ │ │ │ └── handler.py
│ │ │ └── openai
│ │ │ │ ├── completion.py
│ │ │ │ ├── credentials.py
│ │ │ │ ├── embedding.py
│ │ │ │ └── tuning.py
│ │ ├── completion.py
│ │ ├── embedding.py
│ │ ├── mixin.py
│ │ └── utils.py
│ └── vertex
│ │ ├── completion.py
│ │ ├── mixin.py
│ │ └── tuning.py
├── memory
│ ├── __init__.py
│ ├── _annoy.py
│ ├── _sklearn_nn.py
│ └── base.py
├── models
│ ├── _base
│ │ ├── classifier.py
│ │ ├── tagger.py
│ │ ├── text2text.py
│ │ └── vectorizer.py
│ ├── anthropic
│ │ ├── classification
│ │ │ ├── few_shot.py
│ │ │ └── zero_shot.py
│ │ ├── tagging
│ │ │ └── ner.py
│ │ └── text2text
│ │ │ ├── __init__.py
│ │ │ ├── summarization.py
│ │ │ └── translation.py
│ ├── gpt
│ │ ├── classification
│ │ │ ├── few_shot.py
│ │ │ ├── tunable.py
│ │ │ └── zero_shot.py
│ │ ├── tagging
│ │ │ └── ner.py
│ │ ├── text2text
│ │ │ ├── __init__.py
│ │ │ ├── summarization.py
│ │ │ ├── translation.py
│ │ │ └── tunable.py
│ │ └── vectorization.py
│ └── vertex
│ │ ├── classification
│ │ ├── tunable.py
│ │ └── zero_shot.py
│ │ └── text2text
│ │ ├── __init__.py
│ │ └── tunable.py
├── prompts
│ ├── builders.py
│ └── templates.py
├── text2text.py
├── utils
│ ├── __init__.py
│ ├── rendering.py
│ └── xml.py
└── vectorization.py
└── tests
├── llm
├── __init__.py
├── anthropic
│ ├── __init__.py
│ └── test_anthropic_mixins.py
├── gpt
│ ├── __init__.py
│ └── test_gpt_mixins.py
└── vertex
│ ├── __init__.py
│ └── test_vertex_mixins.py
└── test_utils.py
/.github/assets/badges/.gitkeep:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/BeastByteAI/scikit-llm/5491ec8d1ba5528b560cd115f7a0c93369fb0628/.github/assets/badges/.gitkeep
--------------------------------------------------------------------------------
/.github/dependabot.yml:
--------------------------------------------------------------------------------
1 | version: 2
2 | updates:
3 | - package-ecosystem: "pip"
4 | directory: "/"
5 | schedule:
6 | interval: "monthly"
7 | - package-ecosystem: "github-actions"
8 | directory: "/"
9 | schedule:
10 | interval: "monthly"
11 |
--------------------------------------------------------------------------------
/.github/pr-labeler.yml:
--------------------------------------------------------------------------------
1 | feature: ['features/*', 'feature/*', 'feat/*', 'features-*', 'feature-*', 'feat-*']
2 | fix: ['fixes/*', 'fix/*']
3 | chore: ['chore/*']
4 | dependencies: ['update/*']
5 |
--------------------------------------------------------------------------------
/.github/release-drafter.yml:
--------------------------------------------------------------------------------
1 | name-template: "v$RESOLVED_VERSION"
2 | tag-template: "v$RESOLVED_VERSION"
3 | categories:
4 | - title: "🚀 Features"
5 | labels:
6 | - "feature"
7 | - "enhancement"
8 | - title: "🐛 Bug Fixes"
9 | labels:
10 | - "fix"
11 | - "bugfix"
12 | - "bug"
13 | - title: "🧹 Maintenance"
14 | labels:
15 | - "maintenance"
16 | - "dependencies"
17 | - "refactoring"
18 | - "cosmetic"
19 | - "chore"
20 | - title: "📝️ Documentation"
21 | labels:
22 | - "documentation"
23 | - "docs"
24 | change-template: "- $TITLE (#$NUMBER)"
25 | change-title-escapes: '\<*_&' # You can add # and @ to disable mentions
26 | version-resolver:
27 | major:
28 | labels:
29 | - "major"
30 | minor:
31 | labels:
32 | - "minor"
33 | patch:
34 | labels:
35 | - "patch"
36 | default: patch
37 | template: |
38 | ## Changes
39 |
40 | $CHANGES
41 |
--------------------------------------------------------------------------------
/.github/workflows/draft.yml:
--------------------------------------------------------------------------------
1 | # Drafts the next Release notes as Pull Requests are merged (or commits are pushed) into "main" or "master"
2 | name: Draft next release
3 |
4 | on:
5 | push:
6 | branches: [main, "master"]
7 |
8 | permissions:
9 | contents: read
10 |
11 | jobs:
12 | update-release-draft:
13 | permissions:
14 | contents: write
15 | pull-requests: write
16 | runs-on: ubuntu-latest
17 | steps:
18 | - uses: release-drafter/release-drafter@v6
19 | env:
20 | GITHUB_TOKEN: ${{ github.token }}
21 |
--------------------------------------------------------------------------------
/.github/workflows/pr_labeler.yml:
--------------------------------------------------------------------------------
1 | # This workflow will apply the corresponding label on a pull request
2 | name: PR Labeler
3 |
4 | on:
5 | pull_request_target:
6 |
7 | permissions:
8 | contents: read
9 | pull-requests: write
10 |
11 | jobs:
12 | pr-labeler:
13 | runs-on: ubuntu-latest
14 | steps:
15 | - uses: TimonVS/pr-labeler-action@v5
16 | with:
17 | repo-token: ${{ github.token }}
18 |
--------------------------------------------------------------------------------
/.github/workflows/pre_commit_auto_update.yml:
--------------------------------------------------------------------------------
1 | # Run a pre-commit autoupdate every week and open a pull request if needed
2 | name: Pre-commit auto-update
3 |
4 | on:
5 | # At 00:00 on the 1st of every month.
6 | schedule:
7 | - cron: "0 0 1 * *"
8 | workflow_dispatch:
9 |
10 | permissions:
11 | contents: write
12 | pull-requests: write
13 |
14 | jobs:
15 | pre-commit-auto-update:
16 | runs-on: ubuntu-latest
17 | steps:
18 | - name: Checkout
19 | uses: actions/checkout@v4
20 | - name: Set up Python
21 | uses: actions/setup-python@v5
22 | - name: Install pre-commit
23 | run: pip install pre-commit
24 | - name: Run pre-commit
25 | run: pre-commit autoupdate
26 | - name: Set git config
27 | run: |
28 | git config --local user.email "action@github.com"
29 | git config --local user.name "GitHub Action"
30 | - uses: peter-evans/create-pull-request@v6
31 | with:
32 | token: ${{ github.token }}
33 | branch: update/pre-commit-hooks
34 | title: Update pre-commit hooks
35 | commit-message: "Update pre-commit hooks"
36 | body: Update versions of pre-commit hooks to latest version.
37 | labels: "dependencies,github_actions"
38 |
--------------------------------------------------------------------------------
/.github/workflows/pypi-deploy.yml:
--------------------------------------------------------------------------------
1 | name: PyPi Deploy
2 |
3 | on:
4 | release:
5 | types: [published]
6 | workflow_dispatch:
7 |
8 | jobs:
9 | deploy:
10 | runs-on: ubuntu-latest
11 | steps:
12 | - name: Checkout code
13 | uses: actions/checkout@v4
14 |
15 | - name: Setup Python
16 | uses: actions/setup-python@v5
17 | with:
18 | python-version: '3.10'
19 |
20 | - name: Install dependencies
21 | run: |
22 | python -m pip install --upgrade pip
23 | pip install build twine
24 |
25 | - name: Build and publish
26 | env:
27 | TWINE_USERNAME: __token__
28 | TWINE_PASSWORD: ${{ secrets.PYPI_TOKEN }}
29 | run: |
30 | python -m build
31 | twine upload dist/*
32 |
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | # Byte-compiled / optimized / DLL files
2 | __pycache__/
3 | *.py[cod]
4 | *$py.class
5 |
6 | # C extensions
7 | *.so
8 |
9 | # Distribution / packaging
10 | .Python
11 | build/
12 | develop-eggs/
13 | dist/
14 | downloads/
15 | eggs/
16 | .eggs/
17 | lib/
18 | lib64/
19 | parts/
20 | sdist/
21 | var/
22 | wheels/
23 | share/python-wheels/
24 | *.egg-info/
25 | .installed.cfg
26 | *.egg
27 | MANIFEST
28 |
29 | # PyInstaller
30 | # Usually these files are written by a python script from a template
31 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
32 | *.manifest
33 | *.spec
34 |
35 | # Installer logs
36 | pip-log.txt
37 | pip-delete-this-directory.txt
38 |
39 | # Unit test / coverage reports
40 | htmlcov/
41 | .tox/
42 | .nox/
43 | .coverage
44 | .coverage.*
45 | .cache
46 | nosetests.xml
47 | coverage.xml
48 | *.cover
49 | *.py,cover
50 | .hypothesis/
51 | .pytest_cache/
52 | cover/
53 |
54 | # Translations
55 | *.mo
56 | *.pot
57 |
58 | # Django stuff:
59 | *.log
60 | local_settings.py
61 | db.sqlite3
62 | db.sqlite3-journal
63 |
64 | # Flask stuff:
65 | instance/
66 | .webassets-cache
67 |
68 | # Scrapy stuff:
69 | .scrapy
70 |
71 | # Sphinx documentation
72 | docs/_build/
73 |
74 | # PyBuilder
75 | .pybuilder/
76 | target/
77 |
78 | # Jupyter Notebook
79 | .ipynb_checkpoints
80 |
81 | # IPython
82 | profile_default/
83 | ipython_config.py
84 |
85 | # pyenv
86 | # For a library or package, you might want to ignore these files since the code is
87 | # intended to run in multiple environments; otherwise, check them in:
88 | # .python-version
89 |
90 | # pipenv
91 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
92 | # However, in case of collaboration, if having platform-specific dependencies or dependencies
93 | # having no cross-platform support, pipenv may install dependencies that don't work, or not
94 | # install all needed dependencies.
95 | #Pipfile.lock
96 |
97 | # poetry
98 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
99 | # This is especially recommended for binary packages to ensure reproducibility, and is more
100 | # commonly ignored for libraries.
101 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
102 | #poetry.lock
103 |
104 | # pdm
105 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
106 | #pdm.lock
107 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
108 | # in version control.
109 | # https://pdm.fming.dev/#use-with-ide
110 | .pdm.toml
111 |
112 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
113 | __pypackages__/
114 |
115 | # Celery stuff
116 | celerybeat-schedule
117 | celerybeat.pid
118 |
119 | # SageMath parsed files
120 | *.sage.py
121 |
122 | # Environments
123 | .env
124 | .venv
125 | env/
126 | venv/
127 | ENV/
128 | env.bak/
129 | venv.bak/
130 |
131 | # Spyder project settings
132 | .spyderproject
133 | .spyproject
134 |
135 | # Rope project settings
136 | .ropeproject
137 |
138 | # mkdocs documentation
139 | /site
140 |
141 | # mypy
142 | .mypy_cache/
143 | .dmypy.json
144 | dmypy.json
145 |
146 | # Pyre type checker
147 | .pyre/
148 |
149 | # pytype static type analyzer
150 | .pytype/
151 |
152 | # Cython debug symbols
153 | cython_debug/
154 |
155 | # PyCharm
156 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can
157 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
158 | # and can be added to the global gitignore or merged into this file. For a more nuclear
159 | # option (not recommended) you can uncomment the following to ignore the entire idea folder.
160 | #.idea/
161 | test.py
162 | tmp.ipynb
163 | tmp.py
164 | *.pickle
165 | *.ipynb
166 |
167 | # vscode
168 | .vscode/
169 | tmp2.py
170 | tmp.*
--------------------------------------------------------------------------------
/.pre-commit-config.yaml:
--------------------------------------------------------------------------------
1 | repos:
2 | - repo: https://github.com/pre-commit/pre-commit-hooks
3 | rev: v4.4.0
4 | hooks:
5 | - id: trailing-whitespace
6 | - id: end-of-file-fixer
7 | - id: check-docstring-first
8 | - id: check-xml
9 | - id: check-json
10 | - id: check-yaml
11 | - id: check-toml
12 | - id: debug-statements
13 | - id: check-executables-have-shebangs
14 | - id: check-case-conflict
15 | - id: check-added-large-files
16 | - id: detect-private-key
17 | # Formatter for Json and Yaml files
18 | - repo: https://github.com/pre-commit/mirrors-prettier
19 | rev: v3.0.0-alpha.9-for-vscode
20 | hooks:
21 | - id: prettier
22 | types: [json, yaml, toml]
23 | # Formatter for markdown files
24 | - repo: https://github.com/executablebooks/mdformat
25 | rev: 0.7.16
26 | hooks:
27 | - id: mdformat
28 | args: ["--number"]
29 | additional_dependencies:
30 | - mdformat-gfm
31 | - mdformat-tables
32 | - mdformat-frontmatter
33 | - mdformat-black
34 | - mdformat-shfmt
35 | # An extremely fast Python linter, written in Rust
36 | - repo: https://github.com/charliermarsh/ruff-pre-commit
37 | rev: "v0.0.263"
38 | hooks:
39 | - id: ruff
40 | args: [--fix, --exit-non-zero-on-fix]
41 | # Python code formatter
42 | - repo: https://github.com/psf/black
43 | rev: 23.3.0
44 | hooks:
45 | - id: black
46 | args: ["--config", "pyproject.toml"]
47 | # Python's import formatter
48 | - repo: https://github.com/PyCQA/isort
49 | rev: 5.12.0
50 | hooks:
51 | - id: isort
52 | # Formats docstrings to follow PEP 257
53 | - repo: https://github.com/PyCQA/docformatter
54 | rev: v1.6.4
55 | hooks:
56 | - id: docformatter
57 | additional_dependencies: [tomli]
58 | args: ["--in-place", "--config", "pyproject.toml"]
59 |
--------------------------------------------------------------------------------
/CODE_OF_CONDUCT.md:
--------------------------------------------------------------------------------
1 | # Contributor Covenant Code of Conduct
2 |
3 | ## Our Pledge
4 |
5 | We as members, contributors, and leaders pledge to make participation in our
6 | community a harassment-free experience for everyone, regardless of age, body
7 | size, visible or invisible disability, ethnicity, sex characteristics, gender
8 | identity and expression, level of experience, education, socio-economic status,
9 | nationality, personal appearance, race, religion, or sexual identity
10 | and orientation.
11 |
12 | We pledge to act and interact in ways that contribute to an open, welcoming,
13 | diverse, inclusive, and healthy community.
14 |
15 | ## Our Standards
16 |
17 | Examples of behavior that contributes to a positive environment for our
18 | community include:
19 |
20 | * Demonstrating empathy and kindness toward other people
21 | * Being respectful of differing opinions, viewpoints, and experiences
22 | * Giving and gracefully accepting constructive feedback
23 | * Accepting responsibility and apologizing to those affected by our mistakes,
24 | and learning from the experience
25 | * Focusing on what is best not just for us as individuals, but for the
26 | overall community
27 |
28 | Examples of unacceptable behavior include:
29 |
30 | * The use of sexualized language or imagery, and sexual attention or
31 | advances of any kind
32 | * Trolling, insulting or derogatory comments, and personal or political attacks
33 | * Public or private harassment
34 | * Publishing others' private information, such as a physical or email
35 | address, without their explicit permission
36 | * Other conduct which could reasonably be considered inappropriate in a
37 | professional setting
38 |
39 | ## Enforcement Responsibilities
40 |
41 | Community leaders are responsible for clarifying and enforcing our standards of
42 | acceptable behavior and will take appropriate and fair corrective action in
43 | response to any behavior that they deem inappropriate, threatening, offensive,
44 | or harmful.
45 |
46 | Community leaders have the right and responsibility to remove, edit, or reject
47 | comments, commits, code, wiki edits, issues, and other contributions that are
48 | not aligned to this Code of Conduct, and will communicate reasons for moderation
49 | decisions when appropriate.
50 |
51 | ## Scope
52 |
53 | This Code of Conduct applies within all community spaces, and also applies when
54 | an individual is officially representing the community in public spaces.
55 | Examples of representing our community include using an official e-mail address,
56 | posting via an official social media account, or acting as an appointed
57 | representative at an online or offline event.
58 |
59 | ## Enforcement
60 |
61 | Instances of abusive, harassing, or otherwise unacceptable behavior may be
62 | reported to the community leaders responsible for enforcement at
63 | https://www.linkedin.com/in/iryna-kondrashchenko-673800155/.
64 | All complaints will be reviewed and investigated promptly and fairly.
65 |
66 | All community leaders are obligated to respect the privacy and security of the
67 | reporter of any incident.
68 |
69 | ## Enforcement Guidelines
70 |
71 | Community leaders will follow these Community Impact Guidelines in determining
72 | the consequences for any action they deem in violation of this Code of Conduct:
73 |
74 | ### 1. Correction
75 |
76 | **Community Impact**: Use of inappropriate language or other behavior deemed
77 | unprofessional or unwelcome in the community.
78 |
79 | **Consequence**: A private, written warning from community leaders, providing
80 | clarity around the nature of the violation and an explanation of why the
81 | behavior was inappropriate. A public apology may be requested.
82 |
83 | ### 2. Warning
84 |
85 | **Community Impact**: A violation through a single incident or series
86 | of actions.
87 |
88 | **Consequence**: A warning with consequences for continued behavior. No
89 | interaction with the people involved, including unsolicited interaction with
90 | those enforcing the Code of Conduct, for a specified period of time. This
91 | includes avoiding interactions in community spaces as well as external channels
92 | like social media. Violating these terms may lead to a temporary or
93 | permanent ban.
94 |
95 | ### 3. Temporary Ban
96 |
97 | **Community Impact**: A serious violation of community standards, including
98 | sustained inappropriate behavior.
99 |
100 | **Consequence**: A temporary ban from any sort of interaction or public
101 | communication with the community for a specified period of time. No public or
102 | private interaction with the people involved, including unsolicited interaction
103 | with those enforcing the Code of Conduct, is allowed during this period.
104 | Violating these terms may lead to a permanent ban.
105 |
106 | ### 4. Permanent Ban
107 |
108 | **Community Impact**: Demonstrating a pattern of violation of community
109 | standards, including sustained inappropriate behavior, harassment of an
110 | individual, or aggression toward or disparagement of classes of individuals.
111 |
112 | **Consequence**: A permanent ban from any sort of public interaction within
113 | the community.
114 |
115 | ## Attribution
116 |
117 | This Code of Conduct is adapted from the [Contributor Covenant][homepage],
118 | version 2.0, available at
119 | https://www.contributor-covenant.org/version/2/0/code_of_conduct.html.
120 |
121 | Community Impact Guidelines were inspired by [Mozilla's code of conduct
122 | enforcement ladder](https://github.com/mozilla/diversity).
123 |
124 | [homepage]: https://www.contributor-covenant.org
125 |
126 | For answers to common questions about this code of conduct, see the FAQ at
127 | https://www.contributor-covenant.org/faq. Translations are available at
128 | https://www.contributor-covenant.org/translations.
129 |
--------------------------------------------------------------------------------
/CONTRIBUTING.md:
--------------------------------------------------------------------------------
1 | # Contributing to Scikit-LLM
2 |
3 | Welcome! We appreciate your interest in contributing to Scikit-LLM. Whether you're a developer, designer, writer, or simply passionate about open source, there are several ways you can help improve this project. This document will guide you through the process of contributing to our Python repository.
4 |
5 | ## How Can I Contribute?
6 |
7 | **IMPORTANT:** We are currently preparing the transition to v.1.0 which will include major code restructuring. Until then, no new pull requests to the main branch will be approved unless discussed in advance via issues or in Discord !
8 |
9 | There are several ways you can contribute to this project:
10 |
11 | - Bug Fixes: Help us identify and fix issues in the codebase.
12 | - Feature Implementation: Implement new features and enhancements.
13 | - Documentation: Improve the project's documentation, including code comments and README files.
14 | - Testing: Write and improve test cases to ensure the project's quality and reliability.
15 | - Translations: Provide translations for the project's documentation or user interface.
16 | - Bug Reports and Feature Requests: Submit bug reports or suggest new features and improvements.
17 |
18 | **Important:** before contributing, we recommend that you open an issue to discuss your planned changes. This allows us to align our goals, provide guidance, and potentially find other contributors interested in collaborating on the same feature or bug fix.
19 |
20 | > ### Legal Notice
21 | >
22 | > When contributing to this project, you must agree that you have authored 100% of the content, that you have the necessary rights to the content and that the content you contribute may be provided under the project license.
23 |
24 | ## Development dependencies
25 |
26 | In order to install all development dependencies, run the following command:
27 |
28 | ```shell
29 | pip install -r requirements-dev.txt
30 | ```
31 |
32 | To ensure that you follow the development workflow, please setup the pre-commit hooks:
33 |
34 | ```shell
35 | pre-commit install
36 | ```
37 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2023 Iryna Kondrashchenko
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 | # Scikit-LLM: Scikit-Learn Meets Large Language Models
6 |
7 | Seamlessly integrate powerful language models like ChatGPT into scikit-learn for enhanced text analysis tasks.
8 |
9 | ## Installation 💾
10 |
11 | ```bash
12 | pip install scikit-llm
13 | ```
14 |
15 | ## Support us 🤝
16 |
17 | You can support the project in the following ways:
18 |
19 | - ⭐ Star Scikit-LLM on GitHub (click the star button in the top right corner)
20 | - 💡 Provide your feedback or propose ideas in the [issues](https://github.com/iryna-kondr/scikit-llm/issues) section or [Discord](https://discord.gg/YDAbwuWK7V)
21 | - 📰 Post about Scikit-LLM on LinkedIn or other platforms
22 | - 🔗 Check out our other projects: Dingo, Falcon
23 |
24 |
25 |
26 |
27 |
28 |
29 |
30 |
31 |
32 |
33 |
34 |
35 |
36 |
37 |
38 |
39 |
40 | ## Quick Start & Documentation 📚
41 |
42 | Quick start example of zero-shot text classification using GPT:
43 |
44 | ```python
45 | # Import the necessary modules
46 | from skllm.datasets import get_classification_dataset
47 | from skllm.config import SKLLMConfig
48 | from skllm.models.gpt.classification.zero_shot import ZeroShotGPTClassifier
49 |
50 | # Configure the credentials
51 | SKLLMConfig.set_openai_key("")
52 | SKLLMConfig.set_openai_org("")
53 |
54 | # Load a demo dataset
55 | X, y = get_classification_dataset() # labels: positive, negative, neutral
56 |
57 | # Initialize the model and make the predictions
58 | clf = ZeroShotGPTClassifier(model="gpt-4")
59 | clf.fit(X,y)
60 | clf.predict(X)
61 | ```
62 |
63 | For more information please refer to the **[documentation](https://skllm.beastbyte.ai)**.
64 |
65 | ## Citation
66 |
67 | You can cite Scikit-LLM using the following BibTeX:
68 |
69 | ```
70 | @software{ScikitLLM,
71 | author = {Iryna Kondrashchenko and Oleh Kostromin},
72 | year = {2023},
73 | publisher = {beastbyte.ai},
74 | address = {Linz, Austria},
75 | title = {Scikit-LLM: Scikit-Learn Meets Large Language Models},
76 | url = {https://github.com/iryna-kondr/scikit-llm }
77 | }
78 | ```
79 |
--------------------------------------------------------------------------------
/pyproject.toml:
--------------------------------------------------------------------------------
1 | [build-system]
2 | requires = ["setuptools>=61.0"]
3 | build-backend = "setuptools.build_meta"
4 |
5 | [project]
6 | dependencies = [
7 | "scikit-learn>=1.1.0,<2.0.0",
8 | "pandas>=1.5.0,<3.0.0",
9 | "openai>=1.2.0,<2.0.0",
10 | "tqdm>=4.60.0,<5.0.0",
11 | "google-cloud-aiplatform[pipelines]>=1.27.0,<2.0.0"
12 | ]
13 | name = "scikit-llm"
14 | version = "1.4.1"
15 | authors = [
16 | { name="Oleh Kostromin", email="kostromin97@gmail.com" },
17 | { name="Iryna Kondrashchenko", email="iryna230520@gmail.com" },
18 | ]
19 | description = "Scikit-LLM: Seamlessly integrate powerful language models like ChatGPT into scikit-learn for enhanced text analysis tasks."
20 | readme = "README.md"
21 | license = {text = "MIT"}
22 | requires-python = ">=3.9"
23 | classifiers = [
24 | "Programming Language :: Python :: 3",
25 | "License :: OSI Approved :: MIT License",
26 | "Operating System :: OS Independent",
27 | ]
28 |
29 | [project.optional-dependencies]
30 | gguf = ["llama-cpp-python>=0.2.82,<0.2.83"]
31 | annoy = ["annoy>=1.17.2,<2.0.0"]
32 |
33 | [tool.ruff]
34 | select = [
35 | # pycodestyle
36 | "E",
37 | # pyflakes
38 | "F",
39 | # pydocstyle
40 | "D",
41 | # flake8-bandit
42 | "S",
43 | # pyupgrade
44 | "UP",
45 | # pep8-naming
46 | "N",
47 | ]
48 | # Error E501 (Line too long) is ignored because of docstrings.
49 | ignore = [
50 | "S101",
51 | "S301",
52 | "S311",
53 | "D100",
54 | "D200",
55 | "D203",
56 | "D205",
57 | "D401",
58 | "E501",
59 | "N803",
60 | "N806",
61 | "D104",
62 | ]
63 | extend-exclude = ["tests/*.py", "setup.py"]
64 | target-version = "py39"
65 | force-exclude = true
66 |
67 | [tool.ruff.per-file-ignores]
68 | "__init__.py" = ["E402", "F401", "F403", "F811"]
69 |
70 | [tool.ruff.pydocstyle]
71 | convention = "numpy"
72 |
73 | [tool.mypy]
74 | ignore_missing_imports = true
75 |
76 | [tool.black]
77 | preview = true
78 | target-version = ['py39', 'py310', 'py311']
79 |
80 | [tool.isort]
81 | profile = "black"
82 | filter_files = true
83 | known_first_party = ["skllm", "skllm.*"]
84 | skip = ["__init__.py"]
85 |
86 | [tool.docformatter]
87 | close-quotes-on-newline = true # D209
88 |
89 | [tool.pytest.ini_options]
90 | pythonpath = [
91 | "."
92 | ]
93 |
--------------------------------------------------------------------------------
/requirements-dev.txt:
--------------------------------------------------------------------------------
1 | pre-commit
2 | black
3 | isort
4 | ruff
5 | docformatter
6 | interrogate
7 | numpy
8 | pandas
9 |
--------------------------------------------------------------------------------
/skllm/__init__.py:
--------------------------------------------------------------------------------
1 | __version__ = '1.4.1'
2 | __author__ = 'Iryna Kondrashchenko, Oleh Kostromin'
3 |
--------------------------------------------------------------------------------
/skllm/classification.py:
--------------------------------------------------------------------------------
1 | ## GPT
2 |
3 | from skllm.models.gpt.classification.zero_shot import (
4 | ZeroShotGPTClassifier,
5 | MultiLabelZeroShotGPTClassifier,
6 | CoTGPTClassifier,
7 | )
8 | from skllm.models.gpt.classification.few_shot import (
9 | FewShotGPTClassifier,
10 | DynamicFewShotGPTClassifier,
11 | MultiLabelFewShotGPTClassifier,
12 | )
13 | from skllm.models.gpt.classification.tunable import (
14 | GPTClassifier as TunableGPTClassifier,
15 | )
16 |
17 | ## Vertex
18 | from skllm.models.vertex.classification.zero_shot import (
19 | ZeroShotVertexClassifier,
20 | MultiLabelZeroShotVertexClassifier,
21 | )
22 | from skllm.models.vertex.classification.tunable import (
23 | VertexClassifier as TunableVertexClassifier,
24 | )
25 |
--------------------------------------------------------------------------------
/skllm/config.py:
--------------------------------------------------------------------------------
1 | import os
2 | from typing import Optional
3 |
4 | _OPENAI_KEY_VAR = "SKLLM_CONFIG_OPENAI_KEY"
5 | _OPENAI_ORG_VAR = "SKLLM_CONFIG_OPENAI_ORG"
6 | _AZURE_API_BASE_VAR = "SKLLM_CONFIG_AZURE_API_BASE"
7 | _AZURE_API_VERSION_VAR = "SKLLM_CONFIG_AZURE_API_VERSION"
8 | _GOOGLE_PROJECT = "GOOGLE_CLOUD_PROJECT"
9 | _GPT_URL_VAR = "SKLLM_CONFIG_GPT_URL"
10 | _ANTHROPIC_KEY_VAR = "SKLLM_CONFIG_ANTHROPIC_KEY"
11 | _GGUF_DOWNLOAD_PATH = "SKLLM_CONFIG_GGUF_DOWNLOAD_PATH"
12 | _GGUF_MAX_GPU_LAYERS = "SKLLM_CONFIG_GGUF_MAX_GPU_LAYERS"
13 | _GGUF_VERBOSE = "SKLLM_CONFIG_GGUF_VERBOSE"
14 |
15 |
16 | class SKLLMConfig:
17 | @staticmethod
18 | def set_gpt_key(key: str) -> None:
19 | """Sets the GPT key.
20 |
21 | Parameters
22 | ----------
23 | key : str
24 | GPT key.
25 | """
26 | os.environ[_OPENAI_KEY_VAR] = key
27 |
28 | def set_gpt_org(key: str) -> None:
29 | """Sets the GPT organization ID.
30 |
31 | Parameters
32 | ----------
33 | key : str
34 | GPT organization ID.
35 | """
36 | os.environ[_OPENAI_ORG_VAR] = key
37 |
38 | @staticmethod
39 | def set_openai_key(key: str) -> None:
40 | """Sets the OpenAI key.
41 |
42 | Parameters
43 | ----------
44 | key : str
45 | OpenAI key.
46 | """
47 | os.environ[_OPENAI_KEY_VAR] = key
48 |
49 | @staticmethod
50 | def get_openai_key() -> Optional[str]:
51 | """Gets the OpenAI key.
52 |
53 | Returns
54 | -------
55 | Optional[str]
56 | OpenAI key.
57 | """
58 | return os.environ.get(_OPENAI_KEY_VAR, None)
59 |
60 | @staticmethod
61 | def set_openai_org(key: str) -> None:
62 | """Sets OpenAI organization ID.
63 |
64 | Parameters
65 | ----------
66 | key : str
67 | OpenAI organization ID.
68 | """
69 | os.environ[_OPENAI_ORG_VAR] = key
70 |
71 | @staticmethod
72 | def get_openai_org() -> str:
73 | """Gets the OpenAI organization ID.
74 |
75 | Returns
76 | -------
77 | str
78 | OpenAI organization ID.
79 | """
80 | return os.environ.get(_OPENAI_ORG_VAR, "")
81 |
82 | @staticmethod
83 | def get_azure_api_base() -> str:
84 | """Gets the API base for Azure.
85 |
86 | Returns
87 | -------
88 | str
89 | URL to be used as the base for the Azure API.
90 | """
91 | base = os.environ.get(_AZURE_API_BASE_VAR, None)
92 | if base is None:
93 | raise RuntimeError("Azure API base is not set")
94 | return base
95 |
96 | @staticmethod
97 | def set_azure_api_base(base: str) -> None:
98 | """Set the API base for Azure.
99 |
100 | Parameters
101 | ----------
102 | base : str
103 | URL to be used as the base for the Azure API.
104 | """
105 | os.environ[_AZURE_API_BASE_VAR] = base
106 |
107 | @staticmethod
108 | def set_azure_api_version(ver: str) -> None:
109 | """Set the API version for Azure.
110 |
111 | Parameters
112 | ----------
113 | ver : str
114 | Azure API version.
115 | """
116 | os.environ[_AZURE_API_VERSION_VAR] = ver
117 |
118 | @staticmethod
119 | def get_azure_api_version() -> str:
120 | """Gets the API version for Azure.
121 |
122 | Returns
123 | -------
124 | str
125 | Azure API version.
126 | """
127 | return os.environ.get(_AZURE_API_VERSION_VAR, "2023-05-15")
128 |
129 | @staticmethod
130 | def get_google_project() -> Optional[str]:
131 | """Gets the Google Cloud project ID.
132 |
133 | Returns
134 | -------
135 | Optional[str]
136 | Google Cloud project ID.
137 | """
138 | return os.environ.get(_GOOGLE_PROJECT, None)
139 |
140 | @staticmethod
141 | def set_google_project(project: str) -> None:
142 | """Sets the Google Cloud project ID.
143 |
144 | Parameters
145 | ----------
146 | project : str
147 | Google Cloud project ID.
148 | """
149 | os.environ[_GOOGLE_PROJECT] = project
150 |
151 | @staticmethod
152 | def set_gpt_url(url: str):
153 | """Sets the GPT URL.
154 |
155 | Parameters
156 | ----------
157 | url : str
158 | GPT URL.
159 | """
160 | os.environ[_GPT_URL_VAR] = url
161 |
162 | @staticmethod
163 | def get_gpt_url() -> Optional[str]:
164 | """Gets the GPT URL.
165 |
166 | Returns
167 | -------
168 | Optional[str]
169 | GPT URL.
170 | """
171 | return os.environ.get(_GPT_URL_VAR, None)
172 |
173 | @staticmethod
174 | def set_anthropic_key(key: str) -> None:
175 | """Sets the Anthropic key.
176 |
177 | Parameters
178 | ----------
179 | key : str
180 | Anthropic key.
181 | """
182 | os.environ[_ANTHROPIC_KEY_VAR] = key
183 |
184 | @staticmethod
185 | def get_anthropic_key() -> Optional[str]:
186 | """Gets the Anthropic key.
187 |
188 | Returns
189 | -------
190 | Optional[str]
191 | Anthropic key.
192 | """
193 | return os.environ.get(_ANTHROPIC_KEY_VAR, None)
194 |
195 | @staticmethod
196 | def reset_gpt_url():
197 | """Resets the GPT URL."""
198 | os.environ.pop(_GPT_URL_VAR, None)
199 |
200 | @staticmethod
201 | def get_gguf_download_path() -> str:
202 | """Gets the path to store the downloaded GGUF files."""
203 | default_path = os.path.join(os.path.expanduser("~"), ".skllm", "gguf")
204 | return os.environ.get(_GGUF_DOWNLOAD_PATH, default_path)
205 |
206 | @staticmethod
207 | def get_gguf_max_gpu_layers() -> int:
208 | """Gets the maximum number of layers to use for the GGUF model."""
209 | return int(os.environ.get(_GGUF_MAX_GPU_LAYERS, 0))
210 |
211 | @staticmethod
212 | def set_gguf_max_gpu_layers(n_layers: int):
213 | """Sets the maximum number of layers to use for the GGUF model."""
214 | if not isinstance(n_layers, int):
215 | raise ValueError("n_layers must be an integer")
216 | if n_layers < -1:
217 | n_layers = -1
218 | os.environ[_GGUF_MAX_GPU_LAYERS] = str(n_layers)
219 |
220 | @staticmethod
221 | def set_gguf_verbose(verbose: bool):
222 | """Sets the verbosity of the GGUF model."""
223 | if not isinstance(verbose, bool):
224 | raise ValueError("verbose must be a boolean")
225 | os.environ[_GGUF_VERBOSE] = str(verbose)
226 |
227 | @staticmethod
228 | def get_gguf_verbose() -> bool:
229 | """Gets the verbosity of the GGUF model."""
230 | return os.environ.get(_GGUF_VERBOSE, "False").lower() == "true"
231 |
--------------------------------------------------------------------------------
/skllm/datasets/__init__.py:
--------------------------------------------------------------------------------
1 | from skllm.datasets.multi_class import get_classification_dataset
2 | from skllm.datasets.multi_label import get_multilabel_classification_dataset
3 | from skllm.datasets.summarization import get_summarization_dataset
4 | from skllm.datasets.translation import get_translation_dataset
5 |
--------------------------------------------------------------------------------
/skllm/datasets/multi_class.py:
--------------------------------------------------------------------------------
1 | def get_classification_dataset():
2 | X = [
3 | r"I was absolutely blown away by the performances in 'Summer's End'. The acting was top-notch, and the plot had me gripped from start to finish. A truly captivating cinematic experience that I would highly recommend.",
4 | r"The special effects in 'Star Battles: Nebula Conflict' were out of this world. I felt like I was actually in space. The storyline was incredibly engaging and left me wanting more. Excellent film.",
5 | r"'The Lost Symphony' was a masterclass in character development and storytelling. The score was hauntingly beautiful and complimented the intense, emotional scenes perfectly. Kudos to the director and cast for creating such a masterpiece.",
6 | r"I was pleasantly surprised by 'Love in the Time of Cholera'. The romantic storyline was heartwarming and the characters were incredibly realistic. The cinematography was also top-notch. A must-watch for all romance lovers.",
7 | r"I went into 'Marble Street' with low expectations, but I was pleasantly surprised. The suspense was well-maintained throughout, and the twist at the end was something I did not see coming. Bravo!",
8 | r"'The Great Plains' is a touching portrayal of life in rural America. The performances were heartfelt and the scenery was breathtaking. I was moved to tears by the end. It's a story that will stay with me for a long time.",
9 | r"The screenwriting in 'Under the Willow Tree' was superb. The dialogue felt real and the characters were well-rounded. The performances were also fantastic. I haven't enjoyed a movie this much in a while.",
10 | r"'Nightshade' is a brilliant take on the superhero genre. The protagonist was relatable and the villain was genuinely scary. The action sequences were thrilling and the storyline was engaging. I can't wait for the sequel.",
11 | r"The cinematography in 'Awakening' was nothing short of spectacular. The visuals alone are worth the ticket price. The storyline was unique and the performances were solid. An overall fantastic film.",
12 | r"'Eternal Embers' was a cinematic delight. The storytelling was original and the performances were exceptional. The director's vision was truly brought to life on the big screen. A must-see for all movie lovers.",
13 | r"I was thoroughly disappointed with 'Silver Shadows'. The plot was confusing and the performances were lackluster. I wouldn't recommend wasting your time on this one.",
14 | r"'The Darkened Path' was a disaster. The storyline was unoriginal, the acting was wooden and the special effects were laughably bad. Save your money and skip this one.",
15 | r"I had high hopes for 'The Final Frontier', but it failed to deliver. The plot was full of holes and the characters were poorly developed. It was a disappointing experience.",
16 | r"'The Fall of the Phoenix' was a letdown. The storyline was confusing and the characters were one-dimensional. I found myself checking my watch multiple times throughout the movie.",
17 | r"I regret wasting my time on 'Emerald City'. The plot was nonsensical and the performances were uninspired. It was a major disappointment.",
18 | r"I found 'Hollow Echoes' to be a complete mess. The plot was non-existent, the performances were overdone, and the pacing was all over the place. Definitely not worth the hype.",
19 | r"'Underneath the Stars' was a huge disappointment. The storyline was predictable and the acting was mediocre at best. I was expecting so much more.",
20 | r"I was left unimpressed by 'River's Edge'. The plot was convoluted, the characters were uninteresting, and the ending was unsatisfying. It's a pass for me.",
21 | r"The acting in 'Desert Mirage' was subpar, and the plot was boring. I found myself yawning multiple times throughout the movie. Save your time and skip this one.",
22 | r"'Crimson Dawn' was a major letdown. The plot was cliched and the characters were flat. The special effects were also poorly executed. I wouldn't recommend it.",
23 | r"'Remember the Days' was utterly forgettable. The storyline was dull, the performances were bland, and the dialogue was cringeworthy. A big disappointment.",
24 | r"'The Last Frontier' was simply okay. The plot was decent and the performances were acceptable. However, it lacked a certain spark to make it truly memorable.",
25 | r"'Through the Storm' was not bad, but it wasn't great either. The storyline was somewhat predictable, and the characters were somewhat stereotypical. It was an average movie at best.",
26 | r"I found 'After the Rain' to be pretty average. The plot was okay and the performances were decent, but it didn't leave a lasting impression on me.",
27 | r"'Beyond the Horizon' was neither good nor bad. The plot was interesting enough, but the characters were not very well developed. It was an okay watch.",
28 | r"'The Silent Echo' was a mediocre movie. The storyline was passable and the performances were fair, but it didn't stand out in any way.",
29 | r"I thought 'The Scent of Roses' was pretty average. The plot was somewhat engaging, and the performances were okay, but it didn't live up to my expectations.",
30 | r"'Under the Same Sky' was an okay movie. The plot was decent, and the performances were fine, but it lacked depth and originality. It's not a movie I would watch again.",
31 | r"'Chasing Shadows' was fairly average. The plot was not bad, and the performances were passable, but it lacked a certain spark. It was just okay.",
32 | r"'Beneath the Surface' was pretty run-of-the-mill. The plot was decent, the performances were okay, but it wasn't particularly memorable. It was an okay movie.",
33 | ]
34 |
35 |
36 | y = (
37 | ["positive" for _ in range(10)]
38 | + ["negative" for _ in range(10)]
39 | + ["neutral" for _ in range(10)]
40 | )
41 |
42 | return X, y
--------------------------------------------------------------------------------
/skllm/datasets/multi_label.py:
--------------------------------------------------------------------------------
1 | def get_multilabel_classification_dataset():
2 | X = [
3 | "The product was of excellent quality, and the packaging was also very good. Highly recommend!",
4 | "The delivery was super fast, but the product did not match the information provided on the website.",
5 | "Great variety of products, but the customer support was quite unresponsive.",
6 | "Affordable prices and an easy-to-use website. A great shopping experience overall.",
7 | "The delivery was delayed, and the packaging was damaged. Not a good experience.",
8 | "Excellent customer support, but the return policy is quite complicated.",
9 | "The product was not as described. However, the return process was easy and quick.",
10 | "Great service and fast delivery. The product was also of high quality.",
11 | "The prices are a bit high. However, the product quality and user experience are worth it.",
12 | "The website provides detailed information about products. The delivery was also very fast."
13 | ]
14 |
15 | y = [
16 | ["Quality", "Packaging"],
17 | ["Delivery", "Product Information"],
18 | ["Product Variety", "Customer Support"],
19 | ["Price", "User Experience"],
20 | ["Delivery", "Packaging"],
21 | ["Customer Support", "Return Policy"],
22 | ["Product Information", "Return Policy"],
23 | ["Service", "Delivery", "Quality"],
24 | ["Price", "Quality", "User Experience"],
25 | ["Product Information", "Delivery"],
26 | ]
27 |
28 | return X, y
--------------------------------------------------------------------------------
/skllm/datasets/summarization.py:
--------------------------------------------------------------------------------
1 | def get_summarization_dataset():
2 | X = [
3 | r"The AI research company, OpenAI, has launched a new language model called GPT-4. This model is the latest in a series of transformer-based AI systems designed to perform complex tasks, such as generating human-like text, translating languages, and answering questions. According to OpenAI, GPT-4 is even more powerful and versatile than its predecessors.",
4 | r"John went to the grocery store in the morning to prepare for a small get-together at his house. He bought fresh apples, juicy oranges, and a bottle of milk. Once back home, he used the apples and oranges to make a delicious fruit salad, which he served to his guests in the evening.",
5 | r"The first Mars rover, named Sojourner, was launched by NASA in 1996. The mission was a part of the Mars Pathfinder project and was a major success. The data Sojourner provided about Martian terrain and atmosphere greatly contributed to our understanding of the Red Planet.",
6 | r"A new study suggests that regular exercise can improve memory and cognitive function in older adults. The study, which monitored the health and habits of 500 older adults, recommends 30 minutes of moderate exercise daily for the best results.",
7 | r"The Eiffel Tower, a globally recognized symbol of Paris and France, was completed in 1889 for the World's Fair. Despite its initial criticism and controversy over its unconventional design, the Eiffel Tower has become a beloved landmark and a symbol of French architectural innovation.",
8 | r"Microsoft has announced a new version of its flagship operating system, Windows. The update, which will be rolled out later this year, features improved security protocols and a redesigned user interface, promising a more streamlined and safer user experience.",
9 | r"The WHO declared a new public health emergency due to an outbreak of a previously unknown virus. As the number of cases grows globally, the organization urges nations to ramp up their disease surveillance and response systems.",
10 | r"The 2024 Olympics have been confirmed to take place in Paris, France. This marks the third time the city will host the games, with previous occasions in 1900 and 1924. Preparations are already underway to make the event a grand spectacle.",
11 | r"Apple has introduced its latest iPhone model. The new device boasts a range of features, including an improved camera system, a faster processor, and a longer battery life. It is set to hit the market later this year.",
12 | r"Scientists working in the Amazon rainforest have discovered a new species of bird. The bird, characterized by its unique bright plumage, has created excitement among the global ornithological community."
13 | ]
14 |
15 | return X
--------------------------------------------------------------------------------
/skllm/datasets/translation.py:
--------------------------------------------------------------------------------
1 | def get_translation_dataset():
2 | X = [
3 | r"Me encanta bailar salsa y bachata. Es una forma divertida de expresarme.",
4 | r"J'ai passé mes dernières vacances en Grèce. Les plages étaient magnifiques.",
5 | (
6 | r"Ich habe gestern ein tolles Buch gelesen. Die Geschichte war fesselnd bis"
7 | r" zum Ende."
8 | ),
9 | (
10 | r"Gosto de cozinhar pratos tradicionais italianos. O espaguete à carbonara"
11 | r" é um dos meus favoritos."
12 | ),
13 | (
14 | r"Mám v plánu letos v létě vyrazit na výlet do Itálie. Doufám, že navštívím"
15 | r" Řím a Benátky."
16 | ),
17 | (
18 | r"Mijn favoriete hobby is fotograferen. Ik hou ervan om mooie momenten vast"
19 | r" te leggen."
20 | ),
21 | ]
22 |
23 | return X
24 |
--------------------------------------------------------------------------------
/skllm/llm/anthropic/completion.py:
--------------------------------------------------------------------------------
1 | from typing import Dict, List, Optional
2 | from skllm.llm.anthropic.credentials import set_credentials
3 | from skllm.utils import retry
4 |
5 | @retry(max_retries=3)
6 | def get_chat_completion(
7 | messages: List[Dict],
8 | key: str,
9 | model: str = "claude-3-haiku-20240307",
10 | max_tokens: int = 1000,
11 | temperature: float = 0.0,
12 | system: Optional[str] = None,
13 | json_response: bool = False,
14 | ) -> dict:
15 | """
16 | Gets a chat completion from the Anthropic Claude API using the Messages API.
17 |
18 | Parameters
19 | ----------
20 | messages : dict
21 | Input messages to use.
22 | key : str
23 | The Anthropic API key to use.
24 | model : str, optional
25 | The Claude model to use.
26 | max_tokens : int, optional
27 | Maximum tokens to generate.
28 | temperature : float, optional
29 | Sampling temperature.
30 | system : str, optional
31 | System message to set the assistant's behavior.
32 | json_response : bool, optional
33 | Whether to request a JSON-formatted response. Defaults to False.
34 |
35 | Returns
36 | -------
37 | response : dict
38 | The completion response from the API.
39 | """
40 | if not messages:
41 | raise ValueError("Messages list cannot be empty")
42 | if not isinstance(messages, list):
43 | raise TypeError("Messages must be a list")
44 |
45 | client = set_credentials(key)
46 |
47 | if json_response and system:
48 | system = f"{system.rstrip('.')}. Respond in JSON format."
49 | elif json_response:
50 | system = "Respond in JSON format."
51 |
52 | formatted_messages = [
53 | {
54 | "role": "user", # Explicitly set role to "user"
55 | "content": [
56 | {
57 | "type": "text",
58 | "text": message.get("content", "")
59 | }
60 | ]
61 | }
62 | for message in messages
63 | ]
64 |
65 | response = client.messages.create(
66 | model=model,
67 | max_tokens=max_tokens,
68 | temperature=temperature,
69 | system=system,
70 | messages=formatted_messages,
71 | )
72 | return response
--------------------------------------------------------------------------------
/skllm/llm/anthropic/credentials.py:
--------------------------------------------------------------------------------
1 | from anthropic import Anthropic
2 |
3 |
4 | def set_credentials(key: str) -> None:
5 | """Set the Anthropic key.
6 |
7 | Parameters
8 | ----------
9 | key : str
10 | The Anthropic key to use.
11 | """
12 | client = Anthropic(api_key=key)
13 | return client
14 |
--------------------------------------------------------------------------------
/skllm/llm/anthropic/mixin.py:
--------------------------------------------------------------------------------
1 | from typing import Optional, Union, Any, List, Dict, Mapping
2 | from skllm.config import SKLLMConfig as _Config
3 | from skllm.llm.anthropic.completion import get_chat_completion
4 | from skllm.utils import extract_json_key
5 | from skllm.llm.base import BaseTextCompletionMixin, BaseClassifierMixin
6 | import json
7 |
8 |
9 | class ClaudeMixin:
10 | """A mixin class that provides Claude API key to other classes."""
11 |
12 | _prefer_json_output = False
13 |
14 | def _set_keys(self, key: Optional[str] = None) -> None:
15 | """Set the Claude API key."""
16 | self.key = key
17 |
18 | def _get_claude_key(self) -> str:
19 | """Get the Claude key from the class or config file."""
20 | key = self.key
21 | if key is None:
22 | key = _Config.get_anthropic_key()
23 | if key is None:
24 | raise RuntimeError("Claude API key was not found")
25 | return key
26 |
27 | class ClaudeTextCompletionMixin(ClaudeMixin, BaseTextCompletionMixin):
28 | """A mixin class that provides text completion capabilities using the Claude API."""
29 |
30 | def _get_chat_completion(
31 | self,
32 | model: str,
33 | messages: Union[str, List[Dict[str, str]]],
34 | system_message: Optional[str] = None,
35 | **kwargs: Any,
36 | ):
37 | """Gets a chat completion from the Anthropic API.
38 |
39 | Parameters
40 | ----------
41 | model : str
42 | The model to use.
43 | messages : Union[str, List[Dict[str, str]]]
44 | input messages to use.
45 | system_message : Optional[str]
46 | A system message to use.
47 | **kwargs : Any
48 | placeholder.
49 |
50 | Returns
51 | -------
52 | completion : dict
53 | """
54 | if isinstance(messages, str):
55 | messages = [{"content": messages}]
56 | elif isinstance(messages, list):
57 | messages = [{"content": msg["content"]} for msg in messages]
58 |
59 | completion = get_chat_completion(
60 | messages=messages,
61 | key=self._get_claude_key(),
62 | model=model,
63 | system=system_message,
64 | json_response=self._prefer_json_output,
65 | **kwargs,
66 | )
67 | return completion
68 |
69 | def _convert_completion_to_str(self, completion: Mapping[str, Any]):
70 | """Converts Claude API completion to string."""
71 | try:
72 | if hasattr(completion, 'content'):
73 | return completion.content[0].text
74 | return completion.get('content', [{}])[0].get('text', '')
75 | except Exception as e:
76 | print(f"Error converting completion to string: {str(e)}")
77 | return ""
78 |
79 | class ClaudeClassifierMixin(ClaudeTextCompletionMixin, BaseClassifierMixin):
80 | """A mixin class that provides classification capabilities using Claude API."""
81 |
82 | _prefer_json_output = True
83 |
84 | def _extract_out_label(self, completion: Mapping[str, Any], **kwargs) -> str:
85 | """Extracts the label from a Claude API completion."""
86 | try:
87 | content = self._convert_completion_to_str(completion)
88 | if not self._prefer_json_output:
89 | return content.strip()
90 |
91 | # Attempt to parse content as JSON and extract label
92 | try:
93 | data = json.loads(content)
94 | if "label" in data:
95 | return data["label"]
96 | except json.JSONDecodeError:
97 | pass
98 | return ""
99 |
100 | except Exception as e:
101 | print(f"Error extracting label: {str(e)}")
102 | return ""
103 |
--------------------------------------------------------------------------------
/skllm/llm/base.py:
--------------------------------------------------------------------------------
1 | from abc import ABC, abstractmethod
2 | from typing import Any
3 |
4 |
5 | class BaseTextCompletionMixin(ABC):
6 | @abstractmethod
7 | def _get_chat_completion(self, **kwargs):
8 | """Gets a chat completion from the LLM"""
9 | pass
10 |
11 | @abstractmethod
12 | def _convert_completion_to_str(self, completion: Any):
13 | """Converts a completion object to a string"""
14 | pass
15 |
16 |
17 | class BaseClassifierMixin(BaseTextCompletionMixin):
18 | @abstractmethod
19 | def _extract_out_label(self, completion: Any) -> str:
20 | """Extracts the label from a completion"""
21 | pass
22 |
23 |
24 | class BaseEmbeddingMixin(ABC):
25 | @abstractmethod
26 | def _get_embeddings(self, **kwargs):
27 | """Gets embeddings from the LLM"""
28 | pass
29 |
30 |
31 | class BaseTunableMixin(ABC):
32 | @abstractmethod
33 | def _tune(self, X: Any, y: Any):
34 | pass
35 |
36 | @abstractmethod
37 | def _set_hyperparameters(self, **kwargs):
38 | pass
39 |
--------------------------------------------------------------------------------
/skllm/llm/gpt/clients/llama_cpp/completion.py:
--------------------------------------------------------------------------------
1 | from skllm.llm.gpt.clients.llama_cpp.handler import ModelCache, LlamaHandler
2 |
3 |
4 | def get_chat_completion(messages: dict, model: str, **kwargs):
5 |
6 | with ModelCache.lock:
7 | handler = ModelCache.get(model)
8 | if handler is None:
9 | handler = LlamaHandler(model)
10 | ModelCache.store(model, handler)
11 |
12 | return handler.get_chat_completion(messages, **kwargs)
13 |
--------------------------------------------------------------------------------
/skllm/llm/gpt/clients/llama_cpp/handler.py:
--------------------------------------------------------------------------------
1 | import threading
2 | import os
3 | import hashlib
4 | import requests
5 | from tqdm import tqdm
6 | import hashlib
7 | from typing import Optional
8 | import tempfile
9 | from skllm.config import SKLLMConfig
10 | from warnings import warn
11 |
12 |
13 | try:
14 | from llama_cpp import Llama as _Llama
15 |
16 | _llama_imported = True
17 | except (ImportError, ModuleNotFoundError):
18 | _llama_imported = False
19 |
20 |
21 | supported_models = {
22 | "llama3-8b-q4": {
23 | "download_url": "https://huggingface.co/QuantFactory/Meta-Llama-3-8B-Instruct-GGUF/resolve/main/Meta-Llama-3-8B-Instruct.Q4_K_M.gguf",
24 | "sha256": "c57380038ea85d8bec586ec2af9c91abc2f2b332d41d6cf180581d7bdffb93c1",
25 | "n_ctx": 8192,
26 | "supports_system_message": True,
27 | },
28 | "gemma2-9b-q4": {
29 | "download_url": "https://huggingface.co/bartowski/gemma-2-9b-it-GGUF/resolve/main/gemma-2-9b-it-Q4_K_M.gguf",
30 | "sha256": "13b2a7b4115bbd0900162edcebe476da1ba1fc24e718e8b40d32f6e300f56dfe",
31 | "n_ctx": 8192,
32 | "supports_system_message": False,
33 | },
34 | "phi3-mini-q4": {
35 | "download_url": "https://huggingface.co/microsoft/Phi-3-mini-4k-instruct-gguf/resolve/main/Phi-3-mini-4k-instruct-q4.gguf",
36 | "sha256": "8a83c7fb9049a9b2e92266fa7ad04933bb53aa1e85136b7b30f1b8000ff2edef",
37 | "n_ctx": 4096,
38 | "supports_system_message": False,
39 | },
40 | "mistral0.3-7b-q4": {
41 | "download_url": "https://huggingface.co/lmstudio-community/Mistral-7B-Instruct-v0.3-GGUF/resolve/main/Mistral-7B-Instruct-v0.3-Q4_K_M.gguf",
42 | "sha256": "1270d22c0fbb3d092fb725d4d96c457b7b687a5f5a715abe1e818da303e562b6",
43 | "n_ctx": 32768,
44 | "supports_system_message": False,
45 | },
46 | "gemma2-2b-q6": {
47 | "download_url": "https://huggingface.co/bartowski/gemma-2-2b-it-GGUF/resolve/main/gemma-2-2b-it-Q6_K_L.gguf",
48 | "sha256": "b2ef9f67b38c6e246e593cdb9739e34043d84549755a1057d402563a78ff2254",
49 | "n_ctx": 8192,
50 | "supports_system_message": False,
51 | },
52 | }
53 |
54 |
55 | class LlamaHandler:
56 |
57 | def maybe_download_model(self, model_name, download_url, sha256) -> str:
58 | download_folder = SKLLMConfig.get_gguf_download_path()
59 | os.makedirs(download_folder, exist_ok=True)
60 | model_name = model_name + ".gguf"
61 | model_path = os.path.join(download_folder, model_name)
62 | if not os.path.exists(model_path):
63 | print("The model `{0}` is not found locally.".format(model_name))
64 | self._download_model(model_name, download_folder, download_url, sha256)
65 | return model_path
66 |
67 | def _download_model(
68 | self, model_filename: str, model_path: str, url: str, expected_sha256: str
69 | ) -> str:
70 | full_path = os.path.join(model_path, model_filename)
71 | temp_file = tempfile.NamedTemporaryFile(delete=False, dir=model_path)
72 | temp_path = temp_file.name
73 | temp_file.close()
74 |
75 | response = requests.get(url, stream=True)
76 |
77 | if response.status_code != 200:
78 | os.remove(temp_path)
79 | raise ValueError(
80 | f"Request failed: HTTP {response.status_code} {response.reason}"
81 | )
82 |
83 | total_size_in_bytes = int(response.headers.get("content-length", 0))
84 | block_size = 1024 * 1024 * 4
85 |
86 | sha256 = hashlib.sha256()
87 |
88 | with (
89 | open(temp_path, "wb") as file,
90 | tqdm(
91 | desc="Downloading {0}: ".format(model_filename),
92 | total=total_size_in_bytes,
93 | unit="iB",
94 | unit_scale=True,
95 | ) as progress_bar,
96 | ):
97 | for data in response.iter_content(block_size):
98 | file.write(data)
99 | sha256.update(data)
100 | progress_bar.update(len(data))
101 |
102 | downloaded_sha256 = sha256.hexdigest()
103 | if downloaded_sha256 != expected_sha256:
104 | raise ValueError(
105 | f"Expected SHA-256 hash {expected_sha256}, but got {downloaded_sha256}"
106 | )
107 |
108 | os.rename(temp_path, full_path)
109 |
110 | def __init__(self, model: str):
111 | if not _llama_imported:
112 | raise ImportError(
113 | "llama_cpp is not installed, try `pip install scikit-llm[llama_cpp]`"
114 | )
115 | self.lock = threading.Lock()
116 | if model not in supported_models:
117 | raise ValueError(f"Model {model} is not supported.")
118 | download_url = supported_models[model]["download_url"]
119 | sha256 = supported_models[model]["sha256"]
120 | n_ctx = supported_models[model]["n_ctx"]
121 | self.supports_system_message = supported_models[model][
122 | "supports_system_message"
123 | ]
124 | if not self.supports_system_message:
125 | warn(
126 | f"The model {model} does not support system messages. This may cause issues with some estimators."
127 | )
128 | extended_model_name = model + "-" + sha256[:8]
129 | model_path = self.maybe_download_model(
130 | extended_model_name, download_url, sha256
131 | )
132 | max_gpu_layers = SKLLMConfig.get_gguf_max_gpu_layers()
133 | verbose = SKLLMConfig.get_gguf_verbose()
134 | self.model = _Llama(
135 | model_path=model_path,
136 | n_ctx=n_ctx,
137 | verbose=verbose,
138 | n_gpu_layers=max_gpu_layers,
139 | )
140 |
141 | def get_chat_completion(self, messages: dict, **kwargs):
142 | if not self.supports_system_message:
143 | messages = [m for m in messages if m["role"] != "system"]
144 | with self.lock:
145 | return self.model.create_chat_completion(
146 | messages, temperature=0.0, **kwargs
147 | )
148 |
149 |
150 | class ModelCache:
151 | lock = threading.Lock()
152 | cache: dict[str, LlamaHandler] = {}
153 |
154 | @classmethod
155 | def get(cls, key) -> Optional[LlamaHandler]:
156 | return cls.cache.get(key, None)
157 |
158 | @classmethod
159 | def store(cls, key, value):
160 | cls.cache[key] = value
161 |
162 | @classmethod
163 | def clear(cls):
164 | cls.cache = {}
165 |
--------------------------------------------------------------------------------
/skllm/llm/gpt/clients/openai/completion.py:
--------------------------------------------------------------------------------
1 | import openai
2 | from openai import OpenAI
3 | from skllm.llm.gpt.clients.openai.credentials import (
4 | set_azure_credentials,
5 | set_credentials,
6 | )
7 | from skllm.utils import retry
8 |
9 |
10 | @retry(max_retries=3)
11 | def get_chat_completion(
12 | messages: dict,
13 | key: str,
14 | org: str,
15 | model: str = "gpt-3.5-turbo",
16 | api="openai",
17 | json_response=False,
18 | ):
19 | """Gets a chat completion from the OpenAI API.
20 |
21 | Parameters
22 | ----------
23 | messages : dict
24 | input messages to use.
25 | key : str
26 | The OPEN AI key to use.
27 | org : str
28 | The OPEN AI organization ID to use.
29 | model : str, optional
30 | The OPEN AI model to use. Defaults to "gpt-3.5-turbo".
31 | max_retries : int, optional
32 | The maximum number of retries to use. Defaults to 3.
33 | api : str
34 | The API to use. Must be one of "openai" or "azure". Defaults to "openai".
35 |
36 | Returns
37 | -------
38 | completion : dict
39 | """
40 | if api in ("openai", "custom_url"):
41 | client = set_credentials(key, org)
42 | elif api == "azure":
43 | client = set_azure_credentials(key, org)
44 | else:
45 | raise ValueError("Invalid API")
46 | model_dict = {"model": model}
47 | if json_response and api == "openai":
48 | model_dict["response_format"] = {"type": "json_object"}
49 | completion = client.chat.completions.create(
50 | temperature=0.0, messages=messages, **model_dict
51 | )
52 | return completion
53 |
--------------------------------------------------------------------------------
/skllm/llm/gpt/clients/openai/credentials.py:
--------------------------------------------------------------------------------
1 | import openai
2 | from skllm.config import SKLLMConfig as _Config
3 | from time import sleep
4 | from openai import OpenAI, AzureOpenAI
5 | from skllm.config import SKLLMConfig as _Config
6 |
7 |
8 | def set_credentials(key: str, org: str) -> None:
9 | """Set the OpenAI key and organization.
10 |
11 | Parameters
12 | ----------
13 | key : str
14 | The OpenAI key to use.
15 | org : str
16 | The OPEN AI organization ID to use.
17 | """
18 | url = _Config.get_gpt_url()
19 | client = OpenAI(api_key=key, organization=org, base_url=url)
20 | return client
21 |
22 |
23 | def set_azure_credentials(key: str, org: str) -> None:
24 | """Sets OpenAI credentials for Azure.
25 |
26 | Parameters
27 | ----------
28 | key : str
29 | The OpenAI (Azure) key to use.
30 | org : str
31 | The OpenAI (Azure) organization ID to use.
32 | """
33 | client = AzureOpenAI(
34 | api_key=key,
35 | organization=org,
36 | api_version=_Config.get_azure_api_version(),
37 | azure_endpoint=_Config.get_azure_api_base(),
38 | )
39 | return client
40 |
--------------------------------------------------------------------------------
/skllm/llm/gpt/clients/openai/embedding.py:
--------------------------------------------------------------------------------
1 | from skllm.llm.gpt.clients.openai.credentials import set_credentials, set_azure_credentials
2 | from skllm.utils import retry
3 | import openai
4 | from openai import OpenAI
5 |
6 |
7 | @retry(max_retries=3)
8 | def get_embedding(
9 | text: str,
10 | key: str,
11 | org: str,
12 | model: str = "text-embedding-ada-002",
13 | api: str = "openai"
14 | ):
15 | """
16 | Encodes a string and return the embedding for a string.
17 |
18 | Parameters
19 | ----------
20 | text : str
21 | The string to encode.
22 | key : str
23 | The OPEN AI key to use.
24 | org : str
25 | The OPEN AI organization ID to use.
26 | model : str, optional
27 | The model to use. Defaults to "text-embedding-ada-002".
28 | max_retries : int, optional
29 | The maximum number of retries to use. Defaults to 3.
30 | api: str, optional
31 | The API to use. Must be one of "openai" or "azure". Defaults to "openai".
32 |
33 | Returns
34 | -------
35 | emb : list
36 | The GPT embedding for the string.
37 | """
38 | if api in ("openai", "custom_url"):
39 | client = set_credentials(key, org)
40 | elif api == "azure":
41 | client = set_azure_credentials(key, org)
42 | text = [str(t).replace("\n", " ") for t in text]
43 | embeddings = []
44 | emb = client.embeddings.create(input=text, model=model)
45 | for i in range(len(emb.data)):
46 | e = emb.data[i].embedding
47 | if not isinstance(e, list):
48 | raise ValueError(
49 | f"Encountered unknown embedding format. Expected list, got {type(emb)}"
50 | )
51 | embeddings.append(e)
52 | return embeddings
53 |
--------------------------------------------------------------------------------
/skllm/llm/gpt/clients/openai/tuning.py:
--------------------------------------------------------------------------------
1 | from typing import Optional, Callable
2 | from time import sleep
3 | from datetime import datetime
4 | import os
5 |
6 |
7 | def create_tuning_job(
8 | client: Callable,
9 | model: str,
10 | training_file: str,
11 | n_epochs: Optional[str] = None,
12 | suffix: Optional[str] = None,
13 | ):
14 | out = client.files.create(file=open(training_file, "rb"), purpose="fine-tune")
15 | out_id = out.id
16 | print(f"Created new file. FILE_ID = {out_id}")
17 | print(f"Waiting for file to be processed ...")
18 | while not wait_file_ready(client, out_id):
19 | sleep(5)
20 | # delete the training_file after it is uploaded
21 | os.remove(training_file)
22 | params = {
23 | "model": model,
24 | "training_file": out_id,
25 | }
26 | if n_epochs is not None:
27 | params["hyperparameters"] = {"n_epochs": n_epochs}
28 | if suffix is not None:
29 | params["suffix"] = suffix
30 | return client.fine_tuning.jobs.create(**params)
31 |
32 |
33 | def await_results(client: Callable, job_id: str, check_interval: int = 120):
34 | while True:
35 | job = client.fine_tuning.jobs.retrieve(job_id)
36 | status = job.status
37 | if status == "succeeded":
38 | return job
39 | elif status == "failed" or status == "cancelled":
40 | print(job)
41 | raise RuntimeError(f"Tuning job failed with status {status}")
42 | else:
43 | now = datetime.now()
44 | print(
45 | f"[{now}] Waiting for tuning job to complete. Current status: {status}"
46 | )
47 | sleep(check_interval)
48 |
49 |
50 | def delete_file(client: Callable, file_id: str):
51 | client.files.delete(file_id)
52 |
53 |
54 | def wait_file_ready(client: Callable, file_id):
55 | files = client.files.list().data
56 | found = False
57 | for file in files:
58 | if file.id == file_id:
59 | found = True
60 | if file.status == "processed":
61 | return True
62 | elif file.status in ["error", "deleting", "deleted"]:
63 | print(file)
64 | raise RuntimeError(
65 | f"File upload {file_id} failed with status {file.status}"
66 | )
67 | else:
68 | return False
69 | if not found:
70 | raise RuntimeError(f"File {file_id} not found")
71 |
--------------------------------------------------------------------------------
/skllm/llm/gpt/completion.py:
--------------------------------------------------------------------------------
1 | import warnings
2 | from skllm.llm.gpt.clients.openai.completion import (
3 | get_chat_completion as _oai_get_chat_completion,
4 | )
5 | from skllm.llm.gpt.clients.llama_cpp.completion import (
6 | get_chat_completion as _llamacpp_get_chat_completion,
7 | )
8 | from skllm.llm.gpt.utils import split_to_api_and_model
9 | from skllm.config import SKLLMConfig as _Config
10 |
11 |
12 | def get_chat_completion(
13 | messages: dict,
14 | openai_key: str = None,
15 | openai_org: str = None,
16 | model: str = "gpt-3.5-turbo",
17 | json_response: bool = False,
18 | ):
19 | """Gets a chat completion from the OpenAI compatible API."""
20 | api, model = split_to_api_and_model(model)
21 | if api == "gguf":
22 | return _llamacpp_get_chat_completion(messages, model)
23 | else:
24 | url = _Config.get_gpt_url()
25 | if api == "openai" and url is not None:
26 | warnings.warn(
27 | f"You are using the OpenAI backend with a custom URL: {url}; did you mean to use the `custom_url` backend?\nTo use the OpenAI backend, please remove the custom URL using `SKLLMConfig.reset_gpt_url()`."
28 | )
29 | elif api == "custom_url" and url is None:
30 | raise ValueError(
31 | "You are using the `custom_url` backend but no custom URL was provided. Please set it using `SKLLMConfig.set_gpt_url()`."
32 | )
33 | return _oai_get_chat_completion(
34 | messages,
35 | openai_key,
36 | openai_org,
37 | model,
38 | api=api,
39 | json_response=json_response,
40 | )
41 |
--------------------------------------------------------------------------------
/skllm/llm/gpt/embedding.py:
--------------------------------------------------------------------------------
1 | from skllm.llm.gpt.clients.openai.embedding import get_embedding as _oai_get_embedding
2 | from skllm.llm.gpt.utils import split_to_api_and_model
3 |
4 | def get_embedding(
5 | text: str,
6 | key: str,
7 | org: str,
8 | model: str = "text-embedding-ada-002",
9 | ):
10 | """
11 | Encodes a string and return the embedding for a string.
12 |
13 | Parameters
14 | ----------
15 | text : str
16 | The string to encode.
17 | key : str
18 | The OPEN AI key to use.
19 | org : str
20 | The OPEN AI organization ID to use.
21 | model : str, optional
22 | The model to use. Defaults to "text-embedding-ada-002".
23 |
24 | Returns
25 | -------
26 | emb : list
27 | The GPT embedding for the string.
28 | """
29 | api, model = split_to_api_and_model(model)
30 | if api == ("gpt4all"):
31 | raise ValueError("GPT4All is not supported for embeddings")
32 | return _oai_get_embedding(text, key, org, model, api=api)
33 |
--------------------------------------------------------------------------------
/skllm/llm/gpt/mixin.py:
--------------------------------------------------------------------------------
1 | from typing import Optional, Union, List, Any, Dict, Mapping
2 | from skllm.config import SKLLMConfig as _Config
3 | from skllm.llm.gpt.completion import get_chat_completion
4 | from skllm.llm.gpt.embedding import get_embedding
5 | from skllm.llm.base import (
6 | BaseClassifierMixin,
7 | BaseEmbeddingMixin,
8 | BaseTextCompletionMixin,
9 | BaseTunableMixin,
10 | )
11 | from skllm.llm.gpt.clients.openai.tuning import (
12 | create_tuning_job,
13 | await_results,
14 | delete_file,
15 | )
16 | import uuid
17 | from skllm.llm.gpt.clients.openai.credentials import (
18 | set_credentials as _set_credentials_openai,
19 | )
20 | from skllm.utils import extract_json_key
21 | import numpy as np
22 | from tqdm import tqdm
23 | import json
24 |
25 |
26 | def construct_message(role: str, content: str) -> dict:
27 | """Constructs a message for the OpenAI API.
28 |
29 | Parameters
30 | ----------
31 | role : str
32 | The role of the message. Must be one of "system", "user", or "assistant".
33 | content : str
34 | The content of the message.
35 |
36 | Returns
37 | -------
38 | message : dict
39 | """
40 | if role not in ("system", "user", "assistant"):
41 | raise ValueError("Invalid role")
42 | return {"role": role, "content": content}
43 |
44 |
45 | def _build_clf_example(
46 | x: str, y: str, system_msg="You are a text classification model."
47 | ):
48 | sample = {
49 | "messages": [
50 | {"role": "system", "content": system_msg},
51 | {"role": "user", "content": x},
52 | {"role": "assistant", "content": y},
53 | ]
54 | }
55 | return json.dumps(sample)
56 |
57 |
58 | class GPTMixin:
59 | """
60 | A mixin class that provides OpenAI key and organization to other classes.
61 | """
62 |
63 | _prefer_json_output = False
64 |
65 | def _set_keys(self, key: Optional[str] = None, org: Optional[str] = None) -> None:
66 | """
67 | Set the OpenAI key and organization.
68 | """
69 |
70 | self.key = key
71 | self.org = org
72 |
73 | def _get_openai_key(self) -> str:
74 | """
75 | Get the OpenAI key from the class or the config file.
76 |
77 | Returns
78 | -------
79 | key: str
80 | """
81 | key = self.key
82 | if key is None:
83 | key = _Config.get_openai_key()
84 | if (
85 | hasattr(self, "model")
86 | and isinstance(self.model, str)
87 | and self.model.startswith("gguf::")
88 | ):
89 | key = "gguf"
90 | if key is None:
91 | raise RuntimeError("OpenAI key was not found")
92 | return key
93 |
94 | def _get_openai_org(self) -> str:
95 | """
96 | Get the OpenAI organization ID from the class or the config file.
97 |
98 | Returns
99 | -------
100 | org: str
101 | """
102 | org = self.org
103 | if org is None:
104 | org = _Config.get_openai_org()
105 | if org is None:
106 | raise RuntimeError("OpenAI organization was not found")
107 | return org
108 |
109 |
110 | class GPTTextCompletionMixin(GPTMixin, BaseTextCompletionMixin):
111 | def _get_chat_completion(
112 | self,
113 | model: str,
114 | messages: Union[str, List[Dict[str, str]]],
115 | system_message: Optional[str] = None,
116 | **kwargs: Any,
117 | ):
118 | """Gets a chat completion from the OpenAI API.
119 |
120 | Parameters
121 | ----------
122 | model : str
123 | The model to use.
124 | messages : Union[str, List[Dict[str, str]]]
125 | input messages to use.
126 | system_message : Optional[str]
127 | A system message to use.
128 | **kwargs : Any
129 | placeholder.
130 |
131 | Returns
132 | -------
133 | completion : dict
134 | """
135 | msgs = []
136 | if system_message is not None:
137 | msgs.append(construct_message("system", system_message))
138 | if isinstance(messages, str):
139 | msgs.append(construct_message("user", messages))
140 | else:
141 | for message in messages:
142 | msgs.append(construct_message(message["role"], message["content"]))
143 | completion = get_chat_completion(
144 | msgs,
145 | self._get_openai_key(),
146 | self._get_openai_org(),
147 | model,
148 | json_response=self._prefer_json_output,
149 | )
150 | return completion
151 |
152 | def _convert_completion_to_str(self, completion: Mapping[str, Any]):
153 | if hasattr(completion, "__getitem__"):
154 | return str(completion["choices"][0]["message"]["content"])
155 | return str(completion.choices[0].message.content)
156 |
157 |
158 | class GPTClassifierMixin(GPTTextCompletionMixin, BaseClassifierMixin):
159 | _prefer_json_output = True
160 |
161 | def _extract_out_label(self, completion: Mapping[str, Any], **kwargs) -> Any:
162 | """Extracts the label from a completion.
163 |
164 | Parameters
165 | ----------
166 | label : Mapping[str, Any]
167 | The label to extract.
168 |
169 | Returns
170 | -------
171 | label : str
172 | """
173 | try:
174 | if hasattr(completion, "__getitem__"):
175 | label = extract_json_key(
176 | completion["choices"][0]["message"]["content"], "label"
177 | )
178 | else:
179 | label = extract_json_key(completion.choices[0].message.content, "label")
180 | except Exception as e:
181 | print(completion)
182 | print(f"Could not extract the label from the completion: {str(e)}")
183 | label = ""
184 | return label
185 |
186 |
187 | class GPTEmbeddingMixin(GPTMixin, BaseEmbeddingMixin):
188 | def _get_embeddings(self, text: np.ndarray) -> List[List[float]]:
189 | """Gets embeddings from the OpenAI compatible API.
190 |
191 | Parameters
192 | ----------
193 | text : str
194 | The text to embed.
195 | model : str
196 | The model to use.
197 | batch_size : int, optional
198 | The batch size to use. Defaults to 1.
199 |
200 | Returns
201 | -------
202 | embedding : List[List[float]]
203 | """
204 | embeddings = []
205 | print("Batch size:", self.batch_size)
206 | for i in tqdm(range(0, len(text), self.batch_size)):
207 | batch = text[i : i + self.batch_size].tolist()
208 | embeddings.extend(
209 | get_embedding(
210 | batch,
211 | self._get_openai_key(),
212 | self._get_openai_org(),
213 | self.model,
214 | )
215 | )
216 |
217 | return embeddings
218 |
219 |
220 | # for now this works only with OpenAI
221 | class GPTTunableMixin(BaseTunableMixin):
222 | _supported_tunable_models = [
223 | "gpt-3.5-turbo-0125",
224 | "gpt-3.5-turbo",
225 | "gpt-4o-mini-2024-07-18",
226 | "gpt-4o-mini",
227 | ]
228 |
229 | def _build_label(self, label: str):
230 | return json.dumps({"label": label})
231 |
232 | def _set_hyperparameters(self, base_model: str, n_epochs: int, custom_suffix: str):
233 | self.base_model = base_model
234 | self.n_epochs = n_epochs
235 | self.custom_suffix = custom_suffix
236 | if base_model not in self._supported_tunable_models:
237 | raise ValueError(
238 | f"Model {base_model} is not supported. Supported models are"
239 | f" {self._supported_tunable_models}"
240 | )
241 |
242 | def _tune(self, X, y):
243 | if self.base_model.startswith(("azure::", "gpt4all")):
244 | raise ValueError(
245 | "Tuning is not supported for this model. Please use a different model."
246 | )
247 | file_uuid = str(uuid.uuid4())
248 | filename = f"skllm_{file_uuid}.jsonl"
249 | with open(filename, "w+") as f:
250 | for xi, yi in zip(X, y):
251 | prompt = self._get_prompt(xi)
252 | if not isinstance(prompt["messages"], str):
253 | raise ValueError(
254 | "Incompatible prompt. Use a prompt with a single message."
255 | )
256 | f.write(
257 | _build_clf_example(
258 | prompt["messages"],
259 | self._build_label(yi),
260 | prompt["system_message"],
261 | )
262 | )
263 | f.write("\n")
264 | client = _set_credentials_openai(self._get_openai_key(), self._get_openai_org())
265 | job = create_tuning_job(
266 | client,
267 | self.base_model,
268 | filename,
269 | self.n_epochs,
270 | self.custom_suffix,
271 | )
272 | print(f"Created new tuning job. JOB_ID = {job.id}")
273 | job = await_results(client, job.id)
274 | self.openai_model = job.fine_tuned_model
275 | self.model = self.openai_model # openai_model is probably not required anymore
276 | delete_file(client, job.training_file)
277 | print(f"Finished training.")
278 |
--------------------------------------------------------------------------------
/skllm/llm/gpt/utils.py:
--------------------------------------------------------------------------------
1 | from typing import Tuple
2 |
3 | SUPPORTED_APIS = ["openai", "azure", "gguf", "custom_url"]
4 |
5 |
6 | def split_to_api_and_model(model: str) -> Tuple[str, str]:
7 | if "::" not in model:
8 | return "openai", model
9 | for api in SUPPORTED_APIS:
10 | if model.startswith(f"{api}::"):
11 | return api, model[len(api) + 2 :]
12 | raise ValueError(f"Unsupported API: {model.split('::')[0]}")
13 |
--------------------------------------------------------------------------------
/skllm/llm/vertex/completion.py:
--------------------------------------------------------------------------------
1 | from skllm.utils import retry
2 | from vertexai.language_models import ChatModel, TextGenerationModel
3 | from vertexai.generative_models import GenerativeModel, GenerationConfig
4 |
5 |
6 | @retry(max_retries=3)
7 | def get_completion(model: str, text: str):
8 | if model.startswith("text-"):
9 | model_instance = TextGenerationModel.from_pretrained(model)
10 | else:
11 | model_instance = TextGenerationModel.get_tuned_model(model)
12 | response = model_instance.predict(text, temperature=0.0)
13 | return response.text
14 |
15 |
16 | @retry(max_retries=3)
17 | def get_completion_chat_mode(model: str, context: str, text: str):
18 | model_instance = ChatModel.from_pretrained(model)
19 | chat = model_instance.start_chat(context=context)
20 | response = chat.send_message(text, temperature=0.0)
21 | return response.text
22 |
23 |
24 | @retry(max_retries=3)
25 | def get_completion_chat_gemini(model: str, context: str, text: str):
26 | model_instance = GenerativeModel(model, system_instruction=context)
27 | response = model_instance.generate_content(
28 | text, generation_config=GenerationConfig(temperature=0.0)
29 | )
30 | return response.text
31 |
--------------------------------------------------------------------------------
/skllm/llm/vertex/mixin.py:
--------------------------------------------------------------------------------
1 | from typing import Optional, Union, List, Any, Dict
2 | from skllm.llm.base import (
3 | BaseClassifierMixin,
4 | BaseEmbeddingMixin,
5 | BaseTextCompletionMixin,
6 | BaseTunableMixin,
7 | )
8 | from skllm.llm.vertex.tuning import tune
9 | from skllm.llm.vertex.completion import get_completion_chat_mode, get_completion, get_completion_chat_gemini
10 | from skllm.utils import extract_json_key
11 | import numpy as np
12 | import pandas as pd
13 |
14 |
15 | class VertexMixin:
16 | pass
17 |
18 |
19 | class VertexTextCompletionMixin(BaseTextCompletionMixin):
20 | def _get_chat_completion(
21 | self,
22 | model: str,
23 | messages: Union[str, List[Dict[str, str]]],
24 | system_message: Optional[str],
25 | examples: Optional[List] = None,
26 | ) -> str:
27 | if examples is not None:
28 | raise NotImplementedError(
29 | "Examples API is not yet supported for Vertex AI."
30 | )
31 | if not isinstance(messages, str):
32 | raise ValueError("Only messages as strings are supported.")
33 | if model.startswith("chat-"):
34 | completion = get_completion_chat_mode(model, system_message, messages)
35 | elif model.startswith("gemini-"):
36 | completion = get_completion_chat_gemini(model, system_message, messages)
37 | else:
38 | completion = get_completion(model, messages)
39 | return str(completion)
40 |
41 | def _convert_completion_to_str(self, completion: str) -> str:
42 | return completion
43 |
44 |
45 | class VertexClassifierMixin(BaseClassifierMixin, VertexTextCompletionMixin):
46 | def _extract_out_label(self, completion: str, **kwargs) -> Any:
47 | """Extracts the label from a completion.
48 |
49 | Parameters
50 | ----------
51 | label : Mapping[str, Any]
52 | The label to extract.
53 |
54 | Returns
55 | -------
56 | label : str
57 | """
58 | try:
59 | label = extract_json_key(str(completion), "label")
60 | except Exception as e:
61 | print(completion)
62 | print(f"Could not extract the label from the completion: {str(e)}")
63 | label = ""
64 | return label
65 |
66 |
67 | class VertexEmbeddingMixin(BaseEmbeddingMixin):
68 | def _get_embeddings(self, text: np.ndarray) -> List[List[float]]:
69 | raise NotImplementedError("Embeddings are not yet supported for Vertex AI.")
70 |
71 |
72 | class VertexTunableMixin(BaseTunableMixin):
73 | _supported_tunable_models = ["text-bison@002"]
74 |
75 | def _set_hyperparameters(self, base_model: str, n_update_steps: int, **kwargs):
76 | self.verify_model_is_supported(base_model)
77 | self.base_model = base_model
78 | self.n_update_steps = n_update_steps
79 |
80 | def verify_model_is_supported(self, model: str):
81 | if model not in self._supported_tunable_models:
82 | raise ValueError(
83 | f"Model {model} is not supported. Supported models are"
84 | f" {self._supported_tunable_models}"
85 | )
86 |
87 | def _tune(self, X: Any, y: Any):
88 | df = pd.DataFrame({"input_text": X, "output_text": y})
89 | job = tune(self.base_model, df, self.n_update_steps)._job
90 | tuned_model = job.result()
91 | self.tuned_model_ = tuned_model._model_resource_name
92 | self.model = tuned_model._model_resource_name
93 | return self
94 |
--------------------------------------------------------------------------------
/skllm/llm/vertex/tuning.py:
--------------------------------------------------------------------------------
1 | from pandas import DataFrame
2 | from vertexai.language_models import TextGenerationModel
3 |
4 |
5 | def tune(model: str, data: DataFrame, train_steps: int = 100):
6 | model = TextGenerationModel.from_pretrained(model)
7 | model.tune_model(
8 | training_data=data,
9 | train_steps=train_steps,
10 | tuning_job_location="europe-west4", # the only supported training location atm
11 | tuned_model_location="us-central1", # the only supported deployment location atm
12 | )
13 | return model
14 |
--------------------------------------------------------------------------------
/skllm/memory/__init__.py:
--------------------------------------------------------------------------------
1 | from skllm.memory._sklearn_nn import SklearnMemoryIndex
2 |
--------------------------------------------------------------------------------
/skllm/memory/_annoy.py:
--------------------------------------------------------------------------------
1 | import os
2 | import tempfile
3 | from typing import Any, List
4 |
5 | try:
6 | from annoy import AnnoyIndex
7 | except (ImportError, ModuleNotFoundError):
8 | AnnoyIndex = None
9 |
10 | from numpy import ndarray
11 |
12 | from skllm.memory.base import _BaseMemoryIndex
13 |
14 |
15 | class AnnoyMemoryIndex(_BaseMemoryIndex):
16 | """Memory index using Annoy.
17 |
18 | Parameters
19 | ----------
20 | dim : int
21 | dimensionality of the vectors
22 | metric : str, optional
23 | metric to use, by default "euclidean"
24 | """
25 |
26 | def __init__(self, dim: int = -1, metric: str = "euclidean", **kwargs: Any) -> None:
27 | if AnnoyIndex is None:
28 | raise ImportError(
29 | "Annoy is not installed. Please install annoy by running `pip install"
30 | " scikit-llm[annoy]`."
31 | )
32 | self.metric = metric
33 | self.dim = dim
34 | self.built = False
35 | self._index = None
36 | self._counter = 0
37 |
38 | def add(self, vector: ndarray) -> None:
39 | """Adds a vector to the index.
40 |
41 | Parameters
42 | ----------
43 | vector : ndarray
44 | vector to add to the index
45 | """
46 | if self.built:
47 | raise RuntimeError("Cannot add vectors after index is built.")
48 | if self.dim < 0:
49 | raise ValueError("Dimensionality must be positive.")
50 | if not self._index:
51 | self._index = AnnoyIndex(self.dim, self.metric)
52 | id = self._counter
53 | self._index.add_item(id, vector)
54 | self._counter += 1
55 |
56 | def build(self) -> None:
57 | """Builds the index.
58 |
59 | No new vectors can be added after building.
60 | """
61 | if self.dim < 0:
62 | raise ValueError("Dimensionality must be positive.")
63 | self._index.build(-1)
64 | self.built = True
65 |
66 | def retrieve(self, vectors: ndarray, k: int) -> List[List[int]]:
67 | """Retrieves the k nearest neighbors for each vector.
68 |
69 | Parameters
70 | ----------
71 | vectors : ndarray
72 | vectors to retrieve nearest neighbors for
73 | k : int
74 | number of nearest neighbors to retrieve
75 |
76 | Returns
77 | -------
78 | List
79 | ids of retrieved nearest neighbors
80 | """
81 | if not self.built:
82 | raise RuntimeError("Cannot retrieve vectors before the index is built.")
83 | return [
84 | self._index.get_nns_by_vector(v, k, search_k=-1, include_distances=False)
85 | for v in vectors
86 | ]
87 |
88 | def __getstate__(self) -> dict:
89 | """Returns the state of the object. To store the actual annoy index, it
90 | has to be written to a temporary file.
91 |
92 | Returns
93 | -------
94 | dict
95 | state of the object
96 | """
97 | state = self.__dict__.copy()
98 |
99 | # save index to temporary file
100 | with tempfile.NamedTemporaryFile(delete=False) as tmp:
101 | temp_filename = tmp.name
102 | self._index.save(temp_filename)
103 |
104 | # read bytes from the file
105 | with open(temp_filename, "rb") as tmp:
106 | index_bytes = tmp.read()
107 |
108 | # store bytes representation in state
109 | state["_index"] = index_bytes
110 |
111 | # remove temporary file
112 | os.remove(temp_filename)
113 |
114 | return state
115 |
116 | def __setstate__(self, state: dict) -> None:
117 | """Sets the state of the object. It restores the annoy index from the
118 | bytes representation.
119 |
120 | Parameters
121 | ----------
122 | state : dict
123 | state of the object
124 | """
125 | self.__dict__.update(state)
126 | # restore index from bytes
127 | with tempfile.NamedTemporaryFile(delete=False) as tmp:
128 | temp_filename = tmp.name
129 | tmp.write(self._index)
130 |
131 | self._index = AnnoyIndex(self.dim, self.metric)
132 | self._index.load(temp_filename)
133 |
134 | # remove temporary file
135 | os.remove(temp_filename)
136 |
--------------------------------------------------------------------------------
/skllm/memory/_sklearn_nn.py:
--------------------------------------------------------------------------------
1 | from typing import Any, List
2 |
3 | import numpy as np
4 | from sklearn.neighbors import NearestNeighbors
5 |
6 | from skllm.memory.base import _BaseMemoryIndex
7 |
8 |
9 | class SklearnMemoryIndex(_BaseMemoryIndex):
10 | """Memory index using Sklearn's NearestNeighbors.
11 |
12 | Parameters
13 | ----------
14 | dim : int
15 | dimensionality of the vectors
16 | metric : str, optional
17 | metric to use, by default "euclidean"
18 | """
19 |
20 | def __init__(self, dim: int = -1, metric: str = "euclidean", **kwargs: Any) -> None:
21 | self._index = NearestNeighbors(metric=metric, **kwargs)
22 | self.metric = metric
23 | self.dim = dim
24 | self.built = False
25 | self.data = []
26 |
27 | def add(self, vector: np.ndarray) -> None:
28 | """Adds a vector to the index.
29 |
30 | Parameters
31 | ----------
32 | vector : np.ndarray
33 | vector to add to the index
34 | """
35 | if self.built:
36 | raise RuntimeError("Cannot add vectors after index is built.")
37 | self.data.append(vector)
38 |
39 | def build(self) -> None:
40 | """Builds the index.
41 |
42 | No new vectors can be added after building.
43 | """
44 | data_matrix = np.array(self.data)
45 | self._index.fit(data_matrix)
46 | self.built = True
47 |
48 | def retrieve(self, vectors: np.ndarray, k: int) -> List[List[int]]:
49 | """Retrieves the k nearest neighbors for each vector.
50 |
51 | Parameters
52 | ----------
53 | vectors : np.ndarray
54 | vectors to retrieve nearest neighbors for
55 | k : int
56 | number of nearest neighbors to retrieve
57 |
58 | Returns
59 | -------
60 | List
61 | ids of retrieved nearest neighbors
62 | """
63 | if not self.built:
64 | raise RuntimeError("Cannot retrieve vectors before the index is built.")
65 | _, indices = self._index.kneighbors(vectors, n_neighbors=k)
66 | return indices.tolist()
67 |
--------------------------------------------------------------------------------
/skllm/memory/base.py:
--------------------------------------------------------------------------------
1 | from abc import ABC, abstractmethod
2 | from typing import Any, List, Type
3 |
4 | from numpy import ndarray
5 |
6 |
7 | class _BaseMemoryIndex(ABC):
8 | @abstractmethod
9 | def add(self, id: Any, vector: ndarray):
10 | """Adds a vector to the index.
11 |
12 | Parameters
13 | ----------
14 | id : Any
15 | identifier for the vector
16 | vector : ndarray
17 | vector to add to the index
18 | """
19 | pass
20 |
21 | @abstractmethod
22 | def retrieve(self, vectors: ndarray, k: int) -> List:
23 | """Retrieves the k nearest neighbors for each vector.
24 |
25 | Parameters
26 | ----------
27 | vectors : ndarray
28 | vectors to retrieve nearest neighbors for
29 | k : int
30 | number of nearest neighbors to retrieve
31 |
32 | Returns
33 | -------
34 | List
35 | ids of retrieved nearest neighbors
36 | """
37 | pass
38 |
39 | @abstractmethod
40 | def build(self) -> None:
41 | """Builds the index.
42 |
43 | All build parameters should be passed to the constructor.
44 | """
45 | pass
46 |
47 |
48 | class IndexConstructor:
49 | def __init__(self, index: Type[_BaseMemoryIndex], **kwargs: Any) -> None:
50 | self.index = index
51 | self.kwargs = kwargs
52 |
53 | def __call__(self) -> _BaseMemoryIndex:
54 | return self.index(**self.kwargs)
55 |
--------------------------------------------------------------------------------
/skllm/models/_base/classifier.py:
--------------------------------------------------------------------------------
1 | from typing import Any, List, Optional, Union
2 | from abc import ABC, abstractmethod
3 | from sklearn.base import (
4 | BaseEstimator as _SklBaseEstimator,
5 | ClassifierMixin as _SklClassifierMixin,
6 | )
7 | import warnings
8 | import numpy as np
9 | import pandas as pd
10 | from tqdm import tqdm
11 | from concurrent.futures import ThreadPoolExecutor
12 | import random
13 | from collections import Counter
14 | from skllm.llm.base import (
15 | BaseClassifierMixin as _BaseClassifierMixin,
16 | BaseTunableMixin as _BaseTunableMixin,
17 | )
18 | from skllm.utils import to_numpy as _to_numpy
19 | from skllm.prompts.templates import (
20 | ZERO_SHOT_CLF_PROMPT_TEMPLATE,
21 | ZERO_SHOT_MLCLF_PROMPT_TEMPLATE,
22 | FEW_SHOT_CLF_PROMPT_TEMPLATE,
23 | FEW_SHOT_MLCLF_PROMPT_TEMPLATE,
24 | ZERO_SHOT_CLF_SHORT_PROMPT_TEMPLATE,
25 | ZERO_SHOT_MLCLF_SHORT_PROMPT_TEMPLATE,
26 | COT_CLF_PROMPT_TEMPLATE,
27 | COT_MLCLF_PROMPT_TEMPLATE,
28 | )
29 | from skllm.prompts.builders import (
30 | build_zero_shot_prompt_slc,
31 | build_zero_shot_prompt_mlc,
32 | build_few_shot_prompt_slc,
33 | build_few_shot_prompt_mlc,
34 | )
35 | from skllm.memory.base import IndexConstructor
36 | from skllm.memory._sklearn_nn import SklearnMemoryIndex
37 | from skllm.models._base.vectorizer import BaseVectorizer as _BaseVectorizer
38 | from skllm.utils import re_naive_json_extractor
39 | import json
40 |
41 | _TRAINING_SAMPLE_PROMPT_TEMPLATE = """
42 | Sample input:
43 | ```{x}```
44 | s
45 | Sample target: {label}
46 | """
47 |
48 |
49 | class SingleLabelMixin:
50 | """Mixin class for single label classification."""
51 |
52 | def validate_prediction(self, label: Any) -> str:
53 | """
54 | Validates a prediction.
55 |
56 | Parameters
57 | ----------
58 | label : str
59 | The label to validate.
60 |
61 | Returns
62 | -------
63 | str
64 | The validated label.
65 | """
66 | if label not in self.classes_:
67 | label = str(label).replace("'", "").replace('"', "")
68 | if label not in self.classes_: # try again
69 | label = self._get_default_label()
70 | return label
71 |
72 | def _extract_labels(self, y: Any) -> List[str]:
73 | """
74 | Return the class labels as a list.
75 |
76 | Parameters
77 | ----------
78 | y : Any
79 |
80 | Returns
81 | -------
82 | List[str]
83 | """
84 | if isinstance(y, (pd.Series, np.ndarray)):
85 | labels = y.tolist()
86 | else:
87 | labels = y
88 | return labels
89 |
90 |
91 | class MultiLabelMixin:
92 | """Mixin class for multi label classification."""
93 |
94 | def validate_prediction(self, label: Any) -> List[str]:
95 | """
96 | Validates a prediction.
97 |
98 | Parameters
99 | ----------
100 | label : Any
101 | The label to validate.
102 |
103 | Returns
104 | -------
105 | List[str]
106 | The validated label.
107 | """
108 | if not isinstance(label, list):
109 | label = []
110 | filtered_labels = []
111 | for l in label:
112 | if l in self.classes_ and l:
113 | if l not in filtered_labels:
114 | filtered_labels.append(l)
115 | elif la := l.replace("'", "").replace('"', "") in self.classes_:
116 | if la not in filtered_labels:
117 | filtered_labels.append(la)
118 | else:
119 | default_label = self._get_default_label()
120 | if not (
121 | self.default_label == "Random" and default_label in filtered_labels
122 | ):
123 | filtered_labels.append(default_label)
124 | filtered_labels.extend([""] * self.max_labels)
125 | return filtered_labels[: self.max_labels]
126 |
127 | def _extract_labels(self, y) -> List[str]:
128 | """Extracts the labels into a list.
129 |
130 | Parameters
131 | ----------
132 | y : Any
133 |
134 | Returns
135 | -------
136 | List[str]
137 | """
138 | labels = []
139 | for l in y:
140 | if isinstance(l, list):
141 | for j in l:
142 | labels.append(j)
143 | else:
144 | labels.append(l)
145 | return labels
146 |
147 |
148 | class BaseClassifier(ABC, _SklBaseEstimator, _SklClassifierMixin):
149 | system_msg = "You are a text classifier."
150 |
151 | def __init__(
152 | self,
153 | model: Optional[str], # model can initially be None for tunable estimators
154 | default_label: str = "Random",
155 | max_labels: Optional[int] = 5,
156 | prompt_template: Optional[str] = None,
157 | **kwargs,
158 | ):
159 | if not isinstance(self, _BaseClassifierMixin):
160 | raise TypeError(
161 | "Classifier must be mixed with a skllm.llm.base.BaseClassifierMixin"
162 | " class"
163 | )
164 | if not isinstance(self, (SingleLabelMixin, MultiLabelMixin)):
165 | raise TypeError(
166 | "Classifier must be mixed with a SingleLabelMixin or MultiLabelMixin"
167 | " class"
168 | )
169 |
170 | self.model = model
171 | if not isinstance(default_label, str):
172 | raise TypeError("default_label must be a string")
173 | self.default_label = default_label
174 | if isinstance(self, MultiLabelMixin):
175 | if not isinstance(max_labels, int):
176 | raise TypeError("max_labels must be an integer")
177 | if max_labels < 2:
178 | raise ValueError("max_labels must be greater than 1")
179 | self.max_labels = max_labels
180 | if prompt_template is not None and not isinstance(prompt_template, str):
181 | raise TypeError("prompt_template must be a string or None")
182 | self.prompt_template = prompt_template
183 |
184 | def _predict_single(self, x: Any) -> Any:
185 | prompt_dict = self._get_prompt(x)
186 | # this will be inherited from the LLM
187 | prediction = self._get_chat_completion(model=self.model, **prompt_dict)
188 | prediction = self._extract_out_label(prediction)
189 | # this will be inherited from the sl/ml mixin
190 | prediction = self.validate_prediction(prediction)
191 | return prediction
192 |
193 | @abstractmethod
194 | def _get_prompt(self, x: str) -> dict:
195 | """Returns the prompt to use for a single input."""
196 | pass
197 |
198 | def fit(
199 | self,
200 | X: Optional[Union[np.ndarray, pd.Series, List[str]]],
201 | y: Union[np.ndarray, pd.Series, List[str], List[List[str]]],
202 | ):
203 | """
204 | Fits the model to the given data.
205 |
206 | Parameters
207 | ----------
208 | X : Optional[Union[np.ndarray, pd.Series, List[str]]]
209 | Training data
210 | y : Union[np.ndarray, pd.Series, List[str], List[List[str]]]
211 | Training labels
212 |
213 | Returns
214 | -------
215 | BaseClassifier
216 | self
217 | """
218 | X = _to_numpy(X)
219 | self.classes_, self.probabilities_ = self._get_unique_targets(y)
220 | return self
221 |
222 | def predict(self, X: Union[np.ndarray, pd.Series, List[str]], num_workers: int = 1):
223 | """
224 | Predicts the class of each input.
225 |
226 | Parameters
227 | ----------
228 | X : Union[np.ndarray, pd.Series, List[str]]
229 | The input data to predict the class of.
230 |
231 | num_workers : int
232 | number of workers to use for multithreaded prediction, default 1
233 |
234 | Returns
235 | -------
236 | np.ndarray
237 | The predicted classes as a numpy array.
238 | """
239 | X = _to_numpy(X)
240 |
241 | if num_workers > 1:
242 | warnings.warn(
243 | "Passing num_workers to predict is temporary and will be removed in the future."
244 | )
245 | with ThreadPoolExecutor(max_workers=num_workers) as executor:
246 | predictions = list(
247 | tqdm(executor.map(self._predict_single, X), total=len(X))
248 | )
249 | else:
250 | predictions = []
251 | for x in tqdm(X):
252 | predictions.append(self._predict_single(x))
253 |
254 | return np.array(predictions)
255 |
256 | def _get_unique_targets(self, y: Any):
257 | labels = self._extract_labels(y)
258 |
259 | counts = Counter(labels)
260 |
261 | total = sum(counts.values())
262 |
263 | classes, probs = [], []
264 | for l, c in counts.items():
265 | classes.append(l)
266 | probs.append(c / total)
267 |
268 | return classes, probs
269 |
270 | def _get_default_label(self):
271 | """Returns the default label based on the default_label argument."""
272 | if self.default_label == "Random":
273 | return random.choices(self.classes_, self.probabilities_)[0]
274 | else:
275 | return self.default_label
276 |
277 |
278 | class BaseZeroShotClassifier(BaseClassifier):
279 | def _get_prompt_template(self) -> str:
280 | """Returns the prompt template to use for a single input."""
281 | if self.prompt_template is not None:
282 | return self.prompt_template
283 | elif isinstance(self, SingleLabelMixin):
284 | return ZERO_SHOT_CLF_PROMPT_TEMPLATE
285 | return ZERO_SHOT_MLCLF_PROMPT_TEMPLATE
286 |
287 | def _get_prompt(self, x: str) -> dict:
288 | """Returns the prompt to use for a single input."""
289 | if isinstance(self, SingleLabelMixin):
290 | prompt = build_zero_shot_prompt_slc(
291 | x, repr(self.classes_), template=self._get_prompt_template()
292 | )
293 | else:
294 | prompt = build_zero_shot_prompt_mlc(
295 | x,
296 | repr(self.classes_),
297 | self.max_labels,
298 | template=self._get_prompt_template(),
299 | )
300 | return {"messages": prompt, "system_message": self.system_msg}
301 |
302 |
303 | class BaseCoTClassifier(BaseClassifier):
304 | def _get_prompt_template(self) -> str:
305 | """Returns the prompt template to use for a single input."""
306 | if self.prompt_template is not None:
307 | return self.prompt_template
308 | elif isinstance(self, SingleLabelMixin):
309 | return COT_CLF_PROMPT_TEMPLATE
310 | return COT_MLCLF_PROMPT_TEMPLATE
311 |
312 | def _get_prompt(self, x: str) -> dict:
313 | """Returns the prompt to use for a single input."""
314 | if isinstance(self, SingleLabelMixin):
315 | prompt = build_zero_shot_prompt_slc(
316 | x, repr(self.classes_), template=self._get_prompt_template()
317 | )
318 | else:
319 | prompt = build_zero_shot_prompt_mlc(
320 | x,
321 | repr(self.classes_),
322 | self.max_labels,
323 | template=self._get_prompt_template(),
324 | )
325 | return {"messages": prompt, "system_message": self.system_msg}
326 |
327 | def _predict_single(self, x: Any) -> Any:
328 | prompt_dict = self._get_prompt(x)
329 | # this will be inherited from the LLM
330 | completion = self._get_chat_completion(model=self.model, **prompt_dict)
331 | completion = self._convert_completion_to_str(completion)
332 | try:
333 | as_dict = json.loads(re_naive_json_extractor(completion))
334 | label = as_dict["label"]
335 | explanation = str(as_dict["explanation"])
336 | except Exception as e:
337 | label = "None"
338 | explanation = "Explanation is not available."
339 | # this will be inherited from the sl/ml mixin
340 | prediction = self.validate_prediction(label)
341 | return [prediction, explanation]
342 |
343 |
344 | class BaseFewShotClassifier(BaseClassifier):
345 | def _get_prompt_template(self) -> str:
346 | """Returns the prompt template to use for a single input."""
347 | if self.prompt_template is not None:
348 | return self.prompt_template
349 | elif isinstance(self, SingleLabelMixin):
350 | return FEW_SHOT_CLF_PROMPT_TEMPLATE
351 | return FEW_SHOT_MLCLF_PROMPT_TEMPLATE
352 |
353 | def _get_prompt(self, x: str) -> dict:
354 | """Returns the prompt to use for a single input."""
355 | training_data = []
356 | for xt, yt in zip(*self.training_data_):
357 | training_data.append(
358 | _TRAINING_SAMPLE_PROMPT_TEMPLATE.format(x=xt, label=yt)
359 | )
360 |
361 | training_data_str = "\n".join(training_data)
362 |
363 | if isinstance(self, SingleLabelMixin):
364 | prompt = build_few_shot_prompt_slc(
365 | x,
366 | repr(self.classes_),
367 | training_data=training_data_str,
368 | template=self._get_prompt_template(),
369 | )
370 | else:
371 | prompt = build_few_shot_prompt_mlc(
372 | x,
373 | repr(self.classes_),
374 | training_data=training_data_str,
375 | max_cats=self.max_labels,
376 | template=self._get_prompt_template(),
377 | )
378 | return {"messages": prompt, "system_message": "You are a text classifier."}
379 |
380 | def fit(
381 | self,
382 | X: Union[np.ndarray, pd.Series, List[str]],
383 | y: Union[np.ndarray, pd.Series, List[str], List[List[str]]],
384 | ):
385 | """
386 | Fits the model to the given data.
387 |
388 | Parameters
389 | ----------
390 | X : Union[np.ndarray, pd.Series, List[str]]
391 | Training data
392 | y : Union[np.ndarray, pd.Series, List[str]]
393 | Training labels
394 |
395 | Returns
396 | -------
397 | BaseFewShotClassifier
398 | self
399 | """
400 | if not len(X) == len(y):
401 | raise ValueError("X and y must have the same length.")
402 | X = _to_numpy(X)
403 | y = _to_numpy(y)
404 | self.training_data_ = (X, y)
405 | self.classes_, self.probabilities_ = self._get_unique_targets(y)
406 | return self
407 |
408 |
409 | class BaseDynamicFewShotClassifier(BaseClassifier):
410 | def __init__(
411 | self,
412 | model: str,
413 | default_label: str = "Random",
414 | n_examples=3,
415 | memory_index: Optional[IndexConstructor] = None,
416 | vectorizer: _BaseVectorizer = None,
417 | prompt_template: Optional[str] = None,
418 | metric="euclidean",
419 | ):
420 | super().__init__(
421 | model=model,
422 | default_label=default_label,
423 | prompt_template=prompt_template,
424 | )
425 | self.vectorizer = vectorizer
426 | self.memory_index = memory_index
427 | self.n_examples = n_examples
428 | self.metric = metric
429 | if isinstance(self, MultiLabelMixin):
430 | raise TypeError("Multi-label classification is not supported")
431 |
432 | def fit(
433 | self,
434 | X: Union[np.ndarray, pd.Series, List[str]],
435 | y: Union[np.ndarray, pd.Series, List[str]],
436 | ):
437 | """
438 | Fits the model to the given data.
439 |
440 | Parameters
441 | ----------
442 | X : Union[np.ndarray, pd.Series, List[str]]
443 | Training data
444 | y : Union[np.ndarray, pd.Series, List[str]]
445 | Training labels
446 |
447 | Returns
448 | -------
449 | BaseDynamicFewShotClassifier
450 | self
451 | """
452 |
453 | if not self.vectorizer:
454 | raise ValueError("Vectorizer must be set")
455 | X = _to_numpy(X)
456 | y = _to_numpy(y)
457 | self.embedding_model_ = self.vectorizer.fit(X)
458 | self.classes_, self.probabilities_ = self._get_unique_targets(y)
459 |
460 | self.data_ = {}
461 | for cls in self.classes_:
462 | print(f"Building index for class `{cls}` ...")
463 | self.data_[cls] = {}
464 | partition = X[y == cls]
465 | self.data_[cls]["partition"] = partition
466 | embeddings = self.embedding_model_.transform(partition)
467 | if self.memory_index is not None:
468 | index = self.memory_index()
469 | index.dim = embeddings.shape[1]
470 | else:
471 | index = SklearnMemoryIndex(embeddings.shape[1], metric=self.metric)
472 | for embedding in embeddings:
473 | index.add(embedding)
474 | index.build()
475 | self.data_[cls]["index"] = index
476 |
477 | return self
478 |
479 | def _get_prompt_template(self) -> str:
480 | """Returns the prompt template to use for a single input."""
481 | if self.prompt_template is not None:
482 | return self.prompt_template
483 | return FEW_SHOT_CLF_PROMPT_TEMPLATE
484 |
485 | def _reorder_examples(self, examples):
486 | n_classes = len(self.classes_)
487 | n_examples = self.n_examples
488 |
489 | shuffled_list = []
490 |
491 | for i in range(n_examples):
492 | for cls in range(n_classes):
493 | shuffled_list.append(cls * n_examples + i)
494 |
495 | return [examples[i] for i in shuffled_list]
496 |
497 | def _get_prompt(self, x: str) -> dict:
498 | """
499 | Generates the prompt for the given input.
500 |
501 | Parameters
502 | ----------
503 | x : str
504 | sample to classify
505 |
506 | Returns
507 | -------
508 | dict
509 | final prompt
510 | """
511 | embedding = self.embedding_model_.transform([x])
512 | training_data = []
513 | for cls in self.classes_:
514 | index = self.data_[cls]["index"]
515 | partition = self.data_[cls]["partition"]
516 | neighbors = index.retrieve(embedding, min(self.n_examples, len(partition)))
517 | neighbors = [partition[i] for i in neighbors[0]]
518 | training_data.extend(
519 | [
520 | _TRAINING_SAMPLE_PROMPT_TEMPLATE.format(x=neighbor, label=cls)
521 | for neighbor in neighbors
522 | ]
523 | )
524 |
525 | training_data_str = "\n".join(self._reorder_examples(training_data))
526 |
527 | msg = build_few_shot_prompt_slc(
528 | x=x,
529 | training_data=training_data_str,
530 | labels=repr(self.classes_),
531 | template=self._get_prompt_template(),
532 | )
533 |
534 | return {"messages": msg, "system_message": "You are a text classifier."}
535 |
536 |
537 | class BaseTunableClassifier(BaseClassifier):
538 | def fit(
539 | self,
540 | X: Optional[Union[np.ndarray, pd.Series, List[str]]],
541 | y: Union[np.ndarray, pd.Series, List[str], List[List[str]]],
542 | ):
543 | """
544 | Fits the model to the given data.
545 |
546 | Parameters
547 | ----------
548 | X : Optional[Union[np.ndarray, pd.Series, List[str]]]
549 | Training data
550 | y : Union[np.ndarray, pd.Series, List[str], List[List[str]]]
551 | Training labels
552 |
553 | Returns
554 | -------
555 | BaseTunableClassifier
556 | self
557 | """
558 | if not isinstance(self, _BaseTunableMixin):
559 | raise TypeError(
560 | "Classifier must be mixed with a skllm.llm.base.BaseTunableMixin class"
561 | )
562 | super().fit(X, y)
563 | self._tune(X, y)
564 | return self
565 |
566 | def _get_prompt_template(self) -> str:
567 | """Returns the prompt template to use for a single input."""
568 | if self.prompt_template is not None:
569 | return self.prompt_template
570 | elif isinstance(self, SingleLabelMixin):
571 | return ZERO_SHOT_CLF_SHORT_PROMPT_TEMPLATE
572 | return ZERO_SHOT_MLCLF_SHORT_PROMPT_TEMPLATE
573 |
574 | def _get_prompt(self, x: str) -> dict:
575 | """Returns the prompt to use for a single input."""
576 | if isinstance(self, SingleLabelMixin):
577 | prompt = build_zero_shot_prompt_slc(
578 | x, repr(self.classes_), template=self._get_prompt_template()
579 | )
580 | else:
581 | prompt = build_zero_shot_prompt_mlc(
582 | x,
583 | repr(self.classes_),
584 | self.max_labels,
585 | template=self._get_prompt_template(),
586 | )
587 | return {"messages": prompt, "system_message": "You are a text classifier."}
588 |
--------------------------------------------------------------------------------
/skllm/models/_base/tagger.py:
--------------------------------------------------------------------------------
1 | from typing import Any, Union, List, Optional, Dict
2 | from abc import abstractmethod, ABC
3 | from numpy import ndarray
4 | from tqdm import tqdm
5 | import numpy as np
6 | import pandas as pd
7 | from skllm.utils import to_numpy as _to_numpy
8 | from sklearn.base import (
9 | BaseEstimator as _SklBaseEstimator,
10 | TransformerMixin as _SklTransformerMixin,
11 | )
12 |
13 | from skllm.utils.rendering import display_ner
14 | from skllm.utils.xml import filter_xml_tags, filter_unwanted_entities, json_to_xml
15 | from skllm.prompts.builders import build_ner_prompt
16 | from skllm.prompts.templates import (
17 | NER_SYSTEM_MESSAGE_TEMPLATE,
18 | NER_SYSTEM_MESSAGE_SPARSE,
19 | EXPLAINABLE_NER_DENSE_PROMPT_TEMPLATE,
20 | EXPLAINABLE_NER_SPARSE_PROMPT_TEMPLATE,
21 | )
22 | from skllm.utils import re_naive_json_extractor
23 | import json
24 | from concurrent.futures import ThreadPoolExecutor
25 |
26 | class BaseTagger(ABC, _SklBaseEstimator, _SklTransformerMixin):
27 |
28 | num_workers = 1
29 |
30 | def fit(self, X: Any, y: Any = None):
31 | """
32 | Fits the model to the data. Usually a no-op.
33 |
34 | Parameters
35 | ----------
36 | X : Any
37 | training data
38 | y : Any
39 | training outputs
40 |
41 | Returns
42 | -------
43 | self
44 | BaseTagger
45 | """
46 | return self
47 |
48 | def predict(self, X: Union[np.ndarray, pd.Series, List[str]]):
49 | return self.transform(X)
50 |
51 | def fit_transform(
52 | self,
53 | X: Union[np.ndarray, pd.Series, List[str]],
54 | y: Optional[Union[np.ndarray, pd.Series, List[str]]] = None,
55 | ) -> ndarray:
56 | return self.fit(X, y).transform(X)
57 |
58 | def transform(self, X: Union[np.ndarray, pd.Series, List[str]]):
59 | """
60 | Transforms the input data.
61 |
62 | Parameters
63 | ----------
64 | X : Union[np.ndarray, pd.Series, List[str]]
65 | The input data to predict the class of.
66 |
67 | Returns
68 | -------
69 | List[str]
70 | """
71 | X = _to_numpy(X)
72 | predictions = []
73 | with ThreadPoolExecutor(max_workers=self.num_workers) as executor:
74 | predictions = list(tqdm(executor.map(self._predict_single, X), total=len(X)))
75 | return np.asarray(predictions)
76 |
77 | def _predict_single(self, x: Any) -> Any:
78 | prompt_dict = self._get_prompt(x)
79 | # this will be inherited from the LLM
80 | prediction = self._get_chat_completion(model=self.model, **prompt_dict)
81 | prediction = self._convert_completion_to_str(prediction)
82 | return prediction
83 |
84 | @abstractmethod
85 | def _get_prompt(self, x: str) -> dict:
86 | """Returns the prompt to use for a single input."""
87 | pass
88 |
89 |
90 | class ExplainableNER(BaseTagger):
91 | entities: Optional[Dict[str, str]] = None
92 | sparse_output = True
93 | _allowed_tags = ["entity", "not_entity"]
94 |
95 | display_predictions = False
96 |
97 | def fit(self, X: Any, y: Any = None):
98 | entities = []
99 | for k, v in self.entities.items():
100 | entities.append({"entity": k, "definition": v})
101 | self.expanded_entities_ = entities
102 | return self
103 |
104 | def transform(self, X: ndarray | pd.Series | List[str]):
105 | predictions = super().transform(X)
106 | if self.sparse_output:
107 | json_predictions = [
108 | re_naive_json_extractor(p, expected_output="array") for p in predictions
109 | ]
110 | predictions = []
111 | attributes = ["reasoning", "tag", "value"]
112 | for x, p in zip(X, json_predictions):
113 | p_json = json.loads(p)
114 | predictions.append(
115 | json_to_xml(
116 | x,
117 | p_json,
118 | "entity",
119 | "not_entity",
120 | value_key="value",
121 | attributes=attributes,
122 | )
123 | )
124 |
125 | predictions = [
126 | filter_unwanted_entities(
127 | filter_xml_tags(p, self._allowed_tags), self.entities
128 | )
129 | for p in predictions
130 | ]
131 | if self.display_predictions:
132 | print("Displaying predictions...")
133 | display_ner(predictions, self.entities)
134 | return predictions
135 |
136 | def _get_prompt(self, x: str) -> dict:
137 | if not hasattr(self, "expanded_entities_"):
138 | raise ValueError("Model not fitted.")
139 | system_message = (
140 | NER_SYSTEM_MESSAGE_TEMPLATE.format(entities=self.entities.keys())
141 | if self.sparse_output
142 | else NER_SYSTEM_MESSAGE_SPARSE
143 | )
144 | template = (
145 | EXPLAINABLE_NER_SPARSE_PROMPT_TEMPLATE
146 | if self.sparse_output
147 | else EXPLAINABLE_NER_DENSE_PROMPT_TEMPLATE
148 | )
149 | return {
150 | "messages": build_ner_prompt(self.expanded_entities_, x, template=template),
151 | "system_message": system_message.format(entities=self.entities.keys()),
152 | }
153 |
--------------------------------------------------------------------------------
/skllm/models/_base/text2text.py:
--------------------------------------------------------------------------------
1 | from typing import Any, Union, List, Optional
2 | from abc import abstractmethod, ABC
3 | from numpy import ndarray
4 | from tqdm import tqdm
5 | import numpy as np
6 | import pandas as pd
7 | from skllm.utils import to_numpy as _to_numpy
8 | from sklearn.base import (
9 | BaseEstimator as _SklBaseEstimator,
10 | TransformerMixin as _SklTransformerMixin,
11 | )
12 | from skllm.llm.base import BaseTunableMixin as _BaseTunableMixin
13 | from skllm.prompts.builders import (
14 | build_focused_summary_prompt,
15 | build_summary_prompt,
16 | build_translation_prompt,
17 | )
18 |
19 |
20 | class BaseText2TextModel(ABC, _SklBaseEstimator, _SklTransformerMixin):
21 | def fit(self, X: Any, y: Any = None):
22 | """
23 | Fits the model to the data. Usually a no-op.
24 |
25 | Parameters
26 | ----------
27 | X : Any
28 | training data
29 | y : Any
30 | training outputs
31 |
32 | Returns
33 | -------
34 | self
35 | BaseText2TextModel
36 | """
37 | return self
38 |
39 | def predict(self, X: Union[np.ndarray, pd.Series, List[str]]):
40 | return self.transform(X)
41 |
42 | def fit_transform(
43 | self,
44 | X: Union[np.ndarray, pd.Series, List[str]],
45 | y: Optional[Union[np.ndarray, pd.Series, List[str]]] = None,
46 | ) -> ndarray:
47 | return self.fit(X, y).transform(X)
48 |
49 | def transform(self, X: Union[np.ndarray, pd.Series, List[str]]):
50 | """
51 | Transforms the input data.
52 |
53 | Parameters
54 | ----------
55 | X : Union[np.ndarray, pd.Series, List[str]]
56 | The input data to predict the class of.
57 |
58 | Returns
59 | -------
60 | List[str]
61 | """
62 | X = _to_numpy(X)
63 | predictions = []
64 | for i in tqdm(range(len(X))):
65 | predictions.append(self._predict_single(X[i]))
66 | return predictions
67 |
68 | def _predict_single(self, x: Any) -> Any:
69 | prompt_dict = self._get_prompt(x)
70 | # this will be inherited from the LLM
71 | prediction = self._get_chat_completion(model=self.model, **prompt_dict)
72 | prediction = self._convert_completion_to_str(prediction)
73 | return prediction
74 |
75 | @abstractmethod
76 | def _get_prompt(self, x: str) -> dict:
77 | """Returns the prompt to use for a single input."""
78 | pass
79 |
80 |
81 | class BaseTunableText2TextModel(BaseText2TextModel):
82 | def fit(
83 | self,
84 | X: Union[np.ndarray, pd.Series, List[str]],
85 | y: Union[np.ndarray, pd.Series, List[str]],
86 | ):
87 | """
88 | Fits the model to the data.
89 |
90 | Parameters
91 | ----------
92 | X : Union[np.ndarray, pd.Series, List[str]]
93 | training data
94 | y : Union[np.ndarray, pd.Series, List[str]]
95 | training labels
96 |
97 | Returns
98 | -------
99 | BaseTunableText2TextModel
100 | self
101 | """
102 | if not isinstance(self, _BaseTunableMixin):
103 | raise TypeError(
104 | "Classifier must be mixed with a skllm.llm.base.BaseTunableMixin class"
105 | )
106 | self._tune(X, y)
107 | return self
108 |
109 | def _get_prompt(self, x: str) -> dict:
110 | """Returns the prompt to use for a single input."""
111 | return {"messages": str(x), "system_message": ""}
112 |
113 | def _predict_single(self, x: str) -> str:
114 | if self.model is None:
115 | raise RuntimeError("Model has not been tuned yet")
116 | return super()._predict_single(x)
117 |
118 |
119 | class BaseSummarizer(BaseText2TextModel):
120 | max_words: int = 15
121 | focus: Optional[str] = None
122 | system_message: str = "You are a text summarizer."
123 |
124 | def _get_prompt(self, X: str) -> str:
125 | if self.focus:
126 | prompt = build_focused_summary_prompt(X, self.max_words, self.focus)
127 | else:
128 | prompt = build_summary_prompt(X, self.max_words)
129 | return {"messages": prompt, "system_message": self.system_message}
130 |
131 | def transform(
132 | self, X: Union[ndarray, pd.Series, List[str]], **kwargs: Any
133 | ) -> ndarray:
134 | """
135 | Transforms the input data.
136 |
137 | Parameters
138 | ----------
139 | X : Union[np.ndarray, pd.Series, List[str]]
140 | The input data to predict the class of.
141 |
142 | Returns
143 | -------
144 | List[str]
145 | """
146 | y = super().transform(X, **kwargs)
147 | if self.focus:
148 | y = np.asarray(
149 | [
150 | i.replace("Mentioned concept is not present in the text.", "")
151 | .replace("The general summary is:", "")
152 | .strip()
153 | for i in y
154 | ],
155 | dtype=object,
156 | )
157 | return y
158 |
159 |
160 | class BaseTranslator(BaseText2TextModel):
161 | output_language: str = "English"
162 | system_message = "You are a text translator."
163 |
164 | def _get_prompt(self, X: str) -> str:
165 | prompt = build_translation_prompt(X, self.output_language)
166 | return {"messages": prompt, "system_message": self.system_message}
167 |
168 | def transform(
169 | self, X: Union[ndarray, pd.Series, List[str]], **kwargs: Any
170 | ) -> ndarray:
171 | """
172 | Transforms the input data.
173 |
174 | Parameters
175 | ----------
176 | X : Union[np.ndarray, pd.Series, List[str]]
177 | The input data to predict the class of.
178 |
179 | Returns
180 | -------
181 | List[str]
182 | """
183 | y = super().transform(X, **kwargs)
184 | y = np.asarray(
185 | [i.replace("[Translated text:]", "").replace("```", "").strip() for i in y],
186 | dtype=object,
187 | )
188 | return y
189 |
--------------------------------------------------------------------------------
/skllm/models/_base/vectorizer.py:
--------------------------------------------------------------------------------
1 | from typing import Any, List, Optional, Union
2 | import numpy as np
3 | import pandas as pd
4 | from skllm.utils import to_numpy as _to_numpy
5 | from skllm.llm.base import BaseEmbeddingMixin
6 | from sklearn.base import (
7 | BaseEstimator as _SklBaseEstimator,
8 | TransformerMixin as _SklTransformerMixin,
9 | )
10 |
11 |
12 | class BaseVectorizer(_SklBaseEstimator, _SklTransformerMixin):
13 | """
14 | A base vectorization/embedding class.
15 |
16 | Parameters
17 | ----------
18 | model : str
19 | The embedding model to use.
20 | """
21 |
22 | def __init__(self, model: str, batch_size: int = 1):
23 | if not isinstance(self, BaseEmbeddingMixin):
24 | raise TypeError(
25 | "Vectorizer must be mixed with skllm.llm.base.BaseEmbeddingMixin."
26 | )
27 | self.model = model
28 | if not isinstance(batch_size, int):
29 | raise TypeError("batch_size must be an integer")
30 | if batch_size < 1:
31 | raise ValueError("batch_size must be greater than 0")
32 | self.batch_size = batch_size
33 |
34 | def fit(self, X: Any = None, y: Any = None, **kwargs):
35 | """
36 | Does nothing. Needed only for sklearn compatibility.
37 |
38 | Parameters
39 | ----------
40 | X : Any, optional
41 | y : Any, optional
42 | kwargs : dict, optional
43 |
44 | Returns
45 | -------
46 | self : BaseVectorizer
47 | """
48 | return self
49 |
50 | def transform(
51 | self, X: Optional[Union[np.ndarray, pd.Series, List[str]]]
52 | ) -> np.ndarray:
53 | """
54 | Transforms a list of strings into a list of GPT embeddings.
55 | This is modelled to function as the sklearn transform method
56 |
57 | Parameters
58 | ----------
59 | X : Optional[Union[np.ndarray, pd.Series, List[str]]]
60 | The input array of strings to transform into GPT embeddings.
61 |
62 | Returns
63 | -------
64 | embeddings : np.ndarray
65 | """
66 | X = _to_numpy(X)
67 | embeddings = self._get_embeddings(X)
68 | embeddings = np.asarray(embeddings)
69 | return embeddings
70 |
71 | def fit_transform(
72 | self,
73 | X: Optional[Union[np.ndarray, pd.Series, List[str]]],
74 | y: Any = None,
75 | **fit_params,
76 | ) -> np.ndarray:
77 | """
78 | Fits and transforms a list of strings into a list of embeddings.
79 | This is modelled to function as the sklearn fit_transform method
80 |
81 | Parameters
82 | ----------
83 | X : Optional[Union[np.ndarray, pd.Series, List[str]]]
84 | The input array of strings to transform into embeddings.
85 | y : Any, optional
86 |
87 | Returns
88 | -------
89 | embeddings : np.ndarray
90 | """
91 | return self.fit(X, y).transform(X)
92 |
--------------------------------------------------------------------------------
/skllm/models/anthropic/classification/few_shot.py:
--------------------------------------------------------------------------------
1 | from skllm.models._base.classifier import (
2 | BaseFewShotClassifier,
3 | BaseDynamicFewShotClassifier,
4 | SingleLabelMixin,
5 | MultiLabelMixin,
6 | )
7 | from skllm.llm.anthropic.mixin import ClaudeClassifierMixin
8 | from skllm.models.gpt.vectorization import GPTVectorizer
9 | from skllm.models._base.vectorizer import BaseVectorizer
10 | from skllm.memory.base import IndexConstructor
11 | from typing import Optional
12 |
13 |
14 | class FewShotClaudeClassifier(BaseFewShotClassifier, ClaudeClassifierMixin, SingleLabelMixin):
15 | """Few-shot text classifier using Anthropic's Claude API for single-label classification tasks."""
16 |
17 | def __init__(
18 | self,
19 | model: str = "claude-3-haiku-20240307",
20 | default_label: str = "Random",
21 | prompt_template: Optional[str] = None,
22 | key: Optional[str] = None,
23 | **kwargs,
24 | ):
25 | """
26 | Few-shot text classifier using Anthropic's Claude API.
27 |
28 | Parameters
29 | ----------
30 | model : str, optional
31 | model to use, by default "claude-3-haiku-20240307"
32 | default_label : str, optional
33 | default label for failed prediction; if "Random" -> selects randomly based on class frequencies
34 | prompt_template : Optional[str], optional
35 | custom prompt template to use, by default None
36 | key : Optional[str], optional
37 | estimator-specific API key; if None, retrieved from the global config
38 | """
39 | super().__init__(
40 | model=model,
41 | default_label=default_label,
42 | prompt_template=prompt_template,
43 | **kwargs,
44 | )
45 | self._set_keys(key)
46 |
47 |
48 | class MultiLabelFewShotClaudeClassifier(
49 | BaseFewShotClassifier, ClaudeClassifierMixin, MultiLabelMixin
50 | ):
51 | """Few-shot text classifier using Anthropic's Claude API for multi-label classification tasks."""
52 |
53 | def __init__(
54 | self,
55 | model: str = "claude-3-haiku-20240307",
56 | default_label: str = "Random",
57 | max_labels: Optional[int] = 5,
58 | prompt_template: Optional[str] = None,
59 | key: Optional[str] = None,
60 | **kwargs,
61 | ):
62 | """
63 | Multi-label few-shot text classifier using Anthropic's Claude API.
64 |
65 | Parameters
66 | ----------
67 | model : str, optional
68 | model to use, by default "claude-3-haiku-20240307"
69 | default_label : str, optional
70 | default label for failed prediction; if "Random" -> selects randomly based on class frequencies
71 | max_labels : Optional[int], optional
72 | maximum labels per sample, by default 5
73 | prompt_template : Optional[str], optional
74 | custom prompt template to use, by default None
75 | key : Optional[str], optional
76 | estimator-specific API key; if None, retrieved from the global config
77 | """
78 | super().__init__(
79 | model=model,
80 | default_label=default_label,
81 | max_labels=max_labels,
82 | prompt_template=prompt_template,
83 | **kwargs,
84 | )
85 | self._set_keys(key)
86 |
87 |
88 | class DynamicFewShotClaudeClassifier(
89 | BaseDynamicFewShotClassifier, ClaudeClassifierMixin, SingleLabelMixin
90 | ):
91 | """
92 | Dynamic few-shot text classifier using Anthropic's Claude API for
93 | single-label classification tasks with dynamic example selection using GPT embeddings.
94 | """
95 |
96 | def __init__(
97 | self,
98 | model: str = "claude-3-haiku-20240307",
99 | default_label: str = "Random",
100 | prompt_template: Optional[str] = None,
101 | key: Optional[str] = None,
102 | n_examples: int = 3,
103 | memory_index: Optional[IndexConstructor] = None,
104 | vectorizer: Optional[BaseVectorizer] = None,
105 | metric: Optional[str] = "euclidean",
106 | **kwargs,
107 | ):
108 | """
109 | Dynamic few-shot text classifier using Anthropic's Claude API.
110 | For each sample, N closest examples are retrieved from the memory.
111 |
112 | Parameters
113 | ----------
114 | model : str, optional
115 | model to use, by default "claude-3-haiku-20240307"
116 | default_label : str, optional
117 | default label for failed prediction; if "Random" -> selects randomly based on class frequencies
118 | prompt_template : Optional[str], optional
119 | custom prompt template to use, by default None
120 | key : Optional[str], optional
121 | estimator-specific API key; if None, retrieved from the global config
122 | n_examples : int, optional
123 | number of closest examples per class to be retrieved, by default 3
124 | memory_index : Optional[IndexConstructor], optional
125 | custom memory index, for details check `skllm.memory` submodule
126 | vectorizer : Optional[BaseVectorizer], optional
127 | scikit-llm vectorizer; if None, `GPTVectorizer` is used
128 | metric : Optional[str], optional
129 | metric used for similarity search, by default "euclidean"
130 | """
131 | if vectorizer is None:
132 | vectorizer = GPTVectorizer(model="text-embedding-ada-002", key=key)
133 | super().__init__(
134 | model=model,
135 | default_label=default_label,
136 | prompt_template=prompt_template,
137 | n_examples=n_examples,
138 | memory_index=memory_index,
139 | vectorizer=vectorizer,
140 | metric=metric,
141 | )
142 | self._set_keys(key)
143 |
--------------------------------------------------------------------------------
/skllm/models/anthropic/classification/zero_shot.py:
--------------------------------------------------------------------------------
1 | from skllm.models._base.classifier import (
2 | SingleLabelMixin as _SingleLabelMixin,
3 | MultiLabelMixin as _MultiLabelMixin,
4 | BaseZeroShotClassifier as _BaseZeroShotClassifier,
5 | BaseCoTClassifier as _BaseCoTClassifier,
6 | )
7 | from skllm.llm.anthropic.mixin import ClaudeClassifierMixin as _ClaudeClassifierMixin
8 | from typing import Optional
9 |
10 |
11 | class ZeroShotClaudeClassifier(
12 | _BaseZeroShotClassifier, _ClaudeClassifierMixin, _SingleLabelMixin
13 | ):
14 | """Zero-shot text classifier using Anthropic Claude models for single-label classification."""
15 |
16 | def __init__(
17 | self,
18 | model: str = "claude-3-haiku-20240307",
19 | default_label: str = "Random",
20 | prompt_template: Optional[str] = None,
21 | key: Optional[str] = None,
22 | **kwargs,
23 | ):
24 | """
25 | Zero-shot text classifier using Anthropic Claude models.
26 |
27 | Parameters
28 | ----------
29 | model : str, optional
30 | Model to use, by default "claude-3-haiku-20240307".
31 | default_label : str, optional
32 | Default label for failed predictions; if "Random", selects randomly based on class frequencies, defaults to "Random".
33 | prompt_template : Optional[str], optional
34 | Custom prompt template to use, by default None.
35 | key : Optional[str], optional
36 | Estimator-specific API key; if None, retrieved from the global config, by default None.
37 | """
38 | super().__init__(
39 | model=model,
40 | default_label=default_label,
41 | prompt_template=prompt_template,
42 | **kwargs,
43 | )
44 | self._set_keys(key)
45 |
46 |
47 | class CoTClaudeClassifier(
48 | _BaseCoTClassifier, _ClaudeClassifierMixin, _SingleLabelMixin
49 | ):
50 | """Chain-of-thought text classifier using Anthropic Claude models for single-label classification."""
51 |
52 | def __init__(
53 | self,
54 | model: str = "claude-3-haiku-20240307",
55 | default_label: str = "Random",
56 | prompt_template: Optional[str] = None,
57 | key: Optional[str] = None,
58 | **kwargs,
59 | ):
60 | """
61 | Chain-of-thought text classifier using Anthropic Claude models.
62 |
63 | Parameters
64 | ----------
65 | model : str, optional
66 | Model to use, by default "claude-3-haiku-20240307".
67 | default_label : str, optional
68 | Default label for failed predictions; if "Random", selects randomly based on class frequencies, defaults to "Random".
69 | prompt_template : Optional[str], optional
70 | Custom prompt template to use, by default None.
71 | key : Optional[str], optional
72 | Estimator-specific API key; if None, retrieved from the global config, by default None.
73 | """
74 | super().__init__(
75 | model=model,
76 | default_label=default_label,
77 | prompt_template=prompt_template,
78 | **kwargs,
79 | )
80 | self._set_keys(key)
81 |
82 |
83 | class MultiLabelZeroShotClaudeClassifier(
84 | _BaseZeroShotClassifier, _ClaudeClassifierMixin, _MultiLabelMixin
85 | ):
86 | """Zero-shot text classifier using Anthropic Claude models for multi-label classification."""
87 |
88 | def __init__(
89 | self,
90 | model: str = "claude-3-haiku-20240307",
91 | default_label: str = "Random",
92 | max_labels: Optional[int] = 5,
93 | prompt_template: Optional[str] = None,
94 | key: Optional[str] = None,
95 | **kwargs,
96 | ):
97 | """
98 | Multi-label zero-shot text classifier using Anthropic Claude models.
99 |
100 | Parameters
101 | ----------
102 | model : str, optional
103 | Model to use, by default "claude-3-haiku-20240307".
104 | default_label : str, optional
105 | Default label for failed predictions; if "Random", selects randomly based on class frequencies, defaults to "Random".
106 | max_labels : Optional[int], optional
107 | Maximum number of labels per sample, by default 5.
108 | prompt_template : Optional[str], optional
109 | Custom prompt template to use, by default None.
110 | key : Optional[str], optional
111 | Estimator-specific API key; if None, retrieved from the global config, by default None.
112 | """
113 | super().__init__(
114 | model=model,
115 | default_label=default_label,
116 | max_labels=max_labels,
117 | prompt_template=prompt_template,
118 | **kwargs,
119 | )
120 | self._set_keys(key)
--------------------------------------------------------------------------------
/skllm/models/anthropic/tagging/ner.py:
--------------------------------------------------------------------------------
1 | from skllm.models._base.tagger import ExplainableNER as _ExplainableNER
2 | from skllm.llm.anthropic.mixin import ClaudeTextCompletionMixin as _ClaudeTextCompletionMixin
3 | from typing import Optional, Dict
4 |
5 |
6 | class AnthropicExplainableNER(_ExplainableNER, _ClaudeTextCompletionMixin):
7 | """Named Entity Recognition model using Anthropic's Claude API for explainable entity extraction."""
8 |
9 | def __init__(
10 | self,
11 | entities: Dict[str, str],
12 | display_predictions: bool = False,
13 | sparse_output: bool = True,
14 | model: str = "claude-3-haiku-20240307",
15 | key: Optional[str] = None,
16 | num_workers: int = 1,
17 | ) -> None:
18 | """
19 | Named entity recognition using Anthropic Claude API.
20 |
21 | Parameters
22 | ----------
23 | entities : dict
24 | dictionary of entities to recognize, with keys as entity names and values as descriptions
25 | display_predictions : bool, optional
26 | whether to display predictions, by default False
27 | sparse_output : bool, optional
28 | whether to generate a sparse representation of the predictions, by default True
29 | model : str, optional
30 | model to use, by default "claude-3-haiku-20240307"
31 | key : Optional[str], optional
32 | estimator-specific API key; if None, retrieved from the global config
33 | num_workers : int, optional
34 | number of workers (threads) to use, by default 1
35 | """
36 | self._set_keys(key)
37 | self.model = model
38 | self.entities = entities
39 | self.display_predictions = display_predictions
40 | self.sparse_output = sparse_output
41 | self.num_workers = num_workers
--------------------------------------------------------------------------------
/skllm/models/anthropic/text2text/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/BeastByteAI/scikit-llm/5491ec8d1ba5528b560cd115f7a0c93369fb0628/skllm/models/anthropic/text2text/__init__.py
--------------------------------------------------------------------------------
/skllm/models/anthropic/text2text/summarization.py:
--------------------------------------------------------------------------------
1 | from skllm.models._base.text2text import BaseSummarizer as _BaseSummarizer
2 | from skllm.llm.anthropic.mixin import ClaudeTextCompletionMixin as _ClaudeTextCompletionMixin
3 | from typing import Optional
4 |
5 |
6 | class ClaudeSummarizer(_BaseSummarizer, _ClaudeTextCompletionMixin):
7 | """Text summarizer using Anthropic Claude API."""
8 |
9 | def __init__(
10 | self,
11 | model: str = "claude-3-haiku-20240307",
12 | key: Optional[str] = None,
13 | max_words: int = 15,
14 | focus: Optional[str] = None,
15 | ) -> None:
16 | """
17 | Initialize the Claude summarizer.
18 |
19 | Parameters
20 | ----------
21 | model : str, optional
22 | Model to use, by default "claude-3-haiku-20240307"
23 | key : Optional[str], optional
24 | Estimator-specific API key; if None, retrieved from global config
25 | max_words : int, optional
26 | Soft limit of the summary length, by default 15
27 | focus : Optional[str], optional
28 | Concept in the text to focus on, by default None
29 | """
30 | self._set_keys(key)
31 | self.model = model
32 | self.max_words = max_words
33 | self.focus = focus
34 | self.system_message = "You are a text summarizer. Provide concise and accurate summaries."
--------------------------------------------------------------------------------
/skllm/models/anthropic/text2text/translation.py:
--------------------------------------------------------------------------------
1 | from skllm.models._base.text2text import BaseTranslator as _BaseTranslator
2 | from skllm.llm.anthropic.mixin import ClaudeTextCompletionMixin as _ClaudeTextCompletionMixin
3 | from typing import Optional
4 |
5 |
6 | class ClaudeTranslator(_BaseTranslator, _ClaudeTextCompletionMixin):
7 | """Text translator using Anthropic Claude API."""
8 |
9 | default_output = "Translation is unavailable."
10 |
11 | def __init__(
12 | self,
13 | model: str = "claude-3-haiku-20240307",
14 | key: Optional[str] = None,
15 | output_language: str = "English",
16 | ) -> None:
17 | """
18 | Initialize the Claude translator.
19 |
20 | Parameters
21 | ----------
22 | model : str, optional
23 | Model to use, by default "claude-3-haiku-20240307"
24 | key : Optional[str], optional
25 | Estimator-specific API key; if None, retrieved from global config
26 | output_language : str, optional
27 | Target language, by default "English"
28 | """
29 | self._set_keys(key)
30 | self.model = model
31 | self.output_language = output_language
32 | self.system_message = (
33 | "You are a professional translator. Provide accurate translations "
34 | "while maintaining the original meaning and tone of the text."
35 | )
--------------------------------------------------------------------------------
/skllm/models/gpt/classification/few_shot.py:
--------------------------------------------------------------------------------
1 | from skllm.models._base.classifier import (
2 | BaseFewShotClassifier,
3 | BaseDynamicFewShotClassifier,
4 | SingleLabelMixin,
5 | MultiLabelMixin,
6 | )
7 | from skllm.llm.gpt.mixin import GPTClassifierMixin
8 | from skllm.models.gpt.vectorization import GPTVectorizer
9 | from skllm.models._base.vectorizer import BaseVectorizer
10 | from skllm.memory.base import IndexConstructor
11 | from typing import Optional
12 |
13 |
14 | class FewShotGPTClassifier(BaseFewShotClassifier, GPTClassifierMixin, SingleLabelMixin):
15 | def __init__(
16 | self,
17 | model: str = "gpt-3.5-turbo",
18 | default_label: str = "Random",
19 | prompt_template: Optional[str] = None,
20 | key: Optional[str] = None,
21 | org: Optional[str] = None,
22 | **kwargs,
23 | ):
24 | """
25 | Few-shot text classifier using OpenAI/GPT API-compatible models.
26 |
27 | Parameters
28 | ----------
29 | model : str, optional
30 | model to use, by default "gpt-3.5-turbo"
31 | default_label : str, optional
32 | default label for failed prediction; if "Random" -> selects randomly based on class frequencies, by default "Random"
33 | prompt_template : Optional[str], optional
34 | custom prompt template to use, by default None
35 | key : Optional[str], optional
36 | estimator-specific API key; if None, retrieved from the global config, by default None
37 | org : Optional[str], optional
38 | estimator-specific ORG key; if None, retrieved from the global config, by default None
39 | """
40 | super().__init__(
41 | model=model,
42 | default_label=default_label,
43 | prompt_template=prompt_template,
44 | **kwargs,
45 | )
46 | self._set_keys(key, org)
47 |
48 |
49 | class MultiLabelFewShotGPTClassifier(
50 | BaseFewShotClassifier, GPTClassifierMixin, MultiLabelMixin
51 | ):
52 | def __init__(
53 | self,
54 | model: str = "gpt-3.5-turbo",
55 | default_label: str = "Random",
56 | max_labels: Optional[int] = 5,
57 | prompt_template: Optional[str] = None,
58 | key: Optional[str] = None,
59 | org: Optional[str] = None,
60 | **kwargs,
61 | ):
62 | """
63 | Multi-label few-shot text classifier using OpenAI/GPT API-compatible models.
64 |
65 | Parameters
66 | ----------
67 | model : str, optional
68 | model to use, by default "gpt-3.5-turbo"
69 | default_label : str, optional
70 | default label for failed prediction; if "Random" -> selects randomly based on class frequencies, by default "Random"
71 | max_labels : Optional[int], optional
72 | maximum labels per sample, by default 5
73 | prompt_template : Optional[str], optional
74 | custom prompt template to use, by default None
75 | key : Optional[str], optional
76 | estimator-specific API key; if None, retrieved from the global config, by default None
77 | org : Optional[str], optional
78 | estimator-specific ORG key; if None, retrieved from the global config, by default None
79 | """
80 | super().__init__(
81 | model=model,
82 | default_label=default_label,
83 | max_labels=max_labels,
84 | prompt_template=prompt_template,
85 | **kwargs,
86 | )
87 | self._set_keys(key, org)
88 |
89 |
90 | class DynamicFewShotGPTClassifier(
91 | BaseDynamicFewShotClassifier, GPTClassifierMixin, SingleLabelMixin
92 | ):
93 | def __init__(
94 | self,
95 | model: str = "gpt-3.5-turbo",
96 | default_label: str = "Random",
97 | prompt_template: Optional[str] = None,
98 | key: Optional[str] = None,
99 | org: Optional[str] = None,
100 | n_examples: int = 3,
101 | memory_index: Optional[IndexConstructor] = None,
102 | vectorizer: Optional[BaseVectorizer] = None,
103 | metric: Optional[str] = "euclidean",
104 | **kwargs,
105 | ):
106 | """
107 | Dynamic few-shot text classifier using OpenAI/GPT API-compatible models.
108 | For each sample, N closest examples are retrieved from the memory.
109 |
110 | Parameters
111 | ----------
112 | model : str, optional
113 | model to use, by default "gpt-3.5-turbo"
114 | default_label : str, optional
115 | default label for failed prediction; if "Random" -> selects randomly based on class frequencies, by default "Random"
116 | prompt_template : Optional[str], optional
117 | custom prompt template to use, by default None
118 | key : Optional[str], optional
119 | estimator-specific API key; if None, retrieved from the global config, by default None
120 | org : Optional[str], optional
121 | estimator-specific ORG key; if None, retrieved from the global config, by default None
122 | n_examples : int, optional
123 | number of closest examples per class to be retrieved, by default 3
124 | memory_index : Optional[IndexConstructor], optional
125 | custom memory index, for details check `skllm.memory` submodule, by default None
126 | vectorizer : Optional[BaseVectorizer], optional
127 | scikit-llm vectorizer; if None, `GPTVectorizer` is used, by default None
128 | metric : Optional[str], optional
129 | metric used for similarity search, by default "euclidean"
130 | """
131 | if vectorizer is None:
132 | vectorizer = GPTVectorizer(model="text-embedding-ada-002", key=key, org=org)
133 | super().__init__(
134 | model=model,
135 | default_label=default_label,
136 | prompt_template=prompt_template,
137 | n_examples=n_examples,
138 | memory_index=memory_index,
139 | vectorizer=vectorizer,
140 | metric=metric,
141 | )
142 | self._set_keys(key, org)
143 |
--------------------------------------------------------------------------------
/skllm/models/gpt/classification/tunable.py:
--------------------------------------------------------------------------------
1 | from skllm.llm.gpt.mixin import (
2 | GPTClassifierMixin as _GPTClassifierMixin,
3 | GPTTunableMixin as _GPTTunableMixin,
4 | )
5 | from skllm.models._base.classifier import (
6 | BaseTunableClassifier as _BaseTunableClassifier,
7 | SingleLabelMixin as _SingleLabelMixin,
8 | MultiLabelMixin as _MultiLabelMixin,
9 | )
10 | from typing import Optional
11 |
12 |
13 | class _TunableClassifier(_BaseTunableClassifier, _GPTClassifierMixin, _GPTTunableMixin):
14 | pass
15 |
16 |
17 | class GPTClassifier(_TunableClassifier, _SingleLabelMixin):
18 | def __init__(
19 | self,
20 | base_model: str = "gpt-3.5-turbo-0613",
21 | default_label: str = "Random",
22 | key: Optional[str] = None,
23 | org: Optional[str] = None,
24 | n_epochs: Optional[int] = None,
25 | custom_suffix: Optional[str] = "skllm",
26 | prompt_template: Optional[str] = None,
27 | ):
28 | """
29 | Tunable GPT-based text classifier.
30 |
31 | Parameters
32 | ----------
33 | base_model : str, optional
34 | base model to use, by default "gpt-3.5-turbo-0613"
35 | default_label : str, optional
36 | default label for failed prediction; if "Random" -> selects randomly based on class frequencies, by default "Random"
37 | key : Optional[str], optional
38 | estimator-specific API key; if None, retrieved from the global config, by default None
39 | org : Optional[str], optional
40 | estimator-specific ORG key; if None, retrieved from the global config, by default None
41 | n_epochs : Optional[int], optional
42 | number of epochs; if None, determined automatically; by default None
43 | custom_suffix : Optional[str], optional
44 | custom suffix of the tuned model, used for naming purposes only, by default "skllm"
45 | prompt_template : Optional[str], optional
46 | custom prompt template to use, by default None
47 | """
48 | super().__init__(
49 | model=None, default_label=default_label, prompt_template=prompt_template
50 | )
51 | self._set_keys(key, org)
52 | self._set_hyperparameters(base_model, n_epochs, custom_suffix)
53 |
54 |
55 | class MultiLabelGPTClassifier(_TunableClassifier, _MultiLabelMixin):
56 | def __init__(
57 | self,
58 | base_model: str = "gpt-3.5-turbo-0613",
59 | default_label: str = "Random",
60 | key: Optional[str] = None,
61 | org: Optional[str] = None,
62 | n_epochs: Optional[int] = None,
63 | custom_suffix: Optional[str] = "skllm",
64 | prompt_template: Optional[str] = None,
65 | max_labels: Optional[int] = 5,
66 | ):
67 | """
68 | Tunable multi-label GPT-based text classifier.
69 |
70 | Parameters
71 | ----------
72 | base_model : str, optional
73 | base model to use, by default "gpt-3.5-turbo-0613"
74 | default_label : str, optional
75 | default label for failed prediction; if "Random" -> selects randomly based on class frequencies, by default "Random"
76 | key : Optional[str], optional
77 | estimator-specific API key; if None, retrieved from the global config, by default None
78 | org : Optional[str], optional
79 | estimator-specific ORG key; if None, retrieved from the global config, by default None
80 | n_epochs : Optional[int], optional
81 | number of epochs; if None, determined automatically; by default None
82 | custom_suffix : Optional[str], optional
83 | custom suffix of the tuned model, used for naming purposes only, by default "skllm"
84 | prompt_template : Optional[str], optional
85 | custom prompt template to use, by default None
86 | max_labels : Optional[int], optional
87 | maximum labels per sample, by default 5
88 | """
89 | super().__init__(
90 | model=None,
91 | default_label=default_label,
92 | prompt_template=prompt_template,
93 | max_labels=max_labels,
94 | )
95 | self._set_keys(key, org)
96 | self._set_hyperparameters(base_model, n_epochs, custom_suffix)
97 |
--------------------------------------------------------------------------------
/skllm/models/gpt/classification/zero_shot.py:
--------------------------------------------------------------------------------
1 | from skllm.models._base.classifier import (
2 | SingleLabelMixin as _SingleLabelMixin,
3 | MultiLabelMixin as _MultiLabelMixin,
4 | BaseZeroShotClassifier as _BaseZeroShotClassifier,
5 | BaseCoTClassifier as _BaseCoTClassifier,
6 | )
7 | from skllm.llm.gpt.mixin import GPTClassifierMixin as _GPTClassifierMixin
8 | from typing import Optional
9 |
10 |
11 | class ZeroShotGPTClassifier(
12 | _BaseZeroShotClassifier, _GPTClassifierMixin, _SingleLabelMixin
13 | ):
14 | def __init__(
15 | self,
16 | model: str = "gpt-3.5-turbo",
17 | default_label: str = "Random",
18 | prompt_template: Optional[str] = None,
19 | key: Optional[str] = None,
20 | org: Optional[str] = None,
21 | **kwargs,
22 | ):
23 | """
24 | Zero-shot text classifier using OpenAI/GPT API-compatible models.
25 |
26 | Parameters
27 | ----------
28 | model : str, optional
29 | model to use, by default "gpt-3.5-turbo"
30 | default_label : str, optional
31 | default label for failed prediction; if "Random" -> selects randomly based on class frequencies, by default "Random"
32 | prompt_template : Optional[str], optional
33 | custom prompt template to use, by default None
34 | key : Optional[str], optional
35 | estimator-specific API key; if None, retrieved from the global config, by default None
36 | org : Optional[str], optional
37 | estimator-specific ORG key; if None, retrieved from the global config, by default None
38 | """
39 | super().__init__(
40 | model=model,
41 | default_label=default_label,
42 | prompt_template=prompt_template,
43 | **kwargs,
44 | )
45 | self._set_keys(key, org)
46 |
47 |
48 | class CoTGPTClassifier(_BaseCoTClassifier, _GPTClassifierMixin, _SingleLabelMixin):
49 | def __init__(
50 | self,
51 | model: str = "gpt-3.5-turbo",
52 | default_label: str = "Random",
53 | prompt_template: Optional[str] = None,
54 | key: Optional[str] = None,
55 | org: Optional[str] = None,
56 | **kwargs,
57 | ):
58 | """
59 | Chain-of-thought text classifier using OpenAI/GPT API-compatible models.
60 |
61 | Parameters
62 | ----------
63 | model : str, optional
64 | model to use, by default "gpt-3.5-turbo"
65 | default_label : str, optional
66 | default label for failed prediction; if "Random" -> selects randomly based on class frequencies, by default "Random"
67 | prompt_template : Optional[str], optional
68 | custom prompt template to use, by default None
69 | key : Optional[str], optional
70 | estimator-specific API key; if None, retrieved from the global config, by default None
71 | org : Optional[str], optional
72 | estimator-specific ORG key; if None, retrieved from the global config, by default None
73 | """
74 | super().__init__(
75 | model=model,
76 | default_label=default_label,
77 | prompt_template=prompt_template,
78 | **kwargs,
79 | )
80 | self._set_keys(key, org)
81 |
82 |
83 | class MultiLabelZeroShotGPTClassifier(
84 | _BaseZeroShotClassifier, _GPTClassifierMixin, _MultiLabelMixin
85 | ):
86 | def __init__(
87 | self,
88 | model: str = "gpt-3.5-turbo",
89 | default_label: str = "Random",
90 | max_labels: Optional[int] = 5,
91 | prompt_template: Optional[str] = None,
92 | key: Optional[str] = None,
93 | org: Optional[str] = None,
94 | **kwargs,
95 | ):
96 | """
97 | Multi-label zero-shot text classifier using OpenAI/GPT API-compatible models.
98 |
99 | Parameters
100 | ----------
101 | model : str, optional
102 | model to use, by default "gpt-3.5-turbo"
103 | default_label : str, optional
104 | default label for failed prediction; if "Random" -> selects randomly based on class frequencies, by default "Random"
105 | max_labels : Optional[int], optional
106 | maximum labels per sample, by default 5
107 | prompt_template : Optional[str], optional
108 | custom prompt template to use, by default None
109 | key : Optional[str], optional
110 | estimator-specific API key; if None, retrieved from the global config, by default None
111 | org : Optional[str], optional
112 | estimator-specific ORG key; if None, retrieved from the global config, by default None
113 | """
114 | super().__init__(
115 | model=model,
116 | default_label=default_label,
117 | max_labels=max_labels,
118 | prompt_template=prompt_template,
119 | **kwargs,
120 | )
121 | self._set_keys(key, org)
122 |
--------------------------------------------------------------------------------
/skllm/models/gpt/tagging/ner.py:
--------------------------------------------------------------------------------
1 | from skllm.models._base.tagger import ExplainableNER as _ExplainableNER
2 | from skllm.llm.gpt.mixin import GPTTextCompletionMixin as _GPTTextCompletionMixin
3 | from typing import Optional, Dict
4 |
5 |
6 | class GPTExplainableNER(_ExplainableNER, _GPTTextCompletionMixin):
7 | def __init__(
8 | self,
9 | entities: Dict[str, str],
10 | display_predictions: bool = False,
11 | sparse_output: bool = True,
12 | model: str = "gpt-4o",
13 | key: Optional[str] = None,
14 | org: Optional[str] = None,
15 | num_workers: int = 1,
16 | ) -> None:
17 | """
18 | Named entity recognition using OpenAI/GPT API-compatible models.
19 |
20 | Parameters
21 | ----------
22 | entities : dict
23 | dictionary of entities to recognize, with keys as entity names and values as descriptions
24 | display_predictions : bool, optional
25 | whether to display predictions, by default False
26 | sparse_output : bool, optional
27 | whether to generate a sparse representation of the predictions, by default True
28 | model : str, optional
29 | model to use, by default "gpt-4o"
30 | key : Optional[str], optional
31 | estimator-specific API key; if None, retrieved from the global config, by default None
32 | org : Optional[str], optional
33 | estimator-specific ORG key; if None, retrieved from the global config, by default None
34 | num_workers : int, optional
35 | number of workers (threads) to use, by default 1
36 | """
37 | self._set_keys(key, org)
38 | self.model = model
39 | self.entities = entities
40 | self.display_predictions = display_predictions
41 | self.sparse_output = sparse_output
42 | self.num_workers = num_workers
--------------------------------------------------------------------------------
/skllm/models/gpt/text2text/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/BeastByteAI/scikit-llm/5491ec8d1ba5528b560cd115f7a0c93369fb0628/skllm/models/gpt/text2text/__init__.py
--------------------------------------------------------------------------------
/skllm/models/gpt/text2text/summarization.py:
--------------------------------------------------------------------------------
1 | from skllm.models._base.text2text import BaseSummarizer as _BaseSummarizer
2 | from skllm.llm.gpt.mixin import GPTTextCompletionMixin as _GPTTextCompletionMixin
3 | from typing import Optional
4 |
5 |
6 | class GPTSummarizer(_BaseSummarizer, _GPTTextCompletionMixin):
7 | def __init__(
8 | self,
9 | model: str = "gpt-3.5-turbo",
10 | key: Optional[str] = None,
11 | org: Optional[str] = None,
12 | max_words: int = 15,
13 | focus: Optional[str] = None,
14 | ) -> None:
15 | """
16 | Text summarizer using OpenAI/GPT API-compatible models.
17 |
18 | Parameters
19 | ----------
20 | model : str, optional
21 | model to use, by default "gpt-3.5-turbo"
22 | key : Optional[str], optional
23 | estimator-specific API key; if None, retrieved from the global config, by default None
24 | org : Optional[str], optional
25 | estimator-specific ORG key; if None, retrieved from the global config, by default None
26 | max_words : int, optional
27 | soft limit of the summary length, by default 15
28 | focus : Optional[str], optional
29 | concept in the text to focus on, by default None
30 | """
31 | self._set_keys(key, org)
32 | self.model = model
33 | self.max_words = max_words
34 | self.focus = focus
35 |
--------------------------------------------------------------------------------
/skllm/models/gpt/text2text/translation.py:
--------------------------------------------------------------------------------
1 | from skllm.models._base.text2text import BaseTranslator as _BaseTranslator
2 | from skllm.llm.gpt.mixin import GPTTextCompletionMixin as _GPTTextCompletionMixin
3 | from typing import Optional
4 |
5 |
6 | class GPTTranslator(_BaseTranslator, _GPTTextCompletionMixin):
7 | default_output = "Translation is unavailable."
8 |
9 | def __init__(
10 | self,
11 | model: str = "gpt-3.5-turbo",
12 | key: Optional[str] = None,
13 | org: Optional[str] = None,
14 | output_language: str = "English",
15 | ) -> None:
16 | """
17 | Text translator using OpenAI/GPT API-compatible models.
18 |
19 | Parameters
20 | ----------
21 | model : str, optional
22 | model to use, by default "gpt-3.5-turbo"
23 | key : Optional[str], optional
24 | estimator-specific API key; if None, retrieved from the global config, by default None
25 | org : Optional[str], optional
26 | estimator-specific ORG key; if None, retrieved from the global config, by default None
27 | output_language : str, optional
28 | target language, by default "English"
29 | """
30 | self._set_keys(key, org)
31 | self.model = model
32 | self.output_language = output_language
33 |
--------------------------------------------------------------------------------
/skllm/models/gpt/text2text/tunable.py:
--------------------------------------------------------------------------------
1 | from skllm.llm.gpt.mixin import (
2 | GPTTunableMixin as _GPTTunableMixin,
3 | GPTTextCompletionMixin as _GPTTextCompletionMixin,
4 | )
5 | from skllm.models._base.text2text import (
6 | BaseTunableText2TextModel as _BaseTunableText2TextModel,
7 | )
8 | from typing import Optional
9 |
10 |
11 | class TunableGPTText2Text(
12 | _BaseTunableText2TextModel, _GPTTextCompletionMixin, _GPTTunableMixin
13 | ):
14 | def __init__(
15 | self,
16 | base_model: str = "gpt-3.5-turbo-0613",
17 | key: Optional[str] = None,
18 | org: Optional[str] = None,
19 | n_epochs: Optional[int] = None,
20 | custom_suffix: Optional[str] = "skllm",
21 | ):
22 | """
23 | Tunable GPT-based text-to-text model.
24 |
25 | Parameters
26 | ----------
27 | base_model : str, optional
28 | base model to use, by default "gpt-3.5-turbo-0613"
29 | key : Optional[str], optional
30 | estimator-specific API key; if None, retrieved from the global config, by default None
31 | org : Optional[str], optional
32 | estimator-specific ORG key; if None, retrieved from the global config, by default None
33 | n_epochs : Optional[int], optional
34 | number of epochs; if None, determined automatically; by default None
35 | custom_suffix : Optional[str], optional
36 | custom suffix of the tuned model, used for naming purposes only, by default "skllm"
37 | """
38 | self.model = None
39 | self._set_keys(key, org)
40 | self._set_hyperparameters(base_model, n_epochs, custom_suffix)
41 |
--------------------------------------------------------------------------------
/skllm/models/gpt/vectorization.py:
--------------------------------------------------------------------------------
1 | from skllm.models._base.vectorizer import BaseVectorizer as _BaseVectorizer
2 | from skllm.llm.gpt.mixin import GPTEmbeddingMixin as _GPTEmbeddingMixin
3 | from typing import Optional
4 |
5 |
6 | class GPTVectorizer(_BaseVectorizer, _GPTEmbeddingMixin):
7 | def __init__(
8 | self,
9 | model: str = "text-embedding-3-small",
10 | batch_size: int = 1,
11 | key: Optional[str] = None,
12 | org: Optional[str] = None,
13 | ):
14 | """
15 | Text vectorizer using OpenAI/GPT API-compatible models.
16 |
17 | Parameters
18 | ----------
19 | model : str, optional
20 | model to use, by default "text-embedding-ada-002"
21 | batch_size : int, optional
22 | number of samples per request, by default 1
23 | key : Optional[str], optional
24 | estimator-specific API key; if None, retrieved from the global config, by default None
25 | org : Optional[str], optional
26 | estimator-specific ORG key; if None, retrieved from the global config, by default None
27 | """
28 | super().__init__(model=model, batch_size=batch_size)
29 | self._set_keys(key, org)
30 |
--------------------------------------------------------------------------------
/skllm/models/vertex/classification/tunable.py:
--------------------------------------------------------------------------------
1 | from skllm.models._base.classifier import (
2 | BaseTunableClassifier as _BaseTunableClassifier,
3 | SingleLabelMixin as _SingleLabelMixin,
4 | MultiLabelMixin as _MultiLabelMixin,
5 | )
6 | from skllm.llm.vertex.mixin import (
7 | VertexClassifierMixin as _VertexClassifierMixin,
8 | VertexTunableMixin as _VertexTunableMixin,
9 | )
10 | from typing import Optional
11 |
12 |
13 | class _TunableClassifier(
14 | _BaseTunableClassifier, _VertexClassifierMixin, _VertexTunableMixin
15 | ):
16 | pass
17 |
18 |
19 | class VertexClassifier(_TunableClassifier, _SingleLabelMixin):
20 | def __init__(
21 | self,
22 | base_model: str = "text-bison@002",
23 | n_update_steps: int = 1,
24 | default_label: str = "Random",
25 | ):
26 | """
27 | Tunable Vertex-based text classifier.
28 |
29 | Parameters
30 | ----------
31 | base_model : str, optional
32 | base model to use, by default "text-bison@002"
33 | n_update_steps : int, optional
34 | number of epochs, by default 1
35 | default_label : str, optional
36 | default label for failed prediction; if "Random" -> selects randomly based on class frequencies, by default "Random"
37 | """
38 | self._set_hyperparameters(base_model=base_model, n_update_steps=n_update_steps)
39 | super().__init__(
40 | model=None,
41 | default_label=default_label,
42 | )
43 |
--------------------------------------------------------------------------------
/skllm/models/vertex/classification/zero_shot.py:
--------------------------------------------------------------------------------
1 | from skllm.llm.vertex.mixin import VertexClassifierMixin as _VertexClassifierMixin
2 | from skllm.models._base.classifier import (
3 | BaseZeroShotClassifier as _BaseZeroShotClassifier,
4 | SingleLabelMixin as _SingleLabelMixin,
5 | MultiLabelMixin as _MultiLabelMixin,
6 | )
7 | from typing import Optional
8 |
9 |
10 | class ZeroShotVertexClassifier(
11 | _BaseZeroShotClassifier, _SingleLabelMixin, _VertexClassifierMixin
12 | ):
13 | def __init__(
14 | self,
15 | model: str = "text-bison@002",
16 | default_label: str = "Random",
17 | prompt_template: Optional[str] = None,
18 | **kwargs,
19 | ):
20 | """
21 | Zero-shot text classifier using Vertex AI models.
22 |
23 | Parameters
24 | ----------
25 | model : str, optional
26 | model to use, by default "text-bison@002"
27 | default_label : str, optional
28 | default label for failed prediction; if "Random" -> selects randomly based on class frequencies, by default "Random"
29 | prompt_template : Optional[str], optional
30 | custom prompt template to use, by default None
31 | """
32 | super().__init__(
33 | model=model,
34 | default_label=default_label,
35 | prompt_template=prompt_template,
36 | **kwargs,
37 | )
38 |
39 |
40 | class MultiLabelZeroShotVertexClassifier(
41 | _BaseZeroShotClassifier, _MultiLabelMixin, _VertexClassifierMixin
42 | ):
43 | def __init__(
44 | self,
45 | model: str = "text-bison@002",
46 | default_label: str = "Random",
47 | prompt_template: Optional[str] = None,
48 | max_labels: Optional[int] = 5,
49 | **kwargs,
50 | ):
51 | """
52 | Multi-label zero-shot text classifier using Vertex AI models.
53 |
54 | Parameters
55 | ----------
56 | model : str, optional
57 | model to use, by default "text-bison@002"
58 | default_label : str, optional
59 | default label for failed prediction; if "Random" -> selects randomly based on class frequencies, by default "Random"
60 | prompt_template : Optional[str], optional
61 | custom prompt template to use, by default None
62 | max_labels : Optional[int], optional
63 | maximum labels per sample, by default 5
64 | """
65 | super().__init__(
66 | model=model,
67 | default_label=default_label,
68 | prompt_template=prompt_template,
69 | max_labels=max_labels,
70 | **kwargs,
71 | )
72 |
--------------------------------------------------------------------------------
/skllm/models/vertex/text2text/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/BeastByteAI/scikit-llm/5491ec8d1ba5528b560cd115f7a0c93369fb0628/skllm/models/vertex/text2text/__init__.py
--------------------------------------------------------------------------------
/skllm/models/vertex/text2text/tunable.py:
--------------------------------------------------------------------------------
1 | from skllm.llm.vertex.mixin import (
2 | VertexTunableMixin as _VertexTunableMixin,
3 | VertexTextCompletionMixin as _VertexTextCompletionMixin,
4 | )
5 | from skllm.models._base.text2text import (
6 | BaseTunableText2TextModel as _BaseTunableText2TextModel,
7 | )
8 |
9 |
10 | class TunableVertexText2Text(
11 | _BaseTunableText2TextModel, _VertexTextCompletionMixin, _VertexTunableMixin
12 | ):
13 | def __init__(
14 | self,
15 | base_model: str = "text-bison@002",
16 | n_update_steps: int = 1,
17 | ):
18 | """
19 | Tunable Vertex-based text-to-text model.
20 |
21 | Parameters
22 | ----------
23 | base_model : str, optional
24 | base model to use, by default "text-bison@002"
25 | n_update_steps : int, optional
26 | number of epochs, by default 1
27 | """
28 | self.model = None
29 | self._set_hyperparameters(base_model=base_model, n_update_steps=n_update_steps)
30 |
--------------------------------------------------------------------------------
/skllm/prompts/builders.py:
--------------------------------------------------------------------------------
1 | from typing import Union, Optional
2 |
3 | from skllm.prompts.templates import (
4 | FEW_SHOT_CLF_PROMPT_TEMPLATE,
5 | FEW_SHOT_MLCLF_PROMPT_TEMPLATE,
6 | FOCUSED_SUMMARY_PROMPT_TEMPLATE,
7 | SUMMARY_PROMPT_TEMPLATE,
8 | TRANSLATION_PROMPT_TEMPLATE,
9 | ZERO_SHOT_CLF_PROMPT_TEMPLATE,
10 | ZERO_SHOT_MLCLF_PROMPT_TEMPLATE,
11 | EXPLAINABLE_NER_DENSE_PROMPT_TEMPLATE,
12 | )
13 |
14 | # TODO add validators
15 |
16 |
17 | def build_zero_shot_prompt_slc(
18 | x: str, labels: str, template: str = ZERO_SHOT_CLF_PROMPT_TEMPLATE
19 | ) -> str:
20 | """Builds a prompt for zero-shot single-label classification.
21 |
22 | Parameters
23 | ----------
24 | x : str
25 | sample to classify
26 | labels : str
27 | candidate labels in a list-like representation
28 | template : str
29 | prompt template to use, must contain placeholders for all variables, by default ZERO_SHOT_CLF_PROMPT_TEMPLATE
30 |
31 | Returns
32 | -------
33 | str
34 | prepared prompt
35 | """
36 | return template.format(x=x, labels=labels)
37 |
38 |
39 | def build_few_shot_prompt_slc(
40 | x: str,
41 | labels: str,
42 | training_data: str,
43 | template: str = FEW_SHOT_CLF_PROMPT_TEMPLATE,
44 | ) -> str:
45 | """Builds a prompt for zero-shot single-label classification.
46 |
47 | Parameters
48 | ----------
49 | x : str
50 | sample to classify
51 | labels : str
52 | candidate labels in a list-like representation
53 | training_data : str
54 | training data to be used for few-shot learning
55 | template : str
56 | prompt template to use, must contain placeholders for all variables, by default ZERO_SHOT_CLF_PROMPT_TEMPLATE
57 |
58 | Returns
59 | -------
60 | str
61 | prepared prompt
62 | """
63 | return template.format(x=x, labels=labels, training_data=training_data)
64 |
65 |
66 | def build_few_shot_prompt_mlc(
67 | x: str,
68 | labels: str,
69 | training_data: str,
70 | max_cats: Union[int, str],
71 | template: str = FEW_SHOT_MLCLF_PROMPT_TEMPLATE,
72 | ) -> str:
73 | """Builds a prompt for few-shot single-label classification.
74 |
75 | Parameters
76 | ----------
77 | x : str
78 | sample to classify
79 | labels : str
80 | candidate labels in a list-like representation
81 | max_cats : Union[int,str]
82 | maximum number of categories to assign
83 | training_data : str
84 | training data to be used for few-shot learning
85 | template : str
86 | prompt template to use, must contain placeholders for all variables, by default ZERO_SHOT_CLF_PROMPT_TEMPLATE
87 |
88 | Returns
89 | -------
90 | str
91 | prepared prompt
92 | """
93 | return template.format(
94 | x=x, labels=labels, training_data=training_data, max_cats=max_cats
95 | )
96 |
97 |
98 | def build_zero_shot_prompt_mlc(
99 | x: str,
100 | labels: str,
101 | max_cats: Union[int, str],
102 | template: str = ZERO_SHOT_MLCLF_PROMPT_TEMPLATE,
103 | ) -> str:
104 | """Builds a prompt for zero-shot multi-label classification.
105 |
106 | Parameters
107 | ----------
108 | x : str
109 | sample to classify
110 | labels : str
111 | candidate labels in a list-like representation
112 | max_cats : Union[int,str]
113 | maximum number of categories to assign
114 | template : str
115 | prompt template to use, must contain placeholders for all variables, by default ZERO_SHOT_MLCLF_PROMPT_TEMPLATE
116 |
117 | Returns
118 | -------
119 | str
120 | prepared prompt
121 | """
122 | return template.format(x=x, labels=labels, max_cats=max_cats)
123 |
124 |
125 | def build_summary_prompt(
126 | x: str, max_words: Union[int, str], template: str = SUMMARY_PROMPT_TEMPLATE
127 | ) -> str:
128 | """Builds a prompt for text summarization.
129 |
130 | Parameters
131 | ----------
132 | x : str
133 | sample to summarize
134 | max_words : Union[int,str]
135 | maximum number of words to use in the summary
136 | template : str
137 | prompt template to use, must contain placeholders for all variables, by default SUMMARY_PROMPT_TEMPLATE
138 |
139 | Returns
140 | -------
141 | str
142 | prepared prompt
143 | """
144 | return template.format(x=x, max_words=max_words)
145 |
146 |
147 | def build_focused_summary_prompt(
148 | x: str,
149 | max_words: Union[int, str],
150 | focus: Union[int, str],
151 | template: str = FOCUSED_SUMMARY_PROMPT_TEMPLATE,
152 | ) -> str:
153 | """Builds a prompt for focused text summarization.
154 |
155 | Parameters
156 | ----------
157 | x : str
158 | sample to summarize
159 | max_words : Union[int,str]
160 | maximum number of words to use in the summary
161 | focus : Union[int,str]
162 | the topic(s) to focus on
163 | template : str
164 | prompt template to use, must contain placeholders for all variables, by default FOCUSED_SUMMARY_PROMPT_TEMPLATE
165 |
166 | Returns
167 | -------
168 | str
169 | prepared prompt
170 | """
171 | return template.format(x=x, max_words=max_words, focus=focus)
172 |
173 |
174 | def build_translation_prompt(
175 | x: str, output_language: str, template: str = TRANSLATION_PROMPT_TEMPLATE
176 | ) -> str:
177 | """Builds a prompt for text translation.
178 |
179 | Parameters
180 | ----------
181 | x : str
182 | sample to translate
183 | output_language : str
184 | language to translate to
185 | template : str
186 | prompt template to use, must contain placeholders for all variables, by default TRANSLATION_PROMPT_TEMPLATE
187 |
188 | Returns
189 | -------
190 | str
191 | prepared prompt
192 | """
193 | return template.format(x=x, output_language=output_language)
194 |
195 |
196 | def build_ner_prompt(
197 | entities: list,
198 | x: str,
199 | template: str = EXPLAINABLE_NER_DENSE_PROMPT_TEMPLATE,
200 | ) -> str:
201 | """Builds a prompt for named entity recognition.
202 |
203 | Parameters
204 | ----------
205 | entities : list
206 | list of entities to recognize
207 | x : str
208 | sample to recognize entities in
209 | template : str, optional
210 | prompt template to use, must contain placeholders for all variables, by default EXPLAINABLE_NER_DENSE_PROMPT_TEMPLATE
211 |
212 | Returns
213 | -------
214 | str
215 | prepared prompt
216 | """
217 | return template.format(entities=entities, x=x)
218 |
--------------------------------------------------------------------------------
/skllm/prompts/templates.py:
--------------------------------------------------------------------------------
1 | ZERO_SHOT_CLF_PROMPT_TEMPLATE = """
2 | You will be provided with the following information:
3 | 1. An arbitrary text sample. The sample is delimited with triple backticks.
4 | 2. List of categories the text sample can be assigned to. The list is delimited with square brackets. The categories in the list are enclosed in the single quotes and comma separated.
5 |
6 | Perform the following tasks:
7 | 1. Identify to which category the provided text belongs to with the highest probability.
8 | 2. Assign the provided text to that category.
9 | 3. Provide your response in a JSON format containing a single key `label` and a value corresponding to the assigned category. Do not provide any additional information except the JSON.
10 |
11 | List of categories: {labels}
12 |
13 | Text sample: ```{x}```
14 |
15 | Your JSON response:
16 | """
17 |
18 | COT_CLF_PROMPT_TEMPLATE = """
19 | You are tasked with classifying a given text sample based on a list of potential categories. Please adhere to the following guidelines:
20 |
21 | 1. The text intended for classification is presented between triple backticks.
22 | 2. The possible categories are enumerated in square brackets, with each category enclosed in single quotes and separated by commas.
23 |
24 | Tasks:
25 | 1. Examine the text and provide detailed justifications for the possibility of the text belonging or not belonging to each category listed.
26 | 2. Determine and select the most appropriate category for the text based on your comprehensive justifications.
27 | 3. Format your decision into a JSON object containing two keys: `explanation` and `label`. The `explanation` should concisely capture the rationale for each category before concluding with the chosen category.
28 |
29 | Category List: {labels}
30 |
31 | Text Sample: ```{x}```
32 |
33 | Provide your JSON response below, ensuring that justifications for all categories are clearly detailed:
34 | """
35 |
36 | ZERO_SHOT_CLF_SHORT_PROMPT_TEMPLATE = """
37 | Classify the following text into one of the following classes: {labels}. Provide your response in a JSON format containing a single key `label`.
38 | Text: ```{x}```
39 | """
40 |
41 | ZERO_SHOT_MLCLF_SHORT_PROMPT_TEMPLATE = """
42 | Classify the following text into at least 1 but up to {max_cats} of the following classes: {labels}. Provide your response in a JSON format containing a single key `label`.
43 | Text: ```{x}```
44 | """
45 |
46 | FEW_SHOT_CLF_PROMPT_TEMPLATE = """
47 | You will be provided with the following information:
48 | 1. An arbitrary text sample. The sample is delimited with triple backticks.
49 | 2. List of categories the text sample can be assigned to. The list is delimited with square brackets. The categories in the list are enclosed in the single quotes and comma separated.
50 | 3. Examples of text samples and their assigned categories. The examples are delimited with triple backticks. The assigned categories are enclosed in a list-like structure. These examples are to be used as training data.
51 |
52 | Perform the following tasks:
53 | 1. Identify to which category the provided text belongs to with the highest probability.
54 | 2. Assign the provided text to that category.
55 | 3. Provide your response in a JSON format containing a single key `label` and a value corresponding to the assigned category. Do not provide any additional information except the JSON.
56 |
57 | List of categories: {labels}
58 |
59 | Training data:
60 | {training_data}
61 |
62 | Text sample: ```{x}```
63 |
64 | Your JSON response:
65 | """
66 |
67 | FEW_SHOT_MLCLF_PROMPT_TEMPLATE = """
68 | You will be provided with the following information:
69 | 1. An arbitrary text sample. The sample is delimited with triple backticks.
70 | 2. List of categories the text sample can be assigned to. The list is delimited with square brackets. The categories in the list are enclosed in the single quotes and comma separated.
71 | 3. Examples of text samples and their assigned categories. The examples are delimited with triple backticks. The assigned categories are enclosed in a list-like structure. These examples are to be used as training data.
72 |
73 | Perform the following tasks:
74 | 1. Identify to which category the provided text belongs to with the highest probability.
75 | 2. Assign the text sample to at least 1 but up to {max_cats} categories based on the probabilities.
76 | 3. Provide your response in a JSON format containing a single key `label` and a value corresponding to the array of assigned categories. Do not provide any additional information except the JSON.
77 |
78 | List of categories: {labels}
79 |
80 | Training data:
81 | {training_data}
82 |
83 | Text sample: ```{x}```
84 |
85 | Your JSON response:
86 | """
87 |
88 | ZERO_SHOT_MLCLF_PROMPT_TEMPLATE = """
89 | You will be provided with the following information:
90 | 1. An arbitrary text sample. The sample is delimited with triple backticks.
91 | 2. List of categories the text sample can be assigned to. The list is delimited with square brackets. The categories in the list are enclosed in the single quotes and comma separated. The text sample belongs to at least one category but cannot exceed {max_cats}.
92 |
93 | Perform the following tasks:
94 | 1. Identify to which categories the provided text belongs to with the highest probability.
95 | 2. Assign the text sample to at least 1 but up to {max_cats} categories based on the probabilities.
96 | 3. Provide your response in a JSON format containing a single key `label` and a value corresponding to the array of assigned categories. Do not provide any additional information except the JSON.
97 |
98 | List of categories: {labels}
99 |
100 | Text sample: ```{x}```
101 |
102 | Your JSON response:
103 | """
104 |
105 | COT_MLCLF_PROMPT_TEMPLATE = """
106 | You are tasked with classifying a given text sample based on a list of potential categories. Please adhere to the following guidelines:
107 |
108 | 1. The text intended for classification is presented between triple backticks.
109 | 2. The possible categories are enumerated in square brackets, with each category enclosed in quotes and separated by commas.
110 |
111 | Tasks:
112 | 1. Examine the text and provide detailed justifications for the possibility of the text belonging or not belonging to each category listed.
113 | 2. Determine and select at most {max_cats} most appropriate categories for the text based on your comprehensive justifications.
114 | 3. Format your decision into a JSON object containing two keys: `explanation` and `label`. The `explanation` should concisely capture the rationale for each category before concluding with the chosen category. The `label` should contain an array of the chosen categories.
115 |
116 | Category List: {labels}
117 |
118 | Text Sample: ```{x}```
119 |
120 | Provide your JSON response below, ensuring that justifications for all categories are clearly detailed:
121 | """
122 |
123 | SUMMARY_PROMPT_TEMPLATE = """
124 | Your task is to generate a summary of the text sample.
125 | Summarize the text sample provided below, delimited by triple backticks, in at most {max_words} words.
126 |
127 | Text sample: ```{x}```
128 | Summarized text:
129 | """
130 |
131 | FOCUSED_SUMMARY_PROMPT_TEMPLATE = """
132 | As an input you will receive:
133 | 1. A focus parameter delimited with square brackets.
134 | 2. A single text sample delimited with triple backticks.
135 |
136 | Perform the following actions:
137 | 1. Determine whether there is something in the text that matches focus. Do not output anything.
138 | 2. Summarise the text in at most {max_words} words.
139 | 3. If possible, make the summarisation focused on the concept provided in the focus parameter. Otherwise, provide a general summarisation. Do not state that general summary is provided.
140 | 4. Do not output anything except of the summary. Do not output any text that was not present in the original text.
141 | 5. If no focused summary possible, or the mentioned concept is not present in the text, output "Mentioned concept is not present in the text." and the general summary. Do not state that general summary is provided.
142 |
143 | Focus: [{focus}]
144 |
145 | Text sample: ```{x}```
146 |
147 | Summarized text:
148 | """
149 |
150 | TRANSLATION_PROMPT_TEMPLATE = """
151 | If the original text, delimited by triple backticks, is already in {output_language} language, output the original text.
152 | Otherwise, translate the original text, delimited by triple backticks, to {output_language} language, and output the translated text only. Do not output any additional information except the translated text.
153 |
154 | Original text: ```{x}```
155 | Output:
156 | """
157 |
158 | NER_SYSTEM_MESSAGE_TEMPLATE = """You are an expert in Natural Language Processing. Your task is to identify common Named Entities (NER) in a text provided by the user.
159 | Mark the entities with tags according to the following guidelines:
160 | - Use XML format to tag entities;
161 | - All entities must be enclosed in ... tags; All other text must be enclosed in ... tags; No content should be outside of these tags;
162 | - The tagging operation must be invertible, i.e. the original text must be recoverable from the tagged textl; This is crucial and easy to overlook, double-check this requirement;
163 | - Adjacent entities should be separated into different tags;
164 | - The list of entities is strictly restricted to the following: {entities}.
165 | """
166 |
167 | NER_SYSTEM_MESSAGE_SPARSE = """You are an expert in Natural Language Processing."""
168 |
169 | EXPLAINABLE_NER_DENSE_PROMPT_TEMPLATE = """You are provided with a text. Your task is to identify and tag all named entities within the text using the following entity types only:
170 | {entities}
171 |
172 | For each entity, provide a brief explanation for your choice within an XML comment. Use the following XML tag format for each entity:
173 |
174 | Your reasoning hereENTITY_NAME_UPPERCASEEntity text
175 |
176 | The remaining text must be enclosed in a TEXT tag.
177 |
178 | Focus on the context and meaning of each entity rather than just the exact words. The tags should encompass the entire entity based on its definition and usage in the sentence. It is crucial to base your decision on the description of the entity, not just its name.
179 |
180 | Format example:
181 |
182 | Input:
183 | ```This text contains some entity and another entity.```
184 |
185 | Output:
186 | ```xml
187 | This text contains some justificationENTITY1some entity and another another justificationENTITY2entity.
188 | ```
189 |
190 | Input:
191 | ```
192 | {x}
193 | ```
194 |
195 | Output (origina text with tags):
196 | """
197 |
198 |
199 | EXPLAINABLE_NER_SPARSE_PROMPT_TEMPLATE = """You are provided with a text. Your task is to identify and tag all named entities within the text using the following entity types only:
200 | {entities}
201 |
202 | You must provide the following information for each entity:
203 | - The reasoning of why you tagged the entity as such; Based on the reasoning, a non-expert should be able to evaluate your decision;
204 | - The tag of the entity (uppercase);
205 | - The value of the entity (as it appears in the text).
206 |
207 | Your response should be json formatted using the following schema:
208 |
209 | {{
210 | "$schema": "http://json-schema.org/draft-04/schema#",
211 | "type": "array",
212 | "items": [
213 | {{
214 | "type": "object",
215 | "properties": {{
216 | "reasoning": {{
217 | "type": "string"
218 | }},
219 | "tag": {{
220 | "type": "string"
221 | }},
222 | "value": {{
223 | "type": "string"
224 | }}
225 | }},
226 | "required": [
227 | "reasoning",
228 | "tag",
229 | "value"
230 | ]
231 | }}
232 | ]
233 | }}
234 |
235 |
236 | Input:
237 | ```
238 | {x}
239 | ```
240 |
241 | Output json:
242 | """
243 |
--------------------------------------------------------------------------------
/skllm/text2text.py:
--------------------------------------------------------------------------------
1 | ## GPT
2 | from skllm.models.gpt.text2text.summarization import GPTSummarizer
3 | from skllm.models.gpt.text2text.translation import GPTTranslator
4 | from skllm.models.gpt.text2text.tunable import TunableGPTText2Text
5 |
6 | ## Vertex
7 |
8 | from skllm.models.vertex.text2text.tunable import TunableVertexText2Text
9 |
--------------------------------------------------------------------------------
/skllm/utils/__init__.py:
--------------------------------------------------------------------------------
1 | import json
2 | from typing import Any
3 | import numpy as np
4 | import pandas as pd
5 | from functools import wraps
6 | from time import sleep
7 | import re
8 |
9 | def to_numpy(X: Any) -> np.ndarray:
10 | """Converts a pandas Series or list to a numpy array.
11 |
12 | Parameters
13 | ----------
14 | X : Any
15 | The data to convert to a numpy array.
16 |
17 | Returns
18 | -------
19 | X : np.ndarray
20 | """
21 | if isinstance(X, pd.Series):
22 | X = X.to_numpy().astype(object)
23 | elif isinstance(X, list):
24 | X = np.asarray(X, dtype=object)
25 | if isinstance(X, np.ndarray) and len(X.shape) > 1:
26 | # do not squeeze the first dim
27 | X = np.squeeze(X, axis=tuple([i for i in range(1, len(X.shape))]))
28 | return X
29 |
30 | # TODO: replace with re version below
31 | def find_json_in_string(string: str) -> str:
32 | """Finds the JSON object in a string.
33 |
34 | Parameters
35 | ----------
36 | string : str
37 | The string to search for a JSON object.
38 |
39 | Returns
40 | -------
41 | json_string : str
42 | """
43 | start = string.find("{")
44 | end = string.rfind("}")
45 | if start != -1 and end != -1:
46 | json_string = string[start : end + 1]
47 | else:
48 | json_string = "{}"
49 | return json_string
50 |
51 |
52 |
53 | def re_naive_json_extractor(json_string: str, expected_output: str = "object") -> str:
54 | """Finds the first JSON-like object or array in a string using regex.
55 |
56 | Parameters
57 | ----------
58 | string : str
59 | The string to search for a JSON object or array.
60 |
61 | Returns
62 | -------
63 | json_string : str
64 | A JSON string if found, otherwise an empty JSON object.
65 | """
66 | json_pattern = json_pattern = r'(\{.*\}|\[.*\])'
67 | match = re.search(json_pattern, json_string, re.DOTALL)
68 | if match:
69 | return match.group(0)
70 | else:
71 | return r"{}" if expected_output == "object" else "[]"
72 |
73 |
74 |
75 |
76 | def extract_json_key(json_: str, key: str):
77 | """Extracts JSON key from a string.
78 |
79 | json_ : str
80 | The JSON string to extract the key from.
81 | key : str
82 | The key to extract.
83 | """
84 | original_json = json_
85 | for i in range(2):
86 | try:
87 | json_ = original_json.replace("\n", "")
88 | if i == 1:
89 | json_ = json_.replace("'", '"')
90 | json_ = find_json_in_string(json_)
91 | as_json = json.loads(json_)
92 | if key not in as_json.keys():
93 | raise KeyError("The required key was not found")
94 | return as_json[key]
95 | except Exception:
96 | if i == 0:
97 | continue
98 | return None
99 |
100 |
101 | def retry(max_retries=3):
102 | def decorator(func):
103 | @wraps(func)
104 | def wrapper(*args, **kwargs):
105 | for attempt in range(max_retries):
106 | try:
107 | return func(*args, **kwargs)
108 | except Exception as e:
109 | error_msg = str(e)
110 | error_type = type(e).__name__
111 | sleep(2**attempt)
112 | err_msg = (
113 | f"Could not complete the operation after {max_retries} retries:"
114 | f" `{error_type} :: {error_msg}`"
115 | )
116 | print(err_msg)
117 | raise RuntimeError(err_msg)
118 |
119 | return wrapper
120 |
121 | return decorator
122 |
--------------------------------------------------------------------------------
/skllm/utils/rendering.py:
--------------------------------------------------------------------------------
1 | import random
2 | import re
3 | import html
4 | from typing import Dict, List
5 |
6 | color_palettes = {
7 | "light": [
8 | "lightblue",
9 | "lightgreen",
10 | "lightcoral",
11 | "lightsalmon",
12 | "lightyellow",
13 | "lightpink",
14 | "lightgray",
15 | "lightcyan",
16 | ],
17 | "dark": [
18 | "darkblue",
19 | "darkgreen",
20 | "darkred",
21 | "darkorange",
22 | "darkgoldenrod",
23 | "darkmagenta",
24 | "darkgray",
25 | "darkcyan",
26 | ],
27 | }
28 |
29 |
30 | def get_random_color():
31 | return f"#{random.randint(0, 0xFFFFFF):06x}"
32 |
33 |
34 | # def validate_text(input_text, output_text):
35 | # # Verify the original text was not changed (other than addition of tags)
36 | # stripped_output_text = re.sub(r'<<.*?>>', '', output_text)
37 | # stripped_output_text = re.sub(r'<>', '', stripped_output_text)
38 | # if not all(word in stripped_output_text.split() for word in input_text.split()):
39 | # raise ValueError("Original text was altered.")
40 | # return True
41 |
42 |
43 | # TODO: In the future this should probably be replaced with a proper HTML template
44 | def render_ner(output_texts, allowed_entities):
45 | entity_colors = {}
46 | all_entities = [k.upper() for k in allowed_entities.keys()]
47 |
48 | for i, entity in enumerate(all_entities):
49 | if i < len(color_palettes["light"]):
50 | entity_colors[entity] = {
51 | "light": color_palettes["light"][i],
52 | "dark": color_palettes["dark"][i],
53 | }
54 | else:
55 | random_color = get_random_color()
56 | entity_colors[entity] = {"light": random_color, "dark": random_color}
57 |
58 | def replace_match(match):
59 | reasoning, entity, text = match.groups()
60 | entity = entity.upper()
61 | return (
62 | f'{text}'
64 | )
65 |
66 | legend_html = "