├── tests
├── __init__.py
├── modules
│ ├── __init__.py
│ ├── text_generation
│ │ ├── __init__.py
│ │ ├── test_peft_config.py
│ │ ├── test_peft_tgis_remote.py
│ │ └── test_text_generation_local.py
│ ├── tokenization
│ │ └── test_regex_sentence_splitter.py
│ └── text_classification
│ │ └── test_sequence_classification.py
├── resources
│ └── __init__.py
├── toolkit
│ ├── __init__.py
│ ├── test_data_type_utils.py
│ ├── test_task_specific_utils.py
│ ├── test_verbalizers.py
│ ├── text_generation
│ │ ├── test_model_run_utils.py
│ │ └── test_tgis_utils.py
│ └── test_data_stream_wrapper.py
├── model_management
│ └── __init__.py
├── fixtures
│ ├── tiny_models
│ │ ├── BloomForCausalLM
│ │ │ ├── pytorch_model.bin
│ │ │ ├── special_tokens_map.json
│ │ │ ├── generation_config.json
│ │ │ ├── tokenizer_config.json
│ │ │ └── config.json
│ │ ├── BertForSequenceClassification
│ │ │ ├── tf_model.h5
│ │ │ ├── pytorch_model.bin
│ │ │ ├── special_tokens_map.json
│ │ │ ├── tokenizer_config.json
│ │ │ └── config.json
│ │ ├── T5ForConditionalGeneration
│ │ │ ├── pytorch_model.bin
│ │ │ ├── generation_config.json
│ │ │ ├── config.json
│ │ │ ├── special_tokens_map.json
│ │ │ └── tokenizer_config.json
│ │ └── README.md
│ └── data_model
│ │ └── sample_objects.py
├── conftest.py
└── data_model
│ └── test_generation.py
├── caikit_nlp
├── toolkit
│ ├── __init__.py
│ ├── text_generation
│ │ └── __init__.py
│ ├── task_specific_utils.py
│ ├── data_type_utils.py
│ ├── trainer_utils.py
│ ├── data_stream_wrapper.py
│ ├── torch_run.py
│ └── verbalizer_utils.py
├── version.py
├── model_management
│ ├── __init__.py
│ └── tgis_auto_finder.py
├── modules
│ ├── tokenization
│ │ ├── __init__.py
│ │ └── regex_sentence_splitter.py
│ ├── text_classification
│ │ ├── __init__.py
│ │ └── sequence_classification.py
│ ├── token_classification
│ │ └── __init__.py
│ ├── __init__.py
│ ├── text_generation
│ │ ├── __init__.py
│ │ └── peft_config.py
│ └── text_embedding
│ │ ├── utils.py
│ │ └── __init__.py
├── resources
│ ├── __init__.py
│ └── pretrained_model
│ │ ├── __init__.py
│ │ ├── hf_auto_seq_classifier.py
│ │ └── hf_auto_seq2seq_lm.py
├── data_model
│ ├── __init__.py
│ └── generation.py
├── config
│ ├── __init__.py
│ └── config.yml
└── __init__.py
├── setup_requirements.txt
├── .whitesource
├── .dockerignore
├── .prettierignore
├── .github
├── dependabot.yml
├── ISSUE_TEMPLATE
│ ├── user_story.md
│ ├── feature_request.md
│ └── bug_report.md
└── workflows
│ ├── build-image.yml
│ ├── publish-library.yml
│ ├── lint-code.yml
│ └── build-library.yml
├── SECURITY.md
├── code-of-conduct.md
├── scripts
├── run_local.sh
├── dump_apis.sh
└── fmt.sh
├── CODEOWNERS
├── .isort.cfg
├── .pre-commit-config.yaml
├── examples
├── kill-text-generation-launcher.sh
├── text-generation-launcher
├── load_and_run_distributed_peft.py
├── compare_local_vs_tgis_models.py
└── evaluate_model.py
├── pyproject.toml
├── runtime_template
└── run_with_gateway.sh
├── tox.ini
├── Dockerfile
├── benchmarks
├── README.md
└── logs
│ └── llama2-7b
│ ├── 20230905_194133.output
│ ├── 20230906_135211.output
│ ├── 20230905_183655.output
│ ├── 20230905_184809.output
│ └── 20230905_191650.output
├── runtime_config.yaml
├── .gitignore
├── prompt_tuning_parameter_selection.md
└── CONTRIBUTING.md
/tests/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/tests/modules/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/tests/resources/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/tests/toolkit/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/caikit_nlp/toolkit/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/tests/model_management/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/tests/modules/text_generation/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/caikit_nlp/toolkit/text_generation/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/setup_requirements.txt:
--------------------------------------------------------------------------------
1 | tox>=4.4.2,<5
2 | build>=0.10.0,<2.0
--------------------------------------------------------------------------------
/.whitesource:
--------------------------------------------------------------------------------
1 | {
2 | "settingsInheritedFrom": "whitesource-config/whitesource-config@master"
3 | }
--------------------------------------------------------------------------------
/.dockerignore:
--------------------------------------------------------------------------------
1 | *
2 | !README.md
3 | !LICENSE
4 | !caikit_nlp
5 | !pyproject.toml
6 | !tox.ini
7 | !.git
8 |
--------------------------------------------------------------------------------
/.prettierignore:
--------------------------------------------------------------------------------
1 | venv
2 | docker_build_scripts
3 | htmlcov
4 | reports
5 | .pytest_cache
6 | models
7 | *.md
8 | tests/fixtures/tiny_models
9 |
--------------------------------------------------------------------------------
/.github/dependabot.yml:
--------------------------------------------------------------------------------
1 | version: 2
2 | updates:
3 | - package-ecosystem: "pip"
4 | directory: "/"
5 | schedule:
6 | interval: "daily"
7 |
--------------------------------------------------------------------------------
/SECURITY.md:
--------------------------------------------------------------------------------
1 | # Security Policy
2 |
3 | The Caikit project has a [common security policy that can be found here](https://github.com/caikit/community/blob/main/SECURITY.md).
4 |
--------------------------------------------------------------------------------
/tests/fixtures/tiny_models/BloomForCausalLM/pytorch_model.bin:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/caikit/caikit-nlp/HEAD/tests/fixtures/tiny_models/BloomForCausalLM/pytorch_model.bin
--------------------------------------------------------------------------------
/tests/fixtures/tiny_models/BloomForCausalLM/special_tokens_map.json:
--------------------------------------------------------------------------------
1 | {
2 | "bos_token": "",
3 | "eos_token": "",
4 | "pad_token": "",
5 | "unk_token": ""
6 | }
7 |
--------------------------------------------------------------------------------
/tests/fixtures/tiny_models/BertForSequenceClassification/tf_model.h5:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/caikit/caikit-nlp/HEAD/tests/fixtures/tiny_models/BertForSequenceClassification/tf_model.h5
--------------------------------------------------------------------------------
/tests/fixtures/tiny_models/T5ForConditionalGeneration/pytorch_model.bin:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/caikit/caikit-nlp/HEAD/tests/fixtures/tiny_models/T5ForConditionalGeneration/pytorch_model.bin
--------------------------------------------------------------------------------
/code-of-conduct.md:
--------------------------------------------------------------------------------
1 | # Community Code of Conduct
2 |
3 | The Caikit project [has a common code of conduct that can be found here](https://github.com/caikit/community/blob/main/code-of-conduct.md).
4 |
--------------------------------------------------------------------------------
/tests/fixtures/tiny_models/BertForSequenceClassification/pytorch_model.bin:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/caikit/caikit-nlp/HEAD/tests/fixtures/tiny_models/BertForSequenceClassification/pytorch_model.bin
--------------------------------------------------------------------------------
/tests/fixtures/tiny_models/BloomForCausalLM/generation_config.json:
--------------------------------------------------------------------------------
1 | {
2 | "_from_model_config": true,
3 | "bos_token_id": 1,
4 | "eos_token_id": 2,
5 | "pad_token_id": 3,
6 | "transformers_version": "4.27.1"
7 | }
8 |
--------------------------------------------------------------------------------
/tests/fixtures/tiny_models/BertForSequenceClassification/special_tokens_map.json:
--------------------------------------------------------------------------------
1 | {
2 | "cls_token": "[CLS]",
3 | "mask_token": "[MASK]",
4 | "pad_token": "[PAD]",
5 | "sep_token": "[SEP]",
6 | "unk_token": "[UNK]"
7 | }
8 |
--------------------------------------------------------------------------------
/caikit_nlp/version.py:
--------------------------------------------------------------------------------
1 | # pylint: disable=unused-import
2 | try:
3 | # Local
4 | from ._version import __version__, __version_tuple__
5 | except ImportError:
6 | __version__ = "unknown"
7 | version_tuple = (0, 0, __version__)
8 |
--------------------------------------------------------------------------------
/tests/fixtures/tiny_models/T5ForConditionalGeneration/generation_config.json:
--------------------------------------------------------------------------------
1 | {
2 | "_from_model_config": true,
3 | "bos_token_id": 0,
4 | "decoder_start_token_id": 0,
5 | "eos_token_id": 1,
6 | "pad_token_id": 0,
7 | "transformers_version": "4.27.1"
8 | }
9 |
--------------------------------------------------------------------------------
/scripts/run_local.sh:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env bash
2 | cd $(dirname ${BASH_SOURCE[0]})/..
3 | mkdir -p models
4 |
5 | server=${SERVER:-"http"}
6 |
7 | CONFIG_FILES=runtime_config.yaml \
8 | LOG_LEVEL=${LOG_LEVEL:-debug3} \
9 | LOG_FORMATTER=pretty \
10 | python -m caikit.runtime.${server}_server
11 |
--------------------------------------------------------------------------------
/CODEOWNERS:
--------------------------------------------------------------------------------
1 | #####################################################
2 | #
3 | # List of approvers for caikit/caikit-nlp repository
4 | #
5 | #####################################################
6 | #
7 | # Learn about CODEOWNERS file format:
8 | # https://help.github.com/en/articles/about-code-owners
9 | #
10 |
11 | * @gkumbhat @evaline-ju @gabe-l-hart
12 |
--------------------------------------------------------------------------------
/.isort.cfg:
--------------------------------------------------------------------------------
1 | [settings]
2 | profile=black
3 | from_first=true
4 | import_heading_future=Future
5 | import_heading_stdlib=Standard
6 | import_heading_thirdparty=Third Party
7 | import_heading_firstparty=First Party
8 | import_heading_localfolder=Local
9 | known_firstparty=alog,aconfig,caikit,caikit_tgis_backend,import_tracker
10 | known_localfolder=caikit_nlp,tests
11 |
--------------------------------------------------------------------------------
/tests/fixtures/tiny_models/BloomForCausalLM/tokenizer_config.json:
--------------------------------------------------------------------------------
1 | {
2 | "add_prefix_space": false,
3 | "bos_token": "",
4 | "eos_token": "",
5 | "model_max_length": 1000000000000000019884624838656,
6 | "pad_token": "",
7 | "padding_side": "left",
8 | "special_tokens_map_file": null,
9 | "tokenizer_class": "BloomTokenizer",
10 | "unk_token": ""
11 | }
12 |
--------------------------------------------------------------------------------
/.pre-commit-config.yaml:
--------------------------------------------------------------------------------
1 | repos:
2 | - repo: https://github.com/pre-commit/mirrors-prettier
3 | rev: v2.1.2
4 | hooks:
5 | - id: prettier
6 | - repo: https://github.com/psf/black
7 | rev: 22.3.0
8 | hooks:
9 | - id: black
10 | exclude: imports
11 | additional_dependencies: ["platformdirs"]
12 | - repo: https://github.com/PyCQA/isort
13 | rev: 5.11.5
14 | hooks:
15 | - id: isort
16 | exclude: imports
17 |
--------------------------------------------------------------------------------
/examples/kill-text-generation-launcher.sh:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env bash
2 | # Kills the text generation launcher container if it's running
3 | running_container_id=$(docker container ls | grep -i text-gen-server | cut -d " " -f 1)
4 | if [ -z "$running_container_id" ]; then
5 | echo "TGIS container is not running; nothing to do!"
6 | else
7 | echo "Trying to kill TGIS container with id: {$running_container_id}..."
8 | eval "docker stop $running_container_id"
9 | fi
10 |
--------------------------------------------------------------------------------
/tests/fixtures/tiny_models/BertForSequenceClassification/tokenizer_config.json:
--------------------------------------------------------------------------------
1 | {
2 | "clean_up_tokenization_spaces": true,
3 | "cls_token": "[CLS]",
4 | "do_basic_tokenize": true,
5 | "do_lower_case": true,
6 | "mask_token": "[MASK]",
7 | "model_max_length": 512,
8 | "never_split": null,
9 | "pad_token": "[PAD]",
10 | "sep_token": "[SEP]",
11 | "strip_accents": null,
12 | "tokenize_chinese_chars": true,
13 | "tokenizer_class": "BertTokenizer",
14 | "unk_token": "[UNK]"
15 | }
16 |
--------------------------------------------------------------------------------
/.github/ISSUE_TEMPLATE/user_story.md:
--------------------------------------------------------------------------------
1 | ---
2 | name: User story
3 | about: A user-oriented story describing a piece of work to do
4 | title: ""
5 | labels: ""
6 | assignees: ""
7 | ---
8 |
9 | ## Description
10 |
11 | As a , I want to , so that I can
12 |
13 | ## Discussion
14 |
15 | Provide detailed discussion here
16 |
17 | ## Acceptance Criteria
18 |
19 |
20 |
21 | - [ ] Unit tests cover new/changed code
22 | - [ ] Examples build against new/changed code
23 | - [ ] READMEs are updated
24 | - [ ] Type of [semantic version](https://semver.org/) change is identified
25 |
--------------------------------------------------------------------------------
/.github/ISSUE_TEMPLATE/feature_request.md:
--------------------------------------------------------------------------------
1 | ---
2 | name: Feature request
3 | about: Suggest an idea for this project
4 | title: ""
5 | labels: ""
6 | assignees: ""
7 | ---
8 |
9 | ## Is your feature request related to a problem? Please describe.
10 |
11 | A clear and concise description of what the problem is. Ex. I'm always frustrated when [...]
12 |
13 | ## Describe the solution you'd like
14 |
15 | A clear and concise description of what you want to happen.
16 |
17 | ## Describe alternatives you've considered
18 |
19 | A clear and concise description of any alternative solutions or features you've considered.
20 |
21 | ## Additional context
22 |
23 | Add any other context about the feature request here.
24 |
--------------------------------------------------------------------------------
/caikit_nlp/model_management/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright The Caikit Authors
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | # Local
16 | from .tgis_auto_finder import TGISAutoFinder
17 |
--------------------------------------------------------------------------------
/caikit_nlp/modules/tokenization/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright The Caikit Authors
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | # Local
16 | from .regex_sentence_splitter import RegexSentenceSplitter
17 |
--------------------------------------------------------------------------------
/caikit_nlp/modules/text_classification/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright The Caikit Authors
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | # Local
16 | from .sequence_classification import SequenceClassification
17 |
--------------------------------------------------------------------------------
/scripts/dump_apis.sh:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env bash
2 |
3 | # Make a directory with interfaces
4 | http_interface_dir="generated_interfaces/http"
5 | grpc_interface_dir="generated_interfaces/grpc"
6 | mkdir -p $http_interface_dir
7 | mkdir -p $grpc_interface_dir
8 |
9 | # Run the HTTP server in the background
10 | RUNTIME_LIBRARY=caikit_nlp python -m caikit.runtime.http_server &
11 | http_pid=$!
12 |
13 | # Sleep for a bit and then call it to get the swagger doc
14 | sleep 5
15 | curl http://localhost:8080/openapi.json | jq > $http_interface_dir/openapi.json
16 |
17 | # Kill the HTTP server and wait for it to die
18 | kill -9 $http_pid
19 | wait
20 |
21 | # Dump the gRPC interfaces
22 | RUNTIME_LIBRARY=caikit_nlp python -m caikit.runtime.dump_services $grpc_interface_dir
--------------------------------------------------------------------------------
/caikit_nlp/modules/token_classification/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright The Caikit Authors
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | # Local
16 | from .filtered_span_classification import FilteredSpanClassification
17 |
--------------------------------------------------------------------------------
/caikit_nlp/resources/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright The Caikit Authors
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | """
15 | Shared resource implementations used in prompt tuning jobs
16 | """
17 |
18 | # Local
19 | from . import pretrained_model
20 |
--------------------------------------------------------------------------------
/scripts/fmt.sh:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env bash
2 |
3 | pre-commit run --all-files
4 | RETURN_CODE=$?
5 |
6 | function echoWarning() {
7 | LIGHT_YELLOW='\033[1;33m'
8 | NC='\033[0m' # No Color
9 | echo -e "${LIGHT_YELLOW}${1}${NC}"
10 | }
11 |
12 | if [ "$RETURN_CODE" -ne 0 ]; then
13 | if [ "${CI}" != "true" ]; then
14 | echoWarning "☝️ This appears to have failed, but actually your files have been formatted."
15 | echoWarning "Make a new commit with these changes before making a pull request."
16 | else
17 | echoWarning "This test failed because your code isn't formatted correctly."
18 | echoWarning 'Locally, run `make run fmt`, it will appear to fail, but change files.'
19 | echoWarning "Add the changed files to your commit and this stage will pass."
20 | fi
21 |
22 | exit $RETURN_CODE
23 | fi
24 |
--------------------------------------------------------------------------------
/caikit_nlp/modules/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright The Caikit Authors
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | # Local
15 | from . import (
16 | text_classification,
17 | text_embedding,
18 | text_generation,
19 | token_classification,
20 | tokenization,
21 | )
22 |
--------------------------------------------------------------------------------
/tests/fixtures/tiny_models/BertForSequenceClassification/config.json:
--------------------------------------------------------------------------------
1 | {
2 | "_name_or_path": "./bert/BertForSequenceClassification",
3 | "architectures": [
4 | "BertForSequenceClassification"
5 | ],
6 | "attention_probs_dropout_prob": 0.1,
7 | "classifier_dropout": null,
8 | "hidden_act": "gelu",
9 | "hidden_dropout_prob": 0.1,
10 | "hidden_size": 32,
11 | "initializer_range": 0.02,
12 | "intermediate_size": 37,
13 | "layer_norm_eps": 1e-12,
14 | "max_position_embeddings": 512,
15 | "model_type": "bert",
16 | "num_attention_heads": 4,
17 | "num_hidden_layers": 5,
18 | "pad_token_id": 0,
19 | "position_embedding_type": "absolute",
20 | "torch_dtype": "float32",
21 | "transformers_version": "4.30.2",
22 | "type_vocab_size": 16,
23 | "use_cache": true,
24 | "vocab_size": 1124
25 | }
26 |
--------------------------------------------------------------------------------
/caikit_nlp/data_model/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright The Caikit Authors
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | """Common data model containing all data structures that are passed in and out of blocks.
15 | """
16 |
17 | # Local
18 | from . import generation
19 | from .generation import *
20 |
--------------------------------------------------------------------------------
/.github/ISSUE_TEMPLATE/bug_report.md:
--------------------------------------------------------------------------------
1 | ---
2 | name: Bug report
3 | about: Create a report to help us improve
4 | title: ""
5 | labels: ""
6 | assignees: ""
7 | ---
8 |
9 | ## Describe the bug
10 |
11 | A clear and concise description of what the bug is.
12 |
13 | ## Platform
14 |
15 | Please provide details about the environment you are using, including the following:
16 |
17 | - Interpreter version:
18 | - Library version:
19 |
20 | ## Sample Code
21 |
22 | Please include a minimal sample of the code that will (if possible) reproduce the bug in isolation
23 |
24 | ## Expected behavior
25 |
26 | A clear and concise description of what you expected to happen.
27 |
28 | ## Observed behavior
29 |
30 | What you see happening (error messages, stack traces, etc...)
31 |
32 | ## Additional context
33 |
34 | Add any other context about the problem here.
35 |
--------------------------------------------------------------------------------
/.github/workflows/build-image.yml:
--------------------------------------------------------------------------------
1 | on:
2 | push:
3 | branches: ["main", "release-*"]
4 | paths:
5 | - "caikit_nlp"
6 | - "README.md"
7 | - "pyproject.toml"
8 | - "Dockerfile"
9 |
10 | pull_request:
11 |
12 | name: Build Image
13 |
14 | jobs:
15 | build-image:
16 | name: Build Image
17 | runs-on: ubuntu-latest
18 | steps:
19 | - uses: actions/checkout@v3
20 | - name: Reclaim space
21 | run: |
22 | sudo rm -rf /opt/hostedtoolcache
23 | - name: Set up Docker Buildx
24 | uses: docker/setup-buildx-action@v3
25 | - name: Build image
26 | uses: docker/build-push-action@v5
27 | with:
28 | context: .
29 | tags: "caikit-nlp:latest"
30 | load: true
31 | cache-from: type=gha
32 | cache-to: type=gha,mode=max
33 |
--------------------------------------------------------------------------------
/tests/fixtures/data_model/sample_objects.py:
--------------------------------------------------------------------------------
1 | # First Party
2 | from caikit.interfaces.nlp.data_model import FinishReason, GeneratedTextResult
3 |
4 | generated_response = GeneratedTextResult(
5 | generated_text="foo bar",
6 | generated_tokens=2,
7 | finish_reason=FinishReason.STOP_SEQUENCE,
8 | )
9 |
10 | # Add an example of one of each new data model types that you've defined within your extension
11 | # to the list below. These samples are used for verifying that your object is well-aligned with
12 | # with caikit serialization interfaces. They will only be used if your extension enables
13 | # protobuf serialization in its config.
14 | #
15 | # NOTE: You do not need to add any samples from other extensions or the caikit library; only
16 | # new types explicitly created in your extension data model.
17 | data_model_samples = [
18 | generated_response,
19 | ]
20 |
--------------------------------------------------------------------------------
/caikit_nlp/modules/text_generation/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright The Caikit Authors
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | # Local
16 | from .peft_prompt_tuning import PeftPromptTuning
17 | from .peft_tgis_remote import PeftPromptTuningTGIS
18 | from .text_generation_local import TextGeneration
19 | from .text_generation_tgis import TextGenerationTGIS
20 |
--------------------------------------------------------------------------------
/tests/fixtures/tiny_models/T5ForConditionalGeneration/config.json:
--------------------------------------------------------------------------------
1 | {
2 | "_name_or_path": "DUMMY_MODELS/t5/T5ForConditionalGeneration",
3 | "architectures": ["T5ForConditionalGeneration"],
4 | "bos_token_id": 0,
5 | "d_ff": 37,
6 | "d_kv": 8,
7 | "d_model": 32,
8 | "decoder_start_token_id": 0,
9 | "dense_act_fn": "relu",
10 | "dropout_rate": 0.1,
11 | "eos_token_id": 1,
12 | "feed_forward_proj": "relu",
13 | "initializer_factor": 0.002,
14 | "is_encoder_decoder": true,
15 | "is_gated_act": false,
16 | "layer_norm_epsilon": 1e-6,
17 | "model_type": "t5",
18 | "num_decoder_layers": 5,
19 | "num_heads": 4,
20 | "num_layers": 5,
21 | "pad_token_id": 0,
22 | "relative_attention_max_distance": 128,
23 | "relative_attention_num_buckets": 8,
24 | "torch_dtype": "float32",
25 | "transformers_version": "4.27.1",
26 | "use_cache": true,
27 | "vocab_size": 1302
28 | }
29 |
--------------------------------------------------------------------------------
/caikit_nlp/resources/pretrained_model/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright The Caikit Authors
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | """
15 | Resources holding pretrained models of various types
16 | """
17 |
18 | # Local
19 | from .base import PretrainedModelBase
20 | from .hf_auto_causal_lm import HFAutoCausalLM
21 | from .hf_auto_seq2seq_lm import HFAutoSeq2SeqLM
22 | from .hf_auto_seq_classifier import HFAutoSequenceClassifier
23 |
--------------------------------------------------------------------------------
/tests/fixtures/tiny_models/BloomForCausalLM/config.json:
--------------------------------------------------------------------------------
1 | {
2 | "apply_residual_connection_post_layernorm": false,
3 | "architectures": ["BloomForCausalLM"],
4 | "attention_dropout": 0.1,
5 | "bos_token_id": 1,
6 | "dtype": "float32",
7 | "eos_token_id": 2,
8 | "gradient_checkpointing": false,
9 | "hidden_dropout": 0.1,
10 | "hidden_size": 32,
11 | "id2label": {
12 | "0": "LABEL_0",
13 | "1": "LABEL_1",
14 | "2": "LABEL_2"
15 | },
16 | "initializer_range": 0.02,
17 | "is_decoder": true,
18 | "label2id": {
19 | "LABEL_0": 0,
20 | "LABEL_1": 1,
21 | "LABEL_2": 2
22 | },
23 | "layer_norm_epsilon": 1e-5,
24 | "model_type": "bloom",
25 | "n_head": 4,
26 | "n_layer": 5,
27 | "n_positions": 512,
28 | "pad_token_id": 3,
29 | "pretraining_tp": 1,
30 | "seq_length": 7,
31 | "slow_but_exact": true,
32 | "torch_dtype": "float32",
33 | "transformers_version": "4.27.1",
34 | "type_vocab_size": 16,
35 | "use_cache": true,
36 | "vocab_size": 1024
37 | }
38 |
--------------------------------------------------------------------------------
/caikit_nlp/modules/text_embedding/utils.py:
--------------------------------------------------------------------------------
1 | # Copyright The Caikit Authors
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 |
16 | def env_val_to_bool(val):
17 | """Returns the bool value of env var"""
18 | if val is None:
19 | return False
20 | if isinstance(val, bool):
21 | return val
22 |
23 | # For testing env vars for values that mean false (else True!)
24 | return str(val).lower().strip() not in ("no", "n", "false", "0", "f", "off", "")
25 |
--------------------------------------------------------------------------------
/tests/conftest.py:
--------------------------------------------------------------------------------
1 | # Copyright The Caikit Authors
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | """
15 | This sets up global test configs when pytest starts
16 | """
17 |
18 | # Standard
19 | import os
20 |
21 | # First Party
22 | import alog
23 |
24 | # Configure logging from the environment
25 | alog.configure(
26 | default_level=os.environ.get("LOG_LEVEL", "off"),
27 | filters=os.environ.get("LOG_FILTERS", "urllib3:off"),
28 | thread_id=os.environ.get("LOG_THREAD_ID", "") == "true",
29 | )
30 |
--------------------------------------------------------------------------------
/caikit_nlp/config/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright The Caikit Authors
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | # Standard
16 | from pathlib import Path
17 |
18 | # First Party
19 | import alog
20 | import caikit
21 |
22 | log = alog.use_channel("CONFIG_INIT")
23 |
24 | # The name for an extension is simply the name of the directory containing its config dir.
25 | extension_name = Path(__file__).parent.parent.name
26 |
27 | MODEL_MANAGER = caikit.core.MODEL_MANAGER
28 |
29 | extract = MODEL_MANAGER.extract
30 | load = MODEL_MANAGER.load
31 | resolve_and_load = MODEL_MANAGER.resolve_and_load
32 |
--------------------------------------------------------------------------------
/tests/fixtures/tiny_models/README.md:
--------------------------------------------------------------------------------
1 | ### Tiny Models for Testing
2 |
3 | The models in this directory were created using the [Transformers utils for creating tiny models](https://github.com/huggingface/transformers/blob/main/utils/create_dummy_models.py).
4 |
5 | To create a new dummy model, all you need to do it clone the transformers repo and run a command like those shown below, which were used to create the artifacts checked in here.
6 |
7 | - Bloom: `python3 utils/create_dummy_models.py --model_types bloom $OUTPUT_DIR`
8 | - T5: `python3 utils/create_dummy_models.py --model_types t5 $OUTPUT_DIR`
9 | - BERT: `python3 utils/create_dummy_models.py --model_types bert $OUTPUT_DIR`
10 |
11 | This will create several dummy models; you most likely want to place the ones you need in this directory and leverage it in `__init__.py` for the fixtures.
12 |
13 | Note: If you encounter any strangeness when running the script above, be sure to check your version of transformers in your site packages; there seem to be some dynamic imports leveraged underneath which can try to grab things from your site packages instead of the direct source, which can cause some problems if the cloned version of the code is very different.
14 |
--------------------------------------------------------------------------------
/.github/workflows/publish-library.yml:
--------------------------------------------------------------------------------
1 | # Copyright The Caikit Authors
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | name: Publish
16 |
17 | on:
18 | release:
19 | types: [published]
20 |
21 | jobs:
22 | build:
23 | runs-on: ubuntu-latest
24 | steps:
25 | - uses: actions/checkout@v3
26 | - name: Set up Python
27 | uses: actions/setup-python@v3
28 | - name: Build and check package
29 | run: |
30 | pip install tox
31 | tox -e build,twinecheck
32 | - name: Upload package
33 | if: github.event_name == 'release'
34 | uses: pypa/gh-action-pypi-publish@release/v1
35 | with:
36 | password: ${{ secrets.PYPI_TOKEN }}
37 |
--------------------------------------------------------------------------------
/.github/workflows/lint-code.yml:
--------------------------------------------------------------------------------
1 | # Copyright The Caikit Authors
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | name: Lint and Format
16 |
17 | on:
18 | push:
19 | branches: ["main", "release-*"]
20 | pull_request:
21 | branches: ["main", "release-*"]
22 |
23 | jobs:
24 | build:
25 | runs-on: ubuntu-latest
26 | steps:
27 | - uses: actions/checkout@v3
28 | - name: Set up Python 3.9
29 | uses: actions/setup-python@v4
30 | with:
31 | python-version: 3.9
32 | - name: Install dependencies
33 | run: |
34 | python -m pip install --upgrade pip
35 | python -m pip install -r setup_requirements.txt
36 | - name: Check Formatting
37 | run: tox -e fmt
38 | - name: Run pylint
39 | run: tox -e lint
40 |
--------------------------------------------------------------------------------
/caikit_nlp/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright The Caikit Authors
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | """Caikit prompt tuning library
15 | """
16 | # Standard
17 | import os
18 |
19 | # First Party
20 | from caikit.core.model_manager import *
21 |
22 | # Import the model management semantics from the core
23 | import caikit
24 |
25 | # Local
26 | # Import subpackages
27 | from . import config, data_model, model_management
28 | from .config import *
29 | from .data_model import *
30 | from .modules import *
31 | from .resources import *
32 | from .version import __version__, __version_tuple__
33 |
34 | # Configure the library with library-specific configuration file
35 | CONFIG_PATH = os.path.realpath(
36 | os.path.join(os.path.dirname(__file__), "config", "config.yml")
37 | )
38 |
39 | caikit.configure(CONFIG_PATH)
40 |
--------------------------------------------------------------------------------
/pyproject.toml:
--------------------------------------------------------------------------------
1 | [build-system]
2 | requires = [
3 | "setuptools>=60",
4 | "setuptools-scm>=8.0"]
5 |
6 | [project]
7 | name = "caikit-nlp"
8 | dynamic = ["version"]
9 | description = "Caikit NLP"
10 | license = {text = "Apache-2.0"}
11 | readme = "README.md"
12 | requires-python = "~=3.9"
13 | classifiers=[
14 | "License :: OSI Approved :: Apache Software License"
15 | ]
16 | dependencies = [
17 | "caikit[runtime-grpc,runtime-http]>=0.26.34,<0.29.0",
18 | "caikit-tgis-backend>=0.1.36,<0.2.0",
19 | # TODO: loosen dependencies
20 | "grpcio>=1.62.2", # explicitly pin grpc dependencies to a recent version to avoid pip backtracking
21 | "grpcio-reflection>=1.62.2",
22 | "grpcio-health-checking>=1.62.2",
23 | "accelerate>=1.3.0",
24 | "datasets>=3.3.0",
25 | "huggingface-hub",
26 | "numpy>=1.23.0,<2",
27 | "pandas>=2.2.3",
28 | "scikit-learn>=1.6.1",
29 | "scipy>=1.10.0",
30 | "sentence-transformers>=3.4.0,<3.5.0",
31 | "tokenizers>=0.20.0",
32 | "torch>=2.3.1,<2.9.0",
33 | "tqdm>=4.67.0",
34 | "transformers>=4.48.3,<4.50.0",
35 | "peft==0.14.0",
36 | ]
37 |
38 |
39 | [tool.setuptools.packages.find]
40 | exclude = ["tests", "tests.*"]
41 | namespaces = false
42 |
43 |
44 | [tool.setuptools_scm]
45 | version_file = "caikit_nlp/_version.py"
46 |
47 | [project.urls]
48 | Source = "https://github.com/caikit/caikit-nlp"
49 |
--------------------------------------------------------------------------------
/runtime_template/run_with_gateway.sh:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env bash
2 |
3 | ################################################################################
4 | # This script is the entrypoint for the multi-process runtime container that
5 | # runs the REST gateway alongside the grpc runtime.
6 | # The multiprocess management is intended to be handled by `tini`, the tiny
7 | # but valid `init`.
8 | ################################################################################
9 |
10 | set -e
11 |
12 | echo '[STARTING RUNTIME]'
13 | cd /app && python3 -m caikit.runtime.grpc_server &
14 |
15 | RUNTIME_PORT=${RUNTIME_PORT:-8085}
16 |
17 | # If TLS enabled, make an https call, otherwise make an http call
18 | protocol="http"
19 | if [ "${RUNTIME_TLS_SERVER_KEY}" != "" ] && [ "${RUNTIME_TLS_SERVER_CERT}" != "" ]
20 | then
21 | protocol="--cacert $RUNTIME_TLS_SERVER_CERT https"
22 | if [ "${RUNTIME_TLS_CLIENT_CERT}" != "" ]
23 | then
24 | protocol="-k --cert $RUNTIME_TLS_SERVER_CERT --key $RUNTIME_TLS_SERVER_KEY https"
25 | fi
26 | fi
27 |
28 | # Wait for the Runtime to come up before starting the gateway
29 | sleep 3
30 | until $(curl --output /dev/null --silent --fail ${protocol}://localhost:${RUNTIME_PORT}); do
31 | echo '.'
32 | sleep 1
33 | done
34 |
35 | echo '[STARTING GATEWAY]'
36 | PROXY_ENDPOINT="localhost:${RUNTIME_PORT}" SERVE_PORT=${GATEWAY_PORT:-8080} /gateway --swagger_path=/swagger &
37 |
38 | wait -n
39 |
--------------------------------------------------------------------------------
/caikit_nlp/modules/text_embedding/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright The Caikit Authors
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | """
16 | Text Embedding Module
17 | =====================
18 |
19 | Implements the following tasks:
20 |
21 | 1. EmbeddingTask: Returns an embedding from an input text string
22 | 2. EmbeddingsTasks: EmbeddingTask but with a list of inputs producing a list of outputs
23 | 3. SentenceSimilarityTask: Compare one source sentence to a list of sentences
24 | 4. SentenceSimilarityTasks: SentenceSimilarityTask but with a list of source sentences producing
25 | a list of outputs
26 | 5. RerankTask: Return top_n documents ordered by relevance given a query
27 | 6. RerankTasks: RerankTask but with a list of queries producing a list of outputs
28 |
29 | """
30 |
31 | # Local
32 | from .crossencoder import CrossEncoderModule
33 | from .embedding import EmbeddingModule
34 |
--------------------------------------------------------------------------------
/tox.ini:
--------------------------------------------------------------------------------
1 | [tox]
2 | envlist = py, lint, fmt
3 |
4 | [testenv]
5 | description = run tests with pytest with coverage
6 | deps =
7 | evaluate # Currently used for sample scripts etc.
8 | pytest==7.1.3
9 | pytest-cov>=2.10.1,<3.0
10 | pytest-html>=3.1.1,<4.0
11 | tf-keras>=2.18.0
12 | wheel>=0.38.4
13 | passenv =
14 | LOG_LEVEL
15 | LOG_FILTERS
16 | LOG_FORMATTER
17 | LOG_THREAD_ID
18 | LOG_CHANNEL_WIDTH
19 | PYTORCH_ENABLE_MPS_FALLBACK
20 | commands = pytest --durations=42 --cov=caikit_nlp --cov-report=term --cov-report=html {posargs:tests}
21 |
22 | ; Unclear: We probably want to test wheel packaging
23 | ; But! tox will fail when this is set and _any_ interpreter is missing
24 | ; Without this, sdist packaging is tested so that's a start.
25 | package=wheel
26 |
27 | [testenv:fmt]
28 | description = format with pre-commit
29 | deps = pre-commit>=3.0.4,<4.0
30 | commands = ./scripts/fmt.sh
31 | allowlist_externals = ./scripts/fmt.sh
32 | skip_install = True # Skip package install since fmt doesn't need to execute code, for ⚡⚡⚡
33 |
34 | [testenv:lint]
35 | description = lint with pylint
36 | deps = pylint>=2.16.2,<3.0
37 | commands = pylint caikit_nlp
38 |
39 | [testenv:build]
40 | description = build wheel
41 | deps =
42 | build
43 | commands = python -m build
44 | skip_install = True
45 |
46 | [testenv:twinecheck]
47 | description = check wheel
48 | deps =
49 | twine
50 | commands = twine check dist/*
51 | skip_install = True
52 |
--------------------------------------------------------------------------------
/Dockerfile:
--------------------------------------------------------------------------------
1 | FROM registry.access.redhat.com/ubi9/ubi-minimal:latest as base
2 |
3 | RUN microdnf update -y && \
4 | microdnf install -y \
5 | git python-pip && \
6 | pip install --upgrade --no-cache-dir pip wheel && \
7 | microdnf clean all
8 |
9 | FROM base as builder
10 | WORKDIR /build
11 |
12 | RUN pip install --no-cache tox
13 | COPY README.md .
14 | COPY pyproject.toml .
15 | COPY tox.ini .
16 | COPY caikit_nlp caikit_nlp
17 | # .git is required for setuptools-scm get the version
18 | RUN --mount=source=.git,target=.git,type=bind \
19 | --mount=type=cache,target=/root/.cache/pip \
20 | tox -e build
21 |
22 |
23 | FROM base as deploy
24 |
25 | RUN python -m venv --upgrade-deps /opt/caikit/
26 |
27 | ENV VIRTUAL_ENV=/opt/caikit
28 | ENV PATH="$VIRTUAL_ENV/bin:$PATH"
29 |
30 | COPY --from=builder /build/dist/caikit_nlp*.whl /tmp/
31 | RUN --mount=type=cache,target=/root/.cache/pip \
32 | pip install /tmp/caikit_nlp*.whl && \
33 | rm /tmp/caikit_nlp*.whl
34 |
35 | COPY LICENSE /opt/caikit/
36 | COPY README.md /opt/caikit/
37 |
38 | RUN groupadd --system caikit --gid 1001 && \
39 | adduser --system --uid 1001 --gid 0 --groups caikit \
40 | --home-dir /caikit --shell /sbin/nologin \
41 | --comment "Caikit User" caikit
42 |
43 | USER caikit
44 |
45 | ENV RUNTIME_LIBRARY=caikit_nlp
46 | # Optional: use `CONFIG_FILES` and the /caikit/ volume to explicitly provide a configuration file and models
47 | # ENV CONFIG_FILES=/caikit/caikit.yml
48 | VOLUME ["/caikit/"]
49 | WORKDIR /caikit
50 |
51 | CMD ["python"]
52 |
--------------------------------------------------------------------------------
/benchmarks/README.md:
--------------------------------------------------------------------------------
1 | # Caikit NLP Runtime Performance Benchmarks
2 |
3 | Runtime performance benchmarking results for various model on various hardware configurations.
4 |
5 | ## Llama2-7b
6 |
7 | | Date Executed | Hardware | Training Set | Epoch | Precision | Batch Size | Max Source Length | Training Runtime (s) | Samples Per Second | Train Steps Per Second | Loss | Notes |
8 | |---|---|---------------|---|---|:---:|---|------------| --- |---|---|---|
9 | | [2023-09-05](./logs/llama2-7b/20230905_183655.output) | 1 x A100 80GB | [Glue / RTE](https://huggingface.co/datasets/glue) | 1 | bfloat16 | 6 | 4096 | 350 | 21.325 | 0.22 | 1.65 | 4096 is the context size for Llama2 |
10 | | [2023-09-05](./logs/llama2-7b/20230905_184809.output) | 1 x A100 80GB | [Glue / RTE](https://huggingface.co/datasets/glue) | 1 | bfloat16 | 6 | 1024 | 350 | 21.333 | 0.22 | 1.65 | batch size of 7 fails CUDA OOM |
11 | | [2023-09-06](./logs/llama2-7b/20230906_135211.output) | 1 x A100 80GB | [Glue / RTE](https://huggingface.co/datasets/glue) | 1 | bfloat16 | 6 | 512 | 348 | 21.44 | 0.22 | 1.65 | batch size of 7 fails CUDA OOM |
12 | | [2023-09-05](./logs/llama2-7b/20230905_194133.output) | 1 x A100 80GB | [Glue / RTE](https://huggingface.co/datasets/glue) | 1 | bfloat16 | 8 | 256 | 356 | 20.939 | 0.16 | 1.70 | batch size of 9 fails CUDA OOM |
13 | | [2023-09-05](./logs/llama2-7b/20230905_191650.output) | 1 x A100 80GB | [Glue / RTE](https://huggingface.co/datasets/glue) | 1 | bfloat16 | 19 | 128 | 254 | 29.332 | 0.09 | 1.94 | batch size of 20 fails CUDA OOM |
14 |
--------------------------------------------------------------------------------
/.github/workflows/build-library.yml:
--------------------------------------------------------------------------------
1 | # Copyright The Caikit Authors
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | name: Build Caikit NLP Library
16 |
17 | on:
18 | push:
19 | branches: ["main", "release-*"]
20 | pull_request:
21 | branches: ["main", "release-*"]
22 |
23 | jobs:
24 | build:
25 | runs-on: ubuntu-latest
26 | strategy:
27 | matrix:
28 | python-version:
29 | # NOTE: TGIS does not support 3.8 or 3.11
30 | - setup: "3.9"
31 | tox: "py39"
32 | - setup: "3.10"
33 | tox: "py310"
34 |
35 | steps:
36 | - uses: actions/checkout@v3
37 | - name: Set up Python ${{ matrix.python-version.setup }}
38 | uses: actions/setup-python@v4
39 | with:
40 | python-version: ${{ matrix.python-version.setup }}
41 | - name: Install dependencies
42 | run: |
43 | python -m pip install --upgrade pip
44 | python -m pip install -r setup_requirements.txt
45 | - name: Build and test with tox
46 | run: tox -e ${{ matrix.python-version.tox }}
47 |
--------------------------------------------------------------------------------
/caikit_nlp/toolkit/task_specific_utils.py:
--------------------------------------------------------------------------------
1 | # Copyright The Caikit Authors
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | # First Party
16 | from caikit.core.exceptions import error_handler
17 | from caikit.interfaces.nlp.data_model import ClassificationTrainRecord
18 | import alog
19 |
20 | # Local
21 | from ..data_model import GenerationTrainRecord
22 |
23 | log = alog.use_channel("TASK_UTILS")
24 | error = error_handler.get(log)
25 |
26 |
27 | def convert_to_generation_record(train_record):
28 | if isinstance(train_record, GenerationTrainRecord):
29 | return train_record
30 | if isinstance(train_record, ClassificationTrainRecord):
31 | text = train_record.text
32 | labels = labels = ",".join(str(label) for label in train_record.labels)
33 | return GenerationTrainRecord(input=text, output=labels)
34 | error(
35 | "",
36 | TypeError(
37 | "Unsupported instance type. \
38 | Only instances of datamodels ClassificationTrainRecord \
39 | and GenerationTrainRecord are supported"
40 | ),
41 | )
42 |
--------------------------------------------------------------------------------
/tests/modules/tokenization/test_regex_sentence_splitter.py:
--------------------------------------------------------------------------------
1 | """Tests for regex sentence splitter
2 | """
3 | # Standard
4 | import os
5 | import tempfile
6 |
7 | # First Party
8 | from caikit.interfaces.nlp.data_model import TokenizationResults
9 |
10 | # Local
11 | from caikit_nlp.modules.tokenization.regex_sentence_splitter import (
12 | RegexSentenceSplitter,
13 | )
14 |
15 | ## Setup ########################################################################
16 |
17 | # Regex sentence splitter model for reusability across tests
18 | REGEX_STR = "[^.!?\s][^.!?\n]*(?:[.!?](?!['\"]?\s|$)[^.!?]*)*[.!?]?['\"]?(?=\s|$)"
19 | SENTENCE_TOKENIZER = RegexSentenceSplitter.bootstrap(REGEX_STR)
20 | DOCUMENT = "What he told me before, I have it in my heart. I am tired of fighting."
21 |
22 | ## Tests ########################################################################
23 |
24 |
25 | def test_bootstrap_and_run():
26 | """Check if we can bootstrap and run regex sentence splitter"""
27 | tokenization_result = SENTENCE_TOKENIZER.run(DOCUMENT)
28 | assert isinstance(tokenization_result, TokenizationResults)
29 | assert len(tokenization_result.results) == 2
30 |
31 |
32 | def test_save_load_and_run_model():
33 | """Check if we can run a saved model successfully"""
34 | with tempfile.TemporaryDirectory() as model_dir:
35 | SENTENCE_TOKENIZER.save(model_dir)
36 | assert os.path.exists(os.path.join(model_dir, "config.yml"))
37 |
38 | new_splitter = RegexSentenceSplitter.load(model_dir)
39 | tokenization_result = new_splitter.run(DOCUMENT)
40 | assert isinstance(tokenization_result, TokenizationResults)
41 | assert len(tokenization_result.results) == 2
42 |
--------------------------------------------------------------------------------
/tests/data_model/test_generation.py:
--------------------------------------------------------------------------------
1 | # Copyright The Caikit Authors
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | # Local
16 | from caikit_nlp.data_model import ExponentialDecayLengthPenalty
17 |
18 | ## Setup #########################################################################
19 |
20 | dummy_exponential_decay_length_penalty = ExponentialDecayLengthPenalty(
21 | start_index=1, decay_factor=0.95
22 | )
23 |
24 | ## Tests ########################################################################
25 |
26 | ### Exponential Decay Length Penalty
27 | def test_exponential_decay_length_penalty_all_fields_accessible():
28 | assert dummy_exponential_decay_length_penalty.start_index == 1
29 | assert dummy_exponential_decay_length_penalty.decay_factor == 0.95
30 |
31 |
32 | def test_sampling_parameters_from_proto_and_back():
33 | new = ExponentialDecayLengthPenalty.from_proto(
34 | dummy_exponential_decay_length_penalty.to_proto()
35 | )
36 | assert new.start_index == 1
37 | assert new.decay_factor == 0.95
38 |
39 |
40 | def test_sampling_parameters_from_json_and_back():
41 | new = ExponentialDecayLengthPenalty.from_json(
42 | dummy_exponential_decay_length_penalty.to_json()
43 | )
44 | assert new.start_index == 1
45 | assert new.decay_factor == 0.95
46 |
--------------------------------------------------------------------------------
/tests/toolkit/test_data_type_utils.py:
--------------------------------------------------------------------------------
1 | """Tests for data type related utils, e.g., for interacting with serialized torch types.
2 | """
3 | # Third Party
4 | import pytest
5 | import torch
6 |
7 | # Local
8 | from caikit_nlp.toolkit.data_type_utils import get_torch_dtype, str_to_torch_dtype
9 |
10 | ### Tests for converting from strings / types / None -> torch data types
11 |
12 |
13 | def test_get_torch_dtype_from_str():
14 | """Ensure that we can parse a data type from a string."""
15 | assert torch.float32 is get_torch_dtype("float32")
16 |
17 |
18 | def test_get_torch_dtype_from_dtype():
19 | """Ensure that if a data type is provided from pytorch, we simply return it."""
20 | assert torch.float32 is get_torch_dtype(torch.float32)
21 |
22 |
23 | def test_get_torch_dtype_from_bad_type():
24 | """Ensure that if a type we can't coerce to a pytorch dtype is given, we get a TypeError."""
25 | with pytest.raises(TypeError):
26 | assert torch.float32 is get_torch_dtype(100)
27 |
28 |
29 | def test_get_torch_dtype_from_bad_str():
30 | """Ensure that if an invalid type string is given, we get a ValueError."""
31 | with pytest.raises(ValueError):
32 | assert torch.float32 is get_torch_dtype("not a valid attr of pytorch")
33 |
34 |
35 | ### Tests for converting from strings -> torch data types
36 | def test_str_to_torch_dtype():
37 | """Ensure that we can parse a data type from a string."""
38 | assert str_to_torch_dtype("float32") is torch.float32
39 |
40 |
41 | def test_str_to_torch_dtype_invalid_attr():
42 | """Ensure that we raise ValueError if an incorrect type str is provided."""
43 | with pytest.raises(ValueError):
44 | str_to_torch_dtype("not a valid attr of pytorch")
45 |
46 |
47 | def test_str_to_torch_dtype_bad_attr():
48 | """Ensure that we raise ValueError if a non type property of torch is provided."""
49 | with pytest.raises(ValueError):
50 | str_to_torch_dtype("nn")
51 |
--------------------------------------------------------------------------------
/examples/text-generation-launcher:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env bash
2 | # This script is primarily meant for illustrative purposes; if we don't have
3 | # the text-generation-launcher command locally available, but we do have a Docker
4 | # container, we add this script onto our path so when the TGIS backend in caikit
5 | # tries to start the server, it runs this script instead.
6 | #
7 | # NOTE:
8 | # - Model ID, directories, etc are hardcoded to our example, params from the backend,
9 | # e.g., shard configuration, are ignored.
10 | #
11 | # - We need to export port 3000 (for probes in core distributed), and we forward 8033->50055
12 | # so that our gRPC server is exposed on the expected port for local TGIS.
13 | TGIS_MODEL="${MODEL_NAME:-bigscience/bloom-560m}"
14 | MODEL_DIR="${MODEL_DIR:-models}"
15 | echo "Running TGIS with model: $TGIS_MODEL"
16 |
17 | docker run --rm \
18 | --gpus '"device=0"' \
19 | -p 8090:8090 \
20 | -p 8085:8085 \
21 | -p 8060:8060 \
22 | -p 8087:8087 \
23 | -p 50055:8033 \
24 | -p 3000:3000 \
25 | -v $(pwd)/${MODEL_DIR}:/models \
26 | -v $(pwd)/../runtime_config.yaml:/conf/runtime_config.yaml \
27 | -v $(pwd)/transformers_cache:/shared_model_storage/transformers_cache \
28 | -v $(pwd)/prompt_prefixes:/prompt_prefixes \
29 | -e LOG_LEVEL=debug3 \
30 | -e ACCEPT_LICENSE=true \
31 | -e INFERENCE_PLUGIN_MODEL_MESH_MAX_MODEL_CONCURRENCY=10 \
32 | -e RUNTIME_SERVER_THREAD_POOL_SIZE=10 \
33 | -e INFERENCE_PLUGIN_MODEL_MESH_CAPACITY=28000000000 \
34 | -e INFERENCE_PLUGIN_MODEL_MESH_DEFAULT_MODEL_SIZE=1773741824 \
35 | -e CONFIG_FILES="/conf/runtime_config.yaml" \
36 | -e RUNTIME_LOCAL_MODELS_DIR="/models/" \
37 | -e MAX_BATCH_SIZE=8 \
38 | -e MAX_SEQUENCE_LENGTH=2048 \
39 | -e NUM_GPUS=1 \
40 | -e TRANSFORMERS_CACHE="/shared_model_storage/transformers_cache" \
41 | -e HUGGINGFACE_HUB_CACHE="/shared_model_storage/transformers_cache" \
42 | -e MAX_CONCURRENT_REQUESTS=64 \
43 | -e GATEWAY_PORT=8060 \
44 | -e RUNTIME_PORT=8087 \
45 | -e MODEL_NAME=$TGIS_MODEL \
46 | -e PREFIX_STORE_PATH="/prompt_prefixes" \
47 | --user root \
48 | text-gen-server:server-release_ubi8_py38
49 |
--------------------------------------------------------------------------------
/caikit_nlp/resources/pretrained_model/hf_auto_seq_classifier.py:
--------------------------------------------------------------------------------
1 | # Copyright The Caikit Authors
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | """
15 | Huggingface auto sequence classifier resource type
16 | """
17 | # Standard
18 | from typing import Callable, Tuple
19 |
20 | # Third Party
21 | from transformers import AutoModelForSequenceClassification
22 | from transformers.models.auto import modeling_auto
23 |
24 | # First Party
25 | from caikit.core.modules import module
26 |
27 | # Local
28 | from .base import PretrainedModelBase
29 |
30 |
31 | @module(
32 | id="6759e891-287b-405b-bd8b-54a4a4d51c23",
33 | name="HF Transformers Auto Sequence Classifier",
34 | version="0.1.0",
35 | )
36 | class HFAutoSequenceClassifier(PretrainedModelBase):
37 | """This resource (module) wraps a handle to a huggingface
38 | AutoModelForSequenceClassification
39 | """
40 |
41 | MODEL_TYPE = AutoModelForSequenceClassification
42 | SUPPORTED_MODEL_TYPES = (
43 | modeling_auto.MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING_NAMES
44 | )
45 | TASK_TYPE = "SEQ_CLS"
46 | PROMPT_OUTPUT_TYPES = []
47 | MAX_NUM_TRANSFORMERS = 1
48 |
49 | @classmethod
50 | def bootstrap(cls, *args, **kwargs) -> "HFAutoSequenceClassifier":
51 | """Bootstrap from a huggingface model
52 |
53 | See help(PretrainedModelBase)
54 | """
55 | return super().bootstrap(*args, return_dict=True, **kwargs)
56 |
57 | @staticmethod
58 | def tokenize_function(*args, **kwargs) -> Tuple[Callable, bool]:
59 | raise NotImplementedError(
60 | "Tokenize func not implemented for sequence classifier"
61 | )
62 |
--------------------------------------------------------------------------------
/runtime_config.yaml:
--------------------------------------------------------------------------------
1 | # its contents configure the TGIS server & caikit
2 | jvm_options: []
3 |
4 | runtime:
5 | library: caikit_nlp
6 | lazy_load_local_models: true
7 | batching:
8 | standalone-model:
9 | size: 0 # Set to batch size for batching
10 |
11 | model_management:
12 | finders:
13 | default:
14 | type: LOCAL
15 | remote_tgis:
16 | type: TGIS-AUTO
17 | config:
18 | test_connection: true
19 | initializers:
20 | default:
21 | type: LOCAL
22 | config:
23 | backend_priority:
24 | - type: TGIS
25 | config:
26 | local:
27 | load_timeout: 120
28 | grpc_port: null
29 | http_port: null
30 | health_poll_delay: 1.0
31 | remote_models:
32 | flan-t5-xl:
33 | hostname: localhost:8033
34 | prompt_dir: tgis_prompts
35 | llama-70b:
36 | hostname: localhost:8034
37 | prompt_dir: tgis_prompts
38 |
39 | connection:
40 | hostname: "foo.{model_id}:1234"
41 | ca_cert_file: null
42 | client_cert_file: null
43 | client_key_file: null
44 |
45 | # Config used only in EmbeddingModule. Set here or use env vars like EMBEDDING_RETRIES=32
46 | embedding:
47 | # Allow models with remote code.
48 | trust_remote_code: false
49 | # Number of times to retry on error. Most deployments should use 0 retries.
50 | retries: 0
51 | # Batch size for encode() if <= 0 or invalid, the sentence-transformers default is used
52 | batch_size: 0
53 | # Should implicit truncation (with truncate_input_tokens=0) throw error for truncation (default) or disable this
54 | implicit_truncation_errors: true
55 | # Attempt to optimize with PyTorch compile()
56 | pt2_compile: false
57 | # Use IPEX optimize. Works best when used with autocast (bfloat16) below.
58 | ipex: false
59 | # Use autocast in encode with its default dtype (bfloat16)
60 | autocast: false
61 | # For testing, set device to "mps" on MacOS or "xpu" for IPEX GPU.
62 | # Otherwise, the default does automatic checks for cuda GPU (else cpu).
63 | device: ""
64 |
--------------------------------------------------------------------------------
/examples/load_and_run_distributed_peft.py:
--------------------------------------------------------------------------------
1 | """This script loads and runs a sample PEFT model as a caikit module using
2 | the TGIS backend.
3 |
4 | In a nutshell, this does the following:
5 | - Check if `text-generation-launcher` is defined; if it doesn't, assume we have a Docker image
6 | that is (currently hardcoded to be) able to run TGIS & expose the proper ports, and patch a
7 | wrapper script around it onto our path so that the TGIS backend falls back to leveraging it
8 | - Load the model through caikit
9 | - Run an inference text generation request and dump the (garbage) output back to the console
10 | """
11 | # Standard
12 | from shutil import which
13 | import os
14 | import subprocess
15 | import sys
16 |
17 | # First Party
18 | from caikit.core.module_backend_config import _CONFIGURED_BACKENDS, configure
19 | from caikit_tgis_backend import TGISBackend
20 | import alog
21 | import caikit
22 |
23 | # Local
24 | import caikit_nlp
25 |
26 | alog.configure("debug4")
27 |
28 | PREFIX_PATH = "prompt_prefixes"
29 |
30 | has_text_gen = which("text-generation-launcher")
31 | if not which("text-generation-launcher"):
32 | print("Text generation server command not found; using Docker override")
33 | this_dir = os.path.dirname(os.path.abspath(__file__))
34 | os.environ["PATH"] += ":" + this_dir
35 | assert (
36 | which("text-generation-launcher") is not None
37 | ), "Text generation script not found!"
38 |
39 | # Configure caikit to prioritize TGIS backend
40 | _CONFIGURED_BACKENDS.clear()
41 | # load_timeout: 320
42 | # grpc_port: null
43 | # http_port: 3001
44 | # health_poll_delay: 1.0
45 | caikit.configure(
46 | config_dict={"module_backends": {"priority": [TGISBackend.backend_type]}}
47 | ) # should not be necessary but just in case
48 | configure() # backend configure
49 |
50 | # Load with TGIS backend
51 | prefix_model_path = os.path.join(PREFIX_PATH, "sample_prompt")
52 | my_model = caikit.load(prefix_model_path)
53 | sample_text = "@TommyHilfiger Dramatic shopping exp. ordered 6 jeans same size (30/32) 2 fits / 2 too large / 2 too slim : same brand > different sizing"
54 | sample_output = my_model.run(sample_text)
55 |
56 | print("---------- Model result ----------")
57 | print(sample_output)
58 |
--------------------------------------------------------------------------------
/examples/compare_local_vs_tgis_models.py:
--------------------------------------------------------------------------------
1 | # Utils imports should always come first, because they ensure caikit_nlp
2 | # is added to the syspath if not running inside of a container.
3 | # Standard
4 | import json
5 | import time
6 |
7 | # Third Party
8 | from datasets import load_dataset
9 | from utils import SUPPORTED_DATASETS, load_model
10 | import torch
11 |
12 | # First Party
13 | import caikit
14 |
15 | # Local
16 | from caikit_nlp import data_model
17 | import caikit_nlp
18 |
19 | NUM_SAMPLES_TO_RUN = 100
20 |
21 | model_path = "prompt_prefixes/sample_prompt"
22 | # Grab the test stream from the twitter dataset
23 | test_stream = SUPPORTED_DATASETS["twitter_complaints"].dataset_loader()[2]
24 |
25 | # Load the local block
26 | local_model = load_model(is_distributed=False, model_path=model_path)
27 | # Load the TGIS backed block; this will kill TGIS if a container is already running, then restart.
28 | distributed_model = load_model(is_distributed=True, model_path=model_path)
29 |
30 | preds = []
31 | for datum in test_stream:
32 | dis_res = distributed_model.run(datum.input)
33 | local_res = local_model.run(datum.input)
34 | preds.append(
35 | {
36 | "input": datum.input,
37 | "output": datum.output,
38 | "local_prediction": local_res.text.split(":")[-1].strip(),
39 | "distributed_prediction": dis_res.text.split(":")[-1].strip(),
40 | }
41 | )
42 | if len(preds) >= NUM_SAMPLES_TO_RUN:
43 | break
44 |
45 |
46 | with open("preds.json", "w") as f:
47 | json.dump(preds, f, sort_keys=True, indent=4)
48 |
49 | num_matching = 0
50 | num_local_correct = 0
51 | num_distributed_correct = 0
52 | num_mismatching = 0
53 | for x in preds:
54 | if x["output"] == x["local_prediction"]:
55 | num_local_correct += 1
56 |
57 | if x["output"] == x["distributed_prediction"]:
58 | num_distributed_correct += 1
59 |
60 | if x["local_prediction"] == x["distributed_prediction"]:
61 | num_matching += 1
62 | else:
63 | num_mismatching += 1
64 |
65 | print("----- Metrics -----")
66 | print("Num correct [local block via PEFT]: {}".format(num_local_correct))
67 | print("Num correct [distributed via TGIS]: {}".format(num_distributed_correct))
68 | print("Num matching remote / local preds: {}".format(num_matching))
69 | print("Num not matching remote / local preds: {}".format(num_mismatching))
70 |
--------------------------------------------------------------------------------
/tests/toolkit/test_task_specific_utils.py:
--------------------------------------------------------------------------------
1 | # Copyright The Caikit Authors
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | # Third Party
16 | import pytest
17 |
18 | # First Party
19 | from caikit.interfaces.nlp.data_model import ClassificationTrainRecord
20 |
21 | # Local
22 | from caikit_nlp.data_model import GenerationTrainRecord
23 | from caikit_nlp.toolkit.task_specific_utils import convert_to_generation_record
24 |
25 |
26 | def test_convert_classification_train_record_to_generation_record():
27 | classification_train_record = ClassificationTrainRecord(
28 | text="foo bar", labels=["label1"]
29 | )
30 | generated_train = convert_to_generation_record(classification_train_record)
31 | assert isinstance(generated_train, GenerationTrainRecord)
32 | assert generated_train.input == "foo bar"
33 | assert generated_train.output == "label1"
34 |
35 |
36 | def test_convert_generation_record_to_generation_record():
37 | generation_train_record = GenerationTrainRecord(input="foo bar", output="label1")
38 | generated_train = convert_to_generation_record(generation_train_record)
39 | assert isinstance(generated_train, GenerationTrainRecord)
40 | assert generated_train.input == generation_train_record.input
41 | assert generated_train.output == generation_train_record.output
42 |
43 |
44 | # When we support integer labels
45 | # def test_convert_classification_train_record_to_generation_record_numeric_labels():
46 | # classification_train_record = dm.ClassificationTrainRecord(
47 | # text="foo bar", labels=[1]
48 | # )
49 | # generated_train = convert_to_generation_record(classification_train_record)
50 | # assert isinstance(generated_train, dm.GenerationTrainRecord)
51 | # assert generated_train.input == classification_train_record.text
52 | # assert generated_train.output == "1"
53 |
54 |
55 | def test_convert_to_generation_record_gives_error_with_unsupported_type():
56 | string_record = "test record"
57 | with pytest.raises(TypeError):
58 | convert_to_generation_record(string_record)
59 |
--------------------------------------------------------------------------------
/caikit_nlp/config/config.yml:
--------------------------------------------------------------------------------
1 | # Copyright The Caikit Authors
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | ## Config elements for module boot-time configurations ##
15 |
16 | # Disallow downloading from arbitrary URLs in model bootstrapping. This defaults
17 | # to false for security, but can be enabled for bootstrapping
18 | allow_downloads: false
19 | torch_dtype: float32
20 |
21 | # Path of folder that will contain all the source prompts
22 | source_prompt_base: ""
23 |
24 | # Path for searching base models from
25 | base_models_dir: ""
26 |
27 | # Whether or not to purge TGIS prompts on model deletion
28 | unload_tgis_prompt_artifacts: false
29 | # Torchrun elastic launch configuration, e.g., for fine tuning on multiple GPUs
30 | master_addr: localhost
31 | master_port: 29550
32 |
33 | training_data_limit:
34 | __default__: -1
35 | # Configuration for PeftPromptTuning module
36 | 6655831b-960a-4dc5-8df4-867026e2cd41:
37 | add_model_name_here: 10000
38 |
39 | # Config used only in EmbeddingModule. Set here or use env vars like EMBEDDING_RETRIES=32
40 | embedding:
41 | # Allow models with remote code.
42 | trust_remote_code: false
43 | # Number of times to retry on error. Most deployments should use 0 retries.
44 | retries: 0
45 | # Batch size for encode() if <= 0 or invalid, the sentence-transformers default is used
46 | batch_size: 0
47 | # Should implicit truncation (with truncate_input_tokens=0) throw error for truncation (default) or disable this
48 | implicit_truncation_errors: true
49 | # Attempt to optimize with PyTorch compile()
50 | pt2_compile: false
51 | # Use IPEX optimize. Works best when used with autocast (bfloat16) below.
52 | ipex: false
53 | # Use autocast in encode with its default dtype (bfloat16)
54 | autocast: false
55 | # For testing, set device to "mps" on MacOS or "xpu" for IPEX GPU.
56 | # Otherwise, the default does automatic checks for cuda GPU (else cpu).
57 | device: ""
58 |
59 | runtime:
60 | library: caikit_nlp
61 |
62 | # Configure request timeout for TGIS backend (in seconds)
63 | tgis_request_timeout: 60
64 |
--------------------------------------------------------------------------------
/caikit_nlp/toolkit/data_type_utils.py:
--------------------------------------------------------------------------------
1 | # Copyright The Caikit Authors
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | # Standard
16 | from typing import Optional, Union
17 |
18 | # Third Party
19 | import torch
20 |
21 | # First Party
22 | from caikit import get_config
23 | from caikit.core.exceptions import error_handler
24 | import alog
25 |
26 | log = alog.use_channel("DATA_UTIL")
27 | error = error_handler.get(log)
28 |
29 |
30 | def str_to_torch_dtype(dtype_str: str) -> torch.dtype:
31 | """Given a string representation of a Torch data type, convert it to the actual torch dtype.
32 |
33 | Args:
34 | dtype_str: String representation of Torch dtype to be used; this should be an attr
35 | of the torch library whose value is a dtype.
36 |
37 | Returns:
38 | torch.dtype
39 | Data type of the Torch class being used.
40 | """
41 | dt = getattr(torch, dtype_str, None)
42 | if not isinstance(dt, torch.dtype):
43 | error("", ValueError(f"Unrecognized data type: {dtype_str}"))
44 | return dt
45 |
46 |
47 | def get_torch_dtype(dtype: Optional[Union[str, torch.dtype]] = None) -> torch.dtype:
48 | """Get the Torch data type to be used for interacting with a model.
49 |
50 | Args:
51 | dtype: Optional[Union[str, torch.dtype]]
52 | If dtype is a torch.dtype, returns it; if it's a string, grab it from the Torch lib.
53 | If None is provided, fall back to the default type in config, which can be
54 | overridden via environment variable.
55 |
56 | Returns:
57 | torch.dtype
58 | Torch data type to be used.
59 | """
60 | error.type_check("", torch.dtype, str, dtype=dtype, allow_none=True)
61 | # If a Torch dtype is passed, nothing to do
62 | if isinstance(dtype, torch.dtype):
63 | return dtype
64 | # If None/empty str was provided, fall back to config / env var override
65 | if not dtype:
66 | return str_to_torch_dtype(get_config().torch_dtype)
67 | # Otherwise convert it from a string
68 | return str_to_torch_dtype(dtype)
69 |
--------------------------------------------------------------------------------
/tests/fixtures/tiny_models/T5ForConditionalGeneration/special_tokens_map.json:
--------------------------------------------------------------------------------
1 | {
2 | "additional_special_tokens": [
3 | "",
4 | "",
5 | "",
6 | "",
7 | "",
8 | "",
9 | "",
10 | "",
11 | "",
12 | "",
13 | "",
14 | "",
15 | "",
16 | "",
17 | "",
18 | "",
19 | "",
20 | "",
21 | "",
22 | "",
23 | "",
24 | "",
25 | "",
26 | "",
27 | "",
28 | "",
29 | "",
30 | "",
31 | "",
32 | "",
33 | "",
34 | "",
35 | "",
36 | "",
37 | "",
38 | "",
39 | "",
40 | "",
41 | "",
42 | "",
43 | "",
44 | "",
45 | "",
46 | "",
47 | "",
48 | "",
49 | "",
50 | "",
51 | "",
52 | "",
53 | "",
54 | "",
55 | "",
56 | "",
57 | "",
58 | "",
59 | "",
60 | "",
61 | "",
62 | "",
63 | "",
64 | "",
65 | "",
66 | "",
67 | "",
68 | "",
69 | "",
70 | "",
71 | "",
72 | "",
73 | "",
74 | "",
75 | "",
76 | "",
77 | "",
78 | "",
79 | "",
80 | "",
81 | "",
82 | "",
83 | "",
84 | "",
85 | "",
86 | "",
87 | "",
88 | "",
89 | "",
90 | "",
91 | "",
92 | "",
93 | "",
94 | "",
95 | "",
96 | "",
97 | "",
98 | "",
99 | "",
100 | "",
101 | "",
102 | ""
103 | ],
104 | "eos_token": "",
105 | "pad_token": "",
106 | "unk_token": ""
107 | }
108 |
--------------------------------------------------------------------------------
/tests/fixtures/tiny_models/T5ForConditionalGeneration/tokenizer_config.json:
--------------------------------------------------------------------------------
1 | {
2 | "additional_special_tokens": [
3 | "",
4 | "",
5 | "",
6 | "",
7 | "",
8 | "",
9 | "",
10 | "",
11 | "",
12 | "",
13 | "",
14 | "",
15 | "",
16 | "",
17 | "",
18 | "",
19 | "",
20 | "",
21 | "",
22 | "",
23 | "",
24 | "",
25 | "",
26 | "",
27 | "",
28 | "",
29 | "",
30 | "",
31 | "",
32 | "",
33 | "",
34 | "",
35 | "",
36 | "",
37 | "",
38 | "",
39 | "",
40 | "",
41 | "",
42 | "",
43 | "",
44 | "",
45 | "",
46 | "",
47 | "",
48 | "",
49 | "",
50 | "",
51 | "",
52 | "",
53 | "",
54 | "",
55 | "",
56 | "",
57 | "",
58 | "",
59 | "",
60 | "",
61 | "",
62 | "",
63 | "",
64 | "",
65 | "",
66 | "",
67 | "",
68 | "",
69 | "",
70 | "",
71 | "",
72 | "",
73 | "",
74 | "",
75 | "",
76 | "",
77 | "",
78 | "",
79 | "",
80 | "",
81 | "",
82 | "",
83 | "",
84 | "",
85 | "",
86 | "",
87 | "",
88 | "",
89 | "",
90 | "",
91 | "",
92 | "",
93 | "",
94 | "",
95 | "",
96 | "",
97 | "",
98 | "",
99 | "",
100 | "",
101 | "",
102 | ""
103 | ],
104 | "eos_token": "",
105 | "extra_ids": 100,
106 | "model_max_length": 512,
107 | "pad_token": "",
108 | "special_tokens_map_file": null,
109 | "tokenizer_class": "T5Tokenizer",
110 | "unk_token": ""
111 | }
112 |
--------------------------------------------------------------------------------
/caikit_nlp/toolkit/trainer_utils.py:
--------------------------------------------------------------------------------
1 | # Copyright The Caikit Authors
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | """Contains toolkit functionality for huggingface Trainer"""
15 | # Standard
16 | from datetime import datetime
17 |
18 | # Third Party
19 | import torch
20 |
21 | # First Party
22 | from caikit import get_config
23 | from caikit.core.data_model import DataStream
24 | from caikit.core.exceptions import error_handler
25 | import alog
26 |
27 | log = alog.use_channel("TRNR_UTILS")
28 | error = error_handler.get(log)
29 |
30 |
31 | def validate_training_data(train_stream: DataStream, model_name: str, module_id: str):
32 |
33 | global_default = get_config().training_data_limit.__default__
34 | module_default = (
35 | get_config()
36 | .training_data_limit.get(module_id, {})
37 | .get("__default__", global_default)
38 | )
39 |
40 | max_num_examples = (
41 | get_config()
42 | .training_data_limit.get(module_id, {})
43 | .get(model_name, module_default)
44 | )
45 |
46 | if max_num_examples > 0:
47 | train_stream_size = len(train_stream)
48 | error.value_check(
49 | "",
50 | train_stream_size <= max_num_examples,
51 | "Number of examples ({}) exceeds the maximum number of examples allowed "
52 | "({}) for this model",
53 | train_stream_size,
54 | max_num_examples,
55 | )
56 |
57 |
58 | def log_step(state, logs):
59 | if state.epoch is not None:
60 | logs["epoch"] = round(state.epoch, 2)
61 |
62 | # Get Rank
63 | if torch.distributed.is_initialized():
64 | rank = torch.distributed.get_rank()
65 | else:
66 | rank = 0
67 |
68 | if "loss" in logs:
69 | if state.epoch is not None:
70 | logs["epoch"] = round(state.epoch, 2)
71 |
72 | log.debug(
73 | "process rank: {} loss: {} step: {}".format(
74 | rank, float(logs["loss"]), state.global_step
75 | )
76 | )
77 | output = {
78 | "epoch": float(logs["epoch"]),
79 | "step": state.global_step,
80 | "value": float(logs["loss"]),
81 | "timestamp": datetime.isoformat(datetime.now()),
82 | }
83 | state.log_history.append(output)
84 | else:
85 | output = {**logs, **{"step": state.global_step}}
86 | state.log_history.append(output)
87 |
88 | return state
89 |
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | # Byte-compiled / optimized / DLL files
2 | __pycache__/
3 | *.py[cod]
4 | *$py.class
5 |
6 | # C extensions
7 | *.so
8 |
9 | # Distribution / packaging
10 | data/classification_data/
11 | models/
12 | tmp_models/
13 | models_to_upload/
14 | upload_models/
15 | downloaded_models/
16 | .config
17 | .jupyter
18 | .Python
19 | .bash_history
20 | .python_history
21 | .local/
22 | build/
23 | develop-eggs/
24 | dist/
25 | downloads/
26 | eggs/
27 | .eggs/
28 | lib/
29 | lib64/
30 | parts/
31 | sdist/
32 | var/
33 | wheels/
34 | *.egg-info/
35 | .installed.cfg
36 | *.egg
37 | MANIFEST
38 | *_pb2.py
39 | *_pb2_grpc.py
40 | .local/
41 | .bash_history
42 | .python_history
43 |
44 | # PyInstaller
45 | # Usually these files are written by a python script from a template
46 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
47 | *.manifest
48 | *.spec
49 |
50 | # Installer logs
51 | pip-log.txt
52 | pip-delete-this-directory.txt
53 |
54 | # Unit test / coverage reports
55 | htmlcov/
56 | .tox/
57 | .coverage
58 | .coverage.*
59 | .cache
60 | nosetests.xml
61 | coverage.xml
62 | *.cover
63 | .hypothesis/
64 | .pytest_cache/
65 |
66 | # Translations
67 | *.mo
68 | *.pot
69 |
70 | # Django stuff:
71 | *.log
72 | local_settings.py
73 | db.sqlite3
74 |
75 | # Flask stuff:
76 | instance/
77 | .webassets-cache
78 |
79 | # Scrapy stuff:
80 | .scrapy
81 |
82 | # Sphinx documentation
83 | docs/_build/
84 |
85 | # PyBuilder
86 | target/
87 |
88 | # Jupyter Notebook
89 | .ipynb_checkpoints
90 |
91 | # pyenv
92 | .python-version
93 | .ipython
94 | # vscode
95 | .vscode/
96 | # celery beat schedule file
97 | celerybeat-schedule
98 |
99 | # SageMath parsed files
100 | *.sage.py
101 |
102 | # Environments
103 | .env
104 | .venv
105 | env/
106 | venv/
107 | ENV/
108 | env.bak/
109 | venv.bak/
110 |
111 | # shell environment
112 | .inputrc
113 | .profile
114 |
115 | # Spyder project settings
116 | .spyderproject
117 | .spyproject
118 |
119 | # Rope project settings
120 | .ropeproject
121 |
122 | # mkdocs documentation
123 | /site
124 |
125 | # mypy
126 | .mypy_cache/
127 |
128 | # IntelliJ
129 | .idea/
130 |
131 | # Mac OS
132 | .DS_Store
133 | __MACOSX/
134 | generic_perf_benchmark
135 | .local
136 | .keras
137 | .bash_history
138 | perf_test_artifacts
139 | tests/fixtures/downloaded
140 | .mycontrib
141 | mlruns
142 | # Pylint
143 | .pylint.d
144 |
145 | # Make artifacts
146 | .version
147 | .location
148 | build
149 | generate_protos
150 | GITHUB_TOKEN
151 | reports
152 | dry_run.txt
153 |
154 | ### Artifacts from training scripts
155 | # Path to MPT exported models
156 | saved_outputs
157 | # Path used by sampled commands for checkpoints etc
158 | teacher_prompts_t5-base
159 |
160 | # Runtime artifacts
161 | eligible_modules.json
162 | ineligible_modules.json
163 | modules.json
164 |
165 | # Other artifacts from demo scripts, e.g., TGIS integration scripts
166 | prompt_prefixes
167 | sample_prompt
168 | transformers_cache
169 | generated_interfaces
170 | /caikit_nlp/_version.py
171 |
--------------------------------------------------------------------------------
/caikit_nlp/data_model/generation.py:
--------------------------------------------------------------------------------
1 | # Copyright The Caikit Authors
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | """Data structures for text generation representations
15 | """
16 | # Standard
17 | from enum import Enum
18 | from typing import List
19 |
20 | # First Party
21 | from caikit.core import DataObjectBase
22 |
23 | # First party
24 | import alog
25 | import caikit
26 |
27 | log = alog.use_channel("DATAM")
28 |
29 |
30 | class PromptOutputModelType(Enum):
31 | ENCODER = "ENCODER"
32 | DECODER = "DECODER"
33 |
34 |
35 | @caikit.core.dataobject(package="caikit_data_model.caikit_nlp")
36 | class GenerationTrainRecord(DataObjectBase):
37 | input: str
38 | output: str
39 |
40 |
41 | @caikit.core.dataobject(package="caikit_data_model.caikit_nlp")
42 | class TuningConfig(DataObjectBase):
43 | # If prompt_tuning_init_text is not provided, then random would be used
44 | # but since random is not supported currently, we want to keep text to be
45 | # required
46 | num_virtual_tokens: int
47 | # TODO: Move all _init_ params to separate object
48 | prompt_tuning_init_text: str
49 | prompt_tuning_init_method: str
50 | # could be: `RANDOM`, `TEXT`, `ONLY_SOURCE_SHARED` and `AVERAGE_SOURCE`
51 | #
52 | prompt_tuning_init_source_model: str # this maps to prompt_tuning_init_state_dict_path in MPT
53 | # which is path pointing to the state dict of the model to be used for initialization
54 | #
55 | # token_dim: int # Optional - dimension of the virtual tokens.
56 | #
57 | output_model_types: List[str]
58 | # this replaces `num_transformer_submodules`
59 | # option and can take values encoder, decoder. For each
60 | # selected resource type will only provide certain possibilities,
61 | # for example, causal-lm models will only provide decoder as option.
62 | # If None provided, then we will use defaults for that model_type
63 | # num_transformer_submodules: int # Optional - The number of transformer submodules in the
64 | # base transformer model.
65 | # 1 for all decoder-only models and 2 for encoder-decoder models.
66 | # If 1 is used for encoder-decoder models, the prompt will be used for the encoder only.
67 | #
68 | # num_attention_heads: int # Optional - The number of attention heads in the
69 | # base transformer model
70 | #
71 | # num_layers: int # Optional - The number of layers in the base transformer model
72 | #
73 | # encoder_hidden_size: int # Optional - The hidden size of the prompt encoder.
74 |
75 |
76 | @caikit.core.dataobject(package="caikit_data_model.caikit_nlp")
77 | class ExponentialDecayLengthPenalty(DataObjectBase):
78 | start_index: int
79 | decay_factor: float
80 |
--------------------------------------------------------------------------------
/tests/modules/text_classification/test_sequence_classification.py:
--------------------------------------------------------------------------------
1 | """Tests for sequence classification module
2 | """
3 | # Standard
4 | import tempfile
5 |
6 | # Third Party
7 | from pytest import approx
8 | import pytest
9 |
10 | # First Party
11 | from caikit.interfaces.nlp.data_model import ClassificationResult, ClassificationResults
12 |
13 | # Local
14 | from caikit_nlp.modules.text_classification import SequenceClassification
15 | from tests.fixtures import SEQ_CLASS_MODEL
16 |
17 | ## Setup ########################################################################
18 |
19 | # Bootstrapped sequence classification model for reusability across tests
20 | # .bootstrap is tested separately in the first test
21 | BOOTSTRAPPED_SEQ_CLASS_MODEL = SequenceClassification.bootstrap(SEQ_CLASS_MODEL)
22 |
23 | TEXTS = [
24 | "The quick brown fox jumps over the lazy dog.",
25 | "Once upon a time in a land far away",
26 | ]
27 |
28 | ## Tests ########################################################################
29 | # Exact numbers here from the tiny model are not particularly important,
30 | # but we check them here to make sure that the arrays are re-ordered correctly
31 |
32 |
33 | def test_bootstrap_and_run():
34 | """Check if we can bootstrap and run sequence classification models"""
35 | model = SequenceClassification.bootstrap(SEQ_CLASS_MODEL)
36 | classification_result = model.run(TEXTS[0])
37 | assert isinstance(classification_result, ClassificationResults)
38 | assert len(classification_result.results) == 2 # 2 labels
39 |
40 | assert isinstance(classification_result.results[0], ClassificationResult)
41 | assert classification_result.results[0].label == "LABEL_0"
42 | assert approx(classification_result.results[0].score) == 0.49526197
43 | assert classification_result.results[1].label == "LABEL_1"
44 | assert approx(classification_result.results[1].score) == 0.50473803
45 |
46 |
47 | def test_bootstrap_and_run_batch():
48 | """Check if we can bootstrap and run_batch sequence classification models"""
49 | classification_result_list = BOOTSTRAPPED_SEQ_CLASS_MODEL.run_batch(TEXTS)
50 | assert len(classification_result_list) == 2
51 |
52 | first_result = classification_result_list[0]
53 | assert isinstance(first_result, ClassificationResults)
54 | assert first_result.results[0].label == "LABEL_0"
55 | assert approx(first_result.results[0].score) == 0.49526197
56 | assert first_result.results[1].label == "LABEL_1"
57 | assert classification_result_list[1].results[0].label == "LABEL_0"
58 |
59 |
60 | def test_load_save_and_run_model():
61 | """Check if we can load and run a saved model successfully"""
62 | with tempfile.TemporaryDirectory() as model_dir:
63 | BOOTSTRAPPED_SEQ_CLASS_MODEL.save(model_dir)
64 | new_model = SequenceClassification.load(model_dir)
65 | classification_result = new_model.run(TEXTS[0])
66 | assert isinstance(classification_result, ClassificationResults)
67 | assert len(classification_result.results) == 2 # 2 labels
68 |
69 | assert isinstance(classification_result.results[0], ClassificationResult)
70 | assert classification_result.results[0].label == "LABEL_0"
71 | assert approx(classification_result.results[0].score) == 0.49526197
72 |
--------------------------------------------------------------------------------
/caikit_nlp/toolkit/data_stream_wrapper.py:
--------------------------------------------------------------------------------
1 | # Copyright The Caikit Authors
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | """This module contains utility classes for wrapping data streams as Torch datasets efficiently.
15 | Caikit modules leveraging such wrappers can use them to internally leverage common approaches
16 | and objects for training / evaluating PyTorch models built around DataStreams, e.g., PyTorch
17 | DataLoaders, with minimal boilerplate.
18 | """
19 |
20 | # Third Party
21 | from torch.utils.data import IterableDataset
22 |
23 | # First Party
24 | from caikit.core.exceptions import error_handler
25 | import alog
26 |
27 | log = alog.use_channel("PEFT_PROMPT")
28 | error = error_handler.get(log)
29 |
30 |
31 | class SimpleIterableStreamWrapper(IterableDataset):
32 | """DataStream wrapper as an iterable PyTorch dataset; we use this to add
33 | compatability with PyTorch data loaders.
34 | """
35 |
36 | def __init__(self, stream, shuffle, buffer_size=None):
37 | error.type_check("", bool, shuffle=shuffle)
38 | error.type_check(
39 | "", int, buffer_size=buffer_size, allow_none=True
40 | )
41 | self.stream = stream
42 | self.shuffle = shuffle
43 | self.buffer_size = buffer_size
44 | # Load the whole data set in memory
45 | if self.shuffle and buffer_size is None:
46 | self.buffer_size = len(stream)
47 | log.debug("Shuffling enabled? {}".format(self.shuffle))
48 | log.debug("Shuffling buffer size: {}".format(self.buffer_size))
49 |
50 | def __iter__(self):
51 |
52 | # FIXME: We are currently not handling case where we have to work with
53 | # multiple workers, so currently duplicate data will get processed by
54 | # each worker.
55 | if self.shuffle:
56 | log.debug4("Reshuffling training data!")
57 | return iter(self.stream.shuffle(self.buffer_size))
58 | return iter(self.stream)
59 | # worker_info = get_worker_info()
60 | # if worker_info is None: # single-process data loading, return the full iterator
61 | # if self.shuffle:
62 | # log.debug4("Reshuffling training data!")
63 | # return iter(self.stream.shuffle(self.buffer_size))
64 | # return iter(self.stream)
65 |
66 | # When num_workers > 0, each worker process will have a different copy of
67 | # the dataset object, so we configure each copy independently to avoid
68 | # having duplicate data returned from each worker
69 | # else: # in a worker process
70 | # # split workload
71 | # per_worker = int(
72 | # math.ceil((self.end - self.start) / float(worker_info.num_workers))
73 | # )
74 | # worker_id = worker_info.id
75 | # iter_start = self.start + worker_id * per_worker
76 | # iter_end = min(iter_start + per_worker, self.end)
77 | # return iter(range(iter_start, iter_end))
78 |
79 | def __len__(self):
80 | return len(self.stream)
81 |
--------------------------------------------------------------------------------
/benchmarks/logs/llama2-7b/20230905_194133.output:
--------------------------------------------------------------------------------
1 | (tuning) [gpu_user@gpu6120 caikit-nlp]$ ./ft_job.sh
2 | /u/gpu_user/.conda/envs/tuning/lib/python3.9/site-packages/caikit/core/toolkit/errors/__init__.py:29: DeprecationWarning: The caikit.toolkit.errors package has moved to caikit.core.exceptions
3 | _warnings.warn(
4 | is still in the BETA phase and subject to change!
5 | /u/gpu_user/.conda/envs/tuning/lib/python3.9/site-packages/caikit/core/toolkit/error_handler.py:29: DeprecationWarning: The caikit.toolkit.error_handler package has moved to caikit.core.exceptions
6 | _warnings.warn(
7 | Existing model directory found; purging it now.
8 | Experiment Configuration
9 | - Model Name: [/tmp/tu/huggingface/hub/models--llama-2-7b]
10 | |- Inferred Model Resource Type: []
11 | - Dataset: [glue/rte]
12 | - Number of Epochs: [1]
13 | - Learning Rate: [2e-05]
14 | - Batch Size: [8]
15 | - Output Directory: [/tmp/tu/output/tuning/llama27b]
16 | - Maximum source sequence length: [256]
17 | - Maximum target sequence length: [1024]
18 | - Gradient accumulation steps: [16]
19 | - Enable evaluation: [False]
20 | - Evaluation metrics: [['rouge']]
21 | - Torch dtype to use for training: [bfloat16]
22 | [Loading the dataset...]
23 | 2023-09-05T19:40:43.686785 [fsspe:DBUG] open file: /u/gpu_user/.cache/huggingface/datasets/glue/rte/1.0.0/dacbe3125aa31d7f70367a07a8a9e72a5a0bfeb5fc42e75c9db75b96da6053ad/dataset_info.json
24 | 2023-09-05T19:40:43.702480 [fsspe:DBUG] open file: /u/gpu_user/.cache/huggingface/datasets/glue/rte/1.0.0/dacbe3125aa31d7f70367a07a8a9e72a5a0bfeb5fc42e75c9db75b96da6053ad/dataset_info.json
25 | [Loading the base model resource...]
26 | Loading checkpoint shards: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 2/2 [00:05<00:00, 2.73s/it]
27 | [Starting the training...]
28 | 2023-09-05T19:41:33.062266 [PEFT_:DBUG] Shuffling enabled? True
29 | 2023-09-05T19:41:33.062427 [PEFT_:DBUG] Shuffling buffer size: 7470
30 | TRAINING ARGS: {
31 | "output_dir": "/tmp",
32 | "per_device_train_batch_size": 8,
33 | "per_device_eval_batch_size": 8,
34 | "num_train_epochs": 1,
35 | "seed": 73,
36 | "do_eval": false,
37 | "learning_rate": 2e-05,
38 | "weight_decay": 0.01,
39 | "save_total_limit": 3,
40 | "push_to_hub": false,
41 | "no_cuda": false,
42 | "remove_unused_columns": false,
43 | "dataloader_pin_memory": false,
44 | "gradient_accumulation_steps": 16,
45 | "eval_accumulation_steps": 16,
46 | "bf16": true
47 | }
48 | 0%| | 0/58 [00:00, ?it/s]You're using a LlamaTokenizerFast tokenizer. Please note that with a fast tokenizer, using the `__call__` method is faster than using a method to encode the text followed by a call to the `pad` method to get a padded encoding.
49 | {'train_runtime': 356.7428, 'train_samples_per_second': 20.939, 'train_steps_per_second': 0.163, 'train_loss': 1.7029038790998787, 'epoch': 0.99}
50 | 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 58/58 [05:56<00:00, 6.15s/it]
51 | Loading checkpoint shards: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 2/2 [00:03<00:00, 1.86s/it]
52 | Using sep_token, but it is not set yet.
53 | [Training Complete]
--------------------------------------------------------------------------------
/benchmarks/logs/llama2-7b/20230906_135211.output:
--------------------------------------------------------------------------------
1 | (tuning) [gpu_user@gpu5530 caikit-nlp]$ ./ft_job.sh
2 | /u/gpu_user/.conda/envs/tuning/lib/python3.9/site-packages/caikit/core/toolkit/errors/__init__.py:29: DeprecationWarning: The caikit.toolkit.errors package has moved to caikit.core.exceptions
3 | _warnings.warn(
4 | is still in the BETA phase and subject to change!
5 | /u/gpu_user/.conda/envs/tuning/lib/python3.9/site-packages/caikit/core/toolkit/error_handler.py:29: DeprecationWarning: The caikit.toolkit.error_handler package has moved to caikit.core.exceptions
6 | _warnings.warn(
7 | Existing model directory found; purging it now.
8 | Experiment Configuration
9 | - Model Name: [/tmp/tu/huggingface/hub/models--llama-2-7b]
10 | |- Inferred Model Resource Type: []
11 | - Dataset: [glue/rte]
12 | - Number of Epochs: [1]
13 | - Learning Rate: [2e-05]
14 | - Batch Size: [6]
15 | - Output Directory: [/tmp/tu/output/tuning/llama27b]
16 | - Maximum source sequence length: [512]
17 | - Maximum target sequence length: [1024]
18 | - Gradient accumulation steps: [16]
19 | - Enable evaluation: [False]
20 | - Evaluation metrics: [['rouge']]
21 | - Torch dtype to use for training: [bfloat16]
22 | [Loading the dataset...]
23 | 2023-09-06T13:51:21.128309 [fsspe:DBUG] open file: /u/gpu_user/.cache/huggingface/datasets/glue/rte/1.0.0/dacbe3125aa31d7f70367a07a8a9e72a5a0bfeb5fc42e75c9db75b96da6053ad/dataset_info.json
24 | 2023-09-06T13:51:21.146717 [fsspe:DBUG] open file: /u/gpu_user/.cache/huggingface/datasets/glue/rte/1.0.0/dacbe3125aa31d7f70367a07a8a9e72a5a0bfeb5fc42e75c9db75b96da6053ad/dataset_info.json
25 | [Loading the base model resource...]
26 | Loading checkpoint shards: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 2/2 [00:05<00:00, 2.79s/it]
27 | [Starting the training...]
28 | 2023-09-06T13:52:11.307381 [PEFT_:DBUG] Shuffling enabled? True
29 | 2023-09-06T13:52:11.307508 [PEFT_:DBUG] Shuffling buffer size: 7470
30 | TRAINING ARGS: {
31 | "output_dir": "/tmp",
32 | "per_device_train_batch_size": 6,
33 | "per_device_eval_batch_size": 6,
34 | "num_train_epochs": 1,
35 | "seed": 73,
36 | "do_eval": false,
37 | "learning_rate": 2e-05,
38 | "weight_decay": 0.01,
39 | "save_total_limit": 3,
40 | "push_to_hub": false,
41 | "no_cuda": false,
42 | "remove_unused_columns": false,
43 | "dataloader_pin_memory": false,
44 | "gradient_accumulation_steps": 16,
45 | "eval_accumulation_steps": 16,
46 | "bf16": true
47 | }
48 | 0%| | 0/77 [00:00, ?it/s]You're using a LlamaTokenizerFast tokenizer. Please note that with a fast tokenizer, using the `__call__` method is faster than using a method to encode the text followed by a call to the `pad` method to get a padded encoding.
49 | {'train_runtime': 348.4812, 'train_samples_per_second': 21.436, 'train_steps_per_second': 0.221, 'train_loss': 1.6495626870687905, 'epoch': 0.99}
50 | 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 77/77 [05:48<00:00, 4.53s/it]
51 | Loading checkpoint shards: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 2/2 [00:03<00:00, 1.61s/it]
52 | Using sep_token, but it is not set yet.
53 | [Training Complete]
--------------------------------------------------------------------------------
/benchmarks/logs/llama2-7b/20230905_183655.output:
--------------------------------------------------------------------------------
1 | (tuning) [gpu_user@gpu5480 caikit-nlp]$ ./ft_job.sh
2 | /u/gpu_user/.conda/envs/tuning/lib/python3.9/site-packages/caikit/core/toolkit/errors/__init__.py:29: DeprecationWarning: The caikit.toolkit.errors package has moved to caikit.core.exceptions
3 | _warnings.warn(
4 | is still in the BETA phase and subject to change!
5 | /u/gpu_user/.conda/envs/tuning/lib/python3.9/site-packages/caikit/core/toolkit/error_handler.py:29: DeprecationWarning: The caikit.toolkit.error_handler package has moved to caikit.core.exceptions
6 | _warnings.warn(
7 | Existing model directory found; purging it now.
8 | Experiment Configuration
9 | - Model Name: [/tmp/tu/huggingface/hub/models--llama-2-7b]
10 | |- Inferred Model Resource Type: []
11 | - Dataset: [glue/rte]
12 | - Number of Epochs: [1]
13 | - Learning Rate: [2e-05]
14 | - Batch Size: [6]
15 | - Output Directory: [/tmp/tu/output/tuning/llama27b]
16 | - Maximum source sequence length: [4096]
17 | - Maximum target sequence length: [1024]
18 | - Gradient accumulation steps: [16]
19 | - Enable evaluation: [False]
20 | - Evaluation metrics: [['rouge']]
21 | - Torch dtype to use for training: [bfloat16]
22 | [Loading the dataset...]
23 | 2023-09-05T18:36:55.174106 [fsspe:DBUG] open file: /u/gpu_user/.cache/huggingface/datasets/glue/rte/1.0.0/dacbe3125aa31d7f70367a07a8a9e72a5a0bfeb5fc42e75c9db75b96da6053ad/dataset_info.json
24 | 2023-09-05T18:36:55.192203 [fsspe:DBUG] open file: /u/gpu_user/.cache/huggingface/datasets/glue/rte/1.0.0/dacbe3125aa31d7f70367a07a8a9e72a5a0bfeb5fc42e75c9db75b96da6053ad/dataset_info.json
25 | [Loading the base model resource...]
26 | Loading checkpoint shards: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 2/2 [00:05<00:00, 2.89s/it]
27 | [Starting the training...]
28 | 2023-09-05T18:37:47.419502 [PEFT_:DBUG] Shuffling enabled? True
29 | 2023-09-05T18:37:47.419666 [PEFT_:DBUG] Shuffling buffer size: 7470
30 | TRAINING ARGS: {
31 | "output_dir": "/tmp",
32 | "per_device_train_batch_size": 6,
33 | "per_device_eval_batch_size": 6,
34 | "num_train_epochs": 1,
35 | "seed": 73,
36 | "do_eval": false,
37 | "learning_rate": 2e-05,
38 | "weight_decay": 0.01,
39 | "save_total_limit": 3,
40 | "push_to_hub": false,
41 | "no_cuda": false,
42 | "remove_unused_columns": false,
43 | "dataloader_pin_memory": false,
44 | "gradient_accumulation_steps": 16,
45 | "eval_accumulation_steps": 16,
46 | "bf16": true
47 | }
48 | 0%| | 0/77 [00:00, ?it/s]You're using a LlamaTokenizerFast tokenizer. Please note that with a fast tokenizer, using the `__call__` method is faster than using a method to encode the text followed by a call to the `pad` method to get a padded encoding.
49 | {'train_runtime': 350.2997, 'train_samples_per_second': 21.325, 'train_steps_per_second': 0.22, 'train_loss': 1.6495626870687905, 'epoch': 0.99}
50 | 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 77/77 [05:50<00:00, 4.55s/it]
51 | Loading checkpoint shards: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 2/2 [00:03<00:00, 1.79s/it]
52 | Using sep_token, but it is not set yet.
53 | [Training Complete]
54 |
55 |
--------------------------------------------------------------------------------
/benchmarks/logs/llama2-7b/20230905_184809.output:
--------------------------------------------------------------------------------
1 | (tuning) [gpu_user@gpu5480 caikit-nlp]$ ./ft_job.sh
2 | /u/gpu_user/.conda/envs/tuning/lib/python3.9/site-packages/caikit/core/toolkit/errors/__init__.py:29: DeprecationWarning: The caikit.toolkit.errors package has moved to caikit.core.exceptions
3 | _warnings.warn(
4 | is still in the BETA phase and subject to change!
5 | /u/gpu_user/.conda/envs/tuning/lib/python3.9/site-packages/caikit/core/toolkit/error_handler.py:29: DeprecationWarning: The caikit.toolkit.error_handler package has moved to caikit.core.exceptions
6 | _warnings.warn(
7 | Existing model directory found; purging it now.
8 | Experiment Configuration
9 | - Model Name: [/tmp/tu/huggingface/hub/models--llama-2-7b]
10 | |- Inferred Model Resource Type: []
11 | - Dataset: [glue/rte]
12 | - Number of Epochs: [1]
13 | - Learning Rate: [2e-05]
14 | - Batch Size: [6]
15 | - Output Directory: [/tmp/tu/output/tuning/llama27b]
16 | - Maximum source sequence length: [1024]
17 | - Maximum target sequence length: [1024]
18 | - Gradient accumulation steps: [16]
19 | - Enable evaluation: [False]
20 | - Evaluation metrics: [['rouge']]
21 | - Torch dtype to use for training: [bfloat16]
22 | [Loading the dataset...]
23 | 2023-09-05T18:47:18.075310 [fsspe:DBUG] open file: /u/gpu_user/.cache/huggingface/datasets/glue/rte/1.0.0/dacbe3125aa31d7f70367a07a8a9e72a5a0bfeb5fc42e75c9db75b96da6053ad/dataset_info.json
24 | 2023-09-05T18:47:18.093371 [fsspe:DBUG] open file: /u/gpu_user/.cache/huggingface/datasets/glue/rte/1.0.0/dacbe3125aa31d7f70367a07a8a9e72a5a0bfeb5fc42e75c9db75b96da6053ad/dataset_info.json
25 | [Loading the base model resource...]
26 | Loading checkpoint shards: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 2/2 [00:05<00:00, 2.76s/it]
27 | [Starting the training...]
28 | 2023-09-05T18:48:09.755222 [PEFT_:DBUG] Shuffling enabled? True
29 | 2023-09-05T18:48:09.755357 [PEFT_:DBUG] Shuffling buffer size: 7470
30 | TRAINING ARGS: {
31 | "output_dir": "/tmp",
32 | "per_device_train_batch_size": 6,
33 | "per_device_eval_batch_size": 6,
34 | "num_train_epochs": 1,
35 | "seed": 73,
36 | "do_eval": false,
37 | "learning_rate": 2e-05,
38 | "weight_decay": 0.01,
39 | "save_total_limit": 3,
40 | "push_to_hub": false,
41 | "no_cuda": false,
42 | "remove_unused_columns": false,
43 | "dataloader_pin_memory": false,
44 | "gradient_accumulation_steps": 16,
45 | "eval_accumulation_steps": 16,
46 | "bf16": true
47 | }
48 | 0%| | 0/77 [00:00, ?it/s]You're using a LlamaTokenizerFast tokenizer. Please note that with a fast tokenizer, using the `__call__` method is faster than using a method to encode the text followed by a call to the `pad` method to get a padded encoding.
49 | {'train_runtime': 350.165, 'train_samples_per_second': 21.333, 'train_steps_per_second': 0.22, 'train_loss': 1.6495626870687905, 'epoch': 0.99}
50 | 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 77/77 [05:50<00:00, 4.55s/it]
51 | Loading checkpoint shards: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 2/2 [00:03<00:00, 1.83s/it]
52 | Using sep_token, but it is not set yet.
53 | [Training Complete]
54 |
55 |
--------------------------------------------------------------------------------
/caikit_nlp/toolkit/torch_run.py:
--------------------------------------------------------------------------------
1 | # Copyright The Caikit Authors
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | """This toolkit utility contains functions to operate on torch distribution
16 |
17 | NOTE: Content of this file are heavily influenced by torch/distributed/run.py
18 |
19 | Ref: https://github.com/pytorch/pytorch/blob/main/torch/distributed/run.py
20 | """
21 |
22 | # Standard
23 | import os
24 |
25 | # Third Party
26 | from torch import cuda
27 | from torch.distributed.launcher.api import LaunchConfig
28 | import torch.distributed as dist
29 |
30 | # First Party
31 | import alog
32 |
33 | log = alog.use_channel("TRCH_RN")
34 |
35 |
36 | def initialize_torch_distribution(world_size, rank=0, backend="gloo|nccl"):
37 |
38 | if dist.is_available():
39 | log.debug(
40 | "Initializing process group - backend %s, rank %d, world size %d",
41 | backend,
42 | rank,
43 | world_size,
44 | )
45 | dist.init_process_group(backend=backend, world_size=world_size, rank=rank)
46 |
47 |
48 | def determine_local_world_size():
49 | """Function to automatically deduce the world size based on
50 | available processors.
51 |
52 | NOTE: This function will try to use ALL gpus accessible to it
53 | """
54 |
55 | if cuda.is_available():
56 | num_proc = cuda.device_count()
57 | log.info("Cuda devices available! Using %d devices.", num_proc)
58 | return num_proc
59 | # Fall back to using the OS cpu count
60 | # TODO: Callibrate this to some reasonable default...
61 | num_proc = os.cpu_count()
62 | log.info("Cuda devices NOT available! Using CPU %d processes.", num_proc)
63 | return num_proc
64 |
65 |
66 | def get_torch_elastic_launch_config(
67 | master_addr: str,
68 | master_port: str,
69 | start_method: str = "spawn",
70 | max_restarts=3,
71 | ) -> LaunchConfig:
72 |
73 | # Constants; we assume everything executes on the same node
74 | min_nodes = 1
75 | max_nodes = 1
76 | rdzv_configs = {"rank": 0}
77 |
78 | nproc_per_node = determine_local_world_size()
79 |
80 | if "OMP_NUM_THREADS" not in os.environ and nproc_per_node > 1:
81 | omp_num_threads = 1
82 | log.warning(
83 | "\n*****************************************\n"
84 | "Setting OMP_NUM_THREADS environment variable for each process to be "
85 | "%s in default, to avoid your system being overloaded, "
86 | "please further tune the variable for optimal performance in "
87 | "your application as needed. \n"
88 | "*****************************************",
89 | omp_num_threads,
90 | )
91 | # This env variable will be passed down to the subprocesses
92 | os.environ["OMP_NUM_THREADS"] = str(omp_num_threads)
93 |
94 | return LaunchConfig(
95 | min_nodes=min_nodes,
96 | max_nodes=max_nodes,
97 | nproc_per_node=nproc_per_node,
98 | start_method=start_method,
99 | rdzv_backend="static",
100 | rdzv_endpoint=f"{master_addr}:{master_port}",
101 | rdzv_configs=rdzv_configs,
102 | max_restarts=max_restarts,
103 | )
104 |
--------------------------------------------------------------------------------
/tests/toolkit/test_verbalizers.py:
--------------------------------------------------------------------------------
1 | """Tests for verbalizer substitution utils.
2 | """
3 | # Third Party
4 | import pytest
5 |
6 | # Local
7 | from caikit_nlp.toolkit.verbalizer_utils import is_valid_verbalizer, render_verbalizer
8 | import caikit_nlp
9 |
10 | SAMPLE_DM = caikit_nlp.data_model.GenerationTrainRecord(
11 | input="my input text", output="my output text"
12 | )
13 | SAMPLE_DICT = SAMPLE_DM.to_dict()
14 |
15 | ### Happy path rendering cases
16 | def test_is_valid_verbalizer():
17 | """Ensure that when we have happy verbalizers, it's easy to check."""
18 | assert is_valid_verbalizer("{{input}}") == True
19 | assert is_valid_verbalizer("Input: {{input}}") == True
20 | assert is_valid_verbalizer("Input: {{input}} Output: {{output}}") == True
21 |
22 |
23 | def test_raw_verbalizer_is_preserved():
24 | """Ensure that if no placeholders are provided, the raw string is returned."""
25 | # NOTE: you would never want to do this because it means your input string is hardcoded!
26 | # However, we the responsibility of checking this up to the verbalizer validator so that
27 | # we can validate up front, rather than checking lazily at train time etc.
28 | verbalizer_template = "text: foo label: bar"
29 | expected_render_result = verbalizer_template
30 | # Check that we get the same behavior from both the DM / Dict based verbalizer replacement
31 | assert render_verbalizer(verbalizer_template, SAMPLE_DM) == expected_render_result
32 | assert render_verbalizer(verbalizer_template, SAMPLE_DICT) == expected_render_result
33 |
34 |
35 | def test_verbalizer_substitution():
36 | """Ensure that if placeholders are used, class attrs are rendered into placeholders."""
37 | verbalizer_template = "text: {{input}} label: {{output}}"
38 | expected_render_result = "text: {} label: {}".format(
39 | SAMPLE_DM.input, SAMPLE_DM.output
40 | )
41 | assert render_verbalizer(verbalizer_template, SAMPLE_DM) == expected_render_result
42 | assert render_verbalizer(verbalizer_template, SAMPLE_DICT) == expected_render_result
43 |
44 |
45 | def test_invalid_value_verbalizer_substitution():
46 | """Ensure that if we try to render a placeholder that is a bad property name, we skip it."""
47 | verbalizer_template = "text: {{foo bar}}"
48 | expected_render_result = verbalizer_template
49 | assert render_verbalizer(verbalizer_template, SAMPLE_DM) == expected_render_result
50 | assert render_verbalizer(verbalizer_template, SAMPLE_DICT) == expected_render_result
51 |
52 |
53 | def test_empty_verbalizer_substitution():
54 | """Ensure that if we try to render an empty placeholder, we skip it."""
55 | verbalizer_template = "text: {{}} label: {{}}"
56 | expected_render_result = verbalizer_template
57 | assert render_verbalizer(verbalizer_template, SAMPLE_DM) == expected_render_result
58 | assert render_verbalizer(verbalizer_template, SAMPLE_DICT) == expected_render_result
59 |
60 |
61 | ### sad path rendering cases
62 | def test_is_invalid_verbalizer():
63 | """Ensure that when we have happy verbalizers, it's easy to check."""
64 | assert is_valid_verbalizer(100) == False
65 | assert is_valid_verbalizer("") == False
66 | assert is_valid_verbalizer("source") == False
67 | assert is_valid_verbalizer("{{this is not a valid placeholder}}") == False
68 |
69 |
70 | def test_invalid_attribute_verbalizer_substitution():
71 | """Ensure that if we try to render a placeholder that doesn't exist on our DM, we fail."""
72 | verbalizer_template = "text: {{sadness}}"
73 | with pytest.raises(ValueError):
74 | render_verbalizer(verbalizer_template, SAMPLE_DM)
75 |
76 |
77 | def test_invalid_key_verbalizer_substitution():
78 | """Ensure that if we try to render a placeholder that doesn't exist on our dict, we fail."""
79 | verbalizer_template = "text: {{sadness}}"
80 | with pytest.raises(ValueError):
81 | render_verbalizer(verbalizer_template, SAMPLE_DICT)
82 |
--------------------------------------------------------------------------------
/tests/modules/text_generation/test_peft_config.py:
--------------------------------------------------------------------------------
1 | # Standard
2 | from unittest.mock import Mock
3 |
4 | # Third Party
5 | from peft import PromptTuningConfig
6 | import pytest
7 |
8 | # Local
9 | from caikit_nlp.data_model import TuningConfig
10 | from caikit_nlp.modules.text_generation import TextGeneration
11 | from caikit_nlp.modules.text_generation.peft_config import (
12 | TuningType,
13 | get_peft_config,
14 | resolve_base_model,
15 | )
16 | from caikit_nlp.resources.pretrained_model import HFAutoSeq2SeqLM
17 | from tests.fixtures import (
18 | SEQ2SEQ_LM_MODEL,
19 | TINY_MODELS_DIR,
20 | causal_lm_dummy_model,
21 | causal_lm_train_kwargs,
22 | seq2seq_lm_dummy_model,
23 | seq2seq_lm_train_kwargs,
24 | temp_config,
25 | )
26 |
27 |
28 | @pytest.mark.parametrize(
29 | "train_kwargs,dummy_model",
30 | [
31 | (
32 | "seq2seq_lm_train_kwargs",
33 | "seq2seq_lm_dummy_model",
34 | ),
35 | ("causal_lm_train_kwargs", "causal_lm_dummy_model"),
36 | ],
37 | )
38 | def test_get_peft_config(train_kwargs, dummy_model, request):
39 | # Fixtures can't be called directly or passed to mark parametrize;
40 | # Currently, passing the fixture by name and retrieving it through
41 | # the request is the 'right' way to do this.
42 | train_kwargs = request.getfixturevalue(train_kwargs)
43 | dummy_model = request.getfixturevalue(dummy_model)
44 |
45 | # Define some sample values for testing
46 | tuning_type = TuningType.PROMPT_TUNING
47 | tuning_config = TuningConfig(
48 | num_virtual_tokens=8,
49 | prompt_tuning_init_method="TEXT",
50 | prompt_tuning_init_text="Hello world",
51 | )
52 | dummy_resource = train_kwargs["base_model"]
53 |
54 | # Call the function being tested
55 | task_type, output_model_types, peft_config, tuning_type = get_peft_config(
56 | tuning_type,
57 | tuning_config,
58 | dummy_resource,
59 | dummy_model,
60 | "float32",
61 | "{{input}}",
62 | )
63 |
64 | # Add assertions to validate the behavior of the function
65 | assert task_type == dummy_resource.TASK_TYPE
66 | assert output_model_types == dummy_resource.PROMPT_OUTPUT_TYPES
67 | assert tuning_type == TuningType.PROMPT_TUNING
68 |
69 | # Validation for type & important fields in the peft config
70 | assert isinstance(peft_config, PromptTuningConfig)
71 | assert peft_config.num_virtual_tokens == tuning_config.num_virtual_tokens
72 | assert peft_config.task_type == dummy_resource.TASK_TYPE
73 | assert peft_config.prompt_tuning_init == tuning_config.prompt_tuning_init_method
74 | assert peft_config.prompt_tuning_init_text == tuning_config.prompt_tuning_init_text
75 |
76 |
77 | def test_resolve_model_with_invalid_path_raises():
78 | """Test passing invalid path to resolve_model function raises"""
79 |
80 | invalid_base_model = "path/../../important"
81 | with pytest.raises(ValueError):
82 | resolve_base_model(invalid_base_model, None, "foo")
83 |
84 |
85 | def test_resolve_model_with_valid_folder_path():
86 | """Test passing valid folder path to resolve_model function works"""
87 |
88 | model = resolve_base_model(SEQ2SEQ_LM_MODEL, TextGeneration, "float32")
89 |
90 | assert isinstance(model, HFAutoSeq2SeqLM)
91 |
92 |
93 | def test_resolve_model_works_preloaded_model():
94 |
95 | base_model = HFAutoSeq2SeqLM.bootstrap(SEQ2SEQ_LM_MODEL)
96 | resolved_model = resolve_base_model(base_model, TextGeneration, "float32")
97 | assert isinstance(resolved_model, HFAutoSeq2SeqLM)
98 |
99 |
100 | def test_resolve_model_with_different_base_path_works():
101 |
102 | base_model_name = "T5ForConditionalGeneration"
103 | with temp_config(base_models_dir=TINY_MODELS_DIR):
104 | resolved_model = resolve_base_model(base_model_name, TextGeneration, "float32")
105 | assert isinstance(resolved_model, HFAutoSeq2SeqLM)
106 |
--------------------------------------------------------------------------------
/tests/toolkit/text_generation/test_model_run_utils.py:
--------------------------------------------------------------------------------
1 | # Third Party
2 | import pytest
3 |
4 | # First Party
5 | from caikit.core.data_model.producer import ProducerId
6 | from caikit.interfaces.nlp.data_model import GeneratedTextResult
7 |
8 | # Local
9 | from caikit_nlp.toolkit.text_generation.model_run_utils import generate_text_func
10 | from tests.fixtures import (
11 | causal_lm_dummy_model,
12 | causal_lm_train_kwargs,
13 | seq2seq_lm_dummy_model,
14 | seq2seq_lm_train_kwargs,
15 | )
16 |
17 |
18 | @pytest.mark.parametrize(
19 | "model_fixture", ["seq2seq_lm_dummy_model", "causal_lm_dummy_model"]
20 | )
21 | @pytest.mark.parametrize(
22 | "serialization_method,expected_type",
23 | [
24 | ("to_dict", dict),
25 | ("to_json", str),
26 | ("to_proto", GeneratedTextResult._proto_class),
27 | ],
28 | )
29 | def test_generate_text_func_serialization_json(
30 | request,
31 | model_fixture,
32 | serialization_method,
33 | expected_type,
34 | ):
35 | model = request.getfixturevalue(model_fixture)
36 | generated_text = generate_text_func(
37 | model=model.model,
38 | tokenizer=model.tokenizer,
39 | producer_id=ProducerId("TextGeneration", "0.1.0"),
40 | eos_token="<\n>",
41 | text="What is the boiling point of liquid Nitrogen?",
42 | )
43 |
44 | serialized = getattr(generated_text, serialization_method)()
45 | assert isinstance(serialized, expected_type)
46 |
47 |
48 | @pytest.mark.parametrize("causal_model_fixture", ["causal_lm_dummy_model"])
49 | def test_generate_text_func_preserve_input_causal_lm(request, causal_model_fixture):
50 | """For Causal LM task types, setting preserve_inout_text to True
51 | will result in input text in model prediction. Setting to False will
52 | strip the input text from model prediction.
53 | """
54 | input_text = "What is the boiling point of liquid Nitrogen?"
55 | causal_model = request.getfixturevalue(causal_model_fixture)
56 | # assert type(causal_model.model) == False
57 | generated_text = generate_text_func(
58 | model=causal_model.model,
59 | tokenizer=causal_model.tokenizer,
60 | producer_id=ProducerId("TextGeneration", "0.1.0"),
61 | eos_token="<\n>",
62 | text=input_text,
63 | preserve_input_text=True,
64 | task_type="CAUSAL_LM",
65 | )
66 | assert input_text in generated_text.generated_text
67 | generated_text = generate_text_func(
68 | model=causal_model.model,
69 | tokenizer=causal_model.tokenizer,
70 | producer_id=ProducerId("TextGeneration", "0.1.0"),
71 | eos_token="<\n>",
72 | text=input_text,
73 | preserve_input_text=False,
74 | task_type="CAUSAL_LM",
75 | )
76 | assert input_text not in generated_text.generated_text
77 |
78 |
79 | @pytest.mark.parametrize("seq_model_fixture", ["seq2seq_lm_dummy_model"])
80 | def test_generate_text_func_preserve_input(request, seq_model_fixture):
81 | """For Seq2Seq LM task types, setting preserve_inout_text to True
82 | or False should not change predictions.
83 | """
84 | input_text = "What is the boiling point of liquid Nitrogen?"
85 | seq_model = request.getfixturevalue(seq_model_fixture)
86 | # assert type(causal_model.model) == False
87 | generated_text = generate_text_func(
88 | model=seq_model.model,
89 | tokenizer=seq_model.tokenizer,
90 | producer_id=ProducerId("TextGeneration", "0.1.0"),
91 | eos_token="<\n>",
92 | text=input_text,
93 | preserve_input_text=True,
94 | task_type="SEQ_2_SEQ_LM",
95 | )
96 | before_pred = generated_text.generated_text
97 | generated_text = generate_text_func(
98 | model=seq_model.model,
99 | tokenizer=seq_model.tokenizer,
100 | producer_id=ProducerId("TextGeneration", "0.1.0"),
101 | eos_token="<\n>",
102 | text=input_text,
103 | preserve_input_text=False,
104 | task_type="SEQ_2_SEQ_LM",
105 | )
106 | after_pred = generated_text.generated_text
107 | assert before_pred == after_pred
108 |
--------------------------------------------------------------------------------
/tests/toolkit/test_data_stream_wrapper.py:
--------------------------------------------------------------------------------
1 | """Tests for wrapper helpers for making DataStreams play nicely with PyTorch DataLoaders.
2 | """
3 | # Standard
4 | from unittest import mock
5 |
6 | # Third Party
7 | from torch.utils.data._utils import worker
8 | import pytest
9 |
10 | # First Party
11 | from caikit.core.data_model import DataStream
12 |
13 | # Local
14 | from caikit_nlp.toolkit.data_stream_wrapper import SimpleIterableStreamWrapper
15 | from tests.fixtures import requires_determinism
16 |
17 | # Sample data to load via PyTorch
18 | SAMPLE_DATA = [{"label": "foo"}, {"label": "foo"}, {"label": "bar"}, {"label": "bar"}]
19 | SAMPLE_STREAM = DataStream.from_iterable(SAMPLE_DATA)
20 | NUM_CYCLES = 10
21 |
22 |
23 | def test_without_shuffling():
24 | """Ensure that we can build a datastream & load it in a data loader without shuffling."""
25 | test_results = []
26 | # Get the IDs of all objects in the stream
27 | get_stream_id_order = lambda s: [id(datum) for datum in s]
28 | # Compare the data stream at two different iteration points; here, we
29 | # produce True if two streams have the same objects in the same order
30 | have_same_id_order = lambda id_set1, id_set2: all(
31 | [datum1 == datum2 for datum1, datum2 in zip(id_set1, id_set2)]
32 | ) and len(id_set1) == len(id_set2)
33 |
34 | # NOTE - a buffer size of 1 is a noop; the shuffle operation just gets the current element
35 | wrapper = SimpleIterableStreamWrapper(stream=SAMPLE_STREAM, shuffle=False)
36 | # Cycle through NUM_CYCLES times & ensure that the order of our objects does not change
37 | initialize_order = get_stream_id_order(wrapper)
38 | for _ in range(NUM_CYCLES):
39 | cycle_ids = get_stream_id_order(wrapper)
40 | test_res = have_same_id_order(initialize_order, cycle_ids)
41 | test_results.append(test_res)
42 | assert all(test_results)
43 |
44 |
45 | def test_shuffle_full_buffer(requires_determinism):
46 | """Ensure that we can build a datastream & shuffle it all in memory on each iteration."""
47 | test_results = []
48 | # Get the IDs of all objects in the stream
49 | get_stream_id_order = lambda s: [id(datum) for datum in s]
50 | # Compare the data stream at two different iteration points; here, we
51 | # produce True if two streams have the same objects in the same order
52 | have_same_id_order = lambda id_set1, id_set2: all(
53 | [datum1 == datum2 for datum1, datum2 in zip(id_set1, id_set2)]
54 | ) and len(id_set1) == len(id_set2)
55 |
56 | # NOTE - a buffer size of 1 is a noop; the shuffle operation just gets the current element
57 | wrapper = SimpleIterableStreamWrapper(
58 | stream=SAMPLE_STREAM, shuffle=True, buffer_size=len(SAMPLE_STREAM)
59 | )
60 | # Cycle through NUM_CYCLES times & ensure that the order of our objects DOES change sometimes
61 | initialize_order = get_stream_id_order(wrapper)
62 | for _ in range(NUM_CYCLES):
63 | cycle_ids = get_stream_id_order(wrapper)
64 | test_res = have_same_id_order(initialize_order, cycle_ids)
65 | test_results.append(test_res)
66 | assert not all(test_results)
67 |
68 |
69 | # def test_iter_with_multi_worker(requires_determinism):
70 | # """Ensure that we are able to iterate properly over data in case of workers
71 | # managed by torch"""
72 |
73 | # test_results = []
74 | # # Get the IDs of all objects in the stream
75 | # get_stream_id_order = lambda s: [id(datum) for datum in s]
76 | # # Compare the data stream at two different iteration points; here, we
77 | # # produce True if two streams have the same objects in the same order
78 | # have_same_id_order = lambda id_set1, id_set2: all(
79 | # [datum1 == datum2 for datum1, datum2 in zip(id_set1, id_set2)]
80 | # ) and len(id_set1) == len(id_set2)
81 |
82 | # dummy_worker_info = worker.WorkerInfo(
83 | # id=1,
84 | # num_workers=2,
85 | # seed=7,
86 | # )
87 |
88 | # with mock.patch.object(worker, '_worker_info', dummy_worker_info):
89 | # wrapper = SimpleIterableStreamWrapper(stream=SAMPLE_STREAM, shuffle=False)
90 | # initialize_order = get_stream_id_order(wrapper)
91 | # for _ in range(NUM_CYCLES):
92 | # cycle_ids = get_stream_id_order(wrapper)
93 | # test_res = have_same_id_order(initialize_order, cycle_ids)
94 | # test_results.append(test_res)
95 | # assert not all(test_results)
96 |
--------------------------------------------------------------------------------
/caikit_nlp/modules/tokenization/regex_sentence_splitter.py:
--------------------------------------------------------------------------------
1 | # Copyright The Caikit Authors
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | """Module that provides capability to split documents into sentences via regex"""
15 |
16 | # Standard
17 | import os
18 | import re
19 |
20 | # First Party
21 | from caikit.core.exceptions import error_handler
22 | from caikit.core.modules import ModuleBase, ModuleConfig, ModuleSaver, module
23 | from caikit.interfaces.nlp.data_model import Token, TokenizationResults
24 | from caikit.interfaces.nlp.tasks import TokenizationTask
25 | import alog
26 |
27 | log = alog.use_channel("RGX_SNT_SPLT")
28 | error = error_handler.get(log)
29 |
30 |
31 | @module(
32 | id="1e04e21b-7009-499e-abdd-41e5984c2e7d",
33 | name="Regex Sentence Splitter",
34 | version="0.1.0",
35 | task=TokenizationTask,
36 | )
37 | class RegexSentenceSplitter(ModuleBase):
38 | # pylint: disable=anomalous-backslash-in-string
39 | """Use python regexes to split document into sentences.
40 |
41 | Sample regex string:
42 | [^.!?\s][^.!?\n]*(?:[.!?](?!['\"]?\s|$)[^.!?]*)*[.!?]?['\"]?(?=\s|$)
43 | """
44 |
45 | def __init__(self, regex_str: str):
46 | """Construct a RegexSentenceSplitter object
47 | by compiling the input regex string into python regex object
48 | that can be used later on for detection.
49 |
50 | Args:
51 | regex_str: str
52 | String containing pattern that can be complied with python re
53 | module
54 | """
55 | super().__init__()
56 | error.type_check("", str, regex_str=regex_str)
57 | self.regex_str = regex_str
58 | self.regex = re.compile(self.regex_str)
59 |
60 | @classmethod
61 | def bootstrap(cls, regex_str) -> "RegexSentenceSplitter":
62 | """Bootstrap a a RegexSentenceSplitter object
63 |
64 | Args:
65 | regex_str: str
66 | String containing pattern that can be complied with python re
67 | module
68 | """
69 | return cls(regex_str)
70 |
71 | def save(self, model_path: str):
72 | """Save model in target path
73 |
74 | Args:
75 | model_path: str
76 | Path to store model artifact(s)
77 | """
78 | module_saver = ModuleSaver(
79 | self,
80 | model_path=model_path,
81 | )
82 | with module_saver:
83 | config_options = {"regex_str": self.regex_str}
84 | module_saver.update_config(config_options)
85 |
86 | @classmethod
87 | def load(cls, model_path: str) -> "RegexSentenceSplitter":
88 | """Load a regex sentence splitter model.
89 |
90 | Args:
91 | model_path: str
92 | Path to the model to be loaded.
93 |
94 | Returns:
95 | RegexSentenceSplitter
96 | Instance of this class built from the on disk model.
97 | """
98 | config = ModuleConfig.load(os.path.abspath(model_path))
99 | return cls(regex_str=config.regex_str)
100 |
101 | def run(self, text: str) -> TokenizationResults:
102 | """Run sentence splitting regex on input text.
103 |
104 | Args:
105 | text: str
106 | Document to run sentence splitting on.
107 |
108 | Returns:
109 | TokenizationResults
110 | TokenizationResults object containing tokens where each token
111 | corresponds to a detected sentence.
112 | """
113 |
114 | error.type_check("", str, text=text)
115 |
116 | matches = self.regex.finditer(text)
117 | tokens = []
118 | for match in matches:
119 | token = Token(start=match.start(), end=match.end(), text=match.group())
120 | tokens.append(token)
121 |
122 | return TokenizationResults(results=tokens)
123 |
--------------------------------------------------------------------------------
/caikit_nlp/toolkit/verbalizer_utils.py:
--------------------------------------------------------------------------------
1 | # Copyright The Caikit Authors
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | # Standard
15 | import re
16 |
17 | # First Party
18 | from caikit.core.exceptions import error_handler
19 | import alog
20 |
21 | log = alog.use_channel("VERBALIZER_UTIL")
22 | error = error_handler.get(log)
23 |
24 |
25 | def is_valid_verbalizer(verbalizer_template: str) -> bool:
26 | """Given a verbalizer template, determine if it's valid or not. We say a verbalizer
27 | template is valid if and only if we have at least one renderable field.
28 |
29 | Args:
30 | verbalizer_template: str
31 | Tentative verbalizer template to be used in text generation.
32 | Returns:
33 | bool:
34 | True if this is a valid verbalizer str with at least one renderable placeholder.
35 | """
36 | if not isinstance(verbalizer_template, str):
37 | return False
38 | return re.search(r"{{([_a-zA-Z0-9]+)}}", verbalizer_template) is not None
39 |
40 |
41 | def render_verbalizer(verbalizer_template: str, source_object) -> str:
42 | """Given a verbalizer template and a source object, replace all templates with keys or
43 | attributes from the source object. Templates are expected to follow Python variable name
44 | allowed chars, i.e., alphanumeric with underscores; if they don't, they're skipped. The
45 | contents of the template will be replaced with either keys on the source object, or attribute
46 | values.
47 |
48 | Templates should be in double brackets with no whitespace, e.g., {{label}}. Consider the
49 | following examples.
50 |
51 | Examples:
52 | 1. [Dictionary based]
53 | verbalizer_template = "Hello {{label}}"
54 | source_object = {"label": "world"}
55 | -> replace {{label}} with source_object["label"], producing "Hello world"
56 | returns: "Hello world"
57 | 2. [Object based]
58 | verbalizer_template = "Source: {{input}} Target: {{output}}"
59 | source_object = GenerationTrainRecord(input="machine", output="learning")
60 | -> replace {{input}} with getattr(source_object, "source") and replace
61 | {{output}} with getattr(source_object, "target").
62 | returns: "Source: machine Target: learning"
63 |
64 | NOTE: This function will throw ValueError if you try to grab a key or property
65 | that is invalid.
66 |
67 | Args:
68 | verbalizer_template: str
69 | Verbalizer that we want to render object values into.
70 |
71 | Returns:
72 | str
73 | Verbalizer string with placeholders rendered.
74 | """
75 | is_dict = isinstance(source_object, dict)
76 |
77 | def replace_text(match_obj: re.Match):
78 | captured_groups = match_obj.groups()
79 | if len(captured_groups) != 1:
80 | error(
81 | "",
82 | ValueError(
83 | "Unexpectedly captured multiple groups in verbalizer rendering"
84 | ),
85 | )
86 |
87 | index_object = captured_groups[0]
88 | if is_dict:
89 | if index_object not in source_object:
90 | error(
91 | "",
92 | ValueError(
93 | f"Requested template string '{index_object}' is not a valid key in dict"
94 | ),
95 | )
96 | return source_object[index_object]
97 |
98 | if not hasattr(source_object, index_object):
99 | error(
100 | "",
101 | ValueError(
102 | f"Requested template string '{index_object}' is not a valid property of type",
103 | ),
104 | )
105 | return getattr(source_object, index_object)
106 |
107 | return re.sub(r"{{([_a-zA-Z0-9]+)}}", replace_text, verbalizer_template)
108 |
--------------------------------------------------------------------------------
/caikit_nlp/model_management/tgis_auto_finder.py:
--------------------------------------------------------------------------------
1 | # Copyright The Caikit Authors
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | """
15 | The TGISAutoFinder implements the ModelFinder interface to provide automatic
16 | discovery of text-generation models that can be auto-configured to run against
17 | a remote TGIS model.
18 | """
19 | # Standard
20 | from typing import Optional
21 |
22 | # First Party
23 | from caikit.core import MODEL_MANAGER, error_handler
24 | from caikit.core.model_management import ModelFinderBase, model_finder_factory
25 | from caikit.core.modules import ModuleConfig
26 | from caikit_tgis_backend import TGISBackend
27 | import aconfig
28 | import alog
29 |
30 | # Local
31 | from ..modules.text_generation import TextGenerationTGIS
32 |
33 | log = alog.use_channel("TGIS_FND")
34 | error = error_handler.get(log)
35 |
36 |
37 | class TGISAutoFinder(ModelFinderBase):
38 | __doc__ = __doc__
39 |
40 | name = "TGIS-AUTO"
41 |
42 | # Constants for the keys of the config blob
43 | _LOCAL_INITIALIZER_NAME_KEY = "local_initializer_name"
44 | _TGIS_BACKEND_PRIORITY_KEY = "tgis_backend_priority"
45 |
46 | def __init__(self, config: aconfig.Config, instance_name: str = ""):
47 | """Initialize from the model finder factory config
48 |
49 | Config schema:
50 |
51 | local_initializer_name:
52 | type: string
53 | default: "default"
54 | description: The name within the initializers config for the LOCAL
55 | initializer that will hold the tgis backend to use
56 |
57 | tgis_backend_priority:
58 | type: integer
59 | description: Index within the backend_priority list for the TGIS
60 | backend to use. If not set, the first TGIS backend found will be
61 | used.
62 |
63 | Args:
64 | config (aconfig.Config): The configuration blob from caikit's
65 | model_management factory construction
66 | instance_name (str): The name of this finder instance
67 | """
68 | local_initializer_name = config.get(self._LOCAL_INITIALIZER_NAME_KEY, "default")
69 | tgis_backend_priority = config.get(self._TGIS_BACKEND_PRIORITY_KEY)
70 | error.type_check(
71 | "", str, local_initializer_name=local_initializer_name
72 | )
73 | error.type_check(
74 | "",
75 | int,
76 | tgis_backend_priority=tgis_backend_priority,
77 | allow_none=True,
78 | )
79 |
80 | # Extract the TGIS backend instance
81 | local_initializer = MODEL_MANAGER.get_initializer(local_initializer_name)
82 | backends = local_initializer.backends
83 | if tgis_backend_priority is not None:
84 | error.value_check(
85 | "",
86 | 0 <= tgis_backend_priority < len(backends),
87 | "Invalid {}: {}",
88 | self._TGIS_BACKEND_PRIORITY_KEY,
89 | tgis_backend_priority,
90 | )
91 | self._tgis_backend = backends[tgis_backend_priority]
92 | error.value_check(
93 | "",
94 | self._tgis_backend.backend_type == TGISBackend.backend_type,
95 | "Index {} is not a TGIS backend",
96 | tgis_backend_priority,
97 | )
98 | else:
99 | tgis_backend = None
100 | for backend in backends:
101 | if backend.backend_type == TGISBackend.backend_type:
102 | tgis_backend = backend
103 | break
104 | error.value_check(
105 | "",
106 | tgis_backend is not None,
107 | "No TGIS backend found!",
108 | )
109 | self._tgis_backend = tgis_backend
110 |
111 | def find_model(
112 | self,
113 | model_path: str,
114 | **kwargs,
115 | ) -> Optional[ModuleConfig]:
116 | """Find the model if"""
117 |
118 | # Get a connection to this model in tgis
119 | log.debug2("Attempting to setup TGIS client for %s", model_path)
120 | if self._tgis_backend.get_connection(model_id=model_path) is None:
121 | log.debug2("TGIS cannot connect to model %s", model_path)
122 | return None
123 |
124 | # If connection is ok, set up the module config to point to the remote
125 | # TGIS text generation module
126 | cfg = ModuleConfig(
127 | {
128 | "module_id": TextGenerationTGIS.MODULE_ID,
129 | "module_class": TextGenerationTGIS.MODULE_CLASS,
130 | "name": TextGenerationTGIS.MODULE_NAME,
131 | "version": TextGenerationTGIS.MODULE_VERSION,
132 | "model_name": model_path,
133 | }
134 | )
135 | # Set a special indicator in the module config to use the backend that
136 | # this finder found. This will override the backend found by the local
137 | # initializer.
138 | cfg.tgis_backend = self._tgis_backend
139 | return cfg
140 |
141 |
142 | model_finder_factory.register(TGISAutoFinder)
143 |
--------------------------------------------------------------------------------
/prompt_tuning_parameter_selection.md:
--------------------------------------------------------------------------------
1 | # Prompt Tuning Parameters
2 |
3 | ## Parameters for selection of training algorithm
4 | - `base_model`
5 | - **Description**: Path to the base model or `caikit.Resource` object of the base model to be used for tuning. A model-name string may also be provided. In this case, the Transformers API will automatically load it up from Hugging Face model cache if the model is locally available. If it is not available, the model may be downloaded by setting the `ALLOW_DOWNLOADS` environment variable to `true`.
6 | - **Accepted values**:
7 | - The model needs to be of type causal-lm or seq2seq, thus loadable via HuggingFace `AutoModelForCausalLM` or `AutoModelForSeq2SeqLM` loading methods.
8 | - `tuning_type`: (`str` or `caikit_nlp.modules.text_generation.TuningType`)
9 | - Type of Peft Tuning config which we would like to build.
10 | - **Accepted values**: `PROMPT_TUNING` and `MULTITASK_PROMPT_TUNING`
11 | - **Default**: `PROMPT_TUNING`
12 | - `num_epochs`: (int)
13 | - The number of epochs is the number of complete passes through the training dataset.
14 | - quality depends a lot on number of epochs.
15 | - **Expose to end user recommendation**: True
16 | - Accepted values: any positive int
17 | - **Default**: 20
18 | - `device`: (str)
19 | - Training device to be used, could be `cpu`, `cuda`, cuda specific device name
20 | - **Expose to end user recommendation**: False
21 | - `lr` / `learning_rate` - (float) The name of the parameter soon to be changed to make it more intuitive.
22 | - Learning rate to be used for training
23 | - **Expose to end user recommendation**: True
24 | - `accumulate_steps`:
25 | - Number of steps to be used for gradient accumulation. Gradient accumulation refers to a method of collecting gradient for configured number of steps instead of updating the model variables at every step and then applying the update to model variables. This can be used as a tool to overcome smaller batch size limitation. Often also referred in conjuction with "effective batch size".
26 | - **Expose to end user recommendation**: True
27 | - `verbalizer`
28 | - Verbalizer template to be used for formatting data at train and inference time. This template may use brackets to indicate where fields from the data model TrainGenerationRecord must be rendered. Default: "{{input}}", i.e., the raw text.
29 | - **Default**: "{{input}}", i.e., the raw text.
30 | - **Expose to end user recommendation**: True
31 | - `batch_size`:
32 | - The batch size is a number of samples processed before the model is updated.
33 | - **Default**: 8
34 | - **Expose to end user recommendation**: True
35 | - `max_source_length`:
36 | - Max length of input sequences being considered.
37 | - **Default**: 256.
38 | - **Expose to end user recommendation**: True
39 | - `max_target_length`:
40 | - Max length of target sequences being predicted. Default: 128.
41 | - **Default**: 128
42 | - **Expose to end user recommendation**: True
43 | - `tuning_config.num_virtual_tokens`:
44 | - Number of virtual tokens to be used for training. In prompt tuning we are essentially learning the embedded representations for soft prompts, which are known as virtual tokens, via back propagation for a specific task(s) while keeping the rest of the model is fixed. `num_virtual_tokens` is the number of dimensions for these virtual tokens.
45 | - **Expose to end user recommendation**: True (default to be set by application layer end)
46 | - This should also correspond to available source prompt, if source prompt exists, i.e a user need to select number of virtual token as per the source prompt available, in case they want to use MPT source prompts.
47 |
48 | - `tuning_config.prompt_tuning_init_method`:
49 | - Could be: `RANDOM`, `TEXT`, `ONLY_SOURCE_SHARED` and `AVERAGE_SOURCE`
50 | - `TEXT` requires `tuning_config.prompt_tuning_init_text` to be set
51 | - `ONLY_SOURCE_SHARED` and `AVERAGE_SOURCE` requires `tuning_config.prompt_tuning_init_source_model` to be set and source prompt model to be available for the given `base_model`
52 | - **Default**: `RANDOM`
53 | - **Expose to end user recommendation**: True
54 | - Only `RANDOM`, `TEXT` and `AVERAGE_SOURCE` to be exposed where `AVERAGE_SOURCE` is only applicable for tuning method is `MULTITASK_PROMPT_TUNING`
55 | - `tuning_config.prompt_tuning_init_text`:
56 | - Initialization text to be used **IF** `tuning_config.prompt_tuning_init_method` is set to `TEXT` otherwise this will be ignored.
57 | - **Default**: NO Default.
58 | - **Expose to end user recommendation**: True (if `TEXT` init_method is exposed to customers)
59 | - `tuning_config.prompt_tuning_init_source_model`:
60 | - Path pointing to the source prompt model. This path is relative to `config.source_prompt_base` (or SOURCE_PROMPT_BASE` env variable)
61 | - The source model selection needs to correspond to the `base_model`.
62 | - There may be cases where we have multiple source prompts available for a given model, in which case, their selection criteria needs to be determined.
63 | - **Default** Would depend on the `base_model`. If `MULTITASK_PROMPT_TUNING` is not selected as the tuning type, then this will be ignored.
64 | - `tuning_config.output_model_types`: `List(str)`
65 | - Could contain a list containing string `ENCODER`, `DECODER`.
66 | - Acceptable values for types of models:
67 | - CausalLM: `["DECODER"]`
68 | - Seq2Seq: `["ENCODER"]`, `["DECODER"]`, `["ENCODER", "DECODER"]`
69 | - **Default**:
70 | - CausalLM: `[DECODER]`
71 | - Seq2Seq: `[ENCODER]`
72 | - **Expose to end user recommendation**: False
73 |
74 | - `torch_dtype`: (str)
75 | - Datatype to use for training of the underlying text generation model. If no value is provided, we pull from torch_dtype in config. If an in memory resource is provided which does not match the specified data type, the model underpinning the resource will be converted in place to the correct torch dtype.
76 | - **Expose to end user recommendation**: False
77 | - Recommended to be configured at environment or server configuration level.
78 | - `silence_progress_bars` (bool)
79 | - Toggle to control progress bar for training. This is relevant to only "python user experience" and doesn't apply training via caikit runtime.
80 | - **Expose to end user recommendation**: False
81 |
--------------------------------------------------------------------------------
/benchmarks/logs/llama2-7b/20230905_191650.output:
--------------------------------------------------------------------------------
1 | (tuning) [gpu_user@gpu6120 caikit-nlp]$ ./ft_job.sh
2 | The cache for model files in Transformers v4.22.0 has been updated. Migrating your old cache. This is a one-time only operation. You can interrupt this and resume the migration later on by calling `transformers.utils.move_cache()`.
3 | 0it [00:00, ?it/s]
4 | /u/gpu_user/.conda/envs/tuning/lib/python3.9/site-packages/caikit/core/toolkit/errors/__init__.py:29: DeprecationWarning: The caikit.toolkit.errors package has moved to caikit.core.exceptions
5 | _warnings.warn(
6 | is still in the BETA phase and subject to change!
7 | /u/gpu_user/.conda/envs/tuning/lib/python3.9/site-packages/caikit/core/toolkit/error_handler.py:29: DeprecationWarning: The caikit.toolkit.error_handler package has moved to caikit.core.exceptions
8 | _warnings.warn(
9 | Downloading builder script: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 4.20k/4.20k [00:00<00:00, 4.16MB/s]
10 | Downloading builder script: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 6.60k/6.60k [00:00<00:00, 5.36MB/s]
11 | Downloading builder script: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 6.27k/6.27k [00:00<00:00, 5.52MB/s]
12 | Existing model directory found; purging it now.
13 | Experiment Configuration
14 | - Model Name: [/tmp/tu/huggingface/hub/models--llama-2-7b]
15 | |- Inferred Model Resource Type: []
16 | - Dataset: [glue/rte]
17 | - Number of Epochs: [1]
18 | - Learning Rate: [2e-05]
19 | - Batch Size: [19]
20 | - Output Directory: [/tmp/tu/output/tuning/llama27b]
21 | - Maximum source sequence length: [128]
22 | - Maximum target sequence length: [1024]
23 | - Gradient accumulation steps: [16]
24 | - Enable evaluation: [False]
25 | - Evaluation metrics: [['rouge']]
26 | - Torch dtype to use for training: [bfloat16]
27 | [Loading the dataset...]
28 | Downloading builder script: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 28.8k/28.8k [00:00<00:00, 15.9MB/s]
29 | Downloading metadata: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 28.7k/28.7k [00:00<00:00, 26.9MB/s]
30 | Downloading readme: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 27.9k/27.9k [00:00<00:00, 22.1MB/s]
31 | Downloading data: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 697k/697k [00:00<00:00, 12.0MB/s]
32 | Generating train split: 0%| | 0/2490 [00:00, ? examples/s]2023-09-05T19:16:00.306639 [fsspe:DBUG] open file: /u/gpu_user/.cache/huggingface/datasets/glue/rte/1.0.0/dacbe3125aa31d7f70367a07a8a9e72a5a0bfeb5fc42e75c9db75b96da6053ad.incomplete/glue-train-00000-00000-of-NNNNN.arrow
33 | Generating train split: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 2490/2490 [00:00<00:00, 5375.17 examples/s]
34 | Generating validation split: 0%| | 0/277 [00:00, ? examples/s]2023-09-05T19:16:00.770379 [fsspe:DBUG] open file: /u/gpu_user/.cache/huggingface/datasets/glue/rte/1.0.0/dacbe3125aa31d7f70367a07a8a9e72a5a0bfeb5fc42e75c9db75b96da6053ad.incomplete/glue-validation-00000-00000-of-NNNNN.arrow
35 | Generating validation split: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████| 277/277 [00:00<00:00, 28629.00 examples/s]
36 | Generating test split: 0%| | 0/3000 [00:00, ? examples/s]2023-09-05T19:16:00.780343 [fsspe:DBUG] open file: /u/gpu_user/.cache/huggingface/datasets/glue/rte/1.0.0/dacbe3125aa31d7f70367a07a8a9e72a5a0bfeb5fc42e75c9db75b96da6053ad.incomplete/glue-test-00000-00000-of-NNNNN.arrow
37 | Generating test split: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 3000/3000 [00:00<00:00, 35352.71 examples/s]
38 | 2023-09-05T19:16:00.866002 [fsspe:DBUG] open file: /u/gpu_user/.cache/huggingface/datasets/glue/rte/1.0.0/dacbe3125aa31d7f70367a07a8a9e72a5a0bfeb5fc42e75c9db75b96da6053ad.incomplete/dataset_info.json
39 | [Loading the base model resource...]
40 | Loading checkpoint shards: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 2/2 [00:05<00:00, 2.75s/it]
41 | [Starting the training...]
42 | 2023-09-05T19:16:50.992041 [PEFT_:DBUG] Shuffling enabled? True
43 | 2023-09-05T19:16:50.992203 [PEFT_:DBUG] Shuffling buffer size: 7470
44 | TRAINING ARGS: {
45 | "output_dir": "/tmp",
46 | "per_device_train_batch_size": 19,
47 | "per_device_eval_batch_size": 19,
48 | "num_train_epochs": 1,
49 | "seed": 73,
50 | "do_eval": false,
51 | "learning_rate": 2e-05,
52 | "weight_decay": 0.01,
53 | "save_total_limit": 3,
54 | "push_to_hub": false,
55 | "no_cuda": false,
56 | "remove_unused_columns": false,
57 | "dataloader_pin_memory": false,
58 | "gradient_accumulation_steps": 16,
59 | "eval_accumulation_steps": 16,
60 | "bf16": true
61 | }
62 | 0%| | 0/24 [00:00, ?it/s]You're using a LlamaTokenizerFast tokenizer. Please note that with a fast tokenizer, using the `__call__` method is faster than using a method to encode the text followed by a call to the `pad` method to get a padded encoding.
63 | {'train_runtime': 254.6707, 'train_samples_per_second': 29.332, 'train_steps_per_second': 0.094, 'train_loss': 1.93836243947347, 'epoch': 0.97}
64 | 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 24/24 [04:14<00:00, 10.61s/it]
65 | Loading checkpoint shards: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 2/2 [00:03<00:00, 1.60s/it]
66 | Using sep_token, but it is not set yet.
67 | [Training Complete]
68 |
--------------------------------------------------------------------------------
/tests/toolkit/text_generation/test_tgis_utils.py:
--------------------------------------------------------------------------------
1 | # Copyright The Caikit Authors
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | """
15 | Tests for tgis_utils
16 | """
17 |
18 | # Standard
19 | from typing import Iterable, Optional, Type
20 |
21 | # Third Party
22 | import fastapi
23 | import grpc
24 | import grpc._channel
25 | import pytest
26 |
27 | # First Party
28 | from caikit.core.data_model import ProducerId
29 | from caikit.core.exceptions.caikit_core_exception import CaikitCoreException
30 | from caikit.interfaces.runtime.data_model import RuntimeServerContextType
31 | from caikit_tgis_backend.protobufs import generation_pb2
32 |
33 | # Local
34 | from caikit_nlp.toolkit.text_generation import tgis_utils
35 | from tests.fixtures import TestServicerContext
36 |
37 | ## Helpers #####################################################################
38 |
39 |
40 | class MockTgisClient:
41 | """Mock of a TGIS client that doesn't actually call anything"""
42 |
43 | def __init__(
44 | self,
45 | status_code: Optional[grpc.StatusCode],
46 | error_message: str = "Yikes",
47 | ):
48 | self._status_code = status_code
49 | self._error_message = error_message
50 |
51 | def _maybe_raise(self, error_type: Type[grpc.RpcError], *args):
52 | if self._status_code not in [None, grpc.StatusCode.OK]:
53 | raise error_type(
54 | grpc._channel._RPCState(
55 | [], [], [], code=self._status_code, details=self._error_message
56 | ),
57 | *args,
58 | )
59 |
60 | def Generate(
61 | self, request: generation_pb2.BatchedGenerationRequest, **kwargs
62 | ) -> generation_pb2.BatchedGenerationResponse:
63 | self._maybe_raise(grpc._channel._InactiveRpcError)
64 | return generation_pb2.BatchedGenerationResponse()
65 |
66 | def GenerateStream(
67 | self, request: generation_pb2.SingleGenerationRequest, **kwargs
68 | ) -> Iterable[generation_pb2.GenerationResponse]:
69 | self._maybe_raise(grpc._channel._MultiThreadedRendezvous, None, None, None)
70 | yield generation_pb2.GenerationResponse()
71 |
72 | def Tokenize(
73 | self, request: generation_pb2.BatchedTokenizeRequest, **kwargs
74 | ) -> generation_pb2.BatchedTokenizeResponse:
75 | self._maybe_raise(grpc._channel._InactiveRpcError)
76 | return generation_pb2.BatchedTokenizeResponse()
77 |
78 |
79 | ## TGISGenerationClient ########################################################
80 |
81 |
82 | @pytest.mark.parametrize(
83 | "status_code",
84 | [code for code in grpc.StatusCode if code != grpc.StatusCode.OK],
85 | )
86 | @pytest.mark.parametrize(
87 | "method", ["unary_generate", "stream_generate", "unary_tokenize"]
88 | )
89 | def test_TGISGenerationClient_rpc_errors(status_code, method):
90 | """Test that raised errors in downstream RPCs are converted to
91 | CaikitCoreException correctly
92 | """
93 | tgis_client = MockTgisClient(status_code)
94 | gen_client = tgis_utils.TGISGenerationClient(
95 | "foo",
96 | "bar",
97 | tgis_client,
98 | ProducerId("foobar"),
99 | )
100 | with pytest.raises(CaikitCoreException) as context:
101 | kwargs = (
102 | dict(
103 | preserve_input_text=True,
104 | input_tokens=True,
105 | generated_tokens=True,
106 | token_logprobs=True,
107 | token_ranks=True,
108 | max_new_tokens=20,
109 | min_new_tokens=20,
110 | truncate_input_tokens=True,
111 | decoding_method="GREEDY",
112 | top_k=None,
113 | top_p=None,
114 | typical_p=None,
115 | temperature=None,
116 | seed=None,
117 | repetition_penalty=0.5,
118 | max_time=None,
119 | exponential_decay_length_penalty=None,
120 | stop_sequences=["asdf"],
121 | include_stop_sequence=True,
122 | )
123 | if method.endswith("_generate")
124 | else dict()
125 | )
126 | res = getattr(gen_client, method)(text="foobar", **kwargs)
127 | if method.startswith("stream_"):
128 | next(res)
129 |
130 | assert (
131 | context.value.status_code == tgis_utils.GRPC_TO_CAIKIT_CORE_STATUS[status_code]
132 | )
133 | rpc_err = context.value.__context__
134 | assert isinstance(rpc_err, grpc.RpcError)
135 |
136 |
137 | # NOTE: This test is preserved in caikit-nlp despite being duplicated in
138 | # caikit-tgis-backend so that we guarantee that the functionality is accessible
139 | # in a version-compatible way here.
140 | @pytest.mark.parametrize(
141 | argnames=["context", "route_info"],
142 | argvalues=[
143 | (
144 | fastapi.Request(
145 | {
146 | "type": "http",
147 | "headers": [
148 | (tgis_utils.ROUTE_INFO_HEADER_KEY.encode(), b"sometext")
149 | ],
150 | }
151 | ),
152 | "sometext",
153 | ),
154 | (
155 | fastapi.Request(
156 | {"type": "http", "headers": [(b"route-info", b"sometext")]}
157 | ),
158 | None,
159 | ),
160 | (
161 | TestServicerContext({tgis_utils.ROUTE_INFO_HEADER_KEY: "sometext"}),
162 | "sometext",
163 | ),
164 | (
165 | TestServicerContext({"route-info": "sometext"}),
166 | None,
167 | ),
168 | ("should raise ValueError", None),
169 | (None, None),
170 | # Uncertain how to create a grpc.ServicerContext object
171 | ],
172 | )
173 | def test_get_route_info(context: RuntimeServerContextType, route_info: Optional[str]):
174 | if not isinstance(context, (fastapi.Request, grpc.ServicerContext, type(None))):
175 | with pytest.raises(TypeError):
176 | tgis_utils.get_route_info(context)
177 | else:
178 | actual_route_info = tgis_utils.get_route_info(context)
179 | assert actual_route_info == route_info
180 |
--------------------------------------------------------------------------------
/tests/modules/text_generation/test_peft_tgis_remote.py:
--------------------------------------------------------------------------------
1 | """Tests for prompt tuning based inference via TGIS backend; note that these tests mock the
2 | TGIS client and do NOT actually start/hit a TGIS server instance.
3 | """
4 | # Standard
5 | from typing import Iterable
6 | from unittest import mock
7 | import os
8 | import tempfile
9 |
10 | # Third Party
11 | import pytest
12 |
13 | # Local
14 | from caikit_nlp.modules.text_generation import PeftPromptTuningTGIS
15 | from tests.fixtures import (
16 | StubTGISClient,
17 | causal_lm_dummy_model,
18 | causal_lm_train_kwargs,
19 | saved_causal_lm_dummy_model,
20 | stub_tgis_backend,
21 | temp_config,
22 | )
23 |
24 | SAMPLE_TEXT = "Hello stub"
25 |
26 |
27 | def test_load_and_run(causal_lm_dummy_model, stub_tgis_backend):
28 | """Ensure we can export an in memory model, load it, and (mock) run it with the right text & prefix ID."""
29 | # Patch our stub backend into caikit so that we don't actually try to start TGIS
30 | causal_lm_dummy_model.verbalizer = "hello distributed {{input}}"
31 |
32 | with mock.patch.object(StubTGISClient, "Generate") as mock_gen:
33 | mock_gen.side_effect = StubTGISClient.unary_generate
34 |
35 | # Save the local model & reload it a TGIS backend distributed module
36 | # Also, save the name of the dir + prompt ID, which is the path TGIS expects for the prefix ID
37 | with tempfile.TemporaryDirectory() as model_dir:
38 | causal_lm_dummy_model.save(model_dir)
39 | mock_tgis_model = PeftPromptTuningTGIS.load(model_dir, stub_tgis_backend)
40 | model_prompt_dir = os.path.split(model_dir)[-1]
41 |
42 | # Run an inference request, which is wrapped around our mocked Generate call
43 | result = mock_tgis_model.run(
44 | SAMPLE_TEXT, preserve_input_text=True, max_new_tokens=200, min_new_tokens=50
45 | )
46 | StubTGISClient.validate_unary_generate_response(result)
47 |
48 | stub_generation_request = mock_gen.call_args_list[0].args[0]
49 |
50 | # Validate that our verbalizer carried over correctly & was applied at inference time
51 | assert mock_tgis_model.verbalizer == causal_lm_dummy_model.verbalizer
52 | assert stub_generation_request.requests[
53 | 0
54 | ].text == "hello distributed {}".format(SAMPLE_TEXT)
55 |
56 | # Ensure that our prefix ID matches the expected path based on our tmpdir and config
57 | assert model_prompt_dir == stub_generation_request.prefix_id
58 |
59 |
60 | def test_load_and_tokenize(causal_lm_dummy_model, stub_tgis_backend):
61 | """Ensure we can export an in memory model, load it, and tokenize it"""
62 | # Patch our stub backend into caikit so that we don't actually try to start TGIS
63 | causal_lm_dummy_model.verbalizer = "hello distributed {{input}}"
64 |
65 | with mock.patch.object(StubTGISClient, "Tokenize") as mock_gen:
66 | mock_gen.side_effect = StubTGISClient.tokenize
67 |
68 | # Save the local model & reload it a TGIS backend distributed module
69 | with tempfile.TemporaryDirectory() as model_dir:
70 | causal_lm_dummy_model.save(model_dir)
71 | mock_tgis_model = PeftPromptTuningTGIS.load(model_dir, stub_tgis_backend)
72 |
73 | result = mock_tgis_model.run_tokenizer(SAMPLE_TEXT)
74 | StubTGISClient.validate_tokenize_response(result)
75 |
76 | # Validate that our verbalizer carried over correctly & was applied at inference time
77 | assert mock_tgis_model.verbalizer == causal_lm_dummy_model.verbalizer
78 |
79 |
80 | def test_load_and_run_stream_out(causal_lm_dummy_model, stub_tgis_backend):
81 | """Ensure we can export an in memory model, load it, and (mock) run output streaming
82 | with the right text & prefix ID."""
83 | # Patch our stub backend into caikit so that we don't actually try to start TGIS
84 | causal_lm_dummy_model.verbalizer = "hello distributed {{input}}"
85 |
86 | with mock.patch.object(StubTGISClient, "GenerateStream") as mock_gen:
87 | mock_gen.side_effect = StubTGISClient.stream_generate
88 |
89 | # Save the local model & reload it a TGIS backend distributed module
90 | # Also, save the name of the dir + prompt ID, which is the path TGIS expects for the prefix ID
91 | with tempfile.TemporaryDirectory() as model_dir:
92 | causal_lm_dummy_model.save(model_dir)
93 | mock_tgis_model = PeftPromptTuningTGIS.load(model_dir, stub_tgis_backend)
94 | model_prompt_dir = os.path.split(model_dir)[-1]
95 | stub_tgis_backend.load_prompt_artifacts.assert_called_once()
96 |
97 | # Run an inference request, which is wrapped around our mocked GenerateStream call
98 | stream_result = mock_tgis_model.run_stream_out(
99 | SAMPLE_TEXT, preserve_input_text=True, max_new_tokens=200, min_new_tokens=50
100 | )
101 | StubTGISClient.validate_stream_generate_response(stream_result)
102 |
103 | stub_generation_request = mock_gen.call_args_list[0].args[0]
104 |
105 | # Validate that our verbalizer carried over correctly & was applied at inference time
106 | assert mock_tgis_model.verbalizer == causal_lm_dummy_model.verbalizer
107 | assert stub_generation_request.request.text == "hello distributed {}".format(
108 | SAMPLE_TEXT
109 | )
110 |
111 | # Ensure that our prefix ID matches the expected path based on our tmpdir and config
112 | assert model_prompt_dir == stub_generation_request.prefix_id
113 |
114 |
115 | def test_purge_prompt_on_del(saved_causal_lm_dummy_model, stub_tgis_backend):
116 | """Test that the prompt artifacts get purged when a model is deleted"""
117 |
118 | # Load the model and make sure the prompt got copied over
119 | mock_tgis_model = PeftPromptTuningTGIS.load(
120 | saved_causal_lm_dummy_model, stub_tgis_backend
121 | )
122 | stub_tgis_backend.load_prompt_artifacts.assert_called_once()
123 |
124 | # Delete the model and make sure the prompt got "removed"
125 | with temp_config(unload_tgis_prompt_artifacts=True):
126 | mock_tgis_model.__del__()
127 | stub_tgis_backend.unload_prompt_artifacts.assert_called_once()
128 | prompt_id = os.path.basename(saved_causal_lm_dummy_model)
129 | stub_tgis_backend.unload_prompt_artifacts.assert_called_with(
130 | mock_tgis_model.base_model_name, prompt_id
131 | )
132 |
133 |
134 | def test_purge_prompt_disabled_on_del(saved_causal_lm_dummy_model, stub_tgis_backend):
135 | """Test that the prompt artifacts are not purged if disabled"""
136 |
137 | # Load the model and make sure the prompt got copied over
138 | mock_tgis_model = PeftPromptTuningTGIS.load(
139 | saved_causal_lm_dummy_model, stub_tgis_backend
140 | )
141 | stub_tgis_backend.load_prompt_artifacts.assert_called_once()
142 |
143 | # Delete the model and make sure the prompt got "removed"
144 | with temp_config(unload_tgis_prompt_artifacts=False):
145 | mock_tgis_model.__del__()
146 | assert not stub_tgis_backend.unload_prompt_artifacts.called
147 |
--------------------------------------------------------------------------------
/CONTRIBUTING.md:
--------------------------------------------------------------------------------
1 | # Contributing
2 |
3 | 👍🎉 First off, thank you for taking the time to contribute! 🎉👍
4 |
5 | The following is a set of guidelines for contributing. These are just guidelines, not rules. Use your best judgment, and feel free to propose changes to this document in a pull request. Please read the [community contribution guide](https://github.com/caikit/community/blob/main/CONTRIBUTING.md) first for general practices for the Caikit community.
6 |
7 | ## What Should I Know Before I Get Started?
8 |
9 | ### Code of Conduct
10 |
11 | This project adheres to the [Contributor Covenant](./code-of-conduct.md). By participating, you are expected to uphold this code.
12 |
13 | Please report unacceptable behavior to one of the [Code Owners](./CODEOWNERS).
14 |
15 | ### How Do I Start Contributing?
16 |
17 | The below workflow is designed to help you begin your first contribution journey. It will guide you through creating and picking up issues, working through them, having your work reviewed, and then merging.
18 |
19 | Help on open source projects is always welcome and there is always something that can be improved. For example, documentation (like the text you are reading now) can always use improvement, code can always be clarified, variables or functions can always be renamed or commented on, and there is always a need for more test coverage. If you see something that you think should be fixed, take ownership! Here is how you get started:
20 |
21 | ## How Can I Contribute?
22 |
23 | When contributing, it's useful to start by looking at [issues](https://github.com/caikit/caikit-nlp/issues). After picking up an issue, writing code, or updating a document, make a pull request and your work will be reviewed and merged. If you're adding a new feature or find a bug, it's best to [write an issue](https://github.com/caikit/caikit-nlp/issues/new?assignees=&labels=&template=feature_request.md&title=) first to discuss it with maintainers.
24 |
25 | To contribute to this repo, you'll use the Fork and Pull model common in many open source repositories. For details on this process, check out [The GitHub Workflow
26 | Guide](https://github.com/kubernetes/community/blob/master/contributors/guide/github-workflow.md)
27 | from Kubernetes.
28 |
29 | When your contribution is ready, you can create a pull request. Pull requests are often referred to as "PR". In general, we follow the standard [GitHub pull request](https://help.github.com/en/articles/about-pull-requests) process. Follow the template to provide details about your pull request to the maintainers.
30 |
31 | Before sending pull requests, make sure your changes pass formatting, linting and unit tests.
32 |
33 | #### Code Review
34 |
35 | Once you've [created a pull request](#how-can-i-contribute), maintainers will review your code and may make suggestions to fix before merging. It will be easier for your pull request to receive reviews if you consider the criteria the reviewers follow while working. Remember to:
36 |
37 | - Run tests locally and ensure they pass
38 | - Follow the project coding conventions
39 | - Write detailed commit messages
40 | - Break large changes into a logical series of smaller patches, which are easy to understand individually and combine to solve a broader issue
41 |
42 | ### Reporting Bugs
43 |
44 | This section guides you through submitting a bug report. Following these guidelines helps maintainers and the community understand your report ✏️, reproduce the behavior 💻, and find related reports 🔎.
45 |
46 | #### How Do I Submit A (Good) Bug Report?
47 |
48 | Bugs are tracked as [GitHub issues using the Bug Report template](https://github.com/caikit/caikit-nlp/issues/new?assignees=&labels=&template=bug_report.md&title=). Create an issue on that and provide the information suggested in the bug report issue template.
49 |
50 | ### Suggesting Enhancements
51 |
52 | This section guides you through submitting an enhancement suggestion, including completely new features, tools, and minor improvements to existing functionality. Following these guidelines helps maintainers and the community understand your suggestion ✏️ and find related suggestions 🔎
53 |
54 | #### How Do I Submit A (Good) Enhancement Suggestion?
55 |
56 | Enhancement suggestions are tracked as [GitHub issues using the Feature Request template](https://github.com/caikit/caikit-nlp/issues/new?assignees=&labels=&template=feature_request.md&title=). Create an issue and provide the information suggested in the feature requests or user story issue template.
57 |
58 | #### How Do I Submit A (Good) Improvement Item?
59 |
60 | Improvements to existing functionality are tracked as [GitHub issues using the User Story template](https://github.com/caikit/caikit-nlp/issues/new?assignees=&labels=&template=user_story.md&title=). Create an issue and provide the information suggested in the feature requests or user story issue template.
61 |
62 | ## Development
63 |
64 | ### Set up your dev environment
65 |
66 | The following tools are required:
67 |
68 | - [git](https://git-scm.com)
69 | - [python](https://www.python.org) (v3.8+)
70 | - [pip](https://pypi.org/project/pip/) (v23.0+)
71 |
72 | You can setup your dev environment using [tox](https://tox.wiki/en/latest/), an environment orchestrator which allows for setting up environments for and invoking builds, unit tests, formatting, linting, etc. Install tox with:
73 |
74 | ```sh
75 | pip install -r setup_requirements.txt
76 | ```
77 |
78 | If you want to manage your own virtual environment instead of using `tox`, you can install `caikit` and all dependencies with:
79 |
80 | ```sh
81 | pip install .
82 | ```
83 |
84 | ### Unit tests
85 |
86 | Unit tests are enforced by the CI system. When making changes, run the tests before pushing the changes to avoid CI issues.
87 |
88 | Running unit tests against all supported Python versions is as simple as:
89 |
90 | ```sh
91 | tox
92 | ```
93 |
94 | Running tests against a single Python version can be done with:
95 |
96 | ```sh
97 | tox -e py
98 | ```
99 |
100 | ### Coding style
101 |
102 | Caikit follows the python [pep8](https://peps.python.org/pep-0008/) coding style. The coding style is enforced by the CI system, and your PR will fail until the style has been applied correctly.
103 |
104 | We use [pre-commit](https://pre-commit.com/) to enforce coding style using [black](https://github.com/psf/black), [prettier](https://github.com/prettier/prettier) and [isort](https://pycqa.github.io/isort/).
105 |
106 | You can invoke formatting with:
107 |
108 | ```sh
109 | tox -e fmt
110 | ```
111 |
112 | In addition, we use [pylint](https://www.pylint.org) to perform static code analysis of the code.
113 |
114 | You can invoke the linting with the following command
115 |
116 | ```sh
117 | tox -e lint
118 | ```
119 |
120 | ## Your First Code Contribution
121 |
122 | Unsure where to begin contributing? You can start by looking through these issues:
123 |
124 | - Issues with the [`good first issue` label](https://github.com/caikit/caikit-nlp/issues?q=is%3Aopen+is%3Aissue+label%3A%22good+first+issue%22) - these should only require a few lines of code and are good targets if you're just starting contributing.
125 | - Issues with the [`help wanted` label](https://github.com/caikit/caikit-nlp/issues?q=is%3Aopen+is%3Aissue+label%3A%22help+wanted%22) - these range from simple to more complex, but are generally things we want but can't get to in a short time frame.
126 |
127 |
--------------------------------------------------------------------------------
/caikit_nlp/resources/pretrained_model/hf_auto_seq2seq_lm.py:
--------------------------------------------------------------------------------
1 | # Copyright The Caikit Authors
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | """
15 | Huggingface auto causal LM resource type
16 | """
17 | # Standard
18 | from collections.abc import Mapping
19 | from typing import Dict, List, Optional, Union
20 |
21 | # Third Party
22 | from torch.utils.data import IterableDataset
23 | from transformers import (
24 | AutoModelForSeq2SeqLM,
25 | DataCollatorForSeq2Seq,
26 | Seq2SeqTrainer,
27 | Seq2SeqTrainingArguments,
28 | )
29 | from transformers.models.auto import modeling_auto
30 |
31 | # First Party
32 | from caikit.core.exceptions import error_handler
33 | from caikit.core.modules import module
34 | import alog
35 |
36 | # Local
37 | from ...data_model import GenerationTrainRecord, PromptOutputModelType
38 | from ...toolkit.trainer_utils import log_step
39 | from ...toolkit.verbalizer_utils import render_verbalizer
40 | from .base import PretrainedModelBase
41 |
42 | log = alog.use_channel("HFRBAS")
43 | error = error_handler.get(log)
44 |
45 | IGNORE_ID = -100
46 |
47 |
48 | class LoggingTrainer(Seq2SeqTrainer):
49 | # pylint: disable=unused-argument
50 | def log(self, logs: Dict[str, float], start_time: Optional[float] = None) -> None:
51 | """
52 | Log `logs` on the various objects watching training.
53 |
54 | Subclass and override this method to inject custom behavior.
55 |
56 | Args:
57 | logs (`Dict[str, float]`):
58 | The values to log.
59 | """
60 | # start_time was added in default trainer log
61 | # https://github.com/huggingface/transformers/pull/34507
62 | self.state = log_step(self.state, logs)
63 | self.control = self.callback_handler.on_log(
64 | self.args, self.state, self.control, logs
65 | )
66 |
67 |
68 | @module(
69 | id="6759e891-287b-405b-bd8b-54a4a4d51c25",
70 | name="HF Transformers Auto Seq2Seq LM",
71 | version="0.1.0",
72 | )
73 | class HFAutoSeq2SeqLM(PretrainedModelBase):
74 | """This resource (module) wraps a handle to a Huggingface
75 | AutoModelForSeq2SeqLM
76 | """
77 |
78 | MODEL_TYPE = AutoModelForSeq2SeqLM
79 | SUPPORTED_MODEL_TYPES = modeling_auto.MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING_NAMES
80 | TASK_TYPE = "SEQ_2_SEQ_LM"
81 | PROMPT_OUTPUT_TYPES = [PromptOutputModelType.ENCODER]
82 | MAX_NUM_TRANSFORMERS = 2
83 |
84 | @classmethod
85 | def get_num_transformers_submodules(
86 | cls, output_model_types: List[PromptOutputModelType]
87 | ):
88 | """Return number of applicable transformer submodules"""
89 | num_transformer_submodules = 0
90 | if PromptOutputModelType.ENCODER in output_model_types:
91 | num_transformer_submodules += 1
92 | if PromptOutputModelType.DECODER in output_model_types:
93 | num_transformer_submodules += 1
94 | error.value_check(
95 | "", 0 < num_transformer_submodules <= cls.MAX_NUM_TRANSFORMERS
96 | )
97 | return num_transformer_submodules
98 |
99 | def get_trainer(
100 | self,
101 | train_dataset: IterableDataset,
102 | eval_dataset: Union[IterableDataset, None] = None,
103 | optimizers=(None, None),
104 | **kwargs
105 | ):
106 | """
107 | Args:
108 | *kwargs: arguments supported by HF Seq2SeqTrainingArguments:
109 | https://huggingface.co/docs/transformers/v4.30.0/en/main_classes/trainer#transformers.Seq2SeqTrainingArguments
110 |
111 | NOTE: following parameters are not supported currently:
112 | 1. model_init
113 | 2. compute_metrics
114 | 3. callbacks
115 | 4. preprocess_logits_for_metrics
116 | """
117 |
118 | # NOTE: predict_with_generate is incompatible with fsdp
119 | training_args = Seq2SeqTrainingArguments(**kwargs)
120 |
121 | # pylint: disable=duplicate-code
122 | # TODO: Fetch DataCollator either from property of this
123 | # class or fetch it as an argument.
124 | data_collator = self._get_data_collator(**kwargs)
125 |
126 | trainer_arguments = {
127 | "train_dataset": train_dataset,
128 | "data_collator": data_collator,
129 | "optimizers": optimizers,
130 | "eval_dataset": eval_dataset,
131 | # "generation_max_length": max_target_length,
132 | }
133 |
134 | return LoggingTrainer(self._model, training_args, **trainer_arguments)
135 |
136 | def _get_data_collator(self, **kwargs):
137 | """Function to return appropriate data collator based on resource.
138 |
139 | This implementation uses DataCollatorForSeq2Seq
140 |
141 | Args:
142 | **kwargs:
143 | All the keyword arguments passed to this function
144 | will get filtered out to appropriate ones that are
145 | applicable to implemented data collator.
146 | Returns:
147 | transformers.DataCollator
148 | """
149 |
150 | applicable_args = ["max_length", "pad_to_multiple_of"]
151 | collator_kwargs = {key: kwargs[key] for key in applicable_args if key in kwargs}
152 |
153 | return DataCollatorForSeq2Seq(
154 | tokenizer=self._tokenizer, model=self._model, **collator_kwargs
155 | )
156 |
157 | @classmethod
158 | def tokenize_function(
159 | cls,
160 | example: Union[GenerationTrainRecord, Mapping],
161 | tokenizer: "AutoTokenizer",
162 | max_source_length: int,
163 | max_target_length: int,
164 | verbalizer: Union[None, str] = None,
165 | task_ids: Union[None, int] = None,
166 | ) -> "BatchEncoding":
167 | """Tokenization function to be used for seq2seq training; this function consumes a
168 | GenerationTrainRecord object and applies the verbalizer to it followed by
169 | the model tokenizer. Finally, we postprocess by ignoring pad tokens in the label IDs.
170 |
171 | Args:
172 | example: Union[GenerationTrainRecord, Mapping]
173 | Training data model object to convert a form we can learn on.
174 |
175 | Returns:
176 | transformers.tokenization_utils_base.BatchEncoding
177 | encoded tokenization output corresponding to the input example.
178 | """
179 | source, target = cls.decompose_example_io(example)
180 | source = (
181 | source if verbalizer is None else render_verbalizer(verbalizer, example)
182 | )
183 |
184 | model_inputs = tokenizer(
185 | source,
186 | max_length=max_source_length,
187 | padding="max_length",
188 | truncation=True,
189 | )
190 | labels = tokenizer(
191 | target,
192 | max_length=max_target_length,
193 | padding="max_length",
194 | truncation=True,
195 | )
196 |
197 | labels = labels["input_ids"]
198 |
199 | labels = list(
200 | map(lambda x: IGNORE_ID if x == tokenizer.pad_token_id else x, labels)
201 | )
202 | model_inputs["labels"] = labels
203 | if task_ids is not None:
204 | model_inputs["task_ids"] = task_ids
205 |
206 | return model_inputs
207 |
--------------------------------------------------------------------------------
/examples/evaluate_model.py:
--------------------------------------------------------------------------------
1 | """Given a trained model, which was presumably created by run_peft_tuning.py,
2 | load it and run evaluation.
3 | """
4 | # Standard
5 | import argparse
6 | import json
7 | import pathlib
8 |
9 | # Third Party
10 | from tqdm import tqdm
11 | from utils import (
12 | SUPPORTED_DATASETS,
13 | SUPPORTED_METRICS,
14 | configure_random_seed_and_logging,
15 | get_wrapped_evaluate_metric,
16 | is_float,
17 | kill_tgis_container_if_exists,
18 | load_model,
19 | print_colored,
20 | string_to_float,
21 | )
22 |
23 | # Local
24 | from caikit_nlp.toolkit.verbalizer_utils import render_verbalizer
25 |
26 |
27 | def parse_args() -> argparse.Namespace:
28 | """Parse & validate command line arguments.
29 |
30 | Returns:
31 | argparse.Namespace
32 | Parsed arguments to be leveraged model evaluation.
33 | """
34 | parser = argparse.ArgumentParser(
35 | description="Evaluate a text generation model.",
36 | )
37 | # TODO - Patch text-generation-launcher model var so that we can't mount the wrong model
38 | parser.add_argument(
39 | "--tgis",
40 | help="If enabled, runs inference through TGIS instead the local .run().",
41 | action="store_true",
42 | )
43 | parser.add_argument(
44 | "--model_path",
45 | help="Model to be loaded from disk.",
46 | type=pathlib.Path,
47 | required=True,
48 | )
49 | parser.add_argument(
50 | "--dataset",
51 | help="Dataset to use to train prompt vectors. Options: {}".format(
52 | list(SUPPORTED_DATASETS.keys())
53 | ),
54 | default="twitter_complaints",
55 | )
56 | parser.add_argument(
57 | "--metrics",
58 | help="Metrics to calculate in space delimited list",
59 | default=["accuracy"],
60 | nargs="*",
61 | choices=list(SUPPORTED_METRICS.keys()),
62 | )
63 | parser.add_argument(
64 | "--preds_file",
65 | help="JSON file to dump raw source / target texts to.",
66 | default="model_preds.json",
67 | )
68 | parser.add_argument(
69 | "--max_new_tokens",
70 | help="Maximum number of new tokens to be generated",
71 | type=int,
72 | default=20,
73 | )
74 | parser.add_argument(
75 | "--truncate_input_tokens",
76 | help="Number of allowed input tokens (no truncation=0)",
77 | type=int,
78 | default=0,
79 | )
80 | args = parser.parse_args()
81 | return args
82 |
83 |
84 | def get_model_preds_and_references(
85 | model, validation_stream, max_new_tokens, truncate_input_tokens
86 | ):
87 | """Given a model & a validation stream, run the model against every example in the validation
88 | stream and compare the outputs to the target/output sequence.
89 |
90 | Args:
91 | model
92 | Peft Model to be evaluated (may leverage different backends).
93 | validation_stream: DataStream[GenerationTrainRecord]
94 | Validation stream with labeled targets that we want to compare to our model's
95 | predictions.
96 | max_new_tokens: int
97 | Max number of new tokens to be generated, i.e., output limit
98 | truncate_input_tokens: int
99 | Number of allowed input tokens, i.e., input limit
100 |
101 | Returns:
102 | Tuple(List)
103 | Tuple of 2 lists; the model predictions and the expected output sequences.
104 | """
105 | model_preds = []
106 | targets = []
107 |
108 | for datum in tqdm(validation_stream):
109 | # Local .run() currently prepends the input text to the generated string;
110 | # Ensure that we're just splitting the first predicted token & beyond.
111 | raw_model_text = model.run(
112 | datum.input,
113 | max_new_tokens=max_new_tokens,
114 | truncate_input_tokens=truncate_input_tokens,
115 | ).generated_text
116 | parse_pred_text = raw_model_text.split(datum.input)[-1].strip()
117 | model_preds.append(parse_pred_text)
118 | targets.append(datum.output)
119 | return (
120 | model_preds,
121 | targets,
122 | )
123 |
124 |
125 | def export_model_preds(preds_file, predictions, validation_stream, verbalizer):
126 | """Exports a JSON file containing a list of objects, where every object contains:
127 | - source: str - Source string used for generation.
128 | - target: str - Ground truth target label used for generation.
129 | - verbalized_source: str - Source string after model verbalization
130 | - predicted_target: str - Predicted model target.
131 |
132 | Args:
133 | preds_file: str
134 | Path on disk to JSON file to be written.
135 | predictions: List
136 | Model prediction list, where each predicted text excludes source text as a prefix.
137 | validation_stream: DataStream
138 | Datastream object of GenerationTrainRecord objects used for validation against a model
139 | to generate predictions.
140 | verbalizer: str
141 | Model verbalizer used for generating target predictions.
142 | """
143 | pred_objs = []
144 | for pred, record in zip(predictions, validation_stream):
145 | res = {
146 | "source": record.input,
147 | "target": record.output,
148 | "predicted_target": pred,
149 | }
150 | if verbalizer is not None:
151 | res["verbalized_source"] = render_verbalizer(verbalizer, record)
152 | pred_objs.append(res)
153 |
154 | with open(preds_file, "w") as jfile:
155 | json.dump(pred_objs, jfile, indent=4, sort_keys=True)
156 |
157 |
158 | if __name__ == "__main__":
159 | configure_random_seed_and_logging()
160 | args = parse_args()
161 | metric_funcs = [SUPPORTED_METRICS[metric] for metric in args.metrics]
162 | print_colored("Metrics to be calculated: {}".format(args.metrics))
163 |
164 | # Load the model; this can be a local model, or a distributed TGIS instance
165 | print_colored("Loading the model...")
166 | model = load_model(args.tgis, str(args.model_path))
167 | # Load the validation stream with marked target sequences
168 | print_colored("Grabbing validation data...")
169 | dataset_info = SUPPORTED_DATASETS[args.dataset]
170 | validation_stream = dataset_info.dataset_loader()[1]
171 | if validation_stream is None:
172 | raise ValueError(
173 | "Selected dataset does not have a validation dataset available!"
174 | )
175 |
176 | # Run the data through the model; save the predictions & references
177 | print_colored("Getting model predictions...")
178 | predictions, references = get_model_preds_and_references(
179 | model, validation_stream, args.max_new_tokens, args.truncate_input_tokens
180 | )
181 | print_colored(
182 | "Exporting model preds, source, verbalized source, and ground truth targets to {}".format(
183 | args.preds_file
184 | )
185 | )
186 | export_model_preds(
187 | args.preds_file,
188 | predictions,
189 | validation_stream,
190 | getattr(model, "verbalizer", None),
191 | )
192 |
193 | for metric_func in metric_funcs:
194 | metric_res = metric_func(predictions=predictions, references=references)
195 | print_colored(metric_res)
196 | # If we started a TGIS instance, kill it; otherwise, leave our container alone.
197 | # TODO: This will still looks for containers to kill, even if you're running TGIS
198 | # outside of a container through text-generation-server directly. For now, we are
199 | # always running TGIS in a container, so it's ok; the worst that will happen is
200 | # you'll kill somebody else's container.
201 | if args.tgis:
202 | print_colored("Killing containerized TGIS instance...")
203 | kill_tgis_container_if_exists()
204 |
--------------------------------------------------------------------------------
/caikit_nlp/modules/text_classification/sequence_classification.py:
--------------------------------------------------------------------------------
1 | # Copyright The Caikit Authors
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | """This module contains sequence classification, compatible with transformer
15 | SequenceClassification modules. At this time this module is only designed
16 | for inference"""
17 | # Standard
18 | from typing import Dict, List, Union
19 |
20 | # Third Party
21 | import torch
22 |
23 | # First Party
24 | from caikit.core.exceptions import error_handler
25 | from caikit.core.modules import ModuleBase, ModuleLoader, ModuleSaver, module
26 | from caikit.interfaces.nlp.data_model import ClassificationResult, ClassificationResults
27 | from caikit.interfaces.nlp.tasks import TextClassificationTask
28 | import alog
29 |
30 | # Local
31 | from ...resources.pretrained_model.hf_auto_seq_classifier import (
32 | HFAutoSequenceClassifier,
33 | )
34 |
35 | log = alog.use_channel("SEQ_CLASS")
36 | error = error_handler.get(log)
37 |
38 |
39 | @module(
40 | id="d21107ca-d579-4321-aedd-3099a526e0dd",
41 | name="Sequence classification",
42 | version="0.1.0",
43 | task=TextClassificationTask,
44 | )
45 | class SequenceClassification(ModuleBase):
46 |
47 | ################################ Constructor #################################################
48 |
49 | def __init__(
50 | self,
51 | resource: HFAutoSequenceClassifier,
52 | ):
53 | super().__init__()
54 | error.type_check(
55 | "",
56 | HFAutoSequenceClassifier,
57 | resource=resource,
58 | )
59 | self.resource = resource
60 | self.tokenizer = resource.tokenizer
61 | self.model = resource.model
62 |
63 | ################################## API functions #############################################
64 |
65 | def run(self, text: str) -> ClassificationResults:
66 | """Run the sequence classification.
67 | NOTE: This will truncate sequences that are too long for model
68 |
69 | Args:
70 | text: str
71 | Input string to be classified
72 |
73 | Returns:
74 | ClassificationResults
75 | """
76 | scores_dict = self._get_scores(text)
77 | # Re-organize scores_dict - for one text, this is just the first score
78 | return SequenceClassification._process_predictions(scores_dict, text_idx=0)
79 |
80 | def run_batch(self, texts: List[str]) -> List[ClassificationResults]:
81 | """Run the sequence classification on batch, truncates sequences too long for model
82 |
83 | Args:
84 | text: List[str]
85 | Input strings to be classified
86 |
87 | Returns:
88 | List[ClassificationResults]
89 | """
90 | scores_dict = self._get_scores(texts)
91 | num_texts = len(texts)
92 |
93 | # Re-organize scores_dict for each example
94 | # We could eventually consider whether or not to sort classifications by scores
95 | # but avoiding this prescription here for now
96 | classification_predictions = []
97 | for text_idx in range(num_texts):
98 | classification_prediction = SequenceClassification._process_predictions(
99 | scores_dict, text_idx
100 | )
101 | classification_predictions.append(classification_prediction)
102 | return classification_predictions
103 |
104 | def save(self, model_path: str):
105 | """Save model in target path
106 |
107 | Args:
108 | model_path: str
109 | Path to store model artifact(s)
110 | """
111 | module_saver = ModuleSaver(
112 | self,
113 | model_path=model_path,
114 | )
115 |
116 | with module_saver:
117 | module_saver.save_module(self.resource, "sequence_classifier")
118 |
119 | @classmethod
120 | def load(cls, model_path: str) -> "SequenceClassification":
121 | """Load a sequence classification model
122 |
123 | Args:
124 | model_path: str
125 | Path to the model to be loaded.
126 |
127 | Returns:
128 | SequenceClassification
129 | Instance of this class built from the on disk model.
130 | """
131 | loader = ModuleLoader(model_path)
132 | resource = loader.load_module("sequence_classifier")
133 | return cls(resource=resource)
134 |
135 | @classmethod
136 | def bootstrap(cls, base_model_path: str) -> "SequenceClassification":
137 | """Bootstrap a HuggingFace transformer-based sequence classification model
138 |
139 | Args:
140 | base_model_path: str
141 | Path to the model to be loaded.
142 | """
143 | # Note: Must provide path to tokenizer if model_name is a path
144 | # for resource use
145 | resource = HFAutoSequenceClassifier.bootstrap(
146 | model_name=base_model_path, tokenizer_name=base_model_path
147 | )
148 | return cls(
149 | resource=resource,
150 | )
151 |
152 | ################################## Private Functions #########################################
153 |
154 | def _get_scores(self, text: Union[str, List[str]]):
155 | """Run tokenizer and model to get scores on text(s)
156 |
157 | Args:
158 | text: Union[str, List[str]]
159 | Input string(s) to be used
160 |
161 | Returns:
162 | scores_dict
163 | Dict with key label, and values as the array of scores,
164 | each corresponding to text(s)
165 | """
166 | # NOTE: no explicit GPU support at this time
167 |
168 | # Apply tokenizer
169 | tokenized_text = self.tokenizer(
170 | text,
171 | padding="max_length",
172 | truncation=True,
173 | return_tensors="pt", # PyTorch
174 | )
175 | # NOTE: no truncation warning given at this time
176 | # difficult to detect if truncation happened since padding also occurs
177 | with torch.no_grad():
178 | logits = self.model(**tokenized_text).logits
179 |
180 | if not self.model.config.id2label:
181 | log.warning(
182 | "",
183 | "No id2label provided in model config. Defaulting to numeric labels",
184 | )
185 |
186 | softmax = torch.nn.Softmax(dim=1)
187 | raw_scores = softmax(logits)
188 | scores = raw_scores.double().numpy()
189 | num_labels = self.model.num_labels
190 | num_texts = 1 # str
191 | if isinstance(text, List):
192 | num_texts = len(text)
193 | error.value_check(
194 | "",
195 | scores.shape == (num_texts, num_labels),
196 | "model logits expected to be of shape (num_texts, num_labels)",
197 | )
198 |
199 | scores_dict = {}
200 | for label_idx in range(num_labels):
201 | if self.model.config.id2label:
202 | label = self.model.config.id2label[label_idx]
203 | else:
204 | label = label_idx
205 | label_scores = scores[:, label_idx]
206 | scores_dict[label] = label_scores
207 | return scores_dict
208 |
209 | @staticmethod
210 | def _process_predictions(scores_dict: Dict, text_idx: int) -> ClassificationResults:
211 | """Process dictionary of label: scores to ClassificationResults
212 |
213 | Args:
214 | scores_dict: Dict
215 | Dict with key label, and values as the array of scores,
216 | each corresponding to text(s)
217 | text_idx: int
218 | Integer index of text in batch
219 |
220 | Returns:
221 | ClassificationResults
222 | """
223 | error.type_check("", Dict, scores_dict=scores_dict)
224 | classification_list = []
225 | for label, score_array in scores_dict.items():
226 | # NOTE: labels are expected to be str, especially for config
227 | classification_list.append(
228 | ClassificationResult(label=str(label), score=score_array[text_idx])
229 | )
230 | return ClassificationResults(results=classification_list)
231 |
--------------------------------------------------------------------------------
/tests/modules/text_generation/test_text_generation_local.py:
--------------------------------------------------------------------------------
1 | """Tests for text-generation module
2 | """
3 |
4 | # Standard
5 | import os
6 | import platform
7 | import tempfile
8 |
9 | # Third Party
10 | import pytest
11 | import torch
12 |
13 | # First Party
14 | from caikit.interfaces.nlp.data_model import GeneratedTextResult, TokenizationResults
15 | import caikit
16 |
17 | # Local
18 | from caikit_nlp.data_model import GenerationTrainRecord
19 | from caikit_nlp.modules.text_generation import TextGeneration
20 | from caikit_nlp.resources.pretrained_model import HFAutoCausalLM, HFAutoSeq2SeqLM
21 | from tests.fixtures import (
22 | CAUSAL_LM_MODEL,
23 | SEQ2SEQ_LM_MODEL,
24 | disable_wip,
25 | set_cpu_device,
26 | )
27 |
28 | ### Stub Modules
29 |
30 |
31 | def test_bootstrap_and_run_causallm():
32 | """Check if we can bootstrap and run causallm models"""
33 |
34 | model = TextGeneration.bootstrap(CAUSAL_LM_MODEL)
35 |
36 | sample_text = "Hello stub"
37 | generated_text = model.run(sample_text)
38 | assert isinstance(generated_text, GeneratedTextResult)
39 |
40 |
41 | def test_bootstrap_and_run_seq2seq():
42 | """Check if we can bootstrap and run seq2seq models"""
43 |
44 | model = TextGeneration.bootstrap(SEQ2SEQ_LM_MODEL)
45 |
46 | sample_text = "Hello stub"
47 | generated_text = model.run(sample_text)
48 | assert isinstance(generated_text, GeneratedTextResult)
49 |
50 |
51 | def test_bootstrap_and_save_model():
52 | """Check if we can bootstrap and save the model successfully"""
53 |
54 | model = TextGeneration.bootstrap(SEQ2SEQ_LM_MODEL)
55 |
56 | with tempfile.TemporaryDirectory() as model_dir:
57 | model.save(model_dir)
58 | assert os.path.isfile(os.path.join(model_dir, "config.yml"))
59 |
60 |
61 | def test_save_model_can_run():
62 | """Check if the model we bootstrap and save is able to load and run successfully"""
63 | model = TextGeneration.bootstrap(SEQ2SEQ_LM_MODEL)
64 |
65 | with tempfile.TemporaryDirectory() as model_dir:
66 | model.save(model_dir)
67 | del model
68 | new_model = TextGeneration.load(model_dir)
69 | sample_text = "Hello stub"
70 | generated_text = new_model.run(sample_text)
71 | assert isinstance(generated_text, GeneratedTextResult)
72 |
73 |
74 | ############################## Training ################################
75 |
76 |
77 | @pytest.mark.skipif(platform.processor() == "arm", reason="ARM training not supported")
78 | def test_train_model_seq2seq(disable_wip, set_cpu_device):
79 | """Ensure that we can finetune a seq2seq model on some toy data for 1+
80 | steps & run inference."""
81 | train_kwargs = {
82 | "base_model": HFAutoSeq2SeqLM.bootstrap(
83 | model_name=SEQ2SEQ_LM_MODEL, tokenizer_name=SEQ2SEQ_LM_MODEL
84 | ),
85 | "num_epochs": 1,
86 | "train_stream": caikit.core.data_model.DataStream.from_iterable(
87 | [
88 | GenerationTrainRecord(
89 | input="@foo what a cute dog!", output="no complaint"
90 | ),
91 | GenerationTrainRecord(
92 | input="@bar this is the worst idea ever.", output="complaint"
93 | ),
94 | ]
95 | ),
96 | "torch_dtype": torch.float32,
97 | }
98 | model = TextGeneration.train(**train_kwargs)
99 | assert isinstance(model.model, HFAutoSeq2SeqLM)
100 |
101 | # Ensure that we can get something out of it
102 | pred = model.run("@bar what a cute cat!")
103 | assert isinstance(pred, GeneratedTextResult)
104 |
105 |
106 | @pytest.mark.skipif(platform.processor() == "arm", reason="ARM training not supported")
107 | def test_train_model_save_and_load(disable_wip, set_cpu_device):
108 | """Ensure that we are able to save and load a finetuned model and execute inference on it"""
109 | train_kwargs = {
110 | "base_model": HFAutoSeq2SeqLM.bootstrap(
111 | model_name=SEQ2SEQ_LM_MODEL, tokenizer_name=SEQ2SEQ_LM_MODEL
112 | ),
113 | "num_epochs": 1,
114 | "train_stream": caikit.core.data_model.DataStream.from_iterable(
115 | [
116 | GenerationTrainRecord(
117 | input="@foo what a cute dog!", output="no complaint"
118 | )
119 | ]
120 | ),
121 | "torch_dtype": torch.float32,
122 | }
123 | model = TextGeneration.train(**train_kwargs)
124 | assert isinstance(model.model, HFAutoSeq2SeqLM)
125 | with tempfile.TemporaryDirectory() as model_dir:
126 | model.save(model_dir)
127 | new_model = TextGeneration.load(model_dir)
128 | sample_text = "Hello stub"
129 | generated_text = new_model.run(sample_text)
130 | assert isinstance(generated_text, GeneratedTextResult)
131 |
132 |
133 | @pytest.mark.skipif(platform.processor() == "arm", reason="ARM training not supported")
134 | def test_train_model_causallm(disable_wip, set_cpu_device):
135 | """Ensure that we can finetune a causal-lm model on some toy data for 1+
136 | steps & run inference."""
137 | train_kwargs = {
138 | "base_model": HFAutoCausalLM.bootstrap(
139 | model_name=CAUSAL_LM_MODEL, tokenizer_name=CAUSAL_LM_MODEL
140 | ),
141 | "num_epochs": 1,
142 | "train_stream": caikit.core.data_model.DataStream.from_iterable(
143 | [
144 | GenerationTrainRecord(
145 | input="@foo what a cute dog!", output="no complaint"
146 | ),
147 | ]
148 | ),
149 | "torch_dtype": torch.float32,
150 | }
151 | model = TextGeneration.train(**train_kwargs)
152 | assert isinstance(model.model, HFAutoCausalLM)
153 |
154 | # Ensure that we can get something out of it
155 | pred = model.run("@bar what a cute cat!")
156 | assert isinstance(pred, GeneratedTextResult)
157 |
158 |
159 | ############################## Inferencing flags ################################
160 |
161 |
162 | @pytest.mark.skipif(platform.processor() == "arm", reason="ARM training not supported")
163 | def test_train_and_infer_model_causallm(disable_wip, set_cpu_device):
164 | """Ensure that we can finetune a causal-lm model on some toy data for 1+
165 | steps & run inference."""
166 | train_kwargs = {
167 | "base_model": HFAutoCausalLM.bootstrap(
168 | model_name=CAUSAL_LM_MODEL, tokenizer_name=CAUSAL_LM_MODEL
169 | ),
170 | "num_epochs": 1,
171 | "train_stream": caikit.core.data_model.DataStream.from_iterable(
172 | [
173 | GenerationTrainRecord(
174 | input="@foo what a cute dog!", output="no complaint"
175 | ),
176 | ]
177 | ),
178 | "torch_dtype": torch.float32,
179 | }
180 | model = TextGeneration.train(**train_kwargs)
181 | assert isinstance(model.model, HFAutoCausalLM)
182 |
183 | # Ensure that preserve_input_text returns input in output
184 | pred = model.run("@bar what a cute cat!", preserve_input_text=True)
185 | assert "@bar what a cute cat!" in pred.generated_text
186 |
187 | # Ensure that preserve_input_text set to False, removes input from output
188 | pred = model.run("@bar what a cute cat!", preserve_input_text=False)
189 | assert "@bar what a cute cat!" not in pred.generated_text
190 |
191 |
192 | ############################## Error Cases ################################
193 |
194 |
195 | def test_zero_epoch_case(disable_wip):
196 | """Test to ensure 0 epoch training request doesn't explode"""
197 | train_kwargs = {
198 | "base_model": HFAutoSeq2SeqLM.bootstrap(
199 | model_name=SEQ2SEQ_LM_MODEL, tokenizer_name=SEQ2SEQ_LM_MODEL
200 | ),
201 | "num_epochs": 0,
202 | "train_stream": caikit.core.data_model.DataStream.from_iterable(
203 | [
204 | GenerationTrainRecord(
205 | input="@foo what a cute dog!", output="no complaint"
206 | ),
207 | ]
208 | ),
209 | "torch_dtype": torch.float32,
210 | }
211 | model = TextGeneration.train(**train_kwargs)
212 | assert isinstance(model.model, HFAutoSeq2SeqLM)
213 |
214 |
215 | # ############################## Run Tokenizer ################################
216 |
217 |
218 | def test_run_tokenizer_edge_cases(disable_wip, set_cpu_device):
219 | """Test tokenizer on edge cases like empty strings and long input."""
220 | model = TextGeneration.bootstrap(CAUSAL_LM_MODEL)
221 |
222 | # Edge case: Empty string
223 | empty_result = model.run_tokenizer("")
224 | assert isinstance(empty_result, TokenizationResults)
225 | assert empty_result.token_count == 0
226 |
227 | # Normal case: short sentence
228 | short_text = "This is a test sentence."
229 | short_result = model.run_tokenizer(short_text)
230 | assert isinstance(short_result, TokenizationResults)
231 | assert short_result.token_count == len(model.model.tokenizer.encode(short_text))
232 |
233 | # Edge case: Long input
234 | long_text = "This is a test sentence. " * 1000
235 | long_result = model.run_tokenizer(long_text)
236 | assert isinstance(long_result, TokenizationResults)
237 | assert long_result.token_count == len(model.model.tokenizer.encode(long_text))
238 |
--------------------------------------------------------------------------------
/caikit_nlp/modules/text_generation/peft_config.py:
--------------------------------------------------------------------------------
1 | # Copyright The Caikit Authors
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | # Standard
16 | from enum import Enum
17 | import os
18 | import re
19 |
20 | # Third Party
21 | from peft import MultitaskPromptTuningInit
22 | from transformers import AutoConfig
23 |
24 | # First Party
25 | from caikit import get_config
26 | from caikit.core import error_handler
27 | import alog
28 |
29 | # Local
30 | from ...data_model import PromptOutputModelType
31 | from ...resources.pretrained_model import PretrainedModelBase
32 | from ...toolkit.data_type_utils import get_torch_dtype
33 | from ...toolkit.verbalizer_utils import is_valid_verbalizer
34 |
35 | # NOTE: We do not allow all the methods exposed by MPT / PT, such as `EXACT_SOURCE_TASK`
36 | # since those are for experimental use and would not be useful / applicable
37 | # for end-user use-cases
38 | allowed_tuning_init_methods = [
39 | "TEXT",
40 | "RANDOM",
41 | "ONLY_SOURCE_SHARED",
42 | "AVERAGE_SOURCE_TASKS",
43 | ]
44 |
45 | log = alog.use_channel("PFT_CNFG_TLKT")
46 | error = error_handler.get(log)
47 |
48 | SOURCE_DIR_VALIDATION_REGEX = re.compile(r"^[-a-zA-Z_0-9\/\.]+")
49 | # 🤮 FIXME: This two dot regex is added as a way to avoid expressions like ..
50 | # giving access to un-intended directories. But this is an ugly hack
51 | # and we need to figure out better solution or better regex
52 | TWO_DOTS_REGEX = re.compile(r"(\.\.)+")
53 |
54 |
55 | class TuningType(str, Enum):
56 | PROMPT_TUNING = "PROMPT_TUNING"
57 | MULTITASK_PROMPT_TUNING = "MULTITASK_PROMPT_TUNING"
58 | # MULTITASK_PREFIX_TUNING = "MULTITASK_PREFIX_TUNING"
59 | # P_TUNING = "P_TUNING"
60 | # PREFIX_TUNING = "PREFIX_TUNING"
61 | # LORA = "LORA"
62 |
63 |
64 | def resolve_base_model(base_model, cls, torch_dtype):
65 | if isinstance(base_model, str):
66 |
67 | error.value_check(
68 | "",
69 | re.fullmatch(SOURCE_DIR_VALIDATION_REGEX, base_model)
70 | and not re.search(TWO_DOTS_REGEX, base_model),
71 | "invalid characters in base_model name",
72 | )
73 | if get_config().base_models_dir:
74 |
75 | base_model_full_path = os.path.join(
76 | get_config().base_models_dir, base_model
77 | )
78 | if os.path.exists(base_model_full_path):
79 | base_model = base_model_full_path
80 |
81 | model_config = AutoConfig.from_pretrained(
82 | base_model, local_files_only=not get_config().allow_downloads
83 | )
84 |
85 | resource_type = None
86 | for resource in cls.supported_resources:
87 | if model_config.model_type in resource.SUPPORTED_MODEL_TYPES:
88 | resource_type = resource
89 | break
90 |
91 | if not resource_type:
92 | error(
93 | "",
94 | "{} model type is not supported currently!".format(
95 | model_config.model_type
96 | ),
97 | )
98 | log.debug("Bootstrapping base resource [%s]", base_model)
99 | base_model = resource_type.bootstrap(base_model, torch_dtype=torch_dtype)
100 | return base_model
101 |
102 |
103 | def get_peft_config(
104 | tuning_type, tuning_config, base_model, cls, torch_dtype, verbalizer
105 | ):
106 |
107 | if tuning_type not in TuningType._member_names_:
108 | raise NotImplementedError("{} tuning type not supported!".format(tuning_type))
109 |
110 | if tuning_config.prompt_tuning_init_method:
111 | # NOTE: GK-APR-5-2023
112 | # MultitaskPromptTuningInit and MultitaskPrefixTuningInit are same at the
113 | # time of writing, which is a superset of PromptTuningInit
114 | init_method = tuning_config.prompt_tuning_init_method
115 |
116 | error.value_check(
117 | "",
118 | init_method in allowed_tuning_init_methods,
119 | f"Init method [{init_method}] not in allowed init methods: "
120 | f"[{allowed_tuning_init_methods}]",
121 | )
122 |
123 | init_method = MultitaskPromptTuningInit(init_method)
124 | log.info("Using initialization method [%s]", init_method)
125 |
126 | # If init method provided relates to one that requires source model,
127 | # make sure the source prompt model is provided.
128 | if init_method in [
129 | MultitaskPromptTuningInit.AVERAGE_SOURCE_TASKS,
130 | MultitaskPromptTuningInit.ONLY_SOURCE_SHARED,
131 | ]:
132 | # NOTE: prompt_tuning_init_source_model is currently a path. In future
133 | # we will replace this with caikit.resources to properly cataloging these
134 | error.type_check(
135 | "",
136 | str,
137 | prompt_tuning_init_source_model=tuning_config.prompt_tuning_init_source_model,
138 | )
139 | tuning_config.prompt_tuning_init_source_model = os.path.join(
140 | get_config().source_prompt_base,
141 | tuning_config.prompt_tuning_init_source_model,
142 | )
143 |
144 | error.file_check(
145 | "", tuning_config.prompt_tuning_init_source_model
146 | )
147 | log.debug(
148 | "Validated tuning source prompt [%s]",
149 | tuning_config.prompt_tuning_init_source_model,
150 | )
151 |
152 | error.type_check("", PretrainedModelBase, base_model=base_model)
153 |
154 | # Validate if tuned output model type is compatible with base model or not
155 | if not tuning_config.output_model_types:
156 | output_model_types = base_model.PROMPT_OUTPUT_TYPES
157 | else:
158 | # If the first element is not PromptOutputModelType, assume the entire list
159 | # isn't and convert
160 | if not isinstance(tuning_config.output_model_types[0], PromptOutputModelType):
161 | output_model_types = []
162 | for output_type in tuning_config.output_model_types:
163 | output_model_types.append(PromptOutputModelType(output_type))
164 | else:
165 | output_model_types = tuning_config.output_model_types
166 | error.value_check(
167 | "",
168 | all(
169 | output_type in base_model.PROMPT_OUTPUT_TYPES
170 | for output_type in output_model_types
171 | ),
172 | "{} not supported for base model type {}".format(
173 | output_model_types, base_model.MODEL_TYPE
174 | ),
175 | )
176 |
177 | error.value_check(
178 | "",
179 | len(output_model_types) <= base_model.MAX_NUM_TRANSFORMERS,
180 | f"Too many output model types. Got {len(output_model_types)}, "
181 | f"maximum {base_model.MAX_NUM_TRANSFORMERS}",
182 | )
183 | # Ensure that our verbalizer is a string and will not render to a hardcoded string
184 | error.value_check(
185 | "",
186 | is_valid_verbalizer(verbalizer),
187 | "Provided verbalizer is an invalid type or has no renderable placeholders",
188 | )
189 |
190 | # NOTE: Base model is a resource at this point
191 | task_type = base_model.TASK_TYPE
192 |
193 | if isinstance(tuning_type, str):
194 | error.value_check(
195 | "",
196 | tuning_type in TuningType._member_names_,
197 | f"Invalid tuning type [{tuning_type}]. Allowed types: "
198 | f"[{TuningType._member_names_}]",
199 | )
200 | tuning_type = TuningType(tuning_type)
201 | error.type_check("", TuningType, tuning_type=tuning_type)
202 |
203 | # Coerce the passed model into a resource; if we have one, this is a noop
204 | # TODO: When splitting up this mono-module, use the configured resource
205 | # type of the concrete class to bootstrap
206 | torch_dtype = get_torch_dtype(torch_dtype)
207 |
208 | # Take tokenizer name/path from the model
209 | tokenizer_name_or_path = base_model.model.config._name_or_path
210 |
211 | # Build the peft config; this is how we determine that we want a sequence classifier.
212 | # If we want more types, we will likely need to map this to data model outputs etc.
213 |
214 | # NOTE: We currently only support TEXT as init type, this is to later only easily
215 | # switch to MPT
216 | peft_config = cls.create_hf_tuning_config(
217 | base_model=base_model,
218 | tuning_type=tuning_type,
219 | task_type=task_type,
220 | tokenizer_name_or_path=tokenizer_name_or_path,
221 | tuning_config=tuning_config,
222 | output_model_types=output_model_types,
223 | )
224 |
225 | return task_type, output_model_types, peft_config, tuning_type
226 |
--------------------------------------------------------------------------------