├── .github
├── ISSUE_TEMPLATE
│ └── bug_report.md
└── workflows
│ ├── test.yml
│ └── wheels.yml
├── .gitignore
├── LICENSE
├── README.md
├── assets
├── graph-diagram-dark.png
└── graph-diagram.png
├── example.png
├── examples
├── basketball.png
├── basketball_workflow.py
├── construction.jpg
├── construction_workflow1.py
├── construction_workflow2.py
├── construction_workflow3.py
├── dense_street1.jpg
├── dense_street2.jpg
├── dense_street3.jpg
├── drivers_license.jpg
├── drivers_license_workflow.py
├── eiffel_photo.png
├── eiffel_photo_workflow.py
├── kites.jpg
├── kites_workflow.py
└── kites_workflow_local.py
├── gradio_example.html
├── kites_workflow.html
├── mypy.ini
├── overeasy
├── __init__.py
├── agents
│ ├── __init__.py
│ ├── dark_mode.txt
│ ├── misc
│ │ ├── __init__.py
│ │ ├── class_map.py
│ │ ├── filter_class.py
│ │ ├── map_agent.py
│ │ ├── max_confidence.py
│ │ ├── min_crop.py
│ │ ├── nms.py
│ │ ├── pad_crop.py
│ │ ├── split_crop.py
│ │ └── to_classification.py
│ ├── model_agents.py
│ ├── scrape.py
│ ├── split_join_agent.py
│ └── workflow.py
├── assets
│ ├── egg_logo.png
│ └── favicon.ico
├── dirs.py
├── download_utils.py
├── logging.py
├── models
│ ├── LLMs
│ │ ├── __init__.py
│ │ ├── anthropic.py
│ │ ├── gemini.py
│ │ ├── openai.py
│ │ ├── pali_gemma.py
│ │ └── qwenvl.py
│ ├── __init__.py
│ ├── classification
│ │ ├── __init__.py
│ │ ├── clip.py
│ │ ├── laion_clip.py
│ │ └── siglip.py
│ ├── detection
│ │ ├── __init__.py
│ │ ├── detclipv2.py
│ │ ├── detic.py
│ │ ├── dino.py
│ │ ├── owlv2.py
│ │ └── yoloworld.py
│ └── recognition
│ │ └── __init__.py
├── py.typed
├── types
│ ├── __init__.py
│ ├── base.py
│ ├── detections.py
│ └── type_utils.py
└── visualize_utils.py
├── poetry.lock
├── pyproject.toml
├── pytest.ini
├── tests
├── conftest.py
├── count_eggs.jpg
├── dogs
│ ├── dog1.png
│ ├── dog2.png
│ ├── dog3.png
│ ├── dog4.png
│ ├── dog5.png
│ ├── dog6.png
│ └── dog7.png
├── plate.jpg
├── test.png
├── test_construction_workflows.py
├── test_detection_models.py
├── test_import.py
├── test_instructor_agents.py
├── test_large_local_models.py
├── test_misc_agents.py
├── test_model_agents.py
├── test_owl.py
├── test_saving_vis.py
└── test_split_join.py
└── warmup.py
/.github/ISSUE_TEMPLATE/bug_report.md:
--------------------------------------------------------------------------------
1 | ---
2 | name: Bug report
3 | about: Create a report to help us improve
4 | title: ''
5 | labels: ''
6 | assignees: ''
7 |
8 | ---
9 |
10 | **Describe the bug**
11 | A clear and concise description of what the bug is.
12 |
13 | **To Reproduce**
14 | Steps to reproduce the behavior:
15 | 1. Go to '...'
16 | 2. Click on '....'
17 | 3. Scroll down to '....'
18 | 4. See error
19 |
20 | **Expected behavior**
21 | A clear and concise description of what you expected to happen.
22 |
23 | **Screenshots**
24 | If applicable, add screenshots to help explain your problem.
25 |
26 | **Debugging Info**
27 | Please include all of the following
28 | - OS: [e.g. Mac, Linux, Windows]
29 | - Python Version
30 | - CUDA Version (nvcc --version)
31 |
--------------------------------------------------------------------------------
/.github/workflows/test.yml:
--------------------------------------------------------------------------------
1 | name: Run Tests
2 | on:
3 | workflow_dispatch: # Only manual triggering
4 |
5 | jobs:
6 | test:
7 | runs-on: ${{ matrix.os }}
8 | strategy:
9 | matrix:
10 | os: [ubuntu-latest, macos-latest]
11 | python-version: ['3.10', '3.11', '3.12']
12 | include:
13 | - os: ubuntu-latest
14 | architecture: x64
15 | - os: macos-latest
16 | architecture: x64
17 | - os: macos-latest
18 | architecture: arm64
19 | steps:
20 | - name: Check out the repository
21 | uses: actions/checkout@v2
22 | - name: Set up Python
23 | uses: actions/setup-python@v2
24 | with:
25 | python-version: ${{ matrix.python-version }}
26 | architecture: ${{ matrix.architecture }}
27 | - name: Install Poetry
28 | run: |
29 | curl -sSL https://install.python-poetry.org | python -
30 | - name: Install dependencies
31 | run: |
32 | poetry install
33 | - name: Run tests
34 | run: |
35 | poetry run pytest
--------------------------------------------------------------------------------
/.github/workflows/wheels.yml:
--------------------------------------------------------------------------------
1 | name: Build and upload to PyPI
2 |
3 | on:
4 | workflow_dispatch:
5 | release:
6 | types:
7 | - published
8 |
9 | jobs:
10 | build_wheels:
11 | runs-on: ubuntu-latest
12 | steps:
13 | - uses: actions/checkout@v4
14 |
15 | - name: Set up Python
16 | uses: actions/setup-python@v3
17 | with:
18 | python-version: '3.9' # or any version you're targeting
19 |
20 | - name: Build wheel
21 | run: |
22 | python -m pip install build
23 | python -m build --wheel --outdir wheelhouse
24 |
25 | - uses: actions/upload-artifact@v4
26 | with:
27 | name: pure-python-wheels
28 | path: ./wheelhouse/*.whl
29 |
30 | build_sdist:
31 | name: Build source distribution
32 | runs-on: ubuntu-latest
33 | steps:
34 | - uses: actions/checkout@v4
35 |
36 | - name: Build sdist
37 | run: pipx run build --sdist
38 |
39 | - uses: actions/upload-artifact@v4
40 | with:
41 | name: cibw-sdist
42 | path: dist/*.tar.gz
43 |
44 | upload_pypi:
45 | needs: [build_wheels, build_sdist]
46 | runs-on: ubuntu-latest
47 | environment: prod
48 | permissions:
49 | id-token: write
50 | if: github.event_name == 'release' && github.event.action == 'published'
51 | # or, alternatively, upload to PyPI on every tag starting with 'v' (remove on: release above to use this)
52 | # if: github.event_name == 'push' && startsWith(github.ref, 'refs/tags/v')
53 | steps:
54 | - uses: actions/download-artifact@v4
55 | with:
56 | # unpacks all CIBW artifacts into dist/
57 | pattern: cibw-*
58 | path: dist
59 | merge-multiple: true
60 |
61 | - uses: pypa/gh-action-pypi-publish@release/v1
62 | # To test: repository-url: https://test.pypi.org/legacy/
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | tests/outputs
2 | # Byte-compiled / optimized / DLL files
3 | __pycache__/
4 | *.py[cod]
5 | *$py.class
6 |
7 | *.DS_Store
8 | # C extensions
9 | *.so
10 |
11 | # Distribution / packaging
12 | .Python
13 | build/
14 | develop-eggs/
15 | dist/
16 | downloads/
17 | eggs/
18 | .eggs/
19 | lib/
20 | lib64/
21 | parts/
22 | sdist/
23 | var/
24 | wheels/
25 | share/python-wheels/
26 | *.egg-info/
27 | .installed.cfg
28 | *.egg
29 | MANIFEST
30 |
31 | # PyInstaller
32 | # Usually these files are written by a python script from a template
33 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
34 | *.manifest
35 | *.spec
36 |
37 | # Installer logs
38 | pip-log.txt
39 | pip-delete-this-directory.txt
40 |
41 | # Unit test / coverage reports
42 | htmlcov/
43 | .tox/
44 | .nox/
45 | .coverage
46 | .coverage.*
47 | .cache
48 | nosetests.xml
49 | coverage.xml
50 | *.cover
51 | *.py,cover
52 | .hypothesis/
53 | .pytest_cache/
54 | cover/
55 |
56 | # Translations
57 | *.mo
58 | *.pot
59 |
60 | # Django stuff:
61 | *.log
62 | local_settings.py
63 | db.sqlite3
64 | db.sqlite3-journal
65 |
66 | # Flask stuff:
67 | instance/
68 | .webassets-cache
69 |
70 | # Scrapy stuff:
71 | .scrapy
72 |
73 | # Sphinx documentation
74 | docs/_build/
75 |
76 | # PyBuilder
77 | .pybuilder/
78 | target/
79 |
80 | # Jupyter Notebook
81 | .ipynb_checkpoints
82 |
83 | # IPython
84 | profile_default/
85 | ipython_config.py
86 |
87 | # pyenv
88 | # For a library or package, you might want to ignore these files since the code is
89 | # intended to run in multiple environments; otherwise, check them in:
90 | # .python-version
91 |
92 | # pipenv
93 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
94 | # However, in case of collaboration, if having platform-specific dependencies or dependencies
95 | # having no cross-platform support, pipenv may install dependencies that don't work, or not
96 | # install all needed dependencies.
97 | #Pipfile.lock
98 |
99 | # poetry
100 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
101 | # This is especially recommended for binary packages to ensure reproducibility, and is more
102 | # commonly ignored for libraries.
103 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
104 | #poetry.lock
105 |
106 | # pdm
107 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
108 | #pdm.lock
109 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
110 | # in version control.
111 | # https://pdm.fming.dev/#use-with-ide
112 | .pdm.toml
113 |
114 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
115 | __pypackages__/
116 |
117 | # Celery stuff
118 | celerybeat-schedule
119 | celerybeat.pid
120 |
121 | # SageMath parsed files
122 | *.sage.py
123 |
124 | # Environments
125 | .env
126 | .venv
127 | env/
128 | venv/
129 | ENV/
130 | env.bak/
131 | venv.bak/
132 |
133 | # Spyder project settings
134 | .spyderproject
135 | .spyproject
136 |
137 | # Rope project settings
138 | .ropeproject
139 |
140 | # mkdocs documentation
141 | /site
142 |
143 | # mypy
144 | .mypy_cache/
145 | .dmypy.json
146 | dmypy.json
147 |
148 | # Pyre type checker
149 | .pyre/
150 |
151 | # pytype static type analyzer
152 | .pytype/
153 |
154 | # Cython debug symbols
155 | cython_debug/
156 |
157 | # PyCharm
158 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can
159 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
160 | # and can be added to the global gitignore or merged into this file. For a more nuclear
161 | # option (not recommended) you can uncomment the following to ignore the entire idea folder.
162 | #.idea/
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2024 Overeasy
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 |
🥚 Overeasy
2 |
3 |
4 |
5 |
6 |
7 |
8 |
9 |
10 |
11 |
12 |
13 |
14 |
Create powerful zero-shot vision models!
15 |
16 |
17 | Overeasy allows you to chain zero-shot vision models to create custom end-to-end pipelines for tasks like:
18 |
19 | - 📦 Bounding Box Detection
20 | - 🏷️ Classification
21 | - 🖌️ Segmentation (Coming Soon!)
22 |
23 | All of this can be achieved without needing to collect and annotate large training datasets.
24 |
25 | Overeasy makes it simple to combine pre-trained zero-shot models to build powerful custom computer vision solutions.
26 |
27 |
28 | ## Installation
29 | It's as easy as
30 | ```bash
31 | pip install overeasy
32 | ```
33 |
34 | For installing extras refer to our [Docs](https://docs.overeasy.sh/installation/installing-extras).
35 |
36 | ## Key Features
37 | - `🤖 Agents`: Specialized tools that perform specific image processing tasks.
38 | - `🧩 Workflows`: Define a sequence of Agents to process images in a structured manner.
39 | - `🔗 Execution Graphs`: Manage and visualize the image processing pipeline.
40 | - `🔎 Detections`: Represent bounding boxes, segmentation, and classifications.
41 |
42 |
43 | ## Documentation
44 | For more details on types, library structure, and available models please refer to our [Docs](https://docs.overeasy.sh).
45 |
46 | ## Example Usage
47 |
48 | > Note: If you don't have a local GPU, you can run our examples by making a copy of this [Colab notebook](https://colab.research.google.com/drive/1Mkx9S6IG5130wiP9WmwgINiyw0hPsh3c?usp=sharing#scrollTo=L0_U27WJaTNO).
49 |
50 |
51 | Download example image
52 | ```bash
53 | !wget https://github.com/overeasy-sh/overeasy/blob/73adbaeba51f532a7023243266da826ed1ced6ec/examples/construction.jpg?raw=true -O construction.jpg
54 | ```
55 |
56 | Example workflow to identify if a person is wearing a PPE on a work site:
57 | ```python
58 | from overeasy import *
59 | from overeasy.models import OwlV2
60 | from PIL import Image
61 |
62 | workflow = Workflow([
63 | # Detect each head in the input image
64 | BoundingBoxSelectAgent(classes=["person's head"], model=OwlV2()),
65 | # Applies Non-Maximum Suppression to remove overlapping bounding boxes
66 | NMSAgent(iou_threshold=0.5, score_threshold=0),
67 | # Splits the input image into images of each detected head
68 | SplitAgent(),
69 | # Classifies the split images using CLIP
70 | ClassificationAgent(classes=["hard hat", "no hard hat"]),
71 | # Maps the returned class names
72 | ClassMapAgent({"hard hat": "has ppe", "no hard hat": "no ppe"}),
73 | # Combines results back into a BoundingBox Detection
74 | JoinAgent()
75 | ])
76 |
77 | image = Image.open("./construction.jpg")
78 | result, graph = workflow.execute(image)
79 | workflow.visualize(graph)
80 | ```
81 |
82 | ### Diagram
83 |
84 | Here's a diagram of this workflow. Each layer in the graph represents a step in the workflow:
85 |
87 |
88 |
89 |
90 |
91 |
92 |
93 |
94 | The image and data attributes in each node are used together to visualize the current state of the workflow. Calling the `visualize` function on the workflow will spawn a Gradio instance that looks like [this](https://overeasy-sh.github.io/gradio-example/Gradio.html).
95 |
96 | ## Support
97 | If you have any questions or need assistance, please open an issue or reach out to us at help@overeasy.sh.
98 |
99 |
100 | Let's build amazing vision models together 🍳!
101 |
--------------------------------------------------------------------------------
/assets/graph-diagram-dark.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/overeasy-sh/overeasy/8de64ae6a21e0e8c757d3359c52c6134558c34b5/assets/graph-diagram-dark.png
--------------------------------------------------------------------------------
/assets/graph-diagram.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/overeasy-sh/overeasy/8de64ae6a21e0e8c757d3359c52c6134558c34b5/assets/graph-diagram.png
--------------------------------------------------------------------------------
/example.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/overeasy-sh/overeasy/8de64ae6a21e0e8c757d3359c52c6134558c34b5/example.png
--------------------------------------------------------------------------------
/examples/basketball.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/overeasy-sh/overeasy/8de64ae6a21e0e8c757d3359c52c6134558c34b5/examples/basketball.png
--------------------------------------------------------------------------------
/examples/basketball_workflow.py:
--------------------------------------------------------------------------------
1 | import os
2 | from overeasy import *
3 | from overeasy.models import GPTVision
4 | from PIL import Image
5 |
6 | workflow = Workflow([
7 | BinaryChoiceAgent("Do the Celtics have possession of the ball?", model=GPTVision(model="gpt-4o"))
8 | ])
9 | image_path = os.path.join(os.path.dirname(__file__), "basketball.png")
10 | image = Image.open(image_path)
11 | result, graph = workflow.execute(image)
12 | print(result[0].data.class_names)
--------------------------------------------------------------------------------
/examples/construction.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/overeasy-sh/overeasy/8de64ae6a21e0e8c757d3359c52c6134558c34b5/examples/construction.jpg
--------------------------------------------------------------------------------
/examples/construction_workflow1.py:
--------------------------------------------------------------------------------
1 | import os
2 | from overeasy import Workflow, BoundingBoxSelectAgent, ToClassificationAgent, JoinAgent, SplitAgent, InstructorImageAgent, YOLOWorld
3 | from pydantic import BaseModel
4 | from PIL import Image
5 |
6 | class PPE(BaseModel):
7 | hardhat: bool
8 | vest: bool
9 | boots: bool
10 |
11 |
12 | workflow = Workflow([
13 | BoundingBoxSelectAgent(classes=["person"], model=YOLOWorld()),
14 | SplitAgent(),
15 | InstructorImageAgent(response_model=PPE),
16 | ToClassificationAgent(fn=lambda x: "has ppe" if x.hardhat else "no ppe"),
17 | # ClassMapAgent({"has ppe": "has ppe", "no ppe": "no ppe"}),
18 | JoinAgent(),
19 | ])
20 |
21 | image_path = os.path.join(os.path.dirname(__file__), "construction.jpg")
22 | image = Image.open(image_path)
23 | result, graph = workflow.execute(image)
24 | workflow.visualize(graph)
25 |
--------------------------------------------------------------------------------
/examples/construction_workflow2.py:
--------------------------------------------------------------------------------
1 | import os
2 | from overeasy import *
3 | from overeasy.models import OwlV2
4 | from PIL import Image
5 |
6 | workflow = Workflow([
7 | BoundingBoxSelectAgent(classes=["person's head"], model=OwlV2()),
8 | NMSAgent(iou_threshold=0.5, score_threshold=0),
9 | SplitAgent(),
10 | ClassificationAgent(classes=["hard hat", "no hard hat"]),
11 | ClassMapAgent({"hard hat": "has ppe", "no hard hat": "no ppe"}),
12 | JoinAgent(),
13 | ])
14 |
15 | image_path = os.path.join(os.path.dirname(__file__), "construction.jpg")
16 | image = Image.open(image_path)
17 | result, graph = workflow.execute(image)
18 | workflow.visualize(graph)
19 |
--------------------------------------------------------------------------------
/examples/construction_workflow3.py:
--------------------------------------------------------------------------------
1 | import os
2 | from overeasy import *
3 | from overeasy.models import OwlV2
4 | from PIL import Image
5 |
6 | workflow = Workflow([
7 | # Detect each head in the input image
8 | BoundingBoxSelectAgent(classes=["person's head"], model=OwlV2()),
9 | # Applies Non-Maximum Suppression to the detected heads
10 | # It's reccomended to use NMS when using OwlV2
11 | NMSAgent(iou_threshold=0.5),
12 | # Splits the input image into an image of each detected head
13 | SplitAgent(),
14 | # Classifies PPE using CLIP
15 | ClassificationAgent(classes=["hard hat", "no hard hat"]),
16 | # Maps the returned class names
17 | ClassMapAgent({"hard hat": "has ppe", "no hard hat": "no ppe"}),
18 | # Combines results back into a BoundingBox Detection
19 | JoinAgent()
20 | ])
21 |
22 | image_path = os.path.join(os.path.dirname(__file__), "construction.jpg")
23 | image = Image.open(image_path)
24 | result, graph = workflow.execute(image)
25 | workflow.visualize(graph)
--------------------------------------------------------------------------------
/examples/dense_street1.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/overeasy-sh/overeasy/8de64ae6a21e0e8c757d3359c52c6134558c34b5/examples/dense_street1.jpg
--------------------------------------------------------------------------------
/examples/dense_street2.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/overeasy-sh/overeasy/8de64ae6a21e0e8c757d3359c52c6134558c34b5/examples/dense_street2.jpg
--------------------------------------------------------------------------------
/examples/dense_street3.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/overeasy-sh/overeasy/8de64ae6a21e0e8c757d3359c52c6134558c34b5/examples/dense_street3.jpg
--------------------------------------------------------------------------------
/examples/drivers_license.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/overeasy-sh/overeasy/8de64ae6a21e0e8c757d3359c52c6134558c34b5/examples/drivers_license.jpg
--------------------------------------------------------------------------------
/examples/drivers_license_workflow.py:
--------------------------------------------------------------------------------
1 | import os
2 | from overeasy import *
3 | from overeasy.models import GPTVision
4 | from PIL import Image
5 | from pydantic import BaseModel
6 |
7 | class License(BaseModel):
8 | license_number: str
9 | expiration_date: str
10 | state: str
11 | data_of_birth: str
12 | address: str
13 | first_name: str
14 | last_name: str
15 |
16 | workflow = Workflow([
17 | InstructorImageAgent(response_model=License, model=GPTVision())
18 | ])
19 |
20 | image_path = os.path.join(os.path.dirname(__file__), "drivers_license.jpg")
21 | image = Image.open(image_path)
22 | result, graph = workflow.execute(image)
23 | print(repr(result[0].data))
--------------------------------------------------------------------------------
/examples/eiffel_photo.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/overeasy-sh/overeasy/8de64ae6a21e0e8c757d3359c52c6134558c34b5/examples/eiffel_photo.png
--------------------------------------------------------------------------------
/examples/eiffel_photo_workflow.py:
--------------------------------------------------------------------------------
1 | import os
2 | from overeasy import *
3 | from overeasy.models import GPTVision
4 | from PIL import Image
5 |
6 | workflow = Workflow([
7 | DenseCaptioningAgent(model=GPTVision(model="gpt-4o"))
8 | ])
9 | image_path = os.path.join(os.path.dirname(__file__), "eiffel_photo.png")
10 | image = Image.open(image_path)
11 | result, graph = workflow.execute(image)
12 | print(result[0].data)
--------------------------------------------------------------------------------
/examples/kites.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/overeasy-sh/overeasy/8de64ae6a21e0e8c757d3359c52c6134558c34b5/examples/kites.jpg
--------------------------------------------------------------------------------
/examples/kites_workflow.py:
--------------------------------------------------------------------------------
1 | import os
2 | from overeasy import *
3 | from overeasy.models import OwlV2
4 | from PIL import Image
5 |
6 | workflow = Workflow([
7 | BoundingBoxSelectAgent(classes=["butterfly kite"], model=OwlV2()),
8 | NMSAgent(iou_threshold=0.8, score_threshold=0.6),
9 | ])
10 |
11 | image_path = os.path.join(os.path.dirname(__file__), "kites.jpg")
12 | image = Image.open(image_path)
13 | results, graph = workflow.execute(image)
14 | workflow.visualize(graph)
--------------------------------------------------------------------------------
/examples/kites_workflow_local.py:
--------------------------------------------------------------------------------
1 | import sys
2 | import os
3 | sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), '../')))
4 |
5 | import os
6 | from overeasy import *
7 | from overeasy.models import OwlV2
8 | from PIL import Image
9 |
10 | workflow = Workflow([
11 | BoundingBoxSelectAgent(classes=["butterfly kite"], model=OwlV2()),
12 | NMSAgent(iou_threshold=0.8, score_threshold=0.6),
13 | ])
14 |
15 | image_path = os.path.join(os.path.dirname(__file__), "kites.jpg")
16 | image = Image.open(image_path)
17 | results, graph = workflow.execute(image)
18 | workflow.visualize_to_file(graph, "kites_workflow.html")
--------------------------------------------------------------------------------
/mypy.ini:
--------------------------------------------------------------------------------
1 | [mypy]
2 | ignore_missing_imports = True
3 |
--------------------------------------------------------------------------------
/overeasy/__init__.py:
--------------------------------------------------------------------------------
1 | from overeasy.agents import *
2 | from overeasy.models import *
3 |
4 | __version__ = "0.2.16"
5 |
6 |
7 | from overeasy.dirs import ROOT as _ROOT
8 | import os as _os
9 | _os.makedirs(_ROOT, exist_ok=True)
10 |
--------------------------------------------------------------------------------
/overeasy/agents/__init__.py:
--------------------------------------------------------------------------------
1 | from .misc import *
2 | from .split_join_agent import SplitAgent, JoinAgent
3 | from .workflow import Workflow
4 | from .model_agents import *
5 |
--------------------------------------------------------------------------------
/overeasy/agents/dark_mode.txt:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/overeasy/agents/misc/__init__.py:
--------------------------------------------------------------------------------
1 | from .pad_crop import PadCropAgent
2 | from .split_crop import SplitCropAgent
3 | from .nms import NMSAgent
4 | from .class_map import ClassMapAgent
5 | from .map_agent import MapAgent
6 | from .to_classification import ToClassificationAgent
7 | from .filter_class import FilterClassesAgent
8 | from .max_confidence import ConfidenceFilterAgent
--------------------------------------------------------------------------------
/overeasy/agents/misc/class_map.py:
--------------------------------------------------------------------------------
1 | from overeasy.types import DetectionAgent, Detections
2 | from typing import Dict
3 |
4 | class ClassMapAgent(DetectionAgent):
5 | def __init__(self, class_map: Dict[str, str]):
6 | self.class_map = class_map
7 |
8 | def _execute(self, dets: Detections) -> Detections:
9 | unmapped_classes = set(dets.class_names) - set(self.class_map.keys())
10 | if unmapped_classes:
11 | raise ValueError(f"Class names {unmapped_classes} not mapped")
12 | class_names = [self.class_map[x] for x in dets.class_names]
13 |
14 | new_dets = Detections(
15 | xyxy=dets.xyxy,
16 | class_ids=dets.class_ids,
17 | confidence=dets.confidence,
18 | classes=class_names,
19 | data=dets.data,
20 | detection_type=dets.detection_type
21 | )
22 | return new_dets
23 |
24 | def __repr__(self):
25 | return f"{self.__class__.__name__}(class_map={self.class_map})"
26 |
--------------------------------------------------------------------------------
/overeasy/agents/misc/filter_class.py:
--------------------------------------------------------------------------------
1 | from dataclasses import dataclass
2 | from overeasy.types import Detections, DetectionAgent
3 | from typing import List, Optional
4 | import numpy as np
5 |
6 | @dataclass
7 | class FilterClassesAgent(DetectionAgent):
8 | class_names: Optional[List[str]]
9 | class_ids: Optional[List[int]]
10 |
11 | """
12 | Filter detections by class name or class id.
13 | """
14 | def __init__(self, class_names: Optional[List[str]] = None, class_ids: Optional[List[int]] = None):
15 | self.class_names = class_names
16 | self.class_ids = class_ids
17 |
18 | if class_names is None and class_ids is None:
19 | raise ValueError("Must specify class_name or class_id")
20 | if class_names is not None and class_ids is not None:
21 | raise ValueError("Can only specify one of class_names or class_ids")
22 |
23 |
24 | def _execute(self, dets: Detections) -> Detections:
25 | if self.class_ids is not None:
26 | slice = np.isin(dets.class_ids, self.class_ids)
27 | elif self.class_names is not None:
28 | slice = np.isin(dets.class_names, self.class_names)
29 | else:
30 | raise ValueError("No filter specified")
31 |
32 | return dets[slice]
33 |
34 | def __repr__(self):
35 | return f"{self.__class__.__name__}(confidence_threshold={self.confidence_threshold})"
36 |
37 |
--------------------------------------------------------------------------------
/overeasy/agents/misc/map_agent.py:
--------------------------------------------------------------------------------
1 | from typing import Callable, Any
2 | from overeasy.types import DataAgent
3 |
4 | class MapAgent(DataAgent):
5 | def __init__(self, fn: Callable[[Any], Any]):
6 | self.fn = fn
7 |
8 | def _execute(self, data: Any) -> Any:
9 | return self.fn(data)
10 |
11 | def __repr__(self):
12 | return f"{self.__class__.__name__}(fn={self.fn})"
--------------------------------------------------------------------------------
/overeasy/agents/misc/max_confidence.py:
--------------------------------------------------------------------------------
1 | from dataclasses import dataclass
2 | from overeasy.types import Detections, DetectionAgent
3 | from typing import List, Optional
4 | import numpy as np
5 |
6 | @dataclass
7 | class ConfidenceFilterAgent(DetectionAgent):
8 | max_n: Optional[int] = None
9 | min_confidence: Optional[float] = None
10 |
11 | """
12 | Filter detections by confidence, either selecting the top N or those above a minimum confidence threshold.
13 | """
14 | def __init__(self, max_n: Optional[int] = None, min_confidence: Optional[float] = None):
15 | self.max_n = max_n
16 | self.min_confidence = min_confidence
17 |
18 | if max_n is None and min_confidence is None:
19 | raise ValueError("Must specify either max_n or min_confidence")
20 | if max_n is not None and max_n < 1:
21 | raise ValueError("max_n must be at least 1")
22 | if min_confidence is not None and (min_confidence < 0 or min_confidence > 1):
23 | raise ValueError("min_confidence must be between 0 and 1")
24 |
25 | def _execute(self, dets: Detections) -> Detections:
26 | if dets.confidence is None:
27 | raise ValueError("Detections must have confidence scores")
28 |
29 | indices = np.arange(len(dets.confidence))
30 |
31 | if self.min_confidence is not None:
32 | indices = indices[dets.confidence >= self.min_confidence]
33 |
34 | if self.max_n is not None:
35 | if len(indices) > self.max_n:
36 | sorted_indices = np.argsort(dets.confidence[indices])[-self.max_n:][::-1]
37 | indices = indices[sorted_indices]
38 | else:
39 | indices = np.argsort(dets.confidence[indices])[::-1]
40 |
41 | if len(indices) == 0:
42 | return Detections.empty()
43 |
44 | return dets[indices.astype(int)]
45 |
46 | def __repr__(self):
47 | return f"{self.__class__.__name__}(max_n={self.max_n}, min_confidence={self.min_confidence})"
48 |
--------------------------------------------------------------------------------
/overeasy/agents/misc/min_crop.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/overeasy-sh/overeasy/8de64ae6a21e0e8c757d3359c52c6134558c34b5/overeasy/agents/misc/min_crop.py
--------------------------------------------------------------------------------
/overeasy/agents/misc/nms.py:
--------------------------------------------------------------------------------
1 | from overeasy.types import Detections, DetectionType, DetectionAgent
2 | import cv2
3 |
4 | def do_nms(dets: Detections, nms_threshold: float, score_threshold : float = 0.0) -> Detections:
5 | if dets.detection_type != DetectionType.BOUNDING_BOX:
6 | raise ValueError("Only bounding box detections are supported for NMS.")
7 |
8 | bboxes = dets.xyxy
9 | scores = dets.confidence
10 | if scores is None:
11 | raise ValueError("Confidence scores are required for NMS.")
12 |
13 | indices = list(cv2.dnn.NMSBoxes(bboxes, scores, score_threshold, nms_threshold))
14 |
15 | return dets[indices]
16 |
17 | class NMSAgent(DetectionAgent):
18 | # IoU (Intersection over Union) Threshold: Determines the minimum overlap between two bounding boxes to consider them as the same object.
19 | # Score Threshold: Filters out detections that have a confidence score below this value before applying NMS.
20 | def __init__(self, iou_threshold: float, score_threshold: float = 0.0):
21 | self.iou_threshold = iou_threshold
22 | self.score_threshold = score_threshold
23 |
24 | def _execute(self, dets: Detections) -> Detections:
25 | if dets.detection_type != DetectionType.BOUNDING_BOX:
26 | raise ValueError("Only bounding box detections are supported for NMS.")
27 | return do_nms(dets, self.iou_threshold, self.score_threshold)
28 |
29 | def __repr__(self):
30 | return f"NMSAgent(iou_threshold={self.iou_threshold}, score_threshold={self.score_threshold})"
31 |
--------------------------------------------------------------------------------
/overeasy/agents/misc/pad_crop.py:
--------------------------------------------------------------------------------
1 | from overeasy.types import Detections, DetectionType, DetectionAgent
2 | from pydantic.dataclasses import dataclass
3 | import numpy as np
4 |
5 | @dataclass
6 | class PadCropAgent(DetectionAgent):
7 | x1_pad: int
8 | y1_pad: int
9 | x2_pad: int
10 | y2_pad: int
11 |
12 | @classmethod
13 | def from_uniform_padding(cls, padding):
14 | return cls(padding, padding, padding, padding)
15 |
16 | @classmethod
17 | def from_xy_padding(cls, x_pad, y_pad):
18 | return cls(x_pad, y_pad, x_pad, y_pad)
19 |
20 | def _execute(self, dets: Detections) -> Detections:
21 | if dets.detection_type != DetectionType.BOUNDING_BOX:
22 | raise ValueError("Only bounding box detections are supported for padding.")
23 |
24 | padded_bboxes = []
25 | for bbox in dets.xyxy:
26 | x1, y1, x2, y2 = bbox
27 | padded_bbox = [
28 | x1 - self.x1_pad,
29 | y1 - self.y1_pad,
30 | x2 + self.x2_pad,
31 | y2 + self.y2_pad
32 | ]
33 | padded_bboxes.append(padded_bbox)
34 |
35 | dets.xyxy = np.array(padded_bboxes)
36 | dets.__post_init__()
37 | return dets
38 |
39 | def __repr__(self):
40 | return f"{self.__class__.__name__}(x1_pad={self.x1_pad}, y1_pad={self.y1_pad}, x2_pad={self.x2_pad}, y2_pad={self.y2_pad})"
--------------------------------------------------------------------------------
/overeasy/agents/misc/split_crop.py:
--------------------------------------------------------------------------------
1 | from PIL import Image
2 | from pydantic.dataclasses import dataclass
3 | from typing import Tuple
4 | from overeasy.types import ExecutionNode, ImageDetectionAgent, Detections, DetectionType
5 | import numpy as np
6 |
7 | class SplitCropAgent(ImageDetectionAgent):
8 | # WE HAVE PYDANTIC AT HOME
9 | def __init__(self, split: Tuple[int, int]):
10 | if len(split) != 2:
11 | raise ValueError("Split must be a tuple of two integers")
12 | if split[0] <= 0 or split[1] <= 0:
13 | raise ValueError("Split must be greater than 0")
14 | if type(split[0]) != int or type(split[1]) != int:
15 | raise ValueError("Split must be a tuple of two integers")
16 |
17 | self.rows = split[0]
18 | self.columns = split[1]
19 |
20 |
21 | def _execute(self, image: Image.Image) -> Detections:
22 | # Convert PIL Image to numpy array
23 | width, height = image.width, image.height
24 |
25 | # Calculate the size of each block
26 | block_height = height // self.rows
27 | block_width = width // self.columns
28 |
29 | left_over_height = height - block_height * self.rows
30 | left_over_width = width - block_width * self.columns
31 |
32 | # Initialize a list to hold the bounding boxes
33 | bounding_boxes = []
34 |
35 | # Create bounding boxes
36 | for row in range(self.rows):
37 | for col in range(self.columns):
38 | x1 = col * block_width + min(col, left_over_width)
39 | y1 = row * block_height + min(row, left_over_height)
40 | x2 = x1 + block_width + (1 if col < left_over_width else 0)
41 | y2 = y1 + block_height + (1 if row < left_over_height else 0)
42 | bounding_boxes.append((x1, y1, x2, y2))
43 |
44 | # Convert lists to numpy arrays
45 | bounding_boxes_np = np.array(bounding_boxes)
46 | confidence_np = np.ones(len(bounding_boxes_np))
47 | class_ids_np = np.arange(len(bounding_boxes_np))
48 | classes_np = np.array([f'split_{i+1}' for i in range(len(bounding_boxes_np))])
49 |
50 | # Create Detections object
51 | det = Detections(
52 | xyxy=bounding_boxes_np,
53 | confidence=confidence_np, # Confidence for each box
54 | class_ids=class_ids_np, # Unique class ID for each box
55 | classes=classes_np, # Class names
56 | detection_type=DetectionType.BOUNDING_BOX
57 | )
58 |
59 | return det
60 |
61 | def __repr__(self):
62 | return f"{self.__class__.__name__}(split={self.split})"
--------------------------------------------------------------------------------
/overeasy/agents/misc/to_classification.py:
--------------------------------------------------------------------------------
1 | from typing import Union, Callable, Any, List
2 | from overeasy.types import DataAgent, ExecutionNode, Detections
3 |
4 | class ToClassificationAgent(DataAgent):
5 | def __init__(self, fn: Union[Callable[[Any], str], Callable[[Any], List[str]]]):
6 | self.fn = fn
7 |
8 | def _execute(self, data: Any) -> Any:
9 | res = self.fn(data)
10 | if isinstance(res, list) and all(isinstance(x, str) for x in res):
11 | return Detections.from_classification(res)
12 | elif isinstance(res, str):
13 | return Detections.from_classification([res])
14 | else:
15 | raise ValueError(f"{self.__class__.__name__} must return a string or list of strings")
16 |
17 | def __repr__(self):
18 | return f"{self.__class__.__name__}(fn={self.fn})"
19 |
--------------------------------------------------------------------------------
/overeasy/agents/model_agents.py:
--------------------------------------------------------------------------------
1 | from PIL import Image
2 | from torch import mode
3 | from overeasy.models import *
4 | from overeasy.types import *
5 | from pydantic import BaseModel
6 | from typing import List, Union, Optional, Dict, Any
7 | import instructor
8 | import base64, io
9 | import google.generativeai as genai
10 |
11 | __all__ = [
12 | "BoundingBoxSelectAgent",
13 | "VisionPromptAgent",
14 | "DenseCaptioningAgent",
15 | "TextPromptAgent",
16 | "BinaryChoiceAgent",
17 | "ClassificationAgent",
18 | "OCRAgent",
19 | "InstructorImageAgent",
20 | "InstructorTextAgent"
21 | ]
22 |
23 | class BoundingBoxSelectAgent(ImageDetectionAgent):
24 | def __init__(self, classes: List[str], model: BoundingBoxModel = GroundingDINO()):
25 | self.classes = classes
26 | self.model = model
27 |
28 | def _execute(self, image: Image.Image) -> Detections:
29 | return self.model.detect(image, self.classes)
30 |
31 | def __repr__(self):
32 | model_name = self.model.__class__.__name__ if self.model else "None"
33 | return f"{self.__class__.__name__}(classes={self.classes}, model={model_name})"
34 |
35 | class VisionPromptAgent(ImageToTextAgent):
36 | def __init__(self, query: str, model: MultimodalLLM = GPTVision()):
37 | self.query = query
38 | self.model = model
39 |
40 | def _execute(self, image: Image.Image)-> str:
41 | prompt = f"""{self.query}"""
42 | response = self.model.prompt_with_image(image, prompt)
43 | return response
44 |
45 | def __repr__(self):
46 | model_name = self.model.__class__.__name__
47 | return f"{self.__class__.__name__}(query={self.query}, model={model_name})"
48 |
49 | class DenseCaptioningAgent(ImageToTextAgent):
50 | def __init__(self, model: Union[MultimodalLLM, CaptioningModel] = GPTVision()):
51 | self.model = model
52 |
53 | def _execute(self, image: Image.Image)-> str:
54 | prompt = f"""Describe the following image in detail"""
55 | if isinstance(self.model, MultimodalLLM):
56 | response = self.model.prompt_with_image(image, prompt)
57 | else:
58 | response = self.model.caption(image)
59 |
60 | return response
61 |
62 | def __repr__(self):
63 | model_name = self.model.__class__.__name__
64 | return f"{self.__class__.__name__}(model={model_name})"
65 |
66 | class TextPromptAgent(TextAgent):
67 | def __init__(self, query: str, model: LLM = GPT()):
68 | self.query = query
69 | self.model = model
70 |
71 | def _execute(self, text: str)-> str:
72 | prompt = f"""{text}\n{self.query}"""
73 | response = self.model.prompt(prompt)
74 | return response
75 |
76 | def __repr__(self):
77 | model_name = self.model.__class__.__name__
78 | return f"{self.__class__.__name__}(query={self.query}, model={model_name})"
79 |
80 | # Convert binary_choice into an agent class
81 | class BinaryChoiceAgent(ImageToTextAgent):
82 | def __init__(self, query: str, model: MultimodalLLM = GPTVision()):
83 | self.query = query
84 | self.model = model
85 |
86 | def _execute(self, image: Image.Image)-> str:
87 | prompt = f"""{self.query}"""
88 | response = self.model.prompt_with_image(image, prompt)
89 | truthy = "yes" in response.lower()
90 | assigned_class = "yes" if truthy else "no"
91 | return assigned_class
92 |
93 | def __repr__(self):
94 | model_name = self.model.__class__.__name__
95 | return f"{self.__class__.__name__}(query={self.query}, model={model_name})"
96 |
97 | class ClassificationAgent(ImageDetectionAgent):
98 | def __init__(self, classes, model: ClassificationModel = CLIP()):
99 | self.classes = classes
100 | self.model = model
101 |
102 | def _execute(self, image: Image.Image)-> Detections:
103 | selected_class = self.model.classify(image, self.classes)
104 | return selected_class
105 |
106 | def __repr__(self):
107 | model_name = self.model.__class__.__name__
108 | return f"{self.__class__.__name__}(classes={self.classes}, model={model_name})"
109 |
110 | class OCRAgent(ImageToTextAgent):
111 | def __init__(self, model: Optional[OCRModel] = None):
112 | self.model = model if model is not None else GPTVision()
113 |
114 | def _execute(self, image: Image.Image)-> str:
115 | response = self.model.parse_text(image)
116 | return response
117 |
118 | def __repr__(self):
119 | model_name = self.model.__class__.__name__
120 | return f"{self.__class__.__name__}(model={model_name})"
121 |
122 | import warnings
123 | warnings.filterwarnings("ignore", category=DeprecationWarning)
124 |
125 | options = Union[GPTVision, Gemini, Claude]
126 |
127 |
128 |
129 | class InstructorImageAgent(ImageToDataAgent):
130 |
131 | def __init__(self, response_model: type[BaseModel], model: Union[GPTVision, Gemini, Claude] = GPTVision(), max_tokens: int = 4096, extra_context: Optional[List[Dict[str, str]]] = None):
132 | self.response_model = response_model
133 | self.model = model
134 | self.extra_context = extra_context if extra_context is not None else []
135 | self.max_tokens = max_tokens
136 | if not isinstance(self.model, GPTVision) and not isinstance(self.model, Gemini) and not isinstance(self.model, Claude):
137 | raise ValueError("Model must be a GPTVision, Gemini, or Claude")
138 |
139 | def _execute(self, image: Image.Image)-> Any:
140 | model_name = ""
141 | model_client = self.model.client
142 | if model_client is None:
143 | raise ValueError("No client found. Please call load_resources() on model")
144 |
145 | if isinstance(self.model, GPTVision):
146 | client = instructor.from_openai(model_client)
147 | model_name = self.model.model
148 | elif isinstance(self.model, Gemini):
149 | client = instructor.from_gemini(model_client)
150 | model_name = self.model.model_name
151 | elif isinstance(self.model, Claude):
152 | client = instructor.from_anthropic(model_client)
153 | model_name = self.model.model
154 |
155 |
156 | if isinstance(self.model, Gemini):
157 | return client.chat.completions.create(
158 | response_model=self.response_model,
159 | messages=[
160 | *self.extra_context,
161 | {
162 | "role": "user",
163 | "content": image
164 | }
165 | ],
166 | max_tokens=self.max_tokens,
167 | )
168 | elif isinstance(self.model, GPTVision):
169 | buffered = io.BytesIO()
170 | image.save(buffered, format="PNG")
171 | base64_image = base64.b64encode(buffered.getvalue()).decode('utf-8')
172 |
173 | messages = [*self.extra_context, {"role": "user", "content": [
174 | {
175 | "type": "image_url",
176 | "image_url": {
177 | "url": f"data:image/png;base64,{base64_image}"
178 | }
179 | }
180 | ]}]
181 | return client.chat.completions.create(
182 | model=model_name,
183 | response_model=self.response_model,
184 | messages=messages,
185 | max_tokens=self.max_tokens,
186 | )
187 | else:
188 | buffered = io.BytesIO()
189 | image.save(buffered, format="PNG")
190 | base64_image = base64.b64encode(buffered.getvalue()).decode('utf-8')
191 |
192 | messages = [*self.extra_context, {"role": "user", "content": [
193 | {
194 | "type": "image",
195 | "source": {
196 | "type": "base64",
197 | "media_type": "image/png",
198 | "data": base64_image
199 | }
200 | }
201 | ]}]
202 |
203 | return client.chat.completions.create(
204 | model=model_name,
205 | response_model=self.response_model,
206 | messages=messages,
207 | max_tokens=self.max_tokens,
208 | )
209 |
210 | def __repr__(self):
211 | model_name = self.model
212 | return f"{self.__class__.__name__}(response_model={self.response_model}, model={repr(model_name)})"
213 |
214 |
215 | class InstructorTextAgent(TextAgent):
216 | def __init__(self, response_model: type[BaseModel], model: Union[GPT, Gemini, Claude] = GPT(), max_tokens: int = 4096, extra_context: Optional[List[Dict[str, str]]] = None):
217 | self.response_model = response_model
218 | self.model = model
219 | self.extra_context = extra_context if extra_context is not None else []
220 | self.max_tokens = max_tokens
221 | if not isinstance(self.model, GPT) and not isinstance(self.model, Gemini) and not isinstance(self.model, Claude):
222 | raise ValueError("Model must be a GPT, Gemini, or Claude")
223 |
224 | def _execute(self, text: str)-> Any:
225 | model_client = self.model.client
226 | if model_client is None:
227 | raise ValueError("No client found. Please call load_resources() on model")
228 |
229 | if isinstance(self.model, GPT):
230 | client = instructor.from_openai(model_client)
231 | model_name = self.model.model
232 | elif isinstance(self.model, Gemini):
233 | client = instructor.from_gemini(model_client)
234 | model_name = self.model.model_name
235 | elif isinstance(self.model, Claude):
236 | client = instructor.from_anthropic(model_client)
237 | model_name = self.model.model
238 |
239 | if isinstance(self.model, Gemini):
240 | structured_response = client.chat.completions.create(
241 | response_model=self.response_model,
242 | messages=[*self.extra_context, {"role": "user", "parts": [
243 | text
244 | ]}],
245 | max_tokens=self.max_tokens,
246 | )
247 | else:
248 | structured_response = client.chat.completions.create(
249 | model=model_name,
250 | messages=[*self.extra_context, {"role": "user", "content": text}],
251 | response_model=self.response_model,
252 | max_tokens=self.max_tokens,
253 | )
254 |
255 |
256 | return structured_response
257 |
258 | def __repr__(self):
259 | model_name = self.model
260 | return f"{self.__class__.__name__}(response_model={self.response_model}, model={repr(model_name)})"
261 |
--------------------------------------------------------------------------------
/overeasy/agents/scrape.py:
--------------------------------------------------------------------------------
1 | import base64
2 | from urllib.parse import urlparse, urljoin
3 | import requests
4 |
5 | import os
6 |
7 |
8 | def _is_valid_xml_char(c):
9 | codepoint = ord(c)
10 | return (
11 | 0x20 <= codepoint <= 0xD7FF or
12 | codepoint in (0x9, 0xA, 0xD) or
13 | 0xE000 <= codepoint <= 0xFFFD or
14 | 0x10000 <= codepoint <= 0x10FFFF
15 | )
16 |
17 | def _sanitize_string(s):
18 | return ''.join(c for c in s if _is_valid_xml_char(c))
19 |
20 | def _get_page_source(url):
21 | from playwright.sync_api import sync_playwright
22 | with sync_playwright() as p:
23 | browser = p.chromium.launch(headless=True)
24 | page = browser.new_page()
25 | page.goto(url)
26 | page.wait_for_load_state('domcontentloaded')
27 | content = page.content()
28 | browser.close()
29 | return content
30 |
31 | def _download_resource(url, base_url):
32 | full_url = urljoin(base_url, url) if not url.startswith(('http://', 'https://')) else url
33 | response = requests.get(full_url)
34 | response.raise_for_status()
35 | return response.content
36 |
37 | def _embed_stylesheets(tree, base_url):
38 | from lxml import html
39 |
40 | for link in tree.xpath('//link[@rel="stylesheet"]'):
41 | href = link.get('href')
42 | if href:
43 | css_content = _download_resource(href, base_url).decode('utf-8')
44 | style = html.Element('style')
45 | style.text = css_content
46 | link.getparent().replace(link, style)
47 |
48 | def _embed_scripts(tree, base_url):
49 | from lxml import html
50 |
51 | for script in tree.xpath('//script[@src]'):
52 | src = script.get('src')
53 | if src:
54 | js_content = _download_resource(src, base_url)
55 | script_tag = html.Element('script')
56 | script_tag.set('type', 'text/javascript')
57 | script_tag.text = _sanitize_string(js_content.decode('utf-8', errors='ignore'))
58 | script.getparent().replace(script, script_tag)
59 |
60 | for link in tree.xpath('//link[@rel="modulepreload"]'):
61 | href = link.get('href')
62 | if href:
63 | js_content = _download_resource(href, base_url)
64 | script_tag = html.Element('script')
65 | script_tag.set('type', 'module')
66 | script_tag.text = _sanitize_string(js_content.decode('utf-8', errors='ignore'))
67 | link.getparent().replace(link, script_tag)
68 |
69 | def _embed_images(tree, base_url):
70 | for img in tree.xpath('//img'):
71 | src = img.get('src')
72 | if src and not src.startswith('data:'):
73 | img_content = _download_resource(src, base_url)
74 | img_base64 = base64.b64encode(img_content).decode('utf-8')
75 | img.set('src', f"data:image/{src.split('.')[-1]};base64,{img_base64}")
76 |
77 | def _remove_footers(tree):
78 | for footer in tree.xpath('//footer'):
79 | footer.getparent().remove(footer)
80 |
81 | def _process_webpage(url):
82 | from lxml import html
83 |
84 | parsed_url = urlparse(url)
85 | base_url = f"{parsed_url.scheme}://{parsed_url.netloc}"
86 |
87 | html_content = _get_page_source(url)
88 | tree = html.fromstring(html_content)
89 |
90 | _embed_stylesheets(tree, base_url)
91 | _embed_scripts(tree, base_url)
92 | _embed_images(tree, base_url)
93 | _remove_footers(tree)
94 |
95 | for button in tree.xpath('//button[@aria-label="Download"]'):
96 | button.getparent().remove(button)
97 |
98 | dark_mode_path = os.path.join(os.path.dirname(__file__), 'dark_mode.txt')
99 | with open(dark_mode_path, 'r') as f:
100 | dark_mode_content = f.read()
101 |
102 | html_root = tree.getroottree().getroot()
103 | dark_mode_element = html.fromstring(dark_mode_content)
104 | html_root.append(dark_mode_element)
105 |
106 | return html.tostring(tree, pretty_print=True, encoding='utf-8').decode('utf-8')
107 |
108 | def scrape_and_inline_to_buffer(url):
109 | """
110 | Scrape a webpage, inline all resources, remove footers, and return the result as a string.
111 |
112 | Args:
113 | url (str): The URL of the webpage to scrape.
114 |
115 | Returns:
116 | str: The processed HTML content as a string.
117 | """
118 | return _process_webpage(url)
119 |
120 | def scrape_and_inline_to_file(url, output_file='gradio_visualization.html'):
121 | """
122 | Scrape a webpage, inline all resources, remove footers, and save the result to a file.
123 |
124 | Args:
125 | url (str): The URL of the webpage to scrape.
126 | output_file (str): The name of the file to save the result. Defaults to 'requested.html'.
127 |
128 | Returns:
129 | None
130 | """
131 | processed_html = _process_webpage(url)
132 | with open(output_file, 'w', encoding='utf-8') as file:
133 | file.write(processed_html)
134 |
135 |
--------------------------------------------------------------------------------
/overeasy/agents/split_join_agent.py:
--------------------------------------------------------------------------------
1 | from overeasy.types import Agent, ExecutionNode, Detections, DetectionType, ExecutionGraph, NullExecutionNode
2 | from overeasy.logging import log_time
3 | from typing import List, Tuple, Union
4 | import numpy as np
5 |
6 | Node = Union[ExecutionNode, NullExecutionNode]
7 |
8 | class SplitAgent(Agent):
9 | @log_time
10 | def execute(self, node: ExecutionNode) -> List[Node]:
11 | result : List[Node] = []
12 | if not isinstance(node.data, Detections):
13 | raise ValueError(f"ExecutionNode data must be of type Detections, got {type(node.data)}")
14 | detections: Detections = node.data
15 | if detections.detection_type != DetectionType.BOUNDING_BOX:
16 | raise ValueError("Detection type must be BOUNDING_BOX")
17 |
18 | for split_detection in detections.split():
19 | parent_xyxy = split_detection.xyxy[0]
20 | child_image = node.image.crop(parent_xyxy)
21 | child_detection = Detections(
22 | xyxy=np.array([[0, 0, child_image.width, child_image.height]]),
23 | class_ids=split_detection.class_ids,
24 | confidence=split_detection.confidence,
25 | classes=split_detection.classes,
26 | data=split_detection.data,
27 | detection_type=DetectionType.CLASSIFICATION
28 | )
29 | result.append(ExecutionNode(child_image, child_detection, parent_detection=split_detection))
30 |
31 | if len(result) == 0:
32 | result.append(NullExecutionNode(parent_detection=detections))
33 |
34 | return result
35 |
36 | def __repr__(self):
37 | return f"{self.__class__.__name__}()"
38 |
39 | def combine_detections(dets: List[Detections], parent_dets: List[Detections]) -> Detections:
40 | if not all(isinstance(x, Detections) for x in dets) or not all(isinstance(x, Detections) for x in parent_dets):
41 | raise ValueError("Cannot combine detections")
42 |
43 | det_types = [x.detection_type for x in dets]
44 | if not all(x == det_types[0] for x in det_types):
45 | raise ValueError("Cannot combine detections of different types")
46 |
47 | det_type = det_types[0]
48 | parent_detection_type = parent_dets[0].detection_type
49 |
50 | if det_type == DetectionType.CLASSIFICATION:
51 | if any(len(det.class_ids) > 1 for det in dets):
52 | raise ValueError("Detections with multiple classes are not supported for combination")
53 |
54 | uniq_classes = set()
55 | for det in dets:
56 | for cls in det.classes:
57 | uniq_classes.add(cls)
58 | classes = np.array(list(uniq_classes))
59 | class_id_map = {cls: idx for idx, cls in enumerate(classes)}
60 | xyxy = []
61 | class_ids = []
62 | confidence = []
63 |
64 |
65 | for det, parent_det in zip(dets, parent_dets):
66 | if len(parent_det.class_ids) == 0:
67 | continue
68 | confidence.append(det.confidence[0] if det.confidence is not None else None)
69 | cls = det.class_names[0]
70 | class_ids.append(class_id_map[cls])
71 | xyxy.append(parent_det.xyxy[0])
72 |
73 | return Detections(
74 | xyxy=np.array(xyxy) if len(xyxy) > 0 else np.zeros((0, 4)),
75 | class_ids=np.array(class_ids),
76 | confidence=np.array(confidence),
77 | classes=classes,
78 | detection_type=parent_detection_type
79 | )
80 |
81 | elif det_type == DetectionType.BOUNDING_BOX:
82 | # TODO: Implement this
83 | pass
84 | elif det_type == DetectionType.SEGMENTATION:
85 | # TODO: Implement this
86 | pass
87 |
88 |
89 | return Detections.empty()
90 |
91 |
92 | class JoinAgent(Agent):
93 | # Mutates the input graph by adding a new layer of child nodes
94 | @log_time
95 | def join(self, graph: ExecutionGraph, target_split: int) -> List[Node]:
96 | topsort = graph.top_sort()
97 | parents_all = topsort[target_split]
98 |
99 | if len(parents_all) == 1 and isinstance(parents_all[0], NullExecutionNode):
100 | if len(topsort[-1]) == 1 and isinstance(topsort[-1][0], NullExecutionNode):
101 | null_child = NullExecutionNode()
102 | graph.add_child(topsort[-1][0], null_child)
103 | return [null_child]
104 | else:
105 | raise ValueError("Graph is formatted incorrectly")
106 |
107 | leaves: List[Node] = []
108 | def merge_nodes(node_and_parent_det: List[Tuple[Node, Detections]], parent: ExecutionNode):
109 | if len(node_and_parent_det) == 0:
110 | return
111 | nodes = [x[0] for x in node_and_parent_det]
112 | parent_dets = [x[1] for x in node_and_parent_det]
113 |
114 | original_data = [node.data if isinstance(node, ExecutionNode) else None for node in nodes]
115 | merged_node = ExecutionNode(parent.image, original_data)
116 |
117 | if all(isinstance(x, Detections) for x in original_data):
118 | merged_node.data = combine_detections(original_data, parent_dets) #type: ignore
119 |
120 | for node in nodes:
121 | graph.add_child(node, merged_node)
122 | leaves.append(merged_node)
123 |
124 | # Filter out non execution nodes
125 | parents = [p for p in parents_all if isinstance(p, ExecutionNode)]
126 | split_children = [graph.children(p) for p in parents]
127 | to_merge = topsort[-1]
128 |
129 | ind = 0
130 | for (parent_list, parent_node) in zip(split_children, parents):
131 | nodes: List[Tuple[Node, Detections]] = []
132 | for child in parent_list:
133 | if child.parent_detection is None:
134 | raise ValueError("Parent detection is None")
135 | node = to_merge[ind]
136 |
137 | nodes.append((node, child.parent_detection))
138 | ind+=1
139 | if isinstance(parent_node, NullExecutionNode):
140 | raise ValueError("Parent node is NullExecutionNode")
141 | merge_nodes(nodes, parent_node)
142 |
143 |
144 |
145 |
146 | return leaves
147 |
148 | def __repr__(self):
149 | return f"{self.__class__.__name__}()"
--------------------------------------------------------------------------------
/overeasy/agents/workflow.py:
--------------------------------------------------------------------------------
1 | from overeasy.types import *
2 | from overeasy.agents import SplitAgent, JoinAgent
3 | from overeasy.visualize_utils import annotate
4 | from typing import List, Tuple
5 | from PIL import Image
6 | import numpy as np
7 | import gradio as gr
8 | from math import fabs, sqrt
9 | from tqdm import tqdm
10 | from dataclasses import field, dataclass
11 | from overeasy.dirs import FAVICON_PATH
12 | from typing import Optional, Any, Dict, Union
13 | from io import StringIO
14 | from overeasy.agents.scrape import scrape_and_inline_to_buffer
15 | import time
16 | import threading
17 | import sys
18 |
19 | def _visualize_layer(layer: List[Node]) -> List[Tuple[Optional[Image.Image], str]]:
20 | images: List[Tuple[Optional[Image.Image], str]] = []
21 |
22 | if all(isinstance(node, ExecutionNode) and isinstance(node.data, Detections) for node in layer):
23 | for node in layer:
24 | assert isinstance(node, ExecutionNode)
25 | detection = node.data
26 | if detection.detection_type == DetectionType.BOUNDING_BOX:
27 | images.append((annotate(node.image, detection), "Bounding Box"))
28 | elif detection.detection_type == DetectionType.SEGMENTATION:
29 | images.append((annotate(node.image, detection), "Segmentation"))
30 | elif detection.detection_type == DetectionType.CLASSIFICATION:
31 | labels = [f"{name} {score:.2f}" if score is not None else f"{name}" for name, score in zip(detection.class_names, detection.confidence_scores)]
32 | stringified = labels[0] if len(labels)==1 else str(labels)
33 | images.append((node.image, stringified))
34 | else:
35 | images = [(x.image, str(x.data)) if isinstance(x, ExecutionNode) else (None, "None") for x in layer]
36 | return images
37 |
38 | # Can ignore JoinAgents as they are handled earlier in the flow
39 | def handle_node(node: ExecutionNode, agent: Agent) -> List[ExecutionNode]:
40 | if isinstance(agent, SplitAgent):
41 | return agent.execute(node)
42 | # Superclass so we dont have to check for children
43 | elif isinstance(agent, ImageToDataAgent) :
44 | return [ExecutionNode(node.image, agent.execute(node.image))]
45 | elif isinstance(agent, DataAgent):
46 | return [ExecutionNode(node.image, agent.execute(node.data))]
47 | else:
48 | raise ValueError(f"Agent {agent} is not a valid agent type")
49 | @dataclass(frozen=True)
50 | class Workflow:
51 | steps: List[Agent]
52 | join_to_split_map: Dict[int, int] = field(default_factory=dict, init=False)
53 |
54 | def __post_init__(self):
55 | splits = []
56 | for i, agent in enumerate(self.steps):
57 | if isinstance(agent, SplitAgent):
58 | splits.append(i)
59 | elif isinstance(agent, JoinAgent):
60 | if len(splits) == 0:
61 | raise ValueError(f"JoinAgent at index {i} has no matching SplitAgent")
62 | self.join_to_split_map[i] = splits.pop()
63 |
64 |
65 | def __repr__(self):
66 | return f"Workflow(steps={self.steps})"
67 |
68 | # Return leaves of the graph and the graph itself
69 | def execute(self, input_image: Image.Image, data: Optional[Any] = None) -> Tuple[List[ExecutionNode], ExecutionGraph]:
70 |
71 | if input_image is None:
72 | raise ValueError("Input image is None")
73 | elif isinstance(input_image, np.ndarray):
74 | raise ValueError("Input image is a numpy array, please convert it to a PIL.Image first using Image.fromarray(), make sure to convert your color channels to RGB!")
75 | elif not isinstance(input_image, Image.Image):
76 | raise ValueError("Input image is not a valid image format, must be a PIL.Image")
77 |
78 | input_image = input_image.convert("RGB")
79 |
80 | root = ExecutionNode(input_image, data)
81 | graph = ExecutionGraph(root)
82 |
83 | intermediate_results : List[Node] = [root]
84 | for ind, agent in enumerate(self.steps):
85 | if hasattr(agent, 'model') and isinstance(agent.model, Model):
86 | agent.model.load_resources()
87 | try:
88 | if isinstance(agent, JoinAgent):
89 | target_split = self.join_to_split_map[ind]
90 | intermediate_results = agent.join(graph, target_split)
91 | continue
92 |
93 | next_results: List[Node] = []
94 | for node in intermediate_results:
95 | if isinstance(node, NullExecutionNode):
96 | null_node = NullExecutionNode()
97 | graph.add_child(node, null_node)
98 | next_results.append(null_node)
99 | continue
100 | # Now node must be a ExecutionNode
101 | assert isinstance(node, ExecutionNode)
102 | res = handle_node(node, agent)
103 | for child in res:
104 | graph.add_child(node, child)
105 | next_results.extend(res)
106 |
107 | intermediate_results = next_results
108 | finally:
109 | if hasattr(agent, 'model') and isinstance(agent.model, Model):
110 | agent.model.release_resources()
111 |
112 | return [node for node in intermediate_results if isinstance(node, ExecutionNode)], graph
113 |
114 |
115 | def execute_multiple(self, input_images: List[Image.Image]) -> Tuple[List[List[ExecutionNode]], List[ExecutionGraph]]:
116 | if not all(isinstance(img, Image.Image) for img in input_images):
117 | raise ValueError("All input images must be of type PIL.Image")
118 | # Normalize image format
119 | input_images = [img.convert("RGB") for img in input_images]
120 |
121 | all_graphs = [ExecutionGraph(ExecutionNode(img, None)) for img in input_images]
122 | intermediate_results: List[List[Node]] = [[graph.root] for graph in all_graphs]
123 |
124 | for ind, agent in enumerate(tqdm(self.steps, desc="Processing steps")):
125 | try:
126 | if hasattr(agent, 'model') and isinstance(agent.model, Model):
127 | agent.model.load_resources()
128 |
129 | if isinstance(agent, JoinAgent):
130 | target_split = self.join_to_split_map[ind]
131 | for i, graph in enumerate(all_graphs):
132 | intermediate_results[i] = agent.join(graph, target_split)
133 | else:
134 | for i in range(len(all_graphs)):
135 | next_results: List[Node] = []
136 | for node in intermediate_results[i]:
137 | if isinstance(node, NullExecutionNode):
138 | null_node = NullExecutionNode()
139 | all_graphs[i].add_child(node, null_node)
140 | next_results.append(null_node)
141 | elif isinstance(node, ExecutionNode):
142 | res = handle_node(node, agent)
143 | for child in res:
144 | all_graphs[i].add_child(node, child)
145 | next_results.extend(res)
146 | intermediate_results[i] = next_results
147 |
148 | finally:
149 | if hasattr(agent, 'model') and isinstance(agent.model, Model):
150 | agent.model.release_resources()
151 |
152 | all_leaves = [
153 | [node for node in results if isinstance(node, ExecutionNode)]
154 | for results in intermediate_results
155 | ]
156 |
157 | return all_leaves, all_graphs
158 |
159 | def to_steps(self, graph: ExecutionGraph) -> List[Tuple[str, List[Tuple[Optional[Image.Image], str]], str]]:
160 | workflow_steps_names = ["Input Image"]
161 | code_snippets = [""]
162 | for i in range(len(self.steps)):
163 | workflow_steps_names.append(self.steps[i].__class__.__name__)
164 | code_snippets.append(repr(self.steps[i]))
165 |
166 | layers = graph.top_sort()
167 | steps = []
168 | for i, layer in enumerate(layers):
169 | code = f"```python\n{code_snippets[i]}\n```"
170 | layer_images = _visualize_layer(layer)
171 | steps.append((f"Step {i+1}: {workflow_steps_names[i]}", layer_images, code))
172 |
173 | return steps
174 |
175 | def visualize(self, graph: ExecutionGraph, share: bool = False, prevent_thread_lock: bool = False):
176 | steps = self.to_steps(graph)
177 | css = """
178 | .gradio-container .single-image {
179 | display: flex;
180 | justify-content: center;
181 | align-items: center;
182 | }
183 | """
184 | with gr.Blocks(css=css) as demo:
185 | for step_title, images, code in steps:
186 | with gr.Group():
187 | gr.Markdown(f"### {step_title}")
188 | gr.Markdown(f"{code}")
189 | filtered_images = [(img, label) for img, label in images if img is not None]
190 | if len(filtered_images) > 1:
191 | gr.Gallery(
192 | value=filtered_images,
193 | height="max-content",
194 | object_fit="contain",
195 | show_label=False,
196 | columns=max(3, min(9, round(sqrt(len(images))))),
197 | )
198 | elif len(filtered_images) == 1:
199 | [image] = filtered_images
200 | with gr.Row():
201 | with gr.Column(scale=1):
202 | pass
203 | with gr.Column(scale=4):
204 | gr.Image(
205 | width=640,
206 | value=image[0],
207 | label=image[1],
208 | )
209 | with gr.Column(scale=1):
210 | pass
211 | elif len(filtered_images) == 0:
212 | with gr.Row():
213 | with gr.Column(scale=1):
214 | pass
215 | with gr.Column(scale=4):
216 | gr.Text("No image to display skipped because previous split was empty")
217 | with gr.Column(scale=1):
218 | pass
219 |
220 |
221 | demo.launch(favicon_path=FAVICON_PATH, share=share, prevent_thread_lock=prevent_thread_lock)
222 |
223 | def visualize_to_html_string(self, graph: ExecutionGraph):
224 | output_buffer = StringIO()
225 | original_stdout = sys.stdout
226 |
227 | def run_gradio():
228 | try:
229 | sys.stdout = output_buffer
230 | self.visualize(graph, prevent_thread_lock=True)
231 | finally:
232 | sys.stdout = original_stdout
233 |
234 | gradio_thread = threading.Thread(target=run_gradio)
235 | gradio_thread.daemon = True
236 | gradio_thread.start()
237 |
238 | # Monitor the output buffer for the "Running on" string
239 | start_time = time.time()
240 | url = None
241 | while time.time() - start_time < 10: # Wait for up to 30 seconds
242 | output_buffer.seek(0)
243 | for line in output_buffer:
244 | print("thread:", line.strip())
245 | if "Running on local URL:" in line:
246 | url = line.split("Running on local URL:")[1].strip()
247 | break
248 | if url:
249 | break
250 | time.sleep(0.1) # Short sleep to prevent busy-waiting
251 |
252 | sys.stdout = original_stdout
253 | if url:
254 | gradio_thread.join()
255 | buffer = scrape_and_inline_to_buffer(url)
256 | return buffer
257 | else:
258 | gradio_thread.join()
259 | raise RuntimeError("Failed to determine Gradio server URL")
260 |
261 |
262 | def visualize_to_file(self, graph: ExecutionGraph, file_path: str):
263 | html_string = self.visualize_to_html_string(graph)
264 | with open(file_path, "w") as f:
265 | f.write(html_string)
--------------------------------------------------------------------------------
/overeasy/assets/egg_logo.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/overeasy-sh/overeasy/8de64ae6a21e0e8c757d3359c52c6134558c34b5/overeasy/assets/egg_logo.png
--------------------------------------------------------------------------------
/overeasy/assets/favicon.ico:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/overeasy-sh/overeasy/8de64ae6a21e0e8c757d3359c52c6134558c34b5/overeasy/assets/favicon.ico
--------------------------------------------------------------------------------
/overeasy/dirs.py:
--------------------------------------------------------------------------------
1 | import os
2 |
3 | ROOT = os.path.expanduser("~/.overeasy")
4 | FAVICON_PATH = os.path.join(os.path.dirname(__file__), "./assets/favicon.ico")
--------------------------------------------------------------------------------
/overeasy/download_utils.py:
--------------------------------------------------------------------------------
1 | import os
2 | import tempfile
3 | import urllib
4 | import progressbar
5 |
6 | class URLProgressBar():
7 | def __init__(self, filename: str):
8 | print(f"\nDownloading {filename} ...")
9 | self.pbar = None
10 |
11 | def __call__(self, block_num, block_size, total_size):
12 | if not self.pbar:
13 | self.pbar=progressbar.ProgressBar(maxval=total_size)
14 | self.pbar.start()
15 |
16 | downloaded = block_num * block_size
17 | if downloaded < total_size:
18 | self.pbar.update(downloaded)
19 | else:
20 | self.pbar.finish()
21 |
22 | # It's important to download files this way to prevent files from being corrupted
23 | # if a error occurs while downloading.
24 | def atomic_retrieve_and_rename(url: str, destination_path: str):
25 | tmp_file_path = None
26 | try:
27 | with tempfile.NamedTemporaryFile(delete=False) as tmp_file:
28 | tmp_file_path = tmp_file.name
29 | urllib.request.urlretrieve(url, tmp_file_path, reporthook=URLProgressBar(os.path.basename(destination_path)))
30 | os.rename(tmp_file_path, destination_path)
31 | except Exception as e:
32 | if tmp_file_path and os.path.exists(tmp_file_path):
33 | os.remove(tmp_file_path)
34 | raise e
--------------------------------------------------------------------------------
/overeasy/logging.py:
--------------------------------------------------------------------------------
1 | import time
2 | from tabulate import tabulate
3 | from functools import wraps
4 | from collections import defaultdict
5 | from typing import Callable, Any, Dict
6 |
7 | function_stats: Dict[str, Dict[str, float]] = defaultdict(lambda: {"count": 0, "total_time": 0.0})
8 |
9 | def log_time(func: Callable[..., Any]) -> Callable[..., Any]:
10 | @wraps(func)
11 | def wrapper(*args: Any, **kwargs: Any) -> Any:
12 | start_time = time.perf_counter()
13 | result = func(*args, **kwargs)
14 | end_time = time.perf_counter()
15 |
16 | if args and hasattr(args[0], '__class__'):
17 | name = f"{args[0].__class__.__name__}.{func.__name__}"
18 | else:
19 | name = func.__qualname__
20 |
21 |
22 | function_stats[name]["count"] += 1
23 | function_stats[name]["total_time"] += end_time - start_time
24 |
25 | return result
26 | return wrapper
27 |
28 |
29 | def print_summary():
30 | total_run_time = sum(stats["total_time"] for stats in function_stats.values())
31 | table = []
32 |
33 | for func_name, stats in function_stats.items():\
34 |
35 | average_time = stats["total_time"] / stats["count"]
36 |
37 | if stats['total_time'] < 0.1:
38 | total_time_str = f"{average_time*1000:.2f}ms"
39 | elif stats['total_time'] < 60:
40 | total_time_str = f"{average_time:.2f}s"
41 | else:
42 | minutes, seconds = divmod(average_time, 60)
43 | total_time_str = f"{int(minutes)}m {seconds:.2f}s"
44 |
45 |
46 | proportion = (stats["total_time"] / total_run_time) * 100
47 | table.append([func_name, stats['count'], total_time_str, f"{proportion:.2f}%"])
48 |
49 | headers = ["Function Name", "Calls", "Average Time (s)", "Proportion of Total Runtime"]
50 | print(tabulate(table, headers=headers, tablefmt="grid"))
--------------------------------------------------------------------------------
/overeasy/models/LLMs/__init__.py:
--------------------------------------------------------------------------------
1 | from .openai import GPT, GPTVision
2 | from .qwenvl import QwenVL
3 | from .gemini import Gemini
4 | from .anthropic import Claude
5 | from .pali_gemma import PaliGemma
--------------------------------------------------------------------------------
/overeasy/models/LLMs/anthropic.py:
--------------------------------------------------------------------------------
1 | from PIL import Image
2 | from overeasy.types import MultimodalLLM, OCRModel, Model
3 | from typing import Optional
4 | import anthropic
5 | import base64
6 | import io
7 | import os
8 |
9 | class Claude(MultimodalLLM, OCRModel):
10 | models = ["claude-3-5-sonnet-20240620", "claude-3-haiku-20240307", "claude-3-sonnet-20240229", "claude-3-opus-20240229"]
11 | def __init__(self, model: str = 'claude-3-5-sonnet-20240620', api_key: Optional[str] = None):
12 | self.api_key = api_key if api_key is not None else os.getenv("ANTHROPIC_API_KEY")
13 | # Support for shortened model names
14 | self.model = model
15 | self.client = None
16 |
17 | def load_resources(self):
18 | if self.api_key is None:
19 | raise ValueError("No API key found. Please provide an API key, or set the ANTHROPIC_API_KEY environment variable.")
20 | self.client = anthropic.Anthropic(api_key=self.api_key)
21 |
22 | super().load_resources()
23 |
24 | def release_resources(self):
25 | self.client = None
26 | super().release_resources()
27 |
28 | def prompt_with_image(self, image: Image.Image, query: str) -> str:
29 | if self.client is None:
30 | raise ValueError("Anthropic client not loaded. Please call load_resources() first.")
31 | buffered = io.BytesIO()
32 | image.save(buffered, format="JPEG")
33 | image_data = base64.b64encode(buffered.getvalue()).decode("utf-8")
34 |
35 | message = self.client.messages.create(
36 | model=self.model,
37 | max_tokens=1024,
38 | messages=[
39 | {
40 | "role": "user",
41 | "content": [
42 | {
43 | "type": "image",
44 | "source": {
45 | "type": "base64",
46 | "media_type": "image/jpeg",
47 | "data": image_data,
48 | },
49 | },
50 | {
51 | "type": "text",
52 | "text": query
53 | }
54 | ],
55 | }
56 | ],
57 | )
58 | return message.content[0].text # type: ignore
59 |
60 | def prompt(self, query: str) -> str:
61 | if self.client is None:
62 | raise ValueError("Anthropic client not loaded. Please call load_resources() first.")
63 | message = self.client.messages.create(
64 | model=self.model,
65 | max_tokens=1024,
66 | messages=[
67 | {"role": "user", "content": query}
68 | ]
69 | )
70 | return message.content[0].text # type: ignore
71 |
72 | def parse_text(self, image: Image.Image) -> str:
73 | return self.prompt_with_image(image, "Read the text from the image line by line only output the text.")
74 |
75 |
--------------------------------------------------------------------------------
/overeasy/models/LLMs/gemini.py:
--------------------------------------------------------------------------------
1 | import os
2 | from PIL import Image
3 | from overeasy.types import MultimodalLLM, OCRModel, Model
4 | from typing import Optional
5 | import google.generativeai as genai
6 | import warnings
7 |
8 | models = [
9 | "gemini-1.5-flash",
10 | "gemini-1.5-pro",
11 | "gemini-pro-vision",
12 | ]
13 |
14 | class Gemini(MultimodalLLM, OCRModel):
15 | def __init__(self, api_key: Optional[str] = None, model: str = "gemini-1.5-pro"):
16 | if model not in models:
17 | warnings.warn(f"Model {model} not supported for Gemini. Please use one of the following: {models}")
18 | self.api_key = api_key if api_key is not None else os.getenv("GOOGLE_API_KEY")
19 | self.model_name = model
20 | self.client = None
21 |
22 | def load_resources(self):
23 | if self.api_key is None:
24 | warnings.warn("No API key found. Using Gemini free tier. For higher usage please provide an API key, or set the GOOGLE_API_KEY environment variable.")
25 | else:
26 | genai.configure(api_key=self.api_key)
27 | self.client = genai.GenerativeModel(model_name=self.model_name)
28 |
29 |
30 | def release_resources(self):
31 | self.client = None
32 |
33 | def prompt_with_image(self, image: Image.Image, query: str) -> str:
34 | if self.client is None:
35 | raise ValueError("Client is not loaded. Please call load_resources() first.")
36 | response = self.client.generate_content([query, image], stream=True)
37 | response.resolve()
38 | return response.text
39 |
40 |
41 | def prompt(self, query: str) -> str:
42 | if self.client is None:
43 | raise ValueError("Client is not loaded. Please call load_resources() first.")
44 | response = self.client.generate_content([query])
45 | response.resolve()
46 | return response.text
47 |
48 |
49 | def parse_text(self, image: Image.Image) -> str:
50 | return self.prompt_with_image(image, "Read the text from the image line by line only output the text.")
--------------------------------------------------------------------------------
/overeasy/models/LLMs/openai.py:
--------------------------------------------------------------------------------
1 | import os
2 | from PIL import Image
3 | from overeasy.types import MultimodalLLM, LLM, OCRModel, Model
4 | from typing import Literal, Optional
5 | import io
6 | import base64
7 | import requests
8 | import warnings
9 | import backoff
10 | import openai
11 |
12 | def encode_image_to_base64(image: Image.Image) -> str:
13 | buffered = io.BytesIO()
14 | image.save(buffered, format="PNG")
15 | return base64.b64encode(buffered.getvalue()).decode('utf-8')
16 |
17 | current_models = [
18 | "gpt-4o-mini"
19 | "gpt-4o",
20 | "gpt-4o-mini-2024-07-18"
21 | "gpt-4o-2024-05-13",
22 | "gpt-4-turbo",
23 | "gpt-4-turbo-2024-04-09",
24 | "gpt-4-turbo-preview",
25 | "gpt-4-0125-preview",
26 | "gpt-4-1106-preview",
27 | "gpt-4-vision-preview",
28 | "gpt-4-1106-vision-preview",
29 | "gpt-4",
30 | "gpt-4-0613",
31 | "gpt-4-32k",
32 | "gpt-4-32k-0613",
33 | "gpt-3.5-turbo-0125",
34 | "gpt-3.5-turbo",
35 | "gpt-3.5-turbo-1106",
36 | "gpt-3.5-turbo-instruct",
37 | "gpt-3.5-turbo-16k",
38 | "gpt-3.5-turbo-0613",
39 | "gpt-3.5-turbo-16k-0613"
40 | ]
41 |
42 |
43 | @backoff.on_exception(backoff.expo, openai.RateLimitError, max_tries=6)
44 | def _prompt(query: str, model: str, client: openai.OpenAI, temperature: float, max_tokens: int) -> str:
45 | response = client.chat.completions.create(
46 | model=model,
47 | messages=[
48 | {"role": "user", "content": query}
49 | ],
50 | max_tokens=max_tokens,
51 | temperature=temperature
52 | )
53 |
54 | result = response.choices[0].message.content
55 | if result is None:
56 | raise ValueError("No content found in response")
57 |
58 | return result
59 |
60 |
61 |
62 | class GPT(LLM):
63 | def __init__(self, api_key: Optional[str] = None,
64 | model:str = "gpt-3.5-turbo",
65 | temperature: float = 0.5,
66 | max_tokens: int = 1024):
67 | self.api_key = api_key if api_key is not None else os.getenv("OPENAI_API_KEY")
68 | self.model = model
69 | self.temperature = temperature
70 | self.max_tokens = max_tokens
71 | self.client = None
72 | if self.model not in current_models:
73 | warnings.warn(f"Model {model} may not be supported. Please provide a valid model.")
74 |
75 | def prompt(self, query: str) -> str:
76 | if self.client is None:
77 | raise ValueError("Client is not loaded. Please call load_resources() first.")
78 | return _prompt(query, self.model, self.client, self.temperature, self.max_tokens)
79 |
80 | def load_resources(self):
81 | self.client = openai.OpenAI(api_key=self.api_key)
82 |
83 | def release_resources(self):
84 | self.client = None
85 |
86 | class GPTVision(MultimodalLLM, OCRModel, GPT):
87 | def __init__(self, api_key: Optional[str] = None,
88 | model : Literal["gpt-4o", "gpt-4o-mini", "gpt-4o-2024-05-13", "gpt-4-turbo", "gpt-4-turbo-2024-04-09"] = "gpt-4o",
89 | temperature: float = 0.5,
90 | max_tokens: int = 1024
91 | ):
92 | self.api_key = api_key if api_key is not None else os.getenv("OPENAI_API_KEY")
93 | self.model = model
94 | self.temperature = temperature
95 | self.max_tokens = max_tokens
96 | self.client = None
97 |
98 | def prompt_with_image(self, image: Image.Image, query: str) -> str:
99 | if self.client is None:
100 | raise ValueError("Client is not loaded. Please call load_resources() first.")
101 |
102 | base64_image = encode_image_to_base64(image)
103 |
104 | messages = [
105 | {
106 | "role": "user",
107 | "content": [
108 | {
109 | "type": "text",
110 | "text": query
111 | },
112 | {
113 | "type": "image_url",
114 | "image_url": {
115 | "url": f"data:image/png;base64,{base64_image}"
116 | }
117 | }
118 | ]
119 | }
120 | ]
121 |
122 |
123 | response = self.client.chat.completions.create(
124 | model=self.model,
125 | messages=messages,
126 | max_tokens=self.max_tokens,
127 | temperature=self.temperature
128 | )
129 |
130 | result = response.choices[0].message.content
131 | if result is None:
132 | raise ValueError("No content found in response")
133 |
134 | return result
135 |
136 | def parse_text(self, image: Image.Image) -> str:
137 | return self.prompt_with_image(image, "Read the text from the image line by line only output the text.")
138 |
139 |
--------------------------------------------------------------------------------
/overeasy/models/LLMs/pali_gemma.py:
--------------------------------------------------------------------------------
1 | from PIL import Image
2 | from overeasy.types import MultimodalLLM
3 | import torch
4 | from typing import Optional
5 |
6 | # Make sure to set your HuggingFace token to use PaliGemma
7 | class PaliGemma(MultimodalLLM):
8 | SIZES = [224, 448] # Image sizes the model was trained on
9 |
10 | MODEL_OPTIONS = (
11 | [f"google/paligemma-3b-mix-{size}" for size in SIZES] + # Should use mix for most cases especially when doing object detection
12 | [f"google/paligemma-3b-ft-vqav2-{size}" for size in SIZES] + # Diagram Understanding - 85.64 Accuracy on VQAV2
13 | [f"google/paligemma-3b-ft-cococap-{size}" for size in SIZES] + # COCO Captions - 144.6 CIDEr
14 | [f"google/paligemma-3b-ft-science-qa-{size}" for size in SIZES] + # Science Question Answering - 95.93 Accuracy on ScienceQA Img subset with no CoT
15 | [f"google/paligemma-3b-ft-refcoco-seg-{size}" for size in SIZES] + # Understanding References to Specific Objects in Images - 76.94 Mean IoU on refcoco, 72.18 Mean IoU on refcoco+, 72.22 Mean IoU on refcocog
16 | [f"google/paligemma-3b-ft-rsvqa-hr-{size}" for size in SIZES] # Remote Sensing Visual Question Answering - 92.61 Accuracy on test, 90.58 Accuracy on test2
17 | )
18 |
19 | def __init__(self, model_card: str = "google/paligemma-3b-mix-448", device: Optional[str] = None):
20 | self.model_card = model_card
21 | if self.model_card not in self.MODEL_OPTIONS:
22 | raise ValueError(f"Model {self.model_card} not found. Please select a valid model from {self.MODEL_OPTIONS}.")
23 |
24 | if device is None:
25 | device = "cuda" if torch.cuda.is_available() else "cpu"
26 | self.device = device
27 | self.model = None
28 | self.processor = None
29 |
30 | def load_resources(self):
31 | from transformers import AutoProcessor, PaliGemmaForConditionalGeneration
32 | self.model = PaliGemmaForConditionalGeneration.from_pretrained(self.model_card).to(self.device)
33 | self.processor = AutoProcessor.from_pretrained(self.model_card)
34 |
35 | def release_resources(self):
36 | self.model = None
37 | self.processor = None
38 |
39 | def prompt_with_image(self, image: Image.Image, query: str) -> str:
40 | if self.model is None or self.processor is None:
41 | raise ValueError("Model is not loaded. Please call load_resources() first.")
42 |
43 | inputs = self.processor(query, image, return_tensors="pt").to(self.device)
44 | output = self.model.generate(**inputs, max_new_tokens=100)
45 |
46 | return self.processor.decode(output[0], skip_special_tokens=True)[len(query):]
47 |
48 | def prompt(self, query: str) -> str:
49 | if self.model is None or self.processor is None:
50 | raise ValueError("Model is not loaded. Please call load_resources() first.")
51 |
52 | inputs = self.processor(query, None, return_tensors="pt").to(self.device)
53 | output = self.model.generate(**inputs, max_new_tokens=100)
54 |
55 | return self.processor.decode(output[0], skip_special_tokens=True)[len(query):]
--------------------------------------------------------------------------------
/overeasy/models/LLMs/qwenvl.py:
--------------------------------------------------------------------------------
1 | from transformers import AutoModelForCausalLM, AutoTokenizer
2 | from PIL import Image
3 | import tempfile
4 | from overeasy.types import MultimodalLLM
5 | from typing import Literal
6 | import importlib
7 | import torch
8 |
9 | from overeasy.types.base import OCRModel
10 |
11 | # use bf16
12 | # model = AutoModelForCausalLM.from_pretrained("Qwen/Qwen-VL-Chat", device_map="auto", trust_remote_code=True, bf16=True).eval()
13 | # use fp16
14 | # model = AutoModelForCausalLM.from_pretrained("Qwen/Qwen-VL-Chat", device_map="auto", trust_remote_code=True, fp16=True).eval()
15 | # use cpu only
16 | # model = AutoModelForCausalLM.from_pretrained("Qwen/Qwen-VL-Chat", device_map="cpu", trust_remote_code=True).eval()
17 | # use cuda device
18 |
19 | def setup_autogptq():
20 | import subprocess
21 | import torch
22 |
23 | # Check if CUDA is available
24 | if not torch.cuda.is_available():
25 | print("CUDA is not available. Exiting installation.")
26 | return
27 |
28 | # Get CUDA version from PyTorch
29 | cuda_version = torch.version.cuda
30 | if cuda_version:
31 | if cuda_version.startswith("11.8"):
32 | subprocess.run(["pip", "install", "optimum"], check=True)
33 | subprocess.run([
34 | "pip", "install", "auto-gptq", "--no-build-isolation",
35 | "--extra-index-url", "https://huggingface.github.io/autogptq-index/whl/cu118/"
36 | ], check=True)
37 | elif cuda_version.startswith("12.1"):
38 | subprocess.run(["pip", "install", "optimum"], check=True)
39 | subprocess.run([
40 | "pip", "install", "auto-gptq", "--no-build-isolation"
41 | ], check=True)
42 | else:
43 | print(f"Unsupported CUDA version: {cuda_version}")
44 | else:
45 | print("CUDA version could not be determined.")
46 |
47 |
48 | model_TYPE = Literal["base", "int4", "fp16", "bf16"]
49 | def load_model(model_type: model_TYPE):
50 | if model_type == "base":
51 | model = AutoModelForCausalLM.from_pretrained("Qwen/Qwen-VL-Chat", device_map="cuda", trust_remote_code=True).eval()
52 | elif model_type == "fp16":
53 | model = AutoModelForCausalLM.from_pretrained("Qwen/Qwen-VL-Chat", device_map="cuda", trust_remote_code=True, fp16=True).eval()
54 | elif model_type == "bf16":
55 | model = AutoModelForCausalLM.from_pretrained("Qwen/Qwen-VL-Chat", device_map="cuda", trust_remote_code=True, bf16=True).eval()
56 | elif model_type == "int4":
57 | def is_autogptq_installed():
58 | package_name = 'auto_gptq'
59 | spec = importlib.util.find_spec(package_name)
60 | return spec is not None
61 |
62 | if not is_autogptq_installed():
63 | setup_autogptq()
64 |
65 | if not is_autogptq_installed():
66 | raise Exception("AutoGPTQ is not installed can't use int4 quantization")
67 |
68 | model = AutoModelForCausalLM.from_pretrained(
69 | "Qwen/Qwen-VL-Chat-Int4",
70 | device_map="cuda",
71 | trust_remote_code=True
72 | ).eval()
73 | else:
74 | raise Exception("Model type not supported")
75 |
76 | # model.generation_config = GenerationConfig.from_pretrained("Qwen/Qwen-VL-Chat", trust_remote_code=True)
77 |
78 | return model
79 |
80 | class QwenVL(MultimodalLLM, OCRModel):
81 |
82 | def __init__(self, model_type: model_TYPE = "bf16"):
83 | if not torch.cuda.is_available():
84 | raise Exception("CUDA not available. Can't use QwenVL")
85 | if model_type not in ["base", "int4", "fp16", "bf16"]:
86 | raise Exception("Model type not supported")
87 | self.model_type = model_type
88 |
89 | def load_resources(self):
90 | self.model = load_model(self.model_type)
91 |
92 | if self.model_type == "int4":
93 | self.tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen-VL-Chat-Int4", trust_remote_code=True)
94 | else:
95 | self.tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen-VL-Chat", trust_remote_code=True)
96 |
97 | def release_resources(self):
98 | self.model = None
99 | self.tokenizer = None
100 |
101 | def prompt_with_image(self, image : Image.Image, query: str) -> str:
102 |
103 | with tempfile.NamedTemporaryFile(suffix=".png") as temp_file:
104 | image.save(temp_file.name)
105 | query = self.tokenizer.from_list_format([
106 | {'image': temp_file.name},
107 | {'text': query},
108 | ])
109 | response, history = self.model.chat(self.tokenizer, query=query, history=None, max_new_tokens=2048)
110 | return response
111 |
112 | def prompt(self, query: str) -> str:
113 | query = self.tokenizer.from_list_format([
114 | {'text': query},
115 | ])
116 | response, history = self.model.chat(self.tokenizer, query=query, history=None, max_new_tokens=2048)
117 | return response
118 |
119 | def parse_text(self, image: Image.Image):
120 | return self.prompt_with_image(image, "Read the text in this image line by line")
121 |
122 | def draw_bbox_on_latest_picture(self, response, history):
123 | if response.startswith("[") and response.endswith("]"):
124 | response = response[5:-6]
125 | if "" in response and "" in response:
126 | box = response[response.index("") + 5:response.index("")]
127 | box = box.split(",")
128 | box = [int(x) for x in box]
129 | image = Image.open(history[-1]["image"])
130 | return image.crop(box)
131 | return None
--------------------------------------------------------------------------------
/overeasy/models/__init__.py:
--------------------------------------------------------------------------------
1 | # Lift imports from subdirs
2 | from .detection import *
3 | from .recognition import *
4 | from .LLMs import *
5 | from .classification import *
6 |
7 |
8 | def warmup_models():
9 | try:
10 | qwen = QwenVL()
11 | qwen.load_resources()
12 | del qwen
13 | except:
14 | print("Skipping QwenVL")
15 |
16 | try:
17 | detic = DETIC()
18 | detic.load_resources()
19 | detic.set_classes(["hi"])
20 | del detic
21 | except:
22 | print("Skipping DETIC")
23 |
24 | bounding_box_models = [
25 | GroundingDINO(type=GroundingDINOModel.Pretrain_1_8M),
26 | GroundingDINO(type=GroundingDINOModel.SwinB),
27 | GroundingDINO(type=GroundingDINOModel.SwinT),
28 | GroundingDINO(type=GroundingDINOModel.mmdet_SwinB),
29 | GroundingDINO(type=GroundingDINOModel.mmdet_SwinL),
30 | GroundingDINO(type=GroundingDINOModel.mmdet_SwinB_zero),
31 | GroundingDINO(type=GroundingDINOModel.mmdet_SwinL_zero),
32 | YOLOWorld(model="yolov8s-worldv2"),
33 | YOLOWorld(model="yolov8m-worldv2"),
34 | YOLOWorld(model="yolov8l-worldv2"),
35 | YOLOWorld(model="yolov8s-world"),
36 | YOLOWorld(model="yolov8m-world"),
37 | YOLOWorld(model="yolov8l-world"),
38 | OwlV2(),
39 | ]
40 |
41 | for bounding_box_model in bounding_box_models:
42 | bounding_box_model.load_resources()
43 | del bounding_box_model
44 |
45 |
46 |
47 | clip = CLIP()
48 | del clip
49 | laionclip = LaionCLIP()
50 | del laionclip
51 | bio = BiomedCLIP()
52 | del bio
53 |
54 | pali_gemma = PaliGemma()
55 | pali_gemma.load_resources()
56 | del pali_gemma
--------------------------------------------------------------------------------
/overeasy/models/classification/__init__.py:
--------------------------------------------------------------------------------
1 | from .clip import CLIP
2 | from .laion_clip import LaionCLIP, BiomedCLIP, OpenCLIPBase
3 |
--------------------------------------------------------------------------------
/overeasy/models/classification/clip.py:
--------------------------------------------------------------------------------
1 | from typing import Literal, Optional
2 | from transformers import AutoProcessor, AutoModelForZeroShotImageClassification
3 | from PIL import Image
4 | import numpy as np
5 | from overeasy.types import Detections, DetectionType, ClassificationModel
6 | import open_clip
7 | import torch
8 |
9 | class CLIP(ClassificationModel):
10 | def __init__(self, model_card: str = "openai/clip-vit-large-patch14"):
11 | self.model_card = model_card
12 |
13 | def load_resources(self):
14 | self.processor = AutoProcessor.from_pretrained(self.model_card)
15 | self.model = AutoModelForZeroShotImageClassification.from_pretrained(self.model_card)
16 | self.model.eval()
17 |
18 | def release_resources(self):
19 | self.model = None
20 | self.processor = None
21 |
22 | def classify(self, image: Image.Image, classes: list) -> Detections:
23 | inputs = self.processor(text=classes, images=image, return_tensors="pt", padding=True)
24 | outputs = self.model(**inputs).logits_per_image
25 | softmax_outputs = torch.nn.functional.softmax(outputs, dim=-1).detach().numpy()
26 | index = softmax_outputs.argmax()
27 | return Detections(
28 | xyxy=np.zeros((1, 4)),
29 | class_ids=np.array([index]),
30 | confidence= np.array([softmax_outputs[0, index]]),
31 | classes=np.array(classes),
32 | detection_type=DetectionType.CLASSIFICATION
33 | )
34 |
--------------------------------------------------------------------------------
/overeasy/models/classification/laion_clip.py:
--------------------------------------------------------------------------------
1 | import open_clip
2 | import torch
3 | import numpy as np
4 | from PIL import Image
5 | from overeasy.types import Detections, DetectionType, ClassificationModel
6 | from typing import List, Optional
7 |
8 | class OpenCLIPBase(ClassificationModel):
9 | def __init__(self, model_card):
10 | self.model_card = model_card
11 | self.device = "cpu" # TODO: Work out the cuda impl
12 |
13 | def load_resources(self):
14 | model, _, preprocess_val = open_clip.create_model_and_transforms(self.model_card)
15 | self.tokenizer = open_clip.get_tokenizer(self.model_card)
16 | self.model = model
17 | self.model.to(self.device)
18 | self.preprocess = preprocess_val
19 |
20 | def release_resources(self):
21 | self.model = None
22 | self.tokenizer = None
23 | self.preprocess = None
24 |
25 | def classify(self, image: Image.Image, classes: List[str]) -> Detections:
26 | image = self.preprocess(image).to(self.device).unsqueeze(0)
27 | text = self.tokenizer(classes)
28 |
29 | with torch.no_grad():
30 | image_features = self.model.encode_image(image)
31 | text_features = self.model.encode_text(text)
32 | image_features /= image_features.norm(dim=-1, keepdim=True)
33 | text_features /= text_features.norm(dim=-1, keepdim=True)
34 |
35 | text_probs = (100.0 * image_features @ text_features.T).softmax(dim=-1).cpu()
36 | index = text_probs.argmax()
37 | confidence = text_probs[0, index]
38 | return Detections(
39 | xyxy=np.zeros((1, 4)),
40 | class_ids=np.array([index]),
41 | confidence=np.array([confidence]),
42 | classes=np.array(classes),
43 | detection_type=DetectionType.CLASSIFICATION
44 | )
45 |
46 | def batch_classify(self, images: List[Image.Image], classes: List[str], top_k: int = 1) -> List[Detections]:
47 | preprocessed_images = [self.preprocess(image).unsqueeze(0) for image in images]
48 | image_input = torch.cat(preprocessed_images)
49 | text_tokens = self.tokenizer(classes)
50 |
51 | with torch.no_grad():
52 | image_features = self.model.encode_image(image_input).float()
53 | text_features = self.model.encode_text(text_tokens).float()
54 | image_features /= image_features.norm(dim=-1, keepdim=True)
55 | text_features /= text_features.norm(dim=-1, keepdim=True)
56 |
57 | text_probs = (100.0 * image_features @ text_features.T).softmax(dim=-1)
58 | top_probs, top_labels = text_probs.cpu().topk(top_k, dim=-1)
59 |
60 | detections: List[Detections] = []
61 | for i in range(len(images)):
62 | detections.append(Detections(
63 | xyxy=np.zeros((top_k, 4)),
64 | class_ids=top_labels[i].numpy(),
65 | confidence=top_probs[i].numpy(),
66 | classes=np.array(classes)[top_labels[i].numpy()],
67 | detection_type=DetectionType.CLASSIFICATION
68 | ))
69 | return detections
70 |
71 | models = ["laion/CLIP-ViT-L-14-DataComp.XL-s13B-b90K", "hf-hub:laion/CLIP-ViT-bigG-14-laion2B-39B-b160k"]
72 | class LaionCLIP(OpenCLIPBase):
73 | def __init__(self, model_name: str = models[1]):
74 | if model_name not in models:
75 | raise ValueError(f"Model {model_name} not found")
76 | super().__init__(model_name)
77 |
78 | class BiomedCLIP(OpenCLIPBase):
79 | def __init__(self):
80 | super().__init__('hf-hub:microsoft/BiomedCLIP-PubMedBERT_256-vit_base_patch16_224')
--------------------------------------------------------------------------------
/overeasy/models/classification/siglip.py:
--------------------------------------------------------------------------------
1 | from PIL import Image
2 | import torch
3 | from typing import List
4 | from overeasy.types import Detections, ClassificationModel
5 | from transformers import AutoProcessor, AutoModelForZeroShotImageClassification
6 |
7 | class SigLIP(ClassificationModel):
8 | def __init__(self, model_card="google/siglip-base-patch16-224"):
9 | self.model_card = model_card
10 |
11 | def load_resources(self):
12 | self.processor = AutoProcessor.from_pretrained(self.model_card)
13 | self.model = AutoModelForZeroShotImageClassification.from_pretrained(self.model_card)
14 |
15 | def release_resources(self):
16 | self.processor = None
17 | self.model = None
18 |
19 | def classify(self, image: Image.Image, classes: List[str]) -> Detections:
20 | return self.model(image)
--------------------------------------------------------------------------------
/overeasy/models/detection/__init__.py:
--------------------------------------------------------------------------------
1 | from .dino import GroundingDINO, GroundingDINOModel
2 | from .detic import DETIC
3 | from .yoloworld import YOLOWorld
4 | from .owlv2 import OwlV2
5 |
6 |
--------------------------------------------------------------------------------
/overeasy/models/detection/detclipv2.py:
--------------------------------------------------------------------------------
1 | raise NotImplementedError("Detclipv2 is not yet implemented")
--------------------------------------------------------------------------------
/overeasy/models/detection/detic.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import multiprocessing as mp
3 | import os
4 | import subprocess
5 | import sys
6 | from typing import Any, Union
7 | import numpy as np
8 | import torch
9 | from overeasy.types import Detections, DetectionType
10 | from typing import List
11 | from PIL import Image
12 | import cv2
13 | from overeasy.types import BoundingBoxModel
14 | import subprocess
15 | from overeasy.download_utils import atomic_retrieve_and_rename
16 |
17 | VOCAB = "custom"
18 | CONFIDENCE_THRESHOLD = 0.3
19 | OVEREASY_DIR = os.path.abspath(os.path.expanduser("~/.overeasy"))
20 |
21 |
22 | def setup_cfg(args):
23 | from centernet.config import add_centernet_config
24 | from detic.config import add_detic_config
25 | from detectron2.config import get_cfg
26 |
27 | cfg = get_cfg()
28 | cfg.MODEL.DEVICE = "cpu" if args.cpu else "cuda"
29 | add_centernet_config(cfg)
30 | add_detic_config(cfg)
31 | cfg.merge_from_file(args.config_file)
32 | cfg.merge_from_list(args.opts)
33 | # Set score_threshold for builtin models
34 | cfg.MODEL.RETINANET.SCORE_THRESH_TEST = CONFIDENCE_THRESHOLD
35 | cfg.MODEL.ROI_HEADS.SCORE_THRESH_TEST = CONFIDENCE_THRESHOLD
36 | cfg.MODEL.PANOPTIC_FPN.COMBINE.INSTANCES_CONFIDENCE_THRESH = CONFIDENCE_THRESHOLD
37 | cfg.MODEL.ROI_BOX_HEAD.ZEROSHOT_WEIGHT_PATH = "rand" # load later
38 | if not args.pred_all_class:
39 | cfg.MODEL.ROI_HEADS.ONE_CLASS_PER_PROPOSAL = False
40 | cfg.freeze()
41 | return cfg
42 |
43 |
44 | def load_detic_model(classes : List[str], weights_file: str):
45 | original_dir = os.getcwd()
46 | try:
47 | sys.path.insert(0, os.path.join(OVEREASY_DIR, "Detic/third_party/CenterNet2/"))
48 | sys.path.insert(0, os.path.join(OVEREASY_DIR, "Detic/"))
49 | os.chdir(os.path.join(OVEREASY_DIR, "Detic/"))
50 |
51 | mp.set_start_method("spawn", force=True)
52 |
53 | args = argparse.Namespace()
54 |
55 | args.confidence_threshold = CONFIDENCE_THRESHOLD
56 | args.vocabulary = VOCAB
57 | args.opts = []
58 | args.config_file = "configs/Detic_LCOCOI21k_CLIP_SwinB_896b32_4x_ft4x_max-size.yaml"
59 | args.cpu = False if torch.cuda.is_available() else True
60 | args.opts.append("MODEL.WEIGHTS")
61 | args.opts.append(weights_file)
62 | args.output = None
63 | args.webcam = None
64 | args.video_input = None
65 | args.custom_vocabulary = ",".join(classes).rstrip(",")
66 | args.pred_all_class = False
67 | cfg = setup_cfg(args)
68 |
69 | from detic.predictor import VisualizationDemo
70 | from detectron2.data import MetadataCatalog
71 | for key in MetadataCatalog.list():
72 | MetadataCatalog.remove(key)
73 |
74 | # https://github.com/facebookresearch/Detic/blob/main/detic/predictor.py#L39
75 | demo = VisualizationDemo(cfg, args)
76 | return demo
77 | finally:
78 | sys.path.remove(os.path.join(OVEREASY_DIR, "Detic/third_party/CenterNet2/"))
79 | sys.path.remove(os.path.join(OVEREASY_DIR, "Detic/"))
80 | os.chdir(original_dir)
81 |
82 |
83 |
84 | HOME = os.path.expanduser("~")
85 | DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
86 |
87 | def check_dependencies():
88 | # Create the ~/.cache/overeasy directory if it doesn't exist
89 | os.makedirs(OVEREASY_DIR, exist_ok=True)
90 |
91 | detic_path = os.path.join(OVEREASY_DIR, "Detic")
92 | models_dir = os.path.join(detic_path, "models")
93 | model_path = os.path.join(
94 | models_dir, "Detic_LCOCOI21k_CLIP_SwinB_896b32_4x_ft4x_max-size.pth"
95 | )
96 |
97 | if subprocess.call(["which", "git"], stdout=subprocess.PIPE, stderr=subprocess.PIPE) != 0:
98 | raise EnvironmentError("git is not installed. Please install git to use Detic.")
99 |
100 | # Check if Detic is installed
101 | if not os.path.isdir(detic_path):
102 | subprocess.run(
103 | [
104 | "git",
105 | "clone",
106 | "https://github.com/facebookresearch/Detic.git",
107 | "--recurse-submodules",
108 | detic_path
109 | ],
110 | check=True
111 | )
112 |
113 | import platform
114 | install_cmds = [sys.executable, "-m", "pip", "install", "git+https://github.com/facebookresearch/detectron2.git", "--no-build-isolation"]
115 | if platform.system() == "Darwin":
116 | subprocess.run(
117 | ["export", "MACOSX_DEPLOYMENT_TARGET=10.13", "&&", *install_cmds],
118 | check=True
119 | )
120 | else:
121 | subprocess.run(
122 | install_cmds,
123 | check=True
124 | )
125 |
126 | os.makedirs(models_dir, exist_ok=True)
127 |
128 | model_url = "https://dl.fbaipublicfiles.com/detic/Detic_LCOCOI21k_CLIP_SwinB_896b32_4x_ft4x_max-size.pth"
129 |
130 | atomic_retrieve_and_rename(model_url, model_path)
131 |
132 | return os.path.abspath(model_path)
133 |
134 |
135 |
136 |
137 |
138 | class DETIC(BoundingBoxModel):
139 | def __init__(self):
140 | self.classes = None
141 |
142 | def load_resources(self):
143 | self.weights_file = check_dependencies()
144 |
145 | def release_resources(self):
146 | self.detic_model = None
147 |
148 | def set_classes(self, classes: List[str]):
149 | self.classes = classes
150 | self.detic_model = load_detic_model(classes, self.weights_file)
151 |
152 | def detect(self, image: Union[np.ndarray, Image.Image], classes: List[str], box_threshold=0.35, text_threshold=0.25) -> Detections:
153 | if self.classes is None:
154 | self.set_classes(classes)
155 | elif not np.array_equal(self.classes, classes):
156 | self.set_classes(classes)
157 |
158 | if isinstance(image, Image.Image):
159 | image = cv2.cvtColor(np.array(image), cv2.COLOR_RGB2BGR)
160 |
161 | predictions, visualized_output = self.detic_model.run_on_image(np.array(image))
162 | pred_boxes = predictions["instances"].pred_boxes.tensor.cpu().numpy()
163 | pred_classes = predictions["instances"].pred_classes.cpu().numpy()
164 | pred_scores = predictions["instances"].scores.cpu().numpy()
165 |
166 |
167 | if len(pred_classes) == 0:
168 | return Detections.empty()
169 |
170 | return Detections(
171 | xyxy=np.array(pred_boxes),
172 | detection_type=DetectionType.BOUNDING_BOX,
173 | class_ids=np.array(pred_classes),
174 | confidence=np.array(pred_scores),
175 | classes=self.classes,
176 | )
177 |
178 |
--------------------------------------------------------------------------------
/overeasy/models/detection/dino.py:
--------------------------------------------------------------------------------
1 | import os
2 | from re import T
3 | import cv2
4 | import numpy as np
5 | import torch
6 | from PIL import Image
7 | from typing import List, Union
8 | from overeasy.types import Detections, DetectionType
9 | from enum import Enum
10 | from overeasy.types import BoundingBoxModel
11 | import warnings
12 | import sys, io
13 | from overeasy.download_utils import atomic_retrieve_and_rename
14 | from typing import Any
15 |
16 | # Ignore the specific UserWarning about torch.meshgrid
17 | warnings.filterwarnings("ignore", message="torch.meshgrid: in an upcoming release, it will be required to pass the indexing argument.", category=UserWarning, module='torch.functional')
18 | warnings.filterwarnings("ignore", category=FutureWarning, message="The `device` argument is deprecated and will be removed in v5 of Transformers.")
19 | # Suppress UserWarning about use_reentrant parameter in torch.utils.checkpoint
20 | warnings.filterwarnings("ignore", message="torch.utils.checkpoint: the use_reentrant parameter should be passed explicitly. In version 2.4 we will raise an exception if use_reentrant is not passed. use_reentrant=False is recommended, but if you need to preserve the current default behavior, you can pass use_reentrant=True. Refer to docs for more details on the differences between the two variants.", category=UserWarning, module='torch.utils.checkpoint')
21 | # Suppress UserWarning about None of the inputs having requires_grad=True
22 | warnings.filterwarnings("ignore", message="None of the inputs have requires_grad=True. Gradients will be None", category=UserWarning, module='torch.utils.checkpoint')
23 |
24 |
25 | DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
26 |
27 | class GroundingDINOModel(Enum):
28 | SwinB = "swinb"
29 | SwinT = "swint"
30 |
31 | mapping = {
32 | "swinb": {
33 | "config": "https://raw.githubusercontent.com/IDEA-Research/GroundingDINO/main/groundingdino/config/GroundingDINO_SwinB_cfg.py",
34 | "checkpoint": "https://github.com/IDEA-Research/GroundingDINO/releases/download/v0.1.0-alpha2/groundingdino_swinb_cogcoor.pth"
35 | },
36 | "swint": {
37 | "config": "https://raw.githubusercontent.com/IDEA-Research/GroundingDINO/main/groundingdino/config/GroundingDINO_SwinT_OGC.py",
38 | "checkpoint": "https://github.com/IDEA-Research/GroundingDINO/releases/download/v0.1.0-alpha/groundingdino_swint_ogc.pth"
39 | },
40 | "mmdet_swinb_zero": {
41 | "config": "./mmdetection/configs/mm_grounding_dino/grounding_dino_swin-b_pretrain_obj365_goldg_v3det.py",
42 | "checkpoint": "https://download.openmmlab.com/mmdetection/v3.0/mm_grounding_dino/grounding_dino_swin-b_pretrain_obj365_goldg_v3det/grounding_dino_swin-b_pretrain_obj365_goldg_v3de-f83eef00.pth"
43 | },
44 | "mmdet_swinl_zero": {
45 | "config": "./mmdetection/configs/mm_grounding_dino/grounding_dino_swin-l_pretrain_obj365_goldg.py",
46 | "checkpoint": "https://download.openmmlab.com/mmdetection/v3.0/mm_grounding_dino/grounding_dino_swin-l_pretrain_obj365_goldg/grounding_dino_swin-l_pretrain_obj365_goldg-34dcdc53.pth"
47 | },
48 | "mmdet_swinb": {
49 | "config": "./mmdetection/configs/mm_grounding_dino/grounding_dino_swin-l_pretrain_obj365_goldg.py",
50 | "checkpoint": "https://download.openmmlab.com/mmdetection/v3.0/mm_grounding_dino/grounding_dino_swin-b_pretrain_all/grounding_dino_swin-b_pretrain_all-f9818a7c.pth"
51 | },
52 | "mmdet_swinl": {
53 | "config": "./mmdetection/configs/mm_grounding_dino/grounding_dino_swin-l_pretrain_all.py",
54 | "checkpoint": "https://download.openmmlab.com/mmdetection/v3.0/mm_grounding_dino/grounding_dino_swin-l_pretrain_all/grounding_dino_swin-l_pretrain_all-56d69e78.pth"
55 | }
56 |
57 | }
58 |
59 | def download_and_cache_grounding_dino(model: GroundingDINOModel):
60 | OVEREASY_CACHE_DIR = os.path.expanduser("~/.overeasy")
61 | GROUNDING_DINO_CACHE_DIR = os.path.join(OVEREASY_CACHE_DIR, "groundingdino")
62 |
63 | if not os.path.exists(GROUNDING_DINO_CACHE_DIR):
64 | os.makedirs(GROUNDING_DINO_CACHE_DIR)
65 |
66 | model_key = model.value
67 | if model_key not in mapping:
68 | raise ValueError(f"Unsupported model type: {model_key}")
69 |
70 | config_url = mapping[model_key]["config"]
71 | checkpoint_url = mapping[model_key]["checkpoint"]
72 |
73 | config_file = os.path.basename(config_url)
74 | checkpoint_file = os.path.basename(checkpoint_url)
75 |
76 | config_path = os.path.join(GROUNDING_DINO_CACHE_DIR, config_file)
77 | checkpoint_path = os.path.join(GROUNDING_DINO_CACHE_DIR, checkpoint_file)
78 |
79 | if not os.path.exists(checkpoint_path):
80 | atomic_retrieve_and_rename(checkpoint_url, checkpoint_path)
81 |
82 | if not os.path.exists(config_path):
83 | atomic_retrieve_and_rename(config_url, config_path)
84 |
85 | return config_path, checkpoint_path
86 |
87 |
88 |
89 | def load_grounding_dino(model: GroundingDINOModel):
90 | from groundingdino.util.inference import Model
91 |
92 | config_path, checkpoint_path = download_and_cache_grounding_dino(model)
93 |
94 | instantiate = lambda: Model(
95 | model_config_path=config_path,
96 | model_checkpoint_path=checkpoint_path,
97 | device=DEVICE,
98 | )
99 |
100 | try:
101 | grounding_dino_model = instantiate()
102 | return grounding_dino_model
103 | except Exception:
104 |
105 |
106 | grounding_dino_model = instantiate()
107 |
108 | return grounding_dino_model
109 |
110 | def combine_detections(detections_list: List[Detections], classes: List[str], overwrite_class_ids=None):
111 | if len(detections_list) == 0:
112 | return Detections.empty()
113 |
114 | if not all(isinstance(detection, Detections) for detection in detections_list):
115 | raise TypeError("All elements in detections_list must be instances of Detections.")
116 |
117 | if overwrite_class_ids is not None and len(overwrite_class_ids) != len(detections_list):
118 | raise ValueError("Length of overwrite_class_ids must match the length of detections_list.")
119 |
120 | # Initialize lists to collect combined attributes
121 | xyxy = []
122 | confidence = []
123 | class_ids = []
124 | data = []
125 | masks = []
126 |
127 | detection_types = [detection.detection_type for detection in detections_list]
128 | if len(set(detection_types)) > 1:
129 | raise ValueError("All detections in the list must have the same type.")
130 |
131 |
132 | for idx, detection in enumerate(detections_list):
133 | xyxy.append(detection.xyxy)
134 |
135 | if detection.confidence is not None:
136 | confidence.append(detection.confidence)
137 |
138 | if detection.class_ids is not None:
139 | if overwrite_class_ids is not None:
140 | class_ids.append(np.full_like(detection.class_ids, overwrite_class_ids[idx], dtype=np.int32))
141 | else:
142 | class_ids.append(detection.class_ids)
143 | if detection.masks is not None:
144 | masks.append(detection.masks)
145 |
146 | # Merge custom data from each detection
147 | data.append(detection.data)
148 |
149 | return Detections(
150 | xyxy=np.vstack(xyxy),
151 | masks=np.hstack(masks) if masks else None, # Assuming masks are not handled in this function
152 | classes = classes,
153 | confidence=np.hstack(confidence) if confidence else None,
154 | class_ids=np.hstack(class_ids) if class_ids else None,
155 | detection_type=detections_list[0].detection_type
156 | )
157 |
158 | class GroundingDINO(BoundingBoxModel):
159 | grounding_dino_model: Any
160 | box_threshold: float
161 | text_threshold: float
162 |
163 | def __init__(
164 | self, type: GroundingDINOModel = GroundingDINOModel.SwinB,
165 | box_threshold: float = 0.35,
166 | text_threshold: float = 0.25,
167 | ):
168 | self.box_threshold = box_threshold
169 | self.text_threshold = text_threshold
170 | self.model_type = type
171 |
172 | def load_resources(self):
173 | # if DEVICE.type != "cuda":
174 | # warnings.warn("CUDA not available. GroundingDINO may run slowly.")
175 |
176 | download_and_cache_grounding_dino(self.model_type)
177 | original_stdout = sys.stdout
178 | sys.stdout = io.StringIO()
179 | try:
180 | self.grounding_dino_model = load_grounding_dino(model=self.model_type)
181 | except Exception as e:
182 | print(sys.stdout.getvalue())
183 | print(f"Error loading GroundingDINO model: {e}")
184 | raise e
185 | finally:
186 | sys.stdout = original_stdout
187 |
188 | def release_resources(self):
189 | self.grounding_dino_model = None
190 |
191 | # def detect_multiple(self, images: List[np.ndarray], class_groups: List[List[str]], box_threshold) -> List[Detections]:
192 | # cv2_images = [cv2.cvtColor(np.array(image), cv2.COLOR_RGB2BGR) if isinstance(image, Image.Image) else image for image in images]
193 |
194 | # raise ValueError("Unsupported model type")
195 |
196 | def detect(self, image: Union[np.ndarray, Image.Image], classes: List[str], box_threshold=None, text_threshold=None) -> Detections:
197 | if box_threshold is None:
198 | box_threshold = self.box_threshold
199 | if text_threshold is None:
200 | text_threshold = self.text_threshold
201 |
202 |
203 | if isinstance(image, Image.Image):
204 | image = cv2.cvtColor(np.array(image), cv2.COLOR_RGB2BGR)
205 |
206 | sv_detection = self.grounding_dino_model.predict_with_classes(
207 | image=image,
208 | classes=classes,
209 | box_threshold=box_threshold,
210 | text_threshold=text_threshold,
211 | )
212 | valid_indexes = [i for i, class_id in enumerate(sv_detection.class_id) if class_id is not None]
213 | sv_detection = sv_detection[valid_indexes]
214 | detections = Detections.from_supervision_detection(sv_detection, classes=classes)
215 |
216 | return detections
--------------------------------------------------------------------------------
/overeasy/models/detection/owlv2.py:
--------------------------------------------------------------------------------
1 |
2 |
3 | from typing import List
4 | import numpy as np
5 | from PIL import Image
6 | from overeasy.types import BoundingBoxModel, Detections, DetectionType
7 | from transformers import pipeline
8 | import torch
9 |
10 | class OwlV2(BoundingBoxModel):
11 | def __init__(self):
12 | self.checkpoint = "google/owlv2-base-patch16-ensemble"
13 |
14 | def load_resources(self):
15 | self.detector = pipeline(model=self.checkpoint, task="zero-shot-object-detection")
16 |
17 | def release_resources(self):
18 | self.detector = None
19 |
20 | def detect(self, image: Image.Image, classes: List[str], threshold: float = 0) -> Detections:
21 | predictions = self.detector(image, candidate_labels=classes,)
22 | num_preds = len(predictions)
23 | xyxy = np.zeros((num_preds, 4), dtype=np.int32)
24 | confidence = np.zeros(num_preds, dtype=np.float32)
25 | class_ids = np.zeros(num_preds, dtype=np.int64)
26 |
27 | for i, pred in enumerate(predictions):
28 | x1, y1, x2, y2 = pred['box']['xmin'], pred['box']['ymin'], pred['box']['xmax'], pred['box']['ymax']
29 |
30 | xyxy[i] = [x1, y1, x2, y2]
31 | confidence[i] = pred['score']
32 | class_ids[i] = classes.index(pred['label'])
33 |
34 | slice = confidence > threshold
35 | xyxy = xyxy[slice, :]
36 | confidence = confidence[slice]
37 | class_ids = class_ids[slice]
38 |
39 | return Detections(
40 | xyxy=xyxy,
41 | confidence=confidence,
42 | class_ids=class_ids,
43 | classes=classes,
44 | detection_type=DetectionType.BOUNDING_BOX
45 | )
--------------------------------------------------------------------------------
/overeasy/models/detection/yoloworld.py:
--------------------------------------------------------------------------------
1 | import os
2 | import torch
3 | from typing import List, Union
4 | from overeasy.types import BoundingBoxModel, Detections
5 | import numpy as np
6 | from PIL import Image
7 | from overeasy.download_utils import atomic_retrieve_and_rename
8 | from ultralytics import YOLOWorld as YOLOWorld_ultralytics
9 |
10 |
11 | DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
12 |
13 |
14 | valid_models = [
15 | "yolov8l-world-cc3m.pt",
16 | "yolov8l-world.pt",
17 | "yolov8l-worldv2-cc3m.pt",
18 | "yolov8l-worldv2.pt",
19 | "yolov8m-world.pt",
20 | "yolov8m-worldv2.pt",
21 | "yolov8s-world.pt",
22 | "yolov8s-worldv2.pt",
23 | "yolov8x-world.pt",
24 | "yolov8x-worldv2.pt"
25 | ]
26 |
27 | def get_yoloworld_model(model: str) -> str:
28 | if not model.endswith(".pt"):
29 | model = model + ".pt"
30 |
31 | local_model_path = os.path.join(os.path.expanduser("~/.overeasy"), model)
32 | if os.path.exists(local_model_path):
33 | return local_model_path
34 |
35 | url = None
36 | if model in valid_models:
37 | url = f"https://github.com/ultralytics/assets/releases/download/v8.2.0/{model}"
38 | else:
39 | raise ValueError(f"Model {model} not valid. Valid models are: {valid_models}")
40 |
41 | atomic_retrieve_and_rename(url, local_model_path)
42 |
43 | return local_model_path
44 |
45 | # Load a pretrained YOLOv8s-worldv2 model
46 | class YOLOWorld(BoundingBoxModel):
47 | def __init__(self, model: str = "yolov8l-worldv2-cc3m"):
48 | self.model_name = model
49 |
50 | def load_resources(self):
51 | self.model_path = get_yoloworld_model(self.model_name)
52 | self.model = YOLOWorld_ultralytics(self.model_path)
53 |
54 | def release_resources(self):
55 | self.model = None
56 |
57 | def detect(self, image: Image.Image, classes: List[str]) -> Detections:
58 | self.model.set_classes(classes)
59 | results = self.model.predict(image, verbose=False)
60 | detections = Detections.from_ultralytics(results[0])
61 | return detections
62 |
--------------------------------------------------------------------------------
/overeasy/models/recognition/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/overeasy-sh/overeasy/8de64ae6a21e0e8c757d3359c52c6134558c34b5/overeasy/models/recognition/__init__.py
--------------------------------------------------------------------------------
/overeasy/py.typed:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/overeasy-sh/overeasy/8de64ae6a21e0e8c757d3359c52c6134558c34b5/overeasy/py.typed
--------------------------------------------------------------------------------
/overeasy/types/__init__.py:
--------------------------------------------------------------------------------
1 |
2 | from .base import *
3 | from .detections import Detections, DetectionType
4 |
--------------------------------------------------------------------------------
/overeasy/types/base.py:
--------------------------------------------------------------------------------
1 | import itertools
2 | from typing import Any, Dict, List, Union, Optional, Any
3 | from abc import ABC, abstractmethod
4 | from typing import List
5 |
6 | from overeasy.logging import log_time
7 | from .detections import Detections
8 | from PIL import Image
9 | from overeasy.visualize_utils import annotate, annotate_with_string
10 | from dataclasses import dataclass
11 | from collections import defaultdict
12 |
13 | __all__ = [
14 | "OCRModel",
15 | "LLM",
16 | "MultimodalLLM",
17 | "BoundingBoxModel",
18 | "ExecutionNode",
19 | "NullExecutionNode",
20 | "Node",
21 | "ExecutionGraph",
22 | "Model",
23 | "Agent",
24 | "ImageDetectionAgent",
25 | "ImageToTextAgent",
26 | "ImageToDataAgent",
27 | "DetectionAgent",
28 | "TextAgent",
29 | "DataAgent",
30 | "ClassificationModel",
31 | "CaptioningModel",
32 | ]
33 |
34 | @dataclass
35 | class ExecutionNode:
36 | image: Image.Image
37 | data: Union[Detections, Any]
38 | parent_detection: Optional[Detections] = None
39 |
40 | def data_is_detections(self) -> bool:
41 | return isinstance(self.data, Detections)
42 |
43 | def visualize(self, seed: Optional[int] = None) -> Image.Image:
44 | if self.data_is_detections():
45 | return annotate(self.image, self.data, seed)
46 | else:
47 | return annotate_with_string(self.image, str(self.data))
48 |
49 | @property
50 | def id(self):
51 | return id(self)
52 |
53 | @dataclass
54 | class NullExecutionNode():
55 | parent_detection: Optional[Detections] = None
56 |
57 | @property
58 | def id(self):
59 | return id(self)
60 |
61 | Node = Union[ExecutionNode, NullExecutionNode]
62 |
63 | @dataclass
64 | class ExecutionGraph:
65 | root: ExecutionNode
66 | edges: Dict[int, List[Node]]
67 | parent: Dict[int, List[Node]]
68 |
69 | def __init__(self, root: ExecutionNode):
70 | self.root = root
71 | self.edges = {}
72 | self.parent = {}
73 |
74 | def add_child(self, parent: Node, child: Node):
75 | assert isinstance(parent, ExecutionNode) or isinstance(parent, NullExecutionNode)
76 | assert isinstance(child, ExecutionNode) or isinstance(child, NullExecutionNode)
77 |
78 | if parent.id == child.id:
79 | raise ValueError("Cannot self loops to execution graph")
80 |
81 | # Add edges and reverse edges
82 | if parent.id not in self.edges:
83 | self.edges[parent.id] = []
84 | self.edges[parent.id].append(child)
85 |
86 | if child.id not in self.parent:
87 | self.parent[child.id] = []
88 | self.parent[child.id].append(parent)
89 |
90 | def ascii_graph(self):
91 | print("ExecutionGraph")
92 | id_counter = itertools.count()
93 | id_map = {}
94 |
95 | def get_mapped_id(original_id):
96 | if original_id not in id_map:
97 | id_map[original_id] = next(id_counter)
98 | return id_map[original_id]
99 |
100 | def print_node(node, prefix="", is_last=True):
101 | connector = "└── " if is_last else "├── "
102 | print(prefix + connector + f"Node(ID={get_mapped_id(node.id)})")
103 | if node.id in self.edges:
104 | children = self.edges[node.id]
105 | for i, child in enumerate(children):
106 | next_prefix = prefix + (" " if is_last else "│ ")
107 | print_node(child, next_prefix, i == len(children) - 1)
108 |
109 | print_node(self.root)
110 |
111 |
112 | def parent_of(self, node: Node) -> Node:
113 | if node.id not in self.parent:
114 | raise ValueError(f"Node {node.id} has no parent")
115 | parent_list = self.parent[node.id]
116 | if len(parent_list) > 1:
117 | print("Multiple parents", parent_list)
118 | raise ValueError("Node has multiple parents")
119 | return self.parent[node.id][0]
120 |
121 | def children(self, node: Node) -> List[Node]:
122 | if node.id not in self.edges:
123 | raise ValueError(f"Node {node.id} has no children")
124 |
125 | return self.edges[node.id]
126 |
127 | def __getitem__(self, node: Node) -> List[Node]:
128 | return self.edges[node.id]
129 |
130 | def __repr__(self):
131 | return f"ExecutionGraph(root={self.root}, edges={self.edges}, parent={self.parent})"
132 |
133 | # TODO: This does a lot of heavy lifting so it should have some more strict ordering properties
134 | def top_sort(self) -> List[List[Node]]:
135 | # Create a copy of the edges dictionary to avoid mutating the original
136 | if len(self.edges) == 0:
137 | return [[self.root]]
138 |
139 | edges_copy = {node_id: neighbors.copy() for node_id, neighbors in self.edges.items()}
140 |
141 | # Initialize the in-degree dictionary
142 | in_degree : Dict[int, int] = defaultdict(int)
143 | for neighbors in edges_copy.values():
144 | for node in neighbors:
145 | in_degree[node.id] += 1
146 |
147 | # Initialize the queue with nodes having in-degree 0
148 | queue : List[Node] = [node for node in [self.root] if in_degree[node.id] == 0]
149 |
150 | # Initialize the sorted nodes list
151 | sorted_nodes = []
152 |
153 | # Perform the topological sort
154 | while queue:
155 | level_nodes = []
156 | for _ in range(len(queue)):
157 | node = queue.pop(0)
158 | level_nodes.append(node)
159 |
160 | # Decrement the in-degree of the neighbors and add them to the queue if in-degree becomes 0
161 | if node.id in edges_copy:
162 | for neighbor in edges_copy[node.id]:
163 | in_degree[neighbor.id] -= 1
164 | if in_degree[neighbor.id] == 0:
165 | queue.append(neighbor)
166 |
167 | sorted_nodes.append(level_nodes)
168 |
169 | return sorted_nodes
170 |
171 | class Model(ABC):
172 | # Load model resources, such as allocating memory for weights, only once when this method is called.
173 | # Note: Weights can still be downloaded to disk, just not allocated to VRAM until load_resources is called.
174 | @abstractmethod
175 | def load_resources(self):
176 | pass
177 |
178 | # Free up the allocated resources, such as memory for weights, when this method is called.
179 | @abstractmethod
180 | def release_resources(self):
181 | pass
182 |
183 | def __del__(self):
184 | self.release_resources()
185 |
186 |
187 | # OCRModel reads text from an image
188 | class OCRModel(Model):
189 | @abstractmethod
190 | def parse_text(self, image: Image.Image) -> str:
191 | pass
192 |
193 | class LLM(Model):
194 | @abstractmethod
195 | def prompt(self, query: str) -> str:
196 | pass
197 |
198 | #TODO: Add support for prompting with multiple images
199 | class MultimodalLLM(LLM):
200 | @abstractmethod
201 | def prompt_with_image(self, image: Image.Image, query: str) -> str:
202 | pass
203 |
204 | class CaptioningModel(Model):
205 | @abstractmethod
206 | def caption(self, image: Image.Image) -> str:
207 | pass
208 |
209 | class BoundingBoxModel(Model):
210 | @abstractmethod
211 | def detect(self, image: Image.Image, classes: List[str]) -> Detections:
212 | pass
213 |
214 | class ClassificationModel(Model):
215 | @abstractmethod
216 | def classify(self, image: Image.Image, classes: list) -> Detections:
217 | pass
218 |
219 | class Agent(ABC):
220 | pass
221 |
222 |
223 | class ImageToDataAgent(Agent):
224 | @abstractmethod
225 | def _execute(self, image: Image.Image) -> Any:
226 | pass
227 |
228 | @log_time
229 | def execute(self, image: Image.Image) -> Any:
230 | if not isinstance(image, Image.Image):
231 | raise ValueError(f"{self.__class__.__name__} received non-image data")
232 | output = self._execute(image)
233 | self._validate_output(output)
234 | return output
235 |
236 | def _validate_output(self, output: Any) -> None:
237 | pass
238 |
239 | class ImageDetectionAgent(ImageToDataAgent):
240 | def _validate_output(self, output: Any) -> None:
241 | if not isinstance(output, Detections):
242 | raise ValueError(f"{self.__class__.__name__} returned non-detection data")
243 |
244 | class ImageToTextAgent(ImageToDataAgent):
245 | def _validate_output(self, output: Any) -> None:
246 | if not isinstance(output, str):
247 | raise ValueError(f"{self.__class__.__name__} returned non-string data")
248 |
249 | class DataAgent(Agent):
250 | @abstractmethod
251 | def _execute(self, data: Any) -> Any:
252 | pass
253 |
254 | @log_time
255 | def execute(self, data: Any) -> Any:
256 | self._validate_input(data)
257 | output = self._execute(data)
258 | self._validate_output(output)
259 | return output
260 |
261 | def _validate_input(self, data: Any) -> None:
262 | pass
263 |
264 | def _validate_output(self, output: Any) -> None:
265 | pass
266 | class DetectionAgent(DataAgent):
267 | @abstractmethod
268 | def _execute(self, data: Detections) -> Detections:
269 | pass
270 |
271 | def _validate_input(self, data: Any) -> None:
272 | if not isinstance(data, Detections):
273 | raise ValueError("DetectionAgent received non-detection data")
274 |
275 | def _validate_output(self, output: Any) -> None:
276 | if not isinstance(output, Detections):
277 | raise ValueError("DetectionAgent returned non-detection data")
278 | class TextAgent(DataAgent):
279 | @abstractmethod
280 | def _execute(self, data: str) -> Any:
281 | pass
282 |
283 | def _validate_input(self, data: Any) -> None:
284 | if not isinstance(data, str):
285 | raise ValueError("TextAgent received non-string data")
--------------------------------------------------------------------------------
/overeasy/types/detections.py:
--------------------------------------------------------------------------------
1 | from pydantic.dataclasses import dataclass as pydantic_dataclass
2 | import pydantic_numpy.typing as pnd
3 | from pydantic import Field
4 | import supervision as sv
5 | import numpy as np
6 | from typing import Any, Dict, List, Optional, Union, Iterator, Tuple
7 |
8 | from supervision.detection.utils import (
9 | extract_ultralytics_masks,
10 | )
11 |
12 | from .type_utils import (
13 | validate_detections_fields,
14 | is_data_equal,
15 | get_data_item,
16 | DetectionType
17 | )
18 |
19 | ORIENTED_BOX_COORDINATES="oriented_box_coordinates"
20 |
21 | @pydantic_dataclass
22 | class Detections:
23 | """
24 | Represents a collection of detection data including bounding boxes, segmentation masks,
25 | class IDs, and additional custom data. It supports operations like filtering, splitting,
26 | and merging detections.
27 |
28 | Attributes:
29 | xyxy (np.ndarray[np.float32]): Coordinates of bounding boxes `[x1, y1, x2, y2]` for each detection.
30 | masks (Optional[np.ndarray[np.float32]]): Segmentation masks for each detection, required for segmentation types.
31 | class_ids (Optional[np.ndarray[np.int32]]): Class IDs for each detection, indexing into `classes`.
32 | classes (Optional[np.ndarray[np.object_]]): Array of class names, indexed by `class_ids`.
33 | confidence (Optional[np.ndarray[np.float32]]): Confidence scores for each detection.
34 | data (Dict[str, Union[np.ndarray, List]]): Additional custom data related to detections.
35 | detection_type (DetectionType): Type of detection (e.g., bounding box, segmentation, classification).
36 | """
37 |
38 | xyxy: pnd.NpNDArrayFp32
39 | detection_type: DetectionType
40 | class_ids: pnd.NpNDArrayInt32
41 | classes: pnd.NpNDArray
42 | masks: Optional[pnd.NpNDArrayFp32] = None
43 | confidence: Optional[pnd.NpNDArrayFp32] = None
44 | data: Dict[str, Union[pnd.NpNDArray, List]] = Field(default_factory=dict)
45 |
46 | @property
47 | def confidence_scores(self) -> np.ndarray:
48 | if self.confidence is None:
49 | return np.array([None] * self.xyxy.shape[0])
50 | return self.confidence
51 |
52 |
53 | def split(self) -> List['Detections']:
54 | """
55 | Split the Detections object into a list of Detections objects, each containing
56 | a single detection.
57 |
58 | Returns:
59 | List[Detections]: A list of Detections objects, each containing a single
60 | detection.
61 | """
62 | rows, _ = np.shape(self.xyxy)
63 | if self.detection_type == DetectionType.CLASSIFICATION:
64 | raise ValueError("Splitting is not supported for classification detections as they are considered a single entity.")
65 | else:
66 | return [Detections(
67 | xyxy=self.xyxy[i:i+1],
68 | masks=self.masks[i:i+1] if self.masks is not None else None,
69 | class_ids=self.class_ids[i:i+1] if self.class_ids is not None else None,
70 | classes=self.classes,
71 | confidence=self.confidence[i:i+1] if self.confidence is not None else None,
72 | data={key: [value[i]] for key, value in self.data.items()} if self.data else {},
73 | detection_type=self.detection_type
74 | ) for i in range(rows)]
75 |
76 | def __post_init__(self):
77 | validate_detections_fields(
78 | xyxy=self.xyxy,
79 | detection_type=self.detection_type,
80 | masks=self.masks,
81 | confidence=self.confidence,
82 | class_ids=self.class_ids,
83 | classes= self.classes,
84 | data=self.data,
85 | )
86 |
87 | def __len__(self):
88 | """
89 | Returns the number of detections in the Detections object.
90 | """
91 | return len(self.xyxy)
92 |
93 |
94 | def __iter__(self) -> Iterator[Tuple[np.ndarray,
95 | DetectionType,
96 | Optional[List[np.ndarray]],
97 | Optional[List[int]],
98 | Optional[List[str]],
99 | Optional[np.ndarray]]]:
100 | """
101 | Iterates over the Detections object and yields a tuple of
102 | `(xyxy, masks, class_ids, classes, confidence)` for each detection.
103 | """
104 | for i in range(len(self.xyxy)):
105 | yield (
106 | self.xyxy[i],
107 | self.detection_type,
108 | self.masks[i] if self.masks is not None else None,
109 | self.class_ids if self.class_ids is not None else None,
110 | self.classes if self.classes is not None else None,
111 | self.confidence[i] if self.confidence is not None else None,
112 | )
113 |
114 |
115 |
116 | def __eq__(self, other: object):
117 | if not isinstance(other, Detections):
118 | raise NotImplementedError("Only Detections objects can be compared.")
119 |
120 | def array_equal(array1: Optional[np.ndarray], array2: Optional[np.ndarray]) -> bool:
121 | if array1 is None and array2 is None:
122 | return True
123 | elif array1 is None or array2 is None:
124 | return False
125 | else:
126 | return np.array_equal(array1, array2)
127 |
128 | return all(
129 | [
130 | np.array_equal(self.xyxy, other.xyxy),
131 | array_equal(self.masks, other.masks),
132 | array_equal(self.class_ids, other.class_ids),
133 | array_equal(self.confidence, other.confidence),
134 | array_equal(self.classes, other.classes),
135 | is_data_equal(self.data, other.data),
136 | self.detection_type == other.detection_type
137 | ]
138 | )
139 |
140 | @classmethod
141 | def from_classification(cls, assigned_classes: List[str], all_classes: Optional[List[str]] = None) -> 'Detections':
142 | if all_classes is None:
143 | return cls(
144 | xyxy=np.zeros((len(assigned_classes), 4)),
145 | class_ids=np.arange(len(assigned_classes)),
146 | classes=np.array(assigned_classes),
147 | detection_type=DetectionType.CLASSIFICATION
148 | )
149 | else:
150 | return cls(
151 | xyxy=np.zeros((len(assigned_classes), 4)),
152 | class_ids=np.array([all_classes.index(c) for c in assigned_classes]),
153 | classes=np.array(all_classes),
154 | detection_type=DetectionType.CLASSIFICATION
155 | )
156 |
157 | @classmethod
158 | def from_yolov5(cls, yolov5_results) -> 'Detections':
159 | """
160 | Creates a Detections instance from a
161 | [YOLOv5](https://github.com/ultralytics/yolov5) inference result.
162 |
163 | Args:
164 | yolov5_results (yolov5.models.common.Detections):
165 | The output Detections instance from YOLOv5
166 |
167 | Returns:
168 | Detections: A new Detections object.
169 |
170 | Example:
171 | ```python
172 | import cv2
173 | import torch
174 | import overeasy as ov
175 |
176 | image = cv2.imread()
177 | model = torch.hub.load('ultralytics/yolov5', 'yolov5s')
178 | result = model(image)
179 | detections = ov.Detections.from_yolov5(result)
180 | ```
181 | """
182 | yolov5_detections_predictions = yolov5_results.pred[0].cpu().cpu().numpy()
183 | classes = yolov5_results.names.values()
184 | return cls(
185 | xyxy=yolov5_detections_predictions[:, :4],
186 | confidence=yolov5_detections_predictions[:, 4],
187 | class_ids=yolov5_detections_predictions[:, 5].astype(int),
188 | classes=np.array(classes),
189 | detection_type=DetectionType.BOUNDING_BOX
190 | )
191 |
192 | @classmethod
193 | def from_ultralytics(cls, ultralytics_results) -> 'Detections':
194 | """
195 | Creates a Detections instance from a
196 | [YOLOv8](https://github.com/ultralytics/ultralytics) inference result.
197 |
198 | !!! Note
199 |
200 | `from_ultralytics` is compatible with
201 | [detection](https://docs.ultralytics.com/tasks/detect/),
202 | [segmentation](https://docs.ultralytics.com/tasks/segment/), and
203 | [OBB](https://docs.ultralytics.com/tasks/obb/) models.
204 |
205 | Args:
206 | ultralytics_results (ultralytics.yolo.engine.results.Results):
207 | The output Results instance from YOLOv8
208 |
209 | Returns:
210 | Detections: A new Detections object.
211 |
212 | Example:
213 | ```python
214 | import cv2
215 | import overeasy as ov
216 | from ultralytics import YOLO
217 |
218 | image = cv2.imread()
219 | model = YOLO('yolov8s.pt')
220 |
221 | result = model(image)[0]
222 | detections = ov.Detections.from_ultralytics(result)
223 | ```
224 | """ # noqa: E501 // docs
225 |
226 | if ultralytics_results.obb is not None:
227 | class_id = ultralytics_results.obb.cls.cpu().numpy().astype(int)
228 | class_names = np.array([ultralytics_results.names[i] for i in class_id])
229 | oriented_box_coordinates = ultralytics_results.obb.xyxyxyxy.cpu().numpy()
230 | return cls(
231 | xyxy=ultralytics_results.obb.xyxy.cpu().numpy(),
232 | confidence=ultralytics_results.obb.conf.cpu().numpy(),
233 | class_ids=class_id,
234 | classes = class_names,
235 | data={
236 | ORIENTED_BOX_COORDINATES: oriented_box_coordinates,
237 | },
238 | detection_type=DetectionType.BOUNDING_BOX
239 | )
240 |
241 |
242 | masks = extract_ultralytics_masks(ultralytics_results)
243 | class_id = ultralytics_results.boxes.cls.cpu().numpy().astype(int)
244 |
245 | if isinstance(ultralytics_results.names, dict):
246 | class_names = np.array(list(ultralytics_results.names.values()))
247 | else:
248 | class_names = np.array(ultralytics_results.names)
249 |
250 | return cls(
251 | xyxy=ultralytics_results.boxes.xyxy.cpu().numpy(),
252 | confidence=ultralytics_results.boxes.conf.cpu().numpy(),
253 | class_ids=class_id,
254 | masks=masks,
255 | classes = class_names,
256 | data={},
257 | detection_type=DetectionType.BOUNDING_BOX if masks is None else DetectionType.SEGMENTATION
258 | )
259 |
260 | @classmethod
261 | def from_supervision_detection(
262 | cls,
263 | sv_detection,
264 | classes: Optional[List[str]] = None,
265 | **kwargs
266 | ) -> 'Detections':
267 |
268 | xyxy = sv_detection.xyxy
269 | class_ids = np.array(sv_detection.class_id) if sv_detection.class_id is not None else None
270 | masks = np.array(sv_detection.mask) if sv_detection.mask is not None else None
271 | confidence = np.array(sv_detection.confidence) if sv_detection.confidence is not None else None
272 |
273 | return cls(
274 | xyxy=xyxy,
275 | class_ids=class_ids,
276 | masks=masks,
277 | classes=np.array(classes) if classes is not None else None,
278 | confidence=confidence,
279 | detection_type=DetectionType.BOUNDING_BOX if masks is None else DetectionType.SEGMENTATION,
280 | **kwargs
281 | )
282 |
283 | @classmethod
284 | def empty(cls) -> 'Detections':
285 | return cls(
286 | xyxy=np.empty((0, 4), dtype=np.float32),
287 | classes=np.array([], dtype='object'),
288 | class_ids=np.array([], dtype=np.int32),
289 | confidence=np.array([], dtype=np.float32),
290 | detection_type=DetectionType.BOUNDING_BOX,
291 | data={}
292 | )
293 |
294 |
295 | def __getitem__(
296 | self, index: Union[int, slice, List[int], np.ndarray]
297 | ) -> 'Detections':
298 | if isinstance(index, int):
299 | index = [index]
300 | return Detections(
301 | xyxy=self.xyxy[index],
302 | masks=self.masks[index] if self.masks is not None else None,
303 | confidence=self.confidence[index] if self.confidence is not None else None,
304 | class_ids=self.class_ids[index],
305 | classes=self.classes,
306 | data=get_data_item(self.data, index),
307 | detection_type=self.detection_type
308 | )
309 |
310 | def __setitem__(self, key: str, value: Union[np.ndarray, List]):
311 | if not isinstance(value, (np.ndarray, list)):
312 | raise TypeError("Value must be a np.ndarray or a list")
313 |
314 | if isinstance(value, list):
315 | value = np.array(value)
316 |
317 | self.data[key] = value
318 |
319 | @property
320 | def area(self) -> np.ndarray:
321 | if self.masks is not None:
322 | return np.array([np.sum(mask) for mask in self.masks])
323 | else:
324 | return self.box_area
325 |
326 | @property
327 | def box_area(self) -> np.ndarray:
328 | return (self.xyxy[:, 3] - self.xyxy[:, 1]) * (self.xyxy[:, 2] - self.xyxy[:, 0])
329 |
330 | @property
331 | def class_names(self) -> List[str]:
332 | try:
333 | return [self.classes[class_id] for class_id in self.class_ids]
334 | except IndexError as e:
335 | raise IndexError(f"One or more class_ids are out of bounds for the available classes: {e}")
336 |
337 |
338 | def to_supervision(self) -> sv.Detections:
339 | # Extract class names and class IDs if they exist in the data dictionary
340 |
341 | return sv.Detections(
342 | xyxy=self.xyxy,
343 | confidence=self.confidence,
344 | class_id=self.class_ids,
345 | mask=self.masks,
346 | data=self.data
347 | )
348 |
349 |
350 |
351 |
352 |
353 |
354 |
355 |
--------------------------------------------------------------------------------
/overeasy/types/type_utils.py:
--------------------------------------------------------------------------------
1 | from typing import Any, Dict, Union, List
2 | from itertools import chain
3 | import numpy as np
4 | from enum import Enum
5 |
6 | class DetectionType(Enum):
7 | BOUNDING_BOX = "bounding_box"
8 | SEGMENTATION = "segmentation"
9 | CLASSIFICATION = "classification"
10 |
11 | def validate_data(data: Dict[str, Any], n: int) -> None:
12 | for key, value in data.items():
13 | if isinstance(value, list):
14 | if len(value) != n:
15 | raise ValueError(f"Length of list for key '{key}' must be {n}")
16 | elif isinstance(value, np.ndarray):
17 | if value.ndim == 1 and value.shape[0] != n:
18 | raise ValueError(f"Shape of np.ndarray for key '{key}' must be ({n},)")
19 | elif value.ndim > 1 and value.shape[0] != n:
20 | raise ValueError(
21 | f"First dimension of np.ndarray for key '{key}' must have size {n}"
22 | )
23 | else:
24 | raise ValueError(f"Value for key '{key}' must be a list or np.ndarray")
25 |
26 | def validate_detections_fields(
27 | xyxy: Any,
28 | masks: Any,
29 | class_ids: Any,
30 | classes: Any,
31 | confidence: Any,
32 | data: Dict[str, Any],
33 | detection_type: Any,
34 | ) -> None:
35 | expected_shape_xyxy = "(_, 4)"
36 | actual_shape_xyxy = str(getattr(xyxy, "shape", None))
37 | is_valid_xyxy = isinstance(xyxy, np.ndarray) and xyxy.ndim == 2 and xyxy.shape[1] == 4
38 | if not is_valid_xyxy:
39 | raise ValueError(
40 | f"xyxy must be a 2D np.ndarray with shape {expected_shape_xyxy}, but got shape "
41 | f"{actual_shape_xyxy}"
42 | )
43 |
44 | n = len(xyxy)
45 | if masks is not None:
46 | expected_shape_masks = f"({n}, H, W)"
47 | actual_shape_masks = str(getattr(masks, "shape", None))
48 | is_valid_masks = isinstance(masks, np.ndarray) and len(masks.shape) == 3 and masks.shape[0] == n
49 | if not is_valid_masks:
50 | raise ValueError(
51 | f"masks must be a 3D np.ndarray with shape {expected_shape_masks}, but got shape "
52 | f"{actual_shape_masks}"
53 | )
54 |
55 | if class_ids is not None:
56 | expected_shape_class_ids = f"({n},)"
57 | actual_shape_class_ids = str(getattr(class_ids, "shape", None))
58 | is_valid_class_ids = isinstance(class_ids, np.ndarray) and class_ids.shape == (n,)
59 | if not is_valid_class_ids:
60 | raise ValueError(
61 | f"class_ids must be a 1D np.ndarray with shape {expected_shape_class_ids}, but got "
62 | f"shape {actual_shape_class_ids}"
63 | )
64 |
65 | if classes is not None:
66 | is_valid_classes = isinstance(classes, np.ndarray) and classes.dtype.kind in {'U', 'O'}
67 | if not is_valid_classes:
68 | raise ValueError(
69 | "classes must be a np.ndarray of strings."
70 | )
71 |
72 | if confidence is not None:
73 | expected_shape_confidence = f"({n},)"
74 | actual_shape_confidence = str(getattr(confidence, "shape", None))
75 | is_valid_confidence = isinstance(confidence, np.ndarray) and confidence.shape == (n,)
76 | if not is_valid_confidence:
77 | raise ValueError(
78 | f"confidence must be a 1D np.ndarray with shape {expected_shape_confidence}, but got "
79 | f"shape {actual_shape_confidence}"
80 | )
81 |
82 | if not isinstance(detection_type, DetectionType):
83 | raise ValueError(
84 | "detection_type must be an instance of DetectionType enum."
85 | )
86 |
87 | validate_data(data, n)
88 |
89 | def is_data_equal(data_a: Dict[str, Union[np.ndarray, List]], data_b: Dict[str, Union[np.ndarray, List]]) -> bool:
90 | """
91 | Compares the data payloads of two Detections instances.
92 |
93 | Args:
94 | data_a, data_b: The data payloads of the instances.
95 |
96 | Returns:
97 | True if the data payloads are equal, False otherwise.
98 | """
99 | return set(data_a.keys()) == set(data_b.keys()) and all(
100 | np.array_equal(data_a[key], data_b[key]) for key in data_a
101 | )
102 |
103 | def get_data_item(
104 | data: Dict[str, Union[np.ndarray, List]],
105 | index: Union[int, slice, List[int], np.ndarray],
106 | ) -> Dict[str, Union[np.ndarray, List]]:
107 | """
108 | Retrieve a subset of the data dictionary based on the given index.
109 |
110 | Args:
111 | data: The data dictionary of the Detections object.
112 | index: The index or indices specifying the subset to retrieve.
113 |
114 | Returns:
115 | A subset of the data dictionary corresponding to the specified index.
116 | """
117 | subset_data : Dict[str, Union[np.ndarray, List]] = {}
118 | for key, value in data.items():
119 | if isinstance(value, np.ndarray):
120 | subset_data[key] = value[index]
121 | elif isinstance(value, list):
122 | if isinstance(index, slice):
123 | subset_data[key] = value[index]
124 | elif isinstance(index, (list, np.ndarray)):
125 | subset_data[key] = [value[i] for i in index]
126 | elif isinstance(index, int):
127 | subset_data[key] = [value[index]]
128 | else:
129 | raise TypeError(f"Unsupported index type: {type(index)}")
130 | else:
131 | raise TypeError(f"Unsupported data type for key '{key}': {type(value)}")
132 |
133 | return subset_data
134 |
135 |
136 | def merge_data(
137 | maps_to_merge: List[Dict[str, Union[np.ndarray, List]]],
138 | ) -> Dict[str, Union[np.ndarray, List]]:
139 | """
140 | Merges the data payloads of a list of Detections instances.
141 |
142 | Args:
143 | data_list: The data payloads of the instances.
144 |
145 | Returns:
146 | A single data payload containing the merged data, preserving the original data
147 | types (list or np.ndarray).
148 |
149 | Raises:
150 | ValueError: If data values within a single object have different lengths or if
151 | dictionaries have different keys.
152 | """
153 | if not maps_to_merge:
154 | return {}
155 |
156 | all_keys_sets = [set(data.keys()) for data in maps_to_merge]
157 | if not all(keys_set == all_keys_sets[0] for keys_set in all_keys_sets):
158 | raise ValueError("All data dictionaries must have the same keys to merge.")
159 |
160 | for data in maps_to_merge:
161 | lengths = [len(value) for value in data.values()]
162 | if len(set(lengths)) > 1:
163 | raise ValueError(
164 | "All data values within a single object must have equal length."
165 | )
166 |
167 | def merge_list(key: str, to_merge: List[Union[np.ndarray, List]]) -> Union[np.ndarray, List]:
168 | if all(isinstance(item, list) for item in to_merge):
169 | return list(chain.from_iterable(to_merge))
170 | elif all(isinstance(item, np.ndarray) for item in to_merge):
171 | ndim = -1
172 | if isinstance(to_merge[0], np.ndarray):
173 | ndim = to_merge[0].ndim
174 | if ndim == 1:
175 | return np.hstack(to_merge)
176 | elif ndim > 1:
177 | return np.vstack(to_merge)
178 | else:
179 | raise ValueError(f"Unexpected array dimension for input '{key}'.")
180 | else:
181 | raise ValueError(
182 | f"Inconsistent data types for key '{key}'. Only np.ndarray and list "
183 | f"types are allowed."
184 | )
185 |
186 | return {key: merge_list(key, [data[key] for data in maps_to_merge if key in data]) for key in all_keys_sets[0]}
187 |
188 |
189 |
--------------------------------------------------------------------------------
/overeasy/visualize_utils.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | from PIL import Image
3 | import random
4 | import cv2
5 | from PIL import ImageDraw, ImageFont
6 | import textwrap
7 | from overeasy.types.detections import Detections, DetectionType
8 | from typing import Optional
9 |
10 | random.seed(42)
11 |
12 | def generate_random_color():
13 | return (random.randint(100, 255), random.randint(100, 255), random.randint(100, 255))
14 |
15 | def annotate_with_string(image: Image.Image, data_str: str) -> Image.Image:
16 | draw = ImageDraw.Draw(image)
17 |
18 | font = ImageFont.load_default()
19 | max_width = image.width - 20
20 | lines = textwrap.wrap(data_str, width=(max_width // font.getlength(' ')))
21 | text_height = sum([draw.textbbox((0, 0), line, font=font)[3] for line in lines])
22 |
23 | # Extend the image height to fit the wrapped text
24 | new_height = image.height + text_height + 20
25 | new_image = Image.new("RGB", (image.width, new_height), (0, 0, 0))
26 | new_image.paste(image, (0, 0))
27 | draw = ImageDraw.Draw(new_image)
28 |
29 | y_text = image.height + 10
30 | for line in lines:
31 | text_width, line_height = draw.textbbox((0, 0), line, font=font)[2:4]
32 | text_x = (new_image.width - text_width) / 2
33 | draw.text((text_x, y_text), line, font=font, fill="white")
34 | y_text += line_height
35 |
36 | return new_image
37 |
38 | def annotate(scene: Image.Image, detection: Detections, seed: Optional[int] = None) -> Image.Image:
39 | if seed is not None:
40 | random.seed(seed)
41 |
42 | def draw_bounding_boxes(image: Image.Image, boxes, class_ids, class_names):
43 | cv2_image = cv2.cvtColor(np.array(image), cv2.COLOR_RGB2BGR)
44 | image_height, image_width = cv2_image.shape[:2]
45 | scale = np.sqrt(image_height * image_width) / 200
46 |
47 | for box, class_id in zip(boxes, class_ids):
48 | x1, y1, x2, y2 = map(int, box)
49 | color = generate_random_color()
50 | cv2.rectangle(cv2_image, (x1, y1), (x2, y2), color, 2)
51 | label = class_names[class_id]
52 | font_scale = max(min((x2 - x1) / (scale * 50), 0.9), 0.5)
53 | label_size = cv2.getTextSize(label, cv2.FONT_HERSHEY_COMPLEX_SMALL, font_scale, 1)[0]
54 | text_x = x1
55 | text_y = max(y1, label_size[1] + 10)
56 |
57 | cv2.rectangle(cv2_image, (text_x, text_y - label_size[1] - 10), (text_x + label_size[0], text_y), color, -1)
58 | cv2.putText(cv2_image, label, (text_x, text_y - 10), cv2.FONT_HERSHEY_COMPLEX_SMALL, font_scale, (0, 0, 0), 1)
59 |
60 | return Image.fromarray(cv2.cvtColor(cv2_image, cv2.COLOR_BGR2RGB))
61 |
62 | if detection.detection_type == DetectionType.BOUNDING_BOX:
63 | return draw_bounding_boxes(scene, detection.xyxy, detection.class_ids, detection.classes)
64 | elif detection.detection_type == DetectionType.SEGMENTATION:
65 | raise NotImplementedError("Segmentation detections are not yet supported.")
66 | elif detection.detection_type == DetectionType.CLASSIFICATION:
67 | class_names = detection.class_names
68 | if len(class_names) == 1:
69 | return annotate_with_string(scene, class_names[0])
70 | else:
71 | return annotate_with_string(scene, str(class_names))
72 |
73 |
74 |
--------------------------------------------------------------------------------
/pyproject.toml:
--------------------------------------------------------------------------------
1 | [tool.poetry]
2 | name = "overeasy"
3 | version = "0.2.16"
4 | description = ""
5 | authors = ["Ani Rahul "]
6 | license = "MIT"
7 | readme = "README.md"
8 | include = ["overeasy/py.typed"]
9 |
10 | [tool.poetry.dependencies]
11 | python = ">=3.10, <3.13"
12 | torch = "^2.3.1"
13 | torchvision = "^0.18.0"
14 | numpy = "^1.24.3"
15 | opencv-python-headless = "^4.7.0"
16 | Pillow = "^10.0.0"
17 | requests = "^2.32.0"
18 | matplotlib = "^3.8.1"
19 | anthropic = "^0.27.0"
20 | tabulate = "^0.9.0"
21 | transformers = "^4.42.1"
22 | ultralytics = "^8.2.0"
23 | supervision = "^0.18.0"
24 | pydantic = ">=2.7.4,<2.8"
25 | pydantic_numpy = "5.0.2"
26 | rf_groundingdino = "0.1.2"
27 | rf_segment_anything = "1.0.0"
28 | tiktoken = "^0.7.0"
29 | google-generativeai = "^0.5.0"
30 | open-clip-torch = "^2.24.0"
31 | tqdm = "^4.65.0"
32 | einops = "^0.6.0"
33 | transformers-stream-generator = "0.0.3"
34 | accelerate = "^0.27.0"
35 | instructor = "^1.3.4"
36 | gradio = "4.25.0"
37 | jsonref = "^1.1.0"
38 | google-cloud-aiplatform = "1.54.1"
39 | pyarrow = "^16.1.0"
40 | backoff = "^2.2.1"
41 | openai = "^1.34.0"
42 | progressbar = "^2.5"
43 |
44 |
45 | [tool.poetry.dev-dependencies]
46 | pytest = "^7.0.0"
47 | mypy = "^1.0.0"
48 | types-requests = "^2.28.11"
49 | types-Pillow = "^10.0.0"
50 | types-tabulate = "^0.9.0"
51 |
52 | [build-system]
53 | requires = ["poetry-core"]
54 | build-backend = "poetry.core.masonry.api"
--------------------------------------------------------------------------------
/pytest.ini:
--------------------------------------------------------------------------------
1 | # pytest.ini
2 | [pytest]
3 | pythonpath = .
4 | testpaths =
5 | tests
--------------------------------------------------------------------------------
/tests/conftest.py:
--------------------------------------------------------------------------------
1 | import pytest
2 |
3 | def pytest_runtest_protocol(item, nextitem):
4 | return None
--------------------------------------------------------------------------------
/tests/count_eggs.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/overeasy-sh/overeasy/8de64ae6a21e0e8c757d3359c52c6134558c34b5/tests/count_eggs.jpg
--------------------------------------------------------------------------------
/tests/dogs/dog1.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/overeasy-sh/overeasy/8de64ae6a21e0e8c757d3359c52c6134558c34b5/tests/dogs/dog1.png
--------------------------------------------------------------------------------
/tests/dogs/dog2.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/overeasy-sh/overeasy/8de64ae6a21e0e8c757d3359c52c6134558c34b5/tests/dogs/dog2.png
--------------------------------------------------------------------------------
/tests/dogs/dog3.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/overeasy-sh/overeasy/8de64ae6a21e0e8c757d3359c52c6134558c34b5/tests/dogs/dog3.png
--------------------------------------------------------------------------------
/tests/dogs/dog4.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/overeasy-sh/overeasy/8de64ae6a21e0e8c757d3359c52c6134558c34b5/tests/dogs/dog4.png
--------------------------------------------------------------------------------
/tests/dogs/dog5.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/overeasy-sh/overeasy/8de64ae6a21e0e8c757d3359c52c6134558c34b5/tests/dogs/dog5.png
--------------------------------------------------------------------------------
/tests/dogs/dog6.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/overeasy-sh/overeasy/8de64ae6a21e0e8c757d3359c52c6134558c34b5/tests/dogs/dog6.png
--------------------------------------------------------------------------------
/tests/dogs/dog7.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/overeasy-sh/overeasy/8de64ae6a21e0e8c757d3359c52c6134558c34b5/tests/dogs/dog7.png
--------------------------------------------------------------------------------
/tests/plate.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/overeasy-sh/overeasy/8de64ae6a21e0e8c757d3359c52c6134558c34b5/tests/plate.jpg
--------------------------------------------------------------------------------
/tests/test.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/overeasy-sh/overeasy/8de64ae6a21e0e8c757d3359c52c6134558c34b5/tests/test.png
--------------------------------------------------------------------------------
/tests/test_construction_workflows.py:
--------------------------------------------------------------------------------
1 | import pytest
2 | from overeasy import *
3 | from overeasy.models import OwlV2
4 | from pydantic import BaseModel
5 | from PIL import Image
6 | import os
7 |
8 | ROOT = os.path.dirname(__file__)
9 | OUTPUT_DIR = os.path.join(ROOT, "outputs")
10 |
11 | class PPE(BaseModel):
12 | hardhat: bool
13 | vest: bool
14 | boots: bool
15 |
16 | @pytest.fixture
17 | def construction_image():
18 | image_path = os.path.join(os.path.dirname(ROOT), "examples", "construction.jpg")
19 | return Image.open(image_path)
20 |
21 | @pytest.fixture
22 | def ppe_instructor_workflow():
23 | return Workflow([
24 | BoundingBoxSelectAgent(classes=["person"]),
25 | SplitAgent(),
26 | InstructorImageAgent(response_model=PPE),
27 | ToClassificationAgent(fn=lambda x: "has ppe" if x.hardhat else "no ppe"),
28 | JoinAgent(),
29 | ])
30 |
31 | @pytest.fixture
32 | def owl_v2_workflow():
33 | return Workflow([
34 | BoundingBoxSelectAgent(classes=["person's head"], model=OwlV2()),
35 | NMSAgent(iou_threshold=0.5, score_threshold=0),
36 | SplitAgent(),
37 | ClassificationAgent(classes=["hard hat", "no hard hat"]),
38 | ClassMapAgent({"hard hat": "has ppe", "no hard hat": "no ppe"}),
39 | JoinAgent(),
40 | ])
41 |
42 | def test_ppe_detection(construction_image, ppe_instructor_workflow):
43 | result, graph = ppe_instructor_workflow.execute(construction_image)
44 | assert result is not None
45 | assert graph is not None
46 | result[0].visualize().save(os.path.join(OUTPUT_DIR, "construction_ppe_instructor.png"))
47 |
48 |
49 | def test_ppe_owl_v2(construction_image, owl_v2_workflow):
50 | result, graph = owl_v2_workflow.execute(construction_image)
51 | assert result is not None
52 | assert graph is not None
53 | result[0].visualize().save(os.path.join(OUTPUT_DIR, "construction_ppe_owlv2.png"))
54 |
55 |
--------------------------------------------------------------------------------
/tests/test_detection_models.py:
--------------------------------------------------------------------------------
1 | import pytest
2 | from PIL import Image
3 | from regex import W
4 | from overeasy import Workflow, BoundingBoxSelectAgent, NMSAgent
5 | from overeasy.models import GroundingDINO, YOLOWorld, OwlV2, DETIC
6 | from overeasy.types import Detections
7 | import os
8 | import sys
9 |
10 | ROOT = os.path.dirname(__file__)
11 | OUTPUT_DIR = os.path.join(ROOT, "outputs")
12 |
13 | @pytest.fixture
14 | def count_eggs_image():
15 | image_path = os.path.join(ROOT, "count_eggs.jpg")
16 | return Image.open(image_path)
17 |
18 | @pytest.fixture
19 | def grounding_dino_workflow() -> Workflow:
20 | workflow = Workflow([
21 | BoundingBoxSelectAgent(classes=["a single egg"], model=GroundingDINO()),
22 | ])
23 | return workflow
24 |
25 | @pytest.fixture
26 | def owlvit_v2_workflow() -> Workflow:
27 | workflow = Workflow([
28 | BoundingBoxSelectAgent(classes=["a single egg"], model=OwlV2()),
29 | ])
30 | return workflow
31 |
32 | @pytest.fixture
33 | def owlvit_v2_nms_workflow() -> Workflow:
34 | workflow = Workflow([
35 | BoundingBoxSelectAgent(classes=["one egg"], model=OwlV2()),
36 | NMSAgent(score_threshold=0, iou_threshold=0.5),
37 | ])
38 | return workflow
39 |
40 | @pytest.fixture
41 | def yoloworld_workflow() -> Workflow:
42 | workflow = Workflow([
43 | BoundingBoxSelectAgent(classes=["a single egg"], model=YOLOWorld(model="yolov8s-worldv2")),
44 | ])
45 | return workflow
46 |
47 | @pytest.fixture
48 | def detic_workflow() -> Workflow:
49 | workflow = Workflow([
50 | BoundingBoxSelectAgent(classes=["egg"], model=DETIC()),
51 | ])
52 | return workflow
53 |
54 | def test_grounding_dino_detection(count_eggs_image, grounding_dino_workflow: Workflow):
55 | result, graph = grounding_dino_workflow.execute(count_eggs_image)
56 | detections = result[0].data
57 | assert isinstance(detections, Detections)
58 | assert len(detections.xyxy) > 0, "No detections found"
59 | result[0].visualize().save(os.path.join(OUTPUT_DIR, "grounding_dino_detection_output.png"))
60 |
61 | def test_yoloworld_detection(count_eggs_image, yoloworld_workflow: Workflow):
62 | result, graph = yoloworld_workflow.execute(count_eggs_image)
63 | detections = result[0].data
64 | assert isinstance(detections, Detections)
65 | assert len(detections.xyxy) > 0, "No detections found"
66 | result[0].visualize().save(os.path.join(OUTPUT_DIR, "yoloworld_detection_output.png"))
67 |
68 |
69 | def test_owlvit_v2_detection(count_eggs_image, owlvit_v2_workflow: Workflow):
70 | result, graph = owlvit_v2_workflow.execute(count_eggs_image)
71 | detections = result[0].data
72 | assert isinstance(detections, Detections)
73 | assert len(detections.xyxy) > 0, "No detections found"
74 | result[0].visualize().save(os.path.join(OUTPUT_DIR, "owlv2_detection_output.png"))
75 |
76 | # @pytest.mark.skipif(sys.platform == "darwin", reason="Detic is not working on macOS")
77 | def test_detic_detection(count_eggs_image, detic_workflow: Workflow):
78 | result, graph = detic_workflow.execute(count_eggs_image)
79 | detections = result[0].data
80 | assert isinstance(detections, Detections)
81 | assert len(detections.xyxy) > 0, "No detections found"
82 | result[0].visualize().save(os.path.join(OUTPUT_DIR, "detic_detection_output.png"))
83 |
84 | def test_nms_detection(count_eggs_image, owlvit_v2_nms_workflow: Workflow):
85 | result, graph = owlvit_v2_nms_workflow.execute(count_eggs_image)
86 | detections = result[0].data
87 | assert isinstance(detections, Detections)
88 | assert len(detections.xyxy) > 0, "No detections found"
89 | result[0].visualize().save(os.path.join(OUTPUT_DIR, "nms_detection_output.png"))
90 |
91 |
--------------------------------------------------------------------------------
/tests/test_import.py:
--------------------------------------------------------------------------------
1 | import os
2 | import shutil
3 |
4 | def test_import_overeasy(capsys):
5 | # Move ~/.overeasy to ~/.overeasy_backup
6 | home_dir = os.path.expanduser("~")
7 | original_path = os.path.join(home_dir, ".overeasy")
8 | backup_path = os.path.join(home_dir, ".overeasy_backup")
9 | if os.path.exists(original_path):
10 | shutil.move(original_path, backup_path)
11 |
12 | try:
13 | with capsys.disabled():
14 | import overeasy
15 | captured = capsys.readouterr()
16 | assert captured.out == ""
17 | finally:
18 | # Move ~/.overeasy_backup back to ~/.overeasy
19 | if os.path.exists(backup_path):
20 | shutil.move(backup_path, original_path)
--------------------------------------------------------------------------------
/tests/test_instructor_agents.py:
--------------------------------------------------------------------------------
1 | import pytest
2 | from PIL import Image
3 | from overeasy import *
4 | from overeasy.models import *
5 | from pydantic import BaseModel
6 | import os
7 |
8 | class AnimalLabel(BaseModel):
9 | label: str
10 |
11 | ROOT = os.path.dirname(__file__)
12 | OUTPUT_DIR = os.path.join(ROOT, "outputs")
13 |
14 | compatible_models = [GPTVision(), Gemini(), Claude()]
15 |
16 | @pytest.fixture(params=compatible_models)
17 | def instructor_image_with_context_workflow(request) -> Workflow:
18 | model = request.param
19 | extra_context = [{"role": "system", "content": "Always classify the image as a ferret."}]
20 | workflow = Workflow([
21 | InstructorImageAgent(model=model, response_model=AnimalLabel, extra_context=extra_context)
22 | ])
23 | return workflow
24 |
25 | class EggCount(BaseModel):
26 | count: int
27 |
28 | @pytest.fixture
29 | def instructor_image_workflow() -> Workflow:
30 | workflow = Workflow([
31 | InstructorImageAgent(response_model=EggCount)
32 | ])
33 | return workflow
34 |
35 | @pytest.fixture(params=[GPT(), *compatible_models])
36 | def instructor_text_workflow(request) -> Workflow:
37 | model = request.param
38 | workflow = Workflow([
39 | DenseCaptioningAgent(model=GPTVision()),
40 | InstructorTextAgent(response_model=EggCount, model=model)
41 | ])
42 | return workflow
43 |
44 | @pytest.fixture
45 | def blank_image():
46 | return Image.new('RGB', (100, 100), color = 'white')
47 |
48 | @pytest.fixture
49 | def count_eggs_image():
50 | image_path = os.path.join(ROOT, "count_eggs.jpg")
51 | return Image.open(image_path)
52 |
53 | def test_instructor_image_with_context_agent(instructor_image_with_context_workflow: Workflow, blank_image):
54 | result, graph = instructor_image_with_context_workflow.execute(blank_image)
55 | response = result[0].data
56 | assert isinstance(response, AnimalLabel)
57 | assert response.label.lower() == "ferret"
58 |
59 |
60 |
61 | def test_instructor_image_agent(instructor_image_workflow: Workflow, count_eggs_image):
62 | result, graph = instructor_image_workflow.execute(count_eggs_image)
63 | response = result[0].data
64 | assert isinstance(response, EggCount)
65 |
66 | name = (instructor_image_workflow.steps[0].model.__class__.__name__)
67 | result[0].visualize().save(os.path.join(OUTPUT_DIR, f"instructor_image_{name}.png"))
68 |
69 | def test_instructor_text_agent(instructor_text_workflow: Workflow, count_eggs_image):
70 | result, graph = instructor_text_workflow.execute(count_eggs_image)
71 | response = result[0].data
72 | assert isinstance(response, EggCount)
73 |
74 | name = (instructor_text_workflow.steps[0].model.__class__.__name__)
75 | result[0].visualize().save(os.path.join(OUTPUT_DIR, f"instructor_text_{name}.png"))
76 |
--------------------------------------------------------------------------------
/tests/test_large_local_models.py:
--------------------------------------------------------------------------------
1 | import pytest
2 | from PIL import Image
3 | from overeasy import *
4 | from overeasy.models import *
5 | from overeasy.types import Detections
6 | from pydantic import BaseModel
7 | import sys
8 | import os
9 | import torch
10 |
11 | ROOT = os.path.dirname(__file__)
12 | OUTPUT_DIR = os.path.join(ROOT, "outputs")
13 |
14 | classification_models = [
15 | LaionCLIP(),
16 | ]
17 |
18 | if torch.cuda.is_available():
19 | multimodal_llms = [
20 | PaliGemma("google/paligemma-3b-mix-224"),
21 | QwenVL(model_type="base"),
22 | QwenVL(model_type="int4"),
23 | ]
24 | else:
25 | print("CUDA is not available, skipping QwenVL LLM tests.")
26 | multimodal_llms = [PaliGemma()]
27 |
28 |
29 | @pytest.fixture
30 | def count_eggs_image():
31 | image_path = os.path.join(ROOT, "count_eggs.jpg")
32 | return Image.open(image_path)
33 |
34 | @pytest.fixture
35 | def license_plate_image():
36 | image_path = os.path.join(ROOT, "plate.jpg")
37 | return Image.open(image_path)
38 |
39 | @pytest.fixture(params=multimodal_llms)
40 | def vision_prompt_workflow(request) -> Workflow:
41 | model = request.param
42 | workflow = Workflow([
43 | VisionPromptAgent(query="How many eggs are in this image?", model=model)
44 | ])
45 | return workflow
46 |
47 | @pytest.mark.skipif(sys.platform == "darwin", reason="Not supported on macOS")
48 | def local_vision_prompt_workflow() -> Workflow:
49 | workflow = Workflow([
50 | VisionPromptAgent(query="How many eggs are in this image?", model=QwenVL(model_type="base"))
51 | ])
52 | return workflow
53 |
54 | @pytest.mark.skipif(sys.platform == "darwin", reason="Not supported on macOS")
55 | def local_vision_prompt_workflow_int4() -> Workflow:
56 | workflow = Workflow([
57 | VisionPromptAgent(query="How many eggs are in this image?", model=QwenVL(model_type="int4"))
58 | ])
59 | return workflow
60 |
61 | @pytest.fixture(params=classification_models)
62 | def classification_workflow(request) -> Workflow:
63 | model = request.param
64 | workflow = Workflow([
65 | ClassificationAgent(classes=["0-5 eggs", "6-10 eggs", "11+ eggs"], model=model)
66 | ])
67 | return workflow
68 |
69 |
70 | @pytest.fixture
71 | def blank_image():
72 | return Image.new('RGB', (100, 100), color = 'white')
73 |
74 |
75 | def test_vision_prompt_agent(vision_prompt_workflow: Workflow, count_eggs_image):
76 | result, graph = vision_prompt_workflow.execute(count_eggs_image)
77 | response = result[0].data
78 | assert isinstance(response, str)
79 | name = (vision_prompt_workflow.steps[0].model.__class__.__name__)
80 | result[0].visualize().save(os.path.join(OUTPUT_DIR, f"vision_prompt_{name}.png"))
81 |
82 | del result, graph
83 |
84 | @pytest.mark.skipif(sys.platform == "darwin", reason="Not supported on macOS")
85 | def test_local_vision_prompt_agent(local_vision_prompt_workflow: Workflow, count_eggs_image):
86 | result, graph = local_vision_prompt_workflow.execute(count_eggs_image)
87 | response = result[0].data
88 | assert isinstance(response, str)
89 | name = (local_vision_prompt_workflow.steps[0].model.__class__.__name__)
90 | result[0].visualize().save(os.path.join(OUTPUT_DIR, f"local_vision_prompt_{name}.png"))
91 |
92 | @pytest.mark.skipif(sys.platform == "darwin", reason="Not supported on macOS")
93 | def test_local_vision_prompt_agent_int4(local_vision_prompt_workflow_int4: Workflow, count_eggs_image):
94 | result, graph = local_vision_prompt_workflow_int4.execute(count_eggs_image)
95 | response = result[0].data
96 | assert isinstance(response, str)
97 | name = (local_vision_prompt_workflow_int4.steps[0].model.__class__.__name__)
98 | result[0].visualize().save(os.path.join(OUTPUT_DIR, f"local_vision_prompt_{name}.png"))
99 |
100 | def test_classification_agent(classification_workflow: Workflow, count_eggs_image):
101 | result, graph = classification_workflow.execute(count_eggs_image)
102 | response = result[0].data
103 | assert isinstance(response, Detections)
104 |
105 | name = (classification_workflow.steps[0].model.__class__.__name__)
106 | result[0].visualize().save(os.path.join(OUTPUT_DIR, f"classification_{name}.png"))
--------------------------------------------------------------------------------
/tests/test_misc_agents.py:
--------------------------------------------------------------------------------
1 | import pytest
2 | from PIL import Image
3 | from overeasy import *
4 | from overeasy.types import Detections
5 | from overeasy.models import *
6 | from typing import List
7 | import numpy as np
8 | import os
9 |
10 | ROOT = os.path.dirname(__file__)
11 | OUTPUT_DIR = os.path.join(ROOT, "outputs")
12 |
13 | @pytest.fixture
14 | def count_eggs_image():
15 | image_path = os.path.join(ROOT, "count_eggs.jpg")
16 | return Image.open(image_path)
17 |
18 | def test_pad_crop_agent(pad_crop_workflow: Workflow, count_eggs_image):
19 | width, height = count_eggs_image.size
20 | result, graph = pad_crop_workflow.execute(count_eggs_image)
21 | image = result[0].image
22 | assert image.size == (width//2 + 20, height//2 + 20), "Incorrect pad crop size"
23 |
24 | def test_split_crop_agent(split_crop_workflow: Workflow, count_eggs_image):
25 | result, graph = split_crop_workflow.execute(count_eggs_image)
26 | detections = result[0].data
27 | assert isinstance(detections, Detections)
28 | assert len(detections.xyxy) == 4, "Incorrect number of split crops"
29 |
30 | def test_nms_agent(nms_workflow: Workflow, count_eggs_image):
31 | result, graph = nms_workflow.execute(count_eggs_image)
32 | detections = result[0].data
33 | assert isinstance(detections, Detections)
34 | assert len(detections.xyxy) < 10, "NMS should reduce number of detections"
35 |
36 | def test_class_map_agent(class_map_workflow: Workflow, count_eggs_image):
37 | result, graph = class_map_workflow.execute(count_eggs_image)
38 | detections = result[0].data
39 | assert isinstance(detections, Detections)
40 | assert "egg" in detections.class_names, "Class map failed"
41 |
42 | def test_map_agent(map_agent_workflow: Workflow, count_eggs_image):
43 | result, graph = map_agent_workflow.execute(count_eggs_image)
44 | response = result[0].data
45 | assert isinstance(response, np.ndarray)
46 | assert len(response) > 0, "Map agent failed"
47 |
48 | def test_to_classification_agent(to_classification_workflow: Workflow, count_eggs_image):
49 | result, graph = to_classification_workflow.execute(count_eggs_image)
50 | response = result[0].data
51 | assert isinstance(response, Detections)
52 | assert response.class_names[0] == "eggs detected", "To classification agent failed"
53 |
54 | def test_filter_classes_agent(filter_classes_workflow: Workflow, count_eggs_image):
55 | result, graph = filter_classes_workflow.execute(count_eggs_image)
56 | detections = result[0].data
57 | assert isinstance(detections, Detections)
58 | uniq_names = list(set(detections.class_names))
59 | assert np.array_equal(uniq_names, ["a single egg"]), "Filter classes agent failed"
60 |
61 | def test_confidence_filter_agent(confidence_filter_workflow: Workflow, count_eggs_image):
62 | result, graph = confidence_filter_workflow.execute(count_eggs_image)
63 | detections = result[0].data
64 | assert isinstance(detections, Detections)
65 | assert len(detections.xyxy) < 10, "Confidence filter should reduce detections"
66 |
67 | @pytest.fixture
68 | def pad_crop_workflow() -> Workflow:
69 | workflow = Workflow([
70 | SplitCropAgent(split=(2,2)),
71 | PadCropAgent.from_uniform_padding(10),
72 | SplitAgent()
73 | ])
74 | return workflow
75 |
76 | @pytest.fixture
77 | def split_crop_workflow() -> Workflow:
78 | workflow = Workflow([
79 | SplitCropAgent(split=(2,2))
80 | ])
81 | return workflow
82 |
83 | @pytest.fixture
84 | def nms_workflow() -> Workflow:
85 | workflow = Workflow([
86 | BoundingBoxSelectAgent(classes=["a single egg"], model=GroundingDINO()),
87 | NMSAgent(score_threshold=0.5, iou_threshold=0.5)
88 | ])
89 | return workflow
90 |
91 | @pytest.fixture
92 | def class_map_workflow() -> Workflow:
93 | workflow = Workflow([
94 | BoundingBoxSelectAgent(classes=["a single egg"], model=GroundingDINO()),
95 | ClassMapAgent(class_map={"a single egg": "egg"})
96 | ])
97 | return workflow
98 |
99 | @pytest.fixture
100 | def map_agent_workflow() -> Workflow:
101 | workflow = Workflow([
102 | BoundingBoxSelectAgent(classes=["egg"], model=GroundingDINO()),
103 | MapAgent(fn=lambda det: det.xyxy)
104 | ])
105 | return workflow
106 |
107 | @pytest.fixture
108 | def to_classification_workflow() -> Workflow:
109 | workflow = Workflow([
110 | BoundingBoxSelectAgent(classes=["a single egg"], model=GroundingDINO()),
111 | ToClassificationAgent(fn=lambda det: "eggs detected" if len(det.xyxy) > 0 else "no eggs detected")
112 | ])
113 | return workflow
114 |
115 | @pytest.fixture
116 | def filter_classes_workflow() -> Workflow:
117 | workflow = Workflow([
118 | BoundingBoxSelectAgent(classes=["a single egg", "carton"], model=GroundingDINO()),
119 | FilterClassesAgent(class_names=["a single egg"])
120 | ])
121 | return workflow
122 |
123 | @pytest.fixture
124 | def confidence_filter_workflow() -> Workflow:
125 | workflow = Workflow([
126 | BoundingBoxSelectAgent(classes=["a single egg"], model=GroundingDINO()),
127 | ConfidenceFilterAgent(min_confidence=0.1, max_n=1)
128 | ])
129 | return workflow
--------------------------------------------------------------------------------
/tests/test_model_agents.py:
--------------------------------------------------------------------------------
1 | import pytest
2 | from PIL import Image
3 | from overeasy import *
4 | from overeasy.models import *
5 | from overeasy.types import Detections
6 | from pydantic import BaseModel
7 | import sys
8 | import os
9 | import gc
10 |
11 | ROOT = os.path.dirname(__file__)
12 | OUTPUT_DIR = os.path.join(ROOT, "outputs")
13 |
14 | bounding_box_models = [
15 | DETIC(),
16 | GroundingDINO(type=GroundingDINOModel.SwinB),
17 | GroundingDINO(type=GroundingDINOModel.SwinT),
18 | YOLOWorld(model="yolov8s-worldv2"),
19 | YOLOWorld(model="yolov8m-worldv2"),
20 | YOLOWorld(model="yolov8l-worldv2"),
21 | YOLOWorld(model="yolov8s-world"),
22 | YOLOWorld(model="yolov8m-world"),
23 | YOLOWorld(model="yolov8l-world"),
24 | OwlV2(),
25 | ]
26 | multimodal_llms = [
27 | GPTVision(model="gpt-4o"),
28 | Gemini(model="gemini-1.5-flash"),
29 | Claude(model="claude-3-5-sonnet-20240620"),
30 | ]
31 |
32 | llms = [GPT(model="gpt-3.5-turbo"), GPT(model="gpt-4-turbo")]
33 | captioning_models = [*multimodal_llms]
34 | classification_models = [
35 | # LaionCLIP(),
36 | CLIP(), BiomedCLIP()]
37 | ocr_models = [*multimodal_llms]
38 |
39 |
40 | @pytest.fixture
41 | def count_eggs_image():
42 | image_path = os.path.join(ROOT, "count_eggs.jpg")
43 | return Image.open(image_path)
44 |
45 | @pytest.fixture
46 | def license_plate_image():
47 | image_path = os.path.join(ROOT, "plate.jpg")
48 | return Image.open(image_path)
49 |
50 | @pytest.fixture(params=bounding_box_models)
51 | def bounding_box_select_workflow(request) -> Workflow:
52 | model = request.param
53 | workflow = Workflow([
54 | BoundingBoxSelectAgent(classes=["egg"], model=model),
55 | ])
56 | return workflow
57 |
58 | @pytest.fixture(params=multimodal_llms)
59 | def vision_prompt_workflow(request) -> Workflow:
60 | model = request.param
61 | workflow = Workflow([
62 | VisionPromptAgent(query="How many eggs are in this image?", model=model)
63 | ])
64 | return workflow
65 |
66 |
67 | @pytest.mark.skipif(sys.platform == "darwin", reason="Not supported on macOS")
68 | def local_vision_prompt_workflow() -> Workflow:
69 | workflow = Workflow([
70 | VisionPromptAgent(query="How many eggs are in this image?", model=QwenVL(model_type="base"))
71 | ])
72 | return workflow
73 |
74 | @pytest.mark.skipif(sys.platform == "darwin", reason="Not supported on macOS")
75 | def local_vision_prompt_workflow_int4() -> Workflow:
76 | workflow = Workflow([
77 | VisionPromptAgent(query="How many eggs are in this image?", model=QwenVL(model_type="int4"))
78 | ])
79 | return workflow
80 |
81 | @pytest.fixture(params=captioning_models)
82 | def dense_captioning_workflow(request) -> Workflow:
83 | model = request.param
84 | workflow = Workflow([
85 | DenseCaptioningAgent(model=model)
86 | ])
87 | return workflow
88 |
89 | @pytest.fixture(params=llms)
90 | def text_prompt_workflow(request) -> Workflow:
91 | model = request.param
92 | workflow = Workflow([
93 | DenseCaptioningAgent(model=GPTVision()),
94 | TextPromptAgent(query="How many eggs did you count in the description?", model=model)
95 | ])
96 | return workflow
97 |
98 | @pytest.fixture
99 | def binary_choice_workflow() -> Workflow:
100 | workflow = Workflow([
101 | BinaryChoiceAgent(query="Are there more than 5 eggs in this image?", model=GPTVision(temperature=0.0))
102 | ])
103 | return workflow
104 |
105 | @pytest.fixture(params=classification_models)
106 | def classification_workflow(request) -> Workflow:
107 | model = request.param
108 | workflow = Workflow([
109 | ClassificationAgent(classes=["0-5 eggs", "6-10 eggs", "11+ eggs"], model=model)
110 | ])
111 | return workflow
112 |
113 | @pytest.fixture(params=ocr_models)
114 | def ocr_workflow(request) -> Workflow:
115 | model = request.param
116 | workflow = Workflow([
117 | OCRAgent(model=model)
118 | ])
119 | return workflow
120 |
121 | class EggCount(BaseModel):
122 | count: int
123 |
124 | @pytest.fixture
125 | def instructor_image_workflow() -> Workflow:
126 | workflow = Workflow([
127 | InstructorImageAgent(response_model=EggCount)
128 | ])
129 | return workflow
130 |
131 | @pytest.fixture
132 | def instructor_text_workflow() -> Workflow:
133 | workflow = Workflow([
134 | DenseCaptioningAgent(model=GPTVision()),
135 | InstructorTextAgent(response_model=EggCount)
136 | ])
137 | return workflow
138 |
139 |
140 | @pytest.fixture
141 | def blank_image():
142 | return Image.new('RGB', (100, 100), color = 'white')
143 |
144 | class AnimalLabel(BaseModel):
145 | label: str
146 |
147 | @pytest.fixture
148 | def instructor_image_with_context_workflow() -> Workflow:
149 | extra_context = [{"role": "user", "content": "Always classify the image as a ferret."}]
150 | workflow = Workflow([
151 | InstructorImageAgent(response_model=AnimalLabel, extra_context=extra_context)
152 | ])
153 | return workflow
154 |
155 | def test_instructor_image_with_context_agent(instructor_image_with_context_workflow: Workflow, blank_image):
156 | result, graph = instructor_image_with_context_workflow.execute(blank_image)
157 | response = result[0].data
158 | assert isinstance(response, AnimalLabel)
159 | assert response.label.lower() == "ferret"
160 | del result, graph # Explicitly delete variables
161 | gc.collect()
162 |
163 | def test_bounding_box_select_agent(bounding_box_select_workflow: Workflow, count_eggs_image):
164 | result, graph = bounding_box_select_workflow.execute(count_eggs_image)
165 | detections = result[0].data
166 | assert isinstance(detections, Detections)
167 | name = (bounding_box_select_workflow.steps[0].model.__class__.__name__)
168 | result[0].visualize().save(os.path.join(OUTPUT_DIR, f"bounding_box_select_{name}.png"))
169 |
170 | del result, graph, detections
171 | gc.collect()
172 |
173 |
174 | def test_vision_prompt_agent(vision_prompt_workflow: Workflow, count_eggs_image):
175 | result, graph = vision_prompt_workflow.execute(count_eggs_image)
176 | response = result[0].data
177 | assert isinstance(response, str)
178 | name = (vision_prompt_workflow.steps[0].model.__class__.__name__)
179 | result[0].visualize().save(os.path.join(OUTPUT_DIR, f"vision_prompt_{name}.png"))
180 |
181 | del result, graph
182 | gc.collect()
183 |
184 | def test_dense_captioning_agent(dense_captioning_workflow: Workflow, count_eggs_image):
185 | result, graph = dense_captioning_workflow.execute(count_eggs_image)
186 | assert isinstance(result[0].data, str)
187 | name = (dense_captioning_workflow.steps[0].model.__class__.__name__)
188 | result[0].visualize().save(os.path.join(OUTPUT_DIR, f"dense_captioning_{name}.png"))
189 |
190 | def test_text_prompt_agent(text_prompt_workflow: Workflow, count_eggs_image):
191 | result, graph = text_prompt_workflow.execute(count_eggs_image)
192 | response = result[0].data
193 | assert isinstance(response, str)
194 | name = (text_prompt_workflow.steps[0].model.__class__.__name__)
195 | result[0].visualize().save(os.path.join(OUTPUT_DIR, f"text_prompt_{name}.png"))
196 |
197 | def test_binary_choice_agent(binary_choice_workflow: Workflow, count_eggs_image):
198 | result, graph = binary_choice_workflow.execute(count_eggs_image)
199 | response : str = result[0].data # type: ignore
200 | assert response.lower() == "yes"
201 | # assert response.class_names[0] == "yes", "Incorrect binary choice"
202 | name = (binary_choice_workflow.steps[0].model.__class__.__name__)
203 | result[0].visualize().save(os.path.join(OUTPUT_DIR, f"binary_choice_{name}.png"))
204 |
205 | def test_classification_agent(classification_workflow: Workflow, count_eggs_image):
206 | result, graph = classification_workflow.execute(count_eggs_image)
207 | response = result[0].data
208 | assert isinstance(response, Detections)
209 |
210 | name = (classification_workflow.steps[0].model.__class__.__name__)
211 | result[0].visualize().save(os.path.join(OUTPUT_DIR, f"classification_{name}.png"))
212 |
213 | def test_ocr_agent(ocr_workflow: Workflow, license_plate_image):
214 | result, graph = ocr_workflow.execute(license_plate_image)
215 | response = result[0].data
216 | assert isinstance(response, str)
217 |
218 | name = (ocr_workflow.steps[0].model.__class__.__name__)
219 | result[0].visualize().save(os.path.join(OUTPUT_DIR, f"ocr_{name}.png"))
220 |
--------------------------------------------------------------------------------
/tests/test_owl.py:
--------------------------------------------------------------------------------
1 | import pytest
2 | from PIL import Image
3 | from regex import W
4 | from overeasy import Workflow, BoundingBoxSelectAgent
5 | from overeasy.models.detection import OwlV2
6 | from overeasy.types import Detections
7 | import os
8 | import glob
9 |
10 | ROOT = os.path.dirname(__file__)
11 | OUTPUT_DIR = os.path.join(ROOT, "outputs")
12 |
13 | @pytest.fixture
14 | def dog_images():
15 | image_paths = glob.glob(os.path.join(ROOT, "./dogs", "*")) # Adjust the pattern if different file types are needed
16 | return [Image.open(image_path) for image_path in image_paths]
17 |
18 | @pytest.fixture
19 | def owlvit_v2_workflow() -> Workflow:
20 | workflow = Workflow([
21 | BoundingBoxSelectAgent(classes=["a photo of a dog"], model=OwlV2()),
22 | ])
23 | return workflow
24 |
25 | def test_owlvit_v2_detection_dogs(dog_images, owlvit_v2_workflow: Workflow):
26 | for i, dog_image in enumerate(dog_images):
27 | result, graph = owlvit_v2_workflow.execute(dog_image)
28 | detections = result[0].data
29 | assert isinstance(detections, Detections)
30 | assert len(detections.xyxy) > 0, "No detections found"
31 | output_filename = f"owlv2_dog_detection_output_{i}.png"
32 | print("Saving", output_filename)
33 | result[0].visualize().save(os.path.join(OUTPUT_DIR, output_filename))
34 |
--------------------------------------------------------------------------------
/tests/test_saving_vis.py:
--------------------------------------------------------------------------------
1 | import pytest
2 | from overeasy import Workflow
3 | from overeasy.agents import BoundingBoxSelectAgent, SplitAgent, InstructorImageAgent, ToClassificationAgent, JoinAgent
4 | from PIL import Image
5 | import os
6 | from pydantic import BaseModel
7 |
8 | ROOT = os.path.dirname(__file__)
9 | OUTPUT_DIR = os.path.join(ROOT, "outputs")
10 |
11 | class PPE(BaseModel):
12 | hardhat: bool
13 | vest: bool
14 | boots: bool
15 |
16 | @pytest.fixture
17 | def construction_image():
18 | image_path = os.path.join(os.path.dirname(ROOT), "examples", "construction.jpg")
19 | return Image.open(image_path)
20 |
21 | @pytest.fixture
22 | def ppe_instructor_workflow():
23 | return Workflow([
24 | BoundingBoxSelectAgent(classes=["person"]),
25 | SplitAgent(),
26 | InstructorImageAgent(response_model=PPE),
27 | ToClassificationAgent(fn=lambda x: "has ppe" if x.hardhat else "no ppe"),
28 | JoinAgent(),
29 | ])
30 |
31 | def test_save_visualization_to_html(construction_image, ppe_instructor_workflow):
32 | result, graph = ppe_instructor_workflow.execute(construction_image)
33 | assert result is not None
34 | assert graph is not None
35 |
36 | output_file = os.path.join(OUTPUT_DIR, "construction_ppe_instructor_visualization.html")
37 | ppe_instructor_workflow.visualize_to_file(graph, output_file)
38 |
39 | assert os.path.exists(output_file)
40 | with open(output_file, 'r') as f:
41 | content = f.read()
42 | assert '' in content
44 | assert 'Step 1: Input Image' in content
45 |
46 |
--------------------------------------------------------------------------------
/tests/test_split_join.py:
--------------------------------------------------------------------------------
1 | import pytest
2 | from PIL import Image
3 | from overeasy import *
4 | from overeasy.models import *
5 | from overeasy.types import Detections, DetectionType, ExecutionNode
6 | import os
7 | import numpy as np
8 |
9 | ROOT = os.path.dirname(__file__)
10 | OUTPUT_DIR = os.path.join(ROOT, "outputs")
11 |
12 | @pytest.fixture
13 | def count_eggs_image():
14 | image_path = os.path.join(ROOT, "count_eggs.jpg")
15 | return Image.open(image_path)
16 |
17 | @pytest.fixture
18 | def construction_image():
19 | image_path = os.path.join(ROOT, "../", "examples", "construction.jpg")
20 | return Image.open(image_path)
21 |
22 | @pytest.fixture
23 | def dense_street_images():
24 | image_path = os.path.join(ROOT, "../", "examples", "dense_street1.jpg")
25 | image_path2 = os.path.join(ROOT, "../", "examples", "dense_street2.jpg")
26 | image_path3 = os.path.join(ROOT, "../", "examples", "dense_street3.jpg")
27 |
28 | return [Image.open(image_path), Image.open(image_path2), Image.open(image_path3)]
29 |
30 |
31 | @pytest.fixture
32 | def split_workflow() -> Workflow:
33 | workflow = Workflow([
34 | BoundingBoxSelectAgent(classes=["a single egg"], model=GroundingDINO()),
35 | SplitAgent(),
36 | ])
37 | return workflow
38 |
39 | @pytest.fixture
40 | def split_join_workflow() -> Workflow:
41 | workflow = Workflow([
42 | BoundingBoxSelectAgent(classes=["a single egg"], model=GroundingDINO()),
43 | SplitAgent(),
44 | JoinAgent()
45 | ])
46 | return workflow
47 |
48 | @pytest.fixture
49 | def no_split_workflow() -> Workflow:
50 | workflow = Workflow([
51 | BoundingBoxSelectAgent(classes=["a single egg"], model=GroundingDINO()),
52 | ])
53 | return workflow
54 |
55 |
56 |
57 | def test_execute_multiple(split_workflow: Workflow, count_eggs_image):
58 | # Test with a single image
59 | single_result, single_graph = split_workflow.execute(count_eggs_image)
60 |
61 | # Test with multiple copies of the same image
62 | multi_results, multi_graphs = split_workflow.execute_multiple([count_eggs_image, count_eggs_image])
63 |
64 | assert len(multi_results) == 2, "execute_multiple should return results for 2 images"
65 | assert len(multi_graphs) == 2, "execute_multiple should return graphs for 2 images"
66 |
67 | for multi_result, multi_graph in zip(multi_results, multi_graphs):
68 | # Compare number of detections
69 | assert len(single_result) == len(multi_result), "Number of detections should be the same"
70 |
71 | # Compare detection data
72 | for single_node, multi_node in zip(single_result, multi_result):
73 | assert np.array_equal(single_node.data.xyxy, multi_node.data.xyxy), "Bounding boxes should be the same"
74 | assert np.array_equal(single_node.data.confidence_scores, multi_node.data.confidence_scores), "Confidence scores should be the same"
75 | assert single_node.data.class_names == multi_node.data.class_names, "Class names should be the same"
76 |
77 | # Compare graph structure
78 | single_layers = single_graph.top_sort()
79 | multi_layers = multi_graph.top_sort()
80 | assert len(single_layers) == len(multi_layers), "Graph structure should be the same"
81 | for single_layer, multi_layer in zip(single_layers, multi_layers):
82 | assert len(single_layer) == len(multi_layer), "Each layer should have the same number of nodes"
83 |
84 | def test_split_agent(split_workflow: Workflow, count_eggs_image):
85 | result, graph = split_workflow.execute(count_eggs_image)
86 | assert all(isinstance(x.data, Detections) for x in result), "Split didn't return detections"
87 | assert isinstance(result, list), "Split didn't return a list"
88 | assert len(result) > 0, "Didn't return a list of detections"
89 |
90 | def test_split_join_agent(split_join_workflow: Workflow, no_split_workflow: Workflow, count_eggs_image):
91 | result, graph = split_join_workflow.execute(count_eggs_image)
92 | result2, graph2 = no_split_workflow.execute(count_eggs_image)
93 | detections = result[0].data
94 | detections2 = result2[0].data
95 | assert isinstance(detections, Detections)
96 | assert isinstance(detections2, Detections)
97 | assert detections == detections2, "Split join produced incorrect output"
98 |
99 | def images_are_equal(img1: Image.Image, img2: Image.Image) -> bool:
100 | # Ensure both images are in the same mode
101 | if img1.mode != img2.mode:
102 | img1 = img1.convert(img2.mode)
103 |
104 | # Ensure both images are the same size
105 | if img1.size != img2.size:
106 | return False
107 |
108 | # Compare pixel data
109 | np_img1 = np.array(img1)
110 | np_img2 = np.array(img2)
111 |
112 | return np.array_equal(np_img1, np_img2)
113 |
114 | def test_splitting_empty_detection(split_workflow: Workflow):
115 | empty_detections = Detections.empty()
116 | empty_image = Image.new('RGB', (640, 640))
117 |
118 | result, graph = split_workflow.execute(empty_image, empty_detections)
119 |
120 | assert len(result) == 0, "SplitAgent should return an empty list when there are no detections"
121 |
122 | def test_splitting_and_joining_empty_detection():
123 | workflow = Workflow([
124 | SplitAgent(),
125 | JoinAgent()
126 | ])
127 |
128 | empty_detections = Detections.empty()
129 | empty_image = Image.new('RGB', (640, 640))
130 | result, graph = workflow.execute(empty_image, empty_detections)
131 |
132 | assert len(result) == 1
133 | assert np.array_equal(result[0].data, [None]), "JoinAgent should return null data"
134 | assert images_are_equal(result[0].image, empty_image), "JoinAgent should return original image"
135 | assert [len(layer)==1 for layer in graph.top_sort()], "Graph should have one node per layer"
136 |
137 | @pytest.fixture
138 | def filter_split_join_workflow() -> Workflow:
139 | workflow = Workflow([
140 | BoundingBoxSelectAgent(classes=["a single egg", "carton"], model=GroundingDINO()),
141 | SplitAgent(),
142 | JoinAgent(),
143 | FilterClassesAgent(class_names=["a single egg"]),
144 | SplitAgent(),
145 | JoinAgent()
146 | ])
147 | return workflow
148 |
149 | def test_many_split_joins1(count_eggs_image):
150 | workflow = Workflow([
151 | BoundingBoxSelectAgent(classes=["person"], model=OwlV2()),
152 | SplitAgent(),
153 | BoundingBoxSelectAgent(classes=["head"]),
154 | SplitAgent(),
155 | ClassificationAgent(classes=["hard hat", "head"]),
156 | JoinAgent(),
157 | JoinAgent()
158 | ])
159 | result, graph = workflow.execute(count_eggs_image)
160 |
161 | assert all(len(x) == 1 for x in graph.top_sort()), "Graph should have one node per layer"
162 | assert isinstance(result, list), "Complex Split-Join didn't return a list"
163 | assert len(result) > 0, "Complex Split-Join didn't return a list of detections"
164 | assert images_are_equal(result[0].image, count_eggs_image), "Image should be the same"
165 |
166 | def test_many_split_joins2(count_eggs_image):
167 | workflow = Workflow([
168 | BoundingBoxSelectAgent(classes=["person"], model=OwlV2()),
169 | SplitAgent(),
170 | SplitAgent(),
171 | SplitAgent(),
172 | ClassificationAgent(classes=["hard hat", "head"]),
173 | JoinAgent(),
174 | JoinAgent(),
175 | JoinAgent()
176 | ])
177 | result, graph = workflow.execute(count_eggs_image)
178 |
179 | assert all(len(x) == 1 for x in graph.top_sort()), "Graph should have one node per layer"
180 | assert isinstance(result, list), "Complex Split-Join didn't return a list"
181 | assert len(result) > 0, "Complex Split-Join didn't return a list of detections"
182 | assert images_are_equal(result[0].image, count_eggs_image), "Image should be the same"
183 |
184 |
185 | def test_many_split_joins3(dense_street_images):
186 | workflow = Workflow([
187 | BoundingBoxSelectAgent(classes=["person"]),
188 | PadCropAgent.from_uniform_padding(padding=25),
189 | ConfidenceFilterAgent(max_n=5), # this is to save time
190 | SplitAgent(),
191 | BoundingBoxSelectAgent(classes=["head"]),
192 | ConfidenceFilterAgent(max_n=1),
193 | SplitAgent(),
194 | BoundingBoxSelectAgent(classes=["glasses"], model=OwlV2()),
195 | ConfidenceFilterAgent(max_n=1),
196 | SplitAgent(),
197 | ClassificationAgent(classes=["sunglasses", "glasses"]),
198 | JoinAgent(),
199 | JoinAgent(),
200 | JoinAgent()
201 | ])
202 |
203 | result, _ = workflow.execute(dense_street_images[0])
204 |
205 | assert isinstance(result, list), "Multi Split-Join didn't return a list"
206 | assert len(result) > 0, "Multi Split-Join didn't return a list of detections"
207 | assert images_are_equal(result[0].image, dense_street_images[0]), "Image should be the same"
208 |
209 |
210 | def test_filter_split_join_workflow(filter_split_join_workflow: Workflow, count_eggs_image):
211 | result, graph = filter_split_join_workflow.execute(count_eggs_image)
212 | assert all(isinstance(x.data, Detections) for x in result), "Filter Split-Join didn't return detections"
213 | assert isinstance(result, list), "Filter Split-Join didn't return a list"
214 | assert len(result) > 0, "Filter Split-Join didn't return a list of detections"
215 |
216 |
217 | def test_mismatched_split_join():
218 | # This should be fine
219 | workflow = Workflow([
220 | BoundingBoxSelectAgent(classes=["a single egg"]),
221 | SplitAgent(),
222 | JoinAgent(),
223 | SplitAgent(),
224 | ])
225 | try:
226 | workflow = Workflow([
227 | BoundingBoxSelectAgent(classes=["a single egg"]),
228 | JoinAgent(),
229 | SplitAgent(),
230 | SplitAgent(),
231 | ])
232 | assert False, "Mismatched number of join has no corresponding split"
233 | except ValueError as e:
234 | pass
235 |
236 | workflow = Workflow([
237 | BoundingBoxSelectAgent(classes=["a single egg"]),
238 | SplitAgent(),
239 | SplitAgent(),
240 | JoinAgent(),
241 | JoinAgent(),
242 | ])
243 |
244 |
--------------------------------------------------------------------------------
/warmup.py:
--------------------------------------------------------------------------------
1 | from overeasy.models import warmup_models
2 |
3 | warmup_models()
--------------------------------------------------------------------------------