├── .dockerignore ├── .github ├── ISSUE_TEMPLATE │ ├── bug_report.md │ ├── feature_request.md │ └── user_story.md ├── dependabot.yml └── workflows │ ├── build-image.yml │ ├── build-library.yml │ ├── lint-code.yml │ └── publish-library.yml ├── .gitignore ├── .isort.cfg ├── .pre-commit-config.yaml ├── .prettierignore ├── .pylintrc ├── .whitesource ├── CODEOWNERS ├── CONTRIBUTING.md ├── Dockerfile ├── LICENSE ├── README.md ├── SECURITY.md ├── benchmarks ├── README.md └── logs │ └── llama2-7b │ ├── 20230905_183655.output │ ├── 20230905_184809.output │ ├── 20230905_191650.output │ ├── 20230905_194133.output │ └── 20230906_135211.output ├── caikit_nlp ├── __init__.py ├── config │ ├── __init__.py │ └── config.yml ├── data_model │ ├── __init__.py │ └── generation.py ├── model_management │ ├── __init__.py │ └── tgis_auto_finder.py ├── modules │ ├── __init__.py │ ├── text_classification │ │ ├── __init__.py │ │ └── sequence_classification.py │ ├── text_embedding │ │ ├── __init__.py │ │ ├── crossencoder.py │ │ ├── embedding.py │ │ └── utils.py │ ├── text_generation │ │ ├── __init__.py │ │ ├── peft_config.py │ │ ├── peft_prompt_tuning.py │ │ ├── peft_tgis_remote.py │ │ ├── text_generation_local.py │ │ └── text_generation_tgis.py │ ├── token_classification │ │ ├── __init__.py │ │ └── filtered_span_classification.py │ └── tokenization │ │ ├── __init__.py │ │ └── regex_sentence_splitter.py ├── resources │ ├── __init__.py │ └── pretrained_model │ │ ├── __init__.py │ │ ├── base.py │ │ ├── hf_auto_causal_lm.py │ │ ├── hf_auto_seq2seq_lm.py │ │ └── hf_auto_seq_classifier.py ├── toolkit │ ├── __init__.py │ ├── data_stream_wrapper.py │ ├── data_type_utils.py │ ├── task_specific_utils.py │ ├── text_generation │ │ ├── __init__.py │ │ ├── model_run_utils.py │ │ └── tgis_utils.py │ ├── torch_run.py │ ├── trainer_utils.py │ └── verbalizer_utils.py └── version.py ├── code-of-conduct.md ├── examples ├── Caikit_Getting_Started.ipynb ├── compare_local_vs_tgis_models.py ├── evaluate_model.py ├── kill-text-generation-launcher.sh ├── load_and_run_distributed_peft.py ├── run_fine_tuning.py ├── run_peft_tuning.py ├── text-generation-launcher └── utils.py ├── prompt_tuning_parameter_selection.md ├── pyproject.toml ├── runtime_config.yaml ├── runtime_template └── run_with_gateway.sh ├── scripts ├── dump_apis.sh ├── fmt.sh └── run_local.sh ├── setup_requirements.txt ├── tests ├── __init__.py ├── conftest.py ├── data_model │ └── test_generation.py ├── fixtures │ ├── __init__.py │ ├── data_model │ │ └── sample_objects.py │ └── tiny_models │ │ ├── BertForSequenceClassification │ │ ├── config.json │ │ ├── pytorch_model.bin │ │ ├── special_tokens_map.json │ │ ├── tf_model.h5 │ │ ├── tokenizer.json │ │ ├── tokenizer_config.json │ │ └── vocab.txt │ │ ├── BloomForCausalLM │ │ ├── config.json │ │ ├── generation_config.json │ │ ├── pytorch_model.bin │ │ ├── special_tokens_map.json │ │ ├── tokenizer.json │ │ └── tokenizer_config.json │ │ ├── README.md │ │ └── T5ForConditionalGeneration │ │ ├── config.json │ │ ├── generation_config.json │ │ ├── pytorch_model.bin │ │ ├── special_tokens_map.json │ │ ├── tokenizer.json │ │ └── tokenizer_config.json ├── model_management │ ├── __init__.py │ └── test_tgis_auto_finder.py ├── modules │ ├── __init__.py │ ├── text_classification │ │ └── test_sequence_classification.py │ ├── text_embedding │ │ ├── test_crossencoder.py │ │ └── test_embedding.py │ ├── text_generation │ │ ├── __init__.py │ │ ├── test_peft_config.py │ │ ├── test_peft_prompt_tuning.py │ │ ├── test_peft_tgis_remote.py │ │ ├── test_text_generation_local.py │ │ └── test_text_generation_tgis.py │ ├── token_classification │ │ └── test_filtered_span_classification.py │ └── tokenization │ │ └── test_regex_sentence_splitter.py ├── resources │ ├── __init__.py │ └── test_pretrained_model.py └── toolkit │ ├── __init__.py │ ├── test_data_stream_wrapper.py │ ├── test_data_type_utils.py │ ├── test_task_specific_utils.py │ ├── test_verbalizers.py │ └── text_generation │ ├── test_model_run_utils.py │ └── test_tgis_utils.py └── tox.ini /.dockerignore: -------------------------------------------------------------------------------- 1 | * 2 | !README.md 3 | !LICENSE 4 | !caikit_nlp 5 | !pyproject.toml 6 | !tox.ini 7 | !.git 8 | -------------------------------------------------------------------------------- /.github/ISSUE_TEMPLATE/bug_report.md: -------------------------------------------------------------------------------- 1 | --- 2 | name: Bug report 3 | about: Create a report to help us improve 4 | title: "" 5 | labels: "" 6 | assignees: "" 7 | --- 8 | 9 | ## Describe the bug 10 | 11 | A clear and concise description of what the bug is. 12 | 13 | ## Platform 14 | 15 | Please provide details about the environment you are using, including the following: 16 | 17 | - Interpreter version: 18 | - Library version: 19 | 20 | ## Sample Code 21 | 22 | Please include a minimal sample of the code that will (if possible) reproduce the bug in isolation 23 | 24 | ## Expected behavior 25 | 26 | A clear and concise description of what you expected to happen. 27 | 28 | ## Observed behavior 29 | 30 | What you see happening (error messages, stack traces, etc...) 31 | 32 | ## Additional context 33 | 34 | Add any other context about the problem here. 35 | -------------------------------------------------------------------------------- /.github/ISSUE_TEMPLATE/feature_request.md: -------------------------------------------------------------------------------- 1 | --- 2 | name: Feature request 3 | about: Suggest an idea for this project 4 | title: "" 5 | labels: "" 6 | assignees: "" 7 | --- 8 | 9 | ## Is your feature request related to a problem? Please describe. 10 | 11 | A clear and concise description of what the problem is. Ex. I'm always frustrated when [...] 12 | 13 | ## Describe the solution you'd like 14 | 15 | A clear and concise description of what you want to happen. 16 | 17 | ## Describe alternatives you've considered 18 | 19 | A clear and concise description of any alternative solutions or features you've considered. 20 | 21 | ## Additional context 22 | 23 | Add any other context about the feature request here. 24 | -------------------------------------------------------------------------------- /.github/ISSUE_TEMPLATE/user_story.md: -------------------------------------------------------------------------------- 1 | --- 2 | name: User story 3 | about: A user-oriented story describing a piece of work to do 4 | title: "" 5 | labels: "" 6 | assignees: "" 7 | --- 8 | 9 | ## Description 10 | 11 | As a , I want to , so that I can 12 | 13 | ## Discussion 14 | 15 | Provide detailed discussion here 16 | 17 | ## Acceptance Criteria 18 | 19 | 20 | 21 | - [ ] Unit tests cover new/changed code 22 | - [ ] Examples build against new/changed code 23 | - [ ] READMEs are updated 24 | - [ ] Type of [semantic version](https://semver.org/) change is identified 25 | -------------------------------------------------------------------------------- /.github/dependabot.yml: -------------------------------------------------------------------------------- 1 | version: 2 2 | updates: 3 | - package-ecosystem: "pip" 4 | directory: "/" 5 | schedule: 6 | interval: "daily" 7 | -------------------------------------------------------------------------------- /.github/workflows/build-image.yml: -------------------------------------------------------------------------------- 1 | on: 2 | push: 3 | branches: ["main", "release-*"] 4 | paths: 5 | - "caikit_nlp" 6 | - "README.md" 7 | - "pyproject.toml" 8 | - "Dockerfile" 9 | 10 | pull_request: 11 | 12 | name: Build Image 13 | 14 | jobs: 15 | build-image: 16 | name: Build Image 17 | runs-on: ubuntu-latest 18 | steps: 19 | - uses: actions/checkout@v3 20 | - name: Reclaim space 21 | run: | 22 | sudo rm -rf /opt/hostedtoolcache 23 | - name: Set up Docker Buildx 24 | uses: docker/setup-buildx-action@v3 25 | - name: Build image 26 | uses: docker/build-push-action@v5 27 | with: 28 | context: . 29 | tags: "caikit-nlp:latest" 30 | load: true 31 | cache-from: type=gha 32 | cache-to: type=gha,mode=max 33 | -------------------------------------------------------------------------------- /.github/workflows/build-library.yml: -------------------------------------------------------------------------------- 1 | # Copyright The Caikit Authors 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | name: Build Caikit NLP Library 16 | 17 | on: 18 | push: 19 | branches: ["main", "release-*"] 20 | pull_request: 21 | branches: ["main", "release-*"] 22 | 23 | jobs: 24 | build: 25 | runs-on: ubuntu-latest 26 | strategy: 27 | matrix: 28 | python-version: 29 | # NOTE: TGIS does not support 3.8 or 3.11 30 | - setup: "3.9" 31 | tox: "py39" 32 | - setup: "3.10" 33 | tox: "py310" 34 | 35 | steps: 36 | - uses: actions/checkout@v3 37 | - name: Set up Python ${{ matrix.python-version.setup }} 38 | uses: actions/setup-python@v4 39 | with: 40 | python-version: ${{ matrix.python-version.setup }} 41 | - name: Install dependencies 42 | run: | 43 | python -m pip install --upgrade pip 44 | python -m pip install -r setup_requirements.txt 45 | - name: Build and test with tox 46 | run: tox -e ${{ matrix.python-version.tox }} 47 | -------------------------------------------------------------------------------- /.github/workflows/lint-code.yml: -------------------------------------------------------------------------------- 1 | # Copyright The Caikit Authors 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | name: Lint and Format 16 | 17 | on: 18 | push: 19 | branches: ["main", "release-*"] 20 | pull_request: 21 | branches: ["main", "release-*"] 22 | 23 | jobs: 24 | build: 25 | runs-on: ubuntu-latest 26 | steps: 27 | - uses: actions/checkout@v3 28 | - name: Set up Python 3.9 29 | uses: actions/setup-python@v4 30 | with: 31 | python-version: 3.9 32 | - name: Install dependencies 33 | run: | 34 | python -m pip install --upgrade pip 35 | python -m pip install -r setup_requirements.txt 36 | - name: Check Formatting 37 | run: tox -e fmt 38 | - name: Run pylint 39 | run: tox -e lint 40 | -------------------------------------------------------------------------------- /.github/workflows/publish-library.yml: -------------------------------------------------------------------------------- 1 | # Copyright The Caikit Authors 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | name: Publish 16 | 17 | on: 18 | release: 19 | types: [published] 20 | 21 | jobs: 22 | build: 23 | runs-on: ubuntu-latest 24 | steps: 25 | - uses: actions/checkout@v3 26 | - name: Set up Python 27 | uses: actions/setup-python@v3 28 | - name: Build and check package 29 | run: | 30 | pip install tox 31 | tox -e build,twinecheck 32 | - name: Upload package 33 | if: github.event_name == 'release' 34 | uses: pypa/gh-action-pypi-publish@release/v1 35 | with: 36 | password: ${{ secrets.PYPI_TOKEN }} 37 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | data/classification_data/ 11 | models/ 12 | tmp_models/ 13 | models_to_upload/ 14 | upload_models/ 15 | downloaded_models/ 16 | .config 17 | .jupyter 18 | .Python 19 | .bash_history 20 | .python_history 21 | .local/ 22 | build/ 23 | develop-eggs/ 24 | dist/ 25 | downloads/ 26 | eggs/ 27 | .eggs/ 28 | lib/ 29 | lib64/ 30 | parts/ 31 | sdist/ 32 | var/ 33 | wheels/ 34 | *.egg-info/ 35 | .installed.cfg 36 | *.egg 37 | MANIFEST 38 | *_pb2.py 39 | *_pb2_grpc.py 40 | .local/ 41 | .bash_history 42 | .python_history 43 | 44 | # PyInstaller 45 | # Usually these files are written by a python script from a template 46 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 47 | *.manifest 48 | *.spec 49 | 50 | # Installer logs 51 | pip-log.txt 52 | pip-delete-this-directory.txt 53 | 54 | # Unit test / coverage reports 55 | htmlcov/ 56 | .tox/ 57 | .coverage 58 | .coverage.* 59 | .cache 60 | nosetests.xml 61 | coverage.xml 62 | *.cover 63 | .hypothesis/ 64 | .pytest_cache/ 65 | 66 | # Translations 67 | *.mo 68 | *.pot 69 | 70 | # Django stuff: 71 | *.log 72 | local_settings.py 73 | db.sqlite3 74 | 75 | # Flask stuff: 76 | instance/ 77 | .webassets-cache 78 | 79 | # Scrapy stuff: 80 | .scrapy 81 | 82 | # Sphinx documentation 83 | docs/_build/ 84 | 85 | # PyBuilder 86 | target/ 87 | 88 | # Jupyter Notebook 89 | .ipynb_checkpoints 90 | 91 | # pyenv 92 | .python-version 93 | .ipython 94 | # vscode 95 | .vscode/ 96 | # celery beat schedule file 97 | celerybeat-schedule 98 | 99 | # SageMath parsed files 100 | *.sage.py 101 | 102 | # Environments 103 | .env 104 | .venv 105 | env/ 106 | venv/ 107 | ENV/ 108 | env.bak/ 109 | venv.bak/ 110 | 111 | # shell environment 112 | .inputrc 113 | .profile 114 | 115 | # Spyder project settings 116 | .spyderproject 117 | .spyproject 118 | 119 | # Rope project settings 120 | .ropeproject 121 | 122 | # mkdocs documentation 123 | /site 124 | 125 | # mypy 126 | .mypy_cache/ 127 | 128 | # IntelliJ 129 | .idea/ 130 | 131 | # Mac OS 132 | .DS_Store 133 | __MACOSX/ 134 | generic_perf_benchmark 135 | .local 136 | .keras 137 | .bash_history 138 | perf_test_artifacts 139 | tests/fixtures/downloaded 140 | .mycontrib 141 | mlruns 142 | # Pylint 143 | .pylint.d 144 | 145 | # Make artifacts 146 | .version 147 | .location 148 | build 149 | generate_protos 150 | GITHUB_TOKEN 151 | reports 152 | dry_run.txt 153 | 154 | ### Artifacts from training scripts 155 | # Path to MPT exported models 156 | saved_outputs 157 | # Path used by sampled commands for checkpoints etc 158 | teacher_prompts_t5-base 159 | 160 | # Runtime artifacts 161 | eligible_modules.json 162 | ineligible_modules.json 163 | modules.json 164 | 165 | # Other artifacts from demo scripts, e.g., TGIS integration scripts 166 | prompt_prefixes 167 | sample_prompt 168 | transformers_cache 169 | generated_interfaces 170 | /caikit_nlp/_version.py 171 | -------------------------------------------------------------------------------- /.isort.cfg: -------------------------------------------------------------------------------- 1 | [settings] 2 | profile=black 3 | from_first=true 4 | import_heading_future=Future 5 | import_heading_stdlib=Standard 6 | import_heading_thirdparty=Third Party 7 | import_heading_firstparty=First Party 8 | import_heading_localfolder=Local 9 | known_firstparty=alog,aconfig,caikit,caikit_tgis_backend,import_tracker 10 | known_localfolder=caikit_nlp,tests 11 | -------------------------------------------------------------------------------- /.pre-commit-config.yaml: -------------------------------------------------------------------------------- 1 | repos: 2 | - repo: https://github.com/pre-commit/mirrors-prettier 3 | rev: v2.1.2 4 | hooks: 5 | - id: prettier 6 | - repo: https://github.com/psf/black 7 | rev: 22.3.0 8 | hooks: 9 | - id: black 10 | exclude: imports 11 | additional_dependencies: ["platformdirs"] 12 | - repo: https://github.com/PyCQA/isort 13 | rev: 5.11.5 14 | hooks: 15 | - id: isort 16 | exclude: imports 17 | -------------------------------------------------------------------------------- /.prettierignore: -------------------------------------------------------------------------------- 1 | venv 2 | docker_build_scripts 3 | htmlcov 4 | reports 5 | .pytest_cache 6 | models 7 | *.md 8 | tests/fixtures/tiny_models 9 | -------------------------------------------------------------------------------- /.whitesource: -------------------------------------------------------------------------------- 1 | { 2 | "settingsInheritedFrom": "whitesource-config/whitesource-config@master" 3 | } -------------------------------------------------------------------------------- /CODEOWNERS: -------------------------------------------------------------------------------- 1 | ##################################################### 2 | # 3 | # List of approvers for caikit/caikit-nlp repository 4 | # 5 | ##################################################### 6 | # 7 | # Learn about CODEOWNERS file format: 8 | # https://help.github.com/en/articles/about-code-owners 9 | # 10 | 11 | * @gkumbhat @evaline-ju @gabe-l-hart 12 | -------------------------------------------------------------------------------- /CONTRIBUTING.md: -------------------------------------------------------------------------------- 1 | # Contributing 2 | 3 | 👍🎉 First off, thank you for taking the time to contribute! 🎉👍 4 | 5 | The following is a set of guidelines for contributing. These are just guidelines, not rules. Use your best judgment, and feel free to propose changes to this document in a pull request. Please read the [community contribution guide](https://github.com/caikit/community/blob/main/CONTRIBUTING.md) first for general practices for the Caikit community. 6 | 7 | ## What Should I Know Before I Get Started? 8 | 9 | ### Code of Conduct 10 | 11 | This project adheres to the [Contributor Covenant](./code-of-conduct.md). By participating, you are expected to uphold this code. 12 | 13 | Please report unacceptable behavior to one of the [Code Owners](./CODEOWNERS). 14 | 15 | ### How Do I Start Contributing? 16 | 17 | The below workflow is designed to help you begin your first contribution journey. It will guide you through creating and picking up issues, working through them, having your work reviewed, and then merging. 18 | 19 | Help on open source projects is always welcome and there is always something that can be improved. For example, documentation (like the text you are reading now) can always use improvement, code can always be clarified, variables or functions can always be renamed or commented on, and there is always a need for more test coverage. If you see something that you think should be fixed, take ownership! Here is how you get started: 20 | 21 | ## How Can I Contribute? 22 | 23 | When contributing, it's useful to start by looking at [issues](https://github.com/caikit/caikit-nlp/issues). After picking up an issue, writing code, or updating a document, make a pull request and your work will be reviewed and merged. If you're adding a new feature or find a bug, it's best to [write an issue](https://github.com/caikit/caikit-nlp/issues/new?assignees=&labels=&template=feature_request.md&title=) first to discuss it with maintainers. 24 | 25 | To contribute to this repo, you'll use the Fork and Pull model common in many open source repositories. For details on this process, check out [The GitHub Workflow 26 | Guide](https://github.com/kubernetes/community/blob/master/contributors/guide/github-workflow.md) 27 | from Kubernetes. 28 | 29 | When your contribution is ready, you can create a pull request. Pull requests are often referred to as "PR". In general, we follow the standard [GitHub pull request](https://help.github.com/en/articles/about-pull-requests) process. Follow the template to provide details about your pull request to the maintainers. 30 | 31 | Before sending pull requests, make sure your changes pass formatting, linting and unit tests. 32 | 33 | #### Code Review 34 | 35 | Once you've [created a pull request](#how-can-i-contribute), maintainers will review your code and may make suggestions to fix before merging. It will be easier for your pull request to receive reviews if you consider the criteria the reviewers follow while working. Remember to: 36 | 37 | - Run tests locally and ensure they pass 38 | - Follow the project coding conventions 39 | - Write detailed commit messages 40 | - Break large changes into a logical series of smaller patches, which are easy to understand individually and combine to solve a broader issue 41 | 42 | ### Reporting Bugs 43 | 44 | This section guides you through submitting a bug report. Following these guidelines helps maintainers and the community understand your report ✏️, reproduce the behavior 💻, and find related reports 🔎. 45 | 46 | #### How Do I Submit A (Good) Bug Report? 47 | 48 | Bugs are tracked as [GitHub issues using the Bug Report template](https://github.com/caikit/caikit-nlp/issues/new?assignees=&labels=&template=bug_report.md&title=). Create an issue on that and provide the information suggested in the bug report issue template. 49 | 50 | ### Suggesting Enhancements 51 | 52 | This section guides you through submitting an enhancement suggestion, including completely new features, tools, and minor improvements to existing functionality. Following these guidelines helps maintainers and the community understand your suggestion ✏️ and find related suggestions 🔎 53 | 54 | #### How Do I Submit A (Good) Enhancement Suggestion? 55 | 56 | Enhancement suggestions are tracked as [GitHub issues using the Feature Request template](https://github.com/caikit/caikit-nlp/issues/new?assignees=&labels=&template=feature_request.md&title=). Create an issue and provide the information suggested in the feature requests or user story issue template. 57 | 58 | #### How Do I Submit A (Good) Improvement Item? 59 | 60 | Improvements to existing functionality are tracked as [GitHub issues using the User Story template](https://github.com/caikit/caikit-nlp/issues/new?assignees=&labels=&template=user_story.md&title=). Create an issue and provide the information suggested in the feature requests or user story issue template. 61 | 62 | ## Development 63 | 64 | ### Set up your dev environment 65 | 66 | The following tools are required: 67 | 68 | - [git](https://git-scm.com) 69 | - [python](https://www.python.org) (v3.8+) 70 | - [pip](https://pypi.org/project/pip/) (v23.0+) 71 | 72 | You can setup your dev environment using [tox](https://tox.wiki/en/latest/), an environment orchestrator which allows for setting up environments for and invoking builds, unit tests, formatting, linting, etc. Install tox with: 73 | 74 | ```sh 75 | pip install -r setup_requirements.txt 76 | ``` 77 | 78 | If you want to manage your own virtual environment instead of using `tox`, you can install `caikit` and all dependencies with: 79 | 80 | ```sh 81 | pip install . 82 | ``` 83 | 84 | ### Unit tests 85 | 86 | Unit tests are enforced by the CI system. When making changes, run the tests before pushing the changes to avoid CI issues. 87 | 88 | Running unit tests against all supported Python versions is as simple as: 89 | 90 | ```sh 91 | tox 92 | ``` 93 | 94 | Running tests against a single Python version can be done with: 95 | 96 | ```sh 97 | tox -e py 98 | ``` 99 | 100 | ### Coding style 101 | 102 | Caikit follows the python [pep8](https://peps.python.org/pep-0008/) coding style. The coding style is enforced by the CI system, and your PR will fail until the style has been applied correctly. 103 | 104 | We use [pre-commit](https://pre-commit.com/) to enforce coding style using [black](https://github.com/psf/black), [prettier](https://github.com/prettier/prettier) and [isort](https://pycqa.github.io/isort/). 105 | 106 | You can invoke formatting with: 107 | 108 | ```sh 109 | tox -e fmt 110 | ``` 111 | 112 | In addition, we use [pylint](https://www.pylint.org) to perform static code analysis of the code. 113 | 114 | You can invoke the linting with the following command 115 | 116 | ```sh 117 | tox -e lint 118 | ``` 119 | 120 | ## Your First Code Contribution 121 | 122 | Unsure where to begin contributing? You can start by looking through these issues: 123 | 124 | - Issues with the [`good first issue` label](https://github.com/caikit/caikit-nlp/issues?q=is%3Aopen+is%3Aissue+label%3A%22good+first+issue%22) - these should only require a few lines of code and are good targets if you're just starting contributing. 125 | - Issues with the [`help wanted` label](https://github.com/caikit/caikit-nlp/issues?q=is%3Aopen+is%3Aissue+label%3A%22help+wanted%22) - these range from simple to more complex, but are generally things we want but can't get to in a short time frame. 126 | 127 | -------------------------------------------------------------------------------- /Dockerfile: -------------------------------------------------------------------------------- 1 | FROM registry.access.redhat.com/ubi9/ubi-minimal:latest as base 2 | 3 | RUN microdnf update -y && \ 4 | microdnf install -y \ 5 | git python-pip && \ 6 | pip install --upgrade --no-cache-dir pip wheel && \ 7 | microdnf clean all 8 | 9 | FROM base as builder 10 | WORKDIR /build 11 | 12 | RUN pip install --no-cache tox 13 | COPY README.md . 14 | COPY pyproject.toml . 15 | COPY tox.ini . 16 | COPY caikit_nlp caikit_nlp 17 | # .git is required for setuptools-scm get the version 18 | RUN --mount=source=.git,target=.git,type=bind \ 19 | --mount=type=cache,target=/root/.cache/pip \ 20 | tox -e build 21 | 22 | 23 | FROM base as deploy 24 | 25 | RUN python -m venv --upgrade-deps /opt/caikit/ 26 | 27 | ENV VIRTUAL_ENV=/opt/caikit 28 | ENV PATH="$VIRTUAL_ENV/bin:$PATH" 29 | 30 | COPY --from=builder /build/dist/caikit_nlp*.whl /tmp/ 31 | RUN --mount=type=cache,target=/root/.cache/pip \ 32 | pip install /tmp/caikit_nlp*.whl && \ 33 | rm /tmp/caikit_nlp*.whl 34 | 35 | COPY LICENSE /opt/caikit/ 36 | COPY README.md /opt/caikit/ 37 | 38 | RUN groupadd --system caikit --gid 1001 && \ 39 | adduser --system --uid 1001 --gid 0 --groups caikit \ 40 | --home-dir /caikit --shell /sbin/nologin \ 41 | --comment "Caikit User" caikit 42 | 43 | USER caikit 44 | 45 | ENV RUNTIME_LIBRARY=caikit_nlp 46 | # Optional: use `CONFIG_FILES` and the /caikit/ volume to explicitly provide a configuration file and models 47 | # ENV CONFIG_FILES=/caikit/caikit.yml 48 | VOLUME ["/caikit/"] 49 | WORKDIR /caikit 50 | 51 | CMD ["python"] 52 | -------------------------------------------------------------------------------- /SECURITY.md: -------------------------------------------------------------------------------- 1 | # Security Policy 2 | 3 | The Caikit project has a [common security policy that can be found here](https://github.com/caikit/community/blob/main/SECURITY.md). 4 | -------------------------------------------------------------------------------- /benchmarks/README.md: -------------------------------------------------------------------------------- 1 | # Caikit NLP Runtime Performance Benchmarks 2 | 3 | Runtime performance benchmarking results for various model on various hardware configurations. 4 | 5 | ## Llama2-7b 6 | 7 | | Date Executed | Hardware | Training Set | Epoch | Precision | Batch Size | Max Source Length | Training Runtime (s) | Samples Per Second | Train Steps Per Second | Loss | Notes | 8 | |---|---|---------------|---|---|:---:|---|------------| --- |---|---|---| 9 | | [2023-09-05](./logs/llama2-7b/20230905_183655.output) | 1 x A100 80GB | [Glue / RTE](https://huggingface.co/datasets/glue) | 1 | bfloat16 | 6 | 4096 | 350 | 21.325 | 0.22 | 1.65 | 4096 is the context size for Llama2 | 10 | | [2023-09-05](./logs/llama2-7b/20230905_184809.output) | 1 x A100 80GB | [Glue / RTE](https://huggingface.co/datasets/glue) | 1 | bfloat16 | 6 | 1024 | 350 | 21.333 | 0.22 | 1.65 | batch size of 7 fails CUDA OOM | 11 | | [2023-09-06](./logs/llama2-7b/20230906_135211.output) | 1 x A100 80GB | [Glue / RTE](https://huggingface.co/datasets/glue) | 1 | bfloat16 | 6 | 512 | 348 | 21.44 | 0.22 | 1.65 | batch size of 7 fails CUDA OOM | 12 | | [2023-09-05](./logs/llama2-7b/20230905_194133.output) | 1 x A100 80GB | [Glue / RTE](https://huggingface.co/datasets/glue) | 1 | bfloat16 | 8 | 256 | 356 | 20.939 | 0.16 | 1.70 | batch size of 9 fails CUDA OOM | 13 | | [2023-09-05](./logs/llama2-7b/20230905_191650.output) | 1 x A100 80GB | [Glue / RTE](https://huggingface.co/datasets/glue) | 1 | bfloat16 | 19 | 128 | 254 | 29.332 | 0.09 | 1.94 | batch size of 20 fails CUDA OOM | 14 | -------------------------------------------------------------------------------- /benchmarks/logs/llama2-7b/20230905_183655.output: -------------------------------------------------------------------------------- 1 | (tuning) [gpu_user@gpu5480 caikit-nlp]$ ./ft_job.sh 2 | /u/gpu_user/.conda/envs/tuning/lib/python3.9/site-packages/caikit/core/toolkit/errors/__init__.py:29: DeprecationWarning: The caikit.toolkit.errors package has moved to caikit.core.exceptions 3 | _warnings.warn( 4 | is still in the BETA phase and subject to change! 5 | /u/gpu_user/.conda/envs/tuning/lib/python3.9/site-packages/caikit/core/toolkit/error_handler.py:29: DeprecationWarning: The caikit.toolkit.error_handler package has moved to caikit.core.exceptions 6 | _warnings.warn( 7 | Existing model directory found; purging it now. 8 | Experiment Configuration 9 | - Model Name: [/tmp/tu/huggingface/hub/models--llama-2-7b] 10 | |- Inferred Model Resource Type: [] 11 | - Dataset: [glue/rte] 12 | - Number of Epochs: [1] 13 | - Learning Rate: [2e-05] 14 | - Batch Size: [6] 15 | - Output Directory: [/tmp/tu/output/tuning/llama27b] 16 | - Maximum source sequence length: [4096] 17 | - Maximum target sequence length: [1024] 18 | - Gradient accumulation steps: [16] 19 | - Enable evaluation: [False] 20 | - Evaluation metrics: [['rouge']] 21 | - Torch dtype to use for training: [bfloat16] 22 | [Loading the dataset...] 23 | 2023-09-05T18:36:55.174106 [fsspe:DBUG] open file: /u/gpu_user/.cache/huggingface/datasets/glue/rte/1.0.0/dacbe3125aa31d7f70367a07a8a9e72a5a0bfeb5fc42e75c9db75b96da6053ad/dataset_info.json 24 | 2023-09-05T18:36:55.192203 [fsspe:DBUG] open file: /u/gpu_user/.cache/huggingface/datasets/glue/rte/1.0.0/dacbe3125aa31d7f70367a07a8a9e72a5a0bfeb5fc42e75c9db75b96da6053ad/dataset_info.json 25 | [Loading the base model resource...] 26 | Loading checkpoint shards: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 2/2 [00:05<00:00, 2.89s/it] 27 | [Starting the training...] 28 | 2023-09-05T18:37:47.419502 [PEFT_:DBUG] Shuffling enabled? True 29 | 2023-09-05T18:37:47.419666 [PEFT_:DBUG] Shuffling buffer size: 7470 30 | TRAINING ARGS: { 31 | "output_dir": "/tmp", 32 | "per_device_train_batch_size": 6, 33 | "per_device_eval_batch_size": 6, 34 | "num_train_epochs": 1, 35 | "seed": 73, 36 | "do_eval": false, 37 | "learning_rate": 2e-05, 38 | "weight_decay": 0.01, 39 | "save_total_limit": 3, 40 | "push_to_hub": false, 41 | "no_cuda": false, 42 | "remove_unused_columns": false, 43 | "dataloader_pin_memory": false, 44 | "gradient_accumulation_steps": 16, 45 | "eval_accumulation_steps": 16, 46 | "bf16": true 47 | } 48 | 0%| | 0/77 [00:00 is still in the BETA phase and subject to change! 5 | /u/gpu_user/.conda/envs/tuning/lib/python3.9/site-packages/caikit/core/toolkit/error_handler.py:29: DeprecationWarning: The caikit.toolkit.error_handler package has moved to caikit.core.exceptions 6 | _warnings.warn( 7 | Existing model directory found; purging it now. 8 | Experiment Configuration 9 | - Model Name: [/tmp/tu/huggingface/hub/models--llama-2-7b] 10 | |- Inferred Model Resource Type: [] 11 | - Dataset: [glue/rte] 12 | - Number of Epochs: [1] 13 | - Learning Rate: [2e-05] 14 | - Batch Size: [6] 15 | - Output Directory: [/tmp/tu/output/tuning/llama27b] 16 | - Maximum source sequence length: [1024] 17 | - Maximum target sequence length: [1024] 18 | - Gradient accumulation steps: [16] 19 | - Enable evaluation: [False] 20 | - Evaluation metrics: [['rouge']] 21 | - Torch dtype to use for training: [bfloat16] 22 | [Loading the dataset...] 23 | 2023-09-05T18:47:18.075310 [fsspe:DBUG] open file: /u/gpu_user/.cache/huggingface/datasets/glue/rte/1.0.0/dacbe3125aa31d7f70367a07a8a9e72a5a0bfeb5fc42e75c9db75b96da6053ad/dataset_info.json 24 | 2023-09-05T18:47:18.093371 [fsspe:DBUG] open file: /u/gpu_user/.cache/huggingface/datasets/glue/rte/1.0.0/dacbe3125aa31d7f70367a07a8a9e72a5a0bfeb5fc42e75c9db75b96da6053ad/dataset_info.json 25 | [Loading the base model resource...] 26 | Loading checkpoint shards: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 2/2 [00:05<00:00, 2.76s/it] 27 | [Starting the training...] 28 | 2023-09-05T18:48:09.755222 [PEFT_:DBUG] Shuffling enabled? True 29 | 2023-09-05T18:48:09.755357 [PEFT_:DBUG] Shuffling buffer size: 7470 30 | TRAINING ARGS: { 31 | "output_dir": "/tmp", 32 | "per_device_train_batch_size": 6, 33 | "per_device_eval_batch_size": 6, 34 | "num_train_epochs": 1, 35 | "seed": 73, 36 | "do_eval": false, 37 | "learning_rate": 2e-05, 38 | "weight_decay": 0.01, 39 | "save_total_limit": 3, 40 | "push_to_hub": false, 41 | "no_cuda": false, 42 | "remove_unused_columns": false, 43 | "dataloader_pin_memory": false, 44 | "gradient_accumulation_steps": 16, 45 | "eval_accumulation_steps": 16, 46 | "bf16": true 47 | } 48 | 0%| | 0/77 [00:00 is still in the BETA phase and subject to change! 7 | /u/gpu_user/.conda/envs/tuning/lib/python3.9/site-packages/caikit/core/toolkit/error_handler.py:29: DeprecationWarning: The caikit.toolkit.error_handler package has moved to caikit.core.exceptions 8 | _warnings.warn( 9 | Downloading builder script: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 4.20k/4.20k [00:00<00:00, 4.16MB/s] 10 | Downloading builder script: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 6.60k/6.60k [00:00<00:00, 5.36MB/s] 11 | Downloading builder script: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 6.27k/6.27k [00:00<00:00, 5.52MB/s] 12 | Existing model directory found; purging it now. 13 | Experiment Configuration 14 | - Model Name: [/tmp/tu/huggingface/hub/models--llama-2-7b] 15 | |- Inferred Model Resource Type: [] 16 | - Dataset: [glue/rte] 17 | - Number of Epochs: [1] 18 | - Learning Rate: [2e-05] 19 | - Batch Size: [19] 20 | - Output Directory: [/tmp/tu/output/tuning/llama27b] 21 | - Maximum source sequence length: [128] 22 | - Maximum target sequence length: [1024] 23 | - Gradient accumulation steps: [16] 24 | - Enable evaluation: [False] 25 | - Evaluation metrics: [['rouge']] 26 | - Torch dtype to use for training: [bfloat16] 27 | [Loading the dataset...] 28 | Downloading builder script: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 28.8k/28.8k [00:00<00:00, 15.9MB/s] 29 | Downloading metadata: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 28.7k/28.7k [00:00<00:00, 26.9MB/s] 30 | Downloading readme: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 27.9k/27.9k [00:00<00:00, 22.1MB/s] 31 | Downloading data: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 697k/697k [00:00<00:00, 12.0MB/s] 32 | Generating train split: 0%| | 0/2490 [00:00 is still in the BETA phase and subject to change! 5 | /u/gpu_user/.conda/envs/tuning/lib/python3.9/site-packages/caikit/core/toolkit/error_handler.py:29: DeprecationWarning: The caikit.toolkit.error_handler package has moved to caikit.core.exceptions 6 | _warnings.warn( 7 | Existing model directory found; purging it now. 8 | Experiment Configuration 9 | - Model Name: [/tmp/tu/huggingface/hub/models--llama-2-7b] 10 | |- Inferred Model Resource Type: [] 11 | - Dataset: [glue/rte] 12 | - Number of Epochs: [1] 13 | - Learning Rate: [2e-05] 14 | - Batch Size: [8] 15 | - Output Directory: [/tmp/tu/output/tuning/llama27b] 16 | - Maximum source sequence length: [256] 17 | - Maximum target sequence length: [1024] 18 | - Gradient accumulation steps: [16] 19 | - Enable evaluation: [False] 20 | - Evaluation metrics: [['rouge']] 21 | - Torch dtype to use for training: [bfloat16] 22 | [Loading the dataset...] 23 | 2023-09-05T19:40:43.686785 [fsspe:DBUG] open file: /u/gpu_user/.cache/huggingface/datasets/glue/rte/1.0.0/dacbe3125aa31d7f70367a07a8a9e72a5a0bfeb5fc42e75c9db75b96da6053ad/dataset_info.json 24 | 2023-09-05T19:40:43.702480 [fsspe:DBUG] open file: /u/gpu_user/.cache/huggingface/datasets/glue/rte/1.0.0/dacbe3125aa31d7f70367a07a8a9e72a5a0bfeb5fc42e75c9db75b96da6053ad/dataset_info.json 25 | [Loading the base model resource...] 26 | Loading checkpoint shards: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 2/2 [00:05<00:00, 2.73s/it] 27 | [Starting the training...] 28 | 2023-09-05T19:41:33.062266 [PEFT_:DBUG] Shuffling enabled? True 29 | 2023-09-05T19:41:33.062427 [PEFT_:DBUG] Shuffling buffer size: 7470 30 | TRAINING ARGS: { 31 | "output_dir": "/tmp", 32 | "per_device_train_batch_size": 8, 33 | "per_device_eval_batch_size": 8, 34 | "num_train_epochs": 1, 35 | "seed": 73, 36 | "do_eval": false, 37 | "learning_rate": 2e-05, 38 | "weight_decay": 0.01, 39 | "save_total_limit": 3, 40 | "push_to_hub": false, 41 | "no_cuda": false, 42 | "remove_unused_columns": false, 43 | "dataloader_pin_memory": false, 44 | "gradient_accumulation_steps": 16, 45 | "eval_accumulation_steps": 16, 46 | "bf16": true 47 | } 48 | 0%| | 0/58 [00:00 is still in the BETA phase and subject to change! 5 | /u/gpu_user/.conda/envs/tuning/lib/python3.9/site-packages/caikit/core/toolkit/error_handler.py:29: DeprecationWarning: The caikit.toolkit.error_handler package has moved to caikit.core.exceptions 6 | _warnings.warn( 7 | Existing model directory found; purging it now. 8 | Experiment Configuration 9 | - Model Name: [/tmp/tu/huggingface/hub/models--llama-2-7b] 10 | |- Inferred Model Resource Type: [] 11 | - Dataset: [glue/rte] 12 | - Number of Epochs: [1] 13 | - Learning Rate: [2e-05] 14 | - Batch Size: [6] 15 | - Output Directory: [/tmp/tu/output/tuning/llama27b] 16 | - Maximum source sequence length: [512] 17 | - Maximum target sequence length: [1024] 18 | - Gradient accumulation steps: [16] 19 | - Enable evaluation: [False] 20 | - Evaluation metrics: [['rouge']] 21 | - Torch dtype to use for training: [bfloat16] 22 | [Loading the dataset...] 23 | 2023-09-06T13:51:21.128309 [fsspe:DBUG] open file: /u/gpu_user/.cache/huggingface/datasets/glue/rte/1.0.0/dacbe3125aa31d7f70367a07a8a9e72a5a0bfeb5fc42e75c9db75b96da6053ad/dataset_info.json 24 | 2023-09-06T13:51:21.146717 [fsspe:DBUG] open file: /u/gpu_user/.cache/huggingface/datasets/glue/rte/1.0.0/dacbe3125aa31d7f70367a07a8a9e72a5a0bfeb5fc42e75c9db75b96da6053ad/dataset_info.json 25 | [Loading the base model resource...] 26 | Loading checkpoint shards: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 2/2 [00:05<00:00, 2.79s/it] 27 | [Starting the training...] 28 | 2023-09-06T13:52:11.307381 [PEFT_:DBUG] Shuffling enabled? True 29 | 2023-09-06T13:52:11.307508 [PEFT_:DBUG] Shuffling buffer size: 7470 30 | TRAINING ARGS: { 31 | "output_dir": "/tmp", 32 | "per_device_train_batch_size": 6, 33 | "per_device_eval_batch_size": 6, 34 | "num_train_epochs": 1, 35 | "seed": 73, 36 | "do_eval": false, 37 | "learning_rate": 2e-05, 38 | "weight_decay": 0.01, 39 | "save_total_limit": 3, 40 | "push_to_hub": false, 41 | "no_cuda": false, 42 | "remove_unused_columns": false, 43 | "dataloader_pin_memory": false, 44 | "gradient_accumulation_steps": 16, 45 | "eval_accumulation_steps": 16, 46 | "bf16": true 47 | } 48 | 0%| | 0/77 [00:00", str, local_initializer_name=local_initializer_name 72 | ) 73 | error.type_check( 74 | "", 75 | int, 76 | tgis_backend_priority=tgis_backend_priority, 77 | allow_none=True, 78 | ) 79 | 80 | # Extract the TGIS backend instance 81 | local_initializer = MODEL_MANAGER.get_initializer(local_initializer_name) 82 | backends = local_initializer.backends 83 | if tgis_backend_priority is not None: 84 | error.value_check( 85 | "", 86 | 0 <= tgis_backend_priority < len(backends), 87 | "Invalid {}: {}", 88 | self._TGIS_BACKEND_PRIORITY_KEY, 89 | tgis_backend_priority, 90 | ) 91 | self._tgis_backend = backends[tgis_backend_priority] 92 | error.value_check( 93 | "", 94 | self._tgis_backend.backend_type == TGISBackend.backend_type, 95 | "Index {} is not a TGIS backend", 96 | tgis_backend_priority, 97 | ) 98 | else: 99 | tgis_backend = None 100 | for backend in backends: 101 | if backend.backend_type == TGISBackend.backend_type: 102 | tgis_backend = backend 103 | break 104 | error.value_check( 105 | "", 106 | tgis_backend is not None, 107 | "No TGIS backend found!", 108 | ) 109 | self._tgis_backend = tgis_backend 110 | 111 | def find_model( 112 | self, 113 | model_path: str, 114 | **kwargs, 115 | ) -> Optional[ModuleConfig]: 116 | """Find the model if""" 117 | 118 | # Get a connection to this model in tgis 119 | log.debug2("Attempting to setup TGIS client for %s", model_path) 120 | if self._tgis_backend.get_connection(model_id=model_path) is None: 121 | log.debug2("TGIS cannot connect to model %s", model_path) 122 | return None 123 | 124 | # If connection is ok, set up the module config to point to the remote 125 | # TGIS text generation module 126 | cfg = ModuleConfig( 127 | { 128 | "module_id": TextGenerationTGIS.MODULE_ID, 129 | "module_class": TextGenerationTGIS.MODULE_CLASS, 130 | "name": TextGenerationTGIS.MODULE_NAME, 131 | "version": TextGenerationTGIS.MODULE_VERSION, 132 | "model_name": model_path, 133 | } 134 | ) 135 | # Set a special indicator in the module config to use the backend that 136 | # this finder found. This will override the backend found by the local 137 | # initializer. 138 | cfg.tgis_backend = self._tgis_backend 139 | return cfg 140 | 141 | 142 | model_finder_factory.register(TGISAutoFinder) 143 | -------------------------------------------------------------------------------- /caikit_nlp/modules/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright The Caikit Authors 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # Local 15 | from . import ( 16 | text_classification, 17 | text_embedding, 18 | text_generation, 19 | token_classification, 20 | tokenization, 21 | ) 22 | -------------------------------------------------------------------------------- /caikit_nlp/modules/text_classification/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright The Caikit Authors 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | # Local 16 | from .sequence_classification import SequenceClassification 17 | -------------------------------------------------------------------------------- /caikit_nlp/modules/text_classification/sequence_classification.py: -------------------------------------------------------------------------------- 1 | # Copyright The Caikit Authors 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | """This module contains sequence classification, compatible with transformer 15 | SequenceClassification modules. At this time this module is only designed 16 | for inference""" 17 | # Standard 18 | from typing import Dict, List, Union 19 | 20 | # Third Party 21 | import torch 22 | 23 | # First Party 24 | from caikit.core.exceptions import error_handler 25 | from caikit.core.modules import ModuleBase, ModuleLoader, ModuleSaver, module 26 | from caikit.interfaces.nlp.data_model import ClassificationResult, ClassificationResults 27 | from caikit.interfaces.nlp.tasks import TextClassificationTask 28 | import alog 29 | 30 | # Local 31 | from ...resources.pretrained_model.hf_auto_seq_classifier import ( 32 | HFAutoSequenceClassifier, 33 | ) 34 | 35 | log = alog.use_channel("SEQ_CLASS") 36 | error = error_handler.get(log) 37 | 38 | 39 | @module( 40 | id="d21107ca-d579-4321-aedd-3099a526e0dd", 41 | name="Sequence classification", 42 | version="0.1.0", 43 | task=TextClassificationTask, 44 | ) 45 | class SequenceClassification(ModuleBase): 46 | 47 | ################################ Constructor ################################################# 48 | 49 | def __init__( 50 | self, 51 | resource: HFAutoSequenceClassifier, 52 | ): 53 | super().__init__() 54 | error.type_check( 55 | "", 56 | HFAutoSequenceClassifier, 57 | resource=resource, 58 | ) 59 | self.resource = resource 60 | self.tokenizer = resource.tokenizer 61 | self.model = resource.model 62 | 63 | ################################## API functions ############################################# 64 | 65 | def run(self, text: str) -> ClassificationResults: 66 | """Run the sequence classification. 67 | NOTE: This will truncate sequences that are too long for model 68 | 69 | Args: 70 | text: str 71 | Input string to be classified 72 | 73 | Returns: 74 | ClassificationResults 75 | """ 76 | scores_dict = self._get_scores(text) 77 | # Re-organize scores_dict - for one text, this is just the first score 78 | return SequenceClassification._process_predictions(scores_dict, text_idx=0) 79 | 80 | def run_batch(self, texts: List[str]) -> List[ClassificationResults]: 81 | """Run the sequence classification on batch, truncates sequences too long for model 82 | 83 | Args: 84 | text: List[str] 85 | Input strings to be classified 86 | 87 | Returns: 88 | List[ClassificationResults] 89 | """ 90 | scores_dict = self._get_scores(texts) 91 | num_texts = len(texts) 92 | 93 | # Re-organize scores_dict for each example 94 | # We could eventually consider whether or not to sort classifications by scores 95 | # but avoiding this prescription here for now 96 | classification_predictions = [] 97 | for text_idx in range(num_texts): 98 | classification_prediction = SequenceClassification._process_predictions( 99 | scores_dict, text_idx 100 | ) 101 | classification_predictions.append(classification_prediction) 102 | return classification_predictions 103 | 104 | def save(self, model_path: str): 105 | """Save model in target path 106 | 107 | Args: 108 | model_path: str 109 | Path to store model artifact(s) 110 | """ 111 | module_saver = ModuleSaver( 112 | self, 113 | model_path=model_path, 114 | ) 115 | 116 | with module_saver: 117 | module_saver.save_module(self.resource, "sequence_classifier") 118 | 119 | @classmethod 120 | def load(cls, model_path: str) -> "SequenceClassification": 121 | """Load a sequence classification model 122 | 123 | Args: 124 | model_path: str 125 | Path to the model to be loaded. 126 | 127 | Returns: 128 | SequenceClassification 129 | Instance of this class built from the on disk model. 130 | """ 131 | loader = ModuleLoader(model_path) 132 | resource = loader.load_module("sequence_classifier") 133 | return cls(resource=resource) 134 | 135 | @classmethod 136 | def bootstrap(cls, base_model_path: str) -> "SequenceClassification": 137 | """Bootstrap a HuggingFace transformer-based sequence classification model 138 | 139 | Args: 140 | base_model_path: str 141 | Path to the model to be loaded. 142 | """ 143 | # Note: Must provide path to tokenizer if model_name is a path 144 | # for resource use 145 | resource = HFAutoSequenceClassifier.bootstrap( 146 | model_name=base_model_path, tokenizer_name=base_model_path 147 | ) 148 | return cls( 149 | resource=resource, 150 | ) 151 | 152 | ################################## Private Functions ######################################### 153 | 154 | def _get_scores(self, text: Union[str, List[str]]): 155 | """Run tokenizer and model to get scores on text(s) 156 | 157 | Args: 158 | text: Union[str, List[str]] 159 | Input string(s) to be used 160 | 161 | Returns: 162 | scores_dict 163 | Dict with key label, and values as the array of scores, 164 | each corresponding to text(s) 165 | """ 166 | # NOTE: no explicit GPU support at this time 167 | 168 | # Apply tokenizer 169 | tokenized_text = self.tokenizer( 170 | text, 171 | padding="max_length", 172 | truncation=True, 173 | return_tensors="pt", # PyTorch 174 | ) 175 | # NOTE: no truncation warning given at this time 176 | # difficult to detect if truncation happened since padding also occurs 177 | with torch.no_grad(): 178 | logits = self.model(**tokenized_text).logits 179 | 180 | if not self.model.config.id2label: 181 | log.warning( 182 | "", 183 | "No id2label provided in model config. Defaulting to numeric labels", 184 | ) 185 | 186 | softmax = torch.nn.Softmax(dim=1) 187 | raw_scores = softmax(logits) 188 | scores = raw_scores.double().numpy() 189 | num_labels = self.model.num_labels 190 | num_texts = 1 # str 191 | if isinstance(text, List): 192 | num_texts = len(text) 193 | error.value_check( 194 | "", 195 | scores.shape == (num_texts, num_labels), 196 | "model logits expected to be of shape (num_texts, num_labels)", 197 | ) 198 | 199 | scores_dict = {} 200 | for label_idx in range(num_labels): 201 | if self.model.config.id2label: 202 | label = self.model.config.id2label[label_idx] 203 | else: 204 | label = label_idx 205 | label_scores = scores[:, label_idx] 206 | scores_dict[label] = label_scores 207 | return scores_dict 208 | 209 | @staticmethod 210 | def _process_predictions(scores_dict: Dict, text_idx: int) -> ClassificationResults: 211 | """Process dictionary of label: scores to ClassificationResults 212 | 213 | Args: 214 | scores_dict: Dict 215 | Dict with key label, and values as the array of scores, 216 | each corresponding to text(s) 217 | text_idx: int 218 | Integer index of text in batch 219 | 220 | Returns: 221 | ClassificationResults 222 | """ 223 | error.type_check("", Dict, scores_dict=scores_dict) 224 | classification_list = [] 225 | for label, score_array in scores_dict.items(): 226 | # NOTE: labels are expected to be str, especially for config 227 | classification_list.append( 228 | ClassificationResult(label=str(label), score=score_array[text_idx]) 229 | ) 230 | return ClassificationResults(results=classification_list) 231 | -------------------------------------------------------------------------------- /caikit_nlp/modules/text_embedding/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright The Caikit Authors 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """ 16 | Text Embedding Module 17 | ===================== 18 | 19 | Implements the following tasks: 20 | 21 | 1. EmbeddingTask: Returns an embedding from an input text string 22 | 2. EmbeddingsTasks: EmbeddingTask but with a list of inputs producing a list of outputs 23 | 3. SentenceSimilarityTask: Compare one source sentence to a list of sentences 24 | 4. SentenceSimilarityTasks: SentenceSimilarityTask but with a list of source sentences producing 25 | a list of outputs 26 | 5. RerankTask: Return top_n documents ordered by relevance given a query 27 | 6. RerankTasks: RerankTask but with a list of queries producing a list of outputs 28 | 29 | """ 30 | 31 | # Local 32 | from .crossencoder import CrossEncoderModule 33 | from .embedding import EmbeddingModule 34 | -------------------------------------------------------------------------------- /caikit_nlp/modules/text_embedding/utils.py: -------------------------------------------------------------------------------- 1 | # Copyright The Caikit Authors 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | 16 | def env_val_to_bool(val): 17 | """Returns the bool value of env var""" 18 | if val is None: 19 | return False 20 | if isinstance(val, bool): 21 | return val 22 | 23 | # For testing env vars for values that mean false (else True!) 24 | return str(val).lower().strip() not in ("no", "n", "false", "0", "f", "off", "") 25 | -------------------------------------------------------------------------------- /caikit_nlp/modules/text_generation/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright The Caikit Authors 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | # Local 16 | from .peft_prompt_tuning import PeftPromptTuning 17 | from .peft_tgis_remote import PeftPromptTuningTGIS 18 | from .text_generation_local import TextGeneration 19 | from .text_generation_tgis import TextGenerationTGIS 20 | -------------------------------------------------------------------------------- /caikit_nlp/modules/text_generation/peft_config.py: -------------------------------------------------------------------------------- 1 | # Copyright The Caikit Authors 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | # Standard 16 | from enum import Enum 17 | import os 18 | import re 19 | 20 | # Third Party 21 | from peft import MultitaskPromptTuningInit 22 | from transformers import AutoConfig 23 | 24 | # First Party 25 | from caikit import get_config 26 | from caikit.core import error_handler 27 | import alog 28 | 29 | # Local 30 | from ...data_model import PromptOutputModelType 31 | from ...resources.pretrained_model import PretrainedModelBase 32 | from ...toolkit.data_type_utils import get_torch_dtype 33 | from ...toolkit.verbalizer_utils import is_valid_verbalizer 34 | 35 | # NOTE: We do not allow all the methods exposed by MPT / PT, such as `EXACT_SOURCE_TASK` 36 | # since those are for experimental use and would not be useful / applicable 37 | # for end-user use-cases 38 | allowed_tuning_init_methods = [ 39 | "TEXT", 40 | "RANDOM", 41 | "ONLY_SOURCE_SHARED", 42 | "AVERAGE_SOURCE_TASKS", 43 | ] 44 | 45 | log = alog.use_channel("PFT_CNFG_TLKT") 46 | error = error_handler.get(log) 47 | 48 | SOURCE_DIR_VALIDATION_REGEX = re.compile(r"^[-a-zA-Z_0-9\/\.]+") 49 | # 🤮 FIXME: This two dot regex is added as a way to avoid expressions like .. 50 | # giving access to un-intended directories. But this is an ugly hack 51 | # and we need to figure out better solution or better regex 52 | TWO_DOTS_REGEX = re.compile(r"(\.\.)+") 53 | 54 | 55 | class TuningType(str, Enum): 56 | PROMPT_TUNING = "PROMPT_TUNING" 57 | MULTITASK_PROMPT_TUNING = "MULTITASK_PROMPT_TUNING" 58 | # MULTITASK_PREFIX_TUNING = "MULTITASK_PREFIX_TUNING" 59 | # P_TUNING = "P_TUNING" 60 | # PREFIX_TUNING = "PREFIX_TUNING" 61 | # LORA = "LORA" 62 | 63 | 64 | def resolve_base_model(base_model, cls, torch_dtype): 65 | if isinstance(base_model, str): 66 | 67 | error.value_check( 68 | "", 69 | re.fullmatch(SOURCE_DIR_VALIDATION_REGEX, base_model) 70 | and not re.search(TWO_DOTS_REGEX, base_model), 71 | "invalid characters in base_model name", 72 | ) 73 | if get_config().base_models_dir: 74 | 75 | base_model_full_path = os.path.join( 76 | get_config().base_models_dir, base_model 77 | ) 78 | if os.path.exists(base_model_full_path): 79 | base_model = base_model_full_path 80 | 81 | model_config = AutoConfig.from_pretrained( 82 | base_model, local_files_only=not get_config().allow_downloads 83 | ) 84 | 85 | resource_type = None 86 | for resource in cls.supported_resources: 87 | if model_config.model_type in resource.SUPPORTED_MODEL_TYPES: 88 | resource_type = resource 89 | break 90 | 91 | if not resource_type: 92 | error( 93 | "", 94 | "{} model type is not supported currently!".format( 95 | model_config.model_type 96 | ), 97 | ) 98 | log.debug("Bootstrapping base resource [%s]", base_model) 99 | base_model = resource_type.bootstrap(base_model, torch_dtype=torch_dtype) 100 | return base_model 101 | 102 | 103 | def get_peft_config( 104 | tuning_type, tuning_config, base_model, cls, torch_dtype, verbalizer 105 | ): 106 | 107 | if tuning_type not in TuningType._member_names_: 108 | raise NotImplementedError("{} tuning type not supported!".format(tuning_type)) 109 | 110 | if tuning_config.prompt_tuning_init_method: 111 | # NOTE: GK-APR-5-2023 112 | # MultitaskPromptTuningInit and MultitaskPrefixTuningInit are same at the 113 | # time of writing, which is a superset of PromptTuningInit 114 | init_method = tuning_config.prompt_tuning_init_method 115 | 116 | error.value_check( 117 | "", 118 | init_method in allowed_tuning_init_methods, 119 | f"Init method [{init_method}] not in allowed init methods: " 120 | f"[{allowed_tuning_init_methods}]", 121 | ) 122 | 123 | init_method = MultitaskPromptTuningInit(init_method) 124 | log.info("Using initialization method [%s]", init_method) 125 | 126 | # If init method provided relates to one that requires source model, 127 | # make sure the source prompt model is provided. 128 | if init_method in [ 129 | MultitaskPromptTuningInit.AVERAGE_SOURCE_TASKS, 130 | MultitaskPromptTuningInit.ONLY_SOURCE_SHARED, 131 | ]: 132 | # NOTE: prompt_tuning_init_source_model is currently a path. In future 133 | # we will replace this with caikit.resources to properly cataloging these 134 | error.type_check( 135 | "", 136 | str, 137 | prompt_tuning_init_source_model=tuning_config.prompt_tuning_init_source_model, 138 | ) 139 | tuning_config.prompt_tuning_init_source_model = os.path.join( 140 | get_config().source_prompt_base, 141 | tuning_config.prompt_tuning_init_source_model, 142 | ) 143 | 144 | error.file_check( 145 | "", tuning_config.prompt_tuning_init_source_model 146 | ) 147 | log.debug( 148 | "Validated tuning source prompt [%s]", 149 | tuning_config.prompt_tuning_init_source_model, 150 | ) 151 | 152 | error.type_check("", PretrainedModelBase, base_model=base_model) 153 | 154 | # Validate if tuned output model type is compatible with base model or not 155 | if not tuning_config.output_model_types: 156 | output_model_types = base_model.PROMPT_OUTPUT_TYPES 157 | else: 158 | # If the first element is not PromptOutputModelType, assume the entire list 159 | # isn't and convert 160 | if not isinstance(tuning_config.output_model_types[0], PromptOutputModelType): 161 | output_model_types = [] 162 | for output_type in tuning_config.output_model_types: 163 | output_model_types.append(PromptOutputModelType(output_type)) 164 | else: 165 | output_model_types = tuning_config.output_model_types 166 | error.value_check( 167 | "", 168 | all( 169 | output_type in base_model.PROMPT_OUTPUT_TYPES 170 | for output_type in output_model_types 171 | ), 172 | "{} not supported for base model type {}".format( 173 | output_model_types, base_model.MODEL_TYPE 174 | ), 175 | ) 176 | 177 | error.value_check( 178 | "", 179 | len(output_model_types) <= base_model.MAX_NUM_TRANSFORMERS, 180 | f"Too many output model types. Got {len(output_model_types)}, " 181 | f"maximum {base_model.MAX_NUM_TRANSFORMERS}", 182 | ) 183 | # Ensure that our verbalizer is a string and will not render to a hardcoded string 184 | error.value_check( 185 | "", 186 | is_valid_verbalizer(verbalizer), 187 | "Provided verbalizer is an invalid type or has no renderable placeholders", 188 | ) 189 | 190 | # NOTE: Base model is a resource at this point 191 | task_type = base_model.TASK_TYPE 192 | 193 | if isinstance(tuning_type, str): 194 | error.value_check( 195 | "", 196 | tuning_type in TuningType._member_names_, 197 | f"Invalid tuning type [{tuning_type}]. Allowed types: " 198 | f"[{TuningType._member_names_}]", 199 | ) 200 | tuning_type = TuningType(tuning_type) 201 | error.type_check("", TuningType, tuning_type=tuning_type) 202 | 203 | # Coerce the passed model into a resource; if we have one, this is a noop 204 | # TODO: When splitting up this mono-module, use the configured resource 205 | # type of the concrete class to bootstrap 206 | torch_dtype = get_torch_dtype(torch_dtype) 207 | 208 | # Take tokenizer name/path from the model 209 | tokenizer_name_or_path = base_model.model.config._name_or_path 210 | 211 | # Build the peft config; this is how we determine that we want a sequence classifier. 212 | # If we want more types, we will likely need to map this to data model outputs etc. 213 | 214 | # NOTE: We currently only support TEXT as init type, this is to later only easily 215 | # switch to MPT 216 | peft_config = cls.create_hf_tuning_config( 217 | base_model=base_model, 218 | tuning_type=tuning_type, 219 | task_type=task_type, 220 | tokenizer_name_or_path=tokenizer_name_or_path, 221 | tuning_config=tuning_config, 222 | output_model_types=output_model_types, 223 | ) 224 | 225 | return task_type, output_model_types, peft_config, tuning_type 226 | -------------------------------------------------------------------------------- /caikit_nlp/modules/token_classification/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright The Caikit Authors 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | # Local 16 | from .filtered_span_classification import FilteredSpanClassification 17 | -------------------------------------------------------------------------------- /caikit_nlp/modules/tokenization/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright The Caikit Authors 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | # Local 16 | from .regex_sentence_splitter import RegexSentenceSplitter 17 | -------------------------------------------------------------------------------- /caikit_nlp/modules/tokenization/regex_sentence_splitter.py: -------------------------------------------------------------------------------- 1 | # Copyright The Caikit Authors 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | """Module that provides capability to split documents into sentences via regex""" 15 | 16 | # Standard 17 | import os 18 | import re 19 | 20 | # First Party 21 | from caikit.core.exceptions import error_handler 22 | from caikit.core.modules import ModuleBase, ModuleConfig, ModuleSaver, module 23 | from caikit.interfaces.nlp.data_model import Token, TokenizationResults 24 | from caikit.interfaces.nlp.tasks import TokenizationTask 25 | import alog 26 | 27 | log = alog.use_channel("RGX_SNT_SPLT") 28 | error = error_handler.get(log) 29 | 30 | 31 | @module( 32 | id="1e04e21b-7009-499e-abdd-41e5984c2e7d", 33 | name="Regex Sentence Splitter", 34 | version="0.1.0", 35 | task=TokenizationTask, 36 | ) 37 | class RegexSentenceSplitter(ModuleBase): 38 | # pylint: disable=anomalous-backslash-in-string 39 | """Use python regexes to split document into sentences. 40 | 41 | Sample regex string: 42 | [^.!?\s][^.!?\n]*(?:[.!?](?!['\"]?\s|$)[^.!?]*)*[.!?]?['\"]?(?=\s|$) 43 | """ 44 | 45 | def __init__(self, regex_str: str): 46 | """Construct a RegexSentenceSplitter object 47 | by compiling the input regex string into python regex object 48 | that can be used later on for detection. 49 | 50 | Args: 51 | regex_str: str 52 | String containing pattern that can be complied with python re 53 | module 54 | """ 55 | super().__init__() 56 | error.type_check("", str, regex_str=regex_str) 57 | self.regex_str = regex_str 58 | self.regex = re.compile(self.regex_str) 59 | 60 | @classmethod 61 | def bootstrap(cls, regex_str) -> "RegexSentenceSplitter": 62 | """Bootstrap a a RegexSentenceSplitter object 63 | 64 | Args: 65 | regex_str: str 66 | String containing pattern that can be complied with python re 67 | module 68 | """ 69 | return cls(regex_str) 70 | 71 | def save(self, model_path: str): 72 | """Save model in target path 73 | 74 | Args: 75 | model_path: str 76 | Path to store model artifact(s) 77 | """ 78 | module_saver = ModuleSaver( 79 | self, 80 | model_path=model_path, 81 | ) 82 | with module_saver: 83 | config_options = {"regex_str": self.regex_str} 84 | module_saver.update_config(config_options) 85 | 86 | @classmethod 87 | def load(cls, model_path: str) -> "RegexSentenceSplitter": 88 | """Load a regex sentence splitter model. 89 | 90 | Args: 91 | model_path: str 92 | Path to the model to be loaded. 93 | 94 | Returns: 95 | RegexSentenceSplitter 96 | Instance of this class built from the on disk model. 97 | """ 98 | config = ModuleConfig.load(os.path.abspath(model_path)) 99 | return cls(regex_str=config.regex_str) 100 | 101 | def run(self, text: str) -> TokenizationResults: 102 | """Run sentence splitting regex on input text. 103 | 104 | Args: 105 | text: str 106 | Document to run sentence splitting on. 107 | 108 | Returns: 109 | TokenizationResults 110 | TokenizationResults object containing tokens where each token 111 | corresponds to a detected sentence. 112 | """ 113 | 114 | error.type_check("", str, text=text) 115 | 116 | matches = self.regex.finditer(text) 117 | tokens = [] 118 | for match in matches: 119 | token = Token(start=match.start(), end=match.end(), text=match.group()) 120 | tokens.append(token) 121 | 122 | return TokenizationResults(results=tokens) 123 | -------------------------------------------------------------------------------- /caikit_nlp/resources/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright The Caikit Authors 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | """ 15 | Shared resource implementations used in prompt tuning jobs 16 | """ 17 | 18 | # Local 19 | from . import pretrained_model 20 | -------------------------------------------------------------------------------- /caikit_nlp/resources/pretrained_model/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright The Caikit Authors 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | """ 15 | Resources holding pretrained models of various types 16 | """ 17 | 18 | # Local 19 | from .base import PretrainedModelBase 20 | from .hf_auto_causal_lm import HFAutoCausalLM 21 | from .hf_auto_seq2seq_lm import HFAutoSeq2SeqLM 22 | from .hf_auto_seq_classifier import HFAutoSequenceClassifier 23 | -------------------------------------------------------------------------------- /caikit_nlp/resources/pretrained_model/hf_auto_seq2seq_lm.py: -------------------------------------------------------------------------------- 1 | # Copyright The Caikit Authors 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | """ 15 | Huggingface auto causal LM resource type 16 | """ 17 | # Standard 18 | from collections.abc import Mapping 19 | from typing import Dict, List, Optional, Union 20 | 21 | # Third Party 22 | from torch.utils.data import IterableDataset 23 | from transformers import ( 24 | AutoModelForSeq2SeqLM, 25 | DataCollatorForSeq2Seq, 26 | Seq2SeqTrainer, 27 | Seq2SeqTrainingArguments, 28 | ) 29 | from transformers.models.auto import modeling_auto 30 | 31 | # First Party 32 | from caikit.core.exceptions import error_handler 33 | from caikit.core.modules import module 34 | import alog 35 | 36 | # Local 37 | from ...data_model import GenerationTrainRecord, PromptOutputModelType 38 | from ...toolkit.trainer_utils import log_step 39 | from ...toolkit.verbalizer_utils import render_verbalizer 40 | from .base import PretrainedModelBase 41 | 42 | log = alog.use_channel("HFRBAS") 43 | error = error_handler.get(log) 44 | 45 | IGNORE_ID = -100 46 | 47 | 48 | class LoggingTrainer(Seq2SeqTrainer): 49 | # pylint: disable=unused-argument 50 | def log(self, logs: Dict[str, float], start_time: Optional[float] = None) -> None: 51 | """ 52 | Log `logs` on the various objects watching training. 53 | 54 | Subclass and override this method to inject custom behavior. 55 | 56 | Args: 57 | logs (`Dict[str, float]`): 58 | The values to log. 59 | """ 60 | # start_time was added in default trainer log 61 | # https://github.com/huggingface/transformers/pull/34507 62 | self.state = log_step(self.state, logs) 63 | self.control = self.callback_handler.on_log( 64 | self.args, self.state, self.control, logs 65 | ) 66 | 67 | 68 | @module( 69 | id="6759e891-287b-405b-bd8b-54a4a4d51c25", 70 | name="HF Transformers Auto Seq2Seq LM", 71 | version="0.1.0", 72 | ) 73 | class HFAutoSeq2SeqLM(PretrainedModelBase): 74 | """This resource (module) wraps a handle to a Huggingface 75 | AutoModelForSeq2SeqLM 76 | """ 77 | 78 | MODEL_TYPE = AutoModelForSeq2SeqLM 79 | SUPPORTED_MODEL_TYPES = modeling_auto.MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING_NAMES 80 | TASK_TYPE = "SEQ_2_SEQ_LM" 81 | PROMPT_OUTPUT_TYPES = [PromptOutputModelType.ENCODER] 82 | MAX_NUM_TRANSFORMERS = 2 83 | 84 | @classmethod 85 | def get_num_transformers_submodules( 86 | cls, output_model_types: List[PromptOutputModelType] 87 | ): 88 | """Return number of applicable transformer submodules""" 89 | num_transformer_submodules = 0 90 | if PromptOutputModelType.ENCODER in output_model_types: 91 | num_transformer_submodules += 1 92 | if PromptOutputModelType.DECODER in output_model_types: 93 | num_transformer_submodules += 1 94 | error.value_check( 95 | "", 0 < num_transformer_submodules <= cls.MAX_NUM_TRANSFORMERS 96 | ) 97 | return num_transformer_submodules 98 | 99 | def get_trainer( 100 | self, 101 | train_dataset: IterableDataset, 102 | eval_dataset: Union[IterableDataset, None] = None, 103 | optimizers=(None, None), 104 | **kwargs 105 | ): 106 | """ 107 | Args: 108 | *kwargs: arguments supported by HF Seq2SeqTrainingArguments: 109 | https://huggingface.co/docs/transformers/v4.30.0/en/main_classes/trainer#transformers.Seq2SeqTrainingArguments 110 | 111 | NOTE: following parameters are not supported currently: 112 | 1. model_init 113 | 2. compute_metrics 114 | 3. callbacks 115 | 4. preprocess_logits_for_metrics 116 | """ 117 | 118 | # NOTE: predict_with_generate is incompatible with fsdp 119 | training_args = Seq2SeqTrainingArguments(**kwargs) 120 | 121 | # pylint: disable=duplicate-code 122 | # TODO: Fetch DataCollator either from property of this 123 | # class or fetch it as an argument. 124 | data_collator = self._get_data_collator(**kwargs) 125 | 126 | trainer_arguments = { 127 | "train_dataset": train_dataset, 128 | "data_collator": data_collator, 129 | "optimizers": optimizers, 130 | "eval_dataset": eval_dataset, 131 | # "generation_max_length": max_target_length, 132 | } 133 | 134 | return LoggingTrainer(self._model, training_args, **trainer_arguments) 135 | 136 | def _get_data_collator(self, **kwargs): 137 | """Function to return appropriate data collator based on resource. 138 | 139 | This implementation uses DataCollatorForSeq2Seq 140 | 141 | Args: 142 | **kwargs: 143 | All the keyword arguments passed to this function 144 | will get filtered out to appropriate ones that are 145 | applicable to implemented data collator. 146 | Returns: 147 | transformers.DataCollator 148 | """ 149 | 150 | applicable_args = ["max_length", "pad_to_multiple_of"] 151 | collator_kwargs = {key: kwargs[key] for key in applicable_args if key in kwargs} 152 | 153 | return DataCollatorForSeq2Seq( 154 | tokenizer=self._tokenizer, model=self._model, **collator_kwargs 155 | ) 156 | 157 | @classmethod 158 | def tokenize_function( 159 | cls, 160 | example: Union[GenerationTrainRecord, Mapping], 161 | tokenizer: "AutoTokenizer", 162 | max_source_length: int, 163 | max_target_length: int, 164 | verbalizer: Union[None, str] = None, 165 | task_ids: Union[None, int] = None, 166 | ) -> "BatchEncoding": 167 | """Tokenization function to be used for seq2seq training; this function consumes a 168 | GenerationTrainRecord object and applies the verbalizer to it followed by 169 | the model tokenizer. Finally, we postprocess by ignoring pad tokens in the label IDs. 170 | 171 | Args: 172 | example: Union[GenerationTrainRecord, Mapping] 173 | Training data model object to convert a form we can learn on. 174 | 175 | Returns: 176 | transformers.tokenization_utils_base.BatchEncoding 177 | encoded tokenization output corresponding to the input example. 178 | """ 179 | source, target = cls.decompose_example_io(example) 180 | source = ( 181 | source if verbalizer is None else render_verbalizer(verbalizer, example) 182 | ) 183 | 184 | model_inputs = tokenizer( 185 | source, 186 | max_length=max_source_length, 187 | padding="max_length", 188 | truncation=True, 189 | ) 190 | labels = tokenizer( 191 | target, 192 | max_length=max_target_length, 193 | padding="max_length", 194 | truncation=True, 195 | ) 196 | 197 | labels = labels["input_ids"] 198 | 199 | labels = list( 200 | map(lambda x: IGNORE_ID if x == tokenizer.pad_token_id else x, labels) 201 | ) 202 | model_inputs["labels"] = labels 203 | if task_ids is not None: 204 | model_inputs["task_ids"] = task_ids 205 | 206 | return model_inputs 207 | -------------------------------------------------------------------------------- /caikit_nlp/resources/pretrained_model/hf_auto_seq_classifier.py: -------------------------------------------------------------------------------- 1 | # Copyright The Caikit Authors 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | """ 15 | Huggingface auto sequence classifier resource type 16 | """ 17 | # Standard 18 | from typing import Callable, Tuple 19 | 20 | # Third Party 21 | from transformers import AutoModelForSequenceClassification 22 | from transformers.models.auto import modeling_auto 23 | 24 | # First Party 25 | from caikit.core.modules import module 26 | 27 | # Local 28 | from .base import PretrainedModelBase 29 | 30 | 31 | @module( 32 | id="6759e891-287b-405b-bd8b-54a4a4d51c23", 33 | name="HF Transformers Auto Sequence Classifier", 34 | version="0.1.0", 35 | ) 36 | class HFAutoSequenceClassifier(PretrainedModelBase): 37 | """This resource (module) wraps a handle to a huggingface 38 | AutoModelForSequenceClassification 39 | """ 40 | 41 | MODEL_TYPE = AutoModelForSequenceClassification 42 | SUPPORTED_MODEL_TYPES = ( 43 | modeling_auto.MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING_NAMES 44 | ) 45 | TASK_TYPE = "SEQ_CLS" 46 | PROMPT_OUTPUT_TYPES = [] 47 | MAX_NUM_TRANSFORMERS = 1 48 | 49 | @classmethod 50 | def bootstrap(cls, *args, **kwargs) -> "HFAutoSequenceClassifier": 51 | """Bootstrap from a huggingface model 52 | 53 | See help(PretrainedModelBase) 54 | """ 55 | return super().bootstrap(*args, return_dict=True, **kwargs) 56 | 57 | @staticmethod 58 | def tokenize_function(*args, **kwargs) -> Tuple[Callable, bool]: 59 | raise NotImplementedError( 60 | "Tokenize func not implemented for sequence classifier" 61 | ) 62 | -------------------------------------------------------------------------------- /caikit_nlp/toolkit/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/caikit/caikit-nlp/dc5a205f064cc99b56da9bcb887faaef4d6dc3d9/caikit_nlp/toolkit/__init__.py -------------------------------------------------------------------------------- /caikit_nlp/toolkit/data_stream_wrapper.py: -------------------------------------------------------------------------------- 1 | # Copyright The Caikit Authors 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | """This module contains utility classes for wrapping data streams as Torch datasets efficiently. 15 | Caikit modules leveraging such wrappers can use them to internally leverage common approaches 16 | and objects for training / evaluating PyTorch models built around DataStreams, e.g., PyTorch 17 | DataLoaders, with minimal boilerplate. 18 | """ 19 | 20 | # Third Party 21 | from torch.utils.data import IterableDataset 22 | 23 | # First Party 24 | from caikit.core.exceptions import error_handler 25 | import alog 26 | 27 | log = alog.use_channel("PEFT_PROMPT") 28 | error = error_handler.get(log) 29 | 30 | 31 | class SimpleIterableStreamWrapper(IterableDataset): 32 | """DataStream wrapper as an iterable PyTorch dataset; we use this to add 33 | compatability with PyTorch data loaders. 34 | """ 35 | 36 | def __init__(self, stream, shuffle, buffer_size=None): 37 | error.type_check("", bool, shuffle=shuffle) 38 | error.type_check( 39 | "", int, buffer_size=buffer_size, allow_none=True 40 | ) 41 | self.stream = stream 42 | self.shuffle = shuffle 43 | self.buffer_size = buffer_size 44 | # Load the whole data set in memory 45 | if self.shuffle and buffer_size is None: 46 | self.buffer_size = len(stream) 47 | log.debug("Shuffling enabled? {}".format(self.shuffle)) 48 | log.debug("Shuffling buffer size: {}".format(self.buffer_size)) 49 | 50 | def __iter__(self): 51 | 52 | # FIXME: We are currently not handling case where we have to work with 53 | # multiple workers, so currently duplicate data will get processed by 54 | # each worker. 55 | if self.shuffle: 56 | log.debug4("Reshuffling training data!") 57 | return iter(self.stream.shuffle(self.buffer_size)) 58 | return iter(self.stream) 59 | # worker_info = get_worker_info() 60 | # if worker_info is None: # single-process data loading, return the full iterator 61 | # if self.shuffle: 62 | # log.debug4("Reshuffling training data!") 63 | # return iter(self.stream.shuffle(self.buffer_size)) 64 | # return iter(self.stream) 65 | 66 | # When num_workers > 0, each worker process will have a different copy of 67 | # the dataset object, so we configure each copy independently to avoid 68 | # having duplicate data returned from each worker 69 | # else: # in a worker process 70 | # # split workload 71 | # per_worker = int( 72 | # math.ceil((self.end - self.start) / float(worker_info.num_workers)) 73 | # ) 74 | # worker_id = worker_info.id 75 | # iter_start = self.start + worker_id * per_worker 76 | # iter_end = min(iter_start + per_worker, self.end) 77 | # return iter(range(iter_start, iter_end)) 78 | 79 | def __len__(self): 80 | return len(self.stream) 81 | -------------------------------------------------------------------------------- /caikit_nlp/toolkit/data_type_utils.py: -------------------------------------------------------------------------------- 1 | # Copyright The Caikit Authors 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | # Standard 16 | from typing import Optional, Union 17 | 18 | # Third Party 19 | import torch 20 | 21 | # First Party 22 | from caikit import get_config 23 | from caikit.core.exceptions import error_handler 24 | import alog 25 | 26 | log = alog.use_channel("DATA_UTIL") 27 | error = error_handler.get(log) 28 | 29 | 30 | def str_to_torch_dtype(dtype_str: str) -> torch.dtype: 31 | """Given a string representation of a Torch data type, convert it to the actual torch dtype. 32 | 33 | Args: 34 | dtype_str: String representation of Torch dtype to be used; this should be an attr 35 | of the torch library whose value is a dtype. 36 | 37 | Returns: 38 | torch.dtype 39 | Data type of the Torch class being used. 40 | """ 41 | dt = getattr(torch, dtype_str, None) 42 | if not isinstance(dt, torch.dtype): 43 | error("", ValueError(f"Unrecognized data type: {dtype_str}")) 44 | return dt 45 | 46 | 47 | def get_torch_dtype(dtype: Optional[Union[str, torch.dtype]] = None) -> torch.dtype: 48 | """Get the Torch data type to be used for interacting with a model. 49 | 50 | Args: 51 | dtype: Optional[Union[str, torch.dtype]] 52 | If dtype is a torch.dtype, returns it; if it's a string, grab it from the Torch lib. 53 | If None is provided, fall back to the default type in config, which can be 54 | overridden via environment variable. 55 | 56 | Returns: 57 | torch.dtype 58 | Torch data type to be used. 59 | """ 60 | error.type_check("", torch.dtype, str, dtype=dtype, allow_none=True) 61 | # If a Torch dtype is passed, nothing to do 62 | if isinstance(dtype, torch.dtype): 63 | return dtype 64 | # If None/empty str was provided, fall back to config / env var override 65 | if not dtype: 66 | return str_to_torch_dtype(get_config().torch_dtype) 67 | # Otherwise convert it from a string 68 | return str_to_torch_dtype(dtype) 69 | -------------------------------------------------------------------------------- /caikit_nlp/toolkit/task_specific_utils.py: -------------------------------------------------------------------------------- 1 | # Copyright The Caikit Authors 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | # First Party 16 | from caikit.core.exceptions import error_handler 17 | from caikit.interfaces.nlp.data_model import ClassificationTrainRecord 18 | import alog 19 | 20 | # Local 21 | from ..data_model import GenerationTrainRecord 22 | 23 | log = alog.use_channel("TASK_UTILS") 24 | error = error_handler.get(log) 25 | 26 | 27 | def convert_to_generation_record(train_record): 28 | if isinstance(train_record, GenerationTrainRecord): 29 | return train_record 30 | if isinstance(train_record, ClassificationTrainRecord): 31 | text = train_record.text 32 | labels = labels = ",".join(str(label) for label in train_record.labels) 33 | return GenerationTrainRecord(input=text, output=labels) 34 | error( 35 | "", 36 | TypeError( 37 | "Unsupported instance type. \ 38 | Only instances of datamodels ClassificationTrainRecord \ 39 | and GenerationTrainRecord are supported" 40 | ), 41 | ) 42 | -------------------------------------------------------------------------------- /caikit_nlp/toolkit/text_generation/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/caikit/caikit-nlp/dc5a205f064cc99b56da9bcb887faaef4d6dc3d9/caikit_nlp/toolkit/text_generation/__init__.py -------------------------------------------------------------------------------- /caikit_nlp/toolkit/torch_run.py: -------------------------------------------------------------------------------- 1 | # Copyright The Caikit Authors 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """This toolkit utility contains functions to operate on torch distribution 16 | 17 | NOTE: Content of this file are heavily influenced by torch/distributed/run.py 18 | 19 | Ref: https://github.com/pytorch/pytorch/blob/main/torch/distributed/run.py 20 | """ 21 | 22 | # Standard 23 | import os 24 | 25 | # Third Party 26 | from torch import cuda 27 | from torch.distributed.launcher.api import LaunchConfig 28 | import torch.distributed as dist 29 | 30 | # First Party 31 | import alog 32 | 33 | log = alog.use_channel("TRCH_RN") 34 | 35 | 36 | def initialize_torch_distribution(world_size, rank=0, backend="gloo|nccl"): 37 | 38 | if dist.is_available(): 39 | log.debug( 40 | "Initializing process group - backend %s, rank %d, world size %d", 41 | backend, 42 | rank, 43 | world_size, 44 | ) 45 | dist.init_process_group(backend=backend, world_size=world_size, rank=rank) 46 | 47 | 48 | def determine_local_world_size(): 49 | """Function to automatically deduce the world size based on 50 | available processors. 51 | 52 | NOTE: This function will try to use ALL gpus accessible to it 53 | """ 54 | 55 | if cuda.is_available(): 56 | num_proc = cuda.device_count() 57 | log.info("Cuda devices available! Using %d devices.", num_proc) 58 | return num_proc 59 | # Fall back to using the OS cpu count 60 | # TODO: Callibrate this to some reasonable default... 61 | num_proc = os.cpu_count() 62 | log.info("Cuda devices NOT available! Using CPU %d processes.", num_proc) 63 | return num_proc 64 | 65 | 66 | def get_torch_elastic_launch_config( 67 | master_addr: str, 68 | master_port: str, 69 | start_method: str = "spawn", 70 | max_restarts=3, 71 | ) -> LaunchConfig: 72 | 73 | # Constants; we assume everything executes on the same node 74 | min_nodes = 1 75 | max_nodes = 1 76 | rdzv_configs = {"rank": 0} 77 | 78 | nproc_per_node = determine_local_world_size() 79 | 80 | if "OMP_NUM_THREADS" not in os.environ and nproc_per_node > 1: 81 | omp_num_threads = 1 82 | log.warning( 83 | "\n*****************************************\n" 84 | "Setting OMP_NUM_THREADS environment variable for each process to be " 85 | "%s in default, to avoid your system being overloaded, " 86 | "please further tune the variable for optimal performance in " 87 | "your application as needed. \n" 88 | "*****************************************", 89 | omp_num_threads, 90 | ) 91 | # This env variable will be passed down to the subprocesses 92 | os.environ["OMP_NUM_THREADS"] = str(omp_num_threads) 93 | 94 | return LaunchConfig( 95 | min_nodes=min_nodes, 96 | max_nodes=max_nodes, 97 | nproc_per_node=nproc_per_node, 98 | start_method=start_method, 99 | rdzv_backend="static", 100 | rdzv_endpoint=f"{master_addr}:{master_port}", 101 | rdzv_configs=rdzv_configs, 102 | max_restarts=max_restarts, 103 | ) 104 | -------------------------------------------------------------------------------- /caikit_nlp/toolkit/trainer_utils.py: -------------------------------------------------------------------------------- 1 | # Copyright The Caikit Authors 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | """Contains toolkit functionality for huggingface Trainer""" 15 | # Standard 16 | from datetime import datetime 17 | 18 | # Third Party 19 | import torch 20 | 21 | # First Party 22 | from caikit import get_config 23 | from caikit.core.data_model import DataStream 24 | from caikit.core.exceptions import error_handler 25 | import alog 26 | 27 | log = alog.use_channel("TRNR_UTILS") 28 | error = error_handler.get(log) 29 | 30 | 31 | def validate_training_data(train_stream: DataStream, model_name: str, module_id: str): 32 | 33 | global_default = get_config().training_data_limit.__default__ 34 | module_default = ( 35 | get_config() 36 | .training_data_limit.get(module_id, {}) 37 | .get("__default__", global_default) 38 | ) 39 | 40 | max_num_examples = ( 41 | get_config() 42 | .training_data_limit.get(module_id, {}) 43 | .get(model_name, module_default) 44 | ) 45 | 46 | if max_num_examples > 0: 47 | train_stream_size = len(train_stream) 48 | error.value_check( 49 | "", 50 | train_stream_size <= max_num_examples, 51 | "Number of examples ({}) exceeds the maximum number of examples allowed " 52 | "({}) for this model", 53 | train_stream_size, 54 | max_num_examples, 55 | ) 56 | 57 | 58 | def log_step(state, logs): 59 | if state.epoch is not None: 60 | logs["epoch"] = round(state.epoch, 2) 61 | 62 | # Get Rank 63 | if torch.distributed.is_initialized(): 64 | rank = torch.distributed.get_rank() 65 | else: 66 | rank = 0 67 | 68 | if "loss" in logs: 69 | if state.epoch is not None: 70 | logs["epoch"] = round(state.epoch, 2) 71 | 72 | log.debug( 73 | "process rank: {} loss: {} step: {}".format( 74 | rank, float(logs["loss"]), state.global_step 75 | ) 76 | ) 77 | output = { 78 | "epoch": float(logs["epoch"]), 79 | "step": state.global_step, 80 | "value": float(logs["loss"]), 81 | "timestamp": datetime.isoformat(datetime.now()), 82 | } 83 | state.log_history.append(output) 84 | else: 85 | output = {**logs, **{"step": state.global_step}} 86 | state.log_history.append(output) 87 | 88 | return state 89 | -------------------------------------------------------------------------------- /caikit_nlp/toolkit/verbalizer_utils.py: -------------------------------------------------------------------------------- 1 | # Copyright The Caikit Authors 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # Standard 15 | import re 16 | 17 | # First Party 18 | from caikit.core.exceptions import error_handler 19 | import alog 20 | 21 | log = alog.use_channel("VERBALIZER_UTIL") 22 | error = error_handler.get(log) 23 | 24 | 25 | def is_valid_verbalizer(verbalizer_template: str) -> bool: 26 | """Given a verbalizer template, determine if it's valid or not. We say a verbalizer 27 | template is valid if and only if we have at least one renderable field. 28 | 29 | Args: 30 | verbalizer_template: str 31 | Tentative verbalizer template to be used in text generation. 32 | Returns: 33 | bool: 34 | True if this is a valid verbalizer str with at least one renderable placeholder. 35 | """ 36 | if not isinstance(verbalizer_template, str): 37 | return False 38 | return re.search(r"{{([_a-zA-Z0-9]+)}}", verbalizer_template) is not None 39 | 40 | 41 | def render_verbalizer(verbalizer_template: str, source_object) -> str: 42 | """Given a verbalizer template and a source object, replace all templates with keys or 43 | attributes from the source object. Templates are expected to follow Python variable name 44 | allowed chars, i.e., alphanumeric with underscores; if they don't, they're skipped. The 45 | contents of the template will be replaced with either keys on the source object, or attribute 46 | values. 47 | 48 | Templates should be in double brackets with no whitespace, e.g., {{label}}. Consider the 49 | following examples. 50 | 51 | Examples: 52 | 1. [Dictionary based] 53 | verbalizer_template = "Hello {{label}}" 54 | source_object = {"label": "world"} 55 | -> replace {{label}} with source_object["label"], producing "Hello world" 56 | returns: "Hello world" 57 | 2. [Object based] 58 | verbalizer_template = "Source: {{input}} Target: {{output}}" 59 | source_object = GenerationTrainRecord(input="machine", output="learning") 60 | -> replace {{input}} with getattr(source_object, "source") and replace 61 | {{output}} with getattr(source_object, "target"). 62 | returns: "Source: machine Target: learning" 63 | 64 | NOTE: This function will throw ValueError if you try to grab a key or property 65 | that is invalid. 66 | 67 | Args: 68 | verbalizer_template: str 69 | Verbalizer that we want to render object values into. 70 | 71 | Returns: 72 | str 73 | Verbalizer string with placeholders rendered. 74 | """ 75 | is_dict = isinstance(source_object, dict) 76 | 77 | def replace_text(match_obj: re.Match): 78 | captured_groups = match_obj.groups() 79 | if len(captured_groups) != 1: 80 | error( 81 | "", 82 | ValueError( 83 | "Unexpectedly captured multiple groups in verbalizer rendering" 84 | ), 85 | ) 86 | 87 | index_object = captured_groups[0] 88 | if is_dict: 89 | if index_object not in source_object: 90 | error( 91 | "", 92 | ValueError( 93 | f"Requested template string '{index_object}' is not a valid key in dict" 94 | ), 95 | ) 96 | return source_object[index_object] 97 | 98 | if not hasattr(source_object, index_object): 99 | error( 100 | "", 101 | ValueError( 102 | f"Requested template string '{index_object}' is not a valid property of type", 103 | ), 104 | ) 105 | return getattr(source_object, index_object) 106 | 107 | return re.sub(r"{{([_a-zA-Z0-9]+)}}", replace_text, verbalizer_template) 108 | -------------------------------------------------------------------------------- /caikit_nlp/version.py: -------------------------------------------------------------------------------- 1 | # pylint: disable=unused-import 2 | try: 3 | # Local 4 | from ._version import __version__, __version_tuple__ 5 | except ImportError: 6 | __version__ = "unknown" 7 | version_tuple = (0, 0, __version__) 8 | -------------------------------------------------------------------------------- /code-of-conduct.md: -------------------------------------------------------------------------------- 1 | # Community Code of Conduct 2 | 3 | The Caikit project [has a common code of conduct that can be found here](https://github.com/caikit/community/blob/main/code-of-conduct.md). 4 | -------------------------------------------------------------------------------- /examples/compare_local_vs_tgis_models.py: -------------------------------------------------------------------------------- 1 | # Utils imports should always come first, because they ensure caikit_nlp 2 | # is added to the syspath if not running inside of a container. 3 | # Standard 4 | import json 5 | import time 6 | 7 | # Third Party 8 | from datasets import load_dataset 9 | from utils import SUPPORTED_DATASETS, load_model 10 | import torch 11 | 12 | # First Party 13 | import caikit 14 | 15 | # Local 16 | from caikit_nlp import data_model 17 | import caikit_nlp 18 | 19 | NUM_SAMPLES_TO_RUN = 100 20 | 21 | model_path = "prompt_prefixes/sample_prompt" 22 | # Grab the test stream from the twitter dataset 23 | test_stream = SUPPORTED_DATASETS["twitter_complaints"].dataset_loader()[2] 24 | 25 | # Load the local block 26 | local_model = load_model(is_distributed=False, model_path=model_path) 27 | # Load the TGIS backed block; this will kill TGIS if a container is already running, then restart. 28 | distributed_model = load_model(is_distributed=True, model_path=model_path) 29 | 30 | preds = [] 31 | for datum in test_stream: 32 | dis_res = distributed_model.run(datum.input) 33 | local_res = local_model.run(datum.input) 34 | preds.append( 35 | { 36 | "input": datum.input, 37 | "output": datum.output, 38 | "local_prediction": local_res.text.split(":")[-1].strip(), 39 | "distributed_prediction": dis_res.text.split(":")[-1].strip(), 40 | } 41 | ) 42 | if len(preds) >= NUM_SAMPLES_TO_RUN: 43 | break 44 | 45 | 46 | with open("preds.json", "w") as f: 47 | json.dump(preds, f, sort_keys=True, indent=4) 48 | 49 | num_matching = 0 50 | num_local_correct = 0 51 | num_distributed_correct = 0 52 | num_mismatching = 0 53 | for x in preds: 54 | if x["output"] == x["local_prediction"]: 55 | num_local_correct += 1 56 | 57 | if x["output"] == x["distributed_prediction"]: 58 | num_distributed_correct += 1 59 | 60 | if x["local_prediction"] == x["distributed_prediction"]: 61 | num_matching += 1 62 | else: 63 | num_mismatching += 1 64 | 65 | print("----- Metrics -----") 66 | print("Num correct [local block via PEFT]: {}".format(num_local_correct)) 67 | print("Num correct [distributed via TGIS]: {}".format(num_distributed_correct)) 68 | print("Num matching remote / local preds: {}".format(num_matching)) 69 | print("Num not matching remote / local preds: {}".format(num_mismatching)) 70 | -------------------------------------------------------------------------------- /examples/evaluate_model.py: -------------------------------------------------------------------------------- 1 | """Given a trained model, which was presumably created by run_peft_tuning.py, 2 | load it and run evaluation. 3 | """ 4 | # Standard 5 | import argparse 6 | import json 7 | import pathlib 8 | 9 | # Third Party 10 | from tqdm import tqdm 11 | from utils import ( 12 | SUPPORTED_DATASETS, 13 | SUPPORTED_METRICS, 14 | configure_random_seed_and_logging, 15 | get_wrapped_evaluate_metric, 16 | is_float, 17 | kill_tgis_container_if_exists, 18 | load_model, 19 | print_colored, 20 | string_to_float, 21 | ) 22 | 23 | # Local 24 | from caikit_nlp.toolkit.verbalizer_utils import render_verbalizer 25 | 26 | 27 | def parse_args() -> argparse.Namespace: 28 | """Parse & validate command line arguments. 29 | 30 | Returns: 31 | argparse.Namespace 32 | Parsed arguments to be leveraged model evaluation. 33 | """ 34 | parser = argparse.ArgumentParser( 35 | description="Evaluate a text generation model.", 36 | ) 37 | # TODO - Patch text-generation-launcher model var so that we can't mount the wrong model 38 | parser.add_argument( 39 | "--tgis", 40 | help="If enabled, runs inference through TGIS instead the local .run().", 41 | action="store_true", 42 | ) 43 | parser.add_argument( 44 | "--model_path", 45 | help="Model to be loaded from disk.", 46 | type=pathlib.Path, 47 | required=True, 48 | ) 49 | parser.add_argument( 50 | "--dataset", 51 | help="Dataset to use to train prompt vectors. Options: {}".format( 52 | list(SUPPORTED_DATASETS.keys()) 53 | ), 54 | default="twitter_complaints", 55 | ) 56 | parser.add_argument( 57 | "--metrics", 58 | help="Metrics to calculate in space delimited list", 59 | default=["accuracy"], 60 | nargs="*", 61 | choices=list(SUPPORTED_METRICS.keys()), 62 | ) 63 | parser.add_argument( 64 | "--preds_file", 65 | help="JSON file to dump raw source / target texts to.", 66 | default="model_preds.json", 67 | ) 68 | parser.add_argument( 69 | "--max_new_tokens", 70 | help="Maximum number of new tokens to be generated", 71 | type=int, 72 | default=20, 73 | ) 74 | parser.add_argument( 75 | "--truncate_input_tokens", 76 | help="Number of allowed input tokens (no truncation=0)", 77 | type=int, 78 | default=0, 79 | ) 80 | args = parser.parse_args() 81 | return args 82 | 83 | 84 | def get_model_preds_and_references( 85 | model, validation_stream, max_new_tokens, truncate_input_tokens 86 | ): 87 | """Given a model & a validation stream, run the model against every example in the validation 88 | stream and compare the outputs to the target/output sequence. 89 | 90 | Args: 91 | model 92 | Peft Model to be evaluated (may leverage different backends). 93 | validation_stream: DataStream[GenerationTrainRecord] 94 | Validation stream with labeled targets that we want to compare to our model's 95 | predictions. 96 | max_new_tokens: int 97 | Max number of new tokens to be generated, i.e., output limit 98 | truncate_input_tokens: int 99 | Number of allowed input tokens, i.e., input limit 100 | 101 | Returns: 102 | Tuple(List) 103 | Tuple of 2 lists; the model predictions and the expected output sequences. 104 | """ 105 | model_preds = [] 106 | targets = [] 107 | 108 | for datum in tqdm(validation_stream): 109 | # Local .run() currently prepends the input text to the generated string; 110 | # Ensure that we're just splitting the first predicted token & beyond. 111 | raw_model_text = model.run( 112 | datum.input, 113 | max_new_tokens=max_new_tokens, 114 | truncate_input_tokens=truncate_input_tokens, 115 | ).generated_text 116 | parse_pred_text = raw_model_text.split(datum.input)[-1].strip() 117 | model_preds.append(parse_pred_text) 118 | targets.append(datum.output) 119 | return ( 120 | model_preds, 121 | targets, 122 | ) 123 | 124 | 125 | def export_model_preds(preds_file, predictions, validation_stream, verbalizer): 126 | """Exports a JSON file containing a list of objects, where every object contains: 127 | - source: str - Source string used for generation. 128 | - target: str - Ground truth target label used for generation. 129 | - verbalized_source: str - Source string after model verbalization 130 | - predicted_target: str - Predicted model target. 131 | 132 | Args: 133 | preds_file: str 134 | Path on disk to JSON file to be written. 135 | predictions: List 136 | Model prediction list, where each predicted text excludes source text as a prefix. 137 | validation_stream: DataStream 138 | Datastream object of GenerationTrainRecord objects used for validation against a model 139 | to generate predictions. 140 | verbalizer: str 141 | Model verbalizer used for generating target predictions. 142 | """ 143 | pred_objs = [] 144 | for pred, record in zip(predictions, validation_stream): 145 | res = { 146 | "source": record.input, 147 | "target": record.output, 148 | "predicted_target": pred, 149 | } 150 | if verbalizer is not None: 151 | res["verbalized_source"] = render_verbalizer(verbalizer, record) 152 | pred_objs.append(res) 153 | 154 | with open(preds_file, "w") as jfile: 155 | json.dump(pred_objs, jfile, indent=4, sort_keys=True) 156 | 157 | 158 | if __name__ == "__main__": 159 | configure_random_seed_and_logging() 160 | args = parse_args() 161 | metric_funcs = [SUPPORTED_METRICS[metric] for metric in args.metrics] 162 | print_colored("Metrics to be calculated: {}".format(args.metrics)) 163 | 164 | # Load the model; this can be a local model, or a distributed TGIS instance 165 | print_colored("Loading the model...") 166 | model = load_model(args.tgis, str(args.model_path)) 167 | # Load the validation stream with marked target sequences 168 | print_colored("Grabbing validation data...") 169 | dataset_info = SUPPORTED_DATASETS[args.dataset] 170 | validation_stream = dataset_info.dataset_loader()[1] 171 | if validation_stream is None: 172 | raise ValueError( 173 | "Selected dataset does not have a validation dataset available!" 174 | ) 175 | 176 | # Run the data through the model; save the predictions & references 177 | print_colored("Getting model predictions...") 178 | predictions, references = get_model_preds_and_references( 179 | model, validation_stream, args.max_new_tokens, args.truncate_input_tokens 180 | ) 181 | print_colored( 182 | "Exporting model preds, source, verbalized source, and ground truth targets to {}".format( 183 | args.preds_file 184 | ) 185 | ) 186 | export_model_preds( 187 | args.preds_file, 188 | predictions, 189 | validation_stream, 190 | getattr(model, "verbalizer", None), 191 | ) 192 | 193 | for metric_func in metric_funcs: 194 | metric_res = metric_func(predictions=predictions, references=references) 195 | print_colored(metric_res) 196 | # If we started a TGIS instance, kill it; otherwise, leave our container alone. 197 | # TODO: This will still looks for containers to kill, even if you're running TGIS 198 | # outside of a container through text-generation-server directly. For now, we are 199 | # always running TGIS in a container, so it's ok; the worst that will happen is 200 | # you'll kill somebody else's container. 201 | if args.tgis: 202 | print_colored("Killing containerized TGIS instance...") 203 | kill_tgis_container_if_exists() 204 | -------------------------------------------------------------------------------- /examples/kill-text-generation-launcher.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | # Kills the text generation launcher container if it's running 3 | running_container_id=$(docker container ls | grep -i text-gen-server | cut -d " " -f 1) 4 | if [ -z "$running_container_id" ]; then 5 | echo "TGIS container is not running; nothing to do!" 6 | else 7 | echo "Trying to kill TGIS container with id: {$running_container_id}..." 8 | eval "docker stop $running_container_id" 9 | fi 10 | -------------------------------------------------------------------------------- /examples/load_and_run_distributed_peft.py: -------------------------------------------------------------------------------- 1 | """This script loads and runs a sample PEFT model as a caikit module using 2 | the TGIS backend. 3 | 4 | In a nutshell, this does the following: 5 | - Check if `text-generation-launcher` is defined; if it doesn't, assume we have a Docker image 6 | that is (currently hardcoded to be) able to run TGIS & expose the proper ports, and patch a 7 | wrapper script around it onto our path so that the TGIS backend falls back to leveraging it 8 | - Load the model through caikit 9 | - Run an inference text generation request and dump the (garbage) output back to the console 10 | """ 11 | # Standard 12 | from shutil import which 13 | import os 14 | import subprocess 15 | import sys 16 | 17 | # First Party 18 | from caikit.core.module_backend_config import _CONFIGURED_BACKENDS, configure 19 | from caikit_tgis_backend import TGISBackend 20 | import alog 21 | import caikit 22 | 23 | # Local 24 | import caikit_nlp 25 | 26 | alog.configure("debug4") 27 | 28 | PREFIX_PATH = "prompt_prefixes" 29 | 30 | has_text_gen = which("text-generation-launcher") 31 | if not which("text-generation-launcher"): 32 | print("Text generation server command not found; using Docker override") 33 | this_dir = os.path.dirname(os.path.abspath(__file__)) 34 | os.environ["PATH"] += ":" + this_dir 35 | assert ( 36 | which("text-generation-launcher") is not None 37 | ), "Text generation script not found!" 38 | 39 | # Configure caikit to prioritize TGIS backend 40 | _CONFIGURED_BACKENDS.clear() 41 | # load_timeout: 320 42 | # grpc_port: null 43 | # http_port: 3001 44 | # health_poll_delay: 1.0 45 | caikit.configure( 46 | config_dict={"module_backends": {"priority": [TGISBackend.backend_type]}} 47 | ) # should not be necessary but just in case 48 | configure() # backend configure 49 | 50 | # Load with TGIS backend 51 | prefix_model_path = os.path.join(PREFIX_PATH, "sample_prompt") 52 | my_model = caikit.load(prefix_model_path) 53 | sample_text = "@TommyHilfiger Dramatic shopping exp. ordered 6 jeans same size (30/32) 2 fits / 2 too large / 2 too slim : same brand > different sizing" 54 | sample_output = my_model.run(sample_text) 55 | 56 | print("---------- Model result ----------") 57 | print(sample_output) 58 | -------------------------------------------------------------------------------- /examples/text-generation-launcher: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | # This script is primarily meant for illustrative purposes; if we don't have 3 | # the text-generation-launcher command locally available, but we do have a Docker 4 | # container, we add this script onto our path so when the TGIS backend in caikit 5 | # tries to start the server, it runs this script instead. 6 | # 7 | # NOTE: 8 | # - Model ID, directories, etc are hardcoded to our example, params from the backend, 9 | # e.g., shard configuration, are ignored. 10 | # 11 | # - We need to export port 3000 (for probes in core distributed), and we forward 8033->50055 12 | # so that our gRPC server is exposed on the expected port for local TGIS. 13 | TGIS_MODEL="${MODEL_NAME:-bigscience/bloom-560m}" 14 | MODEL_DIR="${MODEL_DIR:-models}" 15 | echo "Running TGIS with model: $TGIS_MODEL" 16 | 17 | docker run --rm \ 18 | --gpus '"device=0"' \ 19 | -p 8090:8090 \ 20 | -p 8085:8085 \ 21 | -p 8060:8060 \ 22 | -p 8087:8087 \ 23 | -p 50055:8033 \ 24 | -p 3000:3000 \ 25 | -v $(pwd)/${MODEL_DIR}:/models \ 26 | -v $(pwd)/../runtime_config.yaml:/conf/runtime_config.yaml \ 27 | -v $(pwd)/transformers_cache:/shared_model_storage/transformers_cache \ 28 | -v $(pwd)/prompt_prefixes:/prompt_prefixes \ 29 | -e LOG_LEVEL=debug3 \ 30 | -e ACCEPT_LICENSE=true \ 31 | -e INFERENCE_PLUGIN_MODEL_MESH_MAX_MODEL_CONCURRENCY=10 \ 32 | -e RUNTIME_SERVER_THREAD_POOL_SIZE=10 \ 33 | -e INFERENCE_PLUGIN_MODEL_MESH_CAPACITY=28000000000 \ 34 | -e INFERENCE_PLUGIN_MODEL_MESH_DEFAULT_MODEL_SIZE=1773741824 \ 35 | -e CONFIG_FILES="/conf/runtime_config.yaml" \ 36 | -e RUNTIME_LOCAL_MODELS_DIR="/models/" \ 37 | -e MAX_BATCH_SIZE=8 \ 38 | -e MAX_SEQUENCE_LENGTH=2048 \ 39 | -e NUM_GPUS=1 \ 40 | -e TRANSFORMERS_CACHE="/shared_model_storage/transformers_cache" \ 41 | -e HUGGINGFACE_HUB_CACHE="/shared_model_storage/transformers_cache" \ 42 | -e MAX_CONCURRENT_REQUESTS=64 \ 43 | -e GATEWAY_PORT=8060 \ 44 | -e RUNTIME_PORT=8087 \ 45 | -e MODEL_NAME=$TGIS_MODEL \ 46 | -e PREFIX_STORE_PATH="/prompt_prefixes" \ 47 | --user root \ 48 | text-gen-server:server-release_ubi8_py38 49 | -------------------------------------------------------------------------------- /prompt_tuning_parameter_selection.md: -------------------------------------------------------------------------------- 1 | # Prompt Tuning Parameters 2 | 3 | ## Parameters for selection of training algorithm 4 | - `base_model` 5 | - **Description**: Path to the base model or `caikit.Resource` object of the base model to be used for tuning. A model-name string may also be provided. In this case, the Transformers API will automatically load it up from Hugging Face model cache if the model is locally available. If it is not available, the model may be downloaded by setting the `ALLOW_DOWNLOADS` environment variable to `true`. 6 | - **Accepted values**: 7 | - The model needs to be of type causal-lm or seq2seq, thus loadable via HuggingFace `AutoModelForCausalLM` or `AutoModelForSeq2SeqLM` loading methods. 8 | - `tuning_type`: (`str` or `caikit_nlp.modules.text_generation.TuningType`) 9 | - Type of Peft Tuning config which we would like to build. 10 | - **Accepted values**: `PROMPT_TUNING` and `MULTITASK_PROMPT_TUNING` 11 | - **Default**: `PROMPT_TUNING` 12 | - `num_epochs`: (int) 13 | - The number of epochs is the number of complete passes through the training dataset. 14 | - quality depends a lot on number of epochs. 15 | - **Expose to end user recommendation**: True 16 | - Accepted values: any positive int 17 | - **Default**: 20 18 | - `device`: (str) 19 | - Training device to be used, could be `cpu`, `cuda`, cuda specific device name 20 | - **Expose to end user recommendation**: False 21 | - `lr` / `learning_rate` - (float) The name of the parameter soon to be changed to make it more intuitive. 22 | - Learning rate to be used for training 23 | - **Expose to end user recommendation**: True 24 | - `accumulate_steps`: 25 | - Number of steps to be used for gradient accumulation. Gradient accumulation refers to a method of collecting gradient for configured number of steps instead of updating the model variables at every step and then applying the update to model variables. This can be used as a tool to overcome smaller batch size limitation. Often also referred in conjuction with "effective batch size". 26 | - **Expose to end user recommendation**: True 27 | - `verbalizer` 28 | - Verbalizer template to be used for formatting data at train and inference time. This template may use brackets to indicate where fields from the data model TrainGenerationRecord must be rendered. Default: "{{input}}", i.e., the raw text. 29 | - **Default**: "{{input}}", i.e., the raw text. 30 | - **Expose to end user recommendation**: True 31 | - `batch_size`: 32 | - The batch size is a number of samples processed before the model is updated. 33 | - **Default**: 8 34 | - **Expose to end user recommendation**: True 35 | - `max_source_length`: 36 | - Max length of input sequences being considered. 37 | - **Default**: 256. 38 | - **Expose to end user recommendation**: True 39 | - `max_target_length`: 40 | - Max length of target sequences being predicted. Default: 128. 41 | - **Default**: 128 42 | - **Expose to end user recommendation**: True 43 | - `tuning_config.num_virtual_tokens`: 44 | - Number of virtual tokens to be used for training. In prompt tuning we are essentially learning the embedded representations for soft prompts, which are known as virtual tokens, via back propagation for a specific task(s) while keeping the rest of the model is fixed. `num_virtual_tokens` is the number of dimensions for these virtual tokens. 45 | - **Expose to end user recommendation**: True (default to be set by application layer end) 46 | - This should also correspond to available source prompt, if source prompt exists, i.e a user need to select number of virtual token as per the source prompt available, in case they want to use MPT source prompts. 47 | 48 | - `tuning_config.prompt_tuning_init_method`: 49 | - Could be: `RANDOM`, `TEXT`, `ONLY_SOURCE_SHARED` and `AVERAGE_SOURCE` 50 | - `TEXT` requires `tuning_config.prompt_tuning_init_text` to be set 51 | - `ONLY_SOURCE_SHARED` and `AVERAGE_SOURCE` requires `tuning_config.prompt_tuning_init_source_model` to be set and source prompt model to be available for the given `base_model` 52 | - **Default**: `RANDOM` 53 | - **Expose to end user recommendation**: True 54 | - Only `RANDOM`, `TEXT` and `AVERAGE_SOURCE` to be exposed where `AVERAGE_SOURCE` is only applicable for tuning method is `MULTITASK_PROMPT_TUNING` 55 | - `tuning_config.prompt_tuning_init_text`: 56 | - Initialization text to be used **IF** `tuning_config.prompt_tuning_init_method` is set to `TEXT` otherwise this will be ignored. 57 | - **Default**: NO Default. 58 | - **Expose to end user recommendation**: True (if `TEXT` init_method is exposed to customers) 59 | - `tuning_config.prompt_tuning_init_source_model`: 60 | - Path pointing to the source prompt model. This path is relative to `config.source_prompt_base` (or SOURCE_PROMPT_BASE` env variable) 61 | - The source model selection needs to correspond to the `base_model`. 62 | - There may be cases where we have multiple source prompts available for a given model, in which case, their selection criteria needs to be determined. 63 | - **Default** Would depend on the `base_model`. If `MULTITASK_PROMPT_TUNING` is not selected as the tuning type, then this will be ignored. 64 | - `tuning_config.output_model_types`: `List(str)` 65 | - Could contain a list containing string `ENCODER`, `DECODER`. 66 | - Acceptable values for types of models: 67 | - CausalLM: `["DECODER"]` 68 | - Seq2Seq: `["ENCODER"]`, `["DECODER"]`, `["ENCODER", "DECODER"]` 69 | - **Default**: 70 | - CausalLM: `[DECODER]` 71 | - Seq2Seq: `[ENCODER]` 72 | - **Expose to end user recommendation**: False 73 | 74 | - `torch_dtype`: (str) 75 | - Datatype to use for training of the underlying text generation model. If no value is provided, we pull from torch_dtype in config. If an in memory resource is provided which does not match the specified data type, the model underpinning the resource will be converted in place to the correct torch dtype. 76 | - **Expose to end user recommendation**: False 77 | - Recommended to be configured at environment or server configuration level. 78 | - `silence_progress_bars` (bool) 79 | - Toggle to control progress bar for training. This is relevant to only "python user experience" and doesn't apply training via caikit runtime. 80 | - **Expose to end user recommendation**: False 81 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [build-system] 2 | requires = [ 3 | "setuptools>=60", 4 | "setuptools-scm>=8.0"] 5 | 6 | [project] 7 | name = "caikit-nlp" 8 | dynamic = ["version"] 9 | description = "Caikit NLP" 10 | license = {text = "Apache-2.0"} 11 | readme = "README.md" 12 | requires-python = "~=3.9" 13 | classifiers=[ 14 | "License :: OSI Approved :: Apache Software License" 15 | ] 16 | dependencies = [ 17 | "caikit[runtime-grpc,runtime-http]>=0.26.34,<0.28.0", 18 | "caikit-tgis-backend>=0.1.36,<0.2.0", 19 | # TODO: loosen dependencies 20 | "grpcio>=1.62.2", # explicitly pin grpc dependencies to a recent version to avoid pip backtracking 21 | "grpcio-reflection>=1.62.2", 22 | "grpcio-health-checking>=1.62.2", 23 | "accelerate>=1.3.0", 24 | "datasets>=3.3.0", 25 | "huggingface-hub", 26 | "numpy>=1.23.0", 27 | "pandas>=2.2.3", 28 | "scikit-learn>=1.6.1", 29 | "scipy>=1.10.0", 30 | "sentence-transformers>=3.4.0,<3.5.0", 31 | "tokenizers>=0.20.0", 32 | "torch>=2.3.1,<2.7.0", 33 | "tqdm>=4.67.0", 34 | "transformers>=4.48.3,<4.50.0", 35 | "peft==0.14.0", 36 | ] 37 | 38 | 39 | [tool.setuptools.packages.find] 40 | exclude = ["tests", "tests.*"] 41 | namespaces = false 42 | 43 | 44 | [tool.setuptools_scm] 45 | version_file = "caikit_nlp/_version.py" 46 | 47 | [project.urls] 48 | Source = "https://github.com/caikit/caikit-nlp" 49 | -------------------------------------------------------------------------------- /runtime_config.yaml: -------------------------------------------------------------------------------- 1 | # its contents configure the TGIS server & caikit 2 | jvm_options: [] 3 | 4 | runtime: 5 | library: caikit_nlp 6 | lazy_load_local_models: true 7 | batching: 8 | standalone-model: 9 | size: 0 # Set to batch size for batching 10 | 11 | model_management: 12 | finders: 13 | default: 14 | type: LOCAL 15 | remote_tgis: 16 | type: TGIS-AUTO 17 | config: 18 | test_connection: true 19 | initializers: 20 | default: 21 | type: LOCAL 22 | config: 23 | backend_priority: 24 | - type: TGIS 25 | config: 26 | local: 27 | load_timeout: 120 28 | grpc_port: null 29 | http_port: null 30 | health_poll_delay: 1.0 31 | remote_models: 32 | flan-t5-xl: 33 | hostname: localhost:8033 34 | prompt_dir: tgis_prompts 35 | llama-70b: 36 | hostname: localhost:8034 37 | prompt_dir: tgis_prompts 38 | 39 | connection: 40 | hostname: "foo.{model_id}:1234" 41 | ca_cert_file: null 42 | client_cert_file: null 43 | client_key_file: null 44 | 45 | # Config used only in EmbeddingModule. Set here or use env vars like EMBEDDING_RETRIES=32 46 | embedding: 47 | # Allow models with remote code. 48 | trust_remote_code: false 49 | # Number of times to retry on error. Most deployments should use 0 retries. 50 | retries: 0 51 | # Batch size for encode() if <= 0 or invalid, the sentence-transformers default is used 52 | batch_size: 0 53 | # Should implicit truncation (with truncate_input_tokens=0) throw error for truncation (default) or disable this 54 | implicit_truncation_errors: true 55 | # Attempt to optimize with PyTorch compile() 56 | pt2_compile: false 57 | # Use IPEX optimize. Works best when used with autocast (bfloat16) below. 58 | ipex: false 59 | # Use autocast in encode with its default dtype (bfloat16) 60 | autocast: false 61 | # For testing, set device to "mps" on MacOS or "xpu" for IPEX GPU. 62 | # Otherwise, the default does automatic checks for cuda GPU (else cpu). 63 | device: "" 64 | -------------------------------------------------------------------------------- /runtime_template/run_with_gateway.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | ################################################################################ 4 | # This script is the entrypoint for the multi-process runtime container that 5 | # runs the REST gateway alongside the grpc runtime. 6 | # The multiprocess management is intended to be handled by `tini`, the tiny 7 | # but valid `init`. 8 | ################################################################################ 9 | 10 | set -e 11 | 12 | echo '[STARTING RUNTIME]' 13 | cd /app && python3 -m caikit.runtime.grpc_server & 14 | 15 | RUNTIME_PORT=${RUNTIME_PORT:-8085} 16 | 17 | # If TLS enabled, make an https call, otherwise make an http call 18 | protocol="http" 19 | if [ "${RUNTIME_TLS_SERVER_KEY}" != "" ] && [ "${RUNTIME_TLS_SERVER_CERT}" != "" ] 20 | then 21 | protocol="--cacert $RUNTIME_TLS_SERVER_CERT https" 22 | if [ "${RUNTIME_TLS_CLIENT_CERT}" != "" ] 23 | then 24 | protocol="-k --cert $RUNTIME_TLS_SERVER_CERT --key $RUNTIME_TLS_SERVER_KEY https" 25 | fi 26 | fi 27 | 28 | # Wait for the Runtime to come up before starting the gateway 29 | sleep 3 30 | until $(curl --output /dev/null --silent --fail ${protocol}://localhost:${RUNTIME_PORT}); do 31 | echo '.' 32 | sleep 1 33 | done 34 | 35 | echo '[STARTING GATEWAY]' 36 | PROXY_ENDPOINT="localhost:${RUNTIME_PORT}" SERVE_PORT=${GATEWAY_PORT:-8080} /gateway --swagger_path=/swagger & 37 | 38 | wait -n 39 | -------------------------------------------------------------------------------- /scripts/dump_apis.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | # Make a directory with interfaces 4 | http_interface_dir="generated_interfaces/http" 5 | grpc_interface_dir="generated_interfaces/grpc" 6 | mkdir -p $http_interface_dir 7 | mkdir -p $grpc_interface_dir 8 | 9 | # Run the HTTP server in the background 10 | RUNTIME_LIBRARY=caikit_nlp python -m caikit.runtime.http_server & 11 | http_pid=$! 12 | 13 | # Sleep for a bit and then call it to get the swagger doc 14 | sleep 5 15 | curl http://localhost:8080/openapi.json | jq > $http_interface_dir/openapi.json 16 | 17 | # Kill the HTTP server and wait for it to die 18 | kill -9 $http_pid 19 | wait 20 | 21 | # Dump the gRPC interfaces 22 | RUNTIME_LIBRARY=caikit_nlp python -m caikit.runtime.dump_services $grpc_interface_dir -------------------------------------------------------------------------------- /scripts/fmt.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | pre-commit run --all-files 4 | RETURN_CODE=$? 5 | 6 | function echoWarning() { 7 | LIGHT_YELLOW='\033[1;33m' 8 | NC='\033[0m' # No Color 9 | echo -e "${LIGHT_YELLOW}${1}${NC}" 10 | } 11 | 12 | if [ "$RETURN_CODE" -ne 0 ]; then 13 | if [ "${CI}" != "true" ]; then 14 | echoWarning "☝️ This appears to have failed, but actually your files have been formatted." 15 | echoWarning "Make a new commit with these changes before making a pull request." 16 | else 17 | echoWarning "This test failed because your code isn't formatted correctly." 18 | echoWarning 'Locally, run `make run fmt`, it will appear to fail, but change files.' 19 | echoWarning "Add the changed files to your commit and this stage will pass." 20 | fi 21 | 22 | exit $RETURN_CODE 23 | fi 24 | -------------------------------------------------------------------------------- /scripts/run_local.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | cd $(dirname ${BASH_SOURCE[0]})/.. 3 | mkdir -p models 4 | 5 | server=${SERVER:-"http"} 6 | 7 | CONFIG_FILES=runtime_config.yaml \ 8 | LOG_LEVEL=${LOG_LEVEL:-debug3} \ 9 | LOG_FORMATTER=pretty \ 10 | python -m caikit.runtime.${server}_server 11 | -------------------------------------------------------------------------------- /setup_requirements.txt: -------------------------------------------------------------------------------- 1 | tox>=4.4.2,<5 2 | build>=0.10.0,<2.0 -------------------------------------------------------------------------------- /tests/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/caikit/caikit-nlp/dc5a205f064cc99b56da9bcb887faaef4d6dc3d9/tests/__init__.py -------------------------------------------------------------------------------- /tests/conftest.py: -------------------------------------------------------------------------------- 1 | # Copyright The Caikit Authors 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | """ 15 | This sets up global test configs when pytest starts 16 | """ 17 | 18 | # Standard 19 | import os 20 | 21 | # First Party 22 | import alog 23 | 24 | # Configure logging from the environment 25 | alog.configure( 26 | default_level=os.environ.get("LOG_LEVEL", "off"), 27 | filters=os.environ.get("LOG_FILTERS", "urllib3:off"), 28 | thread_id=os.environ.get("LOG_THREAD_ID", "") == "true", 29 | ) 30 | -------------------------------------------------------------------------------- /tests/data_model/test_generation.py: -------------------------------------------------------------------------------- 1 | # Copyright The Caikit Authors 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | # Local 16 | from caikit_nlp.data_model import ExponentialDecayLengthPenalty 17 | 18 | ## Setup ######################################################################### 19 | 20 | dummy_exponential_decay_length_penalty = ExponentialDecayLengthPenalty( 21 | start_index=1, decay_factor=0.95 22 | ) 23 | 24 | ## Tests ######################################################################## 25 | 26 | ### Exponential Decay Length Penalty 27 | def test_exponential_decay_length_penalty_all_fields_accessible(): 28 | assert dummy_exponential_decay_length_penalty.start_index == 1 29 | assert dummy_exponential_decay_length_penalty.decay_factor == 0.95 30 | 31 | 32 | def test_sampling_parameters_from_proto_and_back(): 33 | new = ExponentialDecayLengthPenalty.from_proto( 34 | dummy_exponential_decay_length_penalty.to_proto() 35 | ) 36 | assert new.start_index == 1 37 | assert new.decay_factor == 0.95 38 | 39 | 40 | def test_sampling_parameters_from_json_and_back(): 41 | new = ExponentialDecayLengthPenalty.from_json( 42 | dummy_exponential_decay_length_penalty.to_json() 43 | ) 44 | assert new.start_index == 1 45 | assert new.decay_factor == 0.95 46 | -------------------------------------------------------------------------------- /tests/fixtures/data_model/sample_objects.py: -------------------------------------------------------------------------------- 1 | # First Party 2 | from caikit.interfaces.nlp.data_model import FinishReason, GeneratedTextResult 3 | 4 | generated_response = GeneratedTextResult( 5 | generated_text="foo bar", 6 | generated_tokens=2, 7 | finish_reason=FinishReason.STOP_SEQUENCE, 8 | ) 9 | 10 | # Add an example of one of each new data model types that you've defined within your extension 11 | # to the list below. These samples are used for verifying that your object is well-aligned with 12 | # with caikit serialization interfaces. They will only be used if your extension enables 13 | # protobuf serialization in its config. 14 | # 15 | # NOTE: You do not need to add any samples from other extensions or the caikit library; only 16 | # new types explicitly created in your extension data model. 17 | data_model_samples = [ 18 | generated_response, 19 | ] 20 | -------------------------------------------------------------------------------- /tests/fixtures/tiny_models/BertForSequenceClassification/config.json: -------------------------------------------------------------------------------- 1 | { 2 | "_name_or_path": "./bert/BertForSequenceClassification", 3 | "architectures": [ 4 | "BertForSequenceClassification" 5 | ], 6 | "attention_probs_dropout_prob": 0.1, 7 | "classifier_dropout": null, 8 | "hidden_act": "gelu", 9 | "hidden_dropout_prob": 0.1, 10 | "hidden_size": 32, 11 | "initializer_range": 0.02, 12 | "intermediate_size": 37, 13 | "layer_norm_eps": 1e-12, 14 | "max_position_embeddings": 512, 15 | "model_type": "bert", 16 | "num_attention_heads": 4, 17 | "num_hidden_layers": 5, 18 | "pad_token_id": 0, 19 | "position_embedding_type": "absolute", 20 | "torch_dtype": "float32", 21 | "transformers_version": "4.30.2", 22 | "type_vocab_size": 16, 23 | "use_cache": true, 24 | "vocab_size": 1124 25 | } 26 | -------------------------------------------------------------------------------- /tests/fixtures/tiny_models/BertForSequenceClassification/pytorch_model.bin: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/caikit/caikit-nlp/dc5a205f064cc99b56da9bcb887faaef4d6dc3d9/tests/fixtures/tiny_models/BertForSequenceClassification/pytorch_model.bin -------------------------------------------------------------------------------- /tests/fixtures/tiny_models/BertForSequenceClassification/special_tokens_map.json: -------------------------------------------------------------------------------- 1 | { 2 | "cls_token": "[CLS]", 3 | "mask_token": "[MASK]", 4 | "pad_token": "[PAD]", 5 | "sep_token": "[SEP]", 6 | "unk_token": "[UNK]" 7 | } 8 | -------------------------------------------------------------------------------- /tests/fixtures/tiny_models/BertForSequenceClassification/tf_model.h5: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/caikit/caikit-nlp/dc5a205f064cc99b56da9bcb887faaef4d6dc3d9/tests/fixtures/tiny_models/BertForSequenceClassification/tf_model.h5 -------------------------------------------------------------------------------- /tests/fixtures/tiny_models/BertForSequenceClassification/tokenizer_config.json: -------------------------------------------------------------------------------- 1 | { 2 | "clean_up_tokenization_spaces": true, 3 | "cls_token": "[CLS]", 4 | "do_basic_tokenize": true, 5 | "do_lower_case": true, 6 | "mask_token": "[MASK]", 7 | "model_max_length": 512, 8 | "never_split": null, 9 | "pad_token": "[PAD]", 10 | "sep_token": "[SEP]", 11 | "strip_accents": null, 12 | "tokenize_chinese_chars": true, 13 | "tokenizer_class": "BertTokenizer", 14 | "unk_token": "[UNK]" 15 | } 16 | -------------------------------------------------------------------------------- /tests/fixtures/tiny_models/BloomForCausalLM/config.json: -------------------------------------------------------------------------------- 1 | { 2 | "apply_residual_connection_post_layernorm": false, 3 | "architectures": ["BloomForCausalLM"], 4 | "attention_dropout": 0.1, 5 | "bos_token_id": 1, 6 | "dtype": "float32", 7 | "eos_token_id": 2, 8 | "gradient_checkpointing": false, 9 | "hidden_dropout": 0.1, 10 | "hidden_size": 32, 11 | "id2label": { 12 | "0": "LABEL_0", 13 | "1": "LABEL_1", 14 | "2": "LABEL_2" 15 | }, 16 | "initializer_range": 0.02, 17 | "is_decoder": true, 18 | "label2id": { 19 | "LABEL_0": 0, 20 | "LABEL_1": 1, 21 | "LABEL_2": 2 22 | }, 23 | "layer_norm_epsilon": 1e-5, 24 | "model_type": "bloom", 25 | "n_head": 4, 26 | "n_layer": 5, 27 | "n_positions": 512, 28 | "pad_token_id": 3, 29 | "pretraining_tp": 1, 30 | "seq_length": 7, 31 | "slow_but_exact": true, 32 | "torch_dtype": "float32", 33 | "transformers_version": "4.27.1", 34 | "type_vocab_size": 16, 35 | "use_cache": true, 36 | "vocab_size": 1024 37 | } 38 | -------------------------------------------------------------------------------- /tests/fixtures/tiny_models/BloomForCausalLM/generation_config.json: -------------------------------------------------------------------------------- 1 | { 2 | "_from_model_config": true, 3 | "bos_token_id": 1, 4 | "eos_token_id": 2, 5 | "pad_token_id": 3, 6 | "transformers_version": "4.27.1" 7 | } 8 | -------------------------------------------------------------------------------- /tests/fixtures/tiny_models/BloomForCausalLM/pytorch_model.bin: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/caikit/caikit-nlp/dc5a205f064cc99b56da9bcb887faaef4d6dc3d9/tests/fixtures/tiny_models/BloomForCausalLM/pytorch_model.bin -------------------------------------------------------------------------------- /tests/fixtures/tiny_models/BloomForCausalLM/special_tokens_map.json: -------------------------------------------------------------------------------- 1 | { 2 | "bos_token": "", 3 | "eos_token": "", 4 | "pad_token": "", 5 | "unk_token": "" 6 | } 7 | -------------------------------------------------------------------------------- /tests/fixtures/tiny_models/BloomForCausalLM/tokenizer_config.json: -------------------------------------------------------------------------------- 1 | { 2 | "add_prefix_space": false, 3 | "bos_token": "", 4 | "eos_token": "", 5 | "model_max_length": 1000000000000000019884624838656, 6 | "pad_token": "", 7 | "padding_side": "left", 8 | "special_tokens_map_file": null, 9 | "tokenizer_class": "BloomTokenizer", 10 | "unk_token": "" 11 | } 12 | -------------------------------------------------------------------------------- /tests/fixtures/tiny_models/README.md: -------------------------------------------------------------------------------- 1 | ### Tiny Models for Testing 2 | 3 | The models in this directory were created using the [Transformers utils for creating tiny models](https://github.com/huggingface/transformers/blob/main/utils/create_dummy_models.py). 4 | 5 | To create a new dummy model, all you need to do it clone the transformers repo and run a command like those shown below, which were used to create the artifacts checked in here. 6 | 7 | - Bloom: `python3 utils/create_dummy_models.py --model_types bloom $OUTPUT_DIR` 8 | - T5: `python3 utils/create_dummy_models.py --model_types t5 $OUTPUT_DIR` 9 | - BERT: `python3 utils/create_dummy_models.py --model_types bert $OUTPUT_DIR` 10 | 11 | This will create several dummy models; you most likely want to place the ones you need in this directory and leverage it in `__init__.py` for the fixtures. 12 | 13 | Note: If you encounter any strangeness when running the script above, be sure to check your version of transformers in your site packages; there seem to be some dynamic imports leveraged underneath which can try to grab things from your site packages instead of the direct source, which can cause some problems if the cloned version of the code is very different. 14 | -------------------------------------------------------------------------------- /tests/fixtures/tiny_models/T5ForConditionalGeneration/config.json: -------------------------------------------------------------------------------- 1 | { 2 | "_name_or_path": "DUMMY_MODELS/t5/T5ForConditionalGeneration", 3 | "architectures": ["T5ForConditionalGeneration"], 4 | "bos_token_id": 0, 5 | "d_ff": 37, 6 | "d_kv": 8, 7 | "d_model": 32, 8 | "decoder_start_token_id": 0, 9 | "dense_act_fn": "relu", 10 | "dropout_rate": 0.1, 11 | "eos_token_id": 1, 12 | "feed_forward_proj": "relu", 13 | "initializer_factor": 0.002, 14 | "is_encoder_decoder": true, 15 | "is_gated_act": false, 16 | "layer_norm_epsilon": 1e-6, 17 | "model_type": "t5", 18 | "num_decoder_layers": 5, 19 | "num_heads": 4, 20 | "num_layers": 5, 21 | "pad_token_id": 0, 22 | "relative_attention_max_distance": 128, 23 | "relative_attention_num_buckets": 8, 24 | "torch_dtype": "float32", 25 | "transformers_version": "4.27.1", 26 | "use_cache": true, 27 | "vocab_size": 1302 28 | } 29 | -------------------------------------------------------------------------------- /tests/fixtures/tiny_models/T5ForConditionalGeneration/generation_config.json: -------------------------------------------------------------------------------- 1 | { 2 | "_from_model_config": true, 3 | "bos_token_id": 0, 4 | "decoder_start_token_id": 0, 5 | "eos_token_id": 1, 6 | "pad_token_id": 0, 7 | "transformers_version": "4.27.1" 8 | } 9 | -------------------------------------------------------------------------------- /tests/fixtures/tiny_models/T5ForConditionalGeneration/pytorch_model.bin: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/caikit/caikit-nlp/dc5a205f064cc99b56da9bcb887faaef4d6dc3d9/tests/fixtures/tiny_models/T5ForConditionalGeneration/pytorch_model.bin -------------------------------------------------------------------------------- /tests/fixtures/tiny_models/T5ForConditionalGeneration/special_tokens_map.json: -------------------------------------------------------------------------------- 1 | { 2 | "additional_special_tokens": [ 3 | "", 4 | "", 5 | "", 6 | "", 7 | "", 8 | "", 9 | "", 10 | "", 11 | "", 12 | "", 13 | "", 14 | "", 15 | "", 16 | "", 17 | "", 18 | "", 19 | "", 20 | "", 21 | "", 22 | "", 23 | "", 24 | "", 25 | "", 26 | "", 27 | "", 28 | "", 29 | "", 30 | "", 31 | "", 32 | "", 33 | "", 34 | "", 35 | "", 36 | "", 37 | "", 38 | "", 39 | "", 40 | "", 41 | "", 42 | "", 43 | "", 44 | "", 45 | "", 46 | "", 47 | "", 48 | "", 49 | "", 50 | "", 51 | "", 52 | "", 53 | "", 54 | "", 55 | "", 56 | "", 57 | "", 58 | "", 59 | "", 60 | "", 61 | "", 62 | "", 63 | "", 64 | "", 65 | "", 66 | "", 67 | "", 68 | "", 69 | "", 70 | "", 71 | "", 72 | "", 73 | "", 74 | "", 75 | "", 76 | "", 77 | "", 78 | "", 79 | "", 80 | "", 81 | "", 82 | "", 83 | "", 84 | "", 85 | "", 86 | "", 87 | "", 88 | "", 89 | "", 90 | "", 91 | "", 92 | "", 93 | "", 94 | "", 95 | "", 96 | "", 97 | "", 98 | "", 99 | "", 100 | "", 101 | "", 102 | "" 103 | ], 104 | "eos_token": "", 105 | "pad_token": "", 106 | "unk_token": "" 107 | } 108 | -------------------------------------------------------------------------------- /tests/fixtures/tiny_models/T5ForConditionalGeneration/tokenizer_config.json: -------------------------------------------------------------------------------- 1 | { 2 | "additional_special_tokens": [ 3 | "", 4 | "", 5 | "", 6 | "", 7 | "", 8 | "", 9 | "", 10 | "", 11 | "", 12 | "", 13 | "", 14 | "", 15 | "", 16 | "", 17 | "", 18 | "", 19 | "", 20 | "", 21 | "", 22 | "", 23 | "", 24 | "", 25 | "", 26 | "", 27 | "", 28 | "", 29 | "", 30 | "", 31 | "", 32 | "", 33 | "", 34 | "", 35 | "", 36 | "", 37 | "", 38 | "", 39 | "", 40 | "", 41 | "", 42 | "", 43 | "", 44 | "", 45 | "", 46 | "", 47 | "", 48 | "", 49 | "", 50 | "", 51 | "", 52 | "", 53 | "", 54 | "", 55 | "", 56 | "", 57 | "", 58 | "", 59 | "", 60 | "", 61 | "", 62 | "", 63 | "", 64 | "", 65 | "", 66 | "", 67 | "", 68 | "", 69 | "", 70 | "", 71 | "", 72 | "", 73 | "", 74 | "", 75 | "", 76 | "", 77 | "", 78 | "", 79 | "", 80 | "", 81 | "", 82 | "", 83 | "", 84 | "", 85 | "", 86 | "", 87 | "", 88 | "", 89 | "", 90 | "", 91 | "", 92 | "", 93 | "", 94 | "", 95 | "", 96 | "", 97 | "", 98 | "", 99 | "", 100 | "", 101 | "", 102 | "" 103 | ], 104 | "eos_token": "", 105 | "extra_ids": 100, 106 | "model_max_length": 512, 107 | "pad_token": "", 108 | "special_tokens_map_file": null, 109 | "tokenizer_class": "T5Tokenizer", 110 | "unk_token": "" 111 | } 112 | -------------------------------------------------------------------------------- /tests/model_management/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/caikit/caikit-nlp/dc5a205f064cc99b56da9bcb887faaef4d6dc3d9/tests/model_management/__init__.py -------------------------------------------------------------------------------- /tests/modules/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/caikit/caikit-nlp/dc5a205f064cc99b56da9bcb887faaef4d6dc3d9/tests/modules/__init__.py -------------------------------------------------------------------------------- /tests/modules/text_classification/test_sequence_classification.py: -------------------------------------------------------------------------------- 1 | """Tests for sequence classification module 2 | """ 3 | # Standard 4 | import tempfile 5 | 6 | # Third Party 7 | from pytest import approx 8 | import pytest 9 | 10 | # First Party 11 | from caikit.interfaces.nlp.data_model import ClassificationResult, ClassificationResults 12 | 13 | # Local 14 | from caikit_nlp.modules.text_classification import SequenceClassification 15 | from tests.fixtures import SEQ_CLASS_MODEL 16 | 17 | ## Setup ######################################################################## 18 | 19 | # Bootstrapped sequence classification model for reusability across tests 20 | # .bootstrap is tested separately in the first test 21 | BOOTSTRAPPED_SEQ_CLASS_MODEL = SequenceClassification.bootstrap(SEQ_CLASS_MODEL) 22 | 23 | TEXTS = [ 24 | "The quick brown fox jumps over the lazy dog.", 25 | "Once upon a time in a land far away", 26 | ] 27 | 28 | ## Tests ######################################################################## 29 | # Exact numbers here from the tiny model are not particularly important, 30 | # but we check them here to make sure that the arrays are re-ordered correctly 31 | 32 | 33 | def test_bootstrap_and_run(): 34 | """Check if we can bootstrap and run sequence classification models""" 35 | model = SequenceClassification.bootstrap(SEQ_CLASS_MODEL) 36 | classification_result = model.run(TEXTS[0]) 37 | assert isinstance(classification_result, ClassificationResults) 38 | assert len(classification_result.results) == 2 # 2 labels 39 | 40 | assert isinstance(classification_result.results[0], ClassificationResult) 41 | assert classification_result.results[0].label == "LABEL_0" 42 | assert approx(classification_result.results[0].score) == 0.49526197 43 | assert classification_result.results[1].label == "LABEL_1" 44 | assert approx(classification_result.results[1].score) == 0.50473803 45 | 46 | 47 | def test_bootstrap_and_run_batch(): 48 | """Check if we can bootstrap and run_batch sequence classification models""" 49 | classification_result_list = BOOTSTRAPPED_SEQ_CLASS_MODEL.run_batch(TEXTS) 50 | assert len(classification_result_list) == 2 51 | 52 | first_result = classification_result_list[0] 53 | assert isinstance(first_result, ClassificationResults) 54 | assert first_result.results[0].label == "LABEL_0" 55 | assert approx(first_result.results[0].score) == 0.49526197 56 | assert first_result.results[1].label == "LABEL_1" 57 | assert classification_result_list[1].results[0].label == "LABEL_0" 58 | 59 | 60 | def test_load_save_and_run_model(): 61 | """Check if we can load and run a saved model successfully""" 62 | with tempfile.TemporaryDirectory() as model_dir: 63 | BOOTSTRAPPED_SEQ_CLASS_MODEL.save(model_dir) 64 | new_model = SequenceClassification.load(model_dir) 65 | classification_result = new_model.run(TEXTS[0]) 66 | assert isinstance(classification_result, ClassificationResults) 67 | assert len(classification_result.results) == 2 # 2 labels 68 | 69 | assert isinstance(classification_result.results[0], ClassificationResult) 70 | assert classification_result.results[0].label == "LABEL_0" 71 | assert approx(classification_result.results[0].score) == 0.49526197 72 | -------------------------------------------------------------------------------- /tests/modules/text_generation/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/caikit/caikit-nlp/dc5a205f064cc99b56da9bcb887faaef4d6dc3d9/tests/modules/text_generation/__init__.py -------------------------------------------------------------------------------- /tests/modules/text_generation/test_peft_config.py: -------------------------------------------------------------------------------- 1 | # Standard 2 | from unittest.mock import Mock 3 | 4 | # Third Party 5 | from peft import PromptTuningConfig 6 | import pytest 7 | 8 | # Local 9 | from caikit_nlp.data_model import TuningConfig 10 | from caikit_nlp.modules.text_generation import TextGeneration 11 | from caikit_nlp.modules.text_generation.peft_config import ( 12 | TuningType, 13 | get_peft_config, 14 | resolve_base_model, 15 | ) 16 | from caikit_nlp.resources.pretrained_model import HFAutoSeq2SeqLM 17 | from tests.fixtures import ( 18 | SEQ2SEQ_LM_MODEL, 19 | TINY_MODELS_DIR, 20 | causal_lm_dummy_model, 21 | causal_lm_train_kwargs, 22 | seq2seq_lm_dummy_model, 23 | seq2seq_lm_train_kwargs, 24 | temp_config, 25 | ) 26 | 27 | 28 | @pytest.mark.parametrize( 29 | "train_kwargs,dummy_model", 30 | [ 31 | ( 32 | "seq2seq_lm_train_kwargs", 33 | "seq2seq_lm_dummy_model", 34 | ), 35 | ("causal_lm_train_kwargs", "causal_lm_dummy_model"), 36 | ], 37 | ) 38 | def test_get_peft_config(train_kwargs, dummy_model, request): 39 | # Fixtures can't be called directly or passed to mark parametrize; 40 | # Currently, passing the fixture by name and retrieving it through 41 | # the request is the 'right' way to do this. 42 | train_kwargs = request.getfixturevalue(train_kwargs) 43 | dummy_model = request.getfixturevalue(dummy_model) 44 | 45 | # Define some sample values for testing 46 | tuning_type = TuningType.PROMPT_TUNING 47 | tuning_config = TuningConfig( 48 | num_virtual_tokens=8, 49 | prompt_tuning_init_method="TEXT", 50 | prompt_tuning_init_text="Hello world", 51 | ) 52 | dummy_resource = train_kwargs["base_model"] 53 | 54 | # Call the function being tested 55 | task_type, output_model_types, peft_config, tuning_type = get_peft_config( 56 | tuning_type, 57 | tuning_config, 58 | dummy_resource, 59 | dummy_model, 60 | "float32", 61 | "{{input}}", 62 | ) 63 | 64 | # Add assertions to validate the behavior of the function 65 | assert task_type == dummy_resource.TASK_TYPE 66 | assert output_model_types == dummy_resource.PROMPT_OUTPUT_TYPES 67 | assert tuning_type == TuningType.PROMPT_TUNING 68 | 69 | # Validation for type & important fields in the peft config 70 | assert isinstance(peft_config, PromptTuningConfig) 71 | assert peft_config.num_virtual_tokens == tuning_config.num_virtual_tokens 72 | assert peft_config.task_type == dummy_resource.TASK_TYPE 73 | assert peft_config.prompt_tuning_init == tuning_config.prompt_tuning_init_method 74 | assert peft_config.prompt_tuning_init_text == tuning_config.prompt_tuning_init_text 75 | 76 | 77 | def test_resolve_model_with_invalid_path_raises(): 78 | """Test passing invalid path to resolve_model function raises""" 79 | 80 | invalid_base_model = "path/../../important" 81 | with pytest.raises(ValueError): 82 | resolve_base_model(invalid_base_model, None, "foo") 83 | 84 | 85 | def test_resolve_model_with_valid_folder_path(): 86 | """Test passing valid folder path to resolve_model function works""" 87 | 88 | model = resolve_base_model(SEQ2SEQ_LM_MODEL, TextGeneration, "float32") 89 | 90 | assert isinstance(model, HFAutoSeq2SeqLM) 91 | 92 | 93 | def test_resolve_model_works_preloaded_model(): 94 | 95 | base_model = HFAutoSeq2SeqLM.bootstrap(SEQ2SEQ_LM_MODEL) 96 | resolved_model = resolve_base_model(base_model, TextGeneration, "float32") 97 | assert isinstance(resolved_model, HFAutoSeq2SeqLM) 98 | 99 | 100 | def test_resolve_model_with_different_base_path_works(): 101 | 102 | base_model_name = "T5ForConditionalGeneration" 103 | with temp_config(base_models_dir=TINY_MODELS_DIR): 104 | resolved_model = resolve_base_model(base_model_name, TextGeneration, "float32") 105 | assert isinstance(resolved_model, HFAutoSeq2SeqLM) 106 | -------------------------------------------------------------------------------- /tests/modules/text_generation/test_peft_tgis_remote.py: -------------------------------------------------------------------------------- 1 | """Tests for prompt tuning based inference via TGIS backend; note that these tests mock the 2 | TGIS client and do NOT actually start/hit a TGIS server instance. 3 | """ 4 | # Standard 5 | from typing import Iterable 6 | from unittest import mock 7 | import os 8 | import tempfile 9 | 10 | # Third Party 11 | import pytest 12 | 13 | # Local 14 | from caikit_nlp.modules.text_generation import PeftPromptTuningTGIS 15 | from tests.fixtures import ( 16 | StubTGISClient, 17 | causal_lm_dummy_model, 18 | causal_lm_train_kwargs, 19 | saved_causal_lm_dummy_model, 20 | stub_tgis_backend, 21 | temp_config, 22 | ) 23 | 24 | SAMPLE_TEXT = "Hello stub" 25 | 26 | 27 | def test_load_and_run(causal_lm_dummy_model, stub_tgis_backend): 28 | """Ensure we can export an in memory model, load it, and (mock) run it with the right text & prefix ID.""" 29 | # Patch our stub backend into caikit so that we don't actually try to start TGIS 30 | causal_lm_dummy_model.verbalizer = "hello distributed {{input}}" 31 | 32 | with mock.patch.object(StubTGISClient, "Generate") as mock_gen: 33 | mock_gen.side_effect = StubTGISClient.unary_generate 34 | 35 | # Save the local model & reload it a TGIS backend distributed module 36 | # Also, save the name of the dir + prompt ID, which is the path TGIS expects for the prefix ID 37 | with tempfile.TemporaryDirectory() as model_dir: 38 | causal_lm_dummy_model.save(model_dir) 39 | mock_tgis_model = PeftPromptTuningTGIS.load(model_dir, stub_tgis_backend) 40 | model_prompt_dir = os.path.split(model_dir)[-1] 41 | 42 | # Run an inference request, which is wrapped around our mocked Generate call 43 | result = mock_tgis_model.run( 44 | SAMPLE_TEXT, preserve_input_text=True, max_new_tokens=200, min_new_tokens=50 45 | ) 46 | StubTGISClient.validate_unary_generate_response(result) 47 | 48 | stub_generation_request = mock_gen.call_args_list[0].args[0] 49 | 50 | # Validate that our verbalizer carried over correctly & was applied at inference time 51 | assert mock_tgis_model.verbalizer == causal_lm_dummy_model.verbalizer 52 | assert stub_generation_request.requests[ 53 | 0 54 | ].text == "hello distributed {}".format(SAMPLE_TEXT) 55 | 56 | # Ensure that our prefix ID matches the expected path based on our tmpdir and config 57 | assert model_prompt_dir == stub_generation_request.prefix_id 58 | 59 | 60 | def test_load_and_tokenize(causal_lm_dummy_model, stub_tgis_backend): 61 | """Ensure we can export an in memory model, load it, and tokenize it""" 62 | # Patch our stub backend into caikit so that we don't actually try to start TGIS 63 | causal_lm_dummy_model.verbalizer = "hello distributed {{input}}" 64 | 65 | with mock.patch.object(StubTGISClient, "Tokenize") as mock_gen: 66 | mock_gen.side_effect = StubTGISClient.tokenize 67 | 68 | # Save the local model & reload it a TGIS backend distributed module 69 | with tempfile.TemporaryDirectory() as model_dir: 70 | causal_lm_dummy_model.save(model_dir) 71 | mock_tgis_model = PeftPromptTuningTGIS.load(model_dir, stub_tgis_backend) 72 | 73 | result = mock_tgis_model.run_tokenizer(SAMPLE_TEXT) 74 | StubTGISClient.validate_tokenize_response(result) 75 | 76 | # Validate that our verbalizer carried over correctly & was applied at inference time 77 | assert mock_tgis_model.verbalizer == causal_lm_dummy_model.verbalizer 78 | 79 | 80 | def test_load_and_run_stream_out(causal_lm_dummy_model, stub_tgis_backend): 81 | """Ensure we can export an in memory model, load it, and (mock) run output streaming 82 | with the right text & prefix ID.""" 83 | # Patch our stub backend into caikit so that we don't actually try to start TGIS 84 | causal_lm_dummy_model.verbalizer = "hello distributed {{input}}" 85 | 86 | with mock.patch.object(StubTGISClient, "GenerateStream") as mock_gen: 87 | mock_gen.side_effect = StubTGISClient.stream_generate 88 | 89 | # Save the local model & reload it a TGIS backend distributed module 90 | # Also, save the name of the dir + prompt ID, which is the path TGIS expects for the prefix ID 91 | with tempfile.TemporaryDirectory() as model_dir: 92 | causal_lm_dummy_model.save(model_dir) 93 | mock_tgis_model = PeftPromptTuningTGIS.load(model_dir, stub_tgis_backend) 94 | model_prompt_dir = os.path.split(model_dir)[-1] 95 | stub_tgis_backend.load_prompt_artifacts.assert_called_once() 96 | 97 | # Run an inference request, which is wrapped around our mocked GenerateStream call 98 | stream_result = mock_tgis_model.run_stream_out( 99 | SAMPLE_TEXT, preserve_input_text=True, max_new_tokens=200, min_new_tokens=50 100 | ) 101 | StubTGISClient.validate_stream_generate_response(stream_result) 102 | 103 | stub_generation_request = mock_gen.call_args_list[0].args[0] 104 | 105 | # Validate that our verbalizer carried over correctly & was applied at inference time 106 | assert mock_tgis_model.verbalizer == causal_lm_dummy_model.verbalizer 107 | assert stub_generation_request.request.text == "hello distributed {}".format( 108 | SAMPLE_TEXT 109 | ) 110 | 111 | # Ensure that our prefix ID matches the expected path based on our tmpdir and config 112 | assert model_prompt_dir == stub_generation_request.prefix_id 113 | 114 | 115 | def test_purge_prompt_on_del(saved_causal_lm_dummy_model, stub_tgis_backend): 116 | """Test that the prompt artifacts get purged when a model is deleted""" 117 | 118 | # Load the model and make sure the prompt got copied over 119 | mock_tgis_model = PeftPromptTuningTGIS.load( 120 | saved_causal_lm_dummy_model, stub_tgis_backend 121 | ) 122 | stub_tgis_backend.load_prompt_artifacts.assert_called_once() 123 | 124 | # Delete the model and make sure the prompt got "removed" 125 | with temp_config(unload_tgis_prompt_artifacts=True): 126 | mock_tgis_model.__del__() 127 | stub_tgis_backend.unload_prompt_artifacts.assert_called_once() 128 | prompt_id = os.path.basename(saved_causal_lm_dummy_model) 129 | stub_tgis_backend.unload_prompt_artifacts.assert_called_with( 130 | mock_tgis_model.base_model_name, prompt_id 131 | ) 132 | 133 | 134 | def test_purge_prompt_disabled_on_del(saved_causal_lm_dummy_model, stub_tgis_backend): 135 | """Test that the prompt artifacts are not purged if disabled""" 136 | 137 | # Load the model and make sure the prompt got copied over 138 | mock_tgis_model = PeftPromptTuningTGIS.load( 139 | saved_causal_lm_dummy_model, stub_tgis_backend 140 | ) 141 | stub_tgis_backend.load_prompt_artifacts.assert_called_once() 142 | 143 | # Delete the model and make sure the prompt got "removed" 144 | with temp_config(unload_tgis_prompt_artifacts=False): 145 | mock_tgis_model.__del__() 146 | assert not stub_tgis_backend.unload_prompt_artifacts.called 147 | -------------------------------------------------------------------------------- /tests/modules/text_generation/test_text_generation_local.py: -------------------------------------------------------------------------------- 1 | """Tests for text-generation module 2 | """ 3 | 4 | # Standard 5 | import os 6 | import platform 7 | import tempfile 8 | 9 | # Third Party 10 | import pytest 11 | import torch 12 | 13 | # First Party 14 | from caikit.interfaces.nlp.data_model import GeneratedTextResult, TokenizationResults 15 | import caikit 16 | 17 | # Local 18 | from caikit_nlp.data_model import GenerationTrainRecord 19 | from caikit_nlp.modules.text_generation import TextGeneration 20 | from caikit_nlp.resources.pretrained_model import HFAutoCausalLM, HFAutoSeq2SeqLM 21 | from tests.fixtures import ( 22 | CAUSAL_LM_MODEL, 23 | SEQ2SEQ_LM_MODEL, 24 | disable_wip, 25 | set_cpu_device, 26 | ) 27 | 28 | ### Stub Modules 29 | 30 | 31 | def test_bootstrap_and_run_causallm(): 32 | """Check if we can bootstrap and run causallm models""" 33 | 34 | model = TextGeneration.bootstrap(CAUSAL_LM_MODEL) 35 | 36 | sample_text = "Hello stub" 37 | generated_text = model.run(sample_text) 38 | assert isinstance(generated_text, GeneratedTextResult) 39 | 40 | 41 | def test_bootstrap_and_run_seq2seq(): 42 | """Check if we can bootstrap and run seq2seq models""" 43 | 44 | model = TextGeneration.bootstrap(SEQ2SEQ_LM_MODEL) 45 | 46 | sample_text = "Hello stub" 47 | generated_text = model.run(sample_text) 48 | assert isinstance(generated_text, GeneratedTextResult) 49 | 50 | 51 | def test_bootstrap_and_save_model(): 52 | """Check if we can bootstrap and save the model successfully""" 53 | 54 | model = TextGeneration.bootstrap(SEQ2SEQ_LM_MODEL) 55 | 56 | with tempfile.TemporaryDirectory() as model_dir: 57 | model.save(model_dir) 58 | assert os.path.isfile(os.path.join(model_dir, "config.yml")) 59 | 60 | 61 | def test_save_model_can_run(): 62 | """Check if the model we bootstrap and save is able to load and run successfully""" 63 | model = TextGeneration.bootstrap(SEQ2SEQ_LM_MODEL) 64 | 65 | with tempfile.TemporaryDirectory() as model_dir: 66 | model.save(model_dir) 67 | del model 68 | new_model = TextGeneration.load(model_dir) 69 | sample_text = "Hello stub" 70 | generated_text = new_model.run(sample_text) 71 | assert isinstance(generated_text, GeneratedTextResult) 72 | 73 | 74 | ############################## Training ################################ 75 | 76 | 77 | @pytest.mark.skipif(platform.processor() == "arm", reason="ARM training not supported") 78 | def test_train_model_seq2seq(disable_wip, set_cpu_device): 79 | """Ensure that we can finetune a seq2seq model on some toy data for 1+ 80 | steps & run inference.""" 81 | train_kwargs = { 82 | "base_model": HFAutoSeq2SeqLM.bootstrap( 83 | model_name=SEQ2SEQ_LM_MODEL, tokenizer_name=SEQ2SEQ_LM_MODEL 84 | ), 85 | "num_epochs": 1, 86 | "train_stream": caikit.core.data_model.DataStream.from_iterable( 87 | [ 88 | GenerationTrainRecord( 89 | input="@foo what a cute dog!", output="no complaint" 90 | ), 91 | GenerationTrainRecord( 92 | input="@bar this is the worst idea ever.", output="complaint" 93 | ), 94 | ] 95 | ), 96 | "torch_dtype": torch.float32, 97 | } 98 | model = TextGeneration.train(**train_kwargs) 99 | assert isinstance(model.model, HFAutoSeq2SeqLM) 100 | 101 | # Ensure that we can get something out of it 102 | pred = model.run("@bar what a cute cat!") 103 | assert isinstance(pred, GeneratedTextResult) 104 | 105 | 106 | @pytest.mark.skipif(platform.processor() == "arm", reason="ARM training not supported") 107 | def test_train_model_save_and_load(disable_wip, set_cpu_device): 108 | """Ensure that we are able to save and load a finetuned model and execute inference on it""" 109 | train_kwargs = { 110 | "base_model": HFAutoSeq2SeqLM.bootstrap( 111 | model_name=SEQ2SEQ_LM_MODEL, tokenizer_name=SEQ2SEQ_LM_MODEL 112 | ), 113 | "num_epochs": 1, 114 | "train_stream": caikit.core.data_model.DataStream.from_iterable( 115 | [ 116 | GenerationTrainRecord( 117 | input="@foo what a cute dog!", output="no complaint" 118 | ) 119 | ] 120 | ), 121 | "torch_dtype": torch.float32, 122 | } 123 | model = TextGeneration.train(**train_kwargs) 124 | assert isinstance(model.model, HFAutoSeq2SeqLM) 125 | with tempfile.TemporaryDirectory() as model_dir: 126 | model.save(model_dir) 127 | new_model = TextGeneration.load(model_dir) 128 | sample_text = "Hello stub" 129 | generated_text = new_model.run(sample_text) 130 | assert isinstance(generated_text, GeneratedTextResult) 131 | 132 | 133 | @pytest.mark.skipif(platform.processor() == "arm", reason="ARM training not supported") 134 | def test_train_model_causallm(disable_wip, set_cpu_device): 135 | """Ensure that we can finetune a causal-lm model on some toy data for 1+ 136 | steps & run inference.""" 137 | train_kwargs = { 138 | "base_model": HFAutoCausalLM.bootstrap( 139 | model_name=CAUSAL_LM_MODEL, tokenizer_name=CAUSAL_LM_MODEL 140 | ), 141 | "num_epochs": 1, 142 | "train_stream": caikit.core.data_model.DataStream.from_iterable( 143 | [ 144 | GenerationTrainRecord( 145 | input="@foo what a cute dog!", output="no complaint" 146 | ), 147 | ] 148 | ), 149 | "torch_dtype": torch.float32, 150 | } 151 | model = TextGeneration.train(**train_kwargs) 152 | assert isinstance(model.model, HFAutoCausalLM) 153 | 154 | # Ensure that we can get something out of it 155 | pred = model.run("@bar what a cute cat!") 156 | assert isinstance(pred, GeneratedTextResult) 157 | 158 | 159 | ############################## Inferencing flags ################################ 160 | 161 | 162 | @pytest.mark.skipif(platform.processor() == "arm", reason="ARM training not supported") 163 | def test_train_and_infer_model_causallm(disable_wip, set_cpu_device): 164 | """Ensure that we can finetune a causal-lm model on some toy data for 1+ 165 | steps & run inference.""" 166 | train_kwargs = { 167 | "base_model": HFAutoCausalLM.bootstrap( 168 | model_name=CAUSAL_LM_MODEL, tokenizer_name=CAUSAL_LM_MODEL 169 | ), 170 | "num_epochs": 1, 171 | "train_stream": caikit.core.data_model.DataStream.from_iterable( 172 | [ 173 | GenerationTrainRecord( 174 | input="@foo what a cute dog!", output="no complaint" 175 | ), 176 | ] 177 | ), 178 | "torch_dtype": torch.float32, 179 | } 180 | model = TextGeneration.train(**train_kwargs) 181 | assert isinstance(model.model, HFAutoCausalLM) 182 | 183 | # Ensure that preserve_input_text returns input in output 184 | pred = model.run("@bar what a cute cat!", preserve_input_text=True) 185 | assert "@bar what a cute cat!" in pred.generated_text 186 | 187 | # Ensure that preserve_input_text set to False, removes input from output 188 | pred = model.run("@bar what a cute cat!", preserve_input_text=False) 189 | assert "@bar what a cute cat!" not in pred.generated_text 190 | 191 | 192 | ############################## Error Cases ################################ 193 | 194 | 195 | def test_zero_epoch_case(disable_wip): 196 | """Test to ensure 0 epoch training request doesn't explode""" 197 | train_kwargs = { 198 | "base_model": HFAutoSeq2SeqLM.bootstrap( 199 | model_name=SEQ2SEQ_LM_MODEL, tokenizer_name=SEQ2SEQ_LM_MODEL 200 | ), 201 | "num_epochs": 0, 202 | "train_stream": caikit.core.data_model.DataStream.from_iterable( 203 | [ 204 | GenerationTrainRecord( 205 | input="@foo what a cute dog!", output="no complaint" 206 | ), 207 | ] 208 | ), 209 | "torch_dtype": torch.float32, 210 | } 211 | model = TextGeneration.train(**train_kwargs) 212 | assert isinstance(model.model, HFAutoSeq2SeqLM) 213 | 214 | 215 | # ############################## Run Tokenizer ################################ 216 | 217 | 218 | def test_run_tokenizer_edge_cases(disable_wip, set_cpu_device): 219 | """Test tokenizer on edge cases like empty strings and long input.""" 220 | model = TextGeneration.bootstrap(CAUSAL_LM_MODEL) 221 | 222 | # Edge case: Empty string 223 | empty_result = model.run_tokenizer("") 224 | assert isinstance(empty_result, TokenizationResults) 225 | assert empty_result.token_count == 0 226 | 227 | # Normal case: short sentence 228 | short_text = "This is a test sentence." 229 | short_result = model.run_tokenizer(short_text) 230 | assert isinstance(short_result, TokenizationResults) 231 | assert short_result.token_count == len(model.model.tokenizer.encode(short_text)) 232 | 233 | # Edge case: Long input 234 | long_text = "This is a test sentence. " * 1000 235 | long_result = model.run_tokenizer(long_text) 236 | assert isinstance(long_result, TokenizationResults) 237 | assert long_result.token_count == len(model.model.tokenizer.encode(long_text)) 238 | -------------------------------------------------------------------------------- /tests/modules/tokenization/test_regex_sentence_splitter.py: -------------------------------------------------------------------------------- 1 | """Tests for regex sentence splitter 2 | """ 3 | # Standard 4 | import os 5 | import tempfile 6 | 7 | # First Party 8 | from caikit.interfaces.nlp.data_model import TokenizationResults 9 | 10 | # Local 11 | from caikit_nlp.modules.tokenization.regex_sentence_splitter import ( 12 | RegexSentenceSplitter, 13 | ) 14 | 15 | ## Setup ######################################################################## 16 | 17 | # Regex sentence splitter model for reusability across tests 18 | REGEX_STR = "[^.!?\s][^.!?\n]*(?:[.!?](?!['\"]?\s|$)[^.!?]*)*[.!?]?['\"]?(?=\s|$)" 19 | SENTENCE_TOKENIZER = RegexSentenceSplitter.bootstrap(REGEX_STR) 20 | DOCUMENT = "What he told me before, I have it in my heart. I am tired of fighting." 21 | 22 | ## Tests ######################################################################## 23 | 24 | 25 | def test_bootstrap_and_run(): 26 | """Check if we can bootstrap and run regex sentence splitter""" 27 | tokenization_result = SENTENCE_TOKENIZER.run(DOCUMENT) 28 | assert isinstance(tokenization_result, TokenizationResults) 29 | assert len(tokenization_result.results) == 2 30 | 31 | 32 | def test_save_load_and_run_model(): 33 | """Check if we can run a saved model successfully""" 34 | with tempfile.TemporaryDirectory() as model_dir: 35 | SENTENCE_TOKENIZER.save(model_dir) 36 | assert os.path.exists(os.path.join(model_dir, "config.yml")) 37 | 38 | new_splitter = RegexSentenceSplitter.load(model_dir) 39 | tokenization_result = new_splitter.run(DOCUMENT) 40 | assert isinstance(tokenization_result, TokenizationResults) 41 | assert len(tokenization_result.results) == 2 42 | -------------------------------------------------------------------------------- /tests/resources/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/caikit/caikit-nlp/dc5a205f064cc99b56da9bcb887faaef4d6dc3d9/tests/resources/__init__.py -------------------------------------------------------------------------------- /tests/toolkit/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/caikit/caikit-nlp/dc5a205f064cc99b56da9bcb887faaef4d6dc3d9/tests/toolkit/__init__.py -------------------------------------------------------------------------------- /tests/toolkit/test_data_stream_wrapper.py: -------------------------------------------------------------------------------- 1 | """Tests for wrapper helpers for making DataStreams play nicely with PyTorch DataLoaders. 2 | """ 3 | # Standard 4 | from unittest import mock 5 | 6 | # Third Party 7 | from torch.utils.data._utils import worker 8 | import pytest 9 | 10 | # First Party 11 | from caikit.core.data_model import DataStream 12 | 13 | # Local 14 | from caikit_nlp.toolkit.data_stream_wrapper import SimpleIterableStreamWrapper 15 | from tests.fixtures import requires_determinism 16 | 17 | # Sample data to load via PyTorch 18 | SAMPLE_DATA = [{"label": "foo"}, {"label": "foo"}, {"label": "bar"}, {"label": "bar"}] 19 | SAMPLE_STREAM = DataStream.from_iterable(SAMPLE_DATA) 20 | NUM_CYCLES = 10 21 | 22 | 23 | def test_without_shuffling(): 24 | """Ensure that we can build a datastream & load it in a data loader without shuffling.""" 25 | test_results = [] 26 | # Get the IDs of all objects in the stream 27 | get_stream_id_order = lambda s: [id(datum) for datum in s] 28 | # Compare the data stream at two different iteration points; here, we 29 | # produce True if two streams have the same objects in the same order 30 | have_same_id_order = lambda id_set1, id_set2: all( 31 | [datum1 == datum2 for datum1, datum2 in zip(id_set1, id_set2)] 32 | ) and len(id_set1) == len(id_set2) 33 | 34 | # NOTE - a buffer size of 1 is a noop; the shuffle operation just gets the current element 35 | wrapper = SimpleIterableStreamWrapper(stream=SAMPLE_STREAM, shuffle=False) 36 | # Cycle through NUM_CYCLES times & ensure that the order of our objects does not change 37 | initialize_order = get_stream_id_order(wrapper) 38 | for _ in range(NUM_CYCLES): 39 | cycle_ids = get_stream_id_order(wrapper) 40 | test_res = have_same_id_order(initialize_order, cycle_ids) 41 | test_results.append(test_res) 42 | assert all(test_results) 43 | 44 | 45 | def test_shuffle_full_buffer(requires_determinism): 46 | """Ensure that we can build a datastream & shuffle it all in memory on each iteration.""" 47 | test_results = [] 48 | # Get the IDs of all objects in the stream 49 | get_stream_id_order = lambda s: [id(datum) for datum in s] 50 | # Compare the data stream at two different iteration points; here, we 51 | # produce True if two streams have the same objects in the same order 52 | have_same_id_order = lambda id_set1, id_set2: all( 53 | [datum1 == datum2 for datum1, datum2 in zip(id_set1, id_set2)] 54 | ) and len(id_set1) == len(id_set2) 55 | 56 | # NOTE - a buffer size of 1 is a noop; the shuffle operation just gets the current element 57 | wrapper = SimpleIterableStreamWrapper( 58 | stream=SAMPLE_STREAM, shuffle=True, buffer_size=len(SAMPLE_STREAM) 59 | ) 60 | # Cycle through NUM_CYCLES times & ensure that the order of our objects DOES change sometimes 61 | initialize_order = get_stream_id_order(wrapper) 62 | for _ in range(NUM_CYCLES): 63 | cycle_ids = get_stream_id_order(wrapper) 64 | test_res = have_same_id_order(initialize_order, cycle_ids) 65 | test_results.append(test_res) 66 | assert not all(test_results) 67 | 68 | 69 | # def test_iter_with_multi_worker(requires_determinism): 70 | # """Ensure that we are able to iterate properly over data in case of workers 71 | # managed by torch""" 72 | 73 | # test_results = [] 74 | # # Get the IDs of all objects in the stream 75 | # get_stream_id_order = lambda s: [id(datum) for datum in s] 76 | # # Compare the data stream at two different iteration points; here, we 77 | # # produce True if two streams have the same objects in the same order 78 | # have_same_id_order = lambda id_set1, id_set2: all( 79 | # [datum1 == datum2 for datum1, datum2 in zip(id_set1, id_set2)] 80 | # ) and len(id_set1) == len(id_set2) 81 | 82 | # dummy_worker_info = worker.WorkerInfo( 83 | # id=1, 84 | # num_workers=2, 85 | # seed=7, 86 | # ) 87 | 88 | # with mock.patch.object(worker, '_worker_info', dummy_worker_info): 89 | # wrapper = SimpleIterableStreamWrapper(stream=SAMPLE_STREAM, shuffle=False) 90 | # initialize_order = get_stream_id_order(wrapper) 91 | # for _ in range(NUM_CYCLES): 92 | # cycle_ids = get_stream_id_order(wrapper) 93 | # test_res = have_same_id_order(initialize_order, cycle_ids) 94 | # test_results.append(test_res) 95 | # assert not all(test_results) 96 | -------------------------------------------------------------------------------- /tests/toolkit/test_data_type_utils.py: -------------------------------------------------------------------------------- 1 | """Tests for data type related utils, e.g., for interacting with serialized torch types. 2 | """ 3 | # Third Party 4 | import pytest 5 | import torch 6 | 7 | # Local 8 | from caikit_nlp.toolkit.data_type_utils import get_torch_dtype, str_to_torch_dtype 9 | 10 | ### Tests for converting from strings / types / None -> torch data types 11 | 12 | 13 | def test_get_torch_dtype_from_str(): 14 | """Ensure that we can parse a data type from a string.""" 15 | assert torch.float32 is get_torch_dtype("float32") 16 | 17 | 18 | def test_get_torch_dtype_from_dtype(): 19 | """Ensure that if a data type is provided from pytorch, we simply return it.""" 20 | assert torch.float32 is get_torch_dtype(torch.float32) 21 | 22 | 23 | def test_get_torch_dtype_from_bad_type(): 24 | """Ensure that if a type we can't coerce to a pytorch dtype is given, we get a TypeError.""" 25 | with pytest.raises(TypeError): 26 | assert torch.float32 is get_torch_dtype(100) 27 | 28 | 29 | def test_get_torch_dtype_from_bad_str(): 30 | """Ensure that if an invalid type string is given, we get a ValueError.""" 31 | with pytest.raises(ValueError): 32 | assert torch.float32 is get_torch_dtype("not a valid attr of pytorch") 33 | 34 | 35 | ### Tests for converting from strings -> torch data types 36 | def test_str_to_torch_dtype(): 37 | """Ensure that we can parse a data type from a string.""" 38 | assert str_to_torch_dtype("float32") is torch.float32 39 | 40 | 41 | def test_str_to_torch_dtype_invalid_attr(): 42 | """Ensure that we raise ValueError if an incorrect type str is provided.""" 43 | with pytest.raises(ValueError): 44 | str_to_torch_dtype("not a valid attr of pytorch") 45 | 46 | 47 | def test_str_to_torch_dtype_bad_attr(): 48 | """Ensure that we raise ValueError if a non type property of torch is provided.""" 49 | with pytest.raises(ValueError): 50 | str_to_torch_dtype("nn") 51 | -------------------------------------------------------------------------------- /tests/toolkit/test_task_specific_utils.py: -------------------------------------------------------------------------------- 1 | # Copyright The Caikit Authors 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | # Third Party 16 | import pytest 17 | 18 | # First Party 19 | from caikit.interfaces.nlp.data_model import ClassificationTrainRecord 20 | 21 | # Local 22 | from caikit_nlp.data_model import GenerationTrainRecord 23 | from caikit_nlp.toolkit.task_specific_utils import convert_to_generation_record 24 | 25 | 26 | def test_convert_classification_train_record_to_generation_record(): 27 | classification_train_record = ClassificationTrainRecord( 28 | text="foo bar", labels=["label1"] 29 | ) 30 | generated_train = convert_to_generation_record(classification_train_record) 31 | assert isinstance(generated_train, GenerationTrainRecord) 32 | assert generated_train.input == "foo bar" 33 | assert generated_train.output == "label1" 34 | 35 | 36 | def test_convert_generation_record_to_generation_record(): 37 | generation_train_record = GenerationTrainRecord(input="foo bar", output="label1") 38 | generated_train = convert_to_generation_record(generation_train_record) 39 | assert isinstance(generated_train, GenerationTrainRecord) 40 | assert generated_train.input == generation_train_record.input 41 | assert generated_train.output == generation_train_record.output 42 | 43 | 44 | # When we support integer labels 45 | # def test_convert_classification_train_record_to_generation_record_numeric_labels(): 46 | # classification_train_record = dm.ClassificationTrainRecord( 47 | # text="foo bar", labels=[1] 48 | # ) 49 | # generated_train = convert_to_generation_record(classification_train_record) 50 | # assert isinstance(generated_train, dm.GenerationTrainRecord) 51 | # assert generated_train.input == classification_train_record.text 52 | # assert generated_train.output == "1" 53 | 54 | 55 | def test_convert_to_generation_record_gives_error_with_unsupported_type(): 56 | string_record = "test record" 57 | with pytest.raises(TypeError): 58 | convert_to_generation_record(string_record) 59 | -------------------------------------------------------------------------------- /tests/toolkit/test_verbalizers.py: -------------------------------------------------------------------------------- 1 | """Tests for verbalizer substitution utils. 2 | """ 3 | # Third Party 4 | import pytest 5 | 6 | # Local 7 | from caikit_nlp.toolkit.verbalizer_utils import is_valid_verbalizer, render_verbalizer 8 | import caikit_nlp 9 | 10 | SAMPLE_DM = caikit_nlp.data_model.GenerationTrainRecord( 11 | input="my input text", output="my output text" 12 | ) 13 | SAMPLE_DICT = SAMPLE_DM.to_dict() 14 | 15 | ### Happy path rendering cases 16 | def test_is_valid_verbalizer(): 17 | """Ensure that when we have happy verbalizers, it's easy to check.""" 18 | assert is_valid_verbalizer("{{input}}") == True 19 | assert is_valid_verbalizer("Input: {{input}}") == True 20 | assert is_valid_verbalizer("Input: {{input}} Output: {{output}}") == True 21 | 22 | 23 | def test_raw_verbalizer_is_preserved(): 24 | """Ensure that if no placeholders are provided, the raw string is returned.""" 25 | # NOTE: you would never want to do this because it means your input string is hardcoded! 26 | # However, we the responsibility of checking this up to the verbalizer validator so that 27 | # we can validate up front, rather than checking lazily at train time etc. 28 | verbalizer_template = "text: foo label: bar" 29 | expected_render_result = verbalizer_template 30 | # Check that we get the same behavior from both the DM / Dict based verbalizer replacement 31 | assert render_verbalizer(verbalizer_template, SAMPLE_DM) == expected_render_result 32 | assert render_verbalizer(verbalizer_template, SAMPLE_DICT) == expected_render_result 33 | 34 | 35 | def test_verbalizer_substitution(): 36 | """Ensure that if placeholders are used, class attrs are rendered into placeholders.""" 37 | verbalizer_template = "text: {{input}} label: {{output}}" 38 | expected_render_result = "text: {} label: {}".format( 39 | SAMPLE_DM.input, SAMPLE_DM.output 40 | ) 41 | assert render_verbalizer(verbalizer_template, SAMPLE_DM) == expected_render_result 42 | assert render_verbalizer(verbalizer_template, SAMPLE_DICT) == expected_render_result 43 | 44 | 45 | def test_invalid_value_verbalizer_substitution(): 46 | """Ensure that if we try to render a placeholder that is a bad property name, we skip it.""" 47 | verbalizer_template = "text: {{foo bar}}" 48 | expected_render_result = verbalizer_template 49 | assert render_verbalizer(verbalizer_template, SAMPLE_DM) == expected_render_result 50 | assert render_verbalizer(verbalizer_template, SAMPLE_DICT) == expected_render_result 51 | 52 | 53 | def test_empty_verbalizer_substitution(): 54 | """Ensure that if we try to render an empty placeholder, we skip it.""" 55 | verbalizer_template = "text: {{}} label: {{}}" 56 | expected_render_result = verbalizer_template 57 | assert render_verbalizer(verbalizer_template, SAMPLE_DM) == expected_render_result 58 | assert render_verbalizer(verbalizer_template, SAMPLE_DICT) == expected_render_result 59 | 60 | 61 | ### sad path rendering cases 62 | def test_is_invalid_verbalizer(): 63 | """Ensure that when we have happy verbalizers, it's easy to check.""" 64 | assert is_valid_verbalizer(100) == False 65 | assert is_valid_verbalizer("") == False 66 | assert is_valid_verbalizer("source") == False 67 | assert is_valid_verbalizer("{{this is not a valid placeholder}}") == False 68 | 69 | 70 | def test_invalid_attribute_verbalizer_substitution(): 71 | """Ensure that if we try to render a placeholder that doesn't exist on our DM, we fail.""" 72 | verbalizer_template = "text: {{sadness}}" 73 | with pytest.raises(ValueError): 74 | render_verbalizer(verbalizer_template, SAMPLE_DM) 75 | 76 | 77 | def test_invalid_key_verbalizer_substitution(): 78 | """Ensure that if we try to render a placeholder that doesn't exist on our dict, we fail.""" 79 | verbalizer_template = "text: {{sadness}}" 80 | with pytest.raises(ValueError): 81 | render_verbalizer(verbalizer_template, SAMPLE_DICT) 82 | -------------------------------------------------------------------------------- /tests/toolkit/text_generation/test_model_run_utils.py: -------------------------------------------------------------------------------- 1 | # Third Party 2 | import pytest 3 | 4 | # First Party 5 | from caikit.core.data_model.producer import ProducerId 6 | from caikit.interfaces.nlp.data_model import GeneratedTextResult 7 | 8 | # Local 9 | from caikit_nlp.toolkit.text_generation.model_run_utils import generate_text_func 10 | from tests.fixtures import ( 11 | causal_lm_dummy_model, 12 | causal_lm_train_kwargs, 13 | seq2seq_lm_dummy_model, 14 | seq2seq_lm_train_kwargs, 15 | ) 16 | 17 | 18 | @pytest.mark.parametrize( 19 | "model_fixture", ["seq2seq_lm_dummy_model", "causal_lm_dummy_model"] 20 | ) 21 | @pytest.mark.parametrize( 22 | "serialization_method,expected_type", 23 | [ 24 | ("to_dict", dict), 25 | ("to_json", str), 26 | ("to_proto", GeneratedTextResult._proto_class), 27 | ], 28 | ) 29 | def test_generate_text_func_serialization_json( 30 | request, 31 | model_fixture, 32 | serialization_method, 33 | expected_type, 34 | ): 35 | model = request.getfixturevalue(model_fixture) 36 | generated_text = generate_text_func( 37 | model=model.model, 38 | tokenizer=model.tokenizer, 39 | producer_id=ProducerId("TextGeneration", "0.1.0"), 40 | eos_token="<\n>", 41 | text="What is the boiling point of liquid Nitrogen?", 42 | ) 43 | 44 | serialized = getattr(generated_text, serialization_method)() 45 | assert isinstance(serialized, expected_type) 46 | 47 | 48 | @pytest.mark.parametrize("causal_model_fixture", ["causal_lm_dummy_model"]) 49 | def test_generate_text_func_preserve_input_causal_lm(request, causal_model_fixture): 50 | """For Causal LM task types, setting preserve_inout_text to True 51 | will result in input text in model prediction. Setting to False will 52 | strip the input text from model prediction. 53 | """ 54 | input_text = "What is the boiling point of liquid Nitrogen?" 55 | causal_model = request.getfixturevalue(causal_model_fixture) 56 | # assert type(causal_model.model) == False 57 | generated_text = generate_text_func( 58 | model=causal_model.model, 59 | tokenizer=causal_model.tokenizer, 60 | producer_id=ProducerId("TextGeneration", "0.1.0"), 61 | eos_token="<\n>", 62 | text=input_text, 63 | preserve_input_text=True, 64 | task_type="CAUSAL_LM", 65 | ) 66 | assert input_text in generated_text.generated_text 67 | generated_text = generate_text_func( 68 | model=causal_model.model, 69 | tokenizer=causal_model.tokenizer, 70 | producer_id=ProducerId("TextGeneration", "0.1.0"), 71 | eos_token="<\n>", 72 | text=input_text, 73 | preserve_input_text=False, 74 | task_type="CAUSAL_LM", 75 | ) 76 | assert input_text not in generated_text.generated_text 77 | 78 | 79 | @pytest.mark.parametrize("seq_model_fixture", ["seq2seq_lm_dummy_model"]) 80 | def test_generate_text_func_preserve_input(request, seq_model_fixture): 81 | """For Seq2Seq LM task types, setting preserve_inout_text to True 82 | or False should not change predictions. 83 | """ 84 | input_text = "What is the boiling point of liquid Nitrogen?" 85 | seq_model = request.getfixturevalue(seq_model_fixture) 86 | # assert type(causal_model.model) == False 87 | generated_text = generate_text_func( 88 | model=seq_model.model, 89 | tokenizer=seq_model.tokenizer, 90 | producer_id=ProducerId("TextGeneration", "0.1.0"), 91 | eos_token="<\n>", 92 | text=input_text, 93 | preserve_input_text=True, 94 | task_type="SEQ_2_SEQ_LM", 95 | ) 96 | before_pred = generated_text.generated_text 97 | generated_text = generate_text_func( 98 | model=seq_model.model, 99 | tokenizer=seq_model.tokenizer, 100 | producer_id=ProducerId("TextGeneration", "0.1.0"), 101 | eos_token="<\n>", 102 | text=input_text, 103 | preserve_input_text=False, 104 | task_type="SEQ_2_SEQ_LM", 105 | ) 106 | after_pred = generated_text.generated_text 107 | assert before_pred == after_pred 108 | -------------------------------------------------------------------------------- /tests/toolkit/text_generation/test_tgis_utils.py: -------------------------------------------------------------------------------- 1 | # Copyright The Caikit Authors 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | """ 15 | Tests for tgis_utils 16 | """ 17 | 18 | # Standard 19 | from typing import Iterable, Optional, Type 20 | 21 | # Third Party 22 | import fastapi 23 | import grpc 24 | import grpc._channel 25 | import pytest 26 | 27 | # First Party 28 | from caikit.core.data_model import ProducerId 29 | from caikit.core.exceptions.caikit_core_exception import CaikitCoreException 30 | from caikit.interfaces.runtime.data_model import RuntimeServerContextType 31 | from caikit_tgis_backend.protobufs import generation_pb2 32 | 33 | # Local 34 | from caikit_nlp.toolkit.text_generation import tgis_utils 35 | from tests.fixtures import TestServicerContext 36 | 37 | ## Helpers ##################################################################### 38 | 39 | 40 | class MockTgisClient: 41 | """Mock of a TGIS client that doesn't actually call anything""" 42 | 43 | def __init__( 44 | self, 45 | status_code: Optional[grpc.StatusCode], 46 | error_message: str = "Yikes", 47 | ): 48 | self._status_code = status_code 49 | self._error_message = error_message 50 | 51 | def _maybe_raise(self, error_type: Type[grpc.RpcError], *args): 52 | if self._status_code not in [None, grpc.StatusCode.OK]: 53 | raise error_type( 54 | grpc._channel._RPCState( 55 | [], [], [], code=self._status_code, details=self._error_message 56 | ), 57 | *args, 58 | ) 59 | 60 | def Generate( 61 | self, request: generation_pb2.BatchedGenerationRequest, **kwargs 62 | ) -> generation_pb2.BatchedGenerationResponse: 63 | self._maybe_raise(grpc._channel._InactiveRpcError) 64 | return generation_pb2.BatchedGenerationResponse() 65 | 66 | def GenerateStream( 67 | self, request: generation_pb2.SingleGenerationRequest, **kwargs 68 | ) -> Iterable[generation_pb2.GenerationResponse]: 69 | self._maybe_raise(grpc._channel._MultiThreadedRendezvous, None, None, None) 70 | yield generation_pb2.GenerationResponse() 71 | 72 | def Tokenize( 73 | self, request: generation_pb2.BatchedTokenizeRequest, **kwargs 74 | ) -> generation_pb2.BatchedTokenizeResponse: 75 | self._maybe_raise(grpc._channel._InactiveRpcError) 76 | return generation_pb2.BatchedTokenizeResponse() 77 | 78 | 79 | ## TGISGenerationClient ######################################################## 80 | 81 | 82 | @pytest.mark.parametrize( 83 | "status_code", 84 | [code for code in grpc.StatusCode if code != grpc.StatusCode.OK], 85 | ) 86 | @pytest.mark.parametrize( 87 | "method", ["unary_generate", "stream_generate", "unary_tokenize"] 88 | ) 89 | def test_TGISGenerationClient_rpc_errors(status_code, method): 90 | """Test that raised errors in downstream RPCs are converted to 91 | CaikitCoreException correctly 92 | """ 93 | tgis_client = MockTgisClient(status_code) 94 | gen_client = tgis_utils.TGISGenerationClient( 95 | "foo", 96 | "bar", 97 | tgis_client, 98 | ProducerId("foobar"), 99 | ) 100 | with pytest.raises(CaikitCoreException) as context: 101 | kwargs = ( 102 | dict( 103 | preserve_input_text=True, 104 | input_tokens=True, 105 | generated_tokens=True, 106 | token_logprobs=True, 107 | token_ranks=True, 108 | max_new_tokens=20, 109 | min_new_tokens=20, 110 | truncate_input_tokens=True, 111 | decoding_method="GREEDY", 112 | top_k=None, 113 | top_p=None, 114 | typical_p=None, 115 | temperature=None, 116 | seed=None, 117 | repetition_penalty=0.5, 118 | max_time=None, 119 | exponential_decay_length_penalty=None, 120 | stop_sequences=["asdf"], 121 | include_stop_sequence=True, 122 | ) 123 | if method.endswith("_generate") 124 | else dict() 125 | ) 126 | res = getattr(gen_client, method)(text="foobar", **kwargs) 127 | if method.startswith("stream_"): 128 | next(res) 129 | 130 | assert ( 131 | context.value.status_code == tgis_utils.GRPC_TO_CAIKIT_CORE_STATUS[status_code] 132 | ) 133 | rpc_err = context.value.__context__ 134 | assert isinstance(rpc_err, grpc.RpcError) 135 | 136 | 137 | # NOTE: This test is preserved in caikit-nlp despite being duplicated in 138 | # caikit-tgis-backend so that we guarantee that the functionality is accessible 139 | # in a version-compatible way here. 140 | @pytest.mark.parametrize( 141 | argnames=["context", "route_info"], 142 | argvalues=[ 143 | ( 144 | fastapi.Request( 145 | { 146 | "type": "http", 147 | "headers": [ 148 | (tgis_utils.ROUTE_INFO_HEADER_KEY.encode(), b"sometext") 149 | ], 150 | } 151 | ), 152 | "sometext", 153 | ), 154 | ( 155 | fastapi.Request( 156 | {"type": "http", "headers": [(b"route-info", b"sometext")]} 157 | ), 158 | None, 159 | ), 160 | ( 161 | TestServicerContext({tgis_utils.ROUTE_INFO_HEADER_KEY: "sometext"}), 162 | "sometext", 163 | ), 164 | ( 165 | TestServicerContext({"route-info": "sometext"}), 166 | None, 167 | ), 168 | ("should raise ValueError", None), 169 | (None, None), 170 | # Uncertain how to create a grpc.ServicerContext object 171 | ], 172 | ) 173 | def test_get_route_info(context: RuntimeServerContextType, route_info: Optional[str]): 174 | if not isinstance(context, (fastapi.Request, grpc.ServicerContext, type(None))): 175 | with pytest.raises(TypeError): 176 | tgis_utils.get_route_info(context) 177 | else: 178 | actual_route_info = tgis_utils.get_route_info(context) 179 | assert actual_route_info == route_info 180 | -------------------------------------------------------------------------------- /tox.ini: -------------------------------------------------------------------------------- 1 | [tox] 2 | envlist = py, lint, fmt 3 | 4 | [testenv] 5 | description = run tests with pytest with coverage 6 | deps = 7 | evaluate # Currently used for sample scripts etc. 8 | pytest==7.1.3 9 | pytest-cov>=2.10.1,<3.0 10 | pytest-html>=3.1.1,<4.0 11 | tf-keras>=2.18.0 12 | wheel>=0.38.4 13 | passenv = 14 | LOG_LEVEL 15 | LOG_FILTERS 16 | LOG_FORMATTER 17 | LOG_THREAD_ID 18 | LOG_CHANNEL_WIDTH 19 | PYTORCH_ENABLE_MPS_FALLBACK 20 | commands = pytest --durations=42 --cov=caikit_nlp --cov-report=term --cov-report=html {posargs:tests} 21 | 22 | ; Unclear: We probably want to test wheel packaging 23 | ; But! tox will fail when this is set and _any_ interpreter is missing 24 | ; Without this, sdist packaging is tested so that's a start. 25 | package=wheel 26 | 27 | [testenv:fmt] 28 | description = format with pre-commit 29 | deps = pre-commit>=3.0.4,<4.0 30 | commands = ./scripts/fmt.sh 31 | allowlist_externals = ./scripts/fmt.sh 32 | skip_install = True # Skip package install since fmt doesn't need to execute code, for ⚡⚡⚡ 33 | 34 | [testenv:lint] 35 | description = lint with pylint 36 | deps = pylint>=2.16.2,<3.0 37 | commands = pylint caikit_nlp 38 | 39 | [testenv:build] 40 | description = build wheel 41 | deps = 42 | build 43 | commands = python -m build 44 | skip_install = True 45 | 46 | [testenv:twinecheck] 47 | description = check wheel 48 | deps = 49 | twine 50 | commands = twine check dist/* 51 | skip_install = True 52 | --------------------------------------------------------------------------------