├── .assets └── dashboard.png ├── .gitignore ├── LICENSE ├── README.md ├── dashboards └── model.json ├── kubernetes ├── load_tests │ ├── locust_master.yaml │ └── locust_worker.yaml └── models │ └── wine_quality.yaml ├── load_test ├── Dockerfile ├── locustfile.py └── requirements.txt └── model ├── Dockerfile ├── __init__.py ├── app ├── api.py ├── monitoring.py └── schemas.py ├── artifacts ├── model.joblib └── scaler.joblib ├── requirements.txt ├── setup.py └── train.py /.assets/dashboard.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jeremyjordan/ml-monitoring/7cf0f336bd4cbbfd6587fd24edb84ff84620c26b/.assets/dashboard.png -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Project 2 | .vscode/settings.json 3 | .DS_Store 4 | *.ipynb 5 | .ipynb_checkpoints 6 | secrets.json 7 | 8 | # Byte-compiled / optimized / DLL files 9 | __pycache__/ 10 | *.py[cod] 11 | *$py.class 12 | 13 | 14 | # Distribution / packaging 15 | *.egg-info/ -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2020 Jeremy Jordan 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # ml-monitoring 2 | *Jeremy Jordan* 3 | 4 | This repository provides an example setup for monitoring an ML system deployed on Kubernetes. 5 | 6 | Blog post: https://www.jeremyjordan.me/ml-monitoring/ 7 | 8 | Components: 9 | - ML model served via `FastAPI` 10 | - Export server metrics via `prometheus-fastapi-instrumentator` 11 | - Simulate production traffic via `locust` 12 | - Monitor and store metrics via `Prometheus` 13 | - Visualize metrics via `Grafana` 14 | 15 | ![](.assets/dashboard.png) 16 | 17 | ## Setup 18 | 19 | 1. Ensure you can connect to a Kubernetes cluster and have [`kubectl`](https://kubernetes.io/docs/tasks/tools/install-kubectl/) and [`helm`](https://helm.sh/docs/intro/install/) installed. 20 | - You can easily spin up a Kubernetes cluster on your local machine using [minikube](https://minikube.sigs.k8s.io/docs/start/). 21 | ``` 22 | minikube start --driver=docker --memory 4g --nodes 2 23 | ``` 24 | 25 | 2. Deploy Prometheus and Grafana onto the cluster using the [community Helm chart](https://github.com/prometheus-community/helm-charts/tree/main/charts/kube-prometheus-stack). 26 | ``` 27 | kubectl create namespace monitoring 28 | helm repo add prometheus-community https://prometheus-community.github.io/helm-charts 29 | helm install prometheus-stack prometheus-community/kube-prometheus-stack -n monitoring 30 | ``` 31 | 3. Verify the resources were deployed successfully. 32 | ``` 33 | kubectl get all -n monitoring 34 | ``` 35 | 4. Connect to the Grafana dashboard. 36 | ``` 37 | kubectl port-forward svc/prometheus-stack-grafana 8000:80 -n monitoring 38 | ``` 39 | - Go to http://127.0.0.1:8000/ 40 | - Log in with the credentials: 41 | - Username: admin 42 | - Password: prom-operator 43 | - (This password can be configured in the Helm chart `values.yaml` file) 44 | 5. Import the model dashboard. 45 | - On the left sidebar, click the "+" and select "Import". 46 | - Copy and paste the JSON defined in `dashboards/model.json` in the text area. 47 | 48 | ## Deploy a model 49 | 50 | This repository includes an example REST service which exposes an ML model trained on the [UCI Wine Quality dataset](https://archive.ics.uci.edu/ml/datasets/wine+quality). 51 | 52 | You can launch the service on Kubernetes by running: 53 | 54 | ``` 55 | kubectl apply -f kubernetes/models/ 56 | ``` 57 | 58 | You can also build and run the Docker container locally. 59 | 60 | ``` 61 | docker build -t wine-quality-model -f model/Dockerfile model/ 62 | docker run -d -p 3000:80 -e ENABLE_METRICS=true wine-quality-model 63 | ``` 64 | 65 | > **Note:** In order for Prometheus to scrape metrics from this service, we need to define a `ServiceMonitor` resource. This resource must have the label `release: prometheus-stack` in order to be discovered. This is configured in the `Prometheus` resource spec via the `serviceMonitorSelector` attribute. 66 | 67 | You can verify the label required by running: 68 | ``` 69 | kubectl get prometheuses.monitoring.coreos.com prometheus-stack-kube-prom-prometheus -n monitoring -o yaml 70 | ``` 71 | 72 | ## Simulate production traffic 73 | 74 | We can simulate production traffic using a Python load testing tool called [`locust`](https://locust.io/). This will make HTTP requests to our model server and provide us with data to view in the monitoring dashboard. 75 | 76 | You can begin the load test by running: 77 | ``` 78 | kubectl apply -f kubernetes/load_tests/ 79 | ``` 80 | By default, production traffic will be simulated for a duration of 5 minutes. This can be changed by updating the image arguments in the `kubernetes/load_tests/locust_master.yaml` manifest. 81 | 82 | You can also modify the community [Helm chart](https://github.com/deliveryhero/helm-charts/tree/master/stable/locust/templates) instead of using the manifests defined in this repo. 83 | 84 | 85 | ## Uploading new images 86 | 87 | This process can eventually be automated with a Github action, but remains manual for now. 88 | 89 | 1. Obtain a personal access token to connect with the Github container registry. 90 | ``` 91 | echo "INSERT_TOKEN_HERE" >> ~/.github/cr_token 92 | ``` 93 | 2. Authenticate with the Github container registry. 94 | ``` 95 | cat ~/.github/cr_token | docker login ghcr.io -u jeremyjordan --password-stdin 96 | ``` 97 | 3. Build and tag new Docker images. 98 | ``` 99 | MODEL_TAG=0.3 100 | docker build -t wine-quality-model:$MODEL_TAG -f model/Dockerfile model/ 101 | docker tag wine-quality-model:$MODEL_TAG ghcr.io/jeremyjordan/wine-quality-model:$MODEL_TAG 102 | ``` 103 | 104 | ``` 105 | LOAD_TAG=0.2 106 | docker build -t locust-load-test:$LOAD_TAG -f load_test/Dockerfile load_test/ 107 | docker tag locust-load-test:$LOAD_TAG ghcr.io/jeremyjordan/locust-load-test:$LOAD_TAG 108 | ``` 109 | 4. Push Docker images to container registery. 110 | ``` 111 | docker push ghcr.io/jeremyjordan/wine-quality-model:$MODEL_TAG 112 | docker push ghcr.io/jeremyjordan/locust-load-test:$LOAD_TAG 113 | ``` 114 | 5. Update Kubernetes manifests to use the new image tag. 115 | 116 | ## Teardown instructions 117 | 118 | To stop the model REST server, run: 119 | ``` 120 | kubectl delete -f kubernetes/models/ 121 | ``` 122 | 123 | To stop the load tests, run: 124 | ``` 125 | kubectl delete -f kubernetes/load_tests/ 126 | ``` 127 | 128 | To remove the Prometheus stack, run: 129 | ``` 130 | helm uninstall prometheus-stack -n monitoring 131 | ``` 132 | -------------------------------------------------------------------------------- /dashboards/model.json: -------------------------------------------------------------------------------- 1 | { 2 | "annotations": { 3 | "list": [ 4 | { 5 | "builtIn": 1, 6 | "datasource": "-- Grafana --", 7 | "enable": true, 8 | "hide": true, 9 | "iconColor": "rgba(0, 211, 255, 1)", 10 | "name": "Annotations & Alerts", 11 | "type": "dashboard" 12 | } 13 | ] 14 | }, 15 | "editable": true, 16 | "gnetId": null, 17 | "graphTooltip": 0, 18 | "id": 25, 19 | "links": [], 20 | "panels": [ 21 | { 22 | "collapsed": false, 23 | "datasource": null, 24 | "gridPos": { 25 | "h": 1, 26 | "w": 24, 27 | "x": 0, 28 | "y": 0 29 | }, 30 | "id": 18, 31 | "panels": [], 32 | "title": "Model Metrics", 33 | "type": "row" 34 | }, 35 | { 36 | "datasource": null, 37 | "fieldConfig": { 38 | "defaults": { 39 | "custom": {} 40 | }, 41 | "overrides": [] 42 | }, 43 | "gridPos": { 44 | "h": 9, 45 | "w": 5, 46 | "x": 0, 47 | "y": 1 48 | }, 49 | "id": 10, 50 | "options": { 51 | "content": "# Model Predictions\n\nThis model is responsible for predicting \nthe quality of wine given a number of chemical \nattributes. The scores range from 0-10 where\n0 is the worst and 10 is the best.", 52 | "mode": "markdown" 53 | }, 54 | "pluginVersion": "7.1.0", 55 | "timeFrom": null, 56 | "timeShift": null, 57 | "title": "", 58 | "type": "text" 59 | }, 60 | { 61 | "datasource": null, 62 | "fieldConfig": { 63 | "defaults": { 64 | "custom": {}, 65 | "mappings": [], 66 | "max": 10, 67 | "min": 0, 68 | "noValue": "No Data", 69 | "thresholds": { 70 | "mode": "absolute", 71 | "steps": [ 72 | { 73 | "color": "red", 74 | "value": null 75 | }, 76 | { 77 | "color": "yellow", 78 | "value": 1 79 | }, 80 | { 81 | "color": "green", 82 | "value": 3 83 | }, 84 | { 85 | "color": "#EAB839", 86 | "value": 7 87 | }, 88 | { 89 | "color": "red", 90 | "value": 9 91 | } 92 | ] 93 | } 94 | }, 95 | "overrides": [] 96 | }, 97 | "gridPos": { 98 | "h": 9, 99 | "w": 4, 100 | "x": 5, 101 | "y": 1 102 | }, 103 | "id": 8, 104 | "options": { 105 | "orientation": "auto", 106 | "reduceOptions": { 107 | "calcs": [ 108 | "mean" 109 | ], 110 | "fields": "", 111 | "values": false 112 | }, 113 | "showThresholdLabels": false, 114 | "showThresholdMarkers": true 115 | }, 116 | "pluginVersion": "7.2.1", 117 | "targets": [ 118 | { 119 | "expr": "sum(rate(fastapi_regression_model_output_sum[30m])) / sum(rate(fastapi_regression_model_output_count[30m]))", 120 | "interval": "", 121 | "legendFormat": "", 122 | "refId": "A" 123 | } 124 | ], 125 | "timeFrom": null, 126 | "timeShift": null, 127 | "title": "Model Score (30m avg)", 128 | "type": "gauge" 129 | }, 130 | { 131 | "cards": { 132 | "cardPadding": null, 133 | "cardRound": null 134 | }, 135 | "color": { 136 | "cardColor": "#F2495C", 137 | "colorScale": "sqrt", 138 | "colorScheme": "interpolateViridis", 139 | "exponent": 0.5, 140 | "mode": "spectrum" 141 | }, 142 | "dataFormat": "tsbuckets", 143 | "datasource": null, 144 | "description": "Average per second count of predictions which fall into each bucket.", 145 | "fieldConfig": { 146 | "defaults": { 147 | "custom": { 148 | "align": null, 149 | "filterable": false 150 | }, 151 | "mappings": [], 152 | "thresholds": { 153 | "mode": "absolute", 154 | "steps": [ 155 | { 156 | "color": "green", 157 | "value": null 158 | }, 159 | { 160 | "color": "red", 161 | "value": 80 162 | } 163 | ] 164 | } 165 | }, 166 | "overrides": [] 167 | }, 168 | "gridPos": { 169 | "h": 9, 170 | "w": 15, 171 | "x": 9, 172 | "y": 1 173 | }, 174 | "heatmap": {}, 175 | "hideZeroBuckets": true, 176 | "highlightCards": true, 177 | "id": 6, 178 | "legend": { 179 | "show": true 180 | }, 181 | "pluginVersion": "7.2.1", 182 | "reverseYBuckets": false, 183 | "targets": [ 184 | { 185 | "expr": "sum(rate(fastapi_regression_model_output_bucket[1m])) by (le)", 186 | "format": "heatmap", 187 | "interval": "", 188 | "legendFormat": "{{le}}", 189 | "refId": "A" 190 | } 191 | ], 192 | "timeFrom": null, 193 | "timeShift": null, 194 | "title": "Model Prediction Distribution", 195 | "tooltip": { 196 | "show": true, 197 | "showHistogram": false 198 | }, 199 | "type": "heatmap", 200 | "xAxis": { 201 | "show": true 202 | }, 203 | "xBucketNumber": null, 204 | "xBucketSize": null, 205 | "yAxis": { 206 | "decimals": null, 207 | "format": "short", 208 | "logBase": 1, 209 | "max": null, 210 | "min": null, 211 | "show": true, 212 | "splitFactor": null 213 | }, 214 | "yBucketBound": "auto", 215 | "yBucketNumber": null, 216 | "yBucketSize": null 217 | }, 218 | { 219 | "collapsed": false, 220 | "datasource": null, 221 | "gridPos": { 222 | "h": 1, 223 | "w": 24, 224 | "x": 0, 225 | "y": 10 226 | }, 227 | "id": 16, 228 | "panels": [], 229 | "title": "Service Metrics", 230 | "type": "row" 231 | }, 232 | { 233 | "datasource": null, 234 | "fieldConfig": { 235 | "defaults": { 236 | "custom": {} 237 | }, 238 | "overrides": [] 239 | }, 240 | "gridPos": { 241 | "h": 9, 242 | "w": 5, 243 | "x": 0, 244 | "y": 11 245 | }, 246 | "id": 12, 247 | "options": { 248 | "content": "# Request Throughput\n\n2xx status codes denote a successful model\nprediction.\n\n4xx status codes are common when clients\nPOST data that does not meet the expected\nschema.\n\nWe calculate the success rate as the percentage\nof HTTP requests that do not return a 5xx code.", 249 | "mode": "markdown" 250 | }, 251 | "pluginVersion": "7.1.0", 252 | "timeFrom": null, 253 | "timeShift": null, 254 | "title": "", 255 | "type": "text" 256 | }, 257 | { 258 | "datasource": null, 259 | "fieldConfig": { 260 | "defaults": { 261 | "custom": {}, 262 | "mappings": [], 263 | "max": 1, 264 | "min": 0.8, 265 | "noValue": "No Data", 266 | "thresholds": { 267 | "mode": "absolute", 268 | "steps": [ 269 | { 270 | "color": "red", 271 | "value": null 272 | }, 273 | { 274 | "color": "yellow", 275 | "value": 0.95 276 | }, 277 | { 278 | "color": "green", 279 | "value": 0.99 280 | } 281 | ] 282 | }, 283 | "unit": "percentunit" 284 | }, 285 | "overrides": [] 286 | }, 287 | "gridPos": { 288 | "h": 9, 289 | "w": 4, 290 | "x": 5, 291 | "y": 11 292 | }, 293 | "id": 14, 294 | "options": { 295 | "orientation": "auto", 296 | "reduceOptions": { 297 | "calcs": [ 298 | "mean" 299 | ], 300 | "fields": "", 301 | "values": false 302 | }, 303 | "showThresholdLabels": false, 304 | "showThresholdMarkers": true 305 | }, 306 | "pluginVersion": "7.2.1", 307 | "targets": [ 308 | { 309 | "expr": "sum(rate(fastapi_http_requests_total{status!=\"5xx\"}[30m])) / sum(rate(fastapi_http_requests_total[30m]))", 310 | "interval": "", 311 | "legendFormat": "", 312 | "refId": "A" 313 | } 314 | ], 315 | "timeFrom": null, 316 | "timeShift": null, 317 | "title": "HTTP Success Rate (30m avg)", 318 | "type": "gauge" 319 | }, 320 | { 321 | "aliasColors": {}, 322 | "bars": false, 323 | "dashLength": 10, 324 | "dashes": false, 325 | "datasource": null, 326 | "fieldConfig": { 327 | "defaults": { 328 | "custom": {}, 329 | "unit": "reqps" 330 | }, 331 | "overrides": [] 332 | }, 333 | "fill": 10, 334 | "fillGradient": 0, 335 | "gridPos": { 336 | "h": 9, 337 | "w": 15, 338 | "x": 9, 339 | "y": 11 340 | }, 341 | "hiddenSeries": false, 342 | "id": 2, 343 | "legend": { 344 | "alignAsTable": false, 345 | "avg": false, 346 | "current": false, 347 | "hideEmpty": false, 348 | "hideZero": false, 349 | "max": false, 350 | "min": false, 351 | "show": true, 352 | "total": false, 353 | "values": false 354 | }, 355 | "lines": true, 356 | "linewidth": 1, 357 | "nullPointMode": "null", 358 | "options": { 359 | "alertThreshold": true 360 | }, 361 | "percentage": false, 362 | "pluginVersion": "7.2.1", 363 | "pointradius": 2, 364 | "points": false, 365 | "renderer": "flot", 366 | "seriesOverrides": [], 367 | "spaceLength": 10, 368 | "stack": true, 369 | "steppedLine": false, 370 | "targets": [ 371 | { 372 | "expr": "sum(rate(fastapi_http_requests_total[1m])) by (status) ", 373 | "interval": "", 374 | "legendFormat": "", 375 | "refId": "A" 376 | } 377 | ], 378 | "thresholds": [], 379 | "timeFrom": null, 380 | "timeRegions": [], 381 | "timeShift": null, 382 | "title": "Requests", 383 | "tooltip": { 384 | "shared": true, 385 | "sort": 0, 386 | "value_type": "individual" 387 | }, 388 | "type": "graph", 389 | "xaxis": { 390 | "buckets": null, 391 | "mode": "time", 392 | "name": null, 393 | "show": true, 394 | "values": [] 395 | }, 396 | "yaxes": [ 397 | { 398 | "format": "reqps", 399 | "label": null, 400 | "logBase": 1, 401 | "max": null, 402 | "min": "0", 403 | "show": true 404 | }, 405 | { 406 | "format": "short", 407 | "label": null, 408 | "logBase": 1, 409 | "max": null, 410 | "min": null, 411 | "show": true 412 | } 413 | ], 414 | "yaxis": { 415 | "align": false, 416 | "alignLevel": null 417 | } 418 | }, 419 | { 420 | "datasource": null, 421 | "description": "", 422 | "fieldConfig": { 423 | "defaults": { 424 | "custom": {} 425 | }, 426 | "overrides": [] 427 | }, 428 | "gridPos": { 429 | "h": 9, 430 | "w": 5, 431 | "x": 0, 432 | "y": 20 433 | }, 434 | "id": 20, 435 | "options": { 436 | "content": "# Request Latencies\n\nWe have two SLOs for API latency:\n- Average latency SLO is breached if p50\nlatencies exceed 50ms for 5 minutes.\n- Extreme latency SLO is breached if p99\nlatencies exceed 300ms for 5 minutes. ", 437 | "mode": "markdown" 438 | }, 439 | "pluginVersion": "7.1.0", 440 | "timeFrom": null, 441 | "timeShift": null, 442 | "title": "", 443 | "type": "text" 444 | }, 445 | { 446 | "datasource": null, 447 | "fieldConfig": { 448 | "defaults": { 449 | "custom": {}, 450 | "mappings": [], 451 | "min": 0, 452 | "thresholds": { 453 | "mode": "absolute", 454 | "steps": [ 455 | { 456 | "color": "green", 457 | "value": null 458 | }, 459 | { 460 | "color": "red", 461 | "value": 80 462 | } 463 | ] 464 | } 465 | }, 466 | "overrides": [] 467 | }, 468 | "gridPos": { 469 | "h": 9, 470 | "w": 4, 471 | "x": 5, 472 | "y": 20 473 | }, 474 | "id": 23, 475 | "options": { 476 | "colorMode": "value", 477 | "graphMode": "area", 478 | "justifyMode": "auto", 479 | "orientation": "auto", 480 | "reduceOptions": { 481 | "calcs": [ 482 | "last" 483 | ], 484 | "fields": "", 485 | "values": false 486 | }, 487 | "textMode": "auto" 488 | }, 489 | "pluginVersion": "7.2.1", 490 | "targets": [ 491 | { 492 | "expr": "sum(fastapi_inprogress)", 493 | "instant": false, 494 | "interval": "", 495 | "legendFormat": "", 496 | "refId": "A" 497 | } 498 | ], 499 | "timeFrom": null, 500 | "timeShift": null, 501 | "title": "Requests In Progress", 502 | "type": "stat" 503 | }, 504 | { 505 | "aliasColors": {}, 506 | "bars": false, 507 | "dashLength": 10, 508 | "dashes": false, 509 | "datasource": null, 510 | "fieldConfig": { 511 | "defaults": { 512 | "custom": {}, 513 | "unit": "s" 514 | }, 515 | "overrides": [] 516 | }, 517 | "fill": 1, 518 | "fillGradient": 0, 519 | "gridPos": { 520 | "h": 9, 521 | "w": 15, 522 | "x": 9, 523 | "y": 20 524 | }, 525 | "hiddenSeries": false, 526 | "id": 4, 527 | "legend": { 528 | "avg": false, 529 | "current": false, 530 | "max": false, 531 | "min": false, 532 | "show": true, 533 | "total": false, 534 | "values": false 535 | }, 536 | "lines": true, 537 | "linewidth": 1, 538 | "nullPointMode": "null", 539 | "options": { 540 | "alertThreshold": true 541 | }, 542 | "percentage": false, 543 | "pluginVersion": "7.2.1", 544 | "pointradius": 2, 545 | "points": false, 546 | "renderer": "flot", 547 | "seriesOverrides": [], 548 | "spaceLength": 10, 549 | "stack": false, 550 | "steppedLine": false, 551 | "targets": [ 552 | { 553 | "expr": "histogram_quantile(0.99, \n sum(\n rate(fastapi_http_request_duration_seconds_bucket[1m])\n ) by (le)\n)", 554 | "interval": "", 555 | "legendFormat": "99th percentile", 556 | "refId": "A" 557 | }, 558 | { 559 | "expr": "histogram_quantile(0.95, \n sum(\n rate(fastapi_http_request_duration_seconds_bucket[1m])\n ) by (le)\n)", 560 | "interval": "", 561 | "legendFormat": "95th percentile", 562 | "refId": "B" 563 | }, 564 | { 565 | "expr": "histogram_quantile(0.50, \n sum(\n rate(fastapi_http_request_duration_seconds_bucket[1m])\n ) by (le)\n)", 566 | "interval": "", 567 | "legendFormat": "50th percentile", 568 | "refId": "C" 569 | } 570 | ], 571 | "thresholds": [], 572 | "timeFrom": null, 573 | "timeRegions": [], 574 | "timeShift": null, 575 | "title": "Latency", 576 | "tooltip": { 577 | "shared": true, 578 | "sort": 0, 579 | "value_type": "individual" 580 | }, 581 | "type": "graph", 582 | "xaxis": { 583 | "buckets": null, 584 | "mode": "time", 585 | "name": null, 586 | "show": true, 587 | "values": [] 588 | }, 589 | "yaxes": [ 590 | { 591 | "format": "s", 592 | "label": null, 593 | "logBase": 1, 594 | "max": null, 595 | "min": null, 596 | "show": true 597 | }, 598 | { 599 | "format": "short", 600 | "label": null, 601 | "logBase": 1, 602 | "max": null, 603 | "min": null, 604 | "show": true 605 | } 606 | ], 607 | "yaxis": { 608 | "align": false, 609 | "alignLevel": null 610 | } 611 | }, 612 | { 613 | "datasource": null, 614 | "fieldConfig": { 615 | "defaults": { 616 | "custom": {} 617 | }, 618 | "overrides": [] 619 | }, 620 | "gridPos": { 621 | "h": 9, 622 | "w": 5, 623 | "x": 0, 624 | "y": 29 625 | }, 626 | "id": 29, 627 | "options": { 628 | "content": "# Request Size\n\nThis visualizes the request body sizes for\nthe model predict route. We expect the request\nbody to contain a dictionary with 11 keys with \nall values being floats. ", 629 | "mode": "markdown" 630 | }, 631 | "pluginVersion": "7.1.0", 632 | "timeFrom": null, 633 | "timeShift": null, 634 | "title": "", 635 | "type": "text" 636 | }, 637 | { 638 | "datasource": null, 639 | "fieldConfig": { 640 | "defaults": { 641 | "custom": {}, 642 | "mappings": [], 643 | "noValue": "No Data", 644 | "thresholds": { 645 | "mode": "absolute", 646 | "steps": [ 647 | { 648 | "color": "green", 649 | "value": null 650 | }, 651 | { 652 | "color": "#EAB839", 653 | "value": 500 654 | } 655 | ] 656 | }, 657 | "unit": "decbytes" 658 | }, 659 | "overrides": [] 660 | }, 661 | "gridPos": { 662 | "h": 9, 663 | "w": 4, 664 | "x": 5, 665 | "y": 29 666 | }, 667 | "id": 31, 668 | "options": { 669 | "colorMode": "value", 670 | "graphMode": "none", 671 | "justifyMode": "auto", 672 | "orientation": "auto", 673 | "reduceOptions": { 674 | "calcs": [ 675 | "mean" 676 | ], 677 | "fields": "", 678 | "values": false 679 | }, 680 | "textMode": "auto" 681 | }, 682 | "pluginVersion": "7.2.1", 683 | "targets": [ 684 | { 685 | "expr": "sum(rate(fastapi_http_request_size_bytes_sum{handler=\"/predict\"}[30m])) / sum(rate(fastapi_http_request_size_bytes_count{handler=\"/predict\"}[30m]))", 686 | "interval": "", 687 | "legendFormat": "", 688 | "refId": "A" 689 | } 690 | ], 691 | "timeFrom": null, 692 | "timeShift": null, 693 | "title": "Request Size (30m avg)", 694 | "type": "stat" 695 | }, 696 | { 697 | "aliasColors": {}, 698 | "bars": false, 699 | "dashLength": 10, 700 | "dashes": false, 701 | "datasource": null, 702 | "fieldConfig": { 703 | "defaults": { 704 | "custom": {}, 705 | "unit": "decbytes" 706 | }, 707 | "overrides": [] 708 | }, 709 | "fill": 1, 710 | "fillGradient": 0, 711 | "gridPos": { 712 | "h": 9, 713 | "w": 15, 714 | "x": 9, 715 | "y": 29 716 | }, 717 | "hiddenSeries": false, 718 | "id": 30, 719 | "legend": { 720 | "avg": false, 721 | "current": false, 722 | "max": false, 723 | "min": false, 724 | "show": true, 725 | "total": false, 726 | "values": false 727 | }, 728 | "lines": true, 729 | "linewidth": 1, 730 | "nullPointMode": "null", 731 | "options": { 732 | "alertThreshold": true 733 | }, 734 | "percentage": false, 735 | "pluginVersion": "7.2.1", 736 | "pointradius": 2, 737 | "points": false, 738 | "renderer": "flot", 739 | "seriesOverrides": [], 740 | "spaceLength": 10, 741 | "stack": false, 742 | "steppedLine": false, 743 | "targets": [ 744 | { 745 | "expr": "sum(rate(fastapi_http_request_size_bytes_sum{handler=\"/predict\"}[1m])) / sum(rate(fastapi_http_request_size_bytes_count{handler=\"/predict\"}[1m]))", 746 | "interval": "", 747 | "legendFormat": "", 748 | "refId": "A" 749 | } 750 | ], 751 | "thresholds": [], 752 | "timeFrom": null, 753 | "timeRegions": [], 754 | "timeShift": null, 755 | "title": "Request Size", 756 | "tooltip": { 757 | "shared": true, 758 | "sort": 0, 759 | "value_type": "individual" 760 | }, 761 | "type": "graph", 762 | "xaxis": { 763 | "buckets": null, 764 | "mode": "time", 765 | "name": null, 766 | "show": true, 767 | "values": [] 768 | }, 769 | "yaxes": [ 770 | { 771 | "format": "decbytes", 772 | "label": null, 773 | "logBase": 1, 774 | "max": null, 775 | "min": "0", 776 | "show": true 777 | }, 778 | { 779 | "format": "short", 780 | "label": null, 781 | "logBase": 1, 782 | "max": null, 783 | "min": null, 784 | "show": true 785 | } 786 | ], 787 | "yaxis": { 788 | "align": false, 789 | "alignLevel": null 790 | } 791 | }, 792 | { 793 | "collapsed": false, 794 | "datasource": null, 795 | "gridPos": { 796 | "h": 1, 797 | "w": 24, 798 | "x": 0, 799 | "y": 38 800 | }, 801 | "id": 25, 802 | "panels": [], 803 | "title": "Resource Metrics", 804 | "type": "row" 805 | } 806 | ], 807 | "refresh": "5s", 808 | "schemaVersion": 26, 809 | "style": "dark", 810 | "tags": [], 811 | "templating": { 812 | "list": [] 813 | }, 814 | "time": { 815 | "from": "now-5m", 816 | "to": "now" 817 | }, 818 | "timepicker": {}, 819 | "timezone": "", 820 | "title": "Model Dashboard", 821 | "uid": "K1ZW-uxGz", 822 | "version": 2 823 | } -------------------------------------------------------------------------------- /kubernetes/load_tests/locust_master.yaml: -------------------------------------------------------------------------------- 1 | apiVersion: apps/v1 2 | kind: Deployment 3 | metadata: 4 | name: locust-master 5 | labels: 6 | app: locust-master 7 | spec: 8 | replicas: 1 9 | selector: 10 | matchLabels: 11 | app: locust-master 12 | template: 13 | metadata: 14 | labels: 15 | app: locust-master 16 | spec: 17 | containers: 18 | - name: locust 19 | image: ghcr.io/jeremyjordan/locust-load-test:0.2 20 | command: 21 | - locust 22 | args: 23 | - -f 24 | - locustfile.py 25 | - --host 26 | - http://wine-quality-model-service.default 27 | - --master 28 | - --expect-workers=2 29 | - --headless 30 | - -u 31 | - "5" 32 | - -r 33 | - "1" 34 | - --run-time 35 | - 5m 36 | ports: 37 | - name: loc-master-web 38 | containerPort: 8089 39 | protocol: TCP 40 | - name: loc-master-p1 41 | containerPort: 5557 42 | protocol: TCP 43 | - name: loc-master-p2 44 | containerPort: 5558 45 | protocol: TCP 46 | resources: 47 | requests: 48 | memory: 100Mi 49 | cpu: 100m 50 | limits: 51 | memory: 200Mi 52 | cpu: 200m 53 | 54 | --- 55 | kind: Service 56 | apiVersion: v1 57 | metadata: 58 | name: locust-master 59 | labels: 60 | app: locust-master 61 | spec: 62 | ports: 63 | - port: 8089 64 | targetPort: loc-master-web 65 | protocol: TCP 66 | name: loc-master-web 67 | - port: 5557 68 | targetPort: loc-master-p1 69 | protocol: TCP 70 | name: loc-master-p1 71 | - port: 5558 72 | targetPort: loc-master-p2 73 | protocol: TCP 74 | name: loc-master-p2 75 | selector: 76 | app: locust-master 77 | type: ClusterIP 78 | -------------------------------------------------------------------------------- /kubernetes/load_tests/locust_worker.yaml: -------------------------------------------------------------------------------- 1 | apiVersion: apps/v1 2 | kind: Deployment 3 | metadata: 4 | name: locust-worker 5 | labels: 6 | app: locust-worker 7 | spec: 8 | replicas: 2 9 | selector: 10 | matchLabels: 11 | app: locust-worker 12 | template: 13 | metadata: 14 | labels: 15 | app: locust-worker 16 | spec: 17 | containers: 18 | - name: locust 19 | image: ghcr.io/jeremyjordan/locust-load-test:0.2 20 | command: 21 | - locust 22 | args: 23 | - -f 24 | - locustfile.py 25 | - --worker 26 | - --master-host=locust-master.default 27 | ports: 28 | - name: loc-master-web 29 | containerPort: 8089 30 | protocol: TCP 31 | - name: loc-master-p1 32 | containerPort: 5557 33 | protocol: TCP 34 | - name: loc-master-p2 35 | containerPort: 5558 36 | protocol: TCP 37 | resources: 38 | requests: 39 | memory: 100Mi 40 | cpu: 100m 41 | limits: 42 | memory: 200Mi 43 | cpu: 200m 44 | 45 | --- 46 | kind: Service 47 | apiVersion: v1 48 | metadata: 49 | name: locust-worker 50 | labels: 51 | app: locust-worker 52 | spec: 53 | ports: 54 | - port: 8089 55 | targetPort: loc-master-web 56 | protocol: TCP 57 | name: loc-master-web 58 | - port: 5557 59 | targetPort: loc-master-p1 60 | protocol: TCP 61 | name: loc-master-p1 62 | - port: 5558 63 | targetPort: loc-master-p2 64 | protocol: TCP 65 | name: loc-master-p2 66 | selector: 67 | app: locust-worker 68 | type: ClusterIP 69 | -------------------------------------------------------------------------------- /kubernetes/models/wine_quality.yaml: -------------------------------------------------------------------------------- 1 | apiVersion: apps/v1 2 | kind: Deployment 3 | metadata: 4 | name: wine-quality-model 5 | labels: 6 | app: wine-quality-model 7 | spec: 8 | selector: 9 | matchLabels: 10 | app: wine-quality-model 11 | replicas: 2 12 | strategy: 13 | type: RollingUpdate 14 | template: 15 | metadata: 16 | labels: 17 | app: wine-quality-model 18 | spec: 19 | containers: 20 | - name: fastapi-wine-quality-model 21 | image: ghcr.io/jeremyjordan/wine-quality-model:0.3 22 | imagePullPolicy: Always 23 | env: 24 | - name: ENABLE_METRICS 25 | value: "true" 26 | - name: METRICS_NAMESPACE 27 | value: "fastapi" 28 | - name: METRICS_SUBSYSTEM 29 | value: "" 30 | ports: 31 | - name: app 32 | containerPort: 80 33 | resources: 34 | requests: 35 | memory: 0.5Gi 36 | cpu: "0.25" 37 | limits: 38 | memory: 1Gi 39 | cpu: "0.5" 40 | livenessProbe: 41 | httpGet: 42 | path: /healthcheck 43 | port: 80 44 | 45 | --- 46 | apiVersion: v1 47 | kind: Service 48 | metadata: 49 | name: wine-quality-model-service 50 | labels: 51 | app: wine-quality-model 52 | spec: 53 | selector: 54 | app: wine-quality-model 55 | ports: 56 | - name: app 57 | port: 80 58 | type: ClusterIP 59 | 60 | --- 61 | apiVersion: monitoring.coreos.com/v1 62 | kind: ServiceMonitor 63 | metadata: 64 | name: wine-quality-model-servicemonitor 65 | # must by the same namespace that Prometheus is running in 66 | namespace: monitoring 67 | labels: 68 | app: wine-quality-model 69 | release: prometheus-stack 70 | spec: 71 | selector: 72 | matchLabels: 73 | app: wine-quality-model 74 | endpoints: 75 | - path: metrics 76 | port: app 77 | interval: 15s 78 | namespaceSelector: 79 | any: true 80 | -------------------------------------------------------------------------------- /load_test/Dockerfile: -------------------------------------------------------------------------------- 1 | FROM python:3.7 2 | COPY requirements.txt requirements.txt 3 | RUN pip install -r requirements.txt 4 | COPY locustfile.py locustfile.py 5 | EXPOSE 8089 6 | CMD ["locust", "-f", "load_test/locustfile.py", "--host", "http://127.0.0.1:3000"] -------------------------------------------------------------------------------- /load_test/locustfile.py: -------------------------------------------------------------------------------- 1 | """ 2 | Run load tests: 3 | 4 | locust -f load_test/locustfile.py --host http://127.0.0.1:3000 5 | """ 6 | 7 | from locust import HttpUser, task 8 | import pandas as pd 9 | import random 10 | 11 | feature_columns = { 12 | "fixed acidity": "fixed_acidity", 13 | "volatile acidity": "volatile_acidity", 14 | "citric acid": "citric_acid", 15 | "residual sugar": "residual_sugar", 16 | "chlorides": "chlorides", 17 | "free sulfur dioxide": "free_sulfur_dioxide", 18 | "total sulfur dioxide": "total_sulfur_dioxide", 19 | "density": "density", 20 | "pH": "ph", 21 | "sulphates": "sulphates", 22 | "alcohol": "alcohol_pct_vol", 23 | } 24 | dataset = ( 25 | pd.read_csv( 26 | "https://archive.ics.uci.edu/ml/machine-learning-databases/wine-quality/winequality-red.csv", 27 | delimiter=";", 28 | ) 29 | .rename(columns=feature_columns) 30 | .drop("quality", axis=1) 31 | .to_dict(orient="records") 32 | ) 33 | 34 | 35 | class WinePredictionUser(HttpUser): 36 | @task(1) 37 | def healthcheck(self): 38 | self.client.get("/healthcheck") 39 | 40 | @task(10) 41 | def prediction(self): 42 | record = random.choice(dataset).copy() 43 | self.client.post("/predict", json=record) 44 | 45 | @task(2) 46 | def prediction_bad_value(self): 47 | record = random.choice(dataset).copy() 48 | corrupt_key = random.choice(list(record.keys())) 49 | record[corrupt_key] = "bad data" 50 | self.client.post("/predict", json=record) 51 | -------------------------------------------------------------------------------- /load_test/requirements.txt: -------------------------------------------------------------------------------- 1 | locust==1.4.1 2 | pandas==0.23.3 3 | -------------------------------------------------------------------------------- /model/Dockerfile: -------------------------------------------------------------------------------- 1 | FROM python:3.7 2 | COPY requirements.txt requirements.txt 3 | RUN pip install -r requirements.txt 4 | WORKDIR /workdir 5 | COPY setup.py /workdir/setup.py 6 | COPY . /workdir/model 7 | RUN pip install . 8 | EXPOSE 80 9 | CMD ["uvicorn", "model.app.api:app", "--host", "0.0.0.0", "--port", "80"] 10 | -------------------------------------------------------------------------------- /model/__init__.py: -------------------------------------------------------------------------------- 1 | import json 2 | import logging 3 | import os 4 | from pathlib import Path 5 | 6 | 7 | logger = logging.getLogger(__name__) 8 | logging.basicConfig() 9 | 10 | REPO_DIR = Path(__file__).parent.parent 11 | 12 | 13 | if Path(REPO_DIR / "secrets.json").exists(): 14 | logger.info("Reading secrets into environment variables...") 15 | secrets = json.loads(Path(REPO_DIR / "secrets.json").read_text()) 16 | for k, v in secrets.items(): 17 | os.environ[k] = v 18 | -------------------------------------------------------------------------------- /model/app/api.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | 3 | import numpy as np 4 | from fastapi import FastAPI, Response 5 | 6 | from joblib import load 7 | from .schemas import Wine, Rating, feature_names 8 | from .monitoring import instrumentator 9 | 10 | ROOT_DIR = Path(__file__).parent.parent 11 | 12 | app = FastAPI() 13 | scaler = load(ROOT_DIR / "artifacts/scaler.joblib") 14 | model = load(ROOT_DIR / "artifacts/model.joblib") 15 | instrumentator.instrument(app).expose(app, include_in_schema=False, should_gzip=True) 16 | 17 | 18 | @app.get("/") 19 | async def root(): 20 | return "Wine Quality Ratings" 21 | 22 | 23 | @app.post("/predict", response_model=Rating) 24 | def predict(response: Response, sample: Wine): 25 | sample_dict = sample.dict() 26 | features = np.array([sample_dict[f] for f in feature_names]).reshape(1, -1) 27 | features_scaled = scaler.transform(features) 28 | prediction = model.predict(features_scaled)[0] 29 | response.headers["X-model-score"] = str(prediction) 30 | return Rating(quality=prediction) 31 | 32 | 33 | @app.get("/healthcheck") 34 | async def healthcheck(): 35 | return {"status": "ok"} 36 | -------------------------------------------------------------------------------- /model/app/monitoring.py: -------------------------------------------------------------------------------- 1 | import os 2 | from typing import Callable 3 | 4 | import numpy as np 5 | from prometheus_client import Histogram 6 | from prometheus_fastapi_instrumentator import Instrumentator, metrics 7 | from prometheus_fastapi_instrumentator.metrics import Info 8 | 9 | NAMESPACE = os.environ.get("METRICS_NAMESPACE", "fastapi") 10 | SUBSYSTEM = os.environ.get("METRICS_SUBSYSTEM", "model") 11 | 12 | instrumentator = Instrumentator( 13 | should_group_status_codes=True, 14 | should_ignore_untemplated=True, 15 | should_respect_env_var=True, 16 | should_instrument_requests_inprogress=True, 17 | excluded_handlers=["/metrics"], 18 | env_var_name="ENABLE_METRICS", 19 | inprogress_name="fastapi_inprogress", 20 | inprogress_labels=True, 21 | ) 22 | 23 | 24 | # ----- custom metrics ----- 25 | def regression_model_output( 26 | metric_name: str = "regression_model_output", 27 | metric_doc: str = "Output value of regression model", 28 | metric_namespace: str = "", 29 | metric_subsystem: str = "", 30 | buckets=(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, float("inf")), 31 | ) -> Callable[[Info], None]: 32 | METRIC = Histogram( 33 | metric_name, 34 | metric_doc, 35 | buckets=buckets, 36 | namespace=metric_namespace, 37 | subsystem=metric_subsystem, 38 | ) 39 | 40 | def instrumentation(info: Info) -> None: 41 | if info.modified_handler == "/predict": 42 | predicted_quality = info.response.headers.get("X-model-score") 43 | if predicted_quality: 44 | METRIC.observe(float(predicted_quality)) 45 | 46 | return instrumentation 47 | 48 | 49 | # ----- add metrics ----- 50 | instrumentator.add( 51 | metrics.request_size( 52 | should_include_handler=True, 53 | should_include_method=True, 54 | should_include_status=True, 55 | metric_namespace=NAMESPACE, 56 | metric_subsystem=SUBSYSTEM, 57 | ) 58 | ) 59 | instrumentator.add( 60 | metrics.response_size( 61 | should_include_handler=True, 62 | should_include_method=True, 63 | should_include_status=True, 64 | metric_namespace=NAMESPACE, 65 | metric_subsystem=SUBSYSTEM, 66 | ) 67 | ) 68 | instrumentator.add( 69 | metrics.latency( 70 | should_include_handler=True, 71 | should_include_method=True, 72 | should_include_status=True, 73 | metric_namespace=NAMESPACE, 74 | metric_subsystem=SUBSYSTEM, 75 | ) 76 | ) 77 | instrumentator.add( 78 | metrics.requests( 79 | should_include_handler=True, 80 | should_include_method=True, 81 | should_include_status=True, 82 | metric_namespace=NAMESPACE, 83 | metric_subsystem=SUBSYSTEM, 84 | ) 85 | ) 86 | 87 | buckets = (*np.arange(0, 10.5, 0.5).tolist(), float("inf")) 88 | instrumentator.add( 89 | regression_model_output(metric_namespace=NAMESPACE, metric_subsystem=SUBSYSTEM, buckets=buckets) 90 | ) 91 | -------------------------------------------------------------------------------- /model/app/schemas.py: -------------------------------------------------------------------------------- 1 | """ 2 | More information on the features used to describe wine: 3 | http://repositorium.sdum.uminho.pt/bitstream/1822/10029/1/wine5.pdf 4 | """ 5 | 6 | 7 | from pydantic import BaseModel, Field 8 | 9 | 10 | feature_names = [ 11 | "fixed_acidity", 12 | "volatile_acidity", 13 | "citric_acid", 14 | "residual_sugar", 15 | "chlorides", 16 | "free_sulfur_dioxide", 17 | "total_sulfur_dioxide", 18 | "density", 19 | "ph", 20 | "sulphates", 21 | "alcohol_pct_vol", 22 | ] 23 | 24 | 25 | class Wine(BaseModel): 26 | fixed_acidity: float = Field( 27 | ..., ge=0, description="grams per cubic decimeter of tartaric acid" 28 | ) 29 | volatile_acidity: float = Field( 30 | ..., ge=0, description="grams per cubic decimeter of acetic acid" 31 | ) 32 | citric_acid: float = Field(..., ge=0, description="grams per cubic decimeter of citric acid") 33 | residual_sugar: float = Field( 34 | ..., ge=0, description="grams per cubic decimeter of residual sugar" 35 | ) 36 | chlorides: float = Field(..., ge=0, description="grams per cubic decimeter of sodium chloride") 37 | free_sulfur_dioxide: float = Field( 38 | ..., ge=0, description="milligrams per cubic decimeter of free sulfur dioxide" 39 | ) 40 | total_sulfur_dioxide: float = Field( 41 | ..., ge=0, description="milligrams per cubic decimeter of total sulfur dioxide" 42 | ) 43 | density: float = Field(..., ge=0, description="grams per cubic meter") 44 | ph: float = Field(..., ge=0, lt=14, description="measure of the acidity or basicity") 45 | sulphates: float = Field( 46 | ..., ge=0, description="grams per cubic decimeter of potassium sulphate" 47 | ) 48 | alcohol_pct_vol: float = Field(..., ge=0, le=100, description="alcohol percent by volume") 49 | 50 | 51 | class Rating(BaseModel): 52 | quality: float = Field( 53 | ..., 54 | ge=0, 55 | le=10, 56 | description="wine quality grade ranging from 0 (very bad) to 10 (excellent)", 57 | ) 58 | -------------------------------------------------------------------------------- /model/artifacts/model.joblib: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jeremyjordan/ml-monitoring/7cf0f336bd4cbbfd6587fd24edb84ff84620c26b/model/artifacts/model.joblib -------------------------------------------------------------------------------- /model/artifacts/scaler.joblib: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jeremyjordan/ml-monitoring/7cf0f336bd4cbbfd6587fd24edb84ff84620c26b/model/artifacts/scaler.joblib -------------------------------------------------------------------------------- /model/requirements.txt: -------------------------------------------------------------------------------- 1 | fastapi==0.109.1 2 | numpy==1.18.5 3 | pandas==0.23.3 4 | prometheus-fastapi-instrumentator==5.7.0 5 | pydantic==1.6.2 6 | scikit-learn==0.24.0 7 | uvicorn==0.13.2 -------------------------------------------------------------------------------- /model/setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import find_packages, setup 2 | 3 | setup( 4 | name="model", 5 | version="0.1", 6 | packages=find_packages(exclude=["*.tests", "*.tests.*", "tests.*", "tests"]), 7 | install_requires=[], 8 | ) 9 | -------------------------------------------------------------------------------- /model/train.py: -------------------------------------------------------------------------------- 1 | """ 2 | Train a scikit-learn model on UCI Wine Quality Dataset 3 | https://archive.ics.uci.edu/ml/datasets/wine+quality 4 | """ 5 | 6 | import logging 7 | from pathlib import Path 8 | 9 | import pandas as pd 10 | from joblib import dump 11 | from sklearn import preprocessing 12 | from sklearn.experimental import enable_hist_gradient_boosting # noqa 13 | from sklearn.ensemble import HistGradientBoostingRegressor 14 | from sklearn.metrics import mean_squared_error 15 | from sklearn.model_selection import train_test_split 16 | 17 | logger = logging.getLogger(__name__) 18 | 19 | 20 | def prepare_dataset(test_size=0.2, random_seed=1): 21 | dataset = pd.read_csv( 22 | "https://archive.ics.uci.edu/ml/machine-learning-databases/wine-quality/winequality-red.csv", 23 | delimiter=";", 24 | ) 25 | dataset = dataset.rename(columns=lambda x: x.lower().replace(" ", "_")) 26 | train_df, test_df = train_test_split(dataset, test_size=test_size, random_state=random_seed) 27 | return {"train": train_df, "test": test_df} 28 | 29 | 30 | def train(): 31 | logger.info("Preparing dataset...") 32 | dataset = prepare_dataset() 33 | train_df = dataset["train"] 34 | test_df = dataset["test"] 35 | 36 | # separate features from target 37 | y_train = train_df["quality"] 38 | X_train = train_df.drop("quality", axis=1) 39 | y_test = test_df["quality"] 40 | X_test = test_df.drop("quality", axis=1) 41 | 42 | logger.info("Training model...") 43 | scaler = preprocessing.StandardScaler().fit(X_train) 44 | X_train = scaler.transform(X_train) 45 | X_test = scaler.transform(X_test) 46 | model = HistGradientBoostingRegressor(max_iter=50).fit(X_train, y_train) 47 | 48 | y_pred = model.predict(X_test) 49 | error = mean_squared_error(y_test, y_pred) 50 | logger.info(f"Test MSE: {error}") 51 | 52 | logger.info("Saving artifacts...") 53 | Path("artifacts").mkdir(exist_ok=True) 54 | dump(model, "artifacts/model.joblib") 55 | dump(scaler, "artifacts/scaler.joblib") 56 | 57 | 58 | if __name__ == "__main__": 59 | logging.basicConfig(level=logging.INFO) 60 | train() 61 | --------------------------------------------------------------------------------