├── .dockerignore ├── src └── aihero │ ├── research │ ├── finetuning │ │ ├── __init__.py │ │ ├── callback.py │ │ ├── utils.py │ │ ├── infer.py │ │ └── train.py │ └── __init__.py │ └── __init__.py ├── Dockerfile ├── launch.py ├── LICENSE ├── .github └── workflows │ └── docker-image.yml ├── README.md ├── pyproject.toml ├── .pre-commit-config.yaml └── .gitignore /.dockerignore: -------------------------------------------------------------------------------- 1 | **/venv 2 | .env 3 | -------------------------------------------------------------------------------- /src/aihero/research/finetuning/__init__.py: -------------------------------------------------------------------------------- 1 | """Module `aihero.research.finetuning`.""" 2 | -------------------------------------------------------------------------------- /src/aihero/__init__.py: -------------------------------------------------------------------------------- 1 | """Module `aihero`. NOTE: It's parts also exists in other repos.""" 2 | -------------------------------------------------------------------------------- /src/aihero/research/__init__.py: -------------------------------------------------------------------------------- 1 | """Module `aihero.research`. NOTE: It's parts also exists in other repos.""" 2 | -------------------------------------------------------------------------------- /Dockerfile: -------------------------------------------------------------------------------- 1 | # Use an official Python runtime as a parent image 2 | FROM pytorch/pytorch:2.1.0-cuda12.1-cudnn8-devel 3 | 4 | ENV DEBIAN_FRONTEND=noninteractive 5 | 6 | RUN apt-get update \ 7 | && apt-get -y upgrade --only-upgrade systemd openssl cryptsetup \ 8 | && apt-get install -y \ 9 | bzip2 \ 10 | curl \ 11 | git \ 12 | git-lfs \ 13 | tar \ 14 | vim \ 15 | && apt-get clean autoremove --yes \ 16 | && rm -rf /var/lib/{apt,dpkg,cache,log} 17 | 18 | WORKDIR /home/user 19 | # Install any needed packages specified in requirements.txt 20 | RUN pip install --upgrade pip build 21 | 22 | COPY pyproject.toml /home/user/pyproject.toml 23 | COPY src/aihero /home/user/src/aihero 24 | RUN pip install . 25 | 26 | # Run launch.py when the container launches 27 | COPY launch.py /home/user/ 28 | CMD ["python", "launch.py"] 29 | -------------------------------------------------------------------------------- /launch.py: -------------------------------------------------------------------------------- 1 | """run script for fine-tuning a model.""" 2 | import os 3 | 4 | from aihero.research.config.schema import BatchInferenceJob, TrainingJob 5 | from aihero.research.finetuning.infer import BatchInferenceJobRunner 6 | from aihero.research.finetuning.train import TrainingJobRunner 7 | from fire import Fire 8 | 9 | 10 | def train(training_config_file: str = "/mnt/config/training/config.yaml") -> None: 11 | """Run Training.""" 12 | training_config = TrainingJob.load(training_config_file) 13 | TrainingJobRunner(training_config, is_distributed=int(os.getenv("WORLD_SIZE", 1)) > 1).run() 14 | 15 | 16 | def infer(batch_inference_config_file: str = "/mnt/config/batch_inference/config.yaml") -> None: 17 | """Run Batch Inference.""" 18 | batch_inference_config = BatchInferenceJob.load(batch_inference_config_file) 19 | BatchInferenceJobRunner(batch_inference_config).run() 20 | 21 | 22 | if __name__ == "__main__": 23 | Fire({"train": train, "infer": infer}) 24 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2023 A.I. Hero, Inc. 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /.github/workflows/docker-image.yml: -------------------------------------------------------------------------------- 1 | name: Docker Image Creation 2 | on: [push] 3 | jobs: 4 | build: 5 | runs-on: large-runner 6 | timeout-minutes: 5 7 | steps: 8 | - name: Checkout 9 | uses: actions/checkout@v3 10 | - name: Set up QEMU 11 | uses: docker/setup-qemu-action@v3 12 | - name: Set up Docker Buildx 13 | uses: docker/setup-buildx-action@v3 14 | - name: Login to Docker Hub 15 | uses: docker/login-action@v3 16 | with: 17 | username: ${{ secrets.DOCKERHUB_USERNAME }} 18 | password: ${{ secrets.DOCKERHUB_TOKEN }} 19 | - name: Set short git commit SHA 20 | id: short_sha 21 | run: | 22 | echo "SHORT_SHA=$(git rev-parse --short ${{ github.sha }})" >> $GITHUB_ENV 23 | - name: Build and push 24 | uses: docker/build-push-action@v5 25 | with: 26 | context: . 27 | push: true 28 | tags: aihero/${{ github.event.repository.name }}:${{ env.SHORT_SHA }} 29 | cache-from: type=registry,ref=aihero/${{ github.event.repository.name }}:buildcache 30 | cache-to: type=registry,ref=aihero/${{ github.event.repository.name }}:buildcache,mode=max 31 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Fine-Tuning for LLM Research by AI Hero 2 | 3 | This repo contains the code that will be run inside the container. Alternatively, this code can also be run natively. The container is built and pushed to the repo using Github actions (see below). You can launch the fine tuning job using the examples in the `https://github.com/ai-hero/llm-research-examples` project with the data created with the `https://github.com/ai-hero/llm-research-data` project. 4 | 5 | ## Container 6 | Our latest container we use for training is `rparundekar/fine_tune_research:{SHORT_SHA_ON_MAIN}`. You can launch jobs using this tag with the `llm-research-examples` project. 7 | 8 | ## For Contributors 9 | Installing this library locally. 10 | 11 | ```sh 12 | pip install . 13 | ``` 14 | 15 | ### Building Docker using Github Actions 16 | Change the Github actions in the `.github` folder and set the right environment variables in Github to auto build the container and push to the right repo. 17 | 18 | ### Building a Docker Image Manually 19 | Use a tag versioning by date / user as needed. For example, 20 | ```sh 21 | docker build . -t rparundekar/fine_tune_research:20230110_01 22 | docker push rparundekar/fine_tune_research:20230110_01 23 | ``` 24 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [project] 2 | name = "aihero-research-finetuning" 3 | version = "0.3.3" 4 | description = "Framework for open source research on fine-tuning LLMs" 5 | authors = [ 6 | {name = "Rahul Parundekar", email= "rahul@aihero.studio" }, 7 | {name = "Shankar Ganesan", email = "gshankar.87@gmail.com" } 8 | ] 9 | readme = "README.md" 10 | classifiers = [ 11 | "Programming Language :: Python", 12 | "Programming Language :: Python :: 3.9", 13 | "Programming Language :: Python :: 3.10", 14 | "Programming Language :: Python :: 3.11", 15 | "License :: OSI Approved :: MIT License", 16 | ] 17 | dependencies = [ 18 | "accelerate==0.25.0", 19 | "aihero-research-config @ git+https://github.com/ai-hero/llm-research-config.git@main#egg=aihero-research-config", 20 | "bitsandbytes==0.41.3.post2", 21 | "datasets==2.14.6", 22 | "einops==0.7.0", 23 | "fire==0.5.0", 24 | "minio==7.2.0", 25 | "numpy==1.25.2", 26 | "peft==0.7.1", 27 | "pydantic-settings==2.0.3", 28 | "python-dotenv==1.0.1", 29 | "PyYAML==6.0.1", 30 | "scikit-learn==1.4.0", 31 | "scipy==1.11.3", 32 | "transformers==4.36.1", 33 | "trl==0.7.7", 34 | "types-PyYAML==6.0.12.12", 35 | "wandb==0.15.12", 36 | "jsonschema", 37 | ] 38 | 39 | [project.optional-dependencies] 40 | dev = [ 41 | "pytest>=6.2.5", 42 | "black>=22.3", 43 | "mypy>=0.910", 44 | "ruff>=0.0.79", 45 | "blacken-docs>=1.11.0", 46 | "pyupgrade>=2.29.1", 47 | "detect-secrets>=1.2.0", 48 | "tomli>=1.2.3", 49 | "pre-commit>=2.17.0", 50 | ] 51 | 52 | [build-system] 53 | requires = ["setuptools>=61.0"] 54 | build-backend = "setuptools.build_meta" 55 | 56 | 57 | 58 | [tool.pytest.ini_options] 59 | addopts = "-vvv" 60 | testpaths = "src/tests" 61 | 62 | [tool.black] 63 | line_length = 120 64 | target_version = ['py39'] 65 | 66 | [tool.ruff] 67 | exclude = [ 68 | ".venv", 69 | ".git", 70 | "__pycache__", 71 | "build", 72 | "dist", 73 | "venv", 74 | ] 75 | ignore = [] 76 | line-length = 120 77 | select = [ 78 | "D", 79 | "E", 80 | "F", 81 | "I", 82 | "W", 83 | ] 84 | src = ["src/aihero", "src/tests"] 85 | 86 | [mypy] 87 | files = ["src/aihero"] 88 | strict_optional = false 89 | warn_unused_ignores = false 90 | -------------------------------------------------------------------------------- /.pre-commit-config.yaml: -------------------------------------------------------------------------------- 1 | default_language_version: 2 | python: python3.9 3 | repos: 4 | - repo: https://github.com/pre-commit/pre-commit-hooks 5 | rev: v4.5.0 6 | hooks: 7 | - id: check-ast 8 | - id: check-byte-order-marker 9 | - id: check-case-conflict 10 | - id: check-docstring-first 11 | - id: check-executables-have-shebangs 12 | - id: check-json 13 | - id: check-yaml 14 | exclude: .*_template\.yaml 15 | - id: debug-statements 16 | - id: detect-private-key 17 | - id: end-of-file-fixer 18 | - id: trailing-whitespace 19 | - id: mixed-line-ending 20 | - id: requirements-txt-fixer 21 | 22 | - repo: https://github.com/psf/black 23 | rev: 23.12.1 24 | hooks: 25 | - id: black 26 | language_version: python3.9 27 | 28 | - repo: https://github.com/psf/black 29 | rev: 23.12.1 30 | hooks: 31 | - id: black-jupyter 32 | language_version: python3.9 33 | 34 | - repo: https://github.com/pre-commit/mirrors-mypy 35 | rev: v1.8.0 36 | hooks: 37 | - id: mypy 38 | args: [--strict, --ignore-missing-imports] 39 | additional_dependencies: 40 | - "pydantic>=1.10.4" 41 | 42 | - repo: https://github.com/astral-sh/ruff-pre-commit 43 | # Ruff version. 44 | rev: v0.1.9 45 | hooks: 46 | - id: ruff 47 | args: [--fix] 48 | types_or: [python, pyi, jupyter] 49 | exclude: poc/explore.ipynb 50 | 51 | # - repo: https://github.com/econchick/interrogate 52 | # rev: 1.5.0 53 | # hooks: 54 | # - id: interrogate 55 | # args: [--fail-under=80, --verbose] 56 | 57 | - repo: https://github.com/asottile/pyupgrade 58 | rev: v3.15.0 59 | hooks: 60 | - id: pyupgrade 61 | args: [--py36-plus] 62 | 63 | - repo: https://github.com/asottile/blacken-docs 64 | rev: 1.16.0 65 | hooks: 66 | - id: blacken-docs 67 | additional_dependencies: 68 | - black==22.12.0 69 | 70 | - repo: https://github.com/Yelp/detect-secrets 71 | rev: v1.4.0 72 | hooks: 73 | - id: detect-secrets 74 | name: "detect-secrets" 75 | args: ["--exclude-files", '.*\.ipynb$'] 76 | - id: detect-secrets 77 | name: "detect-secrets-jupyter" 78 | args: 79 | [ 80 | "--exclude-files", 81 | ".*[^i][^p][^y][^n][^b]$", 82 | "--exclude-lines", 83 | '"(hash|id|authorship_tag|image/\w+)":.*', 84 | ] 85 | -------------------------------------------------------------------------------- /src/aihero/research/finetuning/callback.py: -------------------------------------------------------------------------------- 1 | """Custom callback for sampling from a LLM and reporting custom eval to WANDB.""" 2 | import random 3 | from typing import Any 4 | 5 | from datasets import Dataset 6 | from transformers.integrations import WandbCallback 7 | from trl import SFTTrainer 8 | 9 | from aihero.research.finetuning.infer import BatchInferenceWithEval 10 | 11 | MAX_NEW_TOKENS = 512 12 | 13 | 14 | class LLMSampleCB(WandbCallback): # type: ignore 15 | """Callback for sampling from a LLM and reporting custom eval to WANDB.""" 16 | 17 | def __init__( 18 | self: "LLMSampleCB", 19 | trainer: SFTTrainer, 20 | task: str, 21 | test_split: Dataset, 22 | num_samples: int = 100, 23 | max_new_tokens: int = MAX_NEW_TOKENS, 24 | log_model: str = "checkpoint", 25 | run_tests_str: str = "", 26 | run_metrics_str: str = "", 27 | ): 28 | """Initialize the callback by extracting a few rows from the test split.""" 29 | super().__init__() 30 | assert task == "completion", "Only completion task supported for now" 31 | self._log_model = log_model 32 | self.batch_inference = BatchInferenceWithEval( 33 | model=trainer.model, 34 | tokenizer=trainer.tokenizer, 35 | task=task, 36 | run_tests_str=run_tests_str, 37 | run_metrics_str=run_metrics_str, 38 | max_new_tokens=max_new_tokens, 39 | ) 40 | 41 | # Sample a few rows from the test split to generate a table of predictions 42 | # for visual inspection a.k.a. spot checking 43 | # Randomly select indices for the samples 44 | if num_samples >= test_split.num_rows: 45 | selected_indices = list(range(0, test_split.num_rows)) 46 | else: 47 | selected_indices = random.sample(range(test_split.num_rows), num_samples) 48 | # Retrieve the selected samples from the dataset 49 | test_split_list = list(test_split) 50 | self.sample_split = [] 51 | for i in selected_indices: 52 | self.sample_split.append(test_split_list[i]) 53 | self.sample_split = Dataset.from_list(self.sample_split) 54 | 55 | def initialize(self: "LLMSampleCB") -> None: 56 | """Generate initial predictions for the sample split and log them to WANDB.""" 57 | self._wandb.init() 58 | 59 | _, (records_table, metrics) = self.batch_inference.run_initial_predictions(self.sample_split) 60 | 61 | # Log the table of sample predictions to W&B 62 | self._wandb.log({"sample_predictions": records_table}) 63 | 64 | # Log the calculated metrics to W&B 65 | self._wandb.log(metrics) 66 | print("LLMSampleCB initialized") 67 | 68 | def on_evaluate(self, args: Any, state: Any, control: Any, **kwargs: dict[str, Any]) -> None: 69 | """Log the sample predictions and metrics to WANDB on eval callback.""" 70 | super().on_evaluate(args, state, control, **kwargs) 71 | 72 | # Generate the table of sample predictions 73 | _, (records_table, metrics) = self.batch_inference.infer(self.sample_split) 74 | 75 | # Log the table of sample predictions to W&B 76 | self._wandb.log({"sample_predictions": records_table}) 77 | 78 | # Log the calculated metrics to W&B 79 | self._wandb.log(metrics) 80 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | .ruff_cache 2 | .DS_Store 3 | 4 | # Byte-compiled / optimized / DLL files 5 | __pycache__/ 6 | *.py[cod] 7 | *$py.class 8 | 9 | # C extensions 10 | *.so 11 | 12 | # Distribution / packaging 13 | .Python 14 | build/ 15 | develop-eggs/ 16 | dist/ 17 | downloads/ 18 | eggs/ 19 | .eggs/ 20 | lib/ 21 | lib64/ 22 | parts/ 23 | sdist/ 24 | var/ 25 | wheels/ 26 | share/python-wheels/ 27 | *.egg-info/ 28 | .installed.cfg 29 | *.egg 30 | MANIFEST 31 | 32 | # PyInstaller 33 | # Usually these files are written by a python script from a template 34 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 35 | *.manifest 36 | *.spec 37 | 38 | # Installer logs 39 | pip-log.txt 40 | pip-delete-this-directory.txt 41 | 42 | # Unit test / coverage reports 43 | htmlcov/ 44 | .tox/ 45 | .nox/ 46 | .coverage 47 | .coverage.* 48 | .cache 49 | nosetests.xml 50 | coverage.xml 51 | *.cover 52 | *.py,cover 53 | .hypothesis/ 54 | .pytest_cache/ 55 | cover/ 56 | 57 | # Translations 58 | *.mo 59 | *.pot 60 | 61 | # Django stuff: 62 | *.log 63 | local_settings.py 64 | db.sqlite3 65 | db.sqlite3-journal 66 | 67 | # Flask stuff: 68 | instance/ 69 | .webassets-cache 70 | 71 | # Scrapy stuff: 72 | .scrapy 73 | 74 | # Sphinx documentation 75 | docs/_build/ 76 | 77 | # PyBuilder 78 | .pybuilder/ 79 | target/ 80 | 81 | # Jupyter Notebook 82 | .ipynb_checkpoints 83 | 84 | # IPython 85 | profile_default/ 86 | ipython_config.py 87 | 88 | # pyenv 89 | # For a library or package, you might want to ignore these files since the code is 90 | # intended to run in multiple environments; otherwise, check them in: 91 | # .python-version 92 | 93 | # pipenv 94 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 95 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 96 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 97 | # install all needed dependencies. 98 | #Pipfile.lock 99 | 100 | # poetry 101 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. 102 | # This is especially recommended for binary packages to ensure reproducibility, and is more 103 | # commonly ignored for libraries. 104 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control 105 | #poetry.lock 106 | 107 | # pdm 108 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. 109 | #pdm.lock 110 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it 111 | # in version control. 112 | # https://pdm.fming.dev/#use-with-ide 113 | .pdm.toml 114 | 115 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm 116 | __pypackages__/ 117 | 118 | # Celery stuff 119 | celerybeat-schedule 120 | celerybeat.pid 121 | 122 | # SageMath parsed files 123 | *.sage.py 124 | 125 | # Environments 126 | .env 127 | .venv 128 | env/ 129 | venv/ 130 | ENV/ 131 | env.bak/ 132 | venv.bak/ 133 | 134 | # Spyder project settings 135 | .spyderproject 136 | .spyproject 137 | 138 | # Rope project settings 139 | .ropeproject 140 | 141 | # mkdocs documentation 142 | /site 143 | 144 | # mypy 145 | .mypy_cache/ 146 | .dmypy.json 147 | dmypy.json 148 | 149 | # Pyre type checker 150 | .pyre/ 151 | 152 | # pytype static type analyzer 153 | .pytype/ 154 | 155 | # Cython debug symbols 156 | cython_debug/ 157 | 158 | # PyCharm 159 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can 160 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore 161 | # and can be added to the global gitignore or merged into this file. For a more nuclear 162 | # option (not recommended) you can uncomment the following to ignore the entire idea folder. 163 | #.idea/ 164 | 165 | .DS_Store 166 | -------------------------------------------------------------------------------- /src/aihero/research/finetuning/utils.py: -------------------------------------------------------------------------------- 1 | """Utility functions for the app. e.g. upload and download files from S3.""" 2 | import os 3 | import tarfile 4 | from typing import Any, Generator 5 | 6 | import torch 7 | from datasets import load_dataset, load_from_disk 8 | from minio import Minio, S3Error 9 | from peft.tuners.lora import LoraLayer 10 | from transformers import AutoModelForCausalLM 11 | 12 | 13 | class DatasetMover: 14 | """Utility class for uploading and downloading files from S3.""" 15 | 16 | def _compress_folder(self, folder_path: str, output_filename: str) -> None: 17 | """Compress a folder into a tar.gz file.""" 18 | with tarfile.open(output_filename, "w:gz") as tar: 19 | tar.add(folder_path, arcname=os.path.basename(folder_path)) 20 | 21 | def _upload_to_s3(self, file_name: str, bucket_name: str, object_name: str) -> None: 22 | """Upload a file to S3.""" 23 | try: 24 | # Initialize MinIO client 25 | minio_client = Minio( 26 | os.environ["S3_ENDPOINT"], 27 | access_key=os.environ["S3_ACCESS_KEY_ID"], 28 | secret_key=os.environ["S3_SECRET_ACCESS_KEY"], 29 | region=os.environ["S3_REGION"], 30 | secure=os.environ.get("S3_SECURE", "True").lower() == "true", 31 | ) # Use secure=False if not using https 32 | minio_client.fput_object(bucket_name, object_name, file_name) 33 | print(f"'{file_name}' is successfully uploaded as '{object_name}' to bucket '{bucket_name}'.") 34 | except S3Error as e: 35 | print("Error occurred: ", e) 36 | 37 | def upload(self, folder_path: str, output_filename: str, bucket_name: str) -> None: 38 | """Compress the folder and upload it to S3.""" 39 | self._compress_folder(folder_path, output_filename) 40 | self._upload_to_s3(output_filename, bucket_name, output_filename) 41 | 42 | def _download_from_s3(self, bucket_name: str, object_name: str, file_name: str) -> None: 43 | """Download a file from S3.""" 44 | try: 45 | # Initialize MinIO client 46 | minio_client = Minio( 47 | os.environ["S3_ENDPOINT"], 48 | access_key=os.environ["S3_ACCESS_KEY_ID"], 49 | secret_key=os.environ["S3_SECRET_ACCESS_KEY"], 50 | region=os.environ["S3_REGION"], 51 | secure=os.environ.get("S3_SECURE", "True").lower() == "true", 52 | ) 53 | minio_client.fget_object(bucket_name, object_name, file_name) 54 | print(f"'{object_name}' from bucket '{bucket_name}' is successfully downloaded as '{file_name}'.") 55 | except S3Error as e: 56 | print("Error occurred: ", e) 57 | 58 | def _decompress_folder(self, input_filename: str, output_folder_path: str) -> None: 59 | """Decompress a tar.gz file into a folder.""" 60 | try: 61 | with tarfile.open(input_filename, "r:gz") as tar: 62 | tar.extractall(path=output_folder_path) 63 | print(f"'{input_filename}' is successfully decompressed to '{output_folder_path}'.") 64 | except Exception as e: 65 | print("Error occurred: ", e) 66 | 67 | def download(self, bucket_name: str, object_name: str, output_folder_path: str) -> None: 68 | """Download a tar.gz file from S3 and decompress it into a folder.""" 69 | temp_filename = "temp.tar.gz" 70 | self._download_from_s3(bucket_name, object_name, temp_filename) 71 | self._decompress_folder(temp_filename, output_folder_path) 72 | os.remove(temp_filename) # Clean up the temporary compressed file 73 | 74 | 75 | def dataset_generator( 76 | dataset: str, 77 | split: str = "train", 78 | from_disk: bool = False, 79 | task: str = "text", 80 | bos_token: str = "", 81 | eos_token: str = "", 82 | ) -> Generator[dict[str, Any], dict[str, Any], None]: 83 | """Generate training data by yielding each row in the dataset split.""" 84 | # We assume that the dataset is a HuggingFace dataset, and a DatasetDict 85 | # such that the dict has train, val, and test splits. 86 | if from_disk: 87 | ds = load_from_disk(dataset) 88 | ds = ds[split] 89 | # Iterate through the dataset and yield each row 90 | print(f"{ds.num_rows} rows in {split} split") 91 | else: 92 | ds = load_dataset(dataset, streaming=True, split=split) 93 | 94 | for row in iter(ds): 95 | if task == "text": 96 | text = f"{row['text']}" 97 | if not text.startswith(bos_token): 98 | text = f"{bos_token}{text}{eos_token}" 99 | yield {"text": text} 100 | elif task == "completion": 101 | # If the dataset is a 'completion' task dataset, we need to concatenate the prompt and completion 102 | text = f"{row['prompt']}{row['completion']}" 103 | if not text.startswith(bos_token): 104 | text = f"{bos_token}{text}{eos_token}" 105 | yield { 106 | "text": text, 107 | "prompt": row["prompt"], 108 | "completion": row["completion"], 109 | } 110 | else: 111 | raise Exception(f"Unknown task: {task}") 112 | 113 | 114 | def peft_module_casting_to_bf16(model: AutoModelForCausalLM, args: dict[str, str]) -> None: 115 | """Cast the PEFT model to bf16.""" 116 | for name, module in model.named_modules(): 117 | if isinstance(module, LoraLayer): 118 | if args.get("bf16", "false").lower() == "true": 119 | module = module.to(torch.bfloat16) 120 | if "norm" in name: 121 | module = module.to(torch.float32) 122 | if any(x in name for x in ["lm_head", "embed_tokens", "wte", "wpe"]): 123 | if hasattr(module, "weight"): 124 | if args["bf16"] and module.weight.dtype == torch.float32: 125 | module = module.to(torch.bfloat16) 126 | -------------------------------------------------------------------------------- /src/aihero/research/finetuning/infer.py: -------------------------------------------------------------------------------- 1 | """Module to run batch inference jobs.""" 2 | import os 3 | from pathlib import Path 4 | from tempfile import TemporaryDirectory 5 | from typing import Any, Tuple 6 | 7 | import torch 8 | from datasets import Dataset, DatasetDict, DatasetInfo 9 | from huggingface_hub import login 10 | from tqdm import tqdm 11 | from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig, GenerationConfig 12 | from wandb import Table, finish 13 | 14 | from aihero.research.config.schema import BatchInferenceJob 15 | from aihero.research.finetuning.utils import DatasetMover, dataset_generator 16 | 17 | CHECKPOINT_DIR = "/mnt/checkpoint" 18 | DATASET_DIR = "/mnt/dataset" 19 | MAX_NEW_TOKENS = 4096 20 | 21 | if os.environ.get("HF_TOKEN", None): 22 | login(token=os.environ["HF_TOKEN"]) 23 | 24 | 25 | class BatchInferenceJobRunner: 26 | """Class to run a batch inferenc job.""" 27 | 28 | def __init__(self, batch_inference_job: BatchInferenceJob): 29 | """Initialize the training job runner.""" 30 | self.batch_inference_job = batch_inference_job 31 | print("Loading model") 32 | self.model, self.tokenizer = self.load_model() 33 | print("Loading dataset") 34 | self.dataset_dict = self.fetch_dataset() 35 | 36 | # Prep for eval 37 | if self.batch_inference_job.eval: 38 | run_tests_str = self.batch_inference_job.eval.tests or "" 39 | run_metrics_str = self.batch_inference_job.eval.metrics or "" 40 | size = self.batch_inference_job.size or 0 41 | randomize = self.batch_inference_job.randomize or False 42 | else: 43 | run_tests_str = "" 44 | run_metrics_str = "" 45 | size = 0 46 | randomize = False 47 | 48 | self.batch_inference_split = self.dataset_dict["batch_inference"] 49 | if size: 50 | if randomize: 51 | self.batch_inference_split = self.batch_inference_split.shuffle() 52 | self.batch_inference_split = self.batch_inference_split.select(range(size)) 53 | self.batch_inference_with_eval = BatchInferenceWithEval( 54 | model=self.model, 55 | tokenizer=self.tokenizer, 56 | task=self.batch_inference_job.task, 57 | run_tests_str=run_tests_str, 58 | run_metrics_str=run_metrics_str, 59 | max_new_tokens=self.batch_inference_job.generator.max_seq_length or MAX_NEW_TOKENS, 60 | ) 61 | 62 | def load_model(self) -> Tuple[AutoModelForCausalLM, AutoTokenizer]: 63 | """Load the model from HuggingFace Hub or S3.""" 64 | use_4bit = self.batch_inference_job.quantized or False 65 | if use_4bit: 66 | # Compute dtype for 4-bit base models 67 | bnb_4bit_compute_dtype = "float16" 68 | # Quantization type (fp4 or nf4) 69 | bnb_4bit_quant_type = "nf4" 70 | # Activate nested quantization for 4-bit base models (double quantization) 71 | use_nested_quant = False 72 | 73 | # Load tokenizer and model with QLoRA configuration 74 | compute_dtype = getattr(torch, bnb_4bit_compute_dtype) 75 | 76 | bnb_config = BitsAndBytesConfig( 77 | load_in_4bit=use_4bit, 78 | bnb_4bit_quant_type=bnb_4bit_quant_type, 79 | bnb_4bit_compute_dtype=compute_dtype, 80 | bnb_4bit_use_double_quant=use_nested_quant, 81 | ) 82 | 83 | # Check GPU compatibility with bfloat16 84 | if compute_dtype == torch.float16 and use_4bit: 85 | major, _ = torch.cuda.get_device_capability() 86 | if major >= 8: 87 | print("=" * 80) 88 | print("Your GPU supports bfloat16: accelerate training with bf16=True") 89 | print("=" * 80) 90 | 91 | if self.batch_inference_job.model.type == "huggingface": 92 | device_map = {"": 0} 93 | 94 | if use_4bit: 95 | # Load base model 96 | model = AutoModelForCausalLM.from_pretrained( 97 | self.batch_inference_job.model.name, 98 | quantization_config=bnb_config, 99 | device_map=device_map, 100 | trust_remote_code=True, 101 | ) 102 | model.config.use_cache = False 103 | model.config.pretraining_tp = 1 104 | else: 105 | model = AutoModelForCausalLM.from_pretrained( 106 | self.batch_inference_job.model.name, 107 | torch_dtype=torch.bfloat16, 108 | use_cache=False, 109 | trust_remote_code=True, 110 | device_map=device_map, 111 | ) 112 | tokenizer = AutoTokenizer.from_pretrained( 113 | self.batch_inference_job.model.name, 114 | trust_remote_code=True, 115 | add_eos_token=False, 116 | add_bos_token=False, 117 | ) 118 | # May need to have some custom padding logic here 119 | special_tokens = {"pad_token": "[PAD]"} 120 | tokenizer.add_special_tokens(special_tokens) 121 | if self.batch_inference_job.tokenizer and self.batch_inference_job.tokenizer.additional_tokens: 122 | tokenizer.add_tokens(self.batch_inference_job.tokenizer.additional_tokens) 123 | tokenizer.padding_side = "right" 124 | model.config.pad_token_id = tokenizer.pad_token_id 125 | model.resize_token_embeddings(len(tokenizer)) 126 | elif self.batch_inference_job.model.type == "s3": 127 | # TODO : Add s3 support 128 | raise NotImplementedError("S3 support not implemented yet") 129 | else: 130 | raise ValueError(f"Unknown base_model_type: {self.batch_inference_job.model.type}") 131 | return model, tokenizer 132 | 133 | def fetch_dataset(self) -> DatasetDict: 134 | """Fetch the dataset from HuggingFace Hub or S3.""" 135 | splits = {} 136 | bos_token = self.tokenizer.bos_token 137 | eos_token = self.tokenizer.eos_token 138 | if self.batch_inference_job.dataset.type == "huggingface": 139 | splits["batch_inference"] = Dataset.from_generator( 140 | dataset_generator, 141 | gen_kwargs={ 142 | "dataset": self.batch_inference_job.dataset.name, 143 | "split": "batch_inference", 144 | "task": self.batch_inference_job.task, 145 | "bos_token": bos_token, 146 | "eos_token": eos_token, 147 | }, 148 | ) 149 | elif self.batch_inference_job.dataset.type == "s3": 150 | os.makedirs(DATASET_DIR) 151 | dataset_mover = DatasetMover() 152 | # If the dataset is s3, download it to the local directory 153 | # The path would look like bucket_name/path/to/dataset_name.tar.gz 154 | # local_name would then be = path/to/dataset_name.tar.gz 155 | local_name = self.batch_inference_job.dataset.name[self.batch_inference_job.dataset.name.find("/") + 1 :] 156 | dataset_mover.download( 157 | bucket_name=self.batch_inference_job.dataset.name.split("/")[0], 158 | object_name=f"{local_name}.tar.gz", 159 | output_folder_path=DATASET_DIR, 160 | ) 161 | print(os.listdir(DATASET_DIR)) 162 | print(os.listdir(f"{DATASET_DIR}/{local_name}")) 163 | splits["batch_inference"] = Dataset.from_generator( 164 | dataset_generator, 165 | gen_kwargs={ 166 | "dataset": f"{DATASET_DIR}/{local_name}", 167 | "split": "batch_inference", 168 | "from_disk": True, 169 | "task": self.batch_inference_job.task, 170 | "bos_token": bos_token, 171 | "eos_token": eos_token, 172 | }, 173 | ) 174 | elif self.batch_inference_job.dataset.type == "local": 175 | print("Loading dataset locally: ", os.listdir(self.batch_inference_job.dataset.path)) 176 | splits["train"] = Dataset.from_generator( 177 | dataset_generator, 178 | gen_kwargs={ 179 | "dataset": self.batch_inference_job.dataset.path, 180 | "split": "train", 181 | "from_disk": True, 182 | "task": self.batch_inference_job.task, 183 | "bos_token": bos_token, 184 | "eos_token": eos_token, 185 | }, 186 | ) 187 | try: 188 | splits["val"] = Dataset.from_generator( 189 | dataset_generator, 190 | gen_kwargs={ 191 | "dataset": self.batch_inference_job.dataset.path, 192 | "split": "val", 193 | "from_disk": True, 194 | "task": self.batch_inference_job.task, 195 | "bos_token": bos_token, 196 | "eos_token": eos_token, 197 | }, 198 | ) 199 | except: # pylint: disable=bare-except # noqa: E722 200 | print("Unable to create val dataset") 201 | try: 202 | splits["test"] = Dataset.from_generator( 203 | dataset_generator, 204 | gen_kwargs={ 205 | "dataset": self.batch_inference_job.dataset.path, 206 | "split": "test", 207 | "from_disk": True, 208 | "task": self.batch_inference_job.task, 209 | "bos_token": bos_token, 210 | "eos_token": eos_token, 211 | }, 212 | ) 213 | except: # pylint: disable=bare-except # noqa: E722 214 | print("Unable to create test dataset") 215 | else: 216 | raise ValueError(f"Unknown dataset_type: {self.batch_inference_job.dataset.type}") 217 | 218 | return DatasetDict(splits) 219 | 220 | def infer_on_dataset(self) -> None: 221 | """Generate batch predictions.""" 222 | if not self.batch_inference_with_eval: 223 | return 224 | # Batch inference config 225 | predicted_rows, (_, _) = self.batch_inference_with_eval.infer(self.batch_inference_split) 226 | 227 | with TemporaryDirectory() as temp_dir: 228 | temp_dir_path = Path(temp_dir) 229 | # If you're creating a new dataset from scratch: 230 | dataset_dict = DatasetDict( 231 | {"predictions": Dataset.from_list(predicted_rows)} # Assign the new dataset as the train split 232 | ) 233 | 234 | output_short_name = self.batch_inference_job.dataset.name.split("/")[-1] + "-output" 235 | print(f"Converting {output_short_name} to dataset dict") 236 | 237 | dataset_info = DatasetInfo( 238 | description=f"Contains output for {output_short_name} from batch inference", 239 | version="1.0.0", 240 | ) 241 | for split, dataset in dataset_dict.items(): 242 | dataset.dataset_info = dataset_info 243 | dataset_path = (temp_dir_path / output_short_name).as_posix() 244 | dataset_dict.save_to_disk(dataset_path) 245 | 246 | # Compress the folder 247 | print(f"Compressing the folder {dataset_path}") 248 | folder_to_compress = dataset_path 249 | output_tar_file = f"{output_short_name}.tar.gz" 250 | bucket_name = "fine-tuning-research" 251 | print(f"Uploading {output_tar_file} to {bucket_name}") 252 | dataset_mover = DatasetMover() 253 | dataset_mover.upload(folder_to_compress, output_tar_file, bucket_name) 254 | 255 | def run(self) -> None: 256 | """Execute the main inference loop.""" 257 | print("Starting batch inference") 258 | self.infer_on_dataset() 259 | print("Save and Uploading model..") 260 | finish() 261 | 262 | 263 | class BatchInferenceWithEval: 264 | """Batch inference class for generating predictions and running custom tests and metrics.""" 265 | 266 | def __init__( 267 | self, 268 | model: AutoModelForCausalLM, 269 | tokenizer: AutoTokenizer, 270 | task: str, 271 | run_tests_str: str = "", 272 | run_metrics_str: str = "", 273 | max_new_tokens: int = MAX_NEW_TOKENS, 274 | batch_size: int = 8, 275 | ): 276 | """Initialize the batch inference class.""" 277 | self.gen_config = GenerationConfig.from_pretrained(model.name_or_path, max_new_tokens=max_new_tokens) 278 | self.model = model 279 | self.tokenizer = tokenizer 280 | self.task = task 281 | self.run_tests_str = run_tests_str 282 | if run_tests_str and os.environ.get("ALLOW_CUSTOM_TESTS", "false").lower() == "true": 283 | exec(self.run_tests_str, globals()) 284 | else: 285 | self.run_tests_str = "" 286 | self.run_metrics_str = run_metrics_str 287 | if run_metrics_str and os.environ.get("ALLOW_CUSTOM_METRICS", "false").lower() == "true": 288 | exec(self.run_metrics_str, globals()) 289 | else: 290 | self.run_metrics_str = "" 291 | self.initial_predictions: list[str] = [] 292 | 293 | def generate(self, prompt: str) -> Any: 294 | """Generate a completion from a prompt.""" 295 | tokenized_prompt = self.tokenizer(prompt, return_tensors="pt", padding=True)["input_ids"].cuda() 296 | with torch.inference_mode(): 297 | output = self.model.generate( 298 | inputs=tokenized_prompt, generation_config=self.gen_config, pad_token_id=self.tokenizer.eos_token_id 299 | ) 300 | return self.tokenizer.decode(output[0][len(tokenized_prompt[0]) :], skip_special_tokens=True) 301 | 302 | def run_initial_predictions(self, rows: Dataset) -> Tuple[list[dict[str, Any]], Tuple[Table, dict[str, Any]]]: 303 | """Generate initial predictions for the sample split.""" 304 | # Test the provided code if present: 305 | print("Testing custom code, on ground truth if provided") 306 | test_rows = [] 307 | for example in tqdm(rows, leave=False): 308 | prompt = example["prompt"] 309 | actual = example["completion"] 310 | test_rows.append({"prompt": prompt, "actual": actual, "predicted": actual, "initial": actual}) 311 | self.execute_custom_code(test_rows) 312 | 313 | print("Generating initial predictions for sample split") 314 | predicted_rows = [] 315 | for example in tqdm(rows, leave=False): 316 | if self.task == "text": 317 | prompt = example["text"] 318 | else: 319 | prompt = example["prompt"] 320 | if not prompt.startswith(self.tokenizer.bos_token): 321 | prompt = f"{self.tokenizer.bos_token}{prompt}" 322 | predicted = self.generate(prompt=prompt) 323 | self.initial_predictions.append(predicted) 324 | actual = example["completion"] 325 | predicted_rows.append({"prompt": prompt, "actual": actual, "predicted": predicted, "initial": predicted}) 326 | return predicted_rows, self.execute_custom_code(predicted_rows) 327 | 328 | def infer(self, rows: Dataset) -> Tuple[list[dict[str, Any]], Tuple[Table, dict[str, Any]]]: 329 | """Generate batch predictions.""" 330 | print("Generating predictions for sample split") 331 | predicted_rows = [] 332 | for i, example in tqdm(enumerate(rows), leave=False): 333 | if self.task == "text": 334 | prompt = example["text"] 335 | else: 336 | prompt = example["prompt"] 337 | actual = example["completion"] 338 | if not prompt.startswith(self.tokenizer.bos_token): 339 | prompt = f"{self.tokenizer.bos_token}{prompt}" 340 | predicted = self.generate(prompt=prompt) 341 | row_obj = {"prompt": prompt, "actual": actual, "predicted": predicted} 342 | if self.initial_predictions: 343 | row_obj["initial"] = self.initial_predictions[i] 344 | predicted_rows.append(row_obj) 345 | return predicted_rows, self.execute_custom_code(predicted_rows) 346 | 347 | def execute_custom_code(self, rows: list[dict[str, Any]]) -> Tuple[Table, dict[str, Any]]: 348 | """Execute custom code for tests and metrics.""" 349 | records_table = Table(columns=["prompt", "predicted", "actual", "initial", "test_result", "errors"]) 350 | 351 | # Assuming run_tests_str and run_metrics_str contain your testing and metrics code respectively 352 | 353 | print("Updating records_table with predictions, test results, and errors") 354 | if self.run_tests_str and os.environ.get("ALLOW_CUSTOM_TESTS", "false").lower() == "true": 355 | # Execute dynamic code for tests 356 | print("Running custom tests") 357 | tests, errors = run_tests([row["prompt"] for row in rows], [row["predicted"] for row in rows]) # type: ignore # noqa: F821 358 | else: 359 | print("Skipping custom tests") 360 | tests, errors = [False] * len(rows), [""] * len(rows) 361 | 362 | if self.run_metrics_str and os.environ.get("ALLOW_CUSTOM_METRICS", "false").lower() == "true": 363 | # Execute dynamic code for metrics 364 | print("Running custom metrics") 365 | pts = [row["prompt"] for row in rows] 366 | acts = [row["actual"] for row in rows] 367 | prds = [row["predicted"] for row in rows] 368 | metrics = run_metrics(pts, acts, prds) # type: ignore # noqa: F821 369 | else: 370 | print("Skipping custom metrics") 371 | metrics = {} 372 | 373 | print("Building table") 374 | index = 0 375 | passed = 0 376 | for row in tqdm(rows, leave=False): 377 | test_result = "PASS" if tests[index] else "FAIL" 378 | passed += 1 if test_result == "PASS" else 0 379 | error_message = errors[index] if index < len(errors) else "" 380 | records_table.add_data( 381 | row["prompt"], 382 | row["predicted"], 383 | row["actual"], 384 | row.get("initial", "N/A"), 385 | test_result, 386 | error_message, 387 | ) 388 | index += 1 389 | 390 | metrics["passed"] = passed * 100 / len(rows) 391 | print("Metrics:", metrics) 392 | 393 | return records_table, metrics 394 | -------------------------------------------------------------------------------- /src/aihero/research/finetuning/train.py: -------------------------------------------------------------------------------- 1 | """Launch the training job inside a container.""" 2 | import os 3 | import time 4 | import traceback 5 | from typing import Any, Tuple 6 | 7 | import torch 8 | from datasets import Dataset, DatasetDict 9 | from huggingface_hub import login 10 | from peft import LoraConfig, get_peft_model 11 | from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig, TrainingArguments 12 | from trl import SFTTrainer 13 | from wandb import finish 14 | 15 | from aihero.research.config.schema import TrainingJob 16 | from aihero.research.finetuning.callback import LLMSampleCB 17 | from aihero.research.finetuning.utils import DatasetMover, dataset_generator, peft_module_casting_to_bf16 18 | 19 | CHECKPOINT_DIR = "/mnt/checkpoint" 20 | DATASET_DIR = "/mnt/dataset" 21 | 22 | if os.environ.get("HF_TOKEN", None): 23 | print("Logging in to HuggingFace Hub") 24 | login(token=os.environ["HF_TOKEN"]) 25 | 26 | 27 | class TrainingJobRunner: 28 | """Class to run a training job.""" 29 | 30 | def __init__(self, training_job: TrainingJob, is_distributed: bool = False): 31 | """Initialize the training job runner.""" 32 | self.training_job = training_job 33 | self.is_distributed = is_distributed 34 | print("Is Distributed: ", self.is_distributed) 35 | if self.is_distributed: 36 | backend = "nccl" if torch.cuda.is_available() else "gloo" 37 | import torch.distributed as dist 38 | 39 | dist.init_process_group(backend) 40 | self.local_rank = dist.get_rank() 41 | torch.cuda.set_device(self.local_rank) 42 | else: 43 | self.local_rank = 0 44 | print("Loading model") 45 | self.model, self.tokenizer = self.load_model() 46 | print("Loading dataset") # After model for tokenizer load to work 47 | self.dataset_dict = self.fetch_dataset() 48 | 49 | def load_model(self) -> Tuple[AutoModelForCausalLM, AutoTokenizer]: 50 | """Load the model from HuggingFace Hub or S3.""" 51 | use_4bit = self.training_job.quantized or False 52 | if use_4bit: 53 | # Compute dtype for 4-bit base models 54 | bnb_4bit_compute_dtype = "float16" 55 | # Quantization type (fp4 or nf4) 56 | bnb_4bit_quant_type = "nf4" 57 | # Activate nested quantization for 4-bit base models (double quantization) 58 | use_nested_quant = False 59 | 60 | # Load tokenizer and model with QLoRA configuration 61 | compute_dtype = getattr(torch, bnb_4bit_compute_dtype) 62 | 63 | bnb_config = BitsAndBytesConfig( 64 | load_in_4bit=use_4bit, 65 | bnb_4bit_quant_type=bnb_4bit_quant_type, 66 | bnb_4bit_compute_dtype=compute_dtype, 67 | bnb_4bit_use_double_quant=use_nested_quant, 68 | ) 69 | 70 | # Check GPU compatibility with bfloat16 71 | if compute_dtype == torch.float16 and use_4bit: 72 | major, _ = torch.cuda.get_device_capability() 73 | if major >= 8: 74 | print("=" * 80) 75 | print("Your GPU supports bfloat16: accelerate training with bf16=True") 76 | print("=" * 80) 77 | 78 | device_map = {"": self.local_rank if self.is_distributed else 0} 79 | print("Device map: ", device_map) 80 | 81 | if self.training_job.base.type == "huggingface": 82 | if use_4bit: 83 | # Load base model 84 | model = AutoModelForCausalLM.from_pretrained( 85 | self.training_job.base.name, 86 | quantization_config=bnb_config, 87 | device_map=device_map, 88 | trust_remote_code=True, 89 | ) 90 | model.config.use_cache = False 91 | model.config.pretraining_tp = 1 92 | else: 93 | model = AutoModelForCausalLM.from_pretrained( 94 | self.training_job.base.name, 95 | torch_dtype=torch.bfloat16, 96 | use_cache=False, 97 | trust_remote_code=True, 98 | device_map=device_map, 99 | ) 100 | tokenizer = AutoTokenizer.from_pretrained( 101 | self.training_job.base.name, 102 | trust_remote_code=True, 103 | add_eos_token=False, 104 | add_bos_token=False, 105 | ) 106 | # May need to have some custom padding logic here 107 | special_tokens = {"pad_token": "[PAD]"} 108 | tokenizer.add_special_tokens(special_tokens) 109 | if self.training_job.tokenizer and self.training_job.tokenizer.additional_tokens: 110 | tokenizer.add_tokens(self.training_job.tokenizer.additional_tokens) 111 | tokenizer.padding_side = "right" 112 | model.config.pad_token_id = tokenizer.pad_token_id 113 | model.resize_token_embeddings(len(tokenizer)) 114 | 115 | elif self.training_job.base.type == "s3": 116 | # TODO : Add s3 support 117 | raise NotImplementedError("S3 support not implemented yet") 118 | else: 119 | raise ValueError(f"Unknown base_model_type: {self.training_job.base.type}") 120 | return model, tokenizer 121 | 122 | def fetch_dataset(self) -> DatasetDict: 123 | """Fetch the dataset from HuggingFace Hub or S3.""" 124 | if self.training_job.dataset.type not in ["huggingface", "s3", "local"]: 125 | raise ValueError(f"Unknown dataset_type: {self.training_job.dataset.type}") 126 | 127 | if not os.path.exists(DATASET_DIR): 128 | os.makedirs(DATASET_DIR) 129 | if self.local_rank > 0: 130 | while not os.path.exists(f"{DATASET_DIR}/downloading_data.txt"): 131 | print(f"LOCAL RANK {self.local_rank}: Waiting for data download to begin") 132 | time.sleep(5) 133 | 134 | while os.path.exists(f"{DATASET_DIR}/downloading_data.txt"): 135 | print(f"LOCAL RANK {self.local_rank}: Waiting for data to be ready") 136 | time.sleep(5) 137 | 138 | if os.path.exists(f"{DATASET_DIR}/data_abort.txt"): 139 | print(f"LOCAL RANK {self.local_rank}: Data Abort") 140 | raise Exception("Data Abort") 141 | print(f"LOCAL RANK {self.local_rank}: Data ready") 142 | else: 143 | with open(f"{DATASET_DIR}/downloading_data.txt", "w") as f: 144 | f.write("Downloading Data") 145 | print(f"LOCAL RANK {self.local_rank}: loading data") 146 | try: 147 | splits = {} 148 | bos_token = self.tokenizer.bos_token 149 | eos_token = self.tokenizer.eos_token 150 | if self.training_job.dataset.type == "huggingface": 151 | splits["train"] = Dataset.from_generator( 152 | dataset_generator, 153 | gen_kwargs={ 154 | "dataset": self.training_job.dataset.name, 155 | "split": "train", 156 | "task": self.training_job.task, 157 | "bos_token": bos_token, 158 | "eos_token": eos_token, 159 | }, 160 | ) 161 | try: 162 | splits["val"] = Dataset.from_generator( 163 | dataset_generator, 164 | gen_kwargs={ 165 | "dataset": self.training_job.dataset.name, 166 | "split": "val", 167 | "task": self.training_job.task, 168 | "bos_token": bos_token, 169 | "eos_token": eos_token, 170 | }, 171 | ) 172 | except: # pylint: disable=bare-except # noqa: E722 173 | print("Unable to create val dataset") 174 | try: 175 | splits["test"] = Dataset.from_generator( 176 | dataset_generator, 177 | gen_kwargs={ 178 | "dataset": self.training_job.dataset.name, 179 | "split": "test", 180 | "task": self.training_job.task, 181 | "bos_token": bos_token, 182 | "eos_token": eos_token, 183 | }, 184 | ) 185 | except: # pylint: disable=bare-except # noqa: E722 186 | print("Unable to create test dataset") 187 | elif self.training_job.dataset.type == "s3": 188 | dataset_mover = DatasetMover() 189 | # If the dataset is s3, download it to the local directory 190 | # The path would look like bucket_name/path/to/dataset_name.tar.gz 191 | # local_name would then be = path/to/dataset_name.tar.gz 192 | local_name = self.training_job.dataset.name[self.training_job.dataset.name.find("/") + 1 :] 193 | if not os.path.exists(f"{DATASET_DIR}/{local_name}"): 194 | dataset_mover.download( 195 | bucket_name=self.training_job.dataset.name.split("/")[0], 196 | object_name=f"{local_name}.tar.gz", 197 | output_folder_path=DATASET_DIR, 198 | ) 199 | print(os.listdir(DATASET_DIR)) 200 | print(os.listdir(f"{DATASET_DIR}/{local_name}")) 201 | splits["train"] = Dataset.from_generator( 202 | dataset_generator, 203 | gen_kwargs={ 204 | "dataset": f"{DATASET_DIR}/{local_name}", 205 | "split": "train", 206 | "from_disk": True, 207 | "task": self.training_job.task, 208 | "bos_token": bos_token, 209 | "eos_token": eos_token, 210 | }, 211 | ) 212 | try: 213 | splits["val"] = Dataset.from_generator( 214 | dataset_generator, 215 | gen_kwargs={ 216 | "dataset": f"{DATASET_DIR}/{local_name}", 217 | "split": "val", 218 | "from_disk": True, 219 | "task": self.training_job.task, 220 | "bos_token": bos_token, 221 | "eos_token": eos_token, 222 | }, 223 | ) 224 | except: # pylint: disable=bare-except # noqa: E722 225 | print("Unable to create val dataset") 226 | try: 227 | splits["test"] = Dataset.from_generator( 228 | dataset_generator, 229 | gen_kwargs={ 230 | "dataset": f"{DATASET_DIR}/{local_name}", 231 | "split": "test", 232 | "from_disk": True, 233 | "task": self.training_job.task, 234 | "bos_token": bos_token, 235 | "eos_token": eos_token, 236 | }, 237 | ) 238 | except: # pylint: disable=bare-except # noqa: E722 239 | print("Unable to create test dataset") 240 | 241 | elif self.training_job.dataset.type == "local": 242 | print("Loading dataset locally: ", os.listdir(self.training_job.dataset.path)) 243 | splits["train"] = Dataset.from_generator( 244 | dataset_generator, 245 | gen_kwargs={ 246 | "dataset": self.training_job.dataset.path, 247 | "split": "train", 248 | "from_disk": True, 249 | "task": self.training_job.task, 250 | "bos_token": bos_token, 251 | "eos_token": eos_token, 252 | }, 253 | ) 254 | try: 255 | splits["val"] = Dataset.from_generator( 256 | dataset_generator, 257 | gen_kwargs={ 258 | "dataset": self.training_job.dataset.path, 259 | "split": "val", 260 | "from_disk": True, 261 | "task": self.training_job.task, 262 | "bos_token": bos_token, 263 | "eos_token": eos_token, 264 | }, 265 | ) 266 | except: # pylint: disable=bare-except # noqa: E722 267 | print("Unable to create val dataset") 268 | try: 269 | splits["test"] = Dataset.from_generator( 270 | dataset_generator, 271 | gen_kwargs={ 272 | "dataset": self.training_job.dataset.path, 273 | "split": "test", 274 | "from_disk": True, 275 | "task": self.training_job.task, 276 | "bos_token": bos_token, 277 | "eos_token": eos_token, 278 | }, 279 | ) 280 | except: # pylint: disable=bare-except # noqa: E722 281 | print("Unable to create test dataset") 282 | 283 | print(f"LOCAL RANK {self.local_rank}: data loaded") 284 | if os.path.exists(f"{DATASET_DIR}/downloading_data.txt"): 285 | os.remove(f"{DATASET_DIR}/downloading_data.txt") 286 | return DatasetDict(splits) 287 | except: # pylint: disable=bare-except # noqa: E722 288 | traceback.print_exc() 289 | if os.path.exists(f"{DATASET_DIR}/downloading_data.txt"): 290 | os.remove(f"{DATASET_DIR}/downloading_data.txt") 291 | print("Unable to load dataset") 292 | with open(f"{DATASET_DIR}/data_abort.txt", "w") as f: 293 | f.write("Data Abort") 294 | raise Exception("Data Abort") 295 | 296 | def freeze(self) -> None: 297 | """Freeze the model layers for SFT without PEFT.""" 298 | if self.training_job.freeze: 299 | n_freeze = self.training_job.freeze.n_freeze or 24 300 | 301 | if n_freeze > 0: 302 | module_name: str = "layers" 303 | 304 | def _find_mod(model: AutoModelForCausalLM, module_name: str) -> Any: 305 | for name, mod in model.named_modules(): 306 | if name.endswith(module_name): 307 | return mod 308 | 309 | # freeze layers (disable gradients) 310 | for param in self.model.parameters(): 311 | param.requires_grad = False 312 | 313 | # never freeze the head 314 | for param in self.model.lm_head.parameters(): 315 | param.requires_grad = True 316 | 317 | layers = _find_mod(self.model, module_name) 318 | for param in layers[n_freeze:].parameters(): 319 | param.requires_grad = True 320 | 321 | # Freeze embeddings for small memory decrease 322 | if self.training_job.freeze.freeze_embed: 323 | embed_tokens = _find_mod(self.model, "embed_tokens") 324 | embed_tokens.weight.requires_grad_(False) 325 | 326 | def save_model(self) -> None: 327 | """Save the model to a local directory.""" 328 | """Upload the model to HuggingFace Hub or S3.""" 329 | if os.getenv("RANK", "0") != "0": 330 | return 331 | if not self.training_job.output: 332 | return 333 | print("Saving model and tokenizer") 334 | local_name = self.training_job.output.name.split("/")[-1] 335 | self.model.save_pretrained(local_name) 336 | self.tokenizer.save_pretrained(local_name) 337 | print("Model saved at: ", os.listdir(local_name)) 338 | if self.training_job.output.type == "huggingface": 339 | print("Saving model and tokenizer to hf") 340 | self.model.push_to_hub(local_name) 341 | self.tokenizer.push_to_hub(local_name) 342 | elif self.training_job.output.type == "s3": 343 | # TODO : Add s3 support 344 | raise NotImplementedError("S3 support not implemented yet") 345 | 346 | def train(self) -> None: 347 | """Start training the model as defined by the config.""" 348 | # Assumes model is a causal language model 349 | self.model.config.use_cache = False 350 | train_split = self.dataset_dict["train"] 351 | val_split = self.dataset_dict.get("val", None) 352 | test_split = self.dataset_dict.get("test", None) 353 | 354 | if not val_split: 355 | assert not self.training_job.sft.eval_steps, "Eval steps should be 0 if no validation set is provided" 356 | self.training_job.sft.evaluation_strategy = "no" 357 | 358 | # SFT training config 359 | training_arguments_dict = self.training_job.sft.model_dump() 360 | training_arguments_dict["save_total_limit"] = 2 361 | training_arguments_dict["save_strategy"] = self.training_job.sft.evaluation_strategy 362 | training_arguments_dict["load_best_model_at_end"] = True 363 | training_arguments_dict["output_dir"] = CHECKPOINT_DIR 364 | training_arguments = TrainingArguments(**training_arguments_dict) 365 | # PEFT training config 366 | if self.training_job.peft: 367 | lora_config = LoraConfig(**self.training_job.peft.model_dump()) 368 | model = get_peft_model(self.model, lora_config) 369 | if self.training_job.sft.bf16: 370 | peft_module_casting_to_bf16(model, self.training_job.peft.model_dump()) 371 | model.print_trainable_parameters() 372 | peft_config = lora_config 373 | # self.training_job.sft.n_freeze = "all" 374 | elif self.training_job.freeze: 375 | self.freeze() 376 | peft_config = None 377 | 378 | trainer = SFTTrainer( 379 | tokenizer=self.tokenizer, 380 | model=self.model, 381 | train_dataset=train_split, 382 | eval_dataset=val_split, 383 | peft_config=peft_config, 384 | dataset_text_field="text", 385 | max_seq_length=self.training_job.trainer.max_seq_length, 386 | packing=self.training_job.trainer.packing, # Should you combine multiple examples into one sequence? 387 | args=training_arguments, 388 | ) 389 | 390 | task = self.training_job.dataset.task 391 | if self.training_job.eval and self.training_job.eval.tests: 392 | run_tests_str = self.training_job.eval.tests 393 | else: 394 | run_tests_str = "" 395 | if self.training_job.eval and self.training_job.eval.metrics: 396 | run_metrics_str = self.training_job.eval.metrics 397 | else: 398 | run_metrics_str = "" 399 | if test_split and test_split.num_rows > 0 and task == "completion": 400 | if os.environ.get("WANDB_API_KEY", None): 401 | # we instantiate the W&B callback with the trainer object and the dataset we want to sample from 402 | wandb_callback = LLMSampleCB( 403 | trainer, 404 | task, 405 | test_split, 406 | num_samples=test_split.num_rows if test_split.num_rows < 100 else 100, 407 | max_new_tokens=self.training_job.trainer.max_seq_length, 408 | run_tests_str=run_tests_str, 409 | run_metrics_str=run_metrics_str, 410 | ) 411 | wandb_callback.initialize() 412 | trainer.add_callback(wandb_callback) 413 | 414 | # distributed training config 415 | if trainer.is_fsdp_enabled: 416 | trainer.accelerator.state.fsdp_plugin.set_state_dict_type("FULL_STATE_DICT") 417 | 418 | print("Starting training") 419 | trainer.train() 420 | 421 | # if test_split and test_split.num_rows > 0: 422 | # trainer.predict(test_split) 423 | 424 | def run(self) -> None: 425 | """Execute the main training loop.""" 426 | print("Starting training") 427 | self.train() 428 | print("Saving model..") 429 | self.save_model() 430 | finish() 431 | --------------------------------------------------------------------------------