├── .github ├── ISSUE_TEMPLATE │ └── bug_report.md └── workflows │ ├── test.yml │ └── wheels.yml ├── .gitignore ├── LICENSE ├── README.md ├── assets ├── graph-diagram-dark.png └── graph-diagram.png ├── example.png ├── examples ├── basketball.png ├── basketball_workflow.py ├── construction.jpg ├── construction_workflow1.py ├── construction_workflow2.py ├── construction_workflow3.py ├── dense_street1.jpg ├── dense_street2.jpg ├── dense_street3.jpg ├── drivers_license.jpg ├── drivers_license_workflow.py ├── eiffel_photo.png ├── eiffel_photo_workflow.py ├── kites.jpg ├── kites_workflow.py └── kites_workflow_local.py ├── gradio_example.html ├── kites_workflow.html ├── mypy.ini ├── overeasy ├── __init__.py ├── agents │ ├── __init__.py │ ├── dark_mode.txt │ ├── misc │ │ ├── __init__.py │ │ ├── class_map.py │ │ ├── filter_class.py │ │ ├── map_agent.py │ │ ├── max_confidence.py │ │ ├── min_crop.py │ │ ├── nms.py │ │ ├── pad_crop.py │ │ ├── split_crop.py │ │ └── to_classification.py │ ├── model_agents.py │ ├── scrape.py │ ├── split_join_agent.py │ └── workflow.py ├── assets │ ├── egg_logo.png │ └── favicon.ico ├── dirs.py ├── download_utils.py ├── logging.py ├── models │ ├── LLMs │ │ ├── __init__.py │ │ ├── anthropic.py │ │ ├── gemini.py │ │ ├── openai.py │ │ ├── pali_gemma.py │ │ └── qwenvl.py │ ├── __init__.py │ ├── classification │ │ ├── __init__.py │ │ ├── clip.py │ │ ├── laion_clip.py │ │ └── siglip.py │ ├── detection │ │ ├── __init__.py │ │ ├── detclipv2.py │ │ ├── detic.py │ │ ├── dino.py │ │ ├── owlv2.py │ │ └── yoloworld.py │ └── recognition │ │ └── __init__.py ├── py.typed ├── types │ ├── __init__.py │ ├── base.py │ ├── detections.py │ └── type_utils.py └── visualize_utils.py ├── poetry.lock ├── pyproject.toml ├── pytest.ini ├── tests ├── conftest.py ├── count_eggs.jpg ├── dogs │ ├── dog1.png │ ├── dog2.png │ ├── dog3.png │ ├── dog4.png │ ├── dog5.png │ ├── dog6.png │ └── dog7.png ├── plate.jpg ├── test.png ├── test_construction_workflows.py ├── test_detection_models.py ├── test_import.py ├── test_instructor_agents.py ├── test_large_local_models.py ├── test_misc_agents.py ├── test_model_agents.py ├── test_owl.py ├── test_saving_vis.py └── test_split_join.py └── warmup.py /.github/ISSUE_TEMPLATE/bug_report.md: -------------------------------------------------------------------------------- 1 | --- 2 | name: Bug report 3 | about: Create a report to help us improve 4 | title: '' 5 | labels: '' 6 | assignees: '' 7 | 8 | --- 9 | 10 | **Describe the bug** 11 | A clear and concise description of what the bug is. 12 | 13 | **To Reproduce** 14 | Steps to reproduce the behavior: 15 | 1. Go to '...' 16 | 2. Click on '....' 17 | 3. Scroll down to '....' 18 | 4. See error 19 | 20 | **Expected behavior** 21 | A clear and concise description of what you expected to happen. 22 | 23 | **Screenshots** 24 | If applicable, add screenshots to help explain your problem. 25 | 26 | **Debugging Info** 27 | Please include all of the following 28 | - OS: [e.g. Mac, Linux, Windows] 29 | - Python Version 30 | - CUDA Version (nvcc --version) 31 | -------------------------------------------------------------------------------- /.github/workflows/test.yml: -------------------------------------------------------------------------------- 1 | name: Run Tests 2 | on: 3 | workflow_dispatch: # Only manual triggering 4 | 5 | jobs: 6 | test: 7 | runs-on: ${{ matrix.os }} 8 | strategy: 9 | matrix: 10 | os: [ubuntu-latest, macos-latest] 11 | python-version: ['3.10', '3.11', '3.12'] 12 | include: 13 | - os: ubuntu-latest 14 | architecture: x64 15 | - os: macos-latest 16 | architecture: x64 17 | - os: macos-latest 18 | architecture: arm64 19 | steps: 20 | - name: Check out the repository 21 | uses: actions/checkout@v2 22 | - name: Set up Python 23 | uses: actions/setup-python@v2 24 | with: 25 | python-version: ${{ matrix.python-version }} 26 | architecture: ${{ matrix.architecture }} 27 | - name: Install Poetry 28 | run: | 29 | curl -sSL https://install.python-poetry.org | python - 30 | - name: Install dependencies 31 | run: | 32 | poetry install 33 | - name: Run tests 34 | run: | 35 | poetry run pytest -------------------------------------------------------------------------------- /.github/workflows/wheels.yml: -------------------------------------------------------------------------------- 1 | name: Build and upload to PyPI 2 | 3 | on: 4 | workflow_dispatch: 5 | release: 6 | types: 7 | - published 8 | 9 | jobs: 10 | build_wheels: 11 | runs-on: ubuntu-latest 12 | steps: 13 | - uses: actions/checkout@v4 14 | 15 | - name: Set up Python 16 | uses: actions/setup-python@v3 17 | with: 18 | python-version: '3.9' # or any version you're targeting 19 | 20 | - name: Build wheel 21 | run: | 22 | python -m pip install build 23 | python -m build --wheel --outdir wheelhouse 24 | 25 | - uses: actions/upload-artifact@v4 26 | with: 27 | name: pure-python-wheels 28 | path: ./wheelhouse/*.whl 29 | 30 | build_sdist: 31 | name: Build source distribution 32 | runs-on: ubuntu-latest 33 | steps: 34 | - uses: actions/checkout@v4 35 | 36 | - name: Build sdist 37 | run: pipx run build --sdist 38 | 39 | - uses: actions/upload-artifact@v4 40 | with: 41 | name: cibw-sdist 42 | path: dist/*.tar.gz 43 | 44 | upload_pypi: 45 | needs: [build_wheels, build_sdist] 46 | runs-on: ubuntu-latest 47 | environment: prod 48 | permissions: 49 | id-token: write 50 | if: github.event_name == 'release' && github.event.action == 'published' 51 | # or, alternatively, upload to PyPI on every tag starting with 'v' (remove on: release above to use this) 52 | # if: github.event_name == 'push' && startsWith(github.ref, 'refs/tags/v') 53 | steps: 54 | - uses: actions/download-artifact@v4 55 | with: 56 | # unpacks all CIBW artifacts into dist/ 57 | pattern: cibw-* 58 | path: dist 59 | merge-multiple: true 60 | 61 | - uses: pypa/gh-action-pypi-publish@release/v1 62 | # To test: repository-url: https://test.pypi.org/legacy/ -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | tests/outputs 2 | # Byte-compiled / optimized / DLL files 3 | __pycache__/ 4 | *.py[cod] 5 | *$py.class 6 | 7 | *.DS_Store 8 | # C extensions 9 | *.so 10 | 11 | # Distribution / packaging 12 | .Python 13 | build/ 14 | develop-eggs/ 15 | dist/ 16 | downloads/ 17 | eggs/ 18 | .eggs/ 19 | lib/ 20 | lib64/ 21 | parts/ 22 | sdist/ 23 | var/ 24 | wheels/ 25 | share/python-wheels/ 26 | *.egg-info/ 27 | .installed.cfg 28 | *.egg 29 | MANIFEST 30 | 31 | # PyInstaller 32 | # Usually these files are written by a python script from a template 33 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 34 | *.manifest 35 | *.spec 36 | 37 | # Installer logs 38 | pip-log.txt 39 | pip-delete-this-directory.txt 40 | 41 | # Unit test / coverage reports 42 | htmlcov/ 43 | .tox/ 44 | .nox/ 45 | .coverage 46 | .coverage.* 47 | .cache 48 | nosetests.xml 49 | coverage.xml 50 | *.cover 51 | *.py,cover 52 | .hypothesis/ 53 | .pytest_cache/ 54 | cover/ 55 | 56 | # Translations 57 | *.mo 58 | *.pot 59 | 60 | # Django stuff: 61 | *.log 62 | local_settings.py 63 | db.sqlite3 64 | db.sqlite3-journal 65 | 66 | # Flask stuff: 67 | instance/ 68 | .webassets-cache 69 | 70 | # Scrapy stuff: 71 | .scrapy 72 | 73 | # Sphinx documentation 74 | docs/_build/ 75 | 76 | # PyBuilder 77 | .pybuilder/ 78 | target/ 79 | 80 | # Jupyter Notebook 81 | .ipynb_checkpoints 82 | 83 | # IPython 84 | profile_default/ 85 | ipython_config.py 86 | 87 | # pyenv 88 | # For a library or package, you might want to ignore these files since the code is 89 | # intended to run in multiple environments; otherwise, check them in: 90 | # .python-version 91 | 92 | # pipenv 93 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 94 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 95 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 96 | # install all needed dependencies. 97 | #Pipfile.lock 98 | 99 | # poetry 100 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. 101 | # This is especially recommended for binary packages to ensure reproducibility, and is more 102 | # commonly ignored for libraries. 103 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control 104 | #poetry.lock 105 | 106 | # pdm 107 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. 108 | #pdm.lock 109 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it 110 | # in version control. 111 | # https://pdm.fming.dev/#use-with-ide 112 | .pdm.toml 113 | 114 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm 115 | __pypackages__/ 116 | 117 | # Celery stuff 118 | celerybeat-schedule 119 | celerybeat.pid 120 | 121 | # SageMath parsed files 122 | *.sage.py 123 | 124 | # Environments 125 | .env 126 | .venv 127 | env/ 128 | venv/ 129 | ENV/ 130 | env.bak/ 131 | venv.bak/ 132 | 133 | # Spyder project settings 134 | .spyderproject 135 | .spyproject 136 | 137 | # Rope project settings 138 | .ropeproject 139 | 140 | # mkdocs documentation 141 | /site 142 | 143 | # mypy 144 | .mypy_cache/ 145 | .dmypy.json 146 | dmypy.json 147 | 148 | # Pyre type checker 149 | .pyre/ 150 | 151 | # pytype static type analyzer 152 | .pytype/ 153 | 154 | # Cython debug symbols 155 | cython_debug/ 156 | 157 | # PyCharm 158 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can 159 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore 160 | # and can be added to the global gitignore or merged into this file. For a more nuclear 161 | # option (not recommended) you can uncomment the following to ignore the entire idea folder. 162 | #.idea/ -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2024 Overeasy 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 |

🥚 Overeasy 2 |
3 | 4 | Issues 5 | License 6 | Docs 7 | Colab Demo 8 | 9 |

10 | 11 | 12 | 13 | 14 |

Create powerful zero-shot vision models!

