├── tests
├── __init__.py
├── resources
│ ├── __init__.py
│ ├── sample.json
│ ├── linear_model.py
│ └── linear_handler.py
├── conftest.py
├── fixtures
│ └── plugin_fixtures.py
├── examples
│ ├── test_torchserve_examples.py
│ └── test_end_to_end_example.py
├── test_plugin.py
└── test_cli.py
├── .flake8
├── examples
├── E2EBert
│ ├── input.json
│ ├── class_mapping.json
│ ├── AGNews_Captum_Insights.png
│ ├── requirements.txt
│ ├── conda.yaml
│ ├── MLproject
│ ├── wrapper.py
│ ├── README.md
│ └── news_classifier_handler.py
├── BertNewsClassification
│ ├── input.json
│ ├── class_mapping.json
│ ├── requirements.txt
│ ├── conda.yaml
│ ├── MLproject
│ ├── news_classifier_handler.py
│ └── README.md
├── Titanic
│ ├── index_to_name.json
│ ├── test_data
│ │ ├── input.json
│ │ ├── titanic_survived.csv
│ │ └── titanic_not_survived.csv
│ ├── conda.yaml
│ ├── MLproject
│ ├── titanic.py
│ ├── README.md
│ └── titanic_handler.py
├── IrisClassificationTorchScript
│ ├── sample.json
│ ├── index_to_name.json
│ ├── conda.yaml
│ ├── MLproject
│ ├── iris_datamodule.py
│ ├── README.md
│ ├── iris_handler.py
│ └── iris_classification.py
├── TextClassification
│ ├── sample_text.txt
│ ├── index_to_name.json
│ ├── predict.py
│ ├── README.md
│ └── create_deployment.py
├── IrisClassification
│ ├── index_to_name.json
│ ├── sig_invalid_column_name.json
│ ├── sample.json
│ ├── sig_invalid_data_type.json
│ ├── conda.yaml
│ ├── MLproject
│ ├── predict.py
│ ├── create_deployment.py
│ ├── iris_handler.py
│ ├── iris_data_module.py
│ ├── README.md
│ └── iris_classification.py
├── MNIST
│ ├── test_data
│ │ └── one.png
│ ├── index_to_name.json
│ ├── conda.yaml
│ ├── MLproject
│ ├── register.py
│ ├── predict.py
│ ├── create_deployment.py
│ ├── README.md
│ ├── mnist_handler.py
│ └── mnist_model.py
├── cifar10
│ ├── test_data
│ │ └── kitten.png
│ ├── index_to_name.json
│ ├── conda.yaml
│ ├── MLproject
│ ├── inference.py
│ ├── create_deployment.py
│ ├── README.md
│ ├── cifar10_handler.py
│ ├── cifar10_train.py
│ └── cifar10_datamodule.py
└── README.md
├── config.properties
├── mlflow_torchserve
└── config.py
├── utils
└── remove-conda-envs.sh
├── setup.py
├── .github
├── ISSUE_TEMPLATE
│ ├── documentation-fix.md
│ ├── feature_request.md
│ └── bug_report.md
├── pull_request_template.md
└── workflows
│ ├── ci.yml
│ └── build-wheel.yml
├── .gitignore
├── README.md
├── docs
└── remote-deployment.rst
└── LICENSE.txt
/tests/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/tests/resources/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/tests/resources/sample.json:
--------------------------------------------------------------------------------
1 | {
2 | "data": [2000]
3 |
4 | }
5 |
--------------------------------------------------------------------------------
/.flake8:
--------------------------------------------------------------------------------
1 | [flake8]
2 | max-line-length = 100
3 | extend-ignore = E203
4 |
--------------------------------------------------------------------------------
/examples/E2EBert/input.json:
--------------------------------------------------------------------------------
1 | {
2 | "data": ["This year business is good"]
3 | }
--------------------------------------------------------------------------------
/examples/BertNewsClassification/input.json:
--------------------------------------------------------------------------------
1 | {
2 | "data": ["This year business is good"]
3 | }
--------------------------------------------------------------------------------
/examples/Titanic/index_to_name.json:
--------------------------------------------------------------------------------
1 | {
2 | "0":"Not Survived",
3 | "1":"Survived"
4 | }
5 |
--------------------------------------------------------------------------------
/examples/IrisClassificationTorchScript/sample.json:
--------------------------------------------------------------------------------
1 | {"data": ["[4.4000, 3.0000, 1.3000, 0.2000]"]}
2 |
--------------------------------------------------------------------------------
/examples/Titanic/test_data/input.json:
--------------------------------------------------------------------------------
1 | {"input_file_path" : ["./test_data/titanic_survived.csv"]}
2 |
--------------------------------------------------------------------------------
/examples/TextClassification/sample_text.txt:
--------------------------------------------------------------------------------
1 | Bloomberg has decided to publish a new report on global economic situation.
2 |
--------------------------------------------------------------------------------
/examples/IrisClassification/index_to_name.json:
--------------------------------------------------------------------------------
1 | {
2 | "0": "SETOSA",
3 | "1": "VERSICOLOR",
4 | "2": "VIRGINICA"
5 | }
--------------------------------------------------------------------------------
/examples/MNIST/test_data/one.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/mlflow/mlflow-torchserve/HEAD/examples/MNIST/test_data/one.png
--------------------------------------------------------------------------------
/examples/IrisClassificationTorchScript/index_to_name.json:
--------------------------------------------------------------------------------
1 | {
2 | "0": "SETOSA",
3 | "1": "VERSICOLOR",
4 | "2": "VIRGINICA"
5 | }
--------------------------------------------------------------------------------
/tests/conftest.py:
--------------------------------------------------------------------------------
1 | # flake8: noqa
2 | from tests.fixtures.plugin_fixtures import start_torchserve, stop_torchserve, health_checkup
3 |
--------------------------------------------------------------------------------
/examples/E2EBert/class_mapping.json:
--------------------------------------------------------------------------------
1 | {
2 | "0": "World",
3 | "1": "Sports",
4 | "2": "Business",
5 | "3": "Sci/Tech"
6 | }
--------------------------------------------------------------------------------
/examples/TextClassification/index_to_name.json:
--------------------------------------------------------------------------------
1 | {
2 | "0":"World",
3 | "1":"Sports",
4 | "2":"Business",
5 | "3":"Sci/Tec"
6 | }
7 |
--------------------------------------------------------------------------------
/examples/cifar10/test_data/kitten.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/mlflow/mlflow-torchserve/HEAD/examples/cifar10/test_data/kitten.png
--------------------------------------------------------------------------------
/config.properties:
--------------------------------------------------------------------------------
1 | inference_address=http://127.0.0.1:8080
2 | management_address=http://127.0.0.1:8081
3 | export_url=http://127.0.0.1:8000
4 |
--------------------------------------------------------------------------------
/examples/E2EBert/AGNews_Captum_Insights.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/mlflow/mlflow-torchserve/HEAD/examples/E2EBert/AGNews_Captum_Insights.png
--------------------------------------------------------------------------------
/examples/BertNewsClassification/class_mapping.json:
--------------------------------------------------------------------------------
1 | {
2 | "0": "World",
3 | "1": "Sports",
4 | "2": "Business",
5 | "3": "Sci/Tech"
6 | }
--------------------------------------------------------------------------------
/examples/Titanic/test_data/titanic_survived.csv:
--------------------------------------------------------------------------------
1 | age,sibsp,parch,fare,female,male,embark_C,embark_Q,embark_S,class_1,class_2,class_3
2 | 29,0,0,211.3375,1,0,0,0,1,1,0,0
--------------------------------------------------------------------------------
/examples/Titanic/test_data/titanic_not_survived.csv:
--------------------------------------------------------------------------------
1 | age,sibsp,parch,fare,female,male,embark_C,embark_Q,embark_S,class_1,class_2,class_3
2 | 30.0,1,2,151.55,0,1,0,0,1,1,0,0
3 |
--------------------------------------------------------------------------------
/examples/IrisClassification/sig_invalid_column_name.json:
--------------------------------------------------------------------------------
1 | {"data": ["{\"length (cm)\":{\"0\":4.4},\"width (cm)\":{\"0\":3.2},\"length (cm)\":{\"0\":1.3},\"width (cm)\":{\"0\":0.2}}"]}
2 |
--------------------------------------------------------------------------------
/examples/IrisClassification/sample.json:
--------------------------------------------------------------------------------
1 | {"data": ["{\"sepal length (cm)\":{\"0\":4.4},\"sepal width (cm)\":{\"0\":3.2},\"petal length (cm)\":{\"0\":1.3},\"petal width (cm)\":{\"0\":0.2}}"]}
2 |
--------------------------------------------------------------------------------
/examples/IrisClassification/sig_invalid_data_type.json:
--------------------------------------------------------------------------------
1 | {"data": ["{\"sepal length (cm)\":{\"0\":4},\"sepal width (cm)\":{\"0\":3},\"petal length (cm)\":{\"0\":1},\"petal width (cm)\":{\"0\":0}}"]}
2 |
--------------------------------------------------------------------------------
/examples/BertNewsClassification/requirements.txt:
--------------------------------------------------------------------------------
1 | torch
2 | mlflow
3 | torchserve
4 | torch-model-archiver
5 | numpy
6 | transformers
7 | sklearn
8 | tqdm
9 | torchtext
10 | pandas
11 | datasets
12 |
--------------------------------------------------------------------------------
/examples/E2EBert/requirements.txt:
--------------------------------------------------------------------------------
1 | torchvision>=0.9.1
2 | torch>=1.12.0
3 | pytorch_lightning>=1.7.6
4 | mlflow
5 | torchserve
6 | torch-model-archiver
7 | numpy
8 | transformers
9 | sklearn
10 | torchtext>=0.13.0
11 |
--------------------------------------------------------------------------------
/examples/MNIST/index_to_name.json:
--------------------------------------------------------------------------------
1 | {
2 | "0": "ZERO",
3 | "1": "ONE",
4 | "2": "TWO",
5 | "3": "THREE",
6 | "4": "FOUR",
7 | "5": "FIVE",
8 | "6": "SIX",
9 | "7": "SEVEN",
10 | "8": "EIGHT",
11 | "9": "NINE"
12 | }
--------------------------------------------------------------------------------
/examples/IrisClassification/conda.yaml:
--------------------------------------------------------------------------------
1 | channels:
2 | - defaults
3 | - conda-forge
4 | - pytorch
5 | dependencies:
6 | - python=3.8.2
7 | - pip
8 | - pip:
9 | - sklearn
10 | - pytorch-lightning>=1.7.6
11 | - torch
12 | - torchvision
13 | - mlflow
14 |
--------------------------------------------------------------------------------
/examples/cifar10/index_to_name.json:
--------------------------------------------------------------------------------
1 | {
2 | "0": "airplane",
3 | "1": "automobile",
4 | "2": "bird",
5 | "3": "cat",
6 | "4": "deer",
7 | "5": "dog",
8 | "6": "frog",
9 | "7": "horse",
10 | "8": "ship",
11 | "9": "truck"
12 | }
--------------------------------------------------------------------------------
/examples/Titanic/conda.yaml:
--------------------------------------------------------------------------------
1 | channels:
2 | - pytorch
3 | - defaults
4 | - conda-forge
5 | dependencies:
6 | - python=3.8.2
7 | - pytorch>=1.9.0
8 | - pip
9 | - pip:
10 | - mlflow
11 | - pandas
12 | - captum
13 | - sklearn
14 | - ipython
15 | - prettytable
16 |
17 |
--------------------------------------------------------------------------------
/examples/MNIST/conda.yaml:
--------------------------------------------------------------------------------
1 | channels:
2 | - defaults
3 | - conda-forge
4 | - pytorch
5 | dependencies:
6 | - python=3.8.2
7 | - pip
8 | - pip:
9 | - torch>=1.9.0
10 | - torchvision
11 | - matplotlib
12 | - pytorch-lightning>=1.7.6
13 | - mlflow>=1.14.0
14 | - captum
15 |
--------------------------------------------------------------------------------
/examples/BertNewsClassification/conda.yaml:
--------------------------------------------------------------------------------
1 | channels:
2 | - defaults
3 | - conda-forge
4 | - pytorch
5 | dependencies:
6 | - python=3.8.2
7 | - pip
8 | - pip:
9 | - mlflow
10 | - scikit-learn
11 | - transformers
12 | - torchtext
13 | - torchvision
14 | - torch>=1.9.0
15 | - pandas
16 | - datasets
17 |
--------------------------------------------------------------------------------
/examples/cifar10/conda.yaml:
--------------------------------------------------------------------------------
1 | channels:
2 | - defaults
3 | - conda-forge
4 | - pytorch
5 | dependencies:
6 | - python=3.8.2
7 | - pip
8 | - pip:
9 | - matplotlib
10 | - torch>=1.9.0
11 | - pytorch-lightning>=1.7.6
12 | - mlflow>=1.14.0
13 | - captum
14 | - webdataset
15 | - sklearn
16 | - torchvision
17 |
--------------------------------------------------------------------------------
/examples/E2EBert/conda.yaml:
--------------------------------------------------------------------------------
1 | channels:
2 | - defaults
3 | - pytorch
4 | - conda-forge
5 | dependencies:
6 | - python=3.8.2
7 | - pip
8 | - pip:
9 | - mlflow
10 | - captum
11 | - sklearn
12 | - transformers
13 | - torch>=1.9.0
14 | - torchtext>=0.10.0
15 | - pytorch-lightning>=1.7.6
16 | - torchvision>=0.10.0
17 | - torchdata
18 |
--------------------------------------------------------------------------------
/examples/IrisClassificationTorchScript/conda.yaml:
--------------------------------------------------------------------------------
1 | channels:
2 | - defaults
3 | - conda-forge
4 | - pytorch
5 | dependencies:
6 | - python=3.8.2
7 | - pip
8 | - pip:
9 | - mlflow
10 | - scikit-learn
11 | - torchvision
12 | - torch>=1.9.0
13 | - torchvision
14 | - pytorch-lightning>=1.7.6
15 |
16 |
--------------------------------------------------------------------------------
/examples/Titanic/MLproject:
--------------------------------------------------------------------------------
1 | name: Titanic-Captum-Example
2 |
3 | conda_env: conda.yaml
4 |
5 | entry_points:
6 | main:
7 | parameters:
8 | max_epochs: {type: int, default: 50}
9 | lr: {type: float, default: 0.1}
10 |
11 | command: |
12 | python titanic_captum_interpret.py \
13 | --max_epochs {max_epochs} \
14 | --lr {lr}
15 |
--------------------------------------------------------------------------------
/mlflow_torchserve/config.py:
--------------------------------------------------------------------------------
1 | import os
2 |
3 |
4 | class Config(dict):
5 | def __init__(self):
6 | """
7 | Initializes constants from Environment variables
8 | """
9 | super().__init__()
10 | self["export_path"] = os.environ.get("EXPORT_PATH")
11 | self["config_properties"] = os.environ.get("CONFIG_PROPERTIES")
12 | self["torchserve_address_names"] = ["inference_address", "management_address", "export_url"]
13 | self["export_uri"] = os.environ.get("EXPORT_URL")
14 |
--------------------------------------------------------------------------------
/utils/remove-conda-envs.sh:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env bash
2 |
3 | set -ex
4 |
5 | mlflow_envs=$(
6 | conda env list | # list (env name, env path) pairs
7 | cut -d' ' -f1 | # extract env names
8 | grep "^mlflow-[a-z0-9]\{40\}\$" # filter envs created by mlflow
9 | ) || true
10 |
11 | if [ ! -z "$mlflow_envs" ]; then
12 | for env in $mlflow_envs
13 | do
14 | conda remove --all --yes --name $env
15 | done
16 |
17 | conda clean --all --yes
18 | conda env list
19 | fi
20 |
21 | set +ex
22 |
--------------------------------------------------------------------------------
/examples/IrisClassification/MLproject:
--------------------------------------------------------------------------------
1 | name: iris-classification
2 |
3 | conda_env: conda.yaml
4 |
5 | entry_points:
6 | main:
7 | parameters:
8 | max_epochs: {type: int, default: 100}
9 | devices: {type: int, default: None}
10 | strategy: {type str, default: "None"}
11 | accelerator: {type str, default: "None"}
12 |
13 | command: |
14 | python iris_classification.py \
15 | --max_epochs {max_epochs} \
16 | --devices {devices} \
17 | --strategy {strategy} \
18 | --accelerator {accelerator}
19 |
--------------------------------------------------------------------------------
/examples/IrisClassificationTorchScript/MLproject:
--------------------------------------------------------------------------------
1 | name: iris-classification
2 |
3 | conda_env: conda.yaml
4 |
5 | entry_points:
6 | main:
7 | parameters:
8 | max_epochs: {type: int, default: 100}
9 | devices: {type: int, default: None}
10 | strategy: {type str, default: "None"}
11 | accelerator: {type str, default: "None"}
12 |
13 | command: |
14 | python iris_classification.py \
15 | --max_epochs {max_epochs} \
16 | --devices {devices} \
17 | --strategy {strategy} \
18 | --accelerator {accelerator}
19 |
--------------------------------------------------------------------------------
/setup.py:
--------------------------------------------------------------------------------
1 | from setuptools import setup, find_packages
2 |
3 |
4 | setup(
5 | name="mlflow-torchserve",
6 | version="0.2.0",
7 | description="Torch Serve Mlflow Deployment",
8 | long_description=open("README.md").read(),
9 | long_description_content_type="text/markdown",
10 | packages=find_packages(),
11 | # Require MLflow as a dependency of the plugin, so that plugin users can simply install
12 | # the plugin & then immediately use it with MLflow
13 | install_requires=["torchserve", "torch-model-archiver", "mlflow"],
14 | entry_points={"mlflow.deployments": "torchserve=mlflow_torchserve"},
15 | )
16 |
--------------------------------------------------------------------------------
/examples/Titanic/titanic.py:
--------------------------------------------------------------------------------
1 | import torch.nn as nn
2 |
3 |
4 | class TitanicSimpleNNModel(nn.Module):
5 | def __init__(self):
6 | super().__init__()
7 | self.linear1 = nn.Linear(12, 12)
8 | self.sigmoid1 = nn.Sigmoid()
9 | self.linear2 = nn.Linear(12, 8)
10 | self.sigmoid2 = nn.Sigmoid()
11 | self.linear3 = nn.Linear(8, 2)
12 | self.softmax = nn.Softmax(dim=1)
13 |
14 | def forward(self, x):
15 | lin1_out = self.linear1(x)
16 | sigmoid_out1 = self.sigmoid1(lin1_out)
17 | sigmoid_out2 = self.sigmoid2(self.linear2(sigmoid_out1))
18 | return self.softmax(self.linear3(sigmoid_out2))
19 |
--------------------------------------------------------------------------------
/examples/MNIST/MLproject:
--------------------------------------------------------------------------------
1 | name: mnist-example
2 |
3 | conda_env: conda.yaml
4 |
5 | entry_points:
6 | main:
7 | parameters:
8 | max_epochs: {type: int, default: 5}
9 | devices: {type: int, default: "None"}
10 | strategy: {type: str, default: "None"}
11 | accelerator: {type: str, default: "None"}
12 | registration_name: {type: str, default: "mnist_classifier"}
13 |
14 | command: |
15 | python mnist_model.py \
16 | --max_epochs {max_epochs} \
17 | --devices {devices} \
18 | --strategy {strategy} \
19 | --accelerator {accelerator} \
20 | --registration_name {registration_name}
21 |
--------------------------------------------------------------------------------
/examples/BertNewsClassification/MLproject:
--------------------------------------------------------------------------------
1 | name: bert-classification
2 |
3 | conda_env: conda.yaml
4 |
5 | entry_points:
6 | main:
7 | parameters:
8 | max_epochs: {type: int, default: 5}
9 | num_train_samples: {type:int, default: 2000}
10 | num_test_samples: {type:int, default: 200}
11 | vocab_file: {type: str, default: 'https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-uncased-vocab.txt'}
12 | model_save_path: {type: str, default: 'models'}
13 |
14 | command: |
15 | python news_classifier.py \
16 | --max_epochs {max_epochs} \
17 | --num_train_samples {num_train_samples} \
18 | --num_test_samples {num_test_samples} \
19 | --vocab_file {vocab_file} \
20 | --model_save_path {model_save_path}
21 |
--------------------------------------------------------------------------------
/examples/E2EBert/MLproject:
--------------------------------------------------------------------------------
1 | name: bert-classification
2 |
3 | conda_env: conda.yaml
4 |
5 | entry_points:
6 | main:
7 | parameters:
8 | max_epochs: {type: int, default: 5}
9 | devices: {type: int, default: None}
10 | num_samples: {type: int, default: 1000}
11 | vocab_file: {type: str, default: 'https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-uncased-vocab.txt'}
12 | strategy: {type str, default: None}
13 | accelerator: {type str, default: None}
14 |
15 | command: |
16 | python news_classifier.py \
17 | --max_epochs {max_epochs} \
18 | --devices {devices} \
19 | --num_samples {num_samples} \
20 | --vocab_file {vocab_file} \
21 | --strategy={strategy} \
22 | --accelerator={accelerator}
--------------------------------------------------------------------------------
/examples/cifar10/MLproject:
--------------------------------------------------------------------------------
1 | name: cifar10-example
2 |
3 | conda_env: conda.yaml
4 |
5 | entry_points:
6 | main:
7 | parameters:
8 | max_epochs: {type: int, default: 1}
9 | devices: {type: int, default: None}
10 | strategy: {type: str, default: "None"}
11 | accelerator: {type: str, default: "None"}
12 | num_samples_train: {type: int, default: 39}
13 | num_samples_val: {type: int, default: 9}
14 | num_samples_test: {type: int, default: 9}
15 |
16 | command: |
17 | python cifar10_train.py \
18 | --max_epochs {max_epochs} \
19 | --devices {devices} \
20 | --num_samples_train {num_samples_train} \
21 | --num_samples_val {num_samples_val} \
22 | --num_samples_test {num_samples_test} \
23 | --strategy {strategy} \
24 | --accelerator {accelerator}
25 |
--------------------------------------------------------------------------------
/examples/MNIST/register.py:
--------------------------------------------------------------------------------
1 | from argparse import ArgumentParser
2 |
3 | from mlflow.deployments import get_deploy_client
4 |
5 |
6 | def register(parser_args):
7 | plugin = get_deploy_client(parser_args["target"])
8 | plugin.register_model(mar_file_path=parser_args["mar_file_name"])
9 | print("Registered Successfully")
10 |
11 |
12 | if __name__ == "__main__":
13 | parser = ArgumentParser(description="MNIST hand written digits classification example")
14 |
15 | parser.add_argument(
16 | "--target",
17 | type=str,
18 | default="torchserve",
19 | help="MLflow target (default: torchserve)",
20 | )
21 |
22 | parser.add_argument(
23 | "--mar_file_name",
24 | type=str,
25 | default="",
26 | help="mar file name to register (Ex: mnist_test.mar)",
27 | )
28 |
29 | args = parser.parse_args()
30 |
31 | register(vars(args))
32 |
--------------------------------------------------------------------------------
/.github/ISSUE_TEMPLATE/documentation-fix.md:
--------------------------------------------------------------------------------
1 | ---
2 | name: Documentation Fix
3 | about: Use this template for proposing documentation fixes/improvements.
4 | title: "[DOC-FIX]"
5 | labels: ''
6 | assignees: ''
7 |
8 | ---
9 |
10 | Thank you for submitting an issue. Please refer to our [issue policy](https://www.github.com/mlflow/mlflow/blob/master/ISSUE_POLICY.md) for information on what types of issues we address.
11 |
12 | **Please fill in this documentation issue template to ensure a timely and thorough response.**
13 |
14 | ### Willingness to contribute
15 | The MLflow Community encourages documentation fix contributions. Would you or another member of your organization be willing to contribute a fix for this documentation issue to the MLflow TorchServe Deployment plugin code base?
16 |
17 | - [ ] Yes. I can contribute a documentation fix independently.
18 | - [ ] Yes. I would be willing to contribute a document fix with guidance from the MLflow community.
19 | - [ ] No. I cannot contribute a documentation fix at this time.
20 |
21 | ### URL(s) with the issue:
22 |
23 | Please provide a link to the documentation entry in question.
24 |
25 | ### Description of proposal (what needs changing):
26 | Provide a clear description. Why is the proposed documentation better?
27 |
--------------------------------------------------------------------------------
/examples/TextClassification/predict.py:
--------------------------------------------------------------------------------
1 | import json
2 | from argparse import ArgumentParser
3 |
4 | from mlflow.deployments import get_deploy_client
5 |
6 |
7 | def predict(parser_args):
8 |
9 | with open(parser_args["input_file_path"], "r") as fp:
10 | text = fp.read()
11 | plugin = get_deploy_client(parser_args["target"])
12 | prediction = plugin.predict(parser_args["deployment_name"], json.dumps(text))
13 | print("Prediction Result {}".format(prediction.to_json()))
14 |
15 |
16 | if __name__ == "__main__":
17 | parser = ArgumentParser(description="Text classifier example")
18 |
19 | parser.add_argument(
20 | "--target",
21 | type=str,
22 | default="torchserve",
23 | help="MLflow target (default: torchserve)",
24 | )
25 |
26 | parser.add_argument(
27 | "--deployment_name",
28 | type=str,
29 | default="text_classification",
30 | help="Deployment name (default: text_classification)",
31 | )
32 |
33 | parser.add_argument(
34 | "--input_file_path",
35 | type=str,
36 | default="sample_text.txt",
37 | help="Path to input text file for prediction (default: sample_text.txt)",
38 | )
39 |
40 | args = parser.parse_args()
41 |
42 | predict(vars(args))
43 |
--------------------------------------------------------------------------------
/examples/TextClassification/README.md:
--------------------------------------------------------------------------------
1 | # Deploying Text Classification model
2 |
3 | Download `train.py` and `model.py` from the [respository](https://github.com/pytorch/serve/tree/master/examples/text_classification)
4 | and subsequently run the following command to train the model in either CPU/GPU.
5 |
6 | CPU: `python train.py AG_NEWS --device cpu --save-model-path model.pt --dictionary source_vocab.pt`
7 |
8 | GPU: `python train.py AG_NEWS --device cuda --save-model-path model.pt --dictionary source_vocab.pt`
9 |
10 | At the end of the training, model file `model.pt` and vocabulary file `source_vocab.pt` will be stored in the current directory.
11 |
12 | ## Starting TorchServe
13 |
14 | create an empty directory `model_store` and run the following command to start torchserve.
15 |
16 | `torchserve --start --model-store model_store`
17 |
18 | ## Creating a new deployment
19 |
20 | This example uses the default TorchServe text handler to generate the mar file.
21 |
22 | To create a new deployment, run the following command
23 |
24 | `python create_deployment.py --deployment_name text_classification --model_file model.py --serialized_file model.pt --extra_files "source_vocab.pt,index_to_name.json"`
25 |
26 | ## Predicting deployed model
27 |
28 | To perform prediction, run the following script
29 |
30 | `python predict.py`
31 |
32 | The prediction results will be printed in the console.
33 |
--------------------------------------------------------------------------------
/.github/pull_request_template.md:
--------------------------------------------------------------------------------
1 | ## What changes are proposed in this pull request?
2 |
3 | (Please fill in changes proposed in this fix)
4 |
5 | ## How is this patch tested?
6 |
7 | (Details)
8 |
9 | ## Release Notes
10 |
11 | ### Is this a user-facing change?
12 |
13 | - [ ] No. You can skip the rest of this section.
14 | - [ ] Yes. Give a description of this change to be included in the release notes for MLflow TorchServe Deployment Plugin users.
15 |
16 | (Details in 1-2 sentences. You can just refer to another PR with a description if this PR is part of a larger change.)
17 |
18 | ### What component(s) does this PR affect?
19 | Components
20 | - [ ] `area/deploy`: Main deployment plugin logic
21 | - [ ] `area/build`: Build and test infrastructure for MLflow TorchServe Deployment Plugin
22 | - [ ] `area/docs`: MLflow TorchServe Deployment Plugin documentation pages
23 | - [ ] `area/examples`: Example code
24 |
25 |
26 | ### How should the PR be classified in the release notes? Choose one:
27 |
28 | - [ ] `rn/breaking-change` - The PR will be mentioned in the "Breaking Changes" section
29 | - [ ] `rn/none` - No description will be included. The PR will be mentioned only by the PR number in the "Small Bugfixes and Documentation Updates" section
30 | - [ ] `rn/feature` - A new user-facing feature worth mentioning in the release notes
31 | - [ ] `rn/bug-fix` - A user-facing bug fix worth mentioning in the release notes
32 | - [ ] `rn/documentation` - A user-facing documentation change worth mentioning in the release notes
33 |
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | # Mlflow
2 | mlruns/
3 | outputs/
4 |
5 | # Directories and files generated when running the examples
6 | sample.json
7 | output.json
8 | bert_base_uncased_vocab.txt
9 | *.pt
10 | data/
11 | .data/
12 | logs/
13 | model_store/
14 | models/
15 |
16 | # Mac
17 | .DS_Store
18 |
19 | # Byte-compiled / optimized / DLL files
20 | __pycache__
21 | *.py[cod]
22 | *$py.class
23 |
24 | # C extensions
25 | *.so
26 |
27 | # Distribution / packaging
28 | .Python
29 | build/
30 | develop-eggs/
31 | dist/
32 | downloads/
33 | eggs/
34 | .eggs/
35 | lib/
36 | lib64/
37 | parts/
38 | sdist/
39 | var/
40 | wheels/
41 | *.egg-info/
42 | .installed.cfg
43 | *.egg
44 | MANIFEST
45 | node_modules
46 |
47 | # PyInstaller
48 | # Usually these files are written by a python script from a template
49 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
50 | *.manifest
51 | *.spec
52 |
53 | # Installer logs
54 | pip-log.txt
55 | pip-delete-this-directory.txt
56 |
57 | # Unit test / coverage reports
58 | htmlcov/
59 | .coverage
60 | .coverage.*
61 | .cache
62 | nosetests.xml
63 | coverage.xml
64 | *.cover
65 | .hypothesis/
66 | .pytest_cache/
67 |
68 | # Sphinx documentation
69 | docs/_build/
70 |
71 | # Jupyter Notebook
72 | .ipynb_checkpoints
73 |
74 | # Environments
75 | env
76 | env3
77 | .env
78 | .venv
79 | env/
80 | venv/
81 | ENV/
82 | env.bak/
83 | venv.bak/
84 | .python-version
85 |
86 | # Editor files
87 | .*project
88 | *.swp
89 | *.swo
90 | *.idea
91 | *.vscode
92 | *.iml
93 | *~
94 |
95 |
96 |
97 |
98 |
--------------------------------------------------------------------------------
/examples/MNIST/predict.py:
--------------------------------------------------------------------------------
1 | import os
2 | from argparse import ArgumentParser
3 |
4 | import matplotlib.pyplot as plt
5 | from mlflow.deployments import get_deploy_client
6 | from torchvision import transforms
7 |
8 |
9 | def predict(parser_args):
10 | plugin = get_deploy_client(parser_args["target"])
11 | img = plt.imread(os.path.join(parser_args["input_file_path"]))
12 | mnist_transforms = transforms.Compose(
13 | [transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))]
14 | )
15 |
16 | image_tensor = mnist_transforms(img)
17 | prediction = plugin.predict(parser_args["deployment_name"], image_tensor)
18 | print("Prediction Result {}".format(prediction.to_json()))
19 |
20 |
21 | if __name__ == "__main__":
22 | parser = ArgumentParser(description="MNIST hand written digits classification example")
23 |
24 | parser.add_argument(
25 | "--target",
26 | type=str,
27 | default="torchserve",
28 | help="MLflow target (default: torchserve)",
29 | )
30 |
31 | parser.add_argument(
32 | "--deployment_name",
33 | type=str,
34 | default="mnist_classification",
35 | help="Deployment name (default: mnist_classification)",
36 | )
37 |
38 | parser.add_argument(
39 | "--input_file_path",
40 | type=str,
41 | default="test_data/one.png",
42 | help="Path to input image for prediction (default: test_data/one.png)",
43 | )
44 |
45 | args = parser.parse_args()
46 |
47 | predict(vars(args))
48 |
--------------------------------------------------------------------------------
/examples/IrisClassificationTorchScript/iris_datamodule.py:
--------------------------------------------------------------------------------
1 | import pytorch_lightning as pl
2 | from pytorch_lightning import seed_everything
3 | import torch
4 | from sklearn.datasets import load_iris
5 | from torch.utils.data import TensorDataset
6 | from torch.utils.data import random_split
7 | from torch.utils.data.dataloader import DataLoader
8 |
9 |
10 | class IRISDataModule(pl.LightningDataModule):
11 | def __init__(self):
12 | super().__init__()
13 |
14 | def prepare_data(self):
15 | """
16 | Implementation of abstract class
17 | """
18 |
19 | def setup(self, stage=None):
20 |
21 | # Assign train/val datasets for use in dataloaders
22 | if stage == "fit" or stage is None:
23 | iris = load_iris()
24 | df = iris.data
25 | target = iris["target"]
26 |
27 | data = torch.Tensor(df).float()
28 | labels = torch.Tensor(target).long()
29 | RANDOM_SEED = 42
30 | seed_everything(RANDOM_SEED)
31 |
32 | data_set = TensorDataset(data, labels)
33 | self.train_set, self.val_set = random_split(data_set, [130, 20])
34 | self.train_set, self.test_set = random_split(self.train_set, [110, 20])
35 |
36 | def train_dataloader(self):
37 | return DataLoader(self.train_set, batch_size=4)
38 |
39 | def val_dataloader(self):
40 | return DataLoader(self.val_set, batch_size=4)
41 |
42 | def test_dataloader(self):
43 | return DataLoader(self.test_set, batch_size=4)
44 |
--------------------------------------------------------------------------------
/tests/fixtures/plugin_fixtures.py:
--------------------------------------------------------------------------------
1 | import atexit
2 | import json
3 | import os
4 | import shutil
5 | import subprocess
6 | import time
7 |
8 | import pytest
9 |
10 | from tests.resources import linear_model
11 |
12 |
13 | @pytest.fixture(scope="session")
14 | def start_torchserve():
15 | linear_model.main()
16 | if not os.path.isdir("model_store"):
17 | os.makedirs("model_store")
18 | cmd = "torchserve --ncs --start --model-store {}".format("./model_store")
19 | _ = subprocess.Popen(cmd, shell=True).wait()
20 |
21 | count = 0
22 | for _ in range(5):
23 | value = health_checkup()
24 | if value is not None and value != "" and json.loads(value)["status"] == "Healthy":
25 | time.sleep(1)
26 | break
27 | else:
28 | count += 1
29 | time.sleep(5)
30 | if count >= 5:
31 | raise Exception("Unable to connect to torchserve")
32 | return True
33 |
34 |
35 | def health_checkup():
36 | curl_cmd = "curl http://localhost:8080/ping"
37 | (value, _) = subprocess.Popen([curl_cmd], stdout=subprocess.PIPE, shell=True).communicate()
38 | return value.decode("utf-8")
39 |
40 |
41 | def stop_torchserve():
42 | cmd = "torchserve --stop"
43 | _ = subprocess.Popen(cmd, shell=True).wait()
44 |
45 | if os.path.isdir("model_store"):
46 | shutil.rmtree("model_store")
47 |
48 | if os.path.exists("tests/resources/linear_state_dict.pt"):
49 | os.remove("tests/resources/linear_state_dict.pt")
50 |
51 | if os.path.exists("tests/resources/linear_model.pt"):
52 | os.remove("tests/resources/linear_model.pt")
53 |
54 |
55 | atexit.register(stop_torchserve)
56 |
--------------------------------------------------------------------------------
/examples/README.md:
--------------------------------------------------------------------------------
1 | The examples in the folder illustrates training and deploying models using `mlflow-torchserve` plugin
2 |
3 | 1.`BertNewClassification` - Demonstrates a workflow using a pytorch based BERT model.
4 | The example illustrates
5 |
model training (fine tuning a pre-trained model)
6 | logging of model,summary, parameters and extra files at the end of training
7 | deployment of the model in TorchServe
8 |
9 | 2.`E2EBert` - Demonstrates a workflow using a pytorch-lightning based BERT model.
10 | The example illustrates,
11 | model training (fine tuning a pre-trained model)
12 | model saving and loading using mlflow autolog
13 | deployment of the model in TorchServe
14 | Calculating explanations using captum
15 |
16 | 3.`IrisClassification` - Demonstrates distributed training (DDP) and deployment using Iris Dataset and MLflow-torchserve plugin. This example illustrates the use of model signature.
17 |
18 | 4.`Iris_TorchScript` - Demonstrates saving TorchScript version of the Iris Classification model and deployment of the same using MLflow-torchserve plugin with `MLflow cli` commands
19 |
20 | 5.`MNIST` - Demonstrates training of MNIST handwritten digit recognition and deployment of the model using MLflow-torchserve **python plugin** - This examples illustrates on registering the model to mlflow and deploying the model in remote torchserve instance.
21 |
22 | 6.`TextClassification` - Demonstrates training of TextClassification example and deployment of the model using using MLflow-torchserve **python plugin**
23 |
24 | 7.`Titanic` - Demonstrates training of titanic dataset using pytorch and deploying the model using mlflow-torchserve plugin. This example illustrates the validation using `captum` library.
--------------------------------------------------------------------------------
/examples/TextClassification/create_deployment.py:
--------------------------------------------------------------------------------
1 | from argparse import ArgumentParser
2 | from mlflow.deployments import get_deploy_client
3 |
4 |
5 | def create_deployment(parser_args):
6 | plugin = get_deploy_client(parser_args["target"])
7 | config = {
8 | "MODEL_FILE": parser_args["model_file"],
9 | "HANDLER": parser_args["handler"],
10 | "EXTRA_FILES": "source_vocab.pt,index_to_name.json",
11 | }
12 | result = plugin.create_deployment(
13 | name=parser_args["deployment_name"],
14 | model_uri=parser_args["serialized_file"],
15 | config=config,
16 | )
17 |
18 | print("Deployment {result} created successfully".format(result=result["name"]))
19 |
20 |
21 | if __name__ == "__main__":
22 | parser = ArgumentParser(description="Text Classifier Example")
23 |
24 | parser.add_argument(
25 | "--target",
26 | type=str,
27 | default="torchserve",
28 | help="MLflow target (default: torchserve)",
29 | )
30 |
31 | parser.add_argument(
32 | "--deployment_name",
33 | type=str,
34 | default="text_classification",
35 | help="Deployment name (default: text_classification)",
36 | )
37 |
38 | parser.add_argument(
39 | "--model_file",
40 | type=str,
41 | default="model.py",
42 | help="Model file path (default: model.py)",
43 | )
44 |
45 | parser.add_argument(
46 | "--handler",
47 | type=str,
48 | default="text_classifier",
49 | help="Handler file path (default: text_classifier)",
50 | )
51 |
52 | parser.add_argument(
53 | "--extra_files",
54 | type=str,
55 | default="source_vocab.pt,index_to_name.json",
56 | help="List of extra files",
57 | )
58 |
59 | parser.add_argument(
60 | "--serialized_file",
61 | type=str,
62 | default="model.pt",
63 | help="Pytorch model path (default: model.pt)",
64 | )
65 |
66 | args = parser.parse_args()
67 |
68 | create_deployment(vars(args))
69 |
--------------------------------------------------------------------------------
/examples/IrisClassification/predict.py:
--------------------------------------------------------------------------------
1 | import os
2 | import json
3 | import torch
4 | import ast
5 | from argparse import ArgumentParser
6 |
7 | from mlflow.deployments import get_deploy_client
8 |
9 |
10 | def convert_input_to_tensor(data):
11 | data = json.loads(data).get("data")
12 | input_tensor = torch.Tensor(ast.literal_eval(data[0]))
13 | return input_tensor
14 |
15 |
16 | def predict(parser_args):
17 | plugin = get_deploy_client(parser_args["target"])
18 | input_file = parser_args["input_file_path"]
19 | if not os.path.exists(input_file):
20 | raise Exception("Unable to locate input file : {}".format(input_file))
21 | else:
22 | with open(input_file) as fp:
23 | input_data = fp.read()
24 |
25 | data = json.loads(input_data).get("data")
26 | import pandas as pd
27 |
28 | df = pd.read_json(data[0])
29 | for column in df.columns:
30 | df[column] = df[column].astype("double")
31 |
32 | prediction = plugin.predict(deployment_name=parser_args["deployment_name"], df=input_data)
33 |
34 | print("Prediction Result {}".format(prediction.to_json()))
35 |
36 |
37 | if __name__ == "__main__":
38 | parser = ArgumentParser(description="Iris Classifiation Model")
39 |
40 | parser.add_argument(
41 | "--target",
42 | type=str,
43 | default="torchserve",
44 | help="MLflow target (default: torchserve)",
45 | )
46 |
47 | parser.add_argument(
48 | "--deployment_name",
49 | type=str,
50 | default="iris_classification",
51 | help="Deployment name (default: iris_classification)",
52 | )
53 |
54 | parser.add_argument(
55 | "--input_file_path",
56 | type=str,
57 | default="sample.json",
58 | help="Path to input image for prediction (default: sample.json)",
59 | )
60 |
61 | parser.add_argument(
62 | "--mlflow-model-uri",
63 | type=str,
64 | default="model",
65 | help="MLFlow model URI)",
66 | )
67 | args = parser.parse_args()
68 |
69 | predict(vars(args))
70 |
--------------------------------------------------------------------------------
/examples/cifar10/inference.py:
--------------------------------------------------------------------------------
1 | import base64
2 | import json
3 | from argparse import ArgumentParser
4 |
5 | from mlflow.deployments import get_deploy_client
6 |
7 |
8 | def predict(parser_args):
9 | plugin = get_deploy_client(parser_args["target"])
10 | image = open(parser_args["input_file_path"], "rb") # open binary file in read mode
11 | image_read = image.read()
12 | image_64_encode = base64.b64encode(image_read)
13 | bytes_array = image_64_encode.decode("utf-8")
14 | request = {"data": str(bytes_array)}
15 |
16 | inference_type = parser_args["inference_type"]
17 | if inference_type == "explanation":
18 | result = plugin.explain(parser_args["deployment_name"], json.dumps(request))
19 | else:
20 | result = plugin.predict(parser_args["deployment_name"], json.dumps(request))
21 |
22 | print("Prediction Result {}".format(result))
23 |
24 | output_path = parser_args["output_file_path"]
25 | if output_path:
26 | with open(output_path, "w") as fp:
27 | fp.write(result)
28 |
29 |
30 | if __name__ == "__main__":
31 | parser = ArgumentParser(description="Cifar10 classification example")
32 |
33 | parser.add_argument(
34 | "--target",
35 | type=str,
36 | default="torchserve",
37 | help="MLflow target (default: torchserve)",
38 | )
39 |
40 | parser.add_argument(
41 | "--deployment_name",
42 | type=str,
43 | default="cifar_test",
44 | help="Deployment name (default: cifar_test)",
45 | )
46 |
47 | parser.add_argument(
48 | "--input_file_path",
49 | type=str,
50 | default="test_data/kitten.png",
51 | help="Path to input image for prediction (default: test_data/one.png)",
52 | )
53 |
54 | parser.add_argument(
55 | "--output_file_path",
56 | type=str,
57 | default="",
58 | help="output path to write the result",
59 | )
60 |
61 | parser.add_argument(
62 | "--inference_type",
63 | type=str,
64 | default="predict",
65 | help="Option to run prediction/explanation",
66 | )
67 |
68 | args = parser.parse_args()
69 |
70 | predict(vars(args))
71 |
--------------------------------------------------------------------------------
/tests/examples/test_torchserve_examples.py:
--------------------------------------------------------------------------------
1 | import os
2 | import pytest
3 | import shutil
4 | from mlflow import cli
5 | from click.testing import CliRunner
6 | from mlflow.utils import process
7 |
8 | EXAMPLES_DIR = "examples"
9 |
10 |
11 | def get_free_disk_space():
12 | # https://stackoverflow.com/a/48929832/6943581
13 | return shutil.disk_usage("/")[-1] / (2**30)
14 |
15 |
16 | @pytest.fixture(scope="function", autouse=True)
17 | def clean_envs_and_cache():
18 | yield
19 |
20 | if get_free_disk_space() < 7.0: # unit: GiB
21 | process._exec_cmd(["./utils/remove-conda-envs.sh"])
22 |
23 |
24 | @pytest.mark.parametrize(
25 | "directory, params",
26 | [
27 | ("IrisClassification", ["-P", "max_epochs=10"]),
28 | ("MNIST", ["-P", "max_epochs=1", "-P", "register=false"]),
29 | ("IrisClassificationTorchScript", ["-P", "max_epochs=10"]),
30 | (
31 | "BertNewsClassification",
32 | [
33 | "-P",
34 | "max_epochs=1",
35 | "-P",
36 | "num_train_samples=100",
37 | "-P",
38 | "num_test_samples=100",
39 | ],
40 | ),
41 | (
42 | "E2EBert",
43 | [
44 | "-P",
45 | "max_epochs=1",
46 | "-P",
47 | "num_samples=100",
48 | ],
49 | ),
50 | ("Titanic", ["-P", "max_epochs=10", "-P", "lr=0.1"]),
51 | (
52 | "cifar10",
53 | [
54 | "-P",
55 | "max_epochs=1",
56 | "-P",
57 | "num_samples_train=1",
58 | "-P",
59 | "num_samples_val=1",
60 | "-P",
61 | "num_samples_test=1",
62 | ],
63 | ),
64 | ],
65 | )
66 | def test_mlflow_run_example(directory, params):
67 | example_dir = os.path.join(EXAMPLES_DIR, directory)
68 | cli_run_list = [example_dir] + params
69 | CliRunner().invoke(cli.run, cli_run_list)
70 | # assert res.exit_code == 0, "Got non-zero exit code {0}. Output is: {1}".format(
71 | # res.exit_code, res.output
72 | # )
73 |
--------------------------------------------------------------------------------
/examples/cifar10/create_deployment.py:
--------------------------------------------------------------------------------
1 | from argparse import ArgumentParser
2 |
3 | from mlflow.deployments import get_deploy_client
4 |
5 |
6 | def create_deployment(parser_args):
7 | plugin = get_deploy_client(parser_args["target"])
8 | config = {
9 | "MODEL_FILE": parser_args["model_file"],
10 | "HANDLER": parser_args["handler"],
11 | "EXTRA_FILES": parser_args["extra_files"],
12 | }
13 |
14 | if parser_args["export_path"] != "":
15 | config["EXPORT_PATH"] = parser_args["export_path"]
16 |
17 | result = plugin.create_deployment(
18 | name=parser_args["deployment_name"],
19 | model_uri=parser_args["model_uri"],
20 | config=config,
21 | )
22 |
23 | print("Deployment {result} created successfully".format(result=result["name"]))
24 |
25 |
26 | if __name__ == "__main__":
27 | parser = ArgumentParser(description="Cifar10 classification example")
28 |
29 | parser.add_argument(
30 | "--target",
31 | type=str,
32 | default="torchserve",
33 | help="MLflow target (default: torchserve)",
34 | )
35 |
36 | parser.add_argument(
37 | "--deployment_name",
38 | type=str,
39 | default="cifar_test",
40 | help="Deployment name (default: cifar_test)",
41 | )
42 |
43 | parser.add_argument(
44 | "--model_file",
45 | type=str,
46 | default="cifar10_train.py",
47 | help="Model file path (default: cifar10_train.py)",
48 | )
49 |
50 | parser.add_argument(
51 | "--handler",
52 | type=str,
53 | default="cifar10_handler.py",
54 | help="Handler file path (default: cifar10_handler.py)",
55 | )
56 |
57 | parser.add_argument(
58 | "--extra_files",
59 | type=str,
60 | default="index_to_name.json",
61 | help="List of extra files",
62 | )
63 |
64 | parser.add_argument(
65 | "--model_uri",
66 | type=str,
67 | default="resnet.pth",
68 | help="List of extra files",
69 | )
70 |
71 | parser.add_argument(
72 | "--export_path",
73 | type=str,
74 | default="model_store",
75 | help="Path to model store (default: 'model_store')",
76 | )
77 |
78 | args = parser.parse_args()
79 |
80 | create_deployment(vars(args))
81 |
--------------------------------------------------------------------------------
/examples/IrisClassification/create_deployment.py:
--------------------------------------------------------------------------------
1 | from argparse import ArgumentParser
2 |
3 | from mlflow.deployments import get_deploy_client
4 |
5 |
6 | def create_deployment(parser_args):
7 | plugin = get_deploy_client(parser_args["target"])
8 | config = {
9 | "MODEL_FILE": parser_args["model_file"],
10 | "HANDLER": parser_args["handler"],
11 | "EXTRA_FILES": parser_args["extra_files"],
12 | }
13 |
14 | if parser_args["export_path"] != "":
15 | config["EXPORT_PATH"] = parser_args["export_path"]
16 |
17 | result = plugin.create_deployment(
18 | name=parser_args["deployment_name"],
19 | model_uri=parser_args["serialized_file_path"],
20 | config=config,
21 | )
22 |
23 | print("Deployment {result} created successfully".format(result=result["name"]))
24 |
25 |
26 | if __name__ == "__main__":
27 | parser = ArgumentParser(description="Iris Classification example")
28 |
29 | parser.add_argument(
30 | "--target",
31 | type=str,
32 | default="torchserve",
33 | help="MLflow target (default: torchserve)",
34 | )
35 |
36 | parser.add_argument(
37 | "--deployment_name",
38 | type=str,
39 | default="iris_classification",
40 | help="Deployment name (default: iris_classification)",
41 | )
42 |
43 | parser.add_argument(
44 | "--model_file",
45 | type=str,
46 | default="iris_classification.py",
47 | help="Model file path (default: iris_classification.py)",
48 | )
49 |
50 | parser.add_argument(
51 | "--handler",
52 | type=str,
53 | default="iris_handler.py",
54 | help="Handler file path (default: iris_handler.py)",
55 | )
56 |
57 | parser.add_argument(
58 | "--extra_files",
59 | type=str,
60 | default="index_to_name.json,model/MLmodel",
61 | help="List of extra files",
62 | )
63 |
64 | parser.add_argument(
65 | "--serialized_file_path",
66 | type=str,
67 | default="model",
68 | help="Pytorch model path",
69 | )
70 |
71 | parser.add_argument(
72 | "--export_path",
73 | type=str,
74 | default="",
75 | help="Path to model store (default: '')",
76 | )
77 |
78 | args = parser.parse_args()
79 |
80 | create_deployment(vars(args))
81 |
--------------------------------------------------------------------------------
/examples/MNIST/create_deployment.py:
--------------------------------------------------------------------------------
1 | from argparse import ArgumentParser
2 |
3 | from mlflow.deployments import get_deploy_client
4 |
5 |
6 | def create_deployment(parser_args):
7 | plugin = get_deploy_client(parser_args["target"])
8 | config = {
9 | "MODEL_FILE": parser_args["model_file"],
10 | "HANDLER": parser_args["handler"],
11 | "EXTRA_FILES": parser_args["extra_files"],
12 | }
13 |
14 | if parser_args["export_path"] != "":
15 | config["EXPORT_PATH"] = parser_args["export_path"]
16 |
17 | result = plugin.create_deployment(
18 | name=parser_args["deployment_name"],
19 | model_uri=parser_args["registered_model_uri"],
20 | config=config,
21 | )
22 |
23 | print("Deployment {result} created successfully".format(result=result["name"]))
24 |
25 |
26 | if __name__ == "__main__":
27 | parser = ArgumentParser(description="MNIST hand written digits classification example")
28 |
29 | parser.add_argument(
30 | "--target",
31 | type=str,
32 | default="torchserve",
33 | help="MLflow target (default: torchserve)",
34 | )
35 |
36 | parser.add_argument(
37 | "--deployment_name",
38 | type=str,
39 | default="mnist_classification",
40 | help="Deployment name (default: mnist_classification)",
41 | )
42 |
43 | parser.add_argument(
44 | "--model_file",
45 | type=str,
46 | default="mnist_model.py",
47 | help="Model file path (default: mnist_model.py)",
48 | )
49 |
50 | parser.add_argument(
51 | "--handler",
52 | type=str,
53 | default="mnist_handler.py",
54 | help="Handler file path (default: mnist_handler.py)",
55 | )
56 |
57 | parser.add_argument(
58 | "--extra_files",
59 | type=str,
60 | default="index_to_name.json",
61 | help="List of extra files",
62 | )
63 |
64 | parser.add_argument(
65 | "--registered_model_uri",
66 | type=str,
67 | default="models:/mnist_classifier/3",
68 | help="Registered model name (default: models:/mnist_classifier/1)",
69 | )
70 |
71 | parser.add_argument(
72 | "--export_path",
73 | type=str,
74 | default="model_store",
75 | help="Path to model store (default: 'model_store')",
76 | )
77 |
78 | args = parser.parse_args()
79 |
80 | create_deployment(vars(args))
81 |
--------------------------------------------------------------------------------
/.github/ISSUE_TEMPLATE/feature_request.md:
--------------------------------------------------------------------------------
1 | ---
2 | name: Feature request
3 | about: Suggest an idea for this project
4 | title: "[FR]"
5 | labels: ''
6 | assignees: ''
7 |
8 | ---
9 |
10 | ---
11 | name: Feature Request
12 | about: Use this template for feature and enhancement proposals.
13 | labels: 'enhancement'
14 | title: "[FR]"
15 | ---
16 | Thank you for submitting a feature request. **Before proceeding, please review MLflow's [Issue Policy for feature requests](https://www.github.com/mlflow/mlflow/blob/master/ISSUE_POLICY.md#feature-requests) and the [MLflow Contributing Guide](https://github.com/mlflow/mlflow/blob/master/CONTRIBUTING.rst)**.
17 |
18 | **Please fill in this feature request template to ensure a timely and thorough response.**
19 |
20 | ## Willingness to contribute
21 | The MLflow Community encourages new feature contributions. Would you or another member of your organization be willing to contribute an implementation of this feature (as an enhancement to the MLflow TorchServe Deployment plugin code base)?
22 |
23 | - [ ] Yes. I can contribute this feature independently.
24 | - [ ] Yes. I would be willing to contribute this feature with guidance from the MLflow community.
25 | - [ ] No. I cannot contribute this feature at this time.
26 |
27 | ## Proposal Summary
28 |
29 | (In a few sentences, provide a clear, high-level description of the feature request)
30 |
31 | ## Motivation
32 | - What is the use case for this feature?
33 | - Why is this use case valuable to support for MLflow TorchServe Deployment plugin users in general?
34 | - Why is this use case valuable to support for your project(s) or organization?
35 | - Why is it currently difficult to achieve this use case? (please be as specific as possible about why related MLflow TorchServe Deployment plugin features and components are insufficient)
36 |
37 | ### What component(s) does this feature affect?
38 | Components
39 | - [ ] `area/deploy`: Main deployment plugin logic
40 | - [ ] `area/build`: Build and test infrastructure for MLflow TorchServe Deployment Plugin
41 | - [ ] `area/docs`: MLflow TorchServe Deployment Plugin documentation pages
42 | - [ ] `area/examples`: Example code
43 |
44 | ## Details
45 |
46 | (Use this section to include any additional information about the feature. If you have a proposal for how to implement this feature, please include it here. For implementation guidelines, please refer to the [Contributing Guide](https://github.com/mlflow/mlflow/blob/master/CONTRIBUTING.rst#contribution-guidelines).)
47 |
--------------------------------------------------------------------------------
/tests/examples/test_end_to_end_example.py:
--------------------------------------------------------------------------------
1 | import mlflow
2 | import pytest
3 | import os
4 | from mlflow.utils import process
5 |
6 |
7 | @pytest.mark.usefixtures("start_torchserve")
8 | def test_mnist_example():
9 | os.environ["MKL_THREADING_LAYER"] = "GNU"
10 | home_dir = os.getcwd()
11 | mnist_dir = "examples/MNIST"
12 | example_command = ["python", "mnist_model.py", "--max_epochs", "1", "--register", "false"]
13 | process._exec_cmd(example_command, cwd=mnist_dir)
14 |
15 | assert os.path.exists(os.path.join(mnist_dir, "model.pth"))
16 | create_deployment_command = [
17 | "python",
18 | "create_deployment.py",
19 | "--export_path",
20 | os.path.join(home_dir, "model_store"),
21 | "--registered_model_uri",
22 | "model.pth",
23 | ]
24 |
25 | process._exec_cmd(create_deployment_command, cwd=mnist_dir)
26 |
27 | assert os.path.exists(os.path.join(home_dir, "model_store", "mnist_classification.mar"))
28 |
29 | predict_command = ["python", "predict.py"]
30 | res = process._exec_cmd(predict_command, cwd=mnist_dir)
31 | assert "ONE" in res.stdout
32 |
33 |
34 | @pytest.mark.usefixtures("start_torchserve")
35 | def test_iris_example(tmpdir):
36 | iris_dir = os.path.join("examples", "IrisClassification")
37 | home_dir = os.getcwd()
38 | example_command = ["python", os.path.join(iris_dir, "iris_classification.py")]
39 | extra_files = "{},{}".format(
40 | os.path.join(iris_dir, "index_to_name.json"),
41 | os.path.join(home_dir, "model/MLmodel"),
42 | )
43 | process._exec_cmd(example_command, cwd=home_dir)
44 | create_deployment_command = [
45 | "python",
46 | os.path.join(iris_dir, "create_deployment.py"),
47 | "--export_path",
48 | os.path.join(home_dir, "model_store"),
49 | "--handler",
50 | os.path.join(iris_dir, "iris_handler.py"),
51 | "--model_file",
52 | os.path.join(iris_dir, "iris_classification.py"),
53 | "--extra_files",
54 | extra_files,
55 | ]
56 |
57 | process._exec_cmd(create_deployment_command, cwd=home_dir)
58 | mlflow.end_run()
59 | assert os.path.exists(os.path.join(home_dir, "model_store", "iris_classification.mar"))
60 | predict_command = [
61 | "python",
62 | os.path.join(iris_dir, "predict.py"),
63 | "--input_file_path",
64 | os.path.join(iris_dir, "sample.json"),
65 | ]
66 | res = process._exec_cmd(predict_command, cwd=home_dir)
67 | assert "SETOSA" in res.stdout
68 |
--------------------------------------------------------------------------------
/tests/resources/linear_model.py:
--------------------------------------------------------------------------------
1 | # pylint: disable=W0223
2 | # IMPORTS SECTION #
3 |
4 | import numpy as np
5 | import torch
6 | from torch.autograd import Variable
7 |
8 | # SYNTHETIC DATA PREPARATION #
9 |
10 | x_values = [i for i in range(11)]
11 |
12 | x_train = np.array(x_values, dtype=np.float32)
13 | x_train = x_train.reshape(-1, 1)
14 |
15 | y_values = [2 * i + 1 for i in x_values]
16 |
17 | y_train = np.array(y_values, dtype=np.float32)
18 | y_train = y_train.reshape(-1, 1)
19 |
20 |
21 | # DEFINING THE NETWORK FOR REGRESSION #
22 |
23 |
24 | class LinearRegression(torch.nn.Module):
25 | def __init__(self, inputSize, outputSize):
26 | super(LinearRegression, self).__init__()
27 | self.linear = torch.nn.Linear(inputSize, outputSize)
28 |
29 | def forward(self, x):
30 | out = self.linear(x)
31 | return out
32 |
33 |
34 | def main():
35 | # SECTION FOR HYPERPARAMETERS #
36 |
37 | inputDim = 1
38 | outputDim = 1
39 | learningRate = 0.01
40 | epochs = 100
41 |
42 | # INITIALIZING THE MODEL #
43 |
44 | model = LinearRegression(inputDim, outputDim)
45 |
46 | # FOR GPU #
47 | if torch.cuda.is_available():
48 | model.cuda()
49 |
50 | # INITIALIZING THE LOSS FUNCTION AND OPTIMIZER #
51 |
52 | criterion = torch.nn.MSELoss()
53 | optimizer = torch.optim.SGD(model.parameters(), lr=learningRate)
54 |
55 | # TRAINING STEP #
56 |
57 | for epoch in range(epochs):
58 | # Converting inputs and labels to Variable
59 | if torch.cuda.is_available():
60 | inputs = Variable(torch.from_numpy(x_train).cuda())
61 | labels = Variable(torch.from_numpy(y_train).cuda())
62 | else:
63 | inputs = Variable(torch.from_numpy(x_train))
64 | labels = Variable(torch.from_numpy(y_train))
65 |
66 | optimizer.zero_grad()
67 | outputs = model(inputs)
68 |
69 | loss = criterion(outputs, labels)
70 |
71 | loss.backward()
72 | optimizer.step()
73 | print("epoch {}, loss {}".format(epoch, loss.item()))
74 |
75 | # EVALUATION AND PREDICTION #
76 |
77 | with torch.no_grad():
78 | if torch.cuda.is_available():
79 | predicted = model(Variable(torch.from_numpy(x_train).cuda())).cpu().data.numpy()
80 | else:
81 | predicted = model(Variable(torch.from_numpy(x_train))).data.numpy()
82 | print(predicted)
83 |
84 | # SAVING THE MODEL #
85 | torch.save(model.state_dict(), "tests/resources/linear_state_dict.pt")
86 | torch.save(model, "tests/resources/linear_model.pt")
87 |
--------------------------------------------------------------------------------
/.github/workflows/ci.yml:
--------------------------------------------------------------------------------
1 | name: CI
2 |
3 | on:
4 | push:
5 | branches: [master]
6 | pull_request:
7 | branches: [master]
8 |
9 | env:
10 | CONDA_BIN: /usr/share/miniconda/bin
11 |
12 | jobs:
13 | lint:
14 | runs-on: ubuntu-latest
15 | steps:
16 | - uses: actions/checkout@v2
17 | - uses: actions/setup-python@v2
18 | with:
19 | python-version: 3.7
20 | - name: Install dependencies
21 | run: pip install black==22.3.0 flake8==5.0.4
22 | - run: flake8 .
23 | - run: black --check --line-length=100 .
24 |
25 | test:
26 | runs-on: ubuntu-20.04
27 | steps:
28 | - uses: actions/checkout@v2
29 | - uses: actions/setup-python@v2
30 | with:
31 | python-version: 3.7
32 | - name: Install Java
33 | run: |
34 | sudo apt-get update
35 | sudo apt-get install openjdk-11-jdk
36 | - name: Check Java version
37 | run: |
38 | java -version
39 | - name: Enable conda
40 | run: |
41 | echo "/usr/share/miniconda/bin" >> $GITHUB_PATH
42 | - name: Install dependencies
43 | run: |
44 | pip install -e .
45 | pip install pytest==6.1.1 torchvision scikit-learn gorilla transformers torchtext matplotlib captum
46 |
47 | - name: Install pytorch lightning
48 | run: |
49 | pip install pytorch-lightning>=1.2.3
50 |
51 | - name: Add permissions for remove conda utility file
52 | run: |
53 | chmod +x utils/remove-conda-envs.sh
54 |
55 | - name: Run torchserve example
56 | run: |
57 | set -x
58 | git clone https://github.com/pytorch/serve.git
59 | cd serve/examples/image_classifier/resnet_18
60 | mkdir model_store
61 | wget --no-verbose https://download.pytorch.org/models/resnet18-5c106cde.pth
62 |
63 | torch-model-archiver \
64 | --model-name resnet-18 \
65 | --version 1.0 \
66 | --model-file model.py \
67 | --serialized-file resnet18-5c106cde.pth \
68 | --handler image_classifier \
69 | --export-path model_store \
70 | --extra-files ../index_to_name.json
71 |
72 | torchserve --start --model-store model_store --models resnet-18=resnet-18.mar || true
73 | sleep 10
74 | curl -s -X POST http://127.0.0.1:8080/predictions/resnet-18 -T ../kitten.jpg
75 | sleep 3
76 | torchserve --stop
77 |
78 | - name: Run tests
79 | run: |
80 | pytest tests --color=yes --verbose --durations=5
81 |
--------------------------------------------------------------------------------
/.github/ISSUE_TEMPLATE/bug_report.md:
--------------------------------------------------------------------------------
1 | ---
2 | name: Bug report
3 | about: Create a report to help us improve
4 | title: ''
5 | labels: ''
6 | assignees: ''
7 |
8 | ---
9 |
10 | ---
11 | name: Bug Report
12 | about: Use this template for reporting bugs encountered while using MLflow TorchServe Deployment Plugin.
13 | labels: 'bug'
14 | title: "[BUG]"
15 | ---
16 | Thank you for submitting an issue. Please refer to our [issue policy](https://www.github.com/mlflow/mlflow/blob/master/ISSUE_POLICY.md) for additional information about bug reports. For help with debugging your code, please refer to [Stack Overflow](https://stackoverflow.com/questions/tagged/mlflow).
17 |
18 | **Please fill in this bug report template to ensure a timely and thorough response.**
19 |
20 | ### Willingness to contribute
21 | The MLflow Community encourages bug fix contributions. Would you or another member of your organization be willing to contribute a fix for this bug to the MLflow code base?
22 |
23 | - [ ] Yes. I can contribute a fix for this bug independently.
24 | - [ ] Yes. I would be willing to contribute a fix for this bug with guidance from the MLflow community.
25 | - [ ] No. I cannot contribute a bug fix at this time.
26 |
27 | ### System information
28 | - **Have I written custom code (as opposed to using a stock example script provided in MLflow)**:
29 | - **OS Platform and Distribution (e.g., Linux Ubuntu 18.04)**:
30 | - **MLflow installed from (source or binary)**:
31 | - **MLflow version (run ``mlflow --version``)**:
32 | - **MLflow TorchServe Deployment plugin installed from (source or binary)**:
33 | - **MLflow TorchServe Deployment plugin version (run ``mlflow deployments--version``)**:
34 | - **TorchServe installed from (source or binary)**:
35 | - **TorchServe version (run ``torchserve --version``)**:
36 | - **Python version**:
37 | - **Exact command to reproduce**:
38 |
39 | ### Describe the problem
40 | Describe the problem clearly here. Include descriptions of the expected behavior and the actual behavior.
41 |
42 | ### Code to reproduce issue
43 | Provide a reproducible test case that is the bare minimum necessary to generate the problem.
44 |
45 | ### Other info / logs
46 | Include any logs or source code that would be helpful to diagnose the problem. If including tracebacks, please include the full traceback. Large logs and files should be attached.
47 |
48 |
49 | ### What component(s) does this bug affect?
50 | Components
51 | - [ ] `area/deploy`: Main deployment plugin logic
52 | - [ ] `area/build`: Build and test infrastructure for MLflow TorchServe Deployment Plugin
53 | - [ ] `area/docs`: MLflow TorchServe Deployment Plugin documentation pages
54 | - [ ] `area/examples`: Example code
55 |
--------------------------------------------------------------------------------
/examples/IrisClassificationTorchScript/README.md:
--------------------------------------------------------------------------------
1 | # Deploying Iris Classification using torchserve
2 |
3 | The code, adapted from this [repository](http://chappers.github.io/2020/04/19/torch-lightning-using-iris/),
4 | is almost entirely dedicated to training, with the addition of a single mlflow.pytorch.autolog() call to enable automatic logging of params, metrics, and the TorchScript model.
5 | TorchScript allows us to save the whole model locally and load it into a different environment, such as in a server written in
6 | a completely different language.
7 |
8 | ## Training the model
9 |
10 | To run the example via MLflow, navigate to the `examples/IrisClassificationTorchScript/` directory and run the command
11 |
12 | ```
13 | mlflow run .
14 |
15 | ```
16 |
17 | This will run `iris_classification.py` with the default set of parameters such as `--max_epochs=100`. You can see the default value in the MLproject file.
18 |
19 | In order to run the file with custom parameters, run the command
20 |
21 | ```
22 | mlflow run . -P max_epochs=X
23 | ```
24 |
25 | where X is your desired value for max_epochs.
26 |
27 | If you have the required modules for the file and would like to skip the creation of a conda environment, add the argument --no-conda.
28 |
29 | ```
30 | mlflow run . --no-conda
31 | ```
32 |
33 | After the training, we will convert the model to a TorchScript model using the function `torch.jit.script`.
34 | At the end of the training process, scripted model is stored as `iris_ts.pt`
35 |
36 | ## Starting TorchServe
37 |
38 | create an empty directory `model_store` and run the following command to start torchserve.
39 |
40 | `torchserve --start --model-store model_store`
41 |
42 | Note:
43 | mlflow-torchserve plugin generates the mar file inside the "model_store" directory. If the `model_store` directory is not present under the current folder,
44 | the plugin creates a new directory named "model_store" and generates the mar file inside it.
45 |
46 | if the torchserve is already running with a different "model_store" location, ensure to pass the "model_store" path with the "EXPORT_PATH" config variable (`-C 'EXPORT_PATH='`)
47 |
48 | ## Creating a new deployment
49 |
50 | Run the following command to create a new deployment named `iris_test`
51 |
52 | `mlflow deployments create --name iris_test --target torchserve --model-uri iris_ts.pt -C "HANDLER=iris_handler.py" -C "EXTRA_FILES=index_to_name.json"`
53 |
54 | ## Running prediction based on deployed model
55 |
56 | Run the following command to invoke prediction of our sample input, where input.json is the sample input file and output.json stores the predicted outcome.
57 |
58 | `mlflow deployments predict --name iris_test --target torchserve --input-path sample.json --output-path output.json`
59 |
--------------------------------------------------------------------------------
/examples/IrisClassificationTorchScript/iris_handler.py:
--------------------------------------------------------------------------------
1 | import ast
2 | import logging
3 | import os
4 | import numpy as np
5 | import torch
6 | from ts.torch_handler.base_handler import BaseHandler
7 |
8 | logger = logging.getLogger(__name__)
9 |
10 |
11 | class IRISClassifierHandler(BaseHandler):
12 | """
13 | IRISClassifier handler class. This handler takes an input tensor and
14 | output the type of iris based on the input
15 | """
16 |
17 | def __init__(self):
18 | super(IRISClassifierHandler, self).__init__()
19 |
20 | def initialize(self, context):
21 | """First try to load torchscript else load eager mode state_dict based model"""
22 |
23 | properties = context.system_properties
24 | self.map_location = "cuda" if torch.cuda.is_available() else "cpu"
25 | self.device = torch.device(
26 | self.map_location + ":" + str(properties.get("gpu_id"))
27 | if torch.cuda.is_available()
28 | else self.map_location
29 | )
30 | self.manifest = context.manifest
31 |
32 | model_dir = properties.get("model_dir")
33 | self.batch_size = properties.get("batch_size")
34 | serialized_file = self.manifest["model"]["serializedFile"]
35 | model_pt_path = os.path.join(model_dir, serialized_file)
36 |
37 | if not os.path.isfile(model_pt_path):
38 | raise RuntimeError("Missing the model.pt file")
39 |
40 | logger.debug("Loading torchscript model")
41 | self.model = self._load_torchscript_model(model_pt_path)
42 |
43 | self.model.to(self.device)
44 | self.model.eval()
45 |
46 | logger.debug("Model file %s loaded successfully", model_pt_path)
47 |
48 | self.initialized = True
49 |
50 | def preprocess(self, data):
51 | """
52 | preprocessing step - Reads the input array and converts it to tensor
53 |
54 | :param data: Input to be passed through the layers for prediction
55 |
56 | :return: output - Preprocessed input
57 | """
58 |
59 | input_data_str = data[0].get("data")
60 | if input_data_str is None:
61 | input_data_str = data[0].get("body")
62 |
63 | input_data = input_data_str.decode("utf-8")
64 | input_tensor = torch.Tensor(ast.literal_eval(input_data))
65 | return input_tensor
66 |
67 | def postprocess(self, inference_output):
68 | """
69 | Does postprocess after inference to be returned to user
70 |
71 | :param inference_output: Output of inference
72 |
73 | :return: output - Output after post processing
74 | """
75 |
76 | predicted_idx = str(np.argmax(inference_output.cpu().detach().numpy()))
77 |
78 | if self.mapping:
79 | return [self.mapping[str(predicted_idx)]]
80 | return [predicted_idx]
81 |
82 |
83 | _service = IRISClassifierHandler()
84 |
--------------------------------------------------------------------------------
/examples/cifar10/README.md:
--------------------------------------------------------------------------------
1 | # Deploying Cifar10 image classification using torchserve
2 |
3 | This example demonstrates fine tuning of resnet model using cifar10 dataset.
4 |
5 | Follow the link given below to set backend store
6 |
7 | https://www.mlflow.org/docs/latest/tracking.html#storage
8 |
9 | ## Training the model
10 |
11 | This example, autologs the trained model and its relevant parameters and metrics into mlflow using a single line of code.
12 | The example also illustrates how one can use the python plugin to deploy and test the model.
13 | Python scripts `create_deployment.py` and `predict.py` have been used for this purpose.
14 |
15 | Run the following command to train the cifar10 model
16 |
17 | CPU: `mlflow run . -P max_epochs=5`
18 | GPU: `mlflow run . -P max_epochs=5 -P devices=2 -P strategy=ddp -P accelerator=gpu`
19 |
20 | At the end of the training, Cifar10 model will be saved as state dict (resnet.pth) in the current working directory
21 |
22 | ## Starting torchserve
23 |
24 | create an empty directory `model_store` and run the following command to start torchserve.
25 |
26 | `torchserve --start --model-store model_store`
27 |
28 | ## Creating a new deployment
29 |
30 | This example uses image path as input for prediction.
31 |
32 | To create a new deployment, run the following command
33 |
34 | `python create_deployment.py`
35 |
36 | It will create a new deployment named `cifar_test`.
37 |
38 | Following are the arguments which can be passed to create_deployment script
39 |
40 | 1. deployment name - `--deployment_name`
41 | 2. path to serialized file - `--model_uri`
42 | 3. handler file path - `--handler`
43 | 4. model file path - `--model_file`
44 |
45 | Note:
46 | if the torchserve is running with a different "model_store" locations, the model-store path
47 | can be passed as input using `--export_path` argument.
48 |
49 | For example:
50 |
51 | `python create_deployment.py --deployment_name cifar_test1 --export_path /home/ubuntu/model_store`
52 |
53 | ## Predicting deployed model
54 |
55 | To perform prediction, run the following script
56 |
57 | `python inference.py`
58 |
59 | The prediction results will be printed in the console.
60 |
61 | to save the inference output in file run the following command
62 |
63 | `python inference.py --output_file_path prediction_result.json`
64 |
65 | Following are the arguments which can be passed to predict_deployment script
66 |
67 | 1. deployment name - `--deployment_name"`
68 | 2. input file path - `--input_file_path`
69 | 3. path to write the result - `--output_file_path`
70 |
71 |
72 | ## Calculate captum explanations
73 |
74 | To perform explain request, run the following script
75 |
76 | `python inferene.py --inference_type explanation`
77 |
78 | to save the explanation output in file run the following command
79 |
80 | `python inference.py --inference_type explanation --output_file_path explanation_result.json`
81 |
82 |
83 | ## Viewing captum results
84 |
85 | Use the notebook - `Cifar10_Captum.ipynb` to view the captum results.
86 |
--------------------------------------------------------------------------------
/examples/E2EBert/wrapper.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 |
5 |
6 | class AGNewsmodelWrapper(nn.Module):
7 | def __init__(self, model):
8 | super(AGNewsmodelWrapper, self).__init__()
9 | self.model = model
10 |
11 | def compute_bert_outputs( # pylint: disable=no-self-use
12 | self, model_bert, embedding_input, attention_mask=None, head_mask=None
13 | ):
14 | """Computes Bert Outputs.
15 |
16 | Args:
17 | model_bert : the bert model
18 | embedding_input : input for bert embeddings.
19 | attention_mask : attention mask
20 | head_mask : head mask
21 | Returns:
22 | output : the bert output
23 | """
24 | if attention_mask is None:
25 | attention_mask = torch.ones( # pylint: disable=no-member
26 | embedding_input.shape[0], embedding_input.shape[1]
27 | ).to(embedding_input)
28 |
29 | extended_attention_mask = attention_mask.unsqueeze(1).unsqueeze(2)
30 |
31 | extended_attention_mask = extended_attention_mask.to(
32 | dtype=next(model_bert.parameters()).dtype
33 | ) # fp16 compatibility
34 | extended_attention_mask = (1.0 - extended_attention_mask) * -10000.0
35 |
36 | if head_mask is not None:
37 | if head_mask.dim() == 1:
38 | head_mask = head_mask.unsqueeze(0).unsqueeze(0).unsqueeze(-1).unsqueeze(-1)
39 | head_mask = head_mask.expand(model_bert.config.num_hidden_layers, -1, -1, -1, -1)
40 | elif head_mask.dim() == 2:
41 | head_mask = (
42 | head_mask.unsqueeze(1).unsqueeze(-1).unsqueeze(-1)
43 | ) # We can specify head_mask for each layer
44 | head_mask = head_mask.to(
45 | dtype=next(model_bert.parameters()).dtype
46 | ) # switch to fload if need + fp16 compatibility
47 | else:
48 | head_mask = [None] * model_bert.config.num_hidden_layers
49 |
50 | encoder_outputs = model_bert.encoder(
51 | embedding_input, extended_attention_mask, head_mask=head_mask
52 | )
53 | sequence_output = encoder_outputs[0]
54 | pooled_output = model_bert.pooler(sequence_output)
55 | outputs = (
56 | sequence_output,
57 | pooled_output,
58 | ) + encoder_outputs[1:]
59 | return outputs
60 |
61 | def forward(self, embeddings, attention_mask=None):
62 | """Forward function.
63 |
64 | Args:
65 | embeddings : bert embeddings.
66 | attention_mask: Attention mask value
67 | """
68 | outputs = self.compute_bert_outputs(self.model.bert_model, embeddings, attention_mask)
69 | pooled_output = outputs[1]
70 | output = F.relu(self.model.fc1(pooled_output))
71 | output = self.model.drop(output)
72 | output = self.model.out(output)
73 | return output
74 |
--------------------------------------------------------------------------------
/examples/MNIST/README.md:
--------------------------------------------------------------------------------
1 | # Deploying MNIST Handwritten Recognition using torchserve
2 |
3 | This example requires Backend store to be set for mlflow.
4 |
5 | Follow the link given below to set backend store
6 |
7 | https://www.mlflow.org/docs/latest/tracking.html#storage
8 |
9 | ## Training the model
10 | The model is used to classify handwritten digits.
11 | This example, autologs the trained model and its relevant parameters and metrics into mlflow using a single line of code.
12 | The example also illustrates how one can use the python plugin to deploy and test the model.
13 | Python scripts `create_deployment.py` and `predict.py` have been used for this purpose.
14 |
15 | Run the following command to train the MNIST model
16 |
17 | CPU: `mlflow run . -P max_epochs=5`
18 | GPU: `mlflow run . -P max_epochs=5 -P devices=2 -P strategy=ddp -P accelerator=gpu`
19 |
20 | At the end of the training, MNIST model will be saved as state dict in the current working directory
21 |
22 | ## Deploying in remote torchserve instance
23 |
24 | To deploy the model in remote torchserve instance follow
25 |
26 | the steps in [remote-deployment.rst](../../docs/remote-deployment.rst) under `docs` folder.
27 |
28 |
29 | ## Deploying in local torchserve instance
30 |
31 | ## Starting torchserve
32 |
33 | create an empty directory `model_store` and run the following command to start torchserve.
34 |
35 | `torchserve --start --model-store model_store`
36 |
37 | ## Creating a new deployment
38 |
39 | This example uses image path as input for prediction.
40 |
41 | To create a new deployment, run the following command
42 |
43 | `python create_deployment.py`
44 |
45 | It will create a new deployment named `mnist_classification`.
46 |
47 | Following are the arguments which can be passed to create_deployment script
48 |
49 | 1. deployment name - `--deployment_name`
50 | 2. registered mlflow model uri - `--registered_model_uri`
51 | 3. handler file path - `--handler`
52 | 4. model file path - `--model_file`
53 |
54 | For example, to create another deployment the script can be triggered as
55 |
56 | `python create_deployment.py --deployment_name mnist_deployment1`
57 |
58 | Note:
59 |
60 | if the torchserve is running with a different "model_store" locations, the model-store path
61 | can be passed as input using `--export_path` argument.
62 |
63 | For example:
64 |
65 | `python create_deployment.py --deployment_name mnist_deployment1 --export_path /home/ubuntu/model_store`
66 |
67 | ## Predicting deployed model
68 |
69 | To perform prediction, run the following script
70 |
71 | `python predict.py`
72 |
73 | The prediction results will be printed in the console.
74 |
75 | Following are the arguments which can be passed to predict_deployment script
76 |
77 | 1. deployment name - `--deployment_name"`
78 | 2. input file path - `--input_file_path`
79 |
80 | For example, to perform prediction on the second deployment which we created. Run the following command
81 |
82 | `python predict.py --deployment_name mnist_deployment1 --input_file_path test_data/one.png`
83 |
--------------------------------------------------------------------------------
/examples/IrisClassification/iris_handler.py:
--------------------------------------------------------------------------------
1 | import logging
2 |
3 | import numpy as np
4 | import torch
5 | import pandas as pd
6 | import os
7 | import json
8 | from ts.torch_handler.base_handler import BaseHandler
9 | from mlflow.models.model import Model
10 | from mlflow.pyfunc import _enforce_schema
11 |
12 | logger = logging.getLogger(__name__)
13 |
14 |
15 | class IRISClassifierHandler(BaseHandler):
16 | """
17 | IRISClassifier handler class. This handler takes an input tensor and
18 | output the type of iris based on the input
19 | """
20 |
21 | def __init__(self):
22 | super(IRISClassifierHandler, self).__init__()
23 | self.mlmodel = None
24 |
25 | def preprocess(self, data):
26 | """
27 | preprocessing step - Reads the input array and converts it to tensor
28 |
29 | :param data: Input to be passed through the layers for prediction
30 |
31 | :return: output - Preprocessed input
32 | """
33 |
34 | data = json.loads(data[0]["data"].decode("utf-8"))
35 | df = pd.DataFrame(data)
36 |
37 | _enforce_schema(df, self.mlmodel.get_input_schema())
38 |
39 | input_tensor = torch.Tensor(list(df.iloc[0]))
40 | return input_tensor
41 |
42 | def extract_signature(self, mlmodel_file):
43 | self.mlmodel = Model.load(mlmodel_file)
44 | model_json = json.loads(Model.to_json(self.mlmodel))
45 |
46 | if "signature" not in model_json.keys():
47 | raise Exception("Model Signature not found")
48 |
49 | def initialize(self, ctx):
50 | """
51 | First try to load torchscript else load eager mode state_dict based model
52 | :param ctx: System properties
53 | """
54 | properties = ctx.system_properties
55 | self.device = torch.device(
56 | "cuda:" + str(properties.get("gpu_id")) if torch.cuda.is_available() else "cpu"
57 | )
58 | model_dir = properties.get("model_dir")
59 |
60 | # Read model serialize/pt file
61 | model_pt_path = os.path.join(model_dir, "model.pth")
62 |
63 | self.model = torch.load(model_pt_path, map_location=self.device)
64 |
65 | logger.debug("Model file %s loaded successfully", model_pt_path)
66 |
67 | mapping_file_path = os.path.join(model_dir, "index_to_name.json")
68 | if os.path.exists(mapping_file_path):
69 | with open(mapping_file_path) as fp:
70 | self.mapping = json.load(fp)
71 | mlmodel_file = os.path.join(model_dir, "MLmodel")
72 |
73 | self.extract_signature(mlmodel_file=mlmodel_file)
74 |
75 | self.initialized = True
76 |
77 | def postprocess(self, inference_output):
78 | """
79 | Does postprocess after inference to be returned to user
80 |
81 | :param inference_output: Output of inference
82 |
83 | :return: output - Output after post processing
84 | """
85 |
86 | predicted_idx = str(np.argmax(inference_output.cpu().detach().numpy()))
87 |
88 | if self.mapping:
89 | return [self.mapping[str(predicted_idx)]]
90 | return [predicted_idx]
91 |
92 |
93 | _service = IRISClassifierHandler()
94 |
--------------------------------------------------------------------------------
/.github/workflows/build-wheel.yml:
--------------------------------------------------------------------------------
1 | name: Build wheel
2 |
3 | on:
4 | push:
5 | branches: [master]
6 | pull_request:
7 | branches: [master]
8 |
9 | jobs:
10 | build:
11 | runs-on: ubuntu-latest
12 | steps:
13 | - uses: actions/checkout@v2
14 | - uses: actions/setup-python@v2
15 | with:
16 | python-version: "3.7"
17 |
18 | - name: Build wheel
19 | id: build-wheel
20 | run: |
21 | pip install wheel
22 | python setup.py bdist_wheel
23 |
24 | # set outputs
25 | wheel_path=$(find dist -type f)
26 | wheel_name=$(basename $wheel_path)
27 | wheel_size=$(stat -c %s $wheel_path)
28 | echo "::set-output name=wheel-path::${wheel_path}"
29 | echo "::set-output name=wheel-name::${wheel_name}"
30 | echo "::set-output name=wheel-size::${wheel_size}"
31 |
32 | - name: Verify wheel can be installed
33 | run: |
34 | pip install ${{ steps.build-wheel.outputs.wheel-path }}
35 |
36 | # Anyone with read access can download the uploaded wheel on GitHub.
37 | - name: Store wheel
38 | uses: actions/upload-artifact@v2
39 | if: github.event_name == 'push'
40 | with:
41 | name: ${{ steps.build-wheel.outputs.wheel-name }}
42 | path: ${{ steps.build-wheel.outputs.wheel-path }}
43 |
44 | - name: Remove old wheels
45 | uses: actions/github-script@v3
46 | if: github.event_name == 'push'
47 | env:
48 | WHEEL_SIZE: ${{ steps.build-wheel.outputs.wheel-size }}
49 | with:
50 | github-token: ${{ secrets.GITHUB_TOKEN }}
51 | script: |
52 | const { owner, repo } = context.repo;
53 |
54 | // For some reason, the newly-uploaded wheel in the previous step is not included.
55 | const artifactsResp = await github.actions.listArtifactsForRepo({
56 | owner,
57 | repo,
58 | });
59 | const wheels = artifactsResp.data.artifacts.filter(({ name }) => name.endsWith(".whl"));
60 |
61 | // The storage usage limit for a free github account is up to 500 MB. See the page below for details:
62 | // https://docs.github.com/en/github/setting-up-and-managing-billing-and-payments-on-github/about-billing-for-github-actions
63 | MAX_SIZE_IN_BYTES = 300_000_000; // 300 MB
64 |
65 | let index = 0;
66 | let sum = parseInt(process.env.WHEEL_SIZE); // include the newly-uploaded wheel
67 | for (const [idx, { size_in_bytes }] of wheels.entries()) {
68 | index = idx;
69 | sum += size_in_bytes;
70 | if (sum > MAX_SIZE_IN_BYTES) {
71 | break;
72 | }
73 | }
74 |
75 | if (sum <= MAX_SIZE_IN_BYTES) {
76 | return;
77 | }
78 |
79 | // Delete old wheels
80 | const promises = wheels.slice(index).map(({ id: artifact_id }) =>
81 | github.actions.deleteArtifact({
82 | owner,
83 | repo,
84 | artifact_id,
85 | })
86 | );
87 | Promise.all(promises).then(data => console.log(data));
88 |
89 |
--------------------------------------------------------------------------------
/examples/IrisClassification/iris_data_module.py:
--------------------------------------------------------------------------------
1 | from argparse import ArgumentParser
2 |
3 | import pytorch_lightning as pl
4 | import torch
5 | from pytorch_lightning import seed_everything
6 | from sklearn.datasets import load_iris
7 | from torch.utils.data import DataLoader, random_split, TensorDataset
8 |
9 |
10 | class IrisDataModule(pl.LightningDataModule):
11 | def __init__(self, **kwargs):
12 | """
13 | Initialization of inherited lightning data module
14 | """
15 | super(IrisDataModule, self).__init__()
16 |
17 | self.train_set = None
18 | self.val_set = None
19 | self.test_set = None
20 | self.args = kwargs
21 |
22 | def prepare_data(self):
23 | """
24 | Implementation of abstract class
25 | """
26 |
27 | def setup(self, stage=None):
28 | """
29 | Downloads the data, parse it and split the data into train, test, validation data
30 |
31 | :param stage: Stage - training or testing
32 | """
33 | iris = load_iris()
34 | df = iris.data
35 | target = iris["target"]
36 |
37 | data = torch.Tensor(df).float()
38 | labels = torch.Tensor(target).long()
39 | RANDOM_SEED = 42
40 | seed_everything(RANDOM_SEED)
41 |
42 | data_set = TensorDataset(data, labels)
43 | self.train_set, self.val_set = random_split(data_set, [130, 20])
44 | self.train_set, self.test_set = random_split(self.train_set, [110, 20])
45 |
46 | @staticmethod
47 | def add_model_specific_args(parent_parser):
48 | """
49 | Adds model specific arguments batch size and num workers
50 |
51 | :param parent_parser: Application specific parser
52 |
53 | :return: Returns the augmented arugument parser
54 | """
55 | parser = ArgumentParser(parents=[parent_parser], add_help=False)
56 | parser.add_argument(
57 | "--batch-size",
58 | type=int,
59 | default=128,
60 | metavar="N",
61 | help="input batch size for training (default: 16)",
62 | )
63 | parser.add_argument(
64 | "--num-workers",
65 | type=int,
66 | default=3,
67 | metavar="N",
68 | help="number of workers (default: 3)",
69 | )
70 | return parser
71 |
72 | def create_data_loader(self, dataset):
73 | """
74 | Generic data loader function
75 |
76 | :param data_set: Input data set
77 |
78 | :return: Returns the constructed dataloader
79 | """
80 |
81 | return DataLoader(
82 | dataset, batch_size=self.args["batch_size"], num_workers=self.args["num_workers"]
83 | )
84 |
85 | def train_dataloader(self):
86 | train_loader = self.create_data_loader(dataset=self.train_set)
87 | return train_loader
88 |
89 | def val_dataloader(self):
90 | validation_loader = self.create_data_loader(dataset=self.val_set)
91 | return validation_loader
92 |
93 | def test_dataloader(self):
94 | test_loader = self.create_data_loader(dataset=self.test_set)
95 | return test_loader
96 |
97 |
98 | if __name__ == "__main__":
99 | pass
100 |
--------------------------------------------------------------------------------
/examples/MNIST/mnist_handler.py:
--------------------------------------------------------------------------------
1 | import json
2 | import logging
3 | import os
4 |
5 | import torch
6 | from ts.torch_handler.image_classifier import ImageClassifier
7 |
8 | logger = logging.getLogger(__name__)
9 |
10 |
11 | class MNISTDigitHandler(ImageClassifier):
12 | """
13 | MNISTDigitClassifier handler class. This handler takes a greyscale image
14 | and returns the digit in that image.
15 | """
16 |
17 | def __init__(self):
18 | super(MNISTDigitHandler, self).__init__()
19 | self.mapping_file_path = None
20 |
21 | def initialize(self, ctx):
22 | """
23 | First try to load torchscript else load eager mode state_dict based model
24 | :param ctx: System properties
25 | """
26 | properties = ctx.system_properties
27 | self.device = torch.device(
28 | "cuda:" + str(properties.get("gpu_id")) if torch.cuda.is_available() else "cpu"
29 | )
30 | model_dir = properties.get("model_dir")
31 |
32 | # Read model serialize/pt file
33 | model_pt_path = os.path.join(model_dir, "model.pth")
34 | from mnist_model import LightningMNISTClassifier
35 |
36 | state_dict = torch.load(model_pt_path, map_location=self.device)
37 | self.model = LightningMNISTClassifier()
38 | self.model.load_state_dict(state_dict, strict=False)
39 | self.model.to(self.device)
40 | self.model.eval()
41 |
42 | logger.debug("Model file %s loaded successfully", model_pt_path)
43 |
44 | self.mapping_file_path = os.path.join(model_dir, "index_to_name.json")
45 |
46 | self.initialized = True
47 |
48 | def preprocess(self, data):
49 | """
50 | Scales, crops, and normalizes a PIL image for a MNIST model,
51 | returns an Numpy array
52 | :param data: Input to be passed through the layers for prediction
53 | :return: output - Preprocessed image
54 | """
55 | image = data[0].get("data")
56 | if image is None:
57 | image = data[0].get("body")
58 |
59 | image = image.decode("utf-8")
60 | image = torch.Tensor(json.loads(image)["data"])
61 | return image
62 |
63 | def inference(self, img):
64 | """
65 | Predict the class (or classes) of an image using a trained deep learning model
66 | :param img: Input to be passed through the layers for prediction
67 | :return: output - Predicted label for the given input
68 | """
69 | # Convert 2D image to 1D vector
70 | # img = np.expand_dims(img, 0)
71 | # img = torch.from_numpy(img)
72 | self.model.eval()
73 | inputs = img.to(self.device)
74 | outputs = self.model.forward(inputs)
75 |
76 | _, y_hat = outputs.max(1)
77 | predicted_idx = str(y_hat.item())
78 | return [predicted_idx]
79 |
80 | def postprocess(self, inference_output):
81 | """
82 | Does postprocess after inference to be returned to user
83 |
84 | :param inference_output: Output of inference
85 |
86 | :return: output - Output after post processing
87 | """
88 |
89 | if self.mapping_file_path:
90 | with open(self.mapping_file_path) as json_file:
91 | data = json.load(json_file)
92 | inference_output = [json.dumps(data[inference_output[0]])]
93 | return inference_output
94 | return [inference_output]
95 |
--------------------------------------------------------------------------------
/tests/resources/linear_handler.py:
--------------------------------------------------------------------------------
1 | # IMPORTS SECTION #
2 |
3 | import logging
4 | import os
5 | import numpy as np
6 | import torch
7 | from torch.autograd import Variable
8 | import json
9 |
10 | logger = logging.getLogger(__name__)
11 |
12 |
13 | # CLASS DEFINITION #
14 |
15 |
16 | class LinearRegressionHandler(object):
17 | def __init__(self):
18 | self.model = None
19 | self.mapping = None
20 | self.device = None
21 | self.initialized = False
22 |
23 | def initialize(self, ctx):
24 | """
25 | Loading the saved model from the serialized file
26 | """
27 |
28 | properties = ctx.system_properties
29 | self.device = torch.device(
30 | "cuda:" + str(properties.get("gpu_id")) if torch.cuda.is_available() else "cpu"
31 | )
32 | model_dir = properties.get("model_dir")
33 |
34 | # Read model serialize/pt file
35 | model_pt_path = os.path.join(model_dir, "linear_state_dict.pt")
36 | # Read model definition file
37 | model_def_path = os.path.join(model_dir, "linear_model.py")
38 |
39 | if not os.path.isfile(model_def_path):
40 | model_pt_path = os.path.join(model_dir, "linear_model.pt")
41 | self.model = torch.load(model_pt_path, map_location=self.device)
42 | else:
43 | from linear_model import LinearRegression
44 |
45 | state_dict = torch.load(model_pt_path, map_location=self.device)
46 | self.model = LinearRegression(1, 1)
47 | self.model.load_state_dict(state_dict)
48 |
49 | self.model.to(self.device)
50 | self.model.eval()
51 |
52 | logger.debug("Model file %s loaded successfully", model_pt_path)
53 | self.initialized = True
54 |
55 | def preprocess(self, data):
56 | """
57 | Preprocess the input to tensor and reshape it to be used as input to the network
58 | """
59 | data = data[0]
60 | image = data.get("data")
61 | if image is None:
62 | image = data.get("body")
63 | image = image.decode("utf-8")
64 | number = float(json.loads(image)["data"][0])
65 | else:
66 | number = float(image)
67 |
68 | np_data = np.array(number, dtype=np.float32)
69 | np_data = np_data.reshape(-1, 1)
70 | data_tensor = torch.from_numpy(np_data)
71 | return data_tensor
72 |
73 | def inference(self, num):
74 |
75 | """
76 | Does inference / prediction on the preprocessed input and returns the output
77 | """
78 |
79 | self.model.eval()
80 | inputs = Variable(num).to(self.device)
81 | outputs = self.model.forward(inputs)
82 | return [outputs.detach().item()]
83 |
84 | def postprocess(self, inference_output):
85 |
86 | """
87 | Does post processing on the output returned from the inference method
88 | """
89 | return inference_output
90 |
91 |
92 | # CLASS INITIALIZATION #
93 |
94 | _service = LinearRegressionHandler()
95 |
96 |
97 | def handle(data, context):
98 |
99 | """
100 | Default handler for the inference api which takes two parameters data and context
101 | and returns the predicted output
102 | """
103 |
104 | if not _service.initialized:
105 | _service.initialize(context)
106 |
107 | if data is None:
108 | return None
109 |
110 | data = _service.preprocess(data)
111 | data = _service.inference(data)
112 | data = _service.postprocess(data)
113 | print("Data: {}".format(data))
114 | return data
115 |
--------------------------------------------------------------------------------
/examples/IrisClassificationTorchScript/iris_classification.py:
--------------------------------------------------------------------------------
1 | import argparse
2 |
3 | import mlflow.pytorch
4 | import pytorch_lightning as pl
5 | import torch
6 | import torch.nn as nn
7 | from sklearn.metrics import accuracy_score
8 | from torch.nn import functional as F
9 |
10 |
11 | class IrisClassification(pl.LightningModule):
12 | def __init__(self):
13 | super(IrisClassification, self).__init__()
14 | self.fc1 = nn.Linear(4, 10)
15 | self.fc2 = nn.Linear(10, 10)
16 | self.fc3 = nn.Linear(10, 3)
17 |
18 | def forward(self, x):
19 | x = F.relu(self.fc1(x))
20 | x = F.relu(self.fc2(x))
21 | x = F.relu(self.fc3(x))
22 | x = F.log_softmax(x, dim=0)
23 | return x
24 |
25 | def cross_entropy_loss(self, logits, labels):
26 | """
27 | Loss Fn to compute loss
28 | """
29 | return F.nll_loss(logits, labels)
30 |
31 | def training_step(self, train_batch, batch_idx):
32 | """
33 | training the data as batches and returns training loss on each batch
34 | """
35 | x, y = train_batch
36 | logits = self.forward(x)
37 | loss = self.cross_entropy_loss(logits, y)
38 | return {"loss": loss}
39 |
40 | def validation_step(self, val_batch, batch_idx):
41 | """
42 | Performs validation of data in batches
43 | """
44 | x, y = val_batch
45 | logits = self.forward(x)
46 | loss = self.cross_entropy_loss(logits, y)
47 | return {"val_loss": loss}
48 |
49 | def validation_epoch_end(self, outputs):
50 | """
51 | Computes average validation accuracy
52 | """
53 | avg_loss = torch.stack([x["val_loss"] for x in outputs]).mean()
54 | tensorboard_logs = {"val_loss": avg_loss}
55 | return {"avg_val_loss": avg_loss, "log": tensorboard_logs}
56 |
57 | def test_step(self, test_batch, batch_idx):
58 | """
59 | Performs test and computes test accuracy
60 | """
61 |
62 | x, y = test_batch
63 | output = self.forward(x)
64 | a, y_hat = torch.max(output, dim=1)
65 | test_acc = accuracy_score(y_hat.cpu(), y.cpu())
66 | return {"test_acc": torch.tensor(test_acc)}
67 |
68 | def test_epoch_end(self, outputs):
69 | """
70 | Computes average test accuracy score
71 | """
72 | avg_test_acc = torch.stack([x["test_acc"] for x in outputs]).mean()
73 | return {"avg_test_acc": avg_test_acc}
74 |
75 | def configure_optimizers(self):
76 | """
77 | Creates and returns Optimizer
78 | """
79 |
80 | self.optimizer = torch.optim.SGD(self.parameters(), lr=0.05, momentum=0.9)
81 | return self.optimizer
82 |
83 |
84 | if __name__ == "__main__":
85 | parser = argparse.ArgumentParser()
86 | parser = pl.Trainer.add_argparse_args(parent_parser=parser)
87 | args = parser.parse_args()
88 | dict_args = vars(args)
89 |
90 | for argument in ["strategy", "accelerator", "devices"]:
91 | if dict_args[argument] == "None":
92 | dict_args[argument] = None
93 |
94 | mlflow.pytorch.autolog()
95 | model = IrisClassification()
96 | import iris_datamodule
97 |
98 | dm = iris_datamodule.IRISDataModule()
99 | dm.prepare_data()
100 | dm.setup("fit")
101 |
102 | trainer = pl.Trainer.from_argparse_args(args)
103 | trainer.fit(model, dm)
104 | trainer.test(datamodule=dm)
105 | trainer.test(datamodule=dm)
106 | if trainer.global_rank == 0:
107 | scripted_model = torch.jit.script(model)
108 | torch.jit.save(scripted_model, "iris_ts.pt")
109 |
--------------------------------------------------------------------------------
/examples/Titanic/README.md:
--------------------------------------------------------------------------------
1 | #Titanic features attribution analysis using Captum and TorchServe.
2 |
3 | In this example, we will demonstrate the basic features of the [Captum](https://captum.ai/) interpretability,and serving the model on torchserve through an example model trained on the Titanic survival data. you can download the data from [titanic](https://biostat.app.vumc.org/wiki/pub/Main/DataSets/titanic3.csv)
4 |
5 | We will first train a deep neural network on the data using PyTorch and use Captum to understand which of the features were most important and how the network reached its prediction.
6 |
7 | you can get more details about used attributions methods used in this example
8 |
9 | 1. [Titanic_Basic_Interpret](https://captum.ai/tutorials/Titanic_Basic_Interpret)
10 | 2. [integrated-gradients](https://captum.ai/docs/algorithms#primary-attribution)
11 | 3. [layer-attributions](https://captum.ai/docs/algorithms#layer-attribution)
12 |
13 |
14 | The inference service would return the prediction and avg attribution socre of features for a given target for a input test record.
15 |
16 | ### Running the code
17 |
18 | To run the example via MLflow, navigate to the `examples/Titanic/` directory and run the commands
19 |
20 | ```
21 | mlflow run .
22 |
23 | ```
24 |
25 | This will run `titanic_captum_interpret.py` with the default set of parameters such as `--max_epochs=100`. You can see the default value in the MLproject file.
26 |
27 | In order to run the file with custom parameters, run the command
28 |
29 | ```
30 | mlflow run . -P max_epochs=X
31 | ```
32 |
33 | where X is your desired value for max_epochs.
34 |
35 | If you have the required modules for the file and would like to skip the creation of a conda environment, add the argument --no-conda.
36 |
37 | ```
38 | mlflow run . --no-conda
39 | ```
40 |
41 | # Above commands will train the titanic model for further use.
42 |
43 |
44 | # Serve a custom model on TorchServe
45 |
46 | * Step - 1: Create a new model architecture file which contains model class extended from torch.nn.modules. In this example we have created [titanic model file](titanic.py).
47 | * Step - 2: Write a custom handler to run the inference on your model. In this example, we have added a [custom_handler](titanic_handler.py) which runs the inference on the input record using the above model and make prediction.
48 | * Step - 3: Create an empty directory model_store and run the following command to start torchserve.
49 |
50 | ```bash
51 | torchserve --start --model-store model_store/
52 | ```
53 |
54 | ## Creating a new deployment
55 | Run the following command to create a new deployment named `titanic`
56 |
57 | The `index_to_name.json` file is the mapping file, which will convert the discrete output of the model to one of the class (survived/not survived)
58 | based on the predefined mapping.
59 |
60 | ```bash
61 | mlflow deployments create --name titanic --target torchserve --model-uri models/titanic_state_dict.pt -C "MODEL_FILE=titanic.py" -C "HANDLER=titanic_handler.py" -C "EXTRA_FILES=index_to_name.json"
62 | ```
63 |
64 | ## Running prediction based on deployed model
65 |
66 | For testing, we are going to use a sample test record placed in test_data folder in input.json
67 |
68 | Run the following command to invoke prediction on test record, whose output is stored in output.json file.
69 |
70 | ```
71 | mlflow deployments predict --name titanic --target torchserve --input-path test_data/input.json --output-path output.json
72 | ```
73 |
74 | This model will classify the test record as survived or not survived and store it in `output.json`
75 |
76 |
77 | Run the below command to invoke explain for feature importance attributions on test record. It will save the attribution image attributions_imp.png in test_data folder.
78 |
79 | ```
80 | mlflow deployments explain -t torchserve --name titanic --input-path test_data/input.json
81 | ```
82 |
83 | this explanations command give us the average attribution for each feature. From the feature attribution information, we obtain some interesting insights regarding the importance of various features.
84 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # Mlflow-TorchServe
2 |
3 | A plugin that integrates [TorchServe](https://github.com/pytorch/serve) with MLflow pipeline.
4 | ``mlflow_torchserve`` enables mlflow users to deploy the mlflow pipeline models into TorchServe .
5 | Command line APIs of the plugin (also accessible through mlflow's python package) makes the deployment process seamless.
6 |
7 | ## Prerequisites
8 |
9 | Following are the list of packages which needs to be installed before running the TorchServe deployment plugin
10 |
11 | 1. torch-model-archiver
12 | 2. torchserve
13 | 3. mlflow
14 |
15 |
16 | ## Installation
17 | Plugin package which is available in pypi and can be installed using
18 |
19 | ```bash
20 | pip install mlflow-torchserve
21 | ```
22 | ##Installation from Source
23 |
24 | Plugin package could also be installed from source using the following commands
25 | ```
26 | python setup.py build
27 | python setup.py install
28 | ```
29 |
30 | ## What does it do
31 | Installing this package uses python's entrypoint mechanism to register the plugin into MLflow's
32 | plugin registry. This registry will be invoked each time you launch MLflow script or command line
33 | argument.
34 |
35 |
36 | ### Create deployment
37 | The `create` command line argument and ``create_deployment`` python
38 | APIs does the deployment of a model built with MLflow to TorchServe.
39 |
40 | ##### CLI
41 | ```shell script
42 | mlflow deployments create -t torchserve -m --name DEPLOYMENT_NAME -C 'MODEL_FILE=' -C 'HANDLER='
43 | ```
44 |
45 | ##### Python API
46 | ```python
47 | from mlflow.deployments import get_deploy_client
48 | target_uri = 'torchserve'
49 | plugin = get_deploy_client(target_uri)
50 | plugin.create_deployment(name=, model_uri=, config={"MODEL_FILE": , "HANDLER": })
51 | ```
52 |
53 | ### Update deployment
54 | Update API can used to modify the configuration parameters such as number of workers, version etc., of an already deployed model.
55 | TorchServe will make sure the user experience is seamless while changing the model in a live environment.
56 |
57 | ##### CLI
58 | ```shell script
59 | mlflow deployments update -t torchserve --name -C "min-worker="
60 | ```
61 |
62 | ##### Python API
63 | ```python
64 | plugin.update_deployment(name=, config={'min-worker': })
65 | ```
66 |
67 | ### Delete deployment
68 | Delete an existing deployment. Excepton will be raised if the model is not already deployed.
69 |
70 | ##### CLI
71 | ```shell script
72 | mlflow deployments delete -t torchserve --name
73 | ```
74 |
75 | ##### Python API
76 | ```python
77 | plugin.delete_deployment(name=)
78 | ```
79 |
80 | ### List all deployments
81 | Lists the names of all the models deployed on the configured TorchServe.
82 |
83 | ##### CLI
84 | ```shell script
85 | mlflow deployments list -t torchserve
86 | ```
87 |
88 | ##### Python API
89 | ```python
90 | plugin.list_deployments()
91 | ```
92 |
93 | ### Get deployment details
94 | Get API fetches the details of the deployed model. By default, Get API fetches all the versions of the
95 | deployed model.
96 |
97 | ##### CLI
98 | ```shell script
99 | mlflow deployments get -t torchserve --name
100 | ```
101 |
102 | ##### Python API
103 | ```python
104 | plugin.get_deployment(name=)
105 | ```
106 |
107 | ### Run Prediction on deployed model
108 | Predict API enables to run prediction on the deployed model.
109 |
110 | For the prediction inputs, DataFrame, Tensor and Json formats are supported. The python API supports all of these
111 | three formats. When invoked via command line, one needs to pass the json file path that contains the inputs.
112 |
113 | ##### CLI
114 | ```shell script
115 | mlflow deployments predict -t torchserve --name --input-path --output-path