├── .gitignore
├── .travis.yml
├── ADOPTERS.md
├── Dockerfile
├── LICENSE
├── Makefile
├── OWNERS
├── PROJECT
├── README.md
├── build_image.sh
├── cmd
└── manager
│ └── main.go
├── config
├── crds
│ ├── xgboostjob_v1_xgboostjobs.yaml
│ └── xgboostjob_v1alpha1_xgboostjob.yaml
├── default
│ ├── kustomization.yaml
│ ├── manager_auth_proxy_patch.yaml
│ ├── manager_image_patch.yaml
│ └── manager_prometheus_metrics_patch.yaml
├── manager
│ └── manager.yaml
├── rbac
│ ├── auth_proxy_role.yaml
│ ├── auth_proxy_role_binding.yaml
│ ├── auth_proxy_service.yaml
│ ├── rbac_role.yaml
│ └── rbac_role_binding.yaml
└── samples
│ ├── lightgbm-dist
│ ├── Dockerfile
│ ├── README.md
│ ├── main.py
│ ├── train.py
│ ├── utils.py
│ └── xgboostjob_v1_lightgbm_dist_training.yaml
│ ├── smoke-dist
│ ├── Dockerfile
│ ├── README.md
│ ├── requirements.txt
│ ├── tracker.py
│ ├── xgboost_smoke_test.py
│ ├── xgboostjob_v1_rabit_test.yaml
│ └── xgboostjob_v1alpha1_rabit_test.yaml
│ └── xgboost-dist
│ ├── Dockerfile
│ ├── README.md
│ ├── build.sh
│ ├── local_test.py
│ ├── main.py
│ ├── predict.py
│ ├── requirements.txt
│ ├── tracker.py
│ ├── train.py
│ ├── utils.py
│ ├── xgboostjob_v1_iris_predict.yaml
│ ├── xgboostjob_v1_iris_train.yaml
│ ├── xgboostjob_v1alpha1_iris_predict.yaml
│ ├── xgboostjob_v1alpha1_iris_predict_local.yaml
│ ├── xgboostjob_v1alpha1_iris_train.yaml
│ └── xgboostjob_v1alpha1_iris_train_local.yaml
├── go.mod
├── go.sum
├── hack
└── boilerplate.go.txt
├── manifests
├── base
│ ├── cluster-role-binding.yaml
│ ├── cluster-role.yaml
│ ├── crd.yaml
│ ├── deployment.yaml
│ ├── kustomization.yaml
│ ├── params.env
│ ├── service-account.yaml
│ └── service.yaml
└── overlays
│ └── kubeflow
│ └── kustomization.yaml
├── pkg
├── apis
│ ├── addtoscheme_xgboostjob.go
│ ├── apis.go
│ └── xgboostjob
│ │ ├── group.go
│ │ └── v1
│ │ ├── constants.go
│ │ ├── doc.go
│ │ ├── register.go
│ │ ├── v1_suite_test.go.back
│ │ ├── xgboostjob_types.go
│ │ ├── xgboostjob_types_test.go.back
│ │ └── zz_generated.deepcopy.go
├── controller
│ └── v1
│ │ ├── add_xgboostjob.go
│ │ ├── controller.go
│ │ └── xgboostjob
│ │ ├── expectation.go
│ │ ├── job.go
│ │ ├── pod.go
│ │ ├── pod_test.go
│ │ ├── service.go
│ │ ├── util.go
│ │ ├── xgboostjob_controller.go
│ │ ├── xgboostjob_controller_suite_test.go.back
│ │ └── xgboostjob_controller_test.go.back
└── webhook
│ └── webhook.go
├── prow_config.yaml
└── test
└── workflows
├── .gitignore
├── app.yaml
├── components
├── build.jsonnet
├── params.libsonnet
└── util.libsonnet
├── environments
├── base.libsonnet
└── dockerbuild
│ ├── globals.libsonnet
│ ├── main.jsonnet
│ └── params.libsonnet
└── vendor
└── kubeflow
└── automation
├── parts.yaml
├── prototypes
└── release.jsonnet
└── release.libsonnet
/.gitignore:
--------------------------------------------------------------------------------
1 | # Files created by Gogland IDE
2 | .idea/
3 |
4 | # Other temporary files
5 | .DS_Store
6 |
7 | # Compiled python files.
8 | *.pyc
9 |
10 | # Emacs temporary files
11 | *~
12 |
13 | # VIM temporary files.
14 | .swp
15 |
16 |
17 | cover.out
18 | bin/
19 |
--------------------------------------------------------------------------------
/.travis.yml:
--------------------------------------------------------------------------------
1 | language: go
2 |
3 | go:
4 | - "1.12"
5 |
6 | go_import_path: github.com/kubeflow/xgboost-operator
7 |
8 | install:
9 | # get coveralls.io support
10 | - go get github.com/mattn/goveralls
11 | # gometalinter is deprecated; migrating to golangci-lint. See https://github.com/alecthomas/gometalinter/issues/590.
12 | - curl -sfL https://install.goreleaser.com/github.com/golangci/golangci-lint.sh | sh -s -- -b $(go env GOPATH)/bin v1.16.0
13 |
14 | script:
15 | - export GO111MODULE=on
16 | - go build ./...
17 | # - golangci-lint run ./...
18 | # - goveralls -service=travis-ci -v -package ./... -ignore "pkg/client/*/*.go,pkg/client/*/*/*.go,pkg/client/*/*/*/*.go,pkg/client/*/*/*/*/*.go,pkg/client/*/*/*/*/*/*.go,pkg/client/*/*/*/*/*/*/*.go,pkg/apis/xgboost/v1alpha1/zz_generated.*.go"
19 |
--------------------------------------------------------------------------------
/ADOPTERS.md:
--------------------------------------------------------------------------------
1 | # Adopters of XGBoost Operator
2 |
3 | This page contains a list of organizations who are using XGBoost Operator. If you'd like to be included here, please send a pull request which modifies this file. Please keep the list in alphabetical order.
4 |
5 | | Organization | Contact |
6 | | ------------ | ------- |
7 | | [Ant Group](https://www.antgroup.com/) | [Mingjie Tang](https://github.com/merlintang) and [Yuan Tang](https://github.com/terrytangyuan) |
8 | | [Tencent](http://tencent.com/en-us/) | [Ye Yin](https://github.com/hustcat) |
9 |
--------------------------------------------------------------------------------
/Dockerfile:
--------------------------------------------------------------------------------
1 | # Build the manager binary
2 | FROM golang:1.12.17 as builder
3 |
4 | ENV GO111MODULE=on
5 |
6 | # Copy in the go src
7 | WORKDIR /go/src/github.com/kubeflow/xgboost-operator
8 |
9 | COPY go.mod .
10 | COPY go.sum .
11 |
12 | RUN go mod download
13 |
14 | COPY pkg/ pkg/
15 | COPY cmd/ cmd/
16 | COPY config/ config/
17 |
18 | # Build
19 | RUN CGO_ENABLED=0 GOOS=linux GOARCH=amd64 go build -a -o manager github.com/kubeflow/xgboost-operator/cmd/manager
20 |
21 | # Copy the controller-manager into a thin image
22 | FROM ubuntu:latest
23 | WORKDIR /root
24 | COPY --from=builder /go/src/github.com/kubeflow/xgboost-operator/manager .
25 | ENTRYPOINT ["/root/manager"]
26 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | Apache License
2 | Version 2.0, January 2004
3 | http://www.apache.org/licenses/
4 |
5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
6 |
7 | 1. Definitions.
8 |
9 | "License" shall mean the terms and conditions for use, reproduction,
10 | and distribution as defined by Sections 1 through 9 of this document.
11 |
12 | "Licensor" shall mean the copyright owner or entity authorized by
13 | the copyright owner that is granting the License.
14 |
15 | "Legal Entity" shall mean the union of the acting entity and all
16 | other entities that control, are controlled by, or are under common
17 | control with that entity. For the purposes of this definition,
18 | "control" means (i) the power, direct or indirect, to cause the
19 | direction or management of such entity, whether by contract or
20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the
21 | outstanding shares, or (iii) beneficial ownership of such entity.
22 |
23 | "You" (or "Your") shall mean an individual or Legal Entity
24 | exercising permissions granted by this License.
25 |
26 | "Source" form shall mean the preferred form for making modifications,
27 | including but not limited to software source code, documentation
28 | source, and configuration files.
29 |
30 | "Object" form shall mean any form resulting from mechanical
31 | transformation or translation of a Source form, including but
32 | not limited to compiled object code, generated documentation,
33 | and conversions to other media types.
34 |
35 | "Work" shall mean the work of authorship, whether in Source or
36 | Object form, made available under the License, as indicated by a
37 | copyright notice that is included in or attached to the work
38 | (an example is provided in the Appendix below).
39 |
40 | "Derivative Works" shall mean any work, whether in Source or Object
41 | form, that is based on (or derived from) the Work and for which the
42 | editorial revisions, annotations, elaborations, or other modifications
43 | represent, as a whole, an original work of authorship. For the purposes
44 | of this License, Derivative Works shall not include works that remain
45 | separable from, or merely link (or bind by name) to the interfaces of,
46 | the Work and Derivative Works thereof.
47 |
48 | "Contribution" shall mean any work of authorship, including
49 | the original version of the Work and any modifications or additions
50 | to that Work or Derivative Works thereof, that is intentionally
51 | submitted to Licensor for inclusion in the Work by the copyright owner
52 | or by an individual or Legal Entity authorized to submit on behalf of
53 | the copyright owner. For the purposes of this definition, "submitted"
54 | means any form of electronic, verbal, or written communication sent
55 | to the Licensor or its representatives, including but not limited to
56 | communication on electronic mailing lists, source code control systems,
57 | and issue tracking systems that are managed by, or on behalf of, the
58 | Licensor for the purpose of discussing and improving the Work, but
59 | excluding communication that is conspicuously marked or otherwise
60 | designated in writing by the copyright owner as "Not a Contribution."
61 |
62 | "Contributor" shall mean Licensor and any individual or Legal Entity
63 | on behalf of whom a Contribution has been received by Licensor and
64 | subsequently incorporated within the Work.
65 |
66 | 2. Grant of Copyright License. Subject to the terms and conditions of
67 | this License, each Contributor hereby grants to You a perpetual,
68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
69 | copyright license to reproduce, prepare Derivative Works of,
70 | publicly display, publicly perform, sublicense, and distribute the
71 | Work and such Derivative Works in Source or Object form.
72 |
73 | 3. Grant of Patent License. Subject to the terms and conditions of
74 | this License, each Contributor hereby grants to You a perpetual,
75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
76 | (except as stated in this section) patent license to make, have made,
77 | use, offer to sell, sell, import, and otherwise transfer the Work,
78 | where such license applies only to those patent claims licensable
79 | by such Contributor that are necessarily infringed by their
80 | Contribution(s) alone or by combination of their Contribution(s)
81 | with the Work to which such Contribution(s) was submitted. If You
82 | institute patent litigation against any entity (including a
83 | cross-claim or counterclaim in a lawsuit) alleging that the Work
84 | or a Contribution incorporated within the Work constitutes direct
85 | or contributory patent infringement, then any patent licenses
86 | granted to You under this License for that Work shall terminate
87 | as of the date such litigation is filed.
88 |
89 | 4. Redistribution. You may reproduce and distribute copies of the
90 | Work or Derivative Works thereof in any medium, with or without
91 | modifications, and in Source or Object form, provided that You
92 | meet the following conditions:
93 |
94 | (a) You must give any other recipients of the Work or
95 | Derivative Works a copy of this License; and
96 |
97 | (b) You must cause any modified files to carry prominent notices
98 | stating that You changed the files; and
99 |
100 | (c) You must retain, in the Source form of any Derivative Works
101 | that You distribute, all copyright, patent, trademark, and
102 | attribution notices from the Source form of the Work,
103 | excluding those notices that do not pertain to any part of
104 | the Derivative Works; and
105 |
106 | (d) If the Work includes a "NOTICE" text file as part of its
107 | distribution, then any Derivative Works that You distribute must
108 | include a readable copy of the attribution notices contained
109 | within such NOTICE file, excluding those notices that do not
110 | pertain to any part of the Derivative Works, in at least one
111 | of the following places: within a NOTICE text file distributed
112 | as part of the Derivative Works; within the Source form or
113 | documentation, if provided along with the Derivative Works; or,
114 | within a display generated by the Derivative Works, if and
115 | wherever such third-party notices normally appear. The contents
116 | of the NOTICE file are for informational purposes only and
117 | do not modify the License. You may add Your own attribution
118 | notices within Derivative Works that You distribute, alongside
119 | or as an addendum to the NOTICE text from the Work, provided
120 | that such additional attribution notices cannot be construed
121 | as modifying the License.
122 |
123 | You may add Your own copyright statement to Your modifications and
124 | may provide additional or different license terms and conditions
125 | for use, reproduction, or distribution of Your modifications, or
126 | for any such Derivative Works as a whole, provided Your use,
127 | reproduction, and distribution of the Work otherwise complies with
128 | the conditions stated in this License.
129 |
130 | 5. Submission of Contributions. Unless You explicitly state otherwise,
131 | any Contribution intentionally submitted for inclusion in the Work
132 | by You to the Licensor shall be under the terms and conditions of
133 | this License, without any additional terms or conditions.
134 | Notwithstanding the above, nothing herein shall supersede or modify
135 | the terms of any separate license agreement you may have executed
136 | with Licensor regarding such Contributions.
137 |
138 | 6. Trademarks. This License does not grant permission to use the trade
139 | names, trademarks, service marks, or product names of the Licensor,
140 | except as required for reasonable and customary use in describing the
141 | origin of the Work and reproducing the content of the NOTICE file.
142 |
143 | 7. Disclaimer of Warranty. Unless required by applicable law or
144 | agreed to in writing, Licensor provides the Work (and each
145 | Contributor provides its Contributions) on an "AS IS" BASIS,
146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
147 | implied, including, without limitation, any warranties or conditions
148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
149 | PARTICULAR PURPOSE. You are solely responsible for determining the
150 | appropriateness of using or redistributing the Work and assume any
151 | risks associated with Your exercise of permissions under this License.
152 |
153 | 8. Limitation of Liability. In no event and under no legal theory,
154 | whether in tort (including negligence), contract, or otherwise,
155 | unless required by applicable law (such as deliberate and grossly
156 | negligent acts) or agreed to in writing, shall any Contributor be
157 | liable to You for damages, including any direct, indirect, special,
158 | incidental, or consequential damages of any character arising as a
159 | result of this License or out of the use or inability to use the
160 | Work (including but not limited to damages for loss of goodwill,
161 | work stoppage, computer failure or malfunction, or any and all
162 | other commercial damages or losses), even if such Contributor
163 | has been advised of the possibility of such damages.
164 |
165 | 9. Accepting Warranty or Additional Liability. While redistributing
166 | the Work or Derivative Works thereof, You may choose to offer,
167 | and charge a fee for, acceptance of support, warranty, indemnity,
168 | or other liability obligations and/or rights consistent with this
169 | License. However, in accepting such obligations, You may act only
170 | on Your own behalf and on Your sole responsibility, not on behalf
171 | of any other Contributor, and only if You agree to indemnify,
172 | defend, and hold each Contributor harmless for any liability
173 | incurred by, or claims asserted against, such Contributor by reason
174 | of your accepting any such warranty or additional liability.
175 |
176 | END OF TERMS AND CONDITIONS
177 |
178 | APPENDIX: How to apply the Apache License to your work.
179 |
180 | To apply the Apache License to your work, attach the following
181 | boilerplate notice, with the fields enclosed by brackets "[]"
182 | replaced with your own identifying information. (Don't include
183 | the brackets!) The text should be enclosed in the appropriate
184 | comment syntax for the file format. We also recommend that a
185 | file or class name and description of purpose be included on the
186 | same "printed page" as the copyright notice for easier
187 | identification within third-party archives.
188 |
189 | Copyright [yyyy] [name of copyright owner]
190 |
191 | Licensed under the Apache License, Version 2.0 (the "License");
192 | you may not use this file except in compliance with the License.
193 | You may obtain a copy of the License at
194 |
195 | http://www.apache.org/licenses/LICENSE-2.0
196 |
197 | Unless required by applicable law or agreed to in writing, software
198 | distributed under the License is distributed on an "AS IS" BASIS,
199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
200 | See the License for the specific language governing permissions and
201 | limitations under the License.
202 |
--------------------------------------------------------------------------------
/Makefile:
--------------------------------------------------------------------------------
1 |
2 | # Image URL to use all building/pushing image targets
3 | IMG ?= controller:latest
4 |
5 | all: test manager
6 |
7 | # Run tests
8 | test: generate fmt vet manifests
9 | go test ./pkg/... ./cmd/... -coverprofile cover.out
10 |
11 | # Build manager binary
12 | manager: generate fmt vet
13 | go build -o bin/manager github.com/kubeflow/xgboost-operator/cmd/manager
14 |
15 | # Run against the configured Kubernetes cluster in ~/.kube/config
16 | run: generate fmt vet
17 | go run ./cmd/manager/main.go
18 |
19 | # Install CRDs into a cluster
20 | install: manifests
21 | kubectl apply -f config/crds
22 |
23 | # Deploy controller in the configured Kubernetes cluster in ~/.kube/config
24 | deploy: manifests
25 | kubectl apply -f config/crds
26 | kustomize build config/default | kubectl apply -f -
27 |
28 | # Generate manifests e.g. CRD, RBAC etc.
29 | manifests: controller-gen
30 | $(CONTROLLER_GEN) crd rbac:roleName=manager-role paths=./pkg/apis/... output:crd:artifacts:config=config/crds
31 |
32 | # Run go fmt against code
33 | fmt:
34 | go fmt ./pkg/... ./cmd/...
35 |
36 | # Run go vet against code
37 | vet:
38 | go vet ./pkg/... ./cmd/...
39 |
40 | # Generate code
41 | generate: controller-gen
42 | $(CONTROLLER_GEN) object:headerFile=./hack/boilerplate.go.txt paths="./pkg/apis/... ./pkg/controller/..."
43 |
44 | # Build the docker image
45 | docker-build: test
46 | docker build . -t ${IMG}
47 | @echo "updating kustomize image patch file for manager resource"
48 | sed -i'' -e 's@image: .*@image: '"${IMG}"'@' ./config/default/manager_image_patch.yaml
49 |
50 | # Push the docker image
51 | docker-push:
52 | docker push ${IMG}
53 |
54 | # find or download controller-gen if necessary
55 | controller-gen:
56 | ifeq (, $(shell which controller-gen))
57 | # Please backport https://github.com/kubernetes-sigs/controller-tools/pull/317
58 | # and build a customize controller-gen to use
59 | go get sigs.k8s.io/controller-tools/cmd/controller-gen@v0.2.2
60 | CONTROLLER_GEN=$(GOBIN)/controller-gen
61 | else
62 | CONTROLLER_GEN=$(shell which controller-gen)
63 | endif
64 |
--------------------------------------------------------------------------------
/OWNERS:
--------------------------------------------------------------------------------
1 | approvers:
2 | - merlintang
3 | - terrytangyuan
4 | - Jeffwan
5 |
--------------------------------------------------------------------------------
/PROJECT:
--------------------------------------------------------------------------------
1 | version: "1"
2 | domain: kubeflow.org
3 | repo: github.com/kubeflow/xgboost-operator
4 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # XGBoost Operator
2 |
3 | [](https://travis-ci.com/kubeflow/xgboost-operator/)
4 | [](https://goreportcard.com/report/github.com/kubeflow/xgboost-operator)
5 |
6 |
7 | ## :warning: **kubeflow/xgboost-operator is not maintained**
8 |
9 | This operator has been merged into [Kubeflow Training Operator](https://github.com/kubeflow/training-operator). This repository is not maintained and has been archived.
10 |
11 | ## Overview
12 |
13 | Incubating project for [XGBoost](https://github.com/dmlc/xgboost) operator. The XGBoost operator makes it easy to run distributed XGBoost job training and batch prediction on Kubernetes cluster.
14 |
15 | The overall design can be found [here]( https://github.com/kubeflow/community/issues/247).
16 |
17 | This repository contains the specification and implementation of `XGBoostJob` custom resource definition.
18 | Using this custom resource, users can create and manage XGBoost jobs like other built-in resources in Kubernetes.
19 | ## Prerequisites
20 | - Kubernetes >= 1.8
21 | - [kubectl](https://kubernetes.io/docs/tasks/tools/install-kubectl)
22 |
23 | ## Install XGBoost Operator
24 |
25 | You can deploy the operator with default settings by running the following commands using [kustomize](https://github.com/kubernetes-sigs/kustomize):
26 |
27 | ```bash
28 | cd manifests
29 | kubectl create namespace kubeflow
30 | kustomize build base | kubectl apply -f -
31 | ```
32 |
33 | Note that since Kubernetes v1.14, `kustomize` became a subcommand in `kubectl` so you can also run the following command instead:
34 |
35 | ```bash
36 | kubectl kustomize base | kubectl apply -f -
37 | ```
38 |
39 | ## Build XGBoost Operator
40 |
41 | XGBoost Operator is developed based on [Kubebuilder](https://github.com/kubernetes-sigs/kubebuilder) and [Kubeflow Common](https://github.com/kubeflow/common).
42 |
43 | You can follow the [installation guide of Kubebuilder](https://book.kubebuilder.io/cronjob-tutorial/running.html) to install XGBoost operator into the Kubernetes cluster.
44 |
45 | You can check whether the XGBoostJob custom resource has been installed via:
46 | ```
47 | kubectl get crd
48 | ```
49 | The output should include xgboostjobs.kubeflow.org like the following:
50 | ```
51 | NAME CREATED AT
52 | xgboostjobs.xgboostjob.kubeflow.org 2021-03-24T22:03:07Z
53 | ```
54 | If it is not included you can add it as follows:
55 | ```
56 | ## setup the build enviroment
57 | export GOPATH=$HOME/go
58 | export PATH=$PATH:$GOPATH/bin
59 | export GO111MODULE=on
60 | cd $GOPATH
61 | mkdir src/github.com/kubeflow
62 | cd src/github.com/kubeflow
63 |
64 | ## clone the code
65 | git clone git@github.com:kubeflow/xgboost-operator.git
66 | cd xgboost-operator
67 |
68 | ## build and install xgboost operator
69 | make
70 | make install
71 | make run
72 | ```
73 | If the XGBoost Job operator can be installed into cluster, you can view the logs likes this
74 |
75 |
76 | Logs
77 |
78 | ```
79 | {"level":"info","ts":1589406873.090652,"logger":"entrypoint","msg":"setting up client for manager"}
80 | {"level":"info","ts":1589406873.0991302,"logger":"entrypoint","msg":"setting up manager"}
81 | {"level":"info","ts":1589406874.2192929,"logger":"entrypoint","msg":"Registering Components."}
82 | {"level":"info","ts":1589406874.219318,"logger":"entrypoint","msg":"setting up scheme"}
83 | {"level":"info","ts":1589406874.219448,"logger":"entrypoint","msg":"Setting up controller"}
84 | {"level":"info","ts":1589406874.2194738,"logger":"controller","msg":"Running controller in local mode, using kubeconfig file"}
85 | {"level":"info","ts":1589406874.224564,"logger":"controller","msg":"gang scheduling is set: ","gangscheduling":false}
86 | {"level":"info","ts":1589406874.2247412,"logger":"kubebuilder.controller","msg":"Starting EventSource","controller":"xgboostjob-controller","source":"kind source: /, Kind="}
87 | {"level":"info","ts":1589406874.224958,"logger":"kubebuilder.controller","msg":"Starting EventSource","controller":"xgboostjob-controller","source":"kind source: /, Kind="}
88 | {"level":"info","ts":1589406874.2251048,"logger":"kubebuilder.controller","msg":"Starting EventSource","controller":"xgboostjob-controller","source":"kind source: /, Kind="}
89 | {"level":"info","ts":1589406874.225237,"logger":"entrypoint","msg":"setting up webhooks"}
90 | {"level":"info","ts":1589406874.225247,"logger":"entrypoint","msg":"Starting the Cmd."}
91 | {"level":"info","ts":1589406874.32791,"logger":"kubebuilder.controller","msg":"Starting Controller","controller":"xgboostjob-controller"}
92 | {"level":"info","ts":1589406874.430336,"logger":"kubebuilder.controller","msg":"Starting workers","controller":"xgboostjob-controller","worker count":1}
93 | ```
94 |
95 |
96 | ## Creating a XGBoost Training/Prediction Job
97 |
98 | You can create a XGBoost training or prediction (batch oriented) job by modifying the XGBoostJob config file.
99 | See the distributed XGBoost Job training and prediction [example](https://github.com/kubeflow/xgboost-operator/tree/master/config/samples/xgboost-dist).
100 | You can change the config file and related python file (i.e., train.py or predict.py)
101 | based on your requirement.
102 |
103 | Following the job configuration guild in the example, you can deploy a XGBoost Job to start training or prediction like:
104 | ```
105 | ## For training job
106 | cat config/samples/xgboost-dist/xgboostjob_v1_iris_train.yaml
107 | kubectl create -f config/samples/xgboost-dist/xgboostjob_v1_iris_train.yaml
108 |
109 | ## For batch prediction job
110 | cat config/samples/xgboost-dist/xgboostjob_v1_iris_predict.yaml
111 | kubectl create -f config/samples/xgboost-dist/xgboostjob_v1_iris_predict.yaml
112 | ```
113 |
114 | ## Monitor a distributed XGBoost Job
115 |
116 | Once the XGBoost job is created, you should be able to watch how the related pod and service working.
117 | Distributed XGBoost job is trained by synchronizing different worker status via tne Rabit of XGBoost.
118 | You can also monitor the job status.
119 |
120 | ```
121 | kubectl get -o yaml XGBoostJob/xgboost-dist-iris-test-train
122 | ```
123 |
124 | Here is the sample output when training job is finished.
125 |
126 |
127 | XGBoost Job Details
128 |
129 | ```
130 | apiVersion: xgboostjob.kubeflow.org/v1
131 | kind: XGBoostJob
132 | metadata:
133 | annotations:
134 | kubectl.kubernetes.io/last-applied-configuration: |
135 | {"apiVersion":"xgboostjob.kubeflow.org/v1","kind":"XGBoostJob","metadata":{"annotations":{},"name":"xgboost-dist-iris-test-train","namespace":"default"},"spec":{"xgbReplicaSpecs":{"Master":{"replicas":1,"restartPolicy":"Never","template":{"spec":{"containers":[{"args":["--job_type=Train","--xgboost_parameter=objective:multi:softprob,num_class:3","--n_estimators=10","--learning_rate=0.1","--model_path=/tmp/xgboost-model","--model_storage_type=local"],"image":"docker.io/merlintang/xgboost-dist-iris:1.1","imagePullPolicy":"Always","name":"xgboostjob","ports":[{"containerPort":9991,"name":"xgboostjob-port"}]}]}}},"Worker":{"replicas":2,"restartPolicy":"ExitCode","template":{"spec":{"containers":[{"args":["--job_type=Train","--xgboost_parameter=\"objective:multi:softprob,num_class:3\"","--n_estimators=10","--learning_rate=0.1"],"image":"docker.io/merlintang/xgboost-dist-iris:1.1","imagePullPolicy":"Always","name":"xgboostjob","ports":[{"containerPort":9991,"name":"xgboostjob-port"}]}]}}}}}}
136 | creationTimestamp: "2021-03-24T22:54:39Z"
137 | generation: 8
138 | name: xgboost-dist-iris-test-train
139 | namespace: default
140 | resourceVersion: "1060393"
141 | selfLink: /apis/xgboostjob.kubeflow.org/v1/namespaces/default/xgboostjobs/xgboost-dist-iris-test-train
142 | uid: 386c9851-7ef8-4928-9dba-2da8829bf048
143 | spec:
144 | RunPolicy:
145 | cleanPodPolicy: None
146 | xgbReplicaSpecs:
147 | Master:
148 | replicas: 1
149 | restartPolicy: Never
150 | template:
151 | metadata:
152 | creationTimestamp: null
153 | spec:
154 | containers:
155 | - args:
156 | - --job_type=Train
157 | - --xgboost_parameter=objective:multi:softprob,num_class:3
158 | - --n_estimators=10
159 | - --learning_rate=0.1
160 | - --model_path=/tmp/xgboost-model
161 | - --model_storage_type=local
162 | image: docker.io/merlintang/xgboost-dist-iris:1.1
163 | imagePullPolicy: Always
164 | name: xgboostjob
165 | ports:
166 | - containerPort: 9991
167 | name: xgboostjob-port
168 | resources: {}
169 | Worker:
170 | replicas: 2
171 | restartPolicy: ExitCode
172 | template:
173 | metadata:
174 | creationTimestamp: null
175 | spec:
176 | containers:
177 | - args:
178 | - --job_type=Train
179 | - --xgboost_parameter="objective:multi:softprob,num_class:3"
180 | - --n_estimators=10
181 | - --learning_rate=0.1
182 | image: docker.io/merlintang/xgboost-dist-iris:1.1
183 | imagePullPolicy: Always
184 | name: xgboostjob
185 | ports:
186 | - containerPort: 9991
187 | name: xgboostjob-port
188 | resources: {}
189 | status:
190 | completionTime: "2021-03-24T22:54:58Z"
191 | conditions:
192 | - lastTransitionTime: "2021-03-24T22:54:39Z"
193 | lastUpdateTime: "2021-03-24T22:54:39Z"
194 | message: xgboostJob xgboost-dist-iris-test-train is created.
195 | reason: XGBoostJobCreated
196 | status: "True"
197 | type: Created
198 | - lastTransitionTime: "2021-03-24T22:54:39Z"
199 | lastUpdateTime: "2021-03-24T22:54:39Z"
200 | message: XGBoostJob xgboost-dist-iris-test-train is running.
201 | reason: XGBoostJobRunning
202 | status: "False"
203 | type: Running
204 | - lastTransitionTime: "2021-03-24T22:54:58Z"
205 | lastUpdateTime: "2021-03-24T22:54:58Z"
206 | message: XGBoostJob xgboost-dist-iris-test-train is successfully completed.
207 | reason: XGBoostJobSucceeded
208 | status: "True"
209 | type: Succeeded
210 | replicaStatuses:
211 | Master:
212 | succeeded: 1
213 | Worker:
214 | succeeded: 2
215 | ```
216 |
217 |
218 |
219 | ## Docker Images
220 |
221 | You can use [this Dockerfile](Dockerfile) to build the image yourself:
222 |
223 | Alternatively, you can pull the existing image from Dockerhub [here](https://hub.docker.com/r/kubeflow/xgboost-operator/tags).
224 |
225 | ## Known Issues
226 |
227 | XGBoost and `kubeflow/common` use pointer value in map like `map[commonv1.ReplicaType]*commonv1.ReplicaSpec`. However, `controller-gen` in [controller-tools](https://github.com/kubernetes-sigs/controller-tools) doesn't accept pointers as map values in latest version (v0.3.0), in order to generate crds and deepcopy files, we need to build custom `controller-gen`. You can follow steps below. Then `make generate` can work properly.
228 |
229 | ```shell
230 | git clone --branch v0.2.2 git@github.com:kubernetes-sigs/controller-tools.git
231 | git cherry-pick 71b6e91
232 | go build -o controller-gen cmd/controller-gen/main.go
233 | cp controller-gen /usr/local/bin
234 | ```
235 |
236 | This can be removed once a newer `controller-gen` released and xgboost can upgrade to corresponding k8s version.
237 |
--------------------------------------------------------------------------------
/build_image.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 | #
3 | # A simple script to build the Docker images.
4 | # This is intended to be invoked as a step in Argo to build the docker image.
5 | #
6 | # build_image.sh ${DOCKERFILE} ${IMAGE} ${TAG}
7 | # Note: TAG is not used from workflows, always generated uniquely
8 | set -ex
9 |
10 | DOCKERFILE=$1
11 | CONTEXT_DIR=$(dirname "$DOCKERFILE")
12 | IMAGE=$2
13 | TIMEOUT=30m
14 |
15 | cd $CONTEXT_DIR
16 | TAG=$(git describe --tags --always --dirty)
17 |
18 | gcloud auth activate-service-account --key-file=${GOOGLE_APPLICATION_CREDENTIALS}
19 |
20 | echo "Building ${IMAGE} using gcloud build"
21 | gcloud builds submit --tag=${IMAGE}:${TAG} --project=${GCP_PROJECT} --timeout=${TIMEOUT} .
22 | echo "Finished building ${IMAGE}:${TAG}"
23 |
--------------------------------------------------------------------------------
/cmd/manager/main.go:
--------------------------------------------------------------------------------
1 | /*
2 |
3 | Licensed under the Apache License, Version 2.0 (the "License");
4 | you may not use this file except in compliance with the License.
5 | You may obtain a copy of the License at
6 |
7 | http://www.apache.org/licenses/LICENSE-2.0
8 |
9 | Unless required by applicable law or agreed to in writing, software
10 | distributed under the License is distributed on an "AS IS" BASIS,
11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | See the License for the specific language governing permissions and
13 | limitations under the License.
14 | */
15 |
16 | package main
17 |
18 | import (
19 | "flag"
20 | "os"
21 |
22 | "github.com/kubeflow/xgboost-operator/pkg/apis"
23 | controller "github.com/kubeflow/xgboost-operator/pkg/controller/v1"
24 | "github.com/kubeflow/xgboost-operator/pkg/webhook"
25 | _ "k8s.io/client-go/plugin/pkg/client/auth/gcp"
26 | "sigs.k8s.io/controller-runtime/pkg/client/config"
27 | "sigs.k8s.io/controller-runtime/pkg/manager"
28 | logf "sigs.k8s.io/controller-runtime/pkg/runtime/log"
29 | "sigs.k8s.io/controller-runtime/pkg/runtime/signals"
30 | )
31 |
32 | func main() {
33 | var metricsAddr string
34 | var mode string
35 | flag.StringVar(&metricsAddr, "metrics-addr", ":8080", "The address the metric endpoint binds to.")
36 | flag.StringVar(&mode, "mode", "local", "The mode in which xgboost-operator to run")
37 | flag.Parse()
38 | logf.SetLogger(logf.ZapLogger(false))
39 | log := logf.Log.WithName("entrypoint")
40 |
41 | // Get a config to talk to the apiserver
42 | log.Info("setting up client for manager")
43 | cfg, err := config.GetConfig()
44 | if err != nil {
45 | log.Error(err, "unable to set up client config")
46 | os.Exit(1)
47 | }
48 |
49 | // Create a new Cmd to provide shared dependencies and start components
50 | log.Info("setting up manager")
51 | mgr, err := manager.New(cfg, manager.Options{MetricsBindAddress: metricsAddr})
52 | if err != nil {
53 | log.Error(err, "unable to set up overall controller manager")
54 | os.Exit(1)
55 | }
56 |
57 | log.Info("Registering Components.")
58 |
59 | // Setup Scheme for all resources
60 | log.Info("setting up scheme")
61 | if err := apis.AddToScheme(mgr.GetScheme()); err != nil {
62 | log.Error(err, "unable add APIs to scheme")
63 | os.Exit(1)
64 | }
65 |
66 | // Setup all Controllers
67 | log.Info("Setting up controller")
68 | if err := controller.AddToManager(mgr); err != nil {
69 | log.Error(err, "unable to register controllers to the manager")
70 | os.Exit(1)
71 | }
72 |
73 | log.Info("setting up webhooks")
74 | if err := webhook.AddToManager(mgr); err != nil {
75 | log.Error(err, "unable to register webhooks to the manager")
76 | os.Exit(1)
77 | }
78 |
79 | // Start the Cmd
80 | log.Info("Starting the Cmd.")
81 | if err := mgr.Start(signals.SetupSignalHandler()); err != nil {
82 | log.Error(err, "unable to run the manager")
83 | os.Exit(1)
84 | }
85 | }
86 |
--------------------------------------------------------------------------------
/config/crds/xgboostjob_v1alpha1_xgboostjob.yaml:
--------------------------------------------------------------------------------
1 | apiVersion: apiextensions.k8s.io/v1beta1
2 | kind: CustomResourceDefinition
3 | metadata:
4 | creationTimestamp: null
5 | labels:
6 | controller-tools.k8s.io: "1.0"
7 | name: xgboostjobs.kubeflow.org
8 | spec:
9 | group: kubeflow.org
10 | names:
11 | kind: XGBoostJob
12 | singular: xgboostjob
13 | plural: xgboostjobs
14 | scope: Namespaced
15 | validation:
16 | openAPIV3Schema:
17 | properties:
18 | apiVersion:
19 | description: 'APIVersion defines the versioned schema of this representation
20 | of an object. Servers should convert recognized schemas to the latest
21 | internal value, and may reject unrecognized values. More info: https://git.k8s.io/community/contributors/devel/api-conventions.md#resources'
22 | type: string
23 | kind:
24 | description: 'Kind is a string value representing the REST resource this
25 | object represents. Servers may infer this from the endpoint the client
26 | submits requests to. Cannot be updated. In CamelCase. More info: https://git.k8s.io/community/contributors/devel/api-conventions.md#types-kinds'
27 | type: string
28 | metadata:
29 | type: object
30 | spec:
31 | properties:
32 | activeDeadlineSeconds:
33 | description: Specifies the duration in seconds relative to the startTime
34 | that the job may be active before the system tries to terminate it;
35 | value must be positive integer.
36 | format: int64
37 | type: integer
38 | backoffLimit:
39 | description: Optional number of retries before marking this job failed.
40 | format: int32
41 | type: integer
42 | cleanPodPolicy:
43 | description: CleanPodPolicy defines the policy to kill pods after the
44 | job completes. Default to Running.
45 | type: string
46 | schedulingPolicy:
47 | description: SchedulingPolicy defines the policy related to scheduling,
48 | e.g. gang-scheduling
49 | properties:
50 | minAvailable:
51 | format: int32
52 | type: integer
53 | type: object
54 | ttlSecondsAfterFinished:
55 | description: TTLSecondsAfterFinished is the TTL to clean up jobs. It
56 | may take extra ReconcilePeriod seconds for the cleanup, since reconcile
57 | gets called periodically. Default to infinite.
58 | format: int32
59 | type: integer
60 | xgbReplicaSpecs:
61 | type: object
62 | required:
63 | - xgbReplicaSpecs
64 | type: object
65 | status:
66 | properties:
67 | completionTime:
68 | description: Represents time when the job was completed. It is not guaranteed
69 | to be set in happens-before order across separate operations. It is
70 | represented in RFC3339 form and is in UTC.
71 | format: date-time
72 | type: string
73 | conditions:
74 | description: Conditions is an array of current observed job conditions.
75 | items:
76 | properties:
77 | lastTransitionTime:
78 | description: Last time the condition transitioned from one status
79 | to another.
80 | format: date-time
81 | type: string
82 | lastUpdateTime:
83 | description: The last time this condition was updated.
84 | format: date-time
85 | type: string
86 | message:
87 | description: A human readable message indicating details about
88 | the transition.
89 | type: string
90 | reason:
91 | description: The reason for the condition's last transition.
92 | type: string
93 | status:
94 | description: Status of the condition, one of True, False, Unknown.
95 | type: string
96 | type:
97 | description: Type of job condition.
98 | type: string
99 | required:
100 | - type
101 | - status
102 | type: object
103 | type: array
104 | lastReconcileTime:
105 | description: Represents last time when the job was reconciled. It is
106 | not guaranteed to be set in happens-before order across separate operations.
107 | It is represented in RFC3339 form and is in UTC.
108 | format: date-time
109 | type: string
110 | replicaStatuses:
111 | description: ReplicaStatuses is map of ReplicaType and ReplicaStatus,
112 | specifies the status of each replica.
113 | type: object
114 | startTime:
115 | description: Represents time when the job was acknowledged by the job
116 | controller. It is not guaranteed to be set in happens-before order
117 | across separate operations. It is represented in RFC3339 form and
118 | is in UTC.
119 | format: date-time
120 | type: string
121 | required:
122 | - conditions
123 | - replicaStatuses
124 | type: object
125 | version: v1alpha1
126 | status:
127 | acceptedNames:
128 | kind: ""
129 | plural: ""
130 | conditions: []
131 | storedVersions: []
132 |
--------------------------------------------------------------------------------
/config/default/kustomization.yaml:
--------------------------------------------------------------------------------
1 | # Adds namespace to all resources.
2 | namespace: xgboost-operator-system
3 |
4 | # Value of this field is prepended to the
5 | # names of all resources, e.g. a deployment named
6 | # "wordpress" becomes "alices-wordpress".
7 | # Note that it should also match with the prefix (text before '-') of the namespace
8 | # field above.
9 | namePrefix: xgboost-operator-
10 |
11 | # Labels to add to all resources and selectors.
12 | #commonLabels:
13 | # someName: someValue
14 |
15 | # Each entry in this list must resolve to an existing
16 | # resource definition in YAML. These are the resource
17 | # files that kustomize reads, modifies and emits as a
18 | # YAML string, with resources separated by document
19 | # markers ("---").
20 | resources:
21 | - ../rbac/rbac_role.yaml
22 | - ../rbac/rbac_role_binding.yaml
23 | - ../manager/manager.yaml
24 | # Comment the following 3 lines if you want to disable
25 | # the auth proxy (https://github.com/brancz/kube-rbac-proxy)
26 | # which protects your /metrics endpoint.
27 | - ../rbac/auth_proxy_service.yaml
28 | - ../rbac/auth_proxy_role.yaml
29 | - ../rbac/auth_proxy_role_binding.yaml
30 |
31 | patches:
32 | - manager_image_patch.yaml
33 | # Protect the /metrics endpoint by putting it behind auth.
34 | # Only one of manager_auth_proxy_patch.yaml and
35 | # manager_prometheus_metrics_patch.yaml should be enabled.
36 | - manager_auth_proxy_patch.yaml
37 | # If you want your controller-manager to expose the /metrics
38 | # endpoint w/o any authn/z, uncomment the following line and
39 | # comment manager_auth_proxy_patch.yaml.
40 | # Only one of manager_auth_proxy_patch.yaml and
41 | # manager_prometheus_metrics_patch.yaml should be enabled.
42 | #- manager_prometheus_metrics_patch.yaml
43 |
44 | vars:
45 | - name: WEBHOOK_SECRET_NAME
46 | objref:
47 | kind: Secret
48 | name: webhook-server-secret
49 | apiVersion: v1
50 |
--------------------------------------------------------------------------------
/config/default/manager_auth_proxy_patch.yaml:
--------------------------------------------------------------------------------
1 | # This patch inject a sidecar container which is a HTTP proxy for the controller manager,
2 | # it performs RBAC authorization against the Kubernetes API using SubjectAccessReviews.
3 | apiVersion: apps/v1
4 | kind: StatefulSet
5 | metadata:
6 | name: controller-manager
7 | namespace: system
8 | spec:
9 | template:
10 | spec:
11 | containers:
12 | - name: kube-rbac-proxy
13 | image: gcr.io/kubebuilder/kube-rbac-proxy:v0.4.0
14 | args:
15 | - "--secure-listen-address=0.0.0.0:8443"
16 | - "--upstream=http://127.0.0.1:8080/"
17 | - "--logtostderr=true"
18 | - "--v=10"
19 | ports:
20 | - containerPort: 8443
21 | name: https
22 | - name: manager
23 | args:
24 | - "--metrics-addr=127.0.0.1:8080"
25 |
--------------------------------------------------------------------------------
/config/default/manager_image_patch.yaml:
--------------------------------------------------------------------------------
1 | apiVersion: apps/v1
2 | kind: StatefulSet
3 | metadata:
4 | name: controller-manager
5 | namespace: system
6 | spec:
7 | template:
8 | spec:
9 | containers:
10 | # Change the value of image field below to your controller image URL
11 | - image: IMAGE_URL
12 | name: manager
13 |
--------------------------------------------------------------------------------
/config/default/manager_prometheus_metrics_patch.yaml:
--------------------------------------------------------------------------------
1 | # This patch enables Prometheus scraping for the manager pod.
2 | apiVersion: apps/v1
3 | kind: StatefulSet
4 | metadata:
5 | name: controller-manager
6 | namespace: system
7 | spec:
8 | template:
9 | metadata:
10 | annotations:
11 | prometheus.io/scrape: 'true'
12 | spec:
13 | containers:
14 | # Expose the prometheus metrics on default port
15 | - name: manager
16 | ports:
17 | - containerPort: 8080
18 | name: metrics
19 | protocol: TCP
20 |
--------------------------------------------------------------------------------
/config/manager/manager.yaml:
--------------------------------------------------------------------------------
1 | apiVersion: v1
2 | kind: Namespace
3 | metadata:
4 | labels:
5 | control-plane: controller-manager
6 | controller-tools.k8s.io: "1.0"
7 | name: system
8 | ---
9 | apiVersion: v1
10 | kind: Service
11 | metadata:
12 | name: controller-manager-service
13 | namespace: system
14 | labels:
15 | control-plane: controller-manager
16 | controller-tools.k8s.io: "1.0"
17 | spec:
18 | selector:
19 | control-plane: controller-manager
20 | controller-tools.k8s.io: "1.0"
21 | ports:
22 | - port: 443
23 | ---
24 | apiVersion: apps/v1
25 | kind: StatefulSet
26 | metadata:
27 | name: controller-manager
28 | namespace: system
29 | labels:
30 | control-plane: controller-manager
31 | controller-tools.k8s.io: "1.0"
32 | spec:
33 | selector:
34 | matchLabels:
35 | control-plane: controller-manager
36 | controller-tools.k8s.io: "1.0"
37 | serviceName: controller-manager-service
38 | template:
39 | metadata:
40 | labels:
41 | control-plane: controller-manager
42 | controller-tools.k8s.io: "1.0"
43 | spec:
44 | containers:
45 | - command:
46 | - /manager
47 | image: controller:latest
48 | imagePullPolicy: Always
49 | name: manager
50 | env:
51 | - name: POD_NAMESPACE
52 | valueFrom:
53 | fieldRef:
54 | fieldPath: metadata.namespace
55 | - name: SECRET_NAME
56 | value: $(WEBHOOK_SECRET_NAME)
57 | resources:
58 | limits:
59 | cpu: 100m
60 | memory: 30Mi
61 | requests:
62 | cpu: 100m
63 | memory: 20Mi
64 | ports:
65 | - containerPort: 9876
66 | name: webhook-server
67 | protocol: TCP
68 | volumeMounts:
69 | - mountPath: /tmp/cert
70 | name: cert
71 | readOnly: true
72 | terminationGracePeriodSeconds: 10
73 | volumes:
74 | - name: cert
75 | secret:
76 | defaultMode: 420
77 | secretName: webhook-server-secret
78 | ---
79 | apiVersion: v1
80 | kind: Secret
81 | metadata:
82 | name: webhook-server-secret
83 | namespace: system
84 |
--------------------------------------------------------------------------------
/config/rbac/auth_proxy_role.yaml:
--------------------------------------------------------------------------------
1 | apiVersion: rbac.authorization.k8s.io/v1
2 | kind: ClusterRole
3 | metadata:
4 | name: proxy-role
5 | rules:
6 | - apiGroups: ["authentication.k8s.io"]
7 | resources:
8 | - tokenreviews
9 | verbs: ["create"]
10 | - apiGroups: ["authorization.k8s.io"]
11 | resources:
12 | - subjectaccessreviews
13 | verbs: ["create"]
14 |
--------------------------------------------------------------------------------
/config/rbac/auth_proxy_role_binding.yaml:
--------------------------------------------------------------------------------
1 | apiVersion: rbac.authorization.k8s.io/v1
2 | kind: ClusterRoleBinding
3 | metadata:
4 | name: proxy-rolebinding
5 | roleRef:
6 | apiGroup: rbac.authorization.k8s.io
7 | kind: ClusterRole
8 | name: proxy-role
9 | subjects:
10 | - kind: ServiceAccount
11 | name: default
12 | namespace: system
13 |
--------------------------------------------------------------------------------
/config/rbac/auth_proxy_service.yaml:
--------------------------------------------------------------------------------
1 | apiVersion: v1
2 | kind: Service
3 | metadata:
4 | annotations:
5 | prometheus.io/port: "8443"
6 | prometheus.io/scheme: https
7 | prometheus.io/scrape: "true"
8 | labels:
9 | control-plane: controller-manager
10 | controller-tools.k8s.io: "1.0"
11 | name: controller-manager-metrics-service
12 | namespace: system
13 | spec:
14 | ports:
15 | - name: https
16 | port: 8443
17 | targetPort: https
18 | selector:
19 | control-plane: controller-manager
20 | controller-tools.k8s.io: "1.0"
21 |
--------------------------------------------------------------------------------
/config/rbac/rbac_role.yaml:
--------------------------------------------------------------------------------
1 | apiVersion: rbac.authorization.k8s.io/v1
2 | kind: ClusterRole
3 | metadata:
4 | creationTimestamp: null
5 | name: manager-role
6 | rules:
7 | - apiGroups:
8 | - apps
9 | resources:
10 | - deployments
11 | verbs:
12 | - get
13 | - list
14 | - watch
15 | - create
16 | - update
17 | - patch
18 | - delete
19 | - apiGroups:
20 | - apps
21 | resources:
22 | - deployments/status
23 | verbs:
24 | - get
25 | - update
26 | - patch
27 | - apiGroups:
28 | - xgboostjob.kubeflow.org
29 | resources:
30 | - xgboostjobs
31 | verbs:
32 | - get
33 | - list
34 | - watch
35 | - create
36 | - update
37 | - patch
38 | - delete
39 | - apiGroups:
40 | - xgboostjob.kubeflow.org
41 | resources:
42 | - xgboostjobs/status
43 | verbs:
44 | - get
45 | - update
46 | - patch
47 | - apiGroups:
48 | - admissionregistration.k8s.io
49 | resources:
50 | - mutatingwebhookconfigurations
51 | - validatingwebhookconfigurations
52 | verbs:
53 | - get
54 | - list
55 | - watch
56 | - create
57 | - update
58 | - patch
59 | - delete
60 | - apiGroups:
61 | - ""
62 | resources:
63 | - secrets
64 | verbs:
65 | - get
66 | - list
67 | - watch
68 | - create
69 | - update
70 | - patch
71 | - delete
72 | - apiGroups:
73 | - ""
74 | resources:
75 | - services
76 | verbs:
77 | - get
78 | - list
79 | - watch
80 | - create
81 | - update
82 | - patch
83 | - delete
84 |
--------------------------------------------------------------------------------
/config/rbac/rbac_role_binding.yaml:
--------------------------------------------------------------------------------
1 | apiVersion: rbac.authorization.k8s.io/v1
2 | kind: ClusterRoleBinding
3 | metadata:
4 | creationTimestamp: null
5 | name: manager-rolebinding
6 | roleRef:
7 | apiGroup: rbac.authorization.k8s.io
8 | kind: ClusterRole
9 | name: manager-role
10 | subjects:
11 | - kind: ServiceAccount
12 | name: default
13 | namespace: system
14 |
--------------------------------------------------------------------------------
/config/samples/lightgbm-dist/Dockerfile:
--------------------------------------------------------------------------------
1 | FROM ubuntu:16.04
2 |
3 | ARG CONDA_DIR=/opt/conda
4 | ENV PATH $CONDA_DIR/bin:$PATH
5 |
6 | RUN apt-get update && \
7 | apt-get install -y --no-install-recommends \
8 | ca-certificates \
9 | cmake \
10 | build-essential \
11 | gcc \
12 | g++ \
13 | git \
14 | curl && \
15 | # python environment
16 | curl -sL https://repo.anaconda.com/miniconda/Miniconda3-latest-Linux-x86_64.sh -o conda.sh && \
17 | /bin/bash conda.sh -f -b -p $CONDA_DIR && \
18 | export PATH="$CONDA_DIR/bin:$PATH" && \
19 | conda config --set always_yes yes --set changeps1 no && \
20 | # lightgbm
21 | conda install -q -y numpy==1.20.3 scipy==1.6.2 scikit-learn==0.24.2 pandas==1.3.0 && \
22 | git clone --recursive --branch stable --depth 1 https://github.com/Microsoft/LightGBM && \
23 | mkdir LightGBM/build && \
24 | cd LightGBM/build && \
25 | cmake .. && \
26 | make -j4 && \
27 | make install && \
28 | cd ../python-package && \
29 | python setup.py install_lib && \
30 | # clean
31 | apt-get autoremove -y && apt-get clean && \
32 | conda clean -a -y && \
33 | rm -rf /usr/local/src/* && \
34 | rm -rf /LightGBM
35 |
36 | WORKDIR /app
37 |
38 | # Download the example data
39 | RUN mkdir data
40 | ADD https://raw.githubusercontent.com/microsoft/LightGBM/stable/examples/parallel_learning/binary.train data/.
41 | ADD https://raw.githubusercontent.com/microsoft/LightGBM/stable/examples/parallel_learning/binary.test data/.
42 | COPY *.py ./
43 |
44 | ENTRYPOINT [ "python", "/app/main.py" ]
--------------------------------------------------------------------------------
/config/samples/lightgbm-dist/README.md:
--------------------------------------------------------------------------------
1 | ### Distributed Lightgbm Job train
2 |
3 | This folder containers Dockerfile and Python scripts to run a distributed Lightgbm training using the XGBoost operator.
4 | The code is based in this [example](https://github.com/microsoft/LightGBM/tree/master/examples/parallel_learning) in the official github repository of the library.
5 |
6 |
7 | **Build image**
8 | The default image name and tag is `kubeflow/lightgbm-dist-py-test:1.0` respectiveily.
9 |
10 | ```shell
11 | docker build -f Dockerfile -t kubeflow/lightgbm-dist-py-test:1.0 ./
12 | ```
13 |
14 | **Start the training**
15 |
16 | ```
17 | kubectl create -f xgboostjob_v1_lightgbm_dist_training.yaml
18 | ```
19 |
20 | **Look at the job status**
21 | ```
22 | kubectl get -o yaml XGBoostJob/lightgbm-dist-train-test
23 | ```
24 | Here is sample output when the job is running. The output result like this
25 |
26 | ```
27 | apiVersion: xgboostjob.kubeflow.org/v1
28 | kind: XGBoostJob
29 | metadata:
30 | annotations:
31 | kubectl.kubernetes.io/last-applied-configuration: |
32 | {"apiVersion":"xgboostjob.kubeflow.org/v1","kind":"XGBoostJob","metadata":{"annotations":{},"name":"lightgbm-dist-train-test","namespace":"default"},"spec":{"xgbReplicaSpecs":{"Master":{"replicas":1,"restartPolicy":"Never","template":{"apiVersion":"v1","kind":"Pod","spec":{"containers":[{"args":["--job_type=Train","--boosting_type=gbdt","--objective=binary","--metric=binary_logloss,auc","--metric_freq=1","--is_training_metric=true","--max_bin=255","--data=data/binary.train","--valid_data=data/binary.test","--num_trees=100","--learning_rate=01","--num_leaves=63","--tree_learner=feature","--feature_fraction=0.8","--bagging_freq=5","--bagging_fraction=0.8","--min_data_in_leaf=50","--min_sum_hessian_in_leaf=50","--is_enable_sparse=true","--use_two_round_loading=false","--is_save_binary_file=false"],"image":"kubeflow/lightgbm-dist-py-test:1.0","imagePullPolicy":"Never","name":"xgboostjob","ports":[{"containerPort":9991,"name":"xgboostjob-port"}]}]}}},"Worker":{"replicas":2,"restartPolicy":"ExitCode","template":{"apiVersion":"v1","kind":"Pod","spec":{"containers":[{"args":["--job_type=Train","--boosting_type=gbdt","--objective=binary","--metric=binary_logloss,auc","--metric_freq=1","--is_training_metric=true","--max_bin=255","--data=data/binary.train","--valid_data=data/binary.test","--num_trees=100","--learning_rate=01","--num_leaves=63","--tree_learner=feature","--feature_fraction=0.8","--bagging_freq=5","--bagging_fraction=0.8","--min_data_in_leaf=50","--min_sum_hessian_in_leaf=50","--is_enable_sparse=true","--use_two_round_loading=false","--is_save_binary_file=false"],"image":"kubeflow/lightgbm-dist-py-test:1.0","imagePullPolicy":"Never","name":"xgboostjob","ports":[{"containerPort":9991,"name":"xgboostjob-port"}]}]}}}}}}
33 | creationTimestamp: "2020-10-14T15:31:23Z"
34 | generation: 7
35 | managedFields:
36 | - apiVersion: xgboostjob.kubeflow.org/v1
37 | fieldsType: FieldsV1
38 | fieldsV1:
39 | f:metadata:
40 | f:annotations:
41 | .: {}
42 | f:kubectl.kubernetes.io/last-applied-configuration: {}
43 | f:spec:
44 | .: {}
45 | f:xgbReplicaSpecs:
46 | .: {}
47 | f:Master:
48 | .: {}
49 | f:replicas: {}
50 | f:restartPolicy: {}
51 | f:template:
52 | .: {}
53 | f:spec: {}
54 | f:Worker:
55 | .: {}
56 | f:replicas: {}
57 | f:restartPolicy: {}
58 | f:template:
59 | .: {}
60 | f:spec: {}
61 | manager: kubectl-client-side-apply
62 | operation: Update
63 | time: "2020-10-14T15:31:23Z"
64 | - apiVersion: xgboostjob.kubeflow.org/v1
65 | fieldsType: FieldsV1
66 | fieldsV1:
67 | f:spec:
68 | f:RunPolicy:
69 | .: {}
70 | f:cleanPodPolicy: {}
71 | f:xgbReplicaSpecs:
72 | f:Master:
73 | f:template:
74 | f:metadata:
75 | .: {}
76 | f:creationTimestamp: {}
77 | f:spec:
78 | f:containers: {}
79 | f:Worker:
80 | f:template:
81 | f:metadata:
82 | .: {}
83 | f:creationTimestamp: {}
84 | f:spec:
85 | f:containers: {}
86 | f:status:
87 | .: {}
88 | f:completionTime: {}
89 | f:conditions: {}
90 | f:replicaStatuses:
91 | .: {}
92 | f:Master:
93 | .: {}
94 | f:succeeded: {}
95 | f:Worker:
96 | .: {}
97 | f:succeeded: {}
98 | manager: main
99 | operation: Update
100 | time: "2020-10-14T15:34:44Z"
101 | name: lightgbm-dist-train-test
102 | namespace: default
103 | resourceVersion: "38923"
104 | selfLink: /apis/xgboostjob.kubeflow.org/v1/namespaces/default/xgboostjobs/lightgbm-dist-train-test
105 | uid: b2b887d0-445b-498b-8852-26c8edc98dc7
106 | spec:
107 | RunPolicy:
108 | cleanPodPolicy: None
109 | xgbReplicaSpecs:
110 | Master:
111 | replicas: 1
112 | restartPolicy: Never
113 | template:
114 | metadata:
115 | creationTimestamp: null
116 | spec:
117 | containers:
118 | - args:
119 | - --job_type=Train
120 | - --boosting_type=gbdt
121 | - --objective=binary
122 | - --metric=binary_logloss,auc
123 | - --metric_freq=1
124 | - --is_training_metric=true
125 | - --max_bin=255
126 | - --data=data/binary.train
127 | - --valid_data=data/binary.test
128 | - --num_trees=100
129 | - --learning_rate=01
130 | - --num_leaves=63
131 | - --tree_learner=feature
132 | - --feature_fraction=0.8
133 | - --bagging_freq=5
134 | - --bagging_fraction=0.8
135 | - --min_data_in_leaf=50
136 | - --min_sum_hessian_in_leaf=50
137 | - --is_enable_sparse=true
138 | - --use_two_round_loading=false
139 | - --is_save_binary_file=false
140 | image: kubeflow/lightgbm-dist-py-test:1.0
141 | imagePullPolicy: Never
142 | name: xgboostjob
143 | ports:
144 | - containerPort: 9991
145 | name: xgboostjob-port
146 | resources: {}
147 | Worker:
148 | replicas: 2
149 | restartPolicy: ExitCode
150 | template:
151 | metadata:
152 | creationTimestamp: null
153 | spec:
154 | containers:
155 | - args:
156 | - --job_type=Train
157 | - --boosting_type=gbdt
158 | - --objective=binary
159 | - --metric=binary_logloss,auc
160 | - --metric_freq=1
161 | - --is_training_metric=true
162 | - --max_bin=255
163 | - --data=data/binary.train
164 | - --valid_data=data/binary.test
165 | - --num_trees=100
166 | - --learning_rate=01
167 | - --num_leaves=63
168 | - --tree_learner=feature
169 | - --feature_fraction=0.8
170 | - --bagging_freq=5
171 | - --bagging_fraction=0.8
172 | - --min_data_in_leaf=50
173 | - --min_sum_hessian_in_leaf=50
174 | - --is_enable_sparse=true
175 | - --use_two_round_loading=false
176 | - --is_save_binary_file=false
177 | image: kubeflow/lightgbm-dist-py-test:1.0
178 | imagePullPolicy: Never
179 | name: xgboostjob
180 | ports:
181 | - containerPort: 9991
182 | name: xgboostjob-port
183 | resources: {}
184 | status:
185 | completionTime: "2020-10-14T15:34:44Z"
186 | conditions:
187 | - lastTransitionTime: "2020-10-14T15:31:23Z"
188 | lastUpdateTime: "2020-10-14T15:31:23Z"
189 | message: xgboostJob lightgbm-dist-train-test is created.
190 | reason: XGBoostJobCreated
191 | status: "True"
192 | type: Created
193 | - lastTransitionTime: "2020-10-14T15:31:23Z"
194 | lastUpdateTime: "2020-10-14T15:31:23Z"
195 | message: XGBoostJob lightgbm-dist-train-test is running.
196 | reason: XGBoostJobRunning
197 | status: "False"
198 | type: Running
199 | - lastTransitionTime: "2020-10-14T15:34:44Z"
200 | lastUpdateTime: "2020-10-14T15:34:44Z"
201 | message: XGBoostJob lightgbm-dist-train-test is successfully completed.
202 | reason: XGBoostJobSucceeded
203 | status: "True"
204 | type: Succeeded
205 | replicaStatuses:
206 | Master:
207 | succeeded: 1
208 | Worker:
209 | succeeded: 2
210 | ```
--------------------------------------------------------------------------------
/config/samples/lightgbm-dist/main.py:
--------------------------------------------------------------------------------
1 | # Licensed under the Apache License, Version 2.0 (the "License");
2 | # you may not use this file except in compliance with the License.
3 | # You may obtain a copy of the License at
4 | #
5 | # http://www.apache.org/licenses/LICENSE-2.0
6 | #
7 | # Unless required by applicable law or agreed to in writing, software
8 | # distributed under the License is distributed on an "AS IS" BASIS,
9 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
10 | # See the License for the specific language governing permissions and
11 | # limitations under the License.
12 |
13 | import os
14 | import logging
15 | import argparse
16 |
17 | from train import train
18 |
19 | from utils import generate_machine_list_file, generate_train_conf_file
20 |
21 |
22 | logger = logging.getLogger(__name__)
23 |
24 |
25 | def main(args, extra_args):
26 |
27 | master_addr = os.environ["MASTER_ADDR"]
28 | master_port = os.environ["MASTER_PORT"]
29 | worker_addrs = os.environ["WORKER_ADDRS"]
30 | worker_port = os.environ["WORKER_PORT"]
31 | world_size = int(os.environ["WORLD_SIZE"])
32 | rank = int(os.environ["RANK"])
33 |
34 | logger.info(
35 | "extract cluster info from env variables \n"
36 | f"master_addr: {master_addr} \n"
37 | f"master_port: {master_port} \n"
38 | f"worker_addrs: {worker_addrs} \n"
39 | f"worker_port: {worker_port} \n"
40 | f"world_size: {world_size} \n"
41 | f"rank: {rank} \n"
42 | )
43 |
44 | if args.job_type == "Predict":
45 | logging.info("starting the predict job")
46 |
47 | elif args.job_type == "Train":
48 | logging.info("starting the train job")
49 | logging.info(f"extra args:\n {extra_args}")
50 | machine_list_filepath = generate_machine_list_file(
51 | master_addr, master_port, worker_addrs, worker_port
52 | )
53 | logging.info(f"machine list generated in: {machine_list_filepath}")
54 | local_port = worker_port if rank else master_port
55 | config_file = generate_train_conf_file(
56 | machine_list_file=machine_list_filepath,
57 | world_size=world_size,
58 | output_model="model.txt",
59 | local_port=local_port,
60 | extra_args=extra_args,
61 | )
62 | logging.info(f"config generated in: {config_file}")
63 | train(config_file)
64 | logging.info("Finish distributed job")
65 |
66 |
67 | if __name__ == "__main__":
68 | parser = argparse.ArgumentParser()
69 |
70 | parser.add_argument(
71 | "--job_type",
72 | help="Job type to execute",
73 | choices=["Train", "Predict"],
74 | required=True,
75 | )
76 |
77 | logging.basicConfig(format="%(message)s")
78 | logging.getLogger().setLevel(logging.INFO)
79 | args, extra_args = parser.parse_known_args()
80 | main(args, extra_args)
81 |
--------------------------------------------------------------------------------
/config/samples/lightgbm-dist/train.py:
--------------------------------------------------------------------------------
1 | # Licensed under the Apache License, Version 2.0 (the "License");
2 | # you may not use this file except in compliance with the License.
3 | # You may obtain a copy of the License at
4 | #
5 | # http://www.apache.org/licenses/LICENSE-2.0
6 | #
7 | # Unless required by applicable law or agreed to in writing, software
8 | # distributed under the License is distributed on an "AS IS" BASIS,
9 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
10 | # See the License for the specific language governing permissions and
11 | # limitations under the License.
12 |
13 |
14 | import logging
15 | import subprocess
16 |
17 | logger = logging.getLogger(__name__)
18 |
19 |
20 | def train(train_config_filepath: str):
21 | cmd = ["lightgbm", f"config={train_config_filepath}"]
22 | proc = subprocess.Popen(cmd, stdout=subprocess.PIPE)
23 | line = proc.stdout.readline()
24 | while line:
25 | logger.info((line.decode("utf-8").strip()))
26 | line = proc.stdout.readline()
27 |
--------------------------------------------------------------------------------
/config/samples/lightgbm-dist/utils.py:
--------------------------------------------------------------------------------
1 | # Licensed under the Apache License, Version 2.0 (the "License");
2 | # you may not use this file except in compliance with the License.
3 | # You may obtain a copy of the License at
4 | #
5 | # http://www.apache.org/licenses/LICENSE-2.0
6 | #
7 | # Unless required by applicable law or agreed to in writing, software
8 | # distributed under the License is distributed on an "AS IS" BASIS,
9 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
10 | # See the License for the specific language governing permissions and
11 | # limitations under the License.
12 |
13 | import re
14 | import socket
15 | import logging
16 | import tempfile
17 | from time import sleep
18 | from typing import List, Union
19 |
20 | logger = logging.getLogger(__name__)
21 |
22 |
23 | def generate_machine_list_file(
24 | master_addr: str, master_port: str, worker_addrs: str, worker_port: str
25 | ) -> str:
26 | logger.info("starting to extract system env")
27 |
28 | filename = tempfile.NamedTemporaryFile(delete=False).name
29 |
30 | def _get_ips(
31 | master_addr_name,
32 | worker_addr_names,
33 | max_retries=10,
34 | sleep_secs=10,
35 | current_retry=0,
36 | ):
37 | try:
38 | worker_addr_ips = []
39 | master_addr_ip = socket.gethostbyname(master_addr_name)
40 |
41 | for addr in worker_addr_names.split(","):
42 | worker_addr_ips.append(socket.gethostbyname(addr))
43 |
44 | except socket.gaierror as ex:
45 | if "Name or service not known" in str(ex) and current_retry < max_retries:
46 | sleep(sleep_secs)
47 | master_addr_ip, worker_addr_ips = _get_ips(
48 | master_addr_name,
49 | worker_addr_names,
50 | max_retries=max_retries,
51 | sleep_secs=sleep_secs,
52 | current_retry=current_retry + 1,
53 | )
54 | else:
55 | raise ValueError("Couldn't get address names")
56 |
57 | return master_addr_ip, worker_addr_ips
58 |
59 | master_ip, worker_ips = _get_ips(master_addr, worker_addrs)
60 |
61 | with open(filename, "w") as file:
62 | print(f"{master_ip} {master_port}", file=file)
63 | for addr in worker_ips:
64 | print(f"{addr} {worker_port}", file=file)
65 |
66 | return filename
67 |
68 |
69 | def generate_train_conf_file(
70 | machine_list_file: str,
71 | world_size: int,
72 | output_model: str,
73 | local_port: Union[int, str],
74 | extra_args: List[str],
75 | ) -> str:
76 |
77 | filename = tempfile.NamedTemporaryFile(delete=False).name
78 |
79 | with open(filename, "w") as file:
80 | print("task = train", file=file)
81 | print(f"output_model = {output_model}", file=file)
82 | print(f"num_machines = {world_size}", file=file)
83 | print(f"local_listen_port = {local_port}", file=file)
84 | print(f"machine_list_file = {machine_list_file}", file=file)
85 | for arg in extra_args:
86 | m = re.match(r"--(.+)=([^\s]+)", arg)
87 | if m is not None:
88 | k, v = m.groups()
89 | print(f"{k} = {v}", file=file)
90 |
91 | return filename
92 |
--------------------------------------------------------------------------------
/config/samples/lightgbm-dist/xgboostjob_v1_lightgbm_dist_training.yaml:
--------------------------------------------------------------------------------
1 | apiVersion: "xgboostjob.kubeflow.org/v1"
2 | kind: "XGBoostJob"
3 | metadata:
4 | name: "lightgbm-dist-train-test"
5 | spec:
6 | xgbReplicaSpecs:
7 | Master:
8 | replicas: 1
9 | restartPolicy: Never
10 | template:
11 | spec:
12 | containers:
13 | - name: xgboostjob
14 | image: kubeflow/lightgbm-dist-py-test:1.0
15 | ports:
16 | - containerPort: 9991
17 | name: xgboostjob-port
18 | imagePullPolicy: Never
19 | args:
20 | - --job_type=Train
21 | - --boosting_type=gbdt
22 | - --objective=binary
23 | - --metric=binary_logloss,auc
24 | - --metric_freq=1
25 | - --is_training_metric=true
26 | - --max_bin=255
27 | - --data=data/binary.train
28 | - --valid_data=data/binary.test
29 | - --num_trees=100
30 | - --learning_rate=01
31 | - --num_leaves=63
32 | - --tree_learner=feature
33 | - --feature_fraction=0.8
34 | - --bagging_freq=5
35 | - --bagging_fraction=0.8
36 | - --min_data_in_leaf=50
37 | - --min_sum_hessian_in_leaf=50
38 | - --is_enable_sparse=true
39 | - --use_two_round_loading=false
40 | - --is_save_binary_file=false
41 | Worker:
42 | replicas: 2
43 | restartPolicy: ExitCode
44 | template:
45 | spec:
46 | containers:
47 | - name: xgboostjob
48 | image: kubeflow/lightgbm-dist-py-test:1.0
49 | ports:
50 | - containerPort: 9991
51 | name: xgboostjob-port
52 | imagePullPolicy: Never
53 | args:
54 | - --job_type=Train
55 | - --boosting_type=gbdt
56 | - --objective=binary
57 | - --metric=binary_logloss,auc
58 | - --metric_freq=1
59 | - --is_training_metric=true
60 | - --max_bin=255
61 | - --data=data/binary.train
62 | - --valid_data=data/binary.test
63 | - --num_trees=100
64 | - --learning_rate=01
65 | - --num_leaves=63
66 | - --tree_learner=feature
67 | - --feature_fraction=0.8
68 | - --bagging_freq=5
69 | - --bagging_fraction=0.8
70 | - --min_data_in_leaf=50
71 | - --min_sum_hessian_in_leaf=50
72 | - --is_enable_sparse=true
73 | - --use_two_round_loading=false
74 | - --is_save_binary_file=false
75 |
76 |
--------------------------------------------------------------------------------
/config/samples/smoke-dist/Dockerfile:
--------------------------------------------------------------------------------
1 | # Install python 3.6
2 | FROM python:3.6
3 |
4 | RUN apt-get update
5 | RUN apt-get install -y git make g++ cmake
6 |
7 | RUN mkdir -p /opt/mlkube
8 |
9 | # Download the rabit tracker and xgboost code.
10 |
11 | COPY tracker.py /opt/mlkube/
12 | COPY requirements.txt /opt/mlkube/
13 |
14 | # Install requirements
15 |
16 | RUN pip install -r /opt/mlkube/requirements.txt
17 |
18 | # Build XGBoost.
19 | RUN git clone --recursive https://github.com/dmlc/xgboost && \
20 | cd xgboost && \
21 | make -j$(nproc) && \
22 | cd python-package; python setup.py install
23 |
24 | COPY xgboost_smoke_test.py /opt/mlkube/
25 |
26 | ENTRYPOINT ["python", "/opt/mlkube/xgboost_smoke_test.py"]
27 |
--------------------------------------------------------------------------------
/config/samples/smoke-dist/README.md:
--------------------------------------------------------------------------------
1 | ### Distributed send/recv e2e test for xgboost rabit
2 |
3 | This folder containers Dockerfile and distributed send/recv test.
4 |
5 | **Build Image**
6 |
7 | The default image name and tag is `kubeflow/xgboost-dist-rabit-test:1.2`.
8 | You can build the image based on your requirement.
9 |
10 | ```shell
11 | docker build -f Dockerfile -t kubeflow/xgboost-dist-rabit-test:1.2 ./
12 | ```
13 |
14 | **Start and test XGBoost Rabit tracker**
15 |
16 | ```
17 | kubectl create -f xgboostjob_v1alpha1_rabit_test.yaml
18 | ```
19 |
20 | **Look at the job status**
21 | ```
22 | kubectl get -o yaml XGBoostJob/xgboost-dist-test
23 | ```
24 | Here is sample output when the job is running. The output result like this
25 | ```
26 | apiVersion: xgboostjob.kubeflow.org/v1alpha1
27 | kind: XGBoostJob
28 | metadata:
29 | creationTimestamp: "2019-06-21T03:32:57Z"
30 | generation: 7
31 | name: xgboost-dist-test
32 | namespace: default
33 | resourceVersion: "258466"
34 | selfLink: /apis/xgboostjob.kubeflow.org/v1alpha1/namespaces/default/xgboostjobs/xgboost-dist-test
35 | uid: 431dc182-93d5-11e9-bbab-080027dfbfe2
36 | spec:
37 | RunPolicy:
38 | cleanPodPolicy: None
39 | xgbReplicaSpecs:
40 | Master:
41 | replicas: 1
42 | restartPolicy: Never
43 | template:
44 | metadata:
45 | creationTimestamp: null
46 | spec:
47 | containers:
48 | - image: docker.io/merlintang/xgboost-dist-rabit-test:1.2
49 | imagePullPolicy: Always
50 | name: xgboostjob
51 | ports:
52 | - containerPort: 9991
53 | name: xgboostjob-port
54 | resources: {}
55 | Worker:
56 | replicas: 2
57 | restartPolicy: Never
58 | template:
59 | metadata:
60 | creationTimestamp: null
61 | spec:
62 | containers:
63 | - image: docker.io/merlintang/xgboost-dist-rabit-test:1.2
64 | imagePullPolicy: Always
65 | name: xgboostjob
66 | ports:
67 | - containerPort: 9991
68 | name: xgboostjob-port
69 | resources: {}
70 | status:
71 | completionTime: "2019-06-21T03:33:03Z"
72 | conditions:
73 | - lastTransitionTime: "2019-06-21T03:32:57Z"
74 | lastUpdateTime: "2019-06-21T03:32:57Z"
75 | message: xgboostJob xgboost-dist-test is created.
76 | reason: XGBoostJobCreated
77 | status: "True"
78 | type: Created
79 | - lastTransitionTime: "2019-06-21T03:32:57Z"
80 | lastUpdateTime: "2019-06-21T03:32:57Z"
81 | message: XGBoostJob xgboost-dist-test is running.
82 | reason: XGBoostJobRunning
83 | status: "False"
84 | type: Running
85 | - lastTransitionTime: "2019-06-21T03:33:03Z"
86 | lastUpdateTime: "2019-06-21T03:33:03Z"
87 | message: XGBoostJob xgboost-dist-test is successfully completed.
88 | reason: XGBoostJobSucceeded
89 | status: "True"
90 | type: Succeeded
91 | replicaStatuses:
92 | Master:
93 | succeeded: 1
94 | Worker:
95 | succeeded: 2
96 | ```
97 |
98 |
99 |
100 |
--------------------------------------------------------------------------------
/config/samples/smoke-dist/requirements.txt:
--------------------------------------------------------------------------------
1 | numpy>=1.16.3
2 | Cython>=0.29.4
3 | requests>=2.21.0
4 | urllib3>=1.21.1
5 | scipy>=1.4.1
6 |
--------------------------------------------------------------------------------
/config/samples/smoke-dist/tracker.py:
--------------------------------------------------------------------------------
1 | """
2 | Tracker script for DMLC
3 | Implements the tracker control protocol
4 | - start dmlc jobs
5 | - start ps scheduler and rabit tracker
6 | - help nodes to establish links with each other
7 | Tianqi Chen
8 | --------------------------
9 | This was taken from
10 | https://github.com/dmlc/dmlc-core/blob/master/tracker/dmlc_tracker/tracker.py
11 | See LICENSE here
12 | https://github.com/dmlc/dmlc-core/blob/master/LICENSE
13 | No code modified or added except for this explanatory comment.
14 | """
15 | # pylint: disable=invalid-name, missing-docstring, too-many-arguments
16 | # pylint: disable=too-many-locals
17 | # pylint: disable=too-many-branches, too-many-statements
18 | from __future__ import absolute_import
19 |
20 | import os
21 | import sys
22 | import socket
23 | import struct
24 | import subprocess
25 | import argparse
26 | import time
27 | import logging
28 | from threading import Thread
29 |
30 |
31 | class ExSocket(object):
32 | """
33 | Extension of socket to handle recv and send of special data
34 | """
35 | def __init__(self, sock):
36 | self.sock = sock
37 |
38 | def recvall(self, nbytes):
39 | res = []
40 | nread = 0
41 | while nread < nbytes:
42 | chunk = self.sock.recv(min(nbytes - nread, 1024))
43 | nread += len(chunk)
44 | res.append(chunk)
45 | return b''.join(res)
46 |
47 | def recvint(self):
48 | return struct.unpack('@i', self.recvall(4))[0]
49 |
50 | def sendint(self, n):
51 | self.sock.sendall(struct.pack('@i', n))
52 |
53 | def sendstr(self, s):
54 | self.sendint(len(s))
55 | self.sock.sendall(s.encode())
56 |
57 | def recvstr(self):
58 | slen = self.recvint()
59 | return self.recvall(slen).decode()
60 |
61 |
62 | # magic number used to verify existence of data
63 | kMagic = 0xff99
64 |
65 |
66 | def get_some_ip(host):
67 | return socket.getaddrinfo(host, None)[0][4][0]
68 |
69 |
70 | def get_family(addr):
71 | return socket.getaddrinfo(addr, None)[0][0]
72 |
73 |
74 | class SlaveEntry(object):
75 | def __init__(self, sock, s_addr):
76 | slave = ExSocket(sock)
77 | self.sock = slave
78 | self.host = get_some_ip(s_addr[0])
79 | magic = slave.recvint()
80 | assert magic == kMagic, 'invalid magic number=%d from %s' % (
81 | magic, self.host)
82 | slave.sendint(kMagic)
83 | self.rank = slave.recvint()
84 | self.world_size = slave.recvint()
85 | self.jobid = slave.recvstr()
86 | self.cmd = slave.recvstr()
87 | self.wait_accept = 0
88 | self.port = None
89 |
90 | def decide_rank(self, job_map):
91 | if self.rank >= 0:
92 | return self.rank
93 | if self.jobid != 'NULL' and self.jobid in job_map:
94 | return job_map[self.jobid]
95 | return -1
96 |
97 | def assign_rank(self, rank, wait_conn, tree_map, parent_map, ring_map):
98 | self.rank = rank
99 | nnset = set(tree_map[rank])
100 | rprev, rnext = ring_map[rank]
101 | self.sock.sendint(rank)
102 | # send parent rank
103 | self.sock.sendint(parent_map[rank])
104 | # send world size
105 | self.sock.sendint(len(tree_map))
106 | self.sock.sendint(len(nnset))
107 | # send the rprev and next link
108 | for r in nnset:
109 | self.sock.sendint(r)
110 | # send prev link
111 | if rprev != -1 and rprev != rank:
112 | nnset.add(rprev)
113 | self.sock.sendint(rprev)
114 | else:
115 | self.sock.sendint(-1)
116 | # send next link
117 | if rnext != -1 and rnext != rank:
118 | nnset.add(rnext)
119 | self.sock.sendint(rnext)
120 | else:
121 | self.sock.sendint(-1)
122 | while True:
123 | ngood = self.sock.recvint()
124 | goodset = set([])
125 | for _ in range(ngood):
126 | goodset.add(self.sock.recvint())
127 | assert goodset.issubset(nnset)
128 | badset = nnset - goodset
129 | conset = []
130 | for r in badset:
131 | if r in wait_conn:
132 | conset.append(r)
133 | self.sock.sendint(len(conset))
134 | self.sock.sendint(len(badset) - len(conset))
135 | for r in conset:
136 | self.sock.sendstr(wait_conn[r].host)
137 | self.sock.sendint(wait_conn[r].port)
138 | self.sock.sendint(r)
139 | nerr = self.sock.recvint()
140 | if nerr != 0:
141 | continue
142 | self.port = self.sock.recvint()
143 | rmset = []
144 | # all connection was successuly setup
145 | for r in conset:
146 | wait_conn[r].wait_accept -= 1
147 | if wait_conn[r].wait_accept == 0:
148 | rmset.append(r)
149 | for r in rmset:
150 | wait_conn.pop(r, None)
151 | self.wait_accept = len(badset) - len(conset)
152 | return rmset
153 |
154 |
155 | class RabitTracker(object):
156 | """
157 | tracker for rabit
158 | """
159 | def __init__(self, hostIP, nslave, port=9091, port_end=9999):
160 | sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
161 | for port in range(port, port_end):
162 | try:
163 | sock.bind((hostIP, port))
164 | self.port = port
165 | break
166 | except socket.error as e:
167 | if e.errno in [98, 48]:
168 | continue
169 | else:
170 | raise
171 | sock.listen(256)
172 | self.sock = sock
173 | self.hostIP = hostIP
174 | self.thread = None
175 | self.start_time = None
176 | self.end_time = None
177 | self.nslave = nslave
178 | logging.info('start listen on %s:%d', hostIP, self.port)
179 |
180 | def __del__(self):
181 | self.sock.close()
182 |
183 | @staticmethod
184 | def get_neighbor(rank, nslave):
185 | rank = rank + 1
186 | ret = []
187 | if rank > 1:
188 | ret.append(rank // 2 - 1)
189 | if rank * 2 - 1 < nslave:
190 | ret.append(rank * 2 - 1)
191 | if rank * 2 < nslave:
192 | ret.append(rank * 2)
193 | return ret
194 |
195 | def slave_envs(self):
196 | """
197 | get enviroment variables for slaves
198 | can be passed in as args or envs
199 | """
200 | return {'DMLC_TRACKER_URI': self.hostIP,
201 | 'DMLC_TRACKER_PORT': self.port}
202 |
203 | def get_tree(self, nslave):
204 | tree_map = {}
205 | parent_map = {}
206 | for r in range(nslave):
207 | tree_map[r] = self.get_neighbor(r, nslave)
208 | parent_map[r] = (r + 1) // 2 - 1
209 | return tree_map, parent_map
210 |
211 | def find_share_ring(self, tree_map, parent_map, r):
212 | """
213 | get a ring structure that tends to share nodes with the tree
214 | return a list starting from r
215 | """
216 | nset = set(tree_map[r])
217 | cset = nset - set([parent_map[r]])
218 | if len(cset) == 0:
219 | return [r]
220 | rlst = [r]
221 | cnt = 0
222 | for v in cset:
223 | vlst = self.find_share_ring(tree_map, parent_map, v)
224 | cnt += 1
225 | if cnt == len(cset):
226 | vlst.reverse()
227 | rlst += vlst
228 | return rlst
229 |
230 | def get_ring(self, tree_map, parent_map):
231 | """
232 | get a ring connection used to recover local data
233 | """
234 | assert parent_map[0] == -1
235 | rlst = self.find_share_ring(tree_map, parent_map, 0)
236 | assert len(rlst) == len(tree_map)
237 | ring_map = {}
238 | nslave = len(tree_map)
239 | for r in range(nslave):
240 | rprev = (r + nslave - 1) % nslave
241 | rnext = (r + 1) % nslave
242 | ring_map[rlst[r]] = (rlst[rprev], rlst[rnext])
243 | return ring_map
244 |
245 | def get_link_map(self, nslave):
246 | """
247 | get the link map, this is a bit hacky, call for better algorithm
248 | to place similar nodes together
249 | """
250 | tree_map, parent_map = self.get_tree(nslave)
251 | ring_map = self.get_ring(tree_map, parent_map)
252 | rmap = {0: 0}
253 | k = 0
254 | for i in range(nslave - 1):
255 | k = ring_map[k][1]
256 | rmap[k] = i + 1
257 |
258 | ring_map_ = {}
259 | tree_map_ = {}
260 | parent_map_ = {}
261 | for k, v in ring_map.items():
262 | ring_map_[rmap[k]] = (rmap[v[0]], rmap[v[1]])
263 | for k, v in tree_map.items():
264 | tree_map_[rmap[k]] = [rmap[x] for x in v]
265 | for k, v in parent_map.items():
266 | if k != 0:
267 | parent_map_[rmap[k]] = rmap[v]
268 | else:
269 | parent_map_[rmap[k]] = -1
270 | return tree_map_, parent_map_, ring_map_
271 |
272 | def accept_slaves(self, nslave):
273 | # set of nodes that finishs the job
274 | shutdown = {}
275 | # set of nodes that is waiting for connections
276 | wait_conn = {}
277 | # maps job id to rank
278 | job_map = {}
279 | # list of workers that is pending to be assigned rank
280 | pending = []
281 | # lazy initialize tree_map
282 | tree_map = None
283 |
284 | while len(shutdown) != nslave:
285 | fd, s_addr = self.sock.accept()
286 | s = SlaveEntry(fd, s_addr)
287 | if s.cmd == 'print':
288 | msg = s.sock.recvstr()
289 | logging.info(msg.strip())
290 | continue
291 | if s.cmd == 'shutdown':
292 | assert s.rank >= 0 and s.rank not in shutdown
293 | assert s.rank not in wait_conn
294 | shutdown[s.rank] = s
295 | logging.debug('Recieve %s signal from %d', s.cmd, s.rank)
296 | continue
297 | assert s.cmd == 'start' or s.cmd == 'recover'
298 | # lazily initialize the slaves
299 | if tree_map is None:
300 | assert s.cmd == 'start'
301 | if s.world_size > 0:
302 | nslave = s.world_size
303 | tree_map, parent_map, ring_map = self.get_link_map(nslave)
304 | # set of nodes that is pending for getting up
305 | todo_nodes = list(range(nslave))
306 | else:
307 | assert s.world_size == -1 or s.world_size == nslave
308 | if s.cmd == 'recover':
309 | assert s.rank >= 0
310 |
311 | rank = s.decide_rank(job_map)
312 | # batch assignment of ranks
313 | if rank == -1:
314 | assert len(todo_nodes) != 0
315 | pending.append(s)
316 | if len(pending) == len(todo_nodes):
317 | pending.sort(key=lambda x: x.host)
318 | for s in pending:
319 | rank = todo_nodes.pop(0)
320 | if s.jobid != 'NULL':
321 | job_map[s.jobid] = rank
322 | s.assign_rank(rank, wait_conn, tree_map, parent_map,
323 | ring_map)
324 | if s.wait_accept > 0:
325 | wait_conn[rank] = s
326 | logging.debug('Recieve %s signal from %s; '
327 | 'assign rank %d', s.cmd, s.host, s.rank)
328 | if len(todo_nodes) == 0:
329 | logging.info('@tracker All of %d nodes getting started',
330 | nslave)
331 | self.start_time = time.time()
332 | else:
333 | s.assign_rank(rank, wait_conn, tree_map, parent_map, ring_map)
334 | logging.debug('Recieve %s signal from %d', s.cmd, s.rank)
335 | if s.wait_accept > 0:
336 | wait_conn[rank] = s
337 |
338 | logging.info("worker(ip_address=%s) connected!" % get_some_ip(s_addr[0]))
339 |
340 | logging.info('@tracker All nodes finishes job')
341 | self.end_time = time.time()
342 | logging.info('@tracker %s secs between node start and job finish',
343 | str(self.end_time - self.start_time))
344 |
345 | def start(self, nslave):
346 | def run():
347 | self.accept_slaves(nslave)
348 | self.thread = Thread(target=run, args=())
349 | self.thread.setDaemon(True)
350 | self.thread.start()
351 |
352 | def join(self):
353 | while self.thread.isAlive():
354 | self.thread.join(100)
355 |
356 |
357 | class PSTracker(object):
358 | """
359 | Tracker module for PS
360 | """
361 | def __init__(self, hostIP, cmd, port=9091, port_end=9999, envs=None):
362 | """
363 | Starts the PS scheduler
364 | """
365 | self.cmd = cmd
366 | if cmd is None:
367 | return
368 | envs = {} if envs is None else envs
369 | self.hostIP = hostIP
370 | sock = socket.socket(get_family(hostIP), socket.SOCK_STREAM)
371 | for port in range(port, port_end):
372 | try:
373 | sock.bind(('', port))
374 | self.port = port
375 | sock.close()
376 | break
377 | except socket.error:
378 | continue
379 | env = os.environ.copy()
380 |
381 | env['DMLC_ROLE'] = 'scheduler'
382 | env['DMLC_PS_ROOT_URI'] = str(self.hostIP)
383 | env['DMLC_PS_ROOT_PORT'] = str(self.port)
384 | for k, v in envs.items():
385 | env[k] = str(v)
386 | self.thread = Thread(
387 | target=(lambda: subprocess.check_call(self.cmd, env=env,
388 | shell=True)), args=())
389 | self.thread.setDaemon(True)
390 | self.thread.start()
391 |
392 | def join(self):
393 | if self.cmd is not None:
394 | while self.thread.isAlive():
395 | self.thread.join(100)
396 |
397 | def slave_envs(self):
398 | if self.cmd is None:
399 | return {}
400 | else:
401 | return {'DMLC_PS_ROOT_URI': self.hostIP,
402 | 'DMLC_PS_ROOT_PORT': self.port}
403 |
404 |
405 | def get_host_ip(hostIP=None):
406 | if hostIP is None or hostIP == 'auto':
407 | hostIP = 'ip'
408 |
409 | if hostIP == 'dns':
410 | hostIP = socket.getfqdn()
411 | elif hostIP == 'ip':
412 | from socket import gaierror
413 | try:
414 | hostIP = socket.gethostbyname(socket.getfqdn())
415 | except gaierror:
416 | logging.warn('gethostbyname(socket.getfqdn()) failed... trying on '
417 | 'hostname()')
418 | hostIP = socket.gethostbyname(socket.gethostname())
419 | if hostIP.startswith("127."):
420 | s = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
421 | # doesn't have to be reachable
422 | s.connect(('10.255.255.255', 1))
423 | hostIP = s.getsockname()[0]
424 | return hostIP
425 |
426 |
427 | def submit(nworker, nserver, fun_submit, hostIP='auto', pscmd=None):
428 | if nserver == 0:
429 | pscmd = None
430 |
431 | envs = {'DMLC_NUM_WORKER': nworker,
432 | 'DMLC_NUM_SERVER': nserver}
433 | hostIP = get_host_ip(hostIP)
434 |
435 | if nserver == 0:
436 | rabit = RabitTracker(hostIP=hostIP, nslave=nworker)
437 | envs.update(rabit.slave_envs())
438 | rabit.start(nworker)
439 | else:
440 | pserver = PSTracker(hostIP=hostIP, cmd=pscmd, envs=envs)
441 | envs.update(pserver.slave_envs())
442 | fun_submit(nworker, nserver, envs)
443 |
444 | if nserver == 0:
445 | rabit.join()
446 | else:
447 | pserver.join()
448 |
449 |
450 | def start_rabit_tracker(args):
451 | """Standalone function to start rabit tracker.
452 | Parameters
453 | ----------
454 | args: arguments to start the rabit tracker.
455 | """
456 | envs = {'DMLC_NUM_WORKER': args.num_workers,
457 | 'DMLC_NUM_SERVER': args.num_servers}
458 | rabit = RabitTracker(hostIP=get_host_ip(args.host_ip),
459 | nslave=args.num_workers)
460 | envs.update(rabit.slave_envs())
461 | rabit.start(args.num_workers)
462 | sys.stdout.write('DMLC_TRACKER_ENV_START\n')
463 | # simply write configuration to stdout
464 | for k, v in envs.items():
465 | sys.stdout.write('%s=%s\n' % (k, str(v)))
466 | sys.stdout.write('DMLC_TRACKER_ENV_END\n')
467 | sys.stdout.flush()
468 | rabit.join()
469 |
470 |
471 | def main():
472 | """Main function if tracker is executed in standalone mode."""
473 | parser = argparse.ArgumentParser(description='Rabit Tracker start.')
474 | parser.add_argument('--num-workers', required=True, type=int,
475 | help='Number of worker proccess to be launched.')
476 | parser.add_argument('--num-servers', default=0, type=int,
477 | help='Number of server process to be launched. Only '
478 | 'used in PS jobs.')
479 | parser.add_argument('--host-ip', default=None, type=str,
480 | help=('Host IP addressed, this is only needed ' +
481 | 'if the host IP cannot be automatically guessed.'
482 | ))
483 | parser.add_argument('--log-level', default='INFO', type=str,
484 | choices=['INFO', 'DEBUG'],
485 | help='Logging level of the logger.')
486 | args = parser.parse_args()
487 |
488 | fmt = '%(asctime)s %(levelname)s %(message)s'
489 | if args.log_level == 'INFO':
490 | level = logging.INFO
491 | elif args.log_level == 'DEBUG':
492 | level = logging.DEBUG
493 | else:
494 | raise RuntimeError("Unknown logging level %s" % args.log_level)
495 |
496 | logging.basicConfig(format=fmt, level=level)
497 |
498 | if args.num_servers == 0:
499 | start_rabit_tracker(args)
500 | else:
501 | raise RuntimeError("Do not yet support start ps tracker in standalone "
502 | "mode.")
503 |
504 |
505 | if __name__ == "__main__":
506 | main()
--------------------------------------------------------------------------------
/config/samples/smoke-dist/xgboost_smoke_test.py:
--------------------------------------------------------------------------------
1 |
2 | # Copyright 2018 Google Inc. All Rights Reserved.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | import logging
17 | import os
18 | import xgboost as xgb
19 | import traceback
20 |
21 | from tracker import RabitTracker
22 |
23 | logger = logging.getLogger(__name__)
24 |
25 | def extract_xgbooost_cluster_env():
26 |
27 | logger.info("start to extract system env")
28 |
29 | master_addr = os.environ.get("MASTER_ADDR", "{}")
30 | master_port = int(os.environ.get("MASTER_PORT", "{}"))
31 | rank = int(os.environ.get("RANK", "{}"))
32 | world_size = int(os.environ.get("WORLD_SIZE", "{}"))
33 |
34 | logger.info("extract the rabit env from cluster : %s, port: %d, rank: %d, word_size: %d ",
35 | master_addr, master_port, rank, world_size)
36 |
37 | return master_addr, master_port, rank, world_size
38 |
39 | def setup_rabit_cluster():
40 | addr, port, rank, world_size = extract_xgbooost_cluster_env()
41 |
42 | rabit_tracker = None
43 | try:
44 | """start to build the network"""
45 | if world_size > 1:
46 | if rank == 0:
47 | logger.info("start the master node")
48 |
49 | rabit = RabitTracker(hostIP="0.0.0.0", nslave=world_size,
50 | port=port, port_end=port + 1)
51 | rabit.start(world_size)
52 | rabit_tracker = rabit
53 | logger.info('########### RabitTracker Setup Finished #########')
54 |
55 | envs = [
56 | 'DMLC_NUM_WORKER=%d' % world_size,
57 | 'DMLC_TRACKER_URI=%s' % addr,
58 | 'DMLC_TRACKER_PORT=%d' % port,
59 | 'DMLC_TASK_ID=%d' % rank
60 | ]
61 | logger.info('##### Rabit rank setup with below envs #####')
62 | for i, env in enumerate(envs):
63 | logger.info(env)
64 | envs[i] = str.encode(env)
65 |
66 | xgb.rabit.init(envs)
67 | logger.info('##### Rabit rank = %d' % xgb.rabit.get_rank())
68 |
69 | rank = xgb.rabit.get_rank()
70 | s = None
71 | if rank == 0:
72 | s = {'hello world': 100, 2: 3}
73 |
74 | logger.info('@node[%d] before-broadcast: s=\"%s\"' % (rank, str(s)))
75 | s = xgb.rabit.broadcast(s, 0)
76 |
77 | logger.info('@node[%d] after-broadcast: s=\"%s\"' % (rank, str(s)))
78 |
79 | except Exception as e:
80 | logger.error("something wrong happen: %s", traceback.format_exc())
81 | raise e
82 | finally:
83 | if world_size > 1:
84 | xgb.rabit.finalize()
85 | if rabit_tracker:
86 | rabit_tracker.join()
87 |
88 | logger.info("the rabit network testing finished!")
89 |
90 | def main():
91 |
92 | port = os.environ.get("MASTER_PORT", "{}")
93 | logging.info("MASTER_PORT: %s", port)
94 |
95 | addr = os.environ.get("MASTER_ADDR", "{}")
96 | logging.info("MASTER_ADDR: %s", addr)
97 |
98 | world_size = os.environ.get("WORLD_SIZE", "{}")
99 | logging.info("WORLD_SIZE: %s", world_size)
100 |
101 | rank = os.environ.get("RANK", "{}")
102 | logging.info("RANK: %s", rank)
103 |
104 | setup_rabit_cluster()
105 |
106 | if __name__ == "__main__":
107 | logging.getLogger().setLevel(logging.INFO)
108 | main()
109 |
--------------------------------------------------------------------------------
/config/samples/smoke-dist/xgboostjob_v1_rabit_test.yaml:
--------------------------------------------------------------------------------
1 | apiVersion: "xgboostjob.kubeflow.org/v1"
2 | kind: "XGBoostJob"
3 | metadata:
4 | name: "xgboost-dist-test"
5 | spec:
6 | xgbReplicaSpecs:
7 | Master:
8 | replicas: 1
9 | restartPolicy: Never
10 | template:
11 | spec:
12 | containers:
13 | - name: xgboostjob
14 | image: docker.io/merlintang/xgboost-dist-rabit-test:1.2
15 | ports:
16 | - containerPort: 9991
17 | name: xgboostjob-port
18 | imagePullPolicy: Always
19 | Worker:
20 | replicas: 2
21 | restartPolicy: Never
22 | template:
23 | spec:
24 | containers:
25 | - name: xgboostjob
26 | image: docker.io/merlintang/xgboost-dist-rabit-test:1.2
27 | ports:
28 | - containerPort: 9991
29 | name: xgboostjob-port
30 | imagePullPolicy: Always
31 |
32 |
--------------------------------------------------------------------------------
/config/samples/smoke-dist/xgboostjob_v1alpha1_rabit_test.yaml:
--------------------------------------------------------------------------------
1 | apiVersion: "xgboostjob.kubeflow.org/v1alpha1"
2 | kind: "XGBoostJob"
3 | metadata:
4 | name: "xgboost-dist-test"
5 | spec:
6 | xgbReplicaSpecs:
7 | Master:
8 | replicas: 1
9 | restartPolicy: Never
10 | template:
11 | apiVersion: v1
12 | kind: Pod
13 | spec:
14 | containers:
15 | - name: xgboostjob
16 | image: docker.io/merlintang/xgboost-dist-rabit-test:1.2
17 | ports:
18 | - containerPort: 9991
19 | name: xgboostjob-port
20 | imagePullPolicy: Always
21 | Worker:
22 | replicas: 2
23 | restartPolicy: Never
24 | template:
25 | apiVersion: v1
26 | kind: Pod
27 | spec:
28 | containers:
29 | - name: xgboostjob
30 | image: docker.io/merlintang/xgboost-dist-rabit-test:1.2
31 | ports:
32 | - containerPort: 9991
33 | name: xgboostjob-port
34 | imagePullPolicy: Always
35 |
36 |
--------------------------------------------------------------------------------
/config/samples/xgboost-dist/Dockerfile:
--------------------------------------------------------------------------------
1 | # Install python 3。6.
2 | FROM python:3.6
3 |
4 | RUN apt-get update
5 | RUN apt-get install -y git make g++ cmake
6 |
7 | RUN mkdir -p /opt/mlkube
8 |
9 | # Download the rabit tracker and xgboost code.
10 |
11 | COPY requirements.txt /opt/mlkube/
12 |
13 | # Install requirements
14 |
15 | RUN pip install -r /opt/mlkube/requirements.txt
16 |
17 | # Build XGBoost.
18 | RUN git clone --recursive https://github.com/dmlc/xgboost && \
19 | cd xgboost && \
20 | make -j$(nproc) && \
21 | cd python-package; python setup.py install
22 |
23 | COPY *.py /opt/mlkube/
24 |
25 | ENTRYPOINT ["python", "/opt/mlkube/main.py"]
26 |
--------------------------------------------------------------------------------
/config/samples/xgboost-dist/build.sh:
--------------------------------------------------------------------------------
1 | ## build the docker file
2 | docker build -f Dockerfile -t merlintang/xgboost-dist-iris:1.0 ./
3 |
4 | ## push the docker image into docker.io
5 | docker push merlintang/xgboost-dist-iris:1.0
6 |
7 | ## run the train job
8 | kubectl create -f xgboostjob_v1alpha1_iris_train.yaml
9 |
10 | ## run the predict job
11 | kubectl create -f xgboostjob_v1alpha1_iris_predict.yaml
12 |
--------------------------------------------------------------------------------
/config/samples/xgboost-dist/local_test.py:
--------------------------------------------------------------------------------
1 | # Licensed under the Apache License, Version 2.0 (the "License");
2 | # you may not use this file except in compliance with the License.
3 | # You may obtain a copy of the License at
4 | #
5 | # http://www.apache.org/licenses/LICENSE-2.0
6 | #
7 | # Unless required by applicable law or agreed to in writing, software
8 | # distributed under the License is distributed on an "AS IS" BASIS,
9 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
10 | # See the License for the specific language governing permissions and
11 | # limitations under the License.
12 |
13 | """
14 | this file contains tests for xgboost local train and predict in single machine.
15 | Note: this is not for distributed train and predict test
16 | """
17 | from utils import dump_model, read_model, read_train_data, read_predict_data
18 | import xgboost as xgb
19 | import logging
20 | import numpy as np
21 | from sklearn.metrics import precision_score
22 |
23 | logger = logging.getLogger(__name__)
24 |
25 |
26 | def test_train_model():
27 | """
28 | test xgboost train in a single machine
29 | :return: trained model
30 | """
31 | rank = 1
32 | world_size = 10
33 | place = "/tmp/data"
34 | dmatrix = read_train_data(rank, world_size, place)
35 |
36 | param_xgboost_default = {'max_depth': 2, 'eta': 1, 'silent': 1,
37 | 'objective': 'multi:softprob', 'num_class': 3}
38 |
39 | booster = xgb.train(param_xgboost_default, dtrain=dmatrix)
40 |
41 | assert booster is not None
42 |
43 | return booster
44 |
45 |
46 | def test_model_predict(booster):
47 | """
48 | test xgboost train in the single node
49 | :return: true if pass the test
50 | """
51 | rank = 1
52 | world_size = 10
53 | place = "/tmp/data"
54 | dmatrix, y_test = read_predict_data(rank, world_size, place)
55 |
56 | preds = booster.predict(dmatrix)
57 | best_preds = np.asarray([np.argmax(line) for line in preds])
58 | score = precision_score(y_test, best_preds, average='macro')
59 |
60 | assert score > 0.99
61 |
62 | logging.info("Predict accuracy: %f", score)
63 |
64 | return True
65 |
66 |
67 | def test_upload_model(model, model_path, args):
68 |
69 | return dump_model(model, type="local", model_path=model_path, args=args)
70 |
71 |
72 | def test_download_model(model_path, args):
73 |
74 | return read_model(type="local", model_path=model_path, args=args)
75 |
76 |
77 | def run_test():
78 | args = {}
79 | model_path = "/tmp/xgboost"
80 |
81 | logging.info("Start the local test")
82 |
83 | booster = test_train_model()
84 | test_upload_model(booster, model_path, args)
85 | booster_new = test_download_model(model_path, args)
86 | test_model_predict(booster_new)
87 |
88 | logging.info("Finish the local test")
89 |
90 |
91 | if __name__ == '__main__':
92 |
93 | logging.basicConfig(format='%(message)s')
94 | logging.getLogger().setLevel(logging.INFO)
95 |
96 | run_test()
97 |
--------------------------------------------------------------------------------
/config/samples/xgboost-dist/main.py:
--------------------------------------------------------------------------------
1 | # Licensed under the Apache License, Version 2.0 (the "License");
2 | # you may not use this file except in compliance with the License.
3 | # You may obtain a copy of the License at
4 | #
5 | # http://www.apache.org/licenses/LICENSE-2.0
6 | #
7 | # Unless required by applicable law or agreed to in writing, software
8 | # distributed under the License is distributed on an "AS IS" BASIS,
9 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
10 | # See the License for the specific language governing permissions and
11 | # limitations under the License.
12 |
13 | import argparse
14 | import logging
15 |
16 | from train import train
17 | from predict import predict
18 | from utils import dump_model
19 |
20 |
21 | def main(args):
22 |
23 | model_storage_type = args.model_storage_type
24 | if (model_storage_type == "local" or model_storage_type == "oss"):
25 | print ( "The storage type is " + model_storage_type)
26 | else:
27 | raise Exception("Only supports storage types like local and OSS")
28 |
29 | if args.job_type == "Predict":
30 | logging.info("starting the predict job")
31 | predict(args)
32 |
33 | elif args.job_type == "Train":
34 | logging.info("starting the train job")
35 | model = train(args)
36 |
37 | if model is not None:
38 | logging.info("finish the model training, and start to dump model ")
39 | model_path = args.model_path
40 | dump_model(model, model_storage_type, model_path, args)
41 |
42 | elif args.job_type == "All":
43 | logging.info("starting the train and predict job")
44 |
45 | logging.info("Finish distributed XGBoost job")
46 |
47 |
48 | if __name__ == '__main__':
49 | parser = argparse.ArgumentParser()
50 |
51 | parser.add_argument(
52 | '--job_type',
53 | help="Train, Predict, All",
54 | required=True
55 | )
56 | parser.add_argument(
57 | '--xgboost_parameter',
58 | help='XGBoost model parameter like: objective, number_class',
59 | )
60 | parser.add_argument(
61 | '--n_estimators',
62 | help='Number of trees in the model',
63 | type=int,
64 | default=1000
65 | )
66 | parser.add_argument(
67 | '--learning_rate',
68 | help='Learning rate for the model',
69 | default=0.1
70 | )
71 | parser.add_argument(
72 | '--early_stopping_rounds',
73 | help='XGBoost argument for stopping early',
74 | default=50
75 | )
76 | parser.add_argument(
77 | '--model_path',
78 | help='place to store model',
79 | default="/tmp/xgboost_model"
80 | )
81 | parser.add_argument(
82 | '--model_storage_type',
83 | help='place to store the model',
84 | default="oss"
85 | )
86 | parser.add_argument(
87 | '--oss_param',
88 | help='oss parameter if you choose the model storage as OSS type',
89 | )
90 |
91 | logging.basicConfig(format='%(message)s')
92 | logging.getLogger().setLevel(logging.INFO)
93 | main_args = parser.parse_args()
94 | main(main_args)
95 |
--------------------------------------------------------------------------------
/config/samples/xgboost-dist/predict.py:
--------------------------------------------------------------------------------
1 | # Licensed under the Apache License, Version 2.0 (the "License");
2 | # you may not use this file except in compliance with the License.
3 | # You may obtain a copy of the License at
4 | #
5 | # http://www.apache.org/licenses/LICENSE-2.0
6 | #
7 | # Unless required by applicable law or agreed to in writing, software
8 | # distributed under the License is distributed on an "AS IS" BASIS,
9 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
10 | # See the License for the specific language governing permissions and
11 | # limitations under the License.
12 |
13 | from sklearn.metrics import precision_score
14 |
15 | import logging
16 | import numpy as np
17 |
18 | from utils import extract_xgbooost_cluster_env, read_predict_data, read_model
19 |
20 |
21 | def predict(args):
22 | """
23 | This is the demonstration for the batch prediction
24 | :param args: parameter for model related config
25 | """
26 |
27 | addr, port, rank, world_size = extract_xgbooost_cluster_env()
28 |
29 | dmatrix, y_test = read_predict_data(rank, world_size, None)
30 |
31 | model_path = args.model_path
32 | storage_type = args.model_storage_type
33 | booster = read_model(storage_type, model_path, args)
34 |
35 | preds = booster.predict(dmatrix)
36 |
37 | best_preds = np.asarray([np.argmax(line) for line in preds])
38 | score = precision_score(y_test, best_preds, average='macro')
39 |
40 | logging.info("Predict accuracy: %f", score)
41 |
--------------------------------------------------------------------------------
/config/samples/xgboost-dist/requirements.txt:
--------------------------------------------------------------------------------
1 | numpy>=1.16.3
2 | Cython>=0.29.4
3 | requests>=2.21.0
4 | urllib3>=1.21.1
5 | scipy>=1.1.0
6 | joblib>=0.13.2
7 | scikit-learn>=0.20
8 | oss2>=2.7.0
9 | pandas>=0.24.2
--------------------------------------------------------------------------------
/config/samples/xgboost-dist/train.py:
--------------------------------------------------------------------------------
1 | # Licensed under the Apache License, Version 2.0 (the "License");
2 | # you may not use this file except in compliance with the License.
3 | # You may obtain a copy of the License at
4 | #
5 | # http://www.apache.org/licenses/LICENSE-2.0
6 | #
7 | # Unless required by applicable law or agreed to in writing, software
8 | # distributed under the License is distributed on an "AS IS" BASIS,
9 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
10 | # See the License for the specific language governing permissions and
11 | # limitations under the License.
12 |
13 |
14 | import logging
15 | import xgboost as xgb
16 | import traceback
17 |
18 | from tracker import RabitTracker
19 | from utils import read_train_data, extract_xgbooost_cluster_env
20 |
21 | logger = logging.getLogger(__name__)
22 |
23 |
24 | def train(args):
25 | """
26 | :param args: configuration for train job
27 | :return: XGBoost model
28 | """
29 | addr, port, rank, world_size = extract_xgbooost_cluster_env()
30 | rabit_tracker = None
31 |
32 | try:
33 | """start to build the network"""
34 | if world_size > 1:
35 | if rank == 0:
36 | logger.info("start the master node")
37 |
38 | rabit = RabitTracker(hostIP="0.0.0.0", nslave=world_size,
39 | port=port, port_end=port + 1)
40 | rabit.start(world_size)
41 | rabit_tracker = rabit
42 | logger.info('###### RabitTracker Setup Finished ######')
43 |
44 | envs = [
45 | 'DMLC_NUM_WORKER=%d' % world_size,
46 | 'DMLC_TRACKER_URI=%s' % addr,
47 | 'DMLC_TRACKER_PORT=%d' % port,
48 | 'DMLC_TASK_ID=%d' % rank
49 | ]
50 | logger.info('##### Rabit rank setup with below envs #####')
51 | for i, env in enumerate(envs):
52 | logger.info(env)
53 | envs[i] = str.encode(env)
54 |
55 | xgb.rabit.init(envs)
56 | logger.info('##### Rabit rank = %d' % xgb.rabit.get_rank())
57 | rank = xgb.rabit.get_rank()
58 |
59 | else:
60 | world_size = 1
61 | logging.info("Start the train in a single node")
62 |
63 | df = read_train_data(rank=rank, num_workers=world_size, path=None)
64 | kwargs = {}
65 | kwargs["dtrain"] = df
66 | kwargs["num_boost_round"] = int(args.n_estimators)
67 | param_xgboost_default = {'max_depth': 2, 'eta': 1, 'silent': 1,
68 | 'objective': 'multi:softprob', 'num_class': 3}
69 | kwargs["params"] = param_xgboost_default
70 |
71 | logging.info("starting to train xgboost at node with rank %d", rank)
72 | bst = xgb.train(**kwargs)
73 |
74 | if rank == 0:
75 | model = bst
76 | else:
77 | model = None
78 |
79 | logging.info("finish xgboost training at node with rank %d", rank)
80 |
81 | except Exception as e:
82 | logger.error("something wrong happen: %s", traceback.format_exc())
83 | raise e
84 | finally:
85 | logger.info("xgboost training job finished!")
86 | if world_size > 1:
87 | xgb.rabit.finalize()
88 | if rabit_tracker:
89 | rabit_tracker.join()
90 |
91 | return model
92 |
--------------------------------------------------------------------------------
/config/samples/xgboost-dist/utils.py:
--------------------------------------------------------------------------------
1 | # Licensed under the Apache License, Version 2.0 (the "License");
2 | # you may not use this file except in compliance with the License.
3 | # You may obtain a copy of the License at
4 | #
5 | # http://www.apache.org/licenses/LICENSE-2.0
6 | #
7 | # Unless required by applicable law or agreed to in writing, software
8 | # distributed under the License is distributed on an "AS IS" BASIS,
9 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
10 | # See the License for the specific language governing permissions and
11 | # limitations under the License.
12 |
13 | import logging
14 | import joblib
15 | import xgboost as xgb
16 | import os
17 | import tempfile
18 | import oss2
19 | import json
20 | import pandas as pd
21 |
22 | from sklearn import datasets
23 |
24 | logger = logging.getLogger(__name__)
25 |
26 |
27 | def extract_xgbooost_cluster_env():
28 | """
29 | Extract the cluster env from pod
30 | :return: the related cluster env to build rabit
31 | """
32 |
33 | logger.info("starting to extract system env")
34 |
35 | master_addr = os.environ.get("MASTER_ADDR", "{}")
36 | master_port = int(os.environ.get("MASTER_PORT", "{}"))
37 | rank = int(os.environ.get("RANK", "{}"))
38 | world_size = int(os.environ.get("WORLD_SIZE", "{}"))
39 |
40 | logger.info("extract the Rabit env from cluster :"
41 | " %s, port: %d, rank: %d, word_size: %d ",
42 | master_addr, master_port, rank, world_size)
43 |
44 | return master_addr, master_port, rank, world_size
45 |
46 |
47 | def read_train_data(rank, num_workers, path):
48 | """
49 | Read file based on the rank of worker.
50 | We use the sklearn.iris data for demonstration
51 | You can extend this to read distributed data source like HDFS, HIVE etc
52 | :param rank: the id of each worker
53 | :param num_workers: total number of workers in this cluster
54 | :param path: the input file name or the place to read the data
55 | :return: XGBoost Dmatrix
56 | """
57 | iris = datasets.load_iris()
58 | x = iris.data
59 | y = iris.target
60 |
61 | start, end = get_range_data(len(x), rank, num_workers)
62 | x = x[start:end, :]
63 | y = y[start:end]
64 |
65 | x = pd.DataFrame(x)
66 | y = pd.DataFrame(y)
67 | dtrain = xgb.DMatrix(data=x, label=y)
68 |
69 | logging.info("Read data from IRIS data source with range from %d to %d",
70 | start, end)
71 |
72 | return dtrain
73 |
74 |
75 | def read_predict_data(rank, num_workers, path):
76 | """
77 | Read file based on the rank of worker.
78 | We use the sklearn.iris data for demonstration
79 | You can extend this to read distributed data source like HDFS, HIVE etc
80 | :param rank: the id of each worker
81 | :param num_workers: total number of workers in this cluster
82 | :param path: the input file name or the place to read the data
83 | :return: XGBoost Dmatrix, and real value
84 | """
85 | iris = datasets.load_iris()
86 | x = iris.data
87 | y = iris.target
88 |
89 | start, end = get_range_data(len(x), rank, num_workers)
90 | x = x[start:end, :]
91 | y = y[start:end]
92 | x = pd.DataFrame(x)
93 | y = pd.DataFrame(y)
94 |
95 | logging.info("Read data from IRIS datasource with range from %d to %d",
96 | start, end)
97 |
98 | predict = xgb.DMatrix(x, label=y)
99 |
100 | return predict, y
101 |
102 |
103 | def get_range_data(num_row, rank, num_workers):
104 | """
105 | compute the data range based on the input data size and worker id
106 | :param num_row: total number of dataset
107 | :param rank: the worker id
108 | :param num_workers: total number of workers
109 | :return: begin and end range of input matrix
110 | """
111 | num_per_partition = int(num_row/num_workers)
112 |
113 | x_start = rank * num_per_partition
114 | x_end = (rank + 1) * num_per_partition
115 |
116 | if x_end > num_row:
117 | x_end = num_row
118 |
119 | return x_start, x_end
120 |
121 |
122 | def dump_model(model, type, model_path, args):
123 | """
124 | dump the trained model into local place
125 | you can update this function to store the model into a remote place
126 | :param model: the xgboost trained booster
127 | :param type: model storage type
128 | :param model_path: place to store model
129 | :param args: configuration for model storage
130 | :return: True if the dump process success
131 | """
132 | if model is None:
133 | raise Exception("fail to get the XGBoost train model")
134 | else:
135 | if type == "local":
136 | joblib.dump(model, model_path)
137 | logging.info("Dump model into local place %s", model_path)
138 |
139 | elif type == "oss":
140 | oss_param = parse_parameters(args.oss_param, ",", ":")
141 | if oss_param is None:
142 | raise Exception("Please config oss parameter to store model")
143 |
144 | oss_param['path'] = args.model_path
145 | dump_model_to_oss(oss_param, model)
146 | logging.info("Dump model into oss place %s", args.model_path)
147 |
148 | return True
149 |
150 |
151 | def read_model(type, model_path, args):
152 | """
153 | read model from physical storage
154 | :param type: oss or local
155 | :param model_path: place to store the model
156 | :param args: configuration to read model
157 | :return: XGBoost model
158 | """
159 |
160 | if type == "local":
161 | model = joblib.load(model_path)
162 | logging.info("Read model from local place %s", model_path)
163 |
164 | elif type == "oss":
165 | oss_param = parse_parameters(args.oss_param, ",", ":")
166 | if oss_param is None:
167 | raise Exception("Please config oss to read model")
168 | return False
169 |
170 | oss_param['path'] = args.model_path
171 |
172 | model = read_model_from_oss(oss_param)
173 | logging.info("read model from oss place %s", model_path)
174 |
175 | return model
176 |
177 |
178 | def dump_model_to_oss(oss_parameters, booster):
179 | """
180 | dump the model to remote OSS disk
181 | :param oss_parameters: oss configuration
182 | :param booster: XGBoost model
183 | :return: True if stored procedure is success
184 | """
185 | """export model into oss"""
186 | model_fname = os.path.join(tempfile.mkdtemp(), 'model')
187 | text_model_fname = os.path.join(tempfile.mkdtemp(), 'model.text')
188 | feature_importance = os.path.join(tempfile.mkdtemp(),
189 | 'feature_importance.json')
190 |
191 | oss_path = oss_parameters['path']
192 | logger.info('---- export model ----')
193 | booster.save_model(model_fname)
194 | booster.dump_model(text_model_fname) # format output model
195 | fscore_dict = booster.get_fscore()
196 | with open(feature_importance, 'w') as file:
197 | file.write(json.dumps(fscore_dict))
198 | logger.info('---- chief dump model successfully!')
199 |
200 | if os.path.exists(model_fname):
201 | logger.info('---- Upload Model start...')
202 |
203 | while oss_path[-1] == '/':
204 | oss_path = oss_path[:-1]
205 |
206 | upload_oss(oss_parameters, model_fname, oss_path)
207 | aux_path = oss_path + '_dir/'
208 | upload_oss(oss_parameters, model_fname, aux_path)
209 | upload_oss(oss_parameters, text_model_fname, aux_path)
210 | upload_oss(oss_parameters, feature_importance, aux_path)
211 | else:
212 | raise Exception("fail to generate model")
213 | return False
214 |
215 | return True
216 |
217 |
218 | def upload_oss(kw, local_file, oss_path):
219 | """
220 | help function to upload a model to oss
221 | :param kw: OSS parameter
222 | :param local_file: local place of model
223 | :param oss_path: remote place of OSS
224 | :return: True if the procedure is success
225 | """
226 | if oss_path[-1] == '/':
227 | oss_path = '%s%s' % (oss_path, os.path.basename(local_file))
228 |
229 | auth = oss2.Auth(kw['access_id'], kw['access_key'])
230 | bucket = kw['access_bucket']
231 | bkt = oss2.Bucket(auth=auth, endpoint=kw['endpoint'], bucket_name=bucket)
232 |
233 | try:
234 | bkt.put_object_from_file(key=oss_path, filename=local_file)
235 | logger.info("upload %s to %s successfully!" %
236 | (os.path.abspath(local_file), oss_path))
237 | except Exception():
238 | raise ValueError('upload %s to %s failed' %
239 | (os.path.abspath(local_file), oss_path))
240 |
241 |
242 | def read_model_from_oss(kw):
243 | """
244 | helper function to read a model from oss
245 | :param kw: OSS parameter
246 | :return: XGBoost booster model
247 | """
248 | auth = oss2.Auth(kw['access_id'], kw['access_key'])
249 | bucket = kw['access_bucket']
250 | bkt = oss2.Bucket(auth=auth, endpoint=kw['endpoint'], bucket_name=bucket)
251 | oss_path = kw["path"]
252 |
253 | temp_model_fname = os.path.join(tempfile.mkdtemp(), 'local_model')
254 | try:
255 | bkt.get_object_to_file(key=oss_path, filename=temp_model_fname)
256 | logger.info("success to load model from oss %s", oss_path)
257 | except Exception as e:
258 | logging.error("fail to load model: " + e)
259 | raise Exception("fail to load model from oss %s", oss_path)
260 |
261 | bst = xgb.Booster({'nthread': 2}) # init model
262 |
263 | bst.load_model(temp_model_fname)
264 |
265 | return bst
266 |
267 |
268 | def parse_parameters(input, splitter_between, splitter_in):
269 | """
270 | helper function parse the input parameter
271 | :param input: the string of configuration like key-value pairs
272 | :param splitter_between: the splitter between config for input string
273 | :param splitter_in: the splitter inside config for input string
274 | :return: key-value pair configuration
275 | """
276 |
277 | ky_pairs = input.split(splitter_between)
278 |
279 | confs = {}
280 |
281 | for kv in ky_pairs:
282 | conf = kv.split(splitter_in)
283 | key = conf[0].strip(" ")
284 | if key == "objective" or key == "endpoint":
285 | value = conf[1].strip("'") + ":" + conf[2].strip("'")
286 | else:
287 | value = conf[1]
288 |
289 | confs[key] = value
290 | return confs
291 |
292 |
--------------------------------------------------------------------------------
/config/samples/xgboost-dist/xgboostjob_v1_iris_predict.yaml:
--------------------------------------------------------------------------------
1 | apiVersion: "xgboostjob.kubeflow.org/v1"
2 | kind: "XGBoostJob"
3 | metadata:
4 | name: "xgboost-dist-iris-test-predict"
5 | spec:
6 | xgbReplicaSpecs:
7 | Master:
8 | replicas: 1
9 | restartPolicy: Never
10 | template:
11 | spec:
12 | containers:
13 | - name: xgboostjob
14 | image: docker.io/merlintang/xgboost-dist-iris:1.1
15 | ports:
16 | - containerPort: 9991
17 | name: xgboostjob-port
18 | imagePullPolicy: Always
19 | args:
20 | - --job_type=Predict
21 | - --model_path=autoAI/xgb-opt/2
22 | - --model_storage_type=oss
23 | - --oss_param=unknown
24 | Worker:
25 | replicas: 2
26 | restartPolicy: ExitCode
27 | template:
28 | spec:
29 | containers:
30 | - name: xgboostjob
31 | image: docker.io/merlintang/xgboost-dist-iris:1.1
32 | ports:
33 | - containerPort: 9991
34 | name: xgboostjob-port
35 | imagePullPolicy: Always
36 | args:
37 | - --job_type=Predict
38 | - --model_path=autoAI/xgb-opt/2
39 | - --model_storage_type=oss
40 | - --oss_param=unknown
41 |
42 |
43 |
--------------------------------------------------------------------------------
/config/samples/xgboost-dist/xgboostjob_v1_iris_train.yaml:
--------------------------------------------------------------------------------
1 | apiVersion: "xgboostjob.kubeflow.org/v1"
2 | kind: "XGBoostJob"
3 | metadata:
4 | name: "xgboost-dist-iris-test-train"
5 | spec:
6 | xgbReplicaSpecs:
7 | Master:
8 | replicas: 1
9 | restartPolicy: Never
10 | template:
11 | spec:
12 | containers:
13 | - name: xgboostjob
14 | image: docker.io/merlintang/xgboost-dist-iris:1.1
15 | ports:
16 | - containerPort: 9991
17 | name: xgboostjob-port
18 | imagePullPolicy: Always
19 | args:
20 | - --job_type=Train
21 | - --xgboost_parameter=objective:multi:softprob,num_class:3
22 | - --n_estimators=10
23 | - --learning_rate=0.1
24 | - --model_path=/tmp/xgboost-model
25 | - --model_storage_type=local
26 | Worker:
27 | replicas: 2
28 | restartPolicy: ExitCode
29 | template:
30 | spec:
31 | containers:
32 | - name: xgboostjob
33 | image: docker.io/merlintang/xgboost-dist-iris:1.1
34 | ports:
35 | - containerPort: 9991
36 | name: xgboostjob-port
37 | imagePullPolicy: Always
38 | args:
39 | - --job_type=Train
40 | - --xgboost_parameter="objective:multi:softprob,num_class:3"
41 | - --n_estimators=10
42 | - --learning_rate=0.1
43 |
44 |
45 |
--------------------------------------------------------------------------------
/config/samples/xgboost-dist/xgboostjob_v1alpha1_iris_predict.yaml:
--------------------------------------------------------------------------------
1 | apiVersion: "xgboostjob.kubeflow.org/v1alpha1"
2 | kind: "XGBoostJob"
3 | metadata:
4 | name: "xgboost-dist-iris-test-predict"
5 | spec:
6 | xgbReplicaSpecs:
7 | Master:
8 | replicas: 1
9 | restartPolicy: Never
10 | template:
11 | apiVersion: v1
12 | kind: Pod
13 | spec:
14 | containers:
15 | - name: xgboostjob
16 | image: docker.io/merlintang/xgboost-dist-iris:1.1
17 | ports:
18 | - containerPort: 9991
19 | name: xgboostjob-port
20 | imagePullPolicy: Always
21 | args:
22 | - --job_type=Predict
23 | - --model_path=autoAI/xgb-opt/2
24 | - --model_storage_type=oss
25 | - --oss_param=unknown
26 | Worker:
27 | replicas: 2
28 | restartPolicy: ExitCode
29 | template:
30 | apiVersion: v1
31 | kind: Pod
32 | spec:
33 | containers:
34 | - name: xgboostjob
35 | image: docker.io/merlintang/xgboost-dist-iris:1.1
36 | ports:
37 | - containerPort: 9991
38 | name: xgboostjob-port
39 | imagePullPolicy: Always
40 | args:
41 | - --job_type=Predict
42 | - --model_path=autoAI/xgb-opt/2
43 | - --model_storage_type=oss
44 | - --oss_param=unknown
45 |
46 |
47 |
--------------------------------------------------------------------------------
/config/samples/xgboost-dist/xgboostjob_v1alpha1_iris_predict_local.yaml:
--------------------------------------------------------------------------------
1 | apiVersion: "xgboostjob.kubeflow.org/v1alpha1"
2 | kind: "XGBoostJob"
3 | metadata:
4 | name: "xgboost-dist-iris-test-predict-local"
5 | spec:
6 | xgbReplicaSpecs:
7 | Master:
8 | replicas: 1
9 | restartPolicy: Never
10 | template:
11 | apiVersion: v1
12 | kind: Pod
13 | spec:
14 | volumes:
15 | - name: task-pv-storage
16 | persistentVolumeClaim:
17 | claimName: xgboostlocal
18 | containers:
19 | - name: xgboostjob
20 | image: docker.io/merlintang/xgboost-dist-iris:1.1
21 | volumeMounts:
22 | - name: task-pv-storage
23 | mountPath: /tmp/xgboost_model
24 | ports:
25 | - containerPort: 9991
26 | name: xgboostjob-port
27 | imagePullPolicy: Always
28 | args:
29 | - --job_type=Predict
30 | - --model_path=/tmp/xgboost_model/2
31 | - --model_storage_type=local
32 | Worker:
33 | replicas: 2
34 | restartPolicy: ExitCode
35 | template:
36 | apiVersion: v1
37 | kind: Pod
38 | spec:
39 | volumes:
40 | - name: task-pv-storage
41 | persistentVolumeClaim:
42 | claimName: xgboostlocal
43 | containers:
44 | - name: xgboostjob
45 | image: docker.io/merlintang/xgboost-dist-iris:1.1
46 | volumeMounts:
47 | - name: task-pv-storage
48 | mountPath: /tmp/xgboost_model
49 | ports:
50 | - containerPort: 9991
51 | name: xgboostjob-port
52 | imagePullPolicy: Always
53 | args:
54 | - --job_type=Predict
55 | - --model_path=/tmp/xgboost_model/2
56 | - --model_storage_type=local
57 |
--------------------------------------------------------------------------------
/config/samples/xgboost-dist/xgboostjob_v1alpha1_iris_train.yaml:
--------------------------------------------------------------------------------
1 | apiVersion: "xgboostjob.kubeflow.org/v1alpha1"
2 | kind: "XGBoostJob"
3 | metadata:
4 | name: "xgboost-dist-iris-test-train"
5 | spec:
6 | xgbReplicaSpecs:
7 | Master:
8 | replicas: 1
9 | restartPolicy: Never
10 | template:
11 | apiVersion: v1
12 | kind: Pod
13 | spec:
14 | containers:
15 | - name: xgboostjob
16 | image: docker.io/merlintang/xgboost-dist-iris:1.1
17 | ports:
18 | - containerPort: 9991
19 | name: xgboostjob-port
20 | imagePullPolicy: Always
21 | args:
22 | - --job_type=Train
23 | - --xgboost_parameter=objective:multi:softprob,num_class:3
24 | - --n_estimators=10
25 | - --learning_rate=0.1
26 | - --model_path=autoAI/xgb-opt/2
27 | - --model_storage_type=oss
28 | - --oss_param=unknown
29 | Worker:
30 | replicas: 2
31 | restartPolicy: ExitCode
32 | template:
33 | apiVersion: v1
34 | kind: Pod
35 | spec:
36 | containers:
37 | - name: xgboostjob
38 | image: docker.io/merlintang/xgboost-dist-iris:1.1
39 | ports:
40 | - containerPort: 9991
41 | name: xgboostjob-port
42 | imagePullPolicy: Always
43 | args:
44 | - --job_type=Train
45 | - --xgboost_parameter="objective:multi:softprob,num_class:3"
46 | - --n_estimators=10
47 | - --learning_rate=0.1
48 |
49 |
50 |
--------------------------------------------------------------------------------
/config/samples/xgboost-dist/xgboostjob_v1alpha1_iris_train_local.yaml:
--------------------------------------------------------------------------------
1 | apiVersion: "xgboostjob.kubeflow.org/v1alpha1"
2 | kind: "XGBoostJob"
3 | metadata:
4 | name: "xgboost-dist-iris-test-train-local"
5 | spec:
6 | xgbReplicaSpecs:
7 | Master:
8 | replicas: 1
9 | restartPolicy: Never
10 | template:
11 | apiVersion: v1
12 | kind: Pod
13 | spec:
14 | volumes:
15 | - name: task-pv-storage
16 | persistentVolumeClaim:
17 | claimName: xgboostlocal
18 | containers:
19 | - name: xgboostjob
20 | image: docker.io/merlintang/xgboost-dist-iris:1.1
21 | volumeMounts:
22 | - name: task-pv-storage
23 | mountPath: /tmp/xgboost_model
24 | ports:
25 | - containerPort: 9991
26 | name: xgboostjob-port
27 | imagePullPolicy: Always
28 | args:
29 | - --job_type=Train
30 | - --xgboost_parameter=objective:multi:softprob,num_class:3
31 | - --n_estimators=10
32 | - --learning_rate=0.1
33 | - --model_path=/tmp/xgboost_model/2
34 | - --model_storage_type=local
35 | Worker:
36 | replicas: 2
37 | restartPolicy: ExitCode
38 | template:
39 | apiVersion: v1
40 | kind: Pod
41 | spec:
42 | volumes:
43 | - name: task-pv-storage
44 | persistentVolumeClaim:
45 | claimName: xgboostlocal
46 | containers:
47 | - name: xgboostjob
48 | image: docker.io/merlintang/xgboost-dist-iris:1.1
49 | volumeMounts:
50 | - name: task-pv-storage
51 | mountPath: /tmp/xgboost_model
52 | ports:
53 | - containerPort: 9991
54 | name: xgboostjob-port
55 | imagePullPolicy: Always
56 | args:
57 | - --job_type=Train
58 | - --xgboost_parameter="objective:multi:softprob,num_class:3"
59 | - --n_estimators=10
60 | - --learning_rate=0.1
61 | - --model_path=/tmp/xgboost_model/2
62 | - --model_storage_type=local
63 |
--------------------------------------------------------------------------------
/go.mod:
--------------------------------------------------------------------------------
1 | module github.com/kubeflow/xgboost-operator
2 |
3 | go 1.12
4 |
5 | require (
6 | cloud.google.com/go v0.39.0 // indirect
7 | github.com/go-logr/zapr v0.1.1 // indirect
8 | github.com/kubeflow/common v0.3.1
9 | github.com/sirupsen/logrus v1.4.2
10 | go.uber.org/atomic v1.4.0 // indirect
11 | go.uber.org/zap v1.10.0 // indirect
12 | google.golang.org/appengine v1.6.0 // indirect
13 | gopkg.in/square/go-jose.v2 v2.3.1 // indirect
14 | k8s.io/api v0.16.9
15 | k8s.io/apimachinery v0.16.9
16 | k8s.io/client-go v0.16.9
17 | sigs.k8s.io/controller-runtime v0.4.0
18 | volcano.sh/volcano v0.4.0
19 | )
20 |
21 | replace (
22 | k8s.io/api => k8s.io/api v0.16.9
23 | k8s.io/apiextensions-apiserver => k8s.io/apiextensions-apiserver v0.16.9
24 | k8s.io/apimachinery => k8s.io/apimachinery v0.16.10-beta.0
25 | k8s.io/apiserver => k8s.io/apiserver v0.16.9
26 | k8s.io/cli-runtime => k8s.io/cli-runtime v0.16.9
27 | k8s.io/client-go => k8s.io/client-go v0.16.9
28 | k8s.io/cloud-provider => k8s.io/cloud-provider v0.16.9
29 | k8s.io/cluster-bootstrap => k8s.io/cluster-bootstrap v0.16.9
30 | k8s.io/code-generator => k8s.io/code-generator v0.16.10-beta.0
31 | k8s.io/component-base => k8s.io/component-base v0.16.9
32 | k8s.io/cri-api => k8s.io/cri-api v0.16.10-beta.0
33 | k8s.io/csi-translation-lib => k8s.io/csi-translation-lib v0.16.9
34 | k8s.io/kube-aggregator => k8s.io/kube-aggregator v0.16.9
35 | k8s.io/kube-controller-manager => k8s.io/kube-controller-manager v0.16.9
36 | k8s.io/kube-proxy => k8s.io/kube-proxy v0.16.9
37 | k8s.io/kube-scheduler => k8s.io/kube-scheduler v0.16.9
38 | k8s.io/kubectl => k8s.io/kubectl v0.16.9
39 | k8s.io/kubelet => k8s.io/kubelet v0.16.9
40 | k8s.io/legacy-cloud-providers => k8s.io/legacy-cloud-providers v0.16.9
41 | k8s.io/metrics => k8s.io/metrics v0.16.9
42 | k8s.io/node-api => k8s.io/node-api v0.16.9
43 | k8s.io/sample-apiserver => k8s.io/sample-apiserver v0.16.9
44 | k8s.io/sample-cli-plugin => k8s.io/sample-cli-plugin v0.16.9
45 | k8s.io/sample-controller => k8s.io/sample-controller v0.16.9
46 | )
47 |
--------------------------------------------------------------------------------
/hack/boilerplate.go.txt:
--------------------------------------------------------------------------------
1 | /*
2 |
3 | Licensed under the Apache License, Version 2.0 (the "License");
4 | you may not use this file except in compliance with the License.
5 | You may obtain a copy of the License at
6 |
7 | http://www.apache.org/licenses/LICENSE-2.0
8 |
9 | Unless required by applicable law or agreed to in writing, software
10 | distributed under the License is distributed on an "AS IS" BASIS,
11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | See the License for the specific language governing permissions and
13 | limitations under the License.
14 | */
--------------------------------------------------------------------------------
/manifests/base/cluster-role-binding.yaml:
--------------------------------------------------------------------------------
1 | apiVersion: rbac.authorization.k8s.io/v1
2 | kind: ClusterRoleBinding
3 | metadata:
4 | name: cluster-role-binding
5 | roleRef:
6 | apiGroup: rbac.authorization.k8s.io
7 | kind: ClusterRole
8 | name: cluster-role
9 | subjects:
10 | - kind: ServiceAccount
11 | name: service-account
12 |
--------------------------------------------------------------------------------
/manifests/base/cluster-role.yaml:
--------------------------------------------------------------------------------
1 | apiVersion: rbac.authorization.k8s.io/v1
2 | kind: ClusterRole
3 | metadata:
4 | name: cluster-role
5 | rules:
6 | - apiGroups:
7 | - apps
8 | resources:
9 | - deployments
10 | - deployments/status
11 | verbs:
12 | - get
13 | - list
14 | - watch
15 | - create
16 | - update
17 | - patch
18 | - delete
19 | - apiGroups:
20 | - xgboostjob.kubeflow.org
21 | resources:
22 | - xgboostjobs
23 | - xgboostjobs/status
24 | verbs:
25 | - get
26 | - list
27 | - watch
28 | - create
29 | - update
30 | - patch
31 | - delete
32 | - apiGroups:
33 | - admissionregistration.k8s.io
34 | resources:
35 | - mutatingwebhookconfigurations
36 | - validatingwebhookconfigurations
37 | verbs:
38 | - get
39 | - list
40 | - watch
41 | - create
42 | - update
43 | - patch
44 | - delete
45 | - apiGroups:
46 | - ""
47 | resources:
48 | - configmaps
49 | - endpoints
50 | - events
51 | - namespaces
52 | - persistentvolumeclaims
53 | - pods
54 | - secrets
55 | - services
56 | verbs:
57 | - get
58 | - list
59 | - watch
60 | - create
61 | - update
62 | - patch
63 | - delete
64 | - apiGroups:
65 | - storage.k8s.io
66 | resources:
67 | - storageclasses
68 | verbs:
69 | - get
70 | - list
71 | - watch
72 | - create
73 | - update
74 | - patch
75 | - delete
76 |
--------------------------------------------------------------------------------
/manifests/base/deployment.yaml:
--------------------------------------------------------------------------------
1 | apiVersion: apps/v1
2 | kind: Deployment
3 | metadata:
4 | name: deployment
5 | spec:
6 | replicas: 1
7 | selector:
8 | matchLabels:
9 | app: xgboost-operator
10 | template:
11 | metadata:
12 | labels:
13 | app: xgboost-operator
14 | spec:
15 | containers:
16 | - name: xgboost-operator
17 | command:
18 | - /root/manager
19 | - -mode=in-cluster
20 | image: gcr.io/kubeflow-images-public/xgboost-operator:v0.1.0
21 | imagePullPolicy: Always
22 | serviceAccountName: service-account
23 |
--------------------------------------------------------------------------------
/manifests/base/kustomization.yaml:
--------------------------------------------------------------------------------
1 | apiVersion: kustomize.config.k8s.io/v1beta1
2 | kind: Kustomization
3 | resources:
4 | - cluster-role.yaml
5 | - cluster-role-binding.yaml
6 | - crd.yaml
7 | - deployment.yaml
8 | - service-account.yaml
9 | - service.yaml
10 | namespace: kubeflow
11 | namePrefix: xgboost-operator-
12 | configMapGenerator:
13 | - envs:
14 | - params.env
15 | name: xgboost-operator-config
16 | images:
17 | - name: gcr.io/kubeflow-images-public/xgboost-operator
18 | newName: kubeflow/xgboost-operator
19 | newTag: v0.2.0
20 |
--------------------------------------------------------------------------------
/manifests/base/params.env:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/kubeflow/xgboost-operator/17383e882bd730103d978af310145d0a45437c74/manifests/base/params.env
--------------------------------------------------------------------------------
/manifests/base/service-account.yaml:
--------------------------------------------------------------------------------
1 | apiVersion: v1
2 | kind: ServiceAccount
3 | metadata:
4 | name: service-account
5 |
--------------------------------------------------------------------------------
/manifests/base/service.yaml:
--------------------------------------------------------------------------------
1 | apiVersion: v1
2 | kind: Service
3 | metadata:
4 | name: service
5 | annotations:
6 | prometheus.io/path: /metrics
7 | prometheus.io/scrape: "true"
8 | prometheus.io/port: "8080"
9 | labels:
10 | app: xgboost-operator
11 | spec:
12 | type: ClusterIP
13 | selector:
14 | app: xgboost-operator
15 | ports:
16 | - port: 443
17 |
--------------------------------------------------------------------------------
/manifests/overlays/kubeflow/kustomization.yaml:
--------------------------------------------------------------------------------
1 | apiVersion: kustomize.config.k8s.io/v1beta1
2 | kind: Kustomization
3 | resources:
4 | - ../../base
5 | namespace: kubeflow
6 | commonLabels:
7 | app.kubernetes.io/component: xgboostjob
8 | app.kubernetes.io/name: xgboost-operator
9 |
--------------------------------------------------------------------------------
/pkg/apis/addtoscheme_xgboostjob.go:
--------------------------------------------------------------------------------
1 | /*
2 |
3 | Licensed under the Apache License, Version 2.0 (the "License");
4 | you may not use this file except in compliance with the License.
5 | You may obtain a copy of the License at
6 |
7 | http://www.apache.org/licenses/LICENSE-2.0
8 |
9 | Unless required by applicable law or agreed to in writing, software
10 | distributed under the License is distributed on an "AS IS" BASIS,
11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | See the License for the specific language governing permissions and
13 | limitations under the License.
14 | */
15 |
16 | package apis
17 |
18 | import (
19 | "github.com/kubeflow/xgboost-operator/pkg/apis/xgboostjob/v1"
20 | )
21 |
22 | func init() {
23 | // Register the types with the Scheme so the components can map objects to GroupVersionKinds and back
24 | AddToSchemes = append(AddToSchemes, v1.SchemeBuilder.AddToScheme)
25 | }
26 |
--------------------------------------------------------------------------------
/pkg/apis/apis.go:
--------------------------------------------------------------------------------
1 | /*
2 |
3 | Licensed under the Apache License, Version 2.0 (the "License");
4 | you may not use this file except in compliance with the License.
5 | You may obtain a copy of the License at
6 |
7 | http://www.apache.org/licenses/LICENSE-2.0
8 |
9 | Unless required by applicable law or agreed to in writing, software
10 | distributed under the License is distributed on an "AS IS" BASIS,
11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | See the License for the specific language governing permissions and
13 | limitations under the License.
14 | */
15 |
16 | // Generate deepcopy for apis
17 | // go:generate go run ../../vendor/k8s.io/code-generator/cmd/deepcopy-gen/main.go -O zz_generated.deepcopy -i ./... -h ../../hack/boilerplate.go.txt
18 |
19 | // Package apis contains Kubernetes API groups.
20 | package apis
21 |
22 | import (
23 | "k8s.io/apimachinery/pkg/runtime"
24 | )
25 |
26 | // AddToSchemes may be used to add all resources defined in the project to a Scheme
27 | var AddToSchemes runtime.SchemeBuilder
28 |
29 | // AddToScheme adds all Resources to the Scheme
30 | func AddToScheme(s *runtime.Scheme) error {
31 | return AddToSchemes.AddToScheme(s)
32 | }
33 |
--------------------------------------------------------------------------------
/pkg/apis/xgboostjob/group.go:
--------------------------------------------------------------------------------
1 | /*
2 |
3 | Licensed under the Apache License, Version 2.0 (the "License");
4 | you may not use this file except in compliance with the License.
5 | You may obtain a copy of the License at
6 |
7 | http://www.apache.org/licenses/LICENSE-2.0
8 |
9 | Unless required by applicable law or agreed to in writing, software
10 | distributed under the License is distributed on an "AS IS" BASIS,
11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | See the License for the specific language governing permissions and
13 | limitations under the License.
14 | */
15 |
16 | // Package xgboostjob contains xgboostjob API versions
17 | package xgboostjob
18 |
--------------------------------------------------------------------------------
/pkg/apis/xgboostjob/v1/constants.go:
--------------------------------------------------------------------------------
1 | /*
2 |
3 | Licensed under the Apache License, Version 2.0 (the "License");
4 | you may not use this file except in compliance with the License.
5 | You may obtain a copy of the License at
6 |
7 | http://www.apache.org/licenses/LICENSE-2.0
8 |
9 | Unless required by applicable law or agreed to in writing, software
10 | distributed under the License is distributed on an "AS IS" BASIS,
11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | See the License for the specific language governing permissions and
13 | limitations under the License.
14 | */
15 |
16 | package v1
17 |
18 | const (
19 | // GroupName is the group name use in this package.
20 | GroupName = "kubeflow.org"
21 | // Kind is the kind name.
22 | Kind = "XGBoostJob"
23 | // GroupVersion is the version.
24 | GroupVersion = "v1"
25 | // Plural is the Plural for XGBoostJob.
26 | Plural = "xgboostjobs"
27 | // Singular is the singular for XGBoostJob.
28 | Singular = "xgboostjob"
29 | // XGBOOSTCRD is the CRD name for XGBoostJob.
30 | XGBoostCRD = "xgboostjobs.kubeflow.org"
31 |
32 | DefaultContainerName = "xgboostjob"
33 | DefaultContainerPortName = "xgboostjob-port"
34 | DefaultPort = 9999
35 | )
36 |
--------------------------------------------------------------------------------
/pkg/apis/xgboostjob/v1/doc.go:
--------------------------------------------------------------------------------
1 | /*
2 |
3 | Licensed under the Apache License, Version 2.0 (the "License");
4 | you may not use this file except in compliance with the License.
5 | You may obtain a copy of the License at
6 |
7 | http://www.apache.org/licenses/LICENSE-2.0
8 |
9 | Unless required by applicable law or agreed to in writing, software
10 | distributed under the License is distributed on an "AS IS" BASIS,
11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | See the License for the specific language governing permissions and
13 | limitations under the License.
14 | */
15 |
16 | // Package v1 contains API Schema definitions for the xgboostjob v1 API group
17 | // +k8s:openapi-gen=true
18 | // +k8s:deepcopy-gen=package,register
19 | // +k8s:conversion-gen=github.com/kubeflow/xgboost-operator/pkg/apis/xgboostjob
20 | // +k8s:defaulter-gen=TypeMeta
21 | // +groupName=xgboostjob.kubeflow.org
22 | package v1
23 |
--------------------------------------------------------------------------------
/pkg/apis/xgboostjob/v1/register.go:
--------------------------------------------------------------------------------
1 | /*
2 |
3 | Licensed under the Apache License, Version 2.0 (the "License");
4 | you may not use this file except in compliance with the License.
5 | You may obtain a copy of the License at
6 |
7 | http://www.apache.org/licenses/LICENSE-2.0
8 |
9 | Unless required by applicable law or agreed to in writing, software
10 | distributed under the License is distributed on an "AS IS" BASIS,
11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | See the License for the specific language governing permissions and
13 | limitations under the License.
14 | */
15 |
16 | // NOTE: Boilerplate only. Ignore this file.
17 |
18 | // Package v1 contains API Schema definitions for the xgboostjob v1 API group
19 | // +k8s:openapi-gen=true
20 | // +k8s:deepcopy-gen=package,register
21 | // +k8s:conversion-gen=github.com/kubeflow/xgboost-operator/pkg/apis/xgboostjob
22 | // +k8s:defaulter-gen=TypeMeta
23 | // +groupName=xgboostjob.kubeflow.org
24 | package v1
25 |
26 | import (
27 | "k8s.io/apimachinery/pkg/runtime/schema"
28 | "sigs.k8s.io/controller-runtime/pkg/runtime/scheme"
29 | )
30 |
31 | var (
32 | // SchemeGroupVersion is group version used to register these objects
33 | SchemeGroupVersion = schema.GroupVersion{Group: "xgboostjob.kubeflow.org", Version: "v1"}
34 |
35 | // SchemeBuilder is used to add go types to the GroupVersionKind scheme
36 | SchemeBuilder = &scheme.Builder{GroupVersion: SchemeGroupVersion}
37 |
38 | // SchemeGroupVersionKind is the GroupVersionKind of the resource.
39 | SchemeGroupVersionKind = SchemeGroupVersion.WithKind(Kind)
40 |
41 | // AddToScheme is required by pkg/client/...
42 | AddToScheme = SchemeBuilder.AddToScheme
43 | )
44 |
45 | // Resource is required by pkg/client/listers/...
46 | func Resource(resource string) schema.GroupResource {
47 | return SchemeGroupVersion.WithResource(resource).GroupResource()
48 | }
49 |
--------------------------------------------------------------------------------
/pkg/apis/xgboostjob/v1/v1_suite_test.go.back:
--------------------------------------------------------------------------------
1 | /*
2 |
3 | Licensed under the Apache License, Version 2.0 (the "License");
4 | you may not use this file except in compliance with the License.
5 | You may obtain a copy of the License at
6 |
7 | http://www.apache.org/licenses/LICENSE-2.0
8 |
9 | Unless required by applicable law or agreed to in writing, software
10 | distributed under the License is distributed on an "AS IS" BASIS,
11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | See the License for the specific language governing permissions and
13 | limitations under the License.
14 | */
15 |
16 | package v1
17 |
18 | import (
19 | "log"
20 | "os"
21 | "path/filepath"
22 | "testing"
23 |
24 | "k8s.io/client-go/kubernetes/scheme"
25 | "k8s.io/client-go/rest"
26 | "sigs.k8s.io/controller-runtime/pkg/client"
27 | "sigs.k8s.io/controller-runtime/pkg/envtest"
28 | )
29 |
30 | var cfg *rest.Config
31 | var c client.Client
32 |
33 | func TestMain(m *testing.M) {
34 | t := &envtest.Environment{
35 | CRDDirectoryPaths: []string{filepath.Join("..", "..", "..", "..", "config", "crds")},
36 | }
37 |
38 | err := SchemeBuilder.AddToScheme(scheme.Scheme)
39 | if err != nil {
40 | log.Fatal(err)
41 | }
42 |
43 | if cfg, err = t.Start(); err != nil {
44 | log.Fatal(err)
45 | }
46 |
47 | if c, err = client.New(cfg, client.Options{Scheme: scheme.Scheme}); err != nil {
48 | log.Fatal(err)
49 | }
50 |
51 | code := m.Run()
52 | t.Stop()
53 | os.Exit(code)
54 | }
55 |
--------------------------------------------------------------------------------
/pkg/apis/xgboostjob/v1/xgboostjob_types.go:
--------------------------------------------------------------------------------
1 | /*
2 |
3 | Licensed under the Apache License, Version 2.0 (the "License");
4 | you may not use this file except in compliance with the License.
5 | You may obtain a copy of the License at
6 |
7 | http://www.apache.org/licenses/LICENSE-2.0
8 |
9 | Unless required by applicable law or agreed to in writing, software
10 | distributed under the License is distributed on an "AS IS" BASIS,
11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | See the License for the specific language governing permissions and
13 | limitations under the License.
14 | */
15 |
16 | package v1
17 |
18 | import (
19 | commonv1 "github.com/kubeflow/common/pkg/apis/common/v1"
20 | metav1 "k8s.io/apimachinery/pkg/apis/meta/v1"
21 | )
22 |
23 | // EDIT THIS FILE! THIS IS SCAFFOLDING FOR YOU TO OWN!
24 | // NOTE: json tags are required. Any new fields you add must have json tags for the fields to be serialized.
25 |
26 | // XGBoostJobSpec defines the desired state of XGBoostJob
27 | type XGBoostJobSpec struct {
28 | // INSERT ADDITIONAL SPEC FIELDS - desired state of cluster
29 | // Important: Run "make" to regenerate code after modifying this file
30 | RunPolicy commonv1.RunPolicy `json:",inline"`
31 |
32 | XGBReplicaSpecs map[commonv1.ReplicaType]*commonv1.ReplicaSpec `json:"xgbReplicaSpecs"`
33 | }
34 |
35 | // XGBoostJobStatus defines the observed state of XGBoostJob
36 | type XGBoostJobStatus struct {
37 | // INSERT ADDITIONAL STATUS FIELD - define observed state of cluster
38 | // Important: Run "make" to regenerate code after modifying this file
39 | commonv1.JobStatus `json:",inline"`
40 | }
41 |
42 | // +genclient
43 | // +k8s:deepcopy-gen:interfaces=k8s.io/apimachinery/pkg/runtime.Object
44 |
45 | // XGBoostJob is the Schema for the xgboostjobs API
46 | // +k8s:openapi-gen=true
47 | type XGBoostJob struct {
48 | metav1.TypeMeta `json:",inline"`
49 | metav1.ObjectMeta `json:"metadata,omitempty"`
50 |
51 | Spec XGBoostJobSpec `json:"spec,omitempty"`
52 | Status XGBoostJobStatus `json:"status,omitempty"`
53 | }
54 |
55 | // +k8s:deepcopy-gen:interfaces=k8s.io/apimachinery/pkg/runtime.Object
56 |
57 | // XGBoostJobList contains a list of XGBoostJob
58 | type XGBoostJobList struct {
59 | metav1.TypeMeta `json:",inline"`
60 | metav1.ListMeta `json:"metadata,omitempty"`
61 | Items []XGBoostJob `json:"items"`
62 | }
63 |
64 | // XGBoostJobReplicaType is the type for XGBoostJobReplica.
65 | type XGBoostJobReplicaType commonv1.ReplicaType
66 |
67 | const (
68 | // XGBoostReplicaTypeMaster is the type for master replica.
69 | XGBoostReplicaTypeMaster XGBoostJobReplicaType = "Master"
70 |
71 | // XGBoostReplicaTypeWorker is the type for worker replicas.
72 | XGBoostReplicaTypeWorker XGBoostJobReplicaType = "Worker"
73 | )
74 |
75 | func init() {
76 | SchemeBuilder.Register(&XGBoostJob{}, &XGBoostJobList{})
77 | }
78 |
--------------------------------------------------------------------------------
/pkg/apis/xgboostjob/v1/xgboostjob_types_test.go.back:
--------------------------------------------------------------------------------
1 | /*
2 |
3 | Licensed under the Apache License, Version 2.0 (the "License");
4 | you may not use this file except in compliance with the License.
5 | You may obtain a copy of the License at
6 |
7 | http://www.apache.org/licenses/LICENSE-2.0
8 |
9 | Unless required by applicable law or agreed to in writing, software
10 | distributed under the License is distributed on an "AS IS" BASIS,
11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | See the License for the specific language governing permissions and
13 | limitations under the License.
14 | */
15 |
16 | package v1
17 |
18 | import (
19 | "testing"
20 |
21 | "github.com/onsi/gomega"
22 | "golang.org/x/net/context"
23 | metav1 "k8s.io/apimachinery/pkg/apis/meta/v1"
24 | "k8s.io/apimachinery/pkg/types"
25 | )
26 |
27 | func TestStorageXGBoostJob(t *testing.T) {
28 | key := types.NamespacedName{
29 | Name: "foo",
30 | Namespace: "default",
31 | }
32 | created := &XGBoostJob{
33 | ObjectMeta: metav1.ObjectMeta{
34 | Name: "foo",
35 | Namespace: "default",
36 | }}
37 | g := gomega.NewGomegaWithT(t)
38 |
39 | // Test Create
40 | fetched := &XGBoostJob{}
41 | g.Expect(c.Create(context.TODO(), created)).NotTo(gomega.HaveOccurred())
42 |
43 | g.Expect(c.Get(context.TODO(), key, fetched)).NotTo(gomega.HaveOccurred())
44 | g.Expect(fetched).To(gomega.Equal(created))
45 |
46 | // Test Updating the Labels
47 | updated := fetched.DeepCopy()
48 | updated.Labels = map[string]string{"hello": "world"}
49 | g.Expect(c.Update(context.TODO(), updated)).NotTo(gomega.HaveOccurred())
50 |
51 | g.Expect(c.Get(context.TODO(), key, fetched)).NotTo(gomega.HaveOccurred())
52 | g.Expect(fetched).To(gomega.Equal(updated))
53 |
54 | // Test Delete
55 | g.Expect(c.Delete(context.TODO(), fetched)).NotTo(gomega.HaveOccurred())
56 | g.Expect(c.Get(context.TODO(), key, fetched)).To(gomega.HaveOccurred())
57 | }
58 |
--------------------------------------------------------------------------------
/pkg/apis/xgboostjob/v1/zz_generated.deepcopy.go:
--------------------------------------------------------------------------------
1 | // +build !ignore_autogenerated
2 |
3 | /*
4 |
5 | Licensed under the Apache License, Version 2.0 (the "License");
6 | you may not use this file except in compliance with the License.
7 | You may obtain a copy of the License at
8 |
9 | http://www.apache.org/licenses/LICENSE-2.0
10 |
11 | Unless required by applicable law or agreed to in writing, software
12 | distributed under the License is distributed on an "AS IS" BASIS,
13 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 | See the License for the specific language governing permissions and
15 | limitations under the License.
16 | */
17 |
18 | // Code generated by controller-gen. DO NOT EDIT.
19 |
20 | package v1
21 |
22 | import (
23 | commonv1 "github.com/kubeflow/common/pkg/apis/common/v1"
24 | runtime "k8s.io/apimachinery/pkg/runtime"
25 | )
26 |
27 | // DeepCopyInto is an autogenerated deepcopy function, copying the receiver, writing into out. in must be non-nil.
28 | func (in *XGBoostJob) DeepCopyInto(out *XGBoostJob) {
29 | *out = *in
30 | out.TypeMeta = in.TypeMeta
31 | in.ObjectMeta.DeepCopyInto(&out.ObjectMeta)
32 | in.Spec.DeepCopyInto(&out.Spec)
33 | in.Status.DeepCopyInto(&out.Status)
34 | }
35 |
36 | // DeepCopy is an autogenerated deepcopy function, copying the receiver, creating a new XGBoostJob.
37 | func (in *XGBoostJob) DeepCopy() *XGBoostJob {
38 | if in == nil {
39 | return nil
40 | }
41 | out := new(XGBoostJob)
42 | in.DeepCopyInto(out)
43 | return out
44 | }
45 |
46 | // DeepCopyObject is an autogenerated deepcopy function, copying the receiver, creating a new runtime.Object.
47 | func (in *XGBoostJob) DeepCopyObject() runtime.Object {
48 | if c := in.DeepCopy(); c != nil {
49 | return c
50 | }
51 | return nil
52 | }
53 |
54 | // DeepCopyInto is an autogenerated deepcopy function, copying the receiver, writing into out. in must be non-nil.
55 | func (in *XGBoostJobList) DeepCopyInto(out *XGBoostJobList) {
56 | *out = *in
57 | out.TypeMeta = in.TypeMeta
58 | in.ListMeta.DeepCopyInto(&out.ListMeta)
59 | if in.Items != nil {
60 | in, out := &in.Items, &out.Items
61 | *out = make([]XGBoostJob, len(*in))
62 | for i := range *in {
63 | (*in)[i].DeepCopyInto(&(*out)[i])
64 | }
65 | }
66 | }
67 |
68 | // DeepCopy is an autogenerated deepcopy function, copying the receiver, creating a new XGBoostJobList.
69 | func (in *XGBoostJobList) DeepCopy() *XGBoostJobList {
70 | if in == nil {
71 | return nil
72 | }
73 | out := new(XGBoostJobList)
74 | in.DeepCopyInto(out)
75 | return out
76 | }
77 |
78 | // DeepCopyObject is an autogenerated deepcopy function, copying the receiver, creating a new runtime.Object.
79 | func (in *XGBoostJobList) DeepCopyObject() runtime.Object {
80 | if c := in.DeepCopy(); c != nil {
81 | return c
82 | }
83 | return nil
84 | }
85 |
86 | // DeepCopyInto is an autogenerated deepcopy function, copying the receiver, writing into out. in must be non-nil.
87 | func (in *XGBoostJobSpec) DeepCopyInto(out *XGBoostJobSpec) {
88 | *out = *in
89 | in.RunPolicy.DeepCopyInto(&out.RunPolicy)
90 | if in.XGBReplicaSpecs != nil {
91 | in, out := &in.XGBReplicaSpecs, &out.XGBReplicaSpecs
92 | *out = make(map[commonv1.ReplicaType]*commonv1.ReplicaSpec, len(*in))
93 | for key, val := range *in {
94 | var outVal *commonv1.ReplicaSpec
95 | if val == nil {
96 | (*out)[key] = nil
97 | } else {
98 | in, out := &val, &outVal
99 | *out = new(commonv1.ReplicaSpec)
100 | (*in).DeepCopyInto(*out)
101 | }
102 | (*out)[key] = outVal
103 | }
104 | }
105 | }
106 |
107 | // DeepCopy is an autogenerated deepcopy function, copying the receiver, creating a new XGBoostJobSpec.
108 | func (in *XGBoostJobSpec) DeepCopy() *XGBoostJobSpec {
109 | if in == nil {
110 | return nil
111 | }
112 | out := new(XGBoostJobSpec)
113 | in.DeepCopyInto(out)
114 | return out
115 | }
116 |
117 | // DeepCopyInto is an autogenerated deepcopy function, copying the receiver, writing into out. in must be non-nil.
118 | func (in *XGBoostJobStatus) DeepCopyInto(out *XGBoostJobStatus) {
119 | *out = *in
120 | in.JobStatus.DeepCopyInto(&out.JobStatus)
121 | }
122 |
123 | // DeepCopy is an autogenerated deepcopy function, copying the receiver, creating a new XGBoostJobStatus.
124 | func (in *XGBoostJobStatus) DeepCopy() *XGBoostJobStatus {
125 | if in == nil {
126 | return nil
127 | }
128 | out := new(XGBoostJobStatus)
129 | in.DeepCopyInto(out)
130 | return out
131 | }
132 |
--------------------------------------------------------------------------------
/pkg/controller/v1/add_xgboostjob.go:
--------------------------------------------------------------------------------
1 | /*
2 |
3 | Licensed under the Apache License, Version 2.0 (the "License");
4 | you may not use this file except in compliance with the License.
5 | You may obtain a copy of the License at
6 |
7 | http://www.apache.org/licenses/LICENSE-2.0
8 |
9 | Unless required by applicable law or agreed to in writing, software
10 | distributed under the License is distributed on an "AS IS" BASIS,
11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | See the License for the specific language governing permissions and
13 | limitations under the License.
14 | */
15 |
16 | package v1
17 |
18 | import (
19 | "github.com/kubeflow/xgboost-operator/pkg/controller/v1/xgboostjob"
20 | )
21 |
22 | func init() {
23 | // AddToManagerFuncs is a list of functions to create controllers and add them to a manager.
24 | AddToManagerFuncs = append(AddToManagerFuncs, xgboostjob.Add)
25 | }
26 |
--------------------------------------------------------------------------------
/pkg/controller/v1/controller.go:
--------------------------------------------------------------------------------
1 | /*
2 |
3 | Licensed under the Apache License, Version 2.0 (the "License");
4 | you may not use this file except in compliance with the License.
5 | You may obtain a copy of the License at
6 |
7 | http://www.apache.org/licenses/LICENSE-2.0
8 |
9 | Unless required by applicable law or agreed to in writing, software
10 | distributed under the License is distributed on an "AS IS" BASIS,
11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | See the License for the specific language governing permissions and
13 | limitations under the License.
14 | */
15 |
16 | package v1
17 |
18 | import (
19 | "sigs.k8s.io/controller-runtime/pkg/manager"
20 | )
21 |
22 | // AddToManagerFuncs is a list of functions to add all Controllers to the Manager
23 | var AddToManagerFuncs []func(manager.Manager) error
24 |
25 | // AddToManager adds all Controllers to the Manager
26 | func AddToManager(m manager.Manager) error {
27 | for _, f := range AddToManagerFuncs {
28 | if err := f(m); err != nil {
29 | return err
30 | }
31 | }
32 | return nil
33 | }
34 |
--------------------------------------------------------------------------------
/pkg/controller/v1/xgboostjob/expectation.go:
--------------------------------------------------------------------------------
1 | /*
2 |
3 | Licensed under the Apache License, Version 2.0 (the "License");
4 | you may not use this file except in compliance with the License.
5 | You may obtain a copy of the License at
6 |
7 | http://www.apache.org/licenses/LICENSE-2.0
8 |
9 | Unless required by applicable law or agreed to in writing, software
10 | distributed under the License is distributed on an "AS IS" BASIS,
11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | See the License for the specific language governing permissions and
13 | limitations under the License.
14 | */
15 |
16 | package xgboostjob
17 |
18 | import (
19 | "fmt"
20 | commonv1 "github.com/kubeflow/common/pkg/apis/common/v1"
21 | "github.com/kubeflow/common/pkg/controller.v1/common"
22 | "github.com/kubeflow/common/pkg/controller.v1/expectation"
23 | "github.com/kubeflow/xgboost-operator/pkg/apis/xgboostjob/v1"
24 | "github.com/sirupsen/logrus"
25 | corev1 "k8s.io/api/core/v1"
26 | metav1 "k8s.io/apimachinery/pkg/apis/meta/v1"
27 | utilruntime "k8s.io/apimachinery/pkg/util/runtime"
28 | "sigs.k8s.io/controller-runtime/pkg/event"
29 | "sigs.k8s.io/controller-runtime/pkg/reconcile"
30 | )
31 |
32 | // satisfiedExpectations returns true if the required adds/dels for the given job have been observed.
33 | // Add/del counts are established by the controller at sync time, and updated as controllees are observed by the controller
34 | // manager.
35 | func (r *ReconcileXGBoostJob) satisfiedExpectations(xgbJob *v1.XGBoostJob) bool {
36 | satisfied := false
37 | key, err := common.KeyFunc(xgbJob)
38 | if err != nil {
39 | utilruntime.HandleError(fmt.Errorf("couldn't get key for job object %#v: %v", xgbJob, err))
40 | return false
41 | }
42 | for rtype := range xgbJob.Spec.XGBReplicaSpecs {
43 | // Check the expectations of the pods.
44 | expectationPodsKey := expectation.GenExpectationPodsKey(key, string(rtype))
45 | satisfied = satisfied || r.Expectations.SatisfiedExpectations(expectationPodsKey)
46 | // Check the expectations of the services.
47 | expectationServicesKey := expectation.GenExpectationServicesKey(key, string(rtype))
48 | satisfied = satisfied || r.Expectations.SatisfiedExpectations(expectationServicesKey)
49 | }
50 | return satisfied
51 | }
52 |
53 | // onDependentCreateFunc modify expectations when dependent (pod/service) creation observed.
54 | func onDependentCreateFunc(r reconcile.Reconciler) func(event.CreateEvent) bool {
55 | return func(e event.CreateEvent) bool {
56 | xgbr, ok := r.(*ReconcileXGBoostJob)
57 | if !ok {
58 | return true
59 | }
60 | rtype := e.Meta.GetLabels()[commonv1.ReplicaTypeLabel]
61 | if len(rtype) == 0 {
62 | return false
63 | }
64 |
65 | logrus.Info("Update on create function ", xgbr.ControllerName(), " create object ", e.Meta.GetName())
66 | if controllerRef := metav1.GetControllerOf(e.Meta); controllerRef != nil {
67 | var expectKey string
68 | if _, ok := e.Object.(*corev1.Pod); ok {
69 | expectKey = expectation.GenExpectationPodsKey(e.Meta.GetNamespace()+"/"+controllerRef.Name, rtype)
70 | }
71 |
72 | if _, ok := e.Object.(*corev1.Service); ok {
73 | expectKey = expectation.GenExpectationServicesKey(e.Meta.GetNamespace()+"/"+controllerRef.Name, rtype)
74 | }
75 | xgbr.Expectations.CreationObserved(expectKey)
76 | return true
77 | }
78 |
79 | return true
80 | }
81 | }
82 |
83 | // onDependentDeleteFunc modify expectations when dependent (pod/service) deletion observed.
84 | func onDependentDeleteFunc(r reconcile.Reconciler) func(event.DeleteEvent) bool {
85 | return func(e event.DeleteEvent) bool {
86 | xgbr, ok := r.(*ReconcileXGBoostJob)
87 | if !ok {
88 | return true
89 | }
90 |
91 | rtype := e.Meta.GetLabels()[commonv1.ReplicaTypeLabel]
92 | if len(rtype) == 0 {
93 | return false
94 | }
95 |
96 | logrus.Info("Update on deleting function ", xgbr.ControllerName(), " delete object ", e.Meta.GetName())
97 | if controllerRef := metav1.GetControllerOf(e.Meta); controllerRef != nil {
98 | var expectKey string
99 | if _, ok := e.Object.(*corev1.Pod); ok {
100 | expectKey = expectation.GenExpectationPodsKey(e.Meta.GetNamespace()+"/"+controllerRef.Name, rtype)
101 | }
102 |
103 | if _, ok := e.Object.(*corev1.Service); ok {
104 | expectKey = expectation.GenExpectationServicesKey(e.Meta.GetNamespace()+"/"+controllerRef.Name, rtype)
105 | }
106 |
107 | xgbr.Expectations.DeletionObserved(expectKey)
108 | return true
109 | }
110 |
111 | return true
112 | }
113 | }
114 |
--------------------------------------------------------------------------------
/pkg/controller/v1/xgboostjob/job.go:
--------------------------------------------------------------------------------
1 | /*
2 |
3 | Licensed under the Apache License, Version 2.0 (the "License");
4 | you may not use this file except in compliance with the License.
5 | You may obtain a copy of the License at
6 |
7 | http://www.apache.org/licenses/LICENSE-2.0
8 |
9 | Unless required by applicable law or agreed to in writing, software
10 | distributed under the License is distributed on an "AS IS" BASIS,
11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | See the License for the specific language governing permissions and
13 | limitations under the License.
14 | */
15 |
16 | package xgboostjob
17 |
18 | import (
19 | "context"
20 | "fmt"
21 | "reflect"
22 |
23 | commonv1 "github.com/kubeflow/common/pkg/apis/common/v1"
24 | commonutil "github.com/kubeflow/common/pkg/util"
25 | logger "github.com/kubeflow/common/pkg/util"
26 | v1xgboost "github.com/kubeflow/xgboost-operator/pkg/apis/xgboostjob/v1"
27 | "github.com/sirupsen/logrus"
28 | corev1 "k8s.io/api/core/v1"
29 | k8sv1 "k8s.io/api/core/v1"
30 | "k8s.io/apimachinery/pkg/api/errors"
31 | metav1 "k8s.io/apimachinery/pkg/apis/meta/v1"
32 | "k8s.io/apimachinery/pkg/types"
33 | "k8s.io/client-go/kubernetes/scheme"
34 | "sigs.k8s.io/controller-runtime/pkg/event"
35 | "sigs.k8s.io/controller-runtime/pkg/reconcile"
36 | )
37 |
38 | // Reasons for job events.
39 | const (
40 | FailedDeleteJobReason = "FailedDeleteJob"
41 | SuccessfulDeleteJobReason = "SuccessfulDeleteJob"
42 | // xgboostJobCreatedReason is added in a job when it is created.
43 | xgboostJobCreatedReason = "XGBoostJobCreated"
44 |
45 | xgboostJobSucceededReason = "XGBoostJobSucceeded"
46 | xgboostJobRunningReason = "XGBoostJobRunning"
47 | xgboostJobFailedReason = "XGBoostJobFailed"
48 | xgboostJobRestartingReason = "XGBoostJobRestarting"
49 | )
50 |
51 | // DeleteJob deletes the job
52 | func (r *ReconcileXGBoostJob) DeleteJob(job interface{}) error {
53 | xgboostjob, ok := job.(*v1xgboost.XGBoostJob)
54 | if !ok {
55 | return fmt.Errorf("%+v is not a type of XGBoostJob", xgboostjob)
56 | }
57 | if err := r.Delete(context.Background(), xgboostjob); err != nil {
58 | r.recorder.Eventf(xgboostjob, corev1.EventTypeWarning, FailedDeleteJobReason, "Error deleting: %v", err)
59 | log.Error(err, "failed to delete job", "namespace", xgboostjob.Namespace, "name", xgboostjob.Name)
60 | return err
61 | }
62 | r.recorder.Eventf(xgboostjob, corev1.EventTypeNormal, SuccessfulDeleteJobReason, "Deleted job: %v", xgboostjob.Name)
63 | log.Info("job deleted", "namespace", xgboostjob.Namespace, "name", xgboostjob.Name)
64 | return nil
65 | }
66 |
67 | // GetJobFromInformerCache returns the Job from Informer Cache
68 | func (r *ReconcileXGBoostJob) GetJobFromInformerCache(namespace, name string) (metav1.Object, error) {
69 | job := &v1xgboost.XGBoostJob{}
70 | // Default reader for XGBoostJob is cache reader.
71 | err := r.Get(context.Background(), types.NamespacedName{Namespace: namespace, Name: name}, job)
72 | if err != nil {
73 | if errors.IsNotFound(err) {
74 | log.Error(err, "xgboost job not found", "namespace", namespace, "name", name)
75 | } else {
76 | log.Error(err, "failed to get job from api-server", "namespace", namespace, "name", name)
77 | }
78 | return nil, err
79 | }
80 | return job, nil
81 | }
82 |
83 | // GetJobFromAPIClient returns the Job from API server
84 | func (r *ReconcileXGBoostJob) GetJobFromAPIClient(namespace, name string) (metav1.Object, error) {
85 | job := &v1xgboost.XGBoostJob{}
86 |
87 | clientReader, err := getClientReaderFromClient(r.Client)
88 | if err != nil {
89 | return nil, err
90 | }
91 | err = clientReader.Get(context.Background(), types.NamespacedName{Namespace: namespace, Name: name}, job)
92 | if err != nil {
93 | if errors.IsNotFound(err) {
94 | log.Error(err, "xgboost job not found", "namespace", namespace, "name", name)
95 | } else {
96 | log.Error(err, "failed to get job from api-server", "namespace", namespace, "name", name)
97 | }
98 | return nil, err
99 | }
100 | return job, nil
101 | }
102 |
103 | // UpdateJobStatus updates the job status and job conditions
104 | func (r *ReconcileXGBoostJob) UpdateJobStatus(job interface{}, replicas map[commonv1.ReplicaType]*commonv1.ReplicaSpec, jobStatus *commonv1.JobStatus) error {
105 | xgboostJob, ok := job.(*v1xgboost.XGBoostJob)
106 | if !ok {
107 | return fmt.Errorf("%+v is not a type of xgboostJob", xgboostJob)
108 | }
109 |
110 | for rtype, spec := range replicas {
111 | status := jobStatus.ReplicaStatuses[rtype]
112 |
113 | succeeded := status.Succeeded
114 | expected := *(spec.Replicas) - succeeded
115 | running := status.Active
116 | failed := status.Failed
117 |
118 | logrus.Infof("XGBoostJob=%s, ReplicaType=%s expected=%d, running=%d, succeeded=%d , failed=%d",
119 | xgboostJob.Name, rtype, expected, running, succeeded, failed)
120 |
121 | if rtype == commonv1.ReplicaType(v1xgboost.XGBoostReplicaTypeMaster) {
122 | if running > 0 {
123 | msg := fmt.Sprintf("XGBoostJob %s is running.", xgboostJob.Name)
124 | err := commonutil.UpdateJobConditions(jobStatus, commonv1.JobRunning, xgboostJobRunningReason, msg)
125 | if err != nil {
126 | logger.LoggerForJob(xgboostJob).Infof("Append job condition error: %v", err)
127 | return err
128 | }
129 | }
130 | // when master is succeed, the job is finished.
131 | if expected == 0 {
132 | msg := fmt.Sprintf("XGBoostJob %s is successfully completed.", xgboostJob.Name)
133 | logrus.Info(msg)
134 | r.Recorder.Event(xgboostJob, k8sv1.EventTypeNormal, xgboostJobSucceededReason, msg)
135 | if jobStatus.CompletionTime == nil {
136 | now := metav1.Now()
137 | jobStatus.CompletionTime = &now
138 | }
139 | err := commonutil.UpdateJobConditions(jobStatus, commonv1.JobSucceeded, xgboostJobSucceededReason, msg)
140 | if err != nil {
141 | logger.LoggerForJob(xgboostJob).Infof("Append job condition error: %v", err)
142 | return err
143 | }
144 | return nil
145 | }
146 | }
147 | if failed > 0 {
148 | if spec.RestartPolicy == commonv1.RestartPolicyExitCode {
149 | msg := fmt.Sprintf("XGBoostJob %s is restarting because %d %s replica(s) failed.", xgboostJob.Name, failed, rtype)
150 | r.Recorder.Event(xgboostJob, k8sv1.EventTypeWarning, xgboostJobRestartingReason, msg)
151 | err := commonutil.UpdateJobConditions(jobStatus, commonv1.JobRestarting, xgboostJobRestartingReason, msg)
152 | if err != nil {
153 | logger.LoggerForJob(xgboostJob).Infof("Append job condition error: %v", err)
154 | return err
155 | }
156 | } else {
157 | msg := fmt.Sprintf("XGBoostJob %s is failed because %d %s replica(s) failed.", xgboostJob.Name, failed, rtype)
158 | r.Recorder.Event(xgboostJob, k8sv1.EventTypeNormal, xgboostJobFailedReason, msg)
159 | if xgboostJob.Status.CompletionTime == nil {
160 | now := metav1.Now()
161 | xgboostJob.Status.CompletionTime = &now
162 | }
163 | err := commonutil.UpdateJobConditions(jobStatus, commonv1.JobFailed, xgboostJobFailedReason, msg)
164 | if err != nil {
165 | logger.LoggerForJob(xgboostJob).Infof("Append job condition error: %v", err)
166 | return err
167 | }
168 | }
169 | }
170 | }
171 |
172 | // Some workers are still running, leave a running condition.
173 | msg := fmt.Sprintf("XGBoostJob %s is running.", xgboostJob.Name)
174 | logger.LoggerForJob(xgboostJob).Infof(msg)
175 |
176 | if err := commonutil.UpdateJobConditions(jobStatus, commonv1.JobRunning, xgboostJobRunningReason, msg); err != nil {
177 | logger.LoggerForJob(xgboostJob).Error(err, "failed to update XGBoost Job conditions")
178 | return err
179 | }
180 |
181 | return nil
182 | }
183 |
184 | // UpdateJobStatusInApiServer updates the job status in to cluster.
185 | func (r *ReconcileXGBoostJob) UpdateJobStatusInApiServer(job interface{}, jobStatus *commonv1.JobStatus) error {
186 | xgboostjob, ok := job.(*v1xgboost.XGBoostJob)
187 | if !ok {
188 | return fmt.Errorf("%+v is not a type of XGBoostJob", xgboostjob)
189 | }
190 |
191 | // Job status passed in differs with status in job, update in basis of the passed in one.
192 | if !reflect.DeepEqual(&xgboostjob.Status.JobStatus, jobStatus) {
193 | xgboostjob = xgboostjob.DeepCopy()
194 | xgboostjob.Status.JobStatus = *jobStatus.DeepCopy()
195 | }
196 |
197 | result := r.Update(context.Background(), xgboostjob)
198 |
199 | if result != nil {
200 | logger.LoggerForJob(xgboostjob).Error(result, "failed to update XGBoost Job conditions in the API server")
201 | return result
202 | }
203 |
204 | return nil
205 | }
206 |
207 | // onOwnerCreateFunc modify creation condition.
208 | func onOwnerCreateFunc(r reconcile.Reconciler) func(event.CreateEvent) bool {
209 | return func(e event.CreateEvent) bool {
210 | xgboostJob, ok := e.Meta.(*v1xgboost.XGBoostJob)
211 | if !ok {
212 | return true
213 | }
214 | scheme.Scheme.Default(xgboostJob)
215 | msg := fmt.Sprintf("xgboostJob %s is created.", e.Meta.GetName())
216 | logrus.Info(msg)
217 | //specific the run policy
218 |
219 | if xgboostJob.Spec.RunPolicy.CleanPodPolicy == nil {
220 | xgboostJob.Spec.RunPolicy.CleanPodPolicy = new(commonv1.CleanPodPolicy)
221 | xgboostJob.Spec.RunPolicy.CleanPodPolicy = &defaultCleanPodPolicy
222 | }
223 |
224 | if err := commonutil.UpdateJobConditions(&xgboostJob.Status.JobStatus, commonv1.JobCreated, xgboostJobCreatedReason, msg); err != nil {
225 | log.Error(err, "append job condition error")
226 | return false
227 | }
228 | return true
229 | }
230 | }
231 |
--------------------------------------------------------------------------------
/pkg/controller/v1/xgboostjob/pod.go:
--------------------------------------------------------------------------------
1 | /*
2 |
3 | Licensed under the Apache License, Version 2.0 (the "License");
4 | you may not use this file except in compliance with the License.
5 | You may obtain a copy of the License at
6 |
7 | http://www.apache.org/licenses/LICENSE-2.0
8 |
9 | Unless required by applicable law or agreed to in writing, software
10 | distributed under the License is distributed on an "AS IS" BASIS,
11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | See the License for the specific language governing permissions and
13 | limitations under the License.
14 | */
15 |
16 | package xgboostjob
17 |
18 | import (
19 | "context"
20 | "fmt"
21 | "strconv"
22 | "strings"
23 |
24 | "k8s.io/apimachinery/pkg/api/meta"
25 | "sigs.k8s.io/controller-runtime/pkg/client"
26 |
27 | commonv1 "github.com/kubeflow/common/pkg/apis/common/v1"
28 | v1xgboost "github.com/kubeflow/xgboost-operator/pkg/apis/xgboostjob/v1"
29 | corev1 "k8s.io/api/core/v1"
30 | )
31 |
32 | // GetPodsForJob returns the pods managed by the job. This can be achieved by selecting pods using label key "job-name"
33 | // i.e. all pods created by the job will come with label "job-name" =
34 | func (r *ReconcileXGBoostJob) GetPodsForJob(obj interface{}) ([]*corev1.Pod, error) {
35 | job, err := meta.Accessor(obj)
36 | if err != nil {
37 | return nil, err
38 | }
39 | // List all pods to include those that don't match the selector anymore
40 | // but have a ControllerRef pointing to this controller.
41 | podlist := &corev1.PodList{}
42 | err = r.List(context.Background(), podlist, client.MatchingLabels(r.GenLabels(job.GetName())))
43 | if err != nil {
44 | return nil, err
45 | }
46 |
47 | return convertPodList(podlist.Items), nil
48 | }
49 |
50 | // convertPodList convert pod list to pod point list
51 | func convertPodList(list []corev1.Pod) []*corev1.Pod {
52 | if list == nil {
53 | return nil
54 | }
55 | ret := make([]*corev1.Pod, 0, len(list))
56 | for i := range list {
57 | ret = append(ret, &list[i])
58 | }
59 | return ret
60 | }
61 |
62 | // SetPodEnv sets the pod env set for:
63 | // - XGBoost Rabit Tracker and worker
64 | // - LightGBM master and workers
65 | func SetPodEnv(job interface{}, podTemplate *corev1.PodTemplateSpec, rtype, index string) error {
66 | xgboostjob, ok := job.(*v1xgboost.XGBoostJob)
67 | if !ok {
68 | return fmt.Errorf("%+v is not a type of XGBoostJob", xgboostjob)
69 | }
70 |
71 | rank, err := strconv.Atoi(index)
72 | if err != nil {
73 | return err
74 | }
75 |
76 | // Add master offset for worker pods
77 | if strings.ToLower(rtype) == strings.ToLower(string(v1xgboost.XGBoostReplicaTypeWorker)) {
78 | masterSpec := xgboostjob.Spec.XGBReplicaSpecs[commonv1.ReplicaType(v1xgboost.XGBoostReplicaTypeMaster)]
79 | masterReplicas := int(*masterSpec.Replicas)
80 | rank += masterReplicas
81 | }
82 |
83 | masterAddr := computeMasterAddr(xgboostjob.Name, strings.ToLower(string(v1xgboost.XGBoostReplicaTypeMaster)), strconv.Itoa(0))
84 |
85 | masterPort, err := GetPortFromXGBoostJob(xgboostjob, v1xgboost.XGBoostReplicaTypeMaster)
86 | if err != nil {
87 | return err
88 | }
89 |
90 | totalReplicas := computeTotalReplicas(xgboostjob)
91 |
92 | var workerPort int32
93 | var workerAddrs []string
94 |
95 | if totalReplicas > 1 {
96 | workerPortTemp, err := GetPortFromXGBoostJob(xgboostjob, v1xgboost.XGBoostReplicaTypeWorker)
97 | if err != nil {
98 | return err
99 | }
100 | workerPort = workerPortTemp
101 | workerAddrs = make([]string, totalReplicas-1)
102 | for i := range workerAddrs {
103 | workerAddrs[i] = computeMasterAddr(xgboostjob.Name, strings.ToLower(string(v1xgboost.XGBoostReplicaTypeWorker)), strconv.Itoa(i))
104 | }
105 | }
106 |
107 | for i := range podTemplate.Spec.Containers {
108 | if len(podTemplate.Spec.Containers[i].Env) == 0 {
109 | podTemplate.Spec.Containers[i].Env = make([]corev1.EnvVar, 0)
110 | }
111 | podTemplate.Spec.Containers[i].Env = append(podTemplate.Spec.Containers[i].Env, corev1.EnvVar{
112 | Name: "MASTER_PORT",
113 | Value: strconv.Itoa(int(masterPort)),
114 | })
115 | podTemplate.Spec.Containers[i].Env = append(podTemplate.Spec.Containers[i].Env, corev1.EnvVar{
116 | Name: "MASTER_ADDR",
117 | Value: masterAddr,
118 | })
119 | podTemplate.Spec.Containers[i].Env = append(podTemplate.Spec.Containers[i].Env, corev1.EnvVar{
120 | Name: "WORLD_SIZE",
121 | Value: strconv.Itoa(int(totalReplicas)),
122 | })
123 | podTemplate.Spec.Containers[i].Env = append(podTemplate.Spec.Containers[i].Env, corev1.EnvVar{
124 | Name: "RANK",
125 | Value: strconv.Itoa(rank),
126 | })
127 | podTemplate.Spec.Containers[i].Env = append(podTemplate.Spec.Containers[i].Env, corev1.EnvVar{
128 | Name: "PYTHONUNBUFFERED",
129 | Value: "0",
130 | })
131 | // This variables are used if it is a LightGBM job
132 | if totalReplicas > 1 {
133 | podTemplate.Spec.Containers[i].Env = append(podTemplate.Spec.Containers[i].Env, corev1.EnvVar{
134 | Name: "WORKER_PORT",
135 | Value: strconv.Itoa(int(workerPort)),
136 | })
137 | podTemplate.Spec.Containers[i].Env = append(podTemplate.Spec.Containers[i].Env, corev1.EnvVar{
138 | Name: "WORKER_ADDRS",
139 | Value: strings.Join(workerAddrs, ","),
140 | })
141 | }
142 | }
143 |
144 | return nil
145 | }
146 |
--------------------------------------------------------------------------------
/pkg/controller/v1/xgboostjob/pod_test.go:
--------------------------------------------------------------------------------
1 | /*
2 |
3 | Licensed under the Apache License, Version 2.0 (the "License");
4 | you may not use this file except in compliance with the License.
5 | You may obtain a copy of the License at
6 |
7 | http://www.apache.org/licenses/LICENSE-2.0
8 |
9 | Unless required by applicable law or agreed to in writing, software
10 | distributed under the License is distributed on an "AS IS" BASIS,
11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | See the License for the specific language governing permissions and
13 | limitations under the License.
14 | */
15 |
16 | package xgboostjob
17 |
18 | import (
19 | "testing"
20 |
21 | commonv1 "github.com/kubeflow/common/pkg/apis/common/v1"
22 | v1xgboost "github.com/kubeflow/xgboost-operator/pkg/apis/xgboostjob/v1"
23 | v1 "k8s.io/api/core/v1"
24 | metav1 "k8s.io/apimachinery/pkg/apis/meta/v1"
25 | )
26 |
27 | func NewXGBoostJobWithMaster(worker int) *v1xgboost.XGBoostJob {
28 | job := NewXGoostJob(worker)
29 | master := int32(1)
30 | masterReplicaSpec := &commonv1.ReplicaSpec{
31 | Replicas: &master,
32 | Template: NewXGBoostReplicaSpecTemplate(),
33 | }
34 | job.Spec.XGBReplicaSpecs[commonv1.ReplicaType(v1xgboost.XGBoostReplicaTypeMaster)] = masterReplicaSpec
35 | return job
36 | }
37 |
38 | func NewXGoostJob(worker int) *v1xgboost.XGBoostJob {
39 |
40 | job := &v1xgboost.XGBoostJob{
41 | TypeMeta: metav1.TypeMeta{
42 | Kind: v1xgboost.Kind,
43 | },
44 | ObjectMeta: metav1.ObjectMeta{
45 | Name: "test-xgboostjob",
46 | Namespace: metav1.NamespaceDefault,
47 | },
48 | Spec: v1xgboost.XGBoostJobSpec{
49 | XGBReplicaSpecs: make(map[commonv1.ReplicaType]*commonv1.ReplicaSpec),
50 | },
51 | }
52 |
53 | if worker > 0 {
54 | worker := int32(worker)
55 | workerReplicaSpec := &commonv1.ReplicaSpec{
56 | Replicas: &worker,
57 | Template: NewXGBoostReplicaSpecTemplate(),
58 | }
59 | job.Spec.XGBReplicaSpecs[commonv1.ReplicaType(v1xgboost.XGBoostReplicaTypeWorker)] = workerReplicaSpec
60 | }
61 |
62 | return job
63 | }
64 |
65 | func NewXGBoostReplicaSpecTemplate() v1.PodTemplateSpec {
66 | return v1.PodTemplateSpec{
67 | Spec: v1.PodSpec{
68 | Containers: []v1.Container{
69 | v1.Container{
70 | Name: v1xgboost.DefaultContainerName,
71 | Image: "test-image-for-kubeflow-xgboost-operator:latest",
72 | Args: []string{"Fake", "Fake"},
73 | Ports: []v1.ContainerPort{
74 | v1.ContainerPort{
75 | Name: v1xgboost.DefaultContainerPortName,
76 | ContainerPort: v1xgboost.DefaultPort,
77 | },
78 | },
79 | },
80 | },
81 | },
82 | }
83 | }
84 |
85 | func TestClusterSpec(t *testing.T) {
86 | type tc struct {
87 | job *v1xgboost.XGBoostJob
88 | rt v1xgboost.XGBoostJobReplicaType
89 | index string
90 | expectedClusterSpec map[string]string
91 | }
92 | testCase := []tc{
93 | tc{
94 | job: NewXGBoostJobWithMaster(0),
95 | rt: v1xgboost.XGBoostReplicaTypeMaster,
96 | index: "0",
97 | expectedClusterSpec: map[string]string{"WORLD_SIZE": "1", "MASTER_PORT": "9999", "RANK": "0", "MASTER_ADDR": "test-xgboostjob-master-0"},
98 | },
99 | tc{
100 | job: NewXGBoostJobWithMaster(1),
101 | rt: v1xgboost.XGBoostReplicaTypeMaster,
102 | index: "1",
103 | expectedClusterSpec: map[string]string{"WORLD_SIZE": "2", "MASTER_PORT": "9999", "RANK": "1", "MASTER_ADDR": "test-xgboostjob-master-0", "WORKER_PORT": "9999", "WORKER_ADDRS": "test-xgboostjob-worker-0"},
104 | },
105 | tc{
106 | job: NewXGBoostJobWithMaster(2),
107 | rt: v1xgboost.XGBoostReplicaTypeMaster,
108 | index: "0",
109 | expectedClusterSpec: map[string]string{"WORLD_SIZE": "3", "MASTER_PORT": "9999", "RANK": "0", "MASTER_ADDR": "test-xgboostjob-master-0", "WORKER_PORT": "9999", "WORKER_ADDRS": "test-xgboostjob-worker-0,test-xgboostjob-worker-1"},
110 | },
111 | tc{
112 | job: NewXGBoostJobWithMaster(2),
113 | rt: v1xgboost.XGBoostReplicaTypeWorker,
114 | index: "0",
115 | expectedClusterSpec: map[string]string{"WORLD_SIZE": "3", "MASTER_PORT": "9999", "RANK": "1", "MASTER_ADDR": "test-xgboostjob-master-0", "WORKER_PORT": "9999", "WORKER_ADDRS": "test-xgboostjob-worker-0,test-xgboostjob-worker-1"},
116 | },
117 | tc{
118 | job: NewXGBoostJobWithMaster(2),
119 | rt: v1xgboost.XGBoostReplicaTypeWorker,
120 | index: "1",
121 | expectedClusterSpec: map[string]string{"WORLD_SIZE": "3", "MASTER_PORT": "9999", "RANK": "2", "MASTER_ADDR": "test-xgboostjob-master-0", "WORKER_PORT": "9999", "WORKER_ADDRS": "test-xgboostjob-worker-0,test-xgboostjob-worker-1"},
122 | },
123 | }
124 | for _, c := range testCase {
125 | demoTemplateSpec := c.job.Spec.XGBReplicaSpecs[commonv1.ReplicaType(c.rt)].Template
126 | if err := SetPodEnv(c.job, &demoTemplateSpec, string(c.rt), c.index); err != nil {
127 | t.Errorf("Failed to set cluster spec: %v", err)
128 | }
129 | actual := demoTemplateSpec.Spec.Containers[0].Env
130 | for _, env := range actual {
131 | if val, ok := c.expectedClusterSpec[env.Name]; ok {
132 | if val != env.Value {
133 | t.Errorf("For name %s Got %s. Expected %s ", env.Name, env.Value, c.expectedClusterSpec[env.Name])
134 | }
135 | }
136 | }
137 | }
138 | }
139 |
--------------------------------------------------------------------------------
/pkg/controller/v1/xgboostjob/service.go:
--------------------------------------------------------------------------------
1 | /*
2 |
3 | Licensed under the Apache License, Version 2.0 (the "License");
4 | you may not use this file except in compliance with the License.
5 | You may obtain a copy of the License at
6 |
7 | http://www.apache.org/licenses/LICENSE-2.0
8 |
9 | Unless required by applicable law or agreed to in writing, software
10 | distributed under the License is distributed on an "AS IS" BASIS,
11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | See the License for the specific language governing permissions and
13 | limitations under the License.
14 | */
15 |
16 | package xgboostjob
17 |
18 | import (
19 | "context"
20 | "fmt"
21 | corev1 "k8s.io/api/core/v1"
22 | "k8s.io/apimachinery/pkg/api/meta"
23 | "sigs.k8s.io/controller-runtime/pkg/client"
24 | )
25 |
26 | // GetServicesForJob returns the services managed by the job. This can be achieved by selecting services using label key "job-name"
27 | // i.e. all services created by the job will come with label "job-name" =
28 | func (r *ReconcileXGBoostJob) GetServicesForJob(obj interface{}) ([]*corev1.Service, error) {
29 | job, err := meta.Accessor(obj)
30 | if err != nil {
31 | return nil, fmt.Errorf("%+v is not a type of XGBoostJob", job)
32 | }
33 | // List all pods to include those that don't match the selector anymore
34 | // but have a ControllerRef pointing to this controller.
35 | serviceList := &corev1.ServiceList{}
36 | err = r.List(context.Background(), serviceList, client.MatchingLabels(r.GenLabels(job.GetName())))
37 | if err != nil {
38 | return nil, err
39 | }
40 |
41 | //TODO support adopting/orphaning
42 | ret := convertServiceList(serviceList.Items)
43 |
44 | return ret, nil
45 | }
46 |
47 | // convertServiceList convert service list to service point list
48 | func convertServiceList(list []corev1.Service) []*corev1.Service {
49 | if list == nil {
50 | return nil
51 | }
52 | ret := make([]*corev1.Service, 0, len(list))
53 | for i := range list {
54 | ret = append(ret, &list[i])
55 | }
56 | return ret
57 | }
58 |
--------------------------------------------------------------------------------
/pkg/controller/v1/xgboostjob/util.go:
--------------------------------------------------------------------------------
1 | /*
2 |
3 | Licensed under the Apache License, Version 2.0 (the "License");
4 | you may not use this file except in compliance with the License.
5 | You may obtain a copy of the License at
6 |
7 | http://www.apache.org/licenses/LICENSE-2.0
8 |
9 | Unless required by applicable law or agreed to in writing, software
10 | distributed under the License is distributed on an "AS IS" BASIS,
11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | See the License for the specific language governing permissions and
13 | limitations under the License.
14 | */
15 |
16 | package xgboostjob
17 |
18 | import (
19 | "errors"
20 | "fmt"
21 | "os"
22 | "strings"
23 | "time"
24 |
25 | commonv1 "github.com/kubeflow/common/pkg/apis/common/v1"
26 | v1xgboost "github.com/kubeflow/xgboost-operator/pkg/apis/xgboostjob/v1"
27 | metav1 "k8s.io/apimachinery/pkg/apis/meta/v1"
28 | kubeclientset "k8s.io/client-go/kubernetes"
29 | restclientset "k8s.io/client-go/rest"
30 | "sigs.k8s.io/controller-runtime/pkg/client"
31 | volcanoclient "volcano.sh/volcano/pkg/client/clientset/versioned"
32 | )
33 |
34 | // getClientReaderFromClient try to extract client reader from client, client
35 | // reader reads cluster info from api client.
36 | func getClientReaderFromClient(c client.Client) (client.Reader, error) {
37 | if dr, err := getDelegatingReader(c); err != nil {
38 | return nil, err
39 | } else {
40 | return dr.ClientReader, nil
41 | }
42 | }
43 |
44 | // getDelegatingReader try to extract DelegatingReader from client.
45 | func getDelegatingReader(c client.Client) (*client.DelegatingReader, error) {
46 | dc, ok := c.(*client.DelegatingClient)
47 | if !ok {
48 | return nil, errors.New("cannot convert from Client to DelegatingClient")
49 | }
50 | dr, ok := dc.Reader.(*client.DelegatingReader)
51 | if !ok {
52 | return nil, errors.New("cannot convert from DelegatingClient.Reader to Delegating Reader")
53 | }
54 | return dr, nil
55 | }
56 |
57 | func computeMasterAddr(jobName, rtype, index string) string {
58 | n := jobName + "-" + rtype + "-" + index
59 | return strings.Replace(n, "/", "-", -1)
60 | }
61 |
62 | // GetPortFromXGBoostJob gets the port of xgboost container.
63 | func GetPortFromXGBoostJob(job *v1xgboost.XGBoostJob, rtype v1xgboost.XGBoostJobReplicaType) (int32, error) {
64 | containers := job.Spec.XGBReplicaSpecs[commonv1.ReplicaType(rtype)].Template.Spec.Containers
65 | for _, container := range containers {
66 | if container.Name == v1xgboost.DefaultContainerName {
67 | ports := container.Ports
68 | for _, port := range ports {
69 | if port.Name == v1xgboost.DefaultContainerPortName {
70 | return port.ContainerPort, nil
71 | }
72 | }
73 | }
74 | }
75 | return -1, fmt.Errorf("failed to found the port")
76 | }
77 |
78 | func computeTotalReplicas(obj metav1.Object) int32 {
79 | job := obj.(*v1xgboost.XGBoostJob)
80 | jobReplicas := int32(0)
81 |
82 | if job.Spec.XGBReplicaSpecs == nil || len(job.Spec.XGBReplicaSpecs) == 0 {
83 | return jobReplicas
84 | }
85 | for _, r := range job.Spec.XGBReplicaSpecs {
86 | if r.Replicas == nil {
87 | continue
88 | } else {
89 | jobReplicas += *r.Replicas
90 | }
91 | }
92 | return jobReplicas
93 | }
94 |
95 | func createClientSets(config *restclientset.Config) (kubeclientset.Interface, kubeclientset.Interface, volcanoclient.Interface, error) {
96 | if config == nil {
97 | println("there is an error for the input config")
98 | return nil, nil, nil, nil
99 | }
100 |
101 | kubeClientSet, err := kubeclientset.NewForConfig(restclientset.AddUserAgent(config, "xgboostjob-operator"))
102 | if err != nil {
103 | return nil, nil, nil, err
104 | }
105 |
106 | leaderElectionClientSet, err := kubeclientset.NewForConfig(restclientset.AddUserAgent(config, "leader-election"))
107 | if err != nil {
108 | return nil, nil, nil, err
109 | }
110 |
111 | volcanoClientSet, err := volcanoclient.NewForConfig(restclientset.AddUserAgent(config, "volcano"))
112 | if err != nil {
113 | return nil, nil, nil, err
114 | }
115 |
116 | return kubeClientSet, leaderElectionClientSet, volcanoClientSet, nil
117 | }
118 |
119 | func homeDir() string {
120 | if h := os.Getenv("HOME"); h != "" {
121 | return h
122 | }
123 | return os.Getenv("USERPROFILE") // windows
124 | }
125 |
126 | func isGangSchedulerSet(replicas map[commonv1.ReplicaType]*commonv1.ReplicaSpec) bool {
127 | for _, spec := range replicas {
128 | if spec.Template.Spec.SchedulerName != "" && spec.Template.Spec.SchedulerName == gangSchedulerName {
129 | return true
130 | }
131 | }
132 | return false
133 | }
134 |
135 | // FakeWorkQueue implements RateLimitingInterface but actually does nothing.
136 | type FakeWorkQueue struct{}
137 |
138 | // Add WorkQueue Add method
139 | func (f *FakeWorkQueue) Add(item interface{}) {}
140 |
141 | // Len WorkQueue Len method
142 | func (f *FakeWorkQueue) Len() int { return 0 }
143 |
144 | // Get WorkQueue Get method
145 | func (f *FakeWorkQueue) Get() (item interface{}, shutdown bool) { return nil, false }
146 |
147 | // Done WorkQueue Done method
148 | func (f *FakeWorkQueue) Done(item interface{}) {}
149 |
150 | // ShutDown WorkQueue ShutDown method
151 | func (f *FakeWorkQueue) ShutDown() {}
152 |
153 | // ShuttingDown WorkQueue ShuttingDown method
154 | func (f *FakeWorkQueue) ShuttingDown() bool { return true }
155 |
156 | // AddAfter WorkQueue AddAfter method
157 | func (f *FakeWorkQueue) AddAfter(item interface{}, duration time.Duration) {}
158 |
159 | // AddRateLimited WorkQueue AddRateLimited method
160 | func (f *FakeWorkQueue) AddRateLimited(item interface{}) {}
161 |
162 | // Forget WorkQueue Forget method
163 | func (f *FakeWorkQueue) Forget(item interface{}) {}
164 |
165 | // NumRequeues WorkQueue NumRequeues method
166 | func (f *FakeWorkQueue) NumRequeues(item interface{}) int { return 0 }
167 |
--------------------------------------------------------------------------------
/pkg/controller/v1/xgboostjob/xgboostjob_controller.go:
--------------------------------------------------------------------------------
1 | /*
2 | Licensed under the Apache License, Version 2.0 (the "License");
3 | you may not use this file except in compliance with the License.
4 | You may obtain a copy of the License at
5 | http://www.apache.org/licenses/LICENSE-2.0
6 | Unless required by applicable law or agreed to in writing, software
7 | distributed under the License is distributed on an "AS IS" BASIS,
8 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
9 | See the License for the specific language governing permissions and
10 | limitations under the License.
11 | */
12 |
13 | package xgboostjob
14 |
15 | import (
16 | "context"
17 | "flag"
18 | commonv1 "github.com/kubeflow/common/pkg/apis/common/v1"
19 | "github.com/kubeflow/common/pkg/controller.v1/common"
20 | "github.com/kubeflow/common/pkg/controller.v1/control"
21 | "github.com/kubeflow/common/pkg/controller.v1/expectation"
22 | v1xgboost "github.com/kubeflow/xgboost-operator/pkg/apis/xgboostjob/v1"
23 | corev1 "k8s.io/api/core/v1"
24 | "path/filepath"
25 |
26 | "github.com/sirupsen/logrus"
27 | "k8s.io/apimachinery/pkg/api/errors"
28 | "k8s.io/apimachinery/pkg/runtime"
29 | "k8s.io/apimachinery/pkg/runtime/schema"
30 | "k8s.io/client-go/kubernetes/scheme"
31 | "k8s.io/client-go/rest"
32 | "k8s.io/client-go/tools/clientcmd"
33 | "k8s.io/client-go/tools/record"
34 | "sigs.k8s.io/controller-runtime/pkg/client"
35 | "sigs.k8s.io/controller-runtime/pkg/controller"
36 | "sigs.k8s.io/controller-runtime/pkg/handler"
37 | "sigs.k8s.io/controller-runtime/pkg/manager"
38 | "sigs.k8s.io/controller-runtime/pkg/predicate"
39 | "sigs.k8s.io/controller-runtime/pkg/reconcile"
40 | logf "sigs.k8s.io/controller-runtime/pkg/runtime/log"
41 | "sigs.k8s.io/controller-runtime/pkg/source"
42 | )
43 |
44 | const (
45 | controllerName = "xgboostjob-operator"
46 | labelXGBoostJobRole = "xgboostjob-job-role"
47 | // gang scheduler name.
48 | gangSchedulerName = "kube-batch"
49 | )
50 |
51 | var (
52 | defaultTTLseconds = int32(100)
53 | defaultCleanPodPolicy = commonv1.CleanPodPolicyNone
54 | )
55 | var log = logf.Log.WithName("controller")
56 |
57 | /**
58 | * USER ACTION REQUIRED: This is a scaffold file intended for the user to modify with their own Controller
59 | * business logic. Delete these comments after modifying this file.*
60 | */
61 |
62 | // Add creates a new XGBoostJob Controller and adds it to the Manager with default RBAC. The Manager will set fields on the Controller
63 | // and Start it when the Manager is Started.
64 | func Add(mgr manager.Manager) error {
65 | return add(mgr, newReconciler(mgr))
66 | }
67 |
68 | const RecommendedKubeConfigPathEnv = "KUBECONFIG"
69 |
70 | // newReconciler returns a new reconcile.Reconciler
71 | func newReconciler(mgr manager.Manager) reconcile.Reconciler {
72 |
73 | r := &ReconcileXGBoostJob{
74 | Client: mgr.GetClient(),
75 | scheme: mgr.GetScheme(),
76 | }
77 |
78 | r.recorder = mgr.GetEventRecorderFor(r.ControllerName())
79 |
80 | var mode string
81 | var kubeconfig *string
82 | var kcfg *rest.Config
83 | if home := homeDir(); home != "" {
84 | kubeconfig = flag.String("kubeconfig_", filepath.Join(home, ".kube", "config"), "(optional) absolute path to the kubeconfig file")
85 | } else {
86 | kubeconfig = flag.String("kubeconfig_", "", "absolute path to the kubeconfig file")
87 | }
88 | flag.Parse()
89 |
90 | mode = flag.Lookup("mode").Value.(flag.Getter).Get().(string)
91 | if mode == "local" {
92 | log.Info("Running controller in local mode, using kubeconfig file")
93 | /// TODO, add the master url and kubeconfigpath with user input
94 | config, err := clientcmd.BuildConfigFromFlags("", *kubeconfig)
95 | if err != nil {
96 | log.Info("Error building kubeconfig: %s", err.Error())
97 | panic(err.Error())
98 | }
99 | kcfg = config
100 | } else if mode == "in-cluster" {
101 | log.Info("Running controller in in-cluster mode")
102 | /// TODO, add the master url and kubeconfigpath with user input
103 | config, err := rest.InClusterConfig()
104 | if err != nil {
105 | log.Info("Error getting in-cluster kubeconfig")
106 | panic(err.Error())
107 | }
108 | kcfg = config
109 | } else {
110 | log.Info("Given mode is not valid: ", "mode", mode)
111 | panic("-mode should be either local or in-cluster")
112 | }
113 |
114 | // Create clients.
115 | kubeClientSet, _, volcanoClientSet, err := createClientSets(kcfg)
116 | if err != nil {
117 | log.Info("Error building kubeclientset: %s", err.Error())
118 | }
119 |
120 | // Create Informer factory
121 | xgboostjob := &v1xgboost.XGBoostJob{}
122 |
123 | gangScheduling := isGangSchedulerSet(xgboostjob.Spec.XGBReplicaSpecs)
124 |
125 | log.Info("gang scheduling is set: ", "gangscheduling", gangScheduling)
126 |
127 | // Initialize common job controller with components we only need.
128 | r.JobController = common.JobController{
129 | Controller: r,
130 | Expectations: expectation.NewControllerExpectations(),
131 | Config: common.JobControllerConfiguration{EnableGangScheduling: gangScheduling},
132 | WorkQueue: &FakeWorkQueue{},
133 | Recorder: r.recorder,
134 | KubeClientSet: kubeClientSet,
135 | VolcanoClientSet: volcanoClientSet,
136 | PodControl: control.RealPodControl{KubeClient: kubeClientSet, Recorder: r.recorder},
137 | ServiceControl: control.RealServiceControl{KubeClient: kubeClientSet, Recorder: r.recorder},
138 | }
139 |
140 | return r
141 | }
142 |
143 | // add adds a new Controller to mgr with r as the reconcile.Reconciler
144 | func add(mgr manager.Manager, r reconcile.Reconciler) error {
145 | // Create a new controller
146 | c, err := controller.New("xgboostjob-controller", mgr, controller.Options{Reconciler: r})
147 | if err != nil {
148 | return err
149 | }
150 |
151 | // Watch for changes to XGBoostJob
152 | err = c.Watch(&source.Kind{Type: &v1xgboost.XGBoostJob{}}, &handler.EnqueueRequestForObject{},
153 | predicate.Funcs{CreateFunc: onOwnerCreateFunc(r)},
154 | )
155 | if err != nil {
156 | return err
157 | }
158 |
159 | //inject watching for xgboostjob related pod
160 | err = c.Watch(&source.Kind{Type: &corev1.Pod{}}, &handler.EnqueueRequestForOwner{
161 | IsController: true,
162 | OwnerType: &v1xgboost.XGBoostJob{},
163 | },
164 | predicate.Funcs{CreateFunc: onDependentCreateFunc(r), DeleteFunc: onDependentDeleteFunc(r)},
165 | )
166 | if err != nil {
167 | return err
168 | }
169 |
170 | //inject watching for xgboostjob related service
171 | err = c.Watch(&source.Kind{Type: &corev1.Service{}}, &handler.EnqueueRequestForOwner{
172 | IsController: true,
173 | OwnerType: &v1xgboost.XGBoostJob{},
174 | },
175 | &predicate.Funcs{CreateFunc: onDependentCreateFunc(r), DeleteFunc: onDependentDeleteFunc(r)},
176 | )
177 | if err != nil {
178 | return err
179 | }
180 |
181 | return nil
182 | }
183 |
184 | var _ reconcile.Reconciler = &ReconcileXGBoostJob{}
185 |
186 | // ReconcileXGBoostJob reconciles a XGBoostJob object
187 | type ReconcileXGBoostJob struct {
188 | common.JobController
189 | client.Client
190 | scheme *runtime.Scheme
191 | recorder record.EventRecorder
192 | }
193 |
194 | // Reconcile reads that state of the cluster for a XGBoostJob object and makes changes based on the state read
195 | // and what is in the XGBoostJob.Spec
196 | // a Deployment as an example
197 | // Automatically generate RBAC rules to allow the Controller to read and write Deployments
198 | // +kubebuilder:rbac:groups=apps,resources=deployments,verbs=get;list;watch;create;update;patch;delete
199 | // +kubebuilder:rbac:groups=apps,resources=deployments/status,verbs=get;update;patch
200 | // +kubebuilder:rbac:groups=xgboostjob.kubeflow.org,resources=xgboostjobs,verbs=get;list;watch;create;update;patch;delete
201 | // +kubebuilder:rbac:groups=xgboostjob.kubeflow.org,resources=xgboostjobs/status,verbs=get;update;patch
202 | func (r *ReconcileXGBoostJob) Reconcile(request reconcile.Request) (reconcile.Result, error) {
203 | // Fetch the XGBoostJob instance
204 | xgboostjob := &v1xgboost.XGBoostJob{}
205 | err := r.Get(context.Background(), request.NamespacedName, xgboostjob)
206 | if err != nil {
207 | if errors.IsNotFound(err) {
208 | // Object not found, return. Created objects are automatically garbage collected.
209 | // For additional cleanup logic use finalizers.
210 | return reconcile.Result{}, nil
211 | }
212 | // Error reading the object - requeue the request.
213 | return reconcile.Result{}, err
214 | }
215 |
216 | // Check reconcile is required.
217 | needSync := r.satisfiedExpectations(xgboostjob)
218 |
219 | if !needSync || xgboostjob.DeletionTimestamp != nil {
220 | log.Info("reconcile cancelled, job does not need to do reconcile or has been deleted",
221 | "sync", needSync, "deleted", xgboostjob.DeletionTimestamp != nil)
222 | return reconcile.Result{}, nil
223 | }
224 | // Set default priorities for xgboost job
225 | scheme.Scheme.Default(xgboostjob)
226 |
227 | // Use common to reconcile the job related pod and service
228 | err = r.ReconcileJobs(xgboostjob, xgboostjob.Spec.XGBReplicaSpecs, xgboostjob.Status.JobStatus, &xgboostjob.Spec.RunPolicy)
229 |
230 | if err != nil {
231 | logrus.Warnf("Reconcile XGBoost Job error %v", err)
232 | return reconcile.Result{}, err
233 | }
234 |
235 | return reconcile.Result{}, err
236 | }
237 |
238 | func (r *ReconcileXGBoostJob) ControllerName() string {
239 | return controllerName
240 | }
241 |
242 | func (r *ReconcileXGBoostJob) GetAPIGroupVersionKind() schema.GroupVersionKind {
243 | return v1xgboost.SchemeGroupVersionKind
244 | }
245 |
246 | func (r *ReconcileXGBoostJob) GetAPIGroupVersion() schema.GroupVersion {
247 | return v1xgboost.SchemeGroupVersion
248 | }
249 |
250 | func (r *ReconcileXGBoostJob) GetGroupNameLabelValue() string {
251 | return v1xgboost.GroupName
252 | }
253 |
254 | func (r *ReconcileXGBoostJob) GetDefaultContainerName() string {
255 | return v1xgboost.DefaultContainerName
256 | }
257 |
258 | func (r *ReconcileXGBoostJob) GetDefaultContainerPortName() string {
259 | return v1xgboost.DefaultContainerPortName
260 | }
261 |
262 | func (r *ReconcileXGBoostJob) GetJobRoleKey() string {
263 | return labelXGBoostJobRole
264 | }
265 |
266 | func (r *ReconcileXGBoostJob) IsMasterRole(replicas map[commonv1.ReplicaType]*commonv1.ReplicaSpec,
267 | rtype commonv1.ReplicaType, index int) bool {
268 | return string(rtype) == string(v1xgboost.XGBoostReplicaTypeMaster)
269 | }
270 |
271 | // SetClusterSpec sets the cluster spec for the pod
272 | func (r *ReconcileXGBoostJob) SetClusterSpec(job interface{}, podTemplate *corev1.PodTemplateSpec, rtype, index string) error {
273 | return SetPodEnv(job, podTemplate, rtype, index)
274 | }
275 |
--------------------------------------------------------------------------------
/pkg/controller/v1/xgboostjob/xgboostjob_controller_suite_test.go.back:
--------------------------------------------------------------------------------
1 | /*
2 |
3 | Licensed under the Apache License, Version 2.0 (the "License");
4 | you may not use this file except in compliance with the License.
5 | You may obtain a copy of the License at
6 |
7 | http://www.apache.org/licenses/LICENSE-2.0
8 |
9 | Unless required by applicable law or agreed to in writing, software
10 | distributed under the License is distributed on an "AS IS" BASIS,
11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | See the License for the specific language governing permissions and
13 | limitations under the License.
14 | */
15 |
16 | package xgboostjob
17 |
18 | import (
19 | stdlog "log"
20 | "os"
21 | "path/filepath"
22 | "sync"
23 | "testing"
24 |
25 | "github.com/kubeflow/xgboost-operator/pkg/apis"
26 | "github.com/onsi/gomega"
27 | "k8s.io/client-go/kubernetes/scheme"
28 | "k8s.io/client-go/rest"
29 | "sigs.k8s.io/controller-runtime/pkg/envtest"
30 | "sigs.k8s.io/controller-runtime/pkg/manager"
31 | "sigs.k8s.io/controller-runtime/pkg/reconcile"
32 | )
33 |
34 | var cfg *rest.Config
35 |
36 | func TestMain(m *testing.M) {
37 | t := &envtest.Environment{
38 | CRDDirectoryPaths: []string{filepath.Join("..", "..", "..", "config", "crds")},
39 | }
40 | apis.AddToScheme(scheme.Scheme)
41 |
42 | var err error
43 | if cfg, err = t.Start(); err != nil {
44 | stdlog.Fatal(err)
45 | }
46 |
47 | code := m.Run()
48 | t.Stop()
49 | os.Exit(code)
50 | }
51 |
52 | // SetupTestReconcile returns a reconcile.Reconcile implementation that delegates to inner and
53 | // writes the request to requests after Reconcile is finished.
54 | func SetupTestReconcile(inner reconcile.Reconciler) (reconcile.Reconciler, chan reconcile.Request) {
55 | requests := make(chan reconcile.Request)
56 | fn := reconcile.Func(func(req reconcile.Request) (reconcile.Result, error) {
57 | result, err := inner.Reconcile(req)
58 | requests <- req
59 | return result, err
60 | })
61 | return fn, requests
62 | }
63 |
64 | // StartTestManager adds recFn
65 | func StartTestManager(mgr manager.Manager, g *gomega.GomegaWithT) (chan struct{}, *sync.WaitGroup) {
66 | stop := make(chan struct{})
67 | wg := &sync.WaitGroup{}
68 | wg.Add(1)
69 | go func() {
70 | defer wg.Done()
71 | g.Expect(mgr.Start(stop)).NotTo(gomega.HaveOccurred())
72 | }()
73 | return stop, wg
74 | }
75 |
--------------------------------------------------------------------------------
/pkg/controller/v1/xgboostjob/xgboostjob_controller_test.go.back:
--------------------------------------------------------------------------------
1 | /*
2 |
3 | Licensed under the Apache License, Version 2.0 (the "License");
4 | you may not use this file except in compliance with the License.
5 | You may obtain a copy of the License at
6 |
7 | http://www.apache.org/licenses/LICENSE-2.0
8 |
9 | Unless required by applicable law or agreed to in writing, software
10 | distributed under the License is distributed on an "AS IS" BASIS,
11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | See the License for the specific language governing permissions and
13 | limitations under the License.
14 | */
15 |
16 | package xgboostjob
17 |
18 | import (
19 | "testing"
20 | "time"
21 |
22 | xgboostjobv1 "github.com/kubeflow/xgboost-operator/pkg/apis/xgboostjob/v1"
23 | "github.com/onsi/gomega"
24 | "golang.org/x/net/context"
25 | appsv1 "k8s.io/api/apps/v1"
26 | apierrors "k8s.io/apimachinery/pkg/api/errors"
27 | metav1 "k8s.io/apimachinery/pkg/apis/meta/v1"
28 | "k8s.io/apimachinery/pkg/types"
29 | "sigs.k8s.io/controller-runtime/pkg/client"
30 | "sigs.k8s.io/controller-runtime/pkg/manager"
31 | "sigs.k8s.io/controller-runtime/pkg/reconcile"
32 | )
33 |
34 | var c client.Client
35 |
36 | var expectedRequest = reconcile.Request{NamespacedName: types.NamespacedName{Name: "foo", Namespace: "default"}}
37 | var depKey = types.NamespacedName{Name: "foo-deployment", Namespace: "default"}
38 |
39 | const timeout = time.Second * 5
40 |
41 | func TestReconcile(t *testing.T) {
42 | g := gomega.NewGomegaWithT(t)
43 | instance := &xgboostjobv1.XGBoostJob{ObjectMeta: metav1.ObjectMeta{Name: "foo", Namespace: "default"}}
44 |
45 | // Setup the Manager and Controller. Wrap the Controller Reconcile function so it writes each request to a
46 | // channel when it is finished.
47 | mgr, err := manager.New(cfg, manager.Options{})
48 | g.Expect(err).NotTo(gomega.HaveOccurred())
49 | c = mgr.GetClient()
50 |
51 | recFn, requests := SetupTestReconcile(newReconciler(mgr))
52 | g.Expect(add(mgr, recFn)).NotTo(gomega.HaveOccurred())
53 |
54 | stopMgr, mgrStopped := StartTestManager(mgr, g)
55 |
56 | defer func() {
57 | close(stopMgr)
58 | mgrStopped.Wait()
59 | }()
60 |
61 | // Create the XGBoostJob object and expect the Reconcile and Deployment to be created
62 | err = c.Create(context.TODO(), instance)
63 | // The instance object may not be a valid object because it might be missing some required fields.
64 | // Please modify the instance object by adding required fields and then remove the following if statement.
65 | if apierrors.IsInvalid(err) {
66 | t.Logf("failed to create object, got an invalid object error: %v", err)
67 | return
68 | }
69 | g.Expect(err).NotTo(gomega.HaveOccurred())
70 | defer c.Delete(context.TODO(), instance)
71 | g.Eventually(requests, timeout).Should(gomega.Receive(gomega.Equal(expectedRequest)))
72 |
73 | deploy := &appsv1.Deployment{}
74 | g.Eventually(func() error { return c.Get(context.TODO(), depKey, deploy) }, timeout).
75 | Should(gomega.Succeed())
76 |
77 | // Delete the Deployment and expect Reconcile to be called for Deployment deletion
78 | g.Expect(c.Delete(context.TODO(), deploy)).NotTo(gomega.HaveOccurred())
79 | g.Eventually(requests, timeout).Should(gomega.Receive(gomega.Equal(expectedRequest)))
80 | g.Eventually(func() error { return c.Get(context.TODO(), depKey, deploy) }, timeout).
81 | Should(gomega.Succeed())
82 |
83 | // Manually delete Deployment since GC isn't enabled in the test control plane
84 | g.Eventually(func() error { return c.Delete(context.TODO(), deploy) }, timeout).
85 | Should(gomega.MatchError("deployments.apps \"foo-deployment\" not found"))
86 |
87 | }
88 |
--------------------------------------------------------------------------------
/pkg/webhook/webhook.go:
--------------------------------------------------------------------------------
1 | /*
2 |
3 | Licensed under the Apache License, Version 2.0 (the "License");
4 | you may not use this file except in compliance with the License.
5 | You may obtain a copy of the License at
6 |
7 | http://www.apache.org/licenses/LICENSE-2.0
8 |
9 | Unless required by applicable law or agreed to in writing, software
10 | distributed under the License is distributed on an "AS IS" BASIS,
11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | See the License for the specific language governing permissions and
13 | limitations under the License.
14 | */
15 |
16 | package webhook
17 |
18 | import (
19 | "sigs.k8s.io/controller-runtime/pkg/manager"
20 | )
21 |
22 | // AddToManagerFuncs is a list of functions to add all Controllers to the Manager
23 | var AddToManagerFuncs []func(manager.Manager) error
24 |
25 | // AddToManager adds all Controllers to the Manager
26 | // +kubebuilder:rbac:groups=admissionregistration.k8s.io,resources=mutatingwebhookconfigurations;validatingwebhookconfigurations,verbs=get;list;watch;create;update;patch;delete
27 | // +kubebuilder:rbac:groups="",resources=secrets,verbs=get;list;watch;create;update;patch;delete
28 | // +kubebuilder:rbac:groups="",resources=services,verbs=get;list;watch;create;update;patch;delete
29 | func AddToManager(m manager.Manager) error {
30 | for _, f := range AddToManagerFuncs {
31 | if err := f(m); err != nil {
32 | return err
33 | }
34 | }
35 | return nil
36 | }
37 |
--------------------------------------------------------------------------------
/prow_config.yaml:
--------------------------------------------------------------------------------
1 | # This file configures the workflows to trigger in our Prow jobs.
2 | # see https://github.com/kubeflow/testing/blob/master/py/kubeflow/testing/run_e2e_workflow.py
3 | workflows:
4 | - app_dir: kubeflow/xgboost-operator/test/workflows
5 | # this super-short names are required so that identity lengths will be shorter than 64
6 | component: build
7 | name: build
8 | job_types:
9 | - presubmit
10 | params:
11 | registry: "gcr.io/kubeflow-ci"
12 | - app_dir: kubeflow/xgboost-operator/test/workflows
13 | component: build
14 | name: build
15 | job_types:
16 | - postsubmit
17 | params:
18 | registry: "gcr.io/kubeflow-images-public"
19 |
--------------------------------------------------------------------------------
/test/workflows/.gitignore:
--------------------------------------------------------------------------------
1 | /.ksonnet/registries
2 | /app.override.yaml
3 | /.ks_environment
4 |
--------------------------------------------------------------------------------
/test/workflows/app.yaml:
--------------------------------------------------------------------------------
1 | apiVersion: 0.3.0
2 | environments:
3 | dockerbuild:
4 | destination:
5 | namespace: kubeflow-releasing
6 | server: https://35.226.49.107
7 | k8sVersion: v1.7.0
8 | path: dockerbuild
9 | kind: ksonnet.io/app
10 | name: workflows
11 | registries:
12 | incubator:
13 | gitVersion:
14 | commitSha: ea3408d44c2d8ea4d321364e5533d5c60e74bce0
15 | refSpec: master
16 | protocol: github
17 | uri: github.com/ksonnet/parts/tree/master/incubator
18 | kubeflow:
19 | gitVersion:
20 | commitSha: 5c35580d76092788b089cb447be3f3097cffe60b
21 | refSpec: master
22 | protocol: github
23 | uri: github.com/google/kubeflow/tree/master/kubeflow
24 | version: 0.0.1
25 |
--------------------------------------------------------------------------------
/test/workflows/components/build.jsonnet:
--------------------------------------------------------------------------------
1 | local env = std.extVar("__ksonnet/environments");
2 | local params = std.extVar("__ksonnet/params").components.build;
3 |
4 | local k = import "k.libsonnet";
5 | local release = import "kubeflow/automation/release.libsonnet";
6 | local updatedParams = params {
7 | extra_args: if params.extra_args == "null" then "" else " " + params.extra_args,
8 | };
9 |
10 | std.prune(k.core.v1.list.new(release.parts(updatedParams.namespace, updatedParams.name, overrides=updatedParams).release))
11 |
--------------------------------------------------------------------------------
/test/workflows/components/params.libsonnet:
--------------------------------------------------------------------------------
1 | {
2 | global: {
3 | // User-defined global parameters; accessible to all component and environments, Ex:
4 | // replicas: 4,
5 | },
6 | components: {
7 | build: {
8 | bucket: "kubeflow-ci_temp",
9 | cluster: "kubeflow-testing",
10 | dockerfile: "Dockerfile",
11 | dockerfileDir: "kubeflow/xgboost-operator/",
12 | extra_args: "null",
13 | extra_repos: "kubeflow/testing@HEAD",
14 | gcpCredentialsSecretName: "gcp-credentials",
15 | image: "xgboost-operator",
16 | name: "xgboost-operator",
17 | namespace: "kubeflow-ci",
18 | nfsVolumeClaim: "nfs-external",
19 | project: "kubeflow-ci",
20 | prow_env: "REPO_OWNER=kubeflow,REPO_NAME=xgboost-operator,PULL_BASE_SHA=master",
21 | registry: "gcr.io/kubeflow-images-public",
22 | testing_image: "gcr.io/kubeflow-ci/test-worker/test-worker:v20190116-b7abb8d-e3b0c4",
23 | versionTag: "v1.0",
24 | zone: "us-central1-a",
25 | },
26 | },
27 | }
28 |
--------------------------------------------------------------------------------
/test/workflows/components/util.libsonnet:
--------------------------------------------------------------------------------
1 | {
2 | // convert a list of two items into a map representing an environment variable
3 | listToMap:: function(v)
4 | {
5 | name: v[0],
6 | value: v[1],
7 | },
8 |
9 | // Function to turn comma separated list of prow environment variables into a dictionary.
10 | parseEnv:: function(v)
11 | local pieces = std.split(v, ",");
12 | if v != "" && std.length(pieces) > 0 then
13 | std.map(
14 | function(i) $.listToMap(std.split(i, "=")),
15 | std.split(v, ",")
16 | )
17 | else [],
18 | }
19 |
--------------------------------------------------------------------------------
/test/workflows/environments/base.libsonnet:
--------------------------------------------------------------------------------
1 | local components = std.extVar("__ksonnet/components");
2 | components + {
3 | // Insert user-specified overrides here.
4 | }
5 |
--------------------------------------------------------------------------------
/test/workflows/environments/dockerbuild/globals.libsonnet:
--------------------------------------------------------------------------------
1 | {
2 | }
3 |
--------------------------------------------------------------------------------
/test/workflows/environments/dockerbuild/main.jsonnet:
--------------------------------------------------------------------------------
1 | local base = import "base.libsonnet";
2 | // uncomment if you reference ksonnet-lib
3 | // local k = import "k.libsonnet";
4 |
5 | base + {
6 | // Insert user-specified overrides here. For example if a component is named \"nginx-deployment\", you might have something like:\n")
7 | // "nginx-deployment"+: k.deployment.mixin.metadata.labels({foo: "bar"})
8 | }
9 |
--------------------------------------------------------------------------------
/test/workflows/environments/dockerbuild/params.libsonnet:
--------------------------------------------------------------------------------
1 | local params = std.extVar("__ksonnet/params");
2 | local globals = import "globals.libsonnet";
3 | local envParams = params + {
4 | components +: {
5 | // Insert component parameter overrides here. Ex:
6 | // guestbook +: {
7 | // name: "guestbook-dev",
8 | // replicas: params.global.replicas,
9 | // },
10 | },
11 | };
12 |
13 | {
14 | components: {
15 | [x]: envParams.components[x] + globals, for x in std.objectFields(envParams.components)
16 | },
17 | }
18 |
--------------------------------------------------------------------------------
/test/workflows/vendor/kubeflow/automation/parts.yaml:
--------------------------------------------------------------------------------
1 | {
2 | "name": "automation",
3 | "apiVersion": "0.0.1",
4 | "kind": "ksonnet.io/parts",
5 | "description": "Toolset for kubeflow community.\n",
6 | "author": "kubeflow-team ",
7 | "contributors": [
8 | ],
9 | "repository": {
10 | "type": "git",
11 | "url": "https://github.com/kubeflow/kubeflow"
12 | },
13 | "bugs": {
14 | "url": "https://github.com/kubeflow/kubeflow/issues"
15 | },
16 | "keywords": [
17 | "kubernetes",
18 | "kubeflow",
19 | "release"
20 | ],
21 | "license": "Apache 2.0",
22 | }
23 |
--------------------------------------------------------------------------------
/test/workflows/vendor/kubeflow/automation/prototypes/release.jsonnet:
--------------------------------------------------------------------------------
1 | // @apiVersion 0.1
2 | // @name release
3 | // @description image auto release workflow.
4 | // @shortDescription setup auto release of any image in 30 min.
5 | // @param name string Name to give to each of the components
6 | // @param image string Image name to give to released image
7 | // @param dockerfileDir string Path to dockerfile used for releasing image
8 | // @optionalParam cluster string kubeflow-releasing Self explanatory.
9 | // @optionalParam gcpCredentialsSecretName string gcp-credentials Name of GCP credential stored in cluster as a secret.
10 | // @optionalParam namespace string kubeflow-releasing Self explanatory.
11 | // @optionalParam nfsVolumeClaim string nfs-external Volumn name.
12 | // @optionalParam project string kubeflow-releasing Self explanatory.
13 | // @optionalParam registry string gcr.io/kubeflow-images-public Registry where the image will be pushed to.
14 | // @optionalParam testing_image string gcr.io/kubeflow-releasing/worker:latest The image where we run the workflow.
15 | // @optionalParam zone string us-central1-a GKE zone.
16 | // @optionalParam bucket string kubeflow-releasing-artifacts GCS bucket storing artifacts.
17 | // @optionalParam prow_env string REPO_OWNER=kubeflow,REPO_NAME=kubeflow,PULL_BASE_SHA=master Self explanatory.
18 | // @optionalParam versionTag string latest Tag for released image.
19 | // @optionalParam dockerfile string Dockerfile Name of your dockerfile.
20 | // @optionalParam extra_repos string kubeflow/testing@HEAD Append your repo here (; separated) if your image is not in kubeflow.
21 | // @optionalParam extra_args string null args for your build_image.sh
22 |
23 | local k = import "k.libsonnet";
24 | local release = import "kubeflow/automation/release.libsonnet";
25 | local updatedParams = params {
26 | extra_args: if params.extra_args == "null" then "" else " " + params.extra_args,
27 | };
28 |
29 | std.prune(k.core.v1.list.new(release.parts(updatedParams.namespace, updatedParams.name, overrides=updatedParams).release))
30 |
--------------------------------------------------------------------------------
/test/workflows/vendor/kubeflow/automation/release.libsonnet:
--------------------------------------------------------------------------------
1 | {
2 | // convert a list of two items into a map representing an environment variable
3 | listToMap:: function(v)
4 | {
5 | name: v[0],
6 | value: v[1],
7 | },
8 |
9 | // Function to turn comma separated list of prow environment variables into a dictionary.
10 | parseEnv:: function(v)
11 | local pieces = std.split(v, ",");
12 | if v != "" && std.length(pieces) > 0 then
13 | std.map(
14 | function(i) $.listToMap(std.split(i, "=")),
15 | std.split(v, ",")
16 | )
17 | else [],
18 |
19 |
20 | // Default parameters.
21 | defaultParams:: {
22 | bucket: "kubeflow-releasing-artifacts",
23 | commit: "master",
24 | // Name of the secret containing GCP credentials.
25 | gcpCredentialsSecretName: "kubeflow-testing-credentials",
26 | name: "placeholder",
27 | namespace: "kubeflow-releasing",
28 | // The name of the NFS volume claim to use for test files.
29 | nfsVolumeClaim: "nfs-external",
30 | prow_env: "REPO_OWNER=kubeflow,REPO_NAME=kubeflow,PULL_BASE_SHA=master",
31 | registry: "gcr.io/kubeflow-images-public",
32 | versionTag: "latest",
33 | // The default image to use for the steps in the Argo workflow.
34 | testing_image: "gcr.io/kubeflow-ci/worker:latest",
35 | project: "kubeflow-releasing",
36 | cluster: "kubeflow-releasing",
37 | zone: "us-central1-a",
38 | image: "default-should-not-exist",
39 | dockerfile: "Dockerfile",
40 | extra_repos: "kubeflow/testing@HEAD",
41 | extra_args: "",
42 | },
43 |
44 | parts(namespace, name, overrides={}):: {
45 | // Workflow to release image.
46 | release::
47 | local params = $.defaultParams + overrides;
48 |
49 | local namespace = params.namespace;
50 | local testing_image = params.testing_image;
51 | local project = params.project;
52 | local cluster = params.cluster;
53 | local zone = params.zone;
54 | local name = params.name;
55 |
56 | local prow_env = $.parseEnv(params.prow_env);
57 | local bucket = params.bucket;
58 |
59 | local stepsNamespace = name;
60 | // mountPath is the directory where the volume to store the test data
61 | // should be mounted.
62 | local mountPath = "/mnt/" + "test-data-volume";
63 | // testDir is the root directory for all data for a particular test run.
64 | local testDir = mountPath + "/" + name;
65 | // outputDir is the directory to sync to GCS to contain the output for this job.
66 | local outputDir = testDir + "/output";
67 | local artifactsDir = outputDir + "/artifacts";
68 | // Source directory where all repos should be checked out
69 | local srcRootDir = testDir + "/src";
70 | // The directory containing the kubeflow/kubeflow repo
71 | local srcDir = srcRootDir + "/kubeflow/kubeflow";
72 | // The name to use for the volume to use to contain test data.
73 | local dataVolume = "kubeflow-test-volume";
74 | local kubeflowPy = srcDir;
75 | // The directory within the kubeflow_testing submodule containing
76 | // py scripts to use.
77 | local kubeflowTestingPy = srcRootDir + "/kubeflow/testing/py";
78 |
79 | // Location where build_image.sh
80 | local imageDir = srcRootDir + "/" + params.dockerfileDir;
81 |
82 | local releaseImage = params.registry + "/" + params.image;
83 |
84 | // Build an Argo template to execute a particular command.
85 | // step_name: Name for the template
86 | // command: List to pass as the container command.
87 | local buildTemplate(step_name, command, env_vars=[], sidecars=[]) = {
88 | name: step_name,
89 | container: {
90 | command: command,
91 | image: testing_image,
92 | env: [
93 | {
94 | // Add the source directories to the python path.
95 | name: "PYTHONPATH",
96 | value: kubeflowPy + ":" + kubeflowTestingPy,
97 | },
98 | {
99 | name: "GOOGLE_APPLICATION_CREDENTIALS",
100 | value: "/secret/gcp-credentials/key.json",
101 | },
102 | {
103 | name: "GITHUB_TOKEN",
104 | valueFrom: {
105 | secretKeyRef: {
106 | name: "github-token",
107 | key: "github_token",
108 | },
109 | },
110 | },
111 | {
112 | name: "GCP_PROJECT",
113 | value: project,
114 | },
115 | {
116 | name: "GCP_REGISTRY",
117 | value: params.registry,
118 | },
119 | ] + prow_env + env_vars,
120 | resources: {
121 | requests: {
122 | memory: "2Gi",
123 | cpu: "2",
124 | },
125 | limits: {
126 | memory: "32Gi",
127 | cpu: "16",
128 | },
129 | },
130 | volumeMounts: [
131 | {
132 | name: dataVolume,
133 | mountPath: mountPath,
134 | },
135 | {
136 | name: "github-token",
137 | mountPath: "/secret/github-token",
138 | },
139 | {
140 | name: "gcp-credentials",
141 | mountPath: "/secret/gcp-credentials",
142 | },
143 | ],
144 | },
145 | sidecars: sidecars,
146 | }; // buildTemplate
147 |
148 |
149 | local buildImageTemplate(step_name, imageDir, dockerfile, image) =
150 | buildTemplate(
151 | step_name,
152 | [
153 | // We need to explicitly specify bash because
154 | // build_image.sh is not in the container its a volume mounted file.
155 | "/bin/bash",
156 | "-c",
157 | imageDir + "/build_image.sh "
158 | + imageDir + "/" + dockerfile + " "
159 | + image + " "
160 | + params.versionTag
161 | + params.extra_args,
162 | ],
163 | [
164 | {
165 | name: "DOCKER_HOST",
166 | value: "127.0.0.1",
167 | },
168 | ],
169 | [{
170 | name: "dind",
171 | image: "docker:17.10-dind",
172 | securityContext: {
173 | privileged: true,
174 | },
175 | resources: {
176 | requests: {
177 | memory: "2Gi",
178 | cpu: "2",
179 | },
180 | limits: {
181 | memory: "32Gi",
182 | cpu: "16",
183 | },
184 | },
185 | mirrorVolumeMounts: true,
186 | }],
187 | ); // buildImageTemplate
188 | {
189 | apiVersion: "argoproj.io/v1alpha1",
190 | kind: "Workflow",
191 | metadata: {
192 | name: name,
193 | namespace: namespace,
194 | },
195 | spec: {
196 | entrypoint: "release",
197 | volumes: [
198 | {
199 | name: "github-token",
200 | secret: {
201 | secretName: "github-token",
202 | },
203 | },
204 | {
205 | name: "gcp-credentials",
206 | secret: {
207 | secretName: params.gcpCredentialsSecretName,
208 | },
209 | },
210 | {
211 | name: dataVolume,
212 | persistentVolumeClaim: {
213 | claimName: params.nfsVolumeClaim,
214 | },
215 | },
216 | ], // volumes
217 |
218 | // onExit specifies the template that should always run when the workflow completes.
219 | onExit: "exit-handler",
220 |
221 | templates: [
222 | {
223 | name: "release",
224 | dag: {
225 | tasks: [
226 | {
227 | name: "checkout",
228 | template: "checkout",
229 | },
230 | {
231 | name: "image-build-release",
232 | template: "image-build-release",
233 | dependencies: ["checkout"],
234 | },
235 | ], // tasks
236 | }, //dag
237 | }, // release
238 | {
239 | name: "exit-handler",
240 | steps: [
241 | [{
242 | name: "copy-artifacts",
243 | template: "copy-artifacts",
244 | }],
245 | ],
246 | },
247 | {
248 | name: "checkout",
249 | container: {
250 | command: [
251 | "/usr/local/bin/checkout.sh",
252 | ],
253 | args: [
254 | srcRootDir,
255 | ],
256 | env: prow_env + [{
257 | name: "EXTRA_REPOS",
258 | value: params.extra_repos,
259 | }],
260 | image: testing_image,
261 | volumeMounts: [
262 | {
263 | name: dataVolume,
264 | mountPath: mountPath,
265 | },
266 | ],
267 | },
268 | }, // checkout
269 |
270 | buildImageTemplate("image-build-release", imageDir, params.dockerfile, releaseImage),
271 |
272 | buildTemplate(
273 | "copy-artifacts",
274 | [
275 | "python",
276 | "-m",
277 | "kubeflow.testing.prow_artifacts",
278 | "--artifacts_dir=" + outputDir,
279 | "copy_artifacts",
280 | "--bucket=" + bucket,
281 | ]
282 | ), // copy-artifacts
283 | ], // templates
284 | },
285 | }, // release
286 | }, // parts
287 | }
288 |
--------------------------------------------------------------------------------