├── .coveragerc ├── .github ├── CODEOWNERS ├── ISSUE_TEMPLATE │ └── epic-or-story.md ├── release-drafter.yml └── workflows │ └── check-test-release.yml ├── .gitignore ├── .pre-commit-config.yaml ├── .pylintrc ├── LICENSE ├── MANIFEST.in ├── README.md ├── hack └── scripts │ └── ci │ └── free-space.sh ├── mlem ├── __init__.py ├── api │ ├── __init__.py │ ├── commands.py │ ├── migrations.py │ └── utils.py ├── cli │ ├── __init__.py │ ├── apply.py │ ├── build.py │ ├── checkenv.py │ ├── clone.py │ ├── config.py │ ├── declare.py │ ├── deployment.py │ ├── dev.py │ ├── import_object.py │ ├── info.py │ ├── init.py │ ├── link.py │ ├── main.py │ ├── migrate.py │ ├── serve.py │ ├── types.py │ └── utils.py ├── config.py ├── constants.py ├── contrib │ ├── __init__.py │ ├── bitbucketfs.py │ ├── callable.py │ ├── catboost.py │ ├── docker │ │ ├── __init__.py │ │ ├── base.py │ │ ├── context.py │ │ ├── copy.j2 │ │ ├── dockerfile.j2 │ │ ├── helpers.py │ │ ├── install_req.j2 │ │ └── utils.py │ ├── dvc.py │ ├── fastapi.py │ ├── flyio │ │ ├── __init__.py │ │ ├── meta.py │ │ └── utils.py │ ├── git.py │ ├── github.py │ ├── gitlabfs.py │ ├── heroku │ │ ├── __init__.py │ │ ├── build.py │ │ ├── config.py │ │ ├── meta.py │ │ ├── server.py │ │ └── utils.py │ ├── kubernetes │ │ ├── __init__.py │ │ ├── base.py │ │ ├── build.py │ │ ├── context.py │ │ ├── resources.yaml.j2 │ │ ├── service.py │ │ └── utils.py │ ├── lightgbm.py │ ├── numpy.py │ ├── onnx.py │ ├── pandas.py │ ├── pil.py │ ├── pip │ │ ├── __init__.py │ │ ├── base.py │ │ ├── setup.py.j2 │ │ └── source.py.j2 │ ├── prometheus.py │ ├── rabbitmq.py │ ├── requirements.py │ ├── sagemaker │ │ ├── __init__.py │ │ ├── build.py │ │ ├── config.py │ │ ├── copy.j2 │ │ ├── env_setup.py │ │ ├── meta.py │ │ ├── mlem_sagemaker.tf │ │ ├── post_copy.j2 │ │ ├── runtime.py │ │ └── utils.py │ ├── scipy.py │ ├── sklearn.py │ ├── streamlit │ │ ├── __init__.py │ │ ├── _template.py │ │ ├── server.py │ │ └── utils.py │ ├── tensorflow.py │ ├── torch.py │ ├── torchvision.py │ ├── venv.py │ └── xgboost.py ├── core │ ├── __init__.py │ ├── artifacts.py │ ├── base.py │ ├── data_type.py │ ├── errors.py │ ├── hooks.py │ ├── import_objects.py │ ├── index.py │ ├── meta_io.py │ ├── metadata.py │ ├── model.py │ ├── objects.py │ └── requirements.py ├── ext.py ├── log.py ├── polydantic │ ├── __init__.py │ ├── core.py │ └── lazy.py ├── runtime │ ├── __init__.py │ ├── client.py │ ├── interface.py │ ├── middleware.py │ └── server.py ├── telemetry.py ├── ui.py ├── utils │ ├── __init__.py │ ├── backport.py │ ├── entrypoints.py │ ├── fslock.py │ ├── git.py │ ├── importing.py │ ├── mlem.isort.cfg │ ├── module.py │ ├── path.py │ ├── root.py │ └── templates.py └── version.py ├── pyproject.toml ├── renovate.json ├── setup.cfg ├── setup.py ├── tests ├── __init__.py ├── __main__.py ├── api │ ├── __init__.py │ ├── test_commands.py │ ├── test_migrations.py │ └── test_utils.py ├── cli │ ├── __init__.py │ ├── conftest.py │ ├── test_apply.py │ ├── test_build.py │ ├── test_checkenv.py │ ├── test_clone.py │ ├── test_config.py │ ├── test_declare.py │ ├── test_deployment.py │ ├── test_import_path.py │ ├── test_info.py │ ├── test_init.py │ ├── test_link.py │ ├── test_main.py │ ├── test_serve.py │ ├── test_stderr.py │ └── test_types.py ├── conftest.py ├── contrib │ ├── __init__.py │ ├── conftest.py │ ├── resources │ │ ├── im.jpg │ │ └── pandas │ │ │ └── .mlem.yaml │ ├── test_bitbucket.py │ ├── test_callable.py │ ├── test_catboost.py │ ├── test_docker │ │ ├── __init__.py │ │ ├── conftest.py │ │ ├── resources │ │ │ └── dockerfile.j2 │ │ ├── test_base.py │ │ ├── test_context.py │ │ ├── test_deploy.py │ │ ├── test_pack.py │ │ └── test_utils.py │ ├── test_fastapi.py │ ├── test_flyio.py │ ├── test_github.py │ ├── test_gitlab.py │ ├── test_heroku.py │ ├── test_kubernetes │ │ ├── __init__.py │ │ ├── conftest.py │ │ ├── test_base.py │ │ ├── test_context.py │ │ └── utils.py │ ├── test_lightgbm.py │ ├── test_numpy.py │ ├── test_onnx.py │ ├── test_pandas.py │ ├── test_pil.py │ ├── test_pip.py │ ├── test_prometheus.py │ ├── test_rabbitmq.py │ ├── test_requirements.py │ ├── test_scipy.py │ ├── test_sklearn.py │ ├── test_streamlit.py │ ├── test_tensorflow.py │ ├── test_torch.py │ ├── test_venv.py │ └── test_xgboost.py ├── core │ ├── __init__.py │ ├── conftest.py │ ├── custom_requirements │ │ ├── model_trainer.py │ │ ├── pack_1 │ │ │ ├── __init__.py │ │ │ ├── model.py │ │ │ └── model_type.py │ │ ├── pack_2 │ │ │ └── __init__.py │ │ ├── pkg │ │ │ ├── __init__.py │ │ │ ├── impl.py │ │ │ └── subpkg │ │ │ │ ├── __init__.py │ │ │ │ ├── impl.py │ │ │ │ └── testfile.json │ │ ├── pkg_import.py │ │ ├── proxy_model.py │ │ ├── proxy_pkg_import.py │ │ ├── shell_reqs.py │ │ ├── test_remote_custom_model.py │ │ ├── test_requirements.py │ │ ├── test_shell_reqs.py │ │ ├── unused_code.py │ │ ├── use_model.py │ │ └── use_model_meta.py │ ├── resources │ │ ├── emoji_model_inside.py │ │ ├── emoji_model_outside.py │ │ ├── emoji_model_shell.py │ │ ├── file.txt │ │ └── server.yaml │ ├── test_artifacts.py │ ├── test_base.py │ ├── test_data_io.py │ ├── test_data_type.py │ ├── test_meta_io.py │ ├── test_metadata.py │ ├── test_model_type.py │ ├── test_objects.py │ └── test_requirements.py ├── polydantic │ ├── __init__.py │ ├── test_lazy.py │ ├── test_multi.py │ └── test_serde.py ├── resources │ ├── empty │ │ └── .mlem.yaml │ └── storage │ │ └── .mlem.yaml ├── runtime │ ├── __init__.py │ ├── test_client.py │ ├── test_interface.py │ └── test_model_interface.py ├── test_config.py ├── test_ext.py ├── test_setup.py ├── test_telemetry.py └── utils │ ├── __init__.py │ ├── module_tools_mock_req.py │ ├── test_entrypoints.py │ ├── test_fslock.py │ ├── test_module_tools.py │ ├── test_path.py │ ├── test_root.py │ └── test_save.ipynb └── tox.ini /.coveragerc: -------------------------------------------------------------------------------- 1 | [run] 2 | omit = 3 | mlem/analytics.py 4 | 5 | [report] 6 | exclude_lines = 7 | pragma: no cover 8 | raise NotImplementedError 9 | @(abc\.)?abstractmethod 10 | @overload 11 | -------------------------------------------------------------------------------- /.github/CODEOWNERS: -------------------------------------------------------------------------------- 1 | * @iterative/mlem 2 | -------------------------------------------------------------------------------- /.github/ISSUE_TEMPLATE/epic-or-story.md: -------------------------------------------------------------------------------- 1 | --- 2 | name: Epic/Story Issue Template 3 | about: Template for "top" level issues - Epics (>2 weeks) / Stories (<2 weeks) 4 | title: 'Epic: New Feature' 5 | labels: epic 6 | assignees: 7 | 8 | --- 9 | 10 | ## Summary / Background 11 | What do you want to achieve, why? business context 12 | ... 13 | 14 | ## Scope 15 | 16 | What will be impacted and what won't be? 17 | What needs implementation and what is invariant? 18 | e.g. 19 | - user should be able to run workflow X from UI 20 | - enable workflow Y, Z from CLI 21 | 22 | ## Assumptions 23 | Product / UX assumptions as well as technical assumptions / limitations 24 | e.g. 25 | * Support only Python Runtime 26 | * Focus on DVC experiments only 27 | * Deployment environments don't change often and can be picked up from shared configuration 28 | 29 | ## Open Questions 30 | e.g. 31 | - How should access control work for shared artifacts (workflow X) 32 | - Python runtime assumption - is it really valid? in light of <...> 33 | 34 | ## Blockers / Dependencies 35 | List issues or other conditions / blockers 36 | 37 | ## General Approach 38 | Invocation example: 39 | ```shell 40 | $ mapper-run task.tar.gz --ray-cluster : 41 | ``` 42 | 43 | ## Steps 44 | 45 | ### Must have (p1) 46 | - [ ] subissue2 47 | - [ ] step 2 48 | - info 49 | - info 50 | 51 | ### Optional / followup (p2) 52 | - [ ] ⌛ step 3 wip 53 | - [ ] step 4 54 | 55 | ## Timelines 56 | 57 | Put your estimations here. Update once certainty changes 58 | - end of week (Feb 3) for prototype with workflows X, Y 59 | - Feb 15 - MVP in production 60 | - Low priority followups can be done later 61 | -------------------------------------------------------------------------------- /.github/release-drafter.yml: -------------------------------------------------------------------------------- 1 | # Config for https://github.com/apps/release-drafter 2 | name-template: "$RESOLVED_VERSION 🐶" 3 | tag-template: "$RESOLVED_VERSION" 4 | change-template: "- $TITLE (#$NUMBER) @$AUTHOR" 5 | version-resolver: 6 | major: 7 | labels: 8 | - "major" 9 | minor: 10 | labels: 11 | - "minor" 12 | patch: 13 | labels: 14 | - "patch" 15 | default: patch 16 | exclude-labels: 17 | - "skip-changelog" 18 | template: | 19 | 20 | ## Changes 21 | 22 | $CHANGES 23 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | share/python-wheels/ 24 | *.egg-info/ 25 | .installed.cfg 26 | *.egg 27 | MANIFEST 28 | 29 | # PyInstaller 30 | # Usually these files are written by a python script from a template 31 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 32 | *.manifest 33 | *.spec 34 | 35 | # Installer logs 36 | pip-log.txt 37 | pip-delete-this-directory.txt 38 | 39 | # Unit test / coverage reports 40 | htmlcov/ 41 | .tox/ 42 | .nox/ 43 | .coverage 44 | .coverage.* 45 | .cache 46 | nosetests.xml 47 | coverage.xml 48 | *.cover 49 | *.py,cover 50 | .hypothesis/ 51 | .pytest_cache/ 52 | cover/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | .pybuilder/ 76 | target/ 77 | 78 | # Jupyter Notebook 79 | .ipynb_checkpoints 80 | 81 | # IPython 82 | profile_default/ 83 | ipython_config.py 84 | 85 | # pyenv 86 | # For a library or package, you might want to ignore these files since the code is 87 | # intended to run in multiple environments; otherwise, check them in: 88 | # .python-version 89 | 90 | # pipenv 91 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 92 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 93 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 94 | # install all needed dependencies. 95 | #Pipfile.lock 96 | 97 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 98 | __pypackages__/ 99 | 100 | # Celery stuff 101 | celerybeat-schedule 102 | celerybeat.pid 103 | 104 | # SageMath parsed files 105 | *.sage.py 106 | 107 | # Environments 108 | .env 109 | .venv 110 | env/ 111 | venv/ 112 | ENV/ 113 | env.bak/ 114 | venv.bak/ 115 | 116 | # Spyder project settings 117 | .spyderproject 118 | .spyproject 119 | 120 | # Rope project settings 121 | .ropeproject 122 | 123 | # mkdocs documentation 124 | /site 125 | 126 | # mypy 127 | .mypy_cache/ 128 | .dmypy.json 129 | dmypy.json 130 | 131 | # Pyre type checker 132 | .pyre/ 133 | 134 | # pytype static type analyzer 135 | .pytype/ 136 | 137 | # Cython debug symbols 138 | cython_debug/ 139 | 140 | .idea 141 | 142 | mlem/_mlem_version.py 143 | -------------------------------------------------------------------------------- /.pre-commit-config.yaml: -------------------------------------------------------------------------------- 1 | default_language_version: 2 | python: python3 3 | repos: 4 | - repo: https://github.com/pre-commit/pre-commit-hooks 5 | rev: v4.4.0 6 | hooks: 7 | - id: check-added-large-files 8 | - id: check-case-conflict 9 | - id: check-docstring-first 10 | - id: check-executables-have-shebangs 11 | - id: check-toml 12 | - id: check-merge-conflict 13 | - id: check-yaml 14 | exclude: examples/layouts 15 | - id: debug-statements 16 | - id: end-of-file-fixer 17 | - id: mixed-line-ending 18 | - id: sort-simple-yaml 19 | - id: trailing-whitespace 20 | - repo: https://github.com/pycqa/flake8 21 | rev: 6.0.0 22 | hooks: 23 | - id: flake8 24 | args: 25 | - '-j8' 26 | additional_dependencies: 27 | - flake8-bugbear 28 | - flake8-comprehensions 29 | - flake8-debugger 30 | - flake8-string-format 31 | - repo: https://github.com/psf/black 32 | rev: 23.1.0 33 | hooks: 34 | - id: black 35 | - repo: 'https://github.com/PyCQA/isort' 36 | rev: 5.12.0 37 | hooks: 38 | - id: isort 39 | - repo: https://github.com/pre-commit/mirrors-mypy 40 | rev: v0.991 41 | hooks: 42 | - id: mypy 43 | additional_dependencies: 44 | - types-requests 45 | - types-six 46 | - types-PyYAML 47 | - pydantic>=1.9.0,<2 48 | - types-filelock 49 | - types-emoji 50 | - repo: local 51 | hooks: 52 | - id: pylint 53 | name: pylint 54 | entry: pylint -v 55 | language: system 56 | types: [ python ] 57 | - repo: https://github.com/PyCQA/bandit 58 | rev: 1.7.4 59 | hooks: 60 | - id: bandit 61 | exclude: tests/ 62 | args: 63 | - -iii # high level 64 | - -lll # high confidence 65 | -------------------------------------------------------------------------------- /MANIFEST.in: -------------------------------------------------------------------------------- 1 | graft mlem 2 | recursive-exclude * tests 3 | -------------------------------------------------------------------------------- /hack/scripts/ci/free-space.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env sh 2 | # Copyright 2023 The Nuclio Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | # 16 | 17 | print_free_space() { 18 | df --human-readable 19 | } 20 | 21 | # before cleanup 22 | print_free_space 23 | 24 | # clean unneeded os packages and misc 25 | sudo apt-get remove -y '^dotnet-.*' 26 | sudo apt-get remove -y 'php.*' 27 | sudo apt-get remove -y \ 28 | azure-cli \ 29 | google-cloud-sdk \ 30 | google-chrome-stable \ 31 | firefox \ 32 | powershell 33 | 34 | sudo apt-get autoremove --yes 35 | sudo apt clean 36 | 37 | # cleanup unneeded share dirs ~30GB 38 | sudo rm --recursive --force \ 39 | /usr/local/lib/android \ 40 | /usr/share/dotnet \ 41 | /usr/share/miniconda \ 42 | /usr/share/swift 43 | 44 | # clean unneeded docker images 45 | docker system prune --all --force 46 | 47 | # post cleanup 48 | print_free_space 49 | -------------------------------------------------------------------------------- /mlem/__init__.py: -------------------------------------------------------------------------------- 1 | """ 2 | MLEM is a tool to help you version and deploy your Machine Learning models: 3 | * Serialize any model trained in Python into ready-to-deploy format 4 | * Model lifecycle management using Git and GitOps principles 5 | * Provider-agnostic deployment 6 | """ 7 | import mlem.log # noqa 8 | 9 | from . import api # noqa 10 | from .config import LOCAL_CONFIG 11 | from .ext import ExtensionLoader 12 | from .version import __version__ 13 | 14 | if LOCAL_CONFIG.AUTOLOAD_EXTS: 15 | ExtensionLoader.load_all() 16 | 17 | __all__ = ["api", "__version__", "LOCAL_CONFIG"] 18 | -------------------------------------------------------------------------------- /mlem/api/__init__.py: -------------------------------------------------------------------------------- 1 | """ 2 | MLEM's Python API 3 | """ 4 | from ..core.metadata import load, load_meta, save 5 | from .commands import ( 6 | apply, 7 | apply_remote, 8 | build, 9 | clone, 10 | deploy, 11 | import_object, 12 | init, 13 | link, 14 | serve, 15 | ) 16 | 17 | __all__ = [ 18 | "save", 19 | "load", 20 | "load_meta", 21 | "clone", 22 | "init", 23 | "link", 24 | "build", 25 | "apply", 26 | "apply_remote", 27 | "import_object", 28 | "deploy", 29 | "serve", 30 | ] 31 | -------------------------------------------------------------------------------- /mlem/api/migrations.py: -------------------------------------------------------------------------------- 1 | import posixpath 2 | from typing import Callable, List, Optional 3 | 4 | from yaml import safe_dump, safe_load 5 | 6 | from mlem.core.errors import MlemObjectNotFound 7 | from mlem.core.meta_io import MLEM_EXT, Location 8 | from mlem.core.metadata import find_meta_location 9 | from mlem.ui import echo 10 | from mlem.utils.path import make_posix 11 | 12 | 13 | def migrate(path: str, project: Optional[str] = None, recursive: bool = False): 14 | path = posixpath.join(make_posix(project or ""), make_posix(path)) 15 | location = Location.resolve(path) 16 | try: 17 | location = find_meta_location(location) 18 | _migrate_one(location) 19 | return 20 | except MlemObjectNotFound: 21 | pass 22 | 23 | postfix = f"/**{MLEM_EXT}" if recursive else f"/*{MLEM_EXT}" 24 | for filepath in location.fs.glob( 25 | location.fullpath + postfix, detail=False 26 | ): 27 | print(filepath) 28 | loc = location.copy() 29 | loc.update_path(filepath) 30 | _migrate_one(loc) 31 | 32 | 33 | def apply_migrations(payload: dict): 34 | changed = False 35 | for migration in _migrations: 36 | migrated = migration(payload) 37 | if migrated is not None: 38 | payload = migrated 39 | changed = True 40 | return payload, changed 41 | 42 | 43 | def _migrate_one(location: Location): 44 | with location.open("r") as f: 45 | payload = safe_load(f) 46 | 47 | payload, changed = apply_migrations(payload) 48 | 49 | if changed: 50 | echo(f"Migrated MLEM Object at {location}") 51 | with location.open("w") as f: 52 | safe_dump(payload, f) 53 | 54 | 55 | def _migrate_to_028(meta: dict) -> Optional[dict]: 56 | if "object_type" not in meta: 57 | return None 58 | 59 | if "description" in meta: 60 | meta.pop("description") 61 | 62 | if "labels" in meta: 63 | meta.pop("labels") 64 | return meta 65 | 66 | 67 | def _migrate_to_040(meta: dict) -> Optional[dict]: 68 | if "object_type" not in meta or meta["object_type"] != "model": 69 | return None 70 | 71 | if "model_type" not in meta: 72 | return None 73 | 74 | main_model = meta.pop("model_type") 75 | meta["processors"] = {"model": main_model} 76 | meta["call_orders"] = { 77 | method: [("model", method)] for method in main_model["methods"] 78 | } 79 | return meta 80 | 81 | 82 | _migrations: List[Callable[[dict], Optional[dict]]] = [ 83 | _migrate_to_028, 84 | _migrate_to_040, 85 | ] 86 | -------------------------------------------------------------------------------- /mlem/cli/__init__.py: -------------------------------------------------------------------------------- 1 | """ 2 | MLEM's command-line interface 3 | """ 4 | from mlem.cli.apply import apply 5 | from mlem.cli.build import build 6 | from mlem.cli.checkenv import checkenv 7 | from mlem.cli.clone import clone 8 | from mlem.cli.config import config 9 | from mlem.cli.declare import declare 10 | from mlem.cli.deployment import deployment 11 | from mlem.cli.dev import dev 12 | from mlem.cli.import_object import import_object 13 | from mlem.cli.info import pretty_print 14 | from mlem.cli.init import init 15 | from mlem.cli.link import link 16 | from mlem.cli.main import app 17 | from mlem.cli.migrate import migrate 18 | from mlem.cli.serve import serve 19 | from mlem.cli.types import list_types 20 | 21 | __all__ = [ 22 | "apply", 23 | "deployment", 24 | "app", 25 | "init", 26 | "build", 27 | "pretty_print", 28 | "link", 29 | "clone", 30 | "serve", 31 | "config", 32 | "declare", 33 | "import_object", 34 | "list_types", 35 | "dev", 36 | "checkenv", 37 | "migrate", 38 | ] 39 | 40 | 41 | def main(): 42 | app() 43 | 44 | 45 | if __name__ == "__main__": 46 | main() 47 | -------------------------------------------------------------------------------- /mlem/cli/build.py: -------------------------------------------------------------------------------- 1 | from typing import List, Optional 2 | 3 | from typer import Typer 4 | 5 | from mlem.cli.main import ( 6 | app, 7 | mlem_command, 8 | mlem_group, 9 | mlem_group_callback, 10 | option_file_conf, 11 | option_load, 12 | option_model, 13 | option_project, 14 | option_rev, 15 | ) 16 | from mlem.cli.utils import ( 17 | abc_fields_parameters, 18 | config_arg, 19 | for_each_impl, 20 | lazy_class_docstring, 21 | make_not_required, 22 | ) 23 | from mlem.core.metadata import load_meta 24 | from mlem.core.objects import MlemBuilder, MlemModel 25 | from mlem.telemetry import pass_telemetry_params 26 | 27 | build = Typer( 28 | name="build", 29 | help=""" 30 | Build models into re-usable assets you can distribute and use in production, 31 | such as a Docker image or Python package. 32 | """, 33 | cls=mlem_group("runtime", aliases=["export"]), 34 | subcommand_metavar="builder", 35 | ) 36 | app.add_typer(build) 37 | 38 | 39 | @mlem_group_callback(build, required=["model", "load"]) 40 | def build_load( 41 | model: str = make_not_required(option_model), 42 | project: Optional[str] = option_project, 43 | rev: Optional[str] = option_rev, 44 | load: str = option_load("builder"), 45 | ): 46 | from mlem.api.commands import build 47 | 48 | mlem_model = load_meta(model, project, rev, force_type=MlemModel) 49 | with pass_telemetry_params(): 50 | build( 51 | config_arg( 52 | MlemBuilder, 53 | load, 54 | None, 55 | conf=None, 56 | file_conf=None, 57 | ), 58 | mlem_model, 59 | ) 60 | 61 | 62 | @for_each_impl(MlemBuilder) 63 | def create_build_command(type_name): 64 | @mlem_command( 65 | type_name, 66 | section="builders", 67 | parent=build, 68 | dynamic_metavar="__kwargs__", 69 | dynamic_options_generator=abc_fields_parameters( 70 | type_name, MlemBuilder 71 | ), 72 | hidden=type_name.startswith("_"), 73 | lazy_help=lazy_class_docstring(MlemBuilder.abs_name, type_name), 74 | no_pass_from_parent=["file_conf"], 75 | is_generated_from_ext=True, 76 | ) 77 | def build_type( 78 | model: str = option_model, 79 | project: Optional[str] = option_project, 80 | rev: Optional[str] = option_rev, 81 | file_conf: List[str] = option_file_conf("builder"), 82 | **__kwargs__ 83 | ): 84 | from mlem.api.commands import build 85 | 86 | mlem_model = load_meta(model, project, rev, force_type=MlemModel) 87 | with pass_telemetry_params(): 88 | build( 89 | config_arg( 90 | MlemBuilder, 91 | None, 92 | type_name, 93 | conf=None, 94 | file_conf=file_conf, 95 | **__kwargs__ 96 | ), 97 | mlem_model, 98 | ) 99 | -------------------------------------------------------------------------------- /mlem/cli/checkenv.py: -------------------------------------------------------------------------------- 1 | from typing import Optional 2 | 3 | from typer import Argument 4 | 5 | from mlem.cli.main import ( 6 | PATH_METAVAR, 7 | mlem_command, 8 | option_project, 9 | option_rev, 10 | ) 11 | from mlem.core.metadata import load_meta 12 | from mlem.core.objects import MlemData, MlemModel 13 | from mlem.ui import EMOJI_OK, echo 14 | 15 | 16 | @mlem_command("checkenv", hidden=True) 17 | def checkenv( 18 | path: str = Argument(..., help="Path to object", metavar=PATH_METAVAR), 19 | project: Optional[str] = option_project, 20 | rev: Optional[str] = option_rev, 21 | ): 22 | """Check that current Python environment satisfies object requirements.""" 23 | meta = load_meta(path, project, rev, follow_links=True, load_value=False) 24 | if isinstance(meta, (MlemModel, MlemData)): 25 | meta.checkenv() 26 | echo(EMOJI_OK + "Requirements are satisfied!") 27 | -------------------------------------------------------------------------------- /mlem/cli/clone.py: -------------------------------------------------------------------------------- 1 | from typing import Optional 2 | 3 | from typer import Argument 4 | 5 | from mlem.cli.main import ( 6 | mlem_command, 7 | option_project, 8 | option_rev, 9 | option_target_project, 10 | ) 11 | from mlem.telemetry import pass_telemetry_params 12 | 13 | 14 | @mlem_command("clone", section="object") 15 | def clone( 16 | uri: str = Argument(..., help="URI to object you want to clone"), 17 | target: str = Argument(..., help="Path to store the downloaded object."), 18 | project: Optional[str] = option_project, 19 | rev: Optional[str] = option_rev, 20 | target_project: Optional[str] = option_target_project, 21 | ): 22 | """Copy a MLEM Object from `uri` and 23 | saves a copy of it to `target` path. 24 | """ 25 | from mlem.api.commands import clone 26 | 27 | with pass_telemetry_params(): 28 | clone( 29 | uri, 30 | target, 31 | project=project, 32 | rev=rev, 33 | target_project=target_project, 34 | ) 35 | -------------------------------------------------------------------------------- /mlem/cli/config.py: -------------------------------------------------------------------------------- 1 | import posixpath 2 | from typing import Optional 3 | 4 | from pydantic import parse_obj_as 5 | from typer import Argument, Option, Typer 6 | from yaml import safe_dump, safe_load 7 | 8 | from mlem.cli.main import app, mlem_command, mlem_group, option_project 9 | from mlem.config import get_config_cls 10 | from mlem.constants import MLEM_CONFIG_FILE_NAME 11 | from mlem.core.base import SmartSplitDict, get_recursively, smart_split 12 | from mlem.core.errors import MlemError 13 | from mlem.core.meta_io import get_fs, get_uri 14 | from mlem.ui import EMOJI_OK, echo 15 | from mlem.utils.root import find_project_root 16 | 17 | config = Typer(name="config", cls=mlem_group("common")) 18 | app.add_typer(config) 19 | 20 | 21 | @config.callback() 22 | def config_callback(): 23 | """Manipulate MLEM configuration.""" 24 | 25 | 26 | @mlem_command("set", parent=config) 27 | def config_set( 28 | name: str = Argument(..., help="Dotted name of option"), 29 | value: str = Argument(..., help="New value"), 30 | project: Optional[str] = option_project, 31 | validate: bool = Option( 32 | True, help="Whether to validate config schema after" 33 | ), 34 | ): 35 | """Set configuration value 36 | 37 | Documentation: 38 | """ 39 | fs, path = get_fs(project or "") 40 | project = find_project_root(path, fs=fs) 41 | try: 42 | section, name = name.split(".", maxsplit=1) 43 | except ValueError as e: 44 | raise MlemError("[name] should contain at least one dot") from e 45 | config_file_path = posixpath.join(project, MLEM_CONFIG_FILE_NAME) 46 | with fs.open(config_file_path) as f: 47 | new_conf = safe_load(f) or {} 48 | 49 | conf = SmartSplitDict(new_conf.get(section, {})) 50 | conf[name] = value 51 | new_conf[section] = conf.build() 52 | if validate: 53 | config_cls = get_config_cls(section) 54 | config_cls(**new_conf[section]) 55 | with fs.open(config_file_path, "w", encoding="utf8") as f: 56 | safe_dump( 57 | new_conf, 58 | f, 59 | ) 60 | echo( 61 | EMOJI_OK 62 | + f"Set `{name}` to `{value}` in project {get_uri(fs, path, True)}" 63 | ) 64 | 65 | 66 | @mlem_command("get", parent=config) 67 | def config_get( 68 | name: str = Argument(..., help="Dotted name of option"), 69 | project: Optional[str] = option_project, 70 | ): 71 | """Get configuration value 72 | 73 | Documentation: 74 | """ 75 | fs, path = get_fs(project or "") 76 | project = find_project_root(path, fs=fs) 77 | section, name = name.split(".", maxsplit=1) 78 | config_cls = get_config_cls(section) 79 | with fs.open(posixpath.join(project, MLEM_CONFIG_FILE_NAME)) as f: 80 | try: 81 | payload = safe_load(f) or {} 82 | config_obj = parse_obj_as(config_cls, payload.get(section, {})) 83 | config_dict = config_obj.dict(skip_defaults=False) 84 | echo( 85 | get_recursively( 86 | config_dict, smart_split(name, "."), ignore_case=True 87 | ) 88 | ) 89 | except KeyError as e: 90 | raise MlemError(f"No such option `{name}`") from e 91 | -------------------------------------------------------------------------------- /mlem/cli/dev.py: -------------------------------------------------------------------------------- 1 | from typer import Argument, Typer 2 | 3 | from mlem.cli.main import app, mlem_command, mlem_group 4 | from mlem.ui import echo 5 | from mlem.utils.entrypoints import ( 6 | MLEM_ENTRY_POINT, 7 | find_abc_implementations, 8 | load_entrypoints, 9 | ) 10 | 11 | dev = Typer(name="dev", cls=mlem_group("common"), hidden=True) 12 | app.add_typer(dev) 13 | 14 | 15 | @dev.callback() 16 | def dev_callback(): 17 | """Developer utility tools 18 | 19 | Documentation: 20 | """ 21 | 22 | 23 | @mlem_command(parent=dev, aliases=["fi"]) 24 | def find_implementations_diff( 25 | root: str = Argument(MLEM_ENTRY_POINT, help="root entry point") 26 | ): 27 | """Loads `root` module or package and finds implementations of MLEM base classes 28 | Shows differences between what was found and what is registered in entrypoints 29 | 30 | Documentation: 31 | """ 32 | exts = {e.entry for e in load_entrypoints().values()} 33 | impls = set(find_abc_implementations(root)[MLEM_ENTRY_POINT]) 34 | extra = exts.difference(impls) 35 | if extra: 36 | echo("Remove implementations:") 37 | echo("\n".join(extra)) 38 | new = impls.difference(exts) 39 | if new: 40 | echo("Add implementations:") 41 | echo("\n".join(new)) 42 | -------------------------------------------------------------------------------- /mlem/cli/import_object.py: -------------------------------------------------------------------------------- 1 | from typing import Optional 2 | 3 | from typer import Argument, Option 4 | 5 | from mlem.cli.main import ( 6 | mlem_command, 7 | option_project, 8 | option_rev, 9 | option_target_project, 10 | ) 11 | from mlem.core.import_objects import ImportHook 12 | from mlem.telemetry import pass_telemetry_params 13 | from mlem.utils.entrypoints import list_implementations 14 | 15 | 16 | @mlem_command("import", section="object") 17 | def import_object( 18 | uri: str = Argument(..., help="File to import"), 19 | target: str = Argument(..., help="Path to save MLEM object"), 20 | project: Optional[str] = option_project, 21 | rev: Optional[str] = option_rev, 22 | target_project: Optional[str] = option_target_project, 23 | copy: bool = Option( 24 | True, 25 | help="Whether to create a copy of file in target location or just link existing file", 26 | ), 27 | type_: Optional[str] = Option(None, "--type", help=f"Specify how to read file Available types: {list_implementations(ImportHook)}", show_default="auto infer"), # type: ignore 28 | ): 29 | """Create a `.mlem` metafile for a model or data in any file or directory.""" 30 | from mlem.api.commands import import_object 31 | 32 | with pass_telemetry_params(): 33 | import_object( 34 | uri, 35 | project=project, 36 | rev=rev, 37 | target=target, 38 | target_project=target_project, 39 | copy_data=copy, 40 | type_=type_, 41 | ) 42 | -------------------------------------------------------------------------------- /mlem/cli/info.py: -------------------------------------------------------------------------------- 1 | from json import dumps 2 | from pprint import pprint 3 | from typing import List, Optional, Type 4 | 5 | from typer import Argument, Option 6 | 7 | from mlem.cli.main import mlem_command, option_json, option_project, option_rev 8 | from mlem.core.metadata import load_meta 9 | from mlem.core.objects import MLEM_EXT, MlemLink, MlemObject 10 | from mlem.telemetry import pass_telemetry_params 11 | from mlem.ui import echo, set_echo 12 | 13 | OBJECT_TYPE_NAMES = {"data": "Data"} 14 | 15 | 16 | def _print_objects_of_type(cls: Type[MlemObject], objects: List[MlemObject]): 17 | if len(objects) == 0: 18 | return 19 | 20 | echo( 21 | OBJECT_TYPE_NAMES.get( 22 | cls.object_type, cls.object_type.capitalize() + "s" 23 | ) 24 | + ":" 25 | ) 26 | for meta in objects: 27 | if ( 28 | isinstance(meta, MlemLink) 29 | and meta.name != meta.path[: -len(MLEM_EXT)] 30 | ): 31 | link = f"-> {meta.path[:-len(MLEM_EXT)]}" 32 | else: 33 | link = "" 34 | echo("", "-", meta.name, *[link] if link else []) 35 | 36 | 37 | @mlem_command("pprint", hidden=True) 38 | def pretty_print( 39 | path: str = Argument(..., help="Path to object"), 40 | project: Optional[str] = option_project, 41 | rev: Optional[str] = option_rev, 42 | follow_links: bool = Option( 43 | False, 44 | "-f", 45 | "--follow-links", 46 | help="If specified, follow the link to the actual object.", 47 | ), 48 | json: bool = option_json, 49 | ): 50 | """Display all details about a specific MLEM Object from an existing MLEM 51 | project. 52 | """ 53 | with set_echo(None if json else ...), pass_telemetry_params(): 54 | meta = load_meta( 55 | path, project, rev, follow_links=follow_links, load_value=False 56 | ).dict() 57 | if json: 58 | print(dumps(meta)) 59 | else: 60 | pprint(meta) 61 | -------------------------------------------------------------------------------- /mlem/cli/init.py: -------------------------------------------------------------------------------- 1 | from typer import Argument 2 | 3 | from mlem.cli.main import PATH_METAVAR, mlem_command 4 | from mlem.telemetry import pass_telemetry_params 5 | 6 | 7 | @mlem_command("init", section="common") 8 | def init( 9 | path: str = Argument( 10 | ".", 11 | help="Where to init project", 12 | show_default=False, 13 | metavar=PATH_METAVAR, 14 | ) 15 | ): 16 | """Initialize a MLEM project.""" 17 | from mlem.api.commands import init 18 | 19 | with pass_telemetry_params(): 20 | init(path) 21 | -------------------------------------------------------------------------------- /mlem/cli/link.py: -------------------------------------------------------------------------------- 1 | from typing import Optional 2 | 3 | from typer import Argument, Option 4 | 5 | from mlem.cli.main import ( 6 | PATH_METAVAR, 7 | mlem_command, 8 | option_rev, 9 | option_target_project, 10 | ) 11 | from mlem.telemetry import pass_telemetry_params 12 | 13 | 14 | @mlem_command("link", section="object") 15 | def link( 16 | source: str = Argument( 17 | ..., help="URI of the MLEM object you are creating a link to" 18 | ), 19 | target: str = Argument(..., help="Path to save link object"), 20 | source_project: Optional[str] = Option( 21 | None, 22 | "--source-project", 23 | "--sp", 24 | help="Project for source object", 25 | metavar=PATH_METAVAR, 26 | ), 27 | rev: Optional[str] = option_rev, 28 | target_project: Optional[str] = option_target_project, 29 | follow_links: bool = Option( 30 | True, 31 | "--follow-links/--no-follow-links", 32 | "--f/--nf", 33 | help="If True, first follow links while reading {source} before creating this link.", 34 | ), 35 | absolute: bool = Option( 36 | False, 37 | "--absolute/--relative", 38 | "--abs/--rel", 39 | help="Which path to linked object to specify: absolute or relative.", 40 | ), 41 | ): 42 | """Create a link (read alias) for an existing MLEM Object, including from 43 | remote MLEM projects. 44 | """ 45 | from mlem.api.commands import link 46 | 47 | with pass_telemetry_params(): 48 | link( 49 | source=source, 50 | source_project=source_project, 51 | rev=rev, 52 | target=target, 53 | target_project=target_project, 54 | follow_links=follow_links, 55 | absolute=absolute, 56 | ) 57 | -------------------------------------------------------------------------------- /mlem/cli/migrate.py: -------------------------------------------------------------------------------- 1 | from typing import Optional 2 | 3 | from typer import Argument, Option 4 | 5 | from mlem.cli.main import mlem_command, option_project 6 | from mlem.telemetry import pass_telemetry_params 7 | 8 | 9 | @mlem_command("migrate", section="object") 10 | def migrate( 11 | path: str = Argument( 12 | ..., 13 | help="URI of the MLEM object you are migrating or directory to migrate", 14 | ), 15 | project: Optional[str] = option_project, 16 | recursive: bool = Option( 17 | False, "--recursive", "-r", help="Enable recursive search of directory" 18 | ), 19 | ): 20 | """Migrate metadata objects from older MLEM version""" 21 | from mlem.api.migrations import migrate 22 | 23 | with pass_telemetry_params(): 24 | migrate(path, project, recursive=recursive) 25 | -------------------------------------------------------------------------------- /mlem/cli/serve.py: -------------------------------------------------------------------------------- 1 | from typing import List, Optional 2 | 3 | from typer import Typer 4 | 5 | from mlem.cli.main import ( 6 | app, 7 | mlem_command, 8 | mlem_group, 9 | mlem_group_callback, 10 | option_file_conf, 11 | option_load, 12 | option_model, 13 | option_project, 14 | option_rev, 15 | ) 16 | from mlem.cli.utils import ( 17 | abc_fields_parameters, 18 | config_arg, 19 | for_each_impl, 20 | lazy_class_docstring, 21 | make_not_required, 22 | ) 23 | from mlem.core.metadata import load_meta 24 | from mlem.core.objects import MlemModel 25 | from mlem.runtime.server import Server 26 | from mlem.telemetry import pass_telemetry_params 27 | 28 | serve = Typer( 29 | name="serve", 30 | help="""Create an API from model methods using a server implementation.""", 31 | cls=mlem_group("runtime"), 32 | subcommand_metavar="server", 33 | ) 34 | app.add_typer(serve) 35 | 36 | 37 | @mlem_group_callback(serve, required=["model", "load"]) 38 | def serve_load( 39 | model: str = make_not_required(option_model), 40 | project: Optional[str] = option_project, 41 | rev: Optional[str] = option_rev, 42 | load: Optional[str] = option_load("server"), 43 | ): 44 | from mlem.api.commands import serve 45 | 46 | with pass_telemetry_params(): 47 | serve( 48 | load_meta(model, project, rev, force_type=MlemModel), 49 | config_arg( 50 | Server, 51 | load, 52 | None, 53 | conf=None, 54 | file_conf=None, 55 | ), 56 | ) 57 | 58 | 59 | @for_each_impl(Server) 60 | def create_serve_command(type_name): 61 | @mlem_command( 62 | type_name, 63 | section="servers", 64 | parent=serve, 65 | dynamic_metavar="__kwargs__", 66 | dynamic_options_generator=abc_fields_parameters(type_name, Server), 67 | hidden=type_name.startswith("_"), 68 | lazy_help=lazy_class_docstring(Server.abs_name, type_name), 69 | no_pass_from_parent=["file_conf"], 70 | is_generated_from_ext=True, 71 | ) 72 | def serve_command( 73 | model: str = option_model, 74 | project: Optional[str] = option_project, 75 | rev: Optional[str] = option_rev, 76 | file_conf: List[str] = option_file_conf("server"), 77 | **__kwargs__ 78 | ): 79 | from mlem.api.commands import serve 80 | 81 | mlem_model = load_meta(model, project, rev, force_type=MlemModel) 82 | with pass_telemetry_params(): 83 | serve( 84 | mlem_model, 85 | config_arg( 86 | Server, 87 | None, 88 | type_name, 89 | conf=None, 90 | file_conf=file_conf, 91 | **__kwargs__ 92 | ), 93 | ) 94 | -------------------------------------------------------------------------------- /mlem/constants.py: -------------------------------------------------------------------------------- 1 | MLEM_STATE_DIR = ".mlem.state" 2 | MLEM_STATE_EXT = ".state" 3 | 4 | PREDICT_METHOD_NAME = "predict" 5 | PREDICT_PROBA_METHOD_NAME = "predict_proba" 6 | PREDICT_ARG_NAME = "data" 7 | TRANSFORM_METHOD_NAME = "transform" 8 | 9 | MLEM_CONFIG_FILE_NAME = ".mlem.yaml" 10 | -------------------------------------------------------------------------------- /mlem/contrib/__init__.py: -------------------------------------------------------------------------------- 1 | """ 2 | Officially supported extensions 3 | """ 4 | -------------------------------------------------------------------------------- /mlem/contrib/docker/__init__.py: -------------------------------------------------------------------------------- 1 | """Docker builds support 2 | Extension type: deployment 3 | 4 | Building docker images from the model 5 | or packing all necessary things to do that in a folder 6 | """ 7 | from .base import DockerDirBuilder, DockerImageBuilder 8 | 9 | __all__ = ["DockerImageBuilder", "DockerDirBuilder"] 10 | -------------------------------------------------------------------------------- /mlem/contrib/docker/copy.j2: -------------------------------------------------------------------------------- 1 | COPY . ./ 2 | -------------------------------------------------------------------------------- /mlem/contrib/docker/dockerfile.j2: -------------------------------------------------------------------------------- 1 | FROM {{ base_image }} 2 | WORKDIR /app 3 | {% for name, value in arg.items() %}ARG {{ name }} 4 | {% endfor %} 5 | {% include "pre_install.j2" ignore missing %} 6 | {% include "install_req.j2" %} 7 | {% include "post_install.j2" ignore missing %} 8 | {% include "copy.j2" %} 9 | {% for name, value in env.items() %}ENV {{ name }}={{ value }} 10 | {% endfor %} 11 | {% include "post_copy.j2" ignore missing %} 12 | {% if run_cmd is not none %}CMD {{ run_cmd }}{% endif %} 13 | -------------------------------------------------------------------------------- /mlem/contrib/docker/helpers.py: -------------------------------------------------------------------------------- 1 | from mlem.core.objects import MlemModel 2 | from mlem.runtime.server import Server 3 | 4 | from . import DockerImageBuilder 5 | from .base import ( 6 | DockerBuildArgs, 7 | DockerDaemon, 8 | DockerImage, 9 | DockerImageOptions, 10 | DockerRegistry, 11 | ) 12 | 13 | 14 | def build_model_image( 15 | model: MlemModel, 16 | name: str, 17 | server: Server = None, 18 | daemon: DockerDaemon = None, 19 | registry: DockerRegistry = None, 20 | tag: str = "latest", 21 | repository: str = None, 22 | force_overwrite: bool = True, 23 | push: bool = True, 24 | **build_args 25 | ) -> DockerImage: 26 | registry = registry or DockerRegistry() 27 | daemon = daemon or DockerDaemon() 28 | image = DockerImageOptions( 29 | name=name, tag=tag, repository=repository, registry=registry 30 | ) 31 | builder = DockerImageBuilder( 32 | server=server, 33 | args=DockerBuildArgs(**build_args), 34 | image=image, 35 | daemon=daemon, 36 | force_overwrite=force_overwrite, 37 | push=push, 38 | ) 39 | return builder.build(model) 40 | -------------------------------------------------------------------------------- /mlem/contrib/docker/install_req.j2: -------------------------------------------------------------------------------- 1 | # install Git in case something in requirements.txt will be installed from Git repo 2 | RUN {{ package_install_cmd }} git {{ package_clean_cmd }} 3 | {% if packages %}RUN {{ package_install_cmd }} {{ packages|join(" ") }} {{ package_clean_cmd }}{% endif %} 4 | COPY requirements.txt . 5 | RUN pip install -r requirements.txt && pip cache purge 6 | {{ mlem_install }} 7 | -------------------------------------------------------------------------------- /mlem/contrib/flyio/__init__.py: -------------------------------------------------------------------------------- 1 | """fly.io Deployments support 2 | Extension type: deployment 3 | 4 | Implements MlemEnv, MlemDeployment and DeployState to work with fly.io 5 | """ 6 | -------------------------------------------------------------------------------- /mlem/contrib/flyio/utils.py: -------------------------------------------------------------------------------- 1 | import json 2 | import os.path 3 | import subprocess 4 | from typing import Any, Dict 5 | 6 | from pydantic import BaseModel, parse_obj_as 7 | 8 | from mlem.core.errors import DeploymentError 9 | 10 | FLY_TOML = "fly.toml" 11 | 12 | 13 | def check_flyctl_exec(): 14 | try: 15 | run_flyctl("version", wrap_error=False) 16 | except subprocess.SubprocessError as e: 17 | raise DeploymentError( 18 | "flyctl executable is not available. Please install it using " 19 | ) from e 20 | 21 | 22 | def run_flyctl( 23 | command: str, 24 | workdir: str = None, 25 | kwargs: Dict[str, Any] = None, 26 | wrap_error=True, 27 | ): 28 | kwargs = kwargs or {} 29 | cmd = ( 30 | ["flyctl"] 31 | + command.split(" ") 32 | + " ".join( 33 | [ 34 | f"--{k} {v}" if v is not True else f"--{k}" 35 | for k, v in kwargs.items() 36 | ] 37 | ).split() 38 | ) 39 | try: 40 | return subprocess.check_output(cmd, cwd=workdir) 41 | except subprocess.SubprocessError as e: 42 | if wrap_error: 43 | raise DeploymentError(e) from e 44 | raise 45 | 46 | 47 | def read_fly_toml(workdir: str): 48 | with open(os.path.join(workdir, FLY_TOML), encoding="utf8") as f: 49 | return f.read() 50 | 51 | 52 | def place_fly_toml(workdir: str, fly_toml: str): 53 | with open(os.path.join(workdir, FLY_TOML), "w", encoding="utf8") as f: 54 | f.write(fly_toml) 55 | 56 | 57 | class FlyioStatusModel(BaseModel): 58 | Name: str 59 | Status: str 60 | Hostname: str 61 | 62 | 63 | def get_status(workdir: str = None, app_name: str = None) -> FlyioStatusModel: 64 | args: Dict[str, Any] = {"json": True} 65 | if app_name is not None: 66 | args["app"] = app_name 67 | status = run_flyctl("status", kwargs=args, workdir=workdir) 68 | return parse_obj_as(FlyioStatusModel, json.loads(status)) 69 | 70 | 71 | class FlyioScaleModel(BaseModel): 72 | Name: str 73 | CPUCores: int 74 | CPUClass: str 75 | MemoryGB: float 76 | MemoryMB: int 77 | PriceMonth: float 78 | PriceSecond: float 79 | Count: str 80 | MaxPerRegion: str 81 | 82 | 83 | def get_scale(workdir: str = None, app_name: str = None) -> FlyioScaleModel: 84 | args: Dict[str, Any] = {"json": True} 85 | if app_name is not None: 86 | args["app"] = app_name 87 | status = run_flyctl("scale show", kwargs=args, workdir=workdir) 88 | return parse_obj_as(FlyioScaleModel, json.loads(status)) 89 | -------------------------------------------------------------------------------- /mlem/contrib/git.py: -------------------------------------------------------------------------------- 1 | """Local git repos support 2 | Extension type: uri 3 | 4 | Implementation of `LocalGitResolver` 5 | """ 6 | import os 7 | import posixpath 8 | from typing import ClassVar, Optional, Tuple 9 | 10 | from fsspec import AbstractFileSystem, get_fs_token_paths 11 | from fsspec.implementations.git import GitFileSystem 12 | from git import InvalidGitRepositoryError, NoSuchPathError, Repo 13 | 14 | from mlem.core.meta_io import UriResolver 15 | 16 | 17 | class LocalGitResolver(UriResolver): 18 | """Resolve git repositories on local fs""" 19 | 20 | type: ClassVar = "local_git" 21 | versioning_support: ClassVar = True 22 | 23 | @classmethod 24 | def check( 25 | cls, 26 | path: str, 27 | project: Optional[str], 28 | rev: Optional[str], 29 | fs: Optional[AbstractFileSystem], 30 | ) -> bool: 31 | if isinstance(fs, GitFileSystem): 32 | return True 33 | if rev is None: 34 | return False 35 | return cls._find_local_git(path) is not None 36 | 37 | @classmethod 38 | def get_fs( 39 | cls, uri: str, rev: Optional[str] 40 | ) -> Tuple[AbstractFileSystem, str]: 41 | git_dir = cls._find_local_git(uri) 42 | fs, _, (path,) = get_fs_token_paths( 43 | os.path.relpath(uri, git_dir), 44 | protocol="git", 45 | storage_options={"ref": rev, "path": git_dir}, 46 | ) 47 | return fs, path 48 | 49 | @classmethod 50 | def get_uri( 51 | cls, 52 | path: str, 53 | project: Optional[str], 54 | rev: Optional[str], 55 | fs: GitFileSystem, 56 | ): 57 | fullpath = posixpath.join(project or "", path) 58 | return f"git://{fs.repo.workdir}:{rev or fs.ref}@{fullpath}" 59 | 60 | @classmethod 61 | def _find_local_git(cls, path: str) -> Optional[str]: 62 | try: 63 | return Repo(path, search_parent_directories=True).working_dir 64 | except (InvalidGitRepositoryError, NoSuchPathError): 65 | return None 66 | -------------------------------------------------------------------------------- /mlem/contrib/heroku/__init__.py: -------------------------------------------------------------------------------- 1 | """Heroku Deployments support 2 | Extension type: deployment 3 | 4 | Implements MlemEnv, MlemDeployment and DeployState to work with heroku.com 5 | """ 6 | -------------------------------------------------------------------------------- /mlem/contrib/heroku/build.py: -------------------------------------------------------------------------------- 1 | import logging 2 | from typing import ClassVar, Optional 3 | 4 | from mlem.core.objects import MlemModel 5 | 6 | from ...runtime.server import Server 7 | from ...ui import EMOJI_BUILD, echo, set_offset 8 | from ..docker.base import DockerImage, RemoteRegistry 9 | from ..docker.helpers import build_model_image 10 | from .server import HerokuServer 11 | 12 | DEFAULT_HEROKU_REGISTRY = "registry.heroku.com" 13 | 14 | logger = logging.getLogger(__name__) 15 | 16 | 17 | class HerokuRemoteRegistry(RemoteRegistry): 18 | """Heroku docker registry""" 19 | 20 | type: ClassVar = "heroku" 21 | api_key: Optional[str] = None 22 | """HEROKU_API_KEY""" 23 | host: str = DEFAULT_HEROKU_REGISTRY 24 | """Registry host""" 25 | 26 | def uri(self, image: str): 27 | return super().uri(image).split(":")[0] 28 | 29 | def login(self, client): 30 | from .utils import get_api_key 31 | 32 | password = self.api_key or get_api_key() 33 | if password is None: 34 | raise ValueError( 35 | "Cannot login to heroku docker registry: no api key provided" 36 | ) 37 | try: 38 | self._login(self.host, client, "_", password) 39 | except Exception as e: 40 | raise ValueError([]) from e 41 | 42 | 43 | def build_heroku_docker( 44 | meta: MlemModel, 45 | server: Server, 46 | app_name: str, 47 | process_type: str = "web", 48 | api_key: str = None, 49 | push: bool = True, 50 | ) -> DockerImage: 51 | echo(EMOJI_BUILD + "Creating docker image for heroku") 52 | with set_offset(2): 53 | return build_model_image( 54 | meta, 55 | process_type, 56 | server=HerokuServer(server=server), 57 | registry=HerokuRemoteRegistry( 58 | host="registry.heroku.com", api_key=api_key 59 | ), 60 | repository=app_name, 61 | force_overwrite=True, 62 | # heroku does not support arm64 images built on Mac M1 devices 63 | # todo: add this to docs for heroku deploy https://github.com/iterative/mlem/issues/151 64 | # notice: if you previosly built an arm64 image on the same device, 65 | # you may cached base images (e.g `python` ) for this image for another architecture and build will fail 66 | # with message "image with reference sha256:... was found but does not match the specified platform ..." 67 | platform="linux/amd64", 68 | push=push, 69 | ) 70 | -------------------------------------------------------------------------------- /mlem/contrib/heroku/config.py: -------------------------------------------------------------------------------- 1 | from typing import Optional 2 | 3 | from mlem.config import MlemConfigBase 4 | 5 | 6 | class HerokuConfig(MlemConfigBase): 7 | API_KEY: Optional[str] = None 8 | 9 | class Config: 10 | env_prefix = "heroku_" 11 | section = "heroku" 12 | 13 | 14 | HEROKU_CONFIG = HerokuConfig() 15 | -------------------------------------------------------------------------------- /mlem/contrib/heroku/server.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import os 3 | from typing import ClassVar, Dict 4 | 5 | from pydantic import validator 6 | 7 | from mlem.core.requirements import Requirements 8 | from mlem.runtime import Interface 9 | from mlem.runtime.server import Server 10 | 11 | logger = logging.getLogger(__name__) 12 | 13 | 14 | class HerokuServer(Server): 15 | """Special FastAPI server to pickup port from env PORT""" 16 | 17 | type: ClassVar = "_heroku" 18 | server: Server 19 | 20 | @validator("server") 21 | @classmethod 22 | def server_validator(cls, value: Server): 23 | if value.port_field is None: 24 | raise ValueError( 25 | f"{value} does not have port field and can not be exposed on heroku" 26 | ) 27 | return value 28 | 29 | def serve(self, interface: Interface): 30 | assert self.server.port_field is not None # ensured by validator 31 | setattr(self.server, self.server.port_field, int(os.environ["PORT"])) 32 | logger.info( 33 | "Switching port to %s", 34 | getattr(self.server, self.server.port_field), 35 | ) 36 | return self.server.serve(interface) 37 | 38 | def get_requirements(self) -> Requirements: 39 | return self.server.get_requirements() 40 | 41 | def get_env_vars(self) -> Dict[str, str]: 42 | env_vars = super().get_env_vars() 43 | env_vars.update(self.server.get_env_vars()) 44 | return env_vars 45 | 46 | def get_sources(self) -> Dict[str, bytes]: 47 | return self.server.get_sources() 48 | -------------------------------------------------------------------------------- /mlem/contrib/kubernetes/__init__.py: -------------------------------------------------------------------------------- 1 | """Kubernetes Deployments support 2 | Extension type: deployment 3 | """ 4 | -------------------------------------------------------------------------------- /mlem/contrib/kubernetes/build.py: -------------------------------------------------------------------------------- 1 | from typing import List, Optional 2 | 3 | from mlem.core.objects import MlemModel 4 | from mlem.runtime.server import Server 5 | from mlem.ui import EMOJI_BUILD, echo, set_offset 6 | 7 | from ..docker.base import DockerDaemon, DockerRegistry 8 | from ..docker.helpers import build_model_image 9 | 10 | 11 | def build_k8s_docker( 12 | meta: MlemModel, 13 | image_name: str, 14 | registry: Optional[DockerRegistry], 15 | daemon: Optional[DockerDaemon], 16 | server: Server, 17 | platform: Optional[str] = "linux/amd64", 18 | # runners usually do not support arm64 images built on Mac M1 devices 19 | build_arg: Optional[List[str]] = None, 20 | set_env: Optional[List[str]] = None, 21 | ): 22 | echo(EMOJI_BUILD + f"Creating docker image {image_name}") 23 | with set_offset(2): 24 | return build_model_image( 25 | meta, 26 | image_name, 27 | server, 28 | daemon=daemon, 29 | registry=registry, 30 | tag=meta.meta_hash(), 31 | force_overwrite=True, 32 | platform=platform, 33 | build_arg=build_arg or [], 34 | set_env=set_env or [], 35 | ) 36 | -------------------------------------------------------------------------------- /mlem/contrib/kubernetes/context.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import os 3 | from enum import Enum 4 | from typing import ClassVar 5 | 6 | from pydantic import BaseModel 7 | 8 | from mlem.contrib.kubernetes.service import NodePortService, ServiceType 9 | from mlem.utils.templates import TemplateModel 10 | 11 | logger = logging.getLogger(__name__) 12 | 13 | 14 | class ImagePullPolicy(str, Enum): 15 | always = "Always" 16 | never = "Never" 17 | if_not_present = "IfNotPresent" 18 | 19 | 20 | class K8sYamlBuildArgs(BaseModel): 21 | """Class encapsulating parameters for Kubernetes manifests/yamls""" 22 | 23 | class Config: 24 | use_enum_values = True 25 | 26 | namespace: str = "mlem" 27 | """Namespace to create kubernetes resources such as pods, service in""" 28 | image_name: str = "ml" 29 | """Name of the docker image to be deployed""" 30 | image_uri: str = "ml:latest" 31 | """URI of the docker image to be deployed""" 32 | image_pull_policy: ImagePullPolicy = ImagePullPolicy.always 33 | """Image pull policy for the docker image to be deployed""" 34 | port: int = 8080 35 | """Port where the service should be available""" 36 | service_type: ServiceType = NodePortService() 37 | """Type of service by which endpoints of the model are exposed""" 38 | 39 | 40 | class K8sYamlGenerator(K8sYamlBuildArgs, TemplateModel): 41 | TEMPLATE_FILE: ClassVar = "resources.yaml.j2" 42 | TEMPLATE_DIR: ClassVar = os.path.dirname(__file__) 43 | 44 | def prepare_dict(self): 45 | logger.debug( 46 | 'Generating Resource Yaml via templates from "%s"...', 47 | self.templates_dir, 48 | ) 49 | 50 | logger.debug('Docker image is based on "%s".', self.image_uri) 51 | 52 | k8s_yaml_args = self.dict() 53 | k8s_yaml_args["service_type"] = self.service_type.get_string() 54 | k8s_yaml_args.pop("templates_dir") 55 | return k8s_yaml_args 56 | -------------------------------------------------------------------------------- /mlem/contrib/kubernetes/resources.yaml.j2: -------------------------------------------------------------------------------- 1 | apiVersion: v1 2 | kind: Namespace 3 | metadata: 4 | name: {{ namespace }} 5 | labels: 6 | name: {{ namespace }} 7 | 8 | --- 9 | 10 | apiVersion: apps/v1 11 | kind: Deployment 12 | metadata: 13 | name: {{ image_name }} 14 | namespace: {{ namespace }} 15 | spec: 16 | selector: 17 | matchLabels: 18 | app: {{ image_name }} 19 | template: 20 | metadata: 21 | labels: 22 | app: {{ image_name }} 23 | spec: 24 | containers: 25 | - name: {{ image_name }} 26 | image: {{ image_uri }} 27 | imagePullPolicy: {{ image_pull_policy }} 28 | ports: 29 | - containerPort: {{ port }} 30 | 31 | --- 32 | 33 | apiVersion: v1 34 | kind: Service 35 | metadata: 36 | name: {{ image_name }} 37 | namespace: {{ namespace }} 38 | labels: 39 | run: {{ image_name }} 40 | spec: 41 | ports: 42 | - port: {{ port }} 43 | protocol: TCP 44 | targetPort: {{ port }} 45 | selector: 46 | app: {{ image_name }} 47 | type: {{ service_type }} 48 | -------------------------------------------------------------------------------- /mlem/contrib/kubernetes/utils.py: -------------------------------------------------------------------------------- 1 | import json 2 | import os 3 | import tempfile 4 | 5 | from kubernetes import client, utils, watch 6 | 7 | from .context import K8sYamlGenerator 8 | 9 | 10 | def create_k8s_resources(generator: K8sYamlGenerator): 11 | k8s_client = client.ApiClient() 12 | with tempfile.TemporaryDirectory(prefix="mlem_k8s_yaml_build_") as tempdir: 13 | filename = os.path.join(tempdir, "resource.yaml") 14 | generator.write(filename) 15 | try: 16 | utils.create_from_yaml(k8s_client, filename, verbose=True) 17 | except utils.FailToCreateError as e: 18 | failures = e.api_exceptions 19 | for each_failure in failures: 20 | error_info = json.loads(each_failure.body) 21 | if error_info["reason"] != "AlreadyExists": 22 | raise e 23 | if error_info["details"]["kind"] == "deployments": 24 | existing_image_uri = ( 25 | client.CoreV1Api() 26 | .list_namespaced_pod(generator.namespace) 27 | .items[0] 28 | .spec.containers[0] 29 | .image 30 | ) 31 | if existing_image_uri != generator.image_uri: 32 | api_instance = client.AppsV1Api() 33 | body = { 34 | "spec": { 35 | "template": { 36 | "spec": { 37 | "containers": [ 38 | { 39 | "name": generator.image_name, 40 | "image": generator.image_uri, 41 | } 42 | ] 43 | } 44 | } 45 | } 46 | } 47 | api_instance.patch_namespaced_deployment( 48 | generator.image_name, 49 | generator.namespace, 50 | body, 51 | pretty=True, 52 | ) 53 | 54 | 55 | def pod_is_running(namespace, timeout=60) -> bool: 56 | w = watch.Watch() 57 | for event in w.stream( 58 | func=client.CoreV1Api().list_namespaced_pod, 59 | namespace=namespace, 60 | timeout_seconds=timeout, 61 | ): 62 | if event["object"].status.phase == "Running": 63 | w.stop() 64 | return True 65 | return False 66 | 67 | 68 | def namespace_deleted(namespace, timeout=60) -> bool: 69 | w = watch.Watch() 70 | for event in w.stream( 71 | func=client.CoreV1Api().list_namespace, 72 | timeout_seconds=timeout, 73 | ): 74 | if ( 75 | namespace == event["object"].metadata.name 76 | and event["type"] == "DELETED" 77 | ): 78 | w.stop() 79 | return True 80 | return False 81 | -------------------------------------------------------------------------------- /mlem/contrib/pil.py: -------------------------------------------------------------------------------- 1 | """PIL Image support 2 | Extension type: data 3 | 4 | Serializer for PIL Images to/from numpy arrays 5 | """ 6 | import contextlib 7 | from io import BytesIO 8 | from typing import Any, BinaryIO, ClassVar, Iterator, Union 9 | 10 | import numpy 11 | from PIL import Image 12 | 13 | from mlem.contrib.numpy import NumpyNdarrayType 14 | from mlem.core.data_type import BinarySerializer 15 | 16 | 17 | class PILImageSerializer(BinarySerializer): 18 | """Serializes numpy arrays to/from images""" 19 | 20 | type: ClassVar = "pil_numpy" 21 | support_files: ClassVar = True 22 | 23 | format: str = "jpeg" 24 | "Image format to use" 25 | 26 | def serialize(self, data_type: NumpyNdarrayType, instance: Any) -> bytes: 27 | with self.dump(data_type, instance) as b: 28 | return b.getvalue() 29 | 30 | @contextlib.contextmanager 31 | def dump( 32 | self, data_type: NumpyNdarrayType, instance: Any 33 | ) -> Iterator[BytesIO]: 34 | buffer = BytesIO() 35 | Image.fromarray(instance).save(buffer, format=self.format) 36 | buffer.seek(0) 37 | yield buffer 38 | 39 | def deserialize( 40 | self, data_type: NumpyNdarrayType, obj: Union[bytes, BinaryIO] 41 | ) -> Any: 42 | if isinstance(obj, bytes): 43 | obj = BytesIO(obj) 44 | im = Image.open(obj) 45 | return numpy.array(im) 46 | -------------------------------------------------------------------------------- /mlem/contrib/pip/__init__.py: -------------------------------------------------------------------------------- 1 | """Python Package builds support 2 | Extension type: build 3 | 4 | Contains two Builder implementations: `pip` to create a directory with 5 | Python Package from model and `whl` to create a wheel file with Python Package 6 | """ 7 | -------------------------------------------------------------------------------- /mlem/contrib/pip/setup.py.j2: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | 4 | import io 5 | import os 6 | 7 | from setuptools import find_packages, setup 8 | 9 | 10 | # Package meta-data. 11 | NAME = "{{ package_name }}" 12 | DESCRIPTION = "{{ short_description }}" 13 | VERSION = "{{ version }}" 14 | URL = "{{ url }}" 15 | EMAIL = "{{ email }}" 16 | AUTHOR = "{{ author }}" 17 | REQUIRES_PYTHON = "{{ python_version }}" 18 | 19 | 20 | def list_reqs(fname='requirements.txt'): 21 | with open(fname) as fd: 22 | return fd.read().splitlines() 23 | 24 | here = os.path.abspath(os.path.dirname(__file__)) 25 | try: 26 | with io.open(os.path.join(here, 'README.md'), encoding='utf-8') as f: 27 | long_description = '\n' + f.read() 28 | except FileNotFoundError: 29 | long_description = DESCRIPTION 30 | 31 | 32 | setup( 33 | name=NAME, 34 | description=DESCRIPTION, 35 | version=VERSION, 36 | long_description=long_description, 37 | long_description_content_type='text/markdown', 38 | author=AUTHOR, 39 | author_email=EMAIL, 40 | python_requires=REQUIRES_PYTHON, 41 | url=URL, 42 | packages=find_packages(exclude=('tests',)), 43 | install_requires=list_reqs(), 44 | extras_require={}, 45 | include_package_data=True, 46 | **{{ additional_setup_kwargs }} 47 | ) 48 | -------------------------------------------------------------------------------- /mlem/contrib/pip/source.py.j2: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | from mlem.core.metadata import load_meta 4 | 5 | model = load_meta(os.path.join(os.path.dirname(__file__), "model"), load_value=True) 6 | 7 | def _create_method(name): 8 | def method(*args, **kwargs): 9 | return getattr(model, name)(*args, **kwargs) 10 | method.__name__ = name 11 | return method 12 | 13 | {% for method in methods %} 14 | {{ method }} = _create_method("{{ method }}") 15 | {% endfor %} 16 | -------------------------------------------------------------------------------- /mlem/contrib/prometheus.py: -------------------------------------------------------------------------------- 1 | """Instrumenting FastAPI app to expose metrics for prometheus 2 | Extension type: middleware 3 | 4 | Exposes /metrics endpoint 5 | """ 6 | from typing import ClassVar, List, Optional 7 | 8 | from fastapi import FastAPI 9 | from prometheus_fastapi_instrumentator import Instrumentator 10 | 11 | from mlem.contrib.fastapi import FastAPIMiddleware 12 | from mlem.utils.importing import import_string_with_local 13 | from mlem.utils.module import get_object_requirements 14 | 15 | 16 | class PrometheusFastAPIMiddleware(FastAPIMiddleware): 17 | """Middleware for FastAPI server that exposes /metrics endpoint to be scraped by Prometheus""" 18 | 19 | type: ClassVar = "prometheus_fastapi" 20 | 21 | metrics: List[str] = [] 22 | """Instrumentator instance to use. If not provided, a new one will be created""" 23 | instrumentator_cache: Optional[Instrumentator] = None 24 | 25 | class Config: 26 | arbitrary_types_allowed = True 27 | exclude = {"instrumentator_cache"} 28 | 29 | @property 30 | def instrumentator(self): 31 | if self.instrumentator_cache is None: 32 | self.instrumentator_cache = self.get_instrumentator() 33 | return self.instrumentator_cache 34 | 35 | def on_app_init(self, app: FastAPI): 36 | self.instrumentator.instrument(app) 37 | 38 | @app.on_event("startup") 39 | async def _startup(): 40 | self.instrumentator.expose(app) 41 | 42 | def on_init(self): 43 | pass 44 | 45 | def on_request(self, request): 46 | return request 47 | 48 | def on_response(self, request, response): 49 | return response 50 | 51 | def get_instrumentator(self): 52 | instrumentator = Instrumentator() 53 | for metric in self._iter_metric_objects(): 54 | # todo: check object type 55 | instrumentator.add(metric) 56 | return instrumentator 57 | 58 | def _iter_metric_objects(self): 59 | for metric in self.metrics: 60 | # todo: meaningful error on import error 61 | yield import_string_with_local(metric) 62 | 63 | def get_requirements(self): 64 | reqs = super().get_requirements() 65 | for metric in self._iter_metric_objects(): 66 | reqs += get_object_requirements(metric) 67 | return reqs 68 | -------------------------------------------------------------------------------- /mlem/contrib/requirements.py: -------------------------------------------------------------------------------- 1 | """Requirements support 2 | Extension type: build 3 | 4 | MlemBuilder implementation for `Requirements` which includes 5 | installable, conda, unix, custom, file etc. based requirements. 6 | """ 7 | import logging 8 | from typing import ClassVar, Optional 9 | 10 | from pydantic import validator 11 | 12 | from mlem.core.base import load_impl_ext 13 | from mlem.core.objects import MlemBuilder, MlemModel 14 | from mlem.core.requirements import Requirement 15 | from mlem.ui import EMOJI_OK, EMOJI_PACK, echo 16 | from mlem.utils.entrypoints import list_implementations 17 | 18 | REQUIREMENTS = "requirements.txt" 19 | 20 | logger = logging.getLogger(__name__) 21 | 22 | 23 | class RequirementsBuilder(MlemBuilder): 24 | """MlemBuilder implementation for building requirements""" 25 | 26 | type: ClassVar = "requirements" 27 | 28 | target: Optional[str] = None 29 | """Target path for requirements""" 30 | req_type: str = "installable" 31 | """Type of requirements, example: unix""" 32 | 33 | @validator("req_type") 34 | def get_req_type(cls, req_type): # pylint: disable=no-self-argument 35 | if req_type not in list_implementations(Requirement): 36 | raise ValueError( 37 | f"req_type {req_type} is not valid. Allowed options are: {list_implementations(Requirement)}" 38 | ) 39 | return req_type 40 | 41 | def build(self, obj: MlemModel): 42 | req_type_cls = load_impl_ext(Requirement.abs_name, self.req_type) 43 | assert issubclass(req_type_cls, Requirement) 44 | reqs = obj.requirements.of_type(req_type_cls) 45 | if self.target is None: 46 | reqs_representation = [r.get_repr() for r in reqs] 47 | requirement_string = " ".join(reqs_representation) 48 | print(requirement_string) 49 | else: 50 | echo(EMOJI_PACK + "Materializing requirements...") 51 | req_type_cls.materialize(reqs, self.target) 52 | echo(EMOJI_OK + f"Materialized to {self.target}!") 53 | -------------------------------------------------------------------------------- /mlem/contrib/sagemaker/__init__.py: -------------------------------------------------------------------------------- 1 | """Sagemaker Deployments support 2 | Extension type: deployment 3 | 4 | Implements MlemEnv, MlemDeployment and DeployState to work with AWS SageMaker 5 | """ 6 | -------------------------------------------------------------------------------- /mlem/contrib/sagemaker/config.py: -------------------------------------------------------------------------------- 1 | from typing import Optional 2 | 3 | from mlem.config import MlemConfigBase 4 | 5 | 6 | class AWSConfig(MlemConfigBase): 7 | ROLE: Optional[str] 8 | PROFILE: Optional[str] 9 | 10 | class Config: 11 | section = "aws" 12 | env_prefix = "AWS_" 13 | -------------------------------------------------------------------------------- /mlem/contrib/sagemaker/copy.j2: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/iterative/mlem/9aa89423f6f0f38dba50c3dc04fc7674738f1260/mlem/contrib/sagemaker/copy.j2 -------------------------------------------------------------------------------- /mlem/contrib/sagemaker/env_setup.py: -------------------------------------------------------------------------------- 1 | import os 2 | import shutil 3 | import subprocess 4 | 5 | from mlem.ui import echo 6 | 7 | MLEM_TF = "mlem_sagemaker.tf" 8 | 9 | 10 | def _tf_command(tf_dir, command, *flags, **args): 11 | args = " ".join(f"-var='{k}={v}'" for k, v in args.items()) 12 | return " ".join( 13 | [ 14 | "terraform", 15 | f"-chdir={tf_dir}", 16 | command, 17 | *flags, 18 | args, 19 | ] 20 | ) 21 | 22 | 23 | def _tf_get_var(tf_dir, varname): 24 | return ( 25 | subprocess.check_output( 26 | _tf_command(tf_dir, "output", varname), shell=True # nosec: B602 27 | ) 28 | .decode("utf8") 29 | .strip() 30 | .strip('"') 31 | ) 32 | 33 | 34 | def sagemaker_terraform( 35 | user_name: str = "mlem", 36 | role_name: str = "mlem", 37 | region_name: str = "us-east-1", 38 | profile: str = "default", 39 | plan: bool = False, 40 | work_dir: str = ".", 41 | export_secret: str = None, 42 | ): 43 | if not os.path.exists(work_dir): 44 | os.makedirs(work_dir, exist_ok=True) 45 | 46 | shutil.copy( 47 | os.path.join(os.path.dirname(__file__), MLEM_TF), 48 | os.path.join(work_dir, MLEM_TF), 49 | ) 50 | subprocess.check_output( 51 | _tf_command(work_dir, "init"), 52 | shell=True, # nosec: B602 53 | ) 54 | 55 | flags = ["-auto-approve"] if not plan else [] 56 | 57 | echo( 58 | subprocess.check_output( 59 | _tf_command( 60 | work_dir, 61 | "plan" if plan else "apply", 62 | *flags, 63 | role_name=role_name, 64 | user_name=user_name, 65 | region_name=region_name, 66 | profile=profile, 67 | ), 68 | shell=True, # nosec: B602 69 | ) 70 | ) 71 | 72 | if not plan and export_secret: 73 | if os.path.exists(export_secret): 74 | print( 75 | f"Creds already present at {export_secret}, please backup and remove them" 76 | ) 77 | return 78 | key_id = _tf_get_var(work_dir, "access_key_id") 79 | access_secret = _tf_get_var(work_dir, "secret_access_key") 80 | region = _tf_get_var(work_dir, "region_name") 81 | profile = _tf_get_var(work_dir, "aws_user") 82 | print(profile, region) 83 | if export_secret.endswith(".csv"): 84 | secrets = f"""User Name,Access key ID,Secret access key 85 | {profile},{key_id},{access_secret}""" 86 | print( 87 | f"Import new profile:\naws configure import --csv file://{export_secret}\naws configure set region {region} --profile {profile}" 88 | ) 89 | else: 90 | secrets = f"""export AWS_ACCESS_KEY_ID={key_id} 91 | export AWS_SECRET_ACCESS_KEY={access_secret} 92 | export AWS_REGION={region} 93 | """ 94 | print(f"Source envs:\nsource {export_secret}") 95 | with open(export_secret, "w", encoding="utf8") as f: 96 | f.write(secrets) 97 | -------------------------------------------------------------------------------- /mlem/contrib/sagemaker/mlem_sagemaker.tf: -------------------------------------------------------------------------------- 1 | variable "profile" { 2 | description = "AWS Profile to use for API calls" 3 | type = string 4 | default = "default" 5 | } 6 | 7 | variable "role_name" { 8 | description = "AWS role name" 9 | type = string 10 | default = "mlem" 11 | } 12 | 13 | variable "user_name" { 14 | description = "AWS user name" 15 | type = string 16 | default = "mlem" 17 | } 18 | 19 | variable "region_name" { 20 | description = "AWS region name" 21 | type = string 22 | default = "us-east-1" 23 | } 24 | 25 | provider "aws" { 26 | region = var.region_name 27 | profile = var.profile 28 | } 29 | 30 | resource "aws_iam_user" "aws_user" { 31 | name = var.user_name 32 | } 33 | 34 | resource "aws_iam_access_key" "aws_user" { 35 | user = aws_iam_user.aws_user.name 36 | } 37 | 38 | resource "aws_iam_user_policy_attachment" "sagemaker_policy" { 39 | user = aws_iam_user.aws_user.name 40 | policy_arn = "arn:aws:iam::aws:policy/AmazonSageMakerFullAccess" 41 | } 42 | 43 | resource "aws_iam_user_policy_attachment" "ecr_policy" { 44 | user = aws_iam_user.aws_user.name 45 | policy_arn = "arn:aws:iam::aws:policy/AmazonEC2ContainerRegistryFullAccess" 46 | } 47 | 48 | resource "aws_iam_role" "aws_role" { 49 | name = var.role_name 50 | description = "MLEM SageMaker Role" 51 | assume_role_policy = < /usr/local/bin/serve && chmod +x /usr/local/bin/serve 3 | ENTRYPOINT ["bash", "-c"] 4 | -------------------------------------------------------------------------------- /mlem/contrib/sagemaker/utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import posixpath 3 | import tarfile 4 | import tempfile 5 | 6 | import boto3 7 | import sagemaker 8 | 9 | from mlem.config import project_config 10 | from mlem.contrib.sagemaker.build import AWSVars 11 | from mlem.contrib.sagemaker.config import AWSConfig 12 | from mlem.core.objects import MlemModel 13 | 14 | MODEL_TAR_FILENAME = "model.tar.gz" 15 | 16 | 17 | def delete_model_file_from_s3(session: sagemaker.Session, model_path: str): 18 | s3_client = session.boto_session.client("s3") 19 | if model_path.startswith("s3://"): 20 | model_path = model_path[len("s3://") :] 21 | bucket, *paths = model_path.split("/") 22 | model_path = posixpath.join(*paths, MODEL_TAR_FILENAME) 23 | s3_client.delete_object(Bucket=bucket, Key=model_path) 24 | 25 | 26 | def init_aws_vars( 27 | profile=None, role=None, bucket=None, region=None, account=None 28 | ): 29 | boto_session = boto3.Session(profile_name=profile, region_name=region) 30 | sess = sagemaker.Session(boto_session, default_bucket=bucket) 31 | 32 | bucket = ( 33 | bucket or sess.default_bucket() 34 | ) # Replace with your own bucket name if needed 35 | region = region or boto_session.region_name 36 | config = project_config(project="", section=AWSConfig) 37 | role = role or config.ROLE or sagemaker.get_execution_role(sess) 38 | account = account or boto_session.client("sts").get_caller_identity().get( 39 | "Account" 40 | ) 41 | return sess, AWSVars( 42 | bucket=bucket, 43 | region=region, 44 | account=account, 45 | role_name=role, 46 | profile=profile or config.PROFILE, 47 | ) 48 | 49 | 50 | def _create_model_arch_and_upload_to_s3( 51 | session: sagemaker.Session, 52 | model: MlemModel, 53 | bucket: str, 54 | model_arch_location: str, 55 | ) -> str: 56 | with tempfile.TemporaryDirectory() as dirname: 57 | model.clone(os.path.join(dirname, "model", "model")) 58 | arch_path = os.path.join(dirname, "arch", MODEL_TAR_FILENAME) 59 | os.makedirs(os.path.dirname(arch_path)) 60 | with tarfile.open(arch_path, "w:gz") as tar: 61 | path = os.path.join(dirname, "model") 62 | for file in os.listdir(path): 63 | tar.add(os.path.join(path, file), arcname=file) 64 | 65 | model_location = session.upload_data( 66 | os.path.dirname(arch_path), 67 | bucket=bucket, 68 | key_prefix=posixpath.join(model_arch_location, model.meta_hash()), 69 | ) 70 | 71 | return model_location 72 | 73 | 74 | def generate_image_name(deploy_id): 75 | return f"mlem-sagemaker-image-{deploy_id}" 76 | 77 | 78 | def generate_model_file_name(deploy_id): 79 | return f"mlem-model-{deploy_id}" 80 | -------------------------------------------------------------------------------- /mlem/contrib/streamlit/__init__.py: -------------------------------------------------------------------------------- 1 | """Streamlit serving 2 | Extension type: serving 3 | 4 | StreamlitServer implementation 5 | """ 6 | -------------------------------------------------------------------------------- /mlem/contrib/streamlit/_template.py: -------------------------------------------------------------------------------- 1 | import streamlit 2 | 3 | from mlem.contrib.streamlit.utils import model_form 4 | from mlem.runtime.client import HTTPClient 5 | 6 | streamlit.set_page_config( 7 | page_title="{{page_title}}", 8 | ) 9 | 10 | 11 | @streamlit.cache_resource 12 | def get_client(): 13 | return HTTPClient( 14 | host="{{server_host}}", port=int("{{server_port}}"), raw=True 15 | ) 16 | 17 | 18 | streamlit.title("{{title}}") 19 | streamlit.write("""{{description}}""") 20 | model_form(get_client()) 21 | streamlit.markdown("---") 22 | streamlit.write( 23 | "Built for FastAPI server at `{{server_host}}:{{server_port}}`. Docs: https://mlem.ai/doc" 24 | ) 25 | -------------------------------------------------------------------------------- /mlem/contrib/torchvision.py: -------------------------------------------------------------------------------- 1 | """Torch Image Serializer 2 | Extension type: serving 3 | 4 | TorchImageSerializer implementation 5 | """ 6 | import contextlib 7 | from io import BytesIO 8 | from typing import Any, BinaryIO, ClassVar, Iterator, Union 9 | 10 | from torch import frombuffer, uint8 11 | from torchvision.io import decode_image 12 | from torchvision.transforms import ToPILImage 13 | 14 | from mlem.contrib.torch import TorchTensorDataType 15 | from mlem.core.data_type import BinarySerializer 16 | 17 | 18 | def _to_buffer(instance): 19 | buffer = BytesIO() 20 | ToPILImage()(instance).save(buffer, "JPEG") 21 | buffer.seek(0) 22 | return buffer 23 | 24 | 25 | class TorchImageSerializer(BinarySerializer): 26 | """Serializes torch tensors to/from images""" 27 | 28 | type: ClassVar = "torch_image" 29 | support_files: ClassVar = True 30 | 31 | def serialize( 32 | self, data_type: TorchTensorDataType, instance: Any 33 | ) -> bytes: 34 | return _to_buffer(instance).read() 35 | 36 | @contextlib.contextmanager 37 | def dump( 38 | self, data_type: TorchTensorDataType, instance: Any 39 | ) -> Iterator[BinaryIO]: 40 | yield _to_buffer(instance) 41 | 42 | def deserialize( 43 | self, data_type: TorchTensorDataType, obj: Union[bytes, BinaryIO] 44 | ) -> Any: 45 | if isinstance(obj, bytes): 46 | buffer = obj 47 | else: 48 | buffer = obj.read() 49 | return decode_image(frombuffer(buffer, dtype=uint8)) 50 | -------------------------------------------------------------------------------- /mlem/core/__init__.py: -------------------------------------------------------------------------------- 1 | """ 2 | Core parts of MLEM 3 | """ 4 | -------------------------------------------------------------------------------- /mlem/core/index.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/iterative/mlem/9aa89423f6f0f38dba50c3dc04fc7674738f1260/mlem/core/index.py -------------------------------------------------------------------------------- /mlem/log.py: -------------------------------------------------------------------------------- 1 | """ 2 | Loggers used in other parts of MLEM 3 | """ 4 | import logging.config 5 | 6 | from mlem.config import LOCAL_CONFIG 7 | 8 | LOG_LEVEL = LOCAL_CONFIG.LOG_LEVEL 9 | if LOCAL_CONFIG.DEBUG: 10 | LOG_LEVEL = logging.getLevelName(logging.DEBUG) 11 | 12 | logging_config = { 13 | "version": 1, 14 | "disable_existing_loggers": False, 15 | "formatters": { 16 | "standard": { 17 | "format": "%(asctime)s [%(levelname)s] %(name)s: %(message)s" 18 | }, 19 | }, 20 | "handlers": { 21 | "default": { 22 | "level": LOG_LEVEL, 23 | "formatter": "standard", 24 | "class": "logging.StreamHandler", 25 | "stream": "ext://sys.stdout", 26 | }, 27 | }, 28 | "loggers": { 29 | "mlem": { 30 | "handlers": ["default"], 31 | "level": LOG_LEVEL, 32 | } 33 | }, 34 | } 35 | 36 | logging.config.dictConfig(logging_config) 37 | -------------------------------------------------------------------------------- /mlem/polydantic/__init__.py: -------------------------------------------------------------------------------- 1 | """ 2 | Python class that enables polymorphism on top of pydantic BaseModel 3 | """ 4 | from .core import PolyModel 5 | 6 | __all__ = ["PolyModel"] 7 | -------------------------------------------------------------------------------- /mlem/runtime/__init__.py: -------------------------------------------------------------------------------- 1 | from .interface import Interface, InterfaceMethod 2 | 3 | __all__ = ["Interface", "InterfaceMethod"] 4 | -------------------------------------------------------------------------------- /mlem/runtime/middleware.py: -------------------------------------------------------------------------------- 1 | from abc import abstractmethod 2 | from typing import ClassVar, List 3 | 4 | from pydantic import BaseModel 5 | 6 | from mlem.core.base import MlemABC 7 | from mlem.core.requirements import Requirements, WithRequirements 8 | 9 | 10 | class Middleware(MlemABC, WithRequirements): 11 | abs_name: ClassVar = "middleware" 12 | 13 | class Config: 14 | type_root = True 15 | 16 | @abstractmethod 17 | def on_init(self): 18 | raise NotImplementedError 19 | 20 | @abstractmethod 21 | def on_request(self, request): 22 | raise NotImplementedError 23 | 24 | @abstractmethod 25 | def on_response(self, request, response): 26 | raise NotImplementedError 27 | 28 | 29 | class Middlewares(BaseModel): 30 | __root__: List[Middleware] = [] 31 | """Middlewares to add to server""" 32 | 33 | def on_init(self): 34 | for middleware in self.__root__: 35 | middleware.on_init() 36 | 37 | def on_request(self, request): 38 | for middleware in self.__root__: 39 | request = middleware.on_request(request) 40 | return request 41 | 42 | def on_response(self, request, response): 43 | for middleware in reversed(self.__root__): 44 | response = middleware.on_response(request, response) 45 | return response 46 | 47 | def get_requirements(self) -> Requirements: 48 | reqs = Requirements.new() 49 | for m in self.__root__: 50 | reqs += m.get_requirements() 51 | return reqs 52 | -------------------------------------------------------------------------------- /mlem/telemetry.py: -------------------------------------------------------------------------------- 1 | import contextlib 2 | from functools import wraps 3 | 4 | from iterative_telemetry import IterativeTelemetryLogger 5 | 6 | from mlem.version import __version__ 7 | 8 | 9 | def _enabled(): 10 | from mlem import LOCAL_CONFIG 11 | 12 | return not LOCAL_CONFIG.TESTS and not LOCAL_CONFIG.NO_ANALYTICS 13 | 14 | 15 | telemetry = IterativeTelemetryLogger( 16 | "mlem", 17 | __version__, 18 | _enabled, 19 | url="https://telemetry.mlem.dev/api/v1/s2s/event?ip_policy=strict", 20 | ) 21 | 22 | _is_api_running = False 23 | _pass_params = False 24 | 25 | 26 | @contextlib.contextmanager 27 | def pass_telemetry_params(): 28 | global _pass_params # pylint: disable=global-statement 29 | pass_params = _pass_params 30 | try: 31 | _pass_params = True 32 | yield 33 | finally: 34 | _pass_params = pass_params 35 | 36 | 37 | def api_telemetry(f): 38 | @wraps(f) 39 | def inner(*args, **kwargs): 40 | global _is_api_running, _pass_params # pylint: disable=global-statement 41 | is_nested = _is_api_running 42 | pass_params = _pass_params 43 | _pass_params = False 44 | _is_api_running = True 45 | try: 46 | from mlem.cli.utils import is_cli 47 | 48 | with telemetry.event_scope("api", f.__name__) as event: 49 | try: 50 | return f(*args, **kwargs) 51 | except Exception as exc: 52 | event.error = exc.__class__.__name__ 53 | raise 54 | finally: 55 | if not is_nested and not is_cli(): 56 | telemetry.send_event( 57 | event.interface, 58 | event.action, 59 | event.error, 60 | **event.kwargs, 61 | ) 62 | 63 | finally: 64 | if pass_params: 65 | for key, value in event.kwargs.items(): 66 | telemetry.log_param(key, value) 67 | _is_api_running = is_nested 68 | _pass_params = pass_params 69 | 70 | return inner 71 | -------------------------------------------------------------------------------- /mlem/ui.py: -------------------------------------------------------------------------------- 1 | import contextlib 2 | from typing import Callable, Optional 3 | 4 | from rich.align import Align 5 | from rich.console import Console 6 | from rich.style import Style 7 | from rich.table import Column, Table 8 | from rich.text import Text 9 | 10 | from mlem.config import LOCAL_CONFIG 11 | 12 | console = Console() 13 | error_console = Console(stderr=True) 14 | 15 | _echo_func: Optional[Callable] = None 16 | _offset: int = 0 17 | 18 | 19 | @contextlib.contextmanager 20 | def set_echo(echo_func=...): 21 | global _echo_func # pylint: disable=global-statement 22 | if echo_func is ...: 23 | yield 24 | return 25 | tmp = _echo_func 26 | try: 27 | _echo_func = echo_func 28 | yield 29 | finally: 30 | _echo_func = tmp 31 | 32 | 33 | @contextlib.contextmanager 34 | def set_offset(offset=0): 35 | global _offset # pylint: disable=global-statement 36 | tmp = _offset 37 | try: 38 | _offset = offset 39 | yield 40 | finally: 41 | _offset = tmp 42 | 43 | 44 | @contextlib.contextmanager 45 | def cli_echo(): 46 | with set_echo(console.print): 47 | yield 48 | 49 | 50 | @contextlib.contextmanager 51 | def stderr_echo(): 52 | with set_echo(error_console.print): 53 | yield 54 | 55 | 56 | @contextlib.contextmanager 57 | def no_echo(): 58 | with set_echo(None): 59 | yield 60 | 61 | 62 | def echo(*message): 63 | if _offset > 0: 64 | message = [" " * (_offset - 1), *message] 65 | if _echo_func is not None: 66 | _echo_func(*message) 67 | 68 | 69 | def boxify(text, col="red"): 70 | table = Table( 71 | Column(justify="center"), 72 | show_header=False, 73 | padding=(1, 4, 1, 4), 74 | style=col, 75 | ) 76 | table.add_row(Align(text, align="center")) 77 | return table 78 | 79 | 80 | def color(text, col): 81 | t = Text(text) 82 | t.stylize(col) 83 | return t 84 | 85 | 86 | def emoji(name): 87 | if not LOCAL_CONFIG.EMOJIS: 88 | return Text("") 89 | return Text(name + " ") 90 | 91 | 92 | def bold(text): 93 | return Style(bold=True).render(text) 94 | 95 | 96 | EMOJI_LOAD = emoji("⏳️") 97 | EMOJI_FAIL = emoji("❌") 98 | EMOJI_OK = emoji("✅ ") 99 | EMOJI_MLEM = emoji("🐶") 100 | EMOJI_SAVE = emoji("💾") 101 | EMOJI_APPLY = emoji("🍏") 102 | EMOJI_COPY = emoji("🐏") 103 | EMOJI_BASE = emoji("🏛") 104 | EMOJI_NAILS = emoji("🖇️ ") 105 | EMOJI_LINK = emoji("🔗") 106 | EMOJI_PACK = emoji("💼") 107 | EMOJI_BUILD = emoji("🛠") 108 | EMOJI_UPLOAD = emoji("🔼") 109 | EMOJI_STOP = emoji("🔻") 110 | EMOJI_KEY = emoji("🗝") 111 | -------------------------------------------------------------------------------- /mlem/utils/__init__.py: -------------------------------------------------------------------------------- 1 | """ 2 | Various utils functions and classes that are used in MLEM 3 | """ 4 | -------------------------------------------------------------------------------- /mlem/utils/backport.py: -------------------------------------------------------------------------------- 1 | import functools 2 | import sys 3 | 4 | if sys.version_info >= (3, 8): 5 | cached_property = functools.cached_property 6 | else: 7 | # Code copied from Python 3.8 https://github.com/python/cpython/blob/3.8/Lib/functools.py 8 | # cached_property is not available in Python versions < 3.8. 9 | from _thread import RLock 10 | 11 | _NOT_FOUND = object() 12 | 13 | class cached_property: 14 | def __init__(self, func): 15 | self.func = func 16 | self.attrname = None 17 | self.__doc__ = func.__doc__ 18 | self.lock = RLock() 19 | 20 | def __set_name__(self, owner, name): 21 | if self.attrname is None: 22 | self.attrname = name 23 | elif name != self.attrname: 24 | raise TypeError( 25 | "Cannot assign the same cached_property to two different names " 26 | f"({self.attrname!r} and {name!r})." 27 | ) 28 | 29 | def __get__(self, instance, owner=None): 30 | if instance is None: 31 | return self 32 | if self.attrname is None: 33 | raise TypeError( 34 | "Cannot use cached_property instance without calling __set_name__ on it." 35 | ) 36 | try: 37 | cache = instance.__dict__ 38 | except ( 39 | AttributeError 40 | ): # not all objects have __dict__ (e.g. class defines slots) 41 | msg = ( 42 | f"No '__dict__' attribute on {type(instance).__name__!r} " 43 | f"instance to cache {self.attrname!r} property." 44 | ) 45 | raise TypeError(msg) from None 46 | val = cache.get(self.attrname, _NOT_FOUND) 47 | if val is _NOT_FOUND: 48 | with self.lock: 49 | # check if another thread filled cache while we awaited lock 50 | val = cache.get(self.attrname, _NOT_FOUND) 51 | if val is _NOT_FOUND: 52 | val = self.func(instance) 53 | try: 54 | cache[self.attrname] = val 55 | except TypeError: 56 | msg = ( 57 | f"The '__dict__' attribute on {type(instance).__name__!r} instance " 58 | f"does not support item assignment for caching {self.attrname!r} property." 59 | ) 60 | raise TypeError(msg) from None 61 | return val 62 | -------------------------------------------------------------------------------- /mlem/utils/fslock.py: -------------------------------------------------------------------------------- 1 | import posixpath 2 | import random 3 | import re 4 | import time 5 | from typing import List, Tuple 6 | 7 | from fsspec import AbstractFileSystem 8 | 9 | from mlem.utils.path import make_posix 10 | 11 | LOCK_EXT = "lock" 12 | 13 | 14 | class LockTimeoutError(Exception): 15 | pass 16 | 17 | 18 | class FSLock: 19 | def __init__( 20 | self, 21 | fs: AbstractFileSystem, 22 | dirpath: str, 23 | name: str, 24 | timeout: float = None, 25 | retry_timeout: float = 0.1, 26 | *, 27 | salt=None, 28 | ): 29 | self.fs = fs 30 | self.dirpath = make_posix(str(dirpath)) 31 | self.name = name 32 | self.timeout = timeout 33 | self.retry_timeout = retry_timeout 34 | self._salt = salt 35 | self._timestamp = None 36 | 37 | @property 38 | def salt(self): 39 | if self._salt is None: 40 | self._salt = random.randint(10**3, 10**4) 41 | return self._salt 42 | 43 | @property 44 | def timestamp(self): 45 | if self._timestamp is None: 46 | self._timestamp = time.time_ns() 47 | return self._timestamp 48 | 49 | @property 50 | def lock_filename(self): 51 | return f"{self.name}.{self.timestamp}.{self.salt}.{LOCK_EXT}" 52 | 53 | @property 54 | def lock_path(self): 55 | return posixpath.join(self.dirpath, self.lock_filename) 56 | 57 | def _list_locks(self) -> List[Tuple[int, int]]: 58 | locks = [ 59 | posixpath.basename(make_posix(f)) 60 | for f in self.fs.listdir(self.dirpath, detail=False) 61 | ] 62 | locks = [ 63 | f[len(self.name) :] 64 | for f in locks 65 | if f.startswith(self.name) and f.endswith(LOCK_EXT) 66 | ] 67 | pat = re.compile(rf"\.(\d+)\.(\d+)\.{LOCK_EXT}") 68 | locks_re = [pat.match(lock) for lock in locks] 69 | return [ 70 | (int(m.group(1)), int(m.group(2))) 71 | for m in locks_re 72 | if m is not None 73 | ] 74 | 75 | def _double_check(self): 76 | locks = self._list_locks() 77 | if not locks: 78 | return False 79 | minlock = min(locks) 80 | c = minlock == (self._timestamp, self._salt) 81 | return c 82 | 83 | def _write_lockfile(self): 84 | self.fs.touch(self.lock_path) 85 | 86 | def _clear(self): 87 | self._timestamp = None 88 | self._salt = None 89 | 90 | def _delete_lockfile(self): 91 | try: 92 | self.fs.delete(self.lock_path) 93 | except FileNotFoundError: 94 | pass 95 | 96 | def __enter__(self): 97 | start = time.time() 98 | 99 | self._write_lockfile() 100 | time.sleep(self.retry_timeout) 101 | 102 | while not self._double_check(): 103 | if self.timeout is not None and time.time() - start > self.timeout: 104 | self._delete_lockfile() 105 | self._clear() 106 | raise LockTimeoutError( 107 | f"Lock aquiring timeouted after {self.timeout}" 108 | ) 109 | time.sleep(self.retry_timeout) 110 | 111 | def __exit__(self, exc_type, exc_val, exc_tb): 112 | self._delete_lockfile() 113 | self._clear() 114 | -------------------------------------------------------------------------------- /mlem/utils/git.py: -------------------------------------------------------------------------------- 1 | import re 2 | 3 | 4 | def is_long_sha(sha: str): 5 | return re.match(r"^[a-f\d]{40}$", sha) 6 | -------------------------------------------------------------------------------- /mlem/utils/importing.py: -------------------------------------------------------------------------------- 1 | import importlib.util 2 | import sys 3 | from importlib import import_module 4 | 5 | 6 | def import_from_path(name: str, path: str): 7 | """Import module from local path""" 8 | spec = importlib.util.spec_from_file_location(name, path) 9 | if spec is None: 10 | raise ImportError(f"Cannot import spec from {path}") 11 | module = importlib.util.module_from_spec(spec) 12 | spec.loader.exec_module(module) # type: ignore 13 | sys.modules[name] = module 14 | return module 15 | 16 | 17 | def import_string(path): 18 | """Import object from dotted path (..)""" 19 | split = path.split(".") 20 | module_name, object_name = ".".join(split[:-1]), split[-1] 21 | mod = import_module(module_name) 22 | try: 23 | return getattr(mod, object_name) 24 | except AttributeError as e: 25 | raise ImportError( 26 | f"No object {object_name} in module {module_name}" 27 | ) from e 28 | 29 | 30 | def module_importable(module_name): 31 | """Check if module is importable (by importing it xD)""" 32 | try: 33 | import_module(module_name) 34 | return True 35 | except ImportError: 36 | return False 37 | 38 | 39 | def module_imported(module_name): 40 | """ 41 | Checks if module already imported 42 | 43 | :param module_name: module name to check 44 | :return: `True` or `False` 45 | """ 46 | return sys.modules.get(module_name) is not None 47 | 48 | 49 | def import_string_with_local(path): 50 | try: 51 | # this is needed because if run from cli curdir is not checked for 52 | # modules to import 53 | sys.path.append(".") 54 | return import_string(path) 55 | finally: 56 | sys.path.remove(".") 57 | 58 | 59 | # Copyright 2019 Zyfra 60 | # Copyright 2021 Iterative 61 | # 62 | # Licensed under the Apache License, Version 2.0 (the "License"); 63 | # you may not use this file except in compliance with the License. 64 | # You may obtain a copy of the License at 65 | # 66 | # http://www.apache.org/licenses/LICENSE-2.0 67 | # 68 | # Unless required by applicable law or agreed to in writing, software 69 | # distributed under the License is distributed on an "AS IS" BASIS, 70 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 71 | # See the License for the specific language governing permissions and 72 | # limitations under the License. 73 | -------------------------------------------------------------------------------- /mlem/utils/mlem.isort.cfg: -------------------------------------------------------------------------------- 1 | [isort] 2 | known_first_party = mlem 3 | known_third_party = xgboost 4 | extra_standard_library = 5 | opcode 6 | # pytest requirements missed by isort 7 | nturl2path 8 | # workaround for imports from setup.py (see docker builds) 9 | pkg_resources 10 | sre_compile 11 | posixpath 12 | setuptools 13 | pydevconsole 14 | pydevd_tracing 15 | pydev_ipython.matplotlibtools 16 | pydev_console.protocol 17 | pydevd_file_utils 18 | pydevd_plugins.extensions.types.pydevd_plugins_django_form_str 19 | pydev_console 20 | pydev_ipython 21 | pydevd_plugins.extensions.types.pydevd_plugin_numpy_types 22 | pydevd_plugins.extensions.types.pydevd_helpers 23 | pydevd_plugins 24 | pydevd_plugins.extensions.types 25 | pydevd_plugins.extensions 26 | pydev_ipython.inputhook 27 | old_finders = True 28 | default_section = FIRSTPARTY 29 | -------------------------------------------------------------------------------- /mlem/utils/path.py: -------------------------------------------------------------------------------- 1 | from typing import Optional 2 | 3 | 4 | def make_posix(path: Optional[str]): 5 | """Turn windows path into posix""" 6 | if not path: 7 | return path 8 | return path.replace("\\", "/") 9 | -------------------------------------------------------------------------------- /mlem/utils/root.py: -------------------------------------------------------------------------------- 1 | import os 2 | import posixpath 3 | from typing import Optional, overload 4 | 5 | from fsspec import AbstractFileSystem 6 | from fsspec.implementations.local import LocalFileSystem 7 | from typing_extensions import Literal 8 | 9 | from mlem.constants import MLEM_CONFIG_FILE_NAME 10 | from mlem.core.errors import MlemProjectNotFound 11 | 12 | 13 | def mlem_project_exists( 14 | path: str, fs: AbstractFileSystem, raise_on_missing: bool = False 15 | ): 16 | """Check is mlem project exists at path""" 17 | try: 18 | exists = fs.exists(posixpath.join(path, MLEM_CONFIG_FILE_NAME)) 19 | except ValueError: 20 | # some fsspec implementations throw ValueError because of 21 | # wrong bucket/container names containing "." 22 | exists = False 23 | if not exists and raise_on_missing: 24 | raise MlemProjectNotFound(path, fs) 25 | return exists 26 | 27 | 28 | @overload 29 | def find_project_root( 30 | path: str = ".", 31 | fs: AbstractFileSystem = None, 32 | raise_on_missing: Literal[True] = ..., 33 | recursive: bool = True, 34 | ) -> str: 35 | ... 36 | 37 | 38 | @overload 39 | def find_project_root( 40 | path: str = ".", 41 | fs: AbstractFileSystem = None, 42 | raise_on_missing: Literal[False] = ..., 43 | recursive: bool = True, 44 | ) -> Optional[str]: 45 | ... 46 | 47 | 48 | def find_project_root( 49 | path: str = ".", 50 | fs: AbstractFileSystem = None, 51 | raise_on_missing: bool = True, 52 | recursive: bool = True, 53 | ) -> Optional[str]: 54 | """Search for mlem project root folder, starting from the given path 55 | and up the directory tree. 56 | Raises an Exception if folder is not found. 57 | """ 58 | if fs is None: 59 | fs = LocalFileSystem() 60 | if isinstance(fs, LocalFileSystem) and not os.path.isabs(path): 61 | path = os.path.abspath(path) 62 | _path = path[:] 63 | if not recursive: 64 | if mlem_project_exists(_path, fs): 65 | return _path 66 | else: 67 | if fs.isfile(_path) or not fs.exists(_path): 68 | _path = os.path.dirname(_path) 69 | while True: 70 | if mlem_project_exists(_path, fs): 71 | return _path 72 | if _path == os.path.dirname(_path): 73 | break 74 | 75 | _path = os.path.dirname(_path) 76 | if raise_on_missing: 77 | raise MlemProjectNotFound(path, fs) 78 | return None 79 | -------------------------------------------------------------------------------- /mlem/utils/templates.py: -------------------------------------------------------------------------------- 1 | from typing import ClassVar, List, Type, TypeVar 2 | 3 | from fsspec import AbstractFileSystem 4 | from fsspec.implementations.local import LocalFileSystem 5 | from jinja2 import ( 6 | Environment, 7 | FileSystemLoader, 8 | StrictUndefined, 9 | select_autoescape, 10 | ) 11 | from pydantic import BaseModel 12 | 13 | T = TypeVar("T", bound="TemplateModel") 14 | 15 | 16 | class TemplateModel(BaseModel): 17 | """Base class to render jinja templates from pydantic models""" 18 | 19 | TEMPLATE_FILE: ClassVar[str] 20 | TEMPLATE_DIR: ClassVar[str] 21 | 22 | templates_dir: List[str] = [] 23 | """list of directories to look for jinja templates""" 24 | 25 | def prepare_dict(self): 26 | return self.dict() 27 | 28 | def generate(self, **additional): 29 | j2 = Environment( 30 | loader=FileSystemLoader(self.templates_dir + [self.TEMPLATE_DIR]), 31 | undefined=StrictUndefined, 32 | autoescape=select_autoescape(), 33 | ) 34 | template = j2.get_template(self.TEMPLATE_FILE) 35 | args = self.prepare_dict() 36 | args.update(additional) 37 | return template.render(**args) 38 | 39 | def write(self, path: str, fs: AbstractFileSystem = None, **additional): 40 | fs = fs or LocalFileSystem() 41 | with fs.open(path, "w") as f: 42 | f.write(self.generate(**additional)) 43 | 44 | @classmethod 45 | def from_model(cls: Type[T], obj, templates_dir: List[str] = None) -> T: 46 | args = { 47 | f: getattr(obj, f) for f in cls.__fields__ if f != "templates_dir" 48 | } 49 | return cls(templates_dir=templates_dir or [], **args) 50 | -------------------------------------------------------------------------------- /mlem/version.py: -------------------------------------------------------------------------------- 1 | try: 2 | from ._mlem_version import version as __version__ 3 | from ._mlem_version import version_tuple 4 | except ImportError: 5 | try: 6 | from setuptools_scm import get_version 7 | 8 | __version__ = get_version(root="..", relative_to=__file__) 9 | except (LookupError, ImportError): 10 | __version__ = "UNKNOWN" 11 | version_tuple = () # type: ignore 12 | 13 | __all__ = ["__version__", "version_tuple"] 14 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [tool.black] 2 | line-length = 79 3 | include = '\.pyi?$' 4 | exclude = ''' 5 | /( 6 | \.eggs 7 | | \.git 8 | | \.hg 9 | | \.mypy_cache 10 | | \.tox 11 | | \.venv 12 | | _build 13 | | buck-out 14 | | build 15 | | dist 16 | )/ 17 | ''' 18 | 19 | [build-system] 20 | requires = ["setuptools>=48", "setuptools_scm[toml]>=6.3.1", "setuptools_scm_git_archive==1.1"] 21 | build-backend = "setuptools.build_meta" 22 | 23 | [tool.setuptools_scm] 24 | write_to = "mlem/_mlem_version.py" 25 | -------------------------------------------------------------------------------- /renovate.json: -------------------------------------------------------------------------------- 1 | { 2 | "extends": [ 3 | "config:base" 4 | ] 5 | } 6 | -------------------------------------------------------------------------------- /setup.cfg: -------------------------------------------------------------------------------- 1 | [flake8] 2 | ignore = 3 | # too long lines, for now. TODO: https://github.com/iterative/mlem/issues/3 4 | E501, 5 | # Whitespace before ':' 6 | E203, 7 | # Too many leading '#' for block comment 8 | E266, 9 | # Line break occurred before a binary operator 10 | W503, 11 | # Do not perform function calls in argument defaults: conflicts with typer 12 | B008, 13 | # unindexed parameters in the str.format, see: 14 | P1, 15 | # Invalid first argument 'cls' used for instance method. 16 | B902, 17 | # ABCs without methods 18 | B024, 19 | # Use f"{obj!r}" instead of f"'{obj}'" 20 | B028, 21 | # https://pypi.org/project/flake8-string-format/ 22 | max_line_length = 79 23 | max-complexity = 15 24 | select = B,C,E,F,W,T4,B902,T,P 25 | show_source = true 26 | count = true 27 | 28 | [isort] 29 | profile = black 30 | known_first_party = mlem,tests 31 | line_length = 79 32 | 33 | [tool:pytest] 34 | log_level = debug 35 | markers = 36 | long: Marks long-running tests 37 | docker: Marks tests that needs Docker 38 | kubernetes: Marks tests that needs Kubernetes 39 | conda: Marks tests that need conda 40 | testpaths = 41 | tests 42 | addopts = -rav --durations=0 --cov=mlem --cov-report=term-missing --cov-report=xml 43 | 44 | [mypy] 45 | # Error output 46 | show_column_numbers = True 47 | show_error_codes = True 48 | show_error_context = True 49 | show_traceback = True 50 | pretty = True 51 | exclude = mlem/deploy/* 52 | disable_error_code = misc, type-abstract, annotation-unchecked 53 | # TODO: enable no_implicit_optional with 54 | # https://github.com/hauntsaninja/no_implicit_optional 55 | no_implicit_optional = False 56 | check_untyped_defs = False 57 | # plugins = pydantic.mypy 58 | 59 | # See https://mypy.readthedocs.io/en/latest/running_mypy.html#missing-imports. 60 | ignore_missing_imports = True 61 | 62 | # Warnings 63 | warn_no_return = True 64 | warn_redundant_casts = True 65 | warn_unreachable = True 66 | -------------------------------------------------------------------------------- /tests/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/iterative/mlem/9aa89423f6f0f38dba50c3dc04fc7674738f1260/tests/__init__.py -------------------------------------------------------------------------------- /tests/__main__.py: -------------------------------------------------------------------------------- 1 | import sys 2 | 3 | import pytest 4 | 5 | if __name__ == "__main__": 6 | sys.exit(pytest.main(["-v", *sys.argv[1:]])) 7 | -------------------------------------------------------------------------------- /tests/api/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/iterative/mlem/9aa89423f6f0f38dba50c3dc04fc7674738f1260/tests/api/__init__.py -------------------------------------------------------------------------------- /tests/api/test_migrations.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | from pydantic import ValidationError 3 | from yaml import safe_dump 4 | 5 | from mlem.api.migrations import migrate 6 | from mlem.core.metadata import load_meta 7 | from mlem.core.objects import MlemModel, MlemObject 8 | 9 | model_02 = ( 10 | { 11 | "object_type": "model", 12 | "description": "machine learning should be mlemming", 13 | "labels": ["mlemming", "it", "should", "be"], 14 | "artifacts": {}, 15 | "model_type": {"type": "sklearn", "methods": {"lol": {}}}, 16 | }, 17 | MlemModel( 18 | artifacts={}, 19 | call_orders={"lol": [("model", "lol")]}, 20 | processors={"model": {"type": "sklearn", "methods": {"lol": {}}}}, 21 | ), 22 | ) 23 | 24 | 25 | model_03 = ( 26 | { 27 | "object_type": "model", 28 | "artifacts": {}, 29 | "model_type": {"type": "sklearn", "methods": {"lol": {}}}, 30 | }, 31 | MlemModel( 32 | artifacts={}, 33 | call_orders={"lol": [("model", "lol")]}, 34 | processors={"model": {"type": "sklearn", "methods": {"lol": {}}}}, 35 | ), 36 | ) 37 | 38 | 39 | @pytest.mark.parametrize("old_data", [model_02, model_03]) 40 | def test_single(tmpdir, old_data): 41 | path = tmpdir / "model.mlem" 42 | old_payload, new_object = old_data 43 | path.write_text(safe_dump(old_payload), encoding="utf8") 44 | 45 | migrate(str(path)) 46 | 47 | meta = load_meta(path, try_migrations=False) 48 | 49 | assert isinstance(meta, MlemObject) 50 | assert meta == new_object 51 | 52 | 53 | @pytest.mark.parametrize("old_data,new_data", [model_02, model_03]) 54 | @pytest.mark.parametrize("recursive", [True, False]) 55 | def test_directory(tmpdir, old_data, new_data, recursive): 56 | subdir_path = tmpdir / "subdir" / "model.mlem" 57 | (tmpdir / "subdir").mkdir() 58 | subdir_path.write_text(safe_dump(old_data), encoding="utf8") 59 | for i in range(3): 60 | path = tmpdir / f"model{i}.mlem" 61 | path.write_text(safe_dump(old_data), encoding="utf8") 62 | 63 | migrate(str(tmpdir), recursive=recursive) 64 | 65 | for i in range(3): 66 | path = tmpdir / f"model{i}.mlem" 67 | meta = load_meta(path, try_migrations=False) 68 | assert isinstance(meta, MlemObject) 69 | assert meta == new_data 70 | 71 | if recursive: 72 | meta = load_meta(subdir_path, try_migrations=False) 73 | assert isinstance(meta, MlemObject) 74 | assert meta == new_data 75 | else: 76 | try: 77 | assert load_meta(subdir_path, try_migrations=False) != new_data 78 | except ValidationError: 79 | pass 80 | 81 | 82 | @pytest.mark.parametrize("old_data,new_data", [model_02, model_03]) 83 | def test_load_with_migration(tmpdir, old_data, new_data): 84 | path = tmpdir / "model.mlem" 85 | path.write_text(safe_dump(old_data), encoding="utf8") 86 | 87 | meta = load_meta(path, try_migrations=True) 88 | 89 | assert isinstance(meta, MlemObject) 90 | assert meta == new_data 91 | -------------------------------------------------------------------------------- /tests/api/test_utils.py: -------------------------------------------------------------------------------- 1 | from pathlib import PosixPath 2 | 3 | import pytest 4 | 5 | from mlem.api.utils import get_model_meta 6 | from mlem.core.objects import MlemModel, ModelAnalyzer 7 | 8 | 9 | def funclen(x): 10 | return len(x) 11 | 12 | 13 | @pytest.fixture 14 | def one_custom_processor_model(tmp_path: PosixPath): 15 | model = MlemModel() 16 | model.add_processor( 17 | "textlen", ModelAnalyzer.analyze(funclen, sample_data="word") 18 | ) 19 | model.call_orders["textlen"] = [("textlen", "__call__")] 20 | path = str(tmp_path / "model") 21 | model.dump(path) 22 | return path 23 | 24 | 25 | @pytest.fixture 26 | def two_custom_processors_model(tmp_path: PosixPath): 27 | model = MlemModel() 28 | model.add_processor( 29 | "textlen", ModelAnalyzer.analyze(funclen, sample_data="word") 30 | ) 31 | model.call_orders["textlen"] = [("textlen", "__call__")] 32 | model.add_processor( 33 | "textlen2", ModelAnalyzer.analyze(funclen, sample_data="word") 34 | ) 35 | model.call_orders["textlen2"] = [("textlen2", "__call__")] 36 | path = str(tmp_path / "model") 37 | model.dump(path) 38 | return path 39 | 40 | 41 | def test_get_model_meta_one_processor(one_custom_processor_model): 42 | model = get_model_meta(one_custom_processor_model, load_value=True) 43 | assert model.textlen("tenletters") == 10 44 | 45 | 46 | def test_get_model_meta_two_processors(two_custom_processors_model): 47 | model = get_model_meta(two_custom_processors_model, load_value=True) 48 | assert model.textlen("tenletters") == 10 49 | -------------------------------------------------------------------------------- /tests/cli/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/iterative/mlem/9aa89423f6f0f38dba50c3dc04fc7674738f1260/tests/cli/__init__.py -------------------------------------------------------------------------------- /tests/cli/conftest.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | from click.testing import Result 3 | from typer.testing import CliRunner 4 | 5 | from mlem import LOCAL_CONFIG 6 | from mlem.cli import app 7 | 8 | app.pretty_exceptions_short = False 9 | 10 | 11 | class Runner: 12 | def __init__(self): 13 | self._runner = CliRunner(mix_stderr=False) 14 | 15 | def invoke(self, *args, raise_on_error: bool = False, **kwargs) -> Result: 16 | result = self._runner.invoke(app, *args, **kwargs) 17 | if raise_on_error and result.exit_code != 0: 18 | if result.exit_code == 1: 19 | raise result.exception 20 | raise RuntimeError(result.stderr) 21 | return result 22 | 23 | 24 | @pytest.fixture 25 | def runner() -> Runner: 26 | return Runner() 27 | 28 | 29 | @pytest.fixture 30 | def no_debug(): 31 | tmp = LOCAL_CONFIG.DEBUG 32 | try: 33 | LOCAL_CONFIG.DEBUG = False 34 | yield 35 | finally: 36 | LOCAL_CONFIG.DEBUG = tmp 37 | -------------------------------------------------------------------------------- /tests/cli/test_build.py: -------------------------------------------------------------------------------- 1 | import json 2 | import os.path 3 | from typing import ClassVar 4 | 5 | from pydantic import parse_obj_as 6 | from yaml import safe_dump 7 | 8 | from mlem.cli.build import create_build_command 9 | from mlem.contrib.fastapi import FastAPIServer 10 | from mlem.core.objects import MlemBuilder, MlemModel 11 | from mlem.runtime.server import Server 12 | from mlem.utils.path import make_posix 13 | from tests.cli.conftest import Runner 14 | 15 | 16 | class BuilderMock(MlemBuilder): 17 | """mock""" 18 | 19 | type: ClassVar = "mock" 20 | target: str 21 | """target""" 22 | server: Server 23 | """server""" 24 | 25 | def build(self, obj: MlemModel): 26 | with open(self.target, "w", encoding="utf8") as f: 27 | f.write(obj.loc.path + "\n") 28 | json.dump(self.server.dict(), f) 29 | 30 | 31 | create_build_command(BuilderMock.type) 32 | 33 | 34 | def test_build(runner: Runner, model_meta_saved_single, tmp_path): 35 | path = os.path.join(tmp_path, "packed") 36 | result = runner.invoke( 37 | f"build mock -m {make_posix(model_meta_saved_single.loc.uri)} --target {make_posix(path)} --server fastapi --server.port 1000" 38 | ) 39 | 40 | assert result.exit_code == 0, ( 41 | result.stdout, 42 | result.stderr, 43 | result.exception, 44 | ) 45 | 46 | with open(path, encoding="utf8") as f: 47 | lines = f.read().splitlines() 48 | assert len(lines) == 2 49 | path, serv = lines 50 | assert path == model_meta_saved_single.loc.path 51 | assert parse_obj_as(Server, json.loads(serv)) == FastAPIServer( 52 | port=1000 53 | ) 54 | 55 | 56 | def test_build_with_file_conf( 57 | runner: Runner, model_meta_saved_single, tmp_path 58 | ): 59 | path = os.path.join(tmp_path, "packed") 60 | server_path = os.path.join(tmp_path, "server.yaml") 61 | with open(server_path, "w", encoding="utf8") as f: 62 | safe_dump(FastAPIServer(port=9999).dict(), f) 63 | 64 | result = runner.invoke( 65 | f"build mock -m {make_posix(model_meta_saved_single.loc.uri)} --target {make_posix(path)} --file_conf server={make_posix(server_path)}" 66 | ) 67 | 68 | assert result.exit_code == 0, (result.exception, result.output) 69 | 70 | with open(path, encoding="utf8") as f: 71 | lines = f.read().splitlines() 72 | assert len(lines) == 2 73 | path, serv = lines 74 | assert path == model_meta_saved_single.loc.path 75 | assert parse_obj_as(Server, json.loads(serv)) == FastAPIServer( 76 | port=9999 77 | ) 78 | 79 | 80 | def test_build_with_load(runner: Runner, model_meta_saved_single, tmp_path): 81 | path = os.path.join(tmp_path, "packed") 82 | load_path = os.path.join(tmp_path, "builder.yaml") 83 | builder = BuilderMock( 84 | server=FastAPIServer(port=9999), target=make_posix(path) 85 | ) 86 | with open(load_path, "w", encoding="utf8") as f: 87 | safe_dump(builder.dict(), f) 88 | 89 | result = runner.invoke( 90 | f"build -m {make_posix(model_meta_saved_single.loc.uri)} --load {make_posix(load_path)}" 91 | ) 92 | 93 | assert result.exit_code == 0, (result.exception, result.output) 94 | 95 | with open(path, encoding="utf8") as f: 96 | lines = f.read().splitlines() 97 | assert len(lines) == 2 98 | path, serv = lines 99 | assert path == model_meta_saved_single.loc.path 100 | assert parse_obj_as(Server, json.loads(serv)) == FastAPIServer( 101 | port=9999 102 | ) 103 | -------------------------------------------------------------------------------- /tests/cli/test_checkenv.py: -------------------------------------------------------------------------------- 1 | from mlem.core.metadata import load_meta 2 | from mlem.core.objects import MlemModel 3 | 4 | 5 | def test_checkenv(runner, model_path_mlem_project): 6 | model_path, _ = model_path_mlem_project 7 | result = runner.invoke( 8 | ["checkenv", model_path], 9 | ) 10 | assert result.exit_code == 0, ( 11 | result.stdout, 12 | result.stderr, 13 | result.exception, 14 | ) 15 | 16 | meta = load_meta(model_path, load_value=False, force_type=MlemModel) 17 | meta.requirements.__root__[0].version = "asdad" 18 | meta.update() 19 | 20 | result = runner.invoke( 21 | ["checkenv", model_path], 22 | ) 23 | assert result.exit_code == 1, ( 24 | result.stdout, 25 | result.stderr, 26 | result.exception, 27 | ) 28 | -------------------------------------------------------------------------------- /tests/cli/test_clone.py: -------------------------------------------------------------------------------- 1 | import posixpath 2 | import tempfile 3 | 4 | from mlem.core.metadata import load_meta 5 | from tests.cli.conftest import Runner 6 | 7 | 8 | def test_model_cloning(runner: Runner, model_path): 9 | with tempfile.TemporaryDirectory() as path: 10 | path = posixpath.join(path, "cloned") 11 | result = runner.invoke(["clone", model_path, path]) 12 | assert result.exit_code == 0, ( 13 | result.stdout, 14 | result.stderr, 15 | result.exception, 16 | ) 17 | load_meta(path, load_value=False) 18 | -------------------------------------------------------------------------------- /tests/cli/test_import_path.py: -------------------------------------------------------------------------------- 1 | import pickle 2 | 3 | import pytest 4 | 5 | from mlem.core.metadata import load 6 | 7 | 8 | @pytest.fixture 9 | def write_model_pickle(model): 10 | def write(path): 11 | with open(path, "wb") as f: 12 | pickle.dump(model, f) 13 | 14 | return write 15 | 16 | 17 | @pytest.mark.parametrize("file_ext, type_", [(".pkl", None), ("", "pickle")]) 18 | def test_import_model_pickle_copy( 19 | runner, write_model_pickle, train, tmpdir, file_ext, type_ 20 | ): 21 | path = str(tmpdir / "mymodel" + file_ext) 22 | write_model_pickle(path) 23 | 24 | out_path = str(tmpdir / "mlem_model") 25 | 26 | result = runner.invoke( 27 | ["import", path, out_path, "--type", type_, "--copy"], 28 | ) 29 | assert result.exit_code == 0, ( 30 | result.stdout, 31 | result.stderr, 32 | result.exception, 33 | ) 34 | 35 | loaded = load(out_path) 36 | loaded.predict(train) 37 | -------------------------------------------------------------------------------- /tests/cli/test_info.py: -------------------------------------------------------------------------------- 1 | import json 2 | import os 3 | 4 | from pydantic import parse_obj_as 5 | 6 | from mlem.core.meta_io import MLEM_EXT 7 | from mlem.core.objects import MlemModel, MlemObject 8 | from tests.conftest import MLEM_TEST_REPO, long 9 | 10 | 11 | def test_pretty_print(runner, model_path_mlem_project): 12 | model_path, _ = model_path_mlem_project 13 | result = runner.invoke( 14 | ["pprint", model_path + MLEM_EXT], 15 | ) 16 | assert result.exit_code == 0, ( 17 | result.stdout, 18 | result.stderr, 19 | result.exception, 20 | ) 21 | 22 | result = runner.invoke( 23 | ["pprint", model_path + MLEM_EXT, "--json"], 24 | ) 25 | assert result.exit_code == 0, ( 26 | result.stdout, 27 | result.stderr, 28 | result.exception, 29 | ) 30 | meta = parse_obj_as(MlemObject, json.loads(result.stdout)) 31 | assert isinstance(meta, MlemModel) 32 | 33 | 34 | @long 35 | def test_pretty_print_remote(runner, current_test_branch): 36 | model_path = os.path.join( 37 | MLEM_TEST_REPO, "tree", current_test_branch, "simple/data/model" 38 | ) 39 | result = runner.invoke( 40 | ["pprint", model_path + MLEM_EXT], 41 | ) 42 | assert result.exit_code == 0, ( 43 | result.stdout, 44 | result.stderr, 45 | result.exception, 46 | ) 47 | -------------------------------------------------------------------------------- /tests/cli/test_init.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | from mlem.constants import MLEM_CONFIG_FILE_NAME 4 | from mlem.utils.path import make_posix 5 | from tests.cli.conftest import Runner 6 | 7 | 8 | def test_init(runner: Runner, tmpdir): 9 | result = runner.invoke(f"init {make_posix(str(tmpdir))}") 10 | assert result.exit_code == 0, result.exception 11 | assert os.path.isfile(tmpdir / MLEM_CONFIG_FILE_NAME) 12 | -------------------------------------------------------------------------------- /tests/cli/test_link.py: -------------------------------------------------------------------------------- 1 | import os 2 | import tempfile 3 | 4 | from mlem.api import load_meta 5 | from mlem.core.meta_io import MLEM_EXT 6 | from mlem.core.objects import MlemLink, MlemModel 7 | 8 | 9 | def test_link(runner, model_path): 10 | with tempfile.TemporaryDirectory() as dir: 11 | link_path = os.path.join(dir, "latest.mlem") 12 | result = runner.invoke( 13 | ["link", model_path, link_path, "--abs"], 14 | ) 15 | assert result.exit_code == 0, ( 16 | result.stdout, 17 | result.stderr, 18 | result.exception, 19 | ) 20 | assert os.path.exists(link_path) 21 | model = load_meta(link_path) 22 | assert isinstance(model, MlemModel) 23 | 24 | 25 | def test_link_mlem_project(runner, model_path_mlem_project): 26 | model_path, project = model_path_mlem_project 27 | link_name = "latest.mlem" 28 | result = runner.invoke( 29 | ["link", model_path, link_name, "--target-project", project], 30 | ) 31 | assert result.exit_code == 0, ( 32 | result.stdout, 33 | result.stderr, 34 | result.exception, 35 | ) 36 | link_path = os.path.join(project, link_name) 37 | assert os.path.exists(link_path) 38 | link_object = load_meta(link_path, follow_links=False) 39 | assert isinstance(link_object, MlemLink) 40 | assert link_object.path[: -len(MLEM_EXT)] == os.path.basename(model_path) 41 | model = load_meta(link_path) 42 | assert isinstance(model, MlemModel) 43 | -------------------------------------------------------------------------------- /tests/cli/test_main.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | import requests 3 | from click import Context, Group 4 | from typer.main import get_command_from_info, get_group, get_group_from_info 5 | 6 | from mlem.cli import app 7 | from tests.cli.conftest import Runner 8 | from tests.conftest import long 9 | 10 | 11 | def iter_group(group: Group, prefix=()): 12 | prefix = prefix + (group.name,) 13 | yield prefix, group 14 | for name, c in group.commands.items(): 15 | if isinstance(c, Group): 16 | yield from iter_group(c, prefix) 17 | else: 18 | yield prefix + (name,), c 19 | 20 | 21 | @pytest.fixture 22 | def app_cli_cmd(): 23 | commands = [ 24 | get_command_from_info( 25 | c, pretty_exceptions_short=False, rich_markup_mode="rich" 26 | ) 27 | for c in app.registered_commands 28 | ] 29 | groups = [ 30 | get_group_from_info( 31 | g, pretty_exceptions_short=False, rich_markup_mode="rich" 32 | ) 33 | for g in app.registered_groups 34 | ] 35 | return [(c.name, c) for c in commands] + [ 36 | (" ".join(names), cmd) for g in groups for names, cmd in iter_group(g) 37 | ] 38 | 39 | 40 | @long 41 | def test_commands_help(app_cli_cmd): 42 | no_help = [] 43 | no_link = [] 44 | link_broken = [] 45 | group = get_group(app) 46 | ctx = Context(group, info_name="mlem", help_option_names=["-h", "--help"]) 47 | 48 | with ctx: 49 | for name, cli_cmd in app_cli_cmd: 50 | if cli_cmd.help is None: 51 | no_help.append(name) 52 | elif "Documentation: <" not in cli_cmd.help: 53 | no_link.append(name) 54 | else: 55 | link = cli_cmd.help.split("Documentation: <")[1].split(">")[0] 56 | response = requests.head(link, timeout=5) 57 | try: 58 | response.raise_for_status() 59 | except requests.HTTPError: 60 | link_broken.append(name) 61 | 62 | assert len(no_help) == 0, f"{no_help} cli commands do not have help!" 63 | assert ( 64 | len(no_link) == 0 65 | ), f"{no_link} cli commands do not have documentation link!" 66 | assert ( 67 | len(link_broken) == 0 68 | ), f"{link_broken} cli commands have broken documentation links!" 69 | 70 | 71 | def test_commands_args_help(app_cli_cmd): 72 | no_help = [] 73 | for name, cmd in app_cli_cmd: 74 | dynamic_metavar = getattr(cmd, "dynamic_metavar", None) 75 | for arg in cmd.params: 76 | if arg.name == dynamic_metavar: 77 | continue 78 | if arg.help is None: 79 | no_help.append(f"{name}:{arg.name}") 80 | assert len(no_help) == 0, f"{no_help} cli commands args do not have help!" 81 | 82 | 83 | @pytest.mark.parametrize("cmd", ["--help", "-h"]) 84 | def test_help(runner: Runner, cmd): 85 | result = runner.invoke(cmd) 86 | assert result.exit_code == 0, ( 87 | result.stdout, 88 | result.stderr, 89 | result.exception, 90 | ) 91 | 92 | 93 | def test_cli_commands_help(runner: Runner, app_cli_cmd): 94 | for name, _ in app_cli_cmd: 95 | runner.invoke(name + " --help", raise_on_error=True) 96 | 97 | 98 | def test_version(runner: Runner): 99 | from mlem import __version__ 100 | 101 | result = runner.invoke("--version") 102 | assert result.exit_code == 0, ( 103 | result.stdout, 104 | result.stderr, 105 | result.exception, 106 | ) 107 | assert __version__ in result.stdout 108 | -------------------------------------------------------------------------------- /tests/cli/test_serve.py: -------------------------------------------------------------------------------- 1 | from typing import ClassVar 2 | 3 | from mlem.cli.serve import create_serve_command 4 | from mlem.runtime import Interface 5 | from mlem.runtime.server import Server 6 | from mlem.ui import echo 7 | from tests.cli.conftest import Runner 8 | 9 | 10 | class MockServer(Server): 11 | """mock""" 12 | 13 | type: ClassVar = "mock" 14 | param: str = "wrong" 15 | """param""" 16 | 17 | def serve(self, interface: Interface): 18 | echo(self.param) 19 | 20 | 21 | create_serve_command(MockServer.type) 22 | 23 | 24 | def test_serve(runner: Runner, model_single_path): 25 | result = runner.invoke(f"serve mock -m {model_single_path} --param aaa") 26 | assert result.exit_code == 0, ( 27 | result.stdout, 28 | result.stderr, 29 | result.exception, 30 | ) 31 | assert result.stdout.splitlines()[-1] == "aaa" 32 | -------------------------------------------------------------------------------- /tests/cli/test_stderr.py: -------------------------------------------------------------------------------- 1 | from io import StringIO 2 | from unittest import mock 3 | 4 | import pytest 5 | 6 | from mlem.core.errors import MlemError 7 | from mlem.ui import echo, stderr_echo 8 | 9 | EXCEPTION_MESSAGE = "Test Exception Message" 10 | 11 | 12 | @pytest.mark.usefixtures("no_debug") 13 | def test_stderr_exception(runner): 14 | # patch the ls command and ensure it throws an expection. 15 | with mock.patch( 16 | "mlem.api.commands.init", side_effect=Exception(EXCEPTION_MESSAGE) 17 | ): 18 | result = runner.invoke( 19 | ["init"], 20 | ) 21 | assert result.exit_code == 1, ( 22 | result.stdout, 23 | result.stderr, 24 | result.exception, 25 | ) 26 | assert len(result.stderr) > 0, "Output is empty, but should not be" 27 | assert EXCEPTION_MESSAGE in result.stderr 28 | 29 | 30 | MLEM_ERROR_MESSAGE = "Test Mlem Error Message" 31 | 32 | 33 | @pytest.mark.usefixtures("no_debug") 34 | def test_stderr_mlem_error(runner): 35 | # patch the ls command and ensure it throws a mlem error. 36 | with mock.patch( 37 | "mlem.api.commands.init", side_effect=MlemError(MLEM_ERROR_MESSAGE) 38 | ): 39 | result = runner.invoke( 40 | ["init"], 41 | ) 42 | assert result.exit_code == 1, ( 43 | result.stdout, 44 | result.stderr, 45 | result.exception, 46 | ) 47 | assert len(result.stderr) > 0, "Output is empty, but should not be" 48 | assert MLEM_ERROR_MESSAGE in result.stderr 49 | 50 | 51 | STDERR_MESSAGE = "Test Stderr Message" 52 | 53 | 54 | def test_stderr_echo(): 55 | with mock.patch("sys.stderr", new_callable=StringIO) as mock_stderr: 56 | with stderr_echo(): 57 | echo(STDERR_MESSAGE) 58 | mock_stderr.seek(0) 59 | output = mock_stderr.read() 60 | assert len(output) > 0, "Output is empty, but should not be" 61 | assert STDERR_MESSAGE in output 62 | -------------------------------------------------------------------------------- /tests/cli/test_types.py: -------------------------------------------------------------------------------- 1 | from typing import Optional 2 | 3 | import pytest 4 | from pydantic import BaseModel 5 | 6 | from mlem.cli.types import iterate_type_fields 7 | from mlem.cli.utils import get_attribute_docstrings, get_field_help 8 | from mlem.core.base import MlemABC, load_impl_ext 9 | from mlem.utils.entrypoints import list_implementations 10 | from tests.cli.conftest import Runner 11 | 12 | 13 | def test_types(runner: Runner): 14 | result = runner.invoke("types") 15 | assert result.exit_code == 0, (result.exception, result.output) 16 | assert all(typename in result.output for typename in MlemABC.abs_types) 17 | 18 | 19 | @pytest.mark.parametrize("abs_name", MlemABC.abs_types.keys()) 20 | def test_types_abs_name(runner: Runner, abs_name): 21 | result = runner.invoke(f"types {abs_name}") 22 | assert result.exit_code == 0, result.exception 23 | assert set(result.output.splitlines()) == set( 24 | list_implementations(abs_name, include_hidden=False) 25 | ) 26 | 27 | 28 | @pytest.mark.parametrize( 29 | "abs_name,subtype", 30 | [ 31 | (abs_name, subtype) 32 | for abs_name, root_type in MlemABC.abs_types.items() 33 | for subtype in list_implementations(root_type, include_hidden=False) 34 | if not subtype.startswith("tests.") and "mock" not in subtype 35 | ], 36 | ) 37 | def test_types_abs_name_subtype(runner: Runner, abs_name, subtype): 38 | result = runner.invoke(f"types {abs_name} {subtype}") 39 | assert result.exit_code == 0, result.exception 40 | if not subtype.startswith("tests."): 41 | assert "docstring missing" not in result.output 42 | 43 | 44 | @pytest.mark.parametrize( 45 | "abs_name,subtype", 46 | [ 47 | (abs_name, subtype) 48 | for abs_name, root_type in MlemABC.abs_types.items() 49 | for subtype in list_implementations(root_type, include_hidden=False) 50 | if not subtype.startswith("tests.") and "mock" not in subtype 51 | ], 52 | ) 53 | def test_fields_capitalized(abs_name, subtype): 54 | impl = load_impl_ext(abs_name, subtype) 55 | ad = get_attribute_docstrings(impl) 56 | allowed_lowercase = ["md5"] 57 | capitalized = { 58 | k: v[0] == v[0].capitalize() 59 | if all(not v.startswith(prefix) for prefix in allowed_lowercase) 60 | else True 61 | for k, v in ad.items() 62 | } 63 | assert capitalized == {k: True for k in ad} 64 | 65 | 66 | def test_iter_type_fields_subclass(): 67 | class Parent(BaseModel): 68 | parent: str 69 | """parent""" 70 | 71 | class Child(Parent): 72 | child: str 73 | """child""" 74 | excluded: Optional[str] = None 75 | 76 | class Config: 77 | fields = {"excluded": {"exclude": True}} 78 | 79 | fields = list(iterate_type_fields(Child)) 80 | 81 | assert len(fields) == 2 82 | assert {get_field_help(Child, f.path) for f in fields} == { 83 | "parent", 84 | "child", 85 | } 86 | 87 | 88 | def test_iter_type_fields_subclass_multiinheritance(): 89 | class Parent(BaseModel): 90 | parent: str 91 | """parent""" 92 | 93 | class Parent2(BaseModel): 94 | parent2 = "" 95 | """parent2""" 96 | 97 | class Child(Parent, Parent2): 98 | child: str 99 | """child""" 100 | 101 | fields = list(iterate_type_fields(Child)) 102 | 103 | assert len(fields) == 3 104 | assert {get_field_help(Child, f.path) for f in fields} == { 105 | "parent", 106 | "child", 107 | "parent2", 108 | } 109 | -------------------------------------------------------------------------------- /tests/contrib/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/iterative/mlem/9aa89423f6f0f38dba50c3dc04fc7674738f1260/tests/contrib/__init__.py -------------------------------------------------------------------------------- /tests/contrib/conftest.py: -------------------------------------------------------------------------------- 1 | import subprocess 2 | 3 | import pytest 4 | 5 | from mlem.contrib.docker.context import use_mlem_source 6 | from tests.conftest import long 7 | 8 | 9 | @pytest.fixture() 10 | def uses_docker_build(): 11 | with use_mlem_source("whl"): 12 | yield 13 | 14 | 15 | def has_conda(): 16 | try: 17 | ret = subprocess.run(["conda"], check=True) 18 | return ret.returncode == 0 19 | except FileNotFoundError: 20 | return False 21 | 22 | 23 | def conda_test(f): 24 | mark = pytest.mark.conda 25 | skip = pytest.mark.skipif(not has_conda(), reason="conda is unavailable") 26 | return long(mark(skip(f))) 27 | -------------------------------------------------------------------------------- /tests/contrib/resources/im.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/iterative/mlem/9aa89423f6f0f38dba50c3dc04fc7674738f1260/tests/contrib/resources/im.jpg -------------------------------------------------------------------------------- /tests/contrib/resources/pandas/.mlem.yaml: -------------------------------------------------------------------------------- 1 | pandas: 2 | default_format: json 3 | -------------------------------------------------------------------------------- /tests/contrib/test_bitbucket.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import pytest 4 | from pytest_lazyfixture import lazy_fixture 5 | 6 | from mlem.contrib.bitbucketfs import BitBucketFileSystem, ls_bb_refs 7 | from mlem.core.errors import RevisionNotFound 8 | from mlem.core.meta_io import Location, get_fs 9 | from mlem.core.metadata import load_meta 10 | from mlem.core.objects import MlemModel 11 | from tests.conftest import get_current_test_branch, long 12 | 13 | MLEM_TEST_REPO_PROJECT = "iterative-ai/mlem-test" 14 | 15 | MLEM_TEST_REPO_URI = f"https://bitbucket.org/{MLEM_TEST_REPO_PROJECT}" 16 | 17 | 18 | @pytest.fixture() 19 | def fs_no_auth(): 20 | username = os.environ.get("BITBUCKET_USERNAME", None) 21 | try: 22 | os.environ.pop("BITBUCKET_USERNAME", None) 23 | yield BitBucketFileSystem(MLEM_TEST_REPO_PROJECT) 24 | finally: 25 | if username: 26 | os.environ["BITBUCKET_USERNAME"] = username 27 | 28 | 29 | @pytest.fixture() 30 | def fs_auth(): 31 | return BitBucketFileSystem(MLEM_TEST_REPO_PROJECT) 32 | 33 | 34 | @pytest.fixture() 35 | def current_test_branch_bb(): 36 | return get_current_test_branch(set(ls_bb_refs(MLEM_TEST_REPO_PROJECT))) 37 | 38 | 39 | @long 40 | @pytest.mark.parametrize( 41 | "fs", 42 | [lazy_fixture("fs_auth"), lazy_fixture("fs_no_auth")], 43 | ) 44 | def test_ls(fs): 45 | assert "README.md" in fs.ls("") 46 | 47 | 48 | @long 49 | @pytest.mark.parametrize( 50 | "fs", 51 | [lazy_fixture("fs_auth"), lazy_fixture("fs_no_auth")], 52 | ) 53 | def test_open(fs): 54 | with fs.open("README.md", "r") as f: 55 | assert f.read().startswith("### Fixture for mlem tests") 56 | 57 | 58 | @long 59 | @pytest.mark.parametrize( 60 | "uri", 61 | [ 62 | MLEM_TEST_REPO_URI + "/src/main/path", 63 | f"bitbucket://{MLEM_TEST_REPO_PROJECT}@main/path", 64 | ], 65 | ) 66 | def test_uri_resolver(uri): 67 | fs, path = get_fs(uri) 68 | 69 | assert isinstance(fs, BitBucketFileSystem) 70 | assert path == "path" 71 | 72 | 73 | @long 74 | @pytest.mark.parametrize( 75 | "rev", 76 | ["main", "branch", "tag", "3897d2ab"], 77 | ) 78 | def test_uri_resolver_rev(rev): 79 | location = Location.resolve(MLEM_TEST_REPO_URI, None, rev=rev, fs=None) 80 | assert isinstance(location.fs, BitBucketFileSystem) 81 | assert location.fs.root == rev 82 | assert "README.md" in location.fs.ls("") 83 | 84 | 85 | @long 86 | def test_uri_resolver_wrong_rev(): 87 | with pytest.raises(RevisionNotFound): 88 | Location.resolve( 89 | MLEM_TEST_REPO_URI, None, rev="__not_exists__", fs=None 90 | ) 91 | 92 | 93 | @long 94 | def test_loading_object(current_test_branch_bb): 95 | meta = load_meta( 96 | "latest", 97 | project=MLEM_TEST_REPO_URI + "/src/main/simple", 98 | rev=current_test_branch_bb, 99 | ) 100 | assert isinstance(meta, MlemModel) 101 | -------------------------------------------------------------------------------- /tests/contrib/test_catboost.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import pytest 3 | from catboost import CatBoostClassifier, CatBoostRegressor 4 | 5 | from mlem.constants import PREDICT_METHOD_NAME, PREDICT_PROBA_METHOD_NAME 6 | from mlem.contrib.numpy import NumpyNdarrayType 7 | from mlem.core.artifacts import LOCAL_STORAGE 8 | from mlem.core.data_type import DataAnalyzer 9 | from mlem.core.model import Argument, ModelAnalyzer 10 | from tests.conftest import long 11 | 12 | 13 | @pytest.fixture 14 | def catboost_params(tmpdir): 15 | return {"iterations": 1, "train_dir": str(tmpdir)} 16 | 17 | 18 | @pytest.fixture 19 | def catboost_classifier(pandas_data, catboost_params): 20 | return CatBoostClassifier(**catboost_params).fit(pandas_data, [1, 0]) 21 | 22 | 23 | @pytest.fixture 24 | def catboost_regressor(pandas_data, catboost_params): 25 | return CatBoostRegressor(**catboost_params).fit(pandas_data, [1, 0]) 26 | 27 | 28 | @long 29 | @pytest.mark.parametrize( 30 | "catboost_model_fixture", ["catboost_classifier", "catboost_regressor"] 31 | ) 32 | def test_catboost_model(catboost_model_fixture, pandas_data, tmpdir, request): 33 | catboost_model = request.getfixturevalue(catboost_model_fixture) 34 | 35 | cbmw = ModelAnalyzer.analyze(catboost_model, sample_data=pandas_data) 36 | 37 | data_type = DataAnalyzer.analyze(pandas_data) 38 | 39 | assert "predict" in cbmw.methods 40 | signature = cbmw.methods["predict"] 41 | assert signature.name == "predict" 42 | assert signature.args[0] == Argument(name="data", type_=data_type) 43 | returns = NumpyNdarrayType( 44 | shape=(None,), 45 | dtype="float64" 46 | if catboost_model_fixture == "catboost_regressor" 47 | else "int64", 48 | ) 49 | assert signature.returns == returns 50 | 51 | expected_requirements = {"catboost", "pandas"} 52 | reqs = set(cbmw.get_requirements().modules) 53 | assert all(r in reqs for r in expected_requirements) 54 | assert cbmw.model is catboost_model 55 | 56 | artifacts = cbmw.dump(LOCAL_STORAGE, tmpdir) 57 | 58 | cbmw.model = None 59 | with pytest.raises(ValueError): 60 | cbmw.call_method(PREDICT_METHOD_NAME, pandas_data) 61 | 62 | cbmw.load(artifacts) 63 | assert cbmw.model is not catboost_model 64 | reqs = set(cbmw.get_requirements().modules) 65 | assert all(r in reqs for r in expected_requirements) 66 | 67 | np.testing.assert_array_almost_equal( 68 | catboost_model.predict(pandas_data), 69 | cbmw.call_method(PREDICT_METHOD_NAME, pandas_data), 70 | ) 71 | 72 | if isinstance(catboost_model, CatBoostClassifier): 73 | np.testing.assert_array_almost_equal( 74 | catboost_model.predict_proba(pandas_data), 75 | cbmw.call_method(PREDICT_PROBA_METHOD_NAME, pandas_data), 76 | ) 77 | else: 78 | assert PREDICT_PROBA_METHOD_NAME not in cbmw.methods 79 | -------------------------------------------------------------------------------- /tests/contrib/test_docker/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/iterative/mlem/9aa89423f6f0f38dba50c3dc04fc7674738f1260/tests/contrib/test_docker/__init__.py -------------------------------------------------------------------------------- /tests/contrib/test_docker/conftest.py: -------------------------------------------------------------------------------- 1 | import os 2 | import time 3 | 4 | import docker.errors 5 | import pytest 6 | from testcontainers.core.container import DockerContainer as TestContainer 7 | 8 | from mlem.contrib.docker.base import DockerDaemon, DockerEnv, RemoteRegistry 9 | from mlem.contrib.docker.context import use_mlem_source 10 | from mlem.contrib.docker.utils import is_docker_running 11 | from tests.conftest import long 12 | 13 | EXTERNAL_REGISTRY_PORT = 2374 14 | INTERNAL_REGISTRY_PORT = 5000 15 | DAEMON_PORT = 2375 16 | CLEAN = True 17 | IMAGE_NAME = "mlem_test_docker_builder_image" 18 | 19 | 20 | @pytest.fixture(scope="session") 21 | def dockerenv_local(tmp_path_factory): 22 | return DockerEnv().dump(str(tmp_path_factory.mktemp("dockerenv_local"))) 23 | 24 | 25 | @pytest.fixture(scope="session") 26 | def dind(): 27 | with ( 28 | TestContainer("docker:dind") 29 | .with_env("DOCKER_TLS_CERTDIR", "") 30 | .with_kargs(privileged=True) 31 | .with_exposed_ports(DAEMON_PORT) 32 | .with_bind_ports(EXTERNAL_REGISTRY_PORT, EXTERNAL_REGISTRY_PORT) 33 | ) as daemon: 34 | time.sleep(1) 35 | yield daemon 36 | 37 | 38 | @pytest.fixture(scope="session") 39 | def docker_daemon(dind): 40 | daemon = DockerDaemon( 41 | host=f"tcp://localhost:{dind.get_exposed_port(DAEMON_PORT)}" 42 | ) 43 | exc = None 44 | for _ in range(10): 45 | try: 46 | with daemon.client() as c: 47 | c.info() 48 | return daemon 49 | except docker.errors.DockerException as e: 50 | exc = e 51 | time.sleep(2) 52 | if exc: 53 | raise exc 54 | return None 55 | 56 | 57 | @pytest.fixture(scope="session") 58 | def docker_registry(dind, docker_daemon): 59 | with docker_daemon.client() as c: 60 | c: docker.DockerClient 61 | c.containers.run( 62 | "registry:latest", 63 | ports={INTERNAL_REGISTRY_PORT: EXTERNAL_REGISTRY_PORT}, 64 | detach=True, 65 | remove=True, 66 | environment={"REGISTRY_STORAGE_DELETE_ENABLED": "true"}, 67 | ) 68 | yield RemoteRegistry(host=f"localhost:{EXTERNAL_REGISTRY_PORT}") 69 | 70 | 71 | @pytest.fixture(scope="session") 72 | def dockerenv_remote(docker_registry, docker_daemon, tmp_path_factory): 73 | return DockerEnv(registry=docker_registry, daemon=docker_daemon).dump( 74 | str(tmp_path_factory.mktemp("dockerenv_remote")) 75 | ) 76 | 77 | 78 | def has_docker(): 79 | if os.environ.get("SKIP_DOCKER_TESTS", None) == "true": 80 | return False 81 | current_os = os.environ.get("GITHUB_MATRIX_OS") 82 | current_python = os.environ.get("GITHUB_MATRIX_PYTHON") 83 | if ( 84 | current_os is not None 85 | and current_os != "ubuntu-latest" 86 | or current_python is not None 87 | and current_python != "3.8" 88 | ): 89 | return False 90 | return is_docker_running() 91 | 92 | 93 | def docker_test(f): 94 | mark = pytest.mark.docker 95 | skip = pytest.mark.skipif( 96 | not has_docker(), reason="docker is unavailable or skipped" 97 | ) 98 | return long(mark(skip(f))) 99 | 100 | 101 | @pytest.fixture(scope="session", autouse=True) 102 | def mlem_source(): 103 | with use_mlem_source("whl"): 104 | yield 105 | -------------------------------------------------------------------------------- /tests/contrib/test_docker/resources/dockerfile.j2: -------------------------------------------------------------------------------- 1 | FROM alpine 2 | 3 | CMD sleep infinity 4 | -------------------------------------------------------------------------------- /tests/contrib/test_docker/test_pack.py: -------------------------------------------------------------------------------- 1 | import os 2 | import time 3 | 4 | import pytest 5 | import requests 6 | from pytest_lazyfixture import lazy_fixture 7 | from testcontainers.general import TestContainer 8 | 9 | from mlem.api import build 10 | from mlem.contrib.docker import DockerDirBuilder, DockerImageBuilder 11 | from mlem.contrib.docker.base import DockerImage 12 | from mlem.contrib.docker.context import DockerModelDirectory 13 | from mlem.contrib.fastapi import FastAPIServer 14 | from tests.conftest import long 15 | from tests.contrib.test_docker.conftest import docker_test 16 | 17 | SERVER_PORT = 8080 18 | 19 | 20 | @long 21 | @pytest.mark.parametrize( 22 | "modelmeta", [lazy_fixture("model_meta"), lazy_fixture("model_meta_saved")] 23 | ) 24 | def test_build_dir(tmpdir, modelmeta): 25 | built = build( 26 | DockerDirBuilder(server=FastAPIServer(), target=str(tmpdir)), 27 | modelmeta, 28 | ) 29 | assert isinstance(built, DockerModelDirectory) 30 | assert os.path.isfile(tmpdir / "run.sh") 31 | assert os.path.isfile(tmpdir / "Dockerfile") 32 | assert os.path.isfile(tmpdir / "requirements.txt") 33 | assert os.path.isfile(tmpdir / "model") 34 | assert os.path.isfile(tmpdir / "model.mlem") 35 | 36 | 37 | @docker_test 38 | def test_pack_image( 39 | model_meta_saved_single, dockerenv_local, uses_docker_build 40 | ): 41 | built = build( 42 | DockerImageBuilder( 43 | server=FastAPIServer(), 44 | image=DockerImage(name="pack_docker_test_image"), 45 | force_overwrite=True, 46 | ), 47 | model_meta_saved_single, 48 | ) 49 | assert isinstance(built, DockerImage) 50 | assert dockerenv_local.image_exists(built) 51 | with ( 52 | TestContainer(built.name) 53 | .with_env("DOCKER_TLS_CERTDIR", "") 54 | .with_exposed_ports(SERVER_PORT) 55 | ) as service: 56 | time.sleep(10) 57 | r = requests.post( 58 | f"http://localhost:{service.get_exposed_port(SERVER_PORT)}/predict", 59 | json={"data": [[0, 0, 0, 0]]}, 60 | ) 61 | assert r.status_code == 200 62 | 63 | 64 | @docker_test 65 | def test_pack_image_with_processors( 66 | processors_model, dockerenv_local, uses_docker_build 67 | ): 68 | built = build( 69 | DockerImageBuilder( 70 | server=FastAPIServer(), 71 | image=DockerImage(name="pack_docker_test_image_proc"), 72 | force_overwrite=True, 73 | ), 74 | processors_model, 75 | ) 76 | assert isinstance(built, DockerImage) 77 | assert dockerenv_local.image_exists(built) 78 | with ( 79 | TestContainer(built.name) 80 | .with_env("DOCKER_TLS_CERTDIR", "") 81 | .with_exposed_ports(SERVER_PORT) 82 | ) as service: 83 | time.sleep(10) 84 | r = requests.post( 85 | f"http://localhost:{service.get_exposed_port(SERVER_PORT)}/predict", 86 | json={"data": ["1", "2", "3"]}, 87 | ) 88 | assert r.status_code == 200 89 | assert r.json() == 4 90 | -------------------------------------------------------------------------------- /tests/contrib/test_docker/test_utils.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | 3 | from mlem.contrib.docker.utils import ( 4 | image_exists_at_dockerhub, 5 | repository_tags_at_dockerhub, 6 | ) 7 | from mlem.utils.module import get_python_version 8 | from tests.contrib.test_docker.conftest import docker_test 9 | 10 | 11 | @docker_test 12 | def test_image_exists(): 13 | assert image_exists_at_dockerhub( 14 | f"python:{get_python_version()}-slim", library=True 15 | ) 16 | assert image_exists_at_dockerhub("minio/minio:latest") 17 | assert image_exists_at_dockerhub("postgres:alpine", library=True) 18 | assert image_exists_at_dockerhub("registry:latest", library=True) 19 | 20 | 21 | @docker_test 22 | def test_image_not_exists(): 23 | assert not image_exists_at_dockerhub("python:this_does_not_exist") 24 | assert not image_exists_at_dockerhub("mlem:this_does_not_exist") 25 | assert not image_exists_at_dockerhub("minio:this_does_not_exist") 26 | assert not image_exists_at_dockerhub("registry:this_does_not_exist") 27 | assert not image_exists_at_dockerhub("this_does_not_exist:latest") 28 | 29 | 30 | @docker_test 31 | def test_repository_tags(request): 32 | tags = repository_tags_at_dockerhub("python", library=True) 33 | python_version = get_python_version() 34 | if python_version == "3.8.14": 35 | request.applymarker(pytest.mark.xfail) 36 | assert f"{python_version}-slim" in tags 37 | assert python_version in tags 38 | 39 | tags = repository_tags_at_dockerhub("minio/minio") 40 | assert "latest" in tags 41 | 42 | 43 | # Copyright 2019 Zyfra 44 | # Copyright 2021 Iterative 45 | # 46 | # Licensed under the Apache License, Version 2.0 (the "License"); 47 | # you may not use this file except in compliance with the License. 48 | # You may obtain a copy of the License at 49 | # 50 | # http://www.apache.org/licenses/LICENSE-2.0 51 | # 52 | # Unless required by applicable law or agreed to in writing, software 53 | # distributed under the License is distributed on an "AS IS" BASIS, 54 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 55 | # See the License for the specific language governing permissions and 56 | # limitations under the License. 57 | -------------------------------------------------------------------------------- /tests/contrib/test_flyio.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | from unittest.mock import ANY, patch 3 | 4 | from mlem.contrib.flyio.meta import FlyioApp 5 | from mlem.contrib.flyio.utils import FlyioStatusModel 6 | 7 | 8 | def test_flyio_create_app(tmp_path: Path): 9 | flyio_app = FlyioApp(org="org", app_name="test") 10 | flyio_app.dump(str(tmp_path)) 11 | state = flyio_app.get_state() 12 | status = FlyioStatusModel( 13 | Name="test", Hostname="fly.io", Status="Deployed" 14 | ) 15 | 16 | with patch("mlem.contrib.flyio.meta.run_flyctl") as run_flyctl: 17 | with patch("mlem.contrib.flyio.meta.read_fly_toml") as read_fly_toml: 18 | with patch("mlem.contrib.flyio.meta.get_status") as get_status: 19 | with patch("mlem.contrib.flyio.meta.FlyioApp._build_in_dir"): 20 | get_status.return_value = status 21 | read_fly_toml.return_value = "" 22 | flyio_app.deploy(state) 23 | 24 | run_flyctl.assert_called_once_with( 25 | "launch", 26 | workdir=ANY, 27 | kwargs={ 28 | "auto-confirm": True, 29 | "reuse-app": True, 30 | "region": "lax", 31 | "no-deploy": True, 32 | "name": "test", 33 | "org": "org", 34 | }, 35 | ) 36 | -------------------------------------------------------------------------------- /tests/contrib/test_github.py: -------------------------------------------------------------------------------- 1 | from urllib.parse import quote_plus 2 | 3 | import pytest 4 | 5 | from mlem.contrib.github import ( 6 | GithubResolver, 7 | github_check_rev, 8 | is_long_sha, 9 | ls_branches, 10 | ls_github_branches, 11 | ) 12 | from tests.conftest import ( 13 | MLEM_TEST_REPO, 14 | MLEM_TEST_REPO_NAME, 15 | MLEM_TEST_REPO_ORG, 16 | long, 17 | need_test_repo_auth, 18 | need_test_repo_ssh_auth, 19 | ) 20 | 21 | 22 | @long 23 | @need_test_repo_ssh_auth 24 | def test_ls_branches(): 25 | assert "main" in ls_branches(MLEM_TEST_REPO) 26 | 27 | 28 | @long 29 | @need_test_repo_auth 30 | def test_ls_github_branches(): 31 | assert "main" in ls_github_branches( 32 | MLEM_TEST_REPO_ORG, MLEM_TEST_REPO_NAME 33 | ) 34 | 35 | 36 | @pytest.fixture 37 | def set_mock_refs(mocker): 38 | def set(rev): 39 | mocker.patch( 40 | "mlem.contrib.github._ls_github_refs", 41 | return_value={rev: ""}, 42 | ) 43 | 44 | return set 45 | 46 | 47 | @pytest.mark.parametrize( 48 | "uri, rev", 49 | [ 50 | ("https://github.com/org/repo/tree/simple_ref/path", "simple_ref"), 51 | ( 52 | "https://github.com/org/repo/tree/ref/with/slashes/path", 53 | "ref/with/slashes", 54 | ), 55 | ], 56 | ) 57 | def test_get_github_kwargs(set_mock_refs, uri, rev): 58 | set_mock_refs(rev) 59 | assert GithubResolver.get_kwargs(uri) == { 60 | "org": "org", 61 | "repo": "repo", 62 | "path": "path", 63 | "sha": quote_plus(rev), 64 | } 65 | 66 | 67 | def test_get_github_kwargs__empty_path(set_mock_refs): 68 | set_mock_refs("ref") 69 | assert GithubResolver.get_kwargs( 70 | "https://github.com/org/repo/tree/ref/" 71 | ) == { 72 | "org": "org", 73 | "repo": "repo", 74 | "path": "", 75 | "sha": "ref", 76 | } 77 | 78 | 79 | @long 80 | def test_github_check_rev(): 81 | assert github_check_rev( 82 | MLEM_TEST_REPO_ORG, MLEM_TEST_REPO_NAME, "main" 83 | ) # branch 84 | assert not github_check_rev( 85 | MLEM_TEST_REPO_ORG, MLEM_TEST_REPO_NAME, "_____" 86 | ) # not exists 87 | assert github_check_rev( 88 | MLEM_TEST_REPO_ORG, 89 | MLEM_TEST_REPO_NAME, 90 | "bf022746331ec6888e58b483fbc1fb08313dffc0", 91 | ) # commit 92 | assert github_check_rev( 93 | MLEM_TEST_REPO_ORG, MLEM_TEST_REPO_NAME, "first_rev_link" 94 | ) # tag 95 | 96 | 97 | def test_is_long_sha(): 98 | assert is_long_sha("cd7c2a08911b697c3f80c73d0394fb105d3044d5") 99 | assert not is_long_sha("cd7c2a08911b697c3f80c73d0394fb105d3044d51") 100 | assert not is_long_sha("cd7c2a08911b697c3f80c73d0394fb105d3044dA") 101 | assert not is_long_sha("cd7c2a08911b697c3f80c73d0394fb105d3044d") 102 | -------------------------------------------------------------------------------- /tests/contrib/test_gitlab.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | 3 | from mlem.contrib.gitlabfs import GitlabFileSystem, ls_gitlab_refs 4 | from mlem.core.errors import RevisionNotFound 5 | from mlem.core.meta_io import Location, get_fs 6 | from mlem.core.metadata import load_meta 7 | from mlem.core.objects import MlemModel 8 | from tests.conftest import get_current_test_branch, long 9 | 10 | MLEM_TEST_REPO_PROJECT = "iterative.ai/mlem-test" 11 | 12 | MLEM_TEST_REPO_URI = f"https://gitlab.com/{MLEM_TEST_REPO_PROJECT}" 13 | 14 | 15 | @pytest.fixture() 16 | def current_test_branch_gl(): 17 | return get_current_test_branch(set(ls_gitlab_refs(MLEM_TEST_REPO_PROJECT))) 18 | 19 | 20 | @long 21 | def test_ls(): 22 | fs = GitlabFileSystem(MLEM_TEST_REPO_PROJECT) 23 | assert "README.md" in fs.ls("") 24 | 25 | 26 | @long 27 | def test_open(): 28 | fs = GitlabFileSystem(MLEM_TEST_REPO_PROJECT) 29 | with fs.open("README.md", "r") as f: 30 | assert f.read().startswith("### Fixture for mlem tests") 31 | 32 | 33 | @long 34 | @pytest.mark.parametrize( 35 | "uri", 36 | [ 37 | MLEM_TEST_REPO_URI + "/-/blob/main/path", 38 | f"gitlab://{MLEM_TEST_REPO_PROJECT}@main/path", 39 | ], 40 | ) 41 | def test_uri_resolver(uri): 42 | fs, path = get_fs(uri) 43 | 44 | assert isinstance(fs, GitlabFileSystem) 45 | assert path == "path" 46 | 47 | 48 | @long 49 | @pytest.mark.parametrize( 50 | "rev", 51 | ["main", "branch", "tag", "3897d2ab"], 52 | ) 53 | def test_uri_resolver_rev(rev): 54 | location = Location.resolve(MLEM_TEST_REPO_URI, None, rev=rev, fs=None) 55 | assert isinstance(location.fs, GitlabFileSystem) 56 | assert location.fs.root == rev 57 | assert "README.md" in location.fs.ls("") 58 | 59 | 60 | @long 61 | def test_uri_resolver_wrong_rev(): 62 | with pytest.raises(RevisionNotFound): 63 | Location.resolve( 64 | MLEM_TEST_REPO_URI, None, rev="__not_exists__", fs=None 65 | ) 66 | 67 | 68 | @long 69 | def test_loading_object(current_test_branch_gl): 70 | meta = load_meta( 71 | "latest", 72 | project=MLEM_TEST_REPO_URI + "/-/blob/main/simple", 73 | rev=current_test_branch_gl, 74 | ) 75 | assert isinstance(meta, MlemModel) 76 | -------------------------------------------------------------------------------- /tests/contrib/test_kubernetes/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/iterative/mlem/9aa89423f6f0f38dba50c3dc04fc7674738f1260/tests/contrib/test_kubernetes/__init__.py -------------------------------------------------------------------------------- /tests/contrib/test_kubernetes/conftest.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import pytest 4 | from kubernetes import client, config 5 | 6 | from tests.conftest import long 7 | 8 | from .utils import Command 9 | 10 | 11 | def is_minikube_running() -> bool: 12 | try: 13 | cmd = Command("minikube status") 14 | returncode = cmd.run(timeout=3, shell=True) 15 | if returncode == 0: 16 | config.load_kube_config( 17 | config_file=os.getenv("KUBECONFIG", default="~/.kube/config") 18 | ) 19 | client.CoreV1Api().list_namespaced_pod("default") 20 | return True 21 | return False 22 | except (config.config_exception.ConfigException, ConnectionRefusedError): 23 | return False 24 | 25 | 26 | def has_k8s(): 27 | if os.environ.get("SKIP_K8S_TESTS", None) == "true": 28 | return False 29 | current_os = os.environ.get("GITHUB_MATRIX_OS") 30 | current_python = os.environ.get("GITHUB_MATRIX_PYTHON") 31 | if ( 32 | current_os is not None 33 | and current_os != "ubuntu-latest" 34 | or current_python is not None 35 | and current_python != "3.9" 36 | ): 37 | return False 38 | return is_minikube_running() 39 | 40 | 41 | def k8s_test(f): 42 | mark = pytest.mark.kubernetes 43 | skip = pytest.mark.skipif( 44 | not has_k8s(), reason="kubernetes is unavailable or skipped" 45 | ) 46 | return long(mark(skip(f))) 47 | -------------------------------------------------------------------------------- /tests/contrib/test_kubernetes/test_context.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | 3 | from mlem.contrib.kubernetes.context import ( 4 | ImagePullPolicy, 5 | K8sYamlBuildArgs, 6 | K8sYamlGenerator, 7 | ) 8 | from mlem.contrib.kubernetes.service import LoadBalancerService 9 | from tests.conftest import _cut_empty_lines 10 | 11 | 12 | @pytest.fixture 13 | def k8s_default_manifest(): 14 | return _cut_empty_lines( 15 | """apiVersion: v1 16 | kind: Namespace 17 | metadata: 18 | name: mlem 19 | labels: 20 | name: mlem 21 | 22 | --- 23 | 24 | apiVersion: apps/v1 25 | kind: Deployment 26 | metadata: 27 | name: ml 28 | namespace: mlem 29 | spec: 30 | selector: 31 | matchLabels: 32 | app: ml 33 | template: 34 | metadata: 35 | labels: 36 | app: ml 37 | spec: 38 | containers: 39 | - name: ml 40 | image: ml:latest 41 | imagePullPolicy: Always 42 | ports: 43 | - containerPort: 8080 44 | 45 | --- 46 | 47 | apiVersion: v1 48 | kind: Service 49 | metadata: 50 | name: ml 51 | namespace: mlem 52 | labels: 53 | run: ml 54 | spec: 55 | ports: 56 | - port: 8080 57 | protocol: TCP 58 | targetPort: 8080 59 | selector: 60 | app: ml 61 | type: NodePort 62 | """ 63 | ) 64 | 65 | 66 | @pytest.fixture 67 | def k8s_manifest(): 68 | return _cut_empty_lines( 69 | """apiVersion: v1 70 | kind: Namespace 71 | metadata: 72 | name: hello 73 | labels: 74 | name: hello 75 | 76 | --- 77 | 78 | apiVersion: apps/v1 79 | kind: Deployment 80 | metadata: 81 | name: test 82 | namespace: hello 83 | spec: 84 | selector: 85 | matchLabels: 86 | app: test 87 | template: 88 | metadata: 89 | labels: 90 | app: test 91 | spec: 92 | containers: 93 | - name: test 94 | image: test:latest 95 | imagePullPolicy: Never 96 | ports: 97 | - containerPort: 8080 98 | 99 | --- 100 | 101 | apiVersion: v1 102 | kind: Service 103 | metadata: 104 | name: test 105 | namespace: hello 106 | labels: 107 | run: test 108 | spec: 109 | ports: 110 | - port: 8080 111 | protocol: TCP 112 | targetPort: 8080 113 | selector: 114 | app: test 115 | type: LoadBalancer 116 | """ 117 | ) 118 | 119 | 120 | def test_k8s_yaml_build_args_default(k8s_default_manifest): 121 | build_args = K8sYamlBuildArgs() 122 | assert _generate_k8s_manifest(**build_args.dict()) == k8s_default_manifest 123 | 124 | 125 | def test_k8s_yaml_build_args(k8s_manifest): 126 | build_args = K8sYamlBuildArgs( 127 | namespace="hello", 128 | image_name="test", 129 | image_uri="test:latest", 130 | image_pull_policy=ImagePullPolicy.never, 131 | port=8080, 132 | service_type=LoadBalancerService(), 133 | ) 134 | assert _generate_k8s_manifest(**build_args.dict()) == k8s_manifest 135 | 136 | 137 | def test_k8s_yaml_generator(k8s_manifest): 138 | kwargs = { 139 | "namespace": "hello", 140 | "image_name": "test", 141 | "image_uri": "test:latest", 142 | "image_pull_policy": "Never", 143 | "port": 8080, 144 | "service_type": LoadBalancerService(), 145 | } 146 | assert _generate_k8s_manifest(**kwargs) == k8s_manifest 147 | 148 | 149 | def _generate_k8s_manifest(**kwargs): 150 | return _cut_empty_lines(K8sYamlGenerator(**kwargs).generate()) 151 | -------------------------------------------------------------------------------- /tests/contrib/test_kubernetes/utils.py: -------------------------------------------------------------------------------- 1 | import subprocess 2 | import threading 3 | 4 | 5 | class Command: 6 | """ 7 | Enables to run subprocess commands in a different thread 8 | with TIMEOUT option! 9 | Based on jcollado's solution: 10 | http://stackoverflow.com/questions/1191374/subprocess-with-timeout/4825933#4825933 11 | """ 12 | 13 | def __init__(self, cmd): 14 | self.cmd = cmd 15 | self.process = None 16 | 17 | def run(self, timeout=0, **kwargs): 18 | def target(**kwargs): 19 | self.process = ( 20 | subprocess.Popen( # pylint: disable=consider-using-with 21 | self.cmd, **kwargs 22 | ) 23 | ) 24 | self.process.communicate() 25 | 26 | thread = threading.Thread(target=target, kwargs=kwargs) 27 | thread.start() 28 | 29 | thread.join(timeout) 30 | if thread.is_alive(): 31 | self.process.terminate() 32 | thread.join() 33 | 34 | return self.process.returncode 35 | -------------------------------------------------------------------------------- /tests/contrib/test_pil.py: -------------------------------------------------------------------------------- 1 | import matplotlib.image 2 | import numpy as np 3 | import pytest 4 | 5 | from mlem.contrib.numpy import NumpyNdarrayType 6 | from mlem.contrib.pil import PILImageSerializer 7 | from mlem.core.data_type import DataType 8 | from tests.conftest import resource_path 9 | 10 | IMAGE_PATH = resource_path(__file__, "im.jpg") 11 | 12 | 13 | @pytest.fixture 14 | def np_image(): 15 | return matplotlib.image.imread(resource_path(__file__, "im.jpg")) 16 | 17 | 18 | def test_pil_serializer(np_image): 19 | data_type = DataType.create(np_image) 20 | 21 | assert isinstance(data_type, NumpyNdarrayType) 22 | 23 | payload = PILImageSerializer().serialize(data_type, np_image) 24 | assert isinstance(payload, bytes) 25 | image_again = PILImageSerializer().deserialize(data_type, payload) 26 | 27 | assert np.equal(image_again, np_image) 28 | -------------------------------------------------------------------------------- /tests/contrib/test_pip.py: -------------------------------------------------------------------------------- 1 | import os 2 | import subprocess 3 | 4 | from mlem.contrib.pip.base import PipBuilder, WhlBuilder 5 | from tests.conftest import long 6 | 7 | PIP_PACKAGE_NAME = "test_pip_package_name" 8 | 9 | 10 | @long 11 | def test_pip_package(tmpdir, model_meta_saved_single): 12 | path = str(tmpdir) 13 | builder = PipBuilder(target=path, package_name=PIP_PACKAGE_NAME) 14 | builder.build(model_meta_saved_single) 15 | 16 | print( 17 | subprocess.check_output( 18 | "pip install -e . --no-deps", shell=True, cwd=path 19 | ).decode("utf8") 20 | ) 21 | try: 22 | subprocess.check_output( 23 | f'python -c "import {PIP_PACKAGE_NAME}; print({PIP_PACKAGE_NAME}.predict([[1,2,3,4]]))"', 24 | shell=True, 25 | ) 26 | finally: 27 | print( 28 | subprocess.check_output( 29 | f"pip uninstall {PIP_PACKAGE_NAME} -y", shell=True 30 | ) 31 | ) 32 | 33 | 34 | @long 35 | def test_whl_build(tmpdir, model_meta_saved_single): 36 | path = str(tmpdir) 37 | builder = WhlBuilder( 38 | target=path, package_name=PIP_PACKAGE_NAME, version="1.0.0" 39 | ) 40 | builder.build(model_meta_saved_single) 41 | files = os.listdir(tmpdir) 42 | assert len(files) == 1 43 | whl_path = files[0] 44 | assert whl_path.endswith(".whl") 45 | assert PIP_PACKAGE_NAME in whl_path 46 | assert "1.0.0" in whl_path 47 | subprocess.check_output( 48 | f"pip install {whl_path} --no-deps", 49 | shell=True, 50 | cwd=path, 51 | ) 52 | try: 53 | subprocess.check_output( 54 | f'python -c "import {PIP_PACKAGE_NAME}; print({PIP_PACKAGE_NAME}.predict([[1,2,3,4]]))"', 55 | shell=True, 56 | ) 57 | finally: 58 | subprocess.check_output( 59 | f"pip uninstall {PIP_PACKAGE_NAME} -y", shell=True 60 | ) 61 | -------------------------------------------------------------------------------- /tests/contrib/test_prometheus.py: -------------------------------------------------------------------------------- 1 | from mlem.contrib.fastapi import FastAPIServer, Middlewares 2 | from mlem.core.objects import MlemModel 3 | from mlem.runtime.client import Client 4 | from mlem.runtime.interface import ModelInterface 5 | from mlem.runtime.server import ServerInterface 6 | 7 | 8 | def test_prometheus_fastapi_middleware(create_mlem_client, create_client): 9 | from mlem.contrib.prometheus import PrometheusFastAPIMiddleware 10 | 11 | model = MlemModel.from_obj(lambda x: x, sample_data=10) 12 | model_interface = ModelInterface.from_model(model) 13 | 14 | server = FastAPIServer( 15 | standardize=True, 16 | middlewares=Middlewares(__root__=[PrometheusFastAPIMiddleware()]), 17 | ) 18 | interface = ServerInterface.create(server, model_interface) 19 | client = create_client(server, interface) 20 | 21 | docs = client.get("/openapi.json") 22 | assert docs.status_code == 200, docs.json() 23 | 24 | # metrics = client.get("/metrics") 25 | # assert metrics.status_code == 200, metrics 26 | # assert metrics.text 27 | 28 | mlem_client: Client = create_mlem_client(client) 29 | remote_interface = mlem_client.interface 30 | dt = remote_interface.__root__["predict"].args[0].data_type 31 | response = client.post("/predict", json={"data": dt.serialize(1)}) 32 | assert response.status_code == 200 33 | resp = remote_interface.__root__["predict"].returns.data_type.deserialize( 34 | response.json() 35 | ) 36 | assert resp == 1 37 | -------------------------------------------------------------------------------- /tests/contrib/test_rabbitmq.py: -------------------------------------------------------------------------------- 1 | import time 2 | from threading import Thread 3 | 4 | import numpy as np 5 | import pytest 6 | import requests 7 | from pika.exceptions import AMQPError 8 | from requests.exceptions import ConnectionError, HTTPError 9 | from testcontainers.general import TestContainer 10 | 11 | from mlem.api import serve 12 | from mlem.contrib.rabbitmq import RabbitMQClient, RabbitMQServer 13 | from tests.conftest import long 14 | from tests.contrib.test_docker.conftest import docker_test 15 | 16 | RMQ_PORT = 5672 17 | RMQ_MANAGE_PORT = 15672 18 | 19 | 20 | @pytest.fixture 21 | def rmq_instance(): 22 | with ( 23 | TestContainer("rabbitmq:3.9-management") 24 | .with_exposed_ports(RMQ_PORT) 25 | .with_exposed_ports(RMQ_MANAGE_PORT) 26 | ) as daemon: 27 | ready = False 28 | times = 0 29 | while not ready and times < 10: 30 | try: 31 | r = requests.head( 32 | f"http://{daemon.get_container_host_ip()}:{daemon.get_exposed_port(RMQ_MANAGE_PORT)}" 33 | ) 34 | r.raise_for_status() 35 | ready = True 36 | except (HTTPError, ConnectionError): 37 | time.sleep(0.5) 38 | times += 1 39 | time.sleep(1) 40 | yield daemon 41 | 42 | 43 | class ServeThread(Thread): 44 | def __init__(self, model, server): 45 | super().__init__() 46 | self.model = model 47 | self.server = server 48 | self.dead = True 49 | 50 | def run(self) -> None: 51 | self.dead = False 52 | try: 53 | serve(self.model, self.server) 54 | finally: 55 | self.dead = True 56 | 57 | 58 | @pytest.fixture 59 | def rmq_server(model_meta_saved_single, rmq_instance): 60 | server = RabbitMQServer( 61 | host=rmq_instance.get_container_host_ip(), 62 | port=int(rmq_instance.get_exposed_port(RMQ_PORT)), 63 | queue_prefix="aaa", 64 | ) 65 | for _ in range(10): 66 | t = ServeThread(model_meta_saved_single, server) 67 | t.start() 68 | time.sleep(0.5) 69 | if not t.dead: 70 | break 71 | t.join() 72 | else: 73 | raise RuntimeError("could not start rmq serving") 74 | 75 | yield server 76 | 77 | 78 | @long 79 | @docker_test 80 | def test_serving(rmq_server): 81 | error = None 82 | for _ in range(20): 83 | try: 84 | client = RabbitMQClient( 85 | host=rmq_server.host, 86 | port=rmq_server.port, 87 | queue_prefix=rmq_server.queue_prefix, 88 | ) 89 | res = client.predict(np.array([[1.0, 1.0, 1.0, 1.0]])) 90 | assert isinstance(res, np.ndarray) 91 | break 92 | except AMQPError as e: 93 | time.sleep(0.5) 94 | error = e 95 | else: 96 | if error is not None: 97 | raise error 98 | pytest.fail("could not connect to server") 99 | -------------------------------------------------------------------------------- /tests/contrib/test_requirements.py: -------------------------------------------------------------------------------- 1 | import lightgbm as lgb 2 | import numpy as np 3 | import pytest 4 | from pydantic.error_wrappers import ValidationError 5 | 6 | from mlem.contrib.requirements import RequirementsBuilder 7 | from mlem.core.objects import MlemModel 8 | 9 | 10 | def test_build_reqs(tmp_path, model_meta): 11 | path = str(tmp_path / "reqs.txt") 12 | builder = RequirementsBuilder(target=path) 13 | builder.build(model_meta) 14 | with open(path, "r", encoding="utf-8") as f: 15 | assert model_meta.requirements.to_pip() == f.read().splitlines() 16 | 17 | 18 | def test_build_reqs_with_invalid_req_type(): 19 | with pytest.raises( 20 | ValidationError, match="req_type invalid is not valid." 21 | ): 22 | RequirementsBuilder(req_type="invalid") 23 | 24 | 25 | def test_build_requirements_should_print_with_no_path(capsys, model_meta): 26 | builder = RequirementsBuilder() 27 | builder.build(model_meta) 28 | captured = capsys.readouterr() 29 | assert captured.out == " ".join(model_meta.requirements.to_pip()) + "\n" 30 | 31 | 32 | def test_unix_requirement(capsys): 33 | np_payload = np.linspace(0, 2, 5).reshape((-1, 1)) 34 | data_np = lgb.Dataset( 35 | np_payload, 36 | label=np_payload.reshape((-1,)).tolist(), 37 | free_raw_data=False, 38 | ) 39 | booster = lgb.train({}, data_np, 1) 40 | model = MlemModel.from_obj(booster, sample_data=data_np) 41 | builder = RequirementsBuilder(req_type="unix") 42 | builder.build(model) 43 | captured = capsys.readouterr() 44 | assert str(captured.out).endswith( 45 | "\n".join(model.requirements.to_unix()) + "\n" 46 | ) 47 | -------------------------------------------------------------------------------- /tests/contrib/test_scipy.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import pytest 3 | from scipy.sparse import csr_matrix 4 | 5 | from mlem.contrib.scipy import ScipySparseMatrix 6 | from mlem.core.data_type import DataAnalyzer 7 | from mlem.core.errors import DeserializationError, SerializationError 8 | from tests.conftest import data_write_read_check 9 | 10 | 11 | @pytest.fixture 12 | def raw_data(): 13 | row = np.array([0, 0, 1, 2, 2, 2]) 14 | col = np.array([0, 2, 2, 0, 1, 2]) 15 | data = np.array([1, 2, 3, 4, 5, 6]) 16 | return data, (row, col) 17 | 18 | 19 | @pytest.fixture 20 | def sparse_mat(raw_data): 21 | return csr_matrix(raw_data, shape=(3, 3), dtype="float32") 22 | 23 | 24 | @pytest.fixture 25 | def schema(): 26 | return { 27 | "title": "ScipySparse", 28 | "type": "array", 29 | "items": { 30 | "type": "array", 31 | "items": {"type": "number"}, 32 | "minItems": 3, 33 | "maxItems": 3, 34 | }, 35 | } 36 | 37 | 38 | @pytest.fixture 39 | def sparse_data_type(sparse_mat): 40 | return DataAnalyzer.analyze(sparse_mat) 41 | 42 | 43 | def test_sparce_matrix(sparse_mat, schema): 44 | assert ScipySparseMatrix.is_object_valid(sparse_mat) 45 | sdt = DataAnalyzer.analyze(sparse_mat) 46 | assert sdt.dict() == { 47 | "dtype": "float32", 48 | "type": "csr_matrix", 49 | "shape": (3, 3), 50 | } 51 | model = sdt.get_model() 52 | assert model.__name__ == "ScipySparse" 53 | assert model.schema() == schema 54 | assert isinstance(sdt, ScipySparseMatrix) 55 | assert sdt.dtype == "float32" 56 | assert sdt.get_requirements().modules == ["scipy"] 57 | 58 | 59 | def test_serialization(raw_data, sparse_mat): 60 | sdt = DataAnalyzer.analyze(sparse_mat) 61 | payload = sdt.serialize(sparse_mat) 62 | deserialized_data = sdt.deserialize(payload) 63 | assert np.array_equal(sparse_mat.todense(), deserialized_data.todense()) 64 | 65 | 66 | def test_write_read(sparse_mat): 67 | sdt = DataAnalyzer.analyze(sparse_mat) 68 | sdt = sdt.bind(sparse_mat) 69 | data_write_read_check( 70 | sdt, custom_eq=lambda x, y: np.array_equal(x.todense(), y.todense()) 71 | ) 72 | 73 | 74 | @pytest.mark.parametrize( 75 | "obj", 76 | [ 77 | 1, # wrong type 78 | csr_matrix( 79 | ([1], ([1], [0])), shape=(3, 3), dtype="float64" 80 | ), # wrong dtype 81 | csr_matrix( 82 | ([1], ([1], [0])), shape=(2, 2), dtype="float32" 83 | ), # wrong shape 84 | ], 85 | ) 86 | def test_serialize_failure(sparse_mat, obj): 87 | sdt = DataAnalyzer.analyze(sparse_mat) 88 | with pytest.raises(SerializationError): 89 | sdt.serialize(obj) 90 | 91 | 92 | @pytest.mark.parametrize( 93 | "obj", [1, ([1, 1], ([0, 6], [1, 6]))] # wrong type # wrong shape 94 | ) 95 | def test_desiarilze_failure(sparse_data_type, obj): 96 | with pytest.raises(DeserializationError): 97 | sparse_data_type.deserialize(obj) 98 | -------------------------------------------------------------------------------- /tests/contrib/test_streamlit.py: -------------------------------------------------------------------------------- 1 | from typing import List 2 | 3 | from pydantic import BaseModel 4 | 5 | from mlem.contrib.streamlit.server import StreamlitServer 6 | from mlem.contrib.streamlit.utils import augment_model 7 | 8 | 9 | def test_augment_model(): 10 | class M1(BaseModel): 11 | field: str 12 | 13 | aug, model = augment_model(M1) 14 | assert model == M1 15 | assert aug(1) == 1 16 | 17 | class M2(BaseModel): 18 | field: List[str] 19 | 20 | aug, model = augment_model(M2) 21 | assert model == str 22 | assert aug("1") == M2(field=["1"]) 23 | 24 | class M3(BaseModel): 25 | field: List[M1] 26 | 27 | aug, model = augment_model(M3) 28 | assert model == M1 29 | assert aug(M1(field="1")) == M3(field=[M1(field="1")]) 30 | 31 | class M4(BaseModel): 32 | field: List[str] 33 | field2: List[str] 34 | 35 | aug, model = augment_model(M4) 36 | assert model is None 37 | 38 | 39 | def test_custom_template(tmpdir): 40 | template_path = str(tmpdir / "template") 41 | with open(template_path, "w", encoding="utf8") as f: 42 | f.write( 43 | """{{page_title}} 44 | {{title}} 45 | {{description}} 46 | {{server_host}} 47 | {{server_port}} 48 | {{custom_arg}}""" 49 | ) 50 | server = StreamlitServer( 51 | template=template_path, 52 | page_title="page title", 53 | title="title", 54 | description="description", 55 | server_host="host", 56 | server_port=0, 57 | args={"custom_arg": "custom arg"}, 58 | ) 59 | path = str(tmpdir / "script") 60 | server._write_streamlit_script(path) # pylint: disable=protected-access 61 | 62 | with open(path, encoding="utf8") as f: 63 | assert ( 64 | f.read() 65 | == """page title 66 | title 67 | description 68 | host 69 | 0 70 | custom arg""" 71 | ) 72 | -------------------------------------------------------------------------------- /tests/contrib/test_venv.py: -------------------------------------------------------------------------------- 1 | import os 2 | import re 3 | import sys 4 | 5 | import pytest 6 | 7 | from mlem.contrib.venv import ( 8 | CondaBuilder, 9 | CondaPackageRequirement, 10 | VenvBuilder, 11 | ) 12 | from mlem.core.errors import MlemError 13 | from mlem.core.requirements import InstallableRequirement 14 | from tests.conftest import long 15 | from tests.contrib.conftest import conda_test 16 | 17 | 18 | @pytest.fixture 19 | def sys_prefix_path(tmp_path): 20 | old_sys_prefix = sys.prefix 21 | path = str(tmp_path / "venv-act") 22 | sys.prefix = os.path.abspath(path) 23 | 24 | yield path 25 | 26 | sys.prefix = old_sys_prefix 27 | 28 | 29 | def process_conda_list_output(installed_pkgs): 30 | def get_words(line): 31 | return re.findall(r"[^\s]+", line) 32 | 33 | words = [get_words(x) for x in installed_pkgs.splitlines()[3:]] 34 | keys = [] 35 | vals = [] 36 | for w in words: 37 | if len(w) >= 4: 38 | keys.append(w[0]) 39 | vals.append(w[3]) 40 | result = dict(zip(keys, vals)) 41 | return result 42 | 43 | 44 | @conda_test 45 | def test_build_conda(tmp_path, model_meta): 46 | path = str(tmp_path / "conda-env") 47 | builder = CondaBuilder( 48 | target=path, 49 | conda_reqs=[CondaPackageRequirement(package_name="xtensor")], 50 | ) 51 | env_dir = builder.build(model_meta) 52 | installed_pkgs = builder.get_installed_packages(env_dir).decode() 53 | pkgs_info = process_conda_list_output(installed_pkgs) 54 | for each_req in model_meta.requirements: 55 | if isinstance(each_req, InstallableRequirement): 56 | assert pkgs_info[each_req.package] == "pypi" 57 | elif isinstance(each_req, CondaPackageRequirement): 58 | assert pkgs_info[each_req.package_name] == each_req.channel_name 59 | 60 | 61 | @long 62 | def test_build_venv(tmp_path, model_meta): 63 | path = str(tmp_path / "venv") 64 | builder = VenvBuilder(target=path) 65 | env_dir = builder.build(model_meta) 66 | installed_pkgs = set( 67 | builder.get_installed_packages(env_dir).decode().splitlines() 68 | ) 69 | required_pkgs = set(model_meta.requirements.to_pip()) 70 | assert required_pkgs.issubset(installed_pkgs) 71 | 72 | 73 | def test_install_in_current_venv_not_active(tmp_path, model_meta): 74 | path = str(tmp_path / "venv") 75 | builder = VenvBuilder(target=path, current_env=True) 76 | with pytest.raises(MlemError, match="No virtual environment detected"): 77 | builder.build(model_meta) 78 | 79 | 80 | @long 81 | def test_install_in_current_active_venv(sys_prefix_path, model_meta): 82 | builder = VenvBuilder(target=sys_prefix_path) 83 | env_dir = os.path.abspath(sys_prefix_path) 84 | builder.create_virtual_env() 85 | assert builder.get_installed_packages(env_dir).decode() == "" 86 | os.environ["VIRTUAL_ENV"] = env_dir 87 | builder.current_env = True 88 | builder.build(model_meta) 89 | installed_pkgs = ( 90 | builder.get_installed_packages(env_dir).decode().splitlines() 91 | ) 92 | for each_req in model_meta.requirements.to_pip(): 93 | assert each_req in installed_pkgs 94 | -------------------------------------------------------------------------------- /tests/core/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/iterative/mlem/9aa89423f6f0f38dba50c3dc04fc7674738f1260/tests/core/__init__.py -------------------------------------------------------------------------------- /tests/core/conftest.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/iterative/mlem/9aa89423f6f0f38dba50c3dc04fc7674738f1260/tests/core/conftest.py -------------------------------------------------------------------------------- /tests/core/custom_requirements/model_trainer.py: -------------------------------------------------------------------------------- 1 | """ 2 | Here we test that functions local imports are also collected with another level of indirection 3 | """ 4 | import six # pylint: disable=unused-import # noqa 5 | from sklearn.linear_model import LinearRegression 6 | 7 | LR = LinearRegression() 8 | 9 | 10 | def model(data): 11 | from proxy_pkg_import import ( # pylint: disable=unused-import # noqa 12 | pkg_func, 13 | ) 14 | 15 | pkg_func() 16 | assert hasattr(LR, "predict") 17 | return data 18 | -------------------------------------------------------------------------------- /tests/core/custom_requirements/pack_1/__init__.py: -------------------------------------------------------------------------------- 1 | """ 2 | This package tests import chains: 3 | test_pack_1.test_model -> test_pack_2 4 | test_pack_1.test_model -> test_pack_1.__init__ -> test_pack_1.test_model_type 5 | """ 6 | import numpy # noqa 7 | 8 | from .model import TestM # noqa 9 | -------------------------------------------------------------------------------- /tests/core/custom_requirements/pack_1/model.py: -------------------------------------------------------------------------------- 1 | from pack_2 import name 2 | 3 | 4 | class TestM: 5 | name = name 6 | 7 | def _init_(self, alpha: float, max_lag: int): 8 | self.alpha = alpha # pylint: disable=attribute-defined-outside-init 9 | self.max_lag = ( # pylint: disable=attribute-defined-outside-init 10 | max_lag 11 | ) 12 | -------------------------------------------------------------------------------- /tests/core/custom_requirements/pack_1/model_type.py: -------------------------------------------------------------------------------- 1 | import typing 2 | from typing import Any, Optional 3 | 4 | import numpy as np 5 | from pack_1 import model 6 | 7 | from mlem.core.model import ModelHook, ModelType, SimplePickleIO 8 | 9 | 10 | class TestModelType(ModelType, ModelHook): 11 | @classmethod 12 | def process( 13 | cls, 14 | obj: Any, 15 | sample_data: Optional[Any] = None, 16 | methods_sample_data: Optional[typing.Dict[str, Any]] = None, 17 | **kwargs 18 | ) -> ModelType: 19 | return TestModelType(io=SimplePickleIO(), methods={}) 20 | 21 | @classmethod 22 | def is_object_valid(cls, obj: Any) -> bool: 23 | return isinstance(obj, model.TestM) 24 | 25 | def m1(self) -> typing.Dict[str, str]: 26 | return {} 27 | 28 | def m2(self, data: np.array): # pylint: disable=unused-argument 29 | return 30 | -------------------------------------------------------------------------------- /tests/core/custom_requirements/pack_2/__init__.py: -------------------------------------------------------------------------------- 1 | name = "name" 2 | -------------------------------------------------------------------------------- /tests/core/custom_requirements/pkg/__init__.py: -------------------------------------------------------------------------------- 1 | """ 2 | This package tests import chain pkg -> pkg.__init__ -> pgk.impl -> pkg.subpkg -> pkg.subpkg.__init__ -> pkg.subpkg.impl 3 | """ 4 | from .impl import pkg_func # noqa 5 | -------------------------------------------------------------------------------- /tests/core/custom_requirements/pkg/impl.py: -------------------------------------------------------------------------------- 1 | from .subpkg import subpkg_func 2 | 3 | 4 | def pkg_func(): 5 | return subpkg_func() 6 | -------------------------------------------------------------------------------- /tests/core/custom_requirements/pkg/subpkg/__init__.py: -------------------------------------------------------------------------------- 1 | from .impl import subpkg_func # noqa 2 | -------------------------------------------------------------------------------- /tests/core/custom_requirements/pkg/subpkg/impl.py: -------------------------------------------------------------------------------- 1 | import isort # pylint: disable=unused-import # noqa 2 | 3 | 4 | def subpkg_func(): 5 | pass 6 | -------------------------------------------------------------------------------- /tests/core/custom_requirements/pkg/subpkg/testfile.json: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/iterative/mlem/9aa89423f6f0f38dba50c3dc04fc7674738f1260/tests/core/custom_requirements/pkg/subpkg/testfile.json -------------------------------------------------------------------------------- /tests/core/custom_requirements/pkg_import.py: -------------------------------------------------------------------------------- 1 | """ 2 | Yet another level of indirection 3 | """ 4 | from pkg import pkg_func # pylint: disable=unused-import # noqa 5 | -------------------------------------------------------------------------------- /tests/core/custom_requirements/proxy_model.py: -------------------------------------------------------------------------------- 1 | """ 2 | We import model from this module to check that intermediate reqs (pandas and this module) are not collected 3 | """ 4 | import pandas # pylint: disable=unused-import # noqa 5 | from model_trainer import model # pylint: disable=unused-import # noqa 6 | -------------------------------------------------------------------------------- /tests/core/custom_requirements/proxy_pkg_import.py: -------------------------------------------------------------------------------- 1 | """ 2 | Yet another level of indirection 3 | """ 4 | from pkg_import import pkg_func # pylint: disable=unused-import # noqa 5 | -------------------------------------------------------------------------------- /tests/core/custom_requirements/shell_reqs.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | import mlem 4 | 5 | 6 | def translate(text: str): 7 | """ 8 | Translate 9 | """ 10 | return " ".join(np.random.choice(list("abcdefg")) for _ in text.split()) 11 | 12 | 13 | mlem.api.save(translate, "model", sample_data="Woof woof!") 14 | -------------------------------------------------------------------------------- /tests/core/custom_requirements/test_remote_custom_model.py: -------------------------------------------------------------------------------- 1 | import posixpath 2 | 3 | from mlem.core.metadata import load_meta 4 | from tests.conftest import MLEM_TEST_REPO, long, need_test_repo_auth 5 | 6 | 7 | @long 8 | @need_test_repo_auth 9 | def test_remote_custom_model(current_test_branch): 10 | model_meta = load_meta( 11 | "custom_model", 12 | project=posixpath.join(MLEM_TEST_REPO, "custom_model"), 13 | rev=current_test_branch, 14 | ) 15 | model_meta.load_value() 16 | model = model_meta.get_value() 17 | assert model.predict("b") == "ba" 18 | -------------------------------------------------------------------------------- /tests/core/custom_requirements/test_requirements.py: -------------------------------------------------------------------------------- 1 | import os 2 | import shutil 3 | import subprocess 4 | 5 | import dill 6 | from pack_1 import TestM 7 | 8 | from mlem.core.metadata import get_object_metadata 9 | from mlem.core.objects import MlemModel 10 | from mlem.utils.module import get_object_requirements 11 | 12 | 13 | def test_requirements_analyzer__custom_modules(): 14 | """Import chains: 15 | proxy_model (skipped) -> {pandas(skipped), model_trainer} 16 | model_trainer -> {six, sklearn, proxy_pkg_import (local import in function)} 17 | proxy_pkg_import -> pkg_import 18 | pkg_import -> pkg 19 | pkg -> all of pkg 20 | pkg.subpkg.impl -> isort 21 | """ 22 | import catboost # pylint: disable=unused-import # noqa 23 | import unused_code # pylint: disable=unused-import # noqa 24 | from proxy_model import model 25 | 26 | reqs = get_object_requirements(model) 27 | 28 | custom_reqs = {req.name for req in reqs.custom} 29 | # "test_cases" appears here as this code is imported by pytest 30 | # __main__ modules won't appear here 31 | assert { 32 | "model_trainer", 33 | "proxy_pkg_import", 34 | "pkg_import", 35 | "pkg", 36 | } == custom_reqs 37 | 38 | inst_reqs = {req.package for req in reqs.installable} 39 | assert {"scikit-learn", "six", "isort"} == inst_reqs 40 | 41 | 42 | def test_requirements_analyzer__model_works(tmpdir): 43 | from proxy_model import model 44 | 45 | reqs = get_object_requirements(model) 46 | 47 | reqs.materialize_custom(tmpdir) 48 | assert os.path.exists( 49 | os.path.join(tmpdir, "pkg", "subpkg", "testfile.json") 50 | ) 51 | 52 | with open(os.path.join(tmpdir, "model.pkl"), "wb") as f: 53 | dill.dump(model, f) 54 | 55 | shutil.copy( 56 | os.path.join(os.path.dirname(__file__), "use_model.py"), tmpdir 57 | ) 58 | 59 | cp = subprocess.run( 60 | "python use_model.py", shell=True, cwd=tmpdir, check=False 61 | ) 62 | assert cp.returncode == 0 63 | 64 | 65 | def test_model_custom_requirements(tmpdir): 66 | from pack_1.model_type import ( # pylint: disable=unused-import # noqa 67 | TestModelType, 68 | ) 69 | 70 | model = get_object_metadata(TestM(), 1) 71 | assert isinstance(model, MlemModel) 72 | 73 | model.dump(os.path.join(tmpdir, "model")) 74 | model.requirements.materialize_custom(tmpdir) 75 | shutil.copy( 76 | os.path.join(os.path.dirname(__file__), "use_model_meta.py"), tmpdir 77 | ) 78 | 79 | cp = subprocess.run( 80 | "python use_model_meta.py", shell=True, cwd=tmpdir, check=False 81 | ) 82 | assert cp.returncode == 0, cp.stderr 83 | 84 | assert {x.name for x in model.requirements.custom} == {"pack_1", "pack_2"} 85 | assert {x.module for x in model.requirements.installable} == {"numpy"} 86 | 87 | 88 | # Copyright 2019 Zyfra 89 | # Copyright 2021 Iterative 90 | # 91 | # Licensed under the Apache License, Version 2.0 (the "License"); 92 | # you may not use this file except in compliance with the License. 93 | # You may obtain a copy of the License at 94 | # 95 | # http://www.apache.org/licenses/LICENSE-2.0 96 | # 97 | # Unless required by applicable law or agreed to in writing, software 98 | # distributed under the License is distributed on an "AS IS" BASIS, 99 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 100 | # See the License for the specific language governing permissions and 101 | # limitations under the License. 102 | -------------------------------------------------------------------------------- /tests/core/custom_requirements/test_shell_reqs.py: -------------------------------------------------------------------------------- 1 | import os.path 2 | import subprocess 3 | 4 | import numpy as np 5 | import pytest 6 | 7 | from mlem.core.metadata import load_meta, save 8 | from mlem.core.objects import MlemModel 9 | 10 | 11 | @pytest.fixture 12 | def script_code(): 13 | with open( 14 | os.path.join(os.path.dirname(__file__), "shell_reqs.py"), 15 | encoding="utf8", 16 | ) as f: 17 | return f.read() 18 | 19 | 20 | exec_param = pytest.mark.parametrize("executable", ["python", "ipython"]) 21 | 22 | 23 | @exec_param 24 | @pytest.mark.xfail 25 | def test_cmd(tmpdir, script_code, executable): 26 | res = subprocess.check_call([executable, "-c", script_code], cwd=tmpdir) 27 | assert res == 0 28 | save("a", os.path.join(tmpdir, "data")) 29 | subprocess.check_call(["mlem", "apply", "model", "data"], cwd=tmpdir) 30 | 31 | meta = load_meta(os.path.join(tmpdir, "model"), force_type=MlemModel) 32 | assert len(meta.requirements.__root__) == 1 33 | assert meta.requirements.to_pip() == [f"numpy=={np.__version__}"] 34 | 35 | 36 | @exec_param 37 | @pytest.mark.xfail 38 | def test_pipe(tmpdir, script_code, executable): 39 | res = subprocess.check_call( 40 | f"echo '{script_code}' | {executable}", cwd=tmpdir, shell=True 41 | ) 42 | assert res == 0 43 | save("a", os.path.join(tmpdir, "data")) 44 | subprocess.check_call(["mlem", "apply", "model", "data"], cwd=tmpdir) 45 | 46 | meta = load_meta(os.path.join(tmpdir, "model"), force_type=MlemModel) 47 | assert len(meta.requirements.__root__) == 1 48 | assert meta.requirements.to_pip() == [f"numpy=={np.__version__}"] 49 | 50 | 51 | @exec_param 52 | @pytest.mark.xfail 53 | def test_pipe_iter(tmpdir, script_code, executable): 54 | with subprocess.Popen(executable, stdin=subprocess.PIPE) as proc: 55 | for line in script_code.splitlines(keepends=True): 56 | proc.stdin.write(line.encode("utf8")) 57 | proc.communicate(b"exit()") 58 | assert proc.returncode == 0 59 | save("a", os.path.join(tmpdir, "data")) 60 | subprocess.check_call(["mlem", "apply", "model", "data"], cwd=tmpdir) 61 | 62 | meta = load_meta(os.path.join(tmpdir, "model"), force_type=MlemModel) 63 | assert len(meta.requirements.__root__) == 1 64 | assert meta.requirements.to_pip() == [f"numpy=={np.__version__}"] 65 | 66 | 67 | @exec_param 68 | @pytest.mark.xfail 69 | def test_script(tmpdir, script_code, executable): 70 | with open(tmpdir / "script.py", "w", encoding="utf8") as f: 71 | f.write(script_code) 72 | res = subprocess.check_call([executable, "script.py"], cwd=tmpdir) 73 | assert res == 0 74 | save("a", os.path.join(tmpdir, "data")) 75 | subprocess.check_call(["mlem", "apply", "model", "data"], cwd=tmpdir) 76 | 77 | meta = load_meta(os.path.join(tmpdir, "model"), force_type=MlemModel) 78 | assert len(meta.requirements.__root__) == 1 79 | assert meta.requirements.to_pip() == [f"numpy=={np.__version__}"] 80 | -------------------------------------------------------------------------------- /tests/core/custom_requirements/unused_code.py: -------------------------------------------------------------------------------- 1 | # this code is imported but not used and should not appear in requirements 2 | import pytest # pylint: disable=unused-import # noqa 3 | -------------------------------------------------------------------------------- /tests/core/custom_requirements/use_model.py: -------------------------------------------------------------------------------- 1 | """ 2 | This code runs in separate process in isolated dir with model and deps to check that we got all of them 3 | """ 4 | import dill 5 | 6 | if __name__ == "__main__": 7 | with open("model.pkl", "rb") as f: 8 | model = dill.load(f) 9 | 10 | model(1) 11 | -------------------------------------------------------------------------------- /tests/core/custom_requirements/use_model_meta.py: -------------------------------------------------------------------------------- 1 | """ 2 | This code runs in separate process in isolated dir with model and deps to check that we got all of them 3 | """ 4 | from mlem.core.metadata import load 5 | 6 | if __name__ == "__main__": 7 | model = load("model") 8 | 9 | assert model.__class__.__name__ == "TestM" 10 | -------------------------------------------------------------------------------- /tests/core/resources/emoji_model_inside.py: -------------------------------------------------------------------------------- 1 | def translate(text: str): 2 | """ 3 | Translate dog barks to emoji, as you hear them 4 | """ 5 | import emoji 6 | import numpy as np 7 | 8 | return " ".join( 9 | np.random.choice(list(emoji.EMOJI_DATA.keys())) for _ in text.split() # type: ignore 10 | ) 11 | 12 | 13 | def main(): 14 | import sys 15 | 16 | import mlem 17 | 18 | mlem.api.save(translate, sys.argv[1], sample_data="Woof woof!") 19 | 20 | 21 | if __name__ == "__main__": 22 | main() 23 | -------------------------------------------------------------------------------- /tests/core/resources/emoji_model_outside.py: -------------------------------------------------------------------------------- 1 | import emoji 2 | import numpy as np 3 | 4 | 5 | def translate(text: str): 6 | """ 7 | Translate dog barks to emoji, as you hear them 8 | """ 9 | return " ".join( 10 | np.random.choice(list(emoji.EMOJI_DATA.keys())) for _ in text.split() # type: ignore 11 | ) 12 | 13 | 14 | def main(): 15 | import sys 16 | 17 | import mlem 18 | 19 | mlem.api.save(translate, sys.argv[1], sample_data="Woof woof!") 20 | 21 | 22 | if __name__ == "__main__": 23 | main() 24 | -------------------------------------------------------------------------------- /tests/core/resources/emoji_model_shell.py: -------------------------------------------------------------------------------- 1 | from importlib import import_module 2 | 3 | emoji = import_module("emoji") 4 | np = import_module("numpy") 5 | 6 | 7 | def translate(text: str): 8 | """ 9 | Translate dog barks to emoji, as you hear them 10 | """ 11 | return " ".join( 12 | np.random.choice(list(emoji.EMOJI_DATA.keys())) for _ in text.split() # type: ignore 13 | ) 14 | 15 | 16 | def main(): 17 | import sys 18 | 19 | import mlem 20 | 21 | mlem.api.save(translate, sys.argv[1], sample_data="Woof woof!") 22 | 23 | 24 | if __name__ == "__main__": 25 | main() 26 | -------------------------------------------------------------------------------- /tests/core/resources/file.txt: -------------------------------------------------------------------------------- 1 | a 2 | -------------------------------------------------------------------------------- /tests/core/resources/server.yaml: -------------------------------------------------------------------------------- 1 | type: fastapi 2 | port: 8081 3 | -------------------------------------------------------------------------------- /tests/core/test_artifacts.py: -------------------------------------------------------------------------------- 1 | import os.path 2 | 3 | import pytest 4 | from fsspec.implementations.local import LocalFileSystem 5 | from s3fs import S3FileSystem 6 | 7 | from mlem.core.artifacts import ( 8 | FSSpecArtifact, 9 | FSSpecStorage, 10 | LocalArtifact, 11 | LocalStorage, 12 | ) 13 | from tests.conftest import long, need_aws_auth, resource_path 14 | 15 | 16 | @long 17 | @need_aws_auth 18 | def test_fsspec_backend_s3_upload(tmpdir, s3_tmp_path, s3_storage): 19 | target = os.path.basename(s3_tmp_path("upload")) 20 | resource = resource_path(__file__, "file.txt") 21 | artifact = s3_storage.upload(resource, target) 22 | assert isinstance(artifact, FSSpecArtifact) 23 | assert artifact.hash != "" 24 | assert artifact.size > 0 25 | local_target = str(tmpdir / "file.txt") 26 | artifact.materialize(local_target) 27 | with open(local_target, "r", encoding="utf8") as actual, open( 28 | resource, "r", encoding="utf8" 29 | ) as expected: 30 | assert actual.read() == expected.read() 31 | 32 | 33 | @long 34 | @need_aws_auth 35 | def test_fsspec_backend_s3_open(s3_tmp_path, s3_storage): 36 | target = os.path.basename(s3_tmp_path("open")) 37 | with s3_storage.open(target) as (f, artifact): 38 | f.write(b"a") 39 | assert isinstance(artifact, FSSpecArtifact) 40 | assert artifact.hash != "" 41 | assert artifact.size > 0 42 | with artifact.open() as f: 43 | assert f.read() == b"a" 44 | 45 | 46 | @pytest.mark.parametrize("fs", [LocalFileSystem(), S3FileSystem()]) 47 | def test_relative_storage_remote(fs): 48 | """This checks that if artifact path is absolute, 49 | it will stay that way if meta is stored locally or remotely. 50 | """ 51 | s3storage = FSSpecStorage(uri="s3://some_bucket") 52 | rel1 = s3storage.relative(fs, "some_path") 53 | assert rel1 == s3storage 54 | 55 | 56 | def test_relative_storage_local(): 57 | """This test case covers a scenario when meta was stored in remote storage 58 | and then was downloaded to local storage, but artifacts are still in the remote. 59 | Then the relative path to artifacts would be the path in the remote. 60 | """ 61 | local_storage = LocalStorage(uri="") 62 | rel1 = local_storage.relative(S3FileSystem(), "some_path") 63 | assert rel1 != local_storage 64 | assert isinstance(rel1, FSSpecStorage) 65 | assert rel1.uri == "s3://some_path" 66 | 67 | 68 | def test_local_storage_relative(tmpdir): 69 | storage = LocalStorage(uri=str(tmpdir)) 70 | rstorage = storage.relative(LocalFileSystem(), "subdir") 71 | with rstorage.open("file2") as (f, open_art): 72 | f.write(b"1") 73 | assert isinstance(open_art, LocalArtifact) 74 | assert open_art.hash != "" 75 | assert open_art.size > 0 76 | assert open_art.uri == "file2" 77 | assert os.path.isfile(os.path.join(tmpdir, "subdir", open_art.uri)) 78 | 79 | upload_art = rstorage.upload(__file__, "file") 80 | assert isinstance(upload_art, LocalArtifact) 81 | assert upload_art.uri == "file" 82 | assert upload_art.hash != "" 83 | assert upload_art.size > 0 84 | assert os.path.isfile(os.path.join(tmpdir, "subdir", upload_art.uri)) 85 | -------------------------------------------------------------------------------- /tests/core/test_data_io.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import pandas as pd 3 | import pytest 4 | 5 | from mlem.contrib.numpy import NumpyArrayReader, NumpyArrayWriter 6 | from mlem.contrib.pandas import PANDAS_FORMATS, PandasReader, PandasWriter 7 | from mlem.core.artifacts import FSSpecStorage 8 | from mlem.core.data_type import DataType 9 | 10 | 11 | def test_numpy_read_write(): 12 | data = np.array([1, 2, 3]) 13 | data_type = DataType.create(data) 14 | 15 | writer = NumpyArrayWriter() 16 | storage = FSSpecStorage(uri="memory://") 17 | reader, artifacts = writer.write(data_type, storage, "/data") 18 | 19 | assert isinstance(reader, NumpyArrayReader) 20 | assert len(artifacts) == 1 21 | assert storage.get_fs().exists("/data") 22 | 23 | data_type2 = reader.read(artifacts) 24 | assert isinstance(data_type2, DataType) 25 | assert data_type2 == data_type 26 | assert isinstance(data_type2.data, np.ndarray) 27 | assert np.array_equal(data_type2.data, data) 28 | 29 | 30 | @pytest.mark.parametrize("format", list(PANDAS_FORMATS.keys())) 31 | def test_pandas_read_write(format): 32 | data = pd.DataFrame([{"a": 1, "b": 2}]) 33 | data_type = DataType.create(data) 34 | storage = FSSpecStorage(uri="memory://") 35 | 36 | writer = PandasWriter(format=format) 37 | reader, artifacts = writer.write(data_type, storage, "/data") 38 | 39 | assert isinstance(reader, PandasReader) 40 | assert len(artifacts) == 1 41 | assert storage.get_fs().exists("/data") 42 | 43 | data_type2 = reader.read(artifacts) 44 | assert isinstance(data_type2, DataType) 45 | assert data_type2 == data_type 46 | assert isinstance(data_type2.data, pd.DataFrame) 47 | 48 | assert data_type2.data.equals(data) 49 | -------------------------------------------------------------------------------- /tests/core/test_model_type.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import pandas as pd 3 | import pytest 4 | from sklearn.linear_model import LinearRegression 5 | from sklearn.naive_bayes import GaussianNB 6 | 7 | from mlem.contrib.numpy import NumpyNdarrayType 8 | from mlem.contrib.pandas import DataFrameType 9 | from mlem.contrib.sklearn import SklearnModel 10 | from mlem.core.data_type import UnspecifiedDataType 11 | from mlem.core.model import ModelAnalyzer, Signature 12 | 13 | 14 | @pytest.mark.parametrize( 15 | "mtype", 16 | [ 17 | ( 18 | GaussianNB, 19 | [ 20 | "predict", 21 | "predict_proba", 22 | ], 23 | ), 24 | (LinearRegression, ["predict"]), 25 | ], 26 | ) 27 | def test_sklearn_model(mtype): 28 | cls, methods = mtype 29 | data = np.array([[1], [2]]) 30 | res = data[:, 0] 31 | model = cls().fit(data, res) 32 | assert SklearnModel.is_object_valid(model) 33 | mt = ModelAnalyzer.analyze(model, sample_data=data) 34 | assert isinstance(mt, SklearnModel) 35 | assert set(mt.methods.keys()) == set(methods) 36 | 37 | 38 | def test_infer_signatire_unspecified(model): 39 | signature = Signature.from_method(model.predict) 40 | assert signature.name == "predict" 41 | assert signature.returns == UnspecifiedDataType() 42 | assert len(signature.args) == 2 43 | arg = signature.args[0] 44 | assert arg.name == "X" 45 | assert arg.type_ == UnspecifiedDataType() 46 | 47 | 48 | def test_infer_signatire(model, train): 49 | signature = Signature.from_method(model.predict, auto_infer=True, X=train) 50 | assert signature.name == "predict" 51 | assert signature.returns == NumpyNdarrayType( 52 | shape=(None,), dtype=model.predict(train).dtype.name 53 | ) 54 | assert len(signature.args) == 2 55 | arg = signature.args[0] 56 | assert arg.name == "X" 57 | if isinstance(train, np.ndarray): 58 | assert arg.type_ == NumpyNdarrayType(shape=(None, 4), dtype="float64") 59 | elif isinstance(train, pd.DataFrame): 60 | assert arg.type_ == DataFrameType( 61 | columns=["0", "1", "2", "3"], 62 | dtypes=["float64", "float64", "float64", "float64"], 63 | index_cols=[], 64 | ) 65 | 66 | 67 | def test_infer_signature_var(): 68 | def func(*inputs): 69 | return inputs[0] 70 | 71 | signature = Signature.from_method(func, "aaa", auto_infer=True) 72 | assert signature.varargs == "inputs" 73 | assert signature.varargs_type == signature.returns 74 | 75 | def func_kw(**inputs): 76 | return inputs[next(iter(inputs))] 77 | 78 | signature = Signature.from_method(func_kw, key="aaa", auto_infer=True) 79 | assert signature.varkw == "inputs" 80 | assert signature.varkw_type == signature.returns 81 | -------------------------------------------------------------------------------- /tests/polydantic/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/iterative/mlem/9aa89423f6f0f38dba50c3dc04fc7674738f1260/tests/polydantic/__init__.py -------------------------------------------------------------------------------- /tests/polydantic/test_lazy.py: -------------------------------------------------------------------------------- 1 | from typing import Any, Dict, Optional 2 | 3 | import pytest 4 | from pydantic import BaseModel, ValidationError, parse_obj_as, validator 5 | 6 | from mlem.polydantic.lazy import LazyModel, lazy_field 7 | 8 | 9 | class Payload(BaseModel): 10 | value: int 11 | 12 | @validator("value") 13 | def counter(cls, value): # pylint: disable=no-self-argument # noqa: B902 14 | return value + 1 15 | 16 | 17 | class Model(BaseModel): 18 | field_cache: Any 19 | field, field_raw, field_cache = lazy_field( 20 | Payload, "field", "field_cache", ... 21 | ) 22 | 23 | 24 | def test_deserialization_and_cache(): 25 | payload = {"field": {"value": 1}} 26 | 27 | obj = parse_obj_as(Model, payload) 28 | 29 | assert isinstance(obj, Model) 30 | assert isinstance(obj.field_raw, dict) 31 | assert isinstance(obj.field, Payload) 32 | assert obj.field.value == 2 33 | assert obj.field.value == 2 34 | 35 | 36 | def test_laziness(): 37 | payload = {"field": {"value": "string"}} 38 | 39 | obj = parse_obj_as(Model, payload) 40 | 41 | assert isinstance(obj, Model) 42 | assert isinstance(obj.field_raw, dict) 43 | with pytest.raises(ValidationError): 44 | print(obj.field) 45 | 46 | 47 | def test_serialization(): 48 | obj = Model(field=Payload(value=0)) 49 | 50 | payload = obj.dict(by_alias=True) 51 | assert payload == {"field": {"value": 1}} 52 | assert isinstance(obj.field, Payload) 53 | assert isinstance(obj.field_raw, dict) 54 | 55 | 56 | def test_setting_value(): 57 | obj = Model(field=Payload(value=0)) 58 | 59 | obj.field.value = 2 60 | assert obj.field.value == 2 61 | assert obj.field_raw["value"] == 2 62 | 63 | 64 | class ModelWithOptional(LazyModel): 65 | field_cache: Optional[Dict] 66 | field, field_raw, field_cache = lazy_field( 67 | Payload, 68 | "field", 69 | "field_cache", 70 | parse_as_type=Optional[Payload], 71 | default=None, 72 | ) 73 | 74 | 75 | def test_setting_optional_field(): 76 | obj = ModelWithOptional() 77 | assert obj.field is None 78 | obj.field = Payload(value=0) 79 | assert obj.field.value == 1 80 | obj.field_raw = {"value": 5} 81 | assert obj.field.value == 6 82 | -------------------------------------------------------------------------------- /tests/resources/empty/.mlem.yaml: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/iterative/mlem/9aa89423f6f0f38dba50c3dc04fc7674738f1260/tests/resources/empty/.mlem.yaml -------------------------------------------------------------------------------- /tests/resources/storage/.mlem.yaml: -------------------------------------------------------------------------------- 1 | core: 2 | ADDITIONAL_EXTENSIONS: ext1 3 | storage: 4 | type: fsspec 5 | uri: s3://somebucket 6 | -------------------------------------------------------------------------------- /tests/runtime/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/iterative/mlem/9aa89423f6f0f38dba50c3dc04fc7674738f1260/tests/runtime/__init__.py -------------------------------------------------------------------------------- /tests/runtime/test_client.py: -------------------------------------------------------------------------------- 1 | import platform 2 | 3 | import numpy as np 4 | import pytest 5 | 6 | from mlem.constants import PREDICT_ARG_NAME, PREDICT_METHOD_NAME 7 | from mlem.contrib.numpy import NumpyNdarrayType 8 | from mlem.core.data_type import DataAnalyzer 9 | from mlem.core.errors import WrongMethodError 10 | from mlem.core.model import Argument, Signature 11 | from mlem.runtime import InterfaceMethod 12 | from mlem.runtime.client import HTTPClient 13 | 14 | 15 | @pytest.fixture 16 | def signature(train): 17 | data_type = DataAnalyzer.analyze(train) 18 | returns_type = NumpyNdarrayType( 19 | shape=(None,), 20 | dtype="int32" if platform.system() == "Windows" else "int64", 21 | ) 22 | kwargs = {"varkw": None} 23 | return Signature( 24 | name=PREDICT_METHOD_NAME, 25 | args=[Argument(name=PREDICT_ARG_NAME, type_=data_type)], 26 | returns=returns_type, 27 | **kwargs, 28 | ) 29 | 30 | 31 | @pytest.fixture 32 | def mlem_client(request_get_mock, request_post_mock): 33 | client = HTTPClient(host="", port=None) 34 | return client 35 | 36 | 37 | @pytest.mark.parametrize("port", [None, 80]) 38 | def test_mlem_client_base_url(port): 39 | client = HTTPClient(host="", port=port) 40 | assert client.base_url == f"http://:{port}" if port else "http://" 41 | 42 | 43 | @pytest.mark.parametrize("use_keyword", [False, True]) 44 | def test_interface_endpoint(mlem_client, train, signature, use_keyword): 45 | assert PREDICT_METHOD_NAME in mlem_client.methods 46 | assert mlem_client.methods[ 47 | PREDICT_METHOD_NAME 48 | ] == InterfaceMethod.from_signature(signature) 49 | if use_keyword: 50 | assert np.array_equal( 51 | getattr(mlem_client, PREDICT_METHOD_NAME)(data=train), 52 | np.array([0] * 50 + [1] * 50 + [2] * 50), 53 | ) 54 | else: 55 | assert np.array_equal( 56 | getattr(mlem_client, PREDICT_METHOD_NAME)(train), 57 | np.array([0] * 50 + [1] * 50 + [2] * 50), 58 | ) 59 | 60 | 61 | def test_wrong_endpoint(mlem_client): 62 | with pytest.raises(WrongMethodError): 63 | mlem_client.dummy_method() 64 | 65 | 66 | def test_data_validation_more_params_than_expected(mlem_client, train): 67 | with pytest.raises(ValueError) as e: 68 | getattr(mlem_client, PREDICT_METHOD_NAME)(train, 2) 69 | assert str(e.value) == "Too much parameters given, expected: 1" 70 | 71 | 72 | def test_data_validation_params_in_positional_and_keyword(mlem_client, train): 73 | with pytest.raises( 74 | ValueError, 75 | match="Parameters should be passed either in positional or in keyword fashion, not both", 76 | ): 77 | getattr(mlem_client, PREDICT_METHOD_NAME)(train, check_input=False) 78 | 79 | 80 | def test_data_validation_params_with_wrong_name(mlem_client, train): 81 | with pytest.raises(ValueError) as e: 82 | getattr(mlem_client, PREDICT_METHOD_NAME)(X=train) 83 | assert ( 84 | str(e.value) 85 | == 'Parameter with name "data" (position 0) should be passed' 86 | ) 87 | -------------------------------------------------------------------------------- /tests/runtime/test_interface.py: -------------------------------------------------------------------------------- 1 | from typing import Any, ClassVar 2 | 3 | import pytest 4 | 5 | import mlem 6 | from mlem.core.data_type import DataType, DataWriter 7 | from mlem.core.requirements import Requirements 8 | from mlem.runtime import Interface 9 | from mlem.runtime.interface import ( 10 | InterfaceArgument, 11 | InterfaceDataType, 12 | InterfaceMethod, 13 | SimpleInterface, 14 | expose, 15 | ) 16 | 17 | 18 | class Container(DataType): 19 | type: ClassVar[str] = "test_container" 20 | field: int 21 | 22 | def serialize( 23 | self, instance: Any # pylint: disable=unused-argument 24 | ) -> dict: 25 | return {} 26 | 27 | def deserialize(self, obj: dict) -> Any: 28 | pass 29 | 30 | def get_requirements(self) -> Requirements: 31 | return Requirements.new([]) 32 | 33 | def get_writer( 34 | self, project: str = None, filename: str = None, **kwargs 35 | ) -> DataWriter: 36 | raise NotImplementedError 37 | 38 | 39 | @pytest.fixture 40 | def interface() -> Interface: 41 | class MyInterface(SimpleInterface): 42 | @expose 43 | def method1(self, arg1: Container(field=5)) -> Container(field=5): # type: ignore[valid-type] 44 | self.method2() 45 | return arg1 46 | 47 | def method2(self): 48 | pass 49 | 50 | return MyInterface() 51 | 52 | 53 | def test_interface_descriptor__from_interface(interface: Interface): 54 | d = interface.get_descriptor() 55 | sig = InterfaceMethod( 56 | name="method1", 57 | args=[ 58 | InterfaceArgument( 59 | name="arg1", 60 | data_type=Container(field=5), 61 | ) 62 | ], 63 | returns=InterfaceDataType(data_type=Container(field=5)), 64 | ) 65 | assert d.__root__ == {"method1": sig} 66 | 67 | 68 | def test_interface_descriptor__to_dict(interface: Interface): 69 | d = interface.get_versioned_descriptor() 70 | 71 | assert d.dict() == { 72 | "version": mlem.__version__, 73 | "meta": None, 74 | "methods": { 75 | "method1": { 76 | "args": [ 77 | { 78 | "default": None, 79 | "name": "arg1", 80 | "required": True, 81 | "data_type": {"field": 5, "type": "test_container"}, 82 | "serializer": None, 83 | } 84 | ], 85 | "name": "method1", 86 | "returns": { 87 | "data_type": {"field": 5, "type": "test_container"}, 88 | "serializer": None, 89 | }, 90 | } 91 | }, 92 | } 93 | -------------------------------------------------------------------------------- /tests/runtime/test_model_interface.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import pandas as pd 3 | import pytest 4 | 5 | from mlem.contrib.sklearn import SklearnModel 6 | from mlem.core.objects import MAIN_PROCESSOR_NAME, MlemModel 7 | from mlem.core.requirements import Requirements 8 | from mlem.runtime.interface import ModelInterface 9 | 10 | 11 | class PandasModel: 12 | def __init__(self, prediction): 13 | self.prediction = prediction 14 | 15 | def predict(self, X: "pd.DataFrame"): 16 | assert isinstance(X, pd.DataFrame) 17 | return self.prediction 18 | 19 | 20 | @pytest.fixture 21 | def data(): 22 | return pd.DataFrame([{"a": 1, "b": 1}]) 23 | 24 | 25 | @pytest.fixture 26 | def prediction(data): 27 | return np.array([[0.5 for _ in range(data.size)]]) 28 | 29 | 30 | @pytest.fixture 31 | def pd_model(data, prediction): 32 | return MlemModel( 33 | processors={ 34 | MAIN_PROCESSOR_NAME: SklearnModel.process( 35 | PandasModel(prediction), sample_data=data 36 | ) 37 | }, 38 | call_orders={"predict": [(MAIN_PROCESSOR_NAME, "predict")]}, 39 | requirements=Requirements.new(), 40 | ) 41 | 42 | 43 | def test_interface_types(pd_model: MlemModel, data, prediction): 44 | interface = ModelInterface.from_model(pd_model) 45 | # assert interface.exposed_method_docs('predict') == pd_model.description 46 | # TODO: https://github.com/iterative/mlem/issues/43 47 | pred = interface.execute("predict", {"X": data}) 48 | assert (pred == prediction).all() 49 | 50 | 51 | def test_with_serde(pd_model: MlemModel): 52 | interface = ModelInterface.from_model(pd_model) 53 | 54 | obj = {"values": [{"a": 1, "b": 1}]} 55 | 56 | data_type = pd_model.model_type.methods["predict"].args[0].type_ 57 | data = data_type.get_serializer().deserialize(obj) 58 | 59 | interface.execute("predict", {"X": data}) 60 | -------------------------------------------------------------------------------- /tests/test_config.py: -------------------------------------------------------------------------------- 1 | import posixpath 2 | 3 | from mlem.config import MlemConfig, project_config 4 | from mlem.constants import MLEM_CONFIG_FILE_NAME 5 | from mlem.contrib.fastapi import FastAPIServer 6 | from mlem.core.artifacts import FSSpecStorage, LocalStorage 7 | from mlem.core.meta_io import get_fs 8 | from tests.conftest import long, need_aws_auth 9 | 10 | 11 | def test_loading_storage(set_mlem_project_root): 12 | set_mlem_project_root("storage") 13 | config = MlemConfig() 14 | assert config.additional_extensions == ["ext1"] 15 | assert config.storage == FSSpecStorage(uri="s3://somebucket") 16 | 17 | 18 | def test_loading_empty(set_mlem_project_root): 19 | set_mlem_project_root("empty") 20 | config = MlemConfig() 21 | assert isinstance(config.storage, LocalStorage) 22 | 23 | 24 | @need_aws_auth 25 | @long 26 | def test_loading_remote(s3_tmp_path, s3_storage_fs): 27 | project = s3_tmp_path("remote_conf") 28 | fs, path = get_fs(project) 29 | path = posixpath.join(path, MLEM_CONFIG_FILE_NAME) 30 | with fs.open(path, "w") as f: 31 | f.write("core:\n ADDITIONAL_EXTENSIONS: ext1\n") 32 | assert project_config(path, fs=fs).additional_extensions == ["ext1"] 33 | 34 | 35 | def test_default_server(): 36 | assert project_config("").server == FastAPIServer() 37 | -------------------------------------------------------------------------------- /tests/test_setup.py: -------------------------------------------------------------------------------- 1 | import dvc 2 | import importlib_metadata 3 | from packaging import version 4 | 5 | from setup import extras 6 | 7 | 8 | def test_dvc_extras(): 9 | # previous to 2.15 DVC had a typo in extras 10 | if version.parse(dvc.__version__) > version.parse("2.15"): 11 | # importlib_metadata checks the locally installed package, 12 | # so this may pass locally, but fail in CI 13 | correct_extras = { 14 | f"dvc-{e}": [f"dvc[{e}]~=2.0"] 15 | for e in importlib_metadata.metadata("dvc").get_all( 16 | "Provides-Extra" 17 | ) 18 | if e not in {"all", "dev", "lint", "terraform", "tests", "testing"} 19 | } 20 | specified_extras = { 21 | e: l for e, l in extras.items() if e[: len("dvc-")] == "dvc-" 22 | } 23 | assert correct_extras == specified_extras 24 | -------------------------------------------------------------------------------- /tests/test_telemetry.py: -------------------------------------------------------------------------------- 1 | from mlem.telemetry import telemetry 2 | 3 | 4 | def test_is_enabled(): 5 | assert not telemetry.is_enabled() 6 | -------------------------------------------------------------------------------- /tests/utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/iterative/mlem/9aa89423f6f0f38dba50c3dc04fc7674738f1260/tests/utils/__init__.py -------------------------------------------------------------------------------- /tests/utils/module_tools_mock_req.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/iterative/mlem/9aa89423f6f0f38dba50c3dc04fc7674738f1260/tests/utils/module_tools_mock_req.py -------------------------------------------------------------------------------- /tests/utils/test_entrypoints.py: -------------------------------------------------------------------------------- 1 | from abc import abstractmethod 2 | 3 | from mlem.core.base import MlemABC 4 | from mlem.core.objects import MlemEnv, MlemObject 5 | from mlem.utils.entrypoints import list_implementations 6 | 7 | 8 | class MockABC(MlemABC): 9 | abs_name = "mock" 10 | 11 | class Config: 12 | type_root = True 13 | 14 | @abstractmethod 15 | def something(self): 16 | pass 17 | 18 | 19 | class MockImpl(MockABC): 20 | type = "impl" 21 | 22 | def something(self): 23 | pass 24 | 25 | 26 | def test_list_implementations(): 27 | assert list_implementations(MockABC) == ["impl"] 28 | assert list_implementations("mock") == ["impl"] 29 | 30 | 31 | def test_list_implementations_meta(): 32 | assert "model" in list_implementations("meta") 33 | assert "model" in list_implementations(MlemObject) 34 | 35 | assert "docker" in list_implementations("meta", MlemEnv) 36 | assert "docker" in list_implementations(MlemObject, MlemEnv) 37 | 38 | assert "docker" in list_implementations("meta", "env") 39 | assert "docker" in list_implementations(MlemObject, "env") 40 | -------------------------------------------------------------------------------- /tests/utils/test_fslock.py: -------------------------------------------------------------------------------- 1 | import os 2 | import time 3 | from threading import Thread 4 | 5 | from fsspec.implementations.local import LocalFileSystem 6 | 7 | from mlem.utils.fslock import LOCK_EXT, FSLock 8 | from mlem.utils.path import make_posix 9 | 10 | NAME = "testlock" 11 | 12 | 13 | # pylint: disable=protected-access 14 | def test_fslock(tmpdir): 15 | fs = LocalFileSystem() 16 | lock = FSLock(fs, tmpdir, NAME) 17 | 18 | with lock: 19 | assert lock._timestamp is not None 20 | assert lock._salt is not None 21 | lock_path = make_posix( 22 | os.path.join( 23 | tmpdir, f"{NAME}.{lock._timestamp}.{lock._salt}.{LOCK_EXT}" 24 | ) 25 | ) 26 | assert lock.lock_path == lock_path 27 | assert fs.exists(lock_path) 28 | 29 | assert lock._timestamp is None 30 | assert lock._salt is None 31 | assert not fs.exists(lock_path) 32 | 33 | 34 | def _work(dirname, num): 35 | time.sleep(0.3 + num / 5) 36 | with FSLock(LocalFileSystem(), dirname, NAME, salt=num): 37 | path = os.path.join(dirname, NAME) 38 | if os.path.exists(path): 39 | with open(path, "r+", encoding="utf8") as f: 40 | data = f.read() 41 | else: 42 | data = "" 43 | time.sleep(0.05) 44 | with open(path, "w", encoding="utf8") as f: 45 | f.write(data + f"{num}\n") 46 | 47 | 48 | def test_fslock_concurrent(tmpdir): 49 | start = 0 50 | end = 10 51 | threads = [ 52 | Thread(target=_work, args=(tmpdir, n)) for n in range(start, end) 53 | ] 54 | for t in threads: 55 | t.start() 56 | for t in threads: 57 | t.join() 58 | with open(os.path.join(tmpdir, NAME), encoding="utf8") as f: 59 | data = f.read() 60 | 61 | assert data.splitlines() == [str(i) for i in range(start, end)] 62 | assert os.listdir(tmpdir) == [NAME] 63 | -------------------------------------------------------------------------------- /tests/utils/test_path.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | 3 | from mlem.utils.path import make_posix 4 | 5 | 6 | @pytest.mark.parametrize( 7 | "path, result", 8 | [ 9 | ("relative/posix/path", "relative/posix/path"), 10 | ("relative\\nt\\path", "relative/nt/path"), 11 | ("relative/nt/path", "relative/nt/path"), 12 | ("/abs/posix/path", "/abs/posix/path"), 13 | ("c:\\abs\\nt\\path", "c:/abs/nt/path"), 14 | ("c:/abs/nt/path", "c:/abs/nt/path"), 15 | ("/aaa\\bbb", "/aaa/bbb"), 16 | ("mixed\\nt/path", "mixed/nt/path"), 17 | ("\\aaa\\bbb", "/aaa/bbb"), 18 | ("\\aaa/bbb", "/aaa/bbb"), 19 | ("c:\\mixed/nt/path", "c:/mixed/nt/path"), 20 | ("c:/mixed\\nt\\path", "c:/mixed/nt/path"), 21 | ("", ""), 22 | ("aaa", "aaa"), 23 | ], 24 | ) 25 | def test_make_posix(path, result): 26 | assert make_posix(path) == result 27 | -------------------------------------------------------------------------------- /tests/utils/test_root.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import pytest 4 | 5 | from mlem.core.errors import MlemProjectNotFound 6 | from mlem.utils.root import find_project_root 7 | 8 | 9 | def test_find_root(mlem_project): 10 | path = os.path.join(mlem_project, "subdir", "subdir") 11 | os.makedirs(path, exist_ok=True) 12 | assert find_project_root(path) == mlem_project 13 | 14 | 15 | def test_find_root_error(): 16 | path = os.path.dirname(__file__) 17 | with pytest.raises(MlemProjectNotFound): 18 | find_project_root(path, raise_on_missing=True) 19 | assert find_project_root(path, raise_on_missing=False) is None 20 | 21 | 22 | def test_find_root_strict(mlem_project): 23 | assert find_project_root(mlem_project, recursive=False) == mlem_project 24 | with pytest.raises(MlemProjectNotFound): 25 | find_project_root( 26 | os.path.join(mlem_project, "subdir"), recursive=False 27 | ) 28 | -------------------------------------------------------------------------------- /tests/utils/test_save.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 1, 6 | "metadata": { 7 | "execution": { 8 | "iopub.execute_input": "2023-02-13T14:11:29.261665Z", 9 | "iopub.status.busy": "2023-02-13T14:11:29.261394Z", 10 | "iopub.status.idle": "2023-02-13T14:11:29.267566Z", 11 | "shell.execute_reply": "2023-02-13T14:11:29.266734Z" 12 | }, 13 | "pycharm": { 14 | "name": "#%%\n" 15 | } 16 | }, 17 | "outputs": [], 18 | "source": [ 19 | "import numpy as np\n", 20 | "\n", 21 | "def func(data):\n", 22 | " return bool(np.all([True]))\n" 23 | ] 24 | }, 25 | { 26 | "cell_type": "code", 27 | "execution_count": 2, 28 | "metadata": { 29 | "execution": { 30 | "iopub.execute_input": "2023-02-13T14:11:29.270506Z", 31 | "iopub.status.busy": "2023-02-13T14:11:29.270233Z", 32 | "iopub.status.idle": "2023-02-13T14:11:29.969468Z", 33 | "shell.execute_reply": "2023-02-13T14:11:29.968705Z" 34 | }, 35 | "pycharm": { 36 | "is_executing": true, 37 | "name": "#%%\n" 38 | } 39 | }, 40 | "outputs": [ 41 | { 42 | "name": "stdout", 43 | "output_type": "stream", 44 | "text": [ 45 | "numpy==1.22.4\n" 46 | ] 47 | } 48 | ], 49 | "source": [ 50 | "from mlem.utils.module import get_object_requirements\n", 51 | "\n", 52 | "res = get_object_requirements(func)\n", 53 | "\n", 54 | "print(\" \".join(res.to_pip()))" 55 | ] 56 | } 57 | ], 58 | "metadata": { 59 | "kernelspec": { 60 | "display_name": "Python 3", 61 | "language": "python", 62 | "name": "python3" 63 | }, 64 | "language_info": { 65 | "codemirror_mode": { 66 | "name": "ipython", 67 | "version": 3 68 | }, 69 | "file_extension": ".py", 70 | "mimetype": "text/x-python", 71 | "name": "python", 72 | "nbconvert_exporter": "python", 73 | "pygments_lexer": "ipython3", 74 | "version": "3.9.13" 75 | } 76 | }, 77 | "nbformat": 4, 78 | "nbformat_minor": 0 79 | } 80 | -------------------------------------------------------------------------------- /tox.ini: -------------------------------------------------------------------------------- 1 | [tox] 2 | envlist = py38 3 | 4 | [testenv] 5 | commands = pytest 6 | extras = tests 7 | passenv = 8 | GITHUB_USERNAME 9 | GITHUB_TOKEN 10 | --------------------------------------------------------------------------------