15 | 16 | 17 | Overeasy allows you to chain zero-shot vision models to create custom end-to-end pipelines for tasks like: 18 | 19 | - 📦 Bounding Box Detection 20 | - 🏷️ Classification 21 | - 🖌️ Segmentation (Coming Soon!) 22 | 23 | All of this can be achieved without needing to collect and annotate large training datasets. 24 | 25 | Overeasy makes it simple to combine pre-trained zero-shot models to build powerful custom computer vision solutions. 26 | 27 | 28 | ## Installation 29 | It's as easy as 30 | ```bash 31 | pip install overeasy 32 | ``` 33 | 34 | For installing extras refer to our [Docs](https://docs.overeasy.sh/installation/installing-extras). 35 | 36 | ## Key Features 37 | - `🤖 Agents`: Specialized tools that perform specific image processing tasks. 38 | - `🧩 Workflows`: Define a sequence of Agents to process images in a structured manner. 39 | - `🔗 Execution Graphs`: Manage and visualize the image processing pipeline. 40 | - `🔎 Detections`: Represent bounding boxes, segmentation, and classifications. 41 | 42 | 43 | ## Documentation 44 | For more details on types, library structure, and available models please refer to our [Docs](https://docs.overeasy.sh). 45 | 46 | ## Example Usage 47 | 48 | > Note: If you don't have a local GPU, you can run our examples by making a copy of this [Colab notebook](https://colab.research.google.com/drive/1Mkx9S6IG5130wiP9WmwgINiyw0hPsh3c?usp=sharing#scrollTo=L0_U27WJaTNO). 49 | 50 | 51 | Download example image 52 | ```bash 53 | !wget https://github.com/overeasy-sh/overeasy/blob/73adbaeba51f532a7023243266da826ed1ced6ec/examples/construction.jpg?raw=true -O construction.jpg 54 | ``` 55 | 56 | Example workflow to identify if a person is wearing a PPE on a work site: 57 | ```python 58 | from overeasy import * 59 | from overeasy.models import OwlV2 60 | from PIL import Image 61 | 62 | workflow = Workflow([ 63 | # Detect each head in the input image 64 | BoundingBoxSelectAgent(classes=["person's head"], model=OwlV2()), 65 | # Applies Non-Maximum Suppression to remove overlapping bounding boxes 66 | NMSAgent(iou_threshold=0.5, score_threshold=0), 67 | # Splits the input image into images of each detected head 68 | SplitAgent(), 69 | # Classifies the split images using CLIP 70 | ClassificationAgent(classes=["hard hat", "no hard hat"]), 71 | # Maps the returned class names 72 | ClassMapAgent({"hard hat": "has ppe", "no hard hat": "no ppe"}), 73 | # Combines results back into a BoundingBox Detection 74 | JoinAgent() 75 | ]) 76 | 77 | image = Image.open("./construction.jpg") 78 | result, graph = workflow.execute(image) 79 | workflow.visualize(graph) 80 | ``` 81 | 82 | ### Diagram 83 | 84 | Here's a diagram of this workflow. Each layer in the graph represents a step in the workflow: 85 | 87 | 88 | 89 | 90 | 91 | Diagram 92 | 93 | 94 | The image and data attributes in each node are used together to visualize the current state of the workflow. Calling the `visualize` function on the workflow will spawn a Gradio instance that looks like [this](https://overeasy-sh.github.io/gradio-example/Gradio.html). 95 | 96 | ## Support 97 | If you have any questions or need assistance, please open an issue or reach out to us at help@overeasy.sh. 98 | 99 | 100 | Let's build amazing vision models together 🍳! 101 | -------------------------------------------------------------------------------- /assets/graph-diagram-dark.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/overeasy-sh/overeasy/8de64ae6a21e0e8c757d3359c52c6134558c34b5/assets/graph-diagram-dark.png -------------------------------------------------------------------------------- /assets/graph-diagram.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/overeasy-sh/overeasy/8de64ae6a21e0e8c757d3359c52c6134558c34b5/assets/graph-diagram.png -------------------------------------------------------------------------------- /example.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/overeasy-sh/overeasy/8de64ae6a21e0e8c757d3359c52c6134558c34b5/example.png -------------------------------------------------------------------------------- /examples/basketball.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/overeasy-sh/overeasy/8de64ae6a21e0e8c757d3359c52c6134558c34b5/examples/basketball.png -------------------------------------------------------------------------------- /examples/basketball_workflow.py: -------------------------------------------------------------------------------- 1 | import os 2 | from overeasy import * 3 | from overeasy.models import GPTVision 4 | from PIL import Image 5 | 6 | workflow = Workflow([ 7 | BinaryChoiceAgent("Do the Celtics have possession of the ball?", model=GPTVision(model="gpt-4o")) 8 | ]) 9 | image_path = os.path.join(os.path.dirname(__file__), "basketball.png") 10 | image = Image.open(image_path) 11 | result, graph = workflow.execute(image) 12 | print(result[0].data.class_names) -------------------------------------------------------------------------------- /examples/construction.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/overeasy-sh/overeasy/8de64ae6a21e0e8c757d3359c52c6134558c34b5/examples/construction.jpg -------------------------------------------------------------------------------- /examples/construction_workflow1.py: -------------------------------------------------------------------------------- 1 | import os 2 | from overeasy import Workflow, BoundingBoxSelectAgent, ToClassificationAgent, JoinAgent, SplitAgent, InstructorImageAgent, YOLOWorld 3 | from pydantic import BaseModel 4 | from PIL import Image 5 | 6 | class PPE(BaseModel): 7 | hardhat: bool 8 | vest: bool 9 | boots: bool 10 | 11 | 12 | workflow = Workflow([ 13 | BoundingBoxSelectAgent(classes=["person"], model=YOLOWorld()), 14 | SplitAgent(), 15 | InstructorImageAgent(response_model=PPE), 16 | ToClassificationAgent(fn=lambda x: "has ppe" if x.hardhat else "no ppe"), 17 | # ClassMapAgent({"has ppe": "has ppe", "no ppe": "no ppe"}), 18 | JoinAgent(), 19 | ]) 20 | 21 | image_path = os.path.join(os.path.dirname(__file__), "construction.jpg") 22 | image = Image.open(image_path) 23 | result, graph = workflow.execute(image) 24 | workflow.visualize(graph) 25 | -------------------------------------------------------------------------------- /examples/construction_workflow2.py: -------------------------------------------------------------------------------- 1 | import os 2 | from overeasy import * 3 | from overeasy.models import OwlV2 4 | from PIL import Image 5 | 6 | workflow = Workflow([ 7 | BoundingBoxSelectAgent(classes=["person's head"], model=OwlV2()), 8 | NMSAgent(iou_threshold=0.5, score_threshold=0), 9 | SplitAgent(), 10 | ClassificationAgent(classes=["hard hat", "no hard hat"]), 11 | ClassMapAgent({"hard hat": "has ppe", "no hard hat": "no ppe"}), 12 | JoinAgent(), 13 | ]) 14 | 15 | image_path = os.path.join(os.path.dirname(__file__), "construction.jpg") 16 | image = Image.open(image_path) 17 | result, graph = workflow.execute(image) 18 | workflow.visualize(graph) 19 | -------------------------------------------------------------------------------- /examples/construction_workflow3.py: -------------------------------------------------------------------------------- 1 | import os 2 | from overeasy import * 3 | from overeasy.models import OwlV2 4 | from PIL import Image 5 | 6 | workflow = Workflow([ 7 | # Detect each head in the input image 8 | BoundingBoxSelectAgent(classes=["person's head"], model=OwlV2()), 9 | # Applies Non-Maximum Suppression to the detected heads 10 | # It's reccomended to use NMS when using OwlV2 11 | NMSAgent(iou_threshold=0.5), 12 | # Splits the input image into an image of each detected head 13 | SplitAgent(), 14 | # Classifies PPE using CLIP 15 | ClassificationAgent(classes=["hard hat", "no hard hat"]), 16 | # Maps the returned class names 17 | ClassMapAgent({"hard hat": "has ppe", "no hard hat": "no ppe"}), 18 | # Combines results back into a BoundingBox Detection 19 | JoinAgent() 20 | ]) 21 | 22 | image_path = os.path.join(os.path.dirname(__file__), "construction.jpg") 23 | image = Image.open(image_path) 24 | result, graph = workflow.execute(image) 25 | workflow.visualize(graph) -------------------------------------------------------------------------------- /examples/dense_street1.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/overeasy-sh/overeasy/8de64ae6a21e0e8c757d3359c52c6134558c34b5/examples/dense_street1.jpg -------------------------------------------------------------------------------- /examples/dense_street2.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/overeasy-sh/overeasy/8de64ae6a21e0e8c757d3359c52c6134558c34b5/examples/dense_street2.jpg -------------------------------------------------------------------------------- /examples/dense_street3.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/overeasy-sh/overeasy/8de64ae6a21e0e8c757d3359c52c6134558c34b5/examples/dense_street3.jpg -------------------------------------------------------------------------------- /examples/drivers_license.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/overeasy-sh/overeasy/8de64ae6a21e0e8c757d3359c52c6134558c34b5/examples/drivers_license.jpg -------------------------------------------------------------------------------- /examples/drivers_license_workflow.py: -------------------------------------------------------------------------------- 1 | import os 2 | from overeasy import * 3 | from overeasy.models import GPTVision 4 | from PIL import Image 5 | from pydantic import BaseModel 6 | 7 | class License(BaseModel): 8 | license_number: str 9 | expiration_date: str 10 | state: str 11 | data_of_birth: str 12 | address: str 13 | first_name: str 14 | last_name: str 15 | 16 | workflow = Workflow([ 17 | InstructorImageAgent(response_model=License, model=GPTVision()) 18 | ]) 19 | 20 | image_path = os.path.join(os.path.dirname(__file__), "drivers_license.jpg") 21 | image = Image.open(image_path) 22 | result, graph = workflow.execute(image) 23 | print(repr(result[0].data)) -------------------------------------------------------------------------------- /examples/eiffel_photo.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/overeasy-sh/overeasy/8de64ae6a21e0e8c757d3359c52c6134558c34b5/examples/eiffel_photo.png -------------------------------------------------------------------------------- /examples/eiffel_photo_workflow.py: -------------------------------------------------------------------------------- 1 | import os 2 | from overeasy import * 3 | from overeasy.models import GPTVision 4 | from PIL import Image 5 | 6 | workflow = Workflow([ 7 | DenseCaptioningAgent(model=GPTVision(model="gpt-4o")) 8 | ]) 9 | image_path = os.path.join(os.path.dirname(__file__), "eiffel_photo.png") 10 | image = Image.open(image_path) 11 | result, graph = workflow.execute(image) 12 | print(result[0].data) -------------------------------------------------------------------------------- /examples/kites.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/overeasy-sh/overeasy/8de64ae6a21e0e8c757d3359c52c6134558c34b5/examples/kites.jpg -------------------------------------------------------------------------------- /examples/kites_workflow.py: -------------------------------------------------------------------------------- 1 | import os 2 | from overeasy import * 3 | from overeasy.models import OwlV2 4 | from PIL import Image 5 | 6 | workflow = Workflow([ 7 | BoundingBoxSelectAgent(classes=["butterfly kite"], model=OwlV2()), 8 | NMSAgent(iou_threshold=0.8, score_threshold=0.6), 9 | ]) 10 | 11 | image_path = os.path.join(os.path.dirname(__file__), "kites.jpg") 12 | image = Image.open(image_path) 13 | results, graph = workflow.execute(image) 14 | workflow.visualize(graph) -------------------------------------------------------------------------------- /examples/kites_workflow_local.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import os 3 | sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), '../'))) 4 | 5 | import os 6 | from overeasy import * 7 | from overeasy.models import OwlV2 8 | from PIL import Image 9 | 10 | workflow = Workflow([ 11 | BoundingBoxSelectAgent(classes=["butterfly kite"], model=OwlV2()), 12 | NMSAgent(iou_threshold=0.8, score_threshold=0.6), 13 | ]) 14 | 15 | image_path = os.path.join(os.path.dirname(__file__), "kites.jpg") 16 | image = Image.open(image_path) 17 | results, graph = workflow.execute(image) 18 | workflow.visualize_to_file(graph, "kites_workflow.html") -------------------------------------------------------------------------------- /mypy.ini: -------------------------------------------------------------------------------- 1 | [mypy] 2 | ignore_missing_imports = True 3 | -------------------------------------------------------------------------------- /overeasy/__init__.py: -------------------------------------------------------------------------------- 1 | from overeasy.agents import * 2 | from overeasy.models import * 3 | 4 | __version__ = "0.2.16" 5 | 6 | 7 | from overeasy.dirs import ROOT as _ROOT 8 | import os as _os 9 | _os.makedirs(_ROOT, exist_ok=True) 10 | -------------------------------------------------------------------------------- /overeasy/agents/__init__.py: -------------------------------------------------------------------------------- 1 | from .misc import * 2 | from .split_join_agent import SplitAgent, JoinAgent 3 | from .workflow import Workflow 4 | from .model_agents import * 5 | -------------------------------------------------------------------------------- /overeasy/agents/dark_mode.txt: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /overeasy/agents/misc/__init__.py: -------------------------------------------------------------------------------- 1 | from .pad_crop import PadCropAgent 2 | from .split_crop import SplitCropAgent 3 | from .nms import NMSAgent 4 | from .class_map import ClassMapAgent 5 | from .map_agent import MapAgent 6 | from .to_classification import ToClassificationAgent 7 | from .filter_class import FilterClassesAgent 8 | from .max_confidence import ConfidenceFilterAgent -------------------------------------------------------------------------------- /overeasy/agents/misc/class_map.py: -------------------------------------------------------------------------------- 1 | from overeasy.types import DetectionAgent, Detections 2 | from typing import Dict 3 | 4 | class ClassMapAgent(DetectionAgent): 5 | def __init__(self, class_map: Dict[str, str]): 6 | self.class_map = class_map 7 | 8 | def _execute(self, dets: Detections) -> Detections: 9 | unmapped_classes = set(dets.class_names) - set(self.class_map.keys()) 10 | if unmapped_classes: 11 | raise ValueError(f"Class names {unmapped_classes} not mapped") 12 | class_names = [self.class_map[x] for x in dets.class_names] 13 | 14 | new_dets = Detections( 15 | xyxy=dets.xyxy, 16 | class_ids=dets.class_ids, 17 | confidence=dets.confidence, 18 | classes=class_names, 19 | data=dets.data, 20 | detection_type=dets.detection_type 21 | ) 22 | return new_dets 23 | 24 | def __repr__(self): 25 | return f"{self.__class__.__name__}(class_map={self.class_map})" 26 | -------------------------------------------------------------------------------- /overeasy/agents/misc/filter_class.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass 2 | from overeasy.types import Detections, DetectionAgent 3 | from typing import List, Optional 4 | import numpy as np 5 | 6 | @dataclass 7 | class FilterClassesAgent(DetectionAgent): 8 | class_names: Optional[List[str]] 9 | class_ids: Optional[List[int]] 10 | 11 | """ 12 | Filter detections by class name or class id. 13 | """ 14 | def __init__(self, class_names: Optional[List[str]] = None, class_ids: Optional[List[int]] = None): 15 | self.class_names = class_names 16 | self.class_ids = class_ids 17 | 18 | if class_names is None and class_ids is None: 19 | raise ValueError("Must specify class_name or class_id") 20 | if class_names is not None and class_ids is not None: 21 | raise ValueError("Can only specify one of class_names or class_ids") 22 | 23 | 24 | def _execute(self, dets: Detections) -> Detections: 25 | if self.class_ids is not None: 26 | slice = np.isin(dets.class_ids, self.class_ids) 27 | elif self.class_names is not None: 28 | slice = np.isin(dets.class_names, self.class_names) 29 | else: 30 | raise ValueError("No filter specified") 31 | 32 | return dets[slice] 33 | 34 | def __repr__(self): 35 | return f"{self.__class__.__name__}(confidence_threshold={self.confidence_threshold})" 36 | 37 | -------------------------------------------------------------------------------- /overeasy/agents/misc/map_agent.py: -------------------------------------------------------------------------------- 1 | from typing import Callable, Any 2 | from overeasy.types import DataAgent 3 | 4 | class MapAgent(DataAgent): 5 | def __init__(self, fn: Callable[[Any], Any]): 6 | self.fn = fn 7 | 8 | def _execute(self, data: Any) -> Any: 9 | return self.fn(data) 10 | 11 | def __repr__(self): 12 | return f"{self.__class__.__name__}(fn={self.fn})" -------------------------------------------------------------------------------- /overeasy/agents/misc/max_confidence.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass 2 | from overeasy.types import Detections, DetectionAgent 3 | from typing import List, Optional 4 | import numpy as np 5 | 6 | @dataclass 7 | class ConfidenceFilterAgent(DetectionAgent): 8 | max_n: Optional[int] = None 9 | min_confidence: Optional[float] = None 10 | 11 | """ 12 | Filter detections by confidence, either selecting the top N or those above a minimum confidence threshold. 13 | """ 14 | def __init__(self, max_n: Optional[int] = None, min_confidence: Optional[float] = None): 15 | self.max_n = max_n 16 | self.min_confidence = min_confidence 17 | 18 | if max_n is None and min_confidence is None: 19 | raise ValueError("Must specify either max_n or min_confidence") 20 | if max_n is not None and max_n < 1: 21 | raise ValueError("max_n must be at least 1") 22 | if min_confidence is not None and (min_confidence < 0 or min_confidence > 1): 23 | raise ValueError("min_confidence must be between 0 and 1") 24 | 25 | def _execute(self, dets: Detections) -> Detections: 26 | if dets.confidence is None: 27 | raise ValueError("Detections must have confidence scores") 28 | 29 | indices = np.arange(len(dets.confidence)) 30 | 31 | if self.min_confidence is not None: 32 | indices = indices[dets.confidence >= self.min_confidence] 33 | 34 | if self.max_n is not None: 35 | if len(indices) > self.max_n: 36 | sorted_indices = np.argsort(dets.confidence[indices])[-self.max_n:][::-1] 37 | indices = indices[sorted_indices] 38 | else: 39 | indices = np.argsort(dets.confidence[indices])[::-1] 40 | 41 | if len(indices) == 0: 42 | return Detections.empty() 43 | 44 | return dets[indices.astype(int)] 45 | 46 | def __repr__(self): 47 | return f"{self.__class__.__name__}(max_n={self.max_n}, min_confidence={self.min_confidence})" 48 | -------------------------------------------------------------------------------- /overeasy/agents/misc/min_crop.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/overeasy-sh/overeasy/8de64ae6a21e0e8c757d3359c52c6134558c34b5/overeasy/agents/misc/min_crop.py -------------------------------------------------------------------------------- /overeasy/agents/misc/nms.py: -------------------------------------------------------------------------------- 1 | from overeasy.types import Detections, DetectionType, DetectionAgent 2 | import cv2 3 | 4 | def do_nms(dets: Detections, nms_threshold: float, score_threshold : float = 0.0) -> Detections: 5 | if dets.detection_type != DetectionType.BOUNDING_BOX: 6 | raise ValueError("Only bounding box detections are supported for NMS.") 7 | 8 | bboxes = dets.xyxy 9 | scores = dets.confidence 10 | if scores is None: 11 | raise ValueError("Confidence scores are required for NMS.") 12 | 13 | indices = list(cv2.dnn.NMSBoxes(bboxes, scores, score_threshold, nms_threshold)) 14 | 15 | return dets[indices] 16 | 17 | class NMSAgent(DetectionAgent): 18 | # IoU (Intersection over Union) Threshold: Determines the minimum overlap between two bounding boxes to consider them as the same object. 19 | # Score Threshold: Filters out detections that have a confidence score below this value before applying NMS. 20 | def __init__(self, iou_threshold: float, score_threshold: float = 0.0): 21 | self.iou_threshold = iou_threshold 22 | self.score_threshold = score_threshold 23 | 24 | def _execute(self, dets: Detections) -> Detections: 25 | if dets.detection_type != DetectionType.BOUNDING_BOX: 26 | raise ValueError("Only bounding box detections are supported for NMS.") 27 | return do_nms(dets, self.iou_threshold, self.score_threshold) 28 | 29 | def __repr__(self): 30 | return f"NMSAgent(iou_threshold={self.iou_threshold}, score_threshold={self.score_threshold})" 31 | -------------------------------------------------------------------------------- /overeasy/agents/misc/pad_crop.py: -------------------------------------------------------------------------------- 1 | from overeasy.types import Detections, DetectionType, DetectionAgent 2 | from pydantic.dataclasses import dataclass 3 | import numpy as np 4 | 5 | @dataclass 6 | class PadCropAgent(DetectionAgent): 7 | x1_pad: int 8 | y1_pad: int 9 | x2_pad: int 10 | y2_pad: int 11 | 12 | @classmethod 13 | def from_uniform_padding(cls, padding): 14 | return cls(padding, padding, padding, padding) 15 | 16 | @classmethod 17 | def from_xy_padding(cls, x_pad, y_pad): 18 | return cls(x_pad, y_pad, x_pad, y_pad) 19 | 20 | def _execute(self, dets: Detections) -> Detections: 21 | if dets.detection_type != DetectionType.BOUNDING_BOX: 22 | raise ValueError("Only bounding box detections are supported for padding.") 23 | 24 | padded_bboxes = [] 25 | for bbox in dets.xyxy: 26 | x1, y1, x2, y2 = bbox 27 | padded_bbox = [ 28 | x1 - self.x1_pad, 29 | y1 - self.y1_pad, 30 | x2 + self.x2_pad, 31 | y2 + self.y2_pad 32 | ] 33 | padded_bboxes.append(padded_bbox) 34 | 35 | dets.xyxy = np.array(padded_bboxes) 36 | dets.__post_init__() 37 | return dets 38 | 39 | def __repr__(self): 40 | return f"{self.__class__.__name__}(x1_pad={self.x1_pad}, y1_pad={self.y1_pad}, x2_pad={self.x2_pad}, y2_pad={self.y2_pad})" -------------------------------------------------------------------------------- /overeasy/agents/misc/split_crop.py: -------------------------------------------------------------------------------- 1 | from PIL import Image 2 | from pydantic.dataclasses import dataclass 3 | from typing import Tuple 4 | from overeasy.types import ExecutionNode, ImageDetectionAgent, Detections, DetectionType 5 | import numpy as np 6 | 7 | class SplitCropAgent(ImageDetectionAgent): 8 | # WE HAVE PYDANTIC AT HOME 9 | def __init__(self, split: Tuple[int, int]): 10 | if len(split) != 2: 11 | raise ValueError("Split must be a tuple of two integers") 12 | if split[0] <= 0 or split[1] <= 0: 13 | raise ValueError("Split must be greater than 0") 14 | if type(split[0]) != int or type(split[1]) != int: 15 | raise ValueError("Split must be a tuple of two integers") 16 | 17 | self.rows = split[0] 18 | self.columns = split[1] 19 | 20 | 21 | def _execute(self, image: Image.Image) -> Detections: 22 | # Convert PIL Image to numpy array 23 | width, height = image.width, image.height 24 | 25 | # Calculate the size of each block 26 | block_height = height // self.rows 27 | block_width = width // self.columns 28 | 29 | left_over_height = height - block_height * self.rows 30 | left_over_width = width - block_width * self.columns 31 | 32 | # Initialize a list to hold the bounding boxes 33 | bounding_boxes = [] 34 | 35 | # Create bounding boxes 36 | for row in range(self.rows): 37 | for col in range(self.columns): 38 | x1 = col * block_width + min(col, left_over_width) 39 | y1 = row * block_height + min(row, left_over_height) 40 | x2 = x1 + block_width + (1 if col < left_over_width else 0) 41 | y2 = y1 + block_height + (1 if row < left_over_height else 0) 42 | bounding_boxes.append((x1, y1, x2, y2)) 43 | 44 | # Convert lists to numpy arrays 45 | bounding_boxes_np = np.array(bounding_boxes) 46 | confidence_np = np.ones(len(bounding_boxes_np)) 47 | class_ids_np = np.arange(len(bounding_boxes_np)) 48 | classes_np = np.array([f'split_{i+1}' for i in range(len(bounding_boxes_np))]) 49 | 50 | # Create Detections object 51 | det = Detections( 52 | xyxy=bounding_boxes_np, 53 | confidence=confidence_np, # Confidence for each box 54 | class_ids=class_ids_np, # Unique class ID for each box 55 | classes=classes_np, # Class names 56 | detection_type=DetectionType.BOUNDING_BOX 57 | ) 58 | 59 | return det 60 | 61 | def __repr__(self): 62 | return f"{self.__class__.__name__}(split={self.split})" -------------------------------------------------------------------------------- /overeasy/agents/misc/to_classification.py: -------------------------------------------------------------------------------- 1 | from typing import Union, Callable, Any, List 2 | from overeasy.types import DataAgent, ExecutionNode, Detections 3 | 4 | class ToClassificationAgent(DataAgent): 5 | def __init__(self, fn: Union[Callable[[Any], str], Callable[[Any], List[str]]]): 6 | self.fn = fn 7 | 8 | def _execute(self, data: Any) -> Any: 9 | res = self.fn(data) 10 | if isinstance(res, list) and all(isinstance(x, str) for x in res): 11 | return Detections.from_classification(res) 12 | elif isinstance(res, str): 13 | return Detections.from_classification([res]) 14 | else: 15 | raise ValueError(f"{self.__class__.__name__} must return a string or list of strings") 16 | 17 | def __repr__(self): 18 | return f"{self.__class__.__name__}(fn={self.fn})" 19 | -------------------------------------------------------------------------------- /overeasy/agents/model_agents.py: -------------------------------------------------------------------------------- 1 | from PIL import Image 2 | from torch import mode 3 | from overeasy.models import * 4 | from overeasy.types import * 5 | from pydantic import BaseModel 6 | from typing import List, Union, Optional, Dict, Any 7 | import instructor 8 | import base64, io 9 | import google.generativeai as genai 10 | 11 | __all__ = [ 12 | "BoundingBoxSelectAgent", 13 | "VisionPromptAgent", 14 | "DenseCaptioningAgent", 15 | "TextPromptAgent", 16 | "BinaryChoiceAgent", 17 | "ClassificationAgent", 18 | "OCRAgent", 19 | "InstructorImageAgent", 20 | "InstructorTextAgent" 21 | ] 22 | 23 | class BoundingBoxSelectAgent(ImageDetectionAgent): 24 | def __init__(self, classes: List[str], model: BoundingBoxModel = GroundingDINO()): 25 | self.classes = classes 26 | self.model = model 27 | 28 | def _execute(self, image: Image.Image) -> Detections: 29 | return self.model.detect(image, self.classes) 30 | 31 | def __repr__(self): 32 | model_name = self.model.__class__.__name__ if self.model else "None" 33 | return f"{self.__class__.__name__}(classes={self.classes}, model={model_name})" 34 | 35 | class VisionPromptAgent(ImageToTextAgent): 36 | def __init__(self, query: str, model: MultimodalLLM = GPTVision()): 37 | self.query = query 38 | self.model = model 39 | 40 | def _execute(self, image: Image.Image)-> str: 41 | prompt = f"""{self.query}""" 42 | response = self.model.prompt_with_image(image, prompt) 43 | return response 44 | 45 | def __repr__(self): 46 | model_name = self.model.__class__.__name__ 47 | return f"{self.__class__.__name__}(query={self.query}, model={model_name})" 48 | 49 | class DenseCaptioningAgent(ImageToTextAgent): 50 | def __init__(self, model: Union[MultimodalLLM, CaptioningModel] = GPTVision()): 51 | self.model = model 52 | 53 | def _execute(self, image: Image.Image)-> str: 54 | prompt = f"""Describe the following image in detail""" 55 | if isinstance(self.model, MultimodalLLM): 56 | response = self.model.prompt_with_image(image, prompt) 57 | else: 58 | response = self.model.caption(image) 59 | 60 | return response 61 | 62 | def __repr__(self): 63 | model_name = self.model.__class__.__name__ 64 | return f"{self.__class__.__name__}(model={model_name})" 65 | 66 | class TextPromptAgent(TextAgent): 67 | def __init__(self, query: str, model: LLM = GPT()): 68 | self.query = query 69 | self.model = model 70 | 71 | def _execute(self, text: str)-> str: 72 | prompt = f"""{text}\n{self.query}""" 73 | response = self.model.prompt(prompt) 74 | return response 75 | 76 | def __repr__(self): 77 | model_name = self.model.__class__.__name__ 78 | return f"{self.__class__.__name__}(query={self.query}, model={model_name})" 79 | 80 | # Convert binary_choice into an agent class 81 | class BinaryChoiceAgent(ImageToTextAgent): 82 | def __init__(self, query: str, model: MultimodalLLM = GPTVision()): 83 | self.query = query 84 | self.model = model 85 | 86 | def _execute(self, image: Image.Image)-> str: 87 | prompt = f"""{self.query}""" 88 | response = self.model.prompt_with_image(image, prompt) 89 | truthy = "yes" in response.lower() 90 | assigned_class = "yes" if truthy else "no" 91 | return assigned_class 92 | 93 | def __repr__(self): 94 | model_name = self.model.__class__.__name__ 95 | return f"{self.__class__.__name__}(query={self.query}, model={model_name})" 96 | 97 | class ClassificationAgent(ImageDetectionAgent): 98 | def __init__(self, classes, model: ClassificationModel = CLIP()): 99 | self.classes = classes 100 | self.model = model 101 | 102 | def _execute(self, image: Image.Image)-> Detections: 103 | selected_class = self.model.classify(image, self.classes) 104 | return selected_class 105 | 106 | def __repr__(self): 107 | model_name = self.model.__class__.__name__ 108 | return f"{self.__class__.__name__}(classes={self.classes}, model={model_name})" 109 | 110 | class OCRAgent(ImageToTextAgent): 111 | def __init__(self, model: Optional[OCRModel] = None): 112 | self.model = model if model is not None else GPTVision() 113 | 114 | def _execute(self, image: Image.Image)-> str: 115 | response = self.model.parse_text(image) 116 | return response 117 | 118 | def __repr__(self): 119 | model_name = self.model.__class__.__name__ 120 | return f"{self.__class__.__name__}(model={model_name})" 121 | 122 | import warnings 123 | warnings.filterwarnings("ignore", category=DeprecationWarning) 124 | 125 | options = Union[GPTVision, Gemini, Claude] 126 | 127 | 128 | 129 | class InstructorImageAgent(ImageToDataAgent): 130 | 131 | def __init__(self, response_model: type[BaseModel], model: Union[GPTVision, Gemini, Claude] = GPTVision(), max_tokens: int = 4096, extra_context: Optional[List[Dict[str, str]]] = None): 132 | self.response_model = response_model 133 | self.model = model 134 | self.extra_context = extra_context if extra_context is not None else [] 135 | self.max_tokens = max_tokens 136 | if not isinstance(self.model, GPTVision) and not isinstance(self.model, Gemini) and not isinstance(self.model, Claude): 137 | raise ValueError("Model must be a GPTVision, Gemini, or Claude") 138 | 139 | def _execute(self, image: Image.Image)-> Any: 140 | model_name = "" 141 | model_client = self.model.client 142 | if model_client is None: 143 | raise ValueError("No client found. Please call load_resources() on model") 144 | 145 | if isinstance(self.model, GPTVision): 146 | client = instructor.from_openai(model_client) 147 | model_name = self.model.model 148 | elif isinstance(self.model, Gemini): 149 | client = instructor.from_gemini(model_client) 150 | model_name = self.model.model_name 151 | elif isinstance(self.model, Claude): 152 | client = instructor.from_anthropic(model_client) 153 | model_name = self.model.model 154 | 155 | 156 | if isinstance(self.model, Gemini): 157 | return client.chat.completions.create( 158 | response_model=self.response_model, 159 | messages=[ 160 | *self.extra_context, 161 | { 162 | "role": "user", 163 | "content": image 164 | } 165 | ], 166 | max_tokens=self.max_tokens, 167 | ) 168 | elif isinstance(self.model, GPTVision): 169 | buffered = io.BytesIO() 170 | image.save(buffered, format="PNG") 171 | base64_image = base64.b64encode(buffered.getvalue()).decode('utf-8') 172 | 173 | messages = [*self.extra_context, {"role": "user", "content": [ 174 | { 175 | "type": "image_url", 176 | "image_url": { 177 | "url": f"data:image/png;base64,{base64_image}" 178 | } 179 | } 180 | ]}] 181 | return client.chat.completions.create( 182 | model=model_name, 183 | response_model=self.response_model, 184 | messages=messages, 185 | max_tokens=self.max_tokens, 186 | ) 187 | else: 188 | buffered = io.BytesIO() 189 | image.save(buffered, format="PNG") 190 | base64_image = base64.b64encode(buffered.getvalue()).decode('utf-8') 191 | 192 | messages = [*self.extra_context, {"role": "user", "content": [ 193 | { 194 | "type": "image", 195 | "source": { 196 | "type": "base64", 197 | "media_type": "image/png", 198 | "data": base64_image 199 | } 200 | } 201 | ]}] 202 | 203 | return client.chat.completions.create( 204 | model=model_name, 205 | response_model=self.response_model, 206 | messages=messages, 207 | max_tokens=self.max_tokens, 208 | ) 209 | 210 | def __repr__(self): 211 | model_name = self.model 212 | return f"{self.__class__.__name__}(response_model={self.response_model}, model={repr(model_name)})" 213 | 214 | 215 | class InstructorTextAgent(TextAgent): 216 | def __init__(self, response_model: type[BaseModel], model: Union[GPT, Gemini, Claude] = GPT(), max_tokens: int = 4096, extra_context: Optional[List[Dict[str, str]]] = None): 217 | self.response_model = response_model 218 | self.model = model 219 | self.extra_context = extra_context if extra_context is not None else [] 220 | self.max_tokens = max_tokens 221 | if not isinstance(self.model, GPT) and not isinstance(self.model, Gemini) and not isinstance(self.model, Claude): 222 | raise ValueError("Model must be a GPT, Gemini, or Claude") 223 | 224 | def _execute(self, text: str)-> Any: 225 | model_client = self.model.client 226 | if model_client is None: 227 | raise ValueError("No client found. Please call load_resources() on model") 228 | 229 | if isinstance(self.model, GPT): 230 | client = instructor.from_openai(model_client) 231 | model_name = self.model.model 232 | elif isinstance(self.model, Gemini): 233 | client = instructor.from_gemini(model_client) 234 | model_name = self.model.model_name 235 | elif isinstance(self.model, Claude): 236 | client = instructor.from_anthropic(model_client) 237 | model_name = self.model.model 238 | 239 | if isinstance(self.model, Gemini): 240 | structured_response = client.chat.completions.create( 241 | response_model=self.response_model, 242 | messages=[*self.extra_context, {"role": "user", "parts": [ 243 | text 244 | ]}], 245 | max_tokens=self.max_tokens, 246 | ) 247 | else: 248 | structured_response = client.chat.completions.create( 249 | model=model_name, 250 | messages=[*self.extra_context, {"role": "user", "content": text}], 251 | response_model=self.response_model, 252 | max_tokens=self.max_tokens, 253 | ) 254 | 255 | 256 | return structured_response 257 | 258 | def __repr__(self): 259 | model_name = self.model 260 | return f"{self.__class__.__name__}(response_model={self.response_model}, model={repr(model_name)})" 261 | -------------------------------------------------------------------------------- /overeasy/agents/scrape.py: -------------------------------------------------------------------------------- 1 | import base64 2 | from urllib.parse import urlparse, urljoin 3 | import requests 4 | 5 | import os 6 | 7 | 8 | def _is_valid_xml_char(c): 9 | codepoint = ord(c) 10 | return ( 11 | 0x20 <= codepoint <= 0xD7FF or 12 | codepoint in (0x9, 0xA, 0xD) or 13 | 0xE000 <= codepoint <= 0xFFFD or 14 | 0x10000 <= codepoint <= 0x10FFFF 15 | ) 16 | 17 | def _sanitize_string(s): 18 | return ''.join(c for c in s if _is_valid_xml_char(c)) 19 | 20 | def _get_page_source(url): 21 | from playwright.sync_api import sync_playwright 22 | with sync_playwright() as p: 23 | browser = p.chromium.launch(headless=True) 24 | page = browser.new_page() 25 | page.goto(url) 26 | page.wait_for_load_state('domcontentloaded') 27 | content = page.content() 28 | browser.close() 29 | return content 30 | 31 | def _download_resource(url, base_url): 32 | full_url = urljoin(base_url, url) if not url.startswith(('http://', 'https://')) else url 33 | response = requests.get(full_url) 34 | response.raise_for_status() 35 | return response.content 36 | 37 | def _embed_stylesheets(tree, base_url): 38 | from lxml import html 39 | 40 | for link in tree.xpath('//link[@rel="stylesheet"]'): 41 | href = link.get('href') 42 | if href: 43 | css_content = _download_resource(href, base_url).decode('utf-8') 44 | style = html.Element('style') 45 | style.text = css_content 46 | link.getparent().replace(link, style) 47 | 48 | def _embed_scripts(tree, base_url): 49 | from lxml import html 50 | 51 | for script in tree.xpath('//script[@src]'): 52 | src = script.get('src') 53 | if src: 54 | js_content = _download_resource(src, base_url) 55 | script_tag = html.Element('script') 56 | script_tag.set('type', 'text/javascript') 57 | script_tag.text = _sanitize_string(js_content.decode('utf-8', errors='ignore')) 58 | script.getparent().replace(script, script_tag) 59 | 60 | for link in tree.xpath('//link[@rel="modulepreload"]'): 61 | href = link.get('href') 62 | if href: 63 | js_content = _download_resource(href, base_url) 64 | script_tag = html.Element('script') 65 | script_tag.set('type', 'module') 66 | script_tag.text = _sanitize_string(js_content.decode('utf-8', errors='ignore')) 67 | link.getparent().replace(link, script_tag) 68 | 69 | def _embed_images(tree, base_url): 70 | for img in tree.xpath('//img'): 71 | src = img.get('src') 72 | if src and not src.startswith('data:'): 73 | img_content = _download_resource(src, base_url) 74 | img_base64 = base64.b64encode(img_content).decode('utf-8') 75 | img.set('src', f"data:image/{src.split('.')[-1]};base64,{img_base64}") 76 | 77 | def _remove_footers(tree): 78 | for footer in tree.xpath('//footer'): 79 | footer.getparent().remove(footer) 80 | 81 | def _process_webpage(url): 82 | from lxml import html 83 | 84 | parsed_url = urlparse(url) 85 | base_url = f"{parsed_url.scheme}://{parsed_url.netloc}" 86 | 87 | html_content = _get_page_source(url) 88 | tree = html.fromstring(html_content) 89 | 90 | _embed_stylesheets(tree, base_url) 91 | _embed_scripts(tree, base_url) 92 | _embed_images(tree, base_url) 93 | _remove_footers(tree) 94 | 95 | for button in tree.xpath('//button[@aria-label="Download"]'): 96 | button.getparent().remove(button) 97 | 98 | dark_mode_path = os.path.join(os.path.dirname(__file__), 'dark_mode.txt') 99 | with open(dark_mode_path, 'r') as f: 100 | dark_mode_content = f.read() 101 | 102 | html_root = tree.getroottree().getroot() 103 | dark_mode_element = html.fromstring(dark_mode_content) 104 | html_root.append(dark_mode_element) 105 | 106 | return html.tostring(tree, pretty_print=True, encoding='utf-8').decode('utf-8') 107 | 108 | def scrape_and_inline_to_buffer(url): 109 | """ 110 | Scrape a webpage, inline all resources, remove footers, and return the result as a string. 111 | 112 | Args: 113 | url (str): The URL of the webpage to scrape. 114 | 115 | Returns: 116 | str: The processed HTML content as a string. 117 | """ 118 | return _process_webpage(url) 119 | 120 | def scrape_and_inline_to_file(url, output_file='gradio_visualization.html'): 121 | """ 122 | Scrape a webpage, inline all resources, remove footers, and save the result to a file. 123 | 124 | Args: 125 | url (str): The URL of the webpage to scrape. 126 | output_file (str): The name of the file to save the result. Defaults to 'requested.html'. 127 | 128 | Returns: 129 | None 130 | """ 131 | processed_html = _process_webpage(url) 132 | with open(output_file, 'w', encoding='utf-8') as file: 133 | file.write(processed_html) 134 | 135 | -------------------------------------------------------------------------------- /overeasy/agents/split_join_agent.py: -------------------------------------------------------------------------------- 1 | from overeasy.types import Agent, ExecutionNode, Detections, DetectionType, ExecutionGraph, NullExecutionNode 2 | from overeasy.logging import log_time 3 | from typing import List, Tuple, Union 4 | import numpy as np 5 | 6 | Node = Union[ExecutionNode, NullExecutionNode] 7 | 8 | class SplitAgent(Agent): 9 | @log_time 10 | def execute(self, node: ExecutionNode) -> List[Node]: 11 | result : List[Node] = [] 12 | if not isinstance(node.data, Detections): 13 | raise ValueError(f"ExecutionNode data must be of type Detections, got {type(node.data)}") 14 | detections: Detections = node.data 15 | if detections.detection_type != DetectionType.BOUNDING_BOX: 16 | raise ValueError("Detection type must be BOUNDING_BOX") 17 | 18 | for split_detection in detections.split(): 19 | parent_xyxy = split_detection.xyxy[0] 20 | child_image = node.image.crop(parent_xyxy) 21 | child_detection = Detections( 22 | xyxy=np.array([[0, 0, child_image.width, child_image.height]]), 23 | class_ids=split_detection.class_ids, 24 | confidence=split_detection.confidence, 25 | classes=split_detection.classes, 26 | data=split_detection.data, 27 | detection_type=DetectionType.CLASSIFICATION 28 | ) 29 | result.append(ExecutionNode(child_image, child_detection, parent_detection=split_detection)) 30 | 31 | if len(result) == 0: 32 | result.append(NullExecutionNode(parent_detection=detections)) 33 | 34 | return result 35 | 36 | def __repr__(self): 37 | return f"{self.__class__.__name__}()" 38 | 39 | def combine_detections(dets: List[Detections], parent_dets: List[Detections]) -> Detections: 40 | if not all(isinstance(x, Detections) for x in dets) or not all(isinstance(x, Detections) for x in parent_dets): 41 | raise ValueError("Cannot combine detections") 42 | 43 | det_types = [x.detection_type for x in dets] 44 | if not all(x == det_types[0] for x in det_types): 45 | raise ValueError("Cannot combine detections of different types") 46 | 47 | det_type = det_types[0] 48 | parent_detection_type = parent_dets[0].detection_type 49 | 50 | if det_type == DetectionType.CLASSIFICATION: 51 | if any(len(det.class_ids) > 1 for det in dets): 52 | raise ValueError("Detections with multiple classes are not supported for combination") 53 | 54 | uniq_classes = set() 55 | for det in dets: 56 | for cls in det.classes: 57 | uniq_classes.add(cls) 58 | classes = np.array(list(uniq_classes)) 59 | class_id_map = {cls: idx for idx, cls in enumerate(classes)} 60 | xyxy = [] 61 | class_ids = [] 62 | confidence = [] 63 | 64 | 65 | for det, parent_det in zip(dets, parent_dets): 66 | if len(parent_det.class_ids) == 0: 67 | continue 68 | confidence.append(det.confidence[0] if det.confidence is not None else None) 69 | cls = det.class_names[0] 70 | class_ids.append(class_id_map[cls]) 71 | xyxy.append(parent_det.xyxy[0]) 72 | 73 | return Detections( 74 | xyxy=np.array(xyxy) if len(xyxy) > 0 else np.zeros((0, 4)), 75 | class_ids=np.array(class_ids), 76 | confidence=np.array(confidence), 77 | classes=classes, 78 | detection_type=parent_detection_type 79 | ) 80 | 81 | elif det_type == DetectionType.BOUNDING_BOX: 82 | # TODO: Implement this 83 | pass 84 | elif det_type == DetectionType.SEGMENTATION: 85 | # TODO: Implement this 86 | pass 87 | 88 | 89 | return Detections.empty() 90 | 91 | 92 | class JoinAgent(Agent): 93 | # Mutates the input graph by adding a new layer of child nodes 94 | @log_time 95 | def join(self, graph: ExecutionGraph, target_split: int) -> List[Node]: 96 | topsort = graph.top_sort() 97 | parents_all = topsort[target_split] 98 | 99 | if len(parents_all) == 1 and isinstance(parents_all[0], NullExecutionNode): 100 | if len(topsort[-1]) == 1 and isinstance(topsort[-1][0], NullExecutionNode): 101 | null_child = NullExecutionNode() 102 | graph.add_child(topsort[-1][0], null_child) 103 | return [null_child] 104 | else: 105 | raise ValueError("Graph is formatted incorrectly") 106 | 107 | leaves: List[Node] = [] 108 | def merge_nodes(node_and_parent_det: List[Tuple[Node, Detections]], parent: ExecutionNode): 109 | if len(node_and_parent_det) == 0: 110 | return 111 | nodes = [x[0] for x in node_and_parent_det] 112 | parent_dets = [x[1] for x in node_and_parent_det] 113 | 114 | original_data = [node.data if isinstance(node, ExecutionNode) else None for node in nodes] 115 | merged_node = ExecutionNode(parent.image, original_data) 116 | 117 | if all(isinstance(x, Detections) for x in original_data): 118 | merged_node.data = combine_detections(original_data, parent_dets) #type: ignore 119 | 120 | for node in nodes: 121 | graph.add_child(node, merged_node) 122 | leaves.append(merged_node) 123 | 124 | # Filter out non execution nodes 125 | parents = [p for p in parents_all if isinstance(p, ExecutionNode)] 126 | split_children = [graph.children(p) for p in parents] 127 | to_merge = topsort[-1] 128 | 129 | ind = 0 130 | for (parent_list, parent_node) in zip(split_children, parents): 131 | nodes: List[Tuple[Node, Detections]] = [] 132 | for child in parent_list: 133 | if child.parent_detection is None: 134 | raise ValueError("Parent detection is None") 135 | node = to_merge[ind] 136 | 137 | nodes.append((node, child.parent_detection)) 138 | ind+=1 139 | if isinstance(parent_node, NullExecutionNode): 140 | raise ValueError("Parent node is NullExecutionNode") 141 | merge_nodes(nodes, parent_node) 142 | 143 | 144 | 145 | 146 | return leaves 147 | 148 | def __repr__(self): 149 | return f"{self.__class__.__name__}()" -------------------------------------------------------------------------------- /overeasy/agents/workflow.py: -------------------------------------------------------------------------------- 1 | from overeasy.types import * 2 | from overeasy.agents import SplitAgent, JoinAgent 3 | from overeasy.visualize_utils import annotate 4 | from typing import List, Tuple 5 | from PIL import Image 6 | import numpy as np 7 | import gradio as gr 8 | from math import fabs, sqrt 9 | from tqdm import tqdm 10 | from dataclasses import field, dataclass 11 | from overeasy.dirs import FAVICON_PATH 12 | from typing import Optional, Any, Dict, Union 13 | from io import StringIO 14 | from overeasy.agents.scrape import scrape_and_inline_to_buffer 15 | import time 16 | import threading 17 | import sys 18 | 19 | def _visualize_layer(layer: List[Node]) -> List[Tuple[Optional[Image.Image], str]]: 20 | images: List[Tuple[Optional[Image.Image], str]] = [] 21 | 22 | if all(isinstance(node, ExecutionNode) and isinstance(node.data, Detections) for node in layer): 23 | for node in layer: 24 | assert isinstance(node, ExecutionNode) 25 | detection = node.data 26 | if detection.detection_type == DetectionType.BOUNDING_BOX: 27 | images.append((annotate(node.image, detection), "Bounding Box")) 28 | elif detection.detection_type == DetectionType.SEGMENTATION: 29 | images.append((annotate(node.image, detection), "Segmentation")) 30 | elif detection.detection_type == DetectionType.CLASSIFICATION: 31 | labels = [f"{name} {score:.2f}" if score is not None else f"{name}" for name, score in zip(detection.class_names, detection.confidence_scores)] 32 | stringified = labels[0] if len(labels)==1 else str(labels) 33 | images.append((node.image, stringified)) 34 | else: 35 | images = [(x.image, str(x.data)) if isinstance(x, ExecutionNode) else (None, "None") for x in layer] 36 | return images 37 | 38 | # Can ignore JoinAgents as they are handled earlier in the flow 39 | def handle_node(node: ExecutionNode, agent: Agent) -> List[ExecutionNode]: 40 | if isinstance(agent, SplitAgent): 41 | return agent.execute(node) 42 | # Superclass so we dont have to check for children 43 | elif isinstance(agent, ImageToDataAgent) : 44 | return [ExecutionNode(node.image, agent.execute(node.image))] 45 | elif isinstance(agent, DataAgent): 46 | return [ExecutionNode(node.image, agent.execute(node.data))] 47 | else: 48 | raise ValueError(f"Agent {agent} is not a valid agent type") 49 | @dataclass(frozen=True) 50 | class Workflow: 51 | steps: List[Agent] 52 | join_to_split_map: Dict[int, int] = field(default_factory=dict, init=False) 53 | 54 | def __post_init__(self): 55 | splits = [] 56 | for i, agent in enumerate(self.steps): 57 | if isinstance(agent, SplitAgent): 58 | splits.append(i) 59 | elif isinstance(agent, JoinAgent): 60 | if len(splits) == 0: 61 | raise ValueError(f"JoinAgent at index {i} has no matching SplitAgent") 62 | self.join_to_split_map[i] = splits.pop() 63 | 64 | 65 | def __repr__(self): 66 | return f"Workflow(steps={self.steps})" 67 | 68 | # Return leaves of the graph and the graph itself 69 | def execute(self, input_image: Image.Image, data: Optional[Any] = None) -> Tuple[List[ExecutionNode], ExecutionGraph]: 70 | 71 | if input_image is None: 72 | raise ValueError("Input image is None") 73 | elif isinstance(input_image, np.ndarray): 74 | raise ValueError("Input image is a numpy array, please convert it to a PIL.Image first using Image.fromarray(), make sure to convert your color channels to RGB!") 75 | elif not isinstance(input_image, Image.Image): 76 | raise ValueError("Input image is not a valid image format, must be a PIL.Image") 77 | 78 | input_image = input_image.convert("RGB") 79 | 80 | root = ExecutionNode(input_image, data) 81 | graph = ExecutionGraph(root) 82 | 83 | intermediate_results : List[Node] = [root] 84 | for ind, agent in enumerate(self.steps): 85 | if hasattr(agent, 'model') and isinstance(agent.model, Model): 86 | agent.model.load_resources() 87 | try: 88 | if isinstance(agent, JoinAgent): 89 | target_split = self.join_to_split_map[ind] 90 | intermediate_results = agent.join(graph, target_split) 91 | continue 92 | 93 | next_results: List[Node] = [] 94 | for node in intermediate_results: 95 | if isinstance(node, NullExecutionNode): 96 | null_node = NullExecutionNode() 97 | graph.add_child(node, null_node) 98 | next_results.append(null_node) 99 | continue 100 | # Now node must be a ExecutionNode 101 | assert isinstance(node, ExecutionNode) 102 | res = handle_node(node, agent) 103 | for child in res: 104 | graph.add_child(node, child) 105 | next_results.extend(res) 106 | 107 | intermediate_results = next_results 108 | finally: 109 | if hasattr(agent, 'model') and isinstance(agent.model, Model): 110 | agent.model.release_resources() 111 | 112 | return [node for node in intermediate_results if isinstance(node, ExecutionNode)], graph 113 | 114 | 115 | def execute_multiple(self, input_images: List[Image.Image]) -> Tuple[List[List[ExecutionNode]], List[ExecutionGraph]]: 116 | if not all(isinstance(img, Image.Image) for img in input_images): 117 | raise ValueError("All input images must be of type PIL.Image") 118 | # Normalize image format 119 | input_images = [img.convert("RGB") for img in input_images] 120 | 121 | all_graphs = [ExecutionGraph(ExecutionNode(img, None)) for img in input_images] 122 | intermediate_results: List[List[Node]] = [[graph.root] for graph in all_graphs] 123 | 124 | for ind, agent in enumerate(tqdm(self.steps, desc="Processing steps")): 125 | try: 126 | if hasattr(agent, 'model') and isinstance(agent.model, Model): 127 | agent.model.load_resources() 128 | 129 | if isinstance(agent, JoinAgent): 130 | target_split = self.join_to_split_map[ind] 131 | for i, graph in enumerate(all_graphs): 132 | intermediate_results[i] = agent.join(graph, target_split) 133 | else: 134 | for i in range(len(all_graphs)): 135 | next_results: List[Node] = [] 136 | for node in intermediate_results[i]: 137 | if isinstance(node, NullExecutionNode): 138 | null_node = NullExecutionNode() 139 | all_graphs[i].add_child(node, null_node) 140 | next_results.append(null_node) 141 | elif isinstance(node, ExecutionNode): 142 | res = handle_node(node, agent) 143 | for child in res: 144 | all_graphs[i].add_child(node, child) 145 | next_results.extend(res) 146 | intermediate_results[i] = next_results 147 | 148 | finally: 149 | if hasattr(agent, 'model') and isinstance(agent.model, Model): 150 | agent.model.release_resources() 151 | 152 | all_leaves = [ 153 | [node for node in results if isinstance(node, ExecutionNode)] 154 | for results in intermediate_results 155 | ] 156 | 157 | return all_leaves, all_graphs 158 | 159 | def to_steps(self, graph: ExecutionGraph) -> List[Tuple[str, List[Tuple[Optional[Image.Image], str]], str]]: 160 | workflow_steps_names = ["Input Image"] 161 | code_snippets = [""] 162 | for i in range(len(self.steps)): 163 | workflow_steps_names.append(self.steps[i].__class__.__name__) 164 | code_snippets.append(repr(self.steps[i])) 165 | 166 | layers = graph.top_sort() 167 | steps = [] 168 | for i, layer in enumerate(layers): 169 | code = f"```python\n{code_snippets[i]}\n```" 170 | layer_images = _visualize_layer(layer) 171 | steps.append((f"Step {i+1}: {workflow_steps_names[i]}", layer_images, code)) 172 | 173 | return steps 174 | 175 | def visualize(self, graph: ExecutionGraph, share: bool = False, prevent_thread_lock: bool = False): 176 | steps = self.to_steps(graph) 177 | css = """ 178 | .gradio-container .single-image { 179 | display: flex; 180 | justify-content: center; 181 | align-items: center; 182 | } 183 | """ 184 | with gr.Blocks(css=css) as demo: 185 | for step_title, images, code in steps: 186 | with gr.Group(): 187 | gr.Markdown(f"### {step_title}") 188 | gr.Markdown(f"{code}") 189 | filtered_images = [(img, label) for img, label in images if img is not None] 190 | if len(filtered_images) > 1: 191 | gr.Gallery( 192 | value=filtered_images, 193 | height="max-content", 194 | object_fit="contain", 195 | show_label=False, 196 | columns=max(3, min(9, round(sqrt(len(images))))), 197 | ) 198 | elif len(filtered_images) == 1: 199 | [image] = filtered_images 200 | with gr.Row(): 201 | with gr.Column(scale=1): 202 | pass 203 | with gr.Column(scale=4): 204 | gr.Image( 205 | width=640, 206 | value=image[0], 207 | label=image[1], 208 | ) 209 | with gr.Column(scale=1): 210 | pass 211 | elif len(filtered_images) == 0: 212 | with gr.Row(): 213 | with gr.Column(scale=1): 214 | pass 215 | with gr.Column(scale=4): 216 | gr.Text("No image to display skipped because previous split was empty") 217 | with gr.Column(scale=1): 218 | pass 219 | 220 | 221 | demo.launch(favicon_path=FAVICON_PATH, share=share, prevent_thread_lock=prevent_thread_lock) 222 | 223 | def visualize_to_html_string(self, graph: ExecutionGraph): 224 | output_buffer = StringIO() 225 | original_stdout = sys.stdout 226 | 227 | def run_gradio(): 228 | try: 229 | sys.stdout = output_buffer 230 | self.visualize(graph, prevent_thread_lock=True) 231 | finally: 232 | sys.stdout = original_stdout 233 | 234 | gradio_thread = threading.Thread(target=run_gradio) 235 | gradio_thread.daemon = True 236 | gradio_thread.start() 237 | 238 | # Monitor the output buffer for the "Running on" string 239 | start_time = time.time() 240 | url = None 241 | while time.time() - start_time < 10: # Wait for up to 30 seconds 242 | output_buffer.seek(0) 243 | for line in output_buffer: 244 | print("thread:", line.strip()) 245 | if "Running on local URL:" in line: 246 | url = line.split("Running on local URL:")[1].strip() 247 | break 248 | if url: 249 | break 250 | time.sleep(0.1) # Short sleep to prevent busy-waiting 251 | 252 | sys.stdout = original_stdout 253 | if url: 254 | gradio_thread.join() 255 | buffer = scrape_and_inline_to_buffer(url) 256 | return buffer 257 | else: 258 | gradio_thread.join() 259 | raise RuntimeError("Failed to determine Gradio server URL") 260 | 261 | 262 | def visualize_to_file(self, graph: ExecutionGraph, file_path: str): 263 | html_string = self.visualize_to_html_string(graph) 264 | with open(file_path, "w") as f: 265 | f.write(html_string) -------------------------------------------------------------------------------- /overeasy/assets/egg_logo.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/overeasy-sh/overeasy/8de64ae6a21e0e8c757d3359c52c6134558c34b5/overeasy/assets/egg_logo.png -------------------------------------------------------------------------------- /overeasy/assets/favicon.ico: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/overeasy-sh/overeasy/8de64ae6a21e0e8c757d3359c52c6134558c34b5/overeasy/assets/favicon.ico -------------------------------------------------------------------------------- /overeasy/dirs.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | ROOT = os.path.expanduser("~/.overeasy") 4 | FAVICON_PATH = os.path.join(os.path.dirname(__file__), "./assets/favicon.ico") -------------------------------------------------------------------------------- /overeasy/download_utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import tempfile 3 | import urllib 4 | import progressbar 5 | 6 | class URLProgressBar(): 7 | def __init__(self, filename: str): 8 | print(f"\nDownloading {filename} ...") 9 | self.pbar = None 10 | 11 | def __call__(self, block_num, block_size, total_size): 12 | if not self.pbar: 13 | self.pbar=progressbar.ProgressBar(maxval=total_size) 14 | self.pbar.start() 15 | 16 | downloaded = block_num * block_size 17 | if downloaded < total_size: 18 | self.pbar.update(downloaded) 19 | else: 20 | self.pbar.finish() 21 | 22 | # It's important to download files this way to prevent files from being corrupted 23 | # if a error occurs while downloading. 24 | def atomic_retrieve_and_rename(url: str, destination_path: str): 25 | tmp_file_path = None 26 | try: 27 | with tempfile.NamedTemporaryFile(delete=False) as tmp_file: 28 | tmp_file_path = tmp_file.name 29 | urllib.request.urlretrieve(url, tmp_file_path, reporthook=URLProgressBar(os.path.basename(destination_path))) 30 | os.rename(tmp_file_path, destination_path) 31 | except Exception as e: 32 | if tmp_file_path and os.path.exists(tmp_file_path): 33 | os.remove(tmp_file_path) 34 | raise e -------------------------------------------------------------------------------- /overeasy/logging.py: -------------------------------------------------------------------------------- 1 | import time 2 | from tabulate import tabulate 3 | from functools import wraps 4 | from collections import defaultdict 5 | from typing import Callable, Any, Dict 6 | 7 | function_stats: Dict[str, Dict[str, float]] = defaultdict(lambda: {"count": 0, "total_time": 0.0}) 8 | 9 | def log_time(func: Callable[..., Any]) -> Callable[..., Any]: 10 | @wraps(func) 11 | def wrapper(*args: Any, **kwargs: Any) -> Any: 12 | start_time = time.perf_counter() 13 | result = func(*args, **kwargs) 14 | end_time = time.perf_counter() 15 | 16 | if args and hasattr(args[0], '__class__'): 17 | name = f"{args[0].__class__.__name__}.{func.__name__}" 18 | else: 19 | name = func.__qualname__ 20 | 21 | 22 | function_stats[name]["count"] += 1 23 | function_stats[name]["total_time"] += end_time - start_time 24 | 25 | return result 26 | return wrapper 27 | 28 | 29 | def print_summary(): 30 | total_run_time = sum(stats["total_time"] for stats in function_stats.values()) 31 | table = [] 32 | 33 | for func_name, stats in function_stats.items():\ 34 | 35 | average_time = stats["total_time"] / stats["count"] 36 | 37 | if stats['total_time'] < 0.1: 38 | total_time_str = f"{average_time*1000:.2f}ms" 39 | elif stats['total_time'] < 60: 40 | total_time_str = f"{average_time:.2f}s" 41 | else: 42 | minutes, seconds = divmod(average_time, 60) 43 | total_time_str = f"{int(minutes)}m {seconds:.2f}s" 44 | 45 | 46 | proportion = (stats["total_time"] / total_run_time) * 100 47 | table.append([func_name, stats['count'], total_time_str, f"{proportion:.2f}%"]) 48 | 49 | headers = ["Function Name", "Calls", "Average Time (s)", "Proportion of Total Runtime"] 50 | print(tabulate(table, headers=headers, tablefmt="grid")) -------------------------------------------------------------------------------- /overeasy/models/LLMs/__init__.py: -------------------------------------------------------------------------------- 1 | from .openai import GPT, GPTVision 2 | from .qwenvl import QwenVL 3 | from .gemini import Gemini 4 | from .anthropic import Claude 5 | from .pali_gemma import PaliGemma -------------------------------------------------------------------------------- /overeasy/models/LLMs/anthropic.py: -------------------------------------------------------------------------------- 1 | from PIL import Image 2 | from overeasy.types import MultimodalLLM, OCRModel, Model 3 | from typing import Optional 4 | import anthropic 5 | import base64 6 | import io 7 | import os 8 | 9 | class Claude(MultimodalLLM, OCRModel): 10 | models = ["claude-3-5-sonnet-20240620", "claude-3-haiku-20240307", "claude-3-sonnet-20240229", "claude-3-opus-20240229"] 11 | def __init__(self, model: str = 'claude-3-5-sonnet-20240620', api_key: Optional[str] = None): 12 | self.api_key = api_key if api_key is not None else os.getenv("ANTHROPIC_API_KEY") 13 | # Support for shortened model names 14 | self.model = model 15 | self.client = None 16 | 17 | def load_resources(self): 18 | if self.api_key is None: 19 | raise ValueError("No API key found. Please provide an API key, or set the ANTHROPIC_API_KEY environment variable.") 20 | self.client = anthropic.Anthropic(api_key=self.api_key) 21 | 22 | super().load_resources() 23 | 24 | def release_resources(self): 25 | self.client = None 26 | super().release_resources() 27 | 28 | def prompt_with_image(self, image: Image.Image, query: str) -> str: 29 | if self.client is None: 30 | raise ValueError("Anthropic client not loaded. Please call load_resources() first.") 31 | buffered = io.BytesIO() 32 | image.save(buffered, format="JPEG") 33 | image_data = base64.b64encode(buffered.getvalue()).decode("utf-8") 34 | 35 | message = self.client.messages.create( 36 | model=self.model, 37 | max_tokens=1024, 38 | messages=[ 39 | { 40 | "role": "user", 41 | "content": [ 42 | { 43 | "type": "image", 44 | "source": { 45 | "type": "base64", 46 | "media_type": "image/jpeg", 47 | "data": image_data, 48 | }, 49 | }, 50 | { 51 | "type": "text", 52 | "text": query 53 | } 54 | ], 55 | } 56 | ], 57 | ) 58 | return message.content[0].text # type: ignore 59 | 60 | def prompt(self, query: str) -> str: 61 | if self.client is None: 62 | raise ValueError("Anthropic client not loaded. Please call load_resources() first.") 63 | message = self.client.messages.create( 64 | model=self.model, 65 | max_tokens=1024, 66 | messages=[ 67 | {"role": "user", "content": query} 68 | ] 69 | ) 70 | return message.content[0].text # type: ignore 71 | 72 | def parse_text(self, image: Image.Image) -> str: 73 | return self.prompt_with_image(image, "Read the text from the image line by line only output the text.") 74 | 75 | -------------------------------------------------------------------------------- /overeasy/models/LLMs/gemini.py: -------------------------------------------------------------------------------- 1 | import os 2 | from PIL import Image 3 | from overeasy.types import MultimodalLLM, OCRModel, Model 4 | from typing import Optional 5 | import google.generativeai as genai 6 | import warnings 7 | 8 | models = [ 9 | "gemini-1.5-flash", 10 | "gemini-1.5-pro", 11 | "gemini-pro-vision", 12 | ] 13 | 14 | class Gemini(MultimodalLLM, OCRModel): 15 | def __init__(self, api_key: Optional[str] = None, model: str = "gemini-1.5-pro"): 16 | if model not in models: 17 | warnings.warn(f"Model {model} not supported for Gemini. Please use one of the following: {models}") 18 | self.api_key = api_key if api_key is not None else os.getenv("GOOGLE_API_KEY") 19 | self.model_name = model 20 | self.client = None 21 | 22 | def load_resources(self): 23 | if self.api_key is None: 24 | warnings.warn("No API key found. Using Gemini free tier. For higher usage please provide an API key, or set the GOOGLE_API_KEY environment variable.") 25 | else: 26 | genai.configure(api_key=self.api_key) 27 | self.client = genai.GenerativeModel(model_name=self.model_name) 28 | 29 | 30 | def release_resources(self): 31 | self.client = None 32 | 33 | def prompt_with_image(self, image: Image.Image, query: str) -> str: 34 | if self.client is None: 35 | raise ValueError("Client is not loaded. Please call load_resources() first.") 36 | response = self.client.generate_content([query, image], stream=True) 37 | response.resolve() 38 | return response.text 39 | 40 | 41 | def prompt(self, query: str) -> str: 42 | if self.client is None: 43 | raise ValueError("Client is not loaded. Please call load_resources() first.") 44 | response = self.client.generate_content([query]) 45 | response.resolve() 46 | return response.text 47 | 48 | 49 | def parse_text(self, image: Image.Image) -> str: 50 | return self.prompt_with_image(image, "Read the text from the image line by line only output the text.") -------------------------------------------------------------------------------- /overeasy/models/LLMs/openai.py: -------------------------------------------------------------------------------- 1 | import os 2 | from PIL import Image 3 | from overeasy.types import MultimodalLLM, LLM, OCRModel, Model 4 | from typing import Literal, Optional 5 | import io 6 | import base64 7 | import requests 8 | import warnings 9 | import backoff 10 | import openai 11 | 12 | def encode_image_to_base64(image: Image.Image) -> str: 13 | buffered = io.BytesIO() 14 | image.save(buffered, format="PNG") 15 | return base64.b64encode(buffered.getvalue()).decode('utf-8') 16 | 17 | current_models = [ 18 | "gpt-4o-mini" 19 | "gpt-4o", 20 | "gpt-4o-mini-2024-07-18" 21 | "gpt-4o-2024-05-13", 22 | "gpt-4-turbo", 23 | "gpt-4-turbo-2024-04-09", 24 | "gpt-4-turbo-preview", 25 | "gpt-4-0125-preview", 26 | "gpt-4-1106-preview", 27 | "gpt-4-vision-preview", 28 | "gpt-4-1106-vision-preview", 29 | "gpt-4", 30 | "gpt-4-0613", 31 | "gpt-4-32k", 32 | "gpt-4-32k-0613", 33 | "gpt-3.5-turbo-0125", 34 | "gpt-3.5-turbo", 35 | "gpt-3.5-turbo-1106", 36 | "gpt-3.5-turbo-instruct", 37 | "gpt-3.5-turbo-16k", 38 | "gpt-3.5-turbo-0613", 39 | "gpt-3.5-turbo-16k-0613" 40 | ] 41 | 42 | 43 | @backoff.on_exception(backoff.expo, openai.RateLimitError, max_tries=6) 44 | def _prompt(query: str, model: str, client: openai.OpenAI, temperature: float, max_tokens: int) -> str: 45 | response = client.chat.completions.create( 46 | model=model, 47 | messages=[ 48 | {"role": "user", "content": query} 49 | ], 50 | max_tokens=max_tokens, 51 | temperature=temperature 52 | ) 53 | 54 | result = response.choices[0].message.content 55 | if result is None: 56 | raise ValueError("No content found in response") 57 | 58 | return result 59 | 60 | 61 | 62 | class GPT(LLM): 63 | def __init__(self, api_key: Optional[str] = None, 64 | model:str = "gpt-3.5-turbo", 65 | temperature: float = 0.5, 66 | max_tokens: int = 1024): 67 | self.api_key = api_key if api_key is not None else os.getenv("OPENAI_API_KEY") 68 | self.model = model 69 | self.temperature = temperature 70 | self.max_tokens = max_tokens 71 | self.client = None 72 | if self.model not in current_models: 73 | warnings.warn(f"Model {model} may not be supported. Please provide a valid model.") 74 | 75 | def prompt(self, query: str) -> str: 76 | if self.client is None: 77 | raise ValueError("Client is not loaded. Please call load_resources() first.") 78 | return _prompt(query, self.model, self.client, self.temperature, self.max_tokens) 79 | 80 | def load_resources(self): 81 | self.client = openai.OpenAI(api_key=self.api_key) 82 | 83 | def release_resources(self): 84 | self.client = None 85 | 86 | class GPTVision(MultimodalLLM, OCRModel, GPT): 87 | def __init__(self, api_key: Optional[str] = None, 88 | model : Literal["gpt-4o", "gpt-4o-mini", "gpt-4o-2024-05-13", "gpt-4-turbo", "gpt-4-turbo-2024-04-09"] = "gpt-4o", 89 | temperature: float = 0.5, 90 | max_tokens: int = 1024 91 | ): 92 | self.api_key = api_key if api_key is not None else os.getenv("OPENAI_API_KEY") 93 | self.model = model 94 | self.temperature = temperature 95 | self.max_tokens = max_tokens 96 | self.client = None 97 | 98 | def prompt_with_image(self, image: Image.Image, query: str) -> str: 99 | if self.client is None: 100 | raise ValueError("Client is not loaded. Please call load_resources() first.") 101 | 102 | base64_image = encode_image_to_base64(image) 103 | 104 | messages = [ 105 | { 106 | "role": "user", 107 | "content": [ 108 | { 109 | "type": "text", 110 | "text": query 111 | }, 112 | { 113 | "type": "image_url", 114 | "image_url": { 115 | "url": f"data:image/png;base64,{base64_image}" 116 | } 117 | } 118 | ] 119 | } 120 | ] 121 | 122 | 123 | response = self.client.chat.completions.create( 124 | model=self.model, 125 | messages=messages, 126 | max_tokens=self.max_tokens, 127 | temperature=self.temperature 128 | ) 129 | 130 | result = response.choices[0].message.content 131 | if result is None: 132 | raise ValueError("No content found in response") 133 | 134 | return result 135 | 136 | def parse_text(self, image: Image.Image) -> str: 137 | return self.prompt_with_image(image, "Read the text from the image line by line only output the text.") 138 | 139 | -------------------------------------------------------------------------------- /overeasy/models/LLMs/pali_gemma.py: -------------------------------------------------------------------------------- 1 | from PIL import Image 2 | from overeasy.types import MultimodalLLM 3 | import torch 4 | from typing import Optional 5 | 6 | # Make sure to set your HuggingFace token to use PaliGemma 7 | class PaliGemma(MultimodalLLM): 8 | SIZES = [224, 448] # Image sizes the model was trained on 9 | 10 | MODEL_OPTIONS = ( 11 | [f"google/paligemma-3b-mix-{size}" for size in SIZES] + # Should use mix for most cases especially when doing object detection 12 | [f"google/paligemma-3b-ft-vqav2-{size}" for size in SIZES] + # Diagram Understanding - 85.64 Accuracy on VQAV2 13 | [f"google/paligemma-3b-ft-cococap-{size}" for size in SIZES] + # COCO Captions - 144.6 CIDEr 14 | [f"google/paligemma-3b-ft-science-qa-{size}" for size in SIZES] + # Science Question Answering - 95.93 Accuracy on ScienceQA Img subset with no CoT 15 | [f"google/paligemma-3b-ft-refcoco-seg-{size}" for size in SIZES] + # Understanding References to Specific Objects in Images - 76.94 Mean IoU on refcoco, 72.18 Mean IoU on refcoco+, 72.22 Mean IoU on refcocog 16 | [f"google/paligemma-3b-ft-rsvqa-hr-{size}" for size in SIZES] # Remote Sensing Visual Question Answering - 92.61 Accuracy on test, 90.58 Accuracy on test2 17 | ) 18 | 19 | def __init__(self, model_card: str = "google/paligemma-3b-mix-448", device: Optional[str] = None): 20 | self.model_card = model_card 21 | if self.model_card not in self.MODEL_OPTIONS: 22 | raise ValueError(f"Model {self.model_card} not found. Please select a valid model from {self.MODEL_OPTIONS}.") 23 | 24 | if device is None: 25 | device = "cuda" if torch.cuda.is_available() else "cpu" 26 | self.device = device 27 | self.model = None 28 | self.processor = None 29 | 30 | def load_resources(self): 31 | from transformers import AutoProcessor, PaliGemmaForConditionalGeneration 32 | self.model = PaliGemmaForConditionalGeneration.from_pretrained(self.model_card).to(self.device) 33 | self.processor = AutoProcessor.from_pretrained(self.model_card) 34 | 35 | def release_resources(self): 36 | self.model = None 37 | self.processor = None 38 | 39 | def prompt_with_image(self, image: Image.Image, query: str) -> str: 40 | if self.model is None or self.processor is None: 41 | raise ValueError("Model is not loaded. Please call load_resources() first.") 42 | 43 | inputs = self.processor(query, image, return_tensors="pt").to(self.device) 44 | output = self.model.generate(**inputs, max_new_tokens=100) 45 | 46 | return self.processor.decode(output[0], skip_special_tokens=True)[len(query):] 47 | 48 | def prompt(self, query: str) -> str: 49 | if self.model is None or self.processor is None: 50 | raise ValueError("Model is not loaded. Please call load_resources() first.") 51 | 52 | inputs = self.processor(query, None, return_tensors="pt").to(self.device) 53 | output = self.model.generate(**inputs, max_new_tokens=100) 54 | 55 | return self.processor.decode(output[0], skip_special_tokens=True)[len(query):] -------------------------------------------------------------------------------- /overeasy/models/LLMs/qwenvl.py: -------------------------------------------------------------------------------- 1 | from transformers import AutoModelForCausalLM, AutoTokenizer 2 | from PIL import Image 3 | import tempfile 4 | from overeasy.types import MultimodalLLM 5 | from typing import Literal 6 | import importlib 7 | import torch 8 | 9 | from overeasy.types.base import OCRModel 10 | 11 | # use bf16 12 | # model = AutoModelForCausalLM.from_pretrained("Qwen/Qwen-VL-Chat", device_map="auto", trust_remote_code=True, bf16=True).eval() 13 | # use fp16 14 | # model = AutoModelForCausalLM.from_pretrained("Qwen/Qwen-VL-Chat", device_map="auto", trust_remote_code=True, fp16=True).eval() 15 | # use cpu only 16 | # model = AutoModelForCausalLM.from_pretrained("Qwen/Qwen-VL-Chat", device_map="cpu", trust_remote_code=True).eval() 17 | # use cuda device 18 | 19 | def setup_autogptq(): 20 | import subprocess 21 | import torch 22 | 23 | # Check if CUDA is available 24 | if not torch.cuda.is_available(): 25 | print("CUDA is not available. Exiting installation.") 26 | return 27 | 28 | # Get CUDA version from PyTorch 29 | cuda_version = torch.version.cuda 30 | if cuda_version: 31 | if cuda_version.startswith("11.8"): 32 | subprocess.run(["pip", "install", "optimum"], check=True) 33 | subprocess.run([ 34 | "pip", "install", "auto-gptq", "--no-build-isolation", 35 | "--extra-index-url", "https://huggingface.github.io/autogptq-index/whl/cu118/" 36 | ], check=True) 37 | elif cuda_version.startswith("12.1"): 38 | subprocess.run(["pip", "install", "optimum"], check=True) 39 | subprocess.run([ 40 | "pip", "install", "auto-gptq", "--no-build-isolation" 41 | ], check=True) 42 | else: 43 | print(f"Unsupported CUDA version: {cuda_version}") 44 | else: 45 | print("CUDA version could not be determined.") 46 | 47 | 48 | model_TYPE = Literal["base", "int4", "fp16", "bf16"] 49 | def load_model(model_type: model_TYPE): 50 | if model_type == "base": 51 | model = AutoModelForCausalLM.from_pretrained("Qwen/Qwen-VL-Chat", device_map="cuda", trust_remote_code=True).eval() 52 | elif model_type == "fp16": 53 | model = AutoModelForCausalLM.from_pretrained("Qwen/Qwen-VL-Chat", device_map="cuda", trust_remote_code=True, fp16=True).eval() 54 | elif model_type == "bf16": 55 | model = AutoModelForCausalLM.from_pretrained("Qwen/Qwen-VL-Chat", device_map="cuda", trust_remote_code=True, bf16=True).eval() 56 | elif model_type == "int4": 57 | def is_autogptq_installed(): 58 | package_name = 'auto_gptq' 59 | spec = importlib.util.find_spec(package_name) 60 | return spec is not None 61 | 62 | if not is_autogptq_installed(): 63 | setup_autogptq() 64 | 65 | if not is_autogptq_installed(): 66 | raise Exception("AutoGPTQ is not installed can't use int4 quantization") 67 | 68 | model = AutoModelForCausalLM.from_pretrained( 69 | "Qwen/Qwen-VL-Chat-Int4", 70 | device_map="cuda", 71 | trust_remote_code=True 72 | ).eval() 73 | else: 74 | raise Exception("Model type not supported") 75 | 76 | # model.generation_config = GenerationConfig.from_pretrained("Qwen/Qwen-VL-Chat", trust_remote_code=True) 77 | 78 | return model 79 | 80 | class QwenVL(MultimodalLLM, OCRModel): 81 | 82 | def __init__(self, model_type: model_TYPE = "bf16"): 83 | if not torch.cuda.is_available(): 84 | raise Exception("CUDA not available. Can't use QwenVL") 85 | if model_type not in ["base", "int4", "fp16", "bf16"]: 86 | raise Exception("Model type not supported") 87 | self.model_type = model_type 88 | 89 | def load_resources(self): 90 | self.model = load_model(self.model_type) 91 | 92 | if self.model_type == "int4": 93 | self.tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen-VL-Chat-Int4", trust_remote_code=True) 94 | else: 95 | self.tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen-VL-Chat", trust_remote_code=True) 96 | 97 | def release_resources(self): 98 | self.model = None 99 | self.tokenizer = None 100 | 101 | def prompt_with_image(self, image : Image.Image, query: str) -> str: 102 | 103 | with tempfile.NamedTemporaryFile(suffix=".png") as temp_file: 104 | image.save(temp_file.name) 105 | query = self.tokenizer.from_list_format([ 106 | {'image': temp_file.name}, 107 | {'text': query}, 108 | ]) 109 | response, history = self.model.chat(self.tokenizer, query=query, history=None, max_new_tokens=2048) 110 | return response 111 | 112 | def prompt(self, query: str) -> str: 113 | query = self.tokenizer.from_list_format([ 114 | {'text': query}, 115 | ]) 116 | response, history = self.model.chat(self.tokenizer, query=query, history=None, max_new_tokens=2048) 117 | return response 118 | 119 | def parse_text(self, image: Image.Image): 120 | return self.prompt_with_image(image, "Read the text in this image line by line") 121 | 122 | def draw_bbox_on_latest_picture(self, response, history): 123 | if response.startswith("") and response.endswith(""): 124 | response = response[5:-6] 125 | if "" in response and "" in response: 126 | box = response[response.index("") + 5:response.index("")] 127 | box = box.split(",") 128 | box = [int(x) for x in box] 129 | image = Image.open(history[-1]["image"]) 130 | return image.crop(box) 131 | return None -------------------------------------------------------------------------------- /overeasy/models/__init__.py: -------------------------------------------------------------------------------- 1 | # Lift imports from subdirs 2 | from .detection import * 3 | from .recognition import * 4 | from .LLMs import * 5 | from .classification import * 6 | 7 | 8 | def warmup_models(): 9 | try: 10 | qwen = QwenVL() 11 | qwen.load_resources() 12 | del qwen 13 | except: 14 | print("Skipping QwenVL") 15 | 16 | try: 17 | detic = DETIC() 18 | detic.load_resources() 19 | detic.set_classes(["hi"]) 20 | del detic 21 | except: 22 | print("Skipping DETIC") 23 | 24 | bounding_box_models = [ 25 | GroundingDINO(type=GroundingDINOModel.Pretrain_1_8M), 26 | GroundingDINO(type=GroundingDINOModel.SwinB), 27 | GroundingDINO(type=GroundingDINOModel.SwinT), 28 | GroundingDINO(type=GroundingDINOModel.mmdet_SwinB), 29 | GroundingDINO(type=GroundingDINOModel.mmdet_SwinL), 30 | GroundingDINO(type=GroundingDINOModel.mmdet_SwinB_zero), 31 | GroundingDINO(type=GroundingDINOModel.mmdet_SwinL_zero), 32 | YOLOWorld(model="yolov8s-worldv2"), 33 | YOLOWorld(model="yolov8m-worldv2"), 34 | YOLOWorld(model="yolov8l-worldv2"), 35 | YOLOWorld(model="yolov8s-world"), 36 | YOLOWorld(model="yolov8m-world"), 37 | YOLOWorld(model="yolov8l-world"), 38 | OwlV2(), 39 | ] 40 | 41 | for bounding_box_model in bounding_box_models: 42 | bounding_box_model.load_resources() 43 | del bounding_box_model 44 | 45 | 46 | 47 | clip = CLIP() 48 | del clip 49 | laionclip = LaionCLIP() 50 | del laionclip 51 | bio = BiomedCLIP() 52 | del bio 53 | 54 | pali_gemma = PaliGemma() 55 | pali_gemma.load_resources() 56 | del pali_gemma -------------------------------------------------------------------------------- /overeasy/models/classification/__init__.py: -------------------------------------------------------------------------------- 1 | from .clip import CLIP 2 | from .laion_clip import LaionCLIP, BiomedCLIP, OpenCLIPBase 3 | -------------------------------------------------------------------------------- /overeasy/models/classification/clip.py: -------------------------------------------------------------------------------- 1 | from typing import Literal, Optional 2 | from transformers import AutoProcessor, AutoModelForZeroShotImageClassification 3 | from PIL import Image 4 | import numpy as np 5 | from overeasy.types import Detections, DetectionType, ClassificationModel 6 | import open_clip 7 | import torch 8 | 9 | class CLIP(ClassificationModel): 10 | def __init__(self, model_card: str = "openai/clip-vit-large-patch14"): 11 | self.model_card = model_card 12 | 13 | def load_resources(self): 14 | self.processor = AutoProcessor.from_pretrained(self.model_card) 15 | self.model = AutoModelForZeroShotImageClassification.from_pretrained(self.model_card) 16 | self.model.eval() 17 | 18 | def release_resources(self): 19 | self.model = None 20 | self.processor = None 21 | 22 | def classify(self, image: Image.Image, classes: list) -> Detections: 23 | inputs = self.processor(text=classes, images=image, return_tensors="pt", padding=True) 24 | outputs = self.model(**inputs).logits_per_image 25 | softmax_outputs = torch.nn.functional.softmax(outputs, dim=-1).detach().numpy() 26 | index = softmax_outputs.argmax() 27 | return Detections( 28 | xyxy=np.zeros((1, 4)), 29 | class_ids=np.array([index]), 30 | confidence= np.array([softmax_outputs[0, index]]), 31 | classes=np.array(classes), 32 | detection_type=DetectionType.CLASSIFICATION 33 | ) 34 | -------------------------------------------------------------------------------- /overeasy/models/classification/laion_clip.py: -------------------------------------------------------------------------------- 1 | import open_clip 2 | import torch 3 | import numpy as np 4 | from PIL import Image 5 | from overeasy.types import Detections, DetectionType, ClassificationModel 6 | from typing import List, Optional 7 | 8 | class OpenCLIPBase(ClassificationModel): 9 | def __init__(self, model_card): 10 | self.model_card = model_card 11 | self.device = "cpu" # TODO: Work out the cuda impl 12 | 13 | def load_resources(self): 14 | model, _, preprocess_val = open_clip.create_model_and_transforms(self.model_card) 15 | self.tokenizer = open_clip.get_tokenizer(self.model_card) 16 | self.model = model 17 | self.model.to(self.device) 18 | self.preprocess = preprocess_val 19 | 20 | def release_resources(self): 21 | self.model = None 22 | self.tokenizer = None 23 | self.preprocess = None 24 | 25 | def classify(self, image: Image.Image, classes: List[str]) -> Detections: 26 | image = self.preprocess(image).to(self.device).unsqueeze(0) 27 | text = self.tokenizer(classes) 28 | 29 | with torch.no_grad(): 30 | image_features = self.model.encode_image(image) 31 | text_features = self.model.encode_text(text) 32 | image_features /= image_features.norm(dim=-1, keepdim=True) 33 | text_features /= text_features.norm(dim=-1, keepdim=True) 34 | 35 | text_probs = (100.0 * image_features @ text_features.T).softmax(dim=-1).cpu() 36 | index = text_probs.argmax() 37 | confidence = text_probs[0, index] 38 | return Detections( 39 | xyxy=np.zeros((1, 4)), 40 | class_ids=np.array([index]), 41 | confidence=np.array([confidence]), 42 | classes=np.array(classes), 43 | detection_type=DetectionType.CLASSIFICATION 44 | ) 45 | 46 | def batch_classify(self, images: List[Image.Image], classes: List[str], top_k: int = 1) -> List[Detections]: 47 | preprocessed_images = [self.preprocess(image).unsqueeze(0) for image in images] 48 | image_input = torch.cat(preprocessed_images) 49 | text_tokens = self.tokenizer(classes) 50 | 51 | with torch.no_grad(): 52 | image_features = self.model.encode_image(image_input).float() 53 | text_features = self.model.encode_text(text_tokens).float() 54 | image_features /= image_features.norm(dim=-1, keepdim=True) 55 | text_features /= text_features.norm(dim=-1, keepdim=True) 56 | 57 | text_probs = (100.0 * image_features @ text_features.T).softmax(dim=-1) 58 | top_probs, top_labels = text_probs.cpu().topk(top_k, dim=-1) 59 | 60 | detections: List[Detections] = [] 61 | for i in range(len(images)): 62 | detections.append(Detections( 63 | xyxy=np.zeros((top_k, 4)), 64 | class_ids=top_labels[i].numpy(), 65 | confidence=top_probs[i].numpy(), 66 | classes=np.array(classes)[top_labels[i].numpy()], 67 | detection_type=DetectionType.CLASSIFICATION 68 | )) 69 | return detections 70 | 71 | models = ["laion/CLIP-ViT-L-14-DataComp.XL-s13B-b90K", "hf-hub:laion/CLIP-ViT-bigG-14-laion2B-39B-b160k"] 72 | class LaionCLIP(OpenCLIPBase): 73 | def __init__(self, model_name: str = models[1]): 74 | if model_name not in models: 75 | raise ValueError(f"Model {model_name} not found") 76 | super().__init__(model_name) 77 | 78 | class BiomedCLIP(OpenCLIPBase): 79 | def __init__(self): 80 | super().__init__('hf-hub:microsoft/BiomedCLIP-PubMedBERT_256-vit_base_patch16_224') -------------------------------------------------------------------------------- /overeasy/models/classification/siglip.py: -------------------------------------------------------------------------------- 1 | from PIL import Image 2 | import torch 3 | from typing import List 4 | from overeasy.types import Detections, ClassificationModel 5 | from transformers import AutoProcessor, AutoModelForZeroShotImageClassification 6 | 7 | class SigLIP(ClassificationModel): 8 | def __init__(self, model_card="google/siglip-base-patch16-224"): 9 | self.model_card = model_card 10 | 11 | def load_resources(self): 12 | self.processor = AutoProcessor.from_pretrained(self.model_card) 13 | self.model = AutoModelForZeroShotImageClassification.from_pretrained(self.model_card) 14 | 15 | def release_resources(self): 16 | self.processor = None 17 | self.model = None 18 | 19 | def classify(self, image: Image.Image, classes: List[str]) -> Detections: 20 | return self.model(image) -------------------------------------------------------------------------------- /overeasy/models/detection/__init__.py: -------------------------------------------------------------------------------- 1 | from .dino import GroundingDINO, GroundingDINOModel 2 | from .detic import DETIC 3 | from .yoloworld import YOLOWorld 4 | from .owlv2 import OwlV2 5 | 6 | -------------------------------------------------------------------------------- /overeasy/models/detection/detclipv2.py: -------------------------------------------------------------------------------- 1 | raise NotImplementedError("Detclipv2 is not yet implemented") -------------------------------------------------------------------------------- /overeasy/models/detection/detic.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import multiprocessing as mp 3 | import os 4 | import subprocess 5 | import sys 6 | from typing import Any, Union 7 | import numpy as np 8 | import torch 9 | from overeasy.types import Detections, DetectionType 10 | from typing import List 11 | from PIL import Image 12 | import cv2 13 | from overeasy.types import BoundingBoxModel 14 | import subprocess 15 | from overeasy.download_utils import atomic_retrieve_and_rename 16 | 17 | VOCAB = "custom" 18 | CONFIDENCE_THRESHOLD = 0.3 19 | OVEREASY_DIR = os.path.abspath(os.path.expanduser("~/.overeasy")) 20 | 21 | 22 | def setup_cfg(args): 23 | from centernet.config import add_centernet_config 24 | from detic.config import add_detic_config 25 | from detectron2.config import get_cfg 26 | 27 | cfg = get_cfg() 28 | cfg.MODEL.DEVICE = "cpu" if args.cpu else "cuda" 29 | add_centernet_config(cfg) 30 | add_detic_config(cfg) 31 | cfg.merge_from_file(args.config_file) 32 | cfg.merge_from_list(args.opts) 33 | # Set score_threshold for builtin models 34 | cfg.MODEL.RETINANET.SCORE_THRESH_TEST = CONFIDENCE_THRESHOLD 35 | cfg.MODEL.ROI_HEADS.SCORE_THRESH_TEST = CONFIDENCE_THRESHOLD 36 | cfg.MODEL.PANOPTIC_FPN.COMBINE.INSTANCES_CONFIDENCE_THRESH = CONFIDENCE_THRESHOLD 37 | cfg.MODEL.ROI_BOX_HEAD.ZEROSHOT_WEIGHT_PATH = "rand" # load later 38 | if not args.pred_all_class: 39 | cfg.MODEL.ROI_HEADS.ONE_CLASS_PER_PROPOSAL = False 40 | cfg.freeze() 41 | return cfg 42 | 43 | 44 | def load_detic_model(classes : List[str], weights_file: str): 45 | original_dir = os.getcwd() 46 | try: 47 | sys.path.insert(0, os.path.join(OVEREASY_DIR, "Detic/third_party/CenterNet2/")) 48 | sys.path.insert(0, os.path.join(OVEREASY_DIR, "Detic/")) 49 | os.chdir(os.path.join(OVEREASY_DIR, "Detic/")) 50 | 51 | mp.set_start_method("spawn", force=True) 52 | 53 | args = argparse.Namespace() 54 | 55 | args.confidence_threshold = CONFIDENCE_THRESHOLD 56 | args.vocabulary = VOCAB 57 | args.opts = [] 58 | args.config_file = "configs/Detic_LCOCOI21k_CLIP_SwinB_896b32_4x_ft4x_max-size.yaml" 59 | args.cpu = False if torch.cuda.is_available() else True 60 | args.opts.append("MODEL.WEIGHTS") 61 | args.opts.append(weights_file) 62 | args.output = None 63 | args.webcam = None 64 | args.video_input = None 65 | args.custom_vocabulary = ",".join(classes).rstrip(",") 66 | args.pred_all_class = False 67 | cfg = setup_cfg(args) 68 | 69 | from detic.predictor import VisualizationDemo 70 | from detectron2.data import MetadataCatalog 71 | for key in MetadataCatalog.list(): 72 | MetadataCatalog.remove(key) 73 | 74 | # https://github.com/facebookresearch/Detic/blob/main/detic/predictor.py#L39 75 | demo = VisualizationDemo(cfg, args) 76 | return demo 77 | finally: 78 | sys.path.remove(os.path.join(OVEREASY_DIR, "Detic/third_party/CenterNet2/")) 79 | sys.path.remove(os.path.join(OVEREASY_DIR, "Detic/")) 80 | os.chdir(original_dir) 81 | 82 | 83 | 84 | HOME = os.path.expanduser("~") 85 | DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu") 86 | 87 | def check_dependencies(): 88 | # Create the ~/.cache/overeasy directory if it doesn't exist 89 | os.makedirs(OVEREASY_DIR, exist_ok=True) 90 | 91 | detic_path = os.path.join(OVEREASY_DIR, "Detic") 92 | models_dir = os.path.join(detic_path, "models") 93 | model_path = os.path.join( 94 | models_dir, "Detic_LCOCOI21k_CLIP_SwinB_896b32_4x_ft4x_max-size.pth" 95 | ) 96 | 97 | if subprocess.call(["which", "git"], stdout=subprocess.PIPE, stderr=subprocess.PIPE) != 0: 98 | raise EnvironmentError("git is not installed. Please install git to use Detic.") 99 | 100 | # Check if Detic is installed 101 | if not os.path.isdir(detic_path): 102 | subprocess.run( 103 | [ 104 | "git", 105 | "clone", 106 | "https://github.com/facebookresearch/Detic.git", 107 | "--recurse-submodules", 108 | detic_path 109 | ], 110 | check=True 111 | ) 112 | 113 | import platform 114 | install_cmds = [sys.executable, "-m", "pip", "install", "git+https://github.com/facebookresearch/detectron2.git", "--no-build-isolation"] 115 | if platform.system() == "Darwin": 116 | subprocess.run( 117 | ["export", "MACOSX_DEPLOYMENT_TARGET=10.13", "&&", *install_cmds], 118 | check=True 119 | ) 120 | else: 121 | subprocess.run( 122 | install_cmds, 123 | check=True 124 | ) 125 | 126 | os.makedirs(models_dir, exist_ok=True) 127 | 128 | model_url = "https://dl.fbaipublicfiles.com/detic/Detic_LCOCOI21k_CLIP_SwinB_896b32_4x_ft4x_max-size.pth" 129 | 130 | atomic_retrieve_and_rename(model_url, model_path) 131 | 132 | return os.path.abspath(model_path) 133 | 134 | 135 | 136 | 137 | 138 | class DETIC(BoundingBoxModel): 139 | def __init__(self): 140 | self.classes = None 141 | 142 | def load_resources(self): 143 | self.weights_file = check_dependencies() 144 | 145 | def release_resources(self): 146 | self.detic_model = None 147 | 148 | def set_classes(self, classes: List[str]): 149 | self.classes = classes 150 | self.detic_model = load_detic_model(classes, self.weights_file) 151 | 152 | def detect(self, image: Union[np.ndarray, Image.Image], classes: List[str], box_threshold=0.35, text_threshold=0.25) -> Detections: 153 | if self.classes is None: 154 | self.set_classes(classes) 155 | elif not np.array_equal(self.classes, classes): 156 | self.set_classes(classes) 157 | 158 | if isinstance(image, Image.Image): 159 | image = cv2.cvtColor(np.array(image), cv2.COLOR_RGB2BGR) 160 | 161 | predictions, visualized_output = self.detic_model.run_on_image(np.array(image)) 162 | pred_boxes = predictions["instances"].pred_boxes.tensor.cpu().numpy() 163 | pred_classes = predictions["instances"].pred_classes.cpu().numpy() 164 | pred_scores = predictions["instances"].scores.cpu().numpy() 165 | 166 | 167 | if len(pred_classes) == 0: 168 | return Detections.empty() 169 | 170 | return Detections( 171 | xyxy=np.array(pred_boxes), 172 | detection_type=DetectionType.BOUNDING_BOX, 173 | class_ids=np.array(pred_classes), 174 | confidence=np.array(pred_scores), 175 | classes=self.classes, 176 | ) 177 | 178 | -------------------------------------------------------------------------------- /overeasy/models/detection/dino.py: -------------------------------------------------------------------------------- 1 | import os 2 | from re import T 3 | import cv2 4 | import numpy as np 5 | import torch 6 | from PIL import Image 7 | from typing import List, Union 8 | from overeasy.types import Detections, DetectionType 9 | from enum import Enum 10 | from overeasy.types import BoundingBoxModel 11 | import warnings 12 | import sys, io 13 | from overeasy.download_utils import atomic_retrieve_and_rename 14 | from typing import Any 15 | 16 | # Ignore the specific UserWarning about torch.meshgrid 17 | warnings.filterwarnings("ignore", message="torch.meshgrid: in an upcoming release, it will be required to pass the indexing argument.", category=UserWarning, module='torch.functional') 18 | warnings.filterwarnings("ignore", category=FutureWarning, message="The `device` argument is deprecated and will be removed in v5 of Transformers.") 19 | # Suppress UserWarning about use_reentrant parameter in torch.utils.checkpoint 20 | warnings.filterwarnings("ignore", message="torch.utils.checkpoint: the use_reentrant parameter should be passed explicitly. In version 2.4 we will raise an exception if use_reentrant is not passed. use_reentrant=False is recommended, but if you need to preserve the current default behavior, you can pass use_reentrant=True. Refer to docs for more details on the differences between the two variants.", category=UserWarning, module='torch.utils.checkpoint') 21 | # Suppress UserWarning about None of the inputs having requires_grad=True 22 | warnings.filterwarnings("ignore", message="None of the inputs have requires_grad=True. Gradients will be None", category=UserWarning, module='torch.utils.checkpoint') 23 | 24 | 25 | DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu") 26 | 27 | class GroundingDINOModel(Enum): 28 | SwinB = "swinb" 29 | SwinT = "swint" 30 | 31 | mapping = { 32 | "swinb": { 33 | "config": "https://raw.githubusercontent.com/IDEA-Research/GroundingDINO/main/groundingdino/config/GroundingDINO_SwinB_cfg.py", 34 | "checkpoint": "https://github.com/IDEA-Research/GroundingDINO/releases/download/v0.1.0-alpha2/groundingdino_swinb_cogcoor.pth" 35 | }, 36 | "swint": { 37 | "config": "https://raw.githubusercontent.com/IDEA-Research/GroundingDINO/main/groundingdino/config/GroundingDINO_SwinT_OGC.py", 38 | "checkpoint": "https://github.com/IDEA-Research/GroundingDINO/releases/download/v0.1.0-alpha/groundingdino_swint_ogc.pth" 39 | }, 40 | "mmdet_swinb_zero": { 41 | "config": "./mmdetection/configs/mm_grounding_dino/grounding_dino_swin-b_pretrain_obj365_goldg_v3det.py", 42 | "checkpoint": "https://download.openmmlab.com/mmdetection/v3.0/mm_grounding_dino/grounding_dino_swin-b_pretrain_obj365_goldg_v3det/grounding_dino_swin-b_pretrain_obj365_goldg_v3de-f83eef00.pth" 43 | }, 44 | "mmdet_swinl_zero": { 45 | "config": "./mmdetection/configs/mm_grounding_dino/grounding_dino_swin-l_pretrain_obj365_goldg.py", 46 | "checkpoint": "https://download.openmmlab.com/mmdetection/v3.0/mm_grounding_dino/grounding_dino_swin-l_pretrain_obj365_goldg/grounding_dino_swin-l_pretrain_obj365_goldg-34dcdc53.pth" 47 | }, 48 | "mmdet_swinb": { 49 | "config": "./mmdetection/configs/mm_grounding_dino/grounding_dino_swin-l_pretrain_obj365_goldg.py", 50 | "checkpoint": "https://download.openmmlab.com/mmdetection/v3.0/mm_grounding_dino/grounding_dino_swin-b_pretrain_all/grounding_dino_swin-b_pretrain_all-f9818a7c.pth" 51 | }, 52 | "mmdet_swinl": { 53 | "config": "./mmdetection/configs/mm_grounding_dino/grounding_dino_swin-l_pretrain_all.py", 54 | "checkpoint": "https://download.openmmlab.com/mmdetection/v3.0/mm_grounding_dino/grounding_dino_swin-l_pretrain_all/grounding_dino_swin-l_pretrain_all-56d69e78.pth" 55 | } 56 | 57 | } 58 | 59 | def download_and_cache_grounding_dino(model: GroundingDINOModel): 60 | OVEREASY_CACHE_DIR = os.path.expanduser("~/.overeasy") 61 | GROUNDING_DINO_CACHE_DIR = os.path.join(OVEREASY_CACHE_DIR, "groundingdino") 62 | 63 | if not os.path.exists(GROUNDING_DINO_CACHE_DIR): 64 | os.makedirs(GROUNDING_DINO_CACHE_DIR) 65 | 66 | model_key = model.value 67 | if model_key not in mapping: 68 | raise ValueError(f"Unsupported model type: {model_key}") 69 | 70 | config_url = mapping[model_key]["config"] 71 | checkpoint_url = mapping[model_key]["checkpoint"] 72 | 73 | config_file = os.path.basename(config_url) 74 | checkpoint_file = os.path.basename(checkpoint_url) 75 | 76 | config_path = os.path.join(GROUNDING_DINO_CACHE_DIR, config_file) 77 | checkpoint_path = os.path.join(GROUNDING_DINO_CACHE_DIR, checkpoint_file) 78 | 79 | if not os.path.exists(checkpoint_path): 80 | atomic_retrieve_and_rename(checkpoint_url, checkpoint_path) 81 | 82 | if not os.path.exists(config_path): 83 | atomic_retrieve_and_rename(config_url, config_path) 84 | 85 | return config_path, checkpoint_path 86 | 87 | 88 | 89 | def load_grounding_dino(model: GroundingDINOModel): 90 | from groundingdino.util.inference import Model 91 | 92 | config_path, checkpoint_path = download_and_cache_grounding_dino(model) 93 | 94 | instantiate = lambda: Model( 95 | model_config_path=config_path, 96 | model_checkpoint_path=checkpoint_path, 97 | device=DEVICE, 98 | ) 99 | 100 | try: 101 | grounding_dino_model = instantiate() 102 | return grounding_dino_model 103 | except Exception: 104 | 105 | 106 | grounding_dino_model = instantiate() 107 | 108 | return grounding_dino_model 109 | 110 | def combine_detections(detections_list: List[Detections], classes: List[str], overwrite_class_ids=None): 111 | if len(detections_list) == 0: 112 | return Detections.empty() 113 | 114 | if not all(isinstance(detection, Detections) for detection in detections_list): 115 | raise TypeError("All elements in detections_list must be instances of Detections.") 116 | 117 | if overwrite_class_ids is not None and len(overwrite_class_ids) != len(detections_list): 118 | raise ValueError("Length of overwrite_class_ids must match the length of detections_list.") 119 | 120 | # Initialize lists to collect combined attributes 121 | xyxy = [] 122 | confidence = [] 123 | class_ids = [] 124 | data = [] 125 | masks = [] 126 | 127 | detection_types = [detection.detection_type for detection in detections_list] 128 | if len(set(detection_types)) > 1: 129 | raise ValueError("All detections in the list must have the same type.") 130 | 131 | 132 | for idx, detection in enumerate(detections_list): 133 | xyxy.append(detection.xyxy) 134 | 135 | if detection.confidence is not None: 136 | confidence.append(detection.confidence) 137 | 138 | if detection.class_ids is not None: 139 | if overwrite_class_ids is not None: 140 | class_ids.append(np.full_like(detection.class_ids, overwrite_class_ids[idx], dtype=np.int32)) 141 | else: 142 | class_ids.append(detection.class_ids) 143 | if detection.masks is not None: 144 | masks.append(detection.masks) 145 | 146 | # Merge custom data from each detection 147 | data.append(detection.data) 148 | 149 | return Detections( 150 | xyxy=np.vstack(xyxy), 151 | masks=np.hstack(masks) if masks else None, # Assuming masks are not handled in this function 152 | classes = classes, 153 | confidence=np.hstack(confidence) if confidence else None, 154 | class_ids=np.hstack(class_ids) if class_ids else None, 155 | detection_type=detections_list[0].detection_type 156 | ) 157 | 158 | class GroundingDINO(BoundingBoxModel): 159 | grounding_dino_model: Any 160 | box_threshold: float 161 | text_threshold: float 162 | 163 | def __init__( 164 | self, type: GroundingDINOModel = GroundingDINOModel.SwinB, 165 | box_threshold: float = 0.35, 166 | text_threshold: float = 0.25, 167 | ): 168 | self.box_threshold = box_threshold 169 | self.text_threshold = text_threshold 170 | self.model_type = type 171 | 172 | def load_resources(self): 173 | # if DEVICE.type != "cuda": 174 | # warnings.warn("CUDA not available. GroundingDINO may run slowly.") 175 | 176 | download_and_cache_grounding_dino(self.model_type) 177 | original_stdout = sys.stdout 178 | sys.stdout = io.StringIO() 179 | try: 180 | self.grounding_dino_model = load_grounding_dino(model=self.model_type) 181 | except Exception as e: 182 | print(sys.stdout.getvalue()) 183 | print(f"Error loading GroundingDINO model: {e}") 184 | raise e 185 | finally: 186 | sys.stdout = original_stdout 187 | 188 | def release_resources(self): 189 | self.grounding_dino_model = None 190 | 191 | # def detect_multiple(self, images: List[np.ndarray], class_groups: List[List[str]], box_threshold) -> List[Detections]: 192 | # cv2_images = [cv2.cvtColor(np.array(image), cv2.COLOR_RGB2BGR) if isinstance(image, Image.Image) else image for image in images] 193 | 194 | # raise ValueError("Unsupported model type") 195 | 196 | def detect(self, image: Union[np.ndarray, Image.Image], classes: List[str], box_threshold=None, text_threshold=None) -> Detections: 197 | if box_threshold is None: 198 | box_threshold = self.box_threshold 199 | if text_threshold is None: 200 | text_threshold = self.text_threshold 201 | 202 | 203 | if isinstance(image, Image.Image): 204 | image = cv2.cvtColor(np.array(image), cv2.COLOR_RGB2BGR) 205 | 206 | sv_detection = self.grounding_dino_model.predict_with_classes( 207 | image=image, 208 | classes=classes, 209 | box_threshold=box_threshold, 210 | text_threshold=text_threshold, 211 | ) 212 | valid_indexes = [i for i, class_id in enumerate(sv_detection.class_id) if class_id is not None] 213 | sv_detection = sv_detection[valid_indexes] 214 | detections = Detections.from_supervision_detection(sv_detection, classes=classes) 215 | 216 | return detections -------------------------------------------------------------------------------- /overeasy/models/detection/owlv2.py: -------------------------------------------------------------------------------- 1 | 2 | 3 | from typing import List 4 | import numpy as np 5 | from PIL import Image 6 | from overeasy.types import BoundingBoxModel, Detections, DetectionType 7 | from transformers import pipeline 8 | import torch 9 | 10 | class OwlV2(BoundingBoxModel): 11 | def __init__(self): 12 | self.checkpoint = "google/owlv2-base-patch16-ensemble" 13 | 14 | def load_resources(self): 15 | self.detector = pipeline(model=self.checkpoint, task="zero-shot-object-detection") 16 | 17 | def release_resources(self): 18 | self.detector = None 19 | 20 | def detect(self, image: Image.Image, classes: List[str], threshold: float = 0) -> Detections: 21 | predictions = self.detector(image, candidate_labels=classes,) 22 | num_preds = len(predictions) 23 | xyxy = np.zeros((num_preds, 4), dtype=np.int32) 24 | confidence = np.zeros(num_preds, dtype=np.float32) 25 | class_ids = np.zeros(num_preds, dtype=np.int64) 26 | 27 | for i, pred in enumerate(predictions): 28 | x1, y1, x2, y2 = pred['box']['xmin'], pred['box']['ymin'], pred['box']['xmax'], pred['box']['ymax'] 29 | 30 | xyxy[i] = [x1, y1, x2, y2] 31 | confidence[i] = pred['score'] 32 | class_ids[i] = classes.index(pred['label']) 33 | 34 | slice = confidence > threshold 35 | xyxy = xyxy[slice, :] 36 | confidence = confidence[slice] 37 | class_ids = class_ids[slice] 38 | 39 | return Detections( 40 | xyxy=xyxy, 41 | confidence=confidence, 42 | class_ids=class_ids, 43 | classes=classes, 44 | detection_type=DetectionType.BOUNDING_BOX 45 | ) -------------------------------------------------------------------------------- /overeasy/models/detection/yoloworld.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | from typing import List, Union 4 | from overeasy.types import BoundingBoxModel, Detections 5 | import numpy as np 6 | from PIL import Image 7 | from overeasy.download_utils import atomic_retrieve_and_rename 8 | from ultralytics import YOLOWorld as YOLOWorld_ultralytics 9 | 10 | 11 | DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu") 12 | 13 | 14 | valid_models = [ 15 | "yolov8l-world-cc3m.pt", 16 | "yolov8l-world.pt", 17 | "yolov8l-worldv2-cc3m.pt", 18 | "yolov8l-worldv2.pt", 19 | "yolov8m-world.pt", 20 | "yolov8m-worldv2.pt", 21 | "yolov8s-world.pt", 22 | "yolov8s-worldv2.pt", 23 | "yolov8x-world.pt", 24 | "yolov8x-worldv2.pt" 25 | ] 26 | 27 | def get_yoloworld_model(model: str) -> str: 28 | if not model.endswith(".pt"): 29 | model = model + ".pt" 30 | 31 | local_model_path = os.path.join(os.path.expanduser("~/.overeasy"), model) 32 | if os.path.exists(local_model_path): 33 | return local_model_path 34 | 35 | url = None 36 | if model in valid_models: 37 | url = f"https://github.com/ultralytics/assets/releases/download/v8.2.0/{model}" 38 | else: 39 | raise ValueError(f"Model {model} not valid. Valid models are: {valid_models}") 40 | 41 | atomic_retrieve_and_rename(url, local_model_path) 42 | 43 | return local_model_path 44 | 45 | # Load a pretrained YOLOv8s-worldv2 model 46 | class YOLOWorld(BoundingBoxModel): 47 | def __init__(self, model: str = "yolov8l-worldv2-cc3m"): 48 | self.model_name = model 49 | 50 | def load_resources(self): 51 | self.model_path = get_yoloworld_model(self.model_name) 52 | self.model = YOLOWorld_ultralytics(self.model_path) 53 | 54 | def release_resources(self): 55 | self.model = None 56 | 57 | def detect(self, image: Image.Image, classes: List[str]) -> Detections: 58 | self.model.set_classes(classes) 59 | results = self.model.predict(image, verbose=False) 60 | detections = Detections.from_ultralytics(results[0]) 61 | return detections 62 | -------------------------------------------------------------------------------- /overeasy/models/recognition/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/overeasy-sh/overeasy/8de64ae6a21e0e8c757d3359c52c6134558c34b5/overeasy/models/recognition/__init__.py -------------------------------------------------------------------------------- /overeasy/py.typed: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/overeasy-sh/overeasy/8de64ae6a21e0e8c757d3359c52c6134558c34b5/overeasy/py.typed -------------------------------------------------------------------------------- /overeasy/types/__init__.py: -------------------------------------------------------------------------------- 1 | 2 | from .base import * 3 | from .detections import Detections, DetectionType 4 | -------------------------------------------------------------------------------- /overeasy/types/base.py: -------------------------------------------------------------------------------- 1 | import itertools 2 | from typing import Any, Dict, List, Union, Optional, Any 3 | from abc import ABC, abstractmethod 4 | from typing import List 5 | 6 | from overeasy.logging import log_time 7 | from .detections import Detections 8 | from PIL import Image 9 | from overeasy.visualize_utils import annotate, annotate_with_string 10 | from dataclasses import dataclass 11 | from collections import defaultdict 12 | 13 | __all__ = [ 14 | "OCRModel", 15 | "LLM", 16 | "MultimodalLLM", 17 | "BoundingBoxModel", 18 | "ExecutionNode", 19 | "NullExecutionNode", 20 | "Node", 21 | "ExecutionGraph", 22 | "Model", 23 | "Agent", 24 | "ImageDetectionAgent", 25 | "ImageToTextAgent", 26 | "ImageToDataAgent", 27 | "DetectionAgent", 28 | "TextAgent", 29 | "DataAgent", 30 | "ClassificationModel", 31 | "CaptioningModel", 32 | ] 33 | 34 | @dataclass 35 | class ExecutionNode: 36 | image: Image.Image 37 | data: Union[Detections, Any] 38 | parent_detection: Optional[Detections] = None 39 | 40 | def data_is_detections(self) -> bool: 41 | return isinstance(self.data, Detections) 42 | 43 | def visualize(self, seed: Optional[int] = None) -> Image.Image: 44 | if self.data_is_detections(): 45 | return annotate(self.image, self.data, seed) 46 | else: 47 | return annotate_with_string(self.image, str(self.data)) 48 | 49 | @property 50 | def id(self): 51 | return id(self) 52 | 53 | @dataclass 54 | class NullExecutionNode(): 55 | parent_detection: Optional[Detections] = None 56 | 57 | @property 58 | def id(self): 59 | return id(self) 60 | 61 | Node = Union[ExecutionNode, NullExecutionNode] 62 | 63 | @dataclass 64 | class ExecutionGraph: 65 | root: ExecutionNode 66 | edges: Dict[int, List[Node]] 67 | parent: Dict[int, List[Node]] 68 | 69 | def __init__(self, root: ExecutionNode): 70 | self.root = root 71 | self.edges = {} 72 | self.parent = {} 73 | 74 | def add_child(self, parent: Node, child: Node): 75 | assert isinstance(parent, ExecutionNode) or isinstance(parent, NullExecutionNode) 76 | assert isinstance(child, ExecutionNode) or isinstance(child, NullExecutionNode) 77 | 78 | if parent.id == child.id: 79 | raise ValueError("Cannot self loops to execution graph") 80 | 81 | # Add edges and reverse edges 82 | if parent.id not in self.edges: 83 | self.edges[parent.id] = [] 84 | self.edges[parent.id].append(child) 85 | 86 | if child.id not in self.parent: 87 | self.parent[child.id] = [] 88 | self.parent[child.id].append(parent) 89 | 90 | def ascii_graph(self): 91 | print("ExecutionGraph") 92 | id_counter = itertools.count() 93 | id_map = {} 94 | 95 | def get_mapped_id(original_id): 96 | if original_id not in id_map: 97 | id_map[original_id] = next(id_counter) 98 | return id_map[original_id] 99 | 100 | def print_node(node, prefix="", is_last=True): 101 | connector = "└── " if is_last else "├── " 102 | print(prefix + connector + f"Node(ID={get_mapped_id(node.id)})") 103 | if node.id in self.edges: 104 | children = self.edges[node.id] 105 | for i, child in enumerate(children): 106 | next_prefix = prefix + (" " if is_last else "│ ") 107 | print_node(child, next_prefix, i == len(children) - 1) 108 | 109 | print_node(self.root) 110 | 111 | 112 | def parent_of(self, node: Node) -> Node: 113 | if node.id not in self.parent: 114 | raise ValueError(f"Node {node.id} has no parent") 115 | parent_list = self.parent[node.id] 116 | if len(parent_list) > 1: 117 | print("Multiple parents", parent_list) 118 | raise ValueError("Node has multiple parents") 119 | return self.parent[node.id][0] 120 | 121 | def children(self, node: Node) -> List[Node]: 122 | if node.id not in self.edges: 123 | raise ValueError(f"Node {node.id} has no children") 124 | 125 | return self.edges[node.id] 126 | 127 | def __getitem__(self, node: Node) -> List[Node]: 128 | return self.edges[node.id] 129 | 130 | def __repr__(self): 131 | return f"ExecutionGraph(root={self.root}, edges={self.edges}, parent={self.parent})" 132 | 133 | # TODO: This does a lot of heavy lifting so it should have some more strict ordering properties 134 | def top_sort(self) -> List[List[Node]]: 135 | # Create a copy of the edges dictionary to avoid mutating the original 136 | if len(self.edges) == 0: 137 | return [[self.root]] 138 | 139 | edges_copy = {node_id: neighbors.copy() for node_id, neighbors in self.edges.items()} 140 | 141 | # Initialize the in-degree dictionary 142 | in_degree : Dict[int, int] = defaultdict(int) 143 | for neighbors in edges_copy.values(): 144 | for node in neighbors: 145 | in_degree[node.id] += 1 146 | 147 | # Initialize the queue with nodes having in-degree 0 148 | queue : List[Node] = [node for node in [self.root] if in_degree[node.id] == 0] 149 | 150 | # Initialize the sorted nodes list 151 | sorted_nodes = [] 152 | 153 | # Perform the topological sort 154 | while queue: 155 | level_nodes = [] 156 | for _ in range(len(queue)): 157 | node = queue.pop(0) 158 | level_nodes.append(node) 159 | 160 | # Decrement the in-degree of the neighbors and add them to the queue if in-degree becomes 0 161 | if node.id in edges_copy: 162 | for neighbor in edges_copy[node.id]: 163 | in_degree[neighbor.id] -= 1 164 | if in_degree[neighbor.id] == 0: 165 | queue.append(neighbor) 166 | 167 | sorted_nodes.append(level_nodes) 168 | 169 | return sorted_nodes 170 | 171 | class Model(ABC): 172 | # Load model resources, such as allocating memory for weights, only once when this method is called. 173 | # Note: Weights can still be downloaded to disk, just not allocated to VRAM until load_resources is called. 174 | @abstractmethod 175 | def load_resources(self): 176 | pass 177 | 178 | # Free up the allocated resources, such as memory for weights, when this method is called. 179 | @abstractmethod 180 | def release_resources(self): 181 | pass 182 | 183 | def __del__(self): 184 | self.release_resources() 185 | 186 | 187 | # OCRModel reads text from an image 188 | class OCRModel(Model): 189 | @abstractmethod 190 | def parse_text(self, image: Image.Image) -> str: 191 | pass 192 | 193 | class LLM(Model): 194 | @abstractmethod 195 | def prompt(self, query: str) -> str: 196 | pass 197 | 198 | #TODO: Add support for prompting with multiple images 199 | class MultimodalLLM(LLM): 200 | @abstractmethod 201 | def prompt_with_image(self, image: Image.Image, query: str) -> str: 202 | pass 203 | 204 | class CaptioningModel(Model): 205 | @abstractmethod 206 | def caption(self, image: Image.Image) -> str: 207 | pass 208 | 209 | class BoundingBoxModel(Model): 210 | @abstractmethod 211 | def detect(self, image: Image.Image, classes: List[str]) -> Detections: 212 | pass 213 | 214 | class ClassificationModel(Model): 215 | @abstractmethod 216 | def classify(self, image: Image.Image, classes: list) -> Detections: 217 | pass 218 | 219 | class Agent(ABC): 220 | pass 221 | 222 | 223 | class ImageToDataAgent(Agent): 224 | @abstractmethod 225 | def _execute(self, image: Image.Image) -> Any: 226 | pass 227 | 228 | @log_time 229 | def execute(self, image: Image.Image) -> Any: 230 | if not isinstance(image, Image.Image): 231 | raise ValueError(f"{self.__class__.__name__} received non-image data") 232 | output = self._execute(image) 233 | self._validate_output(output) 234 | return output 235 | 236 | def _validate_output(self, output: Any) -> None: 237 | pass 238 | 239 | class ImageDetectionAgent(ImageToDataAgent): 240 | def _validate_output(self, output: Any) -> None: 241 | if not isinstance(output, Detections): 242 | raise ValueError(f"{self.__class__.__name__} returned non-detection data") 243 | 244 | class ImageToTextAgent(ImageToDataAgent): 245 | def _validate_output(self, output: Any) -> None: 246 | if not isinstance(output, str): 247 | raise ValueError(f"{self.__class__.__name__} returned non-string data") 248 | 249 | class DataAgent(Agent): 250 | @abstractmethod 251 | def _execute(self, data: Any) -> Any: 252 | pass 253 | 254 | @log_time 255 | def execute(self, data: Any) -> Any: 256 | self._validate_input(data) 257 | output = self._execute(data) 258 | self._validate_output(output) 259 | return output 260 | 261 | def _validate_input(self, data: Any) -> None: 262 | pass 263 | 264 | def _validate_output(self, output: Any) -> None: 265 | pass 266 | class DetectionAgent(DataAgent): 267 | @abstractmethod 268 | def _execute(self, data: Detections) -> Detections: 269 | pass 270 | 271 | def _validate_input(self, data: Any) -> None: 272 | if not isinstance(data, Detections): 273 | raise ValueError("DetectionAgent received non-detection data") 274 | 275 | def _validate_output(self, output: Any) -> None: 276 | if not isinstance(output, Detections): 277 | raise ValueError("DetectionAgent returned non-detection data") 278 | class TextAgent(DataAgent): 279 | @abstractmethod 280 | def _execute(self, data: str) -> Any: 281 | pass 282 | 283 | def _validate_input(self, data: Any) -> None: 284 | if not isinstance(data, str): 285 | raise ValueError("TextAgent received non-string data") -------------------------------------------------------------------------------- /overeasy/types/detections.py: -------------------------------------------------------------------------------- 1 | from pydantic.dataclasses import dataclass as pydantic_dataclass 2 | import pydantic_numpy.typing as pnd 3 | from pydantic import Field 4 | import supervision as sv 5 | import numpy as np 6 | from typing import Any, Dict, List, Optional, Union, Iterator, Tuple 7 | 8 | from supervision.detection.utils import ( 9 | extract_ultralytics_masks, 10 | ) 11 | 12 | from .type_utils import ( 13 | validate_detections_fields, 14 | is_data_equal, 15 | get_data_item, 16 | DetectionType 17 | ) 18 | 19 | ORIENTED_BOX_COORDINATES="oriented_box_coordinates" 20 | 21 | @pydantic_dataclass 22 | class Detections: 23 | """ 24 | Represents a collection of detection data including bounding boxes, segmentation masks, 25 | class IDs, and additional custom data. It supports operations like filtering, splitting, 26 | and merging detections. 27 | 28 | Attributes: 29 | xyxy (np.ndarray[np.float32]): Coordinates of bounding boxes `[x1, y1, x2, y2]` for each detection. 30 | masks (Optional[np.ndarray[np.float32]]): Segmentation masks for each detection, required for segmentation types. 31 | class_ids (Optional[np.ndarray[np.int32]]): Class IDs for each detection, indexing into `classes`. 32 | classes (Optional[np.ndarray[np.object_]]): Array of class names, indexed by `class_ids`. 33 | confidence (Optional[np.ndarray[np.float32]]): Confidence scores for each detection. 34 | data (Dict[str, Union[np.ndarray, List]]): Additional custom data related to detections. 35 | detection_type (DetectionType): Type of detection (e.g., bounding box, segmentation, classification). 36 | """ 37 | 38 | xyxy: pnd.NpNDArrayFp32 39 | detection_type: DetectionType 40 | class_ids: pnd.NpNDArrayInt32 41 | classes: pnd.NpNDArray 42 | masks: Optional[pnd.NpNDArrayFp32] = None 43 | confidence: Optional[pnd.NpNDArrayFp32] = None 44 | data: Dict[str, Union[pnd.NpNDArray, List]] = Field(default_factory=dict) 45 | 46 | @property 47 | def confidence_scores(self) -> np.ndarray: 48 | if self.confidence is None: 49 | return np.array([None] * self.xyxy.shape[0]) 50 | return self.confidence 51 | 52 | 53 | def split(self) -> List['Detections']: 54 | """ 55 | Split the Detections object into a list of Detections objects, each containing 56 | a single detection. 57 | 58 | Returns: 59 | List[Detections]: A list of Detections objects, each containing a single 60 | detection. 61 | """ 62 | rows, _ = np.shape(self.xyxy) 63 | if self.detection_type == DetectionType.CLASSIFICATION: 64 | raise ValueError("Splitting is not supported for classification detections as they are considered a single entity.") 65 | else: 66 | return [Detections( 67 | xyxy=self.xyxy[i:i+1], 68 | masks=self.masks[i:i+1] if self.masks is not None else None, 69 | class_ids=self.class_ids[i:i+1] if self.class_ids is not None else None, 70 | classes=self.classes, 71 | confidence=self.confidence[i:i+1] if self.confidence is not None else None, 72 | data={key: [value[i]] for key, value in self.data.items()} if self.data else {}, 73 | detection_type=self.detection_type 74 | ) for i in range(rows)] 75 | 76 | def __post_init__(self): 77 | validate_detections_fields( 78 | xyxy=self.xyxy, 79 | detection_type=self.detection_type, 80 | masks=self.masks, 81 | confidence=self.confidence, 82 | class_ids=self.class_ids, 83 | classes= self.classes, 84 | data=self.data, 85 | ) 86 | 87 | def __len__(self): 88 | """ 89 | Returns the number of detections in the Detections object. 90 | """ 91 | return len(self.xyxy) 92 | 93 | 94 | def __iter__(self) -> Iterator[Tuple[np.ndarray, 95 | DetectionType, 96 | Optional[List[np.ndarray]], 97 | Optional[List[int]], 98 | Optional[List[str]], 99 | Optional[np.ndarray]]]: 100 | """ 101 | Iterates over the Detections object and yields a tuple of 102 | `(xyxy, masks, class_ids, classes, confidence)` for each detection. 103 | """ 104 | for i in range(len(self.xyxy)): 105 | yield ( 106 | self.xyxy[i], 107 | self.detection_type, 108 | self.masks[i] if self.masks is not None else None, 109 | self.class_ids if self.class_ids is not None else None, 110 | self.classes if self.classes is not None else None, 111 | self.confidence[i] if self.confidence is not None else None, 112 | ) 113 | 114 | 115 | 116 | def __eq__(self, other: object): 117 | if not isinstance(other, Detections): 118 | raise NotImplementedError("Only Detections objects can be compared.") 119 | 120 | def array_equal(array1: Optional[np.ndarray], array2: Optional[np.ndarray]) -> bool: 121 | if array1 is None and array2 is None: 122 | return True 123 | elif array1 is None or array2 is None: 124 | return False 125 | else: 126 | return np.array_equal(array1, array2) 127 | 128 | return all( 129 | [ 130 | np.array_equal(self.xyxy, other.xyxy), 131 | array_equal(self.masks, other.masks), 132 | array_equal(self.class_ids, other.class_ids), 133 | array_equal(self.confidence, other.confidence), 134 | array_equal(self.classes, other.classes), 135 | is_data_equal(self.data, other.data), 136 | self.detection_type == other.detection_type 137 | ] 138 | ) 139 | 140 | @classmethod 141 | def from_classification(cls, assigned_classes: List[str], all_classes: Optional[List[str]] = None) -> 'Detections': 142 | if all_classes is None: 143 | return cls( 144 | xyxy=np.zeros((len(assigned_classes), 4)), 145 | class_ids=np.arange(len(assigned_classes)), 146 | classes=np.array(assigned_classes), 147 | detection_type=DetectionType.CLASSIFICATION 148 | ) 149 | else: 150 | return cls( 151 | xyxy=np.zeros((len(assigned_classes), 4)), 152 | class_ids=np.array([all_classes.index(c) for c in assigned_classes]), 153 | classes=np.array(all_classes), 154 | detection_type=DetectionType.CLASSIFICATION 155 | ) 156 | 157 | @classmethod 158 | def from_yolov5(cls, yolov5_results) -> 'Detections': 159 | """ 160 | Creates a Detections instance from a 161 | [YOLOv5](https://github.com/ultralytics/yolov5) inference result. 162 | 163 | Args: 164 | yolov5_results (yolov5.models.common.Detections): 165 | The output Detections instance from YOLOv5 166 | 167 | Returns: 168 | Detections: A new Detections object. 169 | 170 | Example: 171 | ```python 172 | import cv2 173 | import torch 174 | import overeasy as ov 175 | 176 | image = cv2.imread() 177 | model = torch.hub.load('ultralytics/yolov5', 'yolov5s') 178 | result = model(image) 179 | detections = ov.Detections.from_yolov5(result) 180 | ``` 181 | """ 182 | yolov5_detections_predictions = yolov5_results.pred[0].cpu().cpu().numpy() 183 | classes = yolov5_results.names.values() 184 | return cls( 185 | xyxy=yolov5_detections_predictions[:, :4], 186 | confidence=yolov5_detections_predictions[:, 4], 187 | class_ids=yolov5_detections_predictions[:, 5].astype(int), 188 | classes=np.array(classes), 189 | detection_type=DetectionType.BOUNDING_BOX 190 | ) 191 | 192 | @classmethod 193 | def from_ultralytics(cls, ultralytics_results) -> 'Detections': 194 | """ 195 | Creates a Detections instance from a 196 | [YOLOv8](https://github.com/ultralytics/ultralytics) inference result. 197 | 198 | !!! Note 199 | 200 | `from_ultralytics` is compatible with 201 | [detection](https://docs.ultralytics.com/tasks/detect/), 202 | [segmentation](https://docs.ultralytics.com/tasks/segment/), and 203 | [OBB](https://docs.ultralytics.com/tasks/obb/) models. 204 | 205 | Args: 206 | ultralytics_results (ultralytics.yolo.engine.results.Results): 207 | The output Results instance from YOLOv8 208 | 209 | Returns: 210 | Detections: A new Detections object. 211 | 212 | Example: 213 | ```python 214 | import cv2 215 | import overeasy as ov 216 | from ultralytics import YOLO 217 | 218 | image = cv2.imread() 219 | model = YOLO('yolov8s.pt') 220 | 221 | result = model(image)[0] 222 | detections = ov.Detections.from_ultralytics(result) 223 | ``` 224 | """ # noqa: E501 // docs 225 | 226 | if ultralytics_results.obb is not None: 227 | class_id = ultralytics_results.obb.cls.cpu().numpy().astype(int) 228 | class_names = np.array([ultralytics_results.names[i] for i in class_id]) 229 | oriented_box_coordinates = ultralytics_results.obb.xyxyxyxy.cpu().numpy() 230 | return cls( 231 | xyxy=ultralytics_results.obb.xyxy.cpu().numpy(), 232 | confidence=ultralytics_results.obb.conf.cpu().numpy(), 233 | class_ids=class_id, 234 | classes = class_names, 235 | data={ 236 | ORIENTED_BOX_COORDINATES: oriented_box_coordinates, 237 | }, 238 | detection_type=DetectionType.BOUNDING_BOX 239 | ) 240 | 241 | 242 | masks = extract_ultralytics_masks(ultralytics_results) 243 | class_id = ultralytics_results.boxes.cls.cpu().numpy().astype(int) 244 | 245 | if isinstance(ultralytics_results.names, dict): 246 | class_names = np.array(list(ultralytics_results.names.values())) 247 | else: 248 | class_names = np.array(ultralytics_results.names) 249 | 250 | return cls( 251 | xyxy=ultralytics_results.boxes.xyxy.cpu().numpy(), 252 | confidence=ultralytics_results.boxes.conf.cpu().numpy(), 253 | class_ids=class_id, 254 | masks=masks, 255 | classes = class_names, 256 | data={}, 257 | detection_type=DetectionType.BOUNDING_BOX if masks is None else DetectionType.SEGMENTATION 258 | ) 259 | 260 | @classmethod 261 | def from_supervision_detection( 262 | cls, 263 | sv_detection, 264 | classes: Optional[List[str]] = None, 265 | **kwargs 266 | ) -> 'Detections': 267 | 268 | xyxy = sv_detection.xyxy 269 | class_ids = np.array(sv_detection.class_id) if sv_detection.class_id is not None else None 270 | masks = np.array(sv_detection.mask) if sv_detection.mask is not None else None 271 | confidence = np.array(sv_detection.confidence) if sv_detection.confidence is not None else None 272 | 273 | return cls( 274 | xyxy=xyxy, 275 | class_ids=class_ids, 276 | masks=masks, 277 | classes=np.array(classes) if classes is not None else None, 278 | confidence=confidence, 279 | detection_type=DetectionType.BOUNDING_BOX if masks is None else DetectionType.SEGMENTATION, 280 | **kwargs 281 | ) 282 | 283 | @classmethod 284 | def empty(cls) -> 'Detections': 285 | return cls( 286 | xyxy=np.empty((0, 4), dtype=np.float32), 287 | classes=np.array([], dtype='object'), 288 | class_ids=np.array([], dtype=np.int32), 289 | confidence=np.array([], dtype=np.float32), 290 | detection_type=DetectionType.BOUNDING_BOX, 291 | data={} 292 | ) 293 | 294 | 295 | def __getitem__( 296 | self, index: Union[int, slice, List[int], np.ndarray] 297 | ) -> 'Detections': 298 | if isinstance(index, int): 299 | index = [index] 300 | return Detections( 301 | xyxy=self.xyxy[index], 302 | masks=self.masks[index] if self.masks is not None else None, 303 | confidence=self.confidence[index] if self.confidence is not None else None, 304 | class_ids=self.class_ids[index], 305 | classes=self.classes, 306 | data=get_data_item(self.data, index), 307 | detection_type=self.detection_type 308 | ) 309 | 310 | def __setitem__(self, key: str, value: Union[np.ndarray, List]): 311 | if not isinstance(value, (np.ndarray, list)): 312 | raise TypeError("Value must be a np.ndarray or a list") 313 | 314 | if isinstance(value, list): 315 | value = np.array(value) 316 | 317 | self.data[key] = value 318 | 319 | @property 320 | def area(self) -> np.ndarray: 321 | if self.masks is not None: 322 | return np.array([np.sum(mask) for mask in self.masks]) 323 | else: 324 | return self.box_area 325 | 326 | @property 327 | def box_area(self) -> np.ndarray: 328 | return (self.xyxy[:, 3] - self.xyxy[:, 1]) * (self.xyxy[:, 2] - self.xyxy[:, 0]) 329 | 330 | @property 331 | def class_names(self) -> List[str]: 332 | try: 333 | return [self.classes[class_id] for class_id in self.class_ids] 334 | except IndexError as e: 335 | raise IndexError(f"One or more class_ids are out of bounds for the available classes: {e}") 336 | 337 | 338 | def to_supervision(self) -> sv.Detections: 339 | # Extract class names and class IDs if they exist in the data dictionary 340 | 341 | return sv.Detections( 342 | xyxy=self.xyxy, 343 | confidence=self.confidence, 344 | class_id=self.class_ids, 345 | mask=self.masks, 346 | data=self.data 347 | ) 348 | 349 | 350 | 351 | 352 | 353 | 354 | 355 | -------------------------------------------------------------------------------- /overeasy/types/type_utils.py: -------------------------------------------------------------------------------- 1 | from typing import Any, Dict, Union, List 2 | from itertools import chain 3 | import numpy as np 4 | from enum import Enum 5 | 6 | class DetectionType(Enum): 7 | BOUNDING_BOX = "bounding_box" 8 | SEGMENTATION = "segmentation" 9 | CLASSIFICATION = "classification" 10 | 11 | def validate_data(data: Dict[str, Any], n: int) -> None: 12 | for key, value in data.items(): 13 | if isinstance(value, list): 14 | if len(value) != n: 15 | raise ValueError(f"Length of list for key '{key}' must be {n}") 16 | elif isinstance(value, np.ndarray): 17 | if value.ndim == 1 and value.shape[0] != n: 18 | raise ValueError(f"Shape of np.ndarray for key '{key}' must be ({n},)") 19 | elif value.ndim > 1 and value.shape[0] != n: 20 | raise ValueError( 21 | f"First dimension of np.ndarray for key '{key}' must have size {n}" 22 | ) 23 | else: 24 | raise ValueError(f"Value for key '{key}' must be a list or np.ndarray") 25 | 26 | def validate_detections_fields( 27 | xyxy: Any, 28 | masks: Any, 29 | class_ids: Any, 30 | classes: Any, 31 | confidence: Any, 32 | data: Dict[str, Any], 33 | detection_type: Any, 34 | ) -> None: 35 | expected_shape_xyxy = "(_, 4)" 36 | actual_shape_xyxy = str(getattr(xyxy, "shape", None)) 37 | is_valid_xyxy = isinstance(xyxy, np.ndarray) and xyxy.ndim == 2 and xyxy.shape[1] == 4 38 | if not is_valid_xyxy: 39 | raise ValueError( 40 | f"xyxy must be a 2D np.ndarray with shape {expected_shape_xyxy}, but got shape " 41 | f"{actual_shape_xyxy}" 42 | ) 43 | 44 | n = len(xyxy) 45 | if masks is not None: 46 | expected_shape_masks = f"({n}, H, W)" 47 | actual_shape_masks = str(getattr(masks, "shape", None)) 48 | is_valid_masks = isinstance(masks, np.ndarray) and len(masks.shape) == 3 and masks.shape[0] == n 49 | if not is_valid_masks: 50 | raise ValueError( 51 | f"masks must be a 3D np.ndarray with shape {expected_shape_masks}, but got shape " 52 | f"{actual_shape_masks}" 53 | ) 54 | 55 | if class_ids is not None: 56 | expected_shape_class_ids = f"({n},)" 57 | actual_shape_class_ids = str(getattr(class_ids, "shape", None)) 58 | is_valid_class_ids = isinstance(class_ids, np.ndarray) and class_ids.shape == (n,) 59 | if not is_valid_class_ids: 60 | raise ValueError( 61 | f"class_ids must be a 1D np.ndarray with shape {expected_shape_class_ids}, but got " 62 | f"shape {actual_shape_class_ids}" 63 | ) 64 | 65 | if classes is not None: 66 | is_valid_classes = isinstance(classes, np.ndarray) and classes.dtype.kind in {'U', 'O'} 67 | if not is_valid_classes: 68 | raise ValueError( 69 | "classes must be a np.ndarray of strings." 70 | ) 71 | 72 | if confidence is not None: 73 | expected_shape_confidence = f"({n},)" 74 | actual_shape_confidence = str(getattr(confidence, "shape", None)) 75 | is_valid_confidence = isinstance(confidence, np.ndarray) and confidence.shape == (n,) 76 | if not is_valid_confidence: 77 | raise ValueError( 78 | f"confidence must be a 1D np.ndarray with shape {expected_shape_confidence}, but got " 79 | f"shape {actual_shape_confidence}" 80 | ) 81 | 82 | if not isinstance(detection_type, DetectionType): 83 | raise ValueError( 84 | "detection_type must be an instance of DetectionType enum." 85 | ) 86 | 87 | validate_data(data, n) 88 | 89 | def is_data_equal(data_a: Dict[str, Union[np.ndarray, List]], data_b: Dict[str, Union[np.ndarray, List]]) -> bool: 90 | """ 91 | Compares the data payloads of two Detections instances. 92 | 93 | Args: 94 | data_a, data_b: The data payloads of the instances. 95 | 96 | Returns: 97 | True if the data payloads are equal, False otherwise. 98 | """ 99 | return set(data_a.keys()) == set(data_b.keys()) and all( 100 | np.array_equal(data_a[key], data_b[key]) for key in data_a 101 | ) 102 | 103 | def get_data_item( 104 | data: Dict[str, Union[np.ndarray, List]], 105 | index: Union[int, slice, List[int], np.ndarray], 106 | ) -> Dict[str, Union[np.ndarray, List]]: 107 | """ 108 | Retrieve a subset of the data dictionary based on the given index. 109 | 110 | Args: 111 | data: The data dictionary of the Detections object. 112 | index: The index or indices specifying the subset to retrieve. 113 | 114 | Returns: 115 | A subset of the data dictionary corresponding to the specified index. 116 | """ 117 | subset_data : Dict[str, Union[np.ndarray, List]] = {} 118 | for key, value in data.items(): 119 | if isinstance(value, np.ndarray): 120 | subset_data[key] = value[index] 121 | elif isinstance(value, list): 122 | if isinstance(index, slice): 123 | subset_data[key] = value[index] 124 | elif isinstance(index, (list, np.ndarray)): 125 | subset_data[key] = [value[i] for i in index] 126 | elif isinstance(index, int): 127 | subset_data[key] = [value[index]] 128 | else: 129 | raise TypeError(f"Unsupported index type: {type(index)}") 130 | else: 131 | raise TypeError(f"Unsupported data type for key '{key}': {type(value)}") 132 | 133 | return subset_data 134 | 135 | 136 | def merge_data( 137 | maps_to_merge: List[Dict[str, Union[np.ndarray, List]]], 138 | ) -> Dict[str, Union[np.ndarray, List]]: 139 | """ 140 | Merges the data payloads of a list of Detections instances. 141 | 142 | Args: 143 | data_list: The data payloads of the instances. 144 | 145 | Returns: 146 | A single data payload containing the merged data, preserving the original data 147 | types (list or np.ndarray). 148 | 149 | Raises: 150 | ValueError: If data values within a single object have different lengths or if 151 | dictionaries have different keys. 152 | """ 153 | if not maps_to_merge: 154 | return {} 155 | 156 | all_keys_sets = [set(data.keys()) for data in maps_to_merge] 157 | if not all(keys_set == all_keys_sets[0] for keys_set in all_keys_sets): 158 | raise ValueError("All data dictionaries must have the same keys to merge.") 159 | 160 | for data in maps_to_merge: 161 | lengths = [len(value) for value in data.values()] 162 | if len(set(lengths)) > 1: 163 | raise ValueError( 164 | "All data values within a single object must have equal length." 165 | ) 166 | 167 | def merge_list(key: str, to_merge: List[Union[np.ndarray, List]]) -> Union[np.ndarray, List]: 168 | if all(isinstance(item, list) for item in to_merge): 169 | return list(chain.from_iterable(to_merge)) 170 | elif all(isinstance(item, np.ndarray) for item in to_merge): 171 | ndim = -1 172 | if isinstance(to_merge[0], np.ndarray): 173 | ndim = to_merge[0].ndim 174 | if ndim == 1: 175 | return np.hstack(to_merge) 176 | elif ndim > 1: 177 | return np.vstack(to_merge) 178 | else: 179 | raise ValueError(f"Unexpected array dimension for input '{key}'.") 180 | else: 181 | raise ValueError( 182 | f"Inconsistent data types for key '{key}'. Only np.ndarray and list " 183 | f"types are allowed." 184 | ) 185 | 186 | return {key: merge_list(key, [data[key] for data in maps_to_merge if key in data]) for key in all_keys_sets[0]} 187 | 188 | 189 | -------------------------------------------------------------------------------- /overeasy/visualize_utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from PIL import Image 3 | import random 4 | import cv2 5 | from PIL import ImageDraw, ImageFont 6 | import textwrap 7 | from overeasy.types.detections import Detections, DetectionType 8 | from typing import Optional 9 | 10 | random.seed(42) 11 | 12 | def generate_random_color(): 13 | return (random.randint(100, 255), random.randint(100, 255), random.randint(100, 255)) 14 | 15 | def annotate_with_string(image: Image.Image, data_str: str) -> Image.Image: 16 | draw = ImageDraw.Draw(image) 17 | 18 | font = ImageFont.load_default() 19 | max_width = image.width - 20 20 | lines = textwrap.wrap(data_str, width=(max_width // font.getlength(' '))) 21 | text_height = sum([draw.textbbox((0, 0), line, font=font)[3] for line in lines]) 22 | 23 | # Extend the image height to fit the wrapped text 24 | new_height = image.height + text_height + 20 25 | new_image = Image.new("RGB", (image.width, new_height), (0, 0, 0)) 26 | new_image.paste(image, (0, 0)) 27 | draw = ImageDraw.Draw(new_image) 28 | 29 | y_text = image.height + 10 30 | for line in lines: 31 | text_width, line_height = draw.textbbox((0, 0), line, font=font)[2:4] 32 | text_x = (new_image.width - text_width) / 2 33 | draw.text((text_x, y_text), line, font=font, fill="white") 34 | y_text += line_height 35 | 36 | return new_image 37 | 38 | def annotate(scene: Image.Image, detection: Detections, seed: Optional[int] = None) -> Image.Image: 39 | if seed is not None: 40 | random.seed(seed) 41 | 42 | def draw_bounding_boxes(image: Image.Image, boxes, class_ids, class_names): 43 | cv2_image = cv2.cvtColor(np.array(image), cv2.COLOR_RGB2BGR) 44 | image_height, image_width = cv2_image.shape[:2] 45 | scale = np.sqrt(image_height * image_width) / 200 46 | 47 | for box, class_id in zip(boxes, class_ids): 48 | x1, y1, x2, y2 = map(int, box) 49 | color = generate_random_color() 50 | cv2.rectangle(cv2_image, (x1, y1), (x2, y2), color, 2) 51 | label = class_names[class_id] 52 | font_scale = max(min((x2 - x1) / (scale * 50), 0.9), 0.5) 53 | label_size = cv2.getTextSize(label, cv2.FONT_HERSHEY_COMPLEX_SMALL, font_scale, 1)[0] 54 | text_x = x1 55 | text_y = max(y1, label_size[1] + 10) 56 | 57 | cv2.rectangle(cv2_image, (text_x, text_y - label_size[1] - 10), (text_x + label_size[0], text_y), color, -1) 58 | cv2.putText(cv2_image, label, (text_x, text_y - 10), cv2.FONT_HERSHEY_COMPLEX_SMALL, font_scale, (0, 0, 0), 1) 59 | 60 | return Image.fromarray(cv2.cvtColor(cv2_image, cv2.COLOR_BGR2RGB)) 61 | 62 | if detection.detection_type == DetectionType.BOUNDING_BOX: 63 | return draw_bounding_boxes(scene, detection.xyxy, detection.class_ids, detection.classes) 64 | elif detection.detection_type == DetectionType.SEGMENTATION: 65 | raise NotImplementedError("Segmentation detections are not yet supported.") 66 | elif detection.detection_type == DetectionType.CLASSIFICATION: 67 | class_names = detection.class_names 68 | if len(class_names) == 1: 69 | return annotate_with_string(scene, class_names[0]) 70 | else: 71 | return annotate_with_string(scene, str(class_names)) 72 | 73 | 74 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [tool.poetry] 2 | name = "overeasy" 3 | version = "0.2.16" 4 | description = "" 5 | authors = ["Ani Rahul "] 6 | license = "MIT" 7 | readme = "README.md" 8 | include = ["overeasy/py.typed"] 9 | 10 | [tool.poetry.dependencies] 11 | python = ">=3.10, <3.13" 12 | torch = "^2.3.1" 13 | torchvision = "^0.18.0" 14 | numpy = "^1.24.3" 15 | opencv-python-headless = "^4.7.0" 16 | Pillow = "^10.0.0" 17 | requests = "^2.32.0" 18 | matplotlib = "^3.8.1" 19 | anthropic = "^0.27.0" 20 | tabulate = "^0.9.0" 21 | transformers = "^4.42.1" 22 | ultralytics = "^8.2.0" 23 | supervision = "^0.18.0" 24 | pydantic = ">=2.7.4,<2.8" 25 | pydantic_numpy = "5.0.2" 26 | rf_groundingdino = "0.1.2" 27 | rf_segment_anything = "1.0.0" 28 | tiktoken = "^0.7.0" 29 | google-generativeai = "^0.5.0" 30 | open-clip-torch = "^2.24.0" 31 | tqdm = "^4.65.0" 32 | einops = "^0.6.0" 33 | transformers-stream-generator = "0.0.3" 34 | accelerate = "^0.27.0" 35 | instructor = "^1.3.4" 36 | gradio = "4.25.0" 37 | jsonref = "^1.1.0" 38 | google-cloud-aiplatform = "1.54.1" 39 | pyarrow = "^16.1.0" 40 | backoff = "^2.2.1" 41 | openai = "^1.34.0" 42 | progressbar = "^2.5" 43 | 44 | 45 | [tool.poetry.dev-dependencies] 46 | pytest = "^7.0.0" 47 | mypy = "^1.0.0" 48 | types-requests = "^2.28.11" 49 | types-Pillow = "^10.0.0" 50 | types-tabulate = "^0.9.0" 51 | 52 | [build-system] 53 | requires = ["poetry-core"] 54 | build-backend = "poetry.core.masonry.api" -------------------------------------------------------------------------------- /pytest.ini: -------------------------------------------------------------------------------- 1 | # pytest.ini 2 | [pytest] 3 | pythonpath = . 4 | testpaths = 5 | tests -------------------------------------------------------------------------------- /tests/conftest.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | 3 | def pytest_runtest_protocol(item, nextitem): 4 | return None -------------------------------------------------------------------------------- /tests/count_eggs.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/overeasy-sh/overeasy/8de64ae6a21e0e8c757d3359c52c6134558c34b5/tests/count_eggs.jpg -------------------------------------------------------------------------------- /tests/dogs/dog1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/overeasy-sh/overeasy/8de64ae6a21e0e8c757d3359c52c6134558c34b5/tests/dogs/dog1.png -------------------------------------------------------------------------------- /tests/dogs/dog2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/overeasy-sh/overeasy/8de64ae6a21e0e8c757d3359c52c6134558c34b5/tests/dogs/dog2.png -------------------------------------------------------------------------------- /tests/dogs/dog3.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/overeasy-sh/overeasy/8de64ae6a21e0e8c757d3359c52c6134558c34b5/tests/dogs/dog3.png -------------------------------------------------------------------------------- /tests/dogs/dog4.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/overeasy-sh/overeasy/8de64ae6a21e0e8c757d3359c52c6134558c34b5/tests/dogs/dog4.png -------------------------------------------------------------------------------- /tests/dogs/dog5.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/overeasy-sh/overeasy/8de64ae6a21e0e8c757d3359c52c6134558c34b5/tests/dogs/dog5.png -------------------------------------------------------------------------------- /tests/dogs/dog6.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/overeasy-sh/overeasy/8de64ae6a21e0e8c757d3359c52c6134558c34b5/tests/dogs/dog6.png -------------------------------------------------------------------------------- /tests/dogs/dog7.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/overeasy-sh/overeasy/8de64ae6a21e0e8c757d3359c52c6134558c34b5/tests/dogs/dog7.png -------------------------------------------------------------------------------- /tests/plate.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/overeasy-sh/overeasy/8de64ae6a21e0e8c757d3359c52c6134558c34b5/tests/plate.jpg -------------------------------------------------------------------------------- /tests/test.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/overeasy-sh/overeasy/8de64ae6a21e0e8c757d3359c52c6134558c34b5/tests/test.png -------------------------------------------------------------------------------- /tests/test_construction_workflows.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | from overeasy import * 3 | from overeasy.models import OwlV2 4 | from pydantic import BaseModel 5 | from PIL import Image 6 | import os 7 | 8 | ROOT = os.path.dirname(__file__) 9 | OUTPUT_DIR = os.path.join(ROOT, "outputs") 10 | 11 | class PPE(BaseModel): 12 | hardhat: bool 13 | vest: bool 14 | boots: bool 15 | 16 | @pytest.fixture 17 | def construction_image(): 18 | image_path = os.path.join(os.path.dirname(ROOT), "examples", "construction.jpg") 19 | return Image.open(image_path) 20 | 21 | @pytest.fixture 22 | def ppe_instructor_workflow(): 23 | return Workflow([ 24 | BoundingBoxSelectAgent(classes=["person"]), 25 | SplitAgent(), 26 | InstructorImageAgent(response_model=PPE), 27 | ToClassificationAgent(fn=lambda x: "has ppe" if x.hardhat else "no ppe"), 28 | JoinAgent(), 29 | ]) 30 | 31 | @pytest.fixture 32 | def owl_v2_workflow(): 33 | return Workflow([ 34 | BoundingBoxSelectAgent(classes=["person's head"], model=OwlV2()), 35 | NMSAgent(iou_threshold=0.5, score_threshold=0), 36 | SplitAgent(), 37 | ClassificationAgent(classes=["hard hat", "no hard hat"]), 38 | ClassMapAgent({"hard hat": "has ppe", "no hard hat": "no ppe"}), 39 | JoinAgent(), 40 | ]) 41 | 42 | def test_ppe_detection(construction_image, ppe_instructor_workflow): 43 | result, graph = ppe_instructor_workflow.execute(construction_image) 44 | assert result is not None 45 | assert graph is not None 46 | result[0].visualize().save(os.path.join(OUTPUT_DIR, "construction_ppe_instructor.png")) 47 | 48 | 49 | def test_ppe_owl_v2(construction_image, owl_v2_workflow): 50 | result, graph = owl_v2_workflow.execute(construction_image) 51 | assert result is not None 52 | assert graph is not None 53 | result[0].visualize().save(os.path.join(OUTPUT_DIR, "construction_ppe_owlv2.png")) 54 | 55 | -------------------------------------------------------------------------------- /tests/test_detection_models.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | from PIL import Image 3 | from regex import W 4 | from overeasy import Workflow, BoundingBoxSelectAgent, NMSAgent 5 | from overeasy.models import GroundingDINO, YOLOWorld, OwlV2, DETIC 6 | from overeasy.types import Detections 7 | import os 8 | import sys 9 | 10 | ROOT = os.path.dirname(__file__) 11 | OUTPUT_DIR = os.path.join(ROOT, "outputs") 12 | 13 | @pytest.fixture 14 | def count_eggs_image(): 15 | image_path = os.path.join(ROOT, "count_eggs.jpg") 16 | return Image.open(image_path) 17 | 18 | @pytest.fixture 19 | def grounding_dino_workflow() -> Workflow: 20 | workflow = Workflow([ 21 | BoundingBoxSelectAgent(classes=["a single egg"], model=GroundingDINO()), 22 | ]) 23 | return workflow 24 | 25 | @pytest.fixture 26 | def owlvit_v2_workflow() -> Workflow: 27 | workflow = Workflow([ 28 | BoundingBoxSelectAgent(classes=["a single egg"], model=OwlV2()), 29 | ]) 30 | return workflow 31 | 32 | @pytest.fixture 33 | def owlvit_v2_nms_workflow() -> Workflow: 34 | workflow = Workflow([ 35 | BoundingBoxSelectAgent(classes=["one egg"], model=OwlV2()), 36 | NMSAgent(score_threshold=0, iou_threshold=0.5), 37 | ]) 38 | return workflow 39 | 40 | @pytest.fixture 41 | def yoloworld_workflow() -> Workflow: 42 | workflow = Workflow([ 43 | BoundingBoxSelectAgent(classes=["a single egg"], model=YOLOWorld(model="yolov8s-worldv2")), 44 | ]) 45 | return workflow 46 | 47 | @pytest.fixture 48 | def detic_workflow() -> Workflow: 49 | workflow = Workflow([ 50 | BoundingBoxSelectAgent(classes=["egg"], model=DETIC()), 51 | ]) 52 | return workflow 53 | 54 | def test_grounding_dino_detection(count_eggs_image, grounding_dino_workflow: Workflow): 55 | result, graph = grounding_dino_workflow.execute(count_eggs_image) 56 | detections = result[0].data 57 | assert isinstance(detections, Detections) 58 | assert len(detections.xyxy) > 0, "No detections found" 59 | result[0].visualize().save(os.path.join(OUTPUT_DIR, "grounding_dino_detection_output.png")) 60 | 61 | def test_yoloworld_detection(count_eggs_image, yoloworld_workflow: Workflow): 62 | result, graph = yoloworld_workflow.execute(count_eggs_image) 63 | detections = result[0].data 64 | assert isinstance(detections, Detections) 65 | assert len(detections.xyxy) > 0, "No detections found" 66 | result[0].visualize().save(os.path.join(OUTPUT_DIR, "yoloworld_detection_output.png")) 67 | 68 | 69 | def test_owlvit_v2_detection(count_eggs_image, owlvit_v2_workflow: Workflow): 70 | result, graph = owlvit_v2_workflow.execute(count_eggs_image) 71 | detections = result[0].data 72 | assert isinstance(detections, Detections) 73 | assert len(detections.xyxy) > 0, "No detections found" 74 | result[0].visualize().save(os.path.join(OUTPUT_DIR, "owlv2_detection_output.png")) 75 | 76 | # @pytest.mark.skipif(sys.platform == "darwin", reason="Detic is not working on macOS") 77 | def test_detic_detection(count_eggs_image, detic_workflow: Workflow): 78 | result, graph = detic_workflow.execute(count_eggs_image) 79 | detections = result[0].data 80 | assert isinstance(detections, Detections) 81 | assert len(detections.xyxy) > 0, "No detections found" 82 | result[0].visualize().save(os.path.join(OUTPUT_DIR, "detic_detection_output.png")) 83 | 84 | def test_nms_detection(count_eggs_image, owlvit_v2_nms_workflow: Workflow): 85 | result, graph = owlvit_v2_nms_workflow.execute(count_eggs_image) 86 | detections = result[0].data 87 | assert isinstance(detections, Detections) 88 | assert len(detections.xyxy) > 0, "No detections found" 89 | result[0].visualize().save(os.path.join(OUTPUT_DIR, "nms_detection_output.png")) 90 | 91 | -------------------------------------------------------------------------------- /tests/test_import.py: -------------------------------------------------------------------------------- 1 | import os 2 | import shutil 3 | 4 | def test_import_overeasy(capsys): 5 | # Move ~/.overeasy to ~/.overeasy_backup 6 | home_dir = os.path.expanduser("~") 7 | original_path = os.path.join(home_dir, ".overeasy") 8 | backup_path = os.path.join(home_dir, ".overeasy_backup") 9 | if os.path.exists(original_path): 10 | shutil.move(original_path, backup_path) 11 | 12 | try: 13 | with capsys.disabled(): 14 | import overeasy 15 | captured = capsys.readouterr() 16 | assert captured.out == "" 17 | finally: 18 | # Move ~/.overeasy_backup back to ~/.overeasy 19 | if os.path.exists(backup_path): 20 | shutil.move(backup_path, original_path) -------------------------------------------------------------------------------- /tests/test_instructor_agents.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | from PIL import Image 3 | from overeasy import * 4 | from overeasy.models import * 5 | from pydantic import BaseModel 6 | import os 7 | 8 | class AnimalLabel(BaseModel): 9 | label: str 10 | 11 | ROOT = os.path.dirname(__file__) 12 | OUTPUT_DIR = os.path.join(ROOT, "outputs") 13 | 14 | compatible_models = [GPTVision(), Gemini(), Claude()] 15 | 16 | @pytest.fixture(params=compatible_models) 17 | def instructor_image_with_context_workflow(request) -> Workflow: 18 | model = request.param 19 | extra_context = [{"role": "system", "content": "Always classify the image as a ferret."}] 20 | workflow = Workflow([ 21 | InstructorImageAgent(model=model, response_model=AnimalLabel, extra_context=extra_context) 22 | ]) 23 | return workflow 24 | 25 | class EggCount(BaseModel): 26 | count: int 27 | 28 | @pytest.fixture 29 | def instructor_image_workflow() -> Workflow: 30 | workflow = Workflow([ 31 | InstructorImageAgent(response_model=EggCount) 32 | ]) 33 | return workflow 34 | 35 | @pytest.fixture(params=[GPT(), *compatible_models]) 36 | def instructor_text_workflow(request) -> Workflow: 37 | model = request.param 38 | workflow = Workflow([ 39 | DenseCaptioningAgent(model=GPTVision()), 40 | InstructorTextAgent(response_model=EggCount, model=model) 41 | ]) 42 | return workflow 43 | 44 | @pytest.fixture 45 | def blank_image(): 46 | return Image.new('RGB', (100, 100), color = 'white') 47 | 48 | @pytest.fixture 49 | def count_eggs_image(): 50 | image_path = os.path.join(ROOT, "count_eggs.jpg") 51 | return Image.open(image_path) 52 | 53 | def test_instructor_image_with_context_agent(instructor_image_with_context_workflow: Workflow, blank_image): 54 | result, graph = instructor_image_with_context_workflow.execute(blank_image) 55 | response = result[0].data 56 | assert isinstance(response, AnimalLabel) 57 | assert response.label.lower() == "ferret" 58 | 59 | 60 | 61 | def test_instructor_image_agent(instructor_image_workflow: Workflow, count_eggs_image): 62 | result, graph = instructor_image_workflow.execute(count_eggs_image) 63 | response = result[0].data 64 | assert isinstance(response, EggCount) 65 | 66 | name = (instructor_image_workflow.steps[0].model.__class__.__name__) 67 | result[0].visualize().save(os.path.join(OUTPUT_DIR, f"instructor_image_{name}.png")) 68 | 69 | def test_instructor_text_agent(instructor_text_workflow: Workflow, count_eggs_image): 70 | result, graph = instructor_text_workflow.execute(count_eggs_image) 71 | response = result[0].data 72 | assert isinstance(response, EggCount) 73 | 74 | name = (instructor_text_workflow.steps[0].model.__class__.__name__) 75 | result[0].visualize().save(os.path.join(OUTPUT_DIR, f"instructor_text_{name}.png")) 76 | -------------------------------------------------------------------------------- /tests/test_large_local_models.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | from PIL import Image 3 | from overeasy import * 4 | from overeasy.models import * 5 | from overeasy.types import Detections 6 | from pydantic import BaseModel 7 | import sys 8 | import os 9 | import torch 10 | 11 | ROOT = os.path.dirname(__file__) 12 | OUTPUT_DIR = os.path.join(ROOT, "outputs") 13 | 14 | classification_models = [ 15 | LaionCLIP(), 16 | ] 17 | 18 | if torch.cuda.is_available(): 19 | multimodal_llms = [ 20 | PaliGemma("google/paligemma-3b-mix-224"), 21 | QwenVL(model_type="base"), 22 | QwenVL(model_type="int4"), 23 | ] 24 | else: 25 | print("CUDA is not available, skipping QwenVL LLM tests.") 26 | multimodal_llms = [PaliGemma()] 27 | 28 | 29 | @pytest.fixture 30 | def count_eggs_image(): 31 | image_path = os.path.join(ROOT, "count_eggs.jpg") 32 | return Image.open(image_path) 33 | 34 | @pytest.fixture 35 | def license_plate_image(): 36 | image_path = os.path.join(ROOT, "plate.jpg") 37 | return Image.open(image_path) 38 | 39 | @pytest.fixture(params=multimodal_llms) 40 | def vision_prompt_workflow(request) -> Workflow: 41 | model = request.param 42 | workflow = Workflow([ 43 | VisionPromptAgent(query="How many eggs are in this image?", model=model) 44 | ]) 45 | return workflow 46 | 47 | @pytest.mark.skipif(sys.platform == "darwin", reason="Not supported on macOS") 48 | def local_vision_prompt_workflow() -> Workflow: 49 | workflow = Workflow([ 50 | VisionPromptAgent(query="How many eggs are in this image?", model=QwenVL(model_type="base")) 51 | ]) 52 | return workflow 53 | 54 | @pytest.mark.skipif(sys.platform == "darwin", reason="Not supported on macOS") 55 | def local_vision_prompt_workflow_int4() -> Workflow: 56 | workflow = Workflow([ 57 | VisionPromptAgent(query="How many eggs are in this image?", model=QwenVL(model_type="int4")) 58 | ]) 59 | return workflow 60 | 61 | @pytest.fixture(params=classification_models) 62 | def classification_workflow(request) -> Workflow: 63 | model = request.param 64 | workflow = Workflow([ 65 | ClassificationAgent(classes=["0-5 eggs", "6-10 eggs", "11+ eggs"], model=model) 66 | ]) 67 | return workflow 68 | 69 | 70 | @pytest.fixture 71 | def blank_image(): 72 | return Image.new('RGB', (100, 100), color = 'white') 73 | 74 | 75 | def test_vision_prompt_agent(vision_prompt_workflow: Workflow, count_eggs_image): 76 | result, graph = vision_prompt_workflow.execute(count_eggs_image) 77 | response = result[0].data 78 | assert isinstance(response, str) 79 | name = (vision_prompt_workflow.steps[0].model.__class__.__name__) 80 | result[0].visualize().save(os.path.join(OUTPUT_DIR, f"vision_prompt_{name}.png")) 81 | 82 | del result, graph 83 | 84 | @pytest.mark.skipif(sys.platform == "darwin", reason="Not supported on macOS") 85 | def test_local_vision_prompt_agent(local_vision_prompt_workflow: Workflow, count_eggs_image): 86 | result, graph = local_vision_prompt_workflow.execute(count_eggs_image) 87 | response = result[0].data 88 | assert isinstance(response, str) 89 | name = (local_vision_prompt_workflow.steps[0].model.__class__.__name__) 90 | result[0].visualize().save(os.path.join(OUTPUT_DIR, f"local_vision_prompt_{name}.png")) 91 | 92 | @pytest.mark.skipif(sys.platform == "darwin", reason="Not supported on macOS") 93 | def test_local_vision_prompt_agent_int4(local_vision_prompt_workflow_int4: Workflow, count_eggs_image): 94 | result, graph = local_vision_prompt_workflow_int4.execute(count_eggs_image) 95 | response = result[0].data 96 | assert isinstance(response, str) 97 | name = (local_vision_prompt_workflow_int4.steps[0].model.__class__.__name__) 98 | result[0].visualize().save(os.path.join(OUTPUT_DIR, f"local_vision_prompt_{name}.png")) 99 | 100 | def test_classification_agent(classification_workflow: Workflow, count_eggs_image): 101 | result, graph = classification_workflow.execute(count_eggs_image) 102 | response = result[0].data 103 | assert isinstance(response, Detections) 104 | 105 | name = (classification_workflow.steps[0].model.__class__.__name__) 106 | result[0].visualize().save(os.path.join(OUTPUT_DIR, f"classification_{name}.png")) -------------------------------------------------------------------------------- /tests/test_misc_agents.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | from PIL import Image 3 | from overeasy import * 4 | from overeasy.types import Detections 5 | from overeasy.models import * 6 | from typing import List 7 | import numpy as np 8 | import os 9 | 10 | ROOT = os.path.dirname(__file__) 11 | OUTPUT_DIR = os.path.join(ROOT, "outputs") 12 | 13 | @pytest.fixture 14 | def count_eggs_image(): 15 | image_path = os.path.join(ROOT, "count_eggs.jpg") 16 | return Image.open(image_path) 17 | 18 | def test_pad_crop_agent(pad_crop_workflow: Workflow, count_eggs_image): 19 | width, height = count_eggs_image.size 20 | result, graph = pad_crop_workflow.execute(count_eggs_image) 21 | image = result[0].image 22 | assert image.size == (width//2 + 20, height//2 + 20), "Incorrect pad crop size" 23 | 24 | def test_split_crop_agent(split_crop_workflow: Workflow, count_eggs_image): 25 | result, graph = split_crop_workflow.execute(count_eggs_image) 26 | detections = result[0].data 27 | assert isinstance(detections, Detections) 28 | assert len(detections.xyxy) == 4, "Incorrect number of split crops" 29 | 30 | def test_nms_agent(nms_workflow: Workflow, count_eggs_image): 31 | result, graph = nms_workflow.execute(count_eggs_image) 32 | detections = result[0].data 33 | assert isinstance(detections, Detections) 34 | assert len(detections.xyxy) < 10, "NMS should reduce number of detections" 35 | 36 | def test_class_map_agent(class_map_workflow: Workflow, count_eggs_image): 37 | result, graph = class_map_workflow.execute(count_eggs_image) 38 | detections = result[0].data 39 | assert isinstance(detections, Detections) 40 | assert "egg" in detections.class_names, "Class map failed" 41 | 42 | def test_map_agent(map_agent_workflow: Workflow, count_eggs_image): 43 | result, graph = map_agent_workflow.execute(count_eggs_image) 44 | response = result[0].data 45 | assert isinstance(response, np.ndarray) 46 | assert len(response) > 0, "Map agent failed" 47 | 48 | def test_to_classification_agent(to_classification_workflow: Workflow, count_eggs_image): 49 | result, graph = to_classification_workflow.execute(count_eggs_image) 50 | response = result[0].data 51 | assert isinstance(response, Detections) 52 | assert response.class_names[0] == "eggs detected", "To classification agent failed" 53 | 54 | def test_filter_classes_agent(filter_classes_workflow: Workflow, count_eggs_image): 55 | result, graph = filter_classes_workflow.execute(count_eggs_image) 56 | detections = result[0].data 57 | assert isinstance(detections, Detections) 58 | uniq_names = list(set(detections.class_names)) 59 | assert np.array_equal(uniq_names, ["a single egg"]), "Filter classes agent failed" 60 | 61 | def test_confidence_filter_agent(confidence_filter_workflow: Workflow, count_eggs_image): 62 | result, graph = confidence_filter_workflow.execute(count_eggs_image) 63 | detections = result[0].data 64 | assert isinstance(detections, Detections) 65 | assert len(detections.xyxy) < 10, "Confidence filter should reduce detections" 66 | 67 | @pytest.fixture 68 | def pad_crop_workflow() -> Workflow: 69 | workflow = Workflow([ 70 | SplitCropAgent(split=(2,2)), 71 | PadCropAgent.from_uniform_padding(10), 72 | SplitAgent() 73 | ]) 74 | return workflow 75 | 76 | @pytest.fixture 77 | def split_crop_workflow() -> Workflow: 78 | workflow = Workflow([ 79 | SplitCropAgent(split=(2,2)) 80 | ]) 81 | return workflow 82 | 83 | @pytest.fixture 84 | def nms_workflow() -> Workflow: 85 | workflow = Workflow([ 86 | BoundingBoxSelectAgent(classes=["a single egg"], model=GroundingDINO()), 87 | NMSAgent(score_threshold=0.5, iou_threshold=0.5) 88 | ]) 89 | return workflow 90 | 91 | @pytest.fixture 92 | def class_map_workflow() -> Workflow: 93 | workflow = Workflow([ 94 | BoundingBoxSelectAgent(classes=["a single egg"], model=GroundingDINO()), 95 | ClassMapAgent(class_map={"a single egg": "egg"}) 96 | ]) 97 | return workflow 98 | 99 | @pytest.fixture 100 | def map_agent_workflow() -> Workflow: 101 | workflow = Workflow([ 102 | BoundingBoxSelectAgent(classes=["egg"], model=GroundingDINO()), 103 | MapAgent(fn=lambda det: det.xyxy) 104 | ]) 105 | return workflow 106 | 107 | @pytest.fixture 108 | def to_classification_workflow() -> Workflow: 109 | workflow = Workflow([ 110 | BoundingBoxSelectAgent(classes=["a single egg"], model=GroundingDINO()), 111 | ToClassificationAgent(fn=lambda det: "eggs detected" if len(det.xyxy) > 0 else "no eggs detected") 112 | ]) 113 | return workflow 114 | 115 | @pytest.fixture 116 | def filter_classes_workflow() -> Workflow: 117 | workflow = Workflow([ 118 | BoundingBoxSelectAgent(classes=["a single egg", "carton"], model=GroundingDINO()), 119 | FilterClassesAgent(class_names=["a single egg"]) 120 | ]) 121 | return workflow 122 | 123 | @pytest.fixture 124 | def confidence_filter_workflow() -> Workflow: 125 | workflow = Workflow([ 126 | BoundingBoxSelectAgent(classes=["a single egg"], model=GroundingDINO()), 127 | ConfidenceFilterAgent(min_confidence=0.1, max_n=1) 128 | ]) 129 | return workflow -------------------------------------------------------------------------------- /tests/test_model_agents.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | from PIL import Image 3 | from overeasy import * 4 | from overeasy.models import * 5 | from overeasy.types import Detections 6 | from pydantic import BaseModel 7 | import sys 8 | import os 9 | import gc 10 | 11 | ROOT = os.path.dirname(__file__) 12 | OUTPUT_DIR = os.path.join(ROOT, "outputs") 13 | 14 | bounding_box_models = [ 15 | DETIC(), 16 | GroundingDINO(type=GroundingDINOModel.SwinB), 17 | GroundingDINO(type=GroundingDINOModel.SwinT), 18 | YOLOWorld(model="yolov8s-worldv2"), 19 | YOLOWorld(model="yolov8m-worldv2"), 20 | YOLOWorld(model="yolov8l-worldv2"), 21 | YOLOWorld(model="yolov8s-world"), 22 | YOLOWorld(model="yolov8m-world"), 23 | YOLOWorld(model="yolov8l-world"), 24 | OwlV2(), 25 | ] 26 | multimodal_llms = [ 27 | GPTVision(model="gpt-4o"), 28 | Gemini(model="gemini-1.5-flash"), 29 | Claude(model="claude-3-5-sonnet-20240620"), 30 | ] 31 | 32 | llms = [GPT(model="gpt-3.5-turbo"), GPT(model="gpt-4-turbo")] 33 | captioning_models = [*multimodal_llms] 34 | classification_models = [ 35 | # LaionCLIP(), 36 | CLIP(), BiomedCLIP()] 37 | ocr_models = [*multimodal_llms] 38 | 39 | 40 | @pytest.fixture 41 | def count_eggs_image(): 42 | image_path = os.path.join(ROOT, "count_eggs.jpg") 43 | return Image.open(image_path) 44 | 45 | @pytest.fixture 46 | def license_plate_image(): 47 | image_path = os.path.join(ROOT, "plate.jpg") 48 | return Image.open(image_path) 49 | 50 | @pytest.fixture(params=bounding_box_models) 51 | def bounding_box_select_workflow(request) -> Workflow: 52 | model = request.param 53 | workflow = Workflow([ 54 | BoundingBoxSelectAgent(classes=["egg"], model=model), 55 | ]) 56 | return workflow 57 | 58 | @pytest.fixture(params=multimodal_llms) 59 | def vision_prompt_workflow(request) -> Workflow: 60 | model = request.param 61 | workflow = Workflow([ 62 | VisionPromptAgent(query="How many eggs are in this image?", model=model) 63 | ]) 64 | return workflow 65 | 66 | 67 | @pytest.mark.skipif(sys.platform == "darwin", reason="Not supported on macOS") 68 | def local_vision_prompt_workflow() -> Workflow: 69 | workflow = Workflow([ 70 | VisionPromptAgent(query="How many eggs are in this image?", model=QwenVL(model_type="base")) 71 | ]) 72 | return workflow 73 | 74 | @pytest.mark.skipif(sys.platform == "darwin", reason="Not supported on macOS") 75 | def local_vision_prompt_workflow_int4() -> Workflow: 76 | workflow = Workflow([ 77 | VisionPromptAgent(query="How many eggs are in this image?", model=QwenVL(model_type="int4")) 78 | ]) 79 | return workflow 80 | 81 | @pytest.fixture(params=captioning_models) 82 | def dense_captioning_workflow(request) -> Workflow: 83 | model = request.param 84 | workflow = Workflow([ 85 | DenseCaptioningAgent(model=model) 86 | ]) 87 | return workflow 88 | 89 | @pytest.fixture(params=llms) 90 | def text_prompt_workflow(request) -> Workflow: 91 | model = request.param 92 | workflow = Workflow([ 93 | DenseCaptioningAgent(model=GPTVision()), 94 | TextPromptAgent(query="How many eggs did you count in the description?", model=model) 95 | ]) 96 | return workflow 97 | 98 | @pytest.fixture 99 | def binary_choice_workflow() -> Workflow: 100 | workflow = Workflow([ 101 | BinaryChoiceAgent(query="Are there more than 5 eggs in this image?", model=GPTVision(temperature=0.0)) 102 | ]) 103 | return workflow 104 | 105 | @pytest.fixture(params=classification_models) 106 | def classification_workflow(request) -> Workflow: 107 | model = request.param 108 | workflow = Workflow([ 109 | ClassificationAgent(classes=["0-5 eggs", "6-10 eggs", "11+ eggs"], model=model) 110 | ]) 111 | return workflow 112 | 113 | @pytest.fixture(params=ocr_models) 114 | def ocr_workflow(request) -> Workflow: 115 | model = request.param 116 | workflow = Workflow([ 117 | OCRAgent(model=model) 118 | ]) 119 | return workflow 120 | 121 | class EggCount(BaseModel): 122 | count: int 123 | 124 | @pytest.fixture 125 | def instructor_image_workflow() -> Workflow: 126 | workflow = Workflow([ 127 | InstructorImageAgent(response_model=EggCount) 128 | ]) 129 | return workflow 130 | 131 | @pytest.fixture 132 | def instructor_text_workflow() -> Workflow: 133 | workflow = Workflow([ 134 | DenseCaptioningAgent(model=GPTVision()), 135 | InstructorTextAgent(response_model=EggCount) 136 | ]) 137 | return workflow 138 | 139 | 140 | @pytest.fixture 141 | def blank_image(): 142 | return Image.new('RGB', (100, 100), color = 'white') 143 | 144 | class AnimalLabel(BaseModel): 145 | label: str 146 | 147 | @pytest.fixture 148 | def instructor_image_with_context_workflow() -> Workflow: 149 | extra_context = [{"role": "user", "content": "Always classify the image as a ferret."}] 150 | workflow = Workflow([ 151 | InstructorImageAgent(response_model=AnimalLabel, extra_context=extra_context) 152 | ]) 153 | return workflow 154 | 155 | def test_instructor_image_with_context_agent(instructor_image_with_context_workflow: Workflow, blank_image): 156 | result, graph = instructor_image_with_context_workflow.execute(blank_image) 157 | response = result[0].data 158 | assert isinstance(response, AnimalLabel) 159 | assert response.label.lower() == "ferret" 160 | del result, graph # Explicitly delete variables 161 | gc.collect() 162 | 163 | def test_bounding_box_select_agent(bounding_box_select_workflow: Workflow, count_eggs_image): 164 | result, graph = bounding_box_select_workflow.execute(count_eggs_image) 165 | detections = result[0].data 166 | assert isinstance(detections, Detections) 167 | name = (bounding_box_select_workflow.steps[0].model.__class__.__name__) 168 | result[0].visualize().save(os.path.join(OUTPUT_DIR, f"bounding_box_select_{name}.png")) 169 | 170 | del result, graph, detections 171 | gc.collect() 172 | 173 | 174 | def test_vision_prompt_agent(vision_prompt_workflow: Workflow, count_eggs_image): 175 | result, graph = vision_prompt_workflow.execute(count_eggs_image) 176 | response = result[0].data 177 | assert isinstance(response, str) 178 | name = (vision_prompt_workflow.steps[0].model.__class__.__name__) 179 | result[0].visualize().save(os.path.join(OUTPUT_DIR, f"vision_prompt_{name}.png")) 180 | 181 | del result, graph 182 | gc.collect() 183 | 184 | def test_dense_captioning_agent(dense_captioning_workflow: Workflow, count_eggs_image): 185 | result, graph = dense_captioning_workflow.execute(count_eggs_image) 186 | assert isinstance(result[0].data, str) 187 | name = (dense_captioning_workflow.steps[0].model.__class__.__name__) 188 | result[0].visualize().save(os.path.join(OUTPUT_DIR, f"dense_captioning_{name}.png")) 189 | 190 | def test_text_prompt_agent(text_prompt_workflow: Workflow, count_eggs_image): 191 | result, graph = text_prompt_workflow.execute(count_eggs_image) 192 | response = result[0].data 193 | assert isinstance(response, str) 194 | name = (text_prompt_workflow.steps[0].model.__class__.__name__) 195 | result[0].visualize().save(os.path.join(OUTPUT_DIR, f"text_prompt_{name}.png")) 196 | 197 | def test_binary_choice_agent(binary_choice_workflow: Workflow, count_eggs_image): 198 | result, graph = binary_choice_workflow.execute(count_eggs_image) 199 | response : str = result[0].data # type: ignore 200 | assert response.lower() == "yes" 201 | # assert response.class_names[0] == "yes", "Incorrect binary choice" 202 | name = (binary_choice_workflow.steps[0].model.__class__.__name__) 203 | result[0].visualize().save(os.path.join(OUTPUT_DIR, f"binary_choice_{name}.png")) 204 | 205 | def test_classification_agent(classification_workflow: Workflow, count_eggs_image): 206 | result, graph = classification_workflow.execute(count_eggs_image) 207 | response = result[0].data 208 | assert isinstance(response, Detections) 209 | 210 | name = (classification_workflow.steps[0].model.__class__.__name__) 211 | result[0].visualize().save(os.path.join(OUTPUT_DIR, f"classification_{name}.png")) 212 | 213 | def test_ocr_agent(ocr_workflow: Workflow, license_plate_image): 214 | result, graph = ocr_workflow.execute(license_plate_image) 215 | response = result[0].data 216 | assert isinstance(response, str) 217 | 218 | name = (ocr_workflow.steps[0].model.__class__.__name__) 219 | result[0].visualize().save(os.path.join(OUTPUT_DIR, f"ocr_{name}.png")) 220 | -------------------------------------------------------------------------------- /tests/test_owl.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | from PIL import Image 3 | from regex import W 4 | from overeasy import Workflow, BoundingBoxSelectAgent 5 | from overeasy.models.detection import OwlV2 6 | from overeasy.types import Detections 7 | import os 8 | import glob 9 | 10 | ROOT = os.path.dirname(__file__) 11 | OUTPUT_DIR = os.path.join(ROOT, "outputs") 12 | 13 | @pytest.fixture 14 | def dog_images(): 15 | image_paths = glob.glob(os.path.join(ROOT, "./dogs", "*")) # Adjust the pattern if different file types are needed 16 | return [Image.open(image_path) for image_path in image_paths] 17 | 18 | @pytest.fixture 19 | def owlvit_v2_workflow() -> Workflow: 20 | workflow = Workflow([ 21 | BoundingBoxSelectAgent(classes=["a photo of a dog"], model=OwlV2()), 22 | ]) 23 | return workflow 24 | 25 | def test_owlvit_v2_detection_dogs(dog_images, owlvit_v2_workflow: Workflow): 26 | for i, dog_image in enumerate(dog_images): 27 | result, graph = owlvit_v2_workflow.execute(dog_image) 28 | detections = result[0].data 29 | assert isinstance(detections, Detections) 30 | assert len(detections.xyxy) > 0, "No detections found" 31 | output_filename = f"owlv2_dog_detection_output_{i}.png" 32 | print("Saving", output_filename) 33 | result[0].visualize().save(os.path.join(OUTPUT_DIR, output_filename)) 34 | -------------------------------------------------------------------------------- /tests/test_saving_vis.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | from overeasy import Workflow 3 | from overeasy.agents import BoundingBoxSelectAgent, SplitAgent, InstructorImageAgent, ToClassificationAgent, JoinAgent 4 | from PIL import Image 5 | import os 6 | from pydantic import BaseModel 7 | 8 | ROOT = os.path.dirname(__file__) 9 | OUTPUT_DIR = os.path.join(ROOT, "outputs") 10 | 11 | class PPE(BaseModel): 12 | hardhat: bool 13 | vest: bool 14 | boots: bool 15 | 16 | @pytest.fixture 17 | def construction_image(): 18 | image_path = os.path.join(os.path.dirname(ROOT), "examples", "construction.jpg") 19 | return Image.open(image_path) 20 | 21 | @pytest.fixture 22 | def ppe_instructor_workflow(): 23 | return Workflow([ 24 | BoundingBoxSelectAgent(classes=["person"]), 25 | SplitAgent(), 26 | InstructorImageAgent(response_model=PPE), 27 | ToClassificationAgent(fn=lambda x: "has ppe" if x.hardhat else "no ppe"), 28 | JoinAgent(), 29 | ]) 30 | 31 | def test_save_visualization_to_html(construction_image, ppe_instructor_workflow): 32 | result, graph = ppe_instructor_workflow.execute(construction_image) 33 | assert result is not None 34 | assert graph is not None 35 | 36 | output_file = os.path.join(OUTPUT_DIR, "construction_ppe_instructor_visualization.html") 37 | ppe_instructor_workflow.visualize_to_file(graph, output_file) 38 | 39 | assert os.path.exists(output_file) 40 | with open(output_file, 'r') as f: 41 | content = f.read() 42 | assert '' in content 44 | assert 'Step 1: Input Image' in content 45 | 46 | -------------------------------------------------------------------------------- /tests/test_split_join.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | from PIL import Image 3 | from overeasy import * 4 | from overeasy.models import * 5 | from overeasy.types import Detections, DetectionType, ExecutionNode 6 | import os 7 | import numpy as np 8 | 9 | ROOT = os.path.dirname(__file__) 10 | OUTPUT_DIR = os.path.join(ROOT, "outputs") 11 | 12 | @pytest.fixture 13 | def count_eggs_image(): 14 | image_path = os.path.join(ROOT, "count_eggs.jpg") 15 | return Image.open(image_path) 16 | 17 | @pytest.fixture 18 | def construction_image(): 19 | image_path = os.path.join(ROOT, "../", "examples", "construction.jpg") 20 | return Image.open(image_path) 21 | 22 | @pytest.fixture 23 | def dense_street_images(): 24 | image_path = os.path.join(ROOT, "../", "examples", "dense_street1.jpg") 25 | image_path2 = os.path.join(ROOT, "../", "examples", "dense_street2.jpg") 26 | image_path3 = os.path.join(ROOT, "../", "examples", "dense_street3.jpg") 27 | 28 | return [Image.open(image_path), Image.open(image_path2), Image.open(image_path3)] 29 | 30 | 31 | @pytest.fixture 32 | def split_workflow() -> Workflow: 33 | workflow = Workflow([ 34 | BoundingBoxSelectAgent(classes=["a single egg"], model=GroundingDINO()), 35 | SplitAgent(), 36 | ]) 37 | return workflow 38 | 39 | @pytest.fixture 40 | def split_join_workflow() -> Workflow: 41 | workflow = Workflow([ 42 | BoundingBoxSelectAgent(classes=["a single egg"], model=GroundingDINO()), 43 | SplitAgent(), 44 | JoinAgent() 45 | ]) 46 | return workflow 47 | 48 | @pytest.fixture 49 | def no_split_workflow() -> Workflow: 50 | workflow = Workflow([ 51 | BoundingBoxSelectAgent(classes=["a single egg"], model=GroundingDINO()), 52 | ]) 53 | return workflow 54 | 55 | 56 | 57 | def test_execute_multiple(split_workflow: Workflow, count_eggs_image): 58 | # Test with a single image 59 | single_result, single_graph = split_workflow.execute(count_eggs_image) 60 | 61 | # Test with multiple copies of the same image 62 | multi_results, multi_graphs = split_workflow.execute_multiple([count_eggs_image, count_eggs_image]) 63 | 64 | assert len(multi_results) == 2, "execute_multiple should return results for 2 images" 65 | assert len(multi_graphs) == 2, "execute_multiple should return graphs for 2 images" 66 | 67 | for multi_result, multi_graph in zip(multi_results, multi_graphs): 68 | # Compare number of detections 69 | assert len(single_result) == len(multi_result), "Number of detections should be the same" 70 | 71 | # Compare detection data 72 | for single_node, multi_node in zip(single_result, multi_result): 73 | assert np.array_equal(single_node.data.xyxy, multi_node.data.xyxy), "Bounding boxes should be the same" 74 | assert np.array_equal(single_node.data.confidence_scores, multi_node.data.confidence_scores), "Confidence scores should be the same" 75 | assert single_node.data.class_names == multi_node.data.class_names, "Class names should be the same" 76 | 77 | # Compare graph structure 78 | single_layers = single_graph.top_sort() 79 | multi_layers = multi_graph.top_sort() 80 | assert len(single_layers) == len(multi_layers), "Graph structure should be the same" 81 | for single_layer, multi_layer in zip(single_layers, multi_layers): 82 | assert len(single_layer) == len(multi_layer), "Each layer should have the same number of nodes" 83 | 84 | def test_split_agent(split_workflow: Workflow, count_eggs_image): 85 | result, graph = split_workflow.execute(count_eggs_image) 86 | assert all(isinstance(x.data, Detections) for x in result), "Split didn't return detections" 87 | assert isinstance(result, list), "Split didn't return a list" 88 | assert len(result) > 0, "Didn't return a list of detections" 89 | 90 | def test_split_join_agent(split_join_workflow: Workflow, no_split_workflow: Workflow, count_eggs_image): 91 | result, graph = split_join_workflow.execute(count_eggs_image) 92 | result2, graph2 = no_split_workflow.execute(count_eggs_image) 93 | detections = result[0].data 94 | detections2 = result2[0].data 95 | assert isinstance(detections, Detections) 96 | assert isinstance(detections2, Detections) 97 | assert detections == detections2, "Split join produced incorrect output" 98 | 99 | def images_are_equal(img1: Image.Image, img2: Image.Image) -> bool: 100 | # Ensure both images are in the same mode 101 | if img1.mode != img2.mode: 102 | img1 = img1.convert(img2.mode) 103 | 104 | # Ensure both images are the same size 105 | if img1.size != img2.size: 106 | return False 107 | 108 | # Compare pixel data 109 | np_img1 = np.array(img1) 110 | np_img2 = np.array(img2) 111 | 112 | return np.array_equal(np_img1, np_img2) 113 | 114 | def test_splitting_empty_detection(split_workflow: Workflow): 115 | empty_detections = Detections.empty() 116 | empty_image = Image.new('RGB', (640, 640)) 117 | 118 | result, graph = split_workflow.execute(empty_image, empty_detections) 119 | 120 | assert len(result) == 0, "SplitAgent should return an empty list when there are no detections" 121 | 122 | def test_splitting_and_joining_empty_detection(): 123 | workflow = Workflow([ 124 | SplitAgent(), 125 | JoinAgent() 126 | ]) 127 | 128 | empty_detections = Detections.empty() 129 | empty_image = Image.new('RGB', (640, 640)) 130 | result, graph = workflow.execute(empty_image, empty_detections) 131 | 132 | assert len(result) == 1 133 | assert np.array_equal(result[0].data, [None]), "JoinAgent should return null data" 134 | assert images_are_equal(result[0].image, empty_image), "JoinAgent should return original image" 135 | assert [len(layer)==1 for layer in graph.top_sort()], "Graph should have one node per layer" 136 | 137 | @pytest.fixture 138 | def filter_split_join_workflow() -> Workflow: 139 | workflow = Workflow([ 140 | BoundingBoxSelectAgent(classes=["a single egg", "carton"], model=GroundingDINO()), 141 | SplitAgent(), 142 | JoinAgent(), 143 | FilterClassesAgent(class_names=["a single egg"]), 144 | SplitAgent(), 145 | JoinAgent() 146 | ]) 147 | return workflow 148 | 149 | def test_many_split_joins1(count_eggs_image): 150 | workflow = Workflow([ 151 | BoundingBoxSelectAgent(classes=["person"], model=OwlV2()), 152 | SplitAgent(), 153 | BoundingBoxSelectAgent(classes=["head"]), 154 | SplitAgent(), 155 | ClassificationAgent(classes=["hard hat", "head"]), 156 | JoinAgent(), 157 | JoinAgent() 158 | ]) 159 | result, graph = workflow.execute(count_eggs_image) 160 | 161 | assert all(len(x) == 1 for x in graph.top_sort()), "Graph should have one node per layer" 162 | assert isinstance(result, list), "Complex Split-Join didn't return a list" 163 | assert len(result) > 0, "Complex Split-Join didn't return a list of detections" 164 | assert images_are_equal(result[0].image, count_eggs_image), "Image should be the same" 165 | 166 | def test_many_split_joins2(count_eggs_image): 167 | workflow = Workflow([ 168 | BoundingBoxSelectAgent(classes=["person"], model=OwlV2()), 169 | SplitAgent(), 170 | SplitAgent(), 171 | SplitAgent(), 172 | ClassificationAgent(classes=["hard hat", "head"]), 173 | JoinAgent(), 174 | JoinAgent(), 175 | JoinAgent() 176 | ]) 177 | result, graph = workflow.execute(count_eggs_image) 178 | 179 | assert all(len(x) == 1 for x in graph.top_sort()), "Graph should have one node per layer" 180 | assert isinstance(result, list), "Complex Split-Join didn't return a list" 181 | assert len(result) > 0, "Complex Split-Join didn't return a list of detections" 182 | assert images_are_equal(result[0].image, count_eggs_image), "Image should be the same" 183 | 184 | 185 | def test_many_split_joins3(dense_street_images): 186 | workflow = Workflow([ 187 | BoundingBoxSelectAgent(classes=["person"]), 188 | PadCropAgent.from_uniform_padding(padding=25), 189 | ConfidenceFilterAgent(max_n=5), # this is to save time 190 | SplitAgent(), 191 | BoundingBoxSelectAgent(classes=["head"]), 192 | ConfidenceFilterAgent(max_n=1), 193 | SplitAgent(), 194 | BoundingBoxSelectAgent(classes=["glasses"], model=OwlV2()), 195 | ConfidenceFilterAgent(max_n=1), 196 | SplitAgent(), 197 | ClassificationAgent(classes=["sunglasses", "glasses"]), 198 | JoinAgent(), 199 | JoinAgent(), 200 | JoinAgent() 201 | ]) 202 | 203 | result, _ = workflow.execute(dense_street_images[0]) 204 | 205 | assert isinstance(result, list), "Multi Split-Join didn't return a list" 206 | assert len(result) > 0, "Multi Split-Join didn't return a list of detections" 207 | assert images_are_equal(result[0].image, dense_street_images[0]), "Image should be the same" 208 | 209 | 210 | def test_filter_split_join_workflow(filter_split_join_workflow: Workflow, count_eggs_image): 211 | result, graph = filter_split_join_workflow.execute(count_eggs_image) 212 | assert all(isinstance(x.data, Detections) for x in result), "Filter Split-Join didn't return detections" 213 | assert isinstance(result, list), "Filter Split-Join didn't return a list" 214 | assert len(result) > 0, "Filter Split-Join didn't return a list of detections" 215 | 216 | 217 | def test_mismatched_split_join(): 218 | # This should be fine 219 | workflow = Workflow([ 220 | BoundingBoxSelectAgent(classes=["a single egg"]), 221 | SplitAgent(), 222 | JoinAgent(), 223 | SplitAgent(), 224 | ]) 225 | try: 226 | workflow = Workflow([ 227 | BoundingBoxSelectAgent(classes=["a single egg"]), 228 | JoinAgent(), 229 | SplitAgent(), 230 | SplitAgent(), 231 | ]) 232 | assert False, "Mismatched number of join has no corresponding split" 233 | except ValueError as e: 234 | pass 235 | 236 | workflow = Workflow([ 237 | BoundingBoxSelectAgent(classes=["a single egg"]), 238 | SplitAgent(), 239 | SplitAgent(), 240 | JoinAgent(), 241 | JoinAgent(), 242 | ]) 243 | 244 | -------------------------------------------------------------------------------- /warmup.py: -------------------------------------------------------------------------------- 1 | from overeasy.models import warmup_models 2 | 3 | warmup_models() --------------------------------------------------------------------------------