├── .gitignore ├── .travis.yml ├── ADOPTERS.md ├── Dockerfile ├── LICENSE ├── Makefile ├── OWNERS ├── PROJECT ├── README.md ├── build_image.sh ├── cmd └── manager │ └── main.go ├── config ├── crds │ ├── xgboostjob_v1_xgboostjobs.yaml │ └── xgboostjob_v1alpha1_xgboostjob.yaml ├── default │ ├── kustomization.yaml │ ├── manager_auth_proxy_patch.yaml │ ├── manager_image_patch.yaml │ └── manager_prometheus_metrics_patch.yaml ├── manager │ └── manager.yaml ├── rbac │ ├── auth_proxy_role.yaml │ ├── auth_proxy_role_binding.yaml │ ├── auth_proxy_service.yaml │ ├── rbac_role.yaml │ └── rbac_role_binding.yaml └── samples │ ├── lightgbm-dist │ ├── Dockerfile │ ├── README.md │ ├── main.py │ ├── train.py │ ├── utils.py │ └── xgboostjob_v1_lightgbm_dist_training.yaml │ ├── smoke-dist │ ├── Dockerfile │ ├── README.md │ ├── requirements.txt │ ├── tracker.py │ ├── xgboost_smoke_test.py │ ├── xgboostjob_v1_rabit_test.yaml │ └── xgboostjob_v1alpha1_rabit_test.yaml │ └── xgboost-dist │ ├── Dockerfile │ ├── README.md │ ├── build.sh │ ├── local_test.py │ ├── main.py │ ├── predict.py │ ├── requirements.txt │ ├── tracker.py │ ├── train.py │ ├── utils.py │ ├── xgboostjob_v1_iris_predict.yaml │ ├── xgboostjob_v1_iris_train.yaml │ ├── xgboostjob_v1alpha1_iris_predict.yaml │ ├── xgboostjob_v1alpha1_iris_predict_local.yaml │ ├── xgboostjob_v1alpha1_iris_train.yaml │ └── xgboostjob_v1alpha1_iris_train_local.yaml ├── go.mod ├── go.sum ├── hack └── boilerplate.go.txt ├── manifests ├── base │ ├── cluster-role-binding.yaml │ ├── cluster-role.yaml │ ├── crd.yaml │ ├── deployment.yaml │ ├── kustomization.yaml │ ├── params.env │ ├── service-account.yaml │ └── service.yaml └── overlays │ └── kubeflow │ └── kustomization.yaml ├── pkg ├── apis │ ├── addtoscheme_xgboostjob.go │ ├── apis.go │ └── xgboostjob │ │ ├── group.go │ │ └── v1 │ │ ├── constants.go │ │ ├── doc.go │ │ ├── register.go │ │ ├── v1_suite_test.go.back │ │ ├── xgboostjob_types.go │ │ ├── xgboostjob_types_test.go.back │ │ └── zz_generated.deepcopy.go ├── controller │ └── v1 │ │ ├── add_xgboostjob.go │ │ ├── controller.go │ │ └── xgboostjob │ │ ├── expectation.go │ │ ├── job.go │ │ ├── pod.go │ │ ├── pod_test.go │ │ ├── service.go │ │ ├── util.go │ │ ├── xgboostjob_controller.go │ │ ├── xgboostjob_controller_suite_test.go.back │ │ └── xgboostjob_controller_test.go.back └── webhook │ └── webhook.go ├── prow_config.yaml └── test └── workflows ├── .gitignore ├── app.yaml ├── components ├── build.jsonnet ├── params.libsonnet └── util.libsonnet ├── environments ├── base.libsonnet └── dockerbuild │ ├── globals.libsonnet │ ├── main.jsonnet │ └── params.libsonnet └── vendor └── kubeflow └── automation ├── parts.yaml ├── prototypes └── release.jsonnet └── release.libsonnet /.gitignore: -------------------------------------------------------------------------------- 1 | # Files created by Gogland IDE 2 | .idea/ 3 | 4 | # Other temporary files 5 | .DS_Store 6 | 7 | # Compiled python files. 8 | *.pyc 9 | 10 | # Emacs temporary files 11 | *~ 12 | 13 | # VIM temporary files. 14 | .swp 15 | 16 | 17 | cover.out 18 | bin/ 19 | -------------------------------------------------------------------------------- /.travis.yml: -------------------------------------------------------------------------------- 1 | language: go 2 | 3 | go: 4 | - "1.12" 5 | 6 | go_import_path: github.com/kubeflow/xgboost-operator 7 | 8 | install: 9 | # get coveralls.io support 10 | - go get github.com/mattn/goveralls 11 | # gometalinter is deprecated; migrating to golangci-lint. See https://github.com/alecthomas/gometalinter/issues/590. 12 | - curl -sfL https://install.goreleaser.com/github.com/golangci/golangci-lint.sh | sh -s -- -b $(go env GOPATH)/bin v1.16.0 13 | 14 | script: 15 | - export GO111MODULE=on 16 | - go build ./... 17 | # - golangci-lint run ./... 18 | # - goveralls -service=travis-ci -v -package ./... -ignore "pkg/client/*/*.go,pkg/client/*/*/*.go,pkg/client/*/*/*/*.go,pkg/client/*/*/*/*/*.go,pkg/client/*/*/*/*/*/*.go,pkg/client/*/*/*/*/*/*/*.go,pkg/apis/xgboost/v1alpha1/zz_generated.*.go" 19 | -------------------------------------------------------------------------------- /ADOPTERS.md: -------------------------------------------------------------------------------- 1 | # Adopters of XGBoost Operator 2 | 3 | This page contains a list of organizations who are using XGBoost Operator. If you'd like to be included here, please send a pull request which modifies this file. Please keep the list in alphabetical order. 4 | 5 | | Organization | Contact | 6 | | ------------ | ------- | 7 | | [Ant Group](https://www.antgroup.com/) | [Mingjie Tang](https://github.com/merlintang) and [Yuan Tang](https://github.com/terrytangyuan) | 8 | | [Tencent](http://tencent.com/en-us/) | [Ye Yin](https://github.com/hustcat) | 9 | -------------------------------------------------------------------------------- /Dockerfile: -------------------------------------------------------------------------------- 1 | # Build the manager binary 2 | FROM golang:1.12.17 as builder 3 | 4 | ENV GO111MODULE=on 5 | 6 | # Copy in the go src 7 | WORKDIR /go/src/github.com/kubeflow/xgboost-operator 8 | 9 | COPY go.mod . 10 | COPY go.sum . 11 | 12 | RUN go mod download 13 | 14 | COPY pkg/ pkg/ 15 | COPY cmd/ cmd/ 16 | COPY config/ config/ 17 | 18 | # Build 19 | RUN CGO_ENABLED=0 GOOS=linux GOARCH=amd64 go build -a -o manager github.com/kubeflow/xgboost-operator/cmd/manager 20 | 21 | # Copy the controller-manager into a thin image 22 | FROM ubuntu:latest 23 | WORKDIR /root 24 | COPY --from=builder /go/src/github.com/kubeflow/xgboost-operator/manager . 25 | ENTRYPOINT ["/root/manager"] 26 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright [yyyy] [name of copyright owner] 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /Makefile: -------------------------------------------------------------------------------- 1 | 2 | # Image URL to use all building/pushing image targets 3 | IMG ?= controller:latest 4 | 5 | all: test manager 6 | 7 | # Run tests 8 | test: generate fmt vet manifests 9 | go test ./pkg/... ./cmd/... -coverprofile cover.out 10 | 11 | # Build manager binary 12 | manager: generate fmt vet 13 | go build -o bin/manager github.com/kubeflow/xgboost-operator/cmd/manager 14 | 15 | # Run against the configured Kubernetes cluster in ~/.kube/config 16 | run: generate fmt vet 17 | go run ./cmd/manager/main.go 18 | 19 | # Install CRDs into a cluster 20 | install: manifests 21 | kubectl apply -f config/crds 22 | 23 | # Deploy controller in the configured Kubernetes cluster in ~/.kube/config 24 | deploy: manifests 25 | kubectl apply -f config/crds 26 | kustomize build config/default | kubectl apply -f - 27 | 28 | # Generate manifests e.g. CRD, RBAC etc. 29 | manifests: controller-gen 30 | $(CONTROLLER_GEN) crd rbac:roleName=manager-role paths=./pkg/apis/... output:crd:artifacts:config=config/crds 31 | 32 | # Run go fmt against code 33 | fmt: 34 | go fmt ./pkg/... ./cmd/... 35 | 36 | # Run go vet against code 37 | vet: 38 | go vet ./pkg/... ./cmd/... 39 | 40 | # Generate code 41 | generate: controller-gen 42 | $(CONTROLLER_GEN) object:headerFile=./hack/boilerplate.go.txt paths="./pkg/apis/... ./pkg/controller/..." 43 | 44 | # Build the docker image 45 | docker-build: test 46 | docker build . -t ${IMG} 47 | @echo "updating kustomize image patch file for manager resource" 48 | sed -i'' -e 's@image: .*@image: '"${IMG}"'@' ./config/default/manager_image_patch.yaml 49 | 50 | # Push the docker image 51 | docker-push: 52 | docker push ${IMG} 53 | 54 | # find or download controller-gen if necessary 55 | controller-gen: 56 | ifeq (, $(shell which controller-gen)) 57 | # Please backport https://github.com/kubernetes-sigs/controller-tools/pull/317 58 | # and build a customize controller-gen to use 59 | go get sigs.k8s.io/controller-tools/cmd/controller-gen@v0.2.2 60 | CONTROLLER_GEN=$(GOBIN)/controller-gen 61 | else 62 | CONTROLLER_GEN=$(shell which controller-gen) 63 | endif 64 | -------------------------------------------------------------------------------- /OWNERS: -------------------------------------------------------------------------------- 1 | approvers: 2 | - merlintang 3 | - terrytangyuan 4 | - Jeffwan 5 | -------------------------------------------------------------------------------- /PROJECT: -------------------------------------------------------------------------------- 1 | version: "1" 2 | domain: kubeflow.org 3 | repo: github.com/kubeflow/xgboost-operator 4 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # XGBoost Operator 2 | 3 | [![Build Status](https://travis-ci.com/kubeflow/xgboost-operator.svg?branch=master)](https://travis-ci.com/kubeflow/xgboost-operator/) 4 | [![Go Report Card](https://goreportcard.com/badge/github.com/kubeflow/xgboost-operator)](https://goreportcard.com/report/github.com/kubeflow/xgboost-operator) 5 | 6 | 7 | ## :warning: **kubeflow/xgboost-operator is not maintained** 8 | 9 | This operator has been merged into [Kubeflow Training Operator](https://github.com/kubeflow/training-operator). This repository is not maintained and has been archived. 10 | 11 | ## Overview 12 | 13 | Incubating project for [XGBoost](https://github.com/dmlc/xgboost) operator. The XGBoost operator makes it easy to run distributed XGBoost job training and batch prediction on Kubernetes cluster. 14 | 15 | The overall design can be found [here]( https://github.com/kubeflow/community/issues/247). 16 | 17 | This repository contains the specification and implementation of `XGBoostJob` custom resource definition. 18 | Using this custom resource, users can create and manage XGBoost jobs like other built-in resources in Kubernetes. 19 | ## Prerequisites 20 | - Kubernetes >= 1.8 21 | - [kubectl](https://kubernetes.io/docs/tasks/tools/install-kubectl) 22 | 23 | ## Install XGBoost Operator 24 | 25 | You can deploy the operator with default settings by running the following commands using [kustomize](https://github.com/kubernetes-sigs/kustomize): 26 | 27 | ```bash 28 | cd manifests 29 | kubectl create namespace kubeflow 30 | kustomize build base | kubectl apply -f - 31 | ``` 32 | 33 | Note that since Kubernetes v1.14, `kustomize` became a subcommand in `kubectl` so you can also run the following command instead: 34 | 35 | ```bash 36 | kubectl kustomize base | kubectl apply -f - 37 | ``` 38 | 39 | ## Build XGBoost Operator 40 | 41 | XGBoost Operator is developed based on [Kubebuilder](https://github.com/kubernetes-sigs/kubebuilder) and [Kubeflow Common](https://github.com/kubeflow/common). 42 | 43 | You can follow the [installation guide of Kubebuilder](https://book.kubebuilder.io/cronjob-tutorial/running.html) to install XGBoost operator into the Kubernetes cluster. 44 | 45 | You can check whether the XGBoostJob custom resource has been installed via: 46 | ``` 47 | kubectl get crd 48 | ``` 49 | The output should include xgboostjobs.kubeflow.org like the following: 50 | ``` 51 | NAME CREATED AT 52 | xgboostjobs.xgboostjob.kubeflow.org 2021-03-24T22:03:07Z 53 | ``` 54 | If it is not included you can add it as follows: 55 | ``` 56 | ## setup the build enviroment 57 | export GOPATH=$HOME/go 58 | export PATH=$PATH:$GOPATH/bin 59 | export GO111MODULE=on 60 | cd $GOPATH 61 | mkdir src/github.com/kubeflow 62 | cd src/github.com/kubeflow 63 | 64 | ## clone the code 65 | git clone git@github.com:kubeflow/xgboost-operator.git 66 | cd xgboost-operator 67 | 68 | ## build and install xgboost operator 69 | make 70 | make install 71 | make run 72 | ``` 73 | If the XGBoost Job operator can be installed into cluster, you can view the logs likes this 74 | 75 |
76 | Logs 77 | 78 | ``` 79 | {"level":"info","ts":1589406873.090652,"logger":"entrypoint","msg":"setting up client for manager"} 80 | {"level":"info","ts":1589406873.0991302,"logger":"entrypoint","msg":"setting up manager"} 81 | {"level":"info","ts":1589406874.2192929,"logger":"entrypoint","msg":"Registering Components."} 82 | {"level":"info","ts":1589406874.219318,"logger":"entrypoint","msg":"setting up scheme"} 83 | {"level":"info","ts":1589406874.219448,"logger":"entrypoint","msg":"Setting up controller"} 84 | {"level":"info","ts":1589406874.2194738,"logger":"controller","msg":"Running controller in local mode, using kubeconfig file"} 85 | {"level":"info","ts":1589406874.224564,"logger":"controller","msg":"gang scheduling is set: ","gangscheduling":false} 86 | {"level":"info","ts":1589406874.2247412,"logger":"kubebuilder.controller","msg":"Starting EventSource","controller":"xgboostjob-controller","source":"kind source: /, Kind="} 87 | {"level":"info","ts":1589406874.224958,"logger":"kubebuilder.controller","msg":"Starting EventSource","controller":"xgboostjob-controller","source":"kind source: /, Kind="} 88 | {"level":"info","ts":1589406874.2251048,"logger":"kubebuilder.controller","msg":"Starting EventSource","controller":"xgboostjob-controller","source":"kind source: /, Kind="} 89 | {"level":"info","ts":1589406874.225237,"logger":"entrypoint","msg":"setting up webhooks"} 90 | {"level":"info","ts":1589406874.225247,"logger":"entrypoint","msg":"Starting the Cmd."} 91 | {"level":"info","ts":1589406874.32791,"logger":"kubebuilder.controller","msg":"Starting Controller","controller":"xgboostjob-controller"} 92 | {"level":"info","ts":1589406874.430336,"logger":"kubebuilder.controller","msg":"Starting workers","controller":"xgboostjob-controller","worker count":1} 93 | ``` 94 |
95 | 96 | ## Creating a XGBoost Training/Prediction Job 97 | 98 | You can create a XGBoost training or prediction (batch oriented) job by modifying the XGBoostJob config file. 99 | See the distributed XGBoost Job training and prediction [example](https://github.com/kubeflow/xgboost-operator/tree/master/config/samples/xgboost-dist). 100 | You can change the config file and related python file (i.e., train.py or predict.py) 101 | based on your requirement. 102 | 103 | Following the job configuration guild in the example, you can deploy a XGBoost Job to start training or prediction like: 104 | ``` 105 | ## For training job 106 | cat config/samples/xgboost-dist/xgboostjob_v1_iris_train.yaml 107 | kubectl create -f config/samples/xgboost-dist/xgboostjob_v1_iris_train.yaml 108 | 109 | ## For batch prediction job 110 | cat config/samples/xgboost-dist/xgboostjob_v1_iris_predict.yaml 111 | kubectl create -f config/samples/xgboost-dist/xgboostjob_v1_iris_predict.yaml 112 | ``` 113 | 114 | ## Monitor a distributed XGBoost Job 115 | 116 | Once the XGBoost job is created, you should be able to watch how the related pod and service working. 117 | Distributed XGBoost job is trained by synchronizing different worker status via tne Rabit of XGBoost. 118 | You can also monitor the job status. 119 | 120 | ``` 121 | kubectl get -o yaml XGBoostJob/xgboost-dist-iris-test-train 122 | ``` 123 | 124 | Here is the sample output when training job is finished. 125 | 126 |
127 | XGBoost Job Details 128 | 129 | ``` 130 | apiVersion: xgboostjob.kubeflow.org/v1 131 | kind: XGBoostJob 132 | metadata: 133 | annotations: 134 | kubectl.kubernetes.io/last-applied-configuration: | 135 | {"apiVersion":"xgboostjob.kubeflow.org/v1","kind":"XGBoostJob","metadata":{"annotations":{},"name":"xgboost-dist-iris-test-train","namespace":"default"},"spec":{"xgbReplicaSpecs":{"Master":{"replicas":1,"restartPolicy":"Never","template":{"spec":{"containers":[{"args":["--job_type=Train","--xgboost_parameter=objective:multi:softprob,num_class:3","--n_estimators=10","--learning_rate=0.1","--model_path=/tmp/xgboost-model","--model_storage_type=local"],"image":"docker.io/merlintang/xgboost-dist-iris:1.1","imagePullPolicy":"Always","name":"xgboostjob","ports":[{"containerPort":9991,"name":"xgboostjob-port"}]}]}}},"Worker":{"replicas":2,"restartPolicy":"ExitCode","template":{"spec":{"containers":[{"args":["--job_type=Train","--xgboost_parameter=\"objective:multi:softprob,num_class:3\"","--n_estimators=10","--learning_rate=0.1"],"image":"docker.io/merlintang/xgboost-dist-iris:1.1","imagePullPolicy":"Always","name":"xgboostjob","ports":[{"containerPort":9991,"name":"xgboostjob-port"}]}]}}}}}} 136 | creationTimestamp: "2021-03-24T22:54:39Z" 137 | generation: 8 138 | name: xgboost-dist-iris-test-train 139 | namespace: default 140 | resourceVersion: "1060393" 141 | selfLink: /apis/xgboostjob.kubeflow.org/v1/namespaces/default/xgboostjobs/xgboost-dist-iris-test-train 142 | uid: 386c9851-7ef8-4928-9dba-2da8829bf048 143 | spec: 144 | RunPolicy: 145 | cleanPodPolicy: None 146 | xgbReplicaSpecs: 147 | Master: 148 | replicas: 1 149 | restartPolicy: Never 150 | template: 151 | metadata: 152 | creationTimestamp: null 153 | spec: 154 | containers: 155 | - args: 156 | - --job_type=Train 157 | - --xgboost_parameter=objective:multi:softprob,num_class:3 158 | - --n_estimators=10 159 | - --learning_rate=0.1 160 | - --model_path=/tmp/xgboost-model 161 | - --model_storage_type=local 162 | image: docker.io/merlintang/xgboost-dist-iris:1.1 163 | imagePullPolicy: Always 164 | name: xgboostjob 165 | ports: 166 | - containerPort: 9991 167 | name: xgboostjob-port 168 | resources: {} 169 | Worker: 170 | replicas: 2 171 | restartPolicy: ExitCode 172 | template: 173 | metadata: 174 | creationTimestamp: null 175 | spec: 176 | containers: 177 | - args: 178 | - --job_type=Train 179 | - --xgboost_parameter="objective:multi:softprob,num_class:3" 180 | - --n_estimators=10 181 | - --learning_rate=0.1 182 | image: docker.io/merlintang/xgboost-dist-iris:1.1 183 | imagePullPolicy: Always 184 | name: xgboostjob 185 | ports: 186 | - containerPort: 9991 187 | name: xgboostjob-port 188 | resources: {} 189 | status: 190 | completionTime: "2021-03-24T22:54:58Z" 191 | conditions: 192 | - lastTransitionTime: "2021-03-24T22:54:39Z" 193 | lastUpdateTime: "2021-03-24T22:54:39Z" 194 | message: xgboostJob xgboost-dist-iris-test-train is created. 195 | reason: XGBoostJobCreated 196 | status: "True" 197 | type: Created 198 | - lastTransitionTime: "2021-03-24T22:54:39Z" 199 | lastUpdateTime: "2021-03-24T22:54:39Z" 200 | message: XGBoostJob xgboost-dist-iris-test-train is running. 201 | reason: XGBoostJobRunning 202 | status: "False" 203 | type: Running 204 | - lastTransitionTime: "2021-03-24T22:54:58Z" 205 | lastUpdateTime: "2021-03-24T22:54:58Z" 206 | message: XGBoostJob xgboost-dist-iris-test-train is successfully completed. 207 | reason: XGBoostJobSucceeded 208 | status: "True" 209 | type: Succeeded 210 | replicaStatuses: 211 | Master: 212 | succeeded: 1 213 | Worker: 214 | succeeded: 2 215 | ``` 216 | 217 |
218 | 219 | ## Docker Images 220 | 221 | You can use [this Dockerfile](Dockerfile) to build the image yourself: 222 | 223 | Alternatively, you can pull the existing image from Dockerhub [here](https://hub.docker.com/r/kubeflow/xgboost-operator/tags). 224 | 225 | ## Known Issues 226 | 227 | XGBoost and `kubeflow/common` use pointer value in map like `map[commonv1.ReplicaType]*commonv1.ReplicaSpec`. However, `controller-gen` in [controller-tools](https://github.com/kubernetes-sigs/controller-tools) doesn't accept pointers as map values in latest version (v0.3.0), in order to generate crds and deepcopy files, we need to build custom `controller-gen`. You can follow steps below. Then `make generate` can work properly. 228 | 229 | ```shell 230 | git clone --branch v0.2.2 git@github.com:kubernetes-sigs/controller-tools.git 231 | git cherry-pick 71b6e91 232 | go build -o controller-gen cmd/controller-gen/main.go 233 | cp controller-gen /usr/local/bin 234 | ``` 235 | 236 | This can be removed once a newer `controller-gen` released and xgboost can upgrade to corresponding k8s version. 237 | -------------------------------------------------------------------------------- /build_image.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | # 3 | # A simple script to build the Docker images. 4 | # This is intended to be invoked as a step in Argo to build the docker image. 5 | # 6 | # build_image.sh ${DOCKERFILE} ${IMAGE} ${TAG} 7 | # Note: TAG is not used from workflows, always generated uniquely 8 | set -ex 9 | 10 | DOCKERFILE=$1 11 | CONTEXT_DIR=$(dirname "$DOCKERFILE") 12 | IMAGE=$2 13 | TIMEOUT=30m 14 | 15 | cd $CONTEXT_DIR 16 | TAG=$(git describe --tags --always --dirty) 17 | 18 | gcloud auth activate-service-account --key-file=${GOOGLE_APPLICATION_CREDENTIALS} 19 | 20 | echo "Building ${IMAGE} using gcloud build" 21 | gcloud builds submit --tag=${IMAGE}:${TAG} --project=${GCP_PROJECT} --timeout=${TIMEOUT} . 22 | echo "Finished building ${IMAGE}:${TAG}" 23 | -------------------------------------------------------------------------------- /cmd/manager/main.go: -------------------------------------------------------------------------------- 1 | /* 2 | 3 | Licensed under the Apache License, Version 2.0 (the "License"); 4 | you may not use this file except in compliance with the License. 5 | You may obtain a copy of the License at 6 | 7 | http://www.apache.org/licenses/LICENSE-2.0 8 | 9 | Unless required by applicable law or agreed to in writing, software 10 | distributed under the License is distributed on an "AS IS" BASIS, 11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | See the License for the specific language governing permissions and 13 | limitations under the License. 14 | */ 15 | 16 | package main 17 | 18 | import ( 19 | "flag" 20 | "os" 21 | 22 | "github.com/kubeflow/xgboost-operator/pkg/apis" 23 | controller "github.com/kubeflow/xgboost-operator/pkg/controller/v1" 24 | "github.com/kubeflow/xgboost-operator/pkg/webhook" 25 | _ "k8s.io/client-go/plugin/pkg/client/auth/gcp" 26 | "sigs.k8s.io/controller-runtime/pkg/client/config" 27 | "sigs.k8s.io/controller-runtime/pkg/manager" 28 | logf "sigs.k8s.io/controller-runtime/pkg/runtime/log" 29 | "sigs.k8s.io/controller-runtime/pkg/runtime/signals" 30 | ) 31 | 32 | func main() { 33 | var metricsAddr string 34 | var mode string 35 | flag.StringVar(&metricsAddr, "metrics-addr", ":8080", "The address the metric endpoint binds to.") 36 | flag.StringVar(&mode, "mode", "local", "The mode in which xgboost-operator to run") 37 | flag.Parse() 38 | logf.SetLogger(logf.ZapLogger(false)) 39 | log := logf.Log.WithName("entrypoint") 40 | 41 | // Get a config to talk to the apiserver 42 | log.Info("setting up client for manager") 43 | cfg, err := config.GetConfig() 44 | if err != nil { 45 | log.Error(err, "unable to set up client config") 46 | os.Exit(1) 47 | } 48 | 49 | // Create a new Cmd to provide shared dependencies and start components 50 | log.Info("setting up manager") 51 | mgr, err := manager.New(cfg, manager.Options{MetricsBindAddress: metricsAddr}) 52 | if err != nil { 53 | log.Error(err, "unable to set up overall controller manager") 54 | os.Exit(1) 55 | } 56 | 57 | log.Info("Registering Components.") 58 | 59 | // Setup Scheme for all resources 60 | log.Info("setting up scheme") 61 | if err := apis.AddToScheme(mgr.GetScheme()); err != nil { 62 | log.Error(err, "unable add APIs to scheme") 63 | os.Exit(1) 64 | } 65 | 66 | // Setup all Controllers 67 | log.Info("Setting up controller") 68 | if err := controller.AddToManager(mgr); err != nil { 69 | log.Error(err, "unable to register controllers to the manager") 70 | os.Exit(1) 71 | } 72 | 73 | log.Info("setting up webhooks") 74 | if err := webhook.AddToManager(mgr); err != nil { 75 | log.Error(err, "unable to register webhooks to the manager") 76 | os.Exit(1) 77 | } 78 | 79 | // Start the Cmd 80 | log.Info("Starting the Cmd.") 81 | if err := mgr.Start(signals.SetupSignalHandler()); err != nil { 82 | log.Error(err, "unable to run the manager") 83 | os.Exit(1) 84 | } 85 | } 86 | -------------------------------------------------------------------------------- /config/crds/xgboostjob_v1alpha1_xgboostjob.yaml: -------------------------------------------------------------------------------- 1 | apiVersion: apiextensions.k8s.io/v1beta1 2 | kind: CustomResourceDefinition 3 | metadata: 4 | creationTimestamp: null 5 | labels: 6 | controller-tools.k8s.io: "1.0" 7 | name: xgboostjobs.kubeflow.org 8 | spec: 9 | group: kubeflow.org 10 | names: 11 | kind: XGBoostJob 12 | singular: xgboostjob 13 | plural: xgboostjobs 14 | scope: Namespaced 15 | validation: 16 | openAPIV3Schema: 17 | properties: 18 | apiVersion: 19 | description: 'APIVersion defines the versioned schema of this representation 20 | of an object. Servers should convert recognized schemas to the latest 21 | internal value, and may reject unrecognized values. More info: https://git.k8s.io/community/contributors/devel/api-conventions.md#resources' 22 | type: string 23 | kind: 24 | description: 'Kind is a string value representing the REST resource this 25 | object represents. Servers may infer this from the endpoint the client 26 | submits requests to. Cannot be updated. In CamelCase. More info: https://git.k8s.io/community/contributors/devel/api-conventions.md#types-kinds' 27 | type: string 28 | metadata: 29 | type: object 30 | spec: 31 | properties: 32 | activeDeadlineSeconds: 33 | description: Specifies the duration in seconds relative to the startTime 34 | that the job may be active before the system tries to terminate it; 35 | value must be positive integer. 36 | format: int64 37 | type: integer 38 | backoffLimit: 39 | description: Optional number of retries before marking this job failed. 40 | format: int32 41 | type: integer 42 | cleanPodPolicy: 43 | description: CleanPodPolicy defines the policy to kill pods after the 44 | job completes. Default to Running. 45 | type: string 46 | schedulingPolicy: 47 | description: SchedulingPolicy defines the policy related to scheduling, 48 | e.g. gang-scheduling 49 | properties: 50 | minAvailable: 51 | format: int32 52 | type: integer 53 | type: object 54 | ttlSecondsAfterFinished: 55 | description: TTLSecondsAfterFinished is the TTL to clean up jobs. It 56 | may take extra ReconcilePeriod seconds for the cleanup, since reconcile 57 | gets called periodically. Default to infinite. 58 | format: int32 59 | type: integer 60 | xgbReplicaSpecs: 61 | type: object 62 | required: 63 | - xgbReplicaSpecs 64 | type: object 65 | status: 66 | properties: 67 | completionTime: 68 | description: Represents time when the job was completed. It is not guaranteed 69 | to be set in happens-before order across separate operations. It is 70 | represented in RFC3339 form and is in UTC. 71 | format: date-time 72 | type: string 73 | conditions: 74 | description: Conditions is an array of current observed job conditions. 75 | items: 76 | properties: 77 | lastTransitionTime: 78 | description: Last time the condition transitioned from one status 79 | to another. 80 | format: date-time 81 | type: string 82 | lastUpdateTime: 83 | description: The last time this condition was updated. 84 | format: date-time 85 | type: string 86 | message: 87 | description: A human readable message indicating details about 88 | the transition. 89 | type: string 90 | reason: 91 | description: The reason for the condition's last transition. 92 | type: string 93 | status: 94 | description: Status of the condition, one of True, False, Unknown. 95 | type: string 96 | type: 97 | description: Type of job condition. 98 | type: string 99 | required: 100 | - type 101 | - status 102 | type: object 103 | type: array 104 | lastReconcileTime: 105 | description: Represents last time when the job was reconciled. It is 106 | not guaranteed to be set in happens-before order across separate operations. 107 | It is represented in RFC3339 form and is in UTC. 108 | format: date-time 109 | type: string 110 | replicaStatuses: 111 | description: ReplicaStatuses is map of ReplicaType and ReplicaStatus, 112 | specifies the status of each replica. 113 | type: object 114 | startTime: 115 | description: Represents time when the job was acknowledged by the job 116 | controller. It is not guaranteed to be set in happens-before order 117 | across separate operations. It is represented in RFC3339 form and 118 | is in UTC. 119 | format: date-time 120 | type: string 121 | required: 122 | - conditions 123 | - replicaStatuses 124 | type: object 125 | version: v1alpha1 126 | status: 127 | acceptedNames: 128 | kind: "" 129 | plural: "" 130 | conditions: [] 131 | storedVersions: [] 132 | -------------------------------------------------------------------------------- /config/default/kustomization.yaml: -------------------------------------------------------------------------------- 1 | # Adds namespace to all resources. 2 | namespace: xgboost-operator-system 3 | 4 | # Value of this field is prepended to the 5 | # names of all resources, e.g. a deployment named 6 | # "wordpress" becomes "alices-wordpress". 7 | # Note that it should also match with the prefix (text before '-') of the namespace 8 | # field above. 9 | namePrefix: xgboost-operator- 10 | 11 | # Labels to add to all resources and selectors. 12 | #commonLabels: 13 | # someName: someValue 14 | 15 | # Each entry in this list must resolve to an existing 16 | # resource definition in YAML. These are the resource 17 | # files that kustomize reads, modifies and emits as a 18 | # YAML string, with resources separated by document 19 | # markers ("---"). 20 | resources: 21 | - ../rbac/rbac_role.yaml 22 | - ../rbac/rbac_role_binding.yaml 23 | - ../manager/manager.yaml 24 | # Comment the following 3 lines if you want to disable 25 | # the auth proxy (https://github.com/brancz/kube-rbac-proxy) 26 | # which protects your /metrics endpoint. 27 | - ../rbac/auth_proxy_service.yaml 28 | - ../rbac/auth_proxy_role.yaml 29 | - ../rbac/auth_proxy_role_binding.yaml 30 | 31 | patches: 32 | - manager_image_patch.yaml 33 | # Protect the /metrics endpoint by putting it behind auth. 34 | # Only one of manager_auth_proxy_patch.yaml and 35 | # manager_prometheus_metrics_patch.yaml should be enabled. 36 | - manager_auth_proxy_patch.yaml 37 | # If you want your controller-manager to expose the /metrics 38 | # endpoint w/o any authn/z, uncomment the following line and 39 | # comment manager_auth_proxy_patch.yaml. 40 | # Only one of manager_auth_proxy_patch.yaml and 41 | # manager_prometheus_metrics_patch.yaml should be enabled. 42 | #- manager_prometheus_metrics_patch.yaml 43 | 44 | vars: 45 | - name: WEBHOOK_SECRET_NAME 46 | objref: 47 | kind: Secret 48 | name: webhook-server-secret 49 | apiVersion: v1 50 | -------------------------------------------------------------------------------- /config/default/manager_auth_proxy_patch.yaml: -------------------------------------------------------------------------------- 1 | # This patch inject a sidecar container which is a HTTP proxy for the controller manager, 2 | # it performs RBAC authorization against the Kubernetes API using SubjectAccessReviews. 3 | apiVersion: apps/v1 4 | kind: StatefulSet 5 | metadata: 6 | name: controller-manager 7 | namespace: system 8 | spec: 9 | template: 10 | spec: 11 | containers: 12 | - name: kube-rbac-proxy 13 | image: gcr.io/kubebuilder/kube-rbac-proxy:v0.4.0 14 | args: 15 | - "--secure-listen-address=0.0.0.0:8443" 16 | - "--upstream=http://127.0.0.1:8080/" 17 | - "--logtostderr=true" 18 | - "--v=10" 19 | ports: 20 | - containerPort: 8443 21 | name: https 22 | - name: manager 23 | args: 24 | - "--metrics-addr=127.0.0.1:8080" 25 | -------------------------------------------------------------------------------- /config/default/manager_image_patch.yaml: -------------------------------------------------------------------------------- 1 | apiVersion: apps/v1 2 | kind: StatefulSet 3 | metadata: 4 | name: controller-manager 5 | namespace: system 6 | spec: 7 | template: 8 | spec: 9 | containers: 10 | # Change the value of image field below to your controller image URL 11 | - image: IMAGE_URL 12 | name: manager 13 | -------------------------------------------------------------------------------- /config/default/manager_prometheus_metrics_patch.yaml: -------------------------------------------------------------------------------- 1 | # This patch enables Prometheus scraping for the manager pod. 2 | apiVersion: apps/v1 3 | kind: StatefulSet 4 | metadata: 5 | name: controller-manager 6 | namespace: system 7 | spec: 8 | template: 9 | metadata: 10 | annotations: 11 | prometheus.io/scrape: 'true' 12 | spec: 13 | containers: 14 | # Expose the prometheus metrics on default port 15 | - name: manager 16 | ports: 17 | - containerPort: 8080 18 | name: metrics 19 | protocol: TCP 20 | -------------------------------------------------------------------------------- /config/manager/manager.yaml: -------------------------------------------------------------------------------- 1 | apiVersion: v1 2 | kind: Namespace 3 | metadata: 4 | labels: 5 | control-plane: controller-manager 6 | controller-tools.k8s.io: "1.0" 7 | name: system 8 | --- 9 | apiVersion: v1 10 | kind: Service 11 | metadata: 12 | name: controller-manager-service 13 | namespace: system 14 | labels: 15 | control-plane: controller-manager 16 | controller-tools.k8s.io: "1.0" 17 | spec: 18 | selector: 19 | control-plane: controller-manager 20 | controller-tools.k8s.io: "1.0" 21 | ports: 22 | - port: 443 23 | --- 24 | apiVersion: apps/v1 25 | kind: StatefulSet 26 | metadata: 27 | name: controller-manager 28 | namespace: system 29 | labels: 30 | control-plane: controller-manager 31 | controller-tools.k8s.io: "1.0" 32 | spec: 33 | selector: 34 | matchLabels: 35 | control-plane: controller-manager 36 | controller-tools.k8s.io: "1.0" 37 | serviceName: controller-manager-service 38 | template: 39 | metadata: 40 | labels: 41 | control-plane: controller-manager 42 | controller-tools.k8s.io: "1.0" 43 | spec: 44 | containers: 45 | - command: 46 | - /manager 47 | image: controller:latest 48 | imagePullPolicy: Always 49 | name: manager 50 | env: 51 | - name: POD_NAMESPACE 52 | valueFrom: 53 | fieldRef: 54 | fieldPath: metadata.namespace 55 | - name: SECRET_NAME 56 | value: $(WEBHOOK_SECRET_NAME) 57 | resources: 58 | limits: 59 | cpu: 100m 60 | memory: 30Mi 61 | requests: 62 | cpu: 100m 63 | memory: 20Mi 64 | ports: 65 | - containerPort: 9876 66 | name: webhook-server 67 | protocol: TCP 68 | volumeMounts: 69 | - mountPath: /tmp/cert 70 | name: cert 71 | readOnly: true 72 | terminationGracePeriodSeconds: 10 73 | volumes: 74 | - name: cert 75 | secret: 76 | defaultMode: 420 77 | secretName: webhook-server-secret 78 | --- 79 | apiVersion: v1 80 | kind: Secret 81 | metadata: 82 | name: webhook-server-secret 83 | namespace: system 84 | -------------------------------------------------------------------------------- /config/rbac/auth_proxy_role.yaml: -------------------------------------------------------------------------------- 1 | apiVersion: rbac.authorization.k8s.io/v1 2 | kind: ClusterRole 3 | metadata: 4 | name: proxy-role 5 | rules: 6 | - apiGroups: ["authentication.k8s.io"] 7 | resources: 8 | - tokenreviews 9 | verbs: ["create"] 10 | - apiGroups: ["authorization.k8s.io"] 11 | resources: 12 | - subjectaccessreviews 13 | verbs: ["create"] 14 | -------------------------------------------------------------------------------- /config/rbac/auth_proxy_role_binding.yaml: -------------------------------------------------------------------------------- 1 | apiVersion: rbac.authorization.k8s.io/v1 2 | kind: ClusterRoleBinding 3 | metadata: 4 | name: proxy-rolebinding 5 | roleRef: 6 | apiGroup: rbac.authorization.k8s.io 7 | kind: ClusterRole 8 | name: proxy-role 9 | subjects: 10 | - kind: ServiceAccount 11 | name: default 12 | namespace: system 13 | -------------------------------------------------------------------------------- /config/rbac/auth_proxy_service.yaml: -------------------------------------------------------------------------------- 1 | apiVersion: v1 2 | kind: Service 3 | metadata: 4 | annotations: 5 | prometheus.io/port: "8443" 6 | prometheus.io/scheme: https 7 | prometheus.io/scrape: "true" 8 | labels: 9 | control-plane: controller-manager 10 | controller-tools.k8s.io: "1.0" 11 | name: controller-manager-metrics-service 12 | namespace: system 13 | spec: 14 | ports: 15 | - name: https 16 | port: 8443 17 | targetPort: https 18 | selector: 19 | control-plane: controller-manager 20 | controller-tools.k8s.io: "1.0" 21 | -------------------------------------------------------------------------------- /config/rbac/rbac_role.yaml: -------------------------------------------------------------------------------- 1 | apiVersion: rbac.authorization.k8s.io/v1 2 | kind: ClusterRole 3 | metadata: 4 | creationTimestamp: null 5 | name: manager-role 6 | rules: 7 | - apiGroups: 8 | - apps 9 | resources: 10 | - deployments 11 | verbs: 12 | - get 13 | - list 14 | - watch 15 | - create 16 | - update 17 | - patch 18 | - delete 19 | - apiGroups: 20 | - apps 21 | resources: 22 | - deployments/status 23 | verbs: 24 | - get 25 | - update 26 | - patch 27 | - apiGroups: 28 | - xgboostjob.kubeflow.org 29 | resources: 30 | - xgboostjobs 31 | verbs: 32 | - get 33 | - list 34 | - watch 35 | - create 36 | - update 37 | - patch 38 | - delete 39 | - apiGroups: 40 | - xgboostjob.kubeflow.org 41 | resources: 42 | - xgboostjobs/status 43 | verbs: 44 | - get 45 | - update 46 | - patch 47 | - apiGroups: 48 | - admissionregistration.k8s.io 49 | resources: 50 | - mutatingwebhookconfigurations 51 | - validatingwebhookconfigurations 52 | verbs: 53 | - get 54 | - list 55 | - watch 56 | - create 57 | - update 58 | - patch 59 | - delete 60 | - apiGroups: 61 | - "" 62 | resources: 63 | - secrets 64 | verbs: 65 | - get 66 | - list 67 | - watch 68 | - create 69 | - update 70 | - patch 71 | - delete 72 | - apiGroups: 73 | - "" 74 | resources: 75 | - services 76 | verbs: 77 | - get 78 | - list 79 | - watch 80 | - create 81 | - update 82 | - patch 83 | - delete 84 | -------------------------------------------------------------------------------- /config/rbac/rbac_role_binding.yaml: -------------------------------------------------------------------------------- 1 | apiVersion: rbac.authorization.k8s.io/v1 2 | kind: ClusterRoleBinding 3 | metadata: 4 | creationTimestamp: null 5 | name: manager-rolebinding 6 | roleRef: 7 | apiGroup: rbac.authorization.k8s.io 8 | kind: ClusterRole 9 | name: manager-role 10 | subjects: 11 | - kind: ServiceAccount 12 | name: default 13 | namespace: system 14 | -------------------------------------------------------------------------------- /config/samples/lightgbm-dist/Dockerfile: -------------------------------------------------------------------------------- 1 | FROM ubuntu:16.04 2 | 3 | ARG CONDA_DIR=/opt/conda 4 | ENV PATH $CONDA_DIR/bin:$PATH 5 | 6 | RUN apt-get update && \ 7 | apt-get install -y --no-install-recommends \ 8 | ca-certificates \ 9 | cmake \ 10 | build-essential \ 11 | gcc \ 12 | g++ \ 13 | git \ 14 | curl && \ 15 | # python environment 16 | curl -sL https://repo.anaconda.com/miniconda/Miniconda3-latest-Linux-x86_64.sh -o conda.sh && \ 17 | /bin/bash conda.sh -f -b -p $CONDA_DIR && \ 18 | export PATH="$CONDA_DIR/bin:$PATH" && \ 19 | conda config --set always_yes yes --set changeps1 no && \ 20 | # lightgbm 21 | conda install -q -y numpy==1.20.3 scipy==1.6.2 scikit-learn==0.24.2 pandas==1.3.0 && \ 22 | git clone --recursive --branch stable --depth 1 https://github.com/Microsoft/LightGBM && \ 23 | mkdir LightGBM/build && \ 24 | cd LightGBM/build && \ 25 | cmake .. && \ 26 | make -j4 && \ 27 | make install && \ 28 | cd ../python-package && \ 29 | python setup.py install_lib && \ 30 | # clean 31 | apt-get autoremove -y && apt-get clean && \ 32 | conda clean -a -y && \ 33 | rm -rf /usr/local/src/* && \ 34 | rm -rf /LightGBM 35 | 36 | WORKDIR /app 37 | 38 | # Download the example data 39 | RUN mkdir data 40 | ADD https://raw.githubusercontent.com/microsoft/LightGBM/stable/examples/parallel_learning/binary.train data/. 41 | ADD https://raw.githubusercontent.com/microsoft/LightGBM/stable/examples/parallel_learning/binary.test data/. 42 | COPY *.py ./ 43 | 44 | ENTRYPOINT [ "python", "/app/main.py" ] -------------------------------------------------------------------------------- /config/samples/lightgbm-dist/README.md: -------------------------------------------------------------------------------- 1 | ### Distributed Lightgbm Job train 2 | 3 | This folder containers Dockerfile and Python scripts to run a distributed Lightgbm training using the XGBoost operator. 4 | The code is based in this [example](https://github.com/microsoft/LightGBM/tree/master/examples/parallel_learning) in the official github repository of the library. 5 | 6 | 7 | **Build image** 8 | The default image name and tag is `kubeflow/lightgbm-dist-py-test:1.0` respectiveily. 9 | 10 | ```shell 11 | docker build -f Dockerfile -t kubeflow/lightgbm-dist-py-test:1.0 ./ 12 | ``` 13 | 14 | **Start the training** 15 | 16 | ``` 17 | kubectl create -f xgboostjob_v1_lightgbm_dist_training.yaml 18 | ``` 19 | 20 | **Look at the job status** 21 | ``` 22 | kubectl get -o yaml XGBoostJob/lightgbm-dist-train-test 23 | ``` 24 | Here is sample output when the job is running. The output result like this 25 | 26 | ``` 27 | apiVersion: xgboostjob.kubeflow.org/v1 28 | kind: XGBoostJob 29 | metadata: 30 | annotations: 31 | kubectl.kubernetes.io/last-applied-configuration: | 32 | {"apiVersion":"xgboostjob.kubeflow.org/v1","kind":"XGBoostJob","metadata":{"annotations":{},"name":"lightgbm-dist-train-test","namespace":"default"},"spec":{"xgbReplicaSpecs":{"Master":{"replicas":1,"restartPolicy":"Never","template":{"apiVersion":"v1","kind":"Pod","spec":{"containers":[{"args":["--job_type=Train","--boosting_type=gbdt","--objective=binary","--metric=binary_logloss,auc","--metric_freq=1","--is_training_metric=true","--max_bin=255","--data=data/binary.train","--valid_data=data/binary.test","--num_trees=100","--learning_rate=01","--num_leaves=63","--tree_learner=feature","--feature_fraction=0.8","--bagging_freq=5","--bagging_fraction=0.8","--min_data_in_leaf=50","--min_sum_hessian_in_leaf=50","--is_enable_sparse=true","--use_two_round_loading=false","--is_save_binary_file=false"],"image":"kubeflow/lightgbm-dist-py-test:1.0","imagePullPolicy":"Never","name":"xgboostjob","ports":[{"containerPort":9991,"name":"xgboostjob-port"}]}]}}},"Worker":{"replicas":2,"restartPolicy":"ExitCode","template":{"apiVersion":"v1","kind":"Pod","spec":{"containers":[{"args":["--job_type=Train","--boosting_type=gbdt","--objective=binary","--metric=binary_logloss,auc","--metric_freq=1","--is_training_metric=true","--max_bin=255","--data=data/binary.train","--valid_data=data/binary.test","--num_trees=100","--learning_rate=01","--num_leaves=63","--tree_learner=feature","--feature_fraction=0.8","--bagging_freq=5","--bagging_fraction=0.8","--min_data_in_leaf=50","--min_sum_hessian_in_leaf=50","--is_enable_sparse=true","--use_two_round_loading=false","--is_save_binary_file=false"],"image":"kubeflow/lightgbm-dist-py-test:1.0","imagePullPolicy":"Never","name":"xgboostjob","ports":[{"containerPort":9991,"name":"xgboostjob-port"}]}]}}}}}} 33 | creationTimestamp: "2020-10-14T15:31:23Z" 34 | generation: 7 35 | managedFields: 36 | - apiVersion: xgboostjob.kubeflow.org/v1 37 | fieldsType: FieldsV1 38 | fieldsV1: 39 | f:metadata: 40 | f:annotations: 41 | .: {} 42 | f:kubectl.kubernetes.io/last-applied-configuration: {} 43 | f:spec: 44 | .: {} 45 | f:xgbReplicaSpecs: 46 | .: {} 47 | f:Master: 48 | .: {} 49 | f:replicas: {} 50 | f:restartPolicy: {} 51 | f:template: 52 | .: {} 53 | f:spec: {} 54 | f:Worker: 55 | .: {} 56 | f:replicas: {} 57 | f:restartPolicy: {} 58 | f:template: 59 | .: {} 60 | f:spec: {} 61 | manager: kubectl-client-side-apply 62 | operation: Update 63 | time: "2020-10-14T15:31:23Z" 64 | - apiVersion: xgboostjob.kubeflow.org/v1 65 | fieldsType: FieldsV1 66 | fieldsV1: 67 | f:spec: 68 | f:RunPolicy: 69 | .: {} 70 | f:cleanPodPolicy: {} 71 | f:xgbReplicaSpecs: 72 | f:Master: 73 | f:template: 74 | f:metadata: 75 | .: {} 76 | f:creationTimestamp: {} 77 | f:spec: 78 | f:containers: {} 79 | f:Worker: 80 | f:template: 81 | f:metadata: 82 | .: {} 83 | f:creationTimestamp: {} 84 | f:spec: 85 | f:containers: {} 86 | f:status: 87 | .: {} 88 | f:completionTime: {} 89 | f:conditions: {} 90 | f:replicaStatuses: 91 | .: {} 92 | f:Master: 93 | .: {} 94 | f:succeeded: {} 95 | f:Worker: 96 | .: {} 97 | f:succeeded: {} 98 | manager: main 99 | operation: Update 100 | time: "2020-10-14T15:34:44Z" 101 | name: lightgbm-dist-train-test 102 | namespace: default 103 | resourceVersion: "38923" 104 | selfLink: /apis/xgboostjob.kubeflow.org/v1/namespaces/default/xgboostjobs/lightgbm-dist-train-test 105 | uid: b2b887d0-445b-498b-8852-26c8edc98dc7 106 | spec: 107 | RunPolicy: 108 | cleanPodPolicy: None 109 | xgbReplicaSpecs: 110 | Master: 111 | replicas: 1 112 | restartPolicy: Never 113 | template: 114 | metadata: 115 | creationTimestamp: null 116 | spec: 117 | containers: 118 | - args: 119 | - --job_type=Train 120 | - --boosting_type=gbdt 121 | - --objective=binary 122 | - --metric=binary_logloss,auc 123 | - --metric_freq=1 124 | - --is_training_metric=true 125 | - --max_bin=255 126 | - --data=data/binary.train 127 | - --valid_data=data/binary.test 128 | - --num_trees=100 129 | - --learning_rate=01 130 | - --num_leaves=63 131 | - --tree_learner=feature 132 | - --feature_fraction=0.8 133 | - --bagging_freq=5 134 | - --bagging_fraction=0.8 135 | - --min_data_in_leaf=50 136 | - --min_sum_hessian_in_leaf=50 137 | - --is_enable_sparse=true 138 | - --use_two_round_loading=false 139 | - --is_save_binary_file=false 140 | image: kubeflow/lightgbm-dist-py-test:1.0 141 | imagePullPolicy: Never 142 | name: xgboostjob 143 | ports: 144 | - containerPort: 9991 145 | name: xgboostjob-port 146 | resources: {} 147 | Worker: 148 | replicas: 2 149 | restartPolicy: ExitCode 150 | template: 151 | metadata: 152 | creationTimestamp: null 153 | spec: 154 | containers: 155 | - args: 156 | - --job_type=Train 157 | - --boosting_type=gbdt 158 | - --objective=binary 159 | - --metric=binary_logloss,auc 160 | - --metric_freq=1 161 | - --is_training_metric=true 162 | - --max_bin=255 163 | - --data=data/binary.train 164 | - --valid_data=data/binary.test 165 | - --num_trees=100 166 | - --learning_rate=01 167 | - --num_leaves=63 168 | - --tree_learner=feature 169 | - --feature_fraction=0.8 170 | - --bagging_freq=5 171 | - --bagging_fraction=0.8 172 | - --min_data_in_leaf=50 173 | - --min_sum_hessian_in_leaf=50 174 | - --is_enable_sparse=true 175 | - --use_two_round_loading=false 176 | - --is_save_binary_file=false 177 | image: kubeflow/lightgbm-dist-py-test:1.0 178 | imagePullPolicy: Never 179 | name: xgboostjob 180 | ports: 181 | - containerPort: 9991 182 | name: xgboostjob-port 183 | resources: {} 184 | status: 185 | completionTime: "2020-10-14T15:34:44Z" 186 | conditions: 187 | - lastTransitionTime: "2020-10-14T15:31:23Z" 188 | lastUpdateTime: "2020-10-14T15:31:23Z" 189 | message: xgboostJob lightgbm-dist-train-test is created. 190 | reason: XGBoostJobCreated 191 | status: "True" 192 | type: Created 193 | - lastTransitionTime: "2020-10-14T15:31:23Z" 194 | lastUpdateTime: "2020-10-14T15:31:23Z" 195 | message: XGBoostJob lightgbm-dist-train-test is running. 196 | reason: XGBoostJobRunning 197 | status: "False" 198 | type: Running 199 | - lastTransitionTime: "2020-10-14T15:34:44Z" 200 | lastUpdateTime: "2020-10-14T15:34:44Z" 201 | message: XGBoostJob lightgbm-dist-train-test is successfully completed. 202 | reason: XGBoostJobSucceeded 203 | status: "True" 204 | type: Succeeded 205 | replicaStatuses: 206 | Master: 207 | succeeded: 1 208 | Worker: 209 | succeeded: 2 210 | ``` -------------------------------------------------------------------------------- /config/samples/lightgbm-dist/main.py: -------------------------------------------------------------------------------- 1 | # Licensed under the Apache License, Version 2.0 (the "License"); 2 | # you may not use this file except in compliance with the License. 3 | # You may obtain a copy of the License at 4 | # 5 | # http://www.apache.org/licenses/LICENSE-2.0 6 | # 7 | # Unless required by applicable law or agreed to in writing, software 8 | # distributed under the License is distributed on an "AS IS" BASIS, 9 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 10 | # See the License for the specific language governing permissions and 11 | # limitations under the License. 12 | 13 | import os 14 | import logging 15 | import argparse 16 | 17 | from train import train 18 | 19 | from utils import generate_machine_list_file, generate_train_conf_file 20 | 21 | 22 | logger = logging.getLogger(__name__) 23 | 24 | 25 | def main(args, extra_args): 26 | 27 | master_addr = os.environ["MASTER_ADDR"] 28 | master_port = os.environ["MASTER_PORT"] 29 | worker_addrs = os.environ["WORKER_ADDRS"] 30 | worker_port = os.environ["WORKER_PORT"] 31 | world_size = int(os.environ["WORLD_SIZE"]) 32 | rank = int(os.environ["RANK"]) 33 | 34 | logger.info( 35 | "extract cluster info from env variables \n" 36 | f"master_addr: {master_addr} \n" 37 | f"master_port: {master_port} \n" 38 | f"worker_addrs: {worker_addrs} \n" 39 | f"worker_port: {worker_port} \n" 40 | f"world_size: {world_size} \n" 41 | f"rank: {rank} \n" 42 | ) 43 | 44 | if args.job_type == "Predict": 45 | logging.info("starting the predict job") 46 | 47 | elif args.job_type == "Train": 48 | logging.info("starting the train job") 49 | logging.info(f"extra args:\n {extra_args}") 50 | machine_list_filepath = generate_machine_list_file( 51 | master_addr, master_port, worker_addrs, worker_port 52 | ) 53 | logging.info(f"machine list generated in: {machine_list_filepath}") 54 | local_port = worker_port if rank else master_port 55 | config_file = generate_train_conf_file( 56 | machine_list_file=machine_list_filepath, 57 | world_size=world_size, 58 | output_model="model.txt", 59 | local_port=local_port, 60 | extra_args=extra_args, 61 | ) 62 | logging.info(f"config generated in: {config_file}") 63 | train(config_file) 64 | logging.info("Finish distributed job") 65 | 66 | 67 | if __name__ == "__main__": 68 | parser = argparse.ArgumentParser() 69 | 70 | parser.add_argument( 71 | "--job_type", 72 | help="Job type to execute", 73 | choices=["Train", "Predict"], 74 | required=True, 75 | ) 76 | 77 | logging.basicConfig(format="%(message)s") 78 | logging.getLogger().setLevel(logging.INFO) 79 | args, extra_args = parser.parse_known_args() 80 | main(args, extra_args) 81 | -------------------------------------------------------------------------------- /config/samples/lightgbm-dist/train.py: -------------------------------------------------------------------------------- 1 | # Licensed under the Apache License, Version 2.0 (the "License"); 2 | # you may not use this file except in compliance with the License. 3 | # You may obtain a copy of the License at 4 | # 5 | # http://www.apache.org/licenses/LICENSE-2.0 6 | # 7 | # Unless required by applicable law or agreed to in writing, software 8 | # distributed under the License is distributed on an "AS IS" BASIS, 9 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 10 | # See the License for the specific language governing permissions and 11 | # limitations under the License. 12 | 13 | 14 | import logging 15 | import subprocess 16 | 17 | logger = logging.getLogger(__name__) 18 | 19 | 20 | def train(train_config_filepath: str): 21 | cmd = ["lightgbm", f"config={train_config_filepath}"] 22 | proc = subprocess.Popen(cmd, stdout=subprocess.PIPE) 23 | line = proc.stdout.readline() 24 | while line: 25 | logger.info((line.decode("utf-8").strip())) 26 | line = proc.stdout.readline() 27 | -------------------------------------------------------------------------------- /config/samples/lightgbm-dist/utils.py: -------------------------------------------------------------------------------- 1 | # Licensed under the Apache License, Version 2.0 (the "License"); 2 | # you may not use this file except in compliance with the License. 3 | # You may obtain a copy of the License at 4 | # 5 | # http://www.apache.org/licenses/LICENSE-2.0 6 | # 7 | # Unless required by applicable law or agreed to in writing, software 8 | # distributed under the License is distributed on an "AS IS" BASIS, 9 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 10 | # See the License for the specific language governing permissions and 11 | # limitations under the License. 12 | 13 | import re 14 | import socket 15 | import logging 16 | import tempfile 17 | from time import sleep 18 | from typing import List, Union 19 | 20 | logger = logging.getLogger(__name__) 21 | 22 | 23 | def generate_machine_list_file( 24 | master_addr: str, master_port: str, worker_addrs: str, worker_port: str 25 | ) -> str: 26 | logger.info("starting to extract system env") 27 | 28 | filename = tempfile.NamedTemporaryFile(delete=False).name 29 | 30 | def _get_ips( 31 | master_addr_name, 32 | worker_addr_names, 33 | max_retries=10, 34 | sleep_secs=10, 35 | current_retry=0, 36 | ): 37 | try: 38 | worker_addr_ips = [] 39 | master_addr_ip = socket.gethostbyname(master_addr_name) 40 | 41 | for addr in worker_addr_names.split(","): 42 | worker_addr_ips.append(socket.gethostbyname(addr)) 43 | 44 | except socket.gaierror as ex: 45 | if "Name or service not known" in str(ex) and current_retry < max_retries: 46 | sleep(sleep_secs) 47 | master_addr_ip, worker_addr_ips = _get_ips( 48 | master_addr_name, 49 | worker_addr_names, 50 | max_retries=max_retries, 51 | sleep_secs=sleep_secs, 52 | current_retry=current_retry + 1, 53 | ) 54 | else: 55 | raise ValueError("Couldn't get address names") 56 | 57 | return master_addr_ip, worker_addr_ips 58 | 59 | master_ip, worker_ips = _get_ips(master_addr, worker_addrs) 60 | 61 | with open(filename, "w") as file: 62 | print(f"{master_ip} {master_port}", file=file) 63 | for addr in worker_ips: 64 | print(f"{addr} {worker_port}", file=file) 65 | 66 | return filename 67 | 68 | 69 | def generate_train_conf_file( 70 | machine_list_file: str, 71 | world_size: int, 72 | output_model: str, 73 | local_port: Union[int, str], 74 | extra_args: List[str], 75 | ) -> str: 76 | 77 | filename = tempfile.NamedTemporaryFile(delete=False).name 78 | 79 | with open(filename, "w") as file: 80 | print("task = train", file=file) 81 | print(f"output_model = {output_model}", file=file) 82 | print(f"num_machines = {world_size}", file=file) 83 | print(f"local_listen_port = {local_port}", file=file) 84 | print(f"machine_list_file = {machine_list_file}", file=file) 85 | for arg in extra_args: 86 | m = re.match(r"--(.+)=([^\s]+)", arg) 87 | if m is not None: 88 | k, v = m.groups() 89 | print(f"{k} = {v}", file=file) 90 | 91 | return filename 92 | -------------------------------------------------------------------------------- /config/samples/lightgbm-dist/xgboostjob_v1_lightgbm_dist_training.yaml: -------------------------------------------------------------------------------- 1 | apiVersion: "xgboostjob.kubeflow.org/v1" 2 | kind: "XGBoostJob" 3 | metadata: 4 | name: "lightgbm-dist-train-test" 5 | spec: 6 | xgbReplicaSpecs: 7 | Master: 8 | replicas: 1 9 | restartPolicy: Never 10 | template: 11 | spec: 12 | containers: 13 | - name: xgboostjob 14 | image: kubeflow/lightgbm-dist-py-test:1.0 15 | ports: 16 | - containerPort: 9991 17 | name: xgboostjob-port 18 | imagePullPolicy: Never 19 | args: 20 | - --job_type=Train 21 | - --boosting_type=gbdt 22 | - --objective=binary 23 | - --metric=binary_logloss,auc 24 | - --metric_freq=1 25 | - --is_training_metric=true 26 | - --max_bin=255 27 | - --data=data/binary.train 28 | - --valid_data=data/binary.test 29 | - --num_trees=100 30 | - --learning_rate=01 31 | - --num_leaves=63 32 | - --tree_learner=feature 33 | - --feature_fraction=0.8 34 | - --bagging_freq=5 35 | - --bagging_fraction=0.8 36 | - --min_data_in_leaf=50 37 | - --min_sum_hessian_in_leaf=50 38 | - --is_enable_sparse=true 39 | - --use_two_round_loading=false 40 | - --is_save_binary_file=false 41 | Worker: 42 | replicas: 2 43 | restartPolicy: ExitCode 44 | template: 45 | spec: 46 | containers: 47 | - name: xgboostjob 48 | image: kubeflow/lightgbm-dist-py-test:1.0 49 | ports: 50 | - containerPort: 9991 51 | name: xgboostjob-port 52 | imagePullPolicy: Never 53 | args: 54 | - --job_type=Train 55 | - --boosting_type=gbdt 56 | - --objective=binary 57 | - --metric=binary_logloss,auc 58 | - --metric_freq=1 59 | - --is_training_metric=true 60 | - --max_bin=255 61 | - --data=data/binary.train 62 | - --valid_data=data/binary.test 63 | - --num_trees=100 64 | - --learning_rate=01 65 | - --num_leaves=63 66 | - --tree_learner=feature 67 | - --feature_fraction=0.8 68 | - --bagging_freq=5 69 | - --bagging_fraction=0.8 70 | - --min_data_in_leaf=50 71 | - --min_sum_hessian_in_leaf=50 72 | - --is_enable_sparse=true 73 | - --use_two_round_loading=false 74 | - --is_save_binary_file=false 75 | 76 | -------------------------------------------------------------------------------- /config/samples/smoke-dist/Dockerfile: -------------------------------------------------------------------------------- 1 | # Install python 3.6 2 | FROM python:3.6 3 | 4 | RUN apt-get update 5 | RUN apt-get install -y git make g++ cmake 6 | 7 | RUN mkdir -p /opt/mlkube 8 | 9 | # Download the rabit tracker and xgboost code. 10 | 11 | COPY tracker.py /opt/mlkube/ 12 | COPY requirements.txt /opt/mlkube/ 13 | 14 | # Install requirements 15 | 16 | RUN pip install -r /opt/mlkube/requirements.txt 17 | 18 | # Build XGBoost. 19 | RUN git clone --recursive https://github.com/dmlc/xgboost && \ 20 | cd xgboost && \ 21 | make -j$(nproc) && \ 22 | cd python-package; python setup.py install 23 | 24 | COPY xgboost_smoke_test.py /opt/mlkube/ 25 | 26 | ENTRYPOINT ["python", "/opt/mlkube/xgboost_smoke_test.py"] 27 | -------------------------------------------------------------------------------- /config/samples/smoke-dist/README.md: -------------------------------------------------------------------------------- 1 | ### Distributed send/recv e2e test for xgboost rabit 2 | 3 | This folder containers Dockerfile and distributed send/recv test. 4 | 5 | **Build Image** 6 | 7 | The default image name and tag is `kubeflow/xgboost-dist-rabit-test:1.2`. 8 | You can build the image based on your requirement. 9 | 10 | ```shell 11 | docker build -f Dockerfile -t kubeflow/xgboost-dist-rabit-test:1.2 ./ 12 | ``` 13 | 14 | **Start and test XGBoost Rabit tracker** 15 | 16 | ``` 17 | kubectl create -f xgboostjob_v1alpha1_rabit_test.yaml 18 | ``` 19 | 20 | **Look at the job status** 21 | ``` 22 | kubectl get -o yaml XGBoostJob/xgboost-dist-test 23 | ``` 24 | Here is sample output when the job is running. The output result like this 25 | ``` 26 | apiVersion: xgboostjob.kubeflow.org/v1alpha1 27 | kind: XGBoostJob 28 | metadata: 29 | creationTimestamp: "2019-06-21T03:32:57Z" 30 | generation: 7 31 | name: xgboost-dist-test 32 | namespace: default 33 | resourceVersion: "258466" 34 | selfLink: /apis/xgboostjob.kubeflow.org/v1alpha1/namespaces/default/xgboostjobs/xgboost-dist-test 35 | uid: 431dc182-93d5-11e9-bbab-080027dfbfe2 36 | spec: 37 | RunPolicy: 38 | cleanPodPolicy: None 39 | xgbReplicaSpecs: 40 | Master: 41 | replicas: 1 42 | restartPolicy: Never 43 | template: 44 | metadata: 45 | creationTimestamp: null 46 | spec: 47 | containers: 48 | - image: docker.io/merlintang/xgboost-dist-rabit-test:1.2 49 | imagePullPolicy: Always 50 | name: xgboostjob 51 | ports: 52 | - containerPort: 9991 53 | name: xgboostjob-port 54 | resources: {} 55 | Worker: 56 | replicas: 2 57 | restartPolicy: Never 58 | template: 59 | metadata: 60 | creationTimestamp: null 61 | spec: 62 | containers: 63 | - image: docker.io/merlintang/xgboost-dist-rabit-test:1.2 64 | imagePullPolicy: Always 65 | name: xgboostjob 66 | ports: 67 | - containerPort: 9991 68 | name: xgboostjob-port 69 | resources: {} 70 | status: 71 | completionTime: "2019-06-21T03:33:03Z" 72 | conditions: 73 | - lastTransitionTime: "2019-06-21T03:32:57Z" 74 | lastUpdateTime: "2019-06-21T03:32:57Z" 75 | message: xgboostJob xgboost-dist-test is created. 76 | reason: XGBoostJobCreated 77 | status: "True" 78 | type: Created 79 | - lastTransitionTime: "2019-06-21T03:32:57Z" 80 | lastUpdateTime: "2019-06-21T03:32:57Z" 81 | message: XGBoostJob xgboost-dist-test is running. 82 | reason: XGBoostJobRunning 83 | status: "False" 84 | type: Running 85 | - lastTransitionTime: "2019-06-21T03:33:03Z" 86 | lastUpdateTime: "2019-06-21T03:33:03Z" 87 | message: XGBoostJob xgboost-dist-test is successfully completed. 88 | reason: XGBoostJobSucceeded 89 | status: "True" 90 | type: Succeeded 91 | replicaStatuses: 92 | Master: 93 | succeeded: 1 94 | Worker: 95 | succeeded: 2 96 | ``` 97 | 98 | 99 | 100 | -------------------------------------------------------------------------------- /config/samples/smoke-dist/requirements.txt: -------------------------------------------------------------------------------- 1 | numpy>=1.16.3 2 | Cython>=0.29.4 3 | requests>=2.21.0 4 | urllib3>=1.21.1 5 | scipy>=1.4.1 6 | -------------------------------------------------------------------------------- /config/samples/smoke-dist/tracker.py: -------------------------------------------------------------------------------- 1 | """ 2 | Tracker script for DMLC 3 | Implements the tracker control protocol 4 | - start dmlc jobs 5 | - start ps scheduler and rabit tracker 6 | - help nodes to establish links with each other 7 | Tianqi Chen 8 | -------------------------- 9 | This was taken from 10 | https://github.com/dmlc/dmlc-core/blob/master/tracker/dmlc_tracker/tracker.py 11 | See LICENSE here 12 | https://github.com/dmlc/dmlc-core/blob/master/LICENSE 13 | No code modified or added except for this explanatory comment. 14 | """ 15 | # pylint: disable=invalid-name, missing-docstring, too-many-arguments 16 | # pylint: disable=too-many-locals 17 | # pylint: disable=too-many-branches, too-many-statements 18 | from __future__ import absolute_import 19 | 20 | import os 21 | import sys 22 | import socket 23 | import struct 24 | import subprocess 25 | import argparse 26 | import time 27 | import logging 28 | from threading import Thread 29 | 30 | 31 | class ExSocket(object): 32 | """ 33 | Extension of socket to handle recv and send of special data 34 | """ 35 | def __init__(self, sock): 36 | self.sock = sock 37 | 38 | def recvall(self, nbytes): 39 | res = [] 40 | nread = 0 41 | while nread < nbytes: 42 | chunk = self.sock.recv(min(nbytes - nread, 1024)) 43 | nread += len(chunk) 44 | res.append(chunk) 45 | return b''.join(res) 46 | 47 | def recvint(self): 48 | return struct.unpack('@i', self.recvall(4))[0] 49 | 50 | def sendint(self, n): 51 | self.sock.sendall(struct.pack('@i', n)) 52 | 53 | def sendstr(self, s): 54 | self.sendint(len(s)) 55 | self.sock.sendall(s.encode()) 56 | 57 | def recvstr(self): 58 | slen = self.recvint() 59 | return self.recvall(slen).decode() 60 | 61 | 62 | # magic number used to verify existence of data 63 | kMagic = 0xff99 64 | 65 | 66 | def get_some_ip(host): 67 | return socket.getaddrinfo(host, None)[0][4][0] 68 | 69 | 70 | def get_family(addr): 71 | return socket.getaddrinfo(addr, None)[0][0] 72 | 73 | 74 | class SlaveEntry(object): 75 | def __init__(self, sock, s_addr): 76 | slave = ExSocket(sock) 77 | self.sock = slave 78 | self.host = get_some_ip(s_addr[0]) 79 | magic = slave.recvint() 80 | assert magic == kMagic, 'invalid magic number=%d from %s' % ( 81 | magic, self.host) 82 | slave.sendint(kMagic) 83 | self.rank = slave.recvint() 84 | self.world_size = slave.recvint() 85 | self.jobid = slave.recvstr() 86 | self.cmd = slave.recvstr() 87 | self.wait_accept = 0 88 | self.port = None 89 | 90 | def decide_rank(self, job_map): 91 | if self.rank >= 0: 92 | return self.rank 93 | if self.jobid != 'NULL' and self.jobid in job_map: 94 | return job_map[self.jobid] 95 | return -1 96 | 97 | def assign_rank(self, rank, wait_conn, tree_map, parent_map, ring_map): 98 | self.rank = rank 99 | nnset = set(tree_map[rank]) 100 | rprev, rnext = ring_map[rank] 101 | self.sock.sendint(rank) 102 | # send parent rank 103 | self.sock.sendint(parent_map[rank]) 104 | # send world size 105 | self.sock.sendint(len(tree_map)) 106 | self.sock.sendint(len(nnset)) 107 | # send the rprev and next link 108 | for r in nnset: 109 | self.sock.sendint(r) 110 | # send prev link 111 | if rprev != -1 and rprev != rank: 112 | nnset.add(rprev) 113 | self.sock.sendint(rprev) 114 | else: 115 | self.sock.sendint(-1) 116 | # send next link 117 | if rnext != -1 and rnext != rank: 118 | nnset.add(rnext) 119 | self.sock.sendint(rnext) 120 | else: 121 | self.sock.sendint(-1) 122 | while True: 123 | ngood = self.sock.recvint() 124 | goodset = set([]) 125 | for _ in range(ngood): 126 | goodset.add(self.sock.recvint()) 127 | assert goodset.issubset(nnset) 128 | badset = nnset - goodset 129 | conset = [] 130 | for r in badset: 131 | if r in wait_conn: 132 | conset.append(r) 133 | self.sock.sendint(len(conset)) 134 | self.sock.sendint(len(badset) - len(conset)) 135 | for r in conset: 136 | self.sock.sendstr(wait_conn[r].host) 137 | self.sock.sendint(wait_conn[r].port) 138 | self.sock.sendint(r) 139 | nerr = self.sock.recvint() 140 | if nerr != 0: 141 | continue 142 | self.port = self.sock.recvint() 143 | rmset = [] 144 | # all connection was successuly setup 145 | for r in conset: 146 | wait_conn[r].wait_accept -= 1 147 | if wait_conn[r].wait_accept == 0: 148 | rmset.append(r) 149 | for r in rmset: 150 | wait_conn.pop(r, None) 151 | self.wait_accept = len(badset) - len(conset) 152 | return rmset 153 | 154 | 155 | class RabitTracker(object): 156 | """ 157 | tracker for rabit 158 | """ 159 | def __init__(self, hostIP, nslave, port=9091, port_end=9999): 160 | sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) 161 | for port in range(port, port_end): 162 | try: 163 | sock.bind((hostIP, port)) 164 | self.port = port 165 | break 166 | except socket.error as e: 167 | if e.errno in [98, 48]: 168 | continue 169 | else: 170 | raise 171 | sock.listen(256) 172 | self.sock = sock 173 | self.hostIP = hostIP 174 | self.thread = None 175 | self.start_time = None 176 | self.end_time = None 177 | self.nslave = nslave 178 | logging.info('start listen on %s:%d', hostIP, self.port) 179 | 180 | def __del__(self): 181 | self.sock.close() 182 | 183 | @staticmethod 184 | def get_neighbor(rank, nslave): 185 | rank = rank + 1 186 | ret = [] 187 | if rank > 1: 188 | ret.append(rank // 2 - 1) 189 | if rank * 2 - 1 < nslave: 190 | ret.append(rank * 2 - 1) 191 | if rank * 2 < nslave: 192 | ret.append(rank * 2) 193 | return ret 194 | 195 | def slave_envs(self): 196 | """ 197 | get enviroment variables for slaves 198 | can be passed in as args or envs 199 | """ 200 | return {'DMLC_TRACKER_URI': self.hostIP, 201 | 'DMLC_TRACKER_PORT': self.port} 202 | 203 | def get_tree(self, nslave): 204 | tree_map = {} 205 | parent_map = {} 206 | for r in range(nslave): 207 | tree_map[r] = self.get_neighbor(r, nslave) 208 | parent_map[r] = (r + 1) // 2 - 1 209 | return tree_map, parent_map 210 | 211 | def find_share_ring(self, tree_map, parent_map, r): 212 | """ 213 | get a ring structure that tends to share nodes with the tree 214 | return a list starting from r 215 | """ 216 | nset = set(tree_map[r]) 217 | cset = nset - set([parent_map[r]]) 218 | if len(cset) == 0: 219 | return [r] 220 | rlst = [r] 221 | cnt = 0 222 | for v in cset: 223 | vlst = self.find_share_ring(tree_map, parent_map, v) 224 | cnt += 1 225 | if cnt == len(cset): 226 | vlst.reverse() 227 | rlst += vlst 228 | return rlst 229 | 230 | def get_ring(self, tree_map, parent_map): 231 | """ 232 | get a ring connection used to recover local data 233 | """ 234 | assert parent_map[0] == -1 235 | rlst = self.find_share_ring(tree_map, parent_map, 0) 236 | assert len(rlst) == len(tree_map) 237 | ring_map = {} 238 | nslave = len(tree_map) 239 | for r in range(nslave): 240 | rprev = (r + nslave - 1) % nslave 241 | rnext = (r + 1) % nslave 242 | ring_map[rlst[r]] = (rlst[rprev], rlst[rnext]) 243 | return ring_map 244 | 245 | def get_link_map(self, nslave): 246 | """ 247 | get the link map, this is a bit hacky, call for better algorithm 248 | to place similar nodes together 249 | """ 250 | tree_map, parent_map = self.get_tree(nslave) 251 | ring_map = self.get_ring(tree_map, parent_map) 252 | rmap = {0: 0} 253 | k = 0 254 | for i in range(nslave - 1): 255 | k = ring_map[k][1] 256 | rmap[k] = i + 1 257 | 258 | ring_map_ = {} 259 | tree_map_ = {} 260 | parent_map_ = {} 261 | for k, v in ring_map.items(): 262 | ring_map_[rmap[k]] = (rmap[v[0]], rmap[v[1]]) 263 | for k, v in tree_map.items(): 264 | tree_map_[rmap[k]] = [rmap[x] for x in v] 265 | for k, v in parent_map.items(): 266 | if k != 0: 267 | parent_map_[rmap[k]] = rmap[v] 268 | else: 269 | parent_map_[rmap[k]] = -1 270 | return tree_map_, parent_map_, ring_map_ 271 | 272 | def accept_slaves(self, nslave): 273 | # set of nodes that finishs the job 274 | shutdown = {} 275 | # set of nodes that is waiting for connections 276 | wait_conn = {} 277 | # maps job id to rank 278 | job_map = {} 279 | # list of workers that is pending to be assigned rank 280 | pending = [] 281 | # lazy initialize tree_map 282 | tree_map = None 283 | 284 | while len(shutdown) != nslave: 285 | fd, s_addr = self.sock.accept() 286 | s = SlaveEntry(fd, s_addr) 287 | if s.cmd == 'print': 288 | msg = s.sock.recvstr() 289 | logging.info(msg.strip()) 290 | continue 291 | if s.cmd == 'shutdown': 292 | assert s.rank >= 0 and s.rank not in shutdown 293 | assert s.rank not in wait_conn 294 | shutdown[s.rank] = s 295 | logging.debug('Recieve %s signal from %d', s.cmd, s.rank) 296 | continue 297 | assert s.cmd == 'start' or s.cmd == 'recover' 298 | # lazily initialize the slaves 299 | if tree_map is None: 300 | assert s.cmd == 'start' 301 | if s.world_size > 0: 302 | nslave = s.world_size 303 | tree_map, parent_map, ring_map = self.get_link_map(nslave) 304 | # set of nodes that is pending for getting up 305 | todo_nodes = list(range(nslave)) 306 | else: 307 | assert s.world_size == -1 or s.world_size == nslave 308 | if s.cmd == 'recover': 309 | assert s.rank >= 0 310 | 311 | rank = s.decide_rank(job_map) 312 | # batch assignment of ranks 313 | if rank == -1: 314 | assert len(todo_nodes) != 0 315 | pending.append(s) 316 | if len(pending) == len(todo_nodes): 317 | pending.sort(key=lambda x: x.host) 318 | for s in pending: 319 | rank = todo_nodes.pop(0) 320 | if s.jobid != 'NULL': 321 | job_map[s.jobid] = rank 322 | s.assign_rank(rank, wait_conn, tree_map, parent_map, 323 | ring_map) 324 | if s.wait_accept > 0: 325 | wait_conn[rank] = s 326 | logging.debug('Recieve %s signal from %s; ' 327 | 'assign rank %d', s.cmd, s.host, s.rank) 328 | if len(todo_nodes) == 0: 329 | logging.info('@tracker All of %d nodes getting started', 330 | nslave) 331 | self.start_time = time.time() 332 | else: 333 | s.assign_rank(rank, wait_conn, tree_map, parent_map, ring_map) 334 | logging.debug('Recieve %s signal from %d', s.cmd, s.rank) 335 | if s.wait_accept > 0: 336 | wait_conn[rank] = s 337 | 338 | logging.info("worker(ip_address=%s) connected!" % get_some_ip(s_addr[0])) 339 | 340 | logging.info('@tracker All nodes finishes job') 341 | self.end_time = time.time() 342 | logging.info('@tracker %s secs between node start and job finish', 343 | str(self.end_time - self.start_time)) 344 | 345 | def start(self, nslave): 346 | def run(): 347 | self.accept_slaves(nslave) 348 | self.thread = Thread(target=run, args=()) 349 | self.thread.setDaemon(True) 350 | self.thread.start() 351 | 352 | def join(self): 353 | while self.thread.isAlive(): 354 | self.thread.join(100) 355 | 356 | 357 | class PSTracker(object): 358 | """ 359 | Tracker module for PS 360 | """ 361 | def __init__(self, hostIP, cmd, port=9091, port_end=9999, envs=None): 362 | """ 363 | Starts the PS scheduler 364 | """ 365 | self.cmd = cmd 366 | if cmd is None: 367 | return 368 | envs = {} if envs is None else envs 369 | self.hostIP = hostIP 370 | sock = socket.socket(get_family(hostIP), socket.SOCK_STREAM) 371 | for port in range(port, port_end): 372 | try: 373 | sock.bind(('', port)) 374 | self.port = port 375 | sock.close() 376 | break 377 | except socket.error: 378 | continue 379 | env = os.environ.copy() 380 | 381 | env['DMLC_ROLE'] = 'scheduler' 382 | env['DMLC_PS_ROOT_URI'] = str(self.hostIP) 383 | env['DMLC_PS_ROOT_PORT'] = str(self.port) 384 | for k, v in envs.items(): 385 | env[k] = str(v) 386 | self.thread = Thread( 387 | target=(lambda: subprocess.check_call(self.cmd, env=env, 388 | shell=True)), args=()) 389 | self.thread.setDaemon(True) 390 | self.thread.start() 391 | 392 | def join(self): 393 | if self.cmd is not None: 394 | while self.thread.isAlive(): 395 | self.thread.join(100) 396 | 397 | def slave_envs(self): 398 | if self.cmd is None: 399 | return {} 400 | else: 401 | return {'DMLC_PS_ROOT_URI': self.hostIP, 402 | 'DMLC_PS_ROOT_PORT': self.port} 403 | 404 | 405 | def get_host_ip(hostIP=None): 406 | if hostIP is None or hostIP == 'auto': 407 | hostIP = 'ip' 408 | 409 | if hostIP == 'dns': 410 | hostIP = socket.getfqdn() 411 | elif hostIP == 'ip': 412 | from socket import gaierror 413 | try: 414 | hostIP = socket.gethostbyname(socket.getfqdn()) 415 | except gaierror: 416 | logging.warn('gethostbyname(socket.getfqdn()) failed... trying on ' 417 | 'hostname()') 418 | hostIP = socket.gethostbyname(socket.gethostname()) 419 | if hostIP.startswith("127."): 420 | s = socket.socket(socket.AF_INET, socket.SOCK_DGRAM) 421 | # doesn't have to be reachable 422 | s.connect(('10.255.255.255', 1)) 423 | hostIP = s.getsockname()[0] 424 | return hostIP 425 | 426 | 427 | def submit(nworker, nserver, fun_submit, hostIP='auto', pscmd=None): 428 | if nserver == 0: 429 | pscmd = None 430 | 431 | envs = {'DMLC_NUM_WORKER': nworker, 432 | 'DMLC_NUM_SERVER': nserver} 433 | hostIP = get_host_ip(hostIP) 434 | 435 | if nserver == 0: 436 | rabit = RabitTracker(hostIP=hostIP, nslave=nworker) 437 | envs.update(rabit.slave_envs()) 438 | rabit.start(nworker) 439 | else: 440 | pserver = PSTracker(hostIP=hostIP, cmd=pscmd, envs=envs) 441 | envs.update(pserver.slave_envs()) 442 | fun_submit(nworker, nserver, envs) 443 | 444 | if nserver == 0: 445 | rabit.join() 446 | else: 447 | pserver.join() 448 | 449 | 450 | def start_rabit_tracker(args): 451 | """Standalone function to start rabit tracker. 452 | Parameters 453 | ---------- 454 | args: arguments to start the rabit tracker. 455 | """ 456 | envs = {'DMLC_NUM_WORKER': args.num_workers, 457 | 'DMLC_NUM_SERVER': args.num_servers} 458 | rabit = RabitTracker(hostIP=get_host_ip(args.host_ip), 459 | nslave=args.num_workers) 460 | envs.update(rabit.slave_envs()) 461 | rabit.start(args.num_workers) 462 | sys.stdout.write('DMLC_TRACKER_ENV_START\n') 463 | # simply write configuration to stdout 464 | for k, v in envs.items(): 465 | sys.stdout.write('%s=%s\n' % (k, str(v))) 466 | sys.stdout.write('DMLC_TRACKER_ENV_END\n') 467 | sys.stdout.flush() 468 | rabit.join() 469 | 470 | 471 | def main(): 472 | """Main function if tracker is executed in standalone mode.""" 473 | parser = argparse.ArgumentParser(description='Rabit Tracker start.') 474 | parser.add_argument('--num-workers', required=True, type=int, 475 | help='Number of worker proccess to be launched.') 476 | parser.add_argument('--num-servers', default=0, type=int, 477 | help='Number of server process to be launched. Only ' 478 | 'used in PS jobs.') 479 | parser.add_argument('--host-ip', default=None, type=str, 480 | help=('Host IP addressed, this is only needed ' + 481 | 'if the host IP cannot be automatically guessed.' 482 | )) 483 | parser.add_argument('--log-level', default='INFO', type=str, 484 | choices=['INFO', 'DEBUG'], 485 | help='Logging level of the logger.') 486 | args = parser.parse_args() 487 | 488 | fmt = '%(asctime)s %(levelname)s %(message)s' 489 | if args.log_level == 'INFO': 490 | level = logging.INFO 491 | elif args.log_level == 'DEBUG': 492 | level = logging.DEBUG 493 | else: 494 | raise RuntimeError("Unknown logging level %s" % args.log_level) 495 | 496 | logging.basicConfig(format=fmt, level=level) 497 | 498 | if args.num_servers == 0: 499 | start_rabit_tracker(args) 500 | else: 501 | raise RuntimeError("Do not yet support start ps tracker in standalone " 502 | "mode.") 503 | 504 | 505 | if __name__ == "__main__": 506 | main() -------------------------------------------------------------------------------- /config/samples/smoke-dist/xgboost_smoke_test.py: -------------------------------------------------------------------------------- 1 | 2 | # Copyright 2018 Google Inc. All Rights Reserved. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | import logging 17 | import os 18 | import xgboost as xgb 19 | import traceback 20 | 21 | from tracker import RabitTracker 22 | 23 | logger = logging.getLogger(__name__) 24 | 25 | def extract_xgbooost_cluster_env(): 26 | 27 | logger.info("start to extract system env") 28 | 29 | master_addr = os.environ.get("MASTER_ADDR", "{}") 30 | master_port = int(os.environ.get("MASTER_PORT", "{}")) 31 | rank = int(os.environ.get("RANK", "{}")) 32 | world_size = int(os.environ.get("WORLD_SIZE", "{}")) 33 | 34 | logger.info("extract the rabit env from cluster : %s, port: %d, rank: %d, word_size: %d ", 35 | master_addr, master_port, rank, world_size) 36 | 37 | return master_addr, master_port, rank, world_size 38 | 39 | def setup_rabit_cluster(): 40 | addr, port, rank, world_size = extract_xgbooost_cluster_env() 41 | 42 | rabit_tracker = None 43 | try: 44 | """start to build the network""" 45 | if world_size > 1: 46 | if rank == 0: 47 | logger.info("start the master node") 48 | 49 | rabit = RabitTracker(hostIP="0.0.0.0", nslave=world_size, 50 | port=port, port_end=port + 1) 51 | rabit.start(world_size) 52 | rabit_tracker = rabit 53 | logger.info('########### RabitTracker Setup Finished #########') 54 | 55 | envs = [ 56 | 'DMLC_NUM_WORKER=%d' % world_size, 57 | 'DMLC_TRACKER_URI=%s' % addr, 58 | 'DMLC_TRACKER_PORT=%d' % port, 59 | 'DMLC_TASK_ID=%d' % rank 60 | ] 61 | logger.info('##### Rabit rank setup with below envs #####') 62 | for i, env in enumerate(envs): 63 | logger.info(env) 64 | envs[i] = str.encode(env) 65 | 66 | xgb.rabit.init(envs) 67 | logger.info('##### Rabit rank = %d' % xgb.rabit.get_rank()) 68 | 69 | rank = xgb.rabit.get_rank() 70 | s = None 71 | if rank == 0: 72 | s = {'hello world': 100, 2: 3} 73 | 74 | logger.info('@node[%d] before-broadcast: s=\"%s\"' % (rank, str(s))) 75 | s = xgb.rabit.broadcast(s, 0) 76 | 77 | logger.info('@node[%d] after-broadcast: s=\"%s\"' % (rank, str(s))) 78 | 79 | except Exception as e: 80 | logger.error("something wrong happen: %s", traceback.format_exc()) 81 | raise e 82 | finally: 83 | if world_size > 1: 84 | xgb.rabit.finalize() 85 | if rabit_tracker: 86 | rabit_tracker.join() 87 | 88 | logger.info("the rabit network testing finished!") 89 | 90 | def main(): 91 | 92 | port = os.environ.get("MASTER_PORT", "{}") 93 | logging.info("MASTER_PORT: %s", port) 94 | 95 | addr = os.environ.get("MASTER_ADDR", "{}") 96 | logging.info("MASTER_ADDR: %s", addr) 97 | 98 | world_size = os.environ.get("WORLD_SIZE", "{}") 99 | logging.info("WORLD_SIZE: %s", world_size) 100 | 101 | rank = os.environ.get("RANK", "{}") 102 | logging.info("RANK: %s", rank) 103 | 104 | setup_rabit_cluster() 105 | 106 | if __name__ == "__main__": 107 | logging.getLogger().setLevel(logging.INFO) 108 | main() 109 | -------------------------------------------------------------------------------- /config/samples/smoke-dist/xgboostjob_v1_rabit_test.yaml: -------------------------------------------------------------------------------- 1 | apiVersion: "xgboostjob.kubeflow.org/v1" 2 | kind: "XGBoostJob" 3 | metadata: 4 | name: "xgboost-dist-test" 5 | spec: 6 | xgbReplicaSpecs: 7 | Master: 8 | replicas: 1 9 | restartPolicy: Never 10 | template: 11 | spec: 12 | containers: 13 | - name: xgboostjob 14 | image: docker.io/merlintang/xgboost-dist-rabit-test:1.2 15 | ports: 16 | - containerPort: 9991 17 | name: xgboostjob-port 18 | imagePullPolicy: Always 19 | Worker: 20 | replicas: 2 21 | restartPolicy: Never 22 | template: 23 | spec: 24 | containers: 25 | - name: xgboostjob 26 | image: docker.io/merlintang/xgboost-dist-rabit-test:1.2 27 | ports: 28 | - containerPort: 9991 29 | name: xgboostjob-port 30 | imagePullPolicy: Always 31 | 32 | -------------------------------------------------------------------------------- /config/samples/smoke-dist/xgboostjob_v1alpha1_rabit_test.yaml: -------------------------------------------------------------------------------- 1 | apiVersion: "xgboostjob.kubeflow.org/v1alpha1" 2 | kind: "XGBoostJob" 3 | metadata: 4 | name: "xgboost-dist-test" 5 | spec: 6 | xgbReplicaSpecs: 7 | Master: 8 | replicas: 1 9 | restartPolicy: Never 10 | template: 11 | apiVersion: v1 12 | kind: Pod 13 | spec: 14 | containers: 15 | - name: xgboostjob 16 | image: docker.io/merlintang/xgboost-dist-rabit-test:1.2 17 | ports: 18 | - containerPort: 9991 19 | name: xgboostjob-port 20 | imagePullPolicy: Always 21 | Worker: 22 | replicas: 2 23 | restartPolicy: Never 24 | template: 25 | apiVersion: v1 26 | kind: Pod 27 | spec: 28 | containers: 29 | - name: xgboostjob 30 | image: docker.io/merlintang/xgboost-dist-rabit-test:1.2 31 | ports: 32 | - containerPort: 9991 33 | name: xgboostjob-port 34 | imagePullPolicy: Always 35 | 36 | -------------------------------------------------------------------------------- /config/samples/xgboost-dist/Dockerfile: -------------------------------------------------------------------------------- 1 | # Install python 3。6. 2 | FROM python:3.6 3 | 4 | RUN apt-get update 5 | RUN apt-get install -y git make g++ cmake 6 | 7 | RUN mkdir -p /opt/mlkube 8 | 9 | # Download the rabit tracker and xgboost code. 10 | 11 | COPY requirements.txt /opt/mlkube/ 12 | 13 | # Install requirements 14 | 15 | RUN pip install -r /opt/mlkube/requirements.txt 16 | 17 | # Build XGBoost. 18 | RUN git clone --recursive https://github.com/dmlc/xgboost && \ 19 | cd xgboost && \ 20 | make -j$(nproc) && \ 21 | cd python-package; python setup.py install 22 | 23 | COPY *.py /opt/mlkube/ 24 | 25 | ENTRYPOINT ["python", "/opt/mlkube/main.py"] 26 | -------------------------------------------------------------------------------- /config/samples/xgboost-dist/build.sh: -------------------------------------------------------------------------------- 1 | ## build the docker file 2 | docker build -f Dockerfile -t merlintang/xgboost-dist-iris:1.0 ./ 3 | 4 | ## push the docker image into docker.io 5 | docker push merlintang/xgboost-dist-iris:1.0 6 | 7 | ## run the train job 8 | kubectl create -f xgboostjob_v1alpha1_iris_train.yaml 9 | 10 | ## run the predict job 11 | kubectl create -f xgboostjob_v1alpha1_iris_predict.yaml 12 | -------------------------------------------------------------------------------- /config/samples/xgboost-dist/local_test.py: -------------------------------------------------------------------------------- 1 | # Licensed under the Apache License, Version 2.0 (the "License"); 2 | # you may not use this file except in compliance with the License. 3 | # You may obtain a copy of the License at 4 | # 5 | # http://www.apache.org/licenses/LICENSE-2.0 6 | # 7 | # Unless required by applicable law or agreed to in writing, software 8 | # distributed under the License is distributed on an "AS IS" BASIS, 9 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 10 | # See the License for the specific language governing permissions and 11 | # limitations under the License. 12 | 13 | """ 14 | this file contains tests for xgboost local train and predict in single machine. 15 | Note: this is not for distributed train and predict test 16 | """ 17 | from utils import dump_model, read_model, read_train_data, read_predict_data 18 | import xgboost as xgb 19 | import logging 20 | import numpy as np 21 | from sklearn.metrics import precision_score 22 | 23 | logger = logging.getLogger(__name__) 24 | 25 | 26 | def test_train_model(): 27 | """ 28 | test xgboost train in a single machine 29 | :return: trained model 30 | """ 31 | rank = 1 32 | world_size = 10 33 | place = "/tmp/data" 34 | dmatrix = read_train_data(rank, world_size, place) 35 | 36 | param_xgboost_default = {'max_depth': 2, 'eta': 1, 'silent': 1, 37 | 'objective': 'multi:softprob', 'num_class': 3} 38 | 39 | booster = xgb.train(param_xgboost_default, dtrain=dmatrix) 40 | 41 | assert booster is not None 42 | 43 | return booster 44 | 45 | 46 | def test_model_predict(booster): 47 | """ 48 | test xgboost train in the single node 49 | :return: true if pass the test 50 | """ 51 | rank = 1 52 | world_size = 10 53 | place = "/tmp/data" 54 | dmatrix, y_test = read_predict_data(rank, world_size, place) 55 | 56 | preds = booster.predict(dmatrix) 57 | best_preds = np.asarray([np.argmax(line) for line in preds]) 58 | score = precision_score(y_test, best_preds, average='macro') 59 | 60 | assert score > 0.99 61 | 62 | logging.info("Predict accuracy: %f", score) 63 | 64 | return True 65 | 66 | 67 | def test_upload_model(model, model_path, args): 68 | 69 | return dump_model(model, type="local", model_path=model_path, args=args) 70 | 71 | 72 | def test_download_model(model_path, args): 73 | 74 | return read_model(type="local", model_path=model_path, args=args) 75 | 76 | 77 | def run_test(): 78 | args = {} 79 | model_path = "/tmp/xgboost" 80 | 81 | logging.info("Start the local test") 82 | 83 | booster = test_train_model() 84 | test_upload_model(booster, model_path, args) 85 | booster_new = test_download_model(model_path, args) 86 | test_model_predict(booster_new) 87 | 88 | logging.info("Finish the local test") 89 | 90 | 91 | if __name__ == '__main__': 92 | 93 | logging.basicConfig(format='%(message)s') 94 | logging.getLogger().setLevel(logging.INFO) 95 | 96 | run_test() 97 | -------------------------------------------------------------------------------- /config/samples/xgboost-dist/main.py: -------------------------------------------------------------------------------- 1 | # Licensed under the Apache License, Version 2.0 (the "License"); 2 | # you may not use this file except in compliance with the License. 3 | # You may obtain a copy of the License at 4 | # 5 | # http://www.apache.org/licenses/LICENSE-2.0 6 | # 7 | # Unless required by applicable law or agreed to in writing, software 8 | # distributed under the License is distributed on an "AS IS" BASIS, 9 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 10 | # See the License for the specific language governing permissions and 11 | # limitations under the License. 12 | 13 | import argparse 14 | import logging 15 | 16 | from train import train 17 | from predict import predict 18 | from utils import dump_model 19 | 20 | 21 | def main(args): 22 | 23 | model_storage_type = args.model_storage_type 24 | if (model_storage_type == "local" or model_storage_type == "oss"): 25 | print ( "The storage type is " + model_storage_type) 26 | else: 27 | raise Exception("Only supports storage types like local and OSS") 28 | 29 | if args.job_type == "Predict": 30 | logging.info("starting the predict job") 31 | predict(args) 32 | 33 | elif args.job_type == "Train": 34 | logging.info("starting the train job") 35 | model = train(args) 36 | 37 | if model is not None: 38 | logging.info("finish the model training, and start to dump model ") 39 | model_path = args.model_path 40 | dump_model(model, model_storage_type, model_path, args) 41 | 42 | elif args.job_type == "All": 43 | logging.info("starting the train and predict job") 44 | 45 | logging.info("Finish distributed XGBoost job") 46 | 47 | 48 | if __name__ == '__main__': 49 | parser = argparse.ArgumentParser() 50 | 51 | parser.add_argument( 52 | '--job_type', 53 | help="Train, Predict, All", 54 | required=True 55 | ) 56 | parser.add_argument( 57 | '--xgboost_parameter', 58 | help='XGBoost model parameter like: objective, number_class', 59 | ) 60 | parser.add_argument( 61 | '--n_estimators', 62 | help='Number of trees in the model', 63 | type=int, 64 | default=1000 65 | ) 66 | parser.add_argument( 67 | '--learning_rate', 68 | help='Learning rate for the model', 69 | default=0.1 70 | ) 71 | parser.add_argument( 72 | '--early_stopping_rounds', 73 | help='XGBoost argument for stopping early', 74 | default=50 75 | ) 76 | parser.add_argument( 77 | '--model_path', 78 | help='place to store model', 79 | default="/tmp/xgboost_model" 80 | ) 81 | parser.add_argument( 82 | '--model_storage_type', 83 | help='place to store the model', 84 | default="oss" 85 | ) 86 | parser.add_argument( 87 | '--oss_param', 88 | help='oss parameter if you choose the model storage as OSS type', 89 | ) 90 | 91 | logging.basicConfig(format='%(message)s') 92 | logging.getLogger().setLevel(logging.INFO) 93 | main_args = parser.parse_args() 94 | main(main_args) 95 | -------------------------------------------------------------------------------- /config/samples/xgboost-dist/predict.py: -------------------------------------------------------------------------------- 1 | # Licensed under the Apache License, Version 2.0 (the "License"); 2 | # you may not use this file except in compliance with the License. 3 | # You may obtain a copy of the License at 4 | # 5 | # http://www.apache.org/licenses/LICENSE-2.0 6 | # 7 | # Unless required by applicable law or agreed to in writing, software 8 | # distributed under the License is distributed on an "AS IS" BASIS, 9 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 10 | # See the License for the specific language governing permissions and 11 | # limitations under the License. 12 | 13 | from sklearn.metrics import precision_score 14 | 15 | import logging 16 | import numpy as np 17 | 18 | from utils import extract_xgbooost_cluster_env, read_predict_data, read_model 19 | 20 | 21 | def predict(args): 22 | """ 23 | This is the demonstration for the batch prediction 24 | :param args: parameter for model related config 25 | """ 26 | 27 | addr, port, rank, world_size = extract_xgbooost_cluster_env() 28 | 29 | dmatrix, y_test = read_predict_data(rank, world_size, None) 30 | 31 | model_path = args.model_path 32 | storage_type = args.model_storage_type 33 | booster = read_model(storage_type, model_path, args) 34 | 35 | preds = booster.predict(dmatrix) 36 | 37 | best_preds = np.asarray([np.argmax(line) for line in preds]) 38 | score = precision_score(y_test, best_preds, average='macro') 39 | 40 | logging.info("Predict accuracy: %f", score) 41 | -------------------------------------------------------------------------------- /config/samples/xgboost-dist/requirements.txt: -------------------------------------------------------------------------------- 1 | numpy>=1.16.3 2 | Cython>=0.29.4 3 | requests>=2.21.0 4 | urllib3>=1.21.1 5 | scipy>=1.1.0 6 | joblib>=0.13.2 7 | scikit-learn>=0.20 8 | oss2>=2.7.0 9 | pandas>=0.24.2 -------------------------------------------------------------------------------- /config/samples/xgboost-dist/train.py: -------------------------------------------------------------------------------- 1 | # Licensed under the Apache License, Version 2.0 (the "License"); 2 | # you may not use this file except in compliance with the License. 3 | # You may obtain a copy of the License at 4 | # 5 | # http://www.apache.org/licenses/LICENSE-2.0 6 | # 7 | # Unless required by applicable law or agreed to in writing, software 8 | # distributed under the License is distributed on an "AS IS" BASIS, 9 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 10 | # See the License for the specific language governing permissions and 11 | # limitations under the License. 12 | 13 | 14 | import logging 15 | import xgboost as xgb 16 | import traceback 17 | 18 | from tracker import RabitTracker 19 | from utils import read_train_data, extract_xgbooost_cluster_env 20 | 21 | logger = logging.getLogger(__name__) 22 | 23 | 24 | def train(args): 25 | """ 26 | :param args: configuration for train job 27 | :return: XGBoost model 28 | """ 29 | addr, port, rank, world_size = extract_xgbooost_cluster_env() 30 | rabit_tracker = None 31 | 32 | try: 33 | """start to build the network""" 34 | if world_size > 1: 35 | if rank == 0: 36 | logger.info("start the master node") 37 | 38 | rabit = RabitTracker(hostIP="0.0.0.0", nslave=world_size, 39 | port=port, port_end=port + 1) 40 | rabit.start(world_size) 41 | rabit_tracker = rabit 42 | logger.info('###### RabitTracker Setup Finished ######') 43 | 44 | envs = [ 45 | 'DMLC_NUM_WORKER=%d' % world_size, 46 | 'DMLC_TRACKER_URI=%s' % addr, 47 | 'DMLC_TRACKER_PORT=%d' % port, 48 | 'DMLC_TASK_ID=%d' % rank 49 | ] 50 | logger.info('##### Rabit rank setup with below envs #####') 51 | for i, env in enumerate(envs): 52 | logger.info(env) 53 | envs[i] = str.encode(env) 54 | 55 | xgb.rabit.init(envs) 56 | logger.info('##### Rabit rank = %d' % xgb.rabit.get_rank()) 57 | rank = xgb.rabit.get_rank() 58 | 59 | else: 60 | world_size = 1 61 | logging.info("Start the train in a single node") 62 | 63 | df = read_train_data(rank=rank, num_workers=world_size, path=None) 64 | kwargs = {} 65 | kwargs["dtrain"] = df 66 | kwargs["num_boost_round"] = int(args.n_estimators) 67 | param_xgboost_default = {'max_depth': 2, 'eta': 1, 'silent': 1, 68 | 'objective': 'multi:softprob', 'num_class': 3} 69 | kwargs["params"] = param_xgboost_default 70 | 71 | logging.info("starting to train xgboost at node with rank %d", rank) 72 | bst = xgb.train(**kwargs) 73 | 74 | if rank == 0: 75 | model = bst 76 | else: 77 | model = None 78 | 79 | logging.info("finish xgboost training at node with rank %d", rank) 80 | 81 | except Exception as e: 82 | logger.error("something wrong happen: %s", traceback.format_exc()) 83 | raise e 84 | finally: 85 | logger.info("xgboost training job finished!") 86 | if world_size > 1: 87 | xgb.rabit.finalize() 88 | if rabit_tracker: 89 | rabit_tracker.join() 90 | 91 | return model 92 | -------------------------------------------------------------------------------- /config/samples/xgboost-dist/utils.py: -------------------------------------------------------------------------------- 1 | # Licensed under the Apache License, Version 2.0 (the "License"); 2 | # you may not use this file except in compliance with the License. 3 | # You may obtain a copy of the License at 4 | # 5 | # http://www.apache.org/licenses/LICENSE-2.0 6 | # 7 | # Unless required by applicable law or agreed to in writing, software 8 | # distributed under the License is distributed on an "AS IS" BASIS, 9 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 10 | # See the License for the specific language governing permissions and 11 | # limitations under the License. 12 | 13 | import logging 14 | import joblib 15 | import xgboost as xgb 16 | import os 17 | import tempfile 18 | import oss2 19 | import json 20 | import pandas as pd 21 | 22 | from sklearn import datasets 23 | 24 | logger = logging.getLogger(__name__) 25 | 26 | 27 | def extract_xgbooost_cluster_env(): 28 | """ 29 | Extract the cluster env from pod 30 | :return: the related cluster env to build rabit 31 | """ 32 | 33 | logger.info("starting to extract system env") 34 | 35 | master_addr = os.environ.get("MASTER_ADDR", "{}") 36 | master_port = int(os.environ.get("MASTER_PORT", "{}")) 37 | rank = int(os.environ.get("RANK", "{}")) 38 | world_size = int(os.environ.get("WORLD_SIZE", "{}")) 39 | 40 | logger.info("extract the Rabit env from cluster :" 41 | " %s, port: %d, rank: %d, word_size: %d ", 42 | master_addr, master_port, rank, world_size) 43 | 44 | return master_addr, master_port, rank, world_size 45 | 46 | 47 | def read_train_data(rank, num_workers, path): 48 | """ 49 | Read file based on the rank of worker. 50 | We use the sklearn.iris data for demonstration 51 | You can extend this to read distributed data source like HDFS, HIVE etc 52 | :param rank: the id of each worker 53 | :param num_workers: total number of workers in this cluster 54 | :param path: the input file name or the place to read the data 55 | :return: XGBoost Dmatrix 56 | """ 57 | iris = datasets.load_iris() 58 | x = iris.data 59 | y = iris.target 60 | 61 | start, end = get_range_data(len(x), rank, num_workers) 62 | x = x[start:end, :] 63 | y = y[start:end] 64 | 65 | x = pd.DataFrame(x) 66 | y = pd.DataFrame(y) 67 | dtrain = xgb.DMatrix(data=x, label=y) 68 | 69 | logging.info("Read data from IRIS data source with range from %d to %d", 70 | start, end) 71 | 72 | return dtrain 73 | 74 | 75 | def read_predict_data(rank, num_workers, path): 76 | """ 77 | Read file based on the rank of worker. 78 | We use the sklearn.iris data for demonstration 79 | You can extend this to read distributed data source like HDFS, HIVE etc 80 | :param rank: the id of each worker 81 | :param num_workers: total number of workers in this cluster 82 | :param path: the input file name or the place to read the data 83 | :return: XGBoost Dmatrix, and real value 84 | """ 85 | iris = datasets.load_iris() 86 | x = iris.data 87 | y = iris.target 88 | 89 | start, end = get_range_data(len(x), rank, num_workers) 90 | x = x[start:end, :] 91 | y = y[start:end] 92 | x = pd.DataFrame(x) 93 | y = pd.DataFrame(y) 94 | 95 | logging.info("Read data from IRIS datasource with range from %d to %d", 96 | start, end) 97 | 98 | predict = xgb.DMatrix(x, label=y) 99 | 100 | return predict, y 101 | 102 | 103 | def get_range_data(num_row, rank, num_workers): 104 | """ 105 | compute the data range based on the input data size and worker id 106 | :param num_row: total number of dataset 107 | :param rank: the worker id 108 | :param num_workers: total number of workers 109 | :return: begin and end range of input matrix 110 | """ 111 | num_per_partition = int(num_row/num_workers) 112 | 113 | x_start = rank * num_per_partition 114 | x_end = (rank + 1) * num_per_partition 115 | 116 | if x_end > num_row: 117 | x_end = num_row 118 | 119 | return x_start, x_end 120 | 121 | 122 | def dump_model(model, type, model_path, args): 123 | """ 124 | dump the trained model into local place 125 | you can update this function to store the model into a remote place 126 | :param model: the xgboost trained booster 127 | :param type: model storage type 128 | :param model_path: place to store model 129 | :param args: configuration for model storage 130 | :return: True if the dump process success 131 | """ 132 | if model is None: 133 | raise Exception("fail to get the XGBoost train model") 134 | else: 135 | if type == "local": 136 | joblib.dump(model, model_path) 137 | logging.info("Dump model into local place %s", model_path) 138 | 139 | elif type == "oss": 140 | oss_param = parse_parameters(args.oss_param, ",", ":") 141 | if oss_param is None: 142 | raise Exception("Please config oss parameter to store model") 143 | 144 | oss_param['path'] = args.model_path 145 | dump_model_to_oss(oss_param, model) 146 | logging.info("Dump model into oss place %s", args.model_path) 147 | 148 | return True 149 | 150 | 151 | def read_model(type, model_path, args): 152 | """ 153 | read model from physical storage 154 | :param type: oss or local 155 | :param model_path: place to store the model 156 | :param args: configuration to read model 157 | :return: XGBoost model 158 | """ 159 | 160 | if type == "local": 161 | model = joblib.load(model_path) 162 | logging.info("Read model from local place %s", model_path) 163 | 164 | elif type == "oss": 165 | oss_param = parse_parameters(args.oss_param, ",", ":") 166 | if oss_param is None: 167 | raise Exception("Please config oss to read model") 168 | return False 169 | 170 | oss_param['path'] = args.model_path 171 | 172 | model = read_model_from_oss(oss_param) 173 | logging.info("read model from oss place %s", model_path) 174 | 175 | return model 176 | 177 | 178 | def dump_model_to_oss(oss_parameters, booster): 179 | """ 180 | dump the model to remote OSS disk 181 | :param oss_parameters: oss configuration 182 | :param booster: XGBoost model 183 | :return: True if stored procedure is success 184 | """ 185 | """export model into oss""" 186 | model_fname = os.path.join(tempfile.mkdtemp(), 'model') 187 | text_model_fname = os.path.join(tempfile.mkdtemp(), 'model.text') 188 | feature_importance = os.path.join(tempfile.mkdtemp(), 189 | 'feature_importance.json') 190 | 191 | oss_path = oss_parameters['path'] 192 | logger.info('---- export model ----') 193 | booster.save_model(model_fname) 194 | booster.dump_model(text_model_fname) # format output model 195 | fscore_dict = booster.get_fscore() 196 | with open(feature_importance, 'w') as file: 197 | file.write(json.dumps(fscore_dict)) 198 | logger.info('---- chief dump model successfully!') 199 | 200 | if os.path.exists(model_fname): 201 | logger.info('---- Upload Model start...') 202 | 203 | while oss_path[-1] == '/': 204 | oss_path = oss_path[:-1] 205 | 206 | upload_oss(oss_parameters, model_fname, oss_path) 207 | aux_path = oss_path + '_dir/' 208 | upload_oss(oss_parameters, model_fname, aux_path) 209 | upload_oss(oss_parameters, text_model_fname, aux_path) 210 | upload_oss(oss_parameters, feature_importance, aux_path) 211 | else: 212 | raise Exception("fail to generate model") 213 | return False 214 | 215 | return True 216 | 217 | 218 | def upload_oss(kw, local_file, oss_path): 219 | """ 220 | help function to upload a model to oss 221 | :param kw: OSS parameter 222 | :param local_file: local place of model 223 | :param oss_path: remote place of OSS 224 | :return: True if the procedure is success 225 | """ 226 | if oss_path[-1] == '/': 227 | oss_path = '%s%s' % (oss_path, os.path.basename(local_file)) 228 | 229 | auth = oss2.Auth(kw['access_id'], kw['access_key']) 230 | bucket = kw['access_bucket'] 231 | bkt = oss2.Bucket(auth=auth, endpoint=kw['endpoint'], bucket_name=bucket) 232 | 233 | try: 234 | bkt.put_object_from_file(key=oss_path, filename=local_file) 235 | logger.info("upload %s to %s successfully!" % 236 | (os.path.abspath(local_file), oss_path)) 237 | except Exception(): 238 | raise ValueError('upload %s to %s failed' % 239 | (os.path.abspath(local_file), oss_path)) 240 | 241 | 242 | def read_model_from_oss(kw): 243 | """ 244 | helper function to read a model from oss 245 | :param kw: OSS parameter 246 | :return: XGBoost booster model 247 | """ 248 | auth = oss2.Auth(kw['access_id'], kw['access_key']) 249 | bucket = kw['access_bucket'] 250 | bkt = oss2.Bucket(auth=auth, endpoint=kw['endpoint'], bucket_name=bucket) 251 | oss_path = kw["path"] 252 | 253 | temp_model_fname = os.path.join(tempfile.mkdtemp(), 'local_model') 254 | try: 255 | bkt.get_object_to_file(key=oss_path, filename=temp_model_fname) 256 | logger.info("success to load model from oss %s", oss_path) 257 | except Exception as e: 258 | logging.error("fail to load model: " + e) 259 | raise Exception("fail to load model from oss %s", oss_path) 260 | 261 | bst = xgb.Booster({'nthread': 2}) # init model 262 | 263 | bst.load_model(temp_model_fname) 264 | 265 | return bst 266 | 267 | 268 | def parse_parameters(input, splitter_between, splitter_in): 269 | """ 270 | helper function parse the input parameter 271 | :param input: the string of configuration like key-value pairs 272 | :param splitter_between: the splitter between config for input string 273 | :param splitter_in: the splitter inside config for input string 274 | :return: key-value pair configuration 275 | """ 276 | 277 | ky_pairs = input.split(splitter_between) 278 | 279 | confs = {} 280 | 281 | for kv in ky_pairs: 282 | conf = kv.split(splitter_in) 283 | key = conf[0].strip(" ") 284 | if key == "objective" or key == "endpoint": 285 | value = conf[1].strip("'") + ":" + conf[2].strip("'") 286 | else: 287 | value = conf[1] 288 | 289 | confs[key] = value 290 | return confs 291 | 292 | -------------------------------------------------------------------------------- /config/samples/xgboost-dist/xgboostjob_v1_iris_predict.yaml: -------------------------------------------------------------------------------- 1 | apiVersion: "xgboostjob.kubeflow.org/v1" 2 | kind: "XGBoostJob" 3 | metadata: 4 | name: "xgboost-dist-iris-test-predict" 5 | spec: 6 | xgbReplicaSpecs: 7 | Master: 8 | replicas: 1 9 | restartPolicy: Never 10 | template: 11 | spec: 12 | containers: 13 | - name: xgboostjob 14 | image: docker.io/merlintang/xgboost-dist-iris:1.1 15 | ports: 16 | - containerPort: 9991 17 | name: xgboostjob-port 18 | imagePullPolicy: Always 19 | args: 20 | - --job_type=Predict 21 | - --model_path=autoAI/xgb-opt/2 22 | - --model_storage_type=oss 23 | - --oss_param=unknown 24 | Worker: 25 | replicas: 2 26 | restartPolicy: ExitCode 27 | template: 28 | spec: 29 | containers: 30 | - name: xgboostjob 31 | image: docker.io/merlintang/xgboost-dist-iris:1.1 32 | ports: 33 | - containerPort: 9991 34 | name: xgboostjob-port 35 | imagePullPolicy: Always 36 | args: 37 | - --job_type=Predict 38 | - --model_path=autoAI/xgb-opt/2 39 | - --model_storage_type=oss 40 | - --oss_param=unknown 41 | 42 | 43 | -------------------------------------------------------------------------------- /config/samples/xgboost-dist/xgboostjob_v1_iris_train.yaml: -------------------------------------------------------------------------------- 1 | apiVersion: "xgboostjob.kubeflow.org/v1" 2 | kind: "XGBoostJob" 3 | metadata: 4 | name: "xgboost-dist-iris-test-train" 5 | spec: 6 | xgbReplicaSpecs: 7 | Master: 8 | replicas: 1 9 | restartPolicy: Never 10 | template: 11 | spec: 12 | containers: 13 | - name: xgboostjob 14 | image: docker.io/merlintang/xgboost-dist-iris:1.1 15 | ports: 16 | - containerPort: 9991 17 | name: xgboostjob-port 18 | imagePullPolicy: Always 19 | args: 20 | - --job_type=Train 21 | - --xgboost_parameter=objective:multi:softprob,num_class:3 22 | - --n_estimators=10 23 | - --learning_rate=0.1 24 | - --model_path=/tmp/xgboost-model 25 | - --model_storage_type=local 26 | Worker: 27 | replicas: 2 28 | restartPolicy: ExitCode 29 | template: 30 | spec: 31 | containers: 32 | - name: xgboostjob 33 | image: docker.io/merlintang/xgboost-dist-iris:1.1 34 | ports: 35 | - containerPort: 9991 36 | name: xgboostjob-port 37 | imagePullPolicy: Always 38 | args: 39 | - --job_type=Train 40 | - --xgboost_parameter="objective:multi:softprob,num_class:3" 41 | - --n_estimators=10 42 | - --learning_rate=0.1 43 | 44 | 45 | -------------------------------------------------------------------------------- /config/samples/xgboost-dist/xgboostjob_v1alpha1_iris_predict.yaml: -------------------------------------------------------------------------------- 1 | apiVersion: "xgboostjob.kubeflow.org/v1alpha1" 2 | kind: "XGBoostJob" 3 | metadata: 4 | name: "xgboost-dist-iris-test-predict" 5 | spec: 6 | xgbReplicaSpecs: 7 | Master: 8 | replicas: 1 9 | restartPolicy: Never 10 | template: 11 | apiVersion: v1 12 | kind: Pod 13 | spec: 14 | containers: 15 | - name: xgboostjob 16 | image: docker.io/merlintang/xgboost-dist-iris:1.1 17 | ports: 18 | - containerPort: 9991 19 | name: xgboostjob-port 20 | imagePullPolicy: Always 21 | args: 22 | - --job_type=Predict 23 | - --model_path=autoAI/xgb-opt/2 24 | - --model_storage_type=oss 25 | - --oss_param=unknown 26 | Worker: 27 | replicas: 2 28 | restartPolicy: ExitCode 29 | template: 30 | apiVersion: v1 31 | kind: Pod 32 | spec: 33 | containers: 34 | - name: xgboostjob 35 | image: docker.io/merlintang/xgboost-dist-iris:1.1 36 | ports: 37 | - containerPort: 9991 38 | name: xgboostjob-port 39 | imagePullPolicy: Always 40 | args: 41 | - --job_type=Predict 42 | - --model_path=autoAI/xgb-opt/2 43 | - --model_storage_type=oss 44 | - --oss_param=unknown 45 | 46 | 47 | -------------------------------------------------------------------------------- /config/samples/xgboost-dist/xgboostjob_v1alpha1_iris_predict_local.yaml: -------------------------------------------------------------------------------- 1 | apiVersion: "xgboostjob.kubeflow.org/v1alpha1" 2 | kind: "XGBoostJob" 3 | metadata: 4 | name: "xgboost-dist-iris-test-predict-local" 5 | spec: 6 | xgbReplicaSpecs: 7 | Master: 8 | replicas: 1 9 | restartPolicy: Never 10 | template: 11 | apiVersion: v1 12 | kind: Pod 13 | spec: 14 | volumes: 15 | - name: task-pv-storage 16 | persistentVolumeClaim: 17 | claimName: xgboostlocal 18 | containers: 19 | - name: xgboostjob 20 | image: docker.io/merlintang/xgboost-dist-iris:1.1 21 | volumeMounts: 22 | - name: task-pv-storage 23 | mountPath: /tmp/xgboost_model 24 | ports: 25 | - containerPort: 9991 26 | name: xgboostjob-port 27 | imagePullPolicy: Always 28 | args: 29 | - --job_type=Predict 30 | - --model_path=/tmp/xgboost_model/2 31 | - --model_storage_type=local 32 | Worker: 33 | replicas: 2 34 | restartPolicy: ExitCode 35 | template: 36 | apiVersion: v1 37 | kind: Pod 38 | spec: 39 | volumes: 40 | - name: task-pv-storage 41 | persistentVolumeClaim: 42 | claimName: xgboostlocal 43 | containers: 44 | - name: xgboostjob 45 | image: docker.io/merlintang/xgboost-dist-iris:1.1 46 | volumeMounts: 47 | - name: task-pv-storage 48 | mountPath: /tmp/xgboost_model 49 | ports: 50 | - containerPort: 9991 51 | name: xgboostjob-port 52 | imagePullPolicy: Always 53 | args: 54 | - --job_type=Predict 55 | - --model_path=/tmp/xgboost_model/2 56 | - --model_storage_type=local 57 | -------------------------------------------------------------------------------- /config/samples/xgboost-dist/xgboostjob_v1alpha1_iris_train.yaml: -------------------------------------------------------------------------------- 1 | apiVersion: "xgboostjob.kubeflow.org/v1alpha1" 2 | kind: "XGBoostJob" 3 | metadata: 4 | name: "xgboost-dist-iris-test-train" 5 | spec: 6 | xgbReplicaSpecs: 7 | Master: 8 | replicas: 1 9 | restartPolicy: Never 10 | template: 11 | apiVersion: v1 12 | kind: Pod 13 | spec: 14 | containers: 15 | - name: xgboostjob 16 | image: docker.io/merlintang/xgboost-dist-iris:1.1 17 | ports: 18 | - containerPort: 9991 19 | name: xgboostjob-port 20 | imagePullPolicy: Always 21 | args: 22 | - --job_type=Train 23 | - --xgboost_parameter=objective:multi:softprob,num_class:3 24 | - --n_estimators=10 25 | - --learning_rate=0.1 26 | - --model_path=autoAI/xgb-opt/2 27 | - --model_storage_type=oss 28 | - --oss_param=unknown 29 | Worker: 30 | replicas: 2 31 | restartPolicy: ExitCode 32 | template: 33 | apiVersion: v1 34 | kind: Pod 35 | spec: 36 | containers: 37 | - name: xgboostjob 38 | image: docker.io/merlintang/xgboost-dist-iris:1.1 39 | ports: 40 | - containerPort: 9991 41 | name: xgboostjob-port 42 | imagePullPolicy: Always 43 | args: 44 | - --job_type=Train 45 | - --xgboost_parameter="objective:multi:softprob,num_class:3" 46 | - --n_estimators=10 47 | - --learning_rate=0.1 48 | 49 | 50 | -------------------------------------------------------------------------------- /config/samples/xgboost-dist/xgboostjob_v1alpha1_iris_train_local.yaml: -------------------------------------------------------------------------------- 1 | apiVersion: "xgboostjob.kubeflow.org/v1alpha1" 2 | kind: "XGBoostJob" 3 | metadata: 4 | name: "xgboost-dist-iris-test-train-local" 5 | spec: 6 | xgbReplicaSpecs: 7 | Master: 8 | replicas: 1 9 | restartPolicy: Never 10 | template: 11 | apiVersion: v1 12 | kind: Pod 13 | spec: 14 | volumes: 15 | - name: task-pv-storage 16 | persistentVolumeClaim: 17 | claimName: xgboostlocal 18 | containers: 19 | - name: xgboostjob 20 | image: docker.io/merlintang/xgboost-dist-iris:1.1 21 | volumeMounts: 22 | - name: task-pv-storage 23 | mountPath: /tmp/xgboost_model 24 | ports: 25 | - containerPort: 9991 26 | name: xgboostjob-port 27 | imagePullPolicy: Always 28 | args: 29 | - --job_type=Train 30 | - --xgboost_parameter=objective:multi:softprob,num_class:3 31 | - --n_estimators=10 32 | - --learning_rate=0.1 33 | - --model_path=/tmp/xgboost_model/2 34 | - --model_storage_type=local 35 | Worker: 36 | replicas: 2 37 | restartPolicy: ExitCode 38 | template: 39 | apiVersion: v1 40 | kind: Pod 41 | spec: 42 | volumes: 43 | - name: task-pv-storage 44 | persistentVolumeClaim: 45 | claimName: xgboostlocal 46 | containers: 47 | - name: xgboostjob 48 | image: docker.io/merlintang/xgboost-dist-iris:1.1 49 | volumeMounts: 50 | - name: task-pv-storage 51 | mountPath: /tmp/xgboost_model 52 | ports: 53 | - containerPort: 9991 54 | name: xgboostjob-port 55 | imagePullPolicy: Always 56 | args: 57 | - --job_type=Train 58 | - --xgboost_parameter="objective:multi:softprob,num_class:3" 59 | - --n_estimators=10 60 | - --learning_rate=0.1 61 | - --model_path=/tmp/xgboost_model/2 62 | - --model_storage_type=local 63 | -------------------------------------------------------------------------------- /go.mod: -------------------------------------------------------------------------------- 1 | module github.com/kubeflow/xgboost-operator 2 | 3 | go 1.12 4 | 5 | require ( 6 | cloud.google.com/go v0.39.0 // indirect 7 | github.com/go-logr/zapr v0.1.1 // indirect 8 | github.com/kubeflow/common v0.3.1 9 | github.com/sirupsen/logrus v1.4.2 10 | go.uber.org/atomic v1.4.0 // indirect 11 | go.uber.org/zap v1.10.0 // indirect 12 | google.golang.org/appengine v1.6.0 // indirect 13 | gopkg.in/square/go-jose.v2 v2.3.1 // indirect 14 | k8s.io/api v0.16.9 15 | k8s.io/apimachinery v0.16.9 16 | k8s.io/client-go v0.16.9 17 | sigs.k8s.io/controller-runtime v0.4.0 18 | volcano.sh/volcano v0.4.0 19 | ) 20 | 21 | replace ( 22 | k8s.io/api => k8s.io/api v0.16.9 23 | k8s.io/apiextensions-apiserver => k8s.io/apiextensions-apiserver v0.16.9 24 | k8s.io/apimachinery => k8s.io/apimachinery v0.16.10-beta.0 25 | k8s.io/apiserver => k8s.io/apiserver v0.16.9 26 | k8s.io/cli-runtime => k8s.io/cli-runtime v0.16.9 27 | k8s.io/client-go => k8s.io/client-go v0.16.9 28 | k8s.io/cloud-provider => k8s.io/cloud-provider v0.16.9 29 | k8s.io/cluster-bootstrap => k8s.io/cluster-bootstrap v0.16.9 30 | k8s.io/code-generator => k8s.io/code-generator v0.16.10-beta.0 31 | k8s.io/component-base => k8s.io/component-base v0.16.9 32 | k8s.io/cri-api => k8s.io/cri-api v0.16.10-beta.0 33 | k8s.io/csi-translation-lib => k8s.io/csi-translation-lib v0.16.9 34 | k8s.io/kube-aggregator => k8s.io/kube-aggregator v0.16.9 35 | k8s.io/kube-controller-manager => k8s.io/kube-controller-manager v0.16.9 36 | k8s.io/kube-proxy => k8s.io/kube-proxy v0.16.9 37 | k8s.io/kube-scheduler => k8s.io/kube-scheduler v0.16.9 38 | k8s.io/kubectl => k8s.io/kubectl v0.16.9 39 | k8s.io/kubelet => k8s.io/kubelet v0.16.9 40 | k8s.io/legacy-cloud-providers => k8s.io/legacy-cloud-providers v0.16.9 41 | k8s.io/metrics => k8s.io/metrics v0.16.9 42 | k8s.io/node-api => k8s.io/node-api v0.16.9 43 | k8s.io/sample-apiserver => k8s.io/sample-apiserver v0.16.9 44 | k8s.io/sample-cli-plugin => k8s.io/sample-cli-plugin v0.16.9 45 | k8s.io/sample-controller => k8s.io/sample-controller v0.16.9 46 | ) 47 | -------------------------------------------------------------------------------- /hack/boilerplate.go.txt: -------------------------------------------------------------------------------- 1 | /* 2 | 3 | Licensed under the Apache License, Version 2.0 (the "License"); 4 | you may not use this file except in compliance with the License. 5 | You may obtain a copy of the License at 6 | 7 | http://www.apache.org/licenses/LICENSE-2.0 8 | 9 | Unless required by applicable law or agreed to in writing, software 10 | distributed under the License is distributed on an "AS IS" BASIS, 11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | See the License for the specific language governing permissions and 13 | limitations under the License. 14 | */ -------------------------------------------------------------------------------- /manifests/base/cluster-role-binding.yaml: -------------------------------------------------------------------------------- 1 | apiVersion: rbac.authorization.k8s.io/v1 2 | kind: ClusterRoleBinding 3 | metadata: 4 | name: cluster-role-binding 5 | roleRef: 6 | apiGroup: rbac.authorization.k8s.io 7 | kind: ClusterRole 8 | name: cluster-role 9 | subjects: 10 | - kind: ServiceAccount 11 | name: service-account 12 | -------------------------------------------------------------------------------- /manifests/base/cluster-role.yaml: -------------------------------------------------------------------------------- 1 | apiVersion: rbac.authorization.k8s.io/v1 2 | kind: ClusterRole 3 | metadata: 4 | name: cluster-role 5 | rules: 6 | - apiGroups: 7 | - apps 8 | resources: 9 | - deployments 10 | - deployments/status 11 | verbs: 12 | - get 13 | - list 14 | - watch 15 | - create 16 | - update 17 | - patch 18 | - delete 19 | - apiGroups: 20 | - xgboostjob.kubeflow.org 21 | resources: 22 | - xgboostjobs 23 | - xgboostjobs/status 24 | verbs: 25 | - get 26 | - list 27 | - watch 28 | - create 29 | - update 30 | - patch 31 | - delete 32 | - apiGroups: 33 | - admissionregistration.k8s.io 34 | resources: 35 | - mutatingwebhookconfigurations 36 | - validatingwebhookconfigurations 37 | verbs: 38 | - get 39 | - list 40 | - watch 41 | - create 42 | - update 43 | - patch 44 | - delete 45 | - apiGroups: 46 | - "" 47 | resources: 48 | - configmaps 49 | - endpoints 50 | - events 51 | - namespaces 52 | - persistentvolumeclaims 53 | - pods 54 | - secrets 55 | - services 56 | verbs: 57 | - get 58 | - list 59 | - watch 60 | - create 61 | - update 62 | - patch 63 | - delete 64 | - apiGroups: 65 | - storage.k8s.io 66 | resources: 67 | - storageclasses 68 | verbs: 69 | - get 70 | - list 71 | - watch 72 | - create 73 | - update 74 | - patch 75 | - delete 76 | -------------------------------------------------------------------------------- /manifests/base/deployment.yaml: -------------------------------------------------------------------------------- 1 | apiVersion: apps/v1 2 | kind: Deployment 3 | metadata: 4 | name: deployment 5 | spec: 6 | replicas: 1 7 | selector: 8 | matchLabels: 9 | app: xgboost-operator 10 | template: 11 | metadata: 12 | labels: 13 | app: xgboost-operator 14 | spec: 15 | containers: 16 | - name: xgboost-operator 17 | command: 18 | - /root/manager 19 | - -mode=in-cluster 20 | image: gcr.io/kubeflow-images-public/xgboost-operator:v0.1.0 21 | imagePullPolicy: Always 22 | serviceAccountName: service-account 23 | -------------------------------------------------------------------------------- /manifests/base/kustomization.yaml: -------------------------------------------------------------------------------- 1 | apiVersion: kustomize.config.k8s.io/v1beta1 2 | kind: Kustomization 3 | resources: 4 | - cluster-role.yaml 5 | - cluster-role-binding.yaml 6 | - crd.yaml 7 | - deployment.yaml 8 | - service-account.yaml 9 | - service.yaml 10 | namespace: kubeflow 11 | namePrefix: xgboost-operator- 12 | configMapGenerator: 13 | - envs: 14 | - params.env 15 | name: xgboost-operator-config 16 | images: 17 | - name: gcr.io/kubeflow-images-public/xgboost-operator 18 | newName: kubeflow/xgboost-operator 19 | newTag: v0.2.0 20 | -------------------------------------------------------------------------------- /manifests/base/params.env: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kubeflow/xgboost-operator/17383e882bd730103d978af310145d0a45437c74/manifests/base/params.env -------------------------------------------------------------------------------- /manifests/base/service-account.yaml: -------------------------------------------------------------------------------- 1 | apiVersion: v1 2 | kind: ServiceAccount 3 | metadata: 4 | name: service-account 5 | -------------------------------------------------------------------------------- /manifests/base/service.yaml: -------------------------------------------------------------------------------- 1 | apiVersion: v1 2 | kind: Service 3 | metadata: 4 | name: service 5 | annotations: 6 | prometheus.io/path: /metrics 7 | prometheus.io/scrape: "true" 8 | prometheus.io/port: "8080" 9 | labels: 10 | app: xgboost-operator 11 | spec: 12 | type: ClusterIP 13 | selector: 14 | app: xgboost-operator 15 | ports: 16 | - port: 443 17 | -------------------------------------------------------------------------------- /manifests/overlays/kubeflow/kustomization.yaml: -------------------------------------------------------------------------------- 1 | apiVersion: kustomize.config.k8s.io/v1beta1 2 | kind: Kustomization 3 | resources: 4 | - ../../base 5 | namespace: kubeflow 6 | commonLabels: 7 | app.kubernetes.io/component: xgboostjob 8 | app.kubernetes.io/name: xgboost-operator 9 | -------------------------------------------------------------------------------- /pkg/apis/addtoscheme_xgboostjob.go: -------------------------------------------------------------------------------- 1 | /* 2 | 3 | Licensed under the Apache License, Version 2.0 (the "License"); 4 | you may not use this file except in compliance with the License. 5 | You may obtain a copy of the License at 6 | 7 | http://www.apache.org/licenses/LICENSE-2.0 8 | 9 | Unless required by applicable law or agreed to in writing, software 10 | distributed under the License is distributed on an "AS IS" BASIS, 11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | See the License for the specific language governing permissions and 13 | limitations under the License. 14 | */ 15 | 16 | package apis 17 | 18 | import ( 19 | "github.com/kubeflow/xgboost-operator/pkg/apis/xgboostjob/v1" 20 | ) 21 | 22 | func init() { 23 | // Register the types with the Scheme so the components can map objects to GroupVersionKinds and back 24 | AddToSchemes = append(AddToSchemes, v1.SchemeBuilder.AddToScheme) 25 | } 26 | -------------------------------------------------------------------------------- /pkg/apis/apis.go: -------------------------------------------------------------------------------- 1 | /* 2 | 3 | Licensed under the Apache License, Version 2.0 (the "License"); 4 | you may not use this file except in compliance with the License. 5 | You may obtain a copy of the License at 6 | 7 | http://www.apache.org/licenses/LICENSE-2.0 8 | 9 | Unless required by applicable law or agreed to in writing, software 10 | distributed under the License is distributed on an "AS IS" BASIS, 11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | See the License for the specific language governing permissions and 13 | limitations under the License. 14 | */ 15 | 16 | // Generate deepcopy for apis 17 | // go:generate go run ../../vendor/k8s.io/code-generator/cmd/deepcopy-gen/main.go -O zz_generated.deepcopy -i ./... -h ../../hack/boilerplate.go.txt 18 | 19 | // Package apis contains Kubernetes API groups. 20 | package apis 21 | 22 | import ( 23 | "k8s.io/apimachinery/pkg/runtime" 24 | ) 25 | 26 | // AddToSchemes may be used to add all resources defined in the project to a Scheme 27 | var AddToSchemes runtime.SchemeBuilder 28 | 29 | // AddToScheme adds all Resources to the Scheme 30 | func AddToScheme(s *runtime.Scheme) error { 31 | return AddToSchemes.AddToScheme(s) 32 | } 33 | -------------------------------------------------------------------------------- /pkg/apis/xgboostjob/group.go: -------------------------------------------------------------------------------- 1 | /* 2 | 3 | Licensed under the Apache License, Version 2.0 (the "License"); 4 | you may not use this file except in compliance with the License. 5 | You may obtain a copy of the License at 6 | 7 | http://www.apache.org/licenses/LICENSE-2.0 8 | 9 | Unless required by applicable law or agreed to in writing, software 10 | distributed under the License is distributed on an "AS IS" BASIS, 11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | See the License for the specific language governing permissions and 13 | limitations under the License. 14 | */ 15 | 16 | // Package xgboostjob contains xgboostjob API versions 17 | package xgboostjob 18 | -------------------------------------------------------------------------------- /pkg/apis/xgboostjob/v1/constants.go: -------------------------------------------------------------------------------- 1 | /* 2 | 3 | Licensed under the Apache License, Version 2.0 (the "License"); 4 | you may not use this file except in compliance with the License. 5 | You may obtain a copy of the License at 6 | 7 | http://www.apache.org/licenses/LICENSE-2.0 8 | 9 | Unless required by applicable law or agreed to in writing, software 10 | distributed under the License is distributed on an "AS IS" BASIS, 11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | See the License for the specific language governing permissions and 13 | limitations under the License. 14 | */ 15 | 16 | package v1 17 | 18 | const ( 19 | // GroupName is the group name use in this package. 20 | GroupName = "kubeflow.org" 21 | // Kind is the kind name. 22 | Kind = "XGBoostJob" 23 | // GroupVersion is the version. 24 | GroupVersion = "v1" 25 | // Plural is the Plural for XGBoostJob. 26 | Plural = "xgboostjobs" 27 | // Singular is the singular for XGBoostJob. 28 | Singular = "xgboostjob" 29 | // XGBOOSTCRD is the CRD name for XGBoostJob. 30 | XGBoostCRD = "xgboostjobs.kubeflow.org" 31 | 32 | DefaultContainerName = "xgboostjob" 33 | DefaultContainerPortName = "xgboostjob-port" 34 | DefaultPort = 9999 35 | ) 36 | -------------------------------------------------------------------------------- /pkg/apis/xgboostjob/v1/doc.go: -------------------------------------------------------------------------------- 1 | /* 2 | 3 | Licensed under the Apache License, Version 2.0 (the "License"); 4 | you may not use this file except in compliance with the License. 5 | You may obtain a copy of the License at 6 | 7 | http://www.apache.org/licenses/LICENSE-2.0 8 | 9 | Unless required by applicable law or agreed to in writing, software 10 | distributed under the License is distributed on an "AS IS" BASIS, 11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | See the License for the specific language governing permissions and 13 | limitations under the License. 14 | */ 15 | 16 | // Package v1 contains API Schema definitions for the xgboostjob v1 API group 17 | // +k8s:openapi-gen=true 18 | // +k8s:deepcopy-gen=package,register 19 | // +k8s:conversion-gen=github.com/kubeflow/xgboost-operator/pkg/apis/xgboostjob 20 | // +k8s:defaulter-gen=TypeMeta 21 | // +groupName=xgboostjob.kubeflow.org 22 | package v1 23 | -------------------------------------------------------------------------------- /pkg/apis/xgboostjob/v1/register.go: -------------------------------------------------------------------------------- 1 | /* 2 | 3 | Licensed under the Apache License, Version 2.0 (the "License"); 4 | you may not use this file except in compliance with the License. 5 | You may obtain a copy of the License at 6 | 7 | http://www.apache.org/licenses/LICENSE-2.0 8 | 9 | Unless required by applicable law or agreed to in writing, software 10 | distributed under the License is distributed on an "AS IS" BASIS, 11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | See the License for the specific language governing permissions and 13 | limitations under the License. 14 | */ 15 | 16 | // NOTE: Boilerplate only. Ignore this file. 17 | 18 | // Package v1 contains API Schema definitions for the xgboostjob v1 API group 19 | // +k8s:openapi-gen=true 20 | // +k8s:deepcopy-gen=package,register 21 | // +k8s:conversion-gen=github.com/kubeflow/xgboost-operator/pkg/apis/xgboostjob 22 | // +k8s:defaulter-gen=TypeMeta 23 | // +groupName=xgboostjob.kubeflow.org 24 | package v1 25 | 26 | import ( 27 | "k8s.io/apimachinery/pkg/runtime/schema" 28 | "sigs.k8s.io/controller-runtime/pkg/runtime/scheme" 29 | ) 30 | 31 | var ( 32 | // SchemeGroupVersion is group version used to register these objects 33 | SchemeGroupVersion = schema.GroupVersion{Group: "xgboostjob.kubeflow.org", Version: "v1"} 34 | 35 | // SchemeBuilder is used to add go types to the GroupVersionKind scheme 36 | SchemeBuilder = &scheme.Builder{GroupVersion: SchemeGroupVersion} 37 | 38 | // SchemeGroupVersionKind is the GroupVersionKind of the resource. 39 | SchemeGroupVersionKind = SchemeGroupVersion.WithKind(Kind) 40 | 41 | // AddToScheme is required by pkg/client/... 42 | AddToScheme = SchemeBuilder.AddToScheme 43 | ) 44 | 45 | // Resource is required by pkg/client/listers/... 46 | func Resource(resource string) schema.GroupResource { 47 | return SchemeGroupVersion.WithResource(resource).GroupResource() 48 | } 49 | -------------------------------------------------------------------------------- /pkg/apis/xgboostjob/v1/v1_suite_test.go.back: -------------------------------------------------------------------------------- 1 | /* 2 | 3 | Licensed under the Apache License, Version 2.0 (the "License"); 4 | you may not use this file except in compliance with the License. 5 | You may obtain a copy of the License at 6 | 7 | http://www.apache.org/licenses/LICENSE-2.0 8 | 9 | Unless required by applicable law or agreed to in writing, software 10 | distributed under the License is distributed on an "AS IS" BASIS, 11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | See the License for the specific language governing permissions and 13 | limitations under the License. 14 | */ 15 | 16 | package v1 17 | 18 | import ( 19 | "log" 20 | "os" 21 | "path/filepath" 22 | "testing" 23 | 24 | "k8s.io/client-go/kubernetes/scheme" 25 | "k8s.io/client-go/rest" 26 | "sigs.k8s.io/controller-runtime/pkg/client" 27 | "sigs.k8s.io/controller-runtime/pkg/envtest" 28 | ) 29 | 30 | var cfg *rest.Config 31 | var c client.Client 32 | 33 | func TestMain(m *testing.M) { 34 | t := &envtest.Environment{ 35 | CRDDirectoryPaths: []string{filepath.Join("..", "..", "..", "..", "config", "crds")}, 36 | } 37 | 38 | err := SchemeBuilder.AddToScheme(scheme.Scheme) 39 | if err != nil { 40 | log.Fatal(err) 41 | } 42 | 43 | if cfg, err = t.Start(); err != nil { 44 | log.Fatal(err) 45 | } 46 | 47 | if c, err = client.New(cfg, client.Options{Scheme: scheme.Scheme}); err != nil { 48 | log.Fatal(err) 49 | } 50 | 51 | code := m.Run() 52 | t.Stop() 53 | os.Exit(code) 54 | } 55 | -------------------------------------------------------------------------------- /pkg/apis/xgboostjob/v1/xgboostjob_types.go: -------------------------------------------------------------------------------- 1 | /* 2 | 3 | Licensed under the Apache License, Version 2.0 (the "License"); 4 | you may not use this file except in compliance with the License. 5 | You may obtain a copy of the License at 6 | 7 | http://www.apache.org/licenses/LICENSE-2.0 8 | 9 | Unless required by applicable law or agreed to in writing, software 10 | distributed under the License is distributed on an "AS IS" BASIS, 11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | See the License for the specific language governing permissions and 13 | limitations under the License. 14 | */ 15 | 16 | package v1 17 | 18 | import ( 19 | commonv1 "github.com/kubeflow/common/pkg/apis/common/v1" 20 | metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" 21 | ) 22 | 23 | // EDIT THIS FILE! THIS IS SCAFFOLDING FOR YOU TO OWN! 24 | // NOTE: json tags are required. Any new fields you add must have json tags for the fields to be serialized. 25 | 26 | // XGBoostJobSpec defines the desired state of XGBoostJob 27 | type XGBoostJobSpec struct { 28 | // INSERT ADDITIONAL SPEC FIELDS - desired state of cluster 29 | // Important: Run "make" to regenerate code after modifying this file 30 | RunPolicy commonv1.RunPolicy `json:",inline"` 31 | 32 | XGBReplicaSpecs map[commonv1.ReplicaType]*commonv1.ReplicaSpec `json:"xgbReplicaSpecs"` 33 | } 34 | 35 | // XGBoostJobStatus defines the observed state of XGBoostJob 36 | type XGBoostJobStatus struct { 37 | // INSERT ADDITIONAL STATUS FIELD - define observed state of cluster 38 | // Important: Run "make" to regenerate code after modifying this file 39 | commonv1.JobStatus `json:",inline"` 40 | } 41 | 42 | // +genclient 43 | // +k8s:deepcopy-gen:interfaces=k8s.io/apimachinery/pkg/runtime.Object 44 | 45 | // XGBoostJob is the Schema for the xgboostjobs API 46 | // +k8s:openapi-gen=true 47 | type XGBoostJob struct { 48 | metav1.TypeMeta `json:",inline"` 49 | metav1.ObjectMeta `json:"metadata,omitempty"` 50 | 51 | Spec XGBoostJobSpec `json:"spec,omitempty"` 52 | Status XGBoostJobStatus `json:"status,omitempty"` 53 | } 54 | 55 | // +k8s:deepcopy-gen:interfaces=k8s.io/apimachinery/pkg/runtime.Object 56 | 57 | // XGBoostJobList contains a list of XGBoostJob 58 | type XGBoostJobList struct { 59 | metav1.TypeMeta `json:",inline"` 60 | metav1.ListMeta `json:"metadata,omitempty"` 61 | Items []XGBoostJob `json:"items"` 62 | } 63 | 64 | // XGBoostJobReplicaType is the type for XGBoostJobReplica. 65 | type XGBoostJobReplicaType commonv1.ReplicaType 66 | 67 | const ( 68 | // XGBoostReplicaTypeMaster is the type for master replica. 69 | XGBoostReplicaTypeMaster XGBoostJobReplicaType = "Master" 70 | 71 | // XGBoostReplicaTypeWorker is the type for worker replicas. 72 | XGBoostReplicaTypeWorker XGBoostJobReplicaType = "Worker" 73 | ) 74 | 75 | func init() { 76 | SchemeBuilder.Register(&XGBoostJob{}, &XGBoostJobList{}) 77 | } 78 | -------------------------------------------------------------------------------- /pkg/apis/xgboostjob/v1/xgboostjob_types_test.go.back: -------------------------------------------------------------------------------- 1 | /* 2 | 3 | Licensed under the Apache License, Version 2.0 (the "License"); 4 | you may not use this file except in compliance with the License. 5 | You may obtain a copy of the License at 6 | 7 | http://www.apache.org/licenses/LICENSE-2.0 8 | 9 | Unless required by applicable law or agreed to in writing, software 10 | distributed under the License is distributed on an "AS IS" BASIS, 11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | See the License for the specific language governing permissions and 13 | limitations under the License. 14 | */ 15 | 16 | package v1 17 | 18 | import ( 19 | "testing" 20 | 21 | "github.com/onsi/gomega" 22 | "golang.org/x/net/context" 23 | metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" 24 | "k8s.io/apimachinery/pkg/types" 25 | ) 26 | 27 | func TestStorageXGBoostJob(t *testing.T) { 28 | key := types.NamespacedName{ 29 | Name: "foo", 30 | Namespace: "default", 31 | } 32 | created := &XGBoostJob{ 33 | ObjectMeta: metav1.ObjectMeta{ 34 | Name: "foo", 35 | Namespace: "default", 36 | }} 37 | g := gomega.NewGomegaWithT(t) 38 | 39 | // Test Create 40 | fetched := &XGBoostJob{} 41 | g.Expect(c.Create(context.TODO(), created)).NotTo(gomega.HaveOccurred()) 42 | 43 | g.Expect(c.Get(context.TODO(), key, fetched)).NotTo(gomega.HaveOccurred()) 44 | g.Expect(fetched).To(gomega.Equal(created)) 45 | 46 | // Test Updating the Labels 47 | updated := fetched.DeepCopy() 48 | updated.Labels = map[string]string{"hello": "world"} 49 | g.Expect(c.Update(context.TODO(), updated)).NotTo(gomega.HaveOccurred()) 50 | 51 | g.Expect(c.Get(context.TODO(), key, fetched)).NotTo(gomega.HaveOccurred()) 52 | g.Expect(fetched).To(gomega.Equal(updated)) 53 | 54 | // Test Delete 55 | g.Expect(c.Delete(context.TODO(), fetched)).NotTo(gomega.HaveOccurred()) 56 | g.Expect(c.Get(context.TODO(), key, fetched)).To(gomega.HaveOccurred()) 57 | } 58 | -------------------------------------------------------------------------------- /pkg/apis/xgboostjob/v1/zz_generated.deepcopy.go: -------------------------------------------------------------------------------- 1 | // +build !ignore_autogenerated 2 | 3 | /* 4 | 5 | Licensed under the Apache License, Version 2.0 (the "License"); 6 | you may not use this file except in compliance with the License. 7 | You may obtain a copy of the License at 8 | 9 | http://www.apache.org/licenses/LICENSE-2.0 10 | 11 | Unless required by applicable law or agreed to in writing, software 12 | distributed under the License is distributed on an "AS IS" BASIS, 13 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | See the License for the specific language governing permissions and 15 | limitations under the License. 16 | */ 17 | 18 | // Code generated by controller-gen. DO NOT EDIT. 19 | 20 | package v1 21 | 22 | import ( 23 | commonv1 "github.com/kubeflow/common/pkg/apis/common/v1" 24 | runtime "k8s.io/apimachinery/pkg/runtime" 25 | ) 26 | 27 | // DeepCopyInto is an autogenerated deepcopy function, copying the receiver, writing into out. in must be non-nil. 28 | func (in *XGBoostJob) DeepCopyInto(out *XGBoostJob) { 29 | *out = *in 30 | out.TypeMeta = in.TypeMeta 31 | in.ObjectMeta.DeepCopyInto(&out.ObjectMeta) 32 | in.Spec.DeepCopyInto(&out.Spec) 33 | in.Status.DeepCopyInto(&out.Status) 34 | } 35 | 36 | // DeepCopy is an autogenerated deepcopy function, copying the receiver, creating a new XGBoostJob. 37 | func (in *XGBoostJob) DeepCopy() *XGBoostJob { 38 | if in == nil { 39 | return nil 40 | } 41 | out := new(XGBoostJob) 42 | in.DeepCopyInto(out) 43 | return out 44 | } 45 | 46 | // DeepCopyObject is an autogenerated deepcopy function, copying the receiver, creating a new runtime.Object. 47 | func (in *XGBoostJob) DeepCopyObject() runtime.Object { 48 | if c := in.DeepCopy(); c != nil { 49 | return c 50 | } 51 | return nil 52 | } 53 | 54 | // DeepCopyInto is an autogenerated deepcopy function, copying the receiver, writing into out. in must be non-nil. 55 | func (in *XGBoostJobList) DeepCopyInto(out *XGBoostJobList) { 56 | *out = *in 57 | out.TypeMeta = in.TypeMeta 58 | in.ListMeta.DeepCopyInto(&out.ListMeta) 59 | if in.Items != nil { 60 | in, out := &in.Items, &out.Items 61 | *out = make([]XGBoostJob, len(*in)) 62 | for i := range *in { 63 | (*in)[i].DeepCopyInto(&(*out)[i]) 64 | } 65 | } 66 | } 67 | 68 | // DeepCopy is an autogenerated deepcopy function, copying the receiver, creating a new XGBoostJobList. 69 | func (in *XGBoostJobList) DeepCopy() *XGBoostJobList { 70 | if in == nil { 71 | return nil 72 | } 73 | out := new(XGBoostJobList) 74 | in.DeepCopyInto(out) 75 | return out 76 | } 77 | 78 | // DeepCopyObject is an autogenerated deepcopy function, copying the receiver, creating a new runtime.Object. 79 | func (in *XGBoostJobList) DeepCopyObject() runtime.Object { 80 | if c := in.DeepCopy(); c != nil { 81 | return c 82 | } 83 | return nil 84 | } 85 | 86 | // DeepCopyInto is an autogenerated deepcopy function, copying the receiver, writing into out. in must be non-nil. 87 | func (in *XGBoostJobSpec) DeepCopyInto(out *XGBoostJobSpec) { 88 | *out = *in 89 | in.RunPolicy.DeepCopyInto(&out.RunPolicy) 90 | if in.XGBReplicaSpecs != nil { 91 | in, out := &in.XGBReplicaSpecs, &out.XGBReplicaSpecs 92 | *out = make(map[commonv1.ReplicaType]*commonv1.ReplicaSpec, len(*in)) 93 | for key, val := range *in { 94 | var outVal *commonv1.ReplicaSpec 95 | if val == nil { 96 | (*out)[key] = nil 97 | } else { 98 | in, out := &val, &outVal 99 | *out = new(commonv1.ReplicaSpec) 100 | (*in).DeepCopyInto(*out) 101 | } 102 | (*out)[key] = outVal 103 | } 104 | } 105 | } 106 | 107 | // DeepCopy is an autogenerated deepcopy function, copying the receiver, creating a new XGBoostJobSpec. 108 | func (in *XGBoostJobSpec) DeepCopy() *XGBoostJobSpec { 109 | if in == nil { 110 | return nil 111 | } 112 | out := new(XGBoostJobSpec) 113 | in.DeepCopyInto(out) 114 | return out 115 | } 116 | 117 | // DeepCopyInto is an autogenerated deepcopy function, copying the receiver, writing into out. in must be non-nil. 118 | func (in *XGBoostJobStatus) DeepCopyInto(out *XGBoostJobStatus) { 119 | *out = *in 120 | in.JobStatus.DeepCopyInto(&out.JobStatus) 121 | } 122 | 123 | // DeepCopy is an autogenerated deepcopy function, copying the receiver, creating a new XGBoostJobStatus. 124 | func (in *XGBoostJobStatus) DeepCopy() *XGBoostJobStatus { 125 | if in == nil { 126 | return nil 127 | } 128 | out := new(XGBoostJobStatus) 129 | in.DeepCopyInto(out) 130 | return out 131 | } 132 | -------------------------------------------------------------------------------- /pkg/controller/v1/add_xgboostjob.go: -------------------------------------------------------------------------------- 1 | /* 2 | 3 | Licensed under the Apache License, Version 2.0 (the "License"); 4 | you may not use this file except in compliance with the License. 5 | You may obtain a copy of the License at 6 | 7 | http://www.apache.org/licenses/LICENSE-2.0 8 | 9 | Unless required by applicable law or agreed to in writing, software 10 | distributed under the License is distributed on an "AS IS" BASIS, 11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | See the License for the specific language governing permissions and 13 | limitations under the License. 14 | */ 15 | 16 | package v1 17 | 18 | import ( 19 | "github.com/kubeflow/xgboost-operator/pkg/controller/v1/xgboostjob" 20 | ) 21 | 22 | func init() { 23 | // AddToManagerFuncs is a list of functions to create controllers and add them to a manager. 24 | AddToManagerFuncs = append(AddToManagerFuncs, xgboostjob.Add) 25 | } 26 | -------------------------------------------------------------------------------- /pkg/controller/v1/controller.go: -------------------------------------------------------------------------------- 1 | /* 2 | 3 | Licensed under the Apache License, Version 2.0 (the "License"); 4 | you may not use this file except in compliance with the License. 5 | You may obtain a copy of the License at 6 | 7 | http://www.apache.org/licenses/LICENSE-2.0 8 | 9 | Unless required by applicable law or agreed to in writing, software 10 | distributed under the License is distributed on an "AS IS" BASIS, 11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | See the License for the specific language governing permissions and 13 | limitations under the License. 14 | */ 15 | 16 | package v1 17 | 18 | import ( 19 | "sigs.k8s.io/controller-runtime/pkg/manager" 20 | ) 21 | 22 | // AddToManagerFuncs is a list of functions to add all Controllers to the Manager 23 | var AddToManagerFuncs []func(manager.Manager) error 24 | 25 | // AddToManager adds all Controllers to the Manager 26 | func AddToManager(m manager.Manager) error { 27 | for _, f := range AddToManagerFuncs { 28 | if err := f(m); err != nil { 29 | return err 30 | } 31 | } 32 | return nil 33 | } 34 | -------------------------------------------------------------------------------- /pkg/controller/v1/xgboostjob/expectation.go: -------------------------------------------------------------------------------- 1 | /* 2 | 3 | Licensed under the Apache License, Version 2.0 (the "License"); 4 | you may not use this file except in compliance with the License. 5 | You may obtain a copy of the License at 6 | 7 | http://www.apache.org/licenses/LICENSE-2.0 8 | 9 | Unless required by applicable law or agreed to in writing, software 10 | distributed under the License is distributed on an "AS IS" BASIS, 11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | See the License for the specific language governing permissions and 13 | limitations under the License. 14 | */ 15 | 16 | package xgboostjob 17 | 18 | import ( 19 | "fmt" 20 | commonv1 "github.com/kubeflow/common/pkg/apis/common/v1" 21 | "github.com/kubeflow/common/pkg/controller.v1/common" 22 | "github.com/kubeflow/common/pkg/controller.v1/expectation" 23 | "github.com/kubeflow/xgboost-operator/pkg/apis/xgboostjob/v1" 24 | "github.com/sirupsen/logrus" 25 | corev1 "k8s.io/api/core/v1" 26 | metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" 27 | utilruntime "k8s.io/apimachinery/pkg/util/runtime" 28 | "sigs.k8s.io/controller-runtime/pkg/event" 29 | "sigs.k8s.io/controller-runtime/pkg/reconcile" 30 | ) 31 | 32 | // satisfiedExpectations returns true if the required adds/dels for the given job have been observed. 33 | // Add/del counts are established by the controller at sync time, and updated as controllees are observed by the controller 34 | // manager. 35 | func (r *ReconcileXGBoostJob) satisfiedExpectations(xgbJob *v1.XGBoostJob) bool { 36 | satisfied := false 37 | key, err := common.KeyFunc(xgbJob) 38 | if err != nil { 39 | utilruntime.HandleError(fmt.Errorf("couldn't get key for job object %#v: %v", xgbJob, err)) 40 | return false 41 | } 42 | for rtype := range xgbJob.Spec.XGBReplicaSpecs { 43 | // Check the expectations of the pods. 44 | expectationPodsKey := expectation.GenExpectationPodsKey(key, string(rtype)) 45 | satisfied = satisfied || r.Expectations.SatisfiedExpectations(expectationPodsKey) 46 | // Check the expectations of the services. 47 | expectationServicesKey := expectation.GenExpectationServicesKey(key, string(rtype)) 48 | satisfied = satisfied || r.Expectations.SatisfiedExpectations(expectationServicesKey) 49 | } 50 | return satisfied 51 | } 52 | 53 | // onDependentCreateFunc modify expectations when dependent (pod/service) creation observed. 54 | func onDependentCreateFunc(r reconcile.Reconciler) func(event.CreateEvent) bool { 55 | return func(e event.CreateEvent) bool { 56 | xgbr, ok := r.(*ReconcileXGBoostJob) 57 | if !ok { 58 | return true 59 | } 60 | rtype := e.Meta.GetLabels()[commonv1.ReplicaTypeLabel] 61 | if len(rtype) == 0 { 62 | return false 63 | } 64 | 65 | logrus.Info("Update on create function ", xgbr.ControllerName(), " create object ", e.Meta.GetName()) 66 | if controllerRef := metav1.GetControllerOf(e.Meta); controllerRef != nil { 67 | var expectKey string 68 | if _, ok := e.Object.(*corev1.Pod); ok { 69 | expectKey = expectation.GenExpectationPodsKey(e.Meta.GetNamespace()+"/"+controllerRef.Name, rtype) 70 | } 71 | 72 | if _, ok := e.Object.(*corev1.Service); ok { 73 | expectKey = expectation.GenExpectationServicesKey(e.Meta.GetNamespace()+"/"+controllerRef.Name, rtype) 74 | } 75 | xgbr.Expectations.CreationObserved(expectKey) 76 | return true 77 | } 78 | 79 | return true 80 | } 81 | } 82 | 83 | // onDependentDeleteFunc modify expectations when dependent (pod/service) deletion observed. 84 | func onDependentDeleteFunc(r reconcile.Reconciler) func(event.DeleteEvent) bool { 85 | return func(e event.DeleteEvent) bool { 86 | xgbr, ok := r.(*ReconcileXGBoostJob) 87 | if !ok { 88 | return true 89 | } 90 | 91 | rtype := e.Meta.GetLabels()[commonv1.ReplicaTypeLabel] 92 | if len(rtype) == 0 { 93 | return false 94 | } 95 | 96 | logrus.Info("Update on deleting function ", xgbr.ControllerName(), " delete object ", e.Meta.GetName()) 97 | if controllerRef := metav1.GetControllerOf(e.Meta); controllerRef != nil { 98 | var expectKey string 99 | if _, ok := e.Object.(*corev1.Pod); ok { 100 | expectKey = expectation.GenExpectationPodsKey(e.Meta.GetNamespace()+"/"+controllerRef.Name, rtype) 101 | } 102 | 103 | if _, ok := e.Object.(*corev1.Service); ok { 104 | expectKey = expectation.GenExpectationServicesKey(e.Meta.GetNamespace()+"/"+controllerRef.Name, rtype) 105 | } 106 | 107 | xgbr.Expectations.DeletionObserved(expectKey) 108 | return true 109 | } 110 | 111 | return true 112 | } 113 | } 114 | -------------------------------------------------------------------------------- /pkg/controller/v1/xgboostjob/job.go: -------------------------------------------------------------------------------- 1 | /* 2 | 3 | Licensed under the Apache License, Version 2.0 (the "License"); 4 | you may not use this file except in compliance with the License. 5 | You may obtain a copy of the License at 6 | 7 | http://www.apache.org/licenses/LICENSE-2.0 8 | 9 | Unless required by applicable law or agreed to in writing, software 10 | distributed under the License is distributed on an "AS IS" BASIS, 11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | See the License for the specific language governing permissions and 13 | limitations under the License. 14 | */ 15 | 16 | package xgboostjob 17 | 18 | import ( 19 | "context" 20 | "fmt" 21 | "reflect" 22 | 23 | commonv1 "github.com/kubeflow/common/pkg/apis/common/v1" 24 | commonutil "github.com/kubeflow/common/pkg/util" 25 | logger "github.com/kubeflow/common/pkg/util" 26 | v1xgboost "github.com/kubeflow/xgboost-operator/pkg/apis/xgboostjob/v1" 27 | "github.com/sirupsen/logrus" 28 | corev1 "k8s.io/api/core/v1" 29 | k8sv1 "k8s.io/api/core/v1" 30 | "k8s.io/apimachinery/pkg/api/errors" 31 | metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" 32 | "k8s.io/apimachinery/pkg/types" 33 | "k8s.io/client-go/kubernetes/scheme" 34 | "sigs.k8s.io/controller-runtime/pkg/event" 35 | "sigs.k8s.io/controller-runtime/pkg/reconcile" 36 | ) 37 | 38 | // Reasons for job events. 39 | const ( 40 | FailedDeleteJobReason = "FailedDeleteJob" 41 | SuccessfulDeleteJobReason = "SuccessfulDeleteJob" 42 | // xgboostJobCreatedReason is added in a job when it is created. 43 | xgboostJobCreatedReason = "XGBoostJobCreated" 44 | 45 | xgboostJobSucceededReason = "XGBoostJobSucceeded" 46 | xgboostJobRunningReason = "XGBoostJobRunning" 47 | xgboostJobFailedReason = "XGBoostJobFailed" 48 | xgboostJobRestartingReason = "XGBoostJobRestarting" 49 | ) 50 | 51 | // DeleteJob deletes the job 52 | func (r *ReconcileXGBoostJob) DeleteJob(job interface{}) error { 53 | xgboostjob, ok := job.(*v1xgboost.XGBoostJob) 54 | if !ok { 55 | return fmt.Errorf("%+v is not a type of XGBoostJob", xgboostjob) 56 | } 57 | if err := r.Delete(context.Background(), xgboostjob); err != nil { 58 | r.recorder.Eventf(xgboostjob, corev1.EventTypeWarning, FailedDeleteJobReason, "Error deleting: %v", err) 59 | log.Error(err, "failed to delete job", "namespace", xgboostjob.Namespace, "name", xgboostjob.Name) 60 | return err 61 | } 62 | r.recorder.Eventf(xgboostjob, corev1.EventTypeNormal, SuccessfulDeleteJobReason, "Deleted job: %v", xgboostjob.Name) 63 | log.Info("job deleted", "namespace", xgboostjob.Namespace, "name", xgboostjob.Name) 64 | return nil 65 | } 66 | 67 | // GetJobFromInformerCache returns the Job from Informer Cache 68 | func (r *ReconcileXGBoostJob) GetJobFromInformerCache(namespace, name string) (metav1.Object, error) { 69 | job := &v1xgboost.XGBoostJob{} 70 | // Default reader for XGBoostJob is cache reader. 71 | err := r.Get(context.Background(), types.NamespacedName{Namespace: namespace, Name: name}, job) 72 | if err != nil { 73 | if errors.IsNotFound(err) { 74 | log.Error(err, "xgboost job not found", "namespace", namespace, "name", name) 75 | } else { 76 | log.Error(err, "failed to get job from api-server", "namespace", namespace, "name", name) 77 | } 78 | return nil, err 79 | } 80 | return job, nil 81 | } 82 | 83 | // GetJobFromAPIClient returns the Job from API server 84 | func (r *ReconcileXGBoostJob) GetJobFromAPIClient(namespace, name string) (metav1.Object, error) { 85 | job := &v1xgboost.XGBoostJob{} 86 | 87 | clientReader, err := getClientReaderFromClient(r.Client) 88 | if err != nil { 89 | return nil, err 90 | } 91 | err = clientReader.Get(context.Background(), types.NamespacedName{Namespace: namespace, Name: name}, job) 92 | if err != nil { 93 | if errors.IsNotFound(err) { 94 | log.Error(err, "xgboost job not found", "namespace", namespace, "name", name) 95 | } else { 96 | log.Error(err, "failed to get job from api-server", "namespace", namespace, "name", name) 97 | } 98 | return nil, err 99 | } 100 | return job, nil 101 | } 102 | 103 | // UpdateJobStatus updates the job status and job conditions 104 | func (r *ReconcileXGBoostJob) UpdateJobStatus(job interface{}, replicas map[commonv1.ReplicaType]*commonv1.ReplicaSpec, jobStatus *commonv1.JobStatus) error { 105 | xgboostJob, ok := job.(*v1xgboost.XGBoostJob) 106 | if !ok { 107 | return fmt.Errorf("%+v is not a type of xgboostJob", xgboostJob) 108 | } 109 | 110 | for rtype, spec := range replicas { 111 | status := jobStatus.ReplicaStatuses[rtype] 112 | 113 | succeeded := status.Succeeded 114 | expected := *(spec.Replicas) - succeeded 115 | running := status.Active 116 | failed := status.Failed 117 | 118 | logrus.Infof("XGBoostJob=%s, ReplicaType=%s expected=%d, running=%d, succeeded=%d , failed=%d", 119 | xgboostJob.Name, rtype, expected, running, succeeded, failed) 120 | 121 | if rtype == commonv1.ReplicaType(v1xgboost.XGBoostReplicaTypeMaster) { 122 | if running > 0 { 123 | msg := fmt.Sprintf("XGBoostJob %s is running.", xgboostJob.Name) 124 | err := commonutil.UpdateJobConditions(jobStatus, commonv1.JobRunning, xgboostJobRunningReason, msg) 125 | if err != nil { 126 | logger.LoggerForJob(xgboostJob).Infof("Append job condition error: %v", err) 127 | return err 128 | } 129 | } 130 | // when master is succeed, the job is finished. 131 | if expected == 0 { 132 | msg := fmt.Sprintf("XGBoostJob %s is successfully completed.", xgboostJob.Name) 133 | logrus.Info(msg) 134 | r.Recorder.Event(xgboostJob, k8sv1.EventTypeNormal, xgboostJobSucceededReason, msg) 135 | if jobStatus.CompletionTime == nil { 136 | now := metav1.Now() 137 | jobStatus.CompletionTime = &now 138 | } 139 | err := commonutil.UpdateJobConditions(jobStatus, commonv1.JobSucceeded, xgboostJobSucceededReason, msg) 140 | if err != nil { 141 | logger.LoggerForJob(xgboostJob).Infof("Append job condition error: %v", err) 142 | return err 143 | } 144 | return nil 145 | } 146 | } 147 | if failed > 0 { 148 | if spec.RestartPolicy == commonv1.RestartPolicyExitCode { 149 | msg := fmt.Sprintf("XGBoostJob %s is restarting because %d %s replica(s) failed.", xgboostJob.Name, failed, rtype) 150 | r.Recorder.Event(xgboostJob, k8sv1.EventTypeWarning, xgboostJobRestartingReason, msg) 151 | err := commonutil.UpdateJobConditions(jobStatus, commonv1.JobRestarting, xgboostJobRestartingReason, msg) 152 | if err != nil { 153 | logger.LoggerForJob(xgboostJob).Infof("Append job condition error: %v", err) 154 | return err 155 | } 156 | } else { 157 | msg := fmt.Sprintf("XGBoostJob %s is failed because %d %s replica(s) failed.", xgboostJob.Name, failed, rtype) 158 | r.Recorder.Event(xgboostJob, k8sv1.EventTypeNormal, xgboostJobFailedReason, msg) 159 | if xgboostJob.Status.CompletionTime == nil { 160 | now := metav1.Now() 161 | xgboostJob.Status.CompletionTime = &now 162 | } 163 | err := commonutil.UpdateJobConditions(jobStatus, commonv1.JobFailed, xgboostJobFailedReason, msg) 164 | if err != nil { 165 | logger.LoggerForJob(xgboostJob).Infof("Append job condition error: %v", err) 166 | return err 167 | } 168 | } 169 | } 170 | } 171 | 172 | // Some workers are still running, leave a running condition. 173 | msg := fmt.Sprintf("XGBoostJob %s is running.", xgboostJob.Name) 174 | logger.LoggerForJob(xgboostJob).Infof(msg) 175 | 176 | if err := commonutil.UpdateJobConditions(jobStatus, commonv1.JobRunning, xgboostJobRunningReason, msg); err != nil { 177 | logger.LoggerForJob(xgboostJob).Error(err, "failed to update XGBoost Job conditions") 178 | return err 179 | } 180 | 181 | return nil 182 | } 183 | 184 | // UpdateJobStatusInApiServer updates the job status in to cluster. 185 | func (r *ReconcileXGBoostJob) UpdateJobStatusInApiServer(job interface{}, jobStatus *commonv1.JobStatus) error { 186 | xgboostjob, ok := job.(*v1xgboost.XGBoostJob) 187 | if !ok { 188 | return fmt.Errorf("%+v is not a type of XGBoostJob", xgboostjob) 189 | } 190 | 191 | // Job status passed in differs with status in job, update in basis of the passed in one. 192 | if !reflect.DeepEqual(&xgboostjob.Status.JobStatus, jobStatus) { 193 | xgboostjob = xgboostjob.DeepCopy() 194 | xgboostjob.Status.JobStatus = *jobStatus.DeepCopy() 195 | } 196 | 197 | result := r.Update(context.Background(), xgboostjob) 198 | 199 | if result != nil { 200 | logger.LoggerForJob(xgboostjob).Error(result, "failed to update XGBoost Job conditions in the API server") 201 | return result 202 | } 203 | 204 | return nil 205 | } 206 | 207 | // onOwnerCreateFunc modify creation condition. 208 | func onOwnerCreateFunc(r reconcile.Reconciler) func(event.CreateEvent) bool { 209 | return func(e event.CreateEvent) bool { 210 | xgboostJob, ok := e.Meta.(*v1xgboost.XGBoostJob) 211 | if !ok { 212 | return true 213 | } 214 | scheme.Scheme.Default(xgboostJob) 215 | msg := fmt.Sprintf("xgboostJob %s is created.", e.Meta.GetName()) 216 | logrus.Info(msg) 217 | //specific the run policy 218 | 219 | if xgboostJob.Spec.RunPolicy.CleanPodPolicy == nil { 220 | xgboostJob.Spec.RunPolicy.CleanPodPolicy = new(commonv1.CleanPodPolicy) 221 | xgboostJob.Spec.RunPolicy.CleanPodPolicy = &defaultCleanPodPolicy 222 | } 223 | 224 | if err := commonutil.UpdateJobConditions(&xgboostJob.Status.JobStatus, commonv1.JobCreated, xgboostJobCreatedReason, msg); err != nil { 225 | log.Error(err, "append job condition error") 226 | return false 227 | } 228 | return true 229 | } 230 | } 231 | -------------------------------------------------------------------------------- /pkg/controller/v1/xgboostjob/pod.go: -------------------------------------------------------------------------------- 1 | /* 2 | 3 | Licensed under the Apache License, Version 2.0 (the "License"); 4 | you may not use this file except in compliance with the License. 5 | You may obtain a copy of the License at 6 | 7 | http://www.apache.org/licenses/LICENSE-2.0 8 | 9 | Unless required by applicable law or agreed to in writing, software 10 | distributed under the License is distributed on an "AS IS" BASIS, 11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | See the License for the specific language governing permissions and 13 | limitations under the License. 14 | */ 15 | 16 | package xgboostjob 17 | 18 | import ( 19 | "context" 20 | "fmt" 21 | "strconv" 22 | "strings" 23 | 24 | "k8s.io/apimachinery/pkg/api/meta" 25 | "sigs.k8s.io/controller-runtime/pkg/client" 26 | 27 | commonv1 "github.com/kubeflow/common/pkg/apis/common/v1" 28 | v1xgboost "github.com/kubeflow/xgboost-operator/pkg/apis/xgboostjob/v1" 29 | corev1 "k8s.io/api/core/v1" 30 | ) 31 | 32 | // GetPodsForJob returns the pods managed by the job. This can be achieved by selecting pods using label key "job-name" 33 | // i.e. all pods created by the job will come with label "job-name" = 34 | func (r *ReconcileXGBoostJob) GetPodsForJob(obj interface{}) ([]*corev1.Pod, error) { 35 | job, err := meta.Accessor(obj) 36 | if err != nil { 37 | return nil, err 38 | } 39 | // List all pods to include those that don't match the selector anymore 40 | // but have a ControllerRef pointing to this controller. 41 | podlist := &corev1.PodList{} 42 | err = r.List(context.Background(), podlist, client.MatchingLabels(r.GenLabels(job.GetName()))) 43 | if err != nil { 44 | return nil, err 45 | } 46 | 47 | return convertPodList(podlist.Items), nil 48 | } 49 | 50 | // convertPodList convert pod list to pod point list 51 | func convertPodList(list []corev1.Pod) []*corev1.Pod { 52 | if list == nil { 53 | return nil 54 | } 55 | ret := make([]*corev1.Pod, 0, len(list)) 56 | for i := range list { 57 | ret = append(ret, &list[i]) 58 | } 59 | return ret 60 | } 61 | 62 | // SetPodEnv sets the pod env set for: 63 | // - XGBoost Rabit Tracker and worker 64 | // - LightGBM master and workers 65 | func SetPodEnv(job interface{}, podTemplate *corev1.PodTemplateSpec, rtype, index string) error { 66 | xgboostjob, ok := job.(*v1xgboost.XGBoostJob) 67 | if !ok { 68 | return fmt.Errorf("%+v is not a type of XGBoostJob", xgboostjob) 69 | } 70 | 71 | rank, err := strconv.Atoi(index) 72 | if err != nil { 73 | return err 74 | } 75 | 76 | // Add master offset for worker pods 77 | if strings.ToLower(rtype) == strings.ToLower(string(v1xgboost.XGBoostReplicaTypeWorker)) { 78 | masterSpec := xgboostjob.Spec.XGBReplicaSpecs[commonv1.ReplicaType(v1xgboost.XGBoostReplicaTypeMaster)] 79 | masterReplicas := int(*masterSpec.Replicas) 80 | rank += masterReplicas 81 | } 82 | 83 | masterAddr := computeMasterAddr(xgboostjob.Name, strings.ToLower(string(v1xgboost.XGBoostReplicaTypeMaster)), strconv.Itoa(0)) 84 | 85 | masterPort, err := GetPortFromXGBoostJob(xgboostjob, v1xgboost.XGBoostReplicaTypeMaster) 86 | if err != nil { 87 | return err 88 | } 89 | 90 | totalReplicas := computeTotalReplicas(xgboostjob) 91 | 92 | var workerPort int32 93 | var workerAddrs []string 94 | 95 | if totalReplicas > 1 { 96 | workerPortTemp, err := GetPortFromXGBoostJob(xgboostjob, v1xgboost.XGBoostReplicaTypeWorker) 97 | if err != nil { 98 | return err 99 | } 100 | workerPort = workerPortTemp 101 | workerAddrs = make([]string, totalReplicas-1) 102 | for i := range workerAddrs { 103 | workerAddrs[i] = computeMasterAddr(xgboostjob.Name, strings.ToLower(string(v1xgboost.XGBoostReplicaTypeWorker)), strconv.Itoa(i)) 104 | } 105 | } 106 | 107 | for i := range podTemplate.Spec.Containers { 108 | if len(podTemplate.Spec.Containers[i].Env) == 0 { 109 | podTemplate.Spec.Containers[i].Env = make([]corev1.EnvVar, 0) 110 | } 111 | podTemplate.Spec.Containers[i].Env = append(podTemplate.Spec.Containers[i].Env, corev1.EnvVar{ 112 | Name: "MASTER_PORT", 113 | Value: strconv.Itoa(int(masterPort)), 114 | }) 115 | podTemplate.Spec.Containers[i].Env = append(podTemplate.Spec.Containers[i].Env, corev1.EnvVar{ 116 | Name: "MASTER_ADDR", 117 | Value: masterAddr, 118 | }) 119 | podTemplate.Spec.Containers[i].Env = append(podTemplate.Spec.Containers[i].Env, corev1.EnvVar{ 120 | Name: "WORLD_SIZE", 121 | Value: strconv.Itoa(int(totalReplicas)), 122 | }) 123 | podTemplate.Spec.Containers[i].Env = append(podTemplate.Spec.Containers[i].Env, corev1.EnvVar{ 124 | Name: "RANK", 125 | Value: strconv.Itoa(rank), 126 | }) 127 | podTemplate.Spec.Containers[i].Env = append(podTemplate.Spec.Containers[i].Env, corev1.EnvVar{ 128 | Name: "PYTHONUNBUFFERED", 129 | Value: "0", 130 | }) 131 | // This variables are used if it is a LightGBM job 132 | if totalReplicas > 1 { 133 | podTemplate.Spec.Containers[i].Env = append(podTemplate.Spec.Containers[i].Env, corev1.EnvVar{ 134 | Name: "WORKER_PORT", 135 | Value: strconv.Itoa(int(workerPort)), 136 | }) 137 | podTemplate.Spec.Containers[i].Env = append(podTemplate.Spec.Containers[i].Env, corev1.EnvVar{ 138 | Name: "WORKER_ADDRS", 139 | Value: strings.Join(workerAddrs, ","), 140 | }) 141 | } 142 | } 143 | 144 | return nil 145 | } 146 | -------------------------------------------------------------------------------- /pkg/controller/v1/xgboostjob/pod_test.go: -------------------------------------------------------------------------------- 1 | /* 2 | 3 | Licensed under the Apache License, Version 2.0 (the "License"); 4 | you may not use this file except in compliance with the License. 5 | You may obtain a copy of the License at 6 | 7 | http://www.apache.org/licenses/LICENSE-2.0 8 | 9 | Unless required by applicable law or agreed to in writing, software 10 | distributed under the License is distributed on an "AS IS" BASIS, 11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | See the License for the specific language governing permissions and 13 | limitations under the License. 14 | */ 15 | 16 | package xgboostjob 17 | 18 | import ( 19 | "testing" 20 | 21 | commonv1 "github.com/kubeflow/common/pkg/apis/common/v1" 22 | v1xgboost "github.com/kubeflow/xgboost-operator/pkg/apis/xgboostjob/v1" 23 | v1 "k8s.io/api/core/v1" 24 | metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" 25 | ) 26 | 27 | func NewXGBoostJobWithMaster(worker int) *v1xgboost.XGBoostJob { 28 | job := NewXGoostJob(worker) 29 | master := int32(1) 30 | masterReplicaSpec := &commonv1.ReplicaSpec{ 31 | Replicas: &master, 32 | Template: NewXGBoostReplicaSpecTemplate(), 33 | } 34 | job.Spec.XGBReplicaSpecs[commonv1.ReplicaType(v1xgboost.XGBoostReplicaTypeMaster)] = masterReplicaSpec 35 | return job 36 | } 37 | 38 | func NewXGoostJob(worker int) *v1xgboost.XGBoostJob { 39 | 40 | job := &v1xgboost.XGBoostJob{ 41 | TypeMeta: metav1.TypeMeta{ 42 | Kind: v1xgboost.Kind, 43 | }, 44 | ObjectMeta: metav1.ObjectMeta{ 45 | Name: "test-xgboostjob", 46 | Namespace: metav1.NamespaceDefault, 47 | }, 48 | Spec: v1xgboost.XGBoostJobSpec{ 49 | XGBReplicaSpecs: make(map[commonv1.ReplicaType]*commonv1.ReplicaSpec), 50 | }, 51 | } 52 | 53 | if worker > 0 { 54 | worker := int32(worker) 55 | workerReplicaSpec := &commonv1.ReplicaSpec{ 56 | Replicas: &worker, 57 | Template: NewXGBoostReplicaSpecTemplate(), 58 | } 59 | job.Spec.XGBReplicaSpecs[commonv1.ReplicaType(v1xgboost.XGBoostReplicaTypeWorker)] = workerReplicaSpec 60 | } 61 | 62 | return job 63 | } 64 | 65 | func NewXGBoostReplicaSpecTemplate() v1.PodTemplateSpec { 66 | return v1.PodTemplateSpec{ 67 | Spec: v1.PodSpec{ 68 | Containers: []v1.Container{ 69 | v1.Container{ 70 | Name: v1xgboost.DefaultContainerName, 71 | Image: "test-image-for-kubeflow-xgboost-operator:latest", 72 | Args: []string{"Fake", "Fake"}, 73 | Ports: []v1.ContainerPort{ 74 | v1.ContainerPort{ 75 | Name: v1xgboost.DefaultContainerPortName, 76 | ContainerPort: v1xgboost.DefaultPort, 77 | }, 78 | }, 79 | }, 80 | }, 81 | }, 82 | } 83 | } 84 | 85 | func TestClusterSpec(t *testing.T) { 86 | type tc struct { 87 | job *v1xgboost.XGBoostJob 88 | rt v1xgboost.XGBoostJobReplicaType 89 | index string 90 | expectedClusterSpec map[string]string 91 | } 92 | testCase := []tc{ 93 | tc{ 94 | job: NewXGBoostJobWithMaster(0), 95 | rt: v1xgboost.XGBoostReplicaTypeMaster, 96 | index: "0", 97 | expectedClusterSpec: map[string]string{"WORLD_SIZE": "1", "MASTER_PORT": "9999", "RANK": "0", "MASTER_ADDR": "test-xgboostjob-master-0"}, 98 | }, 99 | tc{ 100 | job: NewXGBoostJobWithMaster(1), 101 | rt: v1xgboost.XGBoostReplicaTypeMaster, 102 | index: "1", 103 | expectedClusterSpec: map[string]string{"WORLD_SIZE": "2", "MASTER_PORT": "9999", "RANK": "1", "MASTER_ADDR": "test-xgboostjob-master-0", "WORKER_PORT": "9999", "WORKER_ADDRS": "test-xgboostjob-worker-0"}, 104 | }, 105 | tc{ 106 | job: NewXGBoostJobWithMaster(2), 107 | rt: v1xgboost.XGBoostReplicaTypeMaster, 108 | index: "0", 109 | expectedClusterSpec: map[string]string{"WORLD_SIZE": "3", "MASTER_PORT": "9999", "RANK": "0", "MASTER_ADDR": "test-xgboostjob-master-0", "WORKER_PORT": "9999", "WORKER_ADDRS": "test-xgboostjob-worker-0,test-xgboostjob-worker-1"}, 110 | }, 111 | tc{ 112 | job: NewXGBoostJobWithMaster(2), 113 | rt: v1xgboost.XGBoostReplicaTypeWorker, 114 | index: "0", 115 | expectedClusterSpec: map[string]string{"WORLD_SIZE": "3", "MASTER_PORT": "9999", "RANK": "1", "MASTER_ADDR": "test-xgboostjob-master-0", "WORKER_PORT": "9999", "WORKER_ADDRS": "test-xgboostjob-worker-0,test-xgboostjob-worker-1"}, 116 | }, 117 | tc{ 118 | job: NewXGBoostJobWithMaster(2), 119 | rt: v1xgboost.XGBoostReplicaTypeWorker, 120 | index: "1", 121 | expectedClusterSpec: map[string]string{"WORLD_SIZE": "3", "MASTER_PORT": "9999", "RANK": "2", "MASTER_ADDR": "test-xgboostjob-master-0", "WORKER_PORT": "9999", "WORKER_ADDRS": "test-xgboostjob-worker-0,test-xgboostjob-worker-1"}, 122 | }, 123 | } 124 | for _, c := range testCase { 125 | demoTemplateSpec := c.job.Spec.XGBReplicaSpecs[commonv1.ReplicaType(c.rt)].Template 126 | if err := SetPodEnv(c.job, &demoTemplateSpec, string(c.rt), c.index); err != nil { 127 | t.Errorf("Failed to set cluster spec: %v", err) 128 | } 129 | actual := demoTemplateSpec.Spec.Containers[0].Env 130 | for _, env := range actual { 131 | if val, ok := c.expectedClusterSpec[env.Name]; ok { 132 | if val != env.Value { 133 | t.Errorf("For name %s Got %s. Expected %s ", env.Name, env.Value, c.expectedClusterSpec[env.Name]) 134 | } 135 | } 136 | } 137 | } 138 | } 139 | -------------------------------------------------------------------------------- /pkg/controller/v1/xgboostjob/service.go: -------------------------------------------------------------------------------- 1 | /* 2 | 3 | Licensed under the Apache License, Version 2.0 (the "License"); 4 | you may not use this file except in compliance with the License. 5 | You may obtain a copy of the License at 6 | 7 | http://www.apache.org/licenses/LICENSE-2.0 8 | 9 | Unless required by applicable law or agreed to in writing, software 10 | distributed under the License is distributed on an "AS IS" BASIS, 11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | See the License for the specific language governing permissions and 13 | limitations under the License. 14 | */ 15 | 16 | package xgboostjob 17 | 18 | import ( 19 | "context" 20 | "fmt" 21 | corev1 "k8s.io/api/core/v1" 22 | "k8s.io/apimachinery/pkg/api/meta" 23 | "sigs.k8s.io/controller-runtime/pkg/client" 24 | ) 25 | 26 | // GetServicesForJob returns the services managed by the job. This can be achieved by selecting services using label key "job-name" 27 | // i.e. all services created by the job will come with label "job-name" = 28 | func (r *ReconcileXGBoostJob) GetServicesForJob(obj interface{}) ([]*corev1.Service, error) { 29 | job, err := meta.Accessor(obj) 30 | if err != nil { 31 | return nil, fmt.Errorf("%+v is not a type of XGBoostJob", job) 32 | } 33 | // List all pods to include those that don't match the selector anymore 34 | // but have a ControllerRef pointing to this controller. 35 | serviceList := &corev1.ServiceList{} 36 | err = r.List(context.Background(), serviceList, client.MatchingLabels(r.GenLabels(job.GetName()))) 37 | if err != nil { 38 | return nil, err 39 | } 40 | 41 | //TODO support adopting/orphaning 42 | ret := convertServiceList(serviceList.Items) 43 | 44 | return ret, nil 45 | } 46 | 47 | // convertServiceList convert service list to service point list 48 | func convertServiceList(list []corev1.Service) []*corev1.Service { 49 | if list == nil { 50 | return nil 51 | } 52 | ret := make([]*corev1.Service, 0, len(list)) 53 | for i := range list { 54 | ret = append(ret, &list[i]) 55 | } 56 | return ret 57 | } 58 | -------------------------------------------------------------------------------- /pkg/controller/v1/xgboostjob/util.go: -------------------------------------------------------------------------------- 1 | /* 2 | 3 | Licensed under the Apache License, Version 2.0 (the "License"); 4 | you may not use this file except in compliance with the License. 5 | You may obtain a copy of the License at 6 | 7 | http://www.apache.org/licenses/LICENSE-2.0 8 | 9 | Unless required by applicable law or agreed to in writing, software 10 | distributed under the License is distributed on an "AS IS" BASIS, 11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | See the License for the specific language governing permissions and 13 | limitations under the License. 14 | */ 15 | 16 | package xgboostjob 17 | 18 | import ( 19 | "errors" 20 | "fmt" 21 | "os" 22 | "strings" 23 | "time" 24 | 25 | commonv1 "github.com/kubeflow/common/pkg/apis/common/v1" 26 | v1xgboost "github.com/kubeflow/xgboost-operator/pkg/apis/xgboostjob/v1" 27 | metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" 28 | kubeclientset "k8s.io/client-go/kubernetes" 29 | restclientset "k8s.io/client-go/rest" 30 | "sigs.k8s.io/controller-runtime/pkg/client" 31 | volcanoclient "volcano.sh/volcano/pkg/client/clientset/versioned" 32 | ) 33 | 34 | // getClientReaderFromClient try to extract client reader from client, client 35 | // reader reads cluster info from api client. 36 | func getClientReaderFromClient(c client.Client) (client.Reader, error) { 37 | if dr, err := getDelegatingReader(c); err != nil { 38 | return nil, err 39 | } else { 40 | return dr.ClientReader, nil 41 | } 42 | } 43 | 44 | // getDelegatingReader try to extract DelegatingReader from client. 45 | func getDelegatingReader(c client.Client) (*client.DelegatingReader, error) { 46 | dc, ok := c.(*client.DelegatingClient) 47 | if !ok { 48 | return nil, errors.New("cannot convert from Client to DelegatingClient") 49 | } 50 | dr, ok := dc.Reader.(*client.DelegatingReader) 51 | if !ok { 52 | return nil, errors.New("cannot convert from DelegatingClient.Reader to Delegating Reader") 53 | } 54 | return dr, nil 55 | } 56 | 57 | func computeMasterAddr(jobName, rtype, index string) string { 58 | n := jobName + "-" + rtype + "-" + index 59 | return strings.Replace(n, "/", "-", -1) 60 | } 61 | 62 | // GetPortFromXGBoostJob gets the port of xgboost container. 63 | func GetPortFromXGBoostJob(job *v1xgboost.XGBoostJob, rtype v1xgboost.XGBoostJobReplicaType) (int32, error) { 64 | containers := job.Spec.XGBReplicaSpecs[commonv1.ReplicaType(rtype)].Template.Spec.Containers 65 | for _, container := range containers { 66 | if container.Name == v1xgboost.DefaultContainerName { 67 | ports := container.Ports 68 | for _, port := range ports { 69 | if port.Name == v1xgboost.DefaultContainerPortName { 70 | return port.ContainerPort, nil 71 | } 72 | } 73 | } 74 | } 75 | return -1, fmt.Errorf("failed to found the port") 76 | } 77 | 78 | func computeTotalReplicas(obj metav1.Object) int32 { 79 | job := obj.(*v1xgboost.XGBoostJob) 80 | jobReplicas := int32(0) 81 | 82 | if job.Spec.XGBReplicaSpecs == nil || len(job.Spec.XGBReplicaSpecs) == 0 { 83 | return jobReplicas 84 | } 85 | for _, r := range job.Spec.XGBReplicaSpecs { 86 | if r.Replicas == nil { 87 | continue 88 | } else { 89 | jobReplicas += *r.Replicas 90 | } 91 | } 92 | return jobReplicas 93 | } 94 | 95 | func createClientSets(config *restclientset.Config) (kubeclientset.Interface, kubeclientset.Interface, volcanoclient.Interface, error) { 96 | if config == nil { 97 | println("there is an error for the input config") 98 | return nil, nil, nil, nil 99 | } 100 | 101 | kubeClientSet, err := kubeclientset.NewForConfig(restclientset.AddUserAgent(config, "xgboostjob-operator")) 102 | if err != nil { 103 | return nil, nil, nil, err 104 | } 105 | 106 | leaderElectionClientSet, err := kubeclientset.NewForConfig(restclientset.AddUserAgent(config, "leader-election")) 107 | if err != nil { 108 | return nil, nil, nil, err 109 | } 110 | 111 | volcanoClientSet, err := volcanoclient.NewForConfig(restclientset.AddUserAgent(config, "volcano")) 112 | if err != nil { 113 | return nil, nil, nil, err 114 | } 115 | 116 | return kubeClientSet, leaderElectionClientSet, volcanoClientSet, nil 117 | } 118 | 119 | func homeDir() string { 120 | if h := os.Getenv("HOME"); h != "" { 121 | return h 122 | } 123 | return os.Getenv("USERPROFILE") // windows 124 | } 125 | 126 | func isGangSchedulerSet(replicas map[commonv1.ReplicaType]*commonv1.ReplicaSpec) bool { 127 | for _, spec := range replicas { 128 | if spec.Template.Spec.SchedulerName != "" && spec.Template.Spec.SchedulerName == gangSchedulerName { 129 | return true 130 | } 131 | } 132 | return false 133 | } 134 | 135 | // FakeWorkQueue implements RateLimitingInterface but actually does nothing. 136 | type FakeWorkQueue struct{} 137 | 138 | // Add WorkQueue Add method 139 | func (f *FakeWorkQueue) Add(item interface{}) {} 140 | 141 | // Len WorkQueue Len method 142 | func (f *FakeWorkQueue) Len() int { return 0 } 143 | 144 | // Get WorkQueue Get method 145 | func (f *FakeWorkQueue) Get() (item interface{}, shutdown bool) { return nil, false } 146 | 147 | // Done WorkQueue Done method 148 | func (f *FakeWorkQueue) Done(item interface{}) {} 149 | 150 | // ShutDown WorkQueue ShutDown method 151 | func (f *FakeWorkQueue) ShutDown() {} 152 | 153 | // ShuttingDown WorkQueue ShuttingDown method 154 | func (f *FakeWorkQueue) ShuttingDown() bool { return true } 155 | 156 | // AddAfter WorkQueue AddAfter method 157 | func (f *FakeWorkQueue) AddAfter(item interface{}, duration time.Duration) {} 158 | 159 | // AddRateLimited WorkQueue AddRateLimited method 160 | func (f *FakeWorkQueue) AddRateLimited(item interface{}) {} 161 | 162 | // Forget WorkQueue Forget method 163 | func (f *FakeWorkQueue) Forget(item interface{}) {} 164 | 165 | // NumRequeues WorkQueue NumRequeues method 166 | func (f *FakeWorkQueue) NumRequeues(item interface{}) int { return 0 } 167 | -------------------------------------------------------------------------------- /pkg/controller/v1/xgboostjob/xgboostjob_controller.go: -------------------------------------------------------------------------------- 1 | /* 2 | Licensed under the Apache License, Version 2.0 (the "License"); 3 | you may not use this file except in compliance with the License. 4 | You may obtain a copy of the License at 5 | http://www.apache.org/licenses/LICENSE-2.0 6 | Unless required by applicable law or agreed to in writing, software 7 | distributed under the License is distributed on an "AS IS" BASIS, 8 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 9 | See the License for the specific language governing permissions and 10 | limitations under the License. 11 | */ 12 | 13 | package xgboostjob 14 | 15 | import ( 16 | "context" 17 | "flag" 18 | commonv1 "github.com/kubeflow/common/pkg/apis/common/v1" 19 | "github.com/kubeflow/common/pkg/controller.v1/common" 20 | "github.com/kubeflow/common/pkg/controller.v1/control" 21 | "github.com/kubeflow/common/pkg/controller.v1/expectation" 22 | v1xgboost "github.com/kubeflow/xgboost-operator/pkg/apis/xgboostjob/v1" 23 | corev1 "k8s.io/api/core/v1" 24 | "path/filepath" 25 | 26 | "github.com/sirupsen/logrus" 27 | "k8s.io/apimachinery/pkg/api/errors" 28 | "k8s.io/apimachinery/pkg/runtime" 29 | "k8s.io/apimachinery/pkg/runtime/schema" 30 | "k8s.io/client-go/kubernetes/scheme" 31 | "k8s.io/client-go/rest" 32 | "k8s.io/client-go/tools/clientcmd" 33 | "k8s.io/client-go/tools/record" 34 | "sigs.k8s.io/controller-runtime/pkg/client" 35 | "sigs.k8s.io/controller-runtime/pkg/controller" 36 | "sigs.k8s.io/controller-runtime/pkg/handler" 37 | "sigs.k8s.io/controller-runtime/pkg/manager" 38 | "sigs.k8s.io/controller-runtime/pkg/predicate" 39 | "sigs.k8s.io/controller-runtime/pkg/reconcile" 40 | logf "sigs.k8s.io/controller-runtime/pkg/runtime/log" 41 | "sigs.k8s.io/controller-runtime/pkg/source" 42 | ) 43 | 44 | const ( 45 | controllerName = "xgboostjob-operator" 46 | labelXGBoostJobRole = "xgboostjob-job-role" 47 | // gang scheduler name. 48 | gangSchedulerName = "kube-batch" 49 | ) 50 | 51 | var ( 52 | defaultTTLseconds = int32(100) 53 | defaultCleanPodPolicy = commonv1.CleanPodPolicyNone 54 | ) 55 | var log = logf.Log.WithName("controller") 56 | 57 | /** 58 | * USER ACTION REQUIRED: This is a scaffold file intended for the user to modify with their own Controller 59 | * business logic. Delete these comments after modifying this file.* 60 | */ 61 | 62 | // Add creates a new XGBoostJob Controller and adds it to the Manager with default RBAC. The Manager will set fields on the Controller 63 | // and Start it when the Manager is Started. 64 | func Add(mgr manager.Manager) error { 65 | return add(mgr, newReconciler(mgr)) 66 | } 67 | 68 | const RecommendedKubeConfigPathEnv = "KUBECONFIG" 69 | 70 | // newReconciler returns a new reconcile.Reconciler 71 | func newReconciler(mgr manager.Manager) reconcile.Reconciler { 72 | 73 | r := &ReconcileXGBoostJob{ 74 | Client: mgr.GetClient(), 75 | scheme: mgr.GetScheme(), 76 | } 77 | 78 | r.recorder = mgr.GetEventRecorderFor(r.ControllerName()) 79 | 80 | var mode string 81 | var kubeconfig *string 82 | var kcfg *rest.Config 83 | if home := homeDir(); home != "" { 84 | kubeconfig = flag.String("kubeconfig_", filepath.Join(home, ".kube", "config"), "(optional) absolute path to the kubeconfig file") 85 | } else { 86 | kubeconfig = flag.String("kubeconfig_", "", "absolute path to the kubeconfig file") 87 | } 88 | flag.Parse() 89 | 90 | mode = flag.Lookup("mode").Value.(flag.Getter).Get().(string) 91 | if mode == "local" { 92 | log.Info("Running controller in local mode, using kubeconfig file") 93 | /// TODO, add the master url and kubeconfigpath with user input 94 | config, err := clientcmd.BuildConfigFromFlags("", *kubeconfig) 95 | if err != nil { 96 | log.Info("Error building kubeconfig: %s", err.Error()) 97 | panic(err.Error()) 98 | } 99 | kcfg = config 100 | } else if mode == "in-cluster" { 101 | log.Info("Running controller in in-cluster mode") 102 | /// TODO, add the master url and kubeconfigpath with user input 103 | config, err := rest.InClusterConfig() 104 | if err != nil { 105 | log.Info("Error getting in-cluster kubeconfig") 106 | panic(err.Error()) 107 | } 108 | kcfg = config 109 | } else { 110 | log.Info("Given mode is not valid: ", "mode", mode) 111 | panic("-mode should be either local or in-cluster") 112 | } 113 | 114 | // Create clients. 115 | kubeClientSet, _, volcanoClientSet, err := createClientSets(kcfg) 116 | if err != nil { 117 | log.Info("Error building kubeclientset: %s", err.Error()) 118 | } 119 | 120 | // Create Informer factory 121 | xgboostjob := &v1xgboost.XGBoostJob{} 122 | 123 | gangScheduling := isGangSchedulerSet(xgboostjob.Spec.XGBReplicaSpecs) 124 | 125 | log.Info("gang scheduling is set: ", "gangscheduling", gangScheduling) 126 | 127 | // Initialize common job controller with components we only need. 128 | r.JobController = common.JobController{ 129 | Controller: r, 130 | Expectations: expectation.NewControllerExpectations(), 131 | Config: common.JobControllerConfiguration{EnableGangScheduling: gangScheduling}, 132 | WorkQueue: &FakeWorkQueue{}, 133 | Recorder: r.recorder, 134 | KubeClientSet: kubeClientSet, 135 | VolcanoClientSet: volcanoClientSet, 136 | PodControl: control.RealPodControl{KubeClient: kubeClientSet, Recorder: r.recorder}, 137 | ServiceControl: control.RealServiceControl{KubeClient: kubeClientSet, Recorder: r.recorder}, 138 | } 139 | 140 | return r 141 | } 142 | 143 | // add adds a new Controller to mgr with r as the reconcile.Reconciler 144 | func add(mgr manager.Manager, r reconcile.Reconciler) error { 145 | // Create a new controller 146 | c, err := controller.New("xgboostjob-controller", mgr, controller.Options{Reconciler: r}) 147 | if err != nil { 148 | return err 149 | } 150 | 151 | // Watch for changes to XGBoostJob 152 | err = c.Watch(&source.Kind{Type: &v1xgboost.XGBoostJob{}}, &handler.EnqueueRequestForObject{}, 153 | predicate.Funcs{CreateFunc: onOwnerCreateFunc(r)}, 154 | ) 155 | if err != nil { 156 | return err 157 | } 158 | 159 | //inject watching for xgboostjob related pod 160 | err = c.Watch(&source.Kind{Type: &corev1.Pod{}}, &handler.EnqueueRequestForOwner{ 161 | IsController: true, 162 | OwnerType: &v1xgboost.XGBoostJob{}, 163 | }, 164 | predicate.Funcs{CreateFunc: onDependentCreateFunc(r), DeleteFunc: onDependentDeleteFunc(r)}, 165 | ) 166 | if err != nil { 167 | return err 168 | } 169 | 170 | //inject watching for xgboostjob related service 171 | err = c.Watch(&source.Kind{Type: &corev1.Service{}}, &handler.EnqueueRequestForOwner{ 172 | IsController: true, 173 | OwnerType: &v1xgboost.XGBoostJob{}, 174 | }, 175 | &predicate.Funcs{CreateFunc: onDependentCreateFunc(r), DeleteFunc: onDependentDeleteFunc(r)}, 176 | ) 177 | if err != nil { 178 | return err 179 | } 180 | 181 | return nil 182 | } 183 | 184 | var _ reconcile.Reconciler = &ReconcileXGBoostJob{} 185 | 186 | // ReconcileXGBoostJob reconciles a XGBoostJob object 187 | type ReconcileXGBoostJob struct { 188 | common.JobController 189 | client.Client 190 | scheme *runtime.Scheme 191 | recorder record.EventRecorder 192 | } 193 | 194 | // Reconcile reads that state of the cluster for a XGBoostJob object and makes changes based on the state read 195 | // and what is in the XGBoostJob.Spec 196 | // a Deployment as an example 197 | // Automatically generate RBAC rules to allow the Controller to read and write Deployments 198 | // +kubebuilder:rbac:groups=apps,resources=deployments,verbs=get;list;watch;create;update;patch;delete 199 | // +kubebuilder:rbac:groups=apps,resources=deployments/status,verbs=get;update;patch 200 | // +kubebuilder:rbac:groups=xgboostjob.kubeflow.org,resources=xgboostjobs,verbs=get;list;watch;create;update;patch;delete 201 | // +kubebuilder:rbac:groups=xgboostjob.kubeflow.org,resources=xgboostjobs/status,verbs=get;update;patch 202 | func (r *ReconcileXGBoostJob) Reconcile(request reconcile.Request) (reconcile.Result, error) { 203 | // Fetch the XGBoostJob instance 204 | xgboostjob := &v1xgboost.XGBoostJob{} 205 | err := r.Get(context.Background(), request.NamespacedName, xgboostjob) 206 | if err != nil { 207 | if errors.IsNotFound(err) { 208 | // Object not found, return. Created objects are automatically garbage collected. 209 | // For additional cleanup logic use finalizers. 210 | return reconcile.Result{}, nil 211 | } 212 | // Error reading the object - requeue the request. 213 | return reconcile.Result{}, err 214 | } 215 | 216 | // Check reconcile is required. 217 | needSync := r.satisfiedExpectations(xgboostjob) 218 | 219 | if !needSync || xgboostjob.DeletionTimestamp != nil { 220 | log.Info("reconcile cancelled, job does not need to do reconcile or has been deleted", 221 | "sync", needSync, "deleted", xgboostjob.DeletionTimestamp != nil) 222 | return reconcile.Result{}, nil 223 | } 224 | // Set default priorities for xgboost job 225 | scheme.Scheme.Default(xgboostjob) 226 | 227 | // Use common to reconcile the job related pod and service 228 | err = r.ReconcileJobs(xgboostjob, xgboostjob.Spec.XGBReplicaSpecs, xgboostjob.Status.JobStatus, &xgboostjob.Spec.RunPolicy) 229 | 230 | if err != nil { 231 | logrus.Warnf("Reconcile XGBoost Job error %v", err) 232 | return reconcile.Result{}, err 233 | } 234 | 235 | return reconcile.Result{}, err 236 | } 237 | 238 | func (r *ReconcileXGBoostJob) ControllerName() string { 239 | return controllerName 240 | } 241 | 242 | func (r *ReconcileXGBoostJob) GetAPIGroupVersionKind() schema.GroupVersionKind { 243 | return v1xgboost.SchemeGroupVersionKind 244 | } 245 | 246 | func (r *ReconcileXGBoostJob) GetAPIGroupVersion() schema.GroupVersion { 247 | return v1xgboost.SchemeGroupVersion 248 | } 249 | 250 | func (r *ReconcileXGBoostJob) GetGroupNameLabelValue() string { 251 | return v1xgboost.GroupName 252 | } 253 | 254 | func (r *ReconcileXGBoostJob) GetDefaultContainerName() string { 255 | return v1xgboost.DefaultContainerName 256 | } 257 | 258 | func (r *ReconcileXGBoostJob) GetDefaultContainerPortName() string { 259 | return v1xgboost.DefaultContainerPortName 260 | } 261 | 262 | func (r *ReconcileXGBoostJob) GetJobRoleKey() string { 263 | return labelXGBoostJobRole 264 | } 265 | 266 | func (r *ReconcileXGBoostJob) IsMasterRole(replicas map[commonv1.ReplicaType]*commonv1.ReplicaSpec, 267 | rtype commonv1.ReplicaType, index int) bool { 268 | return string(rtype) == string(v1xgboost.XGBoostReplicaTypeMaster) 269 | } 270 | 271 | // SetClusterSpec sets the cluster spec for the pod 272 | func (r *ReconcileXGBoostJob) SetClusterSpec(job interface{}, podTemplate *corev1.PodTemplateSpec, rtype, index string) error { 273 | return SetPodEnv(job, podTemplate, rtype, index) 274 | } 275 | -------------------------------------------------------------------------------- /pkg/controller/v1/xgboostjob/xgboostjob_controller_suite_test.go.back: -------------------------------------------------------------------------------- 1 | /* 2 | 3 | Licensed under the Apache License, Version 2.0 (the "License"); 4 | you may not use this file except in compliance with the License. 5 | You may obtain a copy of the License at 6 | 7 | http://www.apache.org/licenses/LICENSE-2.0 8 | 9 | Unless required by applicable law or agreed to in writing, software 10 | distributed under the License is distributed on an "AS IS" BASIS, 11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | See the License for the specific language governing permissions and 13 | limitations under the License. 14 | */ 15 | 16 | package xgboostjob 17 | 18 | import ( 19 | stdlog "log" 20 | "os" 21 | "path/filepath" 22 | "sync" 23 | "testing" 24 | 25 | "github.com/kubeflow/xgboost-operator/pkg/apis" 26 | "github.com/onsi/gomega" 27 | "k8s.io/client-go/kubernetes/scheme" 28 | "k8s.io/client-go/rest" 29 | "sigs.k8s.io/controller-runtime/pkg/envtest" 30 | "sigs.k8s.io/controller-runtime/pkg/manager" 31 | "sigs.k8s.io/controller-runtime/pkg/reconcile" 32 | ) 33 | 34 | var cfg *rest.Config 35 | 36 | func TestMain(m *testing.M) { 37 | t := &envtest.Environment{ 38 | CRDDirectoryPaths: []string{filepath.Join("..", "..", "..", "config", "crds")}, 39 | } 40 | apis.AddToScheme(scheme.Scheme) 41 | 42 | var err error 43 | if cfg, err = t.Start(); err != nil { 44 | stdlog.Fatal(err) 45 | } 46 | 47 | code := m.Run() 48 | t.Stop() 49 | os.Exit(code) 50 | } 51 | 52 | // SetupTestReconcile returns a reconcile.Reconcile implementation that delegates to inner and 53 | // writes the request to requests after Reconcile is finished. 54 | func SetupTestReconcile(inner reconcile.Reconciler) (reconcile.Reconciler, chan reconcile.Request) { 55 | requests := make(chan reconcile.Request) 56 | fn := reconcile.Func(func(req reconcile.Request) (reconcile.Result, error) { 57 | result, err := inner.Reconcile(req) 58 | requests <- req 59 | return result, err 60 | }) 61 | return fn, requests 62 | } 63 | 64 | // StartTestManager adds recFn 65 | func StartTestManager(mgr manager.Manager, g *gomega.GomegaWithT) (chan struct{}, *sync.WaitGroup) { 66 | stop := make(chan struct{}) 67 | wg := &sync.WaitGroup{} 68 | wg.Add(1) 69 | go func() { 70 | defer wg.Done() 71 | g.Expect(mgr.Start(stop)).NotTo(gomega.HaveOccurred()) 72 | }() 73 | return stop, wg 74 | } 75 | -------------------------------------------------------------------------------- /pkg/controller/v1/xgboostjob/xgboostjob_controller_test.go.back: -------------------------------------------------------------------------------- 1 | /* 2 | 3 | Licensed under the Apache License, Version 2.0 (the "License"); 4 | you may not use this file except in compliance with the License. 5 | You may obtain a copy of the License at 6 | 7 | http://www.apache.org/licenses/LICENSE-2.0 8 | 9 | Unless required by applicable law or agreed to in writing, software 10 | distributed under the License is distributed on an "AS IS" BASIS, 11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | See the License for the specific language governing permissions and 13 | limitations under the License. 14 | */ 15 | 16 | package xgboostjob 17 | 18 | import ( 19 | "testing" 20 | "time" 21 | 22 | xgboostjobv1 "github.com/kubeflow/xgboost-operator/pkg/apis/xgboostjob/v1" 23 | "github.com/onsi/gomega" 24 | "golang.org/x/net/context" 25 | appsv1 "k8s.io/api/apps/v1" 26 | apierrors "k8s.io/apimachinery/pkg/api/errors" 27 | metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" 28 | "k8s.io/apimachinery/pkg/types" 29 | "sigs.k8s.io/controller-runtime/pkg/client" 30 | "sigs.k8s.io/controller-runtime/pkg/manager" 31 | "sigs.k8s.io/controller-runtime/pkg/reconcile" 32 | ) 33 | 34 | var c client.Client 35 | 36 | var expectedRequest = reconcile.Request{NamespacedName: types.NamespacedName{Name: "foo", Namespace: "default"}} 37 | var depKey = types.NamespacedName{Name: "foo-deployment", Namespace: "default"} 38 | 39 | const timeout = time.Second * 5 40 | 41 | func TestReconcile(t *testing.T) { 42 | g := gomega.NewGomegaWithT(t) 43 | instance := &xgboostjobv1.XGBoostJob{ObjectMeta: metav1.ObjectMeta{Name: "foo", Namespace: "default"}} 44 | 45 | // Setup the Manager and Controller. Wrap the Controller Reconcile function so it writes each request to a 46 | // channel when it is finished. 47 | mgr, err := manager.New(cfg, manager.Options{}) 48 | g.Expect(err).NotTo(gomega.HaveOccurred()) 49 | c = mgr.GetClient() 50 | 51 | recFn, requests := SetupTestReconcile(newReconciler(mgr)) 52 | g.Expect(add(mgr, recFn)).NotTo(gomega.HaveOccurred()) 53 | 54 | stopMgr, mgrStopped := StartTestManager(mgr, g) 55 | 56 | defer func() { 57 | close(stopMgr) 58 | mgrStopped.Wait() 59 | }() 60 | 61 | // Create the XGBoostJob object and expect the Reconcile and Deployment to be created 62 | err = c.Create(context.TODO(), instance) 63 | // The instance object may not be a valid object because it might be missing some required fields. 64 | // Please modify the instance object by adding required fields and then remove the following if statement. 65 | if apierrors.IsInvalid(err) { 66 | t.Logf("failed to create object, got an invalid object error: %v", err) 67 | return 68 | } 69 | g.Expect(err).NotTo(gomega.HaveOccurred()) 70 | defer c.Delete(context.TODO(), instance) 71 | g.Eventually(requests, timeout).Should(gomega.Receive(gomega.Equal(expectedRequest))) 72 | 73 | deploy := &appsv1.Deployment{} 74 | g.Eventually(func() error { return c.Get(context.TODO(), depKey, deploy) }, timeout). 75 | Should(gomega.Succeed()) 76 | 77 | // Delete the Deployment and expect Reconcile to be called for Deployment deletion 78 | g.Expect(c.Delete(context.TODO(), deploy)).NotTo(gomega.HaveOccurred()) 79 | g.Eventually(requests, timeout).Should(gomega.Receive(gomega.Equal(expectedRequest))) 80 | g.Eventually(func() error { return c.Get(context.TODO(), depKey, deploy) }, timeout). 81 | Should(gomega.Succeed()) 82 | 83 | // Manually delete Deployment since GC isn't enabled in the test control plane 84 | g.Eventually(func() error { return c.Delete(context.TODO(), deploy) }, timeout). 85 | Should(gomega.MatchError("deployments.apps \"foo-deployment\" not found")) 86 | 87 | } 88 | -------------------------------------------------------------------------------- /pkg/webhook/webhook.go: -------------------------------------------------------------------------------- 1 | /* 2 | 3 | Licensed under the Apache License, Version 2.0 (the "License"); 4 | you may not use this file except in compliance with the License. 5 | You may obtain a copy of the License at 6 | 7 | http://www.apache.org/licenses/LICENSE-2.0 8 | 9 | Unless required by applicable law or agreed to in writing, software 10 | distributed under the License is distributed on an "AS IS" BASIS, 11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | See the License for the specific language governing permissions and 13 | limitations under the License. 14 | */ 15 | 16 | package webhook 17 | 18 | import ( 19 | "sigs.k8s.io/controller-runtime/pkg/manager" 20 | ) 21 | 22 | // AddToManagerFuncs is a list of functions to add all Controllers to the Manager 23 | var AddToManagerFuncs []func(manager.Manager) error 24 | 25 | // AddToManager adds all Controllers to the Manager 26 | // +kubebuilder:rbac:groups=admissionregistration.k8s.io,resources=mutatingwebhookconfigurations;validatingwebhookconfigurations,verbs=get;list;watch;create;update;patch;delete 27 | // +kubebuilder:rbac:groups="",resources=secrets,verbs=get;list;watch;create;update;patch;delete 28 | // +kubebuilder:rbac:groups="",resources=services,verbs=get;list;watch;create;update;patch;delete 29 | func AddToManager(m manager.Manager) error { 30 | for _, f := range AddToManagerFuncs { 31 | if err := f(m); err != nil { 32 | return err 33 | } 34 | } 35 | return nil 36 | } 37 | -------------------------------------------------------------------------------- /prow_config.yaml: -------------------------------------------------------------------------------- 1 | # This file configures the workflows to trigger in our Prow jobs. 2 | # see https://github.com/kubeflow/testing/blob/master/py/kubeflow/testing/run_e2e_workflow.py 3 | workflows: 4 | - app_dir: kubeflow/xgboost-operator/test/workflows 5 | # this super-short names are required so that identity lengths will be shorter than 64 6 | component: build 7 | name: build 8 | job_types: 9 | - presubmit 10 | params: 11 | registry: "gcr.io/kubeflow-ci" 12 | - app_dir: kubeflow/xgboost-operator/test/workflows 13 | component: build 14 | name: build 15 | job_types: 16 | - postsubmit 17 | params: 18 | registry: "gcr.io/kubeflow-images-public" 19 | -------------------------------------------------------------------------------- /test/workflows/.gitignore: -------------------------------------------------------------------------------- 1 | /.ksonnet/registries 2 | /app.override.yaml 3 | /.ks_environment 4 | -------------------------------------------------------------------------------- /test/workflows/app.yaml: -------------------------------------------------------------------------------- 1 | apiVersion: 0.3.0 2 | environments: 3 | dockerbuild: 4 | destination: 5 | namespace: kubeflow-releasing 6 | server: https://35.226.49.107 7 | k8sVersion: v1.7.0 8 | path: dockerbuild 9 | kind: ksonnet.io/app 10 | name: workflows 11 | registries: 12 | incubator: 13 | gitVersion: 14 | commitSha: ea3408d44c2d8ea4d321364e5533d5c60e74bce0 15 | refSpec: master 16 | protocol: github 17 | uri: github.com/ksonnet/parts/tree/master/incubator 18 | kubeflow: 19 | gitVersion: 20 | commitSha: 5c35580d76092788b089cb447be3f3097cffe60b 21 | refSpec: master 22 | protocol: github 23 | uri: github.com/google/kubeflow/tree/master/kubeflow 24 | version: 0.0.1 25 | -------------------------------------------------------------------------------- /test/workflows/components/build.jsonnet: -------------------------------------------------------------------------------- 1 | local env = std.extVar("__ksonnet/environments"); 2 | local params = std.extVar("__ksonnet/params").components.build; 3 | 4 | local k = import "k.libsonnet"; 5 | local release = import "kubeflow/automation/release.libsonnet"; 6 | local updatedParams = params { 7 | extra_args: if params.extra_args == "null" then "" else " " + params.extra_args, 8 | }; 9 | 10 | std.prune(k.core.v1.list.new(release.parts(updatedParams.namespace, updatedParams.name, overrides=updatedParams).release)) 11 | -------------------------------------------------------------------------------- /test/workflows/components/params.libsonnet: -------------------------------------------------------------------------------- 1 | { 2 | global: { 3 | // User-defined global parameters; accessible to all component and environments, Ex: 4 | // replicas: 4, 5 | }, 6 | components: { 7 | build: { 8 | bucket: "kubeflow-ci_temp", 9 | cluster: "kubeflow-testing", 10 | dockerfile: "Dockerfile", 11 | dockerfileDir: "kubeflow/xgboost-operator/", 12 | extra_args: "null", 13 | extra_repos: "kubeflow/testing@HEAD", 14 | gcpCredentialsSecretName: "gcp-credentials", 15 | image: "xgboost-operator", 16 | name: "xgboost-operator", 17 | namespace: "kubeflow-ci", 18 | nfsVolumeClaim: "nfs-external", 19 | project: "kubeflow-ci", 20 | prow_env: "REPO_OWNER=kubeflow,REPO_NAME=xgboost-operator,PULL_BASE_SHA=master", 21 | registry: "gcr.io/kubeflow-images-public", 22 | testing_image: "gcr.io/kubeflow-ci/test-worker/test-worker:v20190116-b7abb8d-e3b0c4", 23 | versionTag: "v1.0", 24 | zone: "us-central1-a", 25 | }, 26 | }, 27 | } 28 | -------------------------------------------------------------------------------- /test/workflows/components/util.libsonnet: -------------------------------------------------------------------------------- 1 | { 2 | // convert a list of two items into a map representing an environment variable 3 | listToMap:: function(v) 4 | { 5 | name: v[0], 6 | value: v[1], 7 | }, 8 | 9 | // Function to turn comma separated list of prow environment variables into a dictionary. 10 | parseEnv:: function(v) 11 | local pieces = std.split(v, ","); 12 | if v != "" && std.length(pieces) > 0 then 13 | std.map( 14 | function(i) $.listToMap(std.split(i, "=")), 15 | std.split(v, ",") 16 | ) 17 | else [], 18 | } 19 | -------------------------------------------------------------------------------- /test/workflows/environments/base.libsonnet: -------------------------------------------------------------------------------- 1 | local components = std.extVar("__ksonnet/components"); 2 | components + { 3 | // Insert user-specified overrides here. 4 | } 5 | -------------------------------------------------------------------------------- /test/workflows/environments/dockerbuild/globals.libsonnet: -------------------------------------------------------------------------------- 1 | { 2 | } 3 | -------------------------------------------------------------------------------- /test/workflows/environments/dockerbuild/main.jsonnet: -------------------------------------------------------------------------------- 1 | local base = import "base.libsonnet"; 2 | // uncomment if you reference ksonnet-lib 3 | // local k = import "k.libsonnet"; 4 | 5 | base + { 6 | // Insert user-specified overrides here. For example if a component is named \"nginx-deployment\", you might have something like:\n") 7 | // "nginx-deployment"+: k.deployment.mixin.metadata.labels({foo: "bar"}) 8 | } 9 | -------------------------------------------------------------------------------- /test/workflows/environments/dockerbuild/params.libsonnet: -------------------------------------------------------------------------------- 1 | local params = std.extVar("__ksonnet/params"); 2 | local globals = import "globals.libsonnet"; 3 | local envParams = params + { 4 | components +: { 5 | // Insert component parameter overrides here. Ex: 6 | // guestbook +: { 7 | // name: "guestbook-dev", 8 | // replicas: params.global.replicas, 9 | // }, 10 | }, 11 | }; 12 | 13 | { 14 | components: { 15 | [x]: envParams.components[x] + globals, for x in std.objectFields(envParams.components) 16 | }, 17 | } 18 | -------------------------------------------------------------------------------- /test/workflows/vendor/kubeflow/automation/parts.yaml: -------------------------------------------------------------------------------- 1 | { 2 | "name": "automation", 3 | "apiVersion": "0.0.1", 4 | "kind": "ksonnet.io/parts", 5 | "description": "Toolset for kubeflow community.\n", 6 | "author": "kubeflow-team ", 7 | "contributors": [ 8 | ], 9 | "repository": { 10 | "type": "git", 11 | "url": "https://github.com/kubeflow/kubeflow" 12 | }, 13 | "bugs": { 14 | "url": "https://github.com/kubeflow/kubeflow/issues" 15 | }, 16 | "keywords": [ 17 | "kubernetes", 18 | "kubeflow", 19 | "release" 20 | ], 21 | "license": "Apache 2.0", 22 | } 23 | -------------------------------------------------------------------------------- /test/workflows/vendor/kubeflow/automation/prototypes/release.jsonnet: -------------------------------------------------------------------------------- 1 | // @apiVersion 0.1 2 | // @name release 3 | // @description image auto release workflow. 4 | // @shortDescription setup auto release of any image in 30 min. 5 | // @param name string Name to give to each of the components 6 | // @param image string Image name to give to released image 7 | // @param dockerfileDir string Path to dockerfile used for releasing image 8 | // @optionalParam cluster string kubeflow-releasing Self explanatory. 9 | // @optionalParam gcpCredentialsSecretName string gcp-credentials Name of GCP credential stored in cluster as a secret. 10 | // @optionalParam namespace string kubeflow-releasing Self explanatory. 11 | // @optionalParam nfsVolumeClaim string nfs-external Volumn name. 12 | // @optionalParam project string kubeflow-releasing Self explanatory. 13 | // @optionalParam registry string gcr.io/kubeflow-images-public Registry where the image will be pushed to. 14 | // @optionalParam testing_image string gcr.io/kubeflow-releasing/worker:latest The image where we run the workflow. 15 | // @optionalParam zone string us-central1-a GKE zone. 16 | // @optionalParam bucket string kubeflow-releasing-artifacts GCS bucket storing artifacts. 17 | // @optionalParam prow_env string REPO_OWNER=kubeflow,REPO_NAME=kubeflow,PULL_BASE_SHA=master Self explanatory. 18 | // @optionalParam versionTag string latest Tag for released image. 19 | // @optionalParam dockerfile string Dockerfile Name of your dockerfile. 20 | // @optionalParam extra_repos string kubeflow/testing@HEAD Append your repo here (; separated) if your image is not in kubeflow. 21 | // @optionalParam extra_args string null args for your build_image.sh 22 | 23 | local k = import "k.libsonnet"; 24 | local release = import "kubeflow/automation/release.libsonnet"; 25 | local updatedParams = params { 26 | extra_args: if params.extra_args == "null" then "" else " " + params.extra_args, 27 | }; 28 | 29 | std.prune(k.core.v1.list.new(release.parts(updatedParams.namespace, updatedParams.name, overrides=updatedParams).release)) 30 | -------------------------------------------------------------------------------- /test/workflows/vendor/kubeflow/automation/release.libsonnet: -------------------------------------------------------------------------------- 1 | { 2 | // convert a list of two items into a map representing an environment variable 3 | listToMap:: function(v) 4 | { 5 | name: v[0], 6 | value: v[1], 7 | }, 8 | 9 | // Function to turn comma separated list of prow environment variables into a dictionary. 10 | parseEnv:: function(v) 11 | local pieces = std.split(v, ","); 12 | if v != "" && std.length(pieces) > 0 then 13 | std.map( 14 | function(i) $.listToMap(std.split(i, "=")), 15 | std.split(v, ",") 16 | ) 17 | else [], 18 | 19 | 20 | // Default parameters. 21 | defaultParams:: { 22 | bucket: "kubeflow-releasing-artifacts", 23 | commit: "master", 24 | // Name of the secret containing GCP credentials. 25 | gcpCredentialsSecretName: "kubeflow-testing-credentials", 26 | name: "placeholder", 27 | namespace: "kubeflow-releasing", 28 | // The name of the NFS volume claim to use for test files. 29 | nfsVolumeClaim: "nfs-external", 30 | prow_env: "REPO_OWNER=kubeflow,REPO_NAME=kubeflow,PULL_BASE_SHA=master", 31 | registry: "gcr.io/kubeflow-images-public", 32 | versionTag: "latest", 33 | // The default image to use for the steps in the Argo workflow. 34 | testing_image: "gcr.io/kubeflow-ci/worker:latest", 35 | project: "kubeflow-releasing", 36 | cluster: "kubeflow-releasing", 37 | zone: "us-central1-a", 38 | image: "default-should-not-exist", 39 | dockerfile: "Dockerfile", 40 | extra_repos: "kubeflow/testing@HEAD", 41 | extra_args: "", 42 | }, 43 | 44 | parts(namespace, name, overrides={}):: { 45 | // Workflow to release image. 46 | release:: 47 | local params = $.defaultParams + overrides; 48 | 49 | local namespace = params.namespace; 50 | local testing_image = params.testing_image; 51 | local project = params.project; 52 | local cluster = params.cluster; 53 | local zone = params.zone; 54 | local name = params.name; 55 | 56 | local prow_env = $.parseEnv(params.prow_env); 57 | local bucket = params.bucket; 58 | 59 | local stepsNamespace = name; 60 | // mountPath is the directory where the volume to store the test data 61 | // should be mounted. 62 | local mountPath = "/mnt/" + "test-data-volume"; 63 | // testDir is the root directory for all data for a particular test run. 64 | local testDir = mountPath + "/" + name; 65 | // outputDir is the directory to sync to GCS to contain the output for this job. 66 | local outputDir = testDir + "/output"; 67 | local artifactsDir = outputDir + "/artifacts"; 68 | // Source directory where all repos should be checked out 69 | local srcRootDir = testDir + "/src"; 70 | // The directory containing the kubeflow/kubeflow repo 71 | local srcDir = srcRootDir + "/kubeflow/kubeflow"; 72 | // The name to use for the volume to use to contain test data. 73 | local dataVolume = "kubeflow-test-volume"; 74 | local kubeflowPy = srcDir; 75 | // The directory within the kubeflow_testing submodule containing 76 | // py scripts to use. 77 | local kubeflowTestingPy = srcRootDir + "/kubeflow/testing/py"; 78 | 79 | // Location where build_image.sh 80 | local imageDir = srcRootDir + "/" + params.dockerfileDir; 81 | 82 | local releaseImage = params.registry + "/" + params.image; 83 | 84 | // Build an Argo template to execute a particular command. 85 | // step_name: Name for the template 86 | // command: List to pass as the container command. 87 | local buildTemplate(step_name, command, env_vars=[], sidecars=[]) = { 88 | name: step_name, 89 | container: { 90 | command: command, 91 | image: testing_image, 92 | env: [ 93 | { 94 | // Add the source directories to the python path. 95 | name: "PYTHONPATH", 96 | value: kubeflowPy + ":" + kubeflowTestingPy, 97 | }, 98 | { 99 | name: "GOOGLE_APPLICATION_CREDENTIALS", 100 | value: "/secret/gcp-credentials/key.json", 101 | }, 102 | { 103 | name: "GITHUB_TOKEN", 104 | valueFrom: { 105 | secretKeyRef: { 106 | name: "github-token", 107 | key: "github_token", 108 | }, 109 | }, 110 | }, 111 | { 112 | name: "GCP_PROJECT", 113 | value: project, 114 | }, 115 | { 116 | name: "GCP_REGISTRY", 117 | value: params.registry, 118 | }, 119 | ] + prow_env + env_vars, 120 | resources: { 121 | requests: { 122 | memory: "2Gi", 123 | cpu: "2", 124 | }, 125 | limits: { 126 | memory: "32Gi", 127 | cpu: "16", 128 | }, 129 | }, 130 | volumeMounts: [ 131 | { 132 | name: dataVolume, 133 | mountPath: mountPath, 134 | }, 135 | { 136 | name: "github-token", 137 | mountPath: "/secret/github-token", 138 | }, 139 | { 140 | name: "gcp-credentials", 141 | mountPath: "/secret/gcp-credentials", 142 | }, 143 | ], 144 | }, 145 | sidecars: sidecars, 146 | }; // buildTemplate 147 | 148 | 149 | local buildImageTemplate(step_name, imageDir, dockerfile, image) = 150 | buildTemplate( 151 | step_name, 152 | [ 153 | // We need to explicitly specify bash because 154 | // build_image.sh is not in the container its a volume mounted file. 155 | "/bin/bash", 156 | "-c", 157 | imageDir + "/build_image.sh " 158 | + imageDir + "/" + dockerfile + " " 159 | + image + " " 160 | + params.versionTag 161 | + params.extra_args, 162 | ], 163 | [ 164 | { 165 | name: "DOCKER_HOST", 166 | value: "127.0.0.1", 167 | }, 168 | ], 169 | [{ 170 | name: "dind", 171 | image: "docker:17.10-dind", 172 | securityContext: { 173 | privileged: true, 174 | }, 175 | resources: { 176 | requests: { 177 | memory: "2Gi", 178 | cpu: "2", 179 | }, 180 | limits: { 181 | memory: "32Gi", 182 | cpu: "16", 183 | }, 184 | }, 185 | mirrorVolumeMounts: true, 186 | }], 187 | ); // buildImageTemplate 188 | { 189 | apiVersion: "argoproj.io/v1alpha1", 190 | kind: "Workflow", 191 | metadata: { 192 | name: name, 193 | namespace: namespace, 194 | }, 195 | spec: { 196 | entrypoint: "release", 197 | volumes: [ 198 | { 199 | name: "github-token", 200 | secret: { 201 | secretName: "github-token", 202 | }, 203 | }, 204 | { 205 | name: "gcp-credentials", 206 | secret: { 207 | secretName: params.gcpCredentialsSecretName, 208 | }, 209 | }, 210 | { 211 | name: dataVolume, 212 | persistentVolumeClaim: { 213 | claimName: params.nfsVolumeClaim, 214 | }, 215 | }, 216 | ], // volumes 217 | 218 | // onExit specifies the template that should always run when the workflow completes. 219 | onExit: "exit-handler", 220 | 221 | templates: [ 222 | { 223 | name: "release", 224 | dag: { 225 | tasks: [ 226 | { 227 | name: "checkout", 228 | template: "checkout", 229 | }, 230 | { 231 | name: "image-build-release", 232 | template: "image-build-release", 233 | dependencies: ["checkout"], 234 | }, 235 | ], // tasks 236 | }, //dag 237 | }, // release 238 | { 239 | name: "exit-handler", 240 | steps: [ 241 | [{ 242 | name: "copy-artifacts", 243 | template: "copy-artifacts", 244 | }], 245 | ], 246 | }, 247 | { 248 | name: "checkout", 249 | container: { 250 | command: [ 251 | "/usr/local/bin/checkout.sh", 252 | ], 253 | args: [ 254 | srcRootDir, 255 | ], 256 | env: prow_env + [{ 257 | name: "EXTRA_REPOS", 258 | value: params.extra_repos, 259 | }], 260 | image: testing_image, 261 | volumeMounts: [ 262 | { 263 | name: dataVolume, 264 | mountPath: mountPath, 265 | }, 266 | ], 267 | }, 268 | }, // checkout 269 | 270 | buildImageTemplate("image-build-release", imageDir, params.dockerfile, releaseImage), 271 | 272 | buildTemplate( 273 | "copy-artifacts", 274 | [ 275 | "python", 276 | "-m", 277 | "kubeflow.testing.prow_artifacts", 278 | "--artifacts_dir=" + outputDir, 279 | "copy_artifacts", 280 | "--bucket=" + bucket, 281 | ] 282 | ), // copy-artifacts 283 | ], // templates 284 | }, 285 | }, // release 286 | }, // parts 287 | } 288 | --------------------------------------------------------------------------------