├── .dockerignore ├── .gitignore ├── CHANGELOG.md ├── CODE_OF_CONDUCT.md ├── CONTRIBUTING.md ├── Dockerfile ├── INSTALL.md ├── LICENSE ├── Makefile ├── PROJECT ├── README.md ├── assets ├── design │ ├── batch.png │ ├── batchdark.png │ ├── datasetplugin.png │ ├── datasetplugindark.png │ ├── evalandinference.png │ ├── evaldark.png │ ├── finetune.png │ ├── finetunedark.png │ ├── finetuneexdark.png │ ├── finetuneexperiment.png │ ├── finetunejob.png │ └── finetunejobdark.png ├── logo │ ├── Logo_DataTunerX - Horizontal - Color Dark.png │ └── Logo_DataTunerX - Horizontal - Color Light.png └── screenshot │ └── Job_Details.png ├── cmd ├── controller-manager │ └── app │ │ ├── controller_manager.go │ │ └── options │ │ └── options.go └── tuning │ ├── Dockerfile │ ├── README.md │ ├── build_image.sh │ ├── callback.py │ ├── ds_config.json │ ├── parser.py │ ├── prometheus │ ├── __init__.py │ ├── metrics.py │ ├── prometheus.proto │ └── prometheus_pb2.py │ ├── requirements.txt │ ├── template.py │ ├── train.py │ └── trainer.py ├── config └── rbac │ └── role.yaml ├── go.mod ├── go.sum ├── hack └── boilerplate.go.txt ├── internal └── controller │ └── finetune │ ├── finetune_controller.go │ ├── finetuneexperiment_controller.go │ └── finetunejob_controller.go ├── main.go └── pkg ├── config └── config.go ├── domain └── valueobject │ └── err.go ├── events └── events.go └── util ├── generate └── generate.go ├── handlererr └── handler.go ├── label └── label.go └── util.go /.dockerignore: -------------------------------------------------------------------------------- 1 | # More info: https://docs.docker.com/engine/reference/builder/#dockerignore-file 2 | # Ignore build and test binaries. 3 | bin/ 4 | testbin/ 5 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | 2 | # Binaries for programs and plugins 3 | *.exe 4 | *.exe~ 5 | *.dll 6 | *.so 7 | *.dylib 8 | bin 9 | testbin/* 10 | Dockerfile.cross 11 | 12 | # Test binary, build with `go test -c` 13 | *.test 14 | 15 | # Output of the go coverage tool, specifically when used with LiteIDE 16 | *.out 17 | 18 | # Kubernetes Generated files - skip generated files, except for vendored files 19 | 20 | !vendor/**/zz_generated.* 21 | 22 | # editor and IDE paraphernalia 23 | .idea 24 | *.swp 25 | *.swo 26 | *~ 27 | -------------------------------------------------------------------------------- /CHANGELOG.md: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DataTunerX/datatunerx/508be30c13e401fa8c67a69ec33b2a25661631d2/CHANGELOG.md -------------------------------------------------------------------------------- /CODE_OF_CONDUCT.md: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DataTunerX/datatunerx/508be30c13e401fa8c67a69ec33b2a25661631d2/CODE_OF_CONDUCT.md -------------------------------------------------------------------------------- /CONTRIBUTING.md: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DataTunerX/datatunerx/508be30c13e401fa8c67a69ec33b2a25661631d2/CONTRIBUTING.md -------------------------------------------------------------------------------- /Dockerfile: -------------------------------------------------------------------------------- 1 | # Build the manager binary 2 | FROM golang:1.20 as builder 3 | 4 | WORKDIR /workspace 5 | # Copy the Go Modules manifests 6 | COPY . . 7 | RUN go mod tidy 8 | 9 | RUN CGO_ENABLED=0 GOOS=linux GOARCH=amd64 go build -a -o manager main.go 10 | 11 | # Use distroless as minimal base image to package the manager binary 12 | # Refer to https://github.com/GoogleContainerTools/distroless for more details 13 | FROM alpine:3 14 | WORKDIR / 15 | COPY --from=builder /workspace/manager . 16 | 17 | ENTRYPOINT ["/manager"] 18 | -------------------------------------------------------------------------------- /INSTALL.md: -------------------------------------------------------------------------------- 1 | 2 | # DataTunerX Comprehensive Deployment Guide 3 | 4 | This guide provides detailed instructions for deploying DataTunerX in both online and offline environments. Ensure all prerequisites are met before proceeding with the deployment. 5 | 6 | ## Prerequisites 7 | 8 | Before starting, ensure your system meets the following requirements: 9 | 10 | - **Kubernetes v1.19+**: The container orchestration system for automating software deployment, scaling, and management. 11 | - **Minio** or another S3-compatible object storage: For storing large datasets and models. 12 | - **Harbor** or another container image registry: For securely storing and managing container images. 13 | - **Helm**: The Kubernetes package manager for deploying and managing applications. 14 | 15 | ## Deployment Artifacts 16 | 17 | Required artifacts: 18 | 19 | - `dtx-ctl` DataTunerX deployment tool. 20 | - `images-ai.tar`: Optional llm offline image package. (The image size is 47.1GB) 21 | - `images.tar`: Optional business component offline image package. 22 | 23 | ## Online Deployment 24 | 25 | ### 1. Download the `dtx-ctl` Tool 26 | 27 | ```bash 28 | wget https://github.com/DataTunerX/dtx-ctl/releases/download/v0.1.0/dtx-ctl.tar.gz 29 | ``` 30 | 31 | ### 2. Deploy DataTunerX 32 | 33 | Deploy with default settings: 34 | 35 | ```bash 36 | dtx-ctl install 37 | ``` 38 | 39 | Or, with custom settings: 40 | 41 | ```bash 42 | dtx-ctl install -n --set [Your-Custom-Settings] 43 | ``` 44 | 45 | Or, using a configuration file: 46 | 47 | ```bash 48 | dtx-ctl install -f /path/to/your/config.yaml 49 | ``` 50 | 51 | ## Offline Deployment 52 | 53 | Follow the online deployment steps for downloading the `dtx-ctl` tool and base images. Additionally, handle the business component images as follows: 54 | 55 | ### 1. Download the `dtx-ctl` Tool 56 | 57 | ```bash 58 | wget https://github.com/DataTunerX/dtx-ctl/releases/download/v0.1.0/dtx-ctl.tar.gz 59 | ``` 60 | 61 | ### 2. Download Base Images 62 | 63 | ```bash 64 | # Placeholder for the actual command to download the base AI images, currently the link is valid for 24 hours, if you need to apply for the download package please mention issuer 65 | wget https://public-download.daocloud.io/datatunerx/v0.1.0/images?e=1708664238&token=MHV7x1flrG19kzrdBNfPPO7JpBjTr__AMGzOtlq1:sZrIxT02pubO4BhPunS3sky3Fss= 66 | ``` 67 | 68 | ### 3. Download Base AI Images 69 | 70 | ```bash 71 | # Placeholder for the actual command to download the base AI images, currently the link is valid for 24 hours, if you need to apply for the download package please mention issuer 72 | wget https://public-download.daocloud.io/datatunerx/v0.1.0/images-ai?e=1708594433&token=MHV7x1flrG19kzrdBNfPPO7JpBjTr__AMGzOtlq1:DySesLobN0I7NeCBcYuZ74P8osA= 73 | ``` 74 | 75 | ### 4. Unzip and Import Business Image Package 76 | 77 | ```bash 78 | tar -zxcf images.tar -C /path/to/unzip 79 | cd /path/to/unzip/images 80 | ``` 81 | 82 | For Docker: 83 | 84 | ```bash 85 | docker load -i /path/to/image.tar 86 | ``` 87 | 88 | For Containerd: 89 | 90 | ```bash 91 | ctr -n k8s.io images import /path/to/image.tar 92 | ``` 93 | 94 | ### 5. Modify Image Tags and Push to Your Image Repository 95 | 96 | ```bash 97 | docker tag source_image:tag target_repository/target_image:tag 98 | docker push target_repository/target_image:tag 99 | ``` 100 | 101 | ### 6. Deploy DataTunerX 102 | 103 | Deploy using custom settings to configure your image repository: 104 | 105 | ```bash 106 | dtx-ctl install -n --registry=your_registry --repository=your_repository 107 | ``` 108 | 109 | Or, using a configuration file: 110 | 111 | ```bash 112 | dtx-ctl install -f /path/to/your/config.yaml 113 | ``` 114 | 115 | ## Command-Line Command List 116 | 117 | Commands to interact with `dtx-ctl`, including flags and subcommands for installation and management: 118 | 119 | ```bash 120 | # General usage 121 | dtx-ctl [command] 122 | 123 | # Available Commands 124 | completion Generate the autocompletion script for the specified shell 125 | help Help about any command 126 | install Install DataTunerX on Kubernetes 127 | uninstall Uninstall DataTunerX from Kubernetes 128 | 129 | # Flags for installation 130 | --chart-directory string Helm chart directory 131 | --dry-run Simulate an install 132 | --image-file-dir string Specify an image file directory 133 | --image-pull-policy string Image pull policy 134 | --image-pull-secret string Image pull secret 135 | --registry string Container registry 136 | --repository string Container repository 137 | --set stringArray Set helm values 138 | --set-file stringArray Set helm values from files 139 | --set-string stringArray Set helm STRING values 140 | -f, --values strings Specify helm values in a YAML file 141 | --version string Chart version 142 | --wait Wait for installation completion 143 | --wait-duration duration Maximum time to wait for resource readiness 144 | ``` 145 | 146 | Please replace placeholders with actual values and download links as required. 147 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | 2 | Apache License 3 | Version 2.0, January 2004 4 | http://www.apache.org/licenses/ 5 | 6 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 7 | 8 | 1. Definitions. 9 | 10 | "License" shall mean the terms and conditions for use, reproduction, 11 | and distribution as defined by Sections 1 through 9 of this document. 12 | 13 | "Licensor" shall mean the copyright owner or entity authorized by 14 | the copyright owner that is granting the License. 15 | 16 | "Legal Entity" shall mean the union of the acting entity and all 17 | other entities that control, are controlled by, or are under common 18 | control with that entity. For the purposes of this definition, 19 | "control" means (i) the power, direct or indirect, to cause the 20 | direction or management of such entity, whether by contract or 21 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 22 | outstanding shares, or (iii) beneficial ownership of such entity. 23 | 24 | "You" (or "Your") shall mean an individual or Legal Entity 25 | exercising permissions granted by this License. 26 | 27 | "Source" form shall mean the preferred form for making modifications, 28 | including but not limited to software source code, documentation 29 | source, and configuration files. 30 | 31 | "Object" form shall mean any form resulting from mechanical 32 | transformation or translation of a Source form, including but 33 | not limited to compiled object code, generated documentation, 34 | and conversions to other media types. 35 | 36 | "Work" shall mean the work of authorship, whether in Source or 37 | Object form, made available under the License, as indicated by a 38 | copyright notice that is included in or attached to the work 39 | (an example is provided in the Appendix below). 40 | 41 | "Derivative Works" shall mean any work, whether in Source or Object 42 | form, that is based on (or derived from) the Work and for which the 43 | editorial revisions, annotations, elaborations, or other modifications 44 | represent, as a whole, an original work of authorship. For the purposes 45 | of this License, Derivative Works shall not include works that remain 46 | separable from, or merely link (or bind by name) to the interfaces of, 47 | the Work and Derivative Works thereof. 48 | 49 | "Contribution" shall mean any work of authorship, including 50 | the original version of the Work and any modifications or additions 51 | to that Work or Derivative Works thereof, that is intentionally 52 | submitted to Licensor for inclusion in the Work by the copyright owner 53 | or by an individual or Legal Entity authorized to submit on behalf of 54 | the copyright owner. For the purposes of this definition, "submitted" 55 | means any form of electronic, verbal, or written communication sent 56 | to the Licensor or its representatives, including but not limited to 57 | communication on electronic mailing lists, source code control systems, 58 | and issue tracking systems that are managed by, or on behalf of, the 59 | Licensor for the purpose of discussing and improving the Work, but 60 | excluding communication that is conspicuously marked or otherwise 61 | designated in writing by the copyright owner as "Not a Contribution." 62 | 63 | "Contributor" shall mean Licensor and any individual or Legal Entity 64 | on behalf of whom a Contribution has been received by Licensor and 65 | subsequently incorporated within the Work. 66 | 67 | 2. Grant of Copyright License. Subject to the terms and conditions of 68 | this License, each Contributor hereby grants to You a perpetual, 69 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 70 | copyright license to reproduce, prepare Derivative Works of, 71 | publicly display, publicly perform, sublicense, and distribute the 72 | Work and such Derivative Works in Source or Object form. 73 | 74 | 3. Grant of Patent License. Subject to the terms and conditions of 75 | this License, each Contributor hereby grants to You a perpetual, 76 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 77 | (except as stated in this section) patent license to make, have made, 78 | use, offer to sell, sell, import, and otherwise transfer the Work, 79 | where such license applies only to those patent claims licensable 80 | by such Contributor that are necessarily infringed by their 81 | Contribution(s) alone or by combination of their Contribution(s) 82 | with the Work to which such Contribution(s) was submitted. If You 83 | institute patent litigation against any entity (including a 84 | cross-claim or counterclaim in a lawsuit) alleging that the Work 85 | or a Contribution incorporated within the Work constitutes direct 86 | or contributory patent infringement, then any patent licenses 87 | granted to You under this License for that Work shall terminate 88 | as of the date such litigation is filed. 89 | 90 | 4. Redistribution. You may reproduce and distribute copies of the 91 | Work or Derivative Works thereof in any medium, with or without 92 | modifications, and in Source or Object form, provided that You 93 | meet the following conditions: 94 | 95 | (a) You must give any other recipients of the Work or 96 | Derivative Works a copy of this License; and 97 | 98 | (b) You must cause any modified files to carry prominent notices 99 | stating that You changed the files; and 100 | 101 | (c) You must retain, in the Source form of any Derivative Works 102 | that You distribute, all copyright, patent, trademark, and 103 | attribution notices from the Source form of the Work, 104 | excluding those notices that do not pertain to any part of 105 | the Derivative Works; and 106 | 107 | (d) If the Work includes a "NOTICE" text file as part of its 108 | distribution, then any Derivative Works that You distribute must 109 | include a readable copy of the attribution notices contained 110 | within such NOTICE file, excluding those notices that do not 111 | pertain to any part of the Derivative Works, in at least one 112 | of the following places: within a NOTICE text file distributed 113 | as part of the Derivative Works; within the Source form or 114 | documentation, if provided along with the Derivative Works; or, 115 | within a display generated by the Derivative Works, if and 116 | wherever such third-party notices normally appear. The contents 117 | of the NOTICE file are for informational purposes only and 118 | do not modify the License. You may add Your own attribution 119 | notices within Derivative Works that You distribute, alongside 120 | or as an addendum to the NOTICE text from the Work, provided 121 | that such additional attribution notices cannot be construed 122 | as modifying the License. 123 | 124 | You may add Your own copyright statement to Your modifications and 125 | may provide additional or different license terms and conditions 126 | for use, reproduction, or distribution of Your modifications, or 127 | for any such Derivative Works as a whole, provided Your use, 128 | reproduction, and distribution of the Work otherwise complies with 129 | the conditions stated in this License. 130 | 131 | 5. Submission of Contributions. Unless You explicitly state otherwise, 132 | any Contribution intentionally submitted for inclusion in the Work 133 | by You to the Licensor shall be under the terms and conditions of 134 | this License, without any additional terms or conditions. 135 | Notwithstanding the above, nothing herein shall supersede or modify 136 | the terms of any separate license agreement you may have executed 137 | with Licensor regarding such Contributions. 138 | 139 | 6. Trademarks. This License does not grant permission to use the trade 140 | names, trademarks, service marks, or product names of the Licensor, 141 | except as required for reasonable and customary use in describing the 142 | origin of the Work and reproducing the content of the NOTICE file. 143 | 144 | 7. Disclaimer of Warranty. Unless required by applicable law or 145 | agreed to in writing, Licensor provides the Work (and each 146 | Contributor provides its Contributions) on an "AS IS" BASIS, 147 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 148 | implied, including, without limitation, any warranties or conditions 149 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 150 | PARTICULAR PURPOSE. You are solely responsible for determining the 151 | appropriateness of using or redistributing the Work and assume any 152 | risks associated with Your exercise of permissions under this License. 153 | 154 | 8. Limitation of Liability. In no event and under no legal theory, 155 | whether in tort (including negligence), contract, or otherwise, 156 | unless required by applicable law (such as deliberate and grossly 157 | negligent acts) or agreed to in writing, shall any Contributor be 158 | liable to You for damages, including any direct, indirect, special, 159 | incidental, or consequential damages of any character arising as a 160 | result of this License or out of the use or inability to use the 161 | Work (including but not limited to damages for loss of goodwill, 162 | work stoppage, computer failure or malfunction, or any and all 163 | other commercial damages or losses), even if such Contributor 164 | has been advised of the possibility of such damages. 165 | 166 | 9. Accepting Warranty or Additional Liability. While redistributing 167 | the Work or Derivative Works thereof, You may choose to offer, 168 | and charge a fee for, acceptance of support, warranty, indemnity, 169 | or other liability obligations and/or rights consistent with this 170 | License. However, in accepting such obligations, You may act only 171 | on Your own behalf and on Your sole responsibility, not on behalf 172 | of any other Contributor, and only if You agree to indemnify, 173 | defend, and hold each Contributor harmless for any liability 174 | incurred by, or claims asserted against, such Contributor by reason 175 | of your accepting any such warranty or additional liability. 176 | 177 | END OF TERMS AND CONDITIONS 178 | 179 | APPENDIX: How to apply the Apache License to your work. 180 | 181 | To apply the Apache License to your work, attach the following 182 | boilerplate notice, with the fields enclosed by brackets "[]" 183 | replaced with your own identifying information. (Don't include 184 | the brackets!) The text should be enclosed in the appropriate 185 | comment syntax for the file format. We also recommend that a 186 | file or class name and description of purpose be included on the 187 | same "printed page" as the copyright notice for easier 188 | identification within third-party archives. 189 | 190 | Copyright [yyyy] [name of copyright owner] 191 | 192 | Licensed under the Apache License, Version 2.0 (the "License"); 193 | you may not use this file except in compliance with the License. 194 | You may obtain a copy of the License at 195 | 196 | http://www.apache.org/licenses/LICENSE-2.0 197 | 198 | Unless required by applicable law or agreed to in writing, software 199 | distributed under the License is distributed on an "AS IS" BASIS, 200 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 201 | See the License for the specific language governing permissions and 202 | limitations under the License. -------------------------------------------------------------------------------- /Makefile: -------------------------------------------------------------------------------- 1 | # VERSION defines the project version for the bundle. 2 | # Update this value when you upgrade the version of your project. 3 | # To re-generate a bundle for another specific version without changing the standard setup, you can: 4 | # - use the VERSION as arg of the bundle target (e.g make bundle VERSION=0.0.2) 5 | # - use environment variables to overwrite this value (e.g export VERSION=0.0.2) 6 | VERSION ?= 0.0.1 7 | 8 | # CHANNELS define the bundle channels used in the bundle. 9 | # Add a new line here if you would like to change its default config. (E.g CHANNELS = "candidate,fast,stable") 10 | # To re-generate a bundle for other specific channels without changing the standard setup, you can: 11 | # - use the CHANNELS as arg of the bundle target (e.g make bundle CHANNELS=candidate,fast,stable) 12 | # - use environment variables to overwrite this value (e.g export CHANNELS="candidate,fast,stable") 13 | ifneq ($(origin CHANNELS), undefined) 14 | BUNDLE_CHANNELS := --channels=$(CHANNELS) 15 | endif 16 | 17 | # DEFAULT_CHANNEL defines the default channel used in the bundle. 18 | # Add a new line here if you would like to change its default config. (E.g DEFAULT_CHANNEL = "stable") 19 | # To re-generate a bundle for any other default channel without changing the default setup, you can: 20 | # - use the DEFAULT_CHANNEL as arg of the bundle target (e.g make bundle DEFAULT_CHANNEL=stable) 21 | # - use environment variables to overwrite this value (e.g export DEFAULT_CHANNEL="stable") 22 | ifneq ($(origin DEFAULT_CHANNEL), undefined) 23 | BUNDLE_DEFAULT_CHANNEL := --default-channel=$(DEFAULT_CHANNEL) 24 | endif 25 | BUNDLE_METADATA_OPTS ?= $(BUNDLE_CHANNELS) $(BUNDLE_DEFAULT_CHANNEL) 26 | 27 | # IMAGE_TAG_BASE defines the docker.io namespace and part of the image name for remote images. 28 | # This variable is used to construct full image tags for bundle and catalog images. 29 | # 30 | # For example, running 'make bundle-build bundle-push catalog-build catalog-push' will build and push both 31 | # datatunerx.io/finetune-experiment-controller-bundle:$VERSION and datatunerx.io/finetune-experiment-controller-catalog:$VERSION. 32 | IMAGE_TAG_BASE ?= datatunerx.io/finetune-experiment-controller 33 | 34 | # BUNDLE_IMG defines the image:tag used for the bundle. 35 | # You can use it as an arg. (E.g make bundle-build BUNDLE_IMG=/:) 36 | BUNDLE_IMG ?= $(IMAGE_TAG_BASE)-bundle:v$(VERSION) 37 | 38 | # BUNDLE_GEN_FLAGS are the flags passed to the operator-sdk generate bundle command 39 | BUNDLE_GEN_FLAGS ?= -q --overwrite --version $(VERSION) $(BUNDLE_METADATA_OPTS) 40 | 41 | # USE_IMAGE_DIGESTS defines if images are resolved via tags or digests 42 | # You can enable this value if you would like to use SHA Based Digests 43 | # To enable set flag to true 44 | USE_IMAGE_DIGESTS ?= false 45 | ifeq ($(USE_IMAGE_DIGESTS), true) 46 | BUNDLE_GEN_FLAGS += --use-image-digests 47 | endif 48 | 49 | # Set the Operator SDK version to use. By default, what is installed on the system is used. 50 | # This is useful for CI or a project to utilize a specific version of the operator-sdk toolkit. 51 | OPERATOR_SDK_VERSION ?= v1.31.0 52 | 53 | # Image URL to use all building/pushing image targets 54 | IMG ?= controller:latest 55 | # ENVTEST_K8S_VERSION refers to the version of kubebuilder assets to be downloaded by envtest binary. 56 | ENVTEST_K8S_VERSION = 1.26.0 57 | 58 | # Get the currently used golang install path (in GOPATH/bin, unless GOBIN is set) 59 | ifeq (,$(shell go env GOBIN)) 60 | GOBIN=$(shell go env GOPATH)/bin 61 | else 62 | GOBIN=$(shell go env GOBIN) 63 | endif 64 | 65 | # Setting SHELL to bash allows bash commands to be executed by recipes. 66 | # Options are set to exit when a recipe line exits non-zero or a piped command fails. 67 | SHELL = /usr/bin/env bash -o pipefail 68 | .SHELLFLAGS = -ec 69 | 70 | .PHONY: all 71 | all: build 72 | 73 | ##@ General 74 | 75 | # The help target prints out all targets with their descriptions organized 76 | # beneath their categories. The categories are represented by '##@' and the 77 | # target descriptions by '##'. The awk commands is responsible for reading the 78 | # entire set of makefiles included in this invocation, looking for lines of the 79 | # file as xyz: ## something, and then pretty-format the target and help. Then, 80 | # if there's a line with ##@ something, that gets pretty-printed as a category. 81 | # More info on the usage of ANSI control characters for terminal formatting: 82 | # https://en.wikipedia.org/wiki/ANSI_escape_code#SGR_parameters 83 | # More info on the awk command: 84 | # http://linuxcommand.org/lc3_adv_awk.php 85 | 86 | .PHONY: help 87 | help: ## Display this help. 88 | @awk 'BEGIN {FS = ":.*##"; printf "\nUsage:\n make \033[36m\033[0m\n"} /^[a-zA-Z_0-9-]+:.*?##/ { printf " \033[36m%-15s\033[0m %s\n", $$1, $$2 } /^##@/ { printf "\n\033[1m%s\033[0m\n", substr($$0, 5) } ' $(MAKEFILE_LIST) 89 | 90 | ##@ Development 91 | 92 | .PHONY: manifests 93 | manifests: controller-gen ## Generate WebhookConfiguration, ClusterRole and CustomResourceDefinition objects. 94 | $(CONTROLLER_GEN) rbac:roleName=manager-role crd webhook paths="./..." output:crd:artifacts:config=config/crd/bases 95 | 96 | .PHONY: generate 97 | generate: controller-gen ## Generate code containing DeepCopy, DeepCopyInto, and DeepCopyObject method implementations. 98 | $(CONTROLLER_GEN) object:headerFile="hack/boilerplate.go.txt" paths="./..." 99 | 100 | .PHONY: fmt 101 | fmt: ## Run go fmt against code. 102 | go fmt ./... 103 | 104 | .PHONY: vet 105 | vet: ## Run go vet against code. 106 | go vet ./... 107 | 108 | .PHONY: test 109 | test: manifests generate fmt vet envtest ## Run tests. 110 | KUBEBUILDER_ASSETS="$(shell $(ENVTEST) use $(ENVTEST_K8S_VERSION) --bin-dir $(LOCALBIN) -p path)" go test ./... -coverprofile cover.out 111 | 112 | ##@ Build 113 | 114 | .PHONY: build 115 | build: manifests generate fmt vet ## Build manager binary. 116 | go build -o bin/manager main.go 117 | 118 | .PHONY: run 119 | run: manifests generate fmt vet ## Run a controller from your host. 120 | go run ./main.go 121 | 122 | # If you wish built the manager image targeting other platforms you can use the --platform flag. 123 | # (i.e. docker build --platform linux/arm64 ). However, you must enable docker buildKit for it. 124 | # More info: https://docs.docker.com/develop/develop-images/build_enhancements/ 125 | .PHONY: docker-build 126 | docker-build: test ## Build docker image with the manager. 127 | docker build -t ${IMG} . 128 | 129 | .PHONY: docker-push 130 | docker-push: ## Push docker image with the manager. 131 | docker push ${IMG} 132 | 133 | # PLATFORMS defines the target platforms for the manager image be build to provide support to multiple 134 | # architectures. (i.e. make docker-buildx IMG=myregistry/mypoperator:0.0.1). To use this option you need to: 135 | # - able to use docker buildx . More info: https://docs.docker.com/build/buildx/ 136 | # - have enable BuildKit, More info: https://docs.docker.com/develop/develop-images/build_enhancements/ 137 | # - be able to push the image for your registry (i.e. if you do not inform a valid value via IMG=> then the export will fail) 138 | # To properly provided solutions that supports more than one platform you should use this option. 139 | PLATFORMS ?= linux/arm64,linux/amd64,linux/s390x,linux/ppc64le 140 | .PHONY: docker-buildx 141 | docker-buildx: test ## Build and push docker image for the manager for cross-platform support 142 | # copy existing Dockerfile and insert --platform=${BUILDPLATFORM} into Dockerfile.cross, and preserve the original Dockerfile 143 | sed -e '1 s/\(^FROM\)/FROM --platform=\$$\{BUILDPLATFORM\}/; t' -e ' 1,// s//FROM --platform=\$$\{BUILDPLATFORM\}/' Dockerfile > Dockerfile.cross 144 | - docker buildx create --name project-v3-builder 145 | docker buildx use project-v3-builder 146 | - docker buildx build --push --platform=$(PLATFORMS) --tag ${IMG} -f Dockerfile.cross . 147 | - docker buildx rm project-v3-builder 148 | rm Dockerfile.cross 149 | 150 | ##@ Deployment 151 | 152 | ifndef ignore-not-found 153 | ignore-not-found = false 154 | endif 155 | 156 | .PHONY: install 157 | install: manifests kustomize ## Install CRDs into the K8s cluster specified in ~/.kube/config. 158 | $(KUSTOMIZE) build config/crd | kubectl apply -f - 159 | 160 | .PHONY: uninstall 161 | uninstall: manifests kustomize ## Uninstall CRDs from the K8s cluster specified in ~/.kube/config. Call with ignore-not-found=true to ignore resource not found errors during deletion. 162 | $(KUSTOMIZE) build config/crd | kubectl delete --ignore-not-found=$(ignore-not-found) -f - 163 | 164 | .PHONY: deploy 165 | deploy: manifests kustomize ## Deploy controller to the K8s cluster specified in ~/.kube/config. 166 | cd config/manager && $(KUSTOMIZE) edit set image controller=${IMG} 167 | $(KUSTOMIZE) build config/default | kubectl apply -f - 168 | 169 | .PHONY: undeploy 170 | undeploy: ## Undeploy controller from the K8s cluster specified in ~/.kube/config. Call with ignore-not-found=true to ignore resource not found errors during deletion. 171 | $(KUSTOMIZE) build config/default | kubectl delete --ignore-not-found=$(ignore-not-found) -f - 172 | 173 | ##@ Build Dependencies 174 | 175 | ## Location to install dependencies to 176 | LOCALBIN ?= $(shell pwd)/bin 177 | $(LOCALBIN): 178 | mkdir -p $(LOCALBIN) 179 | 180 | ## Tool Binaries 181 | KUSTOMIZE ?= $(LOCALBIN)/kustomize 182 | CONTROLLER_GEN ?= $(LOCALBIN)/controller-gen 183 | ENVTEST ?= $(LOCALBIN)/setup-envtest 184 | 185 | ## Tool Versions 186 | KUSTOMIZE_VERSION ?= v3.8.7 187 | CONTROLLER_TOOLS_VERSION ?= v0.11.1 188 | 189 | KUSTOMIZE_INSTALL_SCRIPT ?= "https://raw.githubusercontent.com/kubernetes-sigs/kustomize/master/hack/install_kustomize.sh" 190 | .PHONY: kustomize 191 | kustomize: $(KUSTOMIZE) ## Download kustomize locally if necessary. If wrong version is installed, it will be removed before downloading. 192 | $(KUSTOMIZE): $(LOCALBIN) 193 | @if test -x $(LOCALBIN)/kustomize && ! $(LOCALBIN)/kustomize version | grep -q $(KUSTOMIZE_VERSION); then \ 194 | echo "$(LOCALBIN)/kustomize version is not expected $(KUSTOMIZE_VERSION). Removing it before installing."; \ 195 | rm -rf $(LOCALBIN)/kustomize; \ 196 | fi 197 | test -s $(LOCALBIN)/kustomize || { curl -Ss $(KUSTOMIZE_INSTALL_SCRIPT) | bash -s -- $(subst v,,$(KUSTOMIZE_VERSION)) $(LOCALBIN); } 198 | 199 | .PHONY: controller-gen 200 | controller-gen: $(CONTROLLER_GEN) ## Download controller-gen locally if necessary. If wrong version is installed, it will be overwritten. 201 | $(CONTROLLER_GEN): $(LOCALBIN) 202 | test -s $(LOCALBIN)/controller-gen && $(LOCALBIN)/controller-gen --version | grep -q $(CONTROLLER_TOOLS_VERSION) || \ 203 | GOBIN=$(LOCALBIN) go install sigs.k8s.io/controller-tools/cmd/controller-gen@$(CONTROLLER_TOOLS_VERSION) 204 | 205 | .PHONY: envtest 206 | envtest: $(ENVTEST) ## Download envtest-setup locally if necessary. 207 | $(ENVTEST): $(LOCALBIN) 208 | test -s $(LOCALBIN)/setup-envtest || GOBIN=$(LOCALBIN) go install sigs.k8s.io/controller-runtime/tools/setup-envtest@latest 209 | 210 | .PHONY: operator-sdk 211 | OPERATOR_SDK ?= $(LOCALBIN)/operator-sdk 212 | operator-sdk: ## Download operator-sdk locally if necessary. 213 | ifeq (,$(wildcard $(OPERATOR_SDK))) 214 | ifeq (, $(shell which operator-sdk 2>/dev/null)) 215 | @{ \ 216 | set -e ;\ 217 | mkdir -p $(dir $(OPERATOR_SDK)) ;\ 218 | OS=$(shell go env GOOS) && ARCH=$(shell go env GOARCH) && \ 219 | curl -sSLo $(OPERATOR_SDK) https://github.com/operator-framework/operator-sdk/releases/download/$(OPERATOR_SDK_VERSION)/operator-sdk_$${OS}_$${ARCH} ;\ 220 | chmod +x $(OPERATOR_SDK) ;\ 221 | } 222 | else 223 | OPERATOR_SDK = $(shell which operator-sdk) 224 | endif 225 | endif 226 | 227 | .PHONY: bundle 228 | bundle: manifests kustomize operator-sdk ## Generate bundle manifests and metadata, then validate generated files. 229 | $(OPERATOR_SDK) generate kustomize manifests -q 230 | cd config/manager && $(KUSTOMIZE) edit set image controller=$(IMG) 231 | $(KUSTOMIZE) build config/manifests | $(OPERATOR_SDK) generate bundle $(BUNDLE_GEN_FLAGS) 232 | $(OPERATOR_SDK) bundle validate ./bundle 233 | 234 | .PHONY: bundle-build 235 | bundle-build: ## Build the bundle image. 236 | docker build -f bundle.Dockerfile -t $(BUNDLE_IMG) . 237 | 238 | .PHONY: bundle-push 239 | bundle-push: ## Push the bundle image. 240 | $(MAKE) docker-push IMG=$(BUNDLE_IMG) 241 | 242 | .PHONY: opm 243 | OPM = ./bin/opm 244 | opm: ## Download opm locally if necessary. 245 | ifeq (,$(wildcard $(OPM))) 246 | ifeq (,$(shell which opm 2>/dev/null)) 247 | @{ \ 248 | set -e ;\ 249 | mkdir -p $(dir $(OPM)) ;\ 250 | OS=$(shell go env GOOS) && ARCH=$(shell go env GOARCH) && \ 251 | curl -sSLo $(OPM) https://github.com/operator-framework/operator-registry/releases/download/v1.23.0/$${OS}-$${ARCH}-opm ;\ 252 | chmod +x $(OPM) ;\ 253 | } 254 | else 255 | OPM = $(shell which opm) 256 | endif 257 | endif 258 | 259 | # A comma-separated list of bundle images (e.g. make catalog-build BUNDLE_IMGS=example.com/operator-bundle:v0.1.0,example.com/operator-bundle:v0.2.0). 260 | # These images MUST exist in a registry and be pull-able. 261 | BUNDLE_IMGS ?= $(BUNDLE_IMG) 262 | 263 | # The image tag given to the resulting catalog image (e.g. make catalog-build CATALOG_IMG=example.com/operator-catalog:v0.2.0). 264 | CATALOG_IMG ?= $(IMAGE_TAG_BASE)-catalog:v$(VERSION) 265 | 266 | # Set CATALOG_BASE_IMG to an existing catalog image tag to add $BUNDLE_IMGS to that image. 267 | ifneq ($(origin CATALOG_BASE_IMG), undefined) 268 | FROM_INDEX_OPT := --from-index $(CATALOG_BASE_IMG) 269 | endif 270 | 271 | # Build a catalog image by adding bundle images to an empty catalog using the operator package manager tool, 'opm'. 272 | # This recipe invokes 'opm' in 'semver' bundle add mode. For more information on add modes, see: 273 | # https://github.com/operator-framework/community-operators/blob/7f1438c/docs/packaging-operator.md#updating-your-existing-operator 274 | .PHONY: catalog-build 275 | catalog-build: opm ## Build a catalog image. 276 | $(OPM) index add --container-tool docker --mode semver --tag $(CATALOG_IMG) --bundles $(BUNDLE_IMGS) $(FROM_INDEX_OPT) 277 | 278 | # Push the catalog image. 279 | .PHONY: catalog-push 280 | catalog-push: ## Push a catalog image. 281 | $(MAKE) docker-push IMG=$(CATALOG_IMG) 282 | -------------------------------------------------------------------------------- /PROJECT: -------------------------------------------------------------------------------- 1 | # Code generated by tool. DO NOT EDIT. 2 | # This file is used to track the info used to scaffold your project 3 | # and allow the plugins properly work. 4 | # More info: https://book.kubebuilder.io/reference/project-config.html 5 | domain: datatunerx.io 6 | layout: 7 | - go.kubebuilder.io/v4-alpha 8 | multigroup: true 9 | plugins: 10 | manifests.sdk.operatorframework.io/v2: {} 11 | scorecard.sdk.operatorframework.io/v2: {} 12 | projectName: finetune-experiment-controller 13 | repo: github.com/DataTunerX/datatunerx 14 | resources: 15 | - controller: true 16 | domain: datatunerx.io 17 | group: finetune 18 | kind: FinetuneExperiment 19 | version: v1beta1 20 | - controller: true 21 | domain: datatunerx.io 22 | group: finetune 23 | kind: FinetuneJob 24 | version: v1beta1 25 | version: "3" 26 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | ![DTX Logo](https://raw.githubusercontent.com/DataTunerX/datatunerx-controller/main/assets/logo/Logo_DataTunerX%20-%20Horizontal%20-%20Color%20Light.png#gh-dark-mode-only) 2 | ![DTX Logo](https://raw.githubusercontent.com/DataTunerX/datatunerx-controller/main/assets/logo/Logo_DataTunerX%20-%20Horizontal%20-%20Color%20Dark.png#gh-light-mode-only) 3 | 4 | ![Kubernetes](https://img.shields.io/badge/kubernetes-%23326ce5.svg?style=flat&logo=kubernetes&logoColor=white) 5 | ![release](https://img.shields.io/badge/version-0.1.0-blue) 6 | ![fine-tuning](https://img.shields.io/badge/fine--tuning-8B3E3) 7 | # Welcome 👋 8 | 9 | ***DataTunerX (DTX)*** is designed as a cloud-native solution integrated with distributed computing frameworks. Leveraging scalable *GPU* resources, it's a platform built for efficient fine-tuning *LLMs* with a focus on practical utility. Its core strength lies in facilitating batch fine-tuning tasks, enabling users to conduct multiple tasks concurrently within a single ***experiment***. ***DTX*** encompasses essential capabilities such as ***dataset management***, ***hyperparameter control***, ***fine-tuning workflows***, ***model management***, ***model evaluation***, ***model comparison inference***, and a ***modular plugin system***. 10 | 11 | **Technology stack**: 12 | 13 | ***DTX*** is built on cloud-native principles, employing a variety of [*Operators*](https://www.redhat.com/en/topics/containers/what-is-a-kubernetes-operator) that consist of distinct *Custom Resource Definitions (CRDs)* and *Controller* logic. Developed primarily in *Go*, the implementation utilizes the [*operator-sdk*](https://github.com/operator-framework/operator-sdk) toolkit. Operating within a [*Kubernetes (K8s)*](https://github.com/kubernetes/kubernetes) environment, ***DTX*** relies on the operator pattern for *CRD* development and management. Furthermore, ***DTX*** integrates with [*kuberay*](https://github.com/ray-project/kuberay) to harness distributed execution and inference capabilities. 14 | 15 | **Status**: 16 | 17 | *v0.1.0* - Early development phase. [CHANGELOG](CHANGELOG.md) for details on recent updates. 18 | 19 | **Quick Demo & More Documentation**: 20 | 21 | - Demo 22 |
23 | 24 | 25 | 26 |
27 | 28 | - [Documentation](https://github.com/DataTunerX/datatunerx-controller) (COMING SOON) 29 | 30 | **Screenshot**: 31 | 32 | ![**DTX Screenshot**](https://raw.githubusercontent.com/DataTunerX/datatunerx-controller/main/assets/screenshot/Job_Details.png) 33 | 34 | # What DTX can do? 💪 35 | 36 | ***DTX*** empowers users with a robust set of features designed for efficient fine-tuning of large language models. Dive into the capabilities that make ***DTX*** a versatile platform: 37 | 38 | ## 1. Dataset Management 🗄️ 39 | Effortlessly manage datasets by supporting both *S3* protocol (*http* is coming) and local dataset uploads. Datasets are organized with splits such as test, validation, and training. Additionally, feature mapping enhances flexibility for fine-tuning jobs. 40 |
41 | FineTune 42 |
43 | 44 | ## 2. Fine-Tuning Experiments 🧪 45 | Conduct fine-tuning experiments by creating multiple fine-tuning jobs. Each job can employ different llms, datasets, and hyperparameters. Evaluate the fine-tuned models uniformly through the experiment's evaluation unit to identify the fine-tuning results. 46 |
47 | FineTune 48 | FineTuneJob 49 | FineTuneExperiment 50 |
51 | 52 | ## 3. Job Insights 📊 53 | Gain detailed insights into each fine-tuning job within an experiment. Explore job details, logs, and metric visualizations, including learning rate trends, training loss, and more. 54 | 55 | ## 4. Model Repository 🗃️ 56 | Store LLMs in the model repository, facilitating efficient management and deployment of inference services. 57 |
58 | FineTune 59 |
60 | 61 | ## 5. Hyperparameter Group Management 🧰 62 | Utilize a rich parameter configuration system with support for diverse parameters and template-based differentiation. 63 | 64 | ## 6. Inference Services 🚀 65 | Deploy inference services for multiple models simultaneously, enabling straightforward comparison and selection of the best-performing model. 66 | 67 | ## 7. Plugin System 🧩 68 | Leverage the plugin system for datasets and evaluation units, allowing users to integrate specialized datasets and evaluation methods tailored to their unique requirements. 69 | 70 | ## 8. More Coming 🤹‍♀️ 71 | ***DTX*** offers a comprehensive suite of tools, ensuring a seamless fine-tuning experience with flexibility and powerful functionality. Explore each feature to tailor your fine-tuning tasks according to your specific needs. 72 | 73 | # Why DTX? 🤔 74 | 75 | ***DTX*** stands out as the preferred choice for fine-tuning large language models, offering distinct advantages that address critical challenges in natural language processing: 76 | 77 | ## 1. Optimized Resource Utilization 🚀 78 | - **Efficient GPU Integration:** Seamlessly integrates with distributed computing frameworks, ensuring optimal utilization of scalable GPU resources, even in resource-constrained environments. 79 | 80 | ## 2. Streamlined Batch Fine-Tuning 🔄 81 | - **Concurrent Task Execution:** Excels in batch fine-tuning, enabling concurrent execution of multiple tasks within a single experiment. This enhances workflow efficiency and overall productivity. 82 |
83 | FineTuneExperiment 84 |
85 | 86 | ## 3. Robust Feature Set for Varied Needs 🧰 87 | - **Diverse Capabilities:** From dataset management to model management, ***DTX*** provides a comprehensive feature set catering to diverse fine-tuning requirements. 88 | 89 | ## 4. Simplified Experimentation with Lower Entry Barriers 🧪 90 | - **User-Friendly Experimentation:** Empowers users to effortlessly conduct fine-tuning experiments with varying models, datasets, and hyperparameters. This lowers the entry barriers for users with varying skill levels. 91 | 92 | In summary, ***DTX*** strategically addresses challenges in resource optimization, data management, workflow efficiency, and accessibility, making it an ideal solution for efficient natural language processing tasks. 93 | 94 | # Architecture 🏛️ 95 | 96 | Introducing the architectural design provides an overview of how DataTunerX is structured. This includes details on key components, their interactions, and how they contribute to the system's functionality. 97 | 98 | # Installation 📦 99 | 100 | Detailed instructions on how to install, configure, and run the project are available in the [*INSTALL*](INSTALL.md) document. 101 | 102 | # Usage 🖥️ 103 | 104 | Provide clear instructions on how to use the software, including code snippets where appropriate. (COMING SOON) 105 | 106 | # Known issues 🚨 107 | 108 | Document any known significant shortcomings with the software. 109 | 110 | # Getting help ❓ 111 | 112 | If you have questions, concerns, or bug reports, please file an issue in this repository's [*Issue Tracker*](https://github.com/DataTunerX/datatunerx-controller/issues). 113 | 114 | # Getting involved 🤝 115 | 116 | We welcome contributions! Check out our [*CONTRIBUTING*](CONTRIBUTING.md) guidelines to get started. Share your feedback, report bugs, or contribute to ongoing discussions. 117 | 118 | ---- 119 | 120 | # Credits and References 🙌 121 | 122 | 1. **Kubernetes (k8s):** 123 | - [*Kubernetes*](https://kubernetes.io/): An open-source container orchestration platform for automating the deployment, scaling, and management of containerized applications. 124 | 125 | 2. **Ray:** 126 | - [*Ray Project*](https://ray.io/): An open-source distributed computing framework that makes it easy to scale and parallelize applications. 127 | 128 | 3. **KubeRay:** 129 | - [*KubeRay*](https://github.com/kuberay/kuberay): An integration of Ray with Kubernetes, enabling efficient distributed computing on Kubernetes clusters. 130 | 131 | 4. **Operator SDK:** 132 | - [*Operator SDK*](https://sdk.operatorframework.io/): A toolkit for building Kubernetes Operators, which are applications that automate the management of custom resources in a Kubernetes cluster. 133 | 134 | 5. **LLaMA-Factory:** 135 | - [*LLaMA-Factory*](https://github.com/hiyouga/LLaMA-Factory): An easy-to-use llm fine-tuning framework. 136 | 137 | Feel free to explore these projects to deepen your understanding of the technologies and concepts that may have influenced or inspired this project. 138 | -------------------------------------------------------------------------------- /assets/design/batch.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DataTunerX/datatunerx/508be30c13e401fa8c67a69ec33b2a25661631d2/assets/design/batch.png -------------------------------------------------------------------------------- /assets/design/batchdark.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DataTunerX/datatunerx/508be30c13e401fa8c67a69ec33b2a25661631d2/assets/design/batchdark.png -------------------------------------------------------------------------------- /assets/design/datasetplugin.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DataTunerX/datatunerx/508be30c13e401fa8c67a69ec33b2a25661631d2/assets/design/datasetplugin.png -------------------------------------------------------------------------------- /assets/design/datasetplugindark.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DataTunerX/datatunerx/508be30c13e401fa8c67a69ec33b2a25661631d2/assets/design/datasetplugindark.png -------------------------------------------------------------------------------- /assets/design/evalandinference.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DataTunerX/datatunerx/508be30c13e401fa8c67a69ec33b2a25661631d2/assets/design/evalandinference.png -------------------------------------------------------------------------------- /assets/design/evaldark.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DataTunerX/datatunerx/508be30c13e401fa8c67a69ec33b2a25661631d2/assets/design/evaldark.png -------------------------------------------------------------------------------- /assets/design/finetune.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DataTunerX/datatunerx/508be30c13e401fa8c67a69ec33b2a25661631d2/assets/design/finetune.png -------------------------------------------------------------------------------- /assets/design/finetunedark.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DataTunerX/datatunerx/508be30c13e401fa8c67a69ec33b2a25661631d2/assets/design/finetunedark.png -------------------------------------------------------------------------------- /assets/design/finetuneexdark.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DataTunerX/datatunerx/508be30c13e401fa8c67a69ec33b2a25661631d2/assets/design/finetuneexdark.png -------------------------------------------------------------------------------- /assets/design/finetuneexperiment.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DataTunerX/datatunerx/508be30c13e401fa8c67a69ec33b2a25661631d2/assets/design/finetuneexperiment.png -------------------------------------------------------------------------------- /assets/design/finetunejob.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DataTunerX/datatunerx/508be30c13e401fa8c67a69ec33b2a25661631d2/assets/design/finetunejob.png -------------------------------------------------------------------------------- /assets/design/finetunejobdark.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DataTunerX/datatunerx/508be30c13e401fa8c67a69ec33b2a25661631d2/assets/design/finetunejobdark.png -------------------------------------------------------------------------------- /assets/logo/Logo_DataTunerX - Horizontal - Color Dark.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DataTunerX/datatunerx/508be30c13e401fa8c67a69ec33b2a25661631d2/assets/logo/Logo_DataTunerX - Horizontal - Color Dark.png -------------------------------------------------------------------------------- /assets/logo/Logo_DataTunerX - Horizontal - Color Light.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DataTunerX/datatunerx/508be30c13e401fa8c67a69ec33b2a25661631d2/assets/logo/Logo_DataTunerX - Horizontal - Color Light.png -------------------------------------------------------------------------------- /assets/screenshot/Job_Details.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DataTunerX/datatunerx/508be30c13e401fa8c67a69ec33b2a25661631d2/assets/screenshot/Job_Details.png -------------------------------------------------------------------------------- /cmd/controller-manager/app/controller_manager.go: -------------------------------------------------------------------------------- 1 | package app 2 | 3 | import ( 4 | "fmt" 5 | "os" 6 | 7 | "github.com/DataTunerX/datatunerx/cmd/controller-manager/app/options" 8 | "github.com/DataTunerX/datatunerx/internal/controller/finetune" 9 | "github.com/DataTunerX/datatunerx/pkg/util" 10 | corev1beta1 "github.com/DataTunerX/meta-server/api/core/v1beta1" 11 | extensionv1beta1 "github.com/DataTunerX/meta-server/api/extension/v1beta1" 12 | finetunev1beta1 "github.com/DataTunerX/meta-server/api/finetune/v1beta1" 13 | "github.com/DataTunerX/utility-server/logging" 14 | "github.com/go-logr/zapr" 15 | "github.com/open-policy-agent/cert-controller/pkg/rotator" 16 | rayv1 "github.com/ray-project/kuberay/ray-operator/apis/ray/v1" 17 | "github.com/spf13/pflag" 18 | "k8s.io/apimachinery/pkg/runtime" 19 | "k8s.io/apimachinery/pkg/types" 20 | utilruntime "k8s.io/apimachinery/pkg/util/runtime" 21 | "k8s.io/client-go/kubernetes" 22 | clientgoscheme "k8s.io/client-go/kubernetes/scheme" 23 | _ "k8s.io/client-go/plugin/pkg/client/auth" 24 | ctrl "sigs.k8s.io/controller-runtime" 25 | "sigs.k8s.io/controller-runtime/pkg/healthz" 26 | "sigs.k8s.io/controller-runtime/pkg/manager" 27 | metricsserver "sigs.k8s.io/controller-runtime/pkg/metrics/server" 28 | "sigs.k8s.io/controller-runtime/pkg/webhook" 29 | //+kubebuilder:scaffold:imports 30 | ) 31 | 32 | const ( 33 | LockName = "datatunerx-lock" 34 | SecretName = "datatunerx-cert" 35 | CaName = "datatunerx-ca" 36 | CaOrganization = "datatunerx" 37 | ServiceName = "datatunerx" 38 | ) 39 | 40 | var ( 41 | scheme = runtime.NewScheme() 42 | ) 43 | 44 | func init() { 45 | utilruntime.Must(clientgoscheme.AddToScheme(scheme)) 46 | utilruntime.Must(finetunev1beta1.AddToScheme(scheme)) 47 | utilruntime.Must(corev1beta1.AddToScheme(scheme)) 48 | utilruntime.Must(extensionv1beta1.AddToScheme(scheme)) 49 | utilruntime.Must(rayv1.AddToScheme(scheme)) 50 | //+kubebuilder:scaffold:scheme 51 | } 52 | 53 | func NewControllerManager() (manager.Manager, error) { 54 | opts := options.NewOptions() 55 | flagSet := pflag.NewFlagSet("generic", pflag.ExitOnError) 56 | opts.AddFlags(flagSet) 57 | err := flagSet.Parse(os.Args[1:]) 58 | if err != nil { 59 | logging.ZLogger.Errorf("Error parsing flags: %v", err) 60 | os.Exit(1) 61 | } 62 | logging.ZLogger.Info("Set logger for controller") 63 | ctrl.SetLogger(zapr.NewLogger(logging.ZLogger.GetLogger())) 64 | namespace := util.GetOperatorNamespace() 65 | ctrOption := ctrl.Options{ 66 | Scheme: scheme, 67 | Metrics: metricsserver.Options{ 68 | BindAddress: opts.MetricsAddr, 69 | }, 70 | WebhookServer: webhook.NewServer(webhook.Options{Port: 9443}), 71 | HealthProbeBindAddress: opts.ProbeAddr, 72 | LeaderElection: true, 73 | LeaderElectionID: LockName, 74 | LeaderElectionNamespace: namespace, 75 | } 76 | 77 | mgr, err := ctrl.NewManager(ctrl.GetConfigOrDie(), ctrOption) 78 | if err != nil { 79 | logging.ZLogger.Errorf("Build controller manager failed: %v", err) 80 | return nil, err 81 | } 82 | setupFinished := make(chan struct{}) 83 | if opts.EnableCertRotator { 84 | logging.ZLogger.Info("Setting up cert rotation") 85 | if err := rotator.AddRotator(mgr, &rotator.CertRotator{ 86 | SecretKey: types.NamespacedName{ 87 | Namespace: namespace, 88 | Name: SecretName, 89 | }, 90 | CAName: CaName, 91 | CAOrganization: CaOrganization, 92 | CertDir: "/tmp/k8s-webhook-server/serving-certs", 93 | DNSName: fmt.Sprintf("%s.%s.svc", ServiceName, namespace), 94 | IsReady: setupFinished, 95 | Webhooks: []rotator.WebhookInfo{ 96 | { 97 | Name: namespace + "-validating-webhook-configuration", 98 | Type: rotator.Validating, 99 | }, 100 | { 101 | Name: namespace + "-mutating-webhook-configuration", 102 | Type: rotator.Mutating, 103 | }, 104 | }, 105 | }); err != nil { 106 | logging.ZLogger.Errorf("Unable to set up cert rotation, %v", err) 107 | os.Exit(1) 108 | } 109 | } else { 110 | close(setupFinished) 111 | } 112 | go func() { 113 | <-setupFinished 114 | if err := (&finetunev1beta1.FinetuneJob{}).SetupWebhookWithManager(mgr); err != nil { 115 | logging.ZLogger.Errorf("Unable to create webhook, %v", err) 116 | os.Exit(1) 117 | 118 | } 119 | if err := (&finetunev1beta1.FinetuneExperiment{}).SetupWebhookWithManager(mgr); err != nil { 120 | logging.ZLogger.Errorf("Unable to create webhook, %v", err) 121 | os.Exit(1) 122 | } 123 | if err := (&corev1beta1.LLM{}).SetupWebhookWithManager(mgr); err != nil { 124 | logging.ZLogger.Errorf("Unable to create webhook, %v", err) 125 | os.Exit(1) 126 | } 127 | if err := (&corev1beta1.Hyperparameter{}).SetupWebhookWithManager(mgr); err != nil { 128 | logging.ZLogger.Errorf("Unable to create webhook, %v", err) 129 | os.Exit(1) 130 | } 131 | if err := (&extensionv1beta1.Dataset{}).SetupWebhookWithManager(mgr); err != nil { 132 | logging.ZLogger.Errorf("Unable to create webhook, %v", err) 133 | os.Exit(1) 134 | } 135 | }() 136 | 137 | if err = (&finetune.FinetuneExperimentReconciler{ 138 | Client: mgr.GetClient(), 139 | Scheme: mgr.GetScheme(), 140 | Log: logging.ZLogger, 141 | }).SetupWithManager(mgr); err != nil { 142 | logging.ZLogger.Errorf("Unable to create FinetuneExperiment controller, %v", err) 143 | return nil, err 144 | } 145 | if err = (&finetune.FinetuneJobReconciler{ 146 | Client: mgr.GetClient(), 147 | Scheme: mgr.GetScheme(), 148 | Log: logging.ZLogger, 149 | }).SetupWithManager(mgr); err != nil { 150 | logging.ZLogger.Errorf("Unable to create FinetuneJob controller, %v", err) 151 | return nil, err 152 | } 153 | if err = (&finetune.FinetuneReconciler{ 154 | Log: logging.ZLogger, 155 | Client: mgr.GetClient(), 156 | Scheme: mgr.GetScheme(), 157 | Clientset: kubernetes.NewForConfigOrDie(ctrl.GetConfigOrDie()), 158 | Config: ctrl.GetConfigOrDie(), 159 | }).SetupWithManager(mgr); err != nil { 160 | logging.ZLogger.Errorf("Unable to create Finetune controller, %v", err) 161 | return nil, err 162 | } 163 | //+kubebuilder:scaffold:builder 164 | 165 | if err := mgr.AddHealthzCheck("healthz", healthz.Ping); err != nil { 166 | logging.ZLogger.Errorf("Unable to set up health check: %v", err) 167 | return nil, err 168 | } 169 | if err := mgr.AddReadyzCheck("readyz", healthz.Ping); err != nil { 170 | logging.ZLogger.Errorf("Unable to set up ready check: %v", err) 171 | return nil, err 172 | } 173 | 174 | return mgr, nil 175 | } 176 | -------------------------------------------------------------------------------- /cmd/controller-manager/app/options/options.go: -------------------------------------------------------------------------------- 1 | package options 2 | 3 | import ( 4 | "time" 5 | 6 | "github.com/spf13/pflag" 7 | ) 8 | 9 | const ( 10 | defaultLeaseDuration = 10 * time.Second 11 | defaultRenewDeadline = 10 * time.Second 12 | defaultRetryPeriod = 10 * time.Second 13 | defaultMetricsAddr = ":8080" 14 | defaultProbeAddr = ":8081" 15 | defaultNamespace = "datatunerx-dev" 16 | defaultCertRotator = true 17 | ) 18 | 19 | type Options struct { 20 | LeaderElectLeaseConfig LeaderElectLeaseConfig 21 | MetricsAddr string 22 | ProbeAddr string 23 | EnableCertRotator bool 24 | } 25 | 26 | type LeaderElectLeaseConfig struct { 27 | LeaseDuration time.Duration 28 | RenewDeadline time.Duration 29 | RetryPeriod time.Duration 30 | } 31 | 32 | func NewOptions() *Options { 33 | return &Options{ 34 | LeaderElectLeaseConfig: LeaderElectLeaseConfig{}, 35 | } 36 | } 37 | 38 | func (o *Options) AddFlags(fs *pflag.FlagSet) { 39 | if o == nil { 40 | return 41 | } 42 | fs.StringVar(&o.MetricsAddr, "metrics-bind-address", defaultMetricsAddr, "The address the metric endpoint binds to.") 43 | fs.StringVar(&o.ProbeAddr, "health-probe-bind-address", defaultProbeAddr, "The address the probe endpoint binds to.") 44 | fs.DurationVar(&o.LeaderElectLeaseConfig.LeaseDuration, "lease-duration", defaultLeaseDuration, "The duration that non-leader candidates will wait after observing a leadership renewal until attempting to acquire leadership of a led but unrenewed group.") 45 | fs.DurationVar(&o.LeaderElectLeaseConfig.RenewDeadline, "renew-deadline", defaultRenewDeadline, "Duration the clients should wait between attempting to renew the lease of the lock.") 46 | fs.DurationVar(&o.LeaderElectLeaseConfig.RetryPeriod, "retry-period", defaultRetryPeriod, "The time duration for the client to wait between attempts of acquiring a lock.") 47 | fs.BoolVar(&o.EnableCertRotator, "cert-rotator", defaultCertRotator, "Automatically apply for a certificate for Webhooks.") 48 | } 49 | -------------------------------------------------------------------------------- /cmd/tuning/Dockerfile: -------------------------------------------------------------------------------- 1 | FROM rayproject/ray271-py39-gpu-llama2-7b-inference:20231220 2 | 3 | WORKDIR /tuning 4 | 5 | COPY requirements.txt . 6 | RUN pip install -r requirements.txt -i https://pypi.tuna.tsinghua.edu.cn/simple 7 | 8 | COPY . . -------------------------------------------------------------------------------- /cmd/tuning/README.md: -------------------------------------------------------------------------------- 1 | # Dataset 2 | instruction 3 | response -------------------------------------------------------------------------------- /cmd/tuning/build_image.sh: -------------------------------------------------------------------------------- 1 | docker build . -t rayproject/ray271-llama2-7b-finetune:20231220 -------------------------------------------------------------------------------- /cmd/tuning/callback.py: -------------------------------------------------------------------------------- 1 | import os 2 | import json 3 | import time 4 | from typing import TYPE_CHECKING 5 | from datetime import timedelta 6 | 7 | from transformers import TrainerCallback 8 | from transformers.trainer_utils import has_length, PREFIX_CHECKPOINT_DIR 9 | 10 | from prometheus.metrics import export_train_metrics, export_eval_metrics 11 | 12 | if TYPE_CHECKING: 13 | from transformers import TrainingArguments, TrainerState, TrainerControl 14 | 15 | LOG_FILE_NAME = "trainer_log.jsonl" 16 | 17 | 18 | class LogCallback(TrainerCallback): 19 | 20 | def __init__(self, runner=None, metrics_export_address=None, uid=None): 21 | self.runner = runner 22 | self.in_training = False 23 | self.start_time = time.time() 24 | self.cur_steps = 0 25 | self.max_steps = 0 26 | self.elapsed_time = "" 27 | self.remaining_time = "" 28 | self.metrics_export_address = metrics_export_address 29 | self.uid = uid 30 | 31 | def timing(self): 32 | cur_time = time.time() 33 | elapsed_time = cur_time - self.start_time 34 | avg_time_per_step = elapsed_time / self.cur_steps if self.cur_steps != 0 else 0 35 | remaining_time = (self.max_steps - self.cur_steps) * avg_time_per_step 36 | self.elapsed_time = str(timedelta(seconds=int(elapsed_time))) 37 | self.remaining_time = str(timedelta(seconds=int(remaining_time))) 38 | 39 | def on_train_begin(self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs): 40 | r""" 41 | Event called at the beginning of training. 42 | """ 43 | if state.is_local_process_zero: 44 | self.in_training = True 45 | self.start_time = time.time() 46 | self.max_steps = state.max_steps 47 | if os.path.exists(os.path.join(args.output_dir, LOG_FILE_NAME)) and args.overwrite_output_dir: 48 | print("Previous log file in this folder will be deleted.") 49 | os.remove(os.path.join(args.output_dir, LOG_FILE_NAME)) 50 | 51 | def on_train_end(self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs): 52 | r""" 53 | Event called at the end of training. 54 | """ 55 | if state.is_local_process_zero: 56 | self.in_training = False 57 | self.cur_steps = 0 58 | self.max_steps = 0 59 | 60 | def on_substep_end(self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs): 61 | r""" 62 | Event called at the end of an substep during gradient accumulation. 63 | """ 64 | if state.is_local_process_zero and self.runner is not None and self.runner.aborted: 65 | control.should_epoch_stop = True 66 | control.should_training_stop = True 67 | 68 | def on_step_end(self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs): 69 | r""" 70 | Event called at the end of a training step. 71 | """ 72 | if state.is_local_process_zero: 73 | self.cur_steps = state.global_step 74 | self.timing() 75 | if self.runner is not None and self.runner.aborted: 76 | control.should_epoch_stop = True 77 | control.should_training_stop = True 78 | 79 | def on_evaluate(self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs): 80 | r""" 81 | Event called after an evaluation phase. 82 | """ 83 | if state.is_local_process_zero and not self.in_training: 84 | self.cur_steps = 0 85 | self.max_steps = 0 86 | 87 | def on_predict(self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", *other, **kwargs): 88 | r""" 89 | Event called after a successful prediction. 90 | """ 91 | if state.is_local_process_zero and not self.in_training: 92 | self.cur_steps = 0 93 | self.max_steps = 0 94 | 95 | def on_log(self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs) -> None: 96 | r""" 97 | Event called after logging the last logs. 98 | """ 99 | if not state.is_local_process_zero: 100 | return 101 | 102 | print('log_history: ', state.log_history[-1]) # add 看看返回的 key 103 | if "eval_loss" in state.log_history[-1].keys(): 104 | eval_log = dict( 105 | uid=self.uid, 106 | current_steps=self.cur_steps, 107 | total_steps=self.max_steps, 108 | eval_loss=state.log_history[-1].get("eval_loss", None), 109 | eval_perplexity=state.log_history[-1].get("eval_perplexity", None), 110 | eval_rouge_1=state.log_history[-1].get("eval_rouge-1", None), 111 | eval_rouge_2=state.log_history[-1].get("eval_rouge-2", None), 112 | eval_rouge_l=state.log_history[-1].get("eval_rouge-l", None), 113 | eval_bleu_4=state.log_history[-1].get("eval_bleu-4", None), 114 | epoch=state.log_history[-1].get("epoch", None), 115 | percentage=round(self.cur_steps / self.max_steps * 100, 2) if self.max_steps != 0 else 100, 116 | elapsed_time=self.elapsed_time, 117 | remaining_time=self.remaining_time 118 | ) 119 | else: 120 | logs = dict( 121 | uid=self.uid, 122 | current_steps=self.cur_steps, 123 | total_steps=self.max_steps, 124 | loss=state.log_history[-1].get("loss", None), 125 | eval_loss=state.log_history[-1].get("eval_loss", None), 126 | val_perplexity=state.log_history[-1].get("eval_perplexity", None), 127 | eval_rouge_1=state.log_history[-1].get("eval_rouge-1", None), 128 | eval_rouge_2=state.log_history[-1].get("eval_rouge-2", None), 129 | eval_rouge_l=state.log_history[-1].get("eval_rouge-l", None), 130 | eval_bleu_4=state.log_history[-1].get("eval_bleu-4", None), 131 | predict_loss=state.log_history[-1].get("predict_loss", None), 132 | reward=state.log_history[-1].get("reward", None), 133 | learning_rate=state.log_history[-1].get("learning_rate", None), 134 | epoch=state.log_history[-1].get("epoch", None), 135 | percentage=round(self.cur_steps / self.max_steps * 100, 2) if self.max_steps != 0 else 100, 136 | elapsed_time=self.elapsed_time, 137 | remaining_time=self.remaining_time 138 | ) 139 | if self.runner is not None: 140 | print("{{'loss': {:.4f}, 'learning_rate': {:2.4e}, 'epoch': {:.2f}}}".format( 141 | logs["loss"] or 0, logs["learning_rate"] or 0, logs["epoch"] or 0 142 | )) 143 | 144 | os.makedirs(args.output_dir, exist_ok=True) 145 | os.makedirs(os.path.join(args.output_dir, 'watch'), exist_ok=True) 146 | if "eval_loss" in state.log_history[-1].keys(): 147 | with open(os.path.join(args.output_dir, 'watch', "eval_log.jsonl"), "a", encoding="utf-8") as f: 148 | f.write(json.dumps(eval_log) + "\n") 149 | if self.metrics_export_address: 150 | export_eval_metrics(self.metrics_export_address, eval_log) 151 | else: 152 | with open(os.path.join(args.output_dir, 'watch', "trainer_log.jsonl"), "a", encoding="utf-8") as f: 153 | f.write(json.dumps(logs) + "\n") 154 | if self.metrics_export_address: 155 | export_train_metrics(self.metrics_export_address, logs) 156 | 157 | def on_prediction_step(self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs): 158 | r""" 159 | Event called after a prediction step. 160 | """ 161 | eval_dataloader = kwargs.pop("eval_dataloader", None) 162 | if state.is_local_process_zero and has_length(eval_dataloader) and not self.in_training: 163 | if self.max_steps == 0: 164 | self.max_steps = len(eval_dataloader) 165 | self.cur_steps += 1 166 | self.timing() 167 | -------------------------------------------------------------------------------- /cmd/tuning/ds_config.json: -------------------------------------------------------------------------------- 1 | { 2 | "train_batch_size": "auto", 3 | "train_micro_batch_size_per_gpu": "auto", 4 | "gradient_accumulation_steps": "auto", 5 | "gradient_clipping": "auto", 6 | "zero_allow_untested_optimizer": true, 7 | "fp16": { 8 | "enabled": "auto" 9 | }, 10 | "zero_optimization": { 11 | "stage": 0 12 | } 13 | } -------------------------------------------------------------------------------- /cmd/tuning/parser.py: -------------------------------------------------------------------------------- 1 | import json 2 | import logging 3 | from dataclasses import field, dataclass 4 | from typing import Optional, Dict, Any, Tuple, Literal 5 | 6 | import transformers 7 | from transformers import Seq2SeqTrainingArguments, HfArgumentParser 8 | 9 | logger = logging.getLogger(__name__) 10 | 11 | 12 | @dataclass 13 | class ModelArguments: 14 | r""" 15 | Arguments pertaining to which model/config/tokenizer we are going to fine-tune. 16 | """ 17 | model_name_or_path: str = field( 18 | metadata={"help": "Path to pretrained model or model identifier from huggingface.co/models."} 19 | ) 20 | cache_dir: Optional[str] = field( 21 | default=None, 22 | metadata={"help": "Where to store the pretrained models downloaded from huggingface.co."} 23 | ) 24 | use_fast_tokenizer: Optional[bool] = field( 25 | default=True, 26 | metadata={"help": "Whether to use one of the fast tokenizer (backed by the tokenizers library) or not."} 27 | ) 28 | split_special_tokens: Optional[bool] = field( 29 | default=False, 30 | metadata={"help": "Whether or not the special tokens should be split during the tokenization process."} 31 | ) 32 | use_auth_token: Optional[bool] = field( 33 | default=False, 34 | metadata={"help": "Will use the token generated when running `huggingface-cli login`."} 35 | ) 36 | model_revision: Optional[str] = field( 37 | default="main", 38 | metadata={"help": "The specific model version to use (can be a branch name, tag name or commit id)."} 39 | ) 40 | quantization_bit: Optional[int] = field( 41 | default=None, 42 | metadata={"help": "The number of bits to quantize the model."} 43 | ) 44 | quantization_type: Optional[Literal["fp4", "nf4"]] = field( 45 | default="nf4", 46 | metadata={"help": "Quantization data type to use in int4 training."} 47 | ) 48 | double_quantization: Optional[bool] = field( 49 | default=True, 50 | metadata={"help": "Whether to use double quantization in int4 training or not."} 51 | ) 52 | quantization: Optional[str] = field( 53 | default=None, 54 | metadata={"help": "quantize the model, int4, int8, or None."} 55 | ) 56 | 57 | rope_scaling: Optional[Literal["linear", "dynamic"]] = field( 58 | default=None, 59 | metadata={"help": "Adopt scaled rotary positional embeddings."} 60 | ) 61 | checkpoint_dir: Optional[str] = field( 62 | default=None, 63 | metadata={ 64 | "help": "Path to the directory(s) containing the delta model checkpoints as well as the configurations."} 65 | ) 66 | flash_attn: Optional[bool] = field( 67 | default=False, 68 | metadata={"help": "Enable FlashAttention-2 for faster training."} 69 | ) 70 | shift_attn: Optional[bool] = field( 71 | default=False, 72 | metadata={"help": "Enable shift short attention (S^2-Attn) proposed by LongLoRA."} 73 | ) 74 | reward_model: Optional[str] = field( 75 | default=None, 76 | metadata={"help": "Path to the directory containing the checkpoints of the reward model."} 77 | ) 78 | plot_loss: Optional[bool] = field( 79 | default=False, 80 | metadata={"help": "Whether to plot the training loss after fine-tuning or not."} 81 | ) 82 | hf_auth_token: Optional[str] = field( 83 | default=None, 84 | metadata={"help": "Auth token to log in with Hugging Face Hub."} 85 | ) 86 | export_dir: Optional[str] = field( 87 | default=None, 88 | metadata={"help": "Path to the directory to save the exported model."} 89 | ) 90 | 91 | def __post_init__(self): 92 | self.compute_dtype = None 93 | self.model_max_length = None 94 | 95 | if self.split_special_tokens and self.use_fast_tokenizer: 96 | raise ValueError("`split_special_tokens` is only supported for slow tokenizers.") 97 | 98 | if self.checkpoint_dir is not None: # support merging multiple lora weights 99 | self.checkpoint_dir = [cd.strip() for cd in self.checkpoint_dir.split(",")] 100 | 101 | if self.quantization_bit is not None: 102 | assert self.quantization_bit in [4, 8], "We only accept 4-bit or 8-bit quantization." 103 | 104 | if self.quantization is not None: 105 | assert self.quantization in ["int4", "int8"], "We only accept int4 or int8 quantization." 106 | 107 | if self.use_auth_token == True and self.hf_auth_token is not None: 108 | from huggingface_hub.hf_api import HfFolder # lazy load 109 | HfFolder.save_token(self.hf_auth_token) 110 | 111 | 112 | @dataclass 113 | class FinetuningArguments: 114 | r""" 115 | Arguments pertaining to which techniques we are going to fine-tuning with. 116 | """ 117 | stage: Optional[Literal["pt", "sft", "rm", "ppo", "dpo"]] = field( 118 | default="sft", 119 | metadata={"help": "Which stage will be performed in training."} 120 | ) 121 | finetuning_type: Optional[Literal["lora", "freeze", "full", "none"]] = field( 122 | default="lora", 123 | metadata={"help": "Which fine-tuning method to use."} 124 | ) 125 | num_layer_trainable: Optional[int] = field( 126 | default=3, 127 | metadata={"help": "Number of trainable layers for partial-parameter (freeze) fine-tuning."} 128 | ) 129 | name_module_trainable: Optional[Literal["mlp", "self_attn", "self_attention"]] = field( 130 | default="mlp", 131 | metadata={"help": "Name of trainable modules for partial-parameter (freeze) fine-tuning. \ 132 | LLaMA choices: [\"mlp\", \"self_attn\"], \ 133 | BLOOM & Falcon & ChatGLM2 choices: [\"mlp\", \"self_attention\"], \ 134 | Qwen choices: [\"mlp\", \"attn\"], \ 135 | Phi-1.5 choices: [\"mlp\", \"mixer\"], \ 136 | LLaMA-2, Baichuan, InternLM, XVERSE choices: the same as LLaMA."} 137 | ) 138 | lora_rank: Optional[int] = field( 139 | default=8, 140 | metadata={"help": "The intrinsic dimension for LoRA fine-tuning."} 141 | ) 142 | lora_alpha: Optional[float] = field( 143 | default=32.0, 144 | metadata={"help": "The scale factor for LoRA fine-tuning (similar with the learning rate)."} 145 | ) 146 | lora_dropout: Optional[float] = field( 147 | default=0.1, 148 | metadata={"help": "Dropout rate for the LoRA fine-tuning."} 149 | ) 150 | lora_target: Optional[str] = field( 151 | default=None, 152 | metadata={"help": "Name(s) of target modules to apply LoRA. Use commas to separate multiple modules. \ 153 | LLaMA choices: [\"q_proj\", \"k_proj\", \"v_proj\", \"o_proj\", \"gate_proj\", \"up_proj\", \"down_proj\"], \ 154 | BLOOM & Falcon & ChatGLM2 choices: [\"query_key_value\", \"self_attention.dense\", \"mlp.dense\"], \ 155 | Baichuan choices: [\"W_pack\", \"o_proj\", \"gate_proj\", \"up_proj\", \"down_proj\"], \ 156 | Qwen choices: [\"c_attn\", \"attn.c_proj\", \"w1\", \"w2\", \"mlp.c_proj\"], \ 157 | Phi-1.5 choices: [\"Wqkv\", \"out_proj\", \"fc1\", \"fc2\"], \ 158 | LLaMA-2, InternLM, XVERSE choices: the same as LLaMA."} 159 | ) 160 | additional_target: Optional[str] = field( 161 | default=None, 162 | metadata={ 163 | "help": "Name(s) of modules apart from LoRA layers to be set as trainable and saved in the final checkpoint."} 164 | ) 165 | resume_lora_training: Optional[bool] = field( 166 | default=True, 167 | metadata={ 168 | "help": "Whether to resume training from the last LoRA weights or create new weights after merging them."} 169 | ) 170 | ppo_score_norm: Optional[bool] = field( 171 | default=False, 172 | metadata={"help": "Use score normalization in PPO training."} 173 | ) 174 | ppo_logger: Optional[str] = field( 175 | default=None, 176 | metadata={"help": "Log with either 'wandb' or 'tensorboard' in PPO training."} 177 | ) 178 | ppo_target: Optional[float] = field( 179 | default=6.0, 180 | metadata={"help": "Target KL value for adaptive KL control in PPO training."} 181 | ) 182 | dpo_beta: Optional[float] = field( 183 | default=0.1, 184 | metadata={"help": "The beta parameter for the DPO loss."} 185 | ) 186 | upcast_layernorm: Optional[bool] = field( 187 | default=False, 188 | metadata={"help": "Whether to upcast the layernorm weights in fp32."} 189 | ) 190 | neft_alpha: Optional[float] = field( 191 | default=0, 192 | metadata={"help": "The alpha parameter to control the noise magnitude in NEFTune."} 193 | ) 194 | num_workers: Optional[int] = field( 195 | default=1, 196 | metadata={"help": "Number of worker."} 197 | ) 198 | storage_path: Optional[str] = field( 199 | default=None, 200 | metadata={"help": "storage_path is used to storage checkpoint."} 201 | ) 202 | metrics_export_address: Optional[str] = field( 203 | default=None, 204 | metadata={"help": "address to export train metrics."} 205 | ) 206 | uid: Optional[str] = field( 207 | default=None, 208 | metadata={"help": "finetune crd uid."} 209 | ) 210 | 211 | def __post_init__(self): 212 | if isinstance(self.lora_target, str): # support custom target modules/layers of LoRA 213 | self.lora_target = [target.strip() for target in self.lora_target.split(",")] 214 | 215 | if isinstance(self.additional_target, str): 216 | self.additional_target = [target.strip() for target in self.additional_target.split(",")] 217 | 218 | assert self.finetuning_type in ["lora", "freeze", "full", "none"], "Invalid fine-tuning method." 219 | 220 | if not self.storage_path: 221 | raise ValueError("--storage_path must be specified") 222 | 223 | 224 | @dataclass 225 | class DataArguments: 226 | train_path: Optional[str] = field( 227 | default=None, 228 | metadata={"help": "Path to train dataset"} 229 | ) 230 | 231 | evaluation_path: Optional[str] = field( 232 | default=None, 233 | metadata={"help": "Path to evaluation dataset"} 234 | ) 235 | 236 | columns: Optional[str] = field( 237 | default=None, 238 | metadata={"help": "columns map for dataset"} 239 | ) 240 | block_size: Optional[int] = field( 241 | default=1024, 242 | metadata={"help": "length of input."} 243 | ) 244 | 245 | def __post_init__(self): 246 | if self.train_path is None: 247 | raise ValueError("--train_path must be specified") 248 | 249 | 250 | def get_train_args() -> Tuple[Seq2SeqTrainingArguments, FinetuningArguments, ModelArguments, DataArguments]: 251 | parser = HfArgumentParser((Seq2SeqTrainingArguments, FinetuningArguments, ModelArguments, DataArguments)) 252 | 253 | training_args, finetuning_args, model_args, data_args = parser.parse_args_into_dataclasses() 254 | 255 | training_args.optim = "adamw_torch" if training_args.optim == "adamw_hf" else training_args.optim # suppress warning 256 | 257 | # Log on each process the small summary: 258 | logger.info( 259 | f"Process rank: {training_args.local_rank}, device: {training_args.device}, n_gpu: {training_args.n_gpu}\n" 260 | + f" distributed training: {bool(training_args.local_rank != -1)}, 16-bits training: {training_args.fp16}" 261 | ) 262 | 263 | # Set seed before initializing model. 264 | transformers.set_seed(training_args.seed) 265 | 266 | return training_args, finetuning_args, model_args, data_args 267 | -------------------------------------------------------------------------------- /cmd/tuning/prometheus/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DataTunerX/datatunerx/508be30c13e401fa8c67a69ec33b2a25661631d2/cmd/tuning/prometheus/__init__.py -------------------------------------------------------------------------------- /cmd/tuning/prometheus/metrics.py: -------------------------------------------------------------------------------- 1 | from typing import List, Dict 2 | from datetime import datetime 3 | from urllib.parse import urljoin 4 | from .prometheus_pb2 import ( 5 | WriteRequest, 6 | TimeSeries 7 | ) 8 | import calendar 9 | import logging 10 | import requests 11 | import snappy 12 | 13 | 14 | def dt2ts(dt): 15 | """Converts a datetime object to UTC timestamp 16 | naive datetime will be considered UTC. 17 | """ 18 | return calendar.timegm(dt.utctimetuple()) 19 | 20 | 21 | def write(address: str, series: List[TimeSeries]): 22 | write_request = WriteRequest() 23 | write_request.timeseries.extend(series) 24 | 25 | uncompressed = write_request.SerializeToString() 26 | compressed = snappy.compress(uncompressed) 27 | 28 | url = urljoin(address, "/api/v1/write") 29 | headers = { 30 | "Content-Encoding": "snappy", 31 | "Content-Type": "application/x-protobuf", 32 | "X-Prometheus-Remote-Write-Version": "0.1.0", 33 | "User-Agent": "metrics-worker" 34 | } 35 | try: 36 | response = requests.post(url, headers=headers, data=compressed) 37 | print(response) 38 | except Exception as e: 39 | print(e) 40 | 41 | 42 | def export_train_metrics(address: str, metrics: Dict): 43 | series = TimeSeries() 44 | label = series.labels.add() 45 | label.name = "__name__" 46 | label.value = "train_metrics" 47 | 48 | label = series.labels.add() 49 | label.name = "uid" 50 | label.value = str(metrics["uid"]) 51 | 52 | label = series.labels.add() 53 | label.name = "total_steps" 54 | label.value = str(metrics.get("total_steps", "")) 55 | 56 | label = series.labels.add() 57 | label.name = "current_steps" 58 | label.value = str(metrics.get("current_steps", "")) 59 | 60 | label = series.labels.add() 61 | label.name = "loss" 62 | label.value = str(metrics.get("loss", "")) 63 | 64 | label = series.labels.add() 65 | label.name = "learning_rate" 66 | label.value = str(metrics.get("learning_rate", "")) 67 | 68 | label = series.labels.add() 69 | label.name = "epoch" 70 | label.value = str(metrics.get("epoch", "")) 71 | 72 | sample = series.samples.add() 73 | sample.value = 1 74 | sample.timestamp = dt2ts(datetime.utcnow()) * 1000 75 | 76 | write(address, [series]) 77 | 78 | 79 | def export_eval_metrics(address: str, metrics: Dict): 80 | series = TimeSeries() 81 | label = series.labels.add() 82 | label.name = "__name__" 83 | label.value = "eval_metrics" 84 | 85 | label = series.labels.add() 86 | label.name = "uid" 87 | label.value = str(metrics["uid"]) 88 | 89 | label = series.labels.add() 90 | label.name = "total_steps" 91 | label.value = str(metrics.get("total_steps", "")) 92 | 93 | label = series.labels.add() 94 | label.name = "current_steps" 95 | label.value = str(metrics.get("current_steps", "")) 96 | 97 | label = series.labels.add() 98 | label.name = "eval_loss" 99 | label.value = str(metrics.get("eval_loss", "")) 100 | 101 | label = series.labels.add() 102 | label.name = "eval_perplexity" 103 | label.value = str(metrics.get("eval_perplexity", "")) 104 | 105 | label = series.labels.add() 106 | label.name = "epoch" 107 | label.value = str(metrics.get("epoch", "")) 108 | 109 | sample = series.samples.add() 110 | sample.value = 1 111 | sample.timestamp = dt2ts(datetime.utcnow()) * 1000 112 | 113 | write(address, [series]) 114 | 115 | 116 | if __name__ == '__main__': 117 | train_metrics = { 118 | "uid": "1", 119 | "current_steps": 10, 120 | "total_steps": 84, 121 | "loss": 3.088, 122 | "learning_rate": 4.404761904761905e-05, 123 | "epoch": 0.71 124 | } 125 | export_train_metrics("http://10.33.1.10:30722", train_metrics) 126 | 127 | eval_metrics = { 128 | "uid": "1", 129 | "current_steps": 10, 130 | "total_steps": 84, 131 | "eval_loss": 3.088, 132 | "eval_perplexity": 4.404761904761905e-05, 133 | "epoch": 0.71 134 | } 135 | export_eval_metrics("http://10.33.1.10:30722", eval_metrics) 136 | -------------------------------------------------------------------------------- /cmd/tuning/prometheus/prometheus.proto: -------------------------------------------------------------------------------- 1 | // Copyright 2016 Prometheus Team 2 | // Licensed under the Apache License, Version 2.0 (the "License"); 3 | // you may not use this file except in compliance with the License. 4 | // You may obtain a copy of the License at 5 | // 6 | // http://www.apache.org/licenses/LICENSE-2.0 7 | // 8 | // Unless required by applicable law or agreed to in writing, software 9 | // distributed under the License is distributed on an "AS IS" BASIS, 10 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 11 | // See the License for the specific language governing permissions and 12 | // limitations under the License. 13 | 14 | syntax = "proto3"; 15 | package prometheus; 16 | 17 | option go_package = "prompb"; 18 | 19 | message WriteRequest { 20 | repeated prometheus.TimeSeries timeseries = 1; 21 | } 22 | 23 | message ReadRequest { 24 | repeated Query queries = 1; 25 | } 26 | 27 | message ReadResponse { 28 | // In same order as the request's queries. 29 | repeated QueryResult results = 1; 30 | } 31 | 32 | message Query { 33 | int64 start_timestamp_ms = 1; 34 | int64 end_timestamp_ms = 2; 35 | repeated prometheus.LabelMatcher matchers = 3; 36 | prometheus.ReadHints hints = 4; 37 | } 38 | 39 | message QueryResult { 40 | // Samples within a time series must be ordered by time. 41 | repeated prometheus.TimeSeries timeseries = 1; 42 | } 43 | 44 | message Sample { 45 | double value = 1; 46 | int64 timestamp = 2; 47 | } 48 | 49 | message TimeSeries { 50 | repeated Label labels = 1; 51 | repeated Sample samples = 2; 52 | } 53 | 54 | message Label { 55 | string name = 1; 56 | string value = 2; 57 | } 58 | 59 | message Labels { 60 | repeated Label labels = 1; 61 | } 62 | 63 | // Matcher specifies a rule, which can match or set of labels or not. 64 | message LabelMatcher { 65 | enum Type { 66 | EQ = 0; 67 | NEQ = 1; 68 | RE = 2; 69 | NRE = 3; 70 | } 71 | Type type = 1; 72 | string name = 2; 73 | string value = 3; 74 | } 75 | 76 | message ReadHints { 77 | int64 step_ms = 1; // Query step size in milliseconds. 78 | string func = 2; // String representation of surrounding function or aggregation. 79 | int64 start_ms = 3; // Start time in milliseconds. 80 | int64 end_ms = 4; // End time in milliseconds. 81 | } -------------------------------------------------------------------------------- /cmd/tuning/prometheus/prometheus_pb2.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # Generated by the protocol buffer compiler. DO NOT EDIT! 3 | # source: prometheus.proto 4 | """Generated protocol buffer code.""" 5 | from google.protobuf import descriptor as _descriptor 6 | from google.protobuf import descriptor_pool as _descriptor_pool 7 | from google.protobuf import message as _message 8 | from google.protobuf import reflection as _reflection 9 | from google.protobuf import symbol_database as _symbol_database 10 | # @@protoc_insertion_point(imports) 11 | 12 | _sym_db = _symbol_database.Default() 13 | 14 | 15 | 16 | 17 | DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x10prometheus.proto\x12\nprometheus\":\n\x0cWriteRequest\x12*\n\ntimeseries\x18\x01 \x03(\x0b\x32\x16.prometheus.TimeSeries\"1\n\x0bReadRequest\x12\"\n\x07queries\x18\x01 \x03(\x0b\x32\x11.prometheus.Query\"8\n\x0cReadResponse\x12(\n\x07results\x18\x01 \x03(\x0b\x32\x17.prometheus.QueryResult\"\x8f\x01\n\x05Query\x12\x1a\n\x12start_timestamp_ms\x18\x01 \x01(\x03\x12\x18\n\x10\x65nd_timestamp_ms\x18\x02 \x01(\x03\x12*\n\x08matchers\x18\x03 \x03(\x0b\x32\x18.prometheus.LabelMatcher\x12$\n\x05hints\x18\x04 \x01(\x0b\x32\x15.prometheus.ReadHints\"9\n\x0bQueryResult\x12*\n\ntimeseries\x18\x01 \x03(\x0b\x32\x16.prometheus.TimeSeries\"*\n\x06Sample\x12\r\n\x05value\x18\x01 \x01(\x01\x12\x11\n\ttimestamp\x18\x02 \x01(\x03\"T\n\nTimeSeries\x12!\n\x06labels\x18\x01 \x03(\x0b\x32\x11.prometheus.Label\x12#\n\x07samples\x18\x02 \x03(\x0b\x32\x12.prometheus.Sample\"$\n\x05Label\x12\x0c\n\x04name\x18\x01 \x01(\t\x12\r\n\x05value\x18\x02 \x01(\t\"+\n\x06Labels\x12!\n\x06labels\x18\x01 \x03(\x0b\x32\x11.prometheus.Label\"\x82\x01\n\x0cLabelMatcher\x12+\n\x04type\x18\x01 \x01(\x0e\x32\x1d.prometheus.LabelMatcher.Type\x12\x0c\n\x04name\x18\x02 \x01(\t\x12\r\n\x05value\x18\x03 \x01(\t\"(\n\x04Type\x12\x06\n\x02\x45Q\x10\x00\x12\x07\n\x03NEQ\x10\x01\x12\x06\n\x02RE\x10\x02\x12\x07\n\x03NRE\x10\x03\"L\n\tReadHints\x12\x0f\n\x07step_ms\x18\x01 \x01(\x03\x12\x0c\n\x04\x66unc\x18\x02 \x01(\t\x12\x10\n\x08start_ms\x18\x03 \x01(\x03\x12\x0e\n\x06\x65nd_ms\x18\x04 \x01(\x03\x42\x08Z\x06prompbb\x06proto3') 18 | 19 | 20 | 21 | _WRITEREQUEST = DESCRIPTOR.message_types_by_name['WriteRequest'] 22 | _READREQUEST = DESCRIPTOR.message_types_by_name['ReadRequest'] 23 | _READRESPONSE = DESCRIPTOR.message_types_by_name['ReadResponse'] 24 | _QUERY = DESCRIPTOR.message_types_by_name['Query'] 25 | _QUERYRESULT = DESCRIPTOR.message_types_by_name['QueryResult'] 26 | _SAMPLE = DESCRIPTOR.message_types_by_name['Sample'] 27 | _TIMESERIES = DESCRIPTOR.message_types_by_name['TimeSeries'] 28 | _LABEL = DESCRIPTOR.message_types_by_name['Label'] 29 | _LABELS = DESCRIPTOR.message_types_by_name['Labels'] 30 | _LABELMATCHER = DESCRIPTOR.message_types_by_name['LabelMatcher'] 31 | _READHINTS = DESCRIPTOR.message_types_by_name['ReadHints'] 32 | _LABELMATCHER_TYPE = _LABELMATCHER.enum_types_by_name['Type'] 33 | WriteRequest = _reflection.GeneratedProtocolMessageType('WriteRequest', (_message.Message,), { 34 | 'DESCRIPTOR' : _WRITEREQUEST, 35 | '__module__' : 'prometheus_pb2' 36 | # @@protoc_insertion_point(class_scope:prometheus.WriteRequest) 37 | }) 38 | _sym_db.RegisterMessage(WriteRequest) 39 | 40 | ReadRequest = _reflection.GeneratedProtocolMessageType('ReadRequest', (_message.Message,), { 41 | 'DESCRIPTOR' : _READREQUEST, 42 | '__module__' : 'prometheus_pb2' 43 | # @@protoc_insertion_point(class_scope:prometheus.ReadRequest) 44 | }) 45 | _sym_db.RegisterMessage(ReadRequest) 46 | 47 | ReadResponse = _reflection.GeneratedProtocolMessageType('ReadResponse', (_message.Message,), { 48 | 'DESCRIPTOR' : _READRESPONSE, 49 | '__module__' : 'prometheus_pb2' 50 | # @@protoc_insertion_point(class_scope:prometheus.ReadResponse) 51 | }) 52 | _sym_db.RegisterMessage(ReadResponse) 53 | 54 | Query = _reflection.GeneratedProtocolMessageType('Query', (_message.Message,), { 55 | 'DESCRIPTOR' : _QUERY, 56 | '__module__' : 'prometheus_pb2' 57 | # @@protoc_insertion_point(class_scope:prometheus.Query) 58 | }) 59 | _sym_db.RegisterMessage(Query) 60 | 61 | QueryResult = _reflection.GeneratedProtocolMessageType('QueryResult', (_message.Message,), { 62 | 'DESCRIPTOR' : _QUERYRESULT, 63 | '__module__' : 'prometheus_pb2' 64 | # @@protoc_insertion_point(class_scope:prometheus.QueryResult) 65 | }) 66 | _sym_db.RegisterMessage(QueryResult) 67 | 68 | Sample = _reflection.GeneratedProtocolMessageType('Sample', (_message.Message,), { 69 | 'DESCRIPTOR' : _SAMPLE, 70 | '__module__' : 'prometheus_pb2' 71 | # @@protoc_insertion_point(class_scope:prometheus.Sample) 72 | }) 73 | _sym_db.RegisterMessage(Sample) 74 | 75 | TimeSeries = _reflection.GeneratedProtocolMessageType('TimeSeries', (_message.Message,), { 76 | 'DESCRIPTOR' : _TIMESERIES, 77 | '__module__' : 'prometheus_pb2' 78 | # @@protoc_insertion_point(class_scope:prometheus.TimeSeries) 79 | }) 80 | _sym_db.RegisterMessage(TimeSeries) 81 | 82 | Label = _reflection.GeneratedProtocolMessageType('Label', (_message.Message,), { 83 | 'DESCRIPTOR' : _LABEL, 84 | '__module__' : 'prometheus_pb2' 85 | # @@protoc_insertion_point(class_scope:prometheus.Label) 86 | }) 87 | _sym_db.RegisterMessage(Label) 88 | 89 | Labels = _reflection.GeneratedProtocolMessageType('Labels', (_message.Message,), { 90 | 'DESCRIPTOR' : _LABELS, 91 | '__module__' : 'prometheus_pb2' 92 | # @@protoc_insertion_point(class_scope:prometheus.Labels) 93 | }) 94 | _sym_db.RegisterMessage(Labels) 95 | 96 | LabelMatcher = _reflection.GeneratedProtocolMessageType('LabelMatcher', (_message.Message,), { 97 | 'DESCRIPTOR' : _LABELMATCHER, 98 | '__module__' : 'prometheus_pb2' 99 | # @@protoc_insertion_point(class_scope:prometheus.LabelMatcher) 100 | }) 101 | _sym_db.RegisterMessage(LabelMatcher) 102 | 103 | ReadHints = _reflection.GeneratedProtocolMessageType('ReadHints', (_message.Message,), { 104 | 'DESCRIPTOR' : _READHINTS, 105 | '__module__' : 'prometheus_pb2' 106 | # @@protoc_insertion_point(class_scope:prometheus.ReadHints) 107 | }) 108 | _sym_db.RegisterMessage(ReadHints) 109 | 110 | if _descriptor._USE_C_DESCRIPTORS == False: 111 | 112 | DESCRIPTOR._options = None 113 | DESCRIPTOR._serialized_options = b'Z\006prompb' 114 | _WRITEREQUEST._serialized_start=32 115 | _WRITEREQUEST._serialized_end=90 116 | _READREQUEST._serialized_start=92 117 | _READREQUEST._serialized_end=141 118 | _READRESPONSE._serialized_start=143 119 | _READRESPONSE._serialized_end=199 120 | _QUERY._serialized_start=202 121 | _QUERY._serialized_end=345 122 | _QUERYRESULT._serialized_start=347 123 | _QUERYRESULT._serialized_end=404 124 | _SAMPLE._serialized_start=406 125 | _SAMPLE._serialized_end=448 126 | _TIMESERIES._serialized_start=450 127 | _TIMESERIES._serialized_end=534 128 | _LABEL._serialized_start=536 129 | _LABEL._serialized_end=572 130 | _LABELS._serialized_start=574 131 | _LABELS._serialized_end=617 132 | _LABELMATCHER._serialized_start=620 133 | _LABELMATCHER._serialized_end=750 134 | _LABELMATCHER_TYPE._serialized_start=710 135 | _LABELMATCHER_TYPE._serialized_end=750 136 | _READHINTS._serialized_start=752 137 | _READHINTS._serialized_end=828 138 | # @@protoc_insertion_point(module_scope) 139 | -------------------------------------------------------------------------------- /cmd/tuning/requirements.txt: -------------------------------------------------------------------------------- 1 | bitsandbytes==0.41.3.post2 2 | datasets==2.14.5 3 | deepspeed==0.12.2 4 | evaluate==0.4.1 5 | peft==0.5.0 6 | protobuf==3.19.6 7 | python-snappy==0.6.1 8 | torch==2.1.0 9 | transformers==4.34.0 10 | 11 | -------------------------------------------------------------------------------- /cmd/tuning/template.py: -------------------------------------------------------------------------------- 1 | import logging 2 | from dataclasses import dataclass 3 | from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Union 4 | 5 | 6 | if TYPE_CHECKING: 7 | from transformers import PreTrainedTokenizer 8 | 9 | 10 | logger = logging.getLogger(__name__) 11 | 12 | 13 | @dataclass 14 | class Template: 15 | 16 | prefix: List[Union[str, Dict[str, str]]] 17 | prompt: List[Union[str, Dict[str, str]]] 18 | system: str 19 | sep: List[Union[str, Dict[str, str]]] 20 | stop_words: List[str] 21 | use_history: bool 22 | efficient_eos: bool 23 | 24 | def encode_oneturn( 25 | self, 26 | tokenizer: "PreTrainedTokenizer", 27 | query: str, 28 | resp: str, 29 | history: Optional[List[Tuple[str, str]]] = None, 30 | system: Optional[str] = None 31 | ) -> Tuple[List[int], List[int]]: 32 | r""" 33 | Returns a single pair of token ids representing prompt and response respectively. 34 | """ 35 | system, history = self._format(query, resp, history, system) 36 | encoded_pairs = self._encode(tokenizer, system, history) 37 | prompt_ids = [] 38 | for query_ids, resp_ids in encoded_pairs[:-1]: 39 | prompt_ids = prompt_ids + query_ids + resp_ids 40 | prompt_ids, answer_ids = prompt_ids + encoded_pairs[-1][0], encoded_pairs[-1][1] 41 | return prompt_ids, answer_ids 42 | 43 | def encode_multiturn( 44 | self, 45 | tokenizer: "PreTrainedTokenizer", 46 | query: str, 47 | resp: str, 48 | history: Optional[List[Tuple[str, str]]] = None, 49 | system: Optional[str] = None 50 | ) -> List[Tuple[List[int], List[int]]]: 51 | r""" 52 | Returns multiple pairs of token ids representing prompts and responses respectively. 53 | """ 54 | system, history = self._format(query, resp, history, system) 55 | encoded_pairs = self._encode(tokenizer, system, history) 56 | return encoded_pairs 57 | 58 | def _format( 59 | self, 60 | query: str, 61 | resp: str, 62 | history: Optional[List[Tuple[str, str]]] = None, 63 | system: Optional[str] = None 64 | ) -> Tuple[str, List[Tuple[str, str]]]: 65 | r""" 66 | Aligns inputs to the standard format. 67 | """ 68 | system = system or self.system # use system if provided 69 | history = history if (history and self.use_history) else [] 70 | history = history + [(query, resp)] 71 | return system, history 72 | 73 | def _get_special_ids( 74 | self, 75 | tokenizer: "PreTrainedTokenizer" 76 | ) -> Tuple[List[int], List[int]]: 77 | if tokenizer.bos_token_id is not None and getattr(tokenizer, "add_bos_token", True): 78 | bos_ids = [tokenizer.bos_token_id] 79 | else: # baichuan, qwen and gpt2 models have no bos token 80 | bos_ids = [] 81 | 82 | if tokenizer.eos_token_id is None: 83 | raise ValueError("EOS token is required.") 84 | 85 | if self.efficient_eos: # used in baichuan, qwen, chatglm, etc. 86 | eos_ids = [] 87 | else: 88 | eos_ids = [tokenizer.eos_token_id] 89 | 90 | return bos_ids, eos_ids 91 | 92 | def _encode( 93 | self, 94 | tokenizer: "PreTrainedTokenizer", 95 | system: str, 96 | history: List[Tuple[str, str]] 97 | ) -> List[Tuple[List[int], List[int]]]: 98 | r""" 99 | Encodes formatted inputs to pairs of token ids. 100 | Turn 0: bos + prefix + sep + query resp + eos 101 | Turn t: sep + bos + query resp + eos 102 | """ 103 | bos_ids, eos_ids = self._get_special_ids(tokenizer) 104 | sep_ids = self._convert_inputs_to_ids(tokenizer, context=self.sep) 105 | encoded_pairs = [] 106 | for turn_idx, (query, resp) in enumerate(history): 107 | if turn_idx == 0: 108 | prefix_ids = self._convert_inputs_to_ids(tokenizer, context=self.prefix, system=system) 109 | if len(prefix_ids) != 0: # has prefix 110 | prefix_ids = bos_ids + prefix_ids + sep_ids 111 | else: 112 | prefix_ids = bos_ids 113 | else: 114 | prefix_ids = sep_ids + bos_ids 115 | 116 | query_ids = self._convert_inputs_to_ids(tokenizer, context=self.prompt, query=query, idx=str(turn_idx)) 117 | resp_ids = self._convert_inputs_to_ids(tokenizer, context=[resp]) 118 | encoded_pairs.append((prefix_ids + query_ids, resp_ids + eos_ids)) 119 | return encoded_pairs 120 | 121 | def _convert_inputs_to_ids( 122 | self, 123 | tokenizer: "PreTrainedTokenizer", 124 | context: List[Union[str, Dict[str, str]]], 125 | system: Optional[str] = None, 126 | query: Optional[str] = None, 127 | idx: Optional[str] = None 128 | ) -> List[int]: 129 | r""" 130 | Converts context to token ids. 131 | """ 132 | 133 | kwargs = dict(add_special_tokens=False) 134 | 135 | token_ids = [] 136 | for elem in context: 137 | if isinstance(elem, str): 138 | elem = elem.replace("{{system}}", system, 1) if system is not None else elem 139 | elem = elem.replace("{{query}}", query, 1) if query is not None else elem 140 | elem = elem.replace("{{idx}}", idx, 1) if idx is not None else elem 141 | if len(elem) != 0: 142 | token_ids = token_ids + tokenizer.encode(elem, **kwargs) 143 | elif isinstance(elem, dict): 144 | token_ids = token_ids + [tokenizer.convert_tokens_to_ids(elem.get("token"))] 145 | else: 146 | raise ValueError("Input must be string or dict[str, str], got {}".format(type(elem))) 147 | 148 | return token_ids 149 | 150 | 151 | @dataclass 152 | class Llama2Template(Template): 153 | 154 | def _encode( 155 | self, 156 | tokenizer: "PreTrainedTokenizer", 157 | system: str, 158 | history: List[Tuple[str, str]] 159 | ) -> List[Tuple[List[int], List[int]]]: 160 | r""" 161 | Encodes formatted inputs to pairs of token ids. 162 | Turn 0: bos + prefix + query resp + eos 163 | Turn t: bos + query resp + eos 164 | """ 165 | bos_ids, eos_ids = self._get_special_ids(tokenizer) 166 | encoded_pairs = [] 167 | for turn_idx, (query, resp) in enumerate(history): 168 | if turn_idx == 0: # llama2 template has no sep_ids 169 | query = self.prefix[0].replace("{{system}}", system) + query 170 | query_ids = self._convert_inputs_to_ids(tokenizer, context=self.prompt, query=query) 171 | resp_ids = self._convert_inputs_to_ids(tokenizer, context=[resp]) 172 | encoded_pairs.append((bos_ids + query_ids, resp_ids + eos_ids)) 173 | return encoded_pairs 174 | 175 | 176 | templates: Dict[str, Template] = {} 177 | 178 | 179 | def register_template( 180 | name: str, 181 | prefix: List[Union[str, Dict[str, str]]], 182 | prompt: List[Union[str, Dict[str, str]]], 183 | system: str, 184 | sep: List[Union[str, Dict[str, str]]], 185 | stop_words: Optional[List[str]] = [], 186 | use_history: Optional[bool] = True, 187 | efficient_eos: Optional[bool] = False 188 | ) -> None: 189 | template_class = Llama2Template if "llama2" in name else Template 190 | templates[name] = template_class( 191 | prefix=prefix, 192 | prompt=prompt, 193 | system=system, 194 | sep=sep, 195 | stop_words=stop_words, 196 | use_history=use_history, 197 | efficient_eos=efficient_eos 198 | ) 199 | 200 | 201 | def get_template_and_fix_tokenizer( 202 | name: str, 203 | tokenizer: "PreTrainedTokenizer" 204 | ) -> Template: 205 | if tokenizer.eos_token_id is None: 206 | tokenizer.eos_token = "<|endoftext|>" 207 | logger.info("Add eos token: {}".format(tokenizer.eos_token)) 208 | 209 | if tokenizer.pad_token_id is None: 210 | tokenizer.pad_token = tokenizer.eos_token 211 | logger.info("Add pad token: {}".format(tokenizer.pad_token)) 212 | 213 | if name is None: 214 | return None 215 | 216 | template = templates.get(name, None) 217 | assert template is not None, "Template {} does not exist.".format(name) 218 | tokenizer.add_special_tokens( 219 | dict(additional_special_tokens=template.stop_words), 220 | replace_additional_special_tokens=False 221 | ) 222 | return template 223 | 224 | 225 | r""" 226 | Supports language model inference without histories. 227 | """ 228 | register_template( 229 | name="vanilla", 230 | prefix=[], 231 | prompt=[ 232 | "{{query}}" 233 | ], 234 | system="", 235 | sep=[], 236 | use_history=False 237 | ) 238 | 239 | 240 | r""" 241 | Default template. 242 | """ 243 | register_template( 244 | name="default", 245 | prefix=[ 246 | "{{system}}" 247 | ], 248 | prompt=[ 249 | "Human: {{query}}\nAssistant: " 250 | ], 251 | system=( 252 | "A chat between a curious user and an artificial intelligence assistant. " 253 | "The assistant gives helpful, detailed, and polite answers to the user's questions." 254 | ), 255 | sep=[ 256 | "\n" 257 | ] 258 | ) 259 | 260 | 261 | r""" 262 | Supports: https://huggingface.co/meta-llama/Llama-2-7b-chat-hf 263 | https://huggingface.co/meta-llama/Llama-2-13b-chat-hf 264 | https://huggingface.co/meta-llama/Llama-2-70b-chat-hf 265 | """ 266 | register_template( 267 | name="llama2", 268 | prefix=[ 269 | "<>\n{{system}}\n<>\n\n" 270 | ], 271 | prompt=[ 272 | "[INST] {{query}} [/INST] " 273 | ], 274 | system=( 275 | "You are a helpful, respectful and honest assistant. " 276 | "Always answer as helpfully as possible, while being safe. " 277 | "Your answers should not include any harmful, unethical, " 278 | "racist, sexist, toxic, dangerous, or illegal content. " 279 | "Please ensure that your responses are socially unbiased and positive in nature.\n\n" 280 | "If a question does not make any sense, or is not factually coherent, " 281 | "explain why instead of answering something not correct. " 282 | "If you don't know the answer to a question, please don't share false information." 283 | ), 284 | sep=[] 285 | ) 286 | 287 | 288 | r""" 289 | Supports: https://huggingface.co/ziqingyang/chinese-alpaca-2-7b 290 | https://huggingface.co/ziqingyang/chinese-alpaca-2-13b 291 | """ 292 | register_template( 293 | name="llama2_zh", 294 | prefix=[ 295 | "<>\n{{system}}\n<>\n\n" 296 | ], 297 | prompt=[ 298 | "[INST] {{query}} [/INST] " 299 | ], 300 | system="You are a helpful assistant. 你是一个乐于助人的助手。", 301 | sep=[] 302 | ) 303 | 304 | 305 | r""" 306 | Supports: https://huggingface.co/tatsu-lab/alpaca-7b-wdiff 307 | """ 308 | register_template( 309 | name="alpaca", 310 | prefix=[ 311 | "{{system}}" 312 | ], 313 | prompt=[ 314 | "### Instruction:\n{{query}}\n\n### Response:\n" 315 | ], 316 | system=( 317 | "Below is an instruction that describes a task. " 318 | "Write a response that appropriately completes the request." 319 | ), 320 | sep=[ 321 | "\n\n" 322 | ] 323 | ) 324 | 325 | 326 | r""" 327 | Supports: https://huggingface.co/lmsys/vicuna-7b-v1.5 328 | https://huggingface.co/lmsys/vicuna-13b-v1.5 329 | """ 330 | register_template( 331 | name="vicuna", 332 | prefix=[ 333 | "{{system}}" 334 | ], 335 | prompt=[ 336 | "USER: {{query}} ASSISTANT:" 337 | ], 338 | system=( 339 | "A chat between a curious user and an artificial intelligence assistant. " 340 | "The assistant gives helpful, detailed, and polite answers to the user's questions." 341 | ), 342 | sep=[] 343 | ) 344 | 345 | 346 | r""" 347 | Supports: https://huggingface.co/BelleGroup/BELLE-LLaMA-EXT-13B 348 | """ 349 | register_template( 350 | name="belle", 351 | prefix=[ 352 | "{{system}}" 353 | ], 354 | prompt=[ 355 | "Human: {{query}}\n\nBelle: " 356 | ], 357 | system="", 358 | sep=[ 359 | "\n\n" 360 | ] 361 | ) 362 | 363 | 364 | r""" 365 | Supports: https://huggingface.co/IDEA-CCNL/Ziya-LLaMA-13B-v1 366 | https://huggingface.co/IDEA-CCNL/Ziya-LLaMA-13B-v1.1 367 | https://huggingface.co/IDEA-CCNL/Ziya2-13B-Chat 368 | """ 369 | register_template( 370 | name="ziya", 371 | prefix=[ 372 | "{{system}}" 373 | ], 374 | prompt=[ 375 | {"token": ""}, 376 | ":{{query}}\n", 377 | {"token": ""}, 378 | ":" 379 | ], 380 | system="", 381 | sep=[ 382 | "\n" 383 | ] 384 | ) 385 | 386 | 387 | r""" 388 | Supports: https://huggingface.co/BAAI/AquilaChat-7B 389 | https://huggingface.co/BAAI/AquilaChat2-7B 390 | https://huggingface.co/BAAI/AquilaChat2-34B 391 | """ 392 | register_template( 393 | name="aquila", 394 | prefix=[ 395 | "{{system}}" 396 | ], 397 | prompt=[ 398 | "Human: {{query}}###Assistant:" 399 | ], 400 | system=( 401 | "A chat between a curious human and an artificial intelligence assistant. " 402 | "The assistant gives helpful, detailed, and polite answers to the human's questions." 403 | ), 404 | sep=[ 405 | "###" 406 | ], 407 | stop_words=[ 408 | "" 409 | ], 410 | efficient_eos=True 411 | ) 412 | 413 | 414 | r""" 415 | Supports: https://huggingface.co/internlm/internlm-chat-7b 416 | https://huggingface.co/internlm/internlm-chat-20b 417 | """ 418 | register_template( 419 | name="intern", 420 | prefix=[ 421 | "{{system}}" 422 | ], 423 | prompt=[ 424 | "<|User|>:{{query}}", 425 | {"token": ""}, 426 | "\n<|Bot|>:" 427 | ], 428 | system="", 429 | sep=[ 430 | {"token": ""}, 431 | "\n" 432 | ], 433 | stop_words=[ 434 | "" 435 | ], 436 | efficient_eos=True 437 | ) 438 | 439 | 440 | r""" 441 | Supports: https://huggingface.co/baichuan-inc/Baichuan-13B-Chat 442 | """ 443 | register_template( 444 | name="baichuan", 445 | prefix=[ 446 | "{{system}}" 447 | ], 448 | prompt=[ 449 | {"token": ""}, # user token 450 | "{{query}}", 451 | {"token": ""} # assistant token 452 | ], 453 | system="", 454 | sep=[], 455 | efficient_eos=True 456 | ) 457 | 458 | 459 | r""" 460 | Supports: https://huggingface.co/baichuan-inc/Baichuan2-7B-Chat 461 | https://huggingface.co/baichuan-inc/Baichuan2-13B-Chat 462 | """ 463 | register_template( 464 | name="baichuan2", 465 | prefix=[ 466 | "{{system}}" 467 | ], 468 | prompt=[ 469 | {"token": ""}, # user token 470 | "{{query}}", 471 | {"token": ""} # assistant token 472 | ], 473 | system="", 474 | sep=[], 475 | efficient_eos=True 476 | ) 477 | 478 | 479 | r""" 480 | Supports: https://huggingface.co/HuggingFaceH4/starchat-alpha 481 | https://huggingface.co/HuggingFaceH4/starchat-beta 482 | """ 483 | register_template( 484 | name="starchat", 485 | prefix=[ 486 | {"token": "<|system|>"}, 487 | "\n{{system}}", 488 | ], 489 | prompt=[ 490 | {"token": "<|user|>"}, 491 | "\n{{query}}", 492 | {"token": "<|end|>"}, 493 | "\n", 494 | {"token": "<|assistant|>"} 495 | ], 496 | system="", 497 | sep=[ 498 | {"token": "<|end|>"}, 499 | "\n" 500 | ], 501 | stop_words=[ 502 | "<|end|>" 503 | ], 504 | efficient_eos=True 505 | ) 506 | 507 | 508 | r""" 509 | Supports: https://huggingface.co/Qwen/Qwen-7B-Chat 510 | https://huggingface.co/Qwen/Qwen-14B-Chat 511 | """ 512 | register_template( 513 | name="chatml", 514 | prefix=[ 515 | {"token": "<|im_start|>"}, 516 | "system\n{{system}}" 517 | ], 518 | prompt=[ 519 | {"token": "<|im_start|>"}, 520 | "user\n{{query}}", 521 | {"token": "<|im_end|>"}, 522 | "\n", 523 | {"token": "<|im_start|>"}, 524 | "assistant\n" 525 | ], 526 | system="You are a helpful assistant.", 527 | sep=[ 528 | {"token": "<|im_end|>"}, 529 | "\n" 530 | ], 531 | stop_words=[ 532 | "<|im_end|>" 533 | ], 534 | efficient_eos=True 535 | ) 536 | 537 | 538 | r""" 539 | Supports: https://huggingface.co/THUDM/chatglm2-6b 540 | """ 541 | register_template( 542 | name="chatglm2", 543 | prefix=[ 544 | {"token": "[gMASK]"}, 545 | {"token": "sop"}, 546 | "{{system}}" 547 | ], 548 | prompt=[ 549 | "[Round {{idx}}]\n\n问:{{query}}\n\n答:" 550 | ], 551 | system="", 552 | sep=[ 553 | "\n\n" 554 | ], 555 | efficient_eos=True 556 | ) 557 | 558 | 559 | r""" 560 | Supports: https://huggingface.co/THUDM/chatglm3-6b 561 | """ 562 | register_template( 563 | name="chatglm3", 564 | prefix=[ 565 | {"token": "[gMASK]"}, 566 | {"token": "sop"}, 567 | "{{system}}" 568 | ], 569 | prompt=[ 570 | {"token": "<|user|>"}, 571 | "\n", 572 | "{{query}}", 573 | {"token": "<|assistant|>"} 574 | ], 575 | system="", 576 | sep=[], 577 | stop_words=[ 578 | "<|user|>", 579 | "<|observation|>" 580 | ], 581 | efficient_eos=True 582 | ) 583 | 584 | 585 | r""" 586 | Supports: https://huggingface.co/openchat/openchat_v3.2_super 587 | """ 588 | register_template( 589 | name="openchat", 590 | prefix=[ 591 | "{{system}}" 592 | ], 593 | prompt=[ 594 | "GPT4 User: {{query}}", 595 | {"token": "<|end_of_turn|>"}, 596 | "GPT4 Assistant:" 597 | ], 598 | system="", 599 | sep=[ 600 | {"token": "<|end_of_turn|>"} 601 | ], 602 | efficient_eos=True 603 | ) 604 | 605 | 606 | r""" 607 | Supports: https://huggingface.co/xverse/XVERSE-7B-Chat 608 | https://huggingface.co/xverse/XVERSE-13B-Chat 609 | """ 610 | register_template( 611 | name="xverse", 612 | prefix=[ 613 | "{{system}}" 614 | ], 615 | prompt=[ 616 | "Human: {{query}}\n\nAssistant: " 617 | ], 618 | system="", 619 | sep=[] 620 | ) 621 | -------------------------------------------------------------------------------- /cmd/tuning/train.py: -------------------------------------------------------------------------------- 1 | import functools 2 | import json 3 | import logging 4 | import math 5 | import os 6 | import sys 7 | from types import MethodType 8 | 9 | import numpy as np 10 | from dataclasses import dataclass 11 | from typing import Sequence, Union, Tuple, Dict, List, Any, Generator, Optional 12 | 13 | import ray 14 | import ray.data 15 | import torch 16 | import evaluate 17 | from pandas import DataFrame 18 | from peft import PeftModel, LoraConfig, TaskType, get_peft_model 19 | from ray.air import ScalingConfig, RunConfig 20 | from ray.train import Checkpoint 21 | from ray.train.huggingface.transformers import prepare_trainer, RayTrainReportCallback 22 | from ray.train.torch import TorchTrainer, get_device, TorchConfig 23 | from torch import nn 24 | from transformers import AutoTokenizer, AutoModel, Trainer, TrainingArguments, Seq2SeqTrainer, DataCollatorForSeq2Seq, \ 25 | AutoConfig, AutoModelForCausalLM, Seq2SeqTrainingArguments, BitsAndBytesConfig 26 | from ray.train.huggingface import TransformersTrainer 27 | from transformers.integrations import is_deepspeed_zero3_enabled 28 | from transformers.tokenization_utils import PreTrainedTokenizer 29 | from datasets import load_dataset, Dataset 30 | 31 | from callback import LogCallback 32 | from parser import get_train_args 33 | from template import get_template_and_fix_tokenizer 34 | from trainer import SFTTrainer 35 | 36 | logging.basicConfig(level=logging.INFO) 37 | # logger = logging.getLogger(__name__) 38 | # formatter = logging.Formatter( 39 | # fmt="%(asctime)s - %(levelname)s - %(name)s - %(message)s", 40 | # datefmt="%m/%d/%Y %H:%M:%S" 41 | # ) 42 | # handler = logging.StreamHandler(sys.stdout) 43 | # handler.setFormatter(formatter) 44 | # 45 | # logger = logging.getLogger() 46 | # logger.setLevel(logging.INFO) 47 | # logger.addHandler(handler) 48 | 49 | cpus_per_worker = 8 50 | IGNORE_INDEX = -100 51 | cutoff_len = 1024 52 | 53 | 54 | def rename_columns(batch: DataFrame, columns): 55 | return batch.rename(columns=columns) 56 | 57 | 58 | def preprocess_dataset( 59 | dataset: Union["Dataset", "IterableDataset"], 60 | tokenizer: "PreTrainedTokenizer", 61 | training_args: "Seq2SeqTrainingArguments" 62 | ) -> Union["Dataset", "IterableDataset"]: 63 | template = get_template_and_fix_tokenizer("llama2", tokenizer) 64 | 65 | def construct_example(examples: Dict[str, List[Any]]) -> Generator[Any, None, None]: 66 | for i in range(len(examples["instruction"])): 67 | query, response = examples["instruction"][i], examples["response"][i] 68 | query = query + "\n" + examples["query"][i] if "query" in examples and examples["query"][i] else query 69 | history = examples["history"][i] if "history" in examples else None 70 | system = examples["system"][i] if "system" in examples else None 71 | yield query, response, history, system 72 | 73 | def preprocess_supervised_dataset(examples: Dict[str, List[Any]]) -> Dict[str, Any]: 74 | # build inputs with format ` X Y ` and labels with format ` ... Y ` 75 | # for multiturn examples, we only mask the prompt part in each prompt-response pair. 76 | # print(examples) 77 | 78 | model_inputs = {"input_ids": [], "attention_mask": [], "labels": []} 79 | 80 | for query, response, history, system in construct_example(examples): 81 | if not (isinstance(query, str) and isinstance(response, str) and query != "" and response != ""): 82 | continue 83 | 84 | input_ids, labels = [], [] 85 | for turn_idx, (source_ids, target_ids) in enumerate(template.encode_multiturn( 86 | tokenizer, query, response, history, system 87 | )): 88 | total_len = len(source_ids) + len(target_ids) 89 | max_source_len = int(cutoff_len * (len(source_ids) / total_len)) 90 | max_target_len = int(cutoff_len * (len(target_ids) / total_len)) 91 | 92 | if len(source_ids) > max_source_len: 93 | source_ids = source_ids[:max_source_len] 94 | if len(target_ids) > max_target_len: 95 | target_ids = target_ids[:max_target_len] 96 | 97 | if turn_idx != 0 and template.efficient_eos: 98 | source_mask = [tokenizer.eos_token_id] + [IGNORE_INDEX] * (len(source_ids) - 1) 99 | else: 100 | source_mask = [IGNORE_INDEX] * len(source_ids) 101 | 102 | input_ids += source_ids + target_ids 103 | labels += source_mask + target_ids 104 | 105 | if template.efficient_eos: 106 | input_ids += [tokenizer.eos_token_id] 107 | labels += [tokenizer.eos_token_id] 108 | 109 | if len(input_ids) > cutoff_len: 110 | input_ids = input_ids[:cutoff_len] 111 | labels = labels[:cutoff_len] 112 | 113 | model_inputs["input_ids"].append(input_ids) 114 | model_inputs["attention_mask"].append([1] * len(input_ids)) 115 | model_inputs["labels"].append(labels) 116 | 117 | return model_inputs 118 | 119 | def print_supervised_dataset_example(example): 120 | print("input_ids:\n{}".format(example["input_ids"])) 121 | print("inputs:\n{}".format(tokenizer.decode(example["input_ids"], skip_special_tokens=False))) 122 | print("label_ids:\n{}".format(example["labels"])) 123 | print("labels:\n{}".format( 124 | tokenizer.decode(list(filter(lambda x: x != IGNORE_INDEX, example["labels"])), skip_special_tokens=False) 125 | )) 126 | 127 | preprocess_func = preprocess_supervised_dataset 128 | print_function = print_supervised_dataset_example 129 | new_dataset = dataset.map_batches(preprocess_func) 130 | if training_args.should_log: 131 | try: 132 | print_function(new_dataset.take(1)[0]) 133 | except StopIteration: 134 | raise RuntimeError("Empty dataset!") 135 | return new_dataset 136 | 137 | 138 | def trainer_init_per_worker(config): 139 | print("--- train_task, pid: ", os.getpid()) 140 | 141 | cuda_visible_device = os.environ["CUDA_VISIBLE_DEVICES"].split(",") 142 | print("CUDA_VISIBLE_DEVICES", os.environ["CUDA_VISIBLE_DEVICES"]) 143 | local_rank = int(os.environ["LOCAL_RANK"]) 144 | print("local_rank:", local_rank) 145 | device_id = cuda_visible_device[local_rank] 146 | print("device_id:", device_id) 147 | os.environ["CUDA_VISIBLE_DEVICES"] = f"{device_id}" 148 | torch.cuda.set_device(int(device_id)) 149 | 150 | # device setting 151 | # device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") 152 | print("device:", torch.cuda.current_device()) 153 | device_ids = torch._utils._get_all_device_indices() 154 | print("device_ids:", device_ids) 155 | if len(device_ids) <= 0: 156 | print("invalid device_ids, exit") 157 | return 158 | 159 | training_args = config.get("training_args", None) 160 | finetuning_args = config.get("finetuning_args", None) 161 | model_args = config.get("model_args", None) 162 | data_args = config.get("data_args", None) 163 | tokenizer = config.get("tokenizer", None) 164 | 165 | # read dataset 166 | train_ds = ray.train.get_dataset_shard("train") 167 | print(f"train_ds: {train_ds}") 168 | 169 | def train_gen(): 170 | for row in train_ds.iter_rows(): 171 | yield row 172 | 173 | train_dataset = Dataset.from_generator(train_gen) 174 | print(train_dataset) 175 | print('------') 176 | print(train_dataset[0]) 177 | 178 | eval_ds = ray.train.get_dataset_shard("evaluation") 179 | print(f"eval_ds: {eval_ds}") 180 | 181 | def eval_gen(): 182 | for row in eval_ds.iter_rows(): 183 | yield row 184 | 185 | eval_dataset = None 186 | evaluation_strategy = "no" 187 | if eval_ds: 188 | eval_dataset = Dataset.from_generator(eval_gen) 189 | print(eval_dataset) 190 | evaluation_strategy = "steps" 191 | 192 | train_ds_len = len(list(train_ds.iter_batches(batch_size=1))) 193 | steps_per_epoch = math.ceil(train_ds_len / training_args.per_device_train_batch_size) 194 | print(f"train_ds_len: {train_ds_len}, steps_per_epoch: {steps_per_epoch}") 195 | 196 | new_training_args = Seq2SeqTrainingArguments( 197 | training_args.output_dir, 198 | logging_steps=10, 199 | save_strategy="no", 200 | evaluation_strategy=evaluation_strategy, 201 | num_train_epochs=training_args.num_train_epochs, 202 | learning_rate=training_args.learning_rate, 203 | weight_decay=training_args.weight_decay, 204 | warmup_steps=training_args.warmup_steps, 205 | per_device_train_batch_size=training_args.per_device_train_batch_size, 206 | per_device_eval_batch_size=training_args.per_device_eval_batch_size, 207 | optim=training_args.optim, 208 | lr_scheduler_type=training_args.lr_scheduler_type, 209 | gradient_accumulation_steps=training_args.gradient_accumulation_steps, 210 | push_to_hub=False, 211 | report_to="none", 212 | disable_tqdm=False, # declutter the output a little 213 | fp16=training_args.fp16, 214 | gradient_checkpointing=True, 215 | deepspeed=training_args.deepspeed, 216 | log_level="info", 217 | ) 218 | 219 | print(f"new_training_args: {new_training_args}".replace("\n", " ")) 220 | 221 | config = AutoConfig.from_pretrained(model_args.model_name_or_path) 222 | compute_dtype = getattr(config, "torch_dtype", None) 223 | 224 | if model_args.quantization == "int4": 225 | quantization_config = BitsAndBytesConfig( 226 | load_in_4bit=True, 227 | bnb_4bit_quant_type="nf4", 228 | bnb_4bit_compute_dtype=torch.float16, 229 | bnb_4bit_use_double_quant=False, 230 | ) 231 | elif model_args.quantization == "int8": 232 | quantization_config = BitsAndBytesConfig(load_in_8bit=True) 233 | else: 234 | quantization_config = None 235 | 236 | model = AutoModelForCausalLM.from_pretrained( 237 | model_args.model_name_or_path, 238 | config=config, 239 | torch_dtype=compute_dtype, 240 | quantization_config=quantization_config, 241 | low_cpu_mem_usage=(not is_deepspeed_zero3_enabled()), 242 | ) 243 | 244 | if hasattr(model, "enable_input_require_grads"): 245 | model.enable_input_require_grads() 246 | else: 247 | def make_inputs_require_grad(module: torch.nn.Module, input: torch.Tensor, output: torch.Tensor): 248 | output.requires_grad_(True) 249 | 250 | model.get_input_embeddings().register_forward_hook(make_inputs_require_grad) 251 | 252 | model.gradient_checkpointing_enable() 253 | model.config.use_cache = False # turn off when gradient checkpointing is enabled 254 | print("Gradient checkpointing enabled.") 255 | 256 | output_layer_name = "lm_head" 257 | 258 | if hasattr(model, output_layer_name): 259 | output_layer = getattr(model, output_layer_name) 260 | if isinstance(output_layer, torch.nn.Linear): 261 | def forward_in_fp32(self, x: torch.Tensor) -> torch.Tensor: 262 | return output_layer.__class__.forward(self, x.to(output_layer.weight.dtype)).to(torch.float32) 263 | 264 | output_layer.forward = MethodType(forward_in_fp32, output_layer) 265 | 266 | target_modules = finetuning_args.lora_target 267 | 268 | lora_config = LoraConfig( 269 | task_type=TaskType.CAUSAL_LM, 270 | inference_mode=False, 271 | r=finetuning_args.lora_rank, 272 | lora_alpha=finetuning_args.lora_alpha, 273 | lora_dropout=finetuning_args.lora_dropout, 274 | target_modules=target_modules, 275 | modules_to_save=finetuning_args.additional_target 276 | ) 277 | model = get_peft_model(model, lora_config) 278 | if id(model.peft_config) != id(model.base_model.peft_config): # https://github.com/huggingface/peft/issues/923 279 | model.base_model.peft_config = model.peft_config 280 | model.train() 281 | 282 | data_collator = DataCollatorForSeq2Seq( 283 | tokenizer=tokenizer, 284 | pad_to_multiple_of=4, # for shift short attention 285 | label_pad_token_id=IGNORE_INDEX 286 | ) 287 | 288 | trainer = SFTTrainer( 289 | model=model, 290 | args=new_training_args, 291 | tokenizer=tokenizer, 292 | data_collator=data_collator, 293 | train_dataset=train_dataset, 294 | eval_dataset=eval_dataset, 295 | callbacks=[LogCallback(metrics_export_address=finetuning_args.metrics_export_address, uid=finetuning_args.uid)], 296 | ) 297 | 298 | trainer = prepare_trainer(trainer) 299 | train_result = trainer.train() 300 | trainer.save_model(training_args.output_dir) 301 | 302 | checkpoint = None 303 | if ray.train.get_context().get_world_rank() == 0: 304 | checkpoint = Checkpoint.from_directory(training_args.output_dir) 305 | ray.train.report(metrics=train_result.metrics, checkpoint=checkpoint) 306 | 307 | 308 | def main(): 309 | print("init") 310 | ray.init() 311 | 312 | training_args, finetuning_args, model_args, data_args = get_train_args() 313 | 314 | print(f"training_args: {training_args}".replace("\n", " ")) 315 | print(finetuning_args) 316 | print(model_args) 317 | print(data_args) 318 | 319 | model_path = model_args.model_name_or_path 320 | use_gpu = True 321 | num_workers = finetuning_args.num_workers 322 | 323 | if data_args.block_size > 0: 324 | global cutoff_len 325 | cutoff_len = data_args.block_size 326 | 327 | # read dataset 328 | print("preprocess_dataset") 329 | columns_map = { 330 | "instruction": "instruction", 331 | "output": "response" 332 | } 333 | if data_args.columns: 334 | print(data_args.columns) 335 | columns_map.update({v: k for k, v in json.loads(data_args.columns).items()}) 336 | 337 | tokenizer = AutoTokenizer.from_pretrained(model_path) 338 | 339 | train_dataset = ray.data.read_csv(data_args.train_path). \ 340 | map_batches(rename_columns, fn_args=[columns_map], batch_format="pandas") 341 | print(train_dataset) 342 | train_dataset = preprocess_dataset(train_dataset, tokenizer, training_args) 343 | 344 | input_datasets = {"train": train_dataset} 345 | 346 | if data_args.evaluation_path: 347 | evaluation_dataset = ray.data.read_csv(data_args.train_path). \ 348 | map_batches(rename_columns, fn_args=[columns_map], batch_format="pandas") 349 | print(evaluation_dataset) 350 | evaluation_dataset = preprocess_dataset(evaluation_dataset, tokenizer, training_args) 351 | input_datasets["evaluation"] = evaluation_dataset 352 | 353 | scaling_config = ScalingConfig(num_workers=num_workers, use_gpu=use_gpu, 354 | resources_per_worker={"GPU": 1, "CPU": cpus_per_worker}, 355 | trainer_resources={"GPU": 0} 356 | ) 357 | 358 | ray_trainer = TorchTrainer( 359 | train_loop_per_worker=trainer_init_per_worker, 360 | train_loop_config={ 361 | "training_args": training_args, 362 | "finetuning_args": finetuning_args, 363 | "model_args": model_args, 364 | "data_args": data_args, 365 | "tokenizer": tokenizer, 366 | }, 367 | scaling_config=scaling_config, 368 | datasets=input_datasets, 369 | run_config=RunConfig( 370 | storage_path=finetuning_args.storage_path, 371 | # checkpoint_config=ray.train.CheckpointConfig( 372 | # num_to_keep=1, 373 | # checkpoint_score_attribute="eval_loss", 374 | # checkpoint_score_order="min", 375 | # ), 376 | ) 377 | ) 378 | result = ray_trainer.fit() 379 | checkpoint_path = result.checkpoint.path 380 | 381 | print(f"result path {checkpoint_path}") 382 | 383 | file_path = "/home/ray/checkpoint_path" 384 | 385 | directory = os.path.dirname(file_path) 386 | if not os.path.exists(directory): 387 | os.makedirs(directory) 388 | with open(file_path, 'w', encoding='utf-8') as file: 389 | file.write(checkpoint_path) 390 | 391 | 392 | if __name__ == '__main__': 393 | main() 394 | -------------------------------------------------------------------------------- /cmd/tuning/trainer.py: -------------------------------------------------------------------------------- 1 | import os 2 | import json 3 | import torch, time, math 4 | from copy import deepcopy 5 | from pathlib import Path 6 | 7 | import numpy as np 8 | import torch.nn as nn 9 | from torch.utils.data import Dataset 10 | from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union, Callable 11 | from transformers import Seq2SeqTrainer, Trainer 12 | 13 | if TYPE_CHECKING: 14 | from transformers.data.data_collator import DataCollator 15 | from transformers.modeling_utils import PreTrainedModel 16 | from transformers.tokenization_utils_base import PreTrainedTokenizerBase 17 | from transformers.trainer_callback import TrainerCallback 18 | from transformers.training_args import TrainingArguments 19 | 20 | from transformers.generation.configuration_utils import GenerationConfig 21 | from transformers.integrations.deepspeed import is_deepspeed_zero3_enabled 22 | from transformers.utils import logging 23 | from transformers.trainer_utils import EvalPrediction, PredictionOutput, speed_metrics 24 | 25 | IGNORE_INDEX = -100 26 | logger = logging.get_logger(__name__) 27 | 28 | 29 | class GenEvalSeq2SeqTrainer(Seq2SeqTrainer): 30 | def __init__( 31 | self, 32 | model: Union["PreTrainedModel", nn.Module] = None, 33 | args: "TrainingArguments" = None, 34 | data_collator: Optional["DataCollator"] = None, 35 | train_dataset: Optional[Dataset] = None, 36 | eval_dataset: Optional[Union[Dataset, Dict[str, Dataset]]] = None, 37 | tokenizer: Optional["PreTrainedTokenizerBase"] = None, 38 | model_init: Optional[Callable[[], "PreTrainedModel"]] = None, 39 | compute_metrics: Optional[Callable[["EvalPrediction"], Dict]] = None, 40 | callbacks: Optional[List["TrainerCallback"]] = None, 41 | optimizers: Tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR] = (None, None), 42 | preprocess_logits_for_metrics: Optional[Callable[[torch.Tensor, torch.Tensor], torch.Tensor]] = None, 43 | gen_args: Optional[Any] = None, # ADD 44 | ): 45 | self.gen_args = gen_args 46 | 47 | super().__init__( 48 | model=model, 49 | args=args, 50 | data_collator=data_collator, 51 | train_dataset=train_dataset, 52 | eval_dataset=eval_dataset, 53 | tokenizer=tokenizer, 54 | model_init=model_init, 55 | compute_metrics=compute_metrics, 56 | callbacks=callbacks, 57 | optimizers=optimizers, 58 | preprocess_logits_for_metrics=preprocess_logits_for_metrics, 59 | ) 60 | # Override self.model.generation_config if a GenerationConfig is specified in args. 61 | # Priority: args.generation_config > model.generation_config > default GenerationConfig. 62 | if self.args.generation_config is not None: 63 | gen_config = self.load_generation_config(self.args.generation_config) 64 | self.model.generation_config = gen_config 65 | print(f'trainer init : tokenizer{tokenizer}') 66 | 67 | def evaluate( 68 | self, 69 | eval_dataset: Optional[Dataset] = None, 70 | ignore_keys: Optional[List[str]] = None, 71 | metric_key_prefix: str = "eval", 72 | **gen_kwargs, 73 | ) -> Dict[str, float]: 74 | 75 | # force left pad 76 | bak_tp = self.tokenizer.padding_side 77 | bak_dc = self.data_collator 78 | print(f'**** evaluate origin pad style: {self.tokenizer.padding_side} ****', '\n') 79 | 80 | if self.gen_args is not None and (gen_kwargs is None or len(gen_kwargs) == 0): 81 | logger.info("*" * 5 + "Using Initial Trainer gen_args" + "*" * 5) 82 | gen_kwargs = self.gen_args.copy() 83 | else: 84 | logger.info("*" * 5 + "Using Default Trainer gen_kwargs" + "*" * 5) 85 | gen_kwargs = gen_kwargs.copy() 86 | # 添加您自己的逻辑... 87 | # 在这里扩展 gen_kwargs 或通过其他方法修改参数 88 | 89 | # 调用父类的 evaluate 方法并返回结果 90 | self.data_collator = left_collat_fn(self.tokenizer) 91 | # @ eval_dataset 的 datacollator 用的是 right 92 | eval_metrics = super().evaluate(eval_dataset, ignore_keys=ignore_keys, metric_key_prefix=metric_key_prefix, 93 | **gen_kwargs) 94 | 95 | self.tokenizer.padding_side = bak_tp 96 | self.data_collator = bak_dc 97 | return eval_metrics 98 | 99 | def prediction_step( 100 | self, 101 | model: nn.Module, 102 | inputs: Dict[str, Union[torch.Tensor, Any]], 103 | prediction_loss_only: bool, 104 | ignore_keys: Optional[List[str]] = None, 105 | ) -> Tuple[Optional[float], Optional[torch.Tensor], Optional[torch.Tensor]]: 106 | r""" 107 | Removes the prompt part in the generated tokens. 108 | 109 | Subclass and override to inject custom behavior. 110 | """ 111 | labels = inputs["labels"].clone() if "labels" in inputs else None # backup labels 112 | # force left pad 113 | if self.args.predict_with_generate: 114 | self.tokenizer.padding_side = "left" 115 | assert self.tokenizer.padding_side == "left", "This method only accepts left-padded tensor." 116 | prompt_len, label_len = inputs["input_ids"].size(-1), inputs["labels"].size(-1) 117 | if prompt_len > label_len: 118 | inputs["labels"] = self._pad_tensors_to_target_len(inputs["labels"], inputs["input_ids"]) 119 | if label_len > prompt_len: 120 | inputs["labels"] = inputs["labels"][:, :prompt_len] # truncate the labels instead of padding the inputs 121 | 122 | loss, generated_tokens, _ = super().prediction_step( 123 | model, inputs, prediction_loss_only=prediction_loss_only, ignore_keys=ignore_keys 124 | ) 125 | if generated_tokens is not None and self.args.predict_with_generate: 126 | generated_tokens[:, :prompt_len] = self.tokenizer.pad_token_id 127 | generated_tokens = generated_tokens.contiguous() 128 | 129 | return loss, generated_tokens, labels 130 | 131 | def _pad_tensors_to_target_len( 132 | self, 133 | src_tensor: torch.Tensor, 134 | tgt_tensor: torch.Tensor 135 | ) -> torch.Tensor: 136 | r""" 137 | Pads the tensor to the same length as the target tensor. 138 | """ 139 | assert self.tokenizer.pad_token_id is not None, "Pad token is required." 140 | padded_tensor = self.tokenizer.pad_token_id * torch.ones_like(tgt_tensor) 141 | padded_tensor[:, -src_tensor.shape[-1]:] = src_tensor # adopt left-padding 142 | return padded_tensor.contiguous() # in contiguous memory 143 | 144 | def save_predictions( 145 | self, 146 | predict_results: "PredictionOutput" 147 | ) -> None: 148 | r""" 149 | Saves model predictions to `output_dir`. 150 | A custom behavior that not contained in Seq2SeqTrainer. 151 | 自定义行为 152 | """ 153 | if not self.is_world_process_zero(): 154 | return 155 | 156 | output_prediction_file = os.path.join(self.args.output_dir, "generated_predictions.jsonl") 157 | logger.info(f"Saving prediction results to {output_prediction_file}") 158 | 159 | preds = np.where(predict_results.predictions != IGNORE_INDEX, predict_results.predictions, 160 | self.tokenizer.pad_token_id) 161 | labels = np.where(predict_results.label_ids != IGNORE_INDEX, predict_results.label_ids, 162 | self.tokenizer.pad_token_id) 163 | 164 | decoded_preds = self.tokenizer.batch_decode(preds, skip_special_tokens=True, clean_up_tokenization_spaces=True) 165 | decoded_labels = self.tokenizer.batch_decode(labels, skip_special_tokens=True, 166 | clean_up_tokenization_spaces=True) 167 | 168 | with open(output_prediction_file, "w", encoding="utf-8") as writer: 169 | res: List[str] = [] 170 | for pred, label in zip(decoded_preds, decoded_labels): 171 | res.append(json.dumps({"label": label, "predict": pred}, ensure_ascii=False)) 172 | writer.write("\n".join(res)) 173 | 174 | 175 | class SFTTrainer(Trainer): 176 | def __init__( 177 | self, 178 | model: Union["PreTrainedModel", nn.Module] = None, 179 | args: "TrainingArguments" = None, 180 | data_collator: Optional["DataCollator"] = None, 181 | train_dataset: Optional[Dataset] = None, 182 | eval_dataset: Optional[Union[Dataset, Dict[str, Dataset]]] = None, 183 | tokenizer: Optional["PreTrainedTokenizerBase"] = None, 184 | model_init: Optional[Callable[[], "PreTrainedModel"]] = None, 185 | compute_metrics: Optional[Callable[["EvalPrediction"], Dict]] = None, 186 | callbacks: Optional[List["TrainerCallback"]] = None, 187 | optimizers: Tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR] = (None, None), 188 | preprocess_logits_for_metrics: Optional[Callable[[torch.Tensor, torch.Tensor], torch.Tensor]] = None, 189 | ): 190 | super().__init__( 191 | model=model, 192 | args=args, 193 | data_collator=data_collator, 194 | train_dataset=train_dataset, 195 | eval_dataset=eval_dataset, 196 | tokenizer=tokenizer, 197 | model_init=model_init, 198 | compute_metrics=compute_metrics, 199 | callbacks=callbacks, 200 | optimizers=optimizers, 201 | preprocess_logits_for_metrics=preprocess_logits_for_metrics, 202 | ) 203 | 204 | # Override self.model.generation_config if a GenerationConfig is specified in args. 205 | # Priority: args.generation_config > model.generation_config > default GenerationConfig. 206 | if self.args.generation_config is not None: 207 | gen_config = self.load_generation_config(self.args.generation_config) 208 | self.model.generation_config = gen_config 209 | 210 | @staticmethod 211 | def load_generation_config(gen_config_arg: Union[str, GenerationConfig]) -> GenerationConfig: 212 | """ 213 | Loads a `~generation.GenerationConfig` from the `Seq2SeqTrainingArguments.generation_config` arguments. 214 | 215 | Args: 216 | gen_config_arg (`str` or [`~generation.GenerationConfig`]): 217 | `Seq2SeqTrainingArguments.generation_config` argument. 218 | 219 | Returns: 220 | A `~generation.GenerationConfig`. 221 | """ 222 | 223 | # GenerationConfig provided, nothing to do 224 | if isinstance(gen_config_arg, GenerationConfig): 225 | return deepcopy(gen_config_arg) 226 | 227 | # str or Path 228 | pretrained_model_name = Path(gen_config_arg) if isinstance(gen_config_arg, str) else gen_config_arg 229 | config_file_name = None 230 | 231 | # Figuring if it is path pointing to a file, pointing to a directory or else a model id or URL 232 | # This step is required in order to determine config_file_name 233 | if pretrained_model_name.is_file(): 234 | config_file_name = pretrained_model_name.name 235 | pretrained_model_name = pretrained_model_name.parent 236 | # dir path 237 | elif pretrained_model_name.is_dir(): 238 | pass 239 | # model id or URL 240 | else: 241 | pretrained_model_name = gen_config_arg 242 | 243 | gen_config = GenerationConfig.from_pretrained(pretrained_model_name, config_file_name) 244 | return gen_config 245 | 246 | def evaluate( 247 | self, 248 | eval_dataset: Optional[Dataset] = None, 249 | ignore_keys: Optional[List[str]] = None, 250 | metric_key_prefix: str = "eval", 251 | **gen_kwargs, 252 | ) -> Dict[str, float]: 253 | """ 254 | Run evaluation and returns metrics. 255 | 256 | The calling script will be responsible for providing a method to compute metrics, as they are task-dependent 257 | (pass it to the init `compute_metrics` argument). 258 | 259 | You can also subclass and override this method to inject custom behavior. 260 | 261 | Args: 262 | eval_dataset (`Dataset`, *optional*): 263 | Pass a dataset if you wish to override `self.eval_dataset`. If it is an [`~datasets.Dataset`], columns 264 | not accepted by the `model.forward()` method are automatically removed. It must implement the `__len__` 265 | method. 266 | ignore_keys (`List[str]`, *optional*): 267 | A list of keys in the output of your model (if it is a dictionary) that should be ignored when 268 | gathering predictions. 269 | metric_key_prefix (`str`, *optional*, defaults to `"eval"`): 270 | An optional prefix to be used as the metrics key prefix. For example the metrics "bleu" will be named 271 | "eval_bleu" if the prefix is `"eval"` (default) 272 | max_length (`int`, *optional*): 273 | The maximum target length to use when predicting with the generate method. 274 | num_beams (`int`, *optional*): 275 | Number of beams for beam search that will be used when predicting with the generate method. 1 means no 276 | beam search. 277 | gen_kwargs: 278 | Additional `generate` specific kwargs. 279 | 280 | Returns: 281 | A dictionary containing the evaluation loss and the potential metrics computed from the predictions. The 282 | dictionary also contains the epoch number which comes from the training state. 283 | """ 284 | 285 | gen_kwargs = gen_kwargs.copy() 286 | if ( 287 | gen_kwargs.get("max_length") is None 288 | and gen_kwargs.get("max_new_tokens") is None 289 | and self.args.generation_max_length is not None 290 | ): 291 | gen_kwargs["max_length"] = self.args.generation_max_length 292 | if gen_kwargs.get("num_beams") is None and self.args.generation_num_beams is not None: 293 | gen_kwargs["num_beams"] = self.args.generation_num_beams 294 | self._gen_kwargs = gen_kwargs 295 | 296 | # **** Original Code Start **** 297 | self._memory_tracker.start() 298 | 299 | eval_dataloader = self.get_eval_dataloader(eval_dataset) 300 | start_time = time.time() 301 | 302 | eval_loop = self.prediction_loop if self.args.use_legacy_prediction_loop else self.evaluation_loop 303 | output = eval_loop( 304 | eval_dataloader, 305 | description="Evaluation", 306 | prediction_loss_only=True if self.compute_metrics is None else None, 307 | ignore_keys=ignore_keys, 308 | metric_key_prefix=metric_key_prefix, 309 | ) 310 | 311 | total_batch_size = self.args.eval_batch_size * self.args.world_size 312 | if f"{metric_key_prefix}_jit_compilation_time" in output.metrics: 313 | start_time += output.metrics[f"{metric_key_prefix}_jit_compilation_time"] 314 | output.metrics.update( 315 | speed_metrics( 316 | metric_key_prefix, 317 | start_time, 318 | num_samples=output.num_samples, 319 | num_steps=math.ceil(output.num_samples / total_batch_size), 320 | ) 321 | ) 322 | # # **** Original Code End **** 323 | # @ compute perplexity 324 | if 'eval_loss' in output.metrics.keys(): 325 | mean_loss = output.metrics.get('eval_loss') 326 | perplexity = math.exp(mean_loss) 327 | output.metrics['eval_perplexity'] = perplexity 328 | 329 | # logger.info(output.metrics) 330 | self.log(output.metrics) 331 | # 加入到 state.log_history 控制台里面 332 | 333 | # if DebugOption.TPU_METRICS_DEBUG in self.args.debug: 334 | # # tpu-comment: Logging debug metrics for PyTorch/XLA (compile, execute times, ops, etc.) 335 | # xm.master_print(met.metrics_report()) 336 | # @ on_evaluate 337 | self.control = self.callback_handler.on_evaluate(self.args, self.state, self.control, output.metrics) 338 | self._memory_tracker.stop_and_update_metrics(output.metrics) 339 | 340 | return output.metrics 341 | 342 | def predict( 343 | self, 344 | test_dataset: Dataset, 345 | ignore_keys: Optional[List[str]] = None, 346 | metric_key_prefix: str = "test", 347 | **gen_kwargs, 348 | ) -> "PredictionOutput": 349 | """ 350 | Run prediction and returns predictions and potential metrics. 351 | 352 | Depending on the dataset and your use case, your test dataset may contain labels. In that case, this method 353 | will also return metrics, like in `evaluate()`. 354 | 355 | Args: 356 | test_dataset (`Dataset`): 357 | Dataset to run the predictions on. If it is a [`~datasets.Dataset`], columns not accepted by the 358 | `model.forward()` method are automatically removed. Has to implement the method `__len__` 359 | ignore_keys (`List[str]`, *optional*): 360 | A list of keys in the output of your model (if it is a dictionary) that should be ignored when 361 | gathering predictions. 362 | metric_key_prefix (`str`, *optional*, defaults to `"eval"`): 363 | An optional prefix to be used as the metrics key prefix. For example the metrics "bleu" will be named 364 | "eval_bleu" if the prefix is `"eval"` (default) 365 | max_length (`int`, *optional*): 366 | The maximum target length to use when predicting with the generate method. 367 | num_beams (`int`, *optional*): 368 | Number of beams for beam search that will be used when predicting with the generate method. 1 means no 369 | beam search. 370 | gen_kwargs: 371 | Additional `generate` specific kwargs. 372 | 373 | 374 | 375 | If your predictions or labels have different sequence lengths (for instance because you're doing dynamic 376 | padding in a token classification task) the predictions will be padded (on the right) to allow for 377 | concatenation into one array. The padding index is -100. 378 | 379 | 380 | 381 | Returns: *NamedTuple* A namedtuple with the following keys: 382 | 383 | - predictions (`np.ndarray`): The predictions on `test_dataset`. 384 | - label_ids (`np.ndarray`, *optional*): The labels (if the dataset contained some). 385 | - metrics (`Dict[str, float]`, *optional*): The potential dictionary of metrics (if the dataset contained 386 | labels). 387 | """ 388 | 389 | gen_kwargs = gen_kwargs.copy() 390 | 391 | # Use legacy argument setting if a) the option is not explicitly passed; and b) the argument is set in the 392 | # training args 393 | if ( 394 | gen_kwargs.get("max_length") is None 395 | and gen_kwargs.get("max_new_tokens") is None 396 | and self.args.generation_max_length is not None 397 | ): 398 | gen_kwargs["max_length"] = self.args.generation_max_length 399 | if gen_kwargs.get("num_beams") is None and self.args.generation_num_beams is not None: 400 | gen_kwargs["num_beams"] = self.args.generation_num_beams 401 | self._gen_kwargs = gen_kwargs 402 | 403 | return super().predict(test_dataset, ignore_keys=ignore_keys, metric_key_prefix=metric_key_prefix) 404 | 405 | def prediction_step( 406 | self, 407 | model: nn.Module, 408 | inputs: Dict[str, Union[torch.Tensor, Any]], 409 | prediction_loss_only: bool, 410 | ignore_keys: Optional[List[str]] = None, 411 | **gen_kwargs, 412 | ) -> Tuple[Optional[float], Optional[torch.Tensor], Optional[torch.Tensor]]: 413 | 414 | if not self.args.predict_with_generate or prediction_loss_only: 415 | return super().prediction_step( 416 | model, inputs, prediction_loss_only=prediction_loss_only, ignore_keys=ignore_keys 417 | ) 418 | # ADD 419 | has_labels = "labels" in inputs 420 | inputs = self._prepare_inputs(inputs) 421 | 422 | # Priority (handled in generate): 423 | # non-`None` gen_kwargs > model.generation_config > default GenerationConfig() 424 | if len(gen_kwargs) == 0 and hasattr(self, "_gen_kwargs"): 425 | gen_kwargs = self._gen_kwargs.copy() 426 | if "num_beams" in gen_kwargs and gen_kwargs["num_beams"] is None: 427 | gen_kwargs.pop("num_beams") 428 | if "max_length" in gen_kwargs and gen_kwargs["max_length"] is None: 429 | gen_kwargs.pop("max_length") 430 | 431 | default_synced_gpus = True if is_deepspeed_zero3_enabled() else False 432 | gen_kwargs["synced_gpus"] = ( 433 | gen_kwargs["synced_gpus"] if gen_kwargs.get("synced_gpus") is not None else default_synced_gpus 434 | ) 435 | 436 | generation_inputs = inputs.copy() 437 | # If the `decoder_input_ids` was created from `labels`, evict the former, so that the model can freely generate 438 | # (otherwise, it would continue generating from the padded `decoder_input_ids`) 439 | if ( 440 | "labels" in generation_inputs 441 | and "decoder_input_ids" in generation_inputs 442 | and generation_inputs["labels"].shape == generation_inputs["decoder_input_ids"].shape 443 | ): 444 | generation_inputs = { 445 | k: v for k, v in inputs.items() if k not in ("decoder_input_ids", "decoder_attention_mask") 446 | } 447 | # @ 1 generated_tokens 448 | print(gen_kwargs) 449 | generated_tokens = self.model.generate(**generation_inputs, **gen_kwargs) 450 | 451 | # Temporary hack to ensure the generation config is not initialized for each iteration of the evaluation loop 452 | # TODO: remove this hack when the legacy code that initializes generation_config from a model config is 453 | # removed in https://github.com/huggingface/transformers/blob/98d88b23f54e5a23e741833f1e973fdf600cc2c5/src/transformers/generation/utils.py#L1183 454 | if self.model.generation_config._from_model_config: 455 | self.model.generation_config._from_model_config = False 456 | 457 | # Retrieves GenerationConfig from model.generation_config 458 | gen_config = self.model.generation_config 459 | # in case the batch is shorter than max length, the output should be padded 460 | if generated_tokens.shape[-1] < gen_config.max_length: 461 | generated_tokens = self._pad_tensors_to_max_len(generated_tokens, gen_config.max_length) 462 | elif gen_config.max_new_tokens is not None and generated_tokens.shape[-1] < gen_config.max_new_tokens + 1: 463 | generated_tokens = self._pad_tensors_to_max_len(generated_tokens, gen_config.max_new_tokens + 1) 464 | 465 | # @ 2 outputs loss 466 | with torch.no_grad(): 467 | if has_labels: 468 | with self.compute_loss_context_manager(): 469 | outputs = model(**inputs) 470 | if self.label_smoother is not None: 471 | loss = self.label_smoother(outputs, inputs["labels"]).mean().detach() 472 | else: 473 | loss = (outputs["loss"] if isinstance(outputs, dict) else outputs[0]).mean().detach() 474 | else: 475 | loss = None 476 | 477 | if self.args.prediction_loss_only: 478 | return loss, None, None 479 | 480 | if has_labels: 481 | labels = inputs["labels"] 482 | if labels.shape[-1] < gen_config.max_length: 483 | labels = self._pad_tensors_to_max_len(labels, gen_config.max_length) 484 | elif gen_config.max_new_tokens is not None and labels.shape[-1] < gen_config.max_new_tokens + 1: 485 | labels = self._pad_tensors_to_max_len(labels, gen_config.max_new_tokens + 1) 486 | else: 487 | labels = None 488 | 489 | return loss, generated_tokens, labels 490 | 491 | def _pad_tensors_to_max_len(self, tensor, max_length): 492 | if self.tokenizer is not None and hasattr(self.tokenizer, "pad_token_id"): 493 | # If PAD token is not defined at least EOS token has to be defined 494 | pad_token_id = ( 495 | self.tokenizer.pad_token_id if self.tokenizer.pad_token_id is not None else self.tokenizer.eos_token_id 496 | ) 497 | else: 498 | if self.model.config.pad_token_id is not None: 499 | pad_token_id = self.model.config.pad_token_id 500 | else: 501 | raise ValueError("Pad_token_id must be set in the configuration of the model, in order to pad tensors") 502 | 503 | padded_tensor = pad_token_id * torch.ones( 504 | (tensor.shape[0], max_length), dtype=tensor.dtype, device=tensor.device 505 | ) 506 | padded_tensor[:, : tensor.shape[-1]] = tensor 507 | return padded_tensor 508 | -------------------------------------------------------------------------------- /config/rbac/role.yaml: -------------------------------------------------------------------------------- 1 | --- 2 | apiVersion: rbac.authorization.k8s.io/v1 3 | kind: ClusterRole 4 | metadata: 5 | creationTimestamp: null 6 | name: manager-role 7 | rules: 8 | - apiGroups: 9 | - finetune.datatunerx.io 10 | resources: 11 | - finetuneexperiments 12 | verbs: 13 | - create 14 | - delete 15 | - get 16 | - list 17 | - patch 18 | - update 19 | - watch 20 | - apiGroups: 21 | - finetune.datatunerx.io 22 | resources: 23 | - finetuneexperiments/finalizers 24 | verbs: 25 | - update 26 | - apiGroups: 27 | - finetune.datatunerx.io 28 | resources: 29 | - finetuneexperiments/status 30 | verbs: 31 | - get 32 | - patch 33 | - update 34 | - apiGroups: 35 | - finetune.datatunerx.io 36 | resources: 37 | - finetunejobs 38 | verbs: 39 | - create 40 | - delete 41 | - get 42 | - list 43 | - patch 44 | - update 45 | - watch 46 | - apiGroups: 47 | - finetune.datatunerx.io 48 | resources: 49 | - finetunejobs/finalizers 50 | verbs: 51 | - update 52 | - apiGroups: 53 | - finetune.datatunerx.io 54 | resources: 55 | - finetunejobs/status 56 | verbs: 57 | - get 58 | - patch 59 | - update 60 | -------------------------------------------------------------------------------- /go.mod: -------------------------------------------------------------------------------- 1 | module github.com/DataTunerX/datatunerx 2 | 3 | go 1.20 4 | 5 | require ( 6 | github.com/DataTunerX/meta-server v0.0.0-20231225093059-13cc8ff65bdc 7 | github.com/DataTunerX/utility-server v0.0.0-20231208092112-6224f8619737 8 | github.com/duke-git/lancet/v2 v2.2.8 9 | github.com/go-logr/zapr v1.2.4 10 | github.com/open-policy-agent/cert-controller v0.10.0 11 | github.com/ray-project/kuberay/ray-operator v1.0.0 12 | github.com/spf13/pflag v1.0.5 13 | github.com/spf13/viper v1.17.0 14 | k8s.io/api v0.28.1 15 | k8s.io/apimachinery v0.28.1 16 | k8s.io/client-go v0.28.1 17 | sigs.k8s.io/controller-runtime v0.16.1 18 | ) 19 | 20 | require ( 21 | github.com/Masterminds/semver/v3 v3.2.0 // indirect 22 | github.com/beorn7/perks v1.0.1 // indirect 23 | github.com/cespare/xxhash/v2 v2.2.0 // indirect 24 | github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc // indirect 25 | github.com/emicklei/go-restful/v3 v3.9.0 // indirect 26 | github.com/evanphx/json-patch/v5 v5.6.0 // indirect 27 | github.com/fsnotify/fsnotify v1.6.0 // indirect 28 | github.com/go-logr/logr v1.2.4 // indirect 29 | github.com/go-openapi/jsonpointer v0.19.6 // indirect 30 | github.com/go-openapi/jsonreference v0.20.2 // indirect 31 | github.com/go-openapi/swag v0.22.3 // indirect 32 | github.com/gogo/protobuf v1.3.2 // indirect 33 | github.com/golang/groupcache v0.0.0-20210331224755-41bb18bfe9da // indirect 34 | github.com/golang/protobuf v1.5.3 // indirect 35 | github.com/google/gnostic-models v0.6.8 // indirect 36 | github.com/google/go-cmp v0.5.9 // indirect 37 | github.com/google/gofuzz v1.2.0 // indirect 38 | github.com/google/shlex v0.0.0-20191202100458-e7afc7fbc510 // indirect 39 | github.com/google/uuid v1.3.0 // indirect 40 | github.com/hashicorp/hcl v1.0.0 // indirect 41 | github.com/imdario/mergo v0.3.12 // indirect 42 | github.com/josharian/intern v1.0.0 // indirect 43 | github.com/json-iterator/go v1.1.12 // indirect 44 | github.com/magiconair/properties v1.8.7 // indirect 45 | github.com/mailru/easyjson v0.7.7 // indirect 46 | github.com/matttproud/golang_protobuf_extensions v1.0.4 // indirect 47 | github.com/mitchellh/mapstructure v1.5.0 // indirect 48 | github.com/moby/spdystream v0.2.0 // indirect 49 | github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd // indirect 50 | github.com/modern-go/reflect2 v1.0.2 // indirect 51 | github.com/munnerz/goautoneg v0.0.0-20191010083416-a7dc8b61c822 // indirect 52 | github.com/openshift/api v0.0.0-20211209135129-c58d9f695577 // indirect 53 | github.com/pelletier/go-toml/v2 v2.1.0 // indirect 54 | github.com/pkg/errors v0.9.1 // indirect 55 | github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2 // indirect 56 | github.com/prometheus/client_golang v1.16.0 // indirect 57 | github.com/prometheus/client_model v0.4.0 // indirect 58 | github.com/prometheus/common v0.44.0 // indirect 59 | github.com/prometheus/procfs v0.10.1 // indirect 60 | github.com/sagikazarmark/locafero v0.3.0 // indirect 61 | github.com/sagikazarmark/slog-shim v0.1.0 // indirect 62 | github.com/sirupsen/logrus v1.9.3 // indirect 63 | github.com/sourcegraph/conc v0.3.0 // indirect 64 | github.com/spf13/afero v1.10.0 // indirect 65 | github.com/spf13/cast v1.5.1 // indirect 66 | github.com/stretchr/testify v1.8.4 // indirect 67 | github.com/subosito/gotenv v1.6.0 // indirect 68 | go.uber.org/atomic v1.11.0 // indirect 69 | go.uber.org/multierr v1.11.0 // indirect 70 | go.uber.org/zap v1.26.0 // indirect 71 | golang.org/x/exp v0.0.0-20230905200255-921286631fa9 // indirect 72 | golang.org/x/net v0.17.0 // indirect 73 | golang.org/x/oauth2 v0.12.0 // indirect 74 | golang.org/x/sys v0.13.0 // indirect 75 | golang.org/x/term v0.13.0 // indirect 76 | golang.org/x/text v0.13.0 // indirect 77 | golang.org/x/time v0.3.0 // indirect 78 | gomodules.xyz/jsonpatch/v2 v2.4.0 // indirect 79 | google.golang.org/appengine v1.6.7 // indirect 80 | google.golang.org/protobuf v1.31.0 // indirect 81 | gopkg.in/inf.v0 v0.9.1 // indirect 82 | gopkg.in/ini.v1 v1.67.0 // indirect 83 | gopkg.in/yaml.v2 v2.4.0 // indirect 84 | gopkg.in/yaml.v3 v3.0.1 // indirect 85 | k8s.io/apiextensions-apiserver v0.28.1 // indirect 86 | k8s.io/component-base v0.28.1 // indirect 87 | k8s.io/klog/v2 v2.100.1 // indirect 88 | k8s.io/kube-openapi v0.0.0-20230717233707-2695361300d9 // indirect 89 | k8s.io/utils v0.0.0-20230406110748-d93618cff8a2 // indirect 90 | sigs.k8s.io/json v0.0.0-20221116044647-bc3834ca7abd // indirect 91 | sigs.k8s.io/structured-merge-diff/v4 v4.2.3 // indirect 92 | sigs.k8s.io/yaml v1.3.0 // indirect 93 | ) 94 | -------------------------------------------------------------------------------- /hack/boilerplate.go.txt: -------------------------------------------------------------------------------- 1 | /* 2 | Copyright 2023. 3 | 4 | Licensed under the Apache License, Version 2.0 (the "License"); 5 | you may not use this file except in compliance with the License. 6 | You may obtain a copy of the License at 7 | 8 | http://www.apache.org/licenses/LICENSE-2.0 9 | 10 | Unless required by applicable law or agreed to in writing, software 11 | distributed under the License is distributed on an "AS IS" BASIS, 12 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | See the License for the specific language governing permissions and 14 | limitations under the License. 15 | */ -------------------------------------------------------------------------------- /internal/controller/finetune/finetuneexperiment_controller.go: -------------------------------------------------------------------------------- 1 | /* 2 | Copyright 2023. 3 | 4 | Licensed under the Apache License, Version 2.0 (the "License"); 5 | you may not use this file except in compliance with the License. 6 | You may obtain a copy of the License at 7 | 8 | http://www.apache.org/licenses/LICENSE-2.0 9 | 10 | Unless required by applicable law or agreed to in writing, software 11 | distributed under the License is distributed on an "AS IS" BASIS, 12 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | See the License for the specific language governing permissions and 14 | limitations under the License. 15 | */ 16 | 17 | package finetune 18 | 19 | import ( 20 | "context" 21 | "reflect" 22 | "sort" 23 | "time" 24 | 25 | "github.com/DataTunerX/datatunerx/pkg/util" 26 | "github.com/DataTunerX/datatunerx/pkg/util/handlererr" 27 | finetunev1beta1 "github.com/DataTunerX/meta-server/api/finetune/v1beta1" 28 | "github.com/DataTunerX/utility-server/logging" 29 | "k8s.io/apimachinery/pkg/api/errors" 30 | metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" 31 | "k8s.io/apimachinery/pkg/runtime" 32 | "k8s.io/apimachinery/pkg/types" 33 | ctrl "sigs.k8s.io/controller-runtime" 34 | "sigs.k8s.io/controller-runtime/pkg/builder" 35 | "sigs.k8s.io/controller-runtime/pkg/client" 36 | "sigs.k8s.io/controller-runtime/pkg/controller" 37 | "sigs.k8s.io/controller-runtime/pkg/controller/controllerutil" 38 | "sigs.k8s.io/controller-runtime/pkg/event" 39 | "sigs.k8s.io/controller-runtime/pkg/handler" 40 | "sigs.k8s.io/controller-runtime/pkg/predicate" 41 | ) 42 | 43 | // FinetuneExperimentReconciler reconciles a FinetuneExperiment object 44 | type FinetuneExperimentReconciler struct { 45 | client.Client 46 | Scheme *runtime.Scheme 47 | Log logging.Logger 48 | } 49 | 50 | //+kubebuilder:rbac:groups=finetune.datatunerx.io,resources=finetuneexperiments,verbs=get;list;watch;create;update;patch;delete 51 | //+kubebuilder:rbac:groups=finetune.datatunerx.io,resources=finetuneexperiments/status,verbs=get;update;patch 52 | //+kubebuilder:rbac:groups=finetune.datatunerx.io,resources=finetuneexperiments/finalizers,verbs=update 53 | 54 | func (r *FinetuneExperimentReconciler) Reconcile(ctx context.Context, req ctrl.Request) (ctrl.Result, error) { 55 | r.Log.Infof("Start reconcile finetuneExperiment: %s/%s,", req.Name, req.Namespace) 56 | finetuneExperiment := &finetunev1beta1.FinetuneExperiment{} 57 | if err := r.Get(ctx, req.NamespacedName, finetuneExperiment); err != nil { 58 | if errors.IsNotFound(err) { 59 | r.Log.Infof("FinetuneExperiment resource not found. Ignoring since object must be deleted.") 60 | return handlererr.HandlerErr(nil) 61 | } 62 | r.Log.Errorf("Failed get finetuneExperiment: %s/%s, Err: %v", req.Name, req.Namespace, err) 63 | return handlererr.HandlerErr(err) 64 | } 65 | 66 | if finetuneExperiment.GetDeletionTimestamp() != nil { 67 | if controllerutil.ContainsFinalizer(finetuneExperiment, finetunev1beta1.FinetuneGroupFinalizer) { 68 | // todo cleaner 69 | controllerutil.RemoveFinalizer(finetuneExperiment, finetunev1beta1.FinetuneGroupFinalizer) 70 | if err := r.Update(ctx, finetuneExperiment); err != nil { 71 | r.Log.Errorf("Remove finalizer failed: %s/%s, Err: %v", req.Name, req.Namespace, err) 72 | return handlererr.HandlerErr(err) 73 | } 74 | } 75 | return handlererr.HandlerErr(nil) 76 | } 77 | if !controllerutil.ContainsFinalizer(finetuneExperiment, finetunev1beta1.FinetuneGroupFinalizer) { 78 | controllerutil.AddFinalizer(finetuneExperiment, finetunev1beta1.FinetuneGroupFinalizer) 79 | err := r.Update(ctx, finetuneExperiment) 80 | if err != nil { 81 | r.Log.Errorf("Add finalizer failed: %s/%s, %v", req.Name, req.Namespace, err) 82 | return handlererr.HandlerErr(err) 83 | } 84 | } 85 | 86 | if finetuneExperiment.Spec.Pending && finetuneExperiment.Status.State != finetunev1beta1.FinetuneExperimentPending { 87 | for i := range finetuneExperiment.Spec.FinetuneJobs { 88 | finetuneJob := finetuneExperiment.Spec.FinetuneJobs[i] 89 | existFinetuneJob := &finetunev1beta1.FinetuneJob{} 90 | if err := r.Client.Get(ctx, types.NamespacedName{ 91 | Name: finetuneJob.Name, 92 | Namespace: finetuneExperiment.Namespace, 93 | }, existFinetuneJob); err != nil { 94 | if errors.IsNotFound(err) { 95 | r.Log.Infof("FinetuneJob %s/%s not found, continue", finetuneExperiment.Namespace, finetuneJob.Name) 96 | continue 97 | } 98 | return handlererr.HandlerErr(err) 99 | } 100 | if err := r.Client.Delete(ctx, existFinetuneJob); err != nil { 101 | return handlererr.HandlerErr(err) 102 | } 103 | } 104 | finetuneExperiment.Status.JobsStatus = make([]*finetunev1beta1.FinetuneJobStatusSetting, 0) 105 | finetuneExperiment.Status.State = finetunev1beta1.FinetuneExperimentPending 106 | finetuneExperiment.Status.Stats = metav1.Now().Format("2006-01-02 15:04:05") 107 | if err := r.Client.Status().Update(ctx, finetuneExperiment); err != nil { 108 | r.Log.Errorf("Update fineExperiment %s/%s status failed", finetuneExperiment.Name, finetuneExperiment.Namespace) 109 | return handlererr.HandlerErr(err) 110 | } 111 | return handlererr.HandlerErr(nil) 112 | } else if finetuneExperiment.Spec.Pending { 113 | return handlererr.HandlerErr(nil) 114 | } 115 | 116 | if finetuneExperiment.Status.State == "" { 117 | finetuneExperiment.Status.State = finetunev1beta1.FinetuneExperimentProcessing 118 | if err := r.Client.Status().Update(ctx, finetuneExperiment); err != nil { 119 | r.Log.Errorf("Update fineExperiment %s/%s status failed", finetuneExperiment.Name, finetuneExperiment.Namespace) 120 | return handlererr.HandlerErr(err) 121 | } 122 | } 123 | for i := range finetuneExperiment.Spec.FinetuneJobs { 124 | finetuneJob := finetuneExperiment.Spec.FinetuneJobs[i] 125 | existFinetuneJob := &finetunev1beta1.FinetuneJob{} 126 | if err := r.Client.Get(ctx, types.NamespacedName{ 127 | Name: finetuneJob.Name, 128 | Namespace: finetuneExperiment.Namespace, 129 | }, existFinetuneJob); err != nil { 130 | if errors.IsNotFound(err) { 131 | finetuneJobInstance := &finetunev1beta1.FinetuneJob{} 132 | finetuneJobInstance.Spec = finetuneJob.Spec 133 | finetuneJobInstance.Name = finetuneJob.Name 134 | r.Log.Infof("finetuneJob Name: %s", finetuneJobInstance.Name) 135 | finetuneJobInstance.Namespace = finetuneExperiment.Namespace 136 | if err := ctrl.SetControllerReference(finetuneExperiment, finetuneJobInstance, r.Scheme); err != nil { 137 | r.Log.Errorf("SetControllerReference failed finetuneJob: %s/%s, owner finetuneExperiment: %s/%s, err: %v", 138 | finetuneJobInstance.Name, finetuneJobInstance.Namespace, finetuneExperiment.Name, finetuneExperiment.Namespace, err) 139 | return handlererr.HandlerErr(err) 140 | } 141 | if err := r.Client.Create(ctx, finetuneJobInstance); err != nil { 142 | if !errors.IsAlreadyExists(err) { 143 | r.Log.Errorf("Create finetuneJob %s/%s failed: %v", finetuneJobInstance.Name, finetuneJobInstance.Namespace, err) 144 | return handlererr.HandlerErr(err) 145 | } 146 | } 147 | } else { 148 | r.Log.Errorf("Get finetuneJob %s/%s failed: %v", finetuneJob.Name, finetuneExperiment.Namespace, err) 149 | return handlererr.HandlerErr(err) 150 | } 151 | } 152 | } 153 | 154 | success := true 155 | failed := true 156 | for i := range finetuneExperiment.Spec.FinetuneJobs { 157 | finetuneJobInstance := &finetunev1beta1.FinetuneJob{} 158 | if err := r.Client.Get(ctx, types.NamespacedName{Name: finetuneExperiment.Spec.FinetuneJobs[i].Name, Namespace: finetuneExperiment.Namespace}, finetuneJobInstance); err != nil { 159 | r.Log.Errorf("Get finetuneJob %s/%s failed, err: %v", finetuneExperiment.Spec.FinetuneJobs[i].Name, finetuneExperiment.Namespace, err) 160 | return handlererr.HandlerErr(err) 161 | } 162 | if finetuneJobInstance.Status.FinetuneStatus == nil { 163 | finetuneJobInstance.Status.FinetuneStatus = &finetunev1beta1.FinetuneStatus{ 164 | State: finetunev1beta1.FinetuneInit, 165 | } 166 | } 167 | 168 | if finetuneExperiment.Status.JobsStatus == nil { 169 | finetuneExperiment.Status.JobsStatus = make([]*finetunev1beta1.FinetuneJobStatusSetting, len(finetuneExperiment.Spec.FinetuneJobs)) 170 | } 171 | if finetuneExperiment.Status.JobsStatus[i] != nil { 172 | r.Log.Infof("Update finetuneExperiment %s/%s status", finetuneExperiment.Namespace, finetuneExperiment.Name) 173 | if !reflect.DeepEqual(finetuneExperiment.Status.JobsStatus[i].FinetuneJobStatus, finetuneJobInstance.Status) { 174 | finetuneExperiment.Status.JobsStatus[i] = &finetunev1beta1.FinetuneJobStatusSetting{ 175 | Name: finetuneJobInstance.Name, 176 | FinetuneJobStatus: finetuneJobInstance.Status, 177 | } 178 | } 179 | } else { 180 | r.Log.Infof("Set finetuneExperiment %s/%s status", finetuneExperiment.Namespace, finetuneExperiment.Name) 181 | finetuneExperiment.Status.JobsStatus[i] = &finetunev1beta1.FinetuneJobStatusSetting{ 182 | Name: finetuneJobInstance.Name, 183 | FinetuneJobStatus: finetunev1beta1.FinetuneJobStatus{ 184 | State: finetunev1beta1.FinetuneJobInit, 185 | FinetuneStatus: &finetunev1beta1.FinetuneStatus{ 186 | State: finetunev1beta1.FinetuneInit, 187 | }, 188 | }, 189 | } 190 | } 191 | if finetuneJobInstance.Status.State != finetunev1beta1.FinetuneJobSuccessful { 192 | success = false 193 | } 194 | if finetuneJobInstance.Status.State != finetunev1beta1.FinetuneJobFailed { 195 | failed = false 196 | } 197 | } 198 | 199 | if success { 200 | finetuneExperiment.Status.State = finetunev1beta1.FinetuneExperimentSuccess 201 | jobs := finetuneExperiment.Status.JobsStatus 202 | sort.Slice(jobs, func(i, j int) bool { 203 | return util.ParseScore(jobs[i].FinetuneJobStatus.Result.Score) > util.ParseScore(jobs[j].FinetuneJobStatus.Result.Score) 204 | }) 205 | finetuneJobBestVersion := &finetunev1beta1.FinetuneJob{} 206 | if err := r.Client.Get(ctx, types.NamespacedName{Name: jobs[0].Name, Namespace: finetuneExperiment.Namespace}, finetuneJobBestVersion); err != nil { 207 | r.Log.Errorf("Get finetuneJob %s/%s failed: %v", jobs[0].Name, finetuneExperiment.Namespace, err) 208 | } 209 | finetuneExperiment.Status.BestVersion = &finetunev1beta1.BestVersion{ 210 | Score: jobs[0].FinetuneJobStatus.Result.Score, 211 | Image: jobs[0].FinetuneJobStatus.Result.Image, 212 | LLM: finetuneJobBestVersion.Spec.FineTune.FinetuneSpec.LLM, 213 | Hyperparameter: finetuneJobBestVersion.Spec.FineTune.FinetuneSpec.Hyperparameter, 214 | Dataset: finetuneJobBestVersion.Spec.FineTune.FinetuneSpec.Dataset, 215 | } 216 | finetuneExperiment.Status.Stats = metav1.Now().Format("2006-01-02 15:04:05") 217 | } else if failed { 218 | finetuneExperiment.Status.State = finetunev1beta1.FinetuneExperimentFailed 219 | finetuneExperiment.Status.Stats = metav1.Now().Format("2006-01-02 15:04:05") 220 | } 221 | 222 | if err := r.Client.Status().Update(ctx, finetuneExperiment); err != nil { 223 | r.Log.Errorf("Update fineExperiment %s/%s status failed", finetuneExperiment.Namespace, finetuneExperiment.Name) 224 | return handlererr.HandlerErr(err) 225 | } 226 | return handlererr.HandlerErr(nil) 227 | } 228 | 229 | // SetupWithManager sets up the controller with the Manager. 230 | func (r *FinetuneExperimentReconciler) SetupWithManager(mgr ctrl.Manager) error { 231 | return ctrl.NewControllerManagedBy(mgr). 232 | For(&finetunev1beta1.FinetuneExperiment{}). 233 | Watches(&finetunev1beta1.FinetuneJob{}, 234 | handler.EnqueueRequestForOwner(mgr.GetScheme(), mgr.GetRESTMapper(), &finetunev1beta1.FinetuneExperiment{}, handler.OnlyControllerOwner()), 235 | builder.WithPredicates(predicate.Funcs{ 236 | UpdateFunc: func(updateEvent event.UpdateEvent) bool { 237 | oldFinetuneJob := updateEvent.ObjectOld.(*finetunev1beta1.FinetuneJob) 238 | newFinetuneJob := updateEvent.ObjectNew.(*finetunev1beta1.FinetuneJob) 239 | if oldFinetuneJob.Status.State != newFinetuneJob.Status.State { 240 | r.Log.Infof("Get finetuneJob %s/%s update event oldStatus: %s, newStatus: %s", oldFinetuneJob.Namespace, oldFinetuneJob.Name, oldFinetuneJob.Status.State, newFinetuneJob.Status.State) 241 | return true 242 | } 243 | return false 244 | }, 245 | CreateFunc: func(createEvent event.CreateEvent) bool { 246 | finetuneJob := createEvent.Object.(*finetunev1beta1.FinetuneJob) 247 | r.Log.Infof("Get finetuneJob %s/%s crate event, skip", finetuneJob.Name, finetuneJob.Namespace) 248 | return false 249 | }, 250 | })). 251 | WithOptions(controller.Options{ 252 | CacheSyncTimeout: 10 * time.Second, 253 | MaxConcurrentReconciles: 1}). 254 | Complete(r) 255 | } 256 | -------------------------------------------------------------------------------- /internal/controller/finetune/finetunejob_controller.go: -------------------------------------------------------------------------------- 1 | /* 2 | Copyright 2023. 3 | 4 | Licensed under the Apache License, Version 2.0 (the "License"); 5 | you may not use this file except in compliance with the License. 6 | You may obtain a copy of the License at 7 | 8 | http://www.apache.org/licenses/LICENSE-2.0 9 | 10 | Unless required by applicable law or agreed to in writing, software 11 | distributed under the License is distributed on an "AS IS" BASIS, 12 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | See the License for the specific language governing permissions and 14 | limitations under the License. 15 | */ 16 | 17 | package finetune 18 | 19 | import ( 20 | "context" 21 | "fmt" 22 | "reflect" 23 | "time" 24 | 25 | "github.com/DataTunerX/datatunerx/pkg/config" 26 | "github.com/DataTunerX/datatunerx/pkg/domain/valueobject" 27 | "github.com/DataTunerX/datatunerx/pkg/util" 28 | "github.com/DataTunerX/datatunerx/pkg/util/generate" 29 | "github.com/DataTunerX/datatunerx/pkg/util/handlererr" 30 | corev1beta1 "github.com/DataTunerX/meta-server/api/core/v1beta1" 31 | extensionv1beta1 "github.com/DataTunerX/meta-server/api/extension/v1beta1" 32 | finetunev1beta1 "github.com/DataTunerX/meta-server/api/finetune/v1beta1" 33 | "github.com/DataTunerX/utility-server/logging" 34 | "github.com/duke-git/lancet/v2/slice" 35 | rayv1 "github.com/ray-project/kuberay/ray-operator/apis/ray/v1" 36 | batchv1 "k8s.io/api/batch/v1" 37 | "k8s.io/apimachinery/pkg/api/errors" 38 | metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" 39 | "k8s.io/apimachinery/pkg/runtime" 40 | "k8s.io/apimachinery/pkg/types" 41 | ctrl "sigs.k8s.io/controller-runtime" 42 | "sigs.k8s.io/controller-runtime/pkg/builder" 43 | "sigs.k8s.io/controller-runtime/pkg/client" 44 | "sigs.k8s.io/controller-runtime/pkg/controller" 45 | "sigs.k8s.io/controller-runtime/pkg/controller/controllerutil" 46 | "sigs.k8s.io/controller-runtime/pkg/event" 47 | "sigs.k8s.io/controller-runtime/pkg/handler" 48 | "sigs.k8s.io/controller-runtime/pkg/predicate" 49 | ) 50 | 51 | // FinetuneJobReconciler reconciles a FinetuneJob object 52 | type FinetuneJobReconciler struct { 53 | client.Client 54 | Scheme *runtime.Scheme 55 | Log logging.Logger 56 | } 57 | 58 | //+kubebuilder:rbac:groups=finetune.datatunerx.io,resources=finetunejobs,verbs=get;list;watch;create;update;patch;delete 59 | //+kubebuilder:rbac:groups=finetune.datatunerx.io,resources=finetunejobs/status,verbs=get;update;patch 60 | //+kubebuilder:rbac:groups=finetune.datatunerx.io,resources=finetunejobs/finalizers,verbs=update 61 | 62 | // Reconcile is part of the main kubernetes reconciliation loop which aims to 63 | // move the current state of the cluster closer to the desired state. 64 | // TODO(user): Modify the Reconcile function to compare the state specified by 65 | // the FinetuneJob object against the actual cluster state, and then 66 | // perform operations to make the cluster state reflect the state specified by 67 | // the user. 68 | // 69 | // For more details, check Reconcile and its Result here: 70 | // - https://pkg.go.dev/sigs.k8s.io/controller-runtime@v0.14.1/pkg/reconcile 71 | func (r *FinetuneJobReconciler) Reconcile(ctx context.Context, req ctrl.Request) (ctrl.Result, error) { 72 | // todo(tigerK) This reconcile contains a lot of tested strings that need to be optimised after running through the process 73 | r.Log.Infof("Start reconcile finetuneJob: %s/%s", req.Namespace, req.Name) 74 | finetuneJob := &finetunev1beta1.FinetuneJob{} 75 | if err := r.Get(ctx, req.NamespacedName, finetuneJob); err != nil { 76 | if errors.IsNotFound(err) { 77 | r.Log.Infof("FinetuneJob %s resource not found. Ignoring since object must be deleted.", req.NamespacedName) 78 | return handlererr.HandlerErr(nil) 79 | } 80 | r.Log.Errorf("Failed get finetuneJob: %s/%s, Err: %v", req.Namespace, req.Name, err) 81 | return handlererr.HandlerErr(err) 82 | } 83 | if finetuneJob.GetDeletionTimestamp() != nil { 84 | r.Log.Infof("Delete finetuneJob: %s/%s", finetuneJob.Namespace, finetuneJob.Name) 85 | if controllerutil.ContainsFinalizer(finetuneJob, finetunev1beta1.FinetuneGroupFinalizer) { 86 | if err := r.reconcileCleaner(ctx, finetuneJob); err != nil { 87 | r.Log.Errorf("cleaner failed: %s/%s, Err: %v", finetuneJob.Namespace, finetuneJob.Name, err) 88 | return handlererr.HandlerErr(err) 89 | } 90 | controllerutil.RemoveFinalizer(finetuneJob, finetunev1beta1.FinetuneGroupFinalizer) 91 | if err := r.Update(ctx, finetuneJob); err != nil { 92 | r.Log.Errorf("Remove finalizer failed: %s/%s, Err: %v", finetuneJob.Namespace, finetuneJob.Name, err) 93 | return handlererr.HandlerErr(err) 94 | } 95 | } 96 | return handlererr.HandlerErr(nil) 97 | } 98 | if !controllerutil.ContainsFinalizer(finetuneJob, finetunev1beta1.FinetuneGroupFinalizer) { 99 | controllerutil.AddFinalizer(finetuneJob, finetunev1beta1.FinetuneGroupFinalizer) 100 | err := r.Update(ctx, finetuneJob) 101 | if err != nil { 102 | r.Log.Errorf("Add finalizer failed: %s/%s, %v", finetuneJob.Namespace, finetuneJob.Name, err) 103 | return handlererr.HandlerErr(err) 104 | } 105 | } 106 | 107 | if finetuneJob.Status.State == "" { 108 | finetuneJob.Status.State = finetunev1beta1.FinetuneJobInit 109 | if err := r.Client.Status().Update(ctx, finetuneJob); err != nil { 110 | r.Log.Errorf("Update finetuneJob %s/%s status failed: %v", finetuneJob.Namespace, finetuneJob.Name, err) 111 | return handlererr.HandlerErr(err) 112 | } 113 | } 114 | 115 | if err := r.reconcilePreCondition(ctx, finetuneJob); err != nil { 116 | return handlererr.HandlerErr(err) 117 | } 118 | 119 | existFinetune, err := r.reconcileFinetuneSend(ctx, finetuneJob) 120 | if err != nil { 121 | return handlererr.HandlerErr(err) 122 | } 123 | 124 | if err := r.reconcileByFinetuneStatus(ctx, existFinetune, finetuneJob); err != nil { 125 | return handlererr.HandlerErr(err) 126 | } 127 | 128 | if err := r.reconcileByJobStatus(ctx, finetuneJob, existFinetune); err != nil { 129 | return handlererr.HandlerErr(err) 130 | } 131 | 132 | if err := r.reconcileByRayServiceStatus(ctx, finetuneJob); err != nil { 133 | return handlererr.HandlerErr(err) 134 | } 135 | 136 | if err := r.reconcileByScoringStatus(ctx, finetuneJob); err != nil { 137 | return handlererr.HandlerErr(err) 138 | } 139 | 140 | // Phase IIII of the fine-tuning exercise. 141 | // Check finetune cr status, if finetune cr status is SUCCESSFUL, start next 142 | return handlererr.HandlerErr(nil) 143 | } 144 | 145 | // SetupWithManager sets up the controller with the Manager. 146 | func (r *FinetuneJobReconciler) SetupWithManager(mgr ctrl.Manager) error { 147 | return ctrl.NewControllerManagedBy(mgr). 148 | For(&finetunev1beta1.FinetuneJob{}, builder.WithPredicates(predicate.Funcs{ 149 | UpdateFunc: func(updateEvent event.UpdateEvent) bool { 150 | oldFinetuneJob := updateEvent.ObjectOld.(*finetunev1beta1.FinetuneJob) 151 | newFinetuneJob := updateEvent.ObjectNew.(*finetunev1beta1.FinetuneJob) 152 | if !reflect.DeepEqual(oldFinetuneJob.Spec, newFinetuneJob.Spec) || 153 | !newFinetuneJob.GetDeletionTimestamp().IsZero() { 154 | return true 155 | } 156 | return false 157 | }, 158 | DeleteFunc: func(deleteEvent event.DeleteEvent) bool { 159 | return false 160 | }, 161 | })). 162 | Watches(&finetunev1beta1.Finetune{}, 163 | handler.EnqueueRequestForOwner(mgr.GetScheme(), mgr.GetRESTMapper(), &finetunev1beta1.FinetuneJob{}, handler.OnlyControllerOwner()), 164 | builder.WithPredicates(predicate.Funcs{ 165 | UpdateFunc: func(updateEvent event.UpdateEvent) bool { 166 | oldFinetune := updateEvent.ObjectOld.(*finetunev1beta1.Finetune) 167 | newFinetune := updateEvent.ObjectNew.(*finetunev1beta1.Finetune) 168 | if oldFinetune.Status.State != newFinetune.Status.State { 169 | r.Log.Infof("Get finetun %s/%s update event oldStatus: %s, newStatus: %s", oldFinetune.Name, oldFinetune.Namespace, oldFinetune.Status.State, newFinetune.Status.State) 170 | return true 171 | } 172 | return false 173 | }, 174 | CreateFunc: func(createEvent event.CreateEvent) bool { 175 | finetune := createEvent.Object.(*finetunev1beta1.Finetune) 176 | r.Log.Infof("Get finetun %s/%s crate event, skip", finetune.Name, finetune.Namespace) 177 | return false 178 | }, 179 | })). 180 | Watches(&batchv1.Job{}, 181 | handler.EnqueueRequestForOwner(mgr.GetScheme(), mgr.GetRESTMapper(), &finetunev1beta1.FinetuneJob{}, handler.OnlyControllerOwner()), 182 | builder.WithPredicates(predicate.NewPredicateFuncs(func(object client.Object) bool { 183 | job := object.(*batchv1.Job) 184 | if job.Status.CompletionTime != nil { 185 | return true 186 | } 187 | return false 188 | }))). 189 | Watches(&rayv1.RayService{}, 190 | handler.EnqueueRequestForOwner(mgr.GetScheme(), mgr.GetRESTMapper(), &finetunev1beta1.FinetuneJob{}, handler.OnlyControllerOwner()), 191 | builder.WithPredicates(predicate.NewPredicateFuncs(func(object client.Object) bool { 192 | rayService := object.(*rayv1.RayService) 193 | if rayService.Status.ServiceStatus == rayv1.Running { 194 | return true 195 | } 196 | return false 197 | }))). 198 | Watches(&extensionv1beta1.Scoring{}, 199 | handler.EnqueueRequestForOwner(mgr.GetScheme(), mgr.GetRESTMapper(), &finetunev1beta1.FinetuneJob{}, handler.OnlyControllerOwner()), 200 | builder.WithPredicates(predicate.NewPredicateFuncs(func(object client.Object) bool { 201 | scoring := object.(*extensionv1beta1.Scoring) 202 | if scoring.Status.Score != nil { 203 | return true 204 | } 205 | return false 206 | }))). 207 | WithOptions(controller.Options{ 208 | CacheSyncTimeout: 10 * time.Second, 209 | MaxConcurrentReconciles: 1}). 210 | Complete(r) 211 | } 212 | 213 | func (r *FinetuneJobReconciler) reconcilePreCondition(ctx context.Context, finetuneJob *finetunev1beta1.FinetuneJob) error { 214 | preCondition := make(map[string]client.Object, 3) 215 | preCondition[finetuneJob.Spec.FineTune.FinetuneSpec.LLM] = &corev1beta1.LLM{} 216 | preCondition[finetuneJob.Spec.FineTune.FinetuneSpec.Hyperparameter.HyperparameterRef] = &corev1beta1.Hyperparameter{} 217 | preCondition[finetuneJob.Spec.FineTune.FinetuneSpec.Dataset] = &extensionv1beta1.Dataset{} 218 | for name, obj := range preCondition { 219 | if err := r.Get(ctx, types.NamespacedName{Name: name, Namespace: finetuneJob.Namespace}, obj); err != nil { 220 | r.Log.Errorf("Get %s: %s/%s failed, err: %v", obj.GetObjectKind(), finetuneJob.Namespace, name, err) 221 | return err 222 | } 223 | switch obj.(type) { 224 | case *corev1beta1.LLM: 225 | llm := obj.(*corev1beta1.LLM) 226 | if llm.Status.ReferenceFinetuneName == nil { 227 | llm.Status.ReferenceFinetuneName = make([]string, 0) 228 | } 229 | llm.Status.ReferenceFinetuneName = slice.AppendIfAbsent(llm.Status.ReferenceFinetuneName, finetuneJob.Spec.FineTune.Name) 230 | if err := r.Client.Status().Update(ctx, llm); err != nil { 231 | r.Log.Errorf("update %s: %s/%s status failed, err: %v", obj.GetObjectKind(), finetuneJob.Namespace, name, err) 232 | return err 233 | } 234 | case *extensionv1beta1.Dataset: 235 | dataset := obj.(*extensionv1beta1.Dataset) 236 | if dataset.Status.ReferenceFinetuneName == nil { 237 | dataset.Status.ReferenceFinetuneName = make([]string, 0) 238 | } 239 | dataset.Status.ReferenceFinetuneName = slice.AppendIfAbsent(dataset.Status.ReferenceFinetuneName, finetuneJob.Spec.FineTune.Name) 240 | if err := r.Client.Status().Update(ctx, dataset); err != nil { 241 | r.Log.Errorf("update %s: %s/%s status failed, err: %v", obj.GetObjectKind(), finetuneJob.Namespace, name, err) 242 | return err 243 | } 244 | case *corev1beta1.Hyperparameter: 245 | hyperparameter := obj.(*corev1beta1.Hyperparameter) 246 | if hyperparameter.Status.ReferenceFinetuneName == nil { 247 | hyperparameter.Status.ReferenceFinetuneName = make([]string, 0) 248 | } 249 | hyperparameter.Status.ReferenceFinetuneName = slice.AppendIfAbsent(hyperparameter.Status.ReferenceFinetuneName, finetuneJob.Spec.FineTune.Name) 250 | if err := r.Client.Status().Update(ctx, hyperparameter); err != nil { 251 | r.Log.Errorf("update %s: %s/%s status failed, err: %v", obj.GetObjectKind(), finetuneJob.Namespace, name, err) 252 | return err 253 | } 254 | } 255 | } 256 | return nil 257 | } 258 | 259 | func (r *FinetuneJobReconciler) reconcileFinetuneSend(ctx context.Context, finetuneJob *finetunev1beta1.FinetuneJob) (*finetunev1beta1.Finetune, error) { 260 | 261 | specFinetuneInstance := generate.GenerateFinetune(finetuneJob) 262 | existFinetuneInstance := &finetunev1beta1.Finetune{} 263 | if err := r.Get(ctx, types.NamespacedName{ 264 | Name: specFinetuneInstance.Name, 265 | Namespace: specFinetuneInstance.Namespace, 266 | }, existFinetuneInstance); err != nil { 267 | if errors.IsNotFound(err) { 268 | r.Log.Infof("Starting to send down finetune: %s/%s.", specFinetuneInstance.Namespace, specFinetuneInstance.Name) 269 | if err := ctrl.SetControllerReference(finetuneJob, specFinetuneInstance, r.Scheme); err != nil { 270 | r.Log.Errorf("For %s/%s set owner %s/%s failed: %v", specFinetuneInstance.Namespace, specFinetuneInstance.Name, finetuneJob.Namespace, finetuneJob.Name, err) 271 | return nil, err 272 | } 273 | if err := r.Client.Create(ctx, specFinetuneInstance); err != nil { 274 | if !errors.IsAlreadyExists(err) { 275 | r.Log.Errorf("Create finetune %s/%s failed: %v", specFinetuneInstance.Namespace, specFinetuneInstance.Name, err) 276 | return nil, err 277 | } 278 | } 279 | return nil, valueobject.ErrRecalibrate 280 | } 281 | } 282 | return existFinetuneInstance, nil 283 | } 284 | 285 | func (r *FinetuneJobReconciler) reconcileByFinetuneStatus(ctx context.Context, finetuneInstance *finetunev1beta1.Finetune, finetuneJobInstance *finetunev1beta1.FinetuneJob) error { 286 | 287 | if finetuneInstance.Status.State == finetunev1beta1.FinetuneInit || finetuneInstance.Status.State == finetunev1beta1.FinetuneRunning { 288 | r.Log.Infof("Update finetuneJob %s/%s status %s.", finetuneJobInstance.Namespace, finetuneJobInstance.Name, finetunev1beta1.FinetuneJobFinetune) 289 | finetuneJobInstance.Status.State = finetunev1beta1.FinetuneJobFinetune 290 | finetuneJobInstance.Status.FinetuneStatus = &finetuneInstance.Status 291 | if err := r.Client.Status().Update(ctx, finetuneJobInstance); err != nil { 292 | r.Log.Errorf("Update finetuneJob %s/%s status failed: %v", finetuneJobInstance.Namespace, finetuneJobInstance.Name, err) 293 | return err 294 | } 295 | } 296 | 297 | if finetuneInstance.Status.State == finetunev1beta1.FinetuneSuccessful && finetuneJobInstance.Status.State != finetunev1beta1.FinetuneJobBuildImage { 298 | if finetuneInstance.Status.LLMCheckpoint == nil { 299 | r.Log.Infof("Finetune %s/%s status not found llmCheckpointRef", finetuneInstance.Namespace, finetuneInstance.Name) 300 | return fmt.Errorf("Finetune %s/%s status not set llmCheckpointRef", finetuneInstance.Namespace, finetuneInstance.Name) 301 | } 302 | 303 | llmCheckpoint := &corev1beta1.LLMCheckpoint{} 304 | if err := r.Get(ctx, types.NamespacedName{Name: finetuneInstance.Status.LLMCheckpoint.LLMCheckpointRef, Namespace: finetuneJobInstance.Namespace}, llmCheckpoint); err != nil { 305 | r.Log.Errorf("Get llmCheckpoint %s/%s failed, err: %v", finetuneJobInstance.Namespace, finetuneInstance.Status.LLMCheckpoint, err) 306 | return err 307 | } 308 | // build llmCheckpoint image server. job 309 | 310 | imageName := fmt.Sprintf("ray271-llama2-7b-finetune-checkpoint-%s", finetuneJobInstance.Name) 311 | imageTag := fmt.Sprintf("%s", time.Now().Format("20060102")) 312 | checkPointFilePath := finetuneInstance.Status.LLMCheckpoint.CheckpointPath 313 | checkPointFilePath = util.RemoveBucketName(checkPointFilePath, config.GetS3Bucket()) 314 | buildImageJob := generate.GenerateBuildImageJob(checkPointFilePath, imageName, imageTag, finetuneJobInstance) 315 | if err := ctrl.SetControllerReference(finetuneJobInstance, buildImageJob, r.Scheme); err != nil { 316 | r.Log.Errorf("Set owner failed: %v", err) 317 | return err 318 | } 319 | if err := r.Client.Get(ctx, types.NamespacedName{Name: buildImageJob.Name, Namespace: buildImageJob.Namespace}, buildImageJob); err != nil { 320 | if errors.IsNotFound(err) { 321 | if err := r.Client.Create(ctx, buildImageJob); err != nil { 322 | r.Log.Errorf("Create job %s/%s failed, err: %v", buildImageJob.Name, buildImageJob.Namespace, err) 323 | return err 324 | } 325 | } 326 | } 327 | 328 | llmCheckpoint.Spec.CheckpointImage = &corev1beta1.CheckpointImage{} 329 | llmImage := fmt.Sprintf("%s/%s/%s:%s", config.GetRegistryUrl(), config.GetRepositoryName(), imageName, imageTag) 330 | llmCheckpoint.Spec.CheckpointImage.Name = &llmImage 331 | llmCheckpoint.Spec.CheckpointImage.CheckPointPath = fmt.Sprintf("/checkpoint/%s", checkPointFilePath) 332 | llmCheckpoint.Spec.CheckpointImage.LLMPath = llmCheckpoint.Spec.Image.Path 333 | if err := r.Client.Update(ctx, llmCheckpoint); err != nil { 334 | r.Log.Errorf("Update llmCheckpoint %s/%s failed: %v", llmCheckpoint.Namespace, llmCheckpoint.Name, err) 335 | return err 336 | } 337 | 338 | finetuneJobInstance.Status.State = finetunev1beta1.FinetuneJobBuildImage 339 | finetuneJobInstance.Status.FinetuneStatus = &finetuneInstance.Status 340 | if err := r.Client.Status().Update(ctx, finetuneJobInstance); err != nil { 341 | r.Log.Errorf("Update finetuneJob %s/%s status failed: %v", finetuneJobInstance.Namespace, finetuneInstance.Name, err) 342 | return err 343 | } 344 | } 345 | 346 | if finetuneInstance.Status.State == finetunev1beta1.FinetuneFailed { 347 | finetuneJobInstance.Status.State = finetunev1beta1.FinetuneJobFailed 348 | finetuneJobInstance.Status.FinetuneStatus = &finetuneInstance.Status 349 | if err := r.Client.Status().Update(ctx, finetuneJobInstance); err != nil { 350 | r.Log.Errorf("Update finetuneJob %s/%s status failed: %v", finetuneJobInstance.Namespace, finetuneInstance.Name, err) 351 | return err 352 | } 353 | } 354 | return nil 355 | } 356 | 357 | func (r *FinetuneJobReconciler) reconcileByJobStatus(ctx context.Context, finetuneJob *finetunev1beta1.FinetuneJob, finetune *finetunev1beta1.Finetune) error { 358 | 359 | jobName := fmt.Sprintf("%s-buildimage", finetuneJob.Name) 360 | buildImageJob := &batchv1.Job{} 361 | if err := r.Get(ctx, types.NamespacedName{Namespace: finetuneJob.Namespace, Name: jobName}, buildImageJob); err != nil { 362 | if errors.IsNotFound(err) { 363 | r.Log.Infof("Job %s/%s not found, waiting", finetuneJob.Namespace, jobName) 364 | return valueobject.ErrRecalibrate 365 | } 366 | return err 367 | } 368 | 369 | if buildImageJob.Status.CompletionTime != nil && finetuneJob.Status.State != finetunev1beta1.FinetuneJobServe { 370 | llmCheckpoint := &corev1beta1.LLMCheckpoint{} 371 | if err := r.Get(ctx, types.NamespacedName{Name: finetune.Status.LLMCheckpoint.LLMCheckpointRef, Namespace: finetuneJob.Namespace}, llmCheckpoint); err != nil { 372 | r.Log.Errorf("Get llmCheckpoint %s/%s failed, err: %v", finetuneJob.Namespace, finetune.Status.LLMCheckpoint.LLMCheckpointRef, err) 373 | return err 374 | } 375 | r.Log.Infof("Build image success, start update llmCheckpoint %s/%s", llmCheckpoint.Namespace, llmCheckpoint.Name) 376 | // todo(tigerK) update llmCheckpoint spec.checkpointimage 377 | r.Log.Infof("Update llmCheckpoint status successful, start send serve") 378 | rayServiceName := fmt.Sprintf("%s", finetuneJob.Name) 379 | importPath := fmt.Sprintf("%s.deployment", "inference") 380 | runtimeEnv := "working_dir: file:///home/inference/inference.zip" 381 | deploymentName := "LlamaDeployment" 382 | rayService := generate.GenerateRayService(rayServiceName, 383 | finetuneJob.Namespace, importPath, runtimeEnv, deploymentName, 384 | int32(1), float64(1), finetuneJob, llmCheckpoint) 385 | if err := ctrl.SetControllerReference(finetuneJob, rayService, r.Scheme); err != nil { 386 | r.Log.Errorf("Set owner failed: %v", err) 387 | return err 388 | } 389 | 390 | if err := r.Client.Get(ctx, types.NamespacedName{Name: rayServiceName, Namespace: finetuneJob.Namespace}, rayService); err != nil { 391 | if errors.IsNotFound(err) { 392 | if err := r.Create(ctx, rayService); err != nil { 393 | r.Log.Errorf("Create rayService %s/%s failed: %v", rayServiceName, finetuneJob.Namespace, err) 394 | return err 395 | } 396 | } 397 | } 398 | 399 | finetuneJob.Status.State = finetunev1beta1.FinetuneJobServe 400 | finetuneJob.Status.FinetuneStatus = &finetune.Status 401 | finetuneJob.Status.Result = &finetunev1beta1.FinetuneJobResult{ 402 | ModelExportResult: true, 403 | Image: *llmCheckpoint.Spec.CheckpointImage.Name, 404 | } 405 | if err := r.Client.Status().Update(ctx, finetuneJob); err != nil { 406 | r.Log.Errorf("Update finetuneJob status failed: %v", err) 407 | return err 408 | } 409 | } 410 | return nil 411 | } 412 | 413 | func (r *FinetuneJobReconciler) reconcileByRayServiceStatus(ctx context.Context, finetuneJob *finetunev1beta1.FinetuneJob) error { 414 | rayServiceName := fmt.Sprintf("%s", finetuneJob.Name) 415 | rayService := &rayv1.RayService{} 416 | if err := r.Get(ctx, types.NamespacedName{ 417 | Name: rayServiceName, 418 | Namespace: finetuneJob.Namespace, 419 | }, rayService); err != nil { 420 | r.Log.Errorf("Get rayService %s/%s failed: %v", finetuneJob.Namespace, rayServiceName, err) 421 | return err 422 | } 423 | if finetuneJob.Status.State == finetunev1beta1.FinetuneJobServe && rayService.Status.ServiceStatus == rayv1.Running { 424 | if rayService.Status.ActiveServiceStatus.Applications["default"].Deployments["LlamaDeployment"].Status == "HEALTHY" { 425 | // todo(tigerK) no time for optimisation 426 | //serveNodePort := rayService.Status.ActiveServiceStatus.RayClusterStatus.Endpoints["serve"] 427 | //dashboardNodePort := rayService.Status.ActiveServiceStatus.RayClusterStatus.Endpoints["dashboard"] 428 | finetuneJob.Status.Result.Serve = fmt.Sprintf("%s.%s.svc:%s", finetuneJob.Name, finetuneJob.Namespace, "8000") 429 | finetuneJob.Status.Result.Dashboard = fmt.Sprintf("%s.%s.svc:%s", finetuneJob.Name, finetuneJob.Namespace, "8080") 430 | } else { 431 | return valueobject.ErrRecalibrate 432 | } 433 | infrencePath := fmt.Sprintf("http://%s/chat/completions", finetuneJob.Status.Result.Serve) 434 | if err := r.Client.Status().Update(ctx, finetuneJob); err != nil { 435 | r.Log.Errorf("Update finetuneJob status failed: %v", err) 436 | return err 437 | } 438 | scoringName := fmt.Sprintf("%s-scoring", finetuneJob.Name) 439 | if finetuneJob.Spec.ScoringPluginConfig == nil { 440 | scoring := generate.GenerateBuiltInScoring(scoringName, finetuneJob.Namespace, infrencePath) 441 | if err := ctrl.SetControllerReference(finetuneJob, scoring, r.Scheme); err != nil { 442 | r.Log.Errorf("Set owner failed: %v", err) 443 | return err 444 | } 445 | if err := r.Create(ctx, scoring); err != nil { 446 | if !errors.IsAlreadyExists(err) { 447 | r.Log.Errorf("Create scoring %s/%s failed: %v", scoringName, finetuneJob.Namespace, err) 448 | return err 449 | } 450 | } 451 | return nil 452 | } 453 | scoring := generate.GeneratePluginScoring(scoringName, finetuneJob.Namespace, finetuneJob.Spec.ScoringPluginConfig.Name, finetuneJob.Spec.ScoringPluginConfig.Parameters, infrencePath) 454 | if err := ctrl.SetControllerReference(finetuneJob, scoring, r.Scheme); err != nil { 455 | r.Log.Errorf("Set owner failed: %v", err) 456 | return err 457 | } 458 | if err := r.Create(ctx, scoring); err != nil { 459 | if !errors.IsAlreadyExists(err) { 460 | r.Log.Errorf("Create scoring %s/%s failed: %v", scoringName, finetuneJob.Namespace, err) 461 | return err 462 | } 463 | } 464 | } 465 | return nil 466 | } 467 | 468 | func (r *FinetuneJobReconciler) reconcileByScoringStatus(ctx context.Context, finetuneJob *finetunev1beta1.FinetuneJob) error { 469 | 470 | scoringName := fmt.Sprintf("%s-scoring", finetuneJob.Name) 471 | scoring := &extensionv1beta1.Scoring{} 472 | if err := r.Get(ctx, types.NamespacedName{ 473 | Name: scoringName, 474 | Namespace: finetuneJob.Namespace, 475 | }, scoring); err != nil { 476 | if errors.IsNotFound(err) { 477 | r.Log.Infof("Scoring %s/%s not found, err: %v", scoringName, finetuneJob.Namespace, err) 478 | return valueobject.ErrRecalibrate 479 | } 480 | r.Log.Errorf("Get scoring %s/%s failed: %v", scoringName, finetuneJob.Namespace, err) 481 | return err 482 | } 483 | 484 | // todo(tigerK) get scoring result, update finetuneJob status 485 | if scoring.Status.Score != nil { 486 | finetuneJob.Status.State = finetunev1beta1.FinetuneJobSuccessful 487 | finetuneJob.Status.Result.Score = *scoring.Status.Score 488 | finetuneJob.Status.Stats = metav1.Now().Format("2006-01-02 15:04:05") 489 | if err := r.Client.Status().Update(ctx, finetuneJob); err != nil { 490 | r.Log.Errorf("Update finetuneJob status failed: %v", err) 491 | return err 492 | } 493 | rayServiceName := fmt.Sprintf("%s", finetuneJob.Name) 494 | rayService := &rayv1.RayService{} 495 | if err := r.Get(ctx, types.NamespacedName{ 496 | Name: rayServiceName, 497 | Namespace: finetuneJob.Namespace, 498 | }, rayService); err != nil { 499 | if errors.IsNotFound(err) { 500 | return nil 501 | } 502 | r.Log.Errorf("Get rayService %s/%s failed: %v", finetuneJob.Namespace, rayServiceName, err) 503 | return err 504 | } 505 | if err := r.Delete(ctx, rayService); err != nil { 506 | r.Log.Errorf("Delete rayService %s/%s failed: %v", finetuneJob.Namespace, rayServiceName, err) 507 | return err 508 | } 509 | } 510 | return nil 511 | } 512 | 513 | func (r *FinetuneJobReconciler) reconcileCleaner(ctx context.Context, finetuneJob *finetunev1beta1.FinetuneJob) error { 514 | preCondition := make(map[string]client.Object, 3) 515 | preCondition[finetuneJob.Spec.FineTune.FinetuneSpec.LLM] = &corev1beta1.LLM{} 516 | preCondition[finetuneJob.Spec.FineTune.FinetuneSpec.Hyperparameter.HyperparameterRef] = &corev1beta1.Hyperparameter{} 517 | preCondition[finetuneJob.Spec.FineTune.FinetuneSpec.Dataset] = &extensionv1beta1.Dataset{} 518 | for name, obj := range preCondition { 519 | if err := r.Get(ctx, types.NamespacedName{Name: name, Namespace: finetuneJob.Namespace}, obj); err != nil { 520 | r.Log.Errorf("Get %s: %s/%s failed, err: %v", obj.GetObjectKind(), finetuneJob.Namespace, name, err) 521 | return err 522 | } 523 | switch obj.(type) { 524 | case *corev1beta1.LLM: 525 | llm := obj.(*corev1beta1.LLM) 526 | if llm.Status.ReferenceFinetuneName == nil { 527 | continue 528 | } 529 | result := slice.IndexOf(llm.Status.ReferenceFinetuneName, finetuneJob.Spec.FineTune.Name) 530 | llm.Status.ReferenceFinetuneName = slice.DeleteAt(llm.Status.ReferenceFinetuneName, result) 531 | if err := r.Client.Status().Update(ctx, llm); err != nil { 532 | r.Log.Errorf("update %s: %s/%s status failed, err: %v", obj.GetObjectKind(), finetuneJob.Namespace, name, err) 533 | return err 534 | } 535 | case *extensionv1beta1.Dataset: 536 | dataset := obj.(*extensionv1beta1.Dataset) 537 | if dataset.Status.ReferenceFinetuneName == nil { 538 | continue 539 | } 540 | result := slice.IndexOf(dataset.Status.ReferenceFinetuneName, finetuneJob.Spec.FineTune.Name) 541 | dataset.Status.ReferenceFinetuneName = slice.DeleteAt(dataset.Status.ReferenceFinetuneName, result) 542 | if err := r.Client.Status().Update(ctx, dataset); err != nil { 543 | r.Log.Errorf("update %s: %s/%s status failed, err: %v", obj.GetObjectKind(), finetuneJob.Namespace, name, err) 544 | return err 545 | } 546 | case *corev1beta1.Hyperparameter: 547 | hyperparameter := obj.(*corev1beta1.Hyperparameter) 548 | if hyperparameter.Status.ReferenceFinetuneName == nil { 549 | continue 550 | } 551 | result := slice.IndexOf(hyperparameter.Status.ReferenceFinetuneName, finetuneJob.Spec.FineTune.Name) 552 | hyperparameter.Status.ReferenceFinetuneName = slice.DeleteAt(hyperparameter.Status.ReferenceFinetuneName, result) 553 | if err := r.Client.Status().Update(ctx, hyperparameter); err != nil { 554 | r.Log.Errorf("update %s: %s/%s status failed, err: %v", obj.GetObjectKind(), finetuneJob.Namespace, name, err) 555 | return err 556 | } 557 | } 558 | } 559 | return nil 560 | } 561 | -------------------------------------------------------------------------------- /main.go: -------------------------------------------------------------------------------- 1 | /* 2 | Copyright 2023. 3 | 4 | Licensed under the Apache License, Version 2.0 (the "License"); 5 | you may not use this file except in compliance with the License. 6 | You may obtain a copy of the License at 7 | 8 | http://www.apache.org/licenses/LICENSE-2.0 9 | 10 | Unless required by applicable law or agreed to in writing, software 11 | distributed under the License is distributed on an "AS IS" BASIS, 12 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | See the License for the specific language governing permissions and 14 | limitations under the License. 15 | */ 16 | 17 | package main 18 | 19 | import ( 20 | "os" 21 | 22 | "github.com/DataTunerX/datatunerx/cmd/controller-manager/app" 23 | "github.com/DataTunerX/datatunerx/pkg/config" 24 | "github.com/DataTunerX/utility-server/logging" 25 | ctrl "sigs.k8s.io/controller-runtime" 26 | ) 27 | 28 | func main() { 29 | logging.NewZapLogger(config.GetLevel()) 30 | controllerManager, err := app.NewControllerManager() 31 | if err != nil { 32 | os.Exit(1) 33 | } 34 | logging.ZLogger.Infof("Start controller manager") 35 | if err := controllerManager.Start(ctrl.SetupSignalHandler()); err != nil { 36 | logging.ZLogger.Errorf("Problem running manager: %v", err) 37 | os.Exit(1) 38 | } 39 | } 40 | -------------------------------------------------------------------------------- /pkg/config/config.go: -------------------------------------------------------------------------------- 1 | package config 2 | 3 | import "github.com/spf13/viper" 4 | 5 | var config *viper.Viper 6 | 7 | func init() { 8 | config = viper.New() 9 | config.AutomaticEnv() 10 | config.BindEnv("level", "LOG_LEVEL") 11 | config.SetDefault("level", "debug") 12 | config.BindEnv("endpoint", "S3_ENDPOINT") 13 | config.BindEnv("accessKey", "S3_ACCESSKEYID") 14 | config.BindEnv("secretkey", "S3_SECRETACCESSKEY") 15 | config.BindEnv("bucket", "S3_BUCKET") 16 | config.BindEnv("secure", "S3_SECURE") 17 | config.BindEnv("registryUrl", "REGISTRY_URL") 18 | config.BindEnv("repositoryName", "REPOSITORY_NAME") 19 | config.BindEnv("userName", "USERNAME") 20 | config.BindEnv("password", "PASSWORD") 21 | config.BindEnv("mountPath", "MOUNT_PATH") 22 | config.BindEnv("baseImage", "BASE_IMAGE") 23 | config.BindEnv("llmUrl", "LLM_URL") 24 | config.BindEnv("metricsExportAddress", "METRICS_EXPORT_ADDRESS") 25 | config.BindEnv("storagePath", "STORAGE_PATH") 26 | config.SetDefault("llmUrl", "/tmp/llama2-7b/") 27 | } 28 | 29 | func GetS3Endpoint() string { 30 | return config.GetString("endpoint") 31 | } 32 | 33 | func GetS3AccesskeyId() string { 34 | return config.GetString("accessKey") 35 | } 36 | 37 | func GetS3ESecretAccessKey() string { 38 | return config.GetString("secretkey") 39 | } 40 | 41 | func GetS3Bucket() string { 42 | return config.GetString("bucket") 43 | } 44 | 45 | func GetSecure() string { 46 | return config.GetString("secure") 47 | } 48 | 49 | func GetLevel() string { 50 | return config.GetString("level") 51 | } 52 | 53 | func GetUserName() string { 54 | return config.GetString("userName") 55 | } 56 | 57 | func GetBaseImage() string { 58 | return config.GetString("baseImage") 59 | } 60 | 61 | func GetPassword() string { 62 | return config.GetString("password") 63 | } 64 | 65 | func GetRegistryUrl() string { 66 | return config.GetString("registryUrl") 67 | } 68 | 69 | func GetRepositoryName() string { 70 | return config.GetString("repositoryName") 71 | } 72 | 73 | func GetMountPath() string { 74 | return config.GetString("mountPath") 75 | } 76 | 77 | func GetLLMUrl() string { 78 | return config.GetString("llmUrl") 79 | } 80 | 81 | func GetStoragePath() string { 82 | return config.GetString("storagePath") 83 | } 84 | 85 | func GetMetricsExportAddress() string { 86 | return config.GetString("metricsExportAddress") 87 | } 88 | -------------------------------------------------------------------------------- /pkg/domain/valueobject/err.go: -------------------------------------------------------------------------------- 1 | package valueobject 2 | 3 | import "errors" 4 | 5 | var ( 6 | ErrRecalibrate = errors.New("waiting for dependent resources") 7 | ) 8 | -------------------------------------------------------------------------------- /pkg/events/events.go: -------------------------------------------------------------------------------- 1 | package events 2 | 3 | const ( 4 | EventReasonCheckDependentResourcesFailed = "CheckDependentResourcesFailed" 5 | EventReasonCheckDependentResourcesSucceed = "CheckDependentResourcesSucceed" 6 | ) 7 | -------------------------------------------------------------------------------- /pkg/util/generate/generate.go: -------------------------------------------------------------------------------- 1 | package generate 2 | 3 | import ( 4 | "fmt" 5 | 6 | "github.com/DataTunerX/datatunerx/pkg/config" 7 | "github.com/DataTunerX/datatunerx/pkg/util/label" 8 | corev1beta1 "github.com/DataTunerX/meta-server/api/core/v1beta1" 9 | extensionv1beta1 "github.com/DataTunerX/meta-server/api/extension/v1beta1" 10 | finetunev1beta1 "github.com/DataTunerX/meta-server/api/finetune/v1beta1" 11 | rayv1 "github.com/ray-project/kuberay/ray-operator/apis/ray/v1" 12 | batchv1 "k8s.io/api/batch/v1" 13 | corev1 "k8s.io/api/core/v1" 14 | "k8s.io/apimachinery/pkg/api/resource" 15 | metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" 16 | "k8s.io/apimachinery/pkg/util/intstr" 17 | ) 18 | 19 | const ( 20 | // todo llm file path 21 | defaultFinetuneCodePath = "/tmp/llama2-7b/" 22 | 23 | defaultBuildImageJobContainerName = "imagebuild" 24 | defaultBuildImageJobImage = "release.daocloud.io/datatunerx/buildimage:v0.0.1" 25 | ) 26 | 27 | func GenerateFinetune(finetuneJob *finetunev1beta1.FinetuneJob) *finetunev1beta1.Finetune { 28 | finetuneLabel := label.GenerateInstanceLabel(finetuneJob.Name) 29 | finetune := &finetunev1beta1.Finetune{ 30 | ObjectMeta: metav1.ObjectMeta{ 31 | Name: finetuneJob.Spec.FineTune.Name, 32 | Namespace: finetuneJob.Namespace, 33 | Labels: finetuneLabel, 34 | }, 35 | Spec: finetunev1beta1.FinetuneSpec{ 36 | Dataset: finetuneJob.Spec.FineTune.FinetuneSpec.Dataset, 37 | LLM: finetuneJob.Spec.FineTune.FinetuneSpec.LLM, 38 | Hyperparameter: finetuneJob.Spec.FineTune.FinetuneSpec.Hyperparameter, 39 | Image: finetuneJob.Spec.FineTune.FinetuneSpec.Image, 40 | Node: finetuneJob.Spec.FineTune.FinetuneSpec.Node, 41 | }, 42 | } 43 | if finetuneJob.Spec.FineTune.FinetuneSpec.Resource != nil { 44 | finetune.Spec.Resource = finetuneJob.Spec.FineTune.FinetuneSpec.Resource 45 | } 46 | if finetuneJob.Spec.FineTune.FinetuneSpec.Image.Name == "" { 47 | finetune.Spec.Image.Name = config.GetBaseImage() 48 | } 49 | if finetuneJob.Spec.FineTune.FinetuneSpec.Image.Path == "" { 50 | finetune.Spec.Image.Path = defaultFinetuneCodePath 51 | } 52 | return finetune 53 | } 54 | 55 | func GenerateBuildImageJob(filePath, imageName, imageTag string, finetuneJobInstance *finetunev1beta1.FinetuneJob) *batchv1.Job { 56 | privileged := true 57 | directory := corev1.HostPathDirectory 58 | buildImageJobName := fmt.Sprintf("%s-buildimage", finetuneJobInstance.Name) 59 | jobLabel := label.GenerateInstanceLabel(finetuneJobInstance.Name) 60 | return &batchv1.Job{ 61 | ObjectMeta: metav1.ObjectMeta{ 62 | Name: buildImageJobName, 63 | Namespace: finetuneJobInstance.Namespace, 64 | Labels: jobLabel, 65 | }, 66 | Spec: batchv1.JobSpec{ 67 | Template: corev1.PodTemplateSpec{ 68 | Spec: corev1.PodSpec{ 69 | Containers: []corev1.Container{ 70 | { 71 | Name: defaultBuildImageJobContainerName, 72 | Image: defaultBuildImageJobImage, 73 | Env: []corev1.EnvVar{ 74 | { 75 | Name: "S3_ENDPOINT", 76 | Value: config.GetS3Endpoint(), 77 | }, 78 | {Name: "S3_ACCESSKEYID", 79 | Value: config.GetS3AccesskeyId(), 80 | }, 81 | { 82 | Name: "S3_SECRETACCESSKEY", 83 | Value: config.GetS3ESecretAccessKey(), 84 | }, 85 | { 86 | Name: "S3_BUCKET", 87 | Value: config.GetS3Bucket(), 88 | }, 89 | { 90 | Name: "S3_FILEPATH", 91 | Value: filePath, 92 | }, 93 | { 94 | Name: "S3_SECURE", 95 | Value: config.GetSecure(), 96 | }, 97 | { 98 | Name: "MOUNT_PATH", 99 | Value: config.GetMountPath(), 100 | }, 101 | { 102 | Name: "REGISTRY_URL", 103 | Value: config.GetRegistryUrl(), 104 | }, 105 | { 106 | Name: "REPOSITORY_NAME", 107 | Value: config.GetRepositoryName(), 108 | }, 109 | { 110 | Name: "USERNAME", 111 | Value: config.GetUserName(), 112 | }, 113 | { 114 | Name: "BASE_IMAGE", 115 | Value: config.GetBaseImage(), 116 | }, 117 | { 118 | Name: "PASSWORD", 119 | Value: config.GetPassword(), 120 | }, 121 | { 122 | Name: "IMAGE_NAME", 123 | Value: imageName, 124 | }, 125 | { 126 | Name: "IMAGE_TAG", 127 | Value: imageTag, 128 | }, 129 | }, 130 | VolumeMounts: []corev1.VolumeMount{ 131 | { 132 | Name: "data", 133 | MountPath: config.GetMountPath(), 134 | }, 135 | }, 136 | SecurityContext: &corev1.SecurityContext{ 137 | Privileged: &privileged, 138 | }, 139 | }, 140 | }, 141 | RestartPolicy: corev1.RestartPolicyNever, 142 | Volumes: []corev1.Volume{ 143 | { 144 | Name: "data", 145 | VolumeSource: corev1.VolumeSource{ 146 | HostPath: &corev1.HostPathVolumeSource{ 147 | Path: "/root/jobdata/", 148 | Type: &directory, 149 | }, 150 | }, 151 | }, 152 | }, 153 | }, 154 | }, 155 | }, 156 | } 157 | 158 | } 159 | 160 | func GenerateRayService(name, namespace, importPath, runtimeEnv, deploymentName string, numReplicas int32, numGpus float64, finetuneJob *finetunev1beta1.FinetuneJob, llmCheckpoint *corev1beta1.LLMCheckpoint) *rayv1.RayService { 161 | // todo(tigerK) hardcode for rubbish 162 | numReplica := &numReplicas 163 | numGpu := &numGpus 164 | enableInTreeAutoscaling := false 165 | workReplicas := int32(1) 166 | minWorkReplicas := int32(1) 167 | maxWorkReplicas := int32(1) 168 | if finetuneJob.Spec.ServeConfig.NodeSelector == nil { 169 | finetuneJob.Spec.ServeConfig.NodeSelector = map[string]string{ 170 | "nvidia.com/gpu": "present", 171 | } 172 | } 173 | return &rayv1.RayService{ 174 | ObjectMeta: metav1.ObjectMeta{ 175 | Name: name, 176 | Namespace: namespace, 177 | }, 178 | Spec: rayv1.RayServiceSpec{ 179 | ServeService: &corev1.Service{ 180 | ObjectMeta: metav1.ObjectMeta{ 181 | Name: finetuneJob.Name, 182 | }, 183 | Spec: corev1.ServiceSpec{ 184 | Ports: []corev1.ServicePort{ 185 | { 186 | Name: "serve", 187 | Port: 8000, 188 | Protocol: corev1.ProtocolTCP, 189 | TargetPort: intstr.FromInt(8000), 190 | }, 191 | }, 192 | Selector: map[string]string{ 193 | "ray.io/node-type": "head", 194 | }, 195 | Type: corev1.ServiceTypeNodePort, 196 | }, 197 | }, 198 | ServeDeploymentGraphSpec: rayv1.ServeDeploymentGraphSpec{ 199 | ImportPath: importPath, 200 | RuntimeEnv: runtimeEnv, 201 | ServeConfigSpecs: []rayv1.ServeConfigSpec{ 202 | { 203 | Name: deploymentName, 204 | NumReplicas: numReplica, 205 | RayActorOptions: rayv1.RayActorOptionSpec{ 206 | NumGpus: numGpu, 207 | }, 208 | }, 209 | }, 210 | }, 211 | RayClusterSpec: rayv1.RayClusterSpec{ 212 | RayVersion: "2.7.1", 213 | EnableInTreeAutoscaling: &enableInTreeAutoscaling, 214 | HeadGroupSpec: rayv1.HeadGroupSpec{ 215 | RayStartParams: map[string]string{ 216 | "dashboard-host": "0.0.0.0", 217 | "num-gpus": "0", 218 | }, 219 | ServiceType: corev1.ServiceTypeNodePort, 220 | Template: corev1.PodTemplateSpec{ 221 | Spec: corev1.PodSpec{ 222 | Containers: []corev1.Container{ 223 | { 224 | Name: fmt.Sprintf("%s-head", finetuneJob.Name), 225 | Image: *llmCheckpoint.Spec.CheckpointImage.Name, 226 | ImagePullPolicy: corev1.PullAlways, 227 | Env: []corev1.EnvVar{ 228 | { 229 | Name: "RAY_SERVE_ENABLE_EXPERIMENTAL_STREAMING", 230 | Value: "1", 231 | }, 232 | }, 233 | Ports: []corev1.ContainerPort{ 234 | { 235 | Name: "gcs-server", 236 | ContainerPort: 6379, 237 | }, 238 | { 239 | Name: "dashboard", 240 | ContainerPort: 8265, 241 | }, 242 | { 243 | Name: "client", 244 | ContainerPort: 10001, 245 | }, 246 | { 247 | Name: "serve", 248 | ContainerPort: 8000, 249 | }, 250 | }, 251 | Resources: corev1.ResourceRequirements{ 252 | Limits: map[corev1.ResourceName]resource.Quantity{ 253 | corev1.ResourceCPU: resource.MustParse("2"), 254 | corev1.ResourceMemory: resource.MustParse("16Gi"), 255 | }, 256 | Requests: map[corev1.ResourceName]resource.Quantity{ 257 | corev1.ResourceCPU: resource.MustParse("1"), 258 | corev1.ResourceMemory: resource.MustParse("8Gi"), 259 | }, 260 | }, 261 | }, 262 | }, 263 | Tolerations: finetuneJob.Spec.ServeConfig.Tolerations, 264 | NodeSelector: finetuneJob.Spec.ServeConfig.NodeSelector, 265 | }, 266 | }, 267 | }, 268 | WorkerGroupSpecs: []rayv1.WorkerGroupSpec{ 269 | { 270 | Replicas: &workReplicas, 271 | MinReplicas: &minWorkReplicas, 272 | MaxReplicas: &maxWorkReplicas, 273 | GroupName: finetuneJob.Name, 274 | RayStartParams: map[string]string{}, 275 | Template: corev1.PodTemplateSpec{ 276 | Spec: corev1.PodSpec{ 277 | Containers: []corev1.Container{ 278 | { 279 | Name: fmt.Sprintf("%s-work", finetuneJob.Name), 280 | Image: *llmCheckpoint.Spec.CheckpointImage.Name, 281 | ImagePullPolicy: corev1.PullAlways, 282 | Env: []corev1.EnvVar{ 283 | { 284 | Name: "RAY_SERVE_ENABLE_EXPERIMENTAL_STREAMING", 285 | Value: "1", 286 | }, 287 | { 288 | Name: "BASE_MODEL_DIR", 289 | Value: llmCheckpoint.Spec.CheckpointImage.LLMPath, 290 | }, 291 | { 292 | Name: "CHECKPOINT_DIR", 293 | Value: llmCheckpoint.Spec.CheckpointImage.CheckPointPath, 294 | }, 295 | }, 296 | Lifecycle: &corev1.Lifecycle{ 297 | PreStop: &corev1.LifecycleHandler{ 298 | Exec: &corev1.ExecAction{ 299 | Command: []string{ 300 | "/bin/sh", "-c", "ray stop", 301 | }, 302 | }, 303 | }, 304 | }, 305 | Resources: corev1.ResourceRequirements{ 306 | Limits: map[corev1.ResourceName]resource.Quantity{ 307 | corev1.ResourceCPU: resource.MustParse("8"), 308 | corev1.ResourceMemory: resource.MustParse("64Gi"), 309 | "nvidia.com/gpu": resource.MustParse("1"), 310 | }, 311 | Requests: map[corev1.ResourceName]resource.Quantity{ 312 | corev1.ResourceCPU: resource.MustParse("4"), 313 | corev1.ResourceMemory: resource.MustParse("32Gi"), 314 | "nvidia.com/gpu": resource.MustParse("1"), 315 | }, 316 | }, 317 | }, 318 | }, 319 | Tolerations: finetuneJob.Spec.ServeConfig.Tolerations, 320 | NodeSelector: finetuneJob.Spec.ServeConfig.NodeSelector, 321 | }, 322 | }, 323 | }, 324 | }, 325 | }, 326 | }, 327 | } 328 | 329 | } 330 | 331 | func GenerateBuiltInScoring(name, namespace, inference string) *extensionv1beta1.Scoring { 332 | return &extensionv1beta1.Scoring{ 333 | ObjectMeta: metav1.ObjectMeta{ 334 | Name: name, 335 | Namespace: namespace, 336 | }, 337 | Spec: extensionv1beta1.ScoringSpec{ 338 | InferenceService: inference, 339 | }, 340 | } 341 | } 342 | 343 | func GeneratePluginScoring(name, namespace, pluginName, parameters, inference string) *extensionv1beta1.Scoring { 344 | return &extensionv1beta1.Scoring{ 345 | ObjectMeta: metav1.ObjectMeta{ 346 | Name: name, 347 | Namespace: namespace, 348 | }, 349 | Spec: extensionv1beta1.ScoringSpec{ 350 | Plugin: &extensionv1beta1.Plugin{ 351 | LoadPlugin: true, 352 | Name: pluginName, 353 | Parameters: parameters, 354 | }, 355 | InferenceService: inference, 356 | }, 357 | } 358 | } 359 | -------------------------------------------------------------------------------- /pkg/util/handlererr/handler.go: -------------------------------------------------------------------------------- 1 | package handlererr 2 | 3 | import ( 4 | "errors" 5 | "time" 6 | 7 | "github.com/DataTunerX/datatunerx/pkg/domain/valueobject" 8 | ctrl "sigs.k8s.io/controller-runtime" 9 | ) 10 | 11 | func HandlerErr(err error) (ctrl.Result, error) { 12 | if err != nil { 13 | if errors.Is(err, valueobject.ErrRecalibrate) { 14 | return ctrl.Result{RequeueAfter: 10 * time.Second}, nil 15 | } 16 | return ctrl.Result{RequeueAfter: 30 * time.Second}, err 17 | } 18 | return ctrl.Result{}, nil 19 | } 20 | -------------------------------------------------------------------------------- /pkg/util/label/label.go: -------------------------------------------------------------------------------- 1 | package label 2 | 3 | const ( 4 | LabelInstanceKey = "finetune.datatunerx.io/instance" 5 | LabelDatatunerx = "datatunerx" 6 | LabelFinetuneJob = "finetunejob" 7 | LabelFinetune = "finetune" 8 | LabelFinetuneExperiment = "finetuneexperiment" 9 | LabelComponentKey = "finetune.datatunerx.io/component" 10 | LabelPartOfKey = "finetune.datatunerx.io/part-of" 11 | LabelFinetuneBindingKey = "finetune.datatunerx.io/finetunebinding" 12 | ) 13 | 14 | func GenerateInstanceLabel(instanceName string) map[string]string { 15 | baseLabel := GetBaseLabel() 16 | return MergeLabel(baseLabel, map[string]string{ 17 | LabelInstanceKey: instanceName, 18 | LabelComponentKey: LabelFinetuneJob, 19 | }) 20 | } 21 | 22 | func GetBaseLabel() map[string]string { 23 | return map[string]string{ 24 | LabelPartOfKey: LabelDatatunerx, 25 | } 26 | } 27 | 28 | func MergeLabel(baseLabel map[string]string, customLabel map[string]string) map[string]string { 29 | for k, v := range customLabel { 30 | if _, exists := baseLabel[k]; !exists { 31 | baseLabel[k] = v 32 | } 33 | } 34 | return baseLabel 35 | } 36 | -------------------------------------------------------------------------------- /pkg/util/util.go: -------------------------------------------------------------------------------- 1 | package util 2 | 3 | import ( 4 | "io/ioutil" 5 | "os" 6 | "strconv" 7 | "strings" 8 | 9 | "github.com/DataTunerX/utility-server/logging" 10 | ) 11 | 12 | func RemoveBucketName(path, bucketName string) string { 13 | parts := strings.Split(path, "/") 14 | if len(parts) > 0 && parts[0] == bucketName { 15 | return strings.Join(parts[1:], "/") 16 | } 17 | return path 18 | } 19 | 20 | func GenerateName() { 21 | 22 | } 23 | 24 | func ParseScore(s string) int { 25 | score, err := strconv.Atoi(s) 26 | if err != nil { 27 | return 0 28 | } 29 | return score 30 | } 31 | 32 | func GetOperatorNamespace() string { 33 | nsBytes, err := ioutil.ReadFile("/var/run/secrets/kubernetes.io/serviceaccount/namespace") 34 | if err != nil { 35 | logging.ZLogger.Errorf("unable to read file, %v", err) 36 | if os.IsNotExist(err) { 37 | return "datatunerx-dev" 38 | } 39 | } 40 | ns := strings.TrimSpace(string(nsBytes)) 41 | return ns 42 | } 43 | --------------------------------------------------------------------------------