├── .gitignore ├── Dockerfile ├── README.md ├── __init__.py ├── backend ├── Dockerfile ├── main.py ├── models.py └── requirements.txt ├── docker-compose.yml ├── frontend ├── Dockerfile ├── requirements.txt └── streamlit_main.py ├── ml ├── data.py ├── models.py ├── train.py └── utils.py ├── requirements.txt └── resources ├── arch.png ├── pred_pic.png └── train_pic.png /.gitignore: -------------------------------------------------------------------------------- 1 | 2 | # Created by https://www.toptal.com/developers/gitignore/api/python 3 | # Edit at https://www.toptal.com/developers/gitignore?templates=python 4 | 5 | ### Python ### 6 | # Byte-compiled / optimized / DLL files 7 | __pycache__/ 8 | *.py[cod] 9 | *$py.class 10 | 11 | # C extensions 12 | *.so 13 | 14 | # Distribution / packaging 15 | .Python 16 | build/ 17 | develop-eggs/ 18 | dist/ 19 | downloads/ 20 | eggs/ 21 | .eggs/ 22 | lib/ 23 | lib64/ 24 | parts/ 25 | sdist/ 26 | var/ 27 | wheels/ 28 | share/python-wheels/ 29 | *.egg-info/ 30 | .installed.cfg 31 | *.egg 32 | MANIFEST 33 | 34 | # PyInstaller 35 | # Usually these files are written by a python script from a template 36 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 37 | *.manifest 38 | *.spec 39 | 40 | # Installer logs 41 | pip-log.txt 42 | pip-delete-this-directory.txt 43 | 44 | # Unit test / coverage reports 45 | htmlcov/ 46 | .tox/ 47 | .nox/ 48 | .coverage 49 | .coverage.* 50 | .cache 51 | nosetests.xml 52 | coverage.xml 53 | *.cover 54 | *.py,cover 55 | .hypothesis/ 56 | .pytest_cache/ 57 | cover/ 58 | 59 | # Translations 60 | *.mo 61 | *.pot 62 | 63 | # Django stuff: 64 | *.log 65 | local_settings.py 66 | db.sqlite3 67 | db.sqlite3-journal 68 | 69 | # Flask stuff: 70 | instance/ 71 | .webassets-cache 72 | 73 | # Scrapy stuff: 74 | .scrapy 75 | 76 | # Sphinx documentation 77 | docs/_build/ 78 | 79 | # PyBuilder 80 | .pybuilder/ 81 | target/ 82 | 83 | # Jupyter Notebook 84 | .ipynb_checkpoints 85 | 86 | # IPython 87 | profile_default/ 88 | ipython_config.py 89 | 90 | # pyenv 91 | # For a library or package, you might want to ignore these files since the code is 92 | # intended to run in multiple environments; otherwise, check them in: 93 | # .python-version 94 | 95 | # pipenv 96 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 97 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 98 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 99 | # install all needed dependencies. 100 | #Pipfile.lock 101 | 102 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 103 | __pypackages__/ 104 | 105 | # Celery stuff 106 | celerybeat-schedule 107 | celerybeat.pid 108 | 109 | # SageMath parsed files 110 | *.sage.py 111 | 112 | # Environments 113 | .env 114 | .venv 115 | env/ 116 | venv/ 117 | ENV/ 118 | env.bak/ 119 | venv.bak/ 120 | 121 | # Spyder project settings 122 | .spyderproject 123 | .spyproject 124 | 125 | # Rope project settings 126 | .ropeproject 127 | 128 | # mkdocs documentation 129 | /site 130 | 131 | # mypy 132 | .mypy_cache/ 133 | .dmypy.json 134 | dmypy.json 135 | 136 | # Pyre type checker 137 | .pyre/ 138 | 139 | # pytype static type analyzer 140 | .pytype/ 141 | 142 | # Cython debug symbols 143 | cython_debug/ 144 | 145 | # End of https://www.toptal.com/developers/gitignore/api/python 146 | 147 | 148 | data/ 149 | *.ipynb 150 | *.db 151 | mlruns/ -------------------------------------------------------------------------------- /Dockerfile: -------------------------------------------------------------------------------- 1 | FROM python:3.9-slim 2 | 3 | RUN pip install mlflow 4 | 5 | EXPOSE 5000 6 | 7 | CMD ["mlflow", "ui", "--backend-store-uri", "sqlite:///db/backend.db", "--host", "0.0.0.0"] 8 | 9 | 10 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Fastapi + MLflow + streamlit 2 | 3 | Setup env. I hope I covered all. 4 | ```bash 5 | pip install -r requirements.txt 6 | ``` 7 | # Start app 8 | Go in the root dir and run these 9 | 10 | Streamlit 11 | ```bash 12 | streamlit run frontend/streamlit_main.py 13 | ``` 14 | 15 | FastAPI 16 | ``` 17 | uvicorn backend.main:app 18 | ``` 19 | 20 | MLflow UI 21 | ```bash 22 | mlflow ui --backend-store-uri sqlite:///db/bakckend.db 23 | ``` 24 | 25 | ## Docker 26 | - Mlflow: http://localhost:5000 27 | - FastApi: http://localhost:8000/docs 28 | - Streamlit: http://localhost:8501/ 29 | 30 | ```bash 31 | docker-compose build 32 | docker-compose up 33 | ``` 34 | 35 | # Architecture 36 | ![image](resources/arch.png) 37 | 38 | # UI 39 | ![image](resources/train_pic.png) 40 | ![image](resources/pred_pic.png) 41 | 42 | 43 | ## TODO 44 | - [x] Dockerize 45 | - [ ] Testing 46 | - [ ] Maybe add celery instead of that background task? (Needs extra configs though) 47 | 48 | -------------------------------------------------------------------------------- /__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zademn/mnist-mlops-learning/cd70fb33dd9e2fd6d0563856c09f9bddd3b61d4b/__init__.py -------------------------------------------------------------------------------- /backend/Dockerfile: -------------------------------------------------------------------------------- 1 | FROM python:3.9-slim 2 | 3 | RUN mkdir /backend 4 | RUN mkdir /ml 5 | 6 | 7 | COPY ./backend/requirements.txt /backend/requirements.txt 8 | RUN pip install -r backend/requirements.txt 9 | 10 | COPY ./backend /backend 11 | COPY ./ml /ml 12 | 13 | EXPOSE 8000 14 | 15 | CMD ["uvicorn", "backend.main:app", "--host", "0.0.0.0", "--port", "8000"] 16 | 17 | 18 | -------------------------------------------------------------------------------- /backend/main.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from fastapi import FastAPI 3 | from fastapi import BackgroundTasks 4 | from urllib.parse import urlparse 5 | 6 | import mlflow 7 | from mlflow.tracking import MlflowClient 8 | from ml.train import Trainer 9 | from ml.models import LinearModel 10 | from ml.data import load_mnist_data 11 | from ml.utils import set_device 12 | from backend.models import DeleteApiData, TrainApiData, PredictApiData 13 | 14 | 15 | #mlflow.set_tracking_uri('sqlite:///backend.db') 16 | mlflow.set_tracking_uri("sqlite:///db/backend.db") 17 | app = FastAPI() 18 | mlflowclient = MlflowClient( 19 | mlflow.get_tracking_uri(), mlflow.get_registry_uri()) 20 | 21 | 22 | def train_model_task(model_name: str, hyperparams: dict, epochs: int): 23 | """Tasks that trains the model. This is supposed to be running in the background 24 | Since it's a heavy computation it's better to use a stronger task runner like Celery 25 | For the simplicity I kept it as a fastapi background task""" 26 | 27 | # Setup env 28 | device = set_device() 29 | # Set MLflow tracking 30 | mlflow.set_experiment("MNIST") 31 | with mlflow.start_run() as run: 32 | # Log hyperparameters 33 | mlflow.log_params(hyperparams) 34 | 35 | # Prepare for training 36 | print("Loading data...") 37 | train_dataloader, test_dataloader = load_mnist_data() 38 | 39 | # Train 40 | print("Training model") 41 | model = LinearModel(hyperparams).to(device) 42 | trainer = Trainer(model, device=device) # Default configs 43 | history = trainer.train(epochs, train_dataloader, test_dataloader) 44 | 45 | print("Logging results") 46 | # Log in mlflow 47 | for metric_name, metric_values in history.items(): 48 | for metric_value in metric_values: 49 | mlflow.log_metric(metric_name, metric_value) 50 | 51 | # Register model 52 | tracking_url_type_store = urlparse(mlflow.get_tracking_uri()).scheme 53 | print(f"{tracking_url_type_store=}") 54 | 55 | # Model registry does not work with file store 56 | if tracking_url_type_store != "file": 57 | mlflow.pytorch.log_model( 58 | model, "LinearModel", registered_model_name=model_name, conda_env=mlflow.pytorch.get_default_conda_env()) 59 | else: 60 | mlflow.pytorch.log_model( 61 | model, "LinearModel-MNIST", registered_model_name=model_name) 62 | # Transition to production. We search for the last model with the name and we stage it to production 63 | mv = mlflowclient.search_model_versions( 64 | f"name='{model_name}'")[-1] # Take last model version 65 | mlflowclient.transition_model_version_stage( 66 | name=mv.name, version=mv.version, stage="production") 67 | 68 | 69 | @app.get("/") 70 | async def read_root(): 71 | return {"Tracking URI": mlflow.get_tracking_uri(), 72 | "Registry URI": mlflow.get_registry_uri()} 73 | 74 | 75 | @app.get("/models") 76 | async def get_models_api(): 77 | """Gets a list with model names""" 78 | model_list = mlflowclient.list_registered_models() 79 | model_list = [model.name for model in model_list] 80 | return model_list 81 | 82 | 83 | @app.post("/train") 84 | async def train_api(data: TrainApiData, background_tasks: BackgroundTasks): 85 | """Creates a model based on hyperparameters and trains it.""" 86 | hyperparams = data.hyperparams 87 | epochs = data.epochs 88 | model_name = data.model_name 89 | 90 | background_tasks.add_task( 91 | train_model_task, model_name, hyperparams, epochs) 92 | 93 | return {"result": "Training task started"} 94 | 95 | 96 | @app.post("/predict") 97 | async def predict_api(data: PredictApiData): 98 | """Predicts on the provided image""" 99 | img = data.input_image 100 | model_name = data.model_name 101 | # Fetch the last model in production 102 | model = mlflow.pyfunc.load_model( 103 | model_uri=f"models:/{model_name}/Production" 104 | ) 105 | # Preprocess the image 106 | # Flatten input, create a batch of one and normalize 107 | img = np.array(img, dtype=np.float32).flatten()[np.newaxis, ...] / 255 108 | # Postprocess result 109 | pred = model.predict(img) 110 | print(pred) 111 | res = int(np.argmax(pred[0])) 112 | return {"result": res} 113 | 114 | 115 | @app.post("/delete") 116 | async def delete_model_api(data: DeleteApiData): 117 | model_name = data.model_name 118 | version = data.model_version 119 | 120 | if version is None: 121 | # Delete all versions 122 | mlflowclient.delete_registered_model(name=model_name) 123 | response = {"result": f"Deleted all versions of model {model_name}"} 124 | elif isinstance(version, list): 125 | for v in version: 126 | mlflowclient.delete_model_version(name=model_name, version=v) 127 | response = { 128 | "result": f"Deleted versions {version} of model {model_name}"} 129 | else: 130 | mlflowclient.delete_model_version(name=model_name, version=version) 131 | response = { 132 | "result": f"Deleted version {version} of model {model_name}"} 133 | return response 134 | -------------------------------------------------------------------------------- /backend/models.py: -------------------------------------------------------------------------------- 1 | 2 | from typing import Any, Optional, Union 3 | from pydantic import BaseModel 4 | 5 | 6 | class TrainApiData(BaseModel): 7 | model_name: str 8 | hyperparams: dict[str, Any] 9 | epochs: int 10 | 11 | 12 | class PredictApiData(BaseModel): 13 | input_image: Any 14 | model_name: str 15 | 16 | 17 | class DeleteApiData(BaseModel): 18 | model_name: str 19 | model_version: Optional[Union[list[int], int]] # list | int in python 10 20 | -------------------------------------------------------------------------------- /backend/requirements.txt: -------------------------------------------------------------------------------- 1 | mlflow 2 | numpy 3 | pandas 4 | matplotlib 5 | torch 6 | torchvision 7 | fastapi[all] 8 | tqdm -------------------------------------------------------------------------------- /docker-compose.yml: -------------------------------------------------------------------------------- 1 | version: "3" 2 | 3 | services: 4 | mlflow: 5 | build: . 6 | volumes: 7 | - ./db:/db 8 | ports: 9 | - 5000:5000 10 | networks: 11 | deploy_network: 12 | container_name: mlflow_mnist 13 | 14 | fastapi: 15 | build: 16 | context: ./ 17 | dockerfile: ./backend/Dockerfile 18 | volumes: 19 | - ./db:/db 20 | depends_on: 21 | - mlflow 22 | ports: 23 | - 8000:8000 24 | networks: 25 | deploy_network: 26 | container_name: fastapi_mnist 27 | 28 | streamlit: 29 | build: frontend/ 30 | environment: 31 | - BACKEND_URL=http://fastapi_mnist:8000 32 | depends_on: 33 | - fastapi 34 | ports: 35 | - 8501:8501 36 | 37 | networks: 38 | deploy_network: 39 | container_name: streamlit_mnist 40 | 41 | networks: 42 | deploy_network: 43 | driver: bridge 44 | # external: 45 | # name: net-name 46 | -------------------------------------------------------------------------------- /frontend/Dockerfile: -------------------------------------------------------------------------------- 1 | FROM python:3.9-slim 2 | 3 | RUN mkdir /frontend 4 | 5 | COPY requirements.txt /frontend 6 | 7 | WORKDIR /frontend 8 | 9 | RUN pip install -r requirements.txt 10 | 11 | COPY . /frontend 12 | 13 | EXPOSE 8501 14 | 15 | CMD ["streamlit", "run", "streamlit_main.py"] 16 | 17 | 18 | -------------------------------------------------------------------------------- /frontend/requirements.txt: -------------------------------------------------------------------------------- 1 | streamlit 2 | numpy 3 | pandas 4 | matplotlib 5 | streamlit-drawable-canvas 6 | opencv-python-headless 7 | requests -------------------------------------------------------------------------------- /frontend/streamlit_main.py: -------------------------------------------------------------------------------- 1 | import streamlit as st 2 | from streamlit_drawable_canvas import st_canvas 3 | 4 | import cv2 5 | import requests 6 | import urllib 7 | import json 8 | import os 9 | # Configs 10 | MODEL_INPUT_SIZE = 28 11 | CANVAS_SIZE = MODEL_INPUT_SIZE * 8 12 | 13 | if os.environ.get("BACKEND_URL") is not None: 14 | BACKEND_URL = os.environ.get("BACKEND_URL") 15 | else: 16 | BACKEND_URL = "http://localhost:8000" 17 | 18 | MODELS_URL = urllib.parse.urljoin(BACKEND_URL, "models") 19 | TRAIN_URL = urllib.parse.urljoin(BACKEND_URL, "train") 20 | PREDICT_URL = urllib.parse.urljoin(BACKEND_URL, "predict") 21 | DELETE_URL = urllib.parse.urljoin(BACKEND_URL, "delete") 22 | 23 | 24 | st.title("Mnist training and prediction") 25 | st.sidebar.subheader("Page navigtion") 26 | page = st.sidebar.selectbox(label="", options=[ 27 | "Train", "Predict", "Delete"]) 28 | st.sidebar.write("https://github.com/zademn") 29 | 30 | if page == "Train": 31 | # Conv is not provided yet 32 | st.session_state.model_type = st.selectbox( 33 | "Model type", options=["Linear", "Conv"]) 34 | 35 | model_name = st.text_input(label="Model name", value="My Model") 36 | 37 | if st.session_state.model_type == "Linear": 38 | num_layers = st.select_slider( 39 | label="Number of hidden layers", options=[1, 2, 3]) 40 | cols = st.columns(num_layers) 41 | hidden_dims = [64] * num_layers 42 | for i in range(num_layers): 43 | hidden_dims[i] = cols[i].number_input( 44 | label=f"Number of neurons in layer {i}", min_value=2, max_value=128, value=hidden_dims[i]) 45 | 46 | hyperparams = { 47 | "input_dim": 28 * 28, 48 | "hidden_dims": hidden_dims, 49 | "output_dim": 10, 50 | } 51 | 52 | epochs = st.number_input("Epochs", min_value=1, value=5, max_value=128) 53 | 54 | if st.button("Train"): 55 | st.write(f"{hyperparams=}") 56 | to_post = {"model_name": model_name, 57 | "hyperparams": hyperparams, "epochs": epochs} 58 | response = requests.post(url=TRAIN_URL, data=json.dumps(to_post)) 59 | if response.ok: 60 | res = response.json()["result"] 61 | else: 62 | res = "Training task failed" 63 | st.write(res) 64 | 65 | # if st.session_state.model_type == "Conv": 66 | # pass 67 | 68 | elif page == "Predict": 69 | 70 | try: 71 | response = requests.get(MODELS_URL) 72 | if response.ok: 73 | model_list = response.json() 74 | model_name = st.selectbox( 75 | label="Select your model", options=model_list) 76 | else: 77 | st.write("No models found") 78 | except ConnectionError as e: 79 | st.write("Couldn't reach backend") 80 | # Setup canvas 81 | st.write("Draw something here") 82 | canvas_res = st_canvas( 83 | fill_color="black", # Black 84 | stroke_width=20, 85 | stroke_color="white", # White 86 | width=CANVAS_SIZE, 87 | height=CANVAS_SIZE, 88 | drawing_mode="freedraw", 89 | key='canvas', 90 | display_toolbar=True 91 | ) 92 | 93 | # Get image 94 | if canvas_res.image_data is not None: 95 | # Scale down image to the model input size 96 | img = cv2.resize(canvas_res.image_data.astype("uint8"), 97 | (MODEL_INPUT_SIZE, MODEL_INPUT_SIZE)) 98 | # Rescaled image upwards to show 99 | img = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY) 100 | img_rescaled = cv2.resize( 101 | img, (CANVAS_SIZE, CANVAS_SIZE), interpolation=cv2.INTER_NEAREST) 102 | st.write("Model input") 103 | st.image(img_rescaled) 104 | 105 | # Predict on the press of a button 106 | if st.button("Predict"): 107 | try: 108 | response_predict = requests.post(url=PREDICT_URL, 109 | data=json.dumps({"input_image": img.tolist(), "model_name": model_name})) 110 | if response_predict.ok: 111 | res = response_predict.json() 112 | st.markdown(f"**Prediction**: {res['result']}") 113 | 114 | else: 115 | st.write("Some error occured") 116 | except ConnectionError as e: 117 | st.write("Couldn't reach backend") 118 | 119 | elif page == "Delete": 120 | try: 121 | response = requests.get(MODELS_URL) 122 | if response.ok: 123 | model_list = response.json() 124 | model_name = st.selectbox( 125 | label="Select your model", options=model_list) 126 | else: 127 | st.write("No models found") 128 | except ConnectionError as e: 129 | st.write("Couldn't reach backend") 130 | 131 | to_post = {"model_name": model_name} 132 | # Delete on the press of a button 133 | if st.button("Delete"): 134 | try: 135 | response = requests.post(url=DELETE_URL, 136 | data=json.dumps(to_post)) 137 | if response.ok: 138 | res = response.json() 139 | st.write(res["result"]) 140 | else: 141 | st.write("Some error occured") 142 | except ConnectionError as e: 143 | st.write("Couldn't reach backend") 144 | else: 145 | st.write("Page does not exist") 146 | -------------------------------------------------------------------------------- /ml/data.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torchvision 3 | from torchvision.datasets import MNIST 4 | from torch.utils.data import DataLoader 5 | 6 | 7 | def load_mnist_data(root='data', flatten=True, batch_size=32): 8 | if flatten: 9 | transform = torchvision.transforms.Compose( 10 | [torchvision.transforms.ToTensor(), 11 | torchvision.transforms.Lambda(lambda x: torch.flatten(x))] 12 | ) 13 | else: 14 | transform = torchvision.transforms.ToTensor(), 15 | 16 | train_dataset = MNIST(root=root, download=True, transform=transform) 17 | test_dataset = MNIST(root=root, train=False, 18 | download=True, transform=transform) 19 | 20 | train_dataloader = DataLoader(train_dataset, batch_size=32, shuffle=True) 21 | test_dataloader = DataLoader(test_dataset, batch_size=32, shuffle=False) 22 | 23 | return train_dataloader, test_dataloader 24 | -------------------------------------------------------------------------------- /ml/models.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torchvision 3 | 4 | 5 | class LinearModel(torch.nn.Module): 6 | def __init__(self, hyperparameters: dict): 7 | super(LinearModel, self).__init__() 8 | 9 | # Get model config 10 | self.input_dim = hyperparameters['input_dim'] 11 | self.output_dim = hyperparameters['output_dim'] 12 | self.hidden_dims = hyperparameters['hidden_dims'] 13 | self.negative_slope = hyperparameters.get("negative_slope", .2) 14 | 15 | # Create layer list 16 | self.layers = torch.nn.ModuleList([]) 17 | all_dims = [self.input_dim, *self.hidden_dims, self.output_dim] 18 | for in_dim, out_dim in zip(all_dims[:-1], all_dims[1:]): 19 | self.layers.append(torch.nn.Linear(in_dim, out_dim)) 20 | 21 | self.num_layers = len(self.layers) 22 | 23 | def forward(self, x): 24 | for i in range(self.num_layers - 1): 25 | x = self.layers[i](x) 26 | x = torch.nn.functional.leaky_relu( 27 | x, negative_slope=self.negative_slope) 28 | x = self.layers[-1](x) 29 | return torch.nn.functional.softmax(x, dim=-1) 30 | -------------------------------------------------------------------------------- /ml/train.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from tqdm import tqdm 3 | 4 | 5 | class Trainer: 6 | def __init__(self, model, optimizer=None, criterion=None, device=None): 7 | """Initialize the trainer""" 8 | self.model = model 9 | if optimizer is not None: 10 | self.optimizer = optimizer 11 | else: 12 | self.optimizer = torch.optim.Adam(model.parameters(), lr=.001) 13 | 14 | self.criterion = torch.nn.CrossEntropyLoss() if criterion is None else criterion 15 | 16 | if device is None: 17 | self.device = "cpu" 18 | else: 19 | self.device = device 20 | 21 | self.model = self.model.to(device) 22 | 23 | def get_model(self): 24 | return self.model 25 | 26 | def train(self, num_epochs, train_dataloader, val_dataloader=None): 27 | """Trains the model and logs the results""" 28 | # Set result dict 29 | results = {"train_loss": [], "train_acc": []} 30 | if val_dataloader is not None: 31 | results["val_loss"] = [] 32 | results["val_acc"] = [] 33 | 34 | # Start training 35 | for epoch in tqdm(range(num_epochs)): 36 | train_loss, train_acc = self.train_epoch( 37 | dataloader=train_dataloader) 38 | results["train_loss"].append(train_loss) 39 | results["train_acc"].append(train_acc) 40 | # Validate only if we have a val dataloader 41 | if val_dataloader is not None: 42 | val_loss, val_acc = self.eval_epoch(dataloader=val_dataloader) 43 | results["val_loss"].append(val_loss) 44 | results["val_acc"].append(val_acc) 45 | 46 | return results 47 | 48 | def train_epoch(self, dataloader): 49 | """Trains one epoch""" 50 | self.model.train() 51 | total_loss = 0. 52 | total_correct = 0. 53 | for i, batch in enumerate(dataloader): 54 | # Send to device 55 | X, y = batch 56 | X = X.to(self.device) 57 | y = y.to(self.device) 58 | 59 | # Train step 60 | self.optimizer.zero_grad() # Clear gradients. 61 | outs = self.model(X) # Perform a single forward pass. 62 | loss = self.criterion(outs, y) 63 | 64 | loss.backward() # Derive gradients. 65 | self.optimizer.step() # Update parameters based on gradients. 66 | 67 | # Compute metrics 68 | total_loss += loss.detach().item() 69 | total_correct += torch.sum(torch.argmax(outs, 70 | dim=-1) == y).detach().item() 71 | total_acc = total_correct / (len(dataloader) * dataloader.batch_size) 72 | return total_loss, total_acc 73 | 74 | def eval_epoch(self, dataloader): 75 | self.model.eval() 76 | total_loss = 0. 77 | total_correct = 0. 78 | for i, batch in enumerate(dataloader): 79 | # Send to device 80 | X, y = batch 81 | X = X.to(self.device) 82 | y = y.to(self.device) 83 | 84 | # Eval 85 | outs = self.model(X) 86 | loss = self.criterion(outs, y) 87 | 88 | # Compute metrics 89 | total_loss += loss.detach().item() 90 | total_correct += torch.sum(torch.argmax(outs, 91 | dim=-1) == y).detach().item() 92 | total_acc = total_correct / (len(dataloader) * dataloader.batch_size) 93 | return total_loss, total_acc 94 | -------------------------------------------------------------------------------- /ml/utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | 4 | def set_device(cuda: bool = True): 5 | """ 6 | Set the device to cuda and default tensor types to FloatTensor on the device 7 | """ 8 | # Set device 9 | device = torch.device("cuda" if ( 10 | torch.cuda.is_available() and cuda) else "cpu") 11 | # Set default tensor types 12 | # torch.set_default_tensor_type("torch.FloatTensor") 13 | # if device.type == "cuda": 14 | # torch.set_default_tensor_type("torch.cuda.FloatTensor") 15 | return device 16 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | streamlit 2 | mlflow 3 | numpy 4 | pandas 5 | matplotlib 6 | torch 7 | torchvision 8 | fastapi[all] 9 | uvicorn 10 | streamlit-drawable-canvas 11 | tqdm 12 | scikit-learn 13 | opencv-python-headless -------------------------------------------------------------------------------- /resources/arch.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zademn/mnist-mlops-learning/cd70fb33dd9e2fd6d0563856c09f9bddd3b61d4b/resources/arch.png -------------------------------------------------------------------------------- /resources/pred_pic.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zademn/mnist-mlops-learning/cd70fb33dd9e2fd6d0563856c09f9bddd3b61d4b/resources/pred_pic.png -------------------------------------------------------------------------------- /resources/train_pic.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zademn/mnist-mlops-learning/cd70fb33dd9e2fd6d0563856c09f9bddd3b61d4b/resources/train_pic.png --------------------------------------------------------------------------------