├── .github └── workflows │ ├── release.yml │ └── test.yml ├── .gitignore ├── LICENSE ├── Makefile ├── README.md ├── examples └── huggingface_playground.py ├── marimo_labs ├── __init__.py └── huggingface │ ├── README.md │ ├── __init__.py │ ├── _client_utils.py │ ├── _load.py │ ├── _load_utils.py │ ├── _outputs.py │ └── _processing_utils.py ├── pyproject.toml ├── scripts └── pyfix.sh └── tests └── huggingface ├── conftest.py ├── media_data.py └── test_processing_utils.py /.github/workflows/release.yml: -------------------------------------------------------------------------------- 1 | name: Publish release 2 | 3 | # release a new version of marimo_labs on tag push 4 | on: 5 | push: 6 | tags: 7 | - "[0-9]+.[0-9]+.[0-9]+" 8 | 9 | jobs: 10 | publish_release: 11 | name: 📤 Publish release 12 | runs-on: ubuntu-latest 13 | environment: release 14 | permissions: 15 | # IMPORTANT: this permission is mandatory for trusted publishing 16 | id-token: write 17 | defaults: 18 | run: 19 | shell: bash 20 | 21 | steps: 22 | - name: ⬇️ Checkout repo 23 | uses: actions/checkout@v4 24 | 25 | - name: 🐍 Setup Python 3.10 26 | uses: actions/setup-python@v5 27 | with: 28 | python-version: "3.10" 29 | cache: "pip" 30 | 31 | - name: 📦 Build marimo labs 32 | run: | 33 | python -m pip install --upgrade pip 34 | pip install . 35 | pip install build 36 | python -m build 37 | 38 | - name: Publish package distributions to PyPI 39 | uses: pypa/gh-action-pypi-publish@release/v1 40 | -------------------------------------------------------------------------------- /.github/workflows/test.yml: -------------------------------------------------------------------------------- 1 | name: test 2 | 3 | on: 4 | push: 5 | branches: [main] 6 | pull_request: {} 7 | 8 | jobs: 9 | test_python: 10 | name: Tests on ${{ matrix.os }}, Python ${{ matrix.python-version }} 11 | runs-on: ${{ matrix.os }} 12 | timeout-minutes: 10 13 | defaults: 14 | run: 15 | shell: bash 16 | 17 | strategy: 18 | matrix: 19 | os: [ubuntu-latest] 20 | python-version: ["3.8"] 21 | include: 22 | - os: ubuntu-latest 23 | python-version: "3.9" 24 | - os: ubuntu-latest 25 | python-version: "3.10" 26 | - os: ubuntu-latest 27 | python-version: "3.11" 28 | - os: ubuntu-latest 29 | python-version: "3.12" 30 | steps: 31 | - uses: actions/checkout@v4 32 | - name: Set up Python ${{ matrix.python-version }} 33 | uses: actions/setup-python@v5 34 | with: 35 | python-version: ${{ matrix.python-version }} 36 | cache: "pip" 37 | # Lint, typecheck, test 38 | - name: Install marimo-labs with dev dependencies 39 | run: | 40 | python -m pip install --upgrade pip 41 | pip install .[dev] 42 | - name: Lint 43 | run: | 44 | ruff check marimo_labs/ 45 | - name: Typecheck 46 | if: ${{ matrix.python-version == '3.9' || matrix.python-version == '3.10' }} 47 | run: | 48 | mypy --config-file pyproject.toml marimo_labs/ 49 | - name: Test 50 | run: | 51 | pytest -v tests/ 52 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Data files, images 2 | *.csv 3 | *.pdf 4 | *.txt 5 | 6 | # Byte-compiled / optimized / DLL files 7 | __pycache__/ 8 | *.py[cod] 9 | *$py.class 10 | 11 | # C extensions 12 | *.so 13 | 14 | # Distribution / packaging 15 | .Python 16 | build/ 17 | develop-eggs/ 18 | dist/ 19 | downloads/ 20 | eggs/ 21 | .eggs/ 22 | lib/ 23 | lib64/ 24 | parts/ 25 | sdist/ 26 | var/ 27 | wheels/ 28 | pip-wheel-metadata/ 29 | share/python-wheels/ 30 | *.egg-info/ 31 | .installed.cfg 32 | *.egg 33 | MANIFEST 34 | 35 | # PyInstaller 36 | # Usually these files are written by a python script from a template 37 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 38 | *.manifest 39 | *.spec 40 | 41 | # Installer logs 42 | pip-log.txt 43 | pip-delete-this-directory.txt 44 | 45 | # Unit test / coverage reports 46 | htmlcov/ 47 | .tox/ 48 | .nox/ 49 | .coverage 50 | .coverage.* 51 | .cache 52 | nosetests.xml 53 | coverage.xml 54 | *.cover 55 | .hypothesis/ 56 | .pytest_cache/ 57 | 58 | # Translations 59 | *.mo 60 | *.pot 61 | 62 | # Sphinx documentation 63 | docs/_build/ 64 | 65 | # Jupyter Notebook 66 | .ipynb_checkpoints 67 | 68 | # IPython 69 | profile_default/ 70 | ipython_config.py 71 | 72 | # pyenv 73 | .python-version 74 | 75 | # pipenv 76 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 77 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 78 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 79 | # install all needed dependencies. 80 | #Pipfile.lock 81 | 82 | # Environments 83 | .env 84 | .venv 85 | env/ 86 | venv/ 87 | ENV/ 88 | env.bak/ 89 | venv.bak/ 90 | 91 | # editors 92 | *.swp 93 | 94 | # misc 95 | .DS_Store 96 | .env.local 97 | .env.development.local 98 | .env.test.local 99 | .env.production.local 100 | 101 | npm-debug.log* 102 | yarn-debug.log* 103 | yarn-error.log* 104 | pnpm-debug.log* 105 | 106 | marimo/_static/ 107 | marimo/_lsp/ 108 | 109 | .vscode 110 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | 2 | Apache License 3 | Version 2.0, January 2004 4 | http://www.apache.org/licenses/ 5 | 6 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 7 | 8 | 1. Definitions. 9 | 10 | "License" shall mean the terms and conditions for use, reproduction, 11 | and distribution as defined by Sections 1 through 9 of this document. 12 | 13 | "Licensor" shall mean the copyright owner or entity authorized by 14 | the copyright owner that is granting the License. 15 | 16 | "Legal Entity" shall mean the union of the acting entity and all 17 | other entities that control, are controlled by, or are under common 18 | control with that entity. For the purposes of this definition, 19 | "control" means (i) the power, direct or indirect, to cause the 20 | direction or management of such entity, whether by contract or 21 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 22 | outstanding shares, or (iii) beneficial ownership of such entity. 23 | 24 | "You" (or "Your") shall mean an individual or Legal Entity 25 | exercising permissions granted by this License. 26 | 27 | "Source" form shall mean the preferred form for making modifications, 28 | including but not limited to software source code, documentation 29 | source, and configuration files. 30 | 31 | "Object" form shall mean any form resulting from mechanical 32 | transformation or translation of a Source form, including but 33 | not limited to compiled object code, generated documentation, 34 | and conversions to other media types. 35 | 36 | "Work" shall mean the work of authorship, whether in Source or 37 | Object form, made available under the License, as indicated by a 38 | copyright notice that is included in or attached to the work 39 | (an example is provided in the Appendix below). 40 | 41 | "Derivative Works" shall mean any work, whether in Source or Object 42 | form, that is based on (or derived from) the Work and for which the 43 | editorial revisions, annotations, elaborations, or other modifications 44 | represent, as a whole, an original work of authorship. For the purposes 45 | of this License, Derivative Works shall not include works that remain 46 | separable from, or merely link (or bind by name) to the interfaces of, 47 | the Work and Derivative Works thereof. 48 | 49 | "Contribution" shall mean any work of authorship, including 50 | the original version of the Work and any modifications or additions 51 | to that Work or Derivative Works thereof, that is intentionally 52 | submitted to Licensor for inclusion in the Work by the copyright owner 53 | or by an individual or Legal Entity authorized to submit on behalf of 54 | the copyright owner. For the purposes of this definition, "submitted" 55 | means any form of electronic, verbal, or written communication sent 56 | to the Licensor or its representatives, including but not limited to 57 | communication on electronic mailing lists, source code control systems, 58 | and issue tracking systems that are managed by, or on behalf of, the 59 | Licensor for the purpose of discussing and improving the Work, but 60 | excluding communication that is conspicuously marked or otherwise 61 | designated in writing by the copyright owner as "Not a Contribution." 62 | 63 | "Contributor" shall mean Licensor and any individual or Legal Entity 64 | on behalf of whom a Contribution has been received by Licensor and 65 | subsequently incorporated within the Work. 66 | 67 | 2. Grant of Copyright License. Subject to the terms and conditions of 68 | this License, each Contributor hereby grants to You a perpetual, 69 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 70 | copyright license to reproduce, prepare Derivative Works of, 71 | publicly display, publicly perform, sublicense, and distribute the 72 | Work and such Derivative Works in Source or Object form. 73 | 74 | 3. Grant of Patent License. Subject to the terms and conditions of 75 | this License, each Contributor hereby grants to You a perpetual, 76 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 77 | (except as stated in this section) patent license to make, have made, 78 | use, offer to sell, sell, import, and otherwise transfer the Work, 79 | where such license applies only to those patent claims licensable 80 | by such Contributor that are necessarily infringed by their 81 | Contribution(s) alone or by combination of their Contribution(s) 82 | with the Work to which such Contribution(s) was submitted. If You 83 | institute patent litigation against any entity (including a 84 | cross-claim or counterclaim in a lawsuit) alleging that the Work 85 | or a Contribution incorporated within the Work constitutes direct 86 | or contributory patent infringement, then any patent licenses 87 | granted to You under this License for that Work shall terminate 88 | as of the date such litigation is filed. 89 | 90 | 4. Redistribution. You may reproduce and distribute copies of the 91 | Work or Derivative Works thereof in any medium, with or without 92 | modifications, and in Source or Object form, provided that You 93 | meet the following conditions: 94 | 95 | (a) You must give any other recipients of the Work or 96 | Derivative Works a copy of this License; and 97 | 98 | (b) You must cause any modified files to carry prominent notices 99 | stating that You changed the files; and 100 | 101 | (c) You must retain, in the Source form of any Derivative Works 102 | that You distribute, all copyright, patent, trademark, and 103 | attribution notices from the Source form of the Work, 104 | excluding those notices that do not pertain to any part of 105 | the Derivative Works; and 106 | 107 | (d) If the Work includes a "NOTICE" text file as part of its 108 | distribution, then any Derivative Works that You distribute must 109 | include a readable copy of the attribution notices contained 110 | within such NOTICE file, excluding those notices that do not 111 | pertain to any part of the Derivative Works, in at least one 112 | of the following places: within a NOTICE text file distributed 113 | as part of the Derivative Works; within the Source form or 114 | documentation, if provided along with the Derivative Works; or, 115 | within a display generated by the Derivative Works, if and 116 | wherever such third-party notices normally appear. The contents 117 | of the NOTICE file are for informational purposes only and 118 | do not modify the License. You may add Your own attribution 119 | notices within Derivative Works that You distribute, alongside 120 | or as an addendum to the NOTICE text from the Work, provided 121 | that such additional attribution notices cannot be construed 122 | as modifying the License. 123 | 124 | You may add Your own copyright statement to Your modifications and 125 | may provide additional or different license terms and conditions 126 | for use, reproduction, or distribution of Your modifications, or 127 | for any such Derivative Works as a whole, provided Your use, 128 | reproduction, and distribution of the Work otherwise complies with 129 | the conditions stated in this License. 130 | 131 | 5. Submission of Contributions. Unless You explicitly state otherwise, 132 | any Contribution intentionally submitted for inclusion in the Work 133 | by You to the Licensor shall be under the terms and conditions of 134 | this License, without any additional terms or conditions. 135 | Notwithstanding the above, nothing herein shall supersede or modify 136 | the terms of any separate license agreement you may have executed 137 | with Licensor regarding such Contributions. 138 | 139 | 6. Trademarks. This License does not grant permission to use the trade 140 | names, trademarks, service marks, or product names of the Licensor, 141 | except as required for reasonable and customary use in describing the 142 | origin of the Work and reproducing the content of the NOTICE file. 143 | 144 | 7. Disclaimer of Warranty. Unless required by applicable law or 145 | agreed to in writing, Licensor provides the Work (and each 146 | Contributor provides its Contributions) on an "AS IS" BASIS, 147 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 148 | implied, including, without limitation, any warranties or conditions 149 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 150 | PARTICULAR PURPOSE. You are solely responsible for determining the 151 | appropriateness of using or redistributing the Work and assume any 152 | risks associated with Your exercise of permissions under this License. 153 | 154 | 8. Limitation of Liability. In no event and under no legal theory, 155 | whether in tort (including negligence), contract, or otherwise, 156 | unless required by applicable law (such as deliberate and grossly 157 | negligent acts) or agreed to in writing, shall any Contributor be 158 | liable to You for damages, including any direct, indirect, special, 159 | incidental, or consequential damages of any character arising as a 160 | result of this License or out of the use or inability to use the 161 | Work (including but not limited to damages for loss of goodwill, 162 | work stoppage, computer failure or malfunction, or any and all 163 | other commercial damages or losses), even if such Contributor 164 | has been advised of the possibility of such damages. 165 | 166 | 9. Accepting Warranty or Additional Liability. While redistributing 167 | the Work or Derivative Works thereof, You may choose to offer, 168 | and charge a fee for, acceptance of support, warranty, indemnity, 169 | or other liability obligations and/or rights consistent with this 170 | License. However, in accepting such obligations, You may act only 171 | on Your own behalf and on Your sole responsibility, not on behalf 172 | of any other Contributor, and only if You agree to indemnify, 173 | defend, and hold each Contributor harmless for any liability 174 | incurred by, or claims asserted against, such Contributor by reason 175 | of your accepting any such warranty or additional liability. 176 | 177 | END OF TERMS AND CONDITIONS 178 | 179 | APPENDIX: How to apply the Apache License to your work. 180 | 181 | To apply the Apache License to your work, attach the following 182 | boilerplate notice, with the fields enclosed by brackets "[]" 183 | replaced with your own identifying information. (Don't include 184 | the brackets!) The text should be enclosed in the appropriate 185 | comment syntax for the file format. We also recommend that a 186 | file or class name and description of purpose be included on the 187 | same "printed page" as the copyright notice for easier 188 | identification within third-party archives. 189 | 190 | Copyright [yyyy] [name of copyright owner] 191 | 192 | Licensed under the Apache License, Version 2.0 (the "License"); 193 | you may not use this file except in compliance with the License. 194 | You may obtain a copy of the License at 195 | 196 | http://www.apache.org/licenses/LICENSE-2.0 197 | 198 | Unless required by applicable law or agreed to in writing, software 199 | distributed under the License is distributed on an "AS IS" BASIS, 200 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 201 | See the License for the specific language governing permissions and 202 | limitations under the License. 203 | -------------------------------------------------------------------------------- /Makefile: -------------------------------------------------------------------------------- 1 | .PHONY: help 2 | # show this help 3 | help: 4 | @# prints help for rules preceded with comment 5 | @# https://stackoverflow.com/a/35730928 6 | @awk '/^#/{c=substr($$0,3);next}c&&/^[[:alpha:]][[:alnum:]_-]+:/{print substr($$1,1,index($$1,":")),c}1{c=0}' Makefile | column -s: -t 7 | 8 | .PHONY: py 9 | py: 10 | pip install -e . 11 | 12 | .PHONY: check 13 | check: 14 | ./scripts/pyfix.sh 15 | 16 | .PHONY: check-test 17 | # run all checks and tests 18 | check-test: check test 19 | 20 | .PHONY: test 21 | # run all checks and tests 22 | test: 23 | pytest 24 | 25 | .PHONY: wheel 26 | # build wheel 27 | wheel: 28 | python -m build 29 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # marimo labs 2 | 3 | This repository contains experimental functionality for the [marimo notebook](https://github.com/marimo-team/marimo): 4 | 5 | 🤗 `marimo_labs.huggingface`: interactively experiment with any model on HuggingFace 6 | 7 | **Installation.** 8 | 9 | ```shell 10 | pip install marimo-labs 11 | ``` 12 | 13 | **Examples**: 14 | 15 | Use `marimo_labs.huggingface.load` to load and query any model on HuggingFace. See examples for usage: 16 | 17 | - [Playground to experiment with text, image, and audio models](https://marimo.app/l/tmk0k2) 18 | - [Run stable diffusion](https://marimo.app/l/ugvgap) 19 | - [Load and run any HuggingFace model](https://marimo.app/l/ynxf6q) 20 | 21 | HuggingFace provides a rate-limited inference API; increase your rate limit by adding a (free) [HuggingFace API token](https://huggingface.co/docs/hub/en/security-tokens). 22 | 23 | image 24 | -------------------------------------------------------------------------------- /examples/huggingface_playground.py: -------------------------------------------------------------------------------- 1 | import marimo 2 | 3 | __generated_with = "0.3.9" 4 | app = marimo.App() 5 | 6 | 7 | @app.cell 8 | def __(): 9 | import marimo_labs as molabs 10 | return molabs, 11 | 12 | 13 | @app.cell(hide_code=True) 14 | def __(mo, model_type, model_type_to_model): 15 | models = mo.ui.dropdown( 16 | model_type_to_model[model_type.value], label="Choose a model" 17 | ) 18 | 19 | mo.hstack([model_type, models if model_type.value else ""], justify="start") 20 | return models, 21 | 22 | 23 | @app.cell 24 | def load_model(mo, models, molabs): 25 | mo.stop(models.value is None) 26 | model = molabs.huggingface.load(models.value) 27 | return model, 28 | 29 | 30 | @app.cell 31 | def __(mo, model): 32 | mo.md( 33 | f""" 34 | Example inputs: 35 | 36 | {mo.as_html(model.examples)} 37 | """ 38 | ) if model.examples is not None else None 39 | return 40 | 41 | 42 | @app.cell 43 | def __(mo, model): 44 | inputs = model.inputs 45 | 46 | 47 | mo.vstack([mo.md("_Submit inputs to run inference_ 👇"), inputs]) 48 | return inputs, 49 | 50 | 51 | @app.cell 52 | def __(inputs, mo, model): 53 | mo.stop(inputs.value is None) 54 | 55 | output = model.inference_function(inputs.value) 56 | output 57 | return output, 58 | 59 | 60 | @app.cell 61 | def __(): 62 | import marimo as mo 63 | return mo, 64 | 65 | 66 | @app.cell(hide_code=True) 67 | def __(mo): 68 | audio_models = { 69 | "audio classification": "models/ehcalabres/wav2vec2-lg-xlsr-en-speech-emotion-recognition", 70 | "audio to audio": "models/facebook/xm_transformer_sm_all-en", 71 | "speech recognition": "models/facebook/wav2vec2-base-960h", 72 | "text to speech": "models/julien-c/ljspeech_tts_train_tacotron2_raw_phn_tacotron_g2p_en_no_space_train", 73 | } 74 | 75 | image_models = { 76 | "image classification": "models/google/vit-base-patch16-224", 77 | "text to image": "models/runwayml/stable-diffusion-v1-5", 78 | "image to text": "models/Salesforce/blip-image-captioning-base", 79 | "object detection": "models/microsoft/table-transformer-detection", 80 | } 81 | 82 | text_models = { 83 | "feature extraction": "models/julien-c/distilbert-feature-extraction", 84 | "fill mask": "models/distilbert/distilbert-base-uncased", 85 | "zero-shot classification": "models/facebook/bart-large-mnli", 86 | "visual question answering": "models/dandelin/vilt-b32-finetuned-vqa", 87 | "sentence similarity": "models/sentence-transformers/all-MiniLM-L6-v2", 88 | "question answering": "models/deepset/xlm-roberta-base-squad2", 89 | "summarization": "models/facebook/bart-large-cnn", 90 | "text-classification": "models/distilbert/distilbert-base-uncased-finetuned-sst-2-english", 91 | "text generation": "models/openai-community/gpt2", 92 | "text2text generation": "models/valhalla/t5-small-qa-qg-hl", 93 | "translation": "models/Helsinki-NLP/opus-mt-en-ar", 94 | "token classification": "models/huggingface-course/bert-finetuned-ner", 95 | "document question answering": "models/impira/layoutlm-document-qa", 96 | } 97 | 98 | model_type_to_model = { 99 | "text": text_models, 100 | "image": image_models, 101 | "audio": audio_models, 102 | None: [], 103 | } 104 | 105 | model_type = mo.ui.dropdown( 106 | ["text", "image", "audio"], label="Choose a model type" 107 | ) 108 | return ( 109 | audio_models, 110 | image_models, 111 | model_type, 112 | model_type_to_model, 113 | text_models, 114 | ) 115 | 116 | 117 | if __name__ == "__main__": 118 | app.run() 119 | -------------------------------------------------------------------------------- /marimo_labs/__init__.py: -------------------------------------------------------------------------------- 1 | import importlib.metadata 2 | import warnings 3 | 4 | import marimo_labs.huggingface as huggingface 5 | 6 | try: 7 | __version__ = importlib.metadata.version(__name__) 8 | except importlib.metadata.PackageNotFoundError as e: 9 | warnings.warn( 10 | f"Could not determine version of {__name__}\n{e!r}", stacklevel=2 11 | ) 12 | __version__ = "unknown" 13 | 14 | __all__ = ["huggingface"] 15 | -------------------------------------------------------------------------------- /marimo_labs/huggingface/README.md: -------------------------------------------------------------------------------- 1 | # HuggingFace model integration 2 | 3 | This module lets you instantiate input controls to run huggingface 4 | inference APIs. You can experiment with image generation, text similarity, 5 | text to speech, and more. 6 | 7 | Code adapted from gradio. 8 | -------------------------------------------------------------------------------- /marimo_labs/huggingface/__init__.py: -------------------------------------------------------------------------------- 1 | __all__ = ["load"] 2 | 3 | from marimo_labs.huggingface._load import load 4 | -------------------------------------------------------------------------------- /marimo_labs/huggingface/_client_utils.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | import base64 4 | import mimetypes 5 | import secrets 6 | import tempfile 7 | from pathlib import Path 8 | 9 | import requests # type: ignore 10 | 11 | 12 | def get_extension(encoding: str) -> str | None: 13 | encoding = encoding.replace("audio/wav", "audio/x-wav") 14 | typ = mimetypes.guess_type(encoding)[0] 15 | if typ == "audio/flac": # flac is not supported by mimetypes 16 | return "flac" 17 | elif typ is None: 18 | return None 19 | extension = mimetypes.guess_extension(typ) 20 | if extension is not None and extension.startswith("."): 21 | extension = extension[1:] 22 | return extension 23 | 24 | 25 | def strip_invalid_filename_characters( 26 | filename: str, max_bytes: int = 200 27 | ) -> str: 28 | """Strips invalid characters from a filename 29 | 30 | Ensures that the file_length is less than `max_bytes` bytes. 31 | """ 32 | filename = "".join( 33 | [char for char in filename if char.isalnum() or char in "._- "] 34 | ) 35 | filename_len = len(filename.encode()) 36 | if filename_len > max_bytes: 37 | while filename_len > max_bytes: 38 | if len(filename) == 0: 39 | break 40 | filename = filename[:-1] 41 | filename_len = len(filename.encode()) 42 | return filename 43 | 44 | 45 | def decode_base64_to_binary(encoding: str) -> tuple[bytes, str | None]: 46 | extension = get_extension(encoding) 47 | data = encoding.rsplit(",", 1)[-1] 48 | return base64.b64decode(data), extension 49 | 50 | 51 | def decode_base64_to_file( 52 | encoding: str, 53 | file_path: str | None = None, 54 | direct: str | Path | None = None, 55 | prefix: str | None = None, 56 | ): 57 | directory = Path(direct or tempfile.gettempdir()) / secrets.token_hex(20) 58 | directory.mkdir(exist_ok=True, parents=True) 59 | data, extension = decode_base64_to_binary(encoding) 60 | if file_path is not None and prefix is None: 61 | filename = Path(file_path).name 62 | prefix = filename 63 | if "." in filename: 64 | prefix = filename[0 : filename.index(".")] 65 | extension = filename[filename.index(".") + 1 :] 66 | 67 | if prefix is not None: 68 | prefix = strip_invalid_filename_characters(prefix) 69 | 70 | if extension is None: 71 | file_obj = tempfile.NamedTemporaryFile( 72 | delete=False, prefix=prefix, dir=directory 73 | ) 74 | else: 75 | file_obj = tempfile.NamedTemporaryFile( 76 | delete=False, 77 | prefix=prefix, 78 | suffix="." + extension, 79 | dir=directory, 80 | ) 81 | file_obj.write(data) 82 | file_obj.flush() 83 | return file_obj 84 | 85 | 86 | def get_mimetype(filename: str) -> str | None: 87 | if filename.endswith(".vtt"): 88 | return "text/vtt" 89 | mimetype = mimetypes.guess_type(filename)[0] 90 | if mimetype is not None: 91 | mimetype = mimetype.replace("x-wav", "wav").replace("x-flac", "flac") 92 | return mimetype 93 | 94 | 95 | def is_http_url_like(possible_url) -> bool: 96 | """ 97 | Check if the given value is a string that looks like an HTTP(S) URL. 98 | """ 99 | if not isinstance(possible_url, str): 100 | return False 101 | return possible_url.startswith(("http://", "https://")) 102 | 103 | 104 | def encode_file_to_base64(f: str | Path) -> str: 105 | with open(f, "rb") as file: 106 | encoded_string = base64.b64encode(file.read()) 107 | base64_str = str(encoded_string, "utf-8") 108 | mimetype = get_mimetype(str(f)) 109 | return ( 110 | "data:" 111 | + (mimetype if mimetype is not None else "") 112 | + ";base64," 113 | + base64_str 114 | ) 115 | 116 | 117 | def encode_url_to_base64(url: str) -> str: 118 | resp = requests.get(url) 119 | resp.raise_for_status() 120 | encoded_string = base64.b64encode(resp.content) 121 | base64_str = str(encoded_string, "utf-8") 122 | mimetype = get_mimetype(url) 123 | return ( 124 | "data:" 125 | + (mimetype if mimetype is not None else "") 126 | + ";base64," 127 | + base64_str 128 | ) 129 | 130 | 131 | def encode_url_or_file_to_base64(path: str | Path) -> str: 132 | path = str(path) 133 | if is_http_url_like(path): 134 | return encode_url_to_base64(path) 135 | return encode_file_to_base64(path) 136 | -------------------------------------------------------------------------------- /marimo_labs/huggingface/_load.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | import os 4 | import tempfile 5 | from dataclasses import dataclass 6 | from pathlib import Path 7 | from typing import Any, Callable 8 | 9 | import huggingface_hub # type: ignore 10 | import marimo as mo 11 | import requests # type: ignore 12 | 13 | from marimo_labs.huggingface import _load_utils, _outputs 14 | from marimo_labs.huggingface._processing_utils import ( 15 | encode_to_base64, 16 | save_base64_to_cache, 17 | ) 18 | 19 | 20 | @dataclass 21 | class HFModel: 22 | title: str 23 | inputs: mo.ui.form 24 | examples: list[Any] | None 25 | inference_function: Callable[..., _outputs.Output] 26 | 27 | 28 | def load( 29 | name: str, 30 | hf_token: str | None = None, 31 | **kwargs, 32 | ) -> HFModel: 33 | """Constructs a demo from a Hugging Face repo. 34 | 35 | Can accept model repos (if src is "models"). The input and output 36 | components are automatically loaded from the repo. Note that if a Space is 37 | loaded, certain high-level attributes of the Blocks (e.g. custom `css`, 38 | `js`, and `head` attributes) will not be loaded. 39 | 40 | Parameters: 41 | name: the name of the model (e.g. "gpt2" or "facebook/bart-base") or 42 | space (e.g. "flax-community/spanish-gpt2"), can include the `src` 43 | as prefix (e.g. "models/facebook/bart-base") 44 | 45 | src: the source of the model: `models` or `spaces` (or leave empty if 46 | source is provided as a prefix in `name`) 47 | 48 | hf_token: optional access token for loading private Hugging Face Hub 49 | models. Find your token here: 50 | https://huggingface.co/settings/tokens. 51 | 52 | Returns: 53 | an HFModel object 54 | 55 | Example: 56 | 57 | ```python 58 | import marimo_labs 59 | 60 | model = marimo_labs.load("models/runwayml/stable-diffusion-v1-5") 61 | inputs = model.inputs 62 | inputs 63 | ``` 64 | 65 | ```python 66 | model.inference_function(inputs.value) 67 | ``` 68 | """ 69 | return load_model_from_repo(name=name, hf_token=hf_token, **kwargs) 70 | 71 | 72 | def load_model_from_repo( 73 | name: str, 74 | hf_token: str | None = None, 75 | **kwargs, 76 | ) -> HFModel: 77 | """Creates and returns an HFModel""" 78 | # Separate the repo type (e.g. "model") from repo name (e.g. 79 | # "google/vit-base-patch16-224") 80 | tokens = name.split("/") 81 | if len(tokens) <= 1: 82 | raise ValueError( 83 | "Either `src` parameter must be provided, or " 84 | "`name` must be formatted as {src}/{repo name}" 85 | ) 86 | src = tokens[0] 87 | name = "/".join(tokens[1:]) 88 | 89 | factory_methods: dict[str, Callable] = { 90 | # for each repo type, we have a method that returns the Interface given 91 | # the model name & optionally an hf_token 92 | "huggingface": from_model, 93 | "models": from_model, 94 | } 95 | if src.lower() not in factory_methods: 96 | raise ValueError( 97 | f"parameter: src must be one of {factory_methods.keys()}" 98 | ) 99 | 100 | return factory_methods[src](name, hf_token, **kwargs) 101 | 102 | 103 | def from_model(model_name: str, hf_token: str | None, **kwargs) -> HFModel: 104 | del kwargs 105 | 106 | model_url = f"https://huggingface.co/{model_name}" 107 | api_url = f"https://api-inference.huggingface.co/models/{model_name}" 108 | 109 | print(f"Fetching model from: {model_url}") 110 | 111 | headers = ( 112 | {"Authorization": f"Bearer {hf_token}"} if hf_token is not None else {} 113 | ) 114 | response = requests.request("GET", api_url, headers=headers) 115 | if response.status_code != 200: 116 | raise ValueError( 117 | f"Could not find model: {model_name}. " 118 | "If it is a private or gated model, please provide your " 119 | "Hugging Face access token " 120 | "(https://huggingface.co/settings/tokens) as the argument for the " 121 | "`hf_token` parameter." 122 | ) 123 | p = response.json().get("pipeline_tag") 124 | 125 | headers["X-Wait-For-Model"] = "true" 126 | client = huggingface_hub.InferenceClient( 127 | model=model_name, headers=headers, token=hf_token 128 | ) 129 | 130 | # For tasks that are not yet supported by the InferenceClient 131 | MARIMOLABS_CACHE = os.environ.get( 132 | "MARIMOLABS_TEMP_DIR" 133 | ) or str( # noqa: N806 134 | Path(tempfile.gettempdir()) / "marimo_labs" 135 | ) 136 | 137 | def custom_post_binary(data): 138 | response = requests.request( 139 | "POST", api_url, headers=headers, content=data 140 | ) 141 | return save_base64_to_cache( 142 | encode_to_base64(response), cache_dir=MARIMOLABS_CACHE 143 | ) 144 | 145 | inputs: mo.Html # actually a UIElement, but not in public API ... 146 | preprocess = None 147 | postprocess: Callable[..., Any] | None = None 148 | examples: Any = None 149 | fn: Callable[..., _outputs.Output] 150 | 151 | # example model: ehcalabres/wav2vec2-lg-xlsr-en-speech-emotion-recognition 152 | if p == "audio-classification": 153 | inputs = mo.ui.file( 154 | filetypes=["audio/*"], label="Upload an audio file", kind="area" 155 | ) 156 | postprocess = _load_utils.postprocess_label 157 | examples = [ 158 | "https://gradio-builds.s3.amazonaws.com/demo-files/audio_sample.wav" 159 | ] 160 | fn = _outputs.construct_output_function( 161 | inference_function=_load_utils.file_contents_wrapper( 162 | client.audio_classification 163 | ) 164 | ) 165 | # example model: facebook/xm_transformer_sm_all-en 166 | elif p == "audio-to-audio": 167 | inputs = mo.ui.file( 168 | filetypes=["audio/*"], label="Upload an audio file", kind="area" 169 | ) 170 | # output_function = components.Audio(label="Output") 171 | examples = [ 172 | "https://gradio-builds.s3.amazonaws.com/demo-files/audio_sample.wav" 173 | ] 174 | fn = _outputs.construct_output_function( 175 | lambda v: custom_post_binary(v.contents), 176 | _outputs.audio_output_from_path, 177 | ) 178 | # example model: facebook/wav2vec2-base-960h 179 | elif p == "automatic-speech-recognition": 180 | inputs = mo.ui.file( 181 | filetypes=["audio/*"], label="Upload an audio file", kind="area" 182 | ) 183 | # outputs = components.Textbox(label="Output") 184 | examples = [ 185 | "https://gradio-builds.s3.amazonaws.com/demo-files/audio_sample.wav" 186 | ] 187 | fn = _outputs.construct_output_function( 188 | _load_utils.file_contents_wrapper( 189 | client.automatic_speech_recognition 190 | ) 191 | ) 192 | # example model: microsoft/DialoGPT-medium 193 | elif p == "conversational": 194 | raise NotImplementedError 195 | # inputs = [ 196 | # components.Textbox(render=False), 197 | # components.State(render=False), 198 | # ] 199 | # outputs = [ 200 | # components.Chatbot(render=False), 201 | # components.State(render=False), 202 | # ] 203 | # examples = [["Hello World"]] 204 | # preprocess = external_utils.chatbot_preprocess 205 | # postprocess = external_utils.chatbot_postprocess 206 | # fn = client.conversational 207 | # example model: julien-c/distilbert-feature-extraction 208 | elif p == "feature-extraction": 209 | inputs = mo.ui.text_area(label="Text to featurize") 210 | # outputs = components.Dataframe(label="Output") 211 | postprocess = lambda v: v[0] if len(v) == 1 else v # type: ignore # noqa: E731 212 | fn = _outputs.construct_output_function(client.feature_extraction) 213 | # example model: distilbert/distilbert-base-uncased 214 | elif p == "fill-mask": 215 | inputs = mo.ui.text_area(label="Masked text") 216 | # outputs = components.Label(label="Classification") 217 | examples = [ 218 | "Hugging Face is the AI community, working together, to " 219 | "[MASK] the future." 220 | ] 221 | postprocess = _load_utils.postprocess_mask_tokens 222 | fn = _outputs.construct_output_function(client.fill_mask) 223 | # Example: google/vit-base-patch16-224 224 | elif p == "image-classification": 225 | inputs = mo.ui.file( 226 | filetypes="image/*", label="Input Image", kind="area" 227 | ) 228 | # outputs = components.Label(label="Classification") 229 | postprocess = _load_utils.postprocess_label 230 | examples = [ 231 | "https://gradio-builds.s3.amazonaws.com/demo-files/cheetah-002.jpg" 232 | ] 233 | fn = _outputs.construct_output_function( 234 | _load_utils.file_contents_wrapper(client.image_classification) 235 | ) 236 | # Example: deepset/xlm-roberta-base-squad2 237 | elif p == "question-answering": 238 | inputs = mo.ui.array( 239 | [ 240 | mo.ui.text_area(label="Question"), 241 | mo.ui.text_area(rows=7, label="Context"), 242 | ] 243 | ) 244 | # outputs = [ 245 | # components.Textbox(label="Answer"), 246 | # components.Label(label="Score"), 247 | # ] 248 | examples = [ 249 | [ 250 | "What entity was responsible for the Apollo program?", 251 | "The Apollo program, also known as Project Apollo, was the " 252 | "third United States human spaceflight" 253 | " program carried out by the National Aeronautics and Space " 254 | "Administration (NASA), which accomplished" 255 | " landing the first humans on the Moon from 1969 to 1972.", 256 | ] 257 | ] 258 | postprocess = _load_utils.postprocess_question_answering 259 | fn = _outputs.construct_output_function(client.question_answering) 260 | # Example: facebook/bart-large-cnn 261 | elif p == "summarization": 262 | inputs = mo.ui.text_area(label="Text to summarize") 263 | # outputs = components.Textbox(label="Summary") 264 | examples = [ 265 | [ 266 | "The tower is 324 metres (1,063 ft) tall, about the same height as an 81-storey building, and the tallest structure in Paris. Its base is square, measuring 125 metres (410 ft) on each side. During its construction, the Eiffel Tower surpassed the Washington Monument to become the tallest man-made structure in the world, a title it held for 41 years until the Chrysler Building in New York City was finished in 1930. It was the first structure to reach a height of 300 metres. Due to the addition of a broadcasting aerial at the top of the tower in 1957, it is now taller than the Chrysler Building by 5.2 metres (17 ft). Excluding transmitters, the Eiffel Tower is the second tallest free-standing structure in France after the Millau Viaduct." # noqa: E501 267 | ] 268 | ] 269 | fn = _outputs.construct_output_function(client.summarization) 270 | # Example: distilbert-base-uncased-finetuned-sst-2-english 271 | elif p == "text-classification": 272 | inputs = mo.ui.text_area(label="Text to classify") 273 | # outputs = components.Label(label="Classification") 274 | examples = ["I feel great"] 275 | postprocess = _load_utils.postprocess_label 276 | fn = _outputs.construct_output_function(client.text_classification) 277 | # Example: gpt2 278 | elif p == "text-generation": 279 | inputs = mo.ui.text_area(label="Prompt") 280 | # outputs = inputs 281 | examples = ["Once upon a time"] 282 | fn = _outputs.construct_output_function( 283 | _load_utils.text_generation_wrapper(client) 284 | ) 285 | # Example: valhalla/t5-small-qa-qg-hl 286 | elif p == "text2text-generation": 287 | inputs = mo.ui.text_area(label="Input text") 288 | # outputs = components.Textbox(label="Generated Text") 289 | examples = ["Translate English to Arabic: How are you?"] 290 | fn = _outputs.construct_output_function(client.text_generation) 291 | # Example: Helsinki-NLP/opus-mt-en-ar 292 | elif p == "translation": 293 | inputs = mo.ui.text_area(label="Text to translate") 294 | # outputs = components.Textbox(label="Translation") 295 | examples = ["Hello, how are you?"] 296 | fn = _outputs.construct_output_function(client.translation) 297 | # Example: facebook/bart-large-mnli 298 | elif p == "zero-shot-classification": 299 | inputs = mo.ui.array( 300 | [ 301 | mo.ui.text_area(label="Input text"), 302 | mo.ui.text_area( 303 | label="Possible class names (comma-separated)" 304 | ), 305 | mo.ui.checkbox(label="Allow multiple true classes"), 306 | ] 307 | ) 308 | # outputs = components.Label(label="Classification") 309 | postprocess = _load_utils.postprocess_label 310 | examples = [["I feel great", "happy, sad", False]] 311 | fn = _outputs.construct_output_function( 312 | _load_utils.zero_shot_classification_wrapper(client) 313 | ) 314 | # Example: sentence-transformers/distilbert-base-nli-stsb-mean-tokens 315 | elif p == "sentence-similarity": 316 | inputs = mo.ui.array( 317 | [ 318 | mo.ui.text_area( 319 | label="Source Sentence", 320 | placeholder="Enter an original sentence", 321 | ), 322 | mo.ui.text_area( 323 | rows=7, 324 | placeholder=( 325 | "Sentences to compare to -- separate each " 326 | "sentence by a newline" 327 | ), 328 | label="Sentences to compare to", 329 | ), 330 | ] 331 | ) 332 | # outputs = components.JSON(label="Similarity scores") 333 | examples = [["That is a happy person", "That person is very happy"]] 334 | fn = _outputs.construct_output_function( 335 | _load_utils.sentence_similarity_wrapper(client) 336 | ) 337 | # Example: julien-c/ljspeech_tts_train_tacotron2_raw_phn_tacotron_g2p_en_no_space_train # noqa: E501 338 | elif p == "text-to-speech": 339 | inputs = mo.ui.text_area(label="Text") 340 | # outputs = components.Audio(label="Audio") 341 | examples = ["Hello, how are you?"] 342 | fn = _outputs.construct_output_function( 343 | client.text_to_speech, _outputs.audio_output_from_bytes 344 | ) 345 | # example model: osanseviero/BigGAN-deep-128 346 | elif p == "text-to-image": 347 | inputs = mo.ui.text_area(label="Prompt") 348 | # outputs = components.Image(label="Output") 349 | examples = ["A beautiful sunset"] 350 | fn = _outputs.construct_output_function( 351 | client.text_to_image, _outputs.image_output 352 | ) 353 | # example model: huggingface-course/bert-finetuned-ner 354 | elif p == "token-classification": 355 | inputs = mo.ui.text_area(label="Input") 356 | # outputs = components.HighlightedText(label="Output") 357 | examples = [ 358 | "marimo is a new kind of programming environment for ML/AI." 359 | ] 360 | fn = _outputs.construct_output_function( 361 | _load_utils.token_classification_wrapper(client) 362 | ) 363 | # example model: impira/layoutlm-document-qa 364 | elif p == "document-question-answering": 365 | inputs = mo.ui.array( 366 | [ 367 | mo.ui.file(label="Upload a document", kind="area"), 368 | mo.ui.text_area(label="Question"), 369 | ] 370 | ) 371 | postprocess = _load_utils.postprocess_label 372 | # outputs = components.Label(label="Label") 373 | fn = _outputs.construct_output_function( 374 | lambda file_upload_results, text: client.document_question_answering( # noqa: E501 375 | file_upload_results[0].contents, 376 | text, 377 | ) 378 | ) 379 | # example model: dandelin/vilt-b32-finetuned-vqa 380 | elif p == "visual-question-answering": 381 | inputs = mo.ui.array( 382 | [ 383 | mo.ui.file( 384 | filetypes=["image/*"], label="Input Image", kind="area" 385 | ), 386 | mo.ui.text_area(label="Question"), 387 | ] 388 | ) 389 | # outputs = components.Label(label="Label") 390 | postprocess = _load_utils.postprocess_visual_question_answering 391 | examples = [ 392 | [ 393 | "https://gradio-builds.s3.amazonaws.com/demo-files/cheetah-002.jpg", 394 | "What animal is in the image?", 395 | ] 396 | ] 397 | fn = _outputs.construct_output_function( 398 | lambda file_upload_results, text: client.visual_question_answering( 399 | file_upload_results[0].contents, 400 | text, 401 | ) 402 | ) 403 | # example model: Salesforce/blip-image-captioning-base 404 | elif p == "image-to-text": 405 | inputs = mo.ui.file( 406 | filetypes=["image/*"], label="Input Image", kind="area" 407 | ) 408 | # outputs = components.Textbox(label="Generated Text") 409 | examples = [ 410 | "https://gradio-builds.s3.amazonaws.com/demo-files/cheetah-002.jpg" 411 | ] 412 | fn = _outputs.construct_output_function( 413 | _load_utils.file_contents_wrapper(client.image_to_text) 414 | ) 415 | # example model: rajistics/autotrain-Adult-934630783 416 | elif p in ["tabular-classification", "tabular-regression"]: 417 | raise NotImplementedError 418 | # examples = _load_utils.get_tabular_examples(model_name) 419 | # col_names, examples = _load_utils.cols_to_rows(examples) 420 | # examples = [[examples]] if examples else None 421 | # inputs = components.Dataframe( 422 | # label="Input Rows", 423 | # type="pandas", 424 | # headers=col_names, 425 | # col_count=(len(col_names), "fixed"), 426 | # render=False, 427 | # ) 428 | # outputs = components.Dataframe( 429 | # label="Predictions", type="array", headers=["prediction"] 430 | # ) 431 | # fn = external_utils.tabular_wrapper 432 | # output_function = fn 433 | # example model: microsoft/table-transformer-detection 434 | elif p == "object-detection": 435 | inputs = mo.ui.file( 436 | filetypes=["image/*"], label="Input Image", kind="area" 437 | ) 438 | # outputs = components.AnnotatedImage(label="Annotations") 439 | # TODO(akshayka): output representation 440 | fn = _outputs.construct_output_function( 441 | _load_utils.file_contents_wrapper( 442 | _load_utils.object_detection_wrapper(client) 443 | ) 444 | ) 445 | else: 446 | raise ValueError(f"Unsupported pipeline type: {p}") 447 | 448 | def query_huggingface_inference_endpoints(data: Any) -> _outputs.Output: 449 | if not isinstance(data, (list, tuple)): 450 | data = [data] 451 | 452 | if preprocess is not None: 453 | data = preprocess(*data) 454 | output = fn(*data) # type: ignore 455 | value = output.value 456 | if postprocess is not None: 457 | value = postprocess(value) # type: ignore 458 | output.value = value 459 | return output 460 | 461 | query_huggingface_inference_endpoints.__name__ = model_name 462 | return HFModel( 463 | title=model_name, 464 | inputs=inputs.form(bordered=False), # type: ignore 465 | examples=examples, 466 | inference_function=query_huggingface_inference_endpoints, 467 | ) 468 | -------------------------------------------------------------------------------- /marimo_labs/huggingface/_load_utils.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | import math 4 | import re 5 | import warnings 6 | from typing import Any, Callable 7 | 8 | import requests # type: ignore 9 | import yaml # type: ignore 10 | from huggingface_hub import InferenceClient # type: ignore 11 | 12 | 13 | def get_tabular_examples(model_name: str) -> dict[str, list[float]]: 14 | readme = requests.get( 15 | f"https://huggingface.co/{model_name}/resolve/main/README.md" 16 | ) 17 | example_data: Any 18 | if readme.status_code != 200: 19 | warnings.warn( # noqa: B028 20 | f"Cannot load examples from README for {model_name}", UserWarning 21 | ) 22 | example_data = {} 23 | else: 24 | yaml_regex = re.search( 25 | "(?:^|[\r\n])---[\n\r]+([\\S\\s]*?)[\n\r]+---([\n\r]|$)", 26 | readme.text, 27 | ) 28 | if yaml_regex is None: 29 | example_data = {} 30 | else: 31 | example_yaml = next( 32 | yaml.safe_load_all(readme.text[: yaml_regex.span()[-1]]) 33 | ) 34 | example_data = example_yaml.get("widget", {}).get( 35 | "structuredData", {} 36 | ) 37 | if not example_data: 38 | raise ValueError( 39 | f"No example data found in README.md of {model_name} - Cannot " 40 | "build demo. " 41 | "See the README.md here: " 42 | "https://huggingface.co/scikit-learn/tabular-playground/blob/main/README.md " # noqa: E501 43 | "for a reference on how to provide example data to your model." 44 | ) 45 | # replace nan with string NaN for inference Endpoints 46 | for data in example_data.values(): 47 | for i, val in enumerate(data): 48 | if isinstance(val, float) and math.isnan(val): 49 | data[i] = "NaN" 50 | return example_data 51 | 52 | 53 | def cols_to_rows( 54 | example_data: dict[str, list[float]], 55 | ) -> tuple[list[str], list[list[float]]]: 56 | headers = list(example_data.keys()) 57 | n_rows = max(len(example_data[header] or []) for header in headers) 58 | data = [] 59 | row_data: list[Any] 60 | for row_index in range(n_rows): 61 | row_data = [] 62 | for header in headers: 63 | col = example_data[header] or [] 64 | if row_index >= len(col): 65 | row_data.append("NaN") 66 | else: 67 | row_data.append(col[row_index]) 68 | data.append(row_data) 69 | return headers, data 70 | 71 | 72 | def postprocess_label(scores: list[dict[str, str | float]]) -> dict: 73 | return {c["label"]: c["score"] for c in scores} 74 | 75 | 76 | def postprocess_mask_tokens(scores: list[dict[str, str | float]]) -> dict: 77 | return {c["token_str"]: c["score"] for c in scores} 78 | 79 | 80 | def postprocess_question_answering(answer: dict) -> tuple[str, dict]: 81 | return answer["answer"], {answer["answer"]: answer["score"]} 82 | 83 | 84 | def postprocess_visual_question_answering( 85 | scores: list[dict[str, str | float]] 86 | ) -> dict: 87 | return {c["answer"]: c["score"] for c in scores} 88 | 89 | 90 | def zero_shot_classification_wrapper(client: InferenceClient): 91 | def zero_shot_classification_inner( 92 | inp: str, labels: str, multi_label: bool 93 | ): 94 | return client.zero_shot_classification( 95 | inp, labels.split(","), multi_label=multi_label 96 | ) 97 | 98 | return zero_shot_classification_inner 99 | 100 | 101 | def sentence_similarity_wrapper(client: InferenceClient): 102 | def sentence_similarity_inner(inp: str, sentences: str): 103 | return client.sentence_similarity(inp, sentences.split("\n")) 104 | 105 | return sentence_similarity_inner 106 | 107 | 108 | def text_generation_wrapper(client: InferenceClient): 109 | def text_generation_inner(inp: str): 110 | return inp + client.text_generation(inp) 111 | 112 | return text_generation_inner 113 | 114 | 115 | def format_ner_list( 116 | input_string: str, ner_groups: list[dict[str, str | int]] 117 | ) -> list[Any]: 118 | if len(ner_groups) == 0: 119 | return [(input_string, None)] 120 | 121 | output = [] 122 | end = 0 123 | prev_end = 0 124 | 125 | for group in ner_groups: 126 | entity, start, end = ( 127 | group["entity_group"], 128 | group["start"], 129 | group["end"], # type: ignore 130 | ) 131 | output.append((input_string[prev_end:start], None)) # type: ignore 132 | output.append((input_string[start:end], entity)) # type: ignore 133 | prev_end = end 134 | 135 | output.append((input_string[end:], None)) 136 | return output 137 | 138 | 139 | def file_contents_wrapper(fn: Callable[..., Any]) -> Callable[..., Any]: 140 | return lambda file_upload_results: fn(file_upload_results.contents) 141 | 142 | 143 | def token_classification_wrapper( 144 | client: InferenceClient, 145 | ) -> Callable[[str], list[dict[str, str | int]]]: 146 | def token_classification_inner(inp: str) -> list[dict[str, str | int]]: 147 | ner_list = client.token_classification(inp) 148 | return format_ner_list(inp, ner_list) # type: ignore 149 | 150 | return token_classification_inner 151 | 152 | 153 | def tabular_wrapper( 154 | client: InferenceClient, pipeline: str 155 | ) -> Callable[..., Any]: 156 | # This wrapper is needed to handle an issue in the InfereneClient where the 157 | # model name is not automatically loaded when using the 158 | # tabular_classification and tabular_regression methods. 159 | # See: https://github.com/huggingface/huggingface_hub/issues/2015 160 | def tabular_inner(data): 161 | if pipeline not in ("tabular_classification", "tabular_regression"): 162 | raise TypeError(f"pipeline type {pipeline!r} not supported") 163 | assert client.model # noqa: S101 164 | if pipeline == "tabular_classification": 165 | return client.tabular_classification(data, model=client.model) 166 | else: 167 | return client.tabular_regression(data, model=client.model) 168 | 169 | return tabular_inner 170 | 171 | 172 | def object_detection_wrapper( 173 | client: InferenceClient, 174 | ) -> Callable[[str], tuple[str, list[Any]]]: 175 | def object_detection_inner(inp: str) -> tuple[str, list[Any]]: 176 | annotations = client.object_detection(inp) 177 | formatted_annotations = [ 178 | ( 179 | ( 180 | a["box"]["xmin"], 181 | a["box"]["ymin"], 182 | a["box"]["xmax"], 183 | a["box"]["ymax"], 184 | ), 185 | a["label"], 186 | ) 187 | for a in annotations 188 | ] 189 | return (inp, formatted_annotations) 190 | 191 | return object_detection_inner 192 | -------------------------------------------------------------------------------- /marimo_labs/huggingface/_outputs.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | import io 4 | from typing import Any, Callable 5 | 6 | import marimo as mo 7 | import PIL.Image 8 | 9 | 10 | class Output: 11 | def __init__(self, value: object, html: mo.Html | None = None) -> None: 12 | self.value = value 13 | self.html = html 14 | 15 | def _mime_(self) -> tuple[str, str]: 16 | return ( 17 | "text/html", 18 | ( 19 | self.html.text 20 | if self.html is not None 21 | else mo.as_html(self.value).text 22 | ), 23 | ) 24 | 25 | 26 | def default_output(data: bytes) -> Output: 27 | return Output(value=data) 28 | 29 | 30 | def construct_output_function( 31 | inference_function: Callable[..., object], 32 | output: Callable[[Any], Output] = default_output, 33 | ): 34 | return lambda *args: output(inference_function(*args)) 35 | 36 | 37 | def image_output(value: PIL.Image.Image) -> Output: 38 | stream = io.BytesIO() 39 | value.save(stream, format="PNG") 40 | return Output(value, html=mo.image(stream)) 41 | 42 | 43 | def audio_output_from_path(value: str) -> Output: 44 | with open(value, "rb") as f: 45 | audio = mo.audio(io.BytesIO(f.read())) 46 | 47 | return Output(value=value, html=audio) 48 | 49 | 50 | def audio_output_from_bytes(value: bytes) -> Output: 51 | audio = mo.audio(io.BytesIO(value)) 52 | return Output(value=value, html=audio) 53 | -------------------------------------------------------------------------------- /marimo_labs/huggingface/_processing_utils.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | import base64 4 | import hashlib 5 | from os.path import abspath 6 | from pathlib import Path 7 | 8 | import requests # type: ignore 9 | 10 | from marimo_labs.huggingface import _client_utils 11 | 12 | 13 | def hash_base64(base64_encoding: str, chunk_num_blocks: int = 128) -> str: 14 | sha1 = hashlib.sha1() 15 | for i in range( 16 | 0, len(base64_encoding), chunk_num_blocks * sha1.block_size 17 | ): 18 | data = base64_encoding[i : i + chunk_num_blocks * sha1.block_size] 19 | sha1.update(data.encode("utf-8")) 20 | return sha1.hexdigest() 21 | 22 | 23 | def save_base64_to_cache( 24 | base64_encoding: str, cache_dir: str, file_name: str | None = None 25 | ) -> str: 26 | """Converts a base64 encoding to a file and returns the path to the file if 27 | the file doesn't already exist. Otherwise returns the path to the existing 28 | file. 29 | """ 30 | temp_dir_name = hash_base64(base64_encoding) 31 | temp_dir = Path(cache_dir) / temp_dir_name 32 | temp_dir.mkdir(exist_ok=True, parents=True) 33 | 34 | guess_extension = _client_utils.get_extension(base64_encoding) 35 | if file_name: 36 | file_name = _client_utils.strip_invalid_filename_characters(file_name) 37 | elif guess_extension: 38 | file_name = f"file.{guess_extension}" 39 | else: 40 | file_name = "file" 41 | 42 | full_temp_file_path = str(abspath(temp_dir / file_name)) # type: ignore 43 | 44 | if not Path(full_temp_file_path).exists(): 45 | data, _ = _client_utils.decode_base64_to_binary(base64_encoding) 46 | with open(full_temp_file_path, "wb") as fb: 47 | fb.write(data) 48 | 49 | return full_temp_file_path 50 | 51 | 52 | def extract_base64_data(x: str) -> str: 53 | """Just extracts the base64 data from a general base64 string.""" 54 | return x.rsplit(",", 1)[-1] 55 | 56 | 57 | def to_binary(x: str | dict) -> bytes: 58 | """Converts a base64 string or dictionary to a binary string that can be 59 | sent in a POST.""" 60 | if isinstance(x, dict): 61 | if x.get("data"): 62 | base64str = x["data"] 63 | else: 64 | base64str = _client_utils.encode_url_or_file_to_base64(x["path"]) 65 | else: 66 | base64str = x 67 | return base64.b64decode(extract_base64_data(base64str)) 68 | 69 | 70 | def encode_to_base64(r: requests.Response) -> str: 71 | # Handles the different ways HF API returns the prediction 72 | base64_repr = base64.b64encode(r.content).decode("utf-8") 73 | data_prefix = ";base64," 74 | # Case 1: base64 representation already includes data prefix 75 | if data_prefix in base64_repr: 76 | return base64_repr 77 | else: 78 | content_type = r.headers.get("content-type") 79 | # Case 2: the data prefix is a key in the response 80 | if content_type == "application/json": 81 | try: 82 | data = r.json()[0] 83 | content_type = data["content-type"] 84 | base64_repr = data["blob"] 85 | except KeyError as ke: 86 | raise ValueError( 87 | "Cannot determine content type returned by external API." 88 | ) from ke 89 | # Case 3: the data prefix is included in the response headers 90 | else: 91 | pass 92 | new_base64 = f"data:{content_type};base64,{base64_repr}" 93 | return new_base64 94 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [build-system] 2 | requires = ["setuptools"] 3 | build-backend = "setuptools.build_meta" 4 | 5 | [project] 6 | name = "marimo_labs" 7 | description = "Cutting-edge experimental features for marimo" 8 | version = "0.1.0" 9 | dependencies = [ 10 | "marimo>=0.3.8", 11 | "huggingface_hub>=0.19.3", 12 | "requests>=2.0", 13 | "pyyaml>=5.0", 14 | "Pillow>=7.0", 15 | "numpy>=1.21.0", 16 | ] 17 | readme = "README.md" 18 | license = { file = "LICENSE" } 19 | requires-python = ">=3.8" 20 | classifiers = [ 21 | "Operating System :: OS Independent", 22 | "License :: OSI Approved :: Apache Software License", 23 | "Environment :: Console", 24 | "Environment :: Web Environment", 25 | "Intended Audience :: Developers", 26 | "Intended Audience :: Science/Research", 27 | "Intended Audience :: Education", 28 | "Programming Language :: Python", 29 | "Programming Language :: Python :: 3.8", 30 | "Programming Language :: Python :: 3.9", 31 | "Programming Language :: Python :: 3.10", 32 | "Programming Language :: Python :: 3.11", 33 | "Programming Language :: Python :: 3 :: Only", 34 | ] 35 | 36 | [project.urls] 37 | homepage = "https://github.com/marimo-team/marimo_labs" 38 | 39 | [project.optional-dependencies] 40 | dev = [ 41 | "black~=23.3.0", 42 | "build~=0.10.0", 43 | "mypy~=1.9.0", 44 | "ruff~=0.3.5", 45 | "typos~=1.20.4", 46 | "pytest~=8.1.1", 47 | ] 48 | 49 | [tool.setuptools.packages.find] 50 | # project source is entirely contained in the `marimo` package 51 | include = ["marimo_labs*"] 52 | 53 | [tool.ruff] 54 | line-length=79 55 | exclude = [ 56 | "docs", 57 | "build", 58 | ] 59 | lint.ignore = [] 60 | lint.select = [ 61 | # pyflakes 62 | "F", 63 | # pycodestyle 64 | "E", 65 | # warning 66 | "W", 67 | # flake8 builtin-variable-shadowing 68 | "A001", 69 | # flake8 builtin-argument-shadowing 70 | "A002", 71 | # flake8-unused-arguments 72 | "ARG", 73 | # flake8-bugbear 74 | "B", 75 | # future annotations 76 | "FA102", 77 | # isort 78 | "I001", 79 | ] 80 | 81 | # Never try to fix `F401` (unused imports). 82 | lint.unfixable = ["F401"] 83 | 84 | [tool.ruff.lint.isort] 85 | required-imports = ["from __future__ import annotations"] 86 | 87 | [tool.black] 88 | line-length = 79 89 | 90 | [tool.mypy] 91 | strict = false 92 | exclude = [] 93 | warn_unused_ignores=false 94 | 95 | [tool.pytest.ini_options] 96 | minversion = "6.0" 97 | testpaths = [ 98 | "tests", 99 | ] 100 | 101 | [tool.typos.default.extend-words] 102 | wheres = "wheres" 103 | -------------------------------------------------------------------------------- /scripts/pyfix.sh: -------------------------------------------------------------------------------- 1 | #!/bin/sh 2 | 3 | echo "[fix: ruff]" 4 | ruff check marimo_labs/ --fix 5 | echo "[fix: black]" 6 | black marimo_labs/ 7 | black tests/ 8 | echo "[check: typecheck]" 9 | mypy marimo_labs/ 10 | -------------------------------------------------------------------------------- /tests/huggingface/conftest.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | 3 | 4 | @pytest.fixture(autouse=True) 5 | def marimo_temp_dir(monkeypatch, tmp_path): 6 | """tmp_path is unique to each test function. 7 | It will be cleared automatically according to pytest docs: 8 | https://docs.pytest.org/en/6.2.x/reference.html#tmp-path 9 | """ 10 | monkeypatch.setenv("MARIMO_TEMP_DIR", str(tmp_path)) 11 | return tmp_path 12 | -------------------------------------------------------------------------------- /tests/huggingface/test_processing_utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | from copy import deepcopy 3 | from pathlib import Path 4 | 5 | import numpy as np 6 | import pytest 7 | from PIL import Image, ImageCms 8 | 9 | from marimo_labs.huggingface import _processing_utils as processing_utils 10 | import media_data 11 | 12 | 13 | class TestTempFileManagement: 14 | def test_save_b64_to_cache(self, marimo_temp_dir): 15 | base64_file_1 = media_data.BASE64_IMAGE 16 | base64_file_2 = media_data.BASE64_AUDIO["data"] 17 | 18 | f = processing_utils.save_base64_to_cache( 19 | base64_file_1, cache_dir=marimo_temp_dir 20 | ) 21 | try: # Delete if already exists from before this test 22 | os.remove(f) 23 | except OSError: 24 | pass 25 | 26 | f = processing_utils.save_base64_to_cache( 27 | base64_file_1, cache_dir=marimo_temp_dir 28 | ) 29 | assert ( 30 | len([f for f in marimo_temp_dir.glob("**/*") if f.is_file()]) == 1 31 | ) 32 | 33 | f = processing_utils.save_base64_to_cache( 34 | base64_file_1, cache_dir=marimo_temp_dir 35 | ) 36 | assert ( 37 | len([f for f in marimo_temp_dir.glob("**/*") if f.is_file()]) == 1 38 | ) 39 | 40 | f = processing_utils.save_base64_to_cache( 41 | base64_file_2, cache_dir=marimo_temp_dir 42 | ) 43 | assert ( 44 | len([f for f in marimo_temp_dir.glob("**/*") if f.is_file()]) == 2 45 | ) 46 | --------------------------------------------------------------------------------