├── tests ├── __init__.py ├── resources │ ├── __init__.py │ ├── sample.json │ ├── linear_model.py │ └── linear_handler.py ├── conftest.py ├── fixtures │ └── plugin_fixtures.py ├── examples │ ├── test_torchserve_examples.py │ └── test_end_to_end_example.py ├── test_plugin.py └── test_cli.py ├── .flake8 ├── examples ├── E2EBert │ ├── input.json │ ├── class_mapping.json │ ├── AGNews_Captum_Insights.png │ ├── requirements.txt │ ├── conda.yaml │ ├── MLproject │ ├── wrapper.py │ ├── README.md │ └── news_classifier_handler.py ├── BertNewsClassification │ ├── input.json │ ├── class_mapping.json │ ├── requirements.txt │ ├── conda.yaml │ ├── MLproject │ ├── news_classifier_handler.py │ └── README.md ├── Titanic │ ├── index_to_name.json │ ├── test_data │ │ ├── input.json │ │ ├── titanic_survived.csv │ │ └── titanic_not_survived.csv │ ├── conda.yaml │ ├── MLproject │ ├── titanic.py │ ├── README.md │ └── titanic_handler.py ├── IrisClassificationTorchScript │ ├── sample.json │ ├── index_to_name.json │ ├── conda.yaml │ ├── MLproject │ ├── iris_datamodule.py │ ├── README.md │ ├── iris_handler.py │ └── iris_classification.py ├── TextClassification │ ├── sample_text.txt │ ├── index_to_name.json │ ├── predict.py │ ├── README.md │ └── create_deployment.py ├── IrisClassification │ ├── index_to_name.json │ ├── sig_invalid_column_name.json │ ├── sample.json │ ├── sig_invalid_data_type.json │ ├── conda.yaml │ ├── MLproject │ ├── predict.py │ ├── create_deployment.py │ ├── iris_handler.py │ ├── iris_data_module.py │ ├── README.md │ └── iris_classification.py ├── MNIST │ ├── test_data │ │ └── one.png │ ├── index_to_name.json │ ├── conda.yaml │ ├── MLproject │ ├── register.py │ ├── predict.py │ ├── create_deployment.py │ ├── README.md │ ├── mnist_handler.py │ └── mnist_model.py ├── cifar10 │ ├── test_data │ │ └── kitten.png │ ├── index_to_name.json │ ├── conda.yaml │ ├── MLproject │ ├── inference.py │ ├── create_deployment.py │ ├── README.md │ ├── cifar10_handler.py │ ├── cifar10_train.py │ └── cifar10_datamodule.py └── README.md ├── config.properties ├── mlflow_torchserve └── config.py ├── utils └── remove-conda-envs.sh ├── setup.py ├── .github ├── ISSUE_TEMPLATE │ ├── documentation-fix.md │ ├── feature_request.md │ └── bug_report.md ├── pull_request_template.md └── workflows │ ├── ci.yml │ └── build-wheel.yml ├── .gitignore ├── README.md ├── docs └── remote-deployment.rst └── LICENSE.txt /tests/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /tests/resources/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /tests/resources/sample.json: -------------------------------------------------------------------------------- 1 | { 2 | "data": [2000] 3 | 4 | } 5 | -------------------------------------------------------------------------------- /.flake8: -------------------------------------------------------------------------------- 1 | [flake8] 2 | max-line-length = 100 3 | extend-ignore = E203 4 | -------------------------------------------------------------------------------- /examples/E2EBert/input.json: -------------------------------------------------------------------------------- 1 | { 2 | "data": ["This year business is good"] 3 | } -------------------------------------------------------------------------------- /examples/BertNewsClassification/input.json: -------------------------------------------------------------------------------- 1 | { 2 | "data": ["This year business is good"] 3 | } -------------------------------------------------------------------------------- /examples/Titanic/index_to_name.json: -------------------------------------------------------------------------------- 1 | { 2 | "0":"Not Survived", 3 | "1":"Survived" 4 | } 5 | -------------------------------------------------------------------------------- /examples/IrisClassificationTorchScript/sample.json: -------------------------------------------------------------------------------- 1 | {"data": ["[4.4000, 3.0000, 1.3000, 0.2000]"]} 2 | -------------------------------------------------------------------------------- /examples/Titanic/test_data/input.json: -------------------------------------------------------------------------------- 1 | {"input_file_path" : ["./test_data/titanic_survived.csv"]} 2 | -------------------------------------------------------------------------------- /examples/TextClassification/sample_text.txt: -------------------------------------------------------------------------------- 1 | Bloomberg has decided to publish a new report on global economic situation. 2 | -------------------------------------------------------------------------------- /examples/IrisClassification/index_to_name.json: -------------------------------------------------------------------------------- 1 | { 2 | "0": "SETOSA", 3 | "1": "VERSICOLOR", 4 | "2": "VIRGINICA" 5 | } -------------------------------------------------------------------------------- /examples/MNIST/test_data/one.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mlflow/mlflow-torchserve/HEAD/examples/MNIST/test_data/one.png -------------------------------------------------------------------------------- /examples/IrisClassificationTorchScript/index_to_name.json: -------------------------------------------------------------------------------- 1 | { 2 | "0": "SETOSA", 3 | "1": "VERSICOLOR", 4 | "2": "VIRGINICA" 5 | } -------------------------------------------------------------------------------- /tests/conftest.py: -------------------------------------------------------------------------------- 1 | # flake8: noqa 2 | from tests.fixtures.plugin_fixtures import start_torchserve, stop_torchserve, health_checkup 3 | -------------------------------------------------------------------------------- /examples/E2EBert/class_mapping.json: -------------------------------------------------------------------------------- 1 | { 2 | "0": "World", 3 | "1": "Sports", 4 | "2": "Business", 5 | "3": "Sci/Tech" 6 | } -------------------------------------------------------------------------------- /examples/TextClassification/index_to_name.json: -------------------------------------------------------------------------------- 1 | { 2 | "0":"World", 3 | "1":"Sports", 4 | "2":"Business", 5 | "3":"Sci/Tec" 6 | } 7 | -------------------------------------------------------------------------------- /examples/cifar10/test_data/kitten.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mlflow/mlflow-torchserve/HEAD/examples/cifar10/test_data/kitten.png -------------------------------------------------------------------------------- /config.properties: -------------------------------------------------------------------------------- 1 | inference_address=http://127.0.0.1:8080 2 | management_address=http://127.0.0.1:8081 3 | export_url=http://127.0.0.1:8000 4 | -------------------------------------------------------------------------------- /examples/E2EBert/AGNews_Captum_Insights.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mlflow/mlflow-torchserve/HEAD/examples/E2EBert/AGNews_Captum_Insights.png -------------------------------------------------------------------------------- /examples/BertNewsClassification/class_mapping.json: -------------------------------------------------------------------------------- 1 | { 2 | "0": "World", 3 | "1": "Sports", 4 | "2": "Business", 5 | "3": "Sci/Tech" 6 | } -------------------------------------------------------------------------------- /examples/Titanic/test_data/titanic_survived.csv: -------------------------------------------------------------------------------- 1 | age,sibsp,parch,fare,female,male,embark_C,embark_Q,embark_S,class_1,class_2,class_3 2 | 29,0,0,211.3375,1,0,0,0,1,1,0,0 -------------------------------------------------------------------------------- /examples/Titanic/test_data/titanic_not_survived.csv: -------------------------------------------------------------------------------- 1 | age,sibsp,parch,fare,female,male,embark_C,embark_Q,embark_S,class_1,class_2,class_3 2 | 30.0,1,2,151.55,0,1,0,0,1,1,0,0 3 | -------------------------------------------------------------------------------- /examples/IrisClassification/sig_invalid_column_name.json: -------------------------------------------------------------------------------- 1 | {"data": ["{\"length (cm)\":{\"0\":4.4},\"width (cm)\":{\"0\":3.2},\"length (cm)\":{\"0\":1.3},\"width (cm)\":{\"0\":0.2}}"]} 2 | -------------------------------------------------------------------------------- /examples/IrisClassification/sample.json: -------------------------------------------------------------------------------- 1 | {"data": ["{\"sepal length (cm)\":{\"0\":4.4},\"sepal width (cm)\":{\"0\":3.2},\"petal length (cm)\":{\"0\":1.3},\"petal width (cm)\":{\"0\":0.2}}"]} 2 | -------------------------------------------------------------------------------- /examples/IrisClassification/sig_invalid_data_type.json: -------------------------------------------------------------------------------- 1 | {"data": ["{\"sepal length (cm)\":{\"0\":4},\"sepal width (cm)\":{\"0\":3},\"petal length (cm)\":{\"0\":1},\"petal width (cm)\":{\"0\":0}}"]} 2 | -------------------------------------------------------------------------------- /examples/BertNewsClassification/requirements.txt: -------------------------------------------------------------------------------- 1 | torch 2 | mlflow 3 | torchserve 4 | torch-model-archiver 5 | numpy 6 | transformers 7 | sklearn 8 | tqdm 9 | torchtext 10 | pandas 11 | datasets 12 | -------------------------------------------------------------------------------- /examples/E2EBert/requirements.txt: -------------------------------------------------------------------------------- 1 | torchvision>=0.9.1 2 | torch>=1.12.0 3 | pytorch_lightning>=1.7.6 4 | mlflow 5 | torchserve 6 | torch-model-archiver 7 | numpy 8 | transformers 9 | sklearn 10 | torchtext>=0.13.0 11 | -------------------------------------------------------------------------------- /examples/MNIST/index_to_name.json: -------------------------------------------------------------------------------- 1 | { 2 | "0": "ZERO", 3 | "1": "ONE", 4 | "2": "TWO", 5 | "3": "THREE", 6 | "4": "FOUR", 7 | "5": "FIVE", 8 | "6": "SIX", 9 | "7": "SEVEN", 10 | "8": "EIGHT", 11 | "9": "NINE" 12 | } -------------------------------------------------------------------------------- /examples/IrisClassification/conda.yaml: -------------------------------------------------------------------------------- 1 | channels: 2 | - defaults 3 | - conda-forge 4 | - pytorch 5 | dependencies: 6 | - python=3.8.2 7 | - pip 8 | - pip: 9 | - sklearn 10 | - pytorch-lightning>=1.7.6 11 | - torch 12 | - torchvision 13 | - mlflow 14 | -------------------------------------------------------------------------------- /examples/cifar10/index_to_name.json: -------------------------------------------------------------------------------- 1 | { 2 | "0": "airplane", 3 | "1": "automobile", 4 | "2": "bird", 5 | "3": "cat", 6 | "4": "deer", 7 | "5": "dog", 8 | "6": "frog", 9 | "7": "horse", 10 | "8": "ship", 11 | "9": "truck" 12 | } -------------------------------------------------------------------------------- /examples/Titanic/conda.yaml: -------------------------------------------------------------------------------- 1 | channels: 2 | - pytorch 3 | - defaults 4 | - conda-forge 5 | dependencies: 6 | - python=3.8.2 7 | - pytorch>=1.9.0 8 | - pip 9 | - pip: 10 | - mlflow 11 | - pandas 12 | - captum 13 | - sklearn 14 | - ipython 15 | - prettytable 16 | 17 | -------------------------------------------------------------------------------- /examples/MNIST/conda.yaml: -------------------------------------------------------------------------------- 1 | channels: 2 | - defaults 3 | - conda-forge 4 | - pytorch 5 | dependencies: 6 | - python=3.8.2 7 | - pip 8 | - pip: 9 | - torch>=1.9.0 10 | - torchvision 11 | - matplotlib 12 | - pytorch-lightning>=1.7.6 13 | - mlflow>=1.14.0 14 | - captum 15 | -------------------------------------------------------------------------------- /examples/BertNewsClassification/conda.yaml: -------------------------------------------------------------------------------- 1 | channels: 2 | - defaults 3 | - conda-forge 4 | - pytorch 5 | dependencies: 6 | - python=3.8.2 7 | - pip 8 | - pip: 9 | - mlflow 10 | - scikit-learn 11 | - transformers 12 | - torchtext 13 | - torchvision 14 | - torch>=1.9.0 15 | - pandas 16 | - datasets 17 | -------------------------------------------------------------------------------- /examples/cifar10/conda.yaml: -------------------------------------------------------------------------------- 1 | channels: 2 | - defaults 3 | - conda-forge 4 | - pytorch 5 | dependencies: 6 | - python=3.8.2 7 | - pip 8 | - pip: 9 | - matplotlib 10 | - torch>=1.9.0 11 | - pytorch-lightning>=1.7.6 12 | - mlflow>=1.14.0 13 | - captum 14 | - webdataset 15 | - sklearn 16 | - torchvision 17 | -------------------------------------------------------------------------------- /examples/E2EBert/conda.yaml: -------------------------------------------------------------------------------- 1 | channels: 2 | - defaults 3 | - pytorch 4 | - conda-forge 5 | dependencies: 6 | - python=3.8.2 7 | - pip 8 | - pip: 9 | - mlflow 10 | - captum 11 | - sklearn 12 | - transformers 13 | - torch>=1.9.0 14 | - torchtext>=0.10.0 15 | - pytorch-lightning>=1.7.6 16 | - torchvision>=0.10.0 17 | - torchdata 18 | -------------------------------------------------------------------------------- /examples/IrisClassificationTorchScript/conda.yaml: -------------------------------------------------------------------------------- 1 | channels: 2 | - defaults 3 | - conda-forge 4 | - pytorch 5 | dependencies: 6 | - python=3.8.2 7 | - pip 8 | - pip: 9 | - mlflow 10 | - scikit-learn 11 | - torchvision 12 | - torch>=1.9.0 13 | - torchvision 14 | - pytorch-lightning>=1.7.6 15 | 16 | -------------------------------------------------------------------------------- /examples/Titanic/MLproject: -------------------------------------------------------------------------------- 1 | name: Titanic-Captum-Example 2 | 3 | conda_env: conda.yaml 4 | 5 | entry_points: 6 | main: 7 | parameters: 8 | max_epochs: {type: int, default: 50} 9 | lr: {type: float, default: 0.1} 10 | 11 | command: | 12 | python titanic_captum_interpret.py \ 13 | --max_epochs {max_epochs} \ 14 | --lr {lr} 15 | -------------------------------------------------------------------------------- /mlflow_torchserve/config.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | 4 | class Config(dict): 5 | def __init__(self): 6 | """ 7 | Initializes constants from Environment variables 8 | """ 9 | super().__init__() 10 | self["export_path"] = os.environ.get("EXPORT_PATH") 11 | self["config_properties"] = os.environ.get("CONFIG_PROPERTIES") 12 | self["torchserve_address_names"] = ["inference_address", "management_address", "export_url"] 13 | self["export_uri"] = os.environ.get("EXPORT_URL") 14 | -------------------------------------------------------------------------------- /utils/remove-conda-envs.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | set -ex 4 | 5 | mlflow_envs=$( 6 | conda env list | # list (env name, env path) pairs 7 | cut -d' ' -f1 | # extract env names 8 | grep "^mlflow-[a-z0-9]\{40\}\$" # filter envs created by mlflow 9 | ) || true 10 | 11 | if [ ! -z "$mlflow_envs" ]; then 12 | for env in $mlflow_envs 13 | do 14 | conda remove --all --yes --name $env 15 | done 16 | 17 | conda clean --all --yes 18 | conda env list 19 | fi 20 | 21 | set +ex 22 | -------------------------------------------------------------------------------- /examples/IrisClassification/MLproject: -------------------------------------------------------------------------------- 1 | name: iris-classification 2 | 3 | conda_env: conda.yaml 4 | 5 | entry_points: 6 | main: 7 | parameters: 8 | max_epochs: {type: int, default: 100} 9 | devices: {type: int, default: None} 10 | strategy: {type str, default: "None"} 11 | accelerator: {type str, default: "None"} 12 | 13 | command: | 14 | python iris_classification.py \ 15 | --max_epochs {max_epochs} \ 16 | --devices {devices} \ 17 | --strategy {strategy} \ 18 | --accelerator {accelerator} 19 | -------------------------------------------------------------------------------- /examples/IrisClassificationTorchScript/MLproject: -------------------------------------------------------------------------------- 1 | name: iris-classification 2 | 3 | conda_env: conda.yaml 4 | 5 | entry_points: 6 | main: 7 | parameters: 8 | max_epochs: {type: int, default: 100} 9 | devices: {type: int, default: None} 10 | strategy: {type str, default: "None"} 11 | accelerator: {type str, default: "None"} 12 | 13 | command: | 14 | python iris_classification.py \ 15 | --max_epochs {max_epochs} \ 16 | --devices {devices} \ 17 | --strategy {strategy} \ 18 | --accelerator {accelerator} 19 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import setup, find_packages 2 | 3 | 4 | setup( 5 | name="mlflow-torchserve", 6 | version="0.2.0", 7 | description="Torch Serve Mlflow Deployment", 8 | long_description=open("README.md").read(), 9 | long_description_content_type="text/markdown", 10 | packages=find_packages(), 11 | # Require MLflow as a dependency of the plugin, so that plugin users can simply install 12 | # the plugin & then immediately use it with MLflow 13 | install_requires=["torchserve", "torch-model-archiver", "mlflow"], 14 | entry_points={"mlflow.deployments": "torchserve=mlflow_torchserve"}, 15 | ) 16 | -------------------------------------------------------------------------------- /examples/Titanic/titanic.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | 3 | 4 | class TitanicSimpleNNModel(nn.Module): 5 | def __init__(self): 6 | super().__init__() 7 | self.linear1 = nn.Linear(12, 12) 8 | self.sigmoid1 = nn.Sigmoid() 9 | self.linear2 = nn.Linear(12, 8) 10 | self.sigmoid2 = nn.Sigmoid() 11 | self.linear3 = nn.Linear(8, 2) 12 | self.softmax = nn.Softmax(dim=1) 13 | 14 | def forward(self, x): 15 | lin1_out = self.linear1(x) 16 | sigmoid_out1 = self.sigmoid1(lin1_out) 17 | sigmoid_out2 = self.sigmoid2(self.linear2(sigmoid_out1)) 18 | return self.softmax(self.linear3(sigmoid_out2)) 19 | -------------------------------------------------------------------------------- /examples/MNIST/MLproject: -------------------------------------------------------------------------------- 1 | name: mnist-example 2 | 3 | conda_env: conda.yaml 4 | 5 | entry_points: 6 | main: 7 | parameters: 8 | max_epochs: {type: int, default: 5} 9 | devices: {type: int, default: "None"} 10 | strategy: {type: str, default: "None"} 11 | accelerator: {type: str, default: "None"} 12 | registration_name: {type: str, default: "mnist_classifier"} 13 | 14 | command: | 15 | python mnist_model.py \ 16 | --max_epochs {max_epochs} \ 17 | --devices {devices} \ 18 | --strategy {strategy} \ 19 | --accelerator {accelerator} \ 20 | --registration_name {registration_name} 21 | -------------------------------------------------------------------------------- /examples/BertNewsClassification/MLproject: -------------------------------------------------------------------------------- 1 | name: bert-classification 2 | 3 | conda_env: conda.yaml 4 | 5 | entry_points: 6 | main: 7 | parameters: 8 | max_epochs: {type: int, default: 5} 9 | num_train_samples: {type:int, default: 2000} 10 | num_test_samples: {type:int, default: 200} 11 | vocab_file: {type: str, default: 'https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-uncased-vocab.txt'} 12 | model_save_path: {type: str, default: 'models'} 13 | 14 | command: | 15 | python news_classifier.py \ 16 | --max_epochs {max_epochs} \ 17 | --num_train_samples {num_train_samples} \ 18 | --num_test_samples {num_test_samples} \ 19 | --vocab_file {vocab_file} \ 20 | --model_save_path {model_save_path} 21 | -------------------------------------------------------------------------------- /examples/E2EBert/MLproject: -------------------------------------------------------------------------------- 1 | name: bert-classification 2 | 3 | conda_env: conda.yaml 4 | 5 | entry_points: 6 | main: 7 | parameters: 8 | max_epochs: {type: int, default: 5} 9 | devices: {type: int, default: None} 10 | num_samples: {type: int, default: 1000} 11 | vocab_file: {type: str, default: 'https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-uncased-vocab.txt'} 12 | strategy: {type str, default: None} 13 | accelerator: {type str, default: None} 14 | 15 | command: | 16 | python news_classifier.py \ 17 | --max_epochs {max_epochs} \ 18 | --devices {devices} \ 19 | --num_samples {num_samples} \ 20 | --vocab_file {vocab_file} \ 21 | --strategy={strategy} \ 22 | --accelerator={accelerator} -------------------------------------------------------------------------------- /examples/cifar10/MLproject: -------------------------------------------------------------------------------- 1 | name: cifar10-example 2 | 3 | conda_env: conda.yaml 4 | 5 | entry_points: 6 | main: 7 | parameters: 8 | max_epochs: {type: int, default: 1} 9 | devices: {type: int, default: None} 10 | strategy: {type: str, default: "None"} 11 | accelerator: {type: str, default: "None"} 12 | num_samples_train: {type: int, default: 39} 13 | num_samples_val: {type: int, default: 9} 14 | num_samples_test: {type: int, default: 9} 15 | 16 | command: | 17 | python cifar10_train.py \ 18 | --max_epochs {max_epochs} \ 19 | --devices {devices} \ 20 | --num_samples_train {num_samples_train} \ 21 | --num_samples_val {num_samples_val} \ 22 | --num_samples_test {num_samples_test} \ 23 | --strategy {strategy} \ 24 | --accelerator {accelerator} 25 | -------------------------------------------------------------------------------- /examples/MNIST/register.py: -------------------------------------------------------------------------------- 1 | from argparse import ArgumentParser 2 | 3 | from mlflow.deployments import get_deploy_client 4 | 5 | 6 | def register(parser_args): 7 | plugin = get_deploy_client(parser_args["target"]) 8 | plugin.register_model(mar_file_path=parser_args["mar_file_name"]) 9 | print("Registered Successfully") 10 | 11 | 12 | if __name__ == "__main__": 13 | parser = ArgumentParser(description="MNIST hand written digits classification example") 14 | 15 | parser.add_argument( 16 | "--target", 17 | type=str, 18 | default="torchserve", 19 | help="MLflow target (default: torchserve)", 20 | ) 21 | 22 | parser.add_argument( 23 | "--mar_file_name", 24 | type=str, 25 | default="", 26 | help="mar file name to register (Ex: mnist_test.mar)", 27 | ) 28 | 29 | args = parser.parse_args() 30 | 31 | register(vars(args)) 32 | -------------------------------------------------------------------------------- /.github/ISSUE_TEMPLATE/documentation-fix.md: -------------------------------------------------------------------------------- 1 | --- 2 | name: Documentation Fix 3 | about: Use this template for proposing documentation fixes/improvements. 4 | title: "[DOC-FIX]" 5 | labels: '' 6 | assignees: '' 7 | 8 | --- 9 | 10 | Thank you for submitting an issue. Please refer to our [issue policy](https://www.github.com/mlflow/mlflow/blob/master/ISSUE_POLICY.md) for information on what types of issues we address. 11 | 12 | **Please fill in this documentation issue template to ensure a timely and thorough response.** 13 | 14 | ### Willingness to contribute 15 | The MLflow Community encourages documentation fix contributions. Would you or another member of your organization be willing to contribute a fix for this documentation issue to the MLflow TorchServe Deployment plugin code base? 16 | 17 | - [ ] Yes. I can contribute a documentation fix independently. 18 | - [ ] Yes. I would be willing to contribute a document fix with guidance from the MLflow community. 19 | - [ ] No. I cannot contribute a documentation fix at this time. 20 | 21 | ### URL(s) with the issue: 22 | 23 | Please provide a link to the documentation entry in question. 24 | 25 | ### Description of proposal (what needs changing): 26 | Provide a clear description. Why is the proposed documentation better? 27 | -------------------------------------------------------------------------------- /examples/TextClassification/predict.py: -------------------------------------------------------------------------------- 1 | import json 2 | from argparse import ArgumentParser 3 | 4 | from mlflow.deployments import get_deploy_client 5 | 6 | 7 | def predict(parser_args): 8 | 9 | with open(parser_args["input_file_path"], "r") as fp: 10 | text = fp.read() 11 | plugin = get_deploy_client(parser_args["target"]) 12 | prediction = plugin.predict(parser_args["deployment_name"], json.dumps(text)) 13 | print("Prediction Result {}".format(prediction.to_json())) 14 | 15 | 16 | if __name__ == "__main__": 17 | parser = ArgumentParser(description="Text classifier example") 18 | 19 | parser.add_argument( 20 | "--target", 21 | type=str, 22 | default="torchserve", 23 | help="MLflow target (default: torchserve)", 24 | ) 25 | 26 | parser.add_argument( 27 | "--deployment_name", 28 | type=str, 29 | default="text_classification", 30 | help="Deployment name (default: text_classification)", 31 | ) 32 | 33 | parser.add_argument( 34 | "--input_file_path", 35 | type=str, 36 | default="sample_text.txt", 37 | help="Path to input text file for prediction (default: sample_text.txt)", 38 | ) 39 | 40 | args = parser.parse_args() 41 | 42 | predict(vars(args)) 43 | -------------------------------------------------------------------------------- /examples/TextClassification/README.md: -------------------------------------------------------------------------------- 1 | # Deploying Text Classification model 2 | 3 | Download `train.py` and `model.py` from the [respository](https://github.com/pytorch/serve/tree/master/examples/text_classification) 4 | and subsequently run the following command to train the model in either CPU/GPU. 5 | 6 | CPU: `python train.py AG_NEWS --device cpu --save-model-path model.pt --dictionary source_vocab.pt` 7 | 8 | GPU: `python train.py AG_NEWS --device cuda --save-model-path model.pt --dictionary source_vocab.pt` 9 | 10 | At the end of the training, model file `model.pt` and vocabulary file `source_vocab.pt` will be stored in the current directory. 11 | 12 | ## Starting TorchServe 13 | 14 | create an empty directory `model_store` and run the following command to start torchserve. 15 | 16 | `torchserve --start --model-store model_store` 17 | 18 | ## Creating a new deployment 19 | 20 | This example uses the default TorchServe text handler to generate the mar file. 21 | 22 | To create a new deployment, run the following command 23 | 24 | `python create_deployment.py --deployment_name text_classification --model_file model.py --serialized_file model.pt --extra_files "source_vocab.pt,index_to_name.json"` 25 | 26 | ## Predicting deployed model 27 | 28 | To perform prediction, run the following script 29 | 30 | `python predict.py` 31 | 32 | The prediction results will be printed in the console. 33 | -------------------------------------------------------------------------------- /.github/pull_request_template.md: -------------------------------------------------------------------------------- 1 | ## What changes are proposed in this pull request? 2 | 3 | (Please fill in changes proposed in this fix) 4 | 5 | ## How is this patch tested? 6 | 7 | (Details) 8 | 9 | ## Release Notes 10 | 11 | ### Is this a user-facing change? 12 | 13 | - [ ] No. You can skip the rest of this section. 14 | - [ ] Yes. Give a description of this change to be included in the release notes for MLflow TorchServe Deployment Plugin users. 15 | 16 | (Details in 1-2 sentences. You can just refer to another PR with a description if this PR is part of a larger change.) 17 | 18 | ### What component(s) does this PR affect? 19 | Components 20 | - [ ] `area/deploy`: Main deployment plugin logic 21 | - [ ] `area/build`: Build and test infrastructure for MLflow TorchServe Deployment Plugin 22 | - [ ] `area/docs`: MLflow TorchServe Deployment Plugin documentation pages 23 | - [ ] `area/examples`: Example code 24 | 25 | 26 | ### How should the PR be classified in the release notes? Choose one: 27 | 28 | - [ ] `rn/breaking-change` - The PR will be mentioned in the "Breaking Changes" section 29 | - [ ] `rn/none` - No description will be included. The PR will be mentioned only by the PR number in the "Small Bugfixes and Documentation Updates" section 30 | - [ ] `rn/feature` - A new user-facing feature worth mentioning in the release notes 31 | - [ ] `rn/bug-fix` - A user-facing bug fix worth mentioning in the release notes 32 | - [ ] `rn/documentation` - A user-facing documentation change worth mentioning in the release notes 33 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Mlflow 2 | mlruns/ 3 | outputs/ 4 | 5 | # Directories and files generated when running the examples 6 | sample.json 7 | output.json 8 | bert_base_uncased_vocab.txt 9 | *.pt 10 | data/ 11 | .data/ 12 | logs/ 13 | model_store/ 14 | models/ 15 | 16 | # Mac 17 | .DS_Store 18 | 19 | # Byte-compiled / optimized / DLL files 20 | __pycache__ 21 | *.py[cod] 22 | *$py.class 23 | 24 | # C extensions 25 | *.so 26 | 27 | # Distribution / packaging 28 | .Python 29 | build/ 30 | develop-eggs/ 31 | dist/ 32 | downloads/ 33 | eggs/ 34 | .eggs/ 35 | lib/ 36 | lib64/ 37 | parts/ 38 | sdist/ 39 | var/ 40 | wheels/ 41 | *.egg-info/ 42 | .installed.cfg 43 | *.egg 44 | MANIFEST 45 | node_modules 46 | 47 | # PyInstaller 48 | # Usually these files are written by a python script from a template 49 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 50 | *.manifest 51 | *.spec 52 | 53 | # Installer logs 54 | pip-log.txt 55 | pip-delete-this-directory.txt 56 | 57 | # Unit test / coverage reports 58 | htmlcov/ 59 | .coverage 60 | .coverage.* 61 | .cache 62 | nosetests.xml 63 | coverage.xml 64 | *.cover 65 | .hypothesis/ 66 | .pytest_cache/ 67 | 68 | # Sphinx documentation 69 | docs/_build/ 70 | 71 | # Jupyter Notebook 72 | .ipynb_checkpoints 73 | 74 | # Environments 75 | env 76 | env3 77 | .env 78 | .venv 79 | env/ 80 | venv/ 81 | ENV/ 82 | env.bak/ 83 | venv.bak/ 84 | .python-version 85 | 86 | # Editor files 87 | .*project 88 | *.swp 89 | *.swo 90 | *.idea 91 | *.vscode 92 | *.iml 93 | *~ 94 | 95 | 96 | 97 | 98 | -------------------------------------------------------------------------------- /examples/MNIST/predict.py: -------------------------------------------------------------------------------- 1 | import os 2 | from argparse import ArgumentParser 3 | 4 | import matplotlib.pyplot as plt 5 | from mlflow.deployments import get_deploy_client 6 | from torchvision import transforms 7 | 8 | 9 | def predict(parser_args): 10 | plugin = get_deploy_client(parser_args["target"]) 11 | img = plt.imread(os.path.join(parser_args["input_file_path"])) 12 | mnist_transforms = transforms.Compose( 13 | [transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))] 14 | ) 15 | 16 | image_tensor = mnist_transforms(img) 17 | prediction = plugin.predict(parser_args["deployment_name"], image_tensor) 18 | print("Prediction Result {}".format(prediction.to_json())) 19 | 20 | 21 | if __name__ == "__main__": 22 | parser = ArgumentParser(description="MNIST hand written digits classification example") 23 | 24 | parser.add_argument( 25 | "--target", 26 | type=str, 27 | default="torchserve", 28 | help="MLflow target (default: torchserve)", 29 | ) 30 | 31 | parser.add_argument( 32 | "--deployment_name", 33 | type=str, 34 | default="mnist_classification", 35 | help="Deployment name (default: mnist_classification)", 36 | ) 37 | 38 | parser.add_argument( 39 | "--input_file_path", 40 | type=str, 41 | default="test_data/one.png", 42 | help="Path to input image for prediction (default: test_data/one.png)", 43 | ) 44 | 45 | args = parser.parse_args() 46 | 47 | predict(vars(args)) 48 | -------------------------------------------------------------------------------- /examples/IrisClassificationTorchScript/iris_datamodule.py: -------------------------------------------------------------------------------- 1 | import pytorch_lightning as pl 2 | from pytorch_lightning import seed_everything 3 | import torch 4 | from sklearn.datasets import load_iris 5 | from torch.utils.data import TensorDataset 6 | from torch.utils.data import random_split 7 | from torch.utils.data.dataloader import DataLoader 8 | 9 | 10 | class IRISDataModule(pl.LightningDataModule): 11 | def __init__(self): 12 | super().__init__() 13 | 14 | def prepare_data(self): 15 | """ 16 | Implementation of abstract class 17 | """ 18 | 19 | def setup(self, stage=None): 20 | 21 | # Assign train/val datasets for use in dataloaders 22 | if stage == "fit" or stage is None: 23 | iris = load_iris() 24 | df = iris.data 25 | target = iris["target"] 26 | 27 | data = torch.Tensor(df).float() 28 | labels = torch.Tensor(target).long() 29 | RANDOM_SEED = 42 30 | seed_everything(RANDOM_SEED) 31 | 32 | data_set = TensorDataset(data, labels) 33 | self.train_set, self.val_set = random_split(data_set, [130, 20]) 34 | self.train_set, self.test_set = random_split(self.train_set, [110, 20]) 35 | 36 | def train_dataloader(self): 37 | return DataLoader(self.train_set, batch_size=4) 38 | 39 | def val_dataloader(self): 40 | return DataLoader(self.val_set, batch_size=4) 41 | 42 | def test_dataloader(self): 43 | return DataLoader(self.test_set, batch_size=4) 44 | -------------------------------------------------------------------------------- /tests/fixtures/plugin_fixtures.py: -------------------------------------------------------------------------------- 1 | import atexit 2 | import json 3 | import os 4 | import shutil 5 | import subprocess 6 | import time 7 | 8 | import pytest 9 | 10 | from tests.resources import linear_model 11 | 12 | 13 | @pytest.fixture(scope="session") 14 | def start_torchserve(): 15 | linear_model.main() 16 | if not os.path.isdir("model_store"): 17 | os.makedirs("model_store") 18 | cmd = "torchserve --ncs --start --model-store {}".format("./model_store") 19 | _ = subprocess.Popen(cmd, shell=True).wait() 20 | 21 | count = 0 22 | for _ in range(5): 23 | value = health_checkup() 24 | if value is not None and value != "" and json.loads(value)["status"] == "Healthy": 25 | time.sleep(1) 26 | break 27 | else: 28 | count += 1 29 | time.sleep(5) 30 | if count >= 5: 31 | raise Exception("Unable to connect to torchserve") 32 | return True 33 | 34 | 35 | def health_checkup(): 36 | curl_cmd = "curl http://localhost:8080/ping" 37 | (value, _) = subprocess.Popen([curl_cmd], stdout=subprocess.PIPE, shell=True).communicate() 38 | return value.decode("utf-8") 39 | 40 | 41 | def stop_torchserve(): 42 | cmd = "torchserve --stop" 43 | _ = subprocess.Popen(cmd, shell=True).wait() 44 | 45 | if os.path.isdir("model_store"): 46 | shutil.rmtree("model_store") 47 | 48 | if os.path.exists("tests/resources/linear_state_dict.pt"): 49 | os.remove("tests/resources/linear_state_dict.pt") 50 | 51 | if os.path.exists("tests/resources/linear_model.pt"): 52 | os.remove("tests/resources/linear_model.pt") 53 | 54 | 55 | atexit.register(stop_torchserve) 56 | -------------------------------------------------------------------------------- /examples/README.md: -------------------------------------------------------------------------------- 1 | The examples in the folder illustrates training and deploying models using `mlflow-torchserve` plugin 2 | 3 | 1.`BertNewClassification` - Demonstrates a workflow using a pytorch based BERT model. 4 | The example illustrates 5 |
  • model training (fine tuning a pre-trained model)
  • 6 |
  • logging of model,summary, parameters and extra files at the end of training 7 |
  • deployment of the model in TorchServe
  • 8 | 9 | 2.`E2EBert` - Demonstrates a workflow using a pytorch-lightning based BERT model. 10 | The example illustrates, 11 |
  • model training (fine tuning a pre-trained model) 12 |
  • model saving and loading using mlflow autolog 13 |
  • deployment of the model in TorchServe 14 |
  • Calculating explanations using captum 15 | 16 | 3.`IrisClassification` - Demonstrates distributed training (DDP) and deployment using Iris Dataset and MLflow-torchserve plugin. This example illustrates the use of model signature. 17 | 18 | 4.`Iris_TorchScript` - Demonstrates saving TorchScript version of the Iris Classification model and deployment of the same using MLflow-torchserve plugin with `MLflow cli` commands 19 | 20 | 5.`MNIST` - Demonstrates training of MNIST handwritten digit recognition and deployment of the model using MLflow-torchserve **python plugin** - This examples illustrates on registering the model to mlflow and deploying the model in remote torchserve instance. 21 | 22 | 6.`TextClassification` - Demonstrates training of TextClassification example and deployment of the model using using MLflow-torchserve **python plugin** 23 | 24 | 7.`Titanic` - Demonstrates training of titanic dataset using pytorch and deploying the model using mlflow-torchserve plugin. This example illustrates the validation using `captum` library. -------------------------------------------------------------------------------- /examples/TextClassification/create_deployment.py: -------------------------------------------------------------------------------- 1 | from argparse import ArgumentParser 2 | from mlflow.deployments import get_deploy_client 3 | 4 | 5 | def create_deployment(parser_args): 6 | plugin = get_deploy_client(parser_args["target"]) 7 | config = { 8 | "MODEL_FILE": parser_args["model_file"], 9 | "HANDLER": parser_args["handler"], 10 | "EXTRA_FILES": "source_vocab.pt,index_to_name.json", 11 | } 12 | result = plugin.create_deployment( 13 | name=parser_args["deployment_name"], 14 | model_uri=parser_args["serialized_file"], 15 | config=config, 16 | ) 17 | 18 | print("Deployment {result} created successfully".format(result=result["name"])) 19 | 20 | 21 | if __name__ == "__main__": 22 | parser = ArgumentParser(description="Text Classifier Example") 23 | 24 | parser.add_argument( 25 | "--target", 26 | type=str, 27 | default="torchserve", 28 | help="MLflow target (default: torchserve)", 29 | ) 30 | 31 | parser.add_argument( 32 | "--deployment_name", 33 | type=str, 34 | default="text_classification", 35 | help="Deployment name (default: text_classification)", 36 | ) 37 | 38 | parser.add_argument( 39 | "--model_file", 40 | type=str, 41 | default="model.py", 42 | help="Model file path (default: model.py)", 43 | ) 44 | 45 | parser.add_argument( 46 | "--handler", 47 | type=str, 48 | default="text_classifier", 49 | help="Handler file path (default: text_classifier)", 50 | ) 51 | 52 | parser.add_argument( 53 | "--extra_files", 54 | type=str, 55 | default="source_vocab.pt,index_to_name.json", 56 | help="List of extra files", 57 | ) 58 | 59 | parser.add_argument( 60 | "--serialized_file", 61 | type=str, 62 | default="model.pt", 63 | help="Pytorch model path (default: model.pt)", 64 | ) 65 | 66 | args = parser.parse_args() 67 | 68 | create_deployment(vars(args)) 69 | -------------------------------------------------------------------------------- /examples/IrisClassification/predict.py: -------------------------------------------------------------------------------- 1 | import os 2 | import json 3 | import torch 4 | import ast 5 | from argparse import ArgumentParser 6 | 7 | from mlflow.deployments import get_deploy_client 8 | 9 | 10 | def convert_input_to_tensor(data): 11 | data = json.loads(data).get("data") 12 | input_tensor = torch.Tensor(ast.literal_eval(data[0])) 13 | return input_tensor 14 | 15 | 16 | def predict(parser_args): 17 | plugin = get_deploy_client(parser_args["target"]) 18 | input_file = parser_args["input_file_path"] 19 | if not os.path.exists(input_file): 20 | raise Exception("Unable to locate input file : {}".format(input_file)) 21 | else: 22 | with open(input_file) as fp: 23 | input_data = fp.read() 24 | 25 | data = json.loads(input_data).get("data") 26 | import pandas as pd 27 | 28 | df = pd.read_json(data[0]) 29 | for column in df.columns: 30 | df[column] = df[column].astype("double") 31 | 32 | prediction = plugin.predict(deployment_name=parser_args["deployment_name"], df=input_data) 33 | 34 | print("Prediction Result {}".format(prediction.to_json())) 35 | 36 | 37 | if __name__ == "__main__": 38 | parser = ArgumentParser(description="Iris Classifiation Model") 39 | 40 | parser.add_argument( 41 | "--target", 42 | type=str, 43 | default="torchserve", 44 | help="MLflow target (default: torchserve)", 45 | ) 46 | 47 | parser.add_argument( 48 | "--deployment_name", 49 | type=str, 50 | default="iris_classification", 51 | help="Deployment name (default: iris_classification)", 52 | ) 53 | 54 | parser.add_argument( 55 | "--input_file_path", 56 | type=str, 57 | default="sample.json", 58 | help="Path to input image for prediction (default: sample.json)", 59 | ) 60 | 61 | parser.add_argument( 62 | "--mlflow-model-uri", 63 | type=str, 64 | default="model", 65 | help="MLFlow model URI)", 66 | ) 67 | args = parser.parse_args() 68 | 69 | predict(vars(args)) 70 | -------------------------------------------------------------------------------- /examples/cifar10/inference.py: -------------------------------------------------------------------------------- 1 | import base64 2 | import json 3 | from argparse import ArgumentParser 4 | 5 | from mlflow.deployments import get_deploy_client 6 | 7 | 8 | def predict(parser_args): 9 | plugin = get_deploy_client(parser_args["target"]) 10 | image = open(parser_args["input_file_path"], "rb") # open binary file in read mode 11 | image_read = image.read() 12 | image_64_encode = base64.b64encode(image_read) 13 | bytes_array = image_64_encode.decode("utf-8") 14 | request = {"data": str(bytes_array)} 15 | 16 | inference_type = parser_args["inference_type"] 17 | if inference_type == "explanation": 18 | result = plugin.explain(parser_args["deployment_name"], json.dumps(request)) 19 | else: 20 | result = plugin.predict(parser_args["deployment_name"], json.dumps(request)) 21 | 22 | print("Prediction Result {}".format(result)) 23 | 24 | output_path = parser_args["output_file_path"] 25 | if output_path: 26 | with open(output_path, "w") as fp: 27 | fp.write(result) 28 | 29 | 30 | if __name__ == "__main__": 31 | parser = ArgumentParser(description="Cifar10 classification example") 32 | 33 | parser.add_argument( 34 | "--target", 35 | type=str, 36 | default="torchserve", 37 | help="MLflow target (default: torchserve)", 38 | ) 39 | 40 | parser.add_argument( 41 | "--deployment_name", 42 | type=str, 43 | default="cifar_test", 44 | help="Deployment name (default: cifar_test)", 45 | ) 46 | 47 | parser.add_argument( 48 | "--input_file_path", 49 | type=str, 50 | default="test_data/kitten.png", 51 | help="Path to input image for prediction (default: test_data/one.png)", 52 | ) 53 | 54 | parser.add_argument( 55 | "--output_file_path", 56 | type=str, 57 | default="", 58 | help="output path to write the result", 59 | ) 60 | 61 | parser.add_argument( 62 | "--inference_type", 63 | type=str, 64 | default="predict", 65 | help="Option to run prediction/explanation", 66 | ) 67 | 68 | args = parser.parse_args() 69 | 70 | predict(vars(args)) 71 | -------------------------------------------------------------------------------- /tests/examples/test_torchserve_examples.py: -------------------------------------------------------------------------------- 1 | import os 2 | import pytest 3 | import shutil 4 | from mlflow import cli 5 | from click.testing import CliRunner 6 | from mlflow.utils import process 7 | 8 | EXAMPLES_DIR = "examples" 9 | 10 | 11 | def get_free_disk_space(): 12 | # https://stackoverflow.com/a/48929832/6943581 13 | return shutil.disk_usage("/")[-1] / (2**30) 14 | 15 | 16 | @pytest.fixture(scope="function", autouse=True) 17 | def clean_envs_and_cache(): 18 | yield 19 | 20 | if get_free_disk_space() < 7.0: # unit: GiB 21 | process._exec_cmd(["./utils/remove-conda-envs.sh"]) 22 | 23 | 24 | @pytest.mark.parametrize( 25 | "directory, params", 26 | [ 27 | ("IrisClassification", ["-P", "max_epochs=10"]), 28 | ("MNIST", ["-P", "max_epochs=1", "-P", "register=false"]), 29 | ("IrisClassificationTorchScript", ["-P", "max_epochs=10"]), 30 | ( 31 | "BertNewsClassification", 32 | [ 33 | "-P", 34 | "max_epochs=1", 35 | "-P", 36 | "num_train_samples=100", 37 | "-P", 38 | "num_test_samples=100", 39 | ], 40 | ), 41 | ( 42 | "E2EBert", 43 | [ 44 | "-P", 45 | "max_epochs=1", 46 | "-P", 47 | "num_samples=100", 48 | ], 49 | ), 50 | ("Titanic", ["-P", "max_epochs=10", "-P", "lr=0.1"]), 51 | ( 52 | "cifar10", 53 | [ 54 | "-P", 55 | "max_epochs=1", 56 | "-P", 57 | "num_samples_train=1", 58 | "-P", 59 | "num_samples_val=1", 60 | "-P", 61 | "num_samples_test=1", 62 | ], 63 | ), 64 | ], 65 | ) 66 | def test_mlflow_run_example(directory, params): 67 | example_dir = os.path.join(EXAMPLES_DIR, directory) 68 | cli_run_list = [example_dir] + params 69 | CliRunner().invoke(cli.run, cli_run_list) 70 | # assert res.exit_code == 0, "Got non-zero exit code {0}. Output is: {1}".format( 71 | # res.exit_code, res.output 72 | # ) 73 | -------------------------------------------------------------------------------- /examples/cifar10/create_deployment.py: -------------------------------------------------------------------------------- 1 | from argparse import ArgumentParser 2 | 3 | from mlflow.deployments import get_deploy_client 4 | 5 | 6 | def create_deployment(parser_args): 7 | plugin = get_deploy_client(parser_args["target"]) 8 | config = { 9 | "MODEL_FILE": parser_args["model_file"], 10 | "HANDLER": parser_args["handler"], 11 | "EXTRA_FILES": parser_args["extra_files"], 12 | } 13 | 14 | if parser_args["export_path"] != "": 15 | config["EXPORT_PATH"] = parser_args["export_path"] 16 | 17 | result = plugin.create_deployment( 18 | name=parser_args["deployment_name"], 19 | model_uri=parser_args["model_uri"], 20 | config=config, 21 | ) 22 | 23 | print("Deployment {result} created successfully".format(result=result["name"])) 24 | 25 | 26 | if __name__ == "__main__": 27 | parser = ArgumentParser(description="Cifar10 classification example") 28 | 29 | parser.add_argument( 30 | "--target", 31 | type=str, 32 | default="torchserve", 33 | help="MLflow target (default: torchserve)", 34 | ) 35 | 36 | parser.add_argument( 37 | "--deployment_name", 38 | type=str, 39 | default="cifar_test", 40 | help="Deployment name (default: cifar_test)", 41 | ) 42 | 43 | parser.add_argument( 44 | "--model_file", 45 | type=str, 46 | default="cifar10_train.py", 47 | help="Model file path (default: cifar10_train.py)", 48 | ) 49 | 50 | parser.add_argument( 51 | "--handler", 52 | type=str, 53 | default="cifar10_handler.py", 54 | help="Handler file path (default: cifar10_handler.py)", 55 | ) 56 | 57 | parser.add_argument( 58 | "--extra_files", 59 | type=str, 60 | default="index_to_name.json", 61 | help="List of extra files", 62 | ) 63 | 64 | parser.add_argument( 65 | "--model_uri", 66 | type=str, 67 | default="resnet.pth", 68 | help="List of extra files", 69 | ) 70 | 71 | parser.add_argument( 72 | "--export_path", 73 | type=str, 74 | default="model_store", 75 | help="Path to model store (default: 'model_store')", 76 | ) 77 | 78 | args = parser.parse_args() 79 | 80 | create_deployment(vars(args)) 81 | -------------------------------------------------------------------------------- /examples/IrisClassification/create_deployment.py: -------------------------------------------------------------------------------- 1 | from argparse import ArgumentParser 2 | 3 | from mlflow.deployments import get_deploy_client 4 | 5 | 6 | def create_deployment(parser_args): 7 | plugin = get_deploy_client(parser_args["target"]) 8 | config = { 9 | "MODEL_FILE": parser_args["model_file"], 10 | "HANDLER": parser_args["handler"], 11 | "EXTRA_FILES": parser_args["extra_files"], 12 | } 13 | 14 | if parser_args["export_path"] != "": 15 | config["EXPORT_PATH"] = parser_args["export_path"] 16 | 17 | result = plugin.create_deployment( 18 | name=parser_args["deployment_name"], 19 | model_uri=parser_args["serialized_file_path"], 20 | config=config, 21 | ) 22 | 23 | print("Deployment {result} created successfully".format(result=result["name"])) 24 | 25 | 26 | if __name__ == "__main__": 27 | parser = ArgumentParser(description="Iris Classification example") 28 | 29 | parser.add_argument( 30 | "--target", 31 | type=str, 32 | default="torchserve", 33 | help="MLflow target (default: torchserve)", 34 | ) 35 | 36 | parser.add_argument( 37 | "--deployment_name", 38 | type=str, 39 | default="iris_classification", 40 | help="Deployment name (default: iris_classification)", 41 | ) 42 | 43 | parser.add_argument( 44 | "--model_file", 45 | type=str, 46 | default="iris_classification.py", 47 | help="Model file path (default: iris_classification.py)", 48 | ) 49 | 50 | parser.add_argument( 51 | "--handler", 52 | type=str, 53 | default="iris_handler.py", 54 | help="Handler file path (default: iris_handler.py)", 55 | ) 56 | 57 | parser.add_argument( 58 | "--extra_files", 59 | type=str, 60 | default="index_to_name.json,model/MLmodel", 61 | help="List of extra files", 62 | ) 63 | 64 | parser.add_argument( 65 | "--serialized_file_path", 66 | type=str, 67 | default="model", 68 | help="Pytorch model path", 69 | ) 70 | 71 | parser.add_argument( 72 | "--export_path", 73 | type=str, 74 | default="", 75 | help="Path to model store (default: '')", 76 | ) 77 | 78 | args = parser.parse_args() 79 | 80 | create_deployment(vars(args)) 81 | -------------------------------------------------------------------------------- /examples/MNIST/create_deployment.py: -------------------------------------------------------------------------------- 1 | from argparse import ArgumentParser 2 | 3 | from mlflow.deployments import get_deploy_client 4 | 5 | 6 | def create_deployment(parser_args): 7 | plugin = get_deploy_client(parser_args["target"]) 8 | config = { 9 | "MODEL_FILE": parser_args["model_file"], 10 | "HANDLER": parser_args["handler"], 11 | "EXTRA_FILES": parser_args["extra_files"], 12 | } 13 | 14 | if parser_args["export_path"] != "": 15 | config["EXPORT_PATH"] = parser_args["export_path"] 16 | 17 | result = plugin.create_deployment( 18 | name=parser_args["deployment_name"], 19 | model_uri=parser_args["registered_model_uri"], 20 | config=config, 21 | ) 22 | 23 | print("Deployment {result} created successfully".format(result=result["name"])) 24 | 25 | 26 | if __name__ == "__main__": 27 | parser = ArgumentParser(description="MNIST hand written digits classification example") 28 | 29 | parser.add_argument( 30 | "--target", 31 | type=str, 32 | default="torchserve", 33 | help="MLflow target (default: torchserve)", 34 | ) 35 | 36 | parser.add_argument( 37 | "--deployment_name", 38 | type=str, 39 | default="mnist_classification", 40 | help="Deployment name (default: mnist_classification)", 41 | ) 42 | 43 | parser.add_argument( 44 | "--model_file", 45 | type=str, 46 | default="mnist_model.py", 47 | help="Model file path (default: mnist_model.py)", 48 | ) 49 | 50 | parser.add_argument( 51 | "--handler", 52 | type=str, 53 | default="mnist_handler.py", 54 | help="Handler file path (default: mnist_handler.py)", 55 | ) 56 | 57 | parser.add_argument( 58 | "--extra_files", 59 | type=str, 60 | default="index_to_name.json", 61 | help="List of extra files", 62 | ) 63 | 64 | parser.add_argument( 65 | "--registered_model_uri", 66 | type=str, 67 | default="models:/mnist_classifier/3", 68 | help="Registered model name (default: models:/mnist_classifier/1)", 69 | ) 70 | 71 | parser.add_argument( 72 | "--export_path", 73 | type=str, 74 | default="model_store", 75 | help="Path to model store (default: 'model_store')", 76 | ) 77 | 78 | args = parser.parse_args() 79 | 80 | create_deployment(vars(args)) 81 | -------------------------------------------------------------------------------- /.github/ISSUE_TEMPLATE/feature_request.md: -------------------------------------------------------------------------------- 1 | --- 2 | name: Feature request 3 | about: Suggest an idea for this project 4 | title: "[FR]" 5 | labels: '' 6 | assignees: '' 7 | 8 | --- 9 | 10 | --- 11 | name: Feature Request 12 | about: Use this template for feature and enhancement proposals. 13 | labels: 'enhancement' 14 | title: "[FR]" 15 | --- 16 | Thank you for submitting a feature request. **Before proceeding, please review MLflow's [Issue Policy for feature requests](https://www.github.com/mlflow/mlflow/blob/master/ISSUE_POLICY.md#feature-requests) and the [MLflow Contributing Guide](https://github.com/mlflow/mlflow/blob/master/CONTRIBUTING.rst)**. 17 | 18 | **Please fill in this feature request template to ensure a timely and thorough response.** 19 | 20 | ## Willingness to contribute 21 | The MLflow Community encourages new feature contributions. Would you or another member of your organization be willing to contribute an implementation of this feature (as an enhancement to the MLflow TorchServe Deployment plugin code base)? 22 | 23 | - [ ] Yes. I can contribute this feature independently. 24 | - [ ] Yes. I would be willing to contribute this feature with guidance from the MLflow community. 25 | - [ ] No. I cannot contribute this feature at this time. 26 | 27 | ## Proposal Summary 28 | 29 | (In a few sentences, provide a clear, high-level description of the feature request) 30 | 31 | ## Motivation 32 | - What is the use case for this feature? 33 | - Why is this use case valuable to support for MLflow TorchServe Deployment plugin users in general? 34 | - Why is this use case valuable to support for your project(s) or organization? 35 | - Why is it currently difficult to achieve this use case? (please be as specific as possible about why related MLflow TorchServe Deployment plugin features and components are insufficient) 36 | 37 | ### What component(s) does this feature affect? 38 | Components 39 | - [ ] `area/deploy`: Main deployment plugin logic 40 | - [ ] `area/build`: Build and test infrastructure for MLflow TorchServe Deployment Plugin 41 | - [ ] `area/docs`: MLflow TorchServe Deployment Plugin documentation pages 42 | - [ ] `area/examples`: Example code 43 | 44 | ## Details 45 | 46 | (Use this section to include any additional information about the feature. If you have a proposal for how to implement this feature, please include it here. For implementation guidelines, please refer to the [Contributing Guide](https://github.com/mlflow/mlflow/blob/master/CONTRIBUTING.rst#contribution-guidelines).) 47 | -------------------------------------------------------------------------------- /tests/examples/test_end_to_end_example.py: -------------------------------------------------------------------------------- 1 | import mlflow 2 | import pytest 3 | import os 4 | from mlflow.utils import process 5 | 6 | 7 | @pytest.mark.usefixtures("start_torchserve") 8 | def test_mnist_example(): 9 | os.environ["MKL_THREADING_LAYER"] = "GNU" 10 | home_dir = os.getcwd() 11 | mnist_dir = "examples/MNIST" 12 | example_command = ["python", "mnist_model.py", "--max_epochs", "1", "--register", "false"] 13 | process._exec_cmd(example_command, cwd=mnist_dir) 14 | 15 | assert os.path.exists(os.path.join(mnist_dir, "model.pth")) 16 | create_deployment_command = [ 17 | "python", 18 | "create_deployment.py", 19 | "--export_path", 20 | os.path.join(home_dir, "model_store"), 21 | "--registered_model_uri", 22 | "model.pth", 23 | ] 24 | 25 | process._exec_cmd(create_deployment_command, cwd=mnist_dir) 26 | 27 | assert os.path.exists(os.path.join(home_dir, "model_store", "mnist_classification.mar")) 28 | 29 | predict_command = ["python", "predict.py"] 30 | res = process._exec_cmd(predict_command, cwd=mnist_dir) 31 | assert "ONE" in res.stdout 32 | 33 | 34 | @pytest.mark.usefixtures("start_torchserve") 35 | def test_iris_example(tmpdir): 36 | iris_dir = os.path.join("examples", "IrisClassification") 37 | home_dir = os.getcwd() 38 | example_command = ["python", os.path.join(iris_dir, "iris_classification.py")] 39 | extra_files = "{},{}".format( 40 | os.path.join(iris_dir, "index_to_name.json"), 41 | os.path.join(home_dir, "model/MLmodel"), 42 | ) 43 | process._exec_cmd(example_command, cwd=home_dir) 44 | create_deployment_command = [ 45 | "python", 46 | os.path.join(iris_dir, "create_deployment.py"), 47 | "--export_path", 48 | os.path.join(home_dir, "model_store"), 49 | "--handler", 50 | os.path.join(iris_dir, "iris_handler.py"), 51 | "--model_file", 52 | os.path.join(iris_dir, "iris_classification.py"), 53 | "--extra_files", 54 | extra_files, 55 | ] 56 | 57 | process._exec_cmd(create_deployment_command, cwd=home_dir) 58 | mlflow.end_run() 59 | assert os.path.exists(os.path.join(home_dir, "model_store", "iris_classification.mar")) 60 | predict_command = [ 61 | "python", 62 | os.path.join(iris_dir, "predict.py"), 63 | "--input_file_path", 64 | os.path.join(iris_dir, "sample.json"), 65 | ] 66 | res = process._exec_cmd(predict_command, cwd=home_dir) 67 | assert "SETOSA" in res.stdout 68 | -------------------------------------------------------------------------------- /tests/resources/linear_model.py: -------------------------------------------------------------------------------- 1 | # pylint: disable=W0223 2 | # IMPORTS SECTION # 3 | 4 | import numpy as np 5 | import torch 6 | from torch.autograd import Variable 7 | 8 | # SYNTHETIC DATA PREPARATION # 9 | 10 | x_values = [i for i in range(11)] 11 | 12 | x_train = np.array(x_values, dtype=np.float32) 13 | x_train = x_train.reshape(-1, 1) 14 | 15 | y_values = [2 * i + 1 for i in x_values] 16 | 17 | y_train = np.array(y_values, dtype=np.float32) 18 | y_train = y_train.reshape(-1, 1) 19 | 20 | 21 | # DEFINING THE NETWORK FOR REGRESSION # 22 | 23 | 24 | class LinearRegression(torch.nn.Module): 25 | def __init__(self, inputSize, outputSize): 26 | super(LinearRegression, self).__init__() 27 | self.linear = torch.nn.Linear(inputSize, outputSize) 28 | 29 | def forward(self, x): 30 | out = self.linear(x) 31 | return out 32 | 33 | 34 | def main(): 35 | # SECTION FOR HYPERPARAMETERS # 36 | 37 | inputDim = 1 38 | outputDim = 1 39 | learningRate = 0.01 40 | epochs = 100 41 | 42 | # INITIALIZING THE MODEL # 43 | 44 | model = LinearRegression(inputDim, outputDim) 45 | 46 | # FOR GPU # 47 | if torch.cuda.is_available(): 48 | model.cuda() 49 | 50 | # INITIALIZING THE LOSS FUNCTION AND OPTIMIZER # 51 | 52 | criterion = torch.nn.MSELoss() 53 | optimizer = torch.optim.SGD(model.parameters(), lr=learningRate) 54 | 55 | # TRAINING STEP # 56 | 57 | for epoch in range(epochs): 58 | # Converting inputs and labels to Variable 59 | if torch.cuda.is_available(): 60 | inputs = Variable(torch.from_numpy(x_train).cuda()) 61 | labels = Variable(torch.from_numpy(y_train).cuda()) 62 | else: 63 | inputs = Variable(torch.from_numpy(x_train)) 64 | labels = Variable(torch.from_numpy(y_train)) 65 | 66 | optimizer.zero_grad() 67 | outputs = model(inputs) 68 | 69 | loss = criterion(outputs, labels) 70 | 71 | loss.backward() 72 | optimizer.step() 73 | print("epoch {}, loss {}".format(epoch, loss.item())) 74 | 75 | # EVALUATION AND PREDICTION # 76 | 77 | with torch.no_grad(): 78 | if torch.cuda.is_available(): 79 | predicted = model(Variable(torch.from_numpy(x_train).cuda())).cpu().data.numpy() 80 | else: 81 | predicted = model(Variable(torch.from_numpy(x_train))).data.numpy() 82 | print(predicted) 83 | 84 | # SAVING THE MODEL # 85 | torch.save(model.state_dict(), "tests/resources/linear_state_dict.pt") 86 | torch.save(model, "tests/resources/linear_model.pt") 87 | -------------------------------------------------------------------------------- /.github/workflows/ci.yml: -------------------------------------------------------------------------------- 1 | name: CI 2 | 3 | on: 4 | push: 5 | branches: [master] 6 | pull_request: 7 | branches: [master] 8 | 9 | env: 10 | CONDA_BIN: /usr/share/miniconda/bin 11 | 12 | jobs: 13 | lint: 14 | runs-on: ubuntu-latest 15 | steps: 16 | - uses: actions/checkout@v2 17 | - uses: actions/setup-python@v2 18 | with: 19 | python-version: 3.7 20 | - name: Install dependencies 21 | run: pip install black==22.3.0 flake8==5.0.4 22 | - run: flake8 . 23 | - run: black --check --line-length=100 . 24 | 25 | test: 26 | runs-on: ubuntu-20.04 27 | steps: 28 | - uses: actions/checkout@v2 29 | - uses: actions/setup-python@v2 30 | with: 31 | python-version: 3.7 32 | - name: Install Java 33 | run: | 34 | sudo apt-get update 35 | sudo apt-get install openjdk-11-jdk 36 | - name: Check Java version 37 | run: | 38 | java -version 39 | - name: Enable conda 40 | run: | 41 | echo "/usr/share/miniconda/bin" >> $GITHUB_PATH 42 | - name: Install dependencies 43 | run: | 44 | pip install -e . 45 | pip install pytest==6.1.1 torchvision scikit-learn gorilla transformers torchtext matplotlib captum 46 | 47 | - name: Install pytorch lightning 48 | run: | 49 | pip install pytorch-lightning>=1.2.3 50 | 51 | - name: Add permissions for remove conda utility file 52 | run: | 53 | chmod +x utils/remove-conda-envs.sh 54 | 55 | - name: Run torchserve example 56 | run: | 57 | set -x 58 | git clone https://github.com/pytorch/serve.git 59 | cd serve/examples/image_classifier/resnet_18 60 | mkdir model_store 61 | wget --no-verbose https://download.pytorch.org/models/resnet18-5c106cde.pth 62 | 63 | torch-model-archiver \ 64 | --model-name resnet-18 \ 65 | --version 1.0 \ 66 | --model-file model.py \ 67 | --serialized-file resnet18-5c106cde.pth \ 68 | --handler image_classifier \ 69 | --export-path model_store \ 70 | --extra-files ../index_to_name.json 71 | 72 | torchserve --start --model-store model_store --models resnet-18=resnet-18.mar || true 73 | sleep 10 74 | curl -s -X POST http://127.0.0.1:8080/predictions/resnet-18 -T ../kitten.jpg 75 | sleep 3 76 | torchserve --stop 77 | 78 | - name: Run tests 79 | run: | 80 | pytest tests --color=yes --verbose --durations=5 81 | -------------------------------------------------------------------------------- /.github/ISSUE_TEMPLATE/bug_report.md: -------------------------------------------------------------------------------- 1 | --- 2 | name: Bug report 3 | about: Create a report to help us improve 4 | title: '' 5 | labels: '' 6 | assignees: '' 7 | 8 | --- 9 | 10 | --- 11 | name: Bug Report 12 | about: Use this template for reporting bugs encountered while using MLflow TorchServe Deployment Plugin. 13 | labels: 'bug' 14 | title: "[BUG]" 15 | --- 16 | Thank you for submitting an issue. Please refer to our [issue policy](https://www.github.com/mlflow/mlflow/blob/master/ISSUE_POLICY.md) for additional information about bug reports. For help with debugging your code, please refer to [Stack Overflow](https://stackoverflow.com/questions/tagged/mlflow). 17 | 18 | **Please fill in this bug report template to ensure a timely and thorough response.** 19 | 20 | ### Willingness to contribute 21 | The MLflow Community encourages bug fix contributions. Would you or another member of your organization be willing to contribute a fix for this bug to the MLflow code base? 22 | 23 | - [ ] Yes. I can contribute a fix for this bug independently. 24 | - [ ] Yes. I would be willing to contribute a fix for this bug with guidance from the MLflow community. 25 | - [ ] No. I cannot contribute a bug fix at this time. 26 | 27 | ### System information 28 | - **Have I written custom code (as opposed to using a stock example script provided in MLflow)**: 29 | - **OS Platform and Distribution (e.g., Linux Ubuntu 18.04)**: 30 | - **MLflow installed from (source or binary)**: 31 | - **MLflow version (run ``mlflow --version``)**: 32 | - **MLflow TorchServe Deployment plugin installed from (source or binary)**: 33 | - **MLflow TorchServe Deployment plugin version (run ``mlflow deployments--version``)**: 34 | - **TorchServe installed from (source or binary)**: 35 | - **TorchServe version (run ``torchserve --version``)**: 36 | - **Python version**: 37 | - **Exact command to reproduce**: 38 | 39 | ### Describe the problem 40 | Describe the problem clearly here. Include descriptions of the expected behavior and the actual behavior. 41 | 42 | ### Code to reproduce issue 43 | Provide a reproducible test case that is the bare minimum necessary to generate the problem. 44 | 45 | ### Other info / logs 46 | Include any logs or source code that would be helpful to diagnose the problem. If including tracebacks, please include the full traceback. Large logs and files should be attached. 47 | 48 | 49 | ### What component(s) does this bug affect? 50 | Components 51 | - [ ] `area/deploy`: Main deployment plugin logic 52 | - [ ] `area/build`: Build and test infrastructure for MLflow TorchServe Deployment Plugin 53 | - [ ] `area/docs`: MLflow TorchServe Deployment Plugin documentation pages 54 | - [ ] `area/examples`: Example code 55 | -------------------------------------------------------------------------------- /examples/IrisClassificationTorchScript/README.md: -------------------------------------------------------------------------------- 1 | # Deploying Iris Classification using torchserve 2 | 3 | The code, adapted from this [repository](http://chappers.github.io/2020/04/19/torch-lightning-using-iris/), 4 | is almost entirely dedicated to training, with the addition of a single mlflow.pytorch.autolog() call to enable automatic logging of params, metrics, and the TorchScript model. 5 | TorchScript allows us to save the whole model locally and load it into a different environment, such as in a server written in 6 | a completely different language. 7 | 8 | ## Training the model 9 | 10 | To run the example via MLflow, navigate to the `examples/IrisClassificationTorchScript/` directory and run the command 11 | 12 | ``` 13 | mlflow run . 14 | 15 | ``` 16 | 17 | This will run `iris_classification.py` with the default set of parameters such as `--max_epochs=100`. You can see the default value in the MLproject file. 18 | 19 | In order to run the file with custom parameters, run the command 20 | 21 | ``` 22 | mlflow run . -P max_epochs=X 23 | ``` 24 | 25 | where X is your desired value for max_epochs. 26 | 27 | If you have the required modules for the file and would like to skip the creation of a conda environment, add the argument --no-conda. 28 | 29 | ``` 30 | mlflow run . --no-conda 31 | ``` 32 | 33 | After the training, we will convert the model to a TorchScript model using the function `torch.jit.script`. 34 | At the end of the training process, scripted model is stored as `iris_ts.pt` 35 | 36 | ## Starting TorchServe 37 | 38 | create an empty directory `model_store` and run the following command to start torchserve. 39 | 40 | `torchserve --start --model-store model_store` 41 | 42 | Note: 43 | mlflow-torchserve plugin generates the mar file inside the "model_store" directory. If the `model_store` directory is not present under the current folder, 44 | the plugin creates a new directory named "model_store" and generates the mar file inside it. 45 | 46 | if the torchserve is already running with a different "model_store" location, ensure to pass the "model_store" path with the "EXPORT_PATH" config variable (`-C 'EXPORT_PATH='`) 47 | 48 | ## Creating a new deployment 49 | 50 | Run the following command to create a new deployment named `iris_test` 51 | 52 | `mlflow deployments create --name iris_test --target torchserve --model-uri iris_ts.pt -C "HANDLER=iris_handler.py" -C "EXTRA_FILES=index_to_name.json"` 53 | 54 | ## Running prediction based on deployed model 55 | 56 | Run the following command to invoke prediction of our sample input, where input.json is the sample input file and output.json stores the predicted outcome. 57 | 58 | `mlflow deployments predict --name iris_test --target torchserve --input-path sample.json --output-path output.json` 59 | -------------------------------------------------------------------------------- /examples/IrisClassificationTorchScript/iris_handler.py: -------------------------------------------------------------------------------- 1 | import ast 2 | import logging 3 | import os 4 | import numpy as np 5 | import torch 6 | from ts.torch_handler.base_handler import BaseHandler 7 | 8 | logger = logging.getLogger(__name__) 9 | 10 | 11 | class IRISClassifierHandler(BaseHandler): 12 | """ 13 | IRISClassifier handler class. This handler takes an input tensor and 14 | output the type of iris based on the input 15 | """ 16 | 17 | def __init__(self): 18 | super(IRISClassifierHandler, self).__init__() 19 | 20 | def initialize(self, context): 21 | """First try to load torchscript else load eager mode state_dict based model""" 22 | 23 | properties = context.system_properties 24 | self.map_location = "cuda" if torch.cuda.is_available() else "cpu" 25 | self.device = torch.device( 26 | self.map_location + ":" + str(properties.get("gpu_id")) 27 | if torch.cuda.is_available() 28 | else self.map_location 29 | ) 30 | self.manifest = context.manifest 31 | 32 | model_dir = properties.get("model_dir") 33 | self.batch_size = properties.get("batch_size") 34 | serialized_file = self.manifest["model"]["serializedFile"] 35 | model_pt_path = os.path.join(model_dir, serialized_file) 36 | 37 | if not os.path.isfile(model_pt_path): 38 | raise RuntimeError("Missing the model.pt file") 39 | 40 | logger.debug("Loading torchscript model") 41 | self.model = self._load_torchscript_model(model_pt_path) 42 | 43 | self.model.to(self.device) 44 | self.model.eval() 45 | 46 | logger.debug("Model file %s loaded successfully", model_pt_path) 47 | 48 | self.initialized = True 49 | 50 | def preprocess(self, data): 51 | """ 52 | preprocessing step - Reads the input array and converts it to tensor 53 | 54 | :param data: Input to be passed through the layers for prediction 55 | 56 | :return: output - Preprocessed input 57 | """ 58 | 59 | input_data_str = data[0].get("data") 60 | if input_data_str is None: 61 | input_data_str = data[0].get("body") 62 | 63 | input_data = input_data_str.decode("utf-8") 64 | input_tensor = torch.Tensor(ast.literal_eval(input_data)) 65 | return input_tensor 66 | 67 | def postprocess(self, inference_output): 68 | """ 69 | Does postprocess after inference to be returned to user 70 | 71 | :param inference_output: Output of inference 72 | 73 | :return: output - Output after post processing 74 | """ 75 | 76 | predicted_idx = str(np.argmax(inference_output.cpu().detach().numpy())) 77 | 78 | if self.mapping: 79 | return [self.mapping[str(predicted_idx)]] 80 | return [predicted_idx] 81 | 82 | 83 | _service = IRISClassifierHandler() 84 | -------------------------------------------------------------------------------- /examples/cifar10/README.md: -------------------------------------------------------------------------------- 1 | # Deploying Cifar10 image classification using torchserve 2 | 3 | This example demonstrates fine tuning of resnet model using cifar10 dataset. 4 | 5 | Follow the link given below to set backend store 6 | 7 | https://www.mlflow.org/docs/latest/tracking.html#storage 8 | 9 | ## Training the model 10 | 11 | This example, autologs the trained model and its relevant parameters and metrics into mlflow using a single line of code. 12 | The example also illustrates how one can use the python plugin to deploy and test the model. 13 | Python scripts `create_deployment.py` and `predict.py` have been used for this purpose. 14 | 15 | Run the following command to train the cifar10 model 16 | 17 | CPU: `mlflow run . -P max_epochs=5` 18 | GPU: `mlflow run . -P max_epochs=5 -P devices=2 -P strategy=ddp -P accelerator=gpu` 19 | 20 | At the end of the training, Cifar10 model will be saved as state dict (resnet.pth) in the current working directory 21 | 22 | ## Starting torchserve 23 | 24 | create an empty directory `model_store` and run the following command to start torchserve. 25 | 26 | `torchserve --start --model-store model_store` 27 | 28 | ## Creating a new deployment 29 | 30 | This example uses image path as input for prediction. 31 | 32 | To create a new deployment, run the following command 33 | 34 | `python create_deployment.py` 35 | 36 | It will create a new deployment named `cifar_test`. 37 | 38 | Following are the arguments which can be passed to create_deployment script 39 | 40 | 1. deployment name - `--deployment_name` 41 | 2. path to serialized file - `--model_uri` 42 | 3. handler file path - `--handler` 43 | 4. model file path - `--model_file` 44 | 45 | Note: 46 | if the torchserve is running with a different "model_store" locations, the model-store path 47 | can be passed as input using `--export_path` argument. 48 | 49 | For example: 50 | 51 | `python create_deployment.py --deployment_name cifar_test1 --export_path /home/ubuntu/model_store` 52 | 53 | ## Predicting deployed model 54 | 55 | To perform prediction, run the following script 56 | 57 | `python inference.py` 58 | 59 | The prediction results will be printed in the console. 60 | 61 | to save the inference output in file run the following command 62 | 63 | `python inference.py --output_file_path prediction_result.json` 64 | 65 | Following are the arguments which can be passed to predict_deployment script 66 | 67 | 1. deployment name - `--deployment_name"` 68 | 2. input file path - `--input_file_path` 69 | 3. path to write the result - `--output_file_path` 70 | 71 | 72 | ## Calculate captum explanations 73 | 74 | To perform explain request, run the following script 75 | 76 | `python inferene.py --inference_type explanation` 77 | 78 | to save the explanation output in file run the following command 79 | 80 | `python inference.py --inference_type explanation --output_file_path explanation_result.json` 81 | 82 | 83 | ## Viewing captum results 84 | 85 | Use the notebook - `Cifar10_Captum.ipynb` to view the captum results. 86 | -------------------------------------------------------------------------------- /examples/E2EBert/wrapper.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | 6 | class AGNewsmodelWrapper(nn.Module): 7 | def __init__(self, model): 8 | super(AGNewsmodelWrapper, self).__init__() 9 | self.model = model 10 | 11 | def compute_bert_outputs( # pylint: disable=no-self-use 12 | self, model_bert, embedding_input, attention_mask=None, head_mask=None 13 | ): 14 | """Computes Bert Outputs. 15 | 16 | Args: 17 | model_bert : the bert model 18 | embedding_input : input for bert embeddings. 19 | attention_mask : attention mask 20 | head_mask : head mask 21 | Returns: 22 | output : the bert output 23 | """ 24 | if attention_mask is None: 25 | attention_mask = torch.ones( # pylint: disable=no-member 26 | embedding_input.shape[0], embedding_input.shape[1] 27 | ).to(embedding_input) 28 | 29 | extended_attention_mask = attention_mask.unsqueeze(1).unsqueeze(2) 30 | 31 | extended_attention_mask = extended_attention_mask.to( 32 | dtype=next(model_bert.parameters()).dtype 33 | ) # fp16 compatibility 34 | extended_attention_mask = (1.0 - extended_attention_mask) * -10000.0 35 | 36 | if head_mask is not None: 37 | if head_mask.dim() == 1: 38 | head_mask = head_mask.unsqueeze(0).unsqueeze(0).unsqueeze(-1).unsqueeze(-1) 39 | head_mask = head_mask.expand(model_bert.config.num_hidden_layers, -1, -1, -1, -1) 40 | elif head_mask.dim() == 2: 41 | head_mask = ( 42 | head_mask.unsqueeze(1).unsqueeze(-1).unsqueeze(-1) 43 | ) # We can specify head_mask for each layer 44 | head_mask = head_mask.to( 45 | dtype=next(model_bert.parameters()).dtype 46 | ) # switch to fload if need + fp16 compatibility 47 | else: 48 | head_mask = [None] * model_bert.config.num_hidden_layers 49 | 50 | encoder_outputs = model_bert.encoder( 51 | embedding_input, extended_attention_mask, head_mask=head_mask 52 | ) 53 | sequence_output = encoder_outputs[0] 54 | pooled_output = model_bert.pooler(sequence_output) 55 | outputs = ( 56 | sequence_output, 57 | pooled_output, 58 | ) + encoder_outputs[1:] 59 | return outputs 60 | 61 | def forward(self, embeddings, attention_mask=None): 62 | """Forward function. 63 | 64 | Args: 65 | embeddings : bert embeddings. 66 | attention_mask: Attention mask value 67 | """ 68 | outputs = self.compute_bert_outputs(self.model.bert_model, embeddings, attention_mask) 69 | pooled_output = outputs[1] 70 | output = F.relu(self.model.fc1(pooled_output)) 71 | output = self.model.drop(output) 72 | output = self.model.out(output) 73 | return output 74 | -------------------------------------------------------------------------------- /examples/MNIST/README.md: -------------------------------------------------------------------------------- 1 | # Deploying MNIST Handwritten Recognition using torchserve 2 | 3 | This example requires Backend store to be set for mlflow. 4 | 5 | Follow the link given below to set backend store 6 | 7 | https://www.mlflow.org/docs/latest/tracking.html#storage 8 | 9 | ## Training the model 10 | The model is used to classify handwritten digits. 11 | This example, autologs the trained model and its relevant parameters and metrics into mlflow using a single line of code. 12 | The example also illustrates how one can use the python plugin to deploy and test the model. 13 | Python scripts `create_deployment.py` and `predict.py` have been used for this purpose. 14 | 15 | Run the following command to train the MNIST model 16 | 17 | CPU: `mlflow run . -P max_epochs=5` 18 | GPU: `mlflow run . -P max_epochs=5 -P devices=2 -P strategy=ddp -P accelerator=gpu` 19 | 20 | At the end of the training, MNIST model will be saved as state dict in the current working directory 21 | 22 | ## Deploying in remote torchserve instance 23 | 24 | To deploy the model in remote torchserve instance follow 25 | 26 | the steps in [remote-deployment.rst](../../docs/remote-deployment.rst) under `docs` folder. 27 | 28 | 29 | ## Deploying in local torchserve instance 30 | 31 | ## Starting torchserve 32 | 33 | create an empty directory `model_store` and run the following command to start torchserve. 34 | 35 | `torchserve --start --model-store model_store` 36 | 37 | ## Creating a new deployment 38 | 39 | This example uses image path as input for prediction. 40 | 41 | To create a new deployment, run the following command 42 | 43 | `python create_deployment.py` 44 | 45 | It will create a new deployment named `mnist_classification`. 46 | 47 | Following are the arguments which can be passed to create_deployment script 48 | 49 | 1. deployment name - `--deployment_name` 50 | 2. registered mlflow model uri - `--registered_model_uri` 51 | 3. handler file path - `--handler` 52 | 4. model file path - `--model_file` 53 | 54 | For example, to create another deployment the script can be triggered as 55 | 56 | `python create_deployment.py --deployment_name mnist_deployment1` 57 | 58 | Note: 59 | 60 | if the torchserve is running with a different "model_store" locations, the model-store path 61 | can be passed as input using `--export_path` argument. 62 | 63 | For example: 64 | 65 | `python create_deployment.py --deployment_name mnist_deployment1 --export_path /home/ubuntu/model_store` 66 | 67 | ## Predicting deployed model 68 | 69 | To perform prediction, run the following script 70 | 71 | `python predict.py` 72 | 73 | The prediction results will be printed in the console. 74 | 75 | Following are the arguments which can be passed to predict_deployment script 76 | 77 | 1. deployment name - `--deployment_name"` 78 | 2. input file path - `--input_file_path` 79 | 80 | For example, to perform prediction on the second deployment which we created. Run the following command 81 | 82 | `python predict.py --deployment_name mnist_deployment1 --input_file_path test_data/one.png` 83 | -------------------------------------------------------------------------------- /examples/IrisClassification/iris_handler.py: -------------------------------------------------------------------------------- 1 | import logging 2 | 3 | import numpy as np 4 | import torch 5 | import pandas as pd 6 | import os 7 | import json 8 | from ts.torch_handler.base_handler import BaseHandler 9 | from mlflow.models.model import Model 10 | from mlflow.pyfunc import _enforce_schema 11 | 12 | logger = logging.getLogger(__name__) 13 | 14 | 15 | class IRISClassifierHandler(BaseHandler): 16 | """ 17 | IRISClassifier handler class. This handler takes an input tensor and 18 | output the type of iris based on the input 19 | """ 20 | 21 | def __init__(self): 22 | super(IRISClassifierHandler, self).__init__() 23 | self.mlmodel = None 24 | 25 | def preprocess(self, data): 26 | """ 27 | preprocessing step - Reads the input array and converts it to tensor 28 | 29 | :param data: Input to be passed through the layers for prediction 30 | 31 | :return: output - Preprocessed input 32 | """ 33 | 34 | data = json.loads(data[0]["data"].decode("utf-8")) 35 | df = pd.DataFrame(data) 36 | 37 | _enforce_schema(df, self.mlmodel.get_input_schema()) 38 | 39 | input_tensor = torch.Tensor(list(df.iloc[0])) 40 | return input_tensor 41 | 42 | def extract_signature(self, mlmodel_file): 43 | self.mlmodel = Model.load(mlmodel_file) 44 | model_json = json.loads(Model.to_json(self.mlmodel)) 45 | 46 | if "signature" not in model_json.keys(): 47 | raise Exception("Model Signature not found") 48 | 49 | def initialize(self, ctx): 50 | """ 51 | First try to load torchscript else load eager mode state_dict based model 52 | :param ctx: System properties 53 | """ 54 | properties = ctx.system_properties 55 | self.device = torch.device( 56 | "cuda:" + str(properties.get("gpu_id")) if torch.cuda.is_available() else "cpu" 57 | ) 58 | model_dir = properties.get("model_dir") 59 | 60 | # Read model serialize/pt file 61 | model_pt_path = os.path.join(model_dir, "model.pth") 62 | 63 | self.model = torch.load(model_pt_path, map_location=self.device) 64 | 65 | logger.debug("Model file %s loaded successfully", model_pt_path) 66 | 67 | mapping_file_path = os.path.join(model_dir, "index_to_name.json") 68 | if os.path.exists(mapping_file_path): 69 | with open(mapping_file_path) as fp: 70 | self.mapping = json.load(fp) 71 | mlmodel_file = os.path.join(model_dir, "MLmodel") 72 | 73 | self.extract_signature(mlmodel_file=mlmodel_file) 74 | 75 | self.initialized = True 76 | 77 | def postprocess(self, inference_output): 78 | """ 79 | Does postprocess after inference to be returned to user 80 | 81 | :param inference_output: Output of inference 82 | 83 | :return: output - Output after post processing 84 | """ 85 | 86 | predicted_idx = str(np.argmax(inference_output.cpu().detach().numpy())) 87 | 88 | if self.mapping: 89 | return [self.mapping[str(predicted_idx)]] 90 | return [predicted_idx] 91 | 92 | 93 | _service = IRISClassifierHandler() 94 | -------------------------------------------------------------------------------- /.github/workflows/build-wheel.yml: -------------------------------------------------------------------------------- 1 | name: Build wheel 2 | 3 | on: 4 | push: 5 | branches: [master] 6 | pull_request: 7 | branches: [master] 8 | 9 | jobs: 10 | build: 11 | runs-on: ubuntu-latest 12 | steps: 13 | - uses: actions/checkout@v2 14 | - uses: actions/setup-python@v2 15 | with: 16 | python-version: "3.7" 17 | 18 | - name: Build wheel 19 | id: build-wheel 20 | run: | 21 | pip install wheel 22 | python setup.py bdist_wheel 23 | 24 | # set outputs 25 | wheel_path=$(find dist -type f) 26 | wheel_name=$(basename $wheel_path) 27 | wheel_size=$(stat -c %s $wheel_path) 28 | echo "::set-output name=wheel-path::${wheel_path}" 29 | echo "::set-output name=wheel-name::${wheel_name}" 30 | echo "::set-output name=wheel-size::${wheel_size}" 31 | 32 | - name: Verify wheel can be installed 33 | run: | 34 | pip install ${{ steps.build-wheel.outputs.wheel-path }} 35 | 36 | # Anyone with read access can download the uploaded wheel on GitHub. 37 | - name: Store wheel 38 | uses: actions/upload-artifact@v2 39 | if: github.event_name == 'push' 40 | with: 41 | name: ${{ steps.build-wheel.outputs.wheel-name }} 42 | path: ${{ steps.build-wheel.outputs.wheel-path }} 43 | 44 | - name: Remove old wheels 45 | uses: actions/github-script@v3 46 | if: github.event_name == 'push' 47 | env: 48 | WHEEL_SIZE: ${{ steps.build-wheel.outputs.wheel-size }} 49 | with: 50 | github-token: ${{ secrets.GITHUB_TOKEN }} 51 | script: | 52 | const { owner, repo } = context.repo; 53 | 54 | // For some reason, the newly-uploaded wheel in the previous step is not included. 55 | const artifactsResp = await github.actions.listArtifactsForRepo({ 56 | owner, 57 | repo, 58 | }); 59 | const wheels = artifactsResp.data.artifacts.filter(({ name }) => name.endsWith(".whl")); 60 | 61 | // The storage usage limit for a free github account is up to 500 MB. See the page below for details: 62 | // https://docs.github.com/en/github/setting-up-and-managing-billing-and-payments-on-github/about-billing-for-github-actions 63 | MAX_SIZE_IN_BYTES = 300_000_000; // 300 MB 64 | 65 | let index = 0; 66 | let sum = parseInt(process.env.WHEEL_SIZE); // include the newly-uploaded wheel 67 | for (const [idx, { size_in_bytes }] of wheels.entries()) { 68 | index = idx; 69 | sum += size_in_bytes; 70 | if (sum > MAX_SIZE_IN_BYTES) { 71 | break; 72 | } 73 | } 74 | 75 | if (sum <= MAX_SIZE_IN_BYTES) { 76 | return; 77 | } 78 | 79 | // Delete old wheels 80 | const promises = wheels.slice(index).map(({ id: artifact_id }) => 81 | github.actions.deleteArtifact({ 82 | owner, 83 | repo, 84 | artifact_id, 85 | }) 86 | ); 87 | Promise.all(promises).then(data => console.log(data)); 88 | 89 | -------------------------------------------------------------------------------- /examples/IrisClassification/iris_data_module.py: -------------------------------------------------------------------------------- 1 | from argparse import ArgumentParser 2 | 3 | import pytorch_lightning as pl 4 | import torch 5 | from pytorch_lightning import seed_everything 6 | from sklearn.datasets import load_iris 7 | from torch.utils.data import DataLoader, random_split, TensorDataset 8 | 9 | 10 | class IrisDataModule(pl.LightningDataModule): 11 | def __init__(self, **kwargs): 12 | """ 13 | Initialization of inherited lightning data module 14 | """ 15 | super(IrisDataModule, self).__init__() 16 | 17 | self.train_set = None 18 | self.val_set = None 19 | self.test_set = None 20 | self.args = kwargs 21 | 22 | def prepare_data(self): 23 | """ 24 | Implementation of abstract class 25 | """ 26 | 27 | def setup(self, stage=None): 28 | """ 29 | Downloads the data, parse it and split the data into train, test, validation data 30 | 31 | :param stage: Stage - training or testing 32 | """ 33 | iris = load_iris() 34 | df = iris.data 35 | target = iris["target"] 36 | 37 | data = torch.Tensor(df).float() 38 | labels = torch.Tensor(target).long() 39 | RANDOM_SEED = 42 40 | seed_everything(RANDOM_SEED) 41 | 42 | data_set = TensorDataset(data, labels) 43 | self.train_set, self.val_set = random_split(data_set, [130, 20]) 44 | self.train_set, self.test_set = random_split(self.train_set, [110, 20]) 45 | 46 | @staticmethod 47 | def add_model_specific_args(parent_parser): 48 | """ 49 | Adds model specific arguments batch size and num workers 50 | 51 | :param parent_parser: Application specific parser 52 | 53 | :return: Returns the augmented arugument parser 54 | """ 55 | parser = ArgumentParser(parents=[parent_parser], add_help=False) 56 | parser.add_argument( 57 | "--batch-size", 58 | type=int, 59 | default=128, 60 | metavar="N", 61 | help="input batch size for training (default: 16)", 62 | ) 63 | parser.add_argument( 64 | "--num-workers", 65 | type=int, 66 | default=3, 67 | metavar="N", 68 | help="number of workers (default: 3)", 69 | ) 70 | return parser 71 | 72 | def create_data_loader(self, dataset): 73 | """ 74 | Generic data loader function 75 | 76 | :param data_set: Input data set 77 | 78 | :return: Returns the constructed dataloader 79 | """ 80 | 81 | return DataLoader( 82 | dataset, batch_size=self.args["batch_size"], num_workers=self.args["num_workers"] 83 | ) 84 | 85 | def train_dataloader(self): 86 | train_loader = self.create_data_loader(dataset=self.train_set) 87 | return train_loader 88 | 89 | def val_dataloader(self): 90 | validation_loader = self.create_data_loader(dataset=self.val_set) 91 | return validation_loader 92 | 93 | def test_dataloader(self): 94 | test_loader = self.create_data_loader(dataset=self.test_set) 95 | return test_loader 96 | 97 | 98 | if __name__ == "__main__": 99 | pass 100 | -------------------------------------------------------------------------------- /examples/MNIST/mnist_handler.py: -------------------------------------------------------------------------------- 1 | import json 2 | import logging 3 | import os 4 | 5 | import torch 6 | from ts.torch_handler.image_classifier import ImageClassifier 7 | 8 | logger = logging.getLogger(__name__) 9 | 10 | 11 | class MNISTDigitHandler(ImageClassifier): 12 | """ 13 | MNISTDigitClassifier handler class. This handler takes a greyscale image 14 | and returns the digit in that image. 15 | """ 16 | 17 | def __init__(self): 18 | super(MNISTDigitHandler, self).__init__() 19 | self.mapping_file_path = None 20 | 21 | def initialize(self, ctx): 22 | """ 23 | First try to load torchscript else load eager mode state_dict based model 24 | :param ctx: System properties 25 | """ 26 | properties = ctx.system_properties 27 | self.device = torch.device( 28 | "cuda:" + str(properties.get("gpu_id")) if torch.cuda.is_available() else "cpu" 29 | ) 30 | model_dir = properties.get("model_dir") 31 | 32 | # Read model serialize/pt file 33 | model_pt_path = os.path.join(model_dir, "model.pth") 34 | from mnist_model import LightningMNISTClassifier 35 | 36 | state_dict = torch.load(model_pt_path, map_location=self.device) 37 | self.model = LightningMNISTClassifier() 38 | self.model.load_state_dict(state_dict, strict=False) 39 | self.model.to(self.device) 40 | self.model.eval() 41 | 42 | logger.debug("Model file %s loaded successfully", model_pt_path) 43 | 44 | self.mapping_file_path = os.path.join(model_dir, "index_to_name.json") 45 | 46 | self.initialized = True 47 | 48 | def preprocess(self, data): 49 | """ 50 | Scales, crops, and normalizes a PIL image for a MNIST model, 51 | returns an Numpy array 52 | :param data: Input to be passed through the layers for prediction 53 | :return: output - Preprocessed image 54 | """ 55 | image = data[0].get("data") 56 | if image is None: 57 | image = data[0].get("body") 58 | 59 | image = image.decode("utf-8") 60 | image = torch.Tensor(json.loads(image)["data"]) 61 | return image 62 | 63 | def inference(self, img): 64 | """ 65 | Predict the class (or classes) of an image using a trained deep learning model 66 | :param img: Input to be passed through the layers for prediction 67 | :return: output - Predicted label for the given input 68 | """ 69 | # Convert 2D image to 1D vector 70 | # img = np.expand_dims(img, 0) 71 | # img = torch.from_numpy(img) 72 | self.model.eval() 73 | inputs = img.to(self.device) 74 | outputs = self.model.forward(inputs) 75 | 76 | _, y_hat = outputs.max(1) 77 | predicted_idx = str(y_hat.item()) 78 | return [predicted_idx] 79 | 80 | def postprocess(self, inference_output): 81 | """ 82 | Does postprocess after inference to be returned to user 83 | 84 | :param inference_output: Output of inference 85 | 86 | :return: output - Output after post processing 87 | """ 88 | 89 | if self.mapping_file_path: 90 | with open(self.mapping_file_path) as json_file: 91 | data = json.load(json_file) 92 | inference_output = [json.dumps(data[inference_output[0]])] 93 | return inference_output 94 | return [inference_output] 95 | -------------------------------------------------------------------------------- /tests/resources/linear_handler.py: -------------------------------------------------------------------------------- 1 | # IMPORTS SECTION # 2 | 3 | import logging 4 | import os 5 | import numpy as np 6 | import torch 7 | from torch.autograd import Variable 8 | import json 9 | 10 | logger = logging.getLogger(__name__) 11 | 12 | 13 | # CLASS DEFINITION # 14 | 15 | 16 | class LinearRegressionHandler(object): 17 | def __init__(self): 18 | self.model = None 19 | self.mapping = None 20 | self.device = None 21 | self.initialized = False 22 | 23 | def initialize(self, ctx): 24 | """ 25 | Loading the saved model from the serialized file 26 | """ 27 | 28 | properties = ctx.system_properties 29 | self.device = torch.device( 30 | "cuda:" + str(properties.get("gpu_id")) if torch.cuda.is_available() else "cpu" 31 | ) 32 | model_dir = properties.get("model_dir") 33 | 34 | # Read model serialize/pt file 35 | model_pt_path = os.path.join(model_dir, "linear_state_dict.pt") 36 | # Read model definition file 37 | model_def_path = os.path.join(model_dir, "linear_model.py") 38 | 39 | if not os.path.isfile(model_def_path): 40 | model_pt_path = os.path.join(model_dir, "linear_model.pt") 41 | self.model = torch.load(model_pt_path, map_location=self.device) 42 | else: 43 | from linear_model import LinearRegression 44 | 45 | state_dict = torch.load(model_pt_path, map_location=self.device) 46 | self.model = LinearRegression(1, 1) 47 | self.model.load_state_dict(state_dict) 48 | 49 | self.model.to(self.device) 50 | self.model.eval() 51 | 52 | logger.debug("Model file %s loaded successfully", model_pt_path) 53 | self.initialized = True 54 | 55 | def preprocess(self, data): 56 | """ 57 | Preprocess the input to tensor and reshape it to be used as input to the network 58 | """ 59 | data = data[0] 60 | image = data.get("data") 61 | if image is None: 62 | image = data.get("body") 63 | image = image.decode("utf-8") 64 | number = float(json.loads(image)["data"][0]) 65 | else: 66 | number = float(image) 67 | 68 | np_data = np.array(number, dtype=np.float32) 69 | np_data = np_data.reshape(-1, 1) 70 | data_tensor = torch.from_numpy(np_data) 71 | return data_tensor 72 | 73 | def inference(self, num): 74 | 75 | """ 76 | Does inference / prediction on the preprocessed input and returns the output 77 | """ 78 | 79 | self.model.eval() 80 | inputs = Variable(num).to(self.device) 81 | outputs = self.model.forward(inputs) 82 | return [outputs.detach().item()] 83 | 84 | def postprocess(self, inference_output): 85 | 86 | """ 87 | Does post processing on the output returned from the inference method 88 | """ 89 | return inference_output 90 | 91 | 92 | # CLASS INITIALIZATION # 93 | 94 | _service = LinearRegressionHandler() 95 | 96 | 97 | def handle(data, context): 98 | 99 | """ 100 | Default handler for the inference api which takes two parameters data and context 101 | and returns the predicted output 102 | """ 103 | 104 | if not _service.initialized: 105 | _service.initialize(context) 106 | 107 | if data is None: 108 | return None 109 | 110 | data = _service.preprocess(data) 111 | data = _service.inference(data) 112 | data = _service.postprocess(data) 113 | print("Data: {}".format(data)) 114 | return data 115 | -------------------------------------------------------------------------------- /examples/IrisClassificationTorchScript/iris_classification.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | 3 | import mlflow.pytorch 4 | import pytorch_lightning as pl 5 | import torch 6 | import torch.nn as nn 7 | from sklearn.metrics import accuracy_score 8 | from torch.nn import functional as F 9 | 10 | 11 | class IrisClassification(pl.LightningModule): 12 | def __init__(self): 13 | super(IrisClassification, self).__init__() 14 | self.fc1 = nn.Linear(4, 10) 15 | self.fc2 = nn.Linear(10, 10) 16 | self.fc3 = nn.Linear(10, 3) 17 | 18 | def forward(self, x): 19 | x = F.relu(self.fc1(x)) 20 | x = F.relu(self.fc2(x)) 21 | x = F.relu(self.fc3(x)) 22 | x = F.log_softmax(x, dim=0) 23 | return x 24 | 25 | def cross_entropy_loss(self, logits, labels): 26 | """ 27 | Loss Fn to compute loss 28 | """ 29 | return F.nll_loss(logits, labels) 30 | 31 | def training_step(self, train_batch, batch_idx): 32 | """ 33 | training the data as batches and returns training loss on each batch 34 | """ 35 | x, y = train_batch 36 | logits = self.forward(x) 37 | loss = self.cross_entropy_loss(logits, y) 38 | return {"loss": loss} 39 | 40 | def validation_step(self, val_batch, batch_idx): 41 | """ 42 | Performs validation of data in batches 43 | """ 44 | x, y = val_batch 45 | logits = self.forward(x) 46 | loss = self.cross_entropy_loss(logits, y) 47 | return {"val_loss": loss} 48 | 49 | def validation_epoch_end(self, outputs): 50 | """ 51 | Computes average validation accuracy 52 | """ 53 | avg_loss = torch.stack([x["val_loss"] for x in outputs]).mean() 54 | tensorboard_logs = {"val_loss": avg_loss} 55 | return {"avg_val_loss": avg_loss, "log": tensorboard_logs} 56 | 57 | def test_step(self, test_batch, batch_idx): 58 | """ 59 | Performs test and computes test accuracy 60 | """ 61 | 62 | x, y = test_batch 63 | output = self.forward(x) 64 | a, y_hat = torch.max(output, dim=1) 65 | test_acc = accuracy_score(y_hat.cpu(), y.cpu()) 66 | return {"test_acc": torch.tensor(test_acc)} 67 | 68 | def test_epoch_end(self, outputs): 69 | """ 70 | Computes average test accuracy score 71 | """ 72 | avg_test_acc = torch.stack([x["test_acc"] for x in outputs]).mean() 73 | return {"avg_test_acc": avg_test_acc} 74 | 75 | def configure_optimizers(self): 76 | """ 77 | Creates and returns Optimizer 78 | """ 79 | 80 | self.optimizer = torch.optim.SGD(self.parameters(), lr=0.05, momentum=0.9) 81 | return self.optimizer 82 | 83 | 84 | if __name__ == "__main__": 85 | parser = argparse.ArgumentParser() 86 | parser = pl.Trainer.add_argparse_args(parent_parser=parser) 87 | args = parser.parse_args() 88 | dict_args = vars(args) 89 | 90 | for argument in ["strategy", "accelerator", "devices"]: 91 | if dict_args[argument] == "None": 92 | dict_args[argument] = None 93 | 94 | mlflow.pytorch.autolog() 95 | model = IrisClassification() 96 | import iris_datamodule 97 | 98 | dm = iris_datamodule.IRISDataModule() 99 | dm.prepare_data() 100 | dm.setup("fit") 101 | 102 | trainer = pl.Trainer.from_argparse_args(args) 103 | trainer.fit(model, dm) 104 | trainer.test(datamodule=dm) 105 | trainer.test(datamodule=dm) 106 | if trainer.global_rank == 0: 107 | scripted_model = torch.jit.script(model) 108 | torch.jit.save(scripted_model, "iris_ts.pt") 109 | -------------------------------------------------------------------------------- /examples/Titanic/README.md: -------------------------------------------------------------------------------- 1 | #Titanic features attribution analysis using Captum and TorchServe. 2 | 3 | In this example, we will demonstrate the basic features of the [Captum](https://captum.ai/) interpretability,and serving the model on torchserve through an example model trained on the Titanic survival data. you can download the data from [titanic](https://biostat.app.vumc.org/wiki/pub/Main/DataSets/titanic3.csv) 4 | 5 | We will first train a deep neural network on the data using PyTorch and use Captum to understand which of the features were most important and how the network reached its prediction. 6 | 7 | you can get more details about used attributions methods used in this example 8 | 9 | 1. [Titanic_Basic_Interpret](https://captum.ai/tutorials/Titanic_Basic_Interpret) 10 | 2. [integrated-gradients](https://captum.ai/docs/algorithms#primary-attribution) 11 | 3. [layer-attributions](https://captum.ai/docs/algorithms#layer-attribution) 12 | 13 | 14 | The inference service would return the prediction and avg attribution socre of features for a given target for a input test record. 15 | 16 | ### Running the code 17 | 18 | To run the example via MLflow, navigate to the `examples/Titanic/` directory and run the commands 19 | 20 | ``` 21 | mlflow run . 22 | 23 | ``` 24 | 25 | This will run `titanic_captum_interpret.py` with the default set of parameters such as `--max_epochs=100`. You can see the default value in the MLproject file. 26 | 27 | In order to run the file with custom parameters, run the command 28 | 29 | ``` 30 | mlflow run . -P max_epochs=X 31 | ``` 32 | 33 | where X is your desired value for max_epochs. 34 | 35 | If you have the required modules for the file and would like to skip the creation of a conda environment, add the argument --no-conda. 36 | 37 | ``` 38 | mlflow run . --no-conda 39 | ``` 40 | 41 | # Above commands will train the titanic model for further use. 42 | 43 | 44 | # Serve a custom model on TorchServe 45 | 46 | * Step - 1: Create a new model architecture file which contains model class extended from torch.nn.modules. In this example we have created [titanic model file](titanic.py). 47 | * Step - 2: Write a custom handler to run the inference on your model. In this example, we have added a [custom_handler](titanic_handler.py) which runs the inference on the input record using the above model and make prediction. 48 | * Step - 3: Create an empty directory model_store and run the following command to start torchserve. 49 | 50 | ```bash 51 | torchserve --start --model-store model_store/ 52 | ``` 53 | 54 | ## Creating a new deployment 55 | Run the following command to create a new deployment named `titanic` 56 | 57 | The `index_to_name.json` file is the mapping file, which will convert the discrete output of the model to one of the class (survived/not survived) 58 | based on the predefined mapping. 59 | 60 | ```bash 61 | mlflow deployments create --name titanic --target torchserve --model-uri models/titanic_state_dict.pt -C "MODEL_FILE=titanic.py" -C "HANDLER=titanic_handler.py" -C "EXTRA_FILES=index_to_name.json" 62 | ``` 63 | 64 | ## Running prediction based on deployed model 65 | 66 | For testing, we are going to use a sample test record placed in test_data folder in input.json 67 | 68 | Run the following command to invoke prediction on test record, whose output is stored in output.json file. 69 | 70 | ``` 71 | mlflow deployments predict --name titanic --target torchserve --input-path test_data/input.json --output-path output.json 72 | ``` 73 | 74 | This model will classify the test record as survived or not survived and store it in `output.json` 75 | 76 | 77 | Run the below command to invoke explain for feature importance attributions on test record. It will save the attribution image attributions_imp.png in test_data folder. 78 | 79 | ``` 80 | mlflow deployments explain -t torchserve --name titanic --input-path test_data/input.json 81 | ``` 82 | 83 | this explanations command give us the average attribution for each feature. From the feature attribution information, we obtain some interesting insights regarding the importance of various features. 84 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Mlflow-TorchServe 2 | 3 | A plugin that integrates [TorchServe](https://github.com/pytorch/serve) with MLflow pipeline. 4 | ``mlflow_torchserve`` enables mlflow users to deploy the mlflow pipeline models into TorchServe . 5 | Command line APIs of the plugin (also accessible through mlflow's python package) makes the deployment process seamless. 6 | 7 | ## Prerequisites 8 | 9 | Following are the list of packages which needs to be installed before running the TorchServe deployment plugin 10 | 11 | 1. torch-model-archiver 12 | 2. torchserve 13 | 3. mlflow 14 | 15 | 16 | ## Installation 17 | Plugin package which is available in pypi and can be installed using 18 | 19 | ```bash 20 | pip install mlflow-torchserve 21 | ``` 22 | ##Installation from Source 23 | 24 | Plugin package could also be installed from source using the following commands 25 | ``` 26 | python setup.py build 27 | python setup.py install 28 | ``` 29 | 30 | ## What does it do 31 | Installing this package uses python's entrypoint mechanism to register the plugin into MLflow's 32 | plugin registry. This registry will be invoked each time you launch MLflow script or command line 33 | argument. 34 | 35 | 36 | ### Create deployment 37 | The `create` command line argument and ``create_deployment`` python 38 | APIs does the deployment of a model built with MLflow to TorchServe. 39 | 40 | ##### CLI 41 | ```shell script 42 | mlflow deployments create -t torchserve -m --name DEPLOYMENT_NAME -C 'MODEL_FILE=' -C 'HANDLER=' 43 | ``` 44 | 45 | ##### Python API 46 | ```python 47 | from mlflow.deployments import get_deploy_client 48 | target_uri = 'torchserve' 49 | plugin = get_deploy_client(target_uri) 50 | plugin.create_deployment(name=, model_uri=, config={"MODEL_FILE": , "HANDLER": }) 51 | ``` 52 | 53 | ### Update deployment 54 | Update API can used to modify the configuration parameters such as number of workers, version etc., of an already deployed model. 55 | TorchServe will make sure the user experience is seamless while changing the model in a live environment. 56 | 57 | ##### CLI 58 | ```shell script 59 | mlflow deployments update -t torchserve --name -C "min-worker=" 60 | ``` 61 | 62 | ##### Python API 63 | ```python 64 | plugin.update_deployment(name=, config={'min-worker': }) 65 | ``` 66 | 67 | ### Delete deployment 68 | Delete an existing deployment. Excepton will be raised if the model is not already deployed. 69 | 70 | ##### CLI 71 | ```shell script 72 | mlflow deployments delete -t torchserve --name 73 | ``` 74 | 75 | ##### Python API 76 | ```python 77 | plugin.delete_deployment(name=) 78 | ``` 79 | 80 | ### List all deployments 81 | Lists the names of all the models deployed on the configured TorchServe. 82 | 83 | ##### CLI 84 | ```shell script 85 | mlflow deployments list -t torchserve 86 | ``` 87 | 88 | ##### Python API 89 | ```python 90 | plugin.list_deployments() 91 | ``` 92 | 93 | ### Get deployment details 94 | Get API fetches the details of the deployed model. By default, Get API fetches all the versions of the 95 | deployed model. 96 | 97 | ##### CLI 98 | ```shell script 99 | mlflow deployments get -t torchserve --name 100 | ``` 101 | 102 | ##### Python API 103 | ```python 104 | plugin.get_deployment(name=) 105 | ``` 106 | 107 | ### Run Prediction on deployed model 108 | Predict API enables to run prediction on the deployed model. 109 | 110 | For the prediction inputs, DataFrame, Tensor and Json formats are supported. The python API supports all of these 111 | three formats. When invoked via command line, one needs to pass the json file path that contains the inputs. 112 | 113 | ##### CLI 114 | ```shell script 115 | mlflow deployments predict -t torchserve --name --input-path --output-path 116 | ``` 117 | 118 | output-path is an optional parameter. Without output path parameter result will be printed in console. 119 | 120 | ##### Python API 121 | ```python 122 | plugin.predict(name=, df=) 123 | ``` 124 | 125 | ### Plugin help 126 | Run the following command to get the plugin help string. 127 | 128 | ##### CLI 129 | ```shell script 130 | mlflow deployments help -t torchserve 131 | ``` 132 | 133 | 134 | -------------------------------------------------------------------------------- /examples/BertNewsClassification/news_classifier_handler.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import os 3 | import json 4 | import numpy as np 5 | import torch 6 | from transformers import BertTokenizer 7 | 8 | logger = logging.getLogger(__name__) 9 | 10 | 11 | class NewsClassifierHandler(object): 12 | """ 13 | NewsClassifierHandler class. This handler takes a review / sentence 14 | and returns the label as either world / sports / business /sci-tech 15 | """ 16 | 17 | def __init__(self): 18 | self.model = None 19 | self.mapping = None 20 | self.device = None 21 | self.initialized = False 22 | self.class_mapping_file = None 23 | self.VOCAB_FILE = None 24 | 25 | def initialize(self, ctx): 26 | """ 27 | First try to load torchscript else load eager mode state_dict based model 28 | :param ctx: System properties 29 | """ 30 | 31 | properties = ctx.system_properties 32 | self.device = torch.device( 33 | "cuda:" + str(properties.get("gpu_id")) if torch.cuda.is_available() else "cpu" 34 | ) 35 | model_dir = properties.get("model_dir") 36 | 37 | # Read model serialize/pt file 38 | model_pt_path = os.path.join(model_dir, "model.pth") 39 | # Read model definition file 40 | model_def_path = os.path.join(model_dir, "news_classifier.py") 41 | if not os.path.isfile(model_def_path): 42 | raise RuntimeError("Missing the model definition file") 43 | 44 | self.VOCAB_FILE = os.path.join(model_dir, "bert_base_uncased_vocab.txt") 45 | if not os.path.isfile(self.VOCAB_FILE): 46 | raise RuntimeError("Missing the vocab file") 47 | 48 | self.class_mapping_file = os.path.join(model_dir, "class_mapping.json") 49 | 50 | self.model = torch.load(model_pt_path, map_location=self.device) 51 | self.model.to(self.device) 52 | self.model.eval() 53 | 54 | logger.debug("Model file %s loaded successfully", model_pt_path) 55 | self.initialized = True 56 | 57 | def preprocess(self, data): 58 | """ 59 | Receives text in form of json and converts it into an encoding for the inference stage 60 | :param data: Input to be passed through the layers for prediction 61 | :return: output - preprocessed encoding 62 | """ 63 | 64 | text = data[0].get("data") 65 | if text is None: 66 | text = data[0].get("body") 67 | 68 | text = text.decode("utf-8") 69 | 70 | tokenizer = BertTokenizer(self.VOCAB_FILE) 71 | encoding = tokenizer.encode_plus( 72 | text, 73 | max_length=32, 74 | add_special_tokens=True, # Add '[CLS]' and '[SEP]' 75 | return_token_type_ids=False, 76 | padding="max_length", 77 | return_attention_mask=True, 78 | return_tensors="pt", # Return PyTorch tensors 79 | truncation=True, 80 | ) 81 | 82 | return encoding 83 | 84 | def inference(self, encoding): 85 | """ 86 | Predict the class whether it is Positive / Neutral / Negative 87 | :param encoding: Input encoding to be passed through the layers for prediction 88 | :return: output - predicted output 89 | """ 90 | 91 | self.model.eval() 92 | inputs = encoding.to(self.device) 93 | outputs = self.model.forward(**inputs) 94 | 95 | out = np.argmax(outputs.cpu().detach()) 96 | return [out.item()] 97 | 98 | def postprocess(self, inference_output): 99 | """ 100 | Does postprocess after inference to be returned to user 101 | :param inference_output: Output of inference 102 | :return: output - Output after post processing 103 | """ 104 | if self.class_mapping_file: 105 | with open(self.class_mapping_file) as json_file: 106 | data = json.load(json_file) 107 | inference_output = json.dumps(data[str(inference_output[0])]) 108 | return [inference_output] 109 | 110 | return inference_output 111 | 112 | 113 | _service = NewsClassifierHandler() 114 | 115 | 116 | def handle(data, context): 117 | """ 118 | Default function that is called when predict is invoked 119 | :param data: Input to be passed through the layers for prediction 120 | :param context: dict containing system properties 121 | :return: output - Output after postprocess 122 | """ 123 | if not _service.initialized: 124 | _service.initialize(context) 125 | 126 | if data is None: 127 | return None 128 | 129 | data = _service.preprocess(data) 130 | data = _service.inference(data) 131 | data = _service.postprocess(data) 132 | 133 | return data 134 | -------------------------------------------------------------------------------- /examples/E2EBert/README.md: -------------------------------------------------------------------------------- 1 | 2 | # An End-2-End Deep Learning Workflow with BERT PreTrained Model 3 | 4 | An `end-2-end` workflow describing how model training is done,followed by storing all the relevant information leading to model deployment and 5 | testing. A pretrained BERT model is used to illustrate the workflow. 6 | 7 | ## Finetuning the BERT Pretrained Model 8 | The code, adapted from this [repository](https://github.com/maknotavailable/pytorch-pretrained-BERT/blob/master/pytorch_pretrained_bert/modeling.py), 9 | is almost entirely dedicated to model training, with the addition of a single mlflow.pytorch.autolog() call to enable automatic logging of params, metrics, and models, 10 | including the extra files, followed by saving the finetuned model along with extra artifact files such as the vocabulary file and class mapping file, which are essential to make the model 11 | work and help in transforming the model outputs into corresponding labels respectively. 12 | 13 | ## Package Requirement 14 | 15 | Ensure to install the `mlflow-torchserve` [prerequisites](https://github.com/mlflow/mlflow-torchserve#prerequisites) and 16 | [package](https://github.com/mlflow/mlflow-torchserve#installation) in your current python environment before starting. 17 | 18 | Install the required packages using the following command 19 | 20 | `pip install -r requirements.txt` 21 | 22 | 23 | ### Running the code 24 | To run the example via MLflow, navigate to the `mlflow-torchserve/examples/E2EBert` directory and run the command 25 | 26 | ``` 27 | mlflow run . 28 | ``` 29 | 30 | This will run `news_classifier.py` with the default set of parameters such as `--max_epochs=5`. You can see the default value in the `MLproject` file. 31 | 32 | In order to run the file with custom parameters, run the command 33 | 34 | ``` 35 | mlflow run . -P max_epochs=X 36 | ``` 37 | 38 | where `X` is your desired value for `max_epochs`. 39 | 40 | If you have the required modules for the file and would like to skip the creation of a conda environment, add the argument `--no-conda`. 41 | 42 | ``` 43 | mlflow run . --no-conda 44 | 45 | ``` 46 | 47 | To run it in gpu, use the following command 48 | 49 | ``` 50 | mlflow run . -P devices=2 -P strategy=ddp -P accelerator=gpu 51 | ``` 52 | 53 | Run the `news_classifier.py` script which will fine tune the model based on news dataset. 54 | 55 | By default, the script exports the model file as `state_dict.pth` and generates a sample input file `input.json` 56 | 57 | 58 | 59 | ### Passing custom training parameters 60 | 61 | The parameters can be overridden via the command line: 62 | 63 | 1. max_epochs - Number of epochs to train model. Training can be interrupted early via Ctrl+C 64 | 2. num_samples -Number of input samples required for training 65 | 66 | 67 | 68 | For example: 69 | ``` 70 | mlflow run . -P max_epochs=5 71 | ``` 72 | 73 | Or to run the training script directly with custom parameters: 74 | 75 | ``` 76 | python news_classifier.py \ 77 | --max_epochs 5 78 | ``` 79 | 80 | ## Starting TorchServe 81 | 82 | Create an empty directory `model_store` and run the following command to start TorchServe. 83 | 84 | `torchserve --start --model-store model_store` 85 | 86 | ## Creating a new deployment 87 | 88 | Run the following command to create a new deployment named `news_classification_test` 89 | 90 | `mlflow deployments create -t torchserve -m state_dict.pth --name news_classification_test -C "MODEL_FILE=news_classifier.py" -C "HANDLER=news_classifier_handler.py" -C "EXTRA_FILES=class_mapping.json,bert_base_uncased_vocab.txt,wrapper.py"` 91 | 92 | Note: TorchServe plugin determines the version number by itself based on the deployment name. hence, version number 93 | is not a mandatory argument for the plugin. For example, the above command will create a deployment `news_classification_test` with version 1. 94 | 95 | If needed, version number can also be explicitly mentioned as a config variable. 96 | 97 | `mlflow deployments create -t torchserve -m state_dict.pth --name news_classification_test -C "VERSION=1.0" -C "MODEL_FILE=news_classifier.py" -C "HANDLER=news_classifier_handler.py" -C "EXTRA_FILES=class_mapping.json,bert_base_uncased_vocab.txt,wrapper.py"` 98 | 99 | Note: 100 | 101 | By default, the mlfow-torchserve plugin generates the mar file inside the "model_store" directory. If the model store directory is not present under the current folder, 102 | the plugin creates a new directory named "model_store" and generates the mar file inside it. 103 | 104 | if the torchserve is already running with a different "model_store" location, ensure to pass the "model_store" path with the "EXPORT_PATH" config variable (`-C 'EXPORT_PATH='`) 105 | 106 | 107 | ## Running prediction and explain based on deployed model 108 | 109 | For testing the fine tuned model, a sample input text is placed in `input.json` 110 | 111 | Run the following command to invoke prediction of our sample input 112 | 113 | `mlflow deployments predict --name news_classification_test --target torchserve --input-path input.json --output-path output.json` 114 | 115 | Run the following command to invoke explain of our sample input 116 | 117 | 118 | `mlflow deployments explain --name news_classification_test --target torchserve --input-path input.json --output-path output.json` 119 | 120 | All the captum Insights visualization can be seen in the jupyter notebook added in this example -------------------------------------------------------------------------------- /examples/IrisClassification/README.md: -------------------------------------------------------------------------------- 1 | # Deploying Iris Classification using TorchServe 2 | 3 | The code, adapted from this [repository](http://chappers.github.io/2020/04/19/torch-lightning-using-iris/), 4 | This example illustrates the process of logging the model signature and performing validations during the deployment phase. 5 | The model is trained using all the IRIS dataset features namely sepal-length,sepal-width,petal-length,petal-width. On train completion, the model along with its signature is saved using `mlflow.pytorch.save_model` 6 | After deployment, the model first validates the input signature and predicts the test input as belonging to one of the IRIS flower species namely SETOSA` , `VERSICOLOR`, `VIRGINICA`. 7 | 8 | ### Running the code 9 | 10 | To run the example via MLflow, navigate to the `examples/IrisClassification/` directory and run the command 11 | 12 | ``` 13 | mlflow run . 14 | 15 | ``` 16 | 17 | This will run `iris_classification.py` with the default set of parameters such as `--max_epochs=100`. You can see the default value in the MLproject file. 18 | 19 | In order to run the file with custom parameters, run the command 20 | 21 | ``` 22 | mlflow run . -P max_epochs=X 23 | ``` 24 | 25 | where X is your desired value for max_epochs. 26 | 27 | If you have the required modules for the file and would like to skip the creation of a conda environment, add the argument --no-conda. 28 | 29 | ``` 30 | mlflow run . --no-conda 31 | ``` 32 | 33 | To run it in gpu, use the following command 34 | 35 | ``` 36 | mlflow run . -P devices=2 -P strategy="ddp -P accelerator=gpu" 37 | ``` 38 | 39 | At the end of the training process, the model and its signature is saved in `model` directory. 40 | 41 | ## Starting TorchServe 42 | 43 | Create an empty directory `model_store` and run the following command to start torchserve. 44 | 45 | `torchserve --start --model-store model_store` 46 | 47 | ## Creating a new deployment 48 | 49 | Run the following command to create a new deployment named `iris_classification` 50 | 51 | `python create_deployment.py --extra_files "index_to_name.json,model/MLmodel"` 52 | 53 | The default parameters are set in the create_deployment.py script. The arguments can be overriden by parser arguments. 54 | 55 | Note: 56 | mlflow-torchserve plugin generates the mar file inside the "model_store" directory. If the `model_store` directory is not present under the current folder, 57 | the plugin creates a new directory named "model_store" and generates the mar file inside it. 58 | 59 | if the torchserve is already running with a different "model_store" location, ensure to pass the "model_store" path with the "EXPORT_PATH" config variable (`-C 'EXPORT_PATH='`) 60 | 61 | ## Running prediction based on deployed model 62 | 63 | `python create_deployment.py --deployment_name iris_classification_1 --serialized_file_path ` 64 | 65 | Note: 66 | MLflow stores the model signature inside MLmodel file and it is important to pass the MLmodel path as an EXTRA_FILE argument to create_deployment script. 67 | 68 | ## Validating the model signature and running prediction based on deployed model 69 | 70 | IrisClassification model takes 4 different parameters - sepal length, sepal width, petal length and petal width. 71 | 72 | For testing [iris dataset](http://archive.ics.uci.edu/ml/datasets/Iris/), we are going to use a sample input tensor placed in `sample.json` file. 73 | 74 | Run the following command to invoke prediction of our sample input, whose output is printed in the console. 75 | 76 | `python predict.py --input_file_path sample.json` 77 | 78 | `mlflow-torchserve-plugin` validates the input data against the model signature saved during the training process. 79 | 80 | Following are the enforcements made during the model signature validation process 81 | 82 | 1. Number of columns present in the input - In this example, the number of columns is 4. 83 | 2. The column names should match the names specified in the model signature (i.e during training process). The column names for iris classification examples are expected to be `sepal length (cm)`, `sepal width (cm)`, `petal length (cm)`, `petal width (cm)` 84 | 3. The input values must match the datatype specified in model signature - In this example the input values must be `double` 85 | 86 | If any of the above mentioned enforcements fails, the validation exception is raised. 87 | 88 | `sample.json` is created with align to all the enforcements. 89 | 90 | To know more about the model signature implementation in details, check the `iris_handler.py` and 'mlflow_torchserve/SignatureValidator.py' 91 | 92 | The model will classify the flower species based on the input test data as one among the three types. A sample output is shown as below. 93 | 94 | ```Prediction Result SETOSA``` 95 | 96 | 97 | To understand the model signature and its output, the following sample files are created. Run the following command to see the validation errors. 98 | 99 | `python predict.py --input_file_path sig_invalid_column_name.json` 100 | 101 | When the input column name doesn't match with the model signature, validation exception is thrown as below. 102 | 103 | ```mlflow.exceptions.MlflowException: Model is missing inputs ['sepal length (cm)', 'sepal width (cm)', 'petal length (cm)', 'petal width (cm)']. Note that there were extra inputs: ['length (cm)', 'width (cm)']``` 104 | 105 | 106 | `python predict.py --input_file_path sig_invalid_data_type.json` 107 | 108 | When the input data type doesn't match with the model signature, validation exception is thrown as below. 109 | 110 | ```mlflow.exceptions.MlflowException: Incompatible input types for column sepal length (cm). Can not safely convert int64 to float64.``` -------------------------------------------------------------------------------- /docs/remote-deployment.rst: -------------------------------------------------------------------------------- 1 | .. _remote deployment: 2 | 3 | ============================================== 4 | Steps for deploying model in Remote TorchServe 5 | ============================================== 6 | 7 | Steps to be done in remote: 8 | =========================== 9 | 10 | Start the torchserve instance in remote server 11 | 12 | .. code-block:: 13 | 14 | torchserve --start --model-store model_store --ts-config config.properties 15 | 16 | Ignore this step if torchserve is already running on the remote server. 17 | 18 | Steps to be done in local: 19 | ========================== 20 | 21 | Verify Remote TorchServe Connectivity: 22 | -------------------------------------- 23 | 24 | Run the following command to test the remote torchserve connectivity. 25 | 26 | .. code-block:: 27 | 28 | curl "http://:8081/models" 29 | 30 | This command should retrieve the list of existing models present in remote torchserve. This is an optional step. If you are sure about the connectivity proceed with the deployment steps as stated below.. 31 | 32 | Training the model: 33 | ------------------- 34 | 35 | Clone the code from - `https://github.com/mlflow/mlflow-torchserve `_ and move to `mlflow-torchserve/examples/MNIST` folder. 36 | 37 | Train the model by running the following command: 38 | 39 | .. code-block:: 40 | 41 | mlflow run . -P registration_name=mnist_classifier 42 | 43 | The script will train the model and at the end of the training process, the model is registered into mlflow as `mnist_classifier`. 44 | 45 | 46 | Setting up Config Properties: 47 | ----------------------------- 48 | 49 | Set the management and inference URL in config properties file. `config.properties` file is placed in the home directory of the repository. 50 | 51 | For example: 52 | 53 | .. code-block:: 54 | 55 | inference_address=http://:8080 56 | management_address=http://:8081 57 | 58 | Setting Environment variable: 59 | ----------------------------- 60 | 61 | Once the config properties file is updated with the remote TorchServe instance details. Set the environment variable CONFIG_PROPERTIES with the path of the config.properties file. 62 | 63 | For example: 64 | 65 | .. code-block:: 66 | 67 | export CONFIG_PROPERTIES=/home/ubuntu/mlflow-torchserve/config.properties 68 | 69 | Install mlflow-torchserve Plugin: 70 | --------------------------------- 71 | 72 | Ignore this step if the mlflow-torchserve plugin is already installed. 73 | 74 | Install torchserve plugin using the following command 75 | 76 | .. code-block:: 77 | 78 | pip install mlflow-torchserve 79 | 80 | Creating a new deployment: 81 | -------------------------- 82 | 83 | 84 | Run the following script to start with the deployment process. 85 | 86 | .. code-block:: 87 | 88 | python create_deployment.py --deployment_name mnist_test --registered_model_uri models:/mnist_classifier/1 89 | 90 | This comment will generate a `mnist_test.mar` file inside the `model_store` folder. 91 | 92 | Since, the model needs to be deployed on the remote torchserve, the mar file needs to be exposed as a public url. 93 | 94 | Here is an example of hosting the mar file using python http server and ngrok. Any alternate mechanism can be used to expose the mar file as public url (For ex: uploading it into a S3 bucket and assigning necessary permissions to download it from http/https url). 95 | 96 | Start the http server from model store as below 97 | 98 | .. code-block:: 99 | 100 | cd model_store 101 | python -m http.server 102 | 103 | This is to host the file in the local instance. 104 | The verification can be done by downloading the file from the browser or from terminal using wget. 105 | 106 | Open the browser and hit - `http://localhost:8000/mnist_test.mar `_ 107 | 108 | Or in the terminal do `wget` `http://localhost:8000/mnist_test.mar `_ 109 | 110 | The mnist.mar file will be downloaded. However, remote torchserve instance, doesnt understand the mar file hosted in localhost. 111 | 112 | Download and unzip ngrok file from the following url - `https://ngrok.com/download `_ 113 | 114 | Run the following command to run ngrok - 115 | 116 | .. code-block:: 117 | 118 | ./ngrok http 8000 119 | 120 | Copy the web address from the forwarding section and update the EXPORT_URL parameter in config.properties file. 121 | 122 | For example: 123 | 124 | .. code-block:: 125 | 126 | inference_address=http://:8080 127 | management_address=http://:8081 128 | export_url= http://eda154810618.ngrok.io 129 | 130 | 131 | Download the mar file using ngrok url . Open browser and hit 132 | 133 | .. code-block:: 134 | 135 | http://eda154810618.ngrok.io/mnist_test.mar 136 | 137 | mnist_test.mar file should be downloaded. 138 | 139 | We are all set for performing registration. To register the model in remote torchserve instance run 140 | 141 | .. code-block:: 142 | 143 | python register.py --mar_file_name mnist_test.mar 144 | 145 | The plugin will download the mar file from ngrok url and register the model in the remote TorchServe instance. 146 | 147 | 148 | .. code-block:: 149 | 150 | mlflow deployments list -t torchserve 151 | 152 | This command will list the mnist_model which is registered in a remote TorchServe instance. 153 | 154 | Prediction: 155 | ----------- 156 | 157 | The model is registered in the remote TorchServe instance and ready for prediction. For running sample prediction invoke the prediction script as below 158 | 159 | .. code-block:: 160 | 161 | python predict.py --deployment_name mnist_test 162 | 163 | Prediction result “ONE” will be displayed in the console. 164 | 165 | 166 | -------------------------------------------------------------------------------- /examples/BertNewsClassification/README.md: -------------------------------------------------------------------------------- 1 | # Deploying BERT - News Classification using TorchServe 2 | 3 | The code, adapted from this [repository](https://github.com/maknotavailable/pytorch-pretrained-BERT/blob/master/pytorch_pretrained_bert/modeling.py), 4 | is almost entirely dedicated to model training (fine tuning). Pytorch model including the extra files such as the vocabulary file and class mapping file, which are essential to make the model functional 5 | are saved locally using the function `mlflow.pytorch.save_model`. By default, the script exports the model file as `bert_pytorch.pt` and generates a sample input file `input.json`. 6 | This example workflow includes the following steps, 7 | 1. A pre trained Hugging Face bert model is fine-tuned to classify news. 8 | 2. Model is saved with extra files model,summary, parameters and extra files at the end of training 9 | 3. Deployment of the model in TorchServe. 10 | 11 | Torchserve deployment plugin has the ability to detect and add the `requirements.txt` and the extra files. And hence, during the 12 | mar file generation, TorchServe automatically bundles the `requirements.txt`and extra files along with the model. 13 | 14 | 15 | 16 | ## Package Requirement 17 | 18 | Ensure to install the `mlflow-torchserve` [prerequisites](https://github.com/mlflow/mlflow-torchserve#prerequisites) and 19 | [package](https://github.com/mlflow/mlflow-torchserve#installation) in your current python environment before starting. 20 | 21 | Install the required packages using the following command 22 | 23 | `pip install -r requirements.txt` 24 | 25 | 26 | ### Running the code 27 | To run the example via MLflow, navigate to the `mlflow-torchserve/examples/BertNewsClassification_E2E` directory and run the command 28 | 29 | ``` 30 | mlflow run . 31 | ``` 32 | 33 | This will run `news_classifier.py` with the default set of parameters such as `--max_epochs=5`. You can see the default value in the `MLproject` file. 34 | 35 | In order to run the file with custom parameters, run the command 36 | 37 | ``` 38 | mlflow run . -P max_epochs=X 39 | ``` 40 | 41 | where `X` is your desired value for `max_epochs`. 42 | 43 | If you have the required modules for the file and would like to skip the creation of a conda environment, add the argument `--no-conda`. 44 | 45 | ``` 46 | mlflow run . --no-conda 47 | 48 | ``` 49 | 50 | Run the following command to train in the distributed mode 51 | 52 | ``` 53 | torchrun --nnodes 1 --nproc_per_node 4 news_classifier.py --max_epochs 1 --num_train_samples 2000 --num_test_samples 200 54 | ``` 55 | 56 | Note: The arguments `requirements_file` and `extra_files` in `mlflow.pytorch.log_model` are optional. 57 | 58 | Run the `news_classifier.py` script which will fine tune the model based on the news dataset. 59 | 60 | By default, the script exports the model file as `bert_pytorch.pt` and generates a sample input file `input.json` 61 | 62 | Command: 63 | 64 | ### Passing custom training parameters 65 | 66 | The parameters can be overridden via the command line: 67 | 68 | 1. max_epochs - Number of epochs to train models. Training can be interrupted early via Ctrl+C 69 | 2. num_train_samples -Number of input samples required for training 70 | 4. num_test_samples -Number of input samples required for test 71 | 72 | 73 | For example: 74 | ``` 75 | mlflow run . -P max_epochs=5 76 | ``` 77 | 78 | Or to run the training script directly with custom parameters: 79 | ``` 80 | python news_classifier.py \ 81 | --max_epochs 5 \ 82 | --model_save_path /home/ubuntu/mlflow-torchserve/examples/BertNewsClassification/models 83 | ``` 84 | 85 | To run the training script in GPU environment: 86 | ``` 87 | torchrun news_classifier.py \ 88 | --max_epochs 5 \ 89 | --model_save_path /home/ubuntu/mlflow-torchserve/examples/BertNewsClassification/models 90 | ``` 91 | 92 | ## Starting TorchServe 93 | 94 | create an empty directory `model_store` and run the following command to start torchserve. 95 | 96 | `torchserve --start --model-store model_store` 97 | 98 | 99 | ## Creating a new deployment 100 | 101 | Run the following command to create a new deployment named `news_classification_test` 102 | 103 | `mlflow deployments create -t torchserve -m file:///home/ubuntu/mlflow-torchserve/examples/BertNewsClassification/models --name news_classification_test -C "MODEL_FILE=news_classifier.py" -C "HANDLER=news_classifier_handler.py"` 104 | 105 | Torchserve plugin determines the version number by itself based on the deployment name. hence, version number 106 | is not a mandatory argument for the plugin. For example, the above command will create a deployment `news_classification_test` with version 1. 107 | 108 | If needed, version number can also be explicitly mentioned as a config variable. 109 | 110 | 111 | `mlflow deployments create -t torchserve -m file:///home/ubuntu/mlflow-torchserve/examples/BertNewsClassification/models --name news_classification_test -C "VERSION=1.0" -C "MODEL_FILE=news_classifier.py" -C "HANDLER=news_classifier_handler.py"` 112 | 113 | Note: 114 | mlflow-torchserve plugin generates the mar file inside the "model_store" directory. If the `model_store` directory is not present under the current folder, 115 | the plugin creates a new directory named "model_store" and generates the mar file inside it. 116 | 117 | if the torchserve is already running with a different "model_store" location, ensure to pass the "model_store" path with the "EXPORT_PATH" config variable (`-C 'EXPORT_PATH='`) 118 | 119 | ## Running prediction based on deployed model 120 | 121 | The deployed BERT model would predict the classification of the given news text and store the output in `output.json`. Run the following command to invoke prediction of our sample input (input.json) 122 | 123 | `mlflow deployments predict --name news_classification_test --target torchserve --input-path input.json --output-path output.json` 124 | -------------------------------------------------------------------------------- /examples/cifar10/cifar10_handler.py: -------------------------------------------------------------------------------- 1 | """ Cifar10 Custom Handler.""" 2 | 3 | import base64 4 | import io 5 | import json 6 | import logging 7 | import os 8 | from abc import ABC 9 | from base64 import b64encode 10 | from io import BytesIO 11 | 12 | import numpy as np 13 | import torch 14 | from PIL import Image 15 | from captum.attr import IntegratedGradients, Occlusion, LayerGradCam 16 | from captum.attr import visualization as viz 17 | from matplotlib.colors import LinearSegmentedColormap 18 | from torchvision import transforms 19 | from ts.torch_handler.image_classifier import ImageClassifier 20 | 21 | logger = logging.getLogger(__name__) 22 | 23 | 24 | class CIFAR10Classification(ImageClassifier, ABC): 25 | """ 26 | Base class for all vision handlers 27 | """ 28 | 29 | def initialize(self, ctx): # pylint: disable=arguments-differ 30 | """In this initialize function, the CIFAR10 trained model is loaded and 31 | the Integrated Gradients,occlusion and layer_gradcam Algorithm for 32 | Captum Explanations is initialized here. 33 | Args: 34 | ctx (context): It is a JSON Object containing information 35 | pertaining to the model artifacts parameters. 36 | """ 37 | self.manifest = ctx.manifest 38 | properties = ctx.system_properties 39 | model_dir = properties.get("model_dir") 40 | print("Model dir is {}".format(model_dir)) 41 | serialized_file = self.manifest["model"]["serializedFile"] 42 | mapping_file_path = os.path.join(model_dir, "index_to_name.json") 43 | if os.path.exists(mapping_file_path): 44 | with open(mapping_file_path) as fp: 45 | self.mapping = json.load(fp) 46 | model_pt_path = os.path.join(model_dir, serialized_file) 47 | self.device = torch.device( 48 | "cuda:" + str(properties.get("gpu_id")) if torch.cuda.is_available() else "cpu" 49 | ) 50 | from cifar10_train import CIFAR10Classifier 51 | 52 | self.model = CIFAR10Classifier() 53 | self.model.load_state_dict(torch.load(model_pt_path)) 54 | self.model.to(self.device) 55 | self.model.eval() 56 | self.model.zero_grad() 57 | logger.info("CIFAR10 model from path %s loaded successfully", model_dir) 58 | 59 | # Read the mapping file, index to object name 60 | mapping_file_path = os.path.join(model_dir, "class_mapping.json") 61 | if os.path.isfile(mapping_file_path): 62 | print("Mapping file present") 63 | with open(mapping_file_path) as pointer: 64 | self.mapping = json.load(pointer) 65 | else: 66 | print("Mapping file missing") 67 | logger.warning("Missing the class_mapping.json file.") 68 | 69 | self.ig = IntegratedGradients(self.model) 70 | self.layer_gradcam = LayerGradCam(self.model, self.model.model_conv.layer4[2].conv3) 71 | self.occlusion = Occlusion(self.model) 72 | self.initialized = True 73 | self.image_processing = transforms.Compose( 74 | [ 75 | transforms.Resize(224), 76 | transforms.CenterCrop(224), 77 | transforms.ToTensor(), 78 | transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), 79 | ] 80 | ) 81 | 82 | def _get_img(self, row): 83 | """Compat layer: normally the envelope should just return the data 84 | directly, but older version of KFServing envelope and 85 | Torchserve in general didn't have things set up right 86 | """ 87 | 88 | if isinstance(row, dict): 89 | image = row.get("data") or row.get("body") 90 | else: 91 | image = row 92 | 93 | if isinstance(image, bytearray): 94 | # if the image is a string of bytesarray. 95 | image = base64.b64decode(image) 96 | 97 | return image 98 | 99 | def preprocess(self, data): 100 | """The preprocess function of cifar10 program 101 | converts the input data to a float tensor 102 | Args: 103 | data (List): Input data from the request is in the form of a Tensor 104 | Returns: 105 | list : The preprocess function returns 106 | the input image as a list of float tensors. 107 | """ 108 | images = [] 109 | 110 | for row in data: 111 | image = self._get_img(row) 112 | 113 | # If the image is sent as bytesarray 114 | if isinstance(image, (bytearray, bytes)): 115 | image = Image.open(io.BytesIO(image)) 116 | image = self.image_processing(image) 117 | else: 118 | # if the image is a list 119 | image = torch.FloatTensor(image) 120 | 121 | images.append(image) 122 | 123 | return torch.stack(images).to(self.device) 124 | 125 | def attribute_image_features(self, algorithm, data, **kwargs): 126 | """Calculate tensor attributions""" 127 | self.model.zero_grad() 128 | tensor_attributions = algorithm.attribute(data, target=0, **kwargs) 129 | return tensor_attributions 130 | 131 | def output_bytes(self, fig): 132 | """Convert image to bytes""" 133 | fout = BytesIO() 134 | fig.savefig(fout, format="png") 135 | fout.seek(0) 136 | return fout.getvalue() 137 | 138 | def get_insights(self, tensor_data, _, target=0): 139 | default_cmap = LinearSegmentedColormap.from_list( 140 | "custom blue", 141 | [(0, "#ffffff"), (0.25, "#0000ff"), (1, "#0000ff")], 142 | N=256, 143 | ) 144 | 145 | attributions_ig, _ = self.attribute_image_features( 146 | self.ig, 147 | tensor_data, 148 | baselines=tensor_data * 0, 149 | return_convergence_delta=True, 150 | n_steps=15, 151 | ) 152 | 153 | matplot_viz_ig, _ = viz.visualize_image_attr_multiple( 154 | np.transpose(attributions_ig.squeeze().cpu().detach().numpy(), (1, 2, 0)), 155 | np.transpose(tensor_data.squeeze().cpu().detach().numpy(), (1, 2, 0)), 156 | use_pyplot=False, 157 | methods=["original_image", "heat_map"], 158 | cmap=default_cmap, 159 | show_colorbar=True, 160 | signs=["all", "positive"], 161 | titles=["Original", "Integrated Gradients"], 162 | ) 163 | 164 | ig_bytes = self.output_bytes(matplot_viz_ig) 165 | 166 | output = [ 167 | {"b64": b64encode(row).decode("utf8")} if isinstance(row, (bytes, bytearray)) else row 168 | for row in [ig_bytes] 169 | ] 170 | return output 171 | -------------------------------------------------------------------------------- /examples/Titanic/titanic_handler.py: -------------------------------------------------------------------------------- 1 | from titanic import TitanicSimpleNNModel 2 | import json 3 | import logging 4 | import os 5 | import torch 6 | from ts.torch_handler.base_handler import BaseHandler 7 | from captum.attr import IntegratedGradients 8 | import matplotlib.pyplot as plt 9 | import pandas as pd 10 | import numpy as np 11 | 12 | logger = logging.getLogger(__name__) 13 | 14 | 15 | class TitanicHandler(BaseHandler): 16 | """ 17 | Titanic handler class for titanic classifier . 18 | """ 19 | 20 | def __init__(self): 21 | super(TitanicHandler, self).__init__() 22 | self.initialized = False 23 | self.feature_names = None 24 | self.inference_output = [] 25 | self.predicted_idx = None 26 | self.out_probs = None 27 | self.delta = None 28 | self.input_file_path = None 29 | 30 | def initialize(self, ctx): 31 | """In this initialize function, the Titanic trained model is loaded and 32 | the Integrated Gradients Algorithm for Captum Explanations 33 | is initialized here. 34 | 35 | Args: 36 | ctx (context): It is a JSON Object containing information 37 | pertaining to the model artifacts parameters. 38 | """ 39 | self.manifest = ctx.manifest 40 | properties = ctx.system_properties 41 | model_dir = properties.get("model_dir") 42 | print("Model dir is {}".format(model_dir)) 43 | serialized_file = self.manifest["model"]["serializedFile"] 44 | model_pt_path = os.path.join(model_dir, serialized_file) 45 | self.device = torch.device( 46 | "cuda:" + str(properties.get("gpu_id")) if torch.cuda.is_available() else "cpu" 47 | ) 48 | 49 | self.model = TitanicSimpleNNModel() 50 | self.model.load_state_dict(torch.load(model_pt_path)) 51 | self.model.to(self.device) 52 | self.model.eval() 53 | 54 | logger.info("Titanic model from path %s loaded successfully", model_dir) 55 | 56 | # Read the mapping file, index to object name 57 | mapping_file_path = os.path.join(model_dir, "index_to_name.json") 58 | if os.path.isfile(mapping_file_path): 59 | print("Mapping file present") 60 | with open(mapping_file_path) as f: 61 | self.mapping = json.load(f) 62 | else: 63 | print("Mapping file missing") 64 | logger.warning("Missing the index_to_name.json file.") 65 | 66 | # ------------------------------- Captum initialization ----------------------------# 67 | self.ig = IntegratedGradients(self.model) 68 | self.initialized = True 69 | 70 | def preprocess(self, data): 71 | """Basic text preprocessing, based on the user's chocie of application mode. 72 | 73 | Args: 74 | data (csv): The Input data in the form of csv is passed on to the preprocess 75 | function. 76 | 77 | Returns: 78 | list : The preprocess function returns a list of Tensor and feature names 79 | """ 80 | self.input_file_path = data[0]["input_file_path"] 81 | if isinstance(self.input_file_path, bytearray): 82 | self.input_file_path = self.input_file_path.decode() 83 | data = pd.read_csv(self.input_file_path) 84 | self.feature_names = list(data.columns) 85 | data = data.to_numpy() 86 | data = torch.from_numpy(data).type(torch.FloatTensor) 87 | return data 88 | 89 | def inference(self, data): 90 | """Predict the class (survived or not survived) of the received input json file using the 91 | serialized model. 92 | 93 | Args: 94 | input_batch (list): List of Tensors from the pre-process function is passed here 95 | 96 | Returns: 97 | list : It returns a list of the predicted value for the input test record 98 | """ 99 | data = data.to(self.device) 100 | self.out_probs = self.model(data) 101 | self.predicted_idx = self.out_probs.argmax(1).item() 102 | prediction = self.mapping[str(self.predicted_idx)] 103 | self.inference_output.append(prediction) 104 | logger.info("Model predicted: '%s'", prediction) 105 | return [prediction] 106 | 107 | def postprocess(self, inference_output): 108 | """Post Process Function converts the predicted response into Torchserve readable format. 109 | 110 | Args: 111 | inference_output (list): It contains the predicted response of the input record. 112 | Returns: 113 | (list): Returns a list of the Predictions and Explanations. 114 | """ 115 | return inference_output 116 | 117 | def get_insights( 118 | self, input_tensor, title="Average Feature Importances", axis_title="Features" 119 | ): 120 | """This function calls the integrated gradient to the feature importance 121 | 122 | Args: 123 | data(tensor): 124 | target (int): The Target can be set to any acceptable label under the user's discretion. 125 | 126 | Returns: 127 | (list): Returns a dict of feature names and their importances 128 | """ 129 | ig = IntegratedGradients(self.model) 130 | input_tensor.requires_grad_() 131 | input_tensor = input_tensor.to(self.device) 132 | attr, self.delta = ig.attribute(input_tensor, target=1, return_convergence_delta=True) 133 | attr = attr.cpu().detach().numpy() 134 | importances = np.mean(attr, axis=0) 135 | feature_imp_dict = {} 136 | for i in range(len(self.feature_names)): 137 | feature_imp_dict[str(self.feature_names[i])] = importances[i] 138 | x_pos = np.arange(len(self.feature_names)) 139 | fig, ax = plt.subplots(figsize=(12, 6)) 140 | ax.bar(x_pos, importances, align="center") 141 | ax.set(title=title, xlabel=axis_title) 142 | ax.set_xticks(x_pos) 143 | ax.set_xticklabels(self.feature_names, rotation="vertical") 144 | path = os.path.join( 145 | os.path.dirname(os.path.abspath(self.input_file_path)), "attributions_imp.png" 146 | ) 147 | print("path of the saved image", path) 148 | 149 | plt.savefig(path) 150 | logger.info("Saved attributions image") 151 | return [feature_imp_dict] 152 | 153 | def explain_handle(self, data_preprocess, raw_data): 154 | """Captum explanations handler 155 | Args: 156 | data_preprocess (Torch Tensor): Preprocessed data to be used for captum 157 | raw_data (list): The unprocessed data to get target from the request 158 | Returns: 159 | dict : A dictionary response with the explanations response. 160 | """ 161 | output_explain = self.get_insights(data_preprocess) 162 | return output_explain 163 | -------------------------------------------------------------------------------- /examples/cifar10/cifar10_train.py: -------------------------------------------------------------------------------- 1 | """Cifar10 training module.""" 2 | import mlflow.pytorch 3 | import matplotlib.pyplot as plt 4 | import numpy as np 5 | import pytorch_lightning as pl 6 | import torch 7 | import torch.nn.functional as F 8 | from argparse import ArgumentParser 9 | from torchmetrics import Accuracy 10 | from torch import nn 11 | from torchvision import models 12 | 13 | 14 | class CIFAR10Classifier( 15 | pl.LightningModule 16 | ): # pylint: disable=too-many-ancestors,too-many-instance-attributes 17 | """Cifar10 model class.""" 18 | 19 | def __init__(self, **kwargs): 20 | """Initializes the network, optimizer and scheduler.""" 21 | super(CIFAR10Classifier, self).__init__() # pylint: disable=super-with-arguments 22 | self.model_conv = models.resnet50(pretrained=True) 23 | for param in self.model_conv.parameters(): 24 | param.requires_grad = False 25 | num_ftrs = self.model_conv.fc.in_features 26 | num_classes = 10 27 | self.model_conv.fc = nn.Linear(num_ftrs, num_classes) 28 | 29 | self.scheduler = None 30 | self.optimizer = None 31 | self.args = kwargs 32 | 33 | self.train_acc = Accuracy() 34 | self.val_acc = Accuracy() 35 | self.test_acc = Accuracy() 36 | 37 | self.preds = [] 38 | self.target = [] 39 | 40 | def forward(self, x_var): 41 | """Forward function.""" 42 | out = self.model_conv(x_var) 43 | return out 44 | 45 | def training_step(self, train_batch, batch_idx): 46 | """Training Step 47 | Args: 48 | train_batch : training batch 49 | batch_idx : batch id number 50 | Returns: 51 | train accuracy 52 | """ 53 | if batch_idx == 0: 54 | self.reference_image = (train_batch[0][0]).unsqueeze( 55 | 0 56 | ) # pylint: disable=attribute-defined-outside-init 57 | # self.reference_image.resize((1,1,28,28)) 58 | print("\n\nREFERENCE IMAGE!!!") 59 | print(self.reference_image.shape) 60 | x_var, y_var = train_batch 61 | output = self.forward(x_var) 62 | _, y_hat = torch.max(output, dim=1) 63 | loss = F.cross_entropy(output, y_var) 64 | self.log("train_loss", loss) 65 | self.train_acc(y_hat, y_var) 66 | self.log("train_acc", self.train_acc.compute()) 67 | return {"loss": loss} 68 | 69 | def test_step(self, test_batch, batch_idx): 70 | """Testing step 71 | Args: 72 | test_batch : test batch data 73 | batch_idx : tests batch id 74 | Returns: 75 | test accuracy 76 | """ 77 | 78 | x_var, y_var = test_batch 79 | output = self.forward(x_var) 80 | _, y_hat = torch.max(output, dim=1) 81 | loss = F.cross_entropy(output, y_var) 82 | self.log("test_loss", loss, sync_dist=True) 83 | self.test_acc(y_hat, y_var) 84 | self.preds += y_hat.tolist() 85 | self.target += y_var.tolist() 86 | 87 | self.log("test_acc", self.test_acc.compute()) 88 | return {"test_acc": self.test_acc.compute()} 89 | 90 | def validation_step(self, val_batch, batch_idx): 91 | """Testing step. 92 | Args: 93 | val_batch : val batch data 94 | batch_idx : val batch id 95 | Returns: 96 | validation accuracy 97 | """ 98 | 99 | x_var, y_var = val_batch 100 | output = self.forward(x_var) 101 | _, y_hat = torch.max(output, dim=1) 102 | loss = F.cross_entropy(output, y_var) 103 | self.log("val_loss", loss, sync_dist=True) 104 | self.val_acc(y_hat, y_var) 105 | self.log("val_acc", self.val_acc.compute()) 106 | return {"val_step_loss": loss, "val_loss": loss} 107 | 108 | def configure_optimizers(self): 109 | """Initializes the optimizer and learning rate scheduler. 110 | Returns: 111 | output - Initialized optimizer and scheduler 112 | """ 113 | self.optimizer = torch.optim.Adam( 114 | self.parameters(), 115 | lr=self.args.get("lr", 0.001), 116 | weight_decay=self.args.get("weight_decay", 0), 117 | eps=self.args.get("eps", 1e-8), 118 | ) 119 | self.scheduler = { 120 | "scheduler": torch.optim.lr_scheduler.ReduceLROnPlateau( 121 | self.optimizer, 122 | mode="min", 123 | factor=0.2, 124 | patience=3, 125 | min_lr=1e-6, 126 | verbose=True, 127 | ), 128 | "monitor": "val_loss", 129 | } 130 | return [self.optimizer], [self.scheduler] 131 | 132 | def makegrid(self, output, numrows): # pylint: disable=no-self-use 133 | """Makes grids. 134 | Args: 135 | output : Tensor output 136 | numrows : num of rows. 137 | Returns: 138 | c_array : gird array 139 | """ 140 | outer = torch.Tensor.cpu(output).detach() 141 | plt.figure(figsize=(20, 5)) 142 | b_array = np.array([]).reshape(0, outer.shape[2]) 143 | c_array = np.array([]).reshape(numrows * outer.shape[2], 0) 144 | i = 0 145 | j = 0 146 | while i < outer.shape[1]: 147 | img = outer[0][i] 148 | b_array = np.concatenate((img, b_array), axis=0) 149 | j += 1 150 | if j == numrows: 151 | c_array = np.concatenate((c_array, b_array), axis=1) 152 | b_array = np.array([]).reshape(0, outer.shape[2]) 153 | j = 0 154 | 155 | i += 1 156 | return c_array 157 | 158 | def show_activations(self, x_var): 159 | """Showns activation 160 | Args: 161 | x_var: x variable 162 | """ 163 | 164 | # logging reference image 165 | self.logger.experiment.add_image( 166 | "input", torch.Tensor.cpu(x_var[0][0]), self.current_epoch, dataformats="HW" 167 | ) 168 | 169 | # logging layer 1 activations 170 | out = self.model_conv.conv1(x_var) 171 | c_grid = self.makegrid(out, 4) 172 | self.logger.experiment.add_image("layer 1", c_grid, self.current_epoch, dataformats="HW") 173 | 174 | def training_epoch_end(self, outputs): 175 | """Training epoch end. 176 | Args: 177 | outputs: outputs of train end 178 | """ 179 | self.show_activations(self.reference_image) 180 | 181 | 182 | if __name__ == "__main__": 183 | parser = ArgumentParser(description="PyTorch Cifar10 Example") 184 | parser.add_argument( 185 | "--download_path", 186 | type=str, 187 | default="output/processing", 188 | help="Path to write cifar10 dataset", 189 | ) 190 | parser = pl.Trainer.add_argparse_args(parent_parser=parser) 191 | 192 | from cifar10_datamodule import CIFAR10DataModule 193 | 194 | parser = CIFAR10DataModule.add_model_specific_args(parent_parser=parser) 195 | 196 | args = parser.parse_args() 197 | dict_args = vars(args) 198 | 199 | for argument in ["strategy", "accelerator", "devices"]: 200 | if dict_args[argument] == "None": 201 | dict_args[argument] = None 202 | 203 | mlflow.pytorch.autolog() 204 | 205 | model = CIFAR10Classifier(**dict_args) 206 | 207 | dm = CIFAR10DataModule(**dict_args) 208 | dm.setup(stage="fit") 209 | 210 | trainer = pl.Trainer.from_argparse_args(args) 211 | 212 | trainer.fit(model, dm) 213 | trainer.test(datamodule=dm) 214 | 215 | torch.save(trainer.lightning_module.state_dict(), "resnet.pth") 216 | -------------------------------------------------------------------------------- /examples/IrisClassification/iris_classification.py: -------------------------------------------------------------------------------- 1 | # pylint: disable=W0221 2 | # pylint: disable=W0613 3 | # pylint: disable=W0223 4 | import argparse 5 | from argparse import ArgumentParser 6 | 7 | import mlflow 8 | import pytorch_lightning as pl 9 | import torch 10 | import torch.nn as nn 11 | import torch.nn.functional as F 12 | from mlflow.models.signature import ModelSignature 13 | from mlflow.types.schema import Schema, ColSpec 14 | from pytorch_lightning import seed_everything 15 | from torchmetrics import Accuracy 16 | from sklearn.datasets import load_iris 17 | from torch.utils.data import DataLoader, random_split, TensorDataset 18 | 19 | 20 | class IrisClassification(pl.LightningModule): 21 | def __init__(self, **kwargs): 22 | super(IrisClassification, self).__init__() 23 | 24 | self.train_acc = Accuracy(task="multiclass", num_classes=3) 25 | self.val_acc = Accuracy(task="multiclass", num_classes=3) 26 | self.test_acc = Accuracy(task="multiclass", num_classes=3) 27 | self.args = kwargs 28 | 29 | self.fc1 = nn.Linear(4, 10) 30 | self.fc2 = nn.Linear(10, 10) 31 | self.fc3 = nn.Linear(10, 3) 32 | self.cross_entropy_loss = nn.CrossEntropyLoss() 33 | 34 | def forward(self, x): 35 | x = F.relu(self.fc1(x)) 36 | x = F.relu(self.fc2(x)) 37 | x = F.relu(self.fc3(x)) 38 | return x 39 | 40 | @staticmethod 41 | def add_model_specific_args(parent_parser): 42 | """ 43 | Add model specific arguments like learning rate 44 | 45 | :param parent_parser: Application specific parser 46 | 47 | :return: Returns the augmented arugument parser 48 | """ 49 | parser = ArgumentParser(parents=[parent_parser], add_help=False) 50 | parser.add_argument( 51 | "--lr", 52 | type=float, 53 | default=0.01, 54 | metavar="LR", 55 | help="learning rate (default: 0.001)", 56 | ) 57 | return parser 58 | 59 | def configure_optimizers(self): 60 | return torch.optim.Adam(self.parameters(), self.args["lr"]) 61 | 62 | def training_step(self, batch, batch_idx): 63 | x, y = batch 64 | logits = self.forward(x) 65 | _, y_hat = torch.max(logits, dim=1) 66 | loss = self.cross_entropy_loss(logits, y) 67 | self.train_acc(y_hat, y) 68 | self.log( 69 | "train_acc", 70 | self.train_acc.compute(), 71 | on_step=False, 72 | on_epoch=True, 73 | ) 74 | self.log("train_loss", loss) 75 | return {"loss": loss} 76 | 77 | def validation_step(self, batch, batch_idx): 78 | x, y = batch 79 | logits = self.forward(x) 80 | _, y_hat = torch.max(logits, dim=1) 81 | loss = F.cross_entropy(logits, y) 82 | self.val_acc(y_hat, y) 83 | self.log("val_acc", self.val_acc.compute()) 84 | self.log("val_loss", loss, sync_dist=True) 85 | 86 | def test_step(self, batch, batch_idx): 87 | x, y = batch 88 | logits = self.forward(x) 89 | _, y_hat = torch.max(logits, dim=1) 90 | self.test_acc(y_hat, y) 91 | self.log("test_acc", self.test_acc.compute()) 92 | 93 | 94 | class IrisDataModule(pl.LightningDataModule): 95 | def __init__(self, **kwargs): 96 | """ 97 | Initialization of inherited lightning data module 98 | """ 99 | super(IrisDataModule, self).__init__() 100 | 101 | self.train_set = None 102 | self.val_set = None 103 | self.test_set = None 104 | self.args = kwargs 105 | 106 | def prepare_data(self): 107 | """ 108 | Implementation of abstract class 109 | """ 110 | 111 | def setup(self, stage=None): 112 | """ 113 | Downloads the data, parse it and split the data into train, test, validation data 114 | 115 | :param stage: Stage - training or testing 116 | """ 117 | iris = load_iris() 118 | df = iris.data 119 | target = iris["target"] 120 | 121 | data = torch.Tensor(df).float() 122 | labels = torch.Tensor(target).long() 123 | RANDOM_SEED = 42 124 | seed_everything(RANDOM_SEED) 125 | 126 | data_set = TensorDataset(data, labels) 127 | self.train_set, self.val_set = random_split(data_set, [130, 20]) 128 | self.train_set, self.test_set = random_split(self.train_set, [110, 20]) 129 | 130 | @staticmethod 131 | def add_model_specific_args(parent_parser): 132 | """ 133 | Adds model specific arguments batch size and num workers 134 | 135 | :param parent_parser: Application specific parser 136 | 137 | :return: Returns the augmented arugument parser 138 | """ 139 | parser = ArgumentParser(parents=[parent_parser], add_help=False) 140 | parser.add_argument( 141 | "--batch-size", 142 | type=int, 143 | default=128, 144 | metavar="N", 145 | help="input batch size for training (default: 16)", 146 | ) 147 | parser.add_argument( 148 | "--num-workers", 149 | type=int, 150 | default=3, 151 | metavar="N", 152 | help="number of workers (default: 3)", 153 | ) 154 | return parser 155 | 156 | def create_data_loader(self, dataset): 157 | """ 158 | Generic data loader function 159 | 160 | :param data_set: Input data set 161 | 162 | :return: Returns the constructed dataloader 163 | """ 164 | 165 | return DataLoader( 166 | dataset, batch_size=self.args["batch_size"], num_workers=self.args["num_workers"] 167 | ) 168 | 169 | def train_dataloader(self): 170 | train_loader = self.create_data_loader(dataset=self.train_set) 171 | return train_loader 172 | 173 | def val_dataloader(self): 174 | validation_loader = self.create_data_loader(dataset=self.val_set) 175 | return validation_loader 176 | 177 | def test_dataloader(self): 178 | test_loader = self.create_data_loader(dataset=self.test_set) 179 | return test_loader 180 | 181 | 182 | if __name__ == "__main__": 183 | parser = argparse.ArgumentParser(description="Iris Classification model") 184 | 185 | parser.add_argument( 186 | "--save-model", 187 | type=bool, 188 | default=True, 189 | help="For Saving the current Model", 190 | ) 191 | 192 | parser = pl.Trainer.add_argparse_args(parent_parser=parser) 193 | parser = IrisClassification.add_model_specific_args(parent_parser=parser) 194 | parser = IrisDataModule.add_model_specific_args(parent_parser=parser) 195 | 196 | args = parser.parse_args() 197 | dict_args = vars(args) 198 | 199 | for argument in ["strategy", "accelerator", "devices"]: 200 | if dict_args[argument] == "None": 201 | dict_args[argument] = None 202 | 203 | dm = IrisDataModule(**dict_args) 204 | dm.prepare_data() 205 | dm.setup(stage="fit") 206 | 207 | model = IrisClassification(**dict_args) 208 | trainer = pl.Trainer.from_argparse_args(args) 209 | trainer.fit(model, dm) 210 | trainer.test(datamodule=dm) 211 | 212 | if trainer.global_rank == 0: 213 | input_schema = Schema( 214 | [ 215 | ColSpec("double", "sepal length (cm)"), 216 | ColSpec("double", "sepal width (cm)"), 217 | ColSpec("double", "petal length (cm)"), 218 | ColSpec("double", "petal width (cm)"), 219 | ] 220 | ) 221 | output_schema = Schema([ColSpec("long")]) 222 | signature = ModelSignature(inputs=input_schema, outputs=output_schema) 223 | mlflow.pytorch.save_model(trainer.lightning_module, "model", signature=signature) 224 | -------------------------------------------------------------------------------- /examples/E2EBert/news_classifier_handler.py: -------------------------------------------------------------------------------- 1 | from captum.attr import IntegratedGradients 2 | import json 3 | import logging 4 | import os 5 | import numpy as np 6 | import torch 7 | from transformers import BertTokenizer 8 | from ts.torch_handler.base_handler import BaseHandler 9 | from captum.attr import visualization 10 | import torch.nn.functional as F 11 | from news_classifier import BertNewsClassifier 12 | from wrapper import AGNewsmodelWrapper 13 | 14 | logger = logging.getLogger(__name__) 15 | 16 | 17 | class NewsClassifierHandler(BaseHandler): 18 | """ 19 | NewsClassifierHandler class. This handler takes a review / sentence 20 | and returns the label as either world / sports / business /sci-tech 21 | """ 22 | 23 | def __init__(self): 24 | self.model = None 25 | self.mapping = None 26 | self.device = None 27 | self.initialized = False 28 | self.class_mapping_file = None 29 | self.VOCAB_FILE = None 30 | 31 | def initialize(self, ctx): 32 | """ 33 | First try to load torchscript else load eager mode state_dict based model 34 | 35 | :param ctx: System properties 36 | """ 37 | 38 | properties = ctx.system_properties 39 | self.device = torch.device("cuda") if torch.cuda.is_available() else "cpu" 40 | model_dir = properties.get("model_dir") 41 | 42 | # Read model serialize/pt file 43 | model_pt_path = os.path.join(model_dir, "state_dict.pth") 44 | # Read model definition file 45 | model_def_path = os.path.join(model_dir, "news_classifier.py") 46 | if not os.path.isfile(model_def_path): 47 | raise RuntimeError("Missing the model definition file") 48 | self.VOCAB_FILE = os.path.join(model_dir, "bert_base_uncased_vocab.txt") 49 | if not os.path.isfile(self.VOCAB_FILE): 50 | raise RuntimeError("Missing the vocab file") 51 | 52 | self.class_mapping_file = os.path.join(model_dir, "class_mapping.json") 53 | 54 | state_dict = torch.load(model_pt_path, map_location=self.device) 55 | self.model = BertNewsClassifier() 56 | self.model.load_state_dict(state_dict) 57 | self.model.to(self.device) 58 | self.model.eval() 59 | 60 | logger.debug("Model file %s loaded successfully", model_pt_path) 61 | self.initialized = True 62 | 63 | def preprocess(self, data): 64 | """ 65 | Receives text in form of json and converts it into an encoding for the inference stage 66 | 67 | :param data: Input to be passed through the layers for prediction 68 | 69 | :return: output - preprocessed encoding 70 | """ 71 | 72 | text = data[0].get("data") 73 | if text is None: 74 | text = data[0].get("body") 75 | 76 | self.text = text.decode("utf-8") 77 | 78 | self.tokenizer = BertTokenizer(self.VOCAB_FILE) 79 | self.input_ids = torch.tensor( 80 | [self.tokenizer.encode(self.text, add_special_tokens=True)] 81 | ).to(self.device) 82 | return self.input_ids 83 | 84 | def inference(self, input_ids): 85 | """ 86 | Predict the class for a review / sentence whether 87 | it is belong to world / sports / business /sci-tech 88 | :param encoding: Input encoding to be passed through the layers for prediction 89 | 90 | :return: output - predicted output 91 | """ 92 | inputs = self.input_ids.to(self.device) 93 | self.outputs = self.model.forward(inputs) 94 | self.out = np.argmax(self.outputs.cpu().detach()) 95 | return [self.out.item()] 96 | 97 | def postprocess(self, inference_output): 98 | """ 99 | Does postprocess after inference to be returned to user 100 | 101 | :param inference_output: Output of inference 102 | 103 | :return: output - Output after post processing 104 | """ 105 | if os.path.exists(self.class_mapping_file): 106 | with open(self.class_mapping_file) as json_file: 107 | data = json.load(json_file) 108 | inference_output = json.dumps(data[str(inference_output[0])]) 109 | return [inference_output] 110 | 111 | return inference_output 112 | 113 | def add_attributions_to_visualizer( 114 | self, 115 | attributions, 116 | tokens, 117 | pred_prob, 118 | pred_class, 119 | true_class, 120 | attr_class, 121 | delta, 122 | vis_data_records, 123 | ): 124 | attributions = attributions.sum(dim=2).squeeze(0) 125 | attributions = attributions / torch.norm(attributions) 126 | attributions = attributions.cpu().detach().numpy() 127 | 128 | # storing couple samples in an array for visualization purposes 129 | vis_data_records.append( 130 | visualization.VisualizationDataRecord( 131 | attributions, 132 | pred_prob, 133 | pred_class, 134 | true_class, 135 | attr_class, 136 | attributions.sum(), 137 | tokens, 138 | delta, 139 | ) 140 | ) 141 | 142 | def score_func(self, o): 143 | output = F.softmax(o, dim=1) 144 | pre_pro = np.argmax(output.cpu().detach()) 145 | return pre_pro 146 | 147 | def summarize_attributions(self, attributions): 148 | """Summarises the attribution across multiple runs 149 | Args: 150 | attributions ([list): attributions from the Integrated Gradients 151 | Returns: 152 | list : Returns the attributions after normalizing them. 153 | """ 154 | attributions = attributions.sum(dim=-1).squeeze(0) 155 | attributions = attributions / torch.norm(attributions) 156 | return attributions 157 | 158 | def explain_handle(self, model_wraper, text, target=1): 159 | """Captum explanations handler 160 | Args: 161 | data_preprocess (Torch Tensor): 162 | Preprocessed data to be used for captum 163 | raw_data (list): The unprocessed data to get target from the request 164 | Returns: 165 | dict : A dictionary response with the explanations response. 166 | """ 167 | vis_data_records_base = [] 168 | model_wrapper = AGNewsmodelWrapper(self.model) 169 | tokenizer = BertTokenizer(self.VOCAB_FILE) 170 | model_wrapper.eval() 171 | model_wrapper.zero_grad() 172 | encoding = tokenizer.encode_plus( 173 | self.text, return_attention_mask=True, return_tensors="pt", add_special_tokens=False 174 | ) 175 | input_ids = encoding["input_ids"] 176 | attention_mask = encoding["attention_mask"] 177 | input_ids = input_ids.to(self.device) 178 | attention_mask = attention_mask.to(self.device) 179 | input_embedding_test = model_wrapper.model.bert_model.embeddings(input_ids) 180 | preds = model_wrapper(input_embedding_test, attention_mask) 181 | out = np.argmax(preds.cpu().detach(), axis=1) 182 | out = out.item() 183 | ig_1 = IntegratedGradients(model_wrapper) 184 | attributions, delta = ig_1.attribute( # pylint: disable=no-member 185 | input_embedding_test, 186 | n_steps=500, 187 | return_convergence_delta=True, 188 | target=1, 189 | ) 190 | tokens = tokenizer.convert_ids_to_tokens(input_ids[0].cpu().numpy().tolist()) 191 | feature_imp_dict = {} 192 | feature_imp_dict["words"] = tokens 193 | attributions_sum = self.summarize_attributions(attributions) 194 | feature_imp_dict["importances"] = attributions_sum.tolist() 195 | feature_imp_dict["delta"] = delta[0].tolist() 196 | self.add_attributions_to_visualizer( 197 | attributions, tokens, self.score_func(preds), out, 2, 1, delta, vis_data_records_base 198 | ) 199 | return [feature_imp_dict] 200 | -------------------------------------------------------------------------------- /tests/test_plugin.py: -------------------------------------------------------------------------------- 1 | import json 2 | import os 3 | 4 | import pytest 5 | import torch 6 | from mlflow import deployments 7 | from mlflow.exceptions import MlflowException 8 | 9 | f_target = "torchserve" 10 | f_deployment_id = "test" 11 | f_deployment_name_version = "test/2.0" 12 | f_deployment_name_all = "test/all" 13 | f_flavor = None 14 | f_model_uri = os.path.join("tests/resources", "linear_state_dict.pt") 15 | 16 | model_version = "1.0" 17 | model_file_path = os.path.join("tests/resources", "linear_model.py") 18 | handler_file_path = os.path.join("tests/resources", "linear_handler.py") 19 | sample_input_file = os.path.join("tests/resources", "sample.json") 20 | sample_output_file = os.path.join("tests/resources", "output.json") 21 | 22 | 23 | @pytest.mark.usefixtures("start_torchserve") 24 | def test_create_deployment_success(): 25 | client = deployments.get_deploy_client(f_target) 26 | ret = client.create_deployment( 27 | f_deployment_id, 28 | f_model_uri, 29 | f_flavor, 30 | config={ 31 | "VERSION": model_version, 32 | "MODEL_FILE": model_file_path, 33 | "HANDLER": handler_file_path, 34 | }, 35 | ) 36 | assert isinstance(ret, dict) 37 | assert ret["name"] == f_deployment_id + "/" + model_version 38 | assert ret["flavor"] == f_flavor 39 | 40 | 41 | def test_create_deployment_no_version(): 42 | client = deployments.get_deploy_client(f_target) 43 | ret = client.create_deployment( 44 | f_deployment_id, 45 | f_model_uri, 46 | f_flavor, 47 | config={"MODEL_FILE": model_file_path, "HANDLER": handler_file_path}, 48 | ) 49 | assert isinstance(ret, dict) 50 | assert ret["name"] == f_deployment_name_version 51 | assert ret["flavor"] == f_flavor 52 | 53 | 54 | def test_list_success(): 55 | client = deployments.get_deploy_client(f_target) 56 | ret = client.list_deployments() 57 | isNamePresent = False 58 | for i in range(len(ret)): 59 | if list(ret[i].keys())[0] == f_deployment_id: 60 | isNamePresent = True 61 | break 62 | if isNamePresent: 63 | assert True 64 | else: 65 | assert False 66 | 67 | 68 | @pytest.mark.parametrize( 69 | "deployment_name", [f_deployment_id, f_deployment_name_version, f_deployment_name_all] 70 | ) 71 | def test_get_success(deployment_name): 72 | client = deployments.get_deploy_client(f_target) 73 | ret = client.get_deployment(deployment_name) 74 | print("Return value is ", json.loads(ret["deploy"])) 75 | if deployment_name == f_deployment_id: 76 | assert json.loads(ret["deploy"])[0]["modelName"] == f_deployment_id 77 | elif deployment_name == f_deployment_name_version: 78 | assert ( 79 | json.loads(ret["deploy"])[0]["modelVersion"] == f_deployment_name_version.split("/")[1] 80 | ) 81 | else: 82 | assert len(json.loads(ret["deploy"])) == 2 83 | 84 | 85 | def test_wrong_target_name(): 86 | with pytest.raises(MlflowException): 87 | deployments.get_deploy_client("wrong_target") 88 | 89 | 90 | @pytest.mark.parametrize( 91 | "deployment_name, config", 92 | [(f_deployment_name_version, {"SET-DEFAULT": "true"}), (f_deployment_id, {"MIN_WORKER": 3})], 93 | ) 94 | def test_update_deployment_success(deployment_name, config): 95 | client = deployments.get_deploy_client(f_target) 96 | ret = client.update_deployment(deployment_name, config) 97 | assert ret["flavor"] is None 98 | 99 | 100 | @pytest.mark.parametrize("deployment_name", [f_deployment_name_version, f_deployment_id]) 101 | def test_predict_success(deployment_name): 102 | client = deployments.get_deploy_client(f_target) 103 | with open(sample_input_file) as fp: 104 | data = fp.read() 105 | pred = client.predict(deployment_name, data) 106 | assert pred is not None 107 | 108 | 109 | @pytest.mark.parametrize("deployment_name", [f_deployment_name_version, f_deployment_id]) 110 | def test_predict_tensor_input(deployment_name): 111 | client = deployments.get_deploy_client(f_target) 112 | data = torch.Tensor([5000]) 113 | pred = client.predict(deployment_name, data) 114 | assert pred is not None 115 | 116 | 117 | @pytest.mark.parametrize("deployment_name", [f_deployment_name_version, f_deployment_id]) 118 | def test_delete_success(deployment_name): 119 | client = deployments.get_deploy_client(f_target) 120 | assert client.delete_deployment(deployment_name) is None 121 | 122 | 123 | f_dummy = "dummy" 124 | 125 | 126 | def test_create_no_handler_exception(): 127 | client = deployments.get_deploy_client(f_target) 128 | with pytest.raises(Exception, match="Config Variable HANDLER - missing"): 129 | client.create_deployment( 130 | f_deployment_id, 131 | f_model_uri, 132 | f_flavor, 133 | config={"VERSION": model_version, "MODEL_FILE": model_file_path}, 134 | ) 135 | 136 | 137 | def test_create_wrong_handler_exception(): 138 | client = deployments.get_deploy_client(f_target) 139 | with pytest.raises(Exception, match="No such file or directory"): 140 | client.create_deployment( 141 | f_deployment_id, 142 | f_model_uri, 143 | f_flavor, 144 | config={"VERSION": model_version, "MODEL_FILE": model_file_path, "HANDLER": f_dummy}, 145 | ) 146 | 147 | 148 | def test_create_wrong_model_exception(): 149 | client = deployments.get_deploy_client(f_target) 150 | with pytest.raises(Exception, match="No such file or directory"): 151 | client.create_deployment( 152 | f_deployment_id, 153 | f_model_uri, 154 | f_flavor, 155 | config={"VERSION": model_version, "MODEL_FILE": f_dummy, "HANDLER": handler_file_path}, 156 | ) 157 | 158 | 159 | def test_create_mar_file_exception(): 160 | client = deployments.get_deploy_client(f_target) 161 | with pytest.raises(Exception, match="No such file or directory"): 162 | client.create_deployment( 163 | f_deployment_id, 164 | f_dummy, 165 | config={ 166 | "VERSION": model_version, 167 | "MODEL_FILE": model_file_path, 168 | "HANDLER": handler_file_path, 169 | }, 170 | ) 171 | 172 | 173 | def test_update_invalid_name(): 174 | client = deployments.get_deploy_client(f_target) 175 | with pytest.raises(Exception, match="Unable to update deployment with name %s" % f_dummy): 176 | client.update_deployment(f_dummy) 177 | 178 | 179 | def test_get_invalid_name(): 180 | client = deployments.get_deploy_client(f_target) 181 | with pytest.raises(Exception, match="Unable to get deployments with name %s" % f_dummy): 182 | client.get_deployment(f_dummy) 183 | 184 | 185 | def test_delete_invalid_name(): 186 | client = deployments.get_deploy_client(f_target) 187 | with pytest.raises(Exception, match="Unable to delete deployment for name %s" % f_dummy): 188 | client.delete_deployment(f_dummy) 189 | 190 | 191 | def test_predict_exception(): 192 | client = deployments.get_deploy_client(f_target) 193 | with pytest.raises(Exception, match="Unable to parse input json string"): 194 | client.predict(f_dummy, "sample") 195 | 196 | 197 | def test_explain_exception(): 198 | client = deployments.get_deploy_client(f_target) 199 | with pytest.raises(Exception, match="Unable to parse input json string"): 200 | client.explain(f_dummy, "sample") 201 | 202 | 203 | def test_explain_name_exception(): 204 | with open(sample_input_file) as fp: 205 | data = fp.read() 206 | client = deployments.get_deploy_client(f_target) 207 | with pytest.raises(Exception, match="Unable to infer the results for the name %s" % f_dummy): 208 | client.explain(f_dummy, data) 209 | 210 | 211 | def test_predict_name_exception(): 212 | with open(sample_input_file) as fp: 213 | data = fp.read() 214 | client = deployments.get_deploy_client(f_target) 215 | with pytest.raises(Exception, match="Unable to infer the results for the name %s" % f_dummy): 216 | client.predict(f_dummy, data) 217 | -------------------------------------------------------------------------------- /examples/cifar10/cifar10_datamodule.py: -------------------------------------------------------------------------------- 1 | """Cifar10 data module.""" 2 | import os 3 | import subprocess 4 | from argparse import ArgumentParser 5 | from pathlib import Path 6 | 7 | import pytorch_lightning as pl 8 | import torchvision 9 | import webdataset as wds 10 | from sklearn.model_selection import train_test_split 11 | from torch.utils.data import DataLoader 12 | from torchvision import transforms 13 | 14 | 15 | class CIFAR10DataModule(pl.LightningDataModule): # pylint: disable=too-many-instance-attributes 16 | """Data module class.""" 17 | 18 | def __init__(self, **kwargs): 19 | """Initialization of inherited lightning data module.""" 20 | super(CIFAR10DataModule, self).__init__() # pylint: disable=super-with-arguments 21 | 22 | self.train_dataset = None 23 | self.valid_dataset = None 24 | self.test_dataset = None 25 | self.train_data_loader = None 26 | self.val_data_loader = None 27 | self.test_data_loader = None 28 | self.normalize = transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]) 29 | self.valid_transform = transforms.Compose( 30 | [ 31 | transforms.ToTensor(), 32 | self.normalize, 33 | ] 34 | ) 35 | 36 | self.train_transform = transforms.Compose( 37 | [ 38 | transforms.RandomResizedCrop(32), 39 | transforms.RandomHorizontalFlip(), 40 | transforms.ToTensor(), 41 | self.normalize, 42 | ] 43 | ) 44 | self.args = kwargs 45 | 46 | def prepare_data(self): 47 | """Implementation of abstract class.""" 48 | output_path = self.args.get("download_path", "output/processing") 49 | Path(output_path).mkdir(parents=True, exist_ok=True) 50 | 51 | trainset = torchvision.datasets.CIFAR10(root="./", train=True, download=True) 52 | testset = torchvision.datasets.CIFAR10(root="./", train=False, download=True) 53 | 54 | Path(output_path + "/train").mkdir(parents=True, exist_ok=True) 55 | Path(output_path + "/val").mkdir(parents=True, exist_ok=True) 56 | Path(output_path + "/test").mkdir(parents=True, exist_ok=True) 57 | 58 | RANDOM_SEED = 25 59 | y = trainset.targets 60 | trainset, valset, y_train, y_val = train_test_split( 61 | trainset, y, stratify=y, shuffle=True, test_size=0.2, random_state=RANDOM_SEED 62 | ) 63 | 64 | for name in [(trainset, "train"), (valset, "val"), (testset, "test")]: 65 | with wds.ShardWriter( 66 | output_path + "/" + str(name[1]) + "/" + str(name[1]) + "-%d.tar", maxcount=1000 67 | ) as sink: 68 | for index, (image, cls) in enumerate(name[0]): 69 | sink.write({"__key__": "%06d" % index, "ppm": image, "cls": cls}) 70 | 71 | entry_point = ["ls", "-R", output_path] 72 | run_code = subprocess.run( 73 | entry_point, stdout=subprocess.PIPE 74 | ) # pylint: disable=subprocess-run-check 75 | print(run_code.stdout) 76 | 77 | @staticmethod 78 | def get_num_files(input_path): 79 | """Gets num files. 80 | Args: 81 | input_path : path to input 82 | """ 83 | return len(os.listdir(input_path)) - 1 84 | 85 | @staticmethod 86 | def add_model_specific_args(parent_parser): 87 | """ 88 | Returns the review text and the targets of the specified item 89 | 90 | :param parent_parser: Application specific parser 91 | 92 | :return: Returns the augmented arugument parser 93 | """ 94 | parser = ArgumentParser(parents=[parent_parser], add_help=False) 95 | parser.add_argument( 96 | "--num_samples_train", 97 | type=int, 98 | help="Number of samples for training (max: 39)", 99 | ) 100 | 101 | parser.add_argument( 102 | "--num_samples_val", 103 | type=int, 104 | help="Number of samples for Validation (max: 9)", 105 | ) 106 | 107 | parser.add_argument( 108 | "--num_samples_test", 109 | type=int, 110 | help="Number of samples for Testing (max: 9)", 111 | ) 112 | 113 | return parser 114 | 115 | def setup(self, stage=None): 116 | """Downloads the data, parse it and split the data into train, test, 117 | validation data. 118 | Args: 119 | stage: Stage - training or testing 120 | """ 121 | 122 | data_path = self.args.get("train_glob", "output/processing") 123 | 124 | train_base_url = data_path + "/train" 125 | val_base_url = data_path + "/val" 126 | test_base_url = data_path + "/test" 127 | 128 | train_count = self.args["num_samples_train"] 129 | val_count = self.args["num_samples_val"] 130 | test_count = self.args["num_samples_test"] 131 | 132 | if not train_count: 133 | train_count = self.get_num_files(train_base_url) 134 | 135 | if not val_count: 136 | val_count = self.get_num_files(val_base_url) 137 | 138 | if not test_count: 139 | test_count = self.get_num_files(test_base_url) 140 | 141 | train_url = "{}/{}-{}".format(train_base_url, "train", "{0.." + str(train_count) + "}.tar") 142 | valid_url = "{}/{}-{}".format(val_base_url, "val", "{0.." + str(val_count) + "}.tar") 143 | test_url = "{}/{}-{}".format(test_base_url, "test", "{0.." + str(test_count) + "}.tar") 144 | 145 | self.train_dataset = ( 146 | wds.WebDataset( 147 | train_url, handler=wds.warn_and_continue, nodesplitter=wds.shardlists.split_by_node 148 | ) 149 | .shuffle(100) 150 | .decode("pil") 151 | .rename(image="ppm;jpg;jpeg;png", info="cls") 152 | .map_dict(image=self.train_transform) 153 | .to_tuple("image", "info") 154 | .batched(40) 155 | ) 156 | 157 | self.valid_dataset = ( 158 | wds.WebDataset( 159 | valid_url, handler=wds.warn_and_continue, nodesplitter=wds.shardlists.split_by_node 160 | ) 161 | .shuffle(100) 162 | .decode("pil") 163 | .rename(image="ppm", info="cls") 164 | .map_dict(image=self.valid_transform) 165 | .to_tuple("image", "info") 166 | .batched(20) 167 | ) 168 | 169 | self.test_dataset = ( 170 | wds.WebDataset( 171 | test_url, handler=wds.warn_and_continue, nodesplitter=wds.shardlists.split_by_node 172 | ) 173 | .shuffle(100) 174 | .decode("pil") 175 | .rename(image="ppm", info="cls") 176 | .map_dict(image=self.valid_transform) 177 | .to_tuple("image", "info") 178 | .batched(20) 179 | ) 180 | 181 | def create_data_loader(self, dataset, batch_size, num_workers): # pylint: disable=no-self-use 182 | """Creates data loader.""" 183 | return DataLoader(dataset, batch_size=batch_size, num_workers=num_workers) 184 | 185 | def train_dataloader(self): 186 | """Train Data loader. 187 | Returns: 188 | output - Train data loader for the given input 189 | """ 190 | self.train_data_loader = self.create_data_loader( 191 | self.train_dataset, 192 | self.args.get("train_batch_size", None), 193 | self.args.get("train_num_workers", 4), 194 | ) 195 | return self.train_data_loader 196 | 197 | def val_dataloader(self): 198 | """Validation Data Loader. 199 | Returns: 200 | output - Validation data loader for the given input 201 | """ 202 | self.val_data_loader = self.create_data_loader( 203 | self.valid_dataset, 204 | self.args.get("val_batch_size", None), 205 | self.args.get("val_num_workers", 4), 206 | ) 207 | return self.val_data_loader 208 | 209 | def test_dataloader(self): 210 | """Test Data Loader. 211 | Returns: 212 | output - Test data loader for the given input 213 | """ 214 | self.test_data_loader = self.create_data_loader( 215 | self.test_dataset, 216 | self.args.get("val_batch_size", None), 217 | self.args.get("val_num_workers", 4), 218 | ) 219 | return self.test_data_loader 220 | -------------------------------------------------------------------------------- /examples/MNIST/mnist_model.py: -------------------------------------------------------------------------------- 1 | # 2 | # Trains an MNIST digit recognizer using PyTorch Lightning 3 | # NOTE: This example requires you to first install 4 | # pytorch-lightning (using pip install pytorch-lightning) 5 | # and mlflow (using pip install mlflow). 6 | # 7 | # pylint: disable=arguments-differ 8 | # pylint: disable=unused-argument 9 | # pylint: disable=abstract-method 10 | 11 | from argparse import ArgumentParser 12 | 13 | import mlflow.pytorch 14 | import pytorch_lightning as pl 15 | import torch 16 | from torch.nn.parallel import ( 17 | DistributedDataParallel, 18 | DataParallel, 19 | ) 20 | from pytorch_lightning import seed_everything 21 | from torchmetrics import Accuracy 22 | from torch.nn import functional as F 23 | from torch.utils.data import DataLoader, random_split 24 | from torchvision import datasets, transforms 25 | 26 | 27 | class MNISTDataModule(pl.LightningDataModule): 28 | def __init__(self, **kwargs): 29 | """ 30 | Initialization of inherited lightning data module 31 | """ 32 | super(MNISTDataModule, self).__init__() 33 | self.df_train = None 34 | self.df_val = None 35 | self.df_test = None 36 | self.train_data_loader = None 37 | self.val_data_loader = None 38 | self.test_data_loader = None 39 | self.args = kwargs 40 | 41 | # transforms for images 42 | self.transform = transforms.Compose( 43 | [transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))] 44 | ) 45 | 46 | def setup(self, stage=None): 47 | """ 48 | Downloads the data, parse it and split the data into train, test, validation data 49 | 50 | :param stage: Stage - training or testing 51 | """ 52 | 53 | RANDOM_SEED = 42 54 | seed_everything(RANDOM_SEED) 55 | 56 | self.df_train = datasets.MNIST( 57 | "dataset", download=True, train=True, transform=self.transform 58 | ) 59 | self.df_train, self.df_val = random_split(self.df_train, [55000, 5000]) 60 | self.df_test = datasets.MNIST( 61 | "dataset", download=True, train=False, transform=self.transform 62 | ) 63 | 64 | @staticmethod 65 | def add_model_specific_args(parent_parser): 66 | parser = ArgumentParser(parents=[parent_parser], add_help=False) 67 | parser.add_argument( 68 | "--batch-size", 69 | type=int, 70 | default=64, 71 | metavar="N", 72 | help="input batch size for training (default: 64)", 73 | ) 74 | parser.add_argument( 75 | "--num-workers", 76 | type=int, 77 | default=3, 78 | metavar="N", 79 | help="number of workers (default: 3)", 80 | ) 81 | return parser 82 | 83 | def create_data_loader(self, df): 84 | """ 85 | Generic data loader function 86 | 87 | :param df: Input tensor 88 | 89 | :return: Returns the constructed dataloader 90 | """ 91 | return DataLoader( 92 | df, 93 | batch_size=self.args["batch_size"], 94 | num_workers=self.args["num_workers"], 95 | ) 96 | 97 | def train_dataloader(self): 98 | """ 99 | :return: output - Train data loader for the given input 100 | """ 101 | return self.create_data_loader(self.df_train) 102 | 103 | def val_dataloader(self): 104 | """ 105 | :return: output - Validation data loader for the given input 106 | """ 107 | return self.create_data_loader(self.df_val) 108 | 109 | def test_dataloader(self): 110 | """ 111 | :return: output - Test data loader for the given input 112 | """ 113 | return self.create_data_loader(self.df_test) 114 | 115 | 116 | class LightningMNISTClassifier(pl.LightningModule): 117 | def __init__(self, **kwargs): 118 | """mlflow.start_run() 119 | Initializes the network 120 | """ 121 | super(LightningMNISTClassifier, self).__init__() 122 | 123 | self.train_acc = Accuracy(task="multiclass", num_classes=10) 124 | self.val_acc = Accuracy(task="multiclass", num_classes=10) 125 | self.test_acc = Accuracy(task="multiclass", num_classes=10) 126 | 127 | # mnist images are (1, 28, 28) (channels, width, height) 128 | self.layer_1 = torch.nn.Linear(28 * 28, 128) 129 | self.layer_2 = torch.nn.Linear(128, 256) 130 | self.layer_3 = torch.nn.Linear(256, 10) 131 | self.args = kwargs 132 | 133 | @staticmethod 134 | def add_model_specific_args(parent_parser): 135 | parser = ArgumentParser(parents=[parent_parser], add_help=False) 136 | parser.add_argument( 137 | "--lr", 138 | type=float, 139 | default=0.001, 140 | metavar="LR", 141 | help="learning rate (default: 0.001)", 142 | ) 143 | return parser 144 | 145 | def forward(self, x): 146 | """ 147 | :param x: Input data 148 | 149 | :return: output - mnist digit label for the input image 150 | """ 151 | batch_size = x.size()[0] 152 | 153 | # (b, 1, 28, 28) -> (b, 1*28*28) 154 | x = x.view(batch_size, -1) 155 | 156 | # layer 1 (b, 1*28*28) -> (b, 128) 157 | x = self.layer_1(x) 158 | x = torch.relu(x) 159 | 160 | # layer 2 (b, 128) -> (b, 256) 161 | x = self.layer_2(x) 162 | x = torch.relu(x) 163 | 164 | # layer 3 (b, 256) -> (b, 10) 165 | x = self.layer_3(x) 166 | 167 | # probability distribution over labels 168 | x = torch.log_softmax(x, dim=1) 169 | 170 | return x 171 | 172 | def cross_entropy_loss(self, logits, labels): 173 | """ 174 | Initializes the loss function 175 | 176 | :return: output - Initialized cross entropy loss function 177 | """ 178 | return F.nll_loss(logits, labels) 179 | 180 | def training_step(self, train_batch, batch_idx): 181 | """ 182 | Training the data as batches and returns training loss on each batch 183 | :param train_batch: Batch data 184 | :param batch_idx: Batch indices 185 | 186 | :return: output - Training loss 187 | """ 188 | x, y = train_batch 189 | logits = self.forward(x) 190 | loss = self.cross_entropy_loss(logits, y) 191 | _, y_hat = torch.max(logits, dim=1) 192 | self.train_acc(y_hat, y) 193 | self.log("train_acc", self.train_acc.compute()) 194 | self.log("train_loss", loss) 195 | return {"loss": loss} 196 | 197 | def validation_step(self, val_batch, batch_idx): 198 | """ 199 | Performs validation of data in batches 200 | 201 | :param val_batch: Batch data 202 | :param batch_idx: Batch indices 203 | 204 | :return: output - valid step loss 205 | """ 206 | x, y = val_batch 207 | logits = self.forward(x) 208 | loss = self.cross_entropy_loss(logits, y) 209 | _, y_hat = torch.max(logits, dim=1) 210 | self.val_acc(y_hat, y) 211 | self.log("val_acc", self.val_acc.compute(), sync_dist=True) 212 | self.log("val_loss", loss, sync_dist=True) 213 | 214 | def test_step(self, test_batch, batch_idx): 215 | """ 216 | Performs test and computes the accuracy of the model 217 | 218 | :param test_batch: Batch data 219 | :param batch_idx: Batch indices 220 | 221 | :return: output - Testing accuracy 222 | """ 223 | x, y = test_batch 224 | output = self.forward(x) 225 | _, y_hat = torch.max(output, dim=1) 226 | 227 | self.test_acc(y_hat, y) 228 | self.log("test_acc", self.test_acc.compute()) 229 | 230 | def prepare_data(self): 231 | """ 232 | Prepares the data for training and prediction 233 | """ 234 | return {} 235 | 236 | def configure_optimizers(self): 237 | """ 238 | Initializes the optimizer and learning rate scheduler 239 | 240 | :return: output - Initialized optimizer and scheduler 241 | """ 242 | optimizer = torch.optim.Adam(self.parameters(), lr=self.args["lr"]) 243 | scheduler = { 244 | "scheduler": torch.optim.lr_scheduler.ReduceLROnPlateau( 245 | optimizer, 246 | mode="min", 247 | factor=0.2, 248 | patience=2, 249 | min_lr=1e-6, 250 | verbose=True, 251 | ), 252 | "monitor": "val_loss", 253 | } 254 | return [optimizer], [scheduler] 255 | 256 | 257 | def get_model(trainer): 258 | is_dp_module = isinstance(trainer.model, (DistributedDataParallel, DataParallel)) 259 | model = trainer.model.module if is_dp_module else trainer.model 260 | return model 261 | 262 | 263 | if __name__ == "__main__": 264 | parser = ArgumentParser(description="PyTorch Autolog Mnist Example") 265 | parser.add_argument( 266 | "--registration_name", type=str, default="mnist_classifier", help="Model registration name" 267 | ) 268 | parser.add_argument( 269 | "--register", type=str, default="true", help="To enable/disable model registration" 270 | ) 271 | parser = pl.Trainer.add_argparse_args(parent_parser=parser) 272 | parser = LightningMNISTClassifier.add_model_specific_args(parent_parser=parser) 273 | parser = MNISTDataModule.add_model_specific_args(parent_parser=parser) 274 | 275 | args = parser.parse_args() 276 | dict_args = vars(args) 277 | 278 | for argument in ["strategy", "accelerator", "devices"]: 279 | if dict_args[argument] == "None": 280 | dict_args[argument] = None 281 | 282 | model = LightningMNISTClassifier(**dict_args) 283 | 284 | dm = MNISTDataModule(**dict_args) 285 | dm.prepare_data() 286 | dm.setup(stage="fit") 287 | 288 | trainer = pl.Trainer.from_argparse_args(args) 289 | 290 | mlflow.pytorch.autolog() 291 | with mlflow.start_run() as run: 292 | trainer.fit(model, dm) 293 | trainer.test(datamodule=dm) 294 | active_run = mlflow.active_run() 295 | if dict_args["register"] == "true": 296 | mlflow.register_model( 297 | model_uri=active_run.info.artifact_uri, name=dict_args["registration_name"] 298 | ) 299 | else: 300 | torch.save(trainer.lightning_module.state_dict(), "model.pth") 301 | -------------------------------------------------------------------------------- /LICENSE.txt: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | APPENDIX: How to apply the Apache License to your work. 178 | 179 | To apply the Apache License to your work, attach the following 180 | boilerplate notice, with the fields enclosed by brackets "[]" 181 | replaced with your own identifying information. (Don't include 182 | the brackets!) The text should be enclosed in the appropriate 183 | comment syntax for the file format. We also recommend that a 184 | file or class name and description of purpose be included on the 185 | same "printed page" as the copyright notice for easier 186 | identification within third-party archives. 187 | 188 | Copyright [yyyy] [name of copyright owner] 189 | 190 | Licensed under the Apache License, Version 2.0 (the "License"); 191 | you may not use this file except in compliance with the License. 192 | You may obtain a copy of the License at 193 | 194 | http://www.apache.org/licenses/LICENSE-2.0 195 | 196 | Unless required by applicable law or agreed to in writing, software 197 | distributed under the License is distributed on an "AS IS" BASIS, 198 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 199 | See the License for the specific language governing permissions and 200 | limitations under the License. 201 | -------------------------------------------------------------------------------- /tests/test_cli.py: -------------------------------------------------------------------------------- 1 | import json 2 | import os 3 | 4 | import pytest 5 | from click.testing import CliRunner 6 | from mlflow import deployments 7 | from mlflow.deployments import cli 8 | 9 | f_target = "torchserve" 10 | f_deployment_id = "cli_test" 11 | f_deployment_name_version = "cli_test/2.0" 12 | f_deployment_name_all = "cli_test/all" 13 | f_flavor = None 14 | f_model_uri = os.path.join("tests/resources", "linear_state_dict.pt") 15 | 16 | model_version = "1.0" 17 | model_file_path = os.path.join("tests/resources", "linear_model.py") 18 | incorrect_model_file_path = os.path.join("tests/resources", "linear_model1.py") 19 | handler_file_path = os.path.join("tests/resources", "linear_handler.py") 20 | sample_input_file = os.path.join("tests/resources", "sample.json") 21 | sample_incorrect_input_file = os.path.join("tests/resources", "sample.txt") 22 | handler_file = "HANDLER={handler_file_path}".format(handler_file_path=handler_file_path) 23 | model_file = "MODEL_FILE={model_file_path}".format(model_file_path=model_file_path) 24 | incorrect_model_file = "MODEL_FILE={model_file_path}".format( 25 | model_file_path=incorrect_model_file_path 26 | ) 27 | runner = CliRunner() 28 | 29 | 30 | @pytest.mark.usefixtures("start_torchserve") 31 | def test_create_cli_version_success(): 32 | version = "VERSION={version}".format(version="1.0") 33 | _ = deployments.get_deploy_client(f_target) 34 | res = runner.invoke( 35 | cli.create_deployment, 36 | [ 37 | "-f", 38 | f_flavor, 39 | "-m", 40 | f_model_uri, 41 | "-t", 42 | f_target, 43 | "--name", 44 | f_deployment_id, 45 | "-C", 46 | model_file, 47 | "-C", 48 | handler_file, 49 | "-C", 50 | version, 51 | ], 52 | ) 53 | assert "{} deployment {} is created".format(f_flavor, f_deployment_id + "/1.0") in res.stdout 54 | 55 | 56 | def test_create_cli_success_without_version(): 57 | _ = deployments.get_deploy_client(f_target) 58 | res = runner.invoke( 59 | cli.create_deployment, 60 | [ 61 | "-f", 62 | f_flavor, 63 | "-m", 64 | f_model_uri, 65 | "-t", 66 | f_target, 67 | "--name", 68 | f_deployment_id, 69 | "-C", 70 | model_file, 71 | "-C", 72 | handler_file, 73 | ], 74 | ) 75 | assert "{} deployment {} is created".format(f_flavor, f_deployment_name_version) in res.stdout 76 | 77 | 78 | def test_create_cli_failure_without_version(): 79 | _ = deployments.get_deploy_client(f_target) 80 | res = runner.invoke( 81 | cli.create_deployment, 82 | [ 83 | "-f", 84 | f_flavor, 85 | "-m", 86 | f_model_uri, 87 | "-t", 88 | f_target, 89 | "--name", 90 | f_deployment_id, 91 | "-C", 92 | incorrect_model_file, 93 | "-C", 94 | handler_file, 95 | ], 96 | ) 97 | assert "No such file or directory" in str(res.exception) and res.exit_code == 1 98 | res = runner.invoke( 99 | cli.create_deployment, 100 | [ 101 | "-m", 102 | f_model_uri, 103 | "-t", 104 | f_target, 105 | "--name", 106 | f_deployment_id, 107 | "-C", 108 | model_file, 109 | ], 110 | ) 111 | assert str(res.exception) == "Config Variable HANDLER - missing" 112 | res = runner.invoke(cli.create_deployment) 113 | assert ( 114 | res.exit_code == 2 115 | and res.output == "Usage: create [OPTIONS]\nTry 'create --help' for help.\n\n" 116 | "Error: Missing option '--name'.\n" 117 | ) 118 | res = runner.invoke( 119 | cli.create_deployment, 120 | [ 121 | "-m", 122 | f_model_uri, 123 | "-t", 124 | f_target, 125 | "--name", 126 | f_deployment_id, 127 | "-C", 128 | handler_file, 129 | ], 130 | ) 131 | assert "Unable to register the model" in str(res.exception) 132 | res = runner.invoke( 133 | cli.create_deployment, 134 | [ 135 | "-t", 136 | f_target, 137 | "--name", 138 | f_deployment_id, 139 | "-C", 140 | handler_file, 141 | "-C", 142 | model_file, 143 | ], 144 | ) 145 | assert res.exit_code == 2 146 | res = runner.invoke( 147 | cli.create_deployment, 148 | [ 149 | "-m", 150 | f_model_uri, 151 | "--name", 152 | f_deployment_id, 153 | "-C", 154 | handler_file, 155 | "-C", 156 | model_file, 157 | ], 158 | ) 159 | assert res.exit_code == 2 160 | res = runner.invoke( 161 | cli.create_deployment, 162 | [ 163 | "-m", 164 | f_model_uri, 165 | "-t", 166 | f_target, 167 | "-C", 168 | handler_file, 169 | "-C", 170 | handler_file, 171 | ], 172 | ) 173 | assert res.exit_code == 2 174 | 175 | 176 | @pytest.mark.parametrize( 177 | "deployment_name, config", 178 | [(f_deployment_name_version, "SET-DEFAULT=true"), (f_deployment_id, "MIN_WORKER=3")], 179 | ) 180 | def test_update_cli_success(deployment_name, config): 181 | res = runner.invoke( 182 | cli.update_deployment, 183 | [ 184 | "--flavor", 185 | f_flavor, 186 | "--model-uri", 187 | f_model_uri, 188 | "--target", 189 | f_target, 190 | "--name", 191 | deployment_name, 192 | "-C", 193 | config, 194 | ], 195 | ) 196 | assert ( 197 | "Deployment {} is updated (with flavor {})".format(deployment_name, f_flavor) in res.stdout 198 | ) 199 | 200 | 201 | def test_list_cli_success(): 202 | res = runner.invoke(cli.list_deployment, ["--target", f_target]) 203 | assert "{}".format(f_deployment_id) in res.stdout 204 | 205 | 206 | @pytest.mark.parametrize( 207 | "deployment_name", [f_deployment_id, f_deployment_name_version, f_deployment_name_all] 208 | ) 209 | def test_get_cli_success(deployment_name): 210 | res = runner.invoke(cli.get_deployment, ["--name", deployment_name, "--target", f_target]) 211 | ret = json.loads(res.stdout.split("deploy:")[1]) 212 | if deployment_name == f_deployment_id: 213 | assert ret[0]["modelName"] == f_deployment_id 214 | elif deployment_name == f_deployment_name_version: 215 | assert ret[0]["modelVersion"] == f_deployment_name_version.split("/")[1] 216 | else: 217 | assert len(ret) == 2 218 | 219 | 220 | @pytest.mark.parametrize( 221 | "deployment_name", [f_deployment_id, f_deployment_name_version, f_deployment_name_all] 222 | ) 223 | def test_get_cli_failure(deployment_name): 224 | res = runner.invoke( 225 | cli.get_deployment, 226 | ) 227 | assert ( 228 | res.exit_code == 2 229 | and res.output 230 | == "Usage: get [OPTIONS]\nTry 'get --help' for help.\n\nError: Missing option '--name'.\n" 231 | ) 232 | res = runner.invoke(cli.get_deployment, ["--name", deployment_name]) 233 | assert ( 234 | res.exit_code == 2 235 | and res.output == "Usage: get [OPTIONS]\nTry 'get --help' for help.\n\n" 236 | "Error: Missing option '--target' / '-t'.\n" 237 | ) 238 | 239 | 240 | @pytest.mark.parametrize("deployment_name", [f_deployment_name_version, f_deployment_id]) 241 | def test_predict_cli_success(deployment_name): 242 | res = runner.invoke( 243 | cli.predict, 244 | ["--name", deployment_name, "--target", f_target, "--input-path", sample_input_file], 245 | ) 246 | assert res.exit_code == 0 247 | 248 | 249 | @pytest.mark.parametrize("deployment_name", [f_deployment_name_version, f_deployment_id]) 250 | def test_predict_cli_failure(deployment_name): 251 | res = runner.invoke( 252 | cli.predict, 253 | ["--name", deployment_name, "--target", f_target], 254 | ) 255 | assert ( 256 | res.exit_code == 2 257 | and res.output == "Usage: predict [OPTIONS]\nTry 'predict --help' for help.\n\n" 258 | "Error: Missing option '--input-path' / '-I'.\n" 259 | ) 260 | res = runner.invoke( 261 | cli.predict, 262 | ["--name", deployment_name, "--input-path", sample_input_file], 263 | ) 264 | assert ( 265 | res.exit_code == 2 266 | and res.output == "Usage: predict [OPTIONS]\nTry 'predict --help' for help.\n\n" 267 | "Error: Missing option '--target' / '-t'.\n" 268 | ) 269 | res = runner.invoke( 270 | cli.predict, 271 | ["--target", f_target, "--input-path", sample_input_file], 272 | ) 273 | assert ( 274 | res.exit_code == 2 275 | and res.output == "Usage: predict [OPTIONS]\nTry 'predict --help' for help.\n\n" 276 | "Error: Must specify exactly one of --name or --endpoint.\n" 277 | ) 278 | res = runner.invoke( 279 | cli.predict, 280 | [ 281 | "--name", 282 | deployment_name, 283 | "--target", 284 | f_target, 285 | "--input-path", 286 | sample_incorrect_input_file, 287 | ], 288 | ) 289 | assert res.exception 290 | 291 | 292 | @pytest.mark.parametrize("deployment_name", [f_deployment_name_version, f_deployment_id]) 293 | def test_explain_cli_success(deployment_name): 294 | runner.invoke( 295 | cli.explain, 296 | ["--name", deployment_name, "--target", f_target, "--input-path", sample_input_file], 297 | ) 298 | 299 | 300 | @pytest.mark.parametrize("deployment_name", [f_deployment_name_version, f_deployment_id]) 301 | def test_explain_cli_failure(deployment_name): 302 | res = runner.invoke( 303 | cli.explain, 304 | ["--name", deployment_name, "--target", f_target], 305 | ) 306 | assert ( 307 | res.exit_code == 2 308 | and res.output == "Usage: explain [OPTIONS]\nTry 'explain --help' for help.\n\n" 309 | "Error: Missing option '--input-path' / '-I'.\n" 310 | ) 311 | res = runner.invoke( 312 | cli.explain, 313 | ["--name", deployment_name, "--input-path", sample_input_file], 314 | ) 315 | assert ( 316 | res.exit_code == 2 317 | and res.output == "Usage: explain [OPTIONS]\nTry 'explain --help' for help.\n\n" 318 | "Error: Missing option '--target' / '-t'.\n" 319 | ) 320 | res = runner.invoke( 321 | cli.explain, 322 | [ 323 | "--name", 324 | deployment_name, 325 | "--target", 326 | f_target, 327 | "--input-path", 328 | sample_incorrect_input_file, 329 | ], 330 | ) 331 | assert res.exception 332 | 333 | 334 | @pytest.mark.parametrize("deployment_name", [f_deployment_id + "/1.0", f_deployment_name_version]) 335 | def test_delete_cli_success(deployment_name): 336 | res = runner.invoke( 337 | cli.delete_deployment, 338 | ["--name", deployment_name, "--target", f_target], 339 | ) 340 | assert "Deployment {} is deleted".format(deployment_name) in res.stdout 341 | 342 | 343 | @pytest.mark.parametrize("deployment_name", [f_deployment_id + "/1.0", f_deployment_name_version]) 344 | def test_delete_cli_failure(deployment_name): 345 | res = runner.invoke( 346 | cli.delete_deployment, 347 | ["--name", deployment_name], 348 | ) 349 | assert ( 350 | res.exit_code == 2 351 | and res.output == "Usage: delete [OPTIONS]\nTry 'delete --help' for help.\n\n" 352 | "Error: Missing option '--target' / '-t'.\n" 353 | ) 354 | 355 | res = runner.invoke( 356 | cli.delete_deployment, 357 | ["--target", f_target], 358 | ) 359 | assert ( 360 | res.exit_code == 2 361 | and res.output == "Usage: delete [OPTIONS]\nTry 'delete --help' for help.\n\n" 362 | "Error: Missing option '--name'.\n" 363 | ) 364 | 365 | res = runner.invoke( 366 | cli.delete_deployment, 367 | ) 368 | assert ( 369 | res.exit_code == 2 370 | and res.output == "Usage: delete [OPTIONS]\nTry 'delete --help' for help.\n\n" 371 | "Error: Missing option '--name'.\n" 372 | ) 373 | --------------------------------------------------------------------------------