├── .devcontainer └── devcontainer.json ├── .dockerignore ├── .gitignore ├── .pre-commit-config.yaml ├── Dockerfile ├── LICENSE ├── README.md ├── assets ├── ares_configs.png ├── ares_hero_plot.png ├── ares_hero_plot_extras.png ├── ares_structured_curation.png ├── ares_system_diagram.png ├── ares_trajectories.png ├── ares_unstructured_curation.png ├── pi_demos_eval.png └── pi_demos_fig.png ├── data └── .gitkeep ├── main.py ├── mongo-docker-compose.yml ├── notebooks ├── annotations_nb.ipynb └── eval_nb.ipynb ├── pyproject.toml ├── requirements.txt ├── scripts ├── __init__.py ├── annotating │ ├── run_grounding.py │ ├── run_icl.py │ ├── run_pseudo_ecot.py │ └── run_success_criteria.py ├── db_updaters │ ├── __init__.py │ ├── annotation_db_updater.py │ └── strutured_db_updater.py ├── eval.py ├── pi_demo_ingestion.py ├── release │ ├── README.md │ ├── hf_hub_readme.md │ ├── pull_from_hub.sh │ └── push_to_hub.py ├── run_structured_ingestion.py ├── run_trajectory_embedding_ingestion.py └── self_heal.py ├── setup.py └── src └── ares ├── __init__.py ├── annotating ├── __init__.py ├── annotating_base.py ├── annotating_fn.py ├── modal_base.py ├── modal_grounding.py └── orchestration.py ├── app ├── __init__.py ├── annotation_viz_helpers.py ├── data_analysis.py ├── export_data.py ├── filter_helpers.py ├── hero_display.py ├── init_data.py ├── plot_primitives.py ├── sections.py ├── viz_helpers.py └── webapp.py ├── configs ├── __init__.py ├── annotations.py ├── base.py ├── open_x_embodiment_configs.py └── pydantic_sql_helpers.py ├── constants.py ├── databases ├── __init__.py ├── annotation_database.py ├── embedding_database.py └── structured_database.py ├── extras ├── __init__.py ├── oxe.csv └── pi_demo_utils.py ├── models ├── __init__.py ├── base.py ├── extractor.py ├── grounding.py ├── grounding_utils.py ├── prompts │ ├── extractor_prompt.jinja2 │ ├── grounding_description.jinja2 │ ├── icl.jinja2 │ ├── pseudo_ecot.jinja2 │ ├── simple_video_eval.jinja2 │ ├── success_constraint_generation.jinja2 │ ├── summarization_frame_eval.jinja2 │ ├── summarizing.jinja2 │ └── task_frame_description.jinja2 ├── refusal.py ├── sampling_bias.py └── shortcuts.py ├── training ├── README.md ├── __init__.py ├── preprocess.py └── train.py └── utils ├── __init__.py ├── clustering.py └── image_utils.py /.devcontainer/devcontainer.json: -------------------------------------------------------------------------------- 1 | { 2 | "name": "Python 3", 3 | "mounts": [ 4 | "source=${localEnv:HOME}/cert.pem,target=/etc/ssl/certs/ca-certificates.crt,type=bind,readonly,consistency=cached", 5 | "source=${localEnv:HOME}/.config/gcloud,target=/root/.config/gcloud,type=bind,consistency=cached", 6 | "source=${localWorkspaceFolder}/data,target=/workspaces/ares/data,type=bind,consistency=cached", 7 | "source=/tmp,target=/tmp,type=bind", 8 | "source=${localEnv:HOME}/.cache/huggingface,target=/cache/huggingface,type=bind,consistency=cached", 9 | "source=${localEnv:HOME}/.modal.toml,target=/root/.modal.toml,type=bind,consistency=cached" 10 | ], 11 | "build": { 12 | "dockerfile": "../Dockerfile", 13 | "context": ".." 14 | }, 15 | "customizations": { 16 | "vscode": { 17 | "extensions": [ 18 | "ms-python.python", 19 | "ms-python.vscode-pylance", 20 | "ms-python.black-formatter", 21 | "ms-toolsai.jupyter", 22 | "ms-azuretools.vscode-docker", 23 | "ms-python.isort", 24 | "matangover.mypy", 25 | "mongodb.mongodb-vscode" 26 | ], 27 | "settings": { 28 | "[python]": { 29 | "editor.defaultFormatter": "ms-python.black-formatter", 30 | "editor.formatOnSave": true 31 | }, 32 | "files.watcherExclude": { 33 | "**/*": true, 34 | "**/data/**": true, 35 | "**/cache/**": true, 36 | "**/tmp/**": true, 37 | "**/.git/objects/**": true, 38 | "**/.git/subtree-cache/**": true 39 | } 40 | } 41 | } 42 | }, 43 | "forwardPorts": [27017], 44 | "remoteUser": "root", 45 | "remoteEnv": { 46 | "SSL_CERT_FILE": "/etc/ssl/certs/ca-certificates.crt", 47 | "REQUESTS_CA_BUNDLE": "/etc/ssl/certs/ca-certificates.crt", 48 | "CURL_CA_BUNDLE": "/etc/ssl/certs/ca-certificates.crt", 49 | "GOOGLE_APPLICATION_CREDENTIALS": "/root/.config/gcloud/application_default_credentials.json", 50 | "OPENAI_API_KEY": "${localEnv:OPENAI_API_KEY}", 51 | "ANTHROPIC_API_KEY": "${localEnv:ANTHROPIC_API_KEY}", 52 | "HUGGINGFACE_API_KEY": "${localEnv:HUGGINGFACE_API_KEY}", 53 | "VERTEX_PROJECT": "${localEnv:VERTEX_PROJECT}", 54 | "VERTEX_LOCATION": "${localEnv:VERTEX_LOCATION}", 55 | "GEMINI_API_KEY": "${localEnv:GEMINI_API_KEY}", 56 | "MONGODB_URI": "mongodb://localhost:27017", 57 | "TRANSFORMERS_CACHE": "/cache/huggingface" 58 | }, 59 | "runArgs": ["--network=host"], 60 | "postCreateCommand": "pip install black pylint jupyter ipykernel mypy isort pymongo --retries 10" 61 | } 62 | -------------------------------------------------------------------------------- /.dockerignore: -------------------------------------------------------------------------------- 1 | data/ -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | share/python-wheels/ 24 | *.egg-info/ 25 | .installed.cfg 26 | *.egg 27 | MANIFEST 28 | 29 | # PyInstaller 30 | # Usually these files are written by a python script from a template 31 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 32 | *.manifest 33 | *.spec 34 | 35 | # Installer logs 36 | pip-log.txt 37 | pip-delete-this-directory.txt 38 | 39 | # Unit test / coverage reports 40 | htmlcov/ 41 | .tox/ 42 | .nox/ 43 | .coverage 44 | .coverage.* 45 | .cache 46 | nosetests.xml 47 | coverage.xml 48 | *.cover 49 | *.py,cover 50 | .hypothesis/ 51 | .pytest_cache/ 52 | cover/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | .pybuilder/ 76 | target/ 77 | 78 | # Jupyter Notebook 79 | .ipynb_checkpoints 80 | 81 | # IPython 82 | profile_default/ 83 | ipython_config.py 84 | 85 | # pyenv 86 | # For a library or package, you might want to ignore these files since the code is 87 | # intended to run in multiple environments; otherwise, check them in: 88 | # .python-version 89 | 90 | # pipenv 91 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 92 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 93 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 94 | # install all needed dependencies. 95 | #Pipfile.lock 96 | 97 | # poetry 98 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. 99 | # This is especially recommended for binary packages to ensure reproducibility, and is more 100 | # commonly ignored for libraries. 101 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control 102 | #poetry.lock 103 | 104 | # pdm 105 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. 106 | #pdm.lock 107 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it 108 | # in version control. 109 | # https://pdm.fming.dev/latest/usage/project/#working-with-version-control 110 | .pdm.toml 111 | .pdm-python 112 | .pdm-build/ 113 | 114 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm 115 | __pypackages__/ 116 | 117 | # Celery stuff 118 | celerybeat-schedule 119 | celerybeat.pid 120 | 121 | # SageMath parsed files 122 | *.sage.py 123 | 124 | # Environments 125 | .env 126 | .venv 127 | env/ 128 | venv/ 129 | ENV/ 130 | env.bak/ 131 | venv.bak/ 132 | 133 | # Spyder project settings 134 | .spyderproject 135 | .spyproject 136 | 137 | # Rope project settings 138 | .ropeproject 139 | 140 | # mkdocs documentation 141 | /site 142 | 143 | # mypy 144 | .mypy_cache/ 145 | .dmypy.json 146 | dmypy.json 147 | 148 | # Pyre type checker 149 | .pyre/ 150 | 151 | # pytype static type analyzer 152 | .pytype/ 153 | 154 | # Cython debug symbols 155 | cython_debug/ 156 | 157 | # PyCharm 158 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can 159 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore 160 | # and can be added to the global gitignore or merged into this file. For a more nuclear 161 | # option (not recommended) you can uncomment the following to ignore the entire idea folder. 162 | #.idea/ 163 | 164 | # ARES-specific files 165 | *.db 166 | data/* 167 | !data/.gitkeep 168 | .DS_Store -------------------------------------------------------------------------------- /.pre-commit-config.yaml: -------------------------------------------------------------------------------- 1 | repos: 2 | - repo: https://github.com/psf/black 3 | rev: 24.2.0 4 | hooks: 5 | - id: black 6 | language_version: python3 7 | 8 | - repo: https://github.com/pycqa/isort 9 | rev: 5.13.2 10 | hooks: 11 | - id: isort 12 | -------------------------------------------------------------------------------- /Dockerfile: -------------------------------------------------------------------------------- 1 | FROM python:3.10 2 | 3 | WORKDIR /ares 4 | 5 | # Create cache directory and set environment variable 6 | ENV TRANSFORMERS_CACHE=/cache/huggingface 7 | RUN mkdir -p /cache/huggingface 8 | 9 | RUN apt-get update && apt-get install -y \ 10 | libgl1-mesa-glx \ 11 | htop \ 12 | wkhtmltopdf 13 | 14 | COPY requirements.txt requirements.txt 15 | RUN pip install -r requirements.txt --retries 10 16 | 17 | COPY . . 18 | 19 | RUN pip install -e . 20 | 21 | # start in bash for interactive containers 22 | CMD ["/bin/bash"] -------------------------------------------------------------------------------- /assets/ares_configs.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jacobphillips99/ares/21850d69242f9106c186338b5883667562f4442f/assets/ares_configs.png -------------------------------------------------------------------------------- /assets/ares_hero_plot.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jacobphillips99/ares/21850d69242f9106c186338b5883667562f4442f/assets/ares_hero_plot.png -------------------------------------------------------------------------------- /assets/ares_hero_plot_extras.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jacobphillips99/ares/21850d69242f9106c186338b5883667562f4442f/assets/ares_hero_plot_extras.png -------------------------------------------------------------------------------- /assets/ares_structured_curation.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jacobphillips99/ares/21850d69242f9106c186338b5883667562f4442f/assets/ares_structured_curation.png -------------------------------------------------------------------------------- /assets/ares_system_diagram.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jacobphillips99/ares/21850d69242f9106c186338b5883667562f4442f/assets/ares_system_diagram.png -------------------------------------------------------------------------------- /assets/ares_trajectories.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jacobphillips99/ares/21850d69242f9106c186338b5883667562f4442f/assets/ares_trajectories.png -------------------------------------------------------------------------------- /assets/ares_unstructured_curation.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jacobphillips99/ares/21850d69242f9106c186338b5883667562f4442f/assets/ares_unstructured_curation.png -------------------------------------------------------------------------------- /assets/pi_demos_eval.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jacobphillips99/ares/21850d69242f9106c186338b5883667562f4442f/assets/pi_demos_eval.png -------------------------------------------------------------------------------- /assets/pi_demos_fig.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jacobphillips99/ares/21850d69242f9106c186338b5883667562f4442f/assets/pi_demos_fig.png -------------------------------------------------------------------------------- /data/.gitkeep: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jacobphillips99/ares/21850d69242f9106c186338b5883667562f4442f/data/.gitkeep -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | import asyncio 2 | import os 3 | import typing as t 4 | 5 | import tensorflow_datasets as tfds 6 | from sqlalchemy import Engine 7 | 8 | from ares.annotating.orchestration import orchestrate_annotating 9 | from ares.configs.open_x_embodiment_configs import get_dataset_information 10 | from ares.constants import ARES_DATA_DIR, ARES_OXE_DIR, DATASET_NAMES 11 | from ares.databases.annotation_database import ANNOTATION_DB_PATH 12 | from ares.databases.embedding_database import EMBEDDING_DB_PATH 13 | from ares.databases.structured_database import ( 14 | ROBOT_DB_PATH, 15 | RolloutSQLModel, 16 | setup_database, 17 | setup_rollouts, 18 | ) 19 | from ares.models.base import Embedder 20 | from ares.models.shortcuts import get_nomic_embedder 21 | from scripts.annotating.run_grounding import GroundingModalAnnotatingFn 22 | from scripts.run_structured_ingestion import ( 23 | build_dataset, 24 | run_structured_database_ingestion, 25 | ) 26 | from scripts.run_trajectory_embedding_ingestion import ( 27 | run_embedding_database_ingestion_per_dataset, 28 | ) 29 | 30 | 31 | def run_ingestion_pipeline( 32 | ds: t.Iterator, 33 | dataset_info: dict, 34 | dataset_formalname: str, 35 | vlm_name: str, 36 | engine: Engine, 37 | dataset_filename: str, 38 | embedder: Embedder, 39 | split: str, 40 | ) -> dict[str, list[dict]]: 41 | """ 42 | Helper function to run the ingestion pipeline for a given dataset. 43 | Currently, this means ingesting structured data, embedding rollouts, and annotating rollouts. 44 | """ 45 | # run structured ingestion 46 | structured_failures, new_rollout_ids = asyncio.run( 47 | run_structured_database_ingestion( 48 | ds, 49 | dataset_info, 50 | dataset_formalname, 51 | vlm_name, 52 | engine, 53 | dataset_filename, 54 | ) 55 | ) 56 | 57 | # we can't accumulate rollouts and episodes in memory at the same time, so save rollouts 58 | # to db and videos to disk then reconstitute rollouts for indexing 59 | rollouts = setup_rollouts(engine, dataset_formalname) 60 | if new_rollout_ids is not None: 61 | rollouts = [r for r in rollouts if r.id in new_rollout_ids] 62 | 63 | if len(rollouts) == 0: 64 | raise ValueError(f"No rollouts found for {dataset_formalname} in {split}") 65 | run_embedding_database_ingestion_per_dataset( 66 | rollouts, embedder, index_path=EMBEDDING_DB_PATH 67 | ) 68 | 69 | # run grounding annotation with modal 70 | annotation_results, grounding_failures = orchestrate_annotating( 71 | engine_path=ROBOT_DB_PATH, 72 | ann_db_path=ANNOTATION_DB_PATH, 73 | annotating_fn=GroundingModalAnnotatingFn(), 74 | rollout_ids=[str(r.id) for r in rollouts], 75 | failures_path=os.path.join( 76 | ARES_DATA_DIR, 77 | "annotating_failures", 78 | f"grounding_{dataset_filename}_{split}.pkl", 79 | ), 80 | ) 81 | return dict( 82 | structured_failures=structured_failures, 83 | grounding_failures=[f.__dict__ for f in grounding_failures], 84 | ) 85 | 86 | 87 | if __name__ == "__main__": 88 | vlm_name = "gpt-4o" 89 | engine = setup_database(RolloutSQLModel, path=ROBOT_DB_PATH) 90 | embedder = get_nomic_embedder() 91 | 92 | for i, dataset_info in enumerate(DATASET_NAMES): 93 | dataset_filename = dataset_info["dataset_filename"] 94 | dataset_formalname = dataset_info["dataset_formalname"] 95 | builder, dataset_dict = build_dataset(dataset_filename, ARES_OXE_DIR) 96 | print( 97 | f"working on {dataset_formalname} with splits {list(dataset_dict.keys())}" 98 | ) 99 | 100 | for split in dataset_dict.keys(): 101 | ds = dataset_dict[split] 102 | print(f"found {len(ds)} episodes in {split}") 103 | dataset_info = get_dataset_information(dataset_filename) 104 | 105 | # hardcode a few additional fields 106 | dataset_info["Dataset Filename"] = dataset_filename 107 | dataset_info["Dataset Formalname"] = dataset_formalname 108 | dataset_info["Split"] = split 109 | 110 | failures = run_ingestion_pipeline( 111 | ds, 112 | dataset_info, 113 | dataset_formalname, 114 | vlm_name, 115 | engine, 116 | dataset_filename, 117 | embedder, 118 | split, 119 | ) 120 | -------------------------------------------------------------------------------- /mongo-docker-compose.yml: -------------------------------------------------------------------------------- 1 | services: 2 | mongodb: 3 | image: mongo:latest 4 | restart: unless-stopped 5 | ports: 6 | - "27017:27017" 7 | volumes: 8 | - mongodb_data:/data/db 9 | 10 | volumes: 11 | mongodb_data: 12 | -------------------------------------------------------------------------------- /notebooks/annotations_nb.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 1, 6 | "metadata": {}, 7 | "outputs": [], 8 | "source": [ 9 | "%load_ext autoreload\n", 10 | "%autoreload 2" 11 | ] 12 | }, 13 | { 14 | "cell_type": "code", 15 | "execution_count": 2, 16 | "metadata": {}, 17 | "outputs": [], 18 | "source": [ 19 | "import pickle\n", 20 | "from ares.databases.annotation_database import AnnotationDatabase, ANNOTATION_DB_PATH\n", 21 | "from ares.utils.image_utils import load_video_frames\n", 22 | "\n", 23 | "from ares.app.annotation_viz_helpers import draw_annotations\n", 24 | "import matplotlib.pyplot as plt\n", 25 | "from ares.configs.annotations import Annotation\n", 26 | "import numpy as np\n", 27 | "\n", 28 | "\n", 29 | "target_fps = 1\n", 30 | "db = AnnotationDatabase(connection_string=ANNOTATION_DB_PATH)\n", 31 | "peek = db.peek_database()" 32 | ] 33 | }, 34 | { 35 | "cell_type": "code", 36 | "execution_count": 48, 37 | "metadata": {}, 38 | "outputs": [], 39 | "source": [ 40 | "from collections import defaultdict\n", 41 | "\n", 42 | "def display_annotations(frame_to_annotations, frames, frame_indices):\n", 43 | " shared_frame_indices = sorted(set(frame_to_annotations.keys()) & set(frame_indices))\n", 44 | " side_len = int(np.ceil(np.sqrt(len(shared_frame_indices)))) # Calculate grid dimensions\n", 45 | "\n", 46 | " # Create a figure with subplots in a square grid\n", 47 | " fig = plt.figure(figsize=(15, 15))\n", 48 | " for i, frame_idx in enumerate(shared_frame_indices):\n", 49 | " frame_anns = frame_to_annotations[frame_idx]\n", 50 | " frame_num = frame_indices.index(frame_idx)\n", 51 | " im = draw_annotations(frames[frame_num], frame_anns)\n", 52 | " \n", 53 | " # Create subplot in grid\n", 54 | " plt.subplot(side_len, side_len, i + 1)\n", 55 | " plt.title(f\"Frame idx {frame_idx}\")\n", 56 | " plt.axis('off')\n", 57 | " plt.imshow(im)\n", 58 | "\n", 59 | " plt.tight_layout()\n", 60 | " plt.show()" 61 | ] 62 | }, 63 | { 64 | "cell_type": "code", 65 | "execution_count": 46, 66 | "metadata": {}, 67 | "outputs": [], 68 | "source": [ 69 | "def get_and_display(dataset_name, fname, target_fps):\n", 70 | " frames, frame_indices = load_video_frames(dataset_name, fname, target_fps)\n", 71 | " video_id = f\"{dataset_name}/{fname}.mp4\"\n", 72 | " print(f\"searching for {video_id}\")\n", 73 | " annotations = db.get_annotations(video_id=video_id)\n", 74 | " display_annotations(annotations['detection'], frames, frame_indices)" 75 | ] 76 | }, 77 | { 78 | "cell_type": "code", 79 | "execution_count": 41, 80 | "metadata": {}, 81 | "outputs": [], 82 | "source": [ 83 | "dataset_name = \"ucsd_kitchen_dataset_converted_externally_to_rlds\"\n", 84 | "fnames = [f\"data/train/episode_{i}\" for i in range(5)]" 85 | ] 86 | }, 87 | { 88 | "cell_type": "code", 89 | "execution_count": null, 90 | "metadata": {}, 91 | "outputs": [], 92 | "source": [ 93 | "for fname in fnames: \n", 94 | " get_and_display(dataset_name, fname, target_fps)" 95 | ] 96 | } 97 | ], 98 | "metadata": { 99 | "kernelspec": { 100 | "display_name": "Python 3", 101 | "language": "python", 102 | "name": "python3" 103 | }, 104 | "language_info": { 105 | "codemirror_mode": { 106 | "name": "ipython", 107 | "version": 3 108 | }, 109 | "file_extension": ".py", 110 | "mimetype": "text/x-python", 111 | "name": "python", 112 | "nbconvert_exporter": "python", 113 | "pygments_lexer": "ipython3", 114 | "version": "3.10.16" 115 | } 116 | }, 117 | "nbformat": 4, 118 | "nbformat_minor": 2 119 | } 120 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [build-system] 2 | requires = ["setuptools>=45", "wheel"] 3 | build-backend = "setuptools.build_meta" 4 | 5 | [tool.black] 6 | line-length = 88 7 | target-version = ["py310"] 8 | include = '\.pyi?$' 9 | extend-exclude = ''' 10 | # A regex preceded with ^/ will apply only to files and directories 11 | # in the root of the project. 12 | /( 13 | \.eggs 14 | | \.git 15 | | \.hg 16 | | \.mypy_cache 17 | | \.tox 18 | | \.venv 19 | | _build 20 | | buck-out 21 | | build 22 | | dist 23 | )/ 24 | ''' 25 | force-exclude = ''' 26 | # Exclude long strings, comments, and docstrings 27 | (?x) 28 | ( 29 | \"\"\".*?\"\"\"| 30 | \'\'\'.*?\'\'\'| 31 | \#.* 32 | ) 33 | ''' 34 | 35 | [tool.isort] 36 | profile = "black" 37 | multi_line_output = 3 38 | line_length = 88 39 | include_trailing_comma = true 40 | force_grid_wrap = 0 41 | use_parentheses = true 42 | ensure_newline_before_comments = true 43 | 44 | [tool.mypy] 45 | python_version = "3.10" 46 | warn_return_any = true 47 | warn_unused_configs = true 48 | disallow_untyped_defs = true 49 | disallow_incomplete_defs = true 50 | exclude = ["data/"] 51 | 52 | [tool.pytest.ini_options] 53 | testpaths = ["tests"] 54 | python_files = ["test_*.py"] 55 | addopts = "-ra -q" 56 | 57 | [tool.flake8] 58 | max-line-length = 88 59 | extend-ignore = ["E501"] 60 | extend-select = ["E501"] 61 | per-file-ignores = ''' 62 | # Ignore E501 only for lines containing strings/comments 63 | *: E501 # type: ignore 64 | ''' -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | # core libraries 2 | torch>=2.2.0 3 | torchvision>=0.16.0 4 | tensorflow==2.15.0 5 | 6 | # supporting libraries 7 | transformers==4.40.1 8 | tokenizers==0.19.1 9 | sqlmodel 10 | timm==0.9.10 11 | Pillow 12 | imageio 13 | litellm 14 | opencv-python 15 | python-Levenshtein 16 | sentencepiece 17 | einops 18 | pycocotools 19 | sentence-transformers 20 | tenacity 21 | modal 22 | 23 | # datasets and utilities 24 | datasets 25 | tensorflow-datasets==4.9.3 26 | moviepy==1.0.3 27 | pymongo 28 | 29 | # llm providers 30 | vertexai 31 | 32 | # data viz 33 | streamlit 34 | pandas 35 | plotly 36 | umap-learn 37 | hdbscan 38 | kaleido 39 | pdfkit 40 | matplotlib 41 | openpyxl 42 | 43 | # database 44 | faiss-cpu # replace with faiss-gpu if available -------------------------------------------------------------------------------- /scripts/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jacobphillips99/ares/21850d69242f9106c186338b5883667562f4442f/scripts/__init__.py -------------------------------------------------------------------------------- /scripts/annotating/run_grounding.py: -------------------------------------------------------------------------------- 1 | """ 2 | Orchestration script to run grounding annotations using Modal. 3 | """ 4 | 5 | import asyncio 6 | import os 7 | import traceback 8 | import typing as t 9 | 10 | import numpy as np 11 | from tqdm import tqdm 12 | 13 | from ares.annotating.annotating_base import ErrorResult, ResultTracker 14 | from ares.annotating.annotating_fn import AnnotatingFn 15 | from ares.annotating.modal_grounding import GroundingModalWrapper 16 | from ares.annotating.orchestration import orchestrate_annotating 17 | from ares.configs.annotations import Annotation 18 | from ares.configs.base import Rollout 19 | from ares.constants import ( 20 | ANNOTATION_GROUNDING_FPS, 21 | ANNOTATION_OUTER_BATCH_SIZE, 22 | ARES_DATA_DIR, 23 | ) 24 | from ares.databases.annotation_database import ANNOTATION_DB_PATH, AnnotationDatabase 25 | from ares.databases.structured_database import ROBOT_DB_PATH 26 | from ares.models.base import VLM 27 | from ares.models.grounding_utils import get_grounding_nouns_async 28 | from ares.models.refusal import check_refusal 29 | from ares.models.shortcuts import get_gpt_4o 30 | from ares.utils.image_utils import load_video_frames 31 | 32 | 33 | async def run_annotate_and_ingest( 34 | annotator: GroundingModalWrapper, 35 | rollout_ids: list[str], 36 | annotation_input_futures: list[t.Any], 37 | db: AnnotationDatabase, 38 | rollouts: list[t.Any], 39 | ) -> tuple[ResultTracker, list[ErrorResult]]: 40 | """ 41 | Run annotation tasks in parallel using Modal. 42 | 43 | Args: 44 | annotator (GroundingModalWrapper): Modal wrapper for grounding tasks. 45 | rollout_ids (list[str]): List of rollout IDs. 46 | annotation_input_futures (list[t.Any]): List of asyncio Tasks for preparing annotation inputs. 47 | db (AnnotationDatabase): Database instance for storing annotations. 48 | rollouts (list[t.Any]): List of rollout objects. 49 | 50 | Returns: 51 | tuple[ResultTracker, list[ErrorResult]]: Tracker and list of failures. 52 | """ 53 | id_to_rollout = {r.id: r for r in rollouts} 54 | id_to_annotation_inputs = {} 55 | tracker = ResultTracker() 56 | failures = [] 57 | 58 | # Await all preparation tasks 59 | results = await asyncio.gather(*annotation_input_futures, return_exceptions=True) 60 | 61 | tasks = [] 62 | 63 | for res in results: 64 | if isinstance(res, ErrorResult): 65 | failures.append(res) 66 | continue 67 | else: 68 | rollout_id, frames, frame_indices, label_str = res 69 | print( 70 | f"Received grounding output for {rollout_id}: {len(frames)} frames, " 71 | f"{len(frame_indices)} frame indices, label str: {label_str}" 72 | ) 73 | id_to_annotation_inputs[rollout_id] = ( 74 | rollout_id, 75 | frames, 76 | frame_indices, 77 | label_str, 78 | ) 79 | # Prepare annotation task 80 | tasks.append((rollout_id, frames, label_str)) 81 | 82 | # Submit annotation tasks to Modal 83 | annotation_results = await annotator.annotate_videos(tasks) 84 | for rollout_id, all_frame_annotations in annotation_results: 85 | try: 86 | rollout = id_to_rollout[rollout_id] 87 | _, frames, frame_indices, label_str = id_to_annotation_inputs[rollout_id] 88 | all_frame_annotation_objs = [ 89 | [Annotation(**ann) for ann in frame_annotations] 90 | for frame_annotations in all_frame_annotations 91 | ] 92 | video_id = db.add_video_with_annotations( 93 | dataset_filename=rollout.dataset_filename, 94 | video_path=rollout.filename + ".mp4", 95 | frames=frames, 96 | frame_indices=frame_indices, 97 | annotations=all_frame_annotation_objs, 98 | label_str=label_str, 99 | ) 100 | tracker.update_via_batch( 101 | n_videos=1, 102 | n_frames=len(all_frame_annotations), 103 | n_annotations=sum( 104 | len(frame_annotations) 105 | for frame_annotations in all_frame_annotations 106 | ), 107 | video_ids=[video_id], 108 | ) 109 | except Exception as e: 110 | failures.append( 111 | ErrorResult( 112 | rollout_id=rollout_id, 113 | error_pattern="grounding_failure", 114 | error=traceback.format_exc(), 115 | exception=str(e), 116 | ) 117 | ) 118 | 119 | return tracker, failures 120 | 121 | 122 | async def setup_query( 123 | rollout: t.Any, 124 | vlm: VLM, 125 | target_fps: int = 5, 126 | ) -> tuple[str, list[np.ndarray], list[int], str] | ErrorResult: 127 | """ 128 | Prepare annotation inputs for a rollout. 129 | 130 | Args: 131 | rollout (t.Any): Rollout object. 132 | vlm (VLM): Vision-Language Model instance. 133 | target_fps (int, optional): Target FPS for frame extraction. 134 | 135 | Returns: 136 | tuple[str, list[np.ndarray], list[int], str] | dict[str, t.Any]: Prepared data or error dict. 137 | """ 138 | try: 139 | frames, frame_indices = load_video_frames( 140 | rollout.dataset_filename, 141 | rollout.filename, 142 | target_fps, 143 | ) 144 | except Exception as e: 145 | return ErrorResult( 146 | rollout_id=str(rollout.id), 147 | error_pattern="grounding_failure", 148 | error=traceback.format_exc(), 149 | exception=str(e), 150 | ) 151 | 152 | try: 153 | label_str = await get_grounding_nouns_async( 154 | vlm, 155 | frames[0], 156 | rollout.task.language_instruction, 157 | ) 158 | except Exception as e: 159 | return ErrorResult( 160 | rollout_id=str(rollout.id), 161 | error_pattern="grounding_request_failure", 162 | error=traceback.format_exc(), 163 | exception=str(e), 164 | ) 165 | 166 | if check_refusal(label_str): 167 | return ErrorResult( 168 | rollout_id=str(rollout.id), 169 | error_pattern="grounding_request_failure", 170 | error=f"Refusal phrase triggered: '{label_str}'", 171 | exception=None, 172 | ) 173 | return rollout.id, frames, frame_indices, label_str 174 | 175 | 176 | async def run_ground_and_annotate( 177 | rollouts: list[t.Any], 178 | vlm: VLM, 179 | ann_db: AnnotationDatabase, 180 | annotator: GroundingModalWrapper, 181 | target_fps: int = ANNOTATION_GROUNDING_FPS, 182 | ) -> tuple[ResultTracker, list[ErrorResult]]: 183 | """ 184 | Process, ground, and annotate list of rollouts. 185 | 186 | Args: 187 | rollouts (list[t.Any]): List of rollout objects. 188 | vlm (VLM): Vision-Language Model instance. 189 | ann_db (AnnotationDatabase): Annotation database instance. 190 | target_fps (int, optional): Target FPS for annotation. 191 | 192 | Returns: 193 | tuple[dict[str, int], list[dict]]: Tracker and list of failures. 194 | """ 195 | rollout_ids = [r.id for r in rollouts] 196 | 197 | # Create and gather the futures properly 198 | annotation_input_futures = [ 199 | asyncio.create_task(setup_query(rollout, vlm, target_fps)) 200 | for rollout in rollouts 201 | ] 202 | 203 | tracker, failures = await run_annotate_and_ingest( 204 | annotator, 205 | rollout_ids, 206 | annotation_input_futures, 207 | ann_db, 208 | rollouts, 209 | ) 210 | return tracker, failures 211 | 212 | 213 | class GroundingModalAnnotatingFn(AnnotatingFn): 214 | def __call__( 215 | self, 216 | rollouts: list[Rollout], 217 | ann_db: AnnotationDatabase, 218 | outer_batch_size: int, 219 | annotation_fps: int = ANNOTATION_GROUNDING_FPS, 220 | ) -> tuple[ResultTracker, list[ErrorResult]]: 221 | """ 222 | Main function to run grounding annotation using Modal. 223 | """ 224 | # initialize objects for batches 225 | overall_tracker = ResultTracker() 226 | overall_failures = [] 227 | 228 | annotator = GroundingModalWrapper() 229 | with annotator.app.run(): 230 | # Limited by CPU RAM (can't create all requests at once) 231 | for i in tqdm( 232 | range(0, len(rollouts), outer_batch_size), 233 | desc="Processing outer batches", 234 | ): 235 | print( 236 | f"Processing batch {i // outer_batch_size + 1} of {len(rollouts) // outer_batch_size}" 237 | ) 238 | # create VLM outside async as semaphore gets "bound" to async context 239 | vlm = get_gpt_4o() 240 | rollouts_batch = rollouts[i : i + outer_batch_size] 241 | tracker, failures = asyncio.run( 242 | run_ground_and_annotate( 243 | rollouts_batch, 244 | vlm, 245 | ann_db, 246 | annotator, 247 | annotation_fps, 248 | ) 249 | ) 250 | print( 251 | f"Completed batch {i // outer_batch_size + 1} of {max(1, len(rollouts) // outer_batch_size)}" 252 | ) 253 | overall_tracker.update_tracker(tracker) 254 | overall_failures.extend(failures) 255 | return overall_tracker, overall_failures 256 | 257 | 258 | if __name__ == "__main__": 259 | ids_path = ( 260 | "/workspaces/ares/data/heal_info/2025-01-27_22-04-01/update_grounding_ids.txt" 261 | ) 262 | orchestrate_annotating( 263 | engine_path=ROBOT_DB_PATH, 264 | ann_db_path=ANNOTATION_DB_PATH, 265 | annotating_fn=GroundingModalAnnotatingFn(), 266 | ids_path=ids_path, 267 | outer_batch_size=ANNOTATION_OUTER_BATCH_SIZE, 268 | annotating_kwargs=dict( 269 | annotation_fps=ANNOTATION_GROUNDING_FPS, 270 | ), 271 | failures_path=os.path.join( 272 | ARES_DATA_DIR, "annotating_failures", f"grounding_failures.pkl" 273 | ), 274 | ) 275 | -------------------------------------------------------------------------------- /scripts/annotating/run_icl.py: -------------------------------------------------------------------------------- 1 | """ 2 | Inspired by `InstantPolicy` and `R+X: Retrieval and Execution from Everyday Human Videos`, we demonstrate creating rollout annotations for 3 | in-context learning by retrieving similar rollouts from the dataset. We can do this for several keys, such as the task, the text 4 | description of the rollout, or the state and action trajectories of the robot. 5 | """ 6 | 7 | import os 8 | import traceback 9 | 10 | from sqlalchemy import Engine 11 | 12 | from ares.annotating.annotating_base import ErrorResult, ResultTracker 13 | from ares.annotating.annotating_fn import APIAnnotatingFn 14 | from ares.annotating.orchestration import orchestrate_annotating 15 | from ares.configs.base import Rollout 16 | from ares.constants import ANNOTATION_OUTER_BATCH_SIZE, ARES_DATA_DIR, DATASET_NAMES 17 | from ares.databases.annotation_database import ANNOTATION_DB_PATH, AnnotationDatabase 18 | from ares.databases.embedding_database import ( 19 | EMBEDDING_DB_PATH, 20 | META_INDEX_NAMES, 21 | TRAJECTORY_INDEX_NAMES, 22 | FaissIndex, 23 | IndexManager, 24 | rollout_to_index_name, 25 | ) 26 | from ares.databases.structured_database import ( 27 | ROBOT_DB_PATH, 28 | RolloutSQLModel, 29 | get_rollouts_by_ids, 30 | setup_database, 31 | ) 32 | from ares.models.base import VLM, parse_response 33 | from ares.utils.image_utils import load_video_frames 34 | 35 | 36 | class ICLAnnotatingFn(APIAnnotatingFn): 37 | """ 38 | Object to orchestrate the retrieval of similar rollouts (per key) to facilitate ICL. 39 | Given the databases, we can retrieve similar rollouts for each key and then pull their example_field to use as context. 40 | 41 | For example, we can retrieve similar rollouts for the `task` key and `states` key and then pull their `description_estimate` to use as context. 42 | """ 43 | 44 | def __init__( 45 | self, 46 | index_manager: IndexManager, 47 | engine: Engine, 48 | keys: list[str], 49 | n_examples_per_key: int = 5, 50 | example_field: str = "description_estimate", 51 | ): 52 | super().__init__(annotation_key="string", annotation_type="icl") 53 | self.index_manager = index_manager 54 | self.engine = engine 55 | self.keys = keys 56 | self.n_examples_per_key = n_examples_per_key 57 | self.example_field = example_field 58 | 59 | def construct_example_values(self, rollout: Rollout) -> dict[str, list[str]]: 60 | output_vals = dict() 61 | # Initialize set of used IDs with the current rollout ID 62 | already_used_ids = {str(rollout.id)} 63 | 64 | for key in self.keys: 65 | if key not in META_INDEX_NAMES: 66 | index_name = rollout_to_index_name(rollout, suffix=key) 67 | else: 68 | index_name = key 69 | 70 | value = self.index_manager.get_matrix_by_id(index_name, str(rollout.id)) 71 | # Request more IDs to account for potential duplicates 72 | k = self.n_examples_per_key + len(already_used_ids) 73 | dists, ids, _ = self.index_manager.search_matrix(index_name, value, k=k) 74 | 75 | # Filter out already used IDs 76 | new_ids = [] 77 | for id_ in ids: 78 | if ( 79 | id_ not in already_used_ids 80 | and len(new_ids) < self.n_examples_per_key 81 | ): 82 | new_ids.append(id_) 83 | already_used_ids.add(id_) 84 | 85 | example_rollouts = get_rollouts_by_ids(self.engine, new_ids) 86 | example_vals = [ 87 | rollout.get_nested_attr(self.example_field) 88 | for rollout in example_rollouts 89 | ] 90 | display_key = key.replace("_", " ").title() 91 | output_vals[display_key] = example_vals 92 | return output_vals 93 | 94 | async def run_query( 95 | self, vlm: VLM, rollout: Rollout, ann_db: AnnotationDatabase 96 | ) -> str | ErrorResult: 97 | try: 98 | frames, frame_indices = load_video_frames( 99 | rollout.dataset_filename, 100 | rollout.filename, 101 | target_fps=0, 102 | ) 103 | except Exception as e: 104 | return ErrorResult( 105 | rollout_id=str(rollout.id), 106 | error_pattern="loading_video_failure", 107 | error=traceback.format_exc(), 108 | exception=str(e), 109 | ) 110 | try: 111 | example_values = self.construct_example_values(rollout) 112 | info = { 113 | "task": rollout.get_nested_attr("task_language_instruction"), 114 | "examples": example_values, 115 | } 116 | messages, res = await vlm.ask_async( 117 | info=info, 118 | prompt_filename="icl.jinja2", 119 | images=[frames[0]], 120 | ) 121 | icl_str = parse_response(res.choices[0], load_json=False) 122 | except Exception as e: 123 | return ErrorResult( 124 | rollout_id=str(rollout.id), 125 | error_pattern="icl_parsing_failure", 126 | error=traceback.format_exc(), 127 | exception=str(e), 128 | ) 129 | return icl_str 130 | 131 | 132 | if __name__ == "__main__": 133 | index_manager = IndexManager(EMBEDDING_DB_PATH, FaissIndex) 134 | engine = setup_database(RolloutSQLModel, path=ROBOT_DB_PATH) 135 | 136 | # dont use description estimate 137 | keys = META_INDEX_NAMES + TRAJECTORY_INDEX_NAMES 138 | keys = [key for key in keys if key != "description_estimate"] 139 | n_examples_per_key = 5 140 | overall_tracker = ResultTracker() 141 | overall_failures = [] 142 | 143 | for dataset_info in DATASET_NAMES: 144 | print(f"Processing {dataset_info['dataset_formalname']}") 145 | dataset_filename = dataset_info["dataset_filename"] 146 | tracker, failures = orchestrate_annotating( 147 | engine_path=ROBOT_DB_PATH, 148 | ann_db_path=ANNOTATION_DB_PATH, 149 | annotating_fn=ICLAnnotatingFn( 150 | index_manager=index_manager, 151 | engine=engine, 152 | keys=keys, 153 | n_examples_per_key=n_examples_per_key, 154 | ), 155 | dataset_filename=dataset_filename, 156 | outer_batch_size=ANNOTATION_OUTER_BATCH_SIZE, 157 | failures_path=os.path.join( 158 | ARES_DATA_DIR, 159 | "annotating_failures", 160 | f"icl_failures_{dataset_filename}.pkl", 161 | ), 162 | ) 163 | overall_tracker.update_tracker(tracker) 164 | overall_failures.extend(failures) 165 | 166 | print(f"OVERALL STATS") 167 | overall_tracker.print_stats() 168 | print(f"Number of failures: {len(overall_failures)}") 169 | -------------------------------------------------------------------------------- /scripts/annotating/run_pseudo_ecot.py: -------------------------------------------------------------------------------- 1 | """ 2 | Orchestration script to run a simpler version of the 'Embodied Chain of Thought' paper. 3 | See original code https://github.com/MichalZawalski/embodied-CoT/blob/main/scripts/generate_embodied_data/full_reasonings.py 4 | 5 | We utilize the `grounding_string`, `detections`, and `success_criteria` annotations (see other annotating scripts!) + rollout fields to generate a pseudo-ECoT in a similar fashion. 6 | """ 7 | 8 | import os 9 | import traceback 10 | import typing as t 11 | 12 | from ares.annotating.annotating_base import ErrorResult, ResultTracker 13 | from ares.annotating.annotating_fn import APIAnnotatingFn 14 | from ares.annotating.orchestration import orchestrate_annotating 15 | from ares.configs.annotations import Annotation 16 | from ares.configs.base import Rollout 17 | from ares.constants import ANNOTATION_OUTER_BATCH_SIZE, ARES_DATA_DIR, DATASET_NAMES 18 | from ares.databases.annotation_database import ( 19 | ANNOTATION_DB_PATH, 20 | AnnotationDatabase, 21 | get_video_id, 22 | ) 23 | from ares.databases.structured_database import ROBOT_DB_PATH 24 | from ares.models.base import VLM, parse_response 25 | from ares.utils.image_utils import load_video_frames 26 | 27 | 28 | def construct_pseudo_ecot_info( 29 | rollout: Rollout, ann_db: AnnotationDatabase 30 | ) -> dict[str, t.Any]: 31 | """ 32 | Collect all the composable annotations and rollout fields to generate a pseudo-ECoT prompt 33 | """ 34 | video_id = get_video_id(rollout.dataset_filename, rollout.filename) 35 | anns = ann_db.get_annotations(video_id) 36 | 37 | # get the annotations from the database 38 | grounding_string = anns.get("grounding_string")[0].description 39 | grounding_string = ", ".join(grounding_string.split(".")[:-1]) + "." 40 | success_criteria = anns.get("success_criteria")[0].description 41 | detections = anns.get("detection") 42 | 43 | # separate and parse the detections into a string 44 | frame_0_detections: list[Annotation] = detections[0] if detections else [] 45 | if frame_0_detections: 46 | text_detections = ", ".join( 47 | f"{d.category_name} at {[round(r, 2) for r in d.bbox]} (LTRB format)" 48 | for d in frame_0_detections 49 | ) 50 | else: 51 | text_detections = None 52 | return dict( 53 | task=rollout.task.language_instruction, 54 | complexity_category_estimate=rollout.task.complexity_category_estimate, 55 | grounding_string=grounding_string, 56 | detections=text_detections, 57 | success_criteria=success_criteria, 58 | ) 59 | 60 | 61 | class PseudoECoTAnnotatingFn(APIAnnotatingFn): 62 | def __init__(self) -> None: 63 | super().__init__(annotation_key="string", annotation_type="pseudo_ecot") 64 | 65 | async def run_query( 66 | self, vlm: VLM, rollout: Rollout, ann_db: AnnotationDatabase 67 | ) -> str | ErrorResult: 68 | try: 69 | frames, _ = load_video_frames( 70 | rollout.dataset_filename, 71 | rollout.filename, 72 | target_fps=0, 73 | ) 74 | except Exception as e: 75 | return ErrorResult( 76 | rollout_id=str(rollout.id), 77 | error_pattern="loading_video_failure", 78 | error=traceback.format_exc(), 79 | exception=str(e), 80 | ) 81 | try: 82 | info = construct_pseudo_ecot_info(rollout, ann_db) 83 | messages, res = await vlm.ask_async( 84 | info=info, 85 | prompt_filename="pseudo_ecot.jinja2", 86 | images=[frames[0]], 87 | ) 88 | pseudo_ecot_str = parse_response(res.choices[0], load_json=False) 89 | except Exception as e: 90 | return ErrorResult( 91 | rollout_id=str(rollout.id), 92 | error_pattern="pseudo_ecot_failure", 93 | error=traceback.format_exc(), 94 | exception=str(e), 95 | ) 96 | return pseudo_ecot_str 97 | 98 | 99 | if __name__ == "__main__": 100 | overall_tracker = ResultTracker() 101 | overall_failures = [] 102 | 103 | for dataset_info in DATASET_NAMES: 104 | print(f"Processing {dataset_info['dataset_formalname']}") 105 | dataset_filename = dataset_info["dataset_filename"] 106 | tracker, failures = orchestrate_annotating( 107 | engine_path=ROBOT_DB_PATH, 108 | ann_db_path=ANNOTATION_DB_PATH, 109 | annotating_fn=PseudoECoTAnnotatingFn(), 110 | dataset_filename=dataset_filename, 111 | outer_batch_size=ANNOTATION_OUTER_BATCH_SIZE, 112 | failures_path=os.path.join( 113 | ARES_DATA_DIR, 114 | "annotating_failures", 115 | f"pseudo_ecot_failures_{dataset_filename}.pkl", 116 | ), 117 | ) 118 | overall_tracker.update_tracker(tracker) 119 | overall_failures.extend(failures) 120 | 121 | print(f"OVERALL STATS") 122 | overall_tracker.print_stats() 123 | print(f"Number of failures: {len(overall_failures)}") 124 | -------------------------------------------------------------------------------- /scripts/annotating/run_success_criteria.py: -------------------------------------------------------------------------------- 1 | import os 2 | import traceback 3 | 4 | from ares.annotating.annotating_base import ErrorResult, ResultTracker 5 | from ares.annotating.annotating_fn import APIAnnotatingFn 6 | from ares.annotating.orchestration import orchestrate_annotating 7 | from ares.configs.base import Rollout 8 | from ares.constants import ANNOTATION_OUTER_BATCH_SIZE, ARES_DATA_DIR, DATASET_NAMES 9 | from ares.databases.annotation_database import ANNOTATION_DB_PATH, AnnotationDatabase 10 | from ares.databases.structured_database import ROBOT_DB_PATH 11 | from ares.models.base import VLM, parse_response 12 | from ares.utils.image_utils import load_video_frames 13 | 14 | 15 | class SuccessCriteriaAnnotatingFn(APIAnnotatingFn): 16 | def __init__(self) -> None: 17 | super().__init__(annotation_key="string", annotation_type="success_criteria") 18 | 19 | async def run_query( 20 | self, vlm: VLM, rollout: Rollout, ann_db: AnnotationDatabase 21 | ) -> str | ErrorResult: 22 | try: 23 | frames, _ = load_video_frames( 24 | rollout.dataset_filename, 25 | rollout.filename, 26 | target_fps=0, 27 | ) 28 | except Exception as e: 29 | return ErrorResult( 30 | rollout_id=str(rollout.id), 31 | error_pattern="loading_video_failure", 32 | error=traceback.format_exc(), 33 | exception=str(e), 34 | ) 35 | try: 36 | _, res = await vlm.ask_async( 37 | info=dict(task=rollout.task.language_instruction), 38 | prompt_filename="success_constraint_generation.jinja2", 39 | images=[frames[0]], 40 | ) 41 | success_criteria = parse_response(res.choices[0], load_json=False) 42 | except Exception as e: 43 | return ErrorResult( 44 | rollout_id=str(rollout.id), 45 | error_pattern="success_constraint_generation_failure", 46 | error=traceback.format_exc(), 47 | exception=str(e), 48 | ) 49 | return success_criteria 50 | 51 | 52 | if __name__ == "__main__": 53 | overall_tracker = ResultTracker() 54 | overall_failures = [] 55 | 56 | for dataset_info in DATASET_NAMES: 57 | print(f"Processing {dataset_info['dataset_formalname']}") 58 | dataset_filename = dataset_info["dataset_filename"] 59 | tracker, failures = orchestrate_annotating( 60 | engine_path=ROBOT_DB_PATH, 61 | ann_db_path=ANNOTATION_DB_PATH, 62 | annotating_fn=SuccessCriteriaAnnotatingFn(), 63 | dataset_filename=dataset_filename, 64 | outer_batch_size=ANNOTATION_OUTER_BATCH_SIZE, 65 | failures_path=os.path.join( 66 | ARES_DATA_DIR, 67 | "annotating_failures", 68 | f"success_criteria_failures_{dataset_filename}.pkl", 69 | ), 70 | ) 71 | overall_tracker.update_tracker(tracker) 72 | overall_failures.extend(failures) 73 | 74 | print(f"OVERALL STATS") 75 | overall_tracker.print_stats() 76 | print(f"Number of failures: {len(overall_failures)}") 77 | -------------------------------------------------------------------------------- /scripts/db_updaters/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jacobphillips99/ares/21850d69242f9106c186338b5883667562f4442f/scripts/db_updaters/__init__.py -------------------------------------------------------------------------------- /scripts/db_updaters/annotation_db_updater.py: -------------------------------------------------------------------------------- 1 | from tqdm import tqdm 2 | 3 | from ares.configs.annotations import Annotation 4 | from ares.databases.annotation_database import ANNOTATION_DB_PATH, AnnotationDatabase 5 | 6 | 7 | def migrate(): 8 | db = AnnotationDatabase(connection_string=ANNOTATION_DB_PATH) 9 | 10 | # Get all videos 11 | video_ids = db.get_video_ids() 12 | print(f"Found {len(video_ids)} videos to process") 13 | 14 | for video_id in tqdm(video_ids): 15 | # Get annotations for this video 16 | annotations = db.get_annotations(video_id, annotation_type="success_criteria") 17 | if not annotations or "success_criteria" not in annotations: 18 | print(f"Skipping {video_id} - no success_criteria annotations found") 19 | continue 20 | 21 | # Get the original string from the list of annotations 22 | string_value = None 23 | for ann in annotations["success_criteria"]: 24 | if isinstance(ann, str): 25 | string_value = ann 26 | break 27 | 28 | if string_value is None: 29 | print(f"Skipping {video_id} - no string annotation found") 30 | continue 31 | 32 | # Create new Annotation object with the original string 33 | annotation_obj = Annotation( 34 | description=string_value, annotation_type="success_criteria" 35 | ) 36 | 37 | # Delete all existing success_criteria annotations for this video 38 | db.delete_annotations(video_id, annotation_type="success_criteria") 39 | 40 | # Add the new annotation 41 | db.add_annotation( 42 | video_id=video_id, 43 | key="string", 44 | value=annotation_obj, 45 | annotation_type="success_criteria", 46 | frame=None, 47 | ) 48 | 49 | print(f"Processed {video_id}") 50 | 51 | print("Migration complete!") 52 | 53 | 54 | if __name__ == "__main__": 55 | migrate() 56 | -------------------------------------------------------------------------------- /scripts/db_updaters/strutured_db_updater.py: -------------------------------------------------------------------------------- 1 | """ 2 | Helper script to amend values in the structured database. At a high level, we can add a new column or cell to the database by specifying the row identifiers (the id_keys) and the new value. 3 | For example, we can add a new column (such as rollout.environment.data_collection_method) by specifying the id_keys (e.g. dataset_name, path) and the new value (calculated from the dataset_information). 4 | The defaults just ensure that the new column is populated with a default value for all rows that don't have a specific value. 5 | """ 6 | 7 | import numpy as np 8 | 9 | from ares.configs.open_x_embodiment_configs import get_dataset_information 10 | from ares.constants import DATASET_NAMES 11 | from ares.databases.structured_database import ( 12 | ROBOT_DB_PATH, 13 | RolloutSQLModel, 14 | add_column_with_vals_and_defaults, 15 | get_all_rollouts, 16 | setup_database, 17 | ) 18 | 19 | engine = setup_database(RolloutSQLModel, path=ROBOT_DB_PATH) 20 | rollouts = get_all_rollouts(engine) 21 | 22 | DATASET_INFOS = dict() 23 | for data_names in DATASET_NAMES: 24 | dfilename = data_names["dataset_filename"] 25 | dformalname = data_names["dataset_formalname"] 26 | dataset_info = get_dataset_information(dfilename) 27 | DATASET_INFOS[dformalname] = dataset_info 28 | 29 | 30 | if __name__ == "__main__": 31 | id_keys = ["dataset_name", "path"] 32 | 33 | new_cols_flat_names = ["environment_data_collection_method"] 34 | new_cols_flat_types = [str] 35 | default_vals = [None] 36 | 37 | for i in range(len(new_cols_flat_names)): 38 | input_mapping = dict() 39 | for rollout in rollouts: 40 | new_val = DATASET_INFOS[rollout.dataset_formalname]["Data Collect Method"] 41 | input_mapping[tuple(getattr(rollout, k) for k in id_keys)] = new_val 42 | 43 | print(f"prepped {len(input_mapping)} to add to db:") 44 | print(f"e.g. {set(np.random.choice(list(input_mapping.values()), 50))}") 45 | print(f"under new name {new_cols_flat_names[i]}") 46 | print("...confirm? Press c to continue.") 47 | breakpoint() # breakpoint to check things look right before updating db 48 | 49 | add_column_with_vals_and_defaults( 50 | engine=engine, 51 | new_column_name=new_cols_flat_names[i], 52 | python_type=new_cols_flat_types[i], 53 | default_value=default_vals[i], 54 | key_mapping_col_names=id_keys, 55 | specific_key_mapping_values=input_mapping, 56 | ) 57 | print(f"added {new_cols_flat_names[i]}") 58 | -------------------------------------------------------------------------------- /scripts/pi_demo_ingestion.py: -------------------------------------------------------------------------------- 1 | import asyncio 2 | import os 3 | import typing as t 4 | 5 | from tqdm import tqdm 6 | 7 | from ares.constants import ( 8 | ARES_DATA_DIR, 9 | ARES_OXE_DIR, 10 | DATASET_NAMES, 11 | get_dataset_info_by_key, 12 | ) 13 | from ares.databases.annotation_database import ANNOTATION_DB_PATH 14 | from ares.databases.embedding_database import EMBEDDING_DB_PATH 15 | from ares.databases.structured_database import ( 16 | ROBOT_DB_PATH, 17 | RolloutSQLModel, 18 | setup_database, 19 | ) 20 | from ares.extras.pi_demo_utils import PI_DEMO_TASKS 21 | from ares.models.shortcuts import get_nomic_embedder 22 | from ares.utils.image_utils import get_video_frames 23 | from main import run_ingestion_pipeline 24 | 25 | dataset_filename = "pi_demos" 26 | split = "test" 27 | dataset_info = get_dataset_info_by_key("dataset_filename", dataset_filename) 28 | dataset_formalname = dataset_info["dataset_formalname"] 29 | 30 | full_dataset_info = { 31 | "Dataset": dataset_formalname, 32 | "Dataset Filename": dataset_filename, 33 | "Dataset Formalname": dataset_formalname, 34 | "Split": split, 35 | "Robot": None, 36 | "Robot Morphology": None, 37 | "Gripper": None, 38 | "Action Space": None, 39 | "# RGB Cams": None, 40 | "# Depth Cams": None, 41 | "# Wrist Cams": None, 42 | "Language Annotations": "Natural", 43 | "Data Collect Method": "Expert Policy", 44 | "Scene Type": None, 45 | "Citation": "year={2024}", 46 | } 47 | 48 | 49 | def prep_for_oxe_episode(task_info: dict, success_flag: str) -> dict | None: 50 | """ 51 | Force the PI Demo videos and task information into the OpenXEmbodimentEpisode format. 52 | """ 53 | filename = f"{task_info['filename_prefix']}_{success_flag}" 54 | try: 55 | frames = get_video_frames(dataset_filename="pi_demos", filename=filename) 56 | except Exception as e: 57 | print(f"Error getting video frames for {filename}: {e}") 58 | return None 59 | metadata = {"file_path": filename, "success": success_flag == "success"} 60 | steps = [] 61 | for i, frame in enumerate(frames): 62 | observation = { 63 | "image": frame, 64 | } 65 | steps.append( 66 | { 67 | "image": frame, 68 | "action": None, 69 | "state": None, 70 | "is_first": i == 0, 71 | "is_last": i == len(frames) - 1, 72 | "is_terminal": False, 73 | "language_embedding": None, 74 | "language_instruction": task_info["task"], 75 | "observation": observation, 76 | } 77 | ) 78 | return {"episode_metadata": metadata, "steps": steps} 79 | 80 | 81 | class PiDemoIngestion: 82 | def __init__(self, task_infos: list[dict], success_flags: list[str]): 83 | self.task_infos = task_infos 84 | self.success_flags = success_flags 85 | self._episodes = [] 86 | for task_info in tqdm(task_infos): 87 | for success_flag in success_flags: 88 | episode = prep_for_oxe_episode(task_info, success_flag) 89 | if episode is not None: 90 | self._episodes.append(episode) 91 | self._index = 0 92 | 93 | def __iter__(self) -> "PiDemoIngestion": 94 | self._index = 0 95 | return self 96 | 97 | def __next__(self) -> dict: 98 | if self._index >= len(self._episodes): 99 | raise StopIteration 100 | episode = self._episodes[self._index] 101 | self._index += 1 102 | return episode 103 | 104 | def __len__(self) -> int: 105 | return len(self._episodes) 106 | 107 | 108 | if __name__ == "__main__": 109 | vlm_name = "gpt-4o" 110 | engine = setup_database(RolloutSQLModel, path=ROBOT_DB_PATH) 111 | embedder = get_nomic_embedder() 112 | task_infos = list(PI_DEMO_TASKS.values()) 113 | # the PI Demo videos are enormous, so we can only ingest them one-at-a-time 114 | for task_info in tqdm(task_infos): 115 | for flag in ["success", "fail"]: 116 | print(task_info) 117 | ds = PiDemoIngestion([task_info], [flag]) 118 | run_ingestion_pipeline( 119 | ds, 120 | full_dataset_info, 121 | dataset_formalname, 122 | vlm_name, 123 | engine, 124 | dataset_filename, 125 | embedder, 126 | split, 127 | ) 128 | -------------------------------------------------------------------------------- /scripts/release/README.md: -------------------------------------------------------------------------------- 1 | # Release Scripts 2 | 3 | This directory contains scripts for managing ARES quick-start [data and databases on the Hugging Face Hub](https://huggingface.co/datasets/jacobphillips99/ares-data). The data includes the StructuredDatabase, AnnotationDatabase, EmbeddingDatabase, and videos spanning 5000 Open X-Embodiment demonstations, structured as follows: 4 | 5 | - `robot_data.db`: StructuredDatabase SQLite database containing structured robot data 6 | - `embedding_data`: IndexManager for the EmbeddingDatabase containing FAISS indexes 7 | - `annotation_mongodump`: MongoDB dump of the AnnotationDatabase 8 | - `videos`: Collection of robot demonstration videos and frames 9 | 10 | ## Prerequisites 11 | 12 | - Set the `HUGGINGFACE_API_KEY` environment variable with your Hugging Face access token 13 | - MongoDB installed and running locally (for pull/restore operations) 14 | - MongoDB database tools (mongorestore/mongodump) installed on your system 15 | - **Ubuntu/Debian**: `sudo apt-get install mongodb-database-tools` 16 | - **macOS**: `brew install mongodb/brew/mongodb-database-tools` 17 | - **Windows**: Download and install from the [MongoDB Download Center](https://www.mongodb.com/try/download/database-tools) 18 | - **From source**: Follow instructions at [MongoDB GitHub repository](https://github.com/mongodb/mongo-tools) 19 | - 20 | 21 | ## Scripts 22 | 23 | ### pull_from_hub.sh 24 | 25 | Downloads and restores the ARES data and databases from the Hugging Face Hub (`jacobphillips99/ares-data`). 26 | 27 | ```bash 28 | ./scripts/release/pull_from_hub.sh [output_directory] 29 | ``` 30 | 31 | - Default output directory: `$HOME/ares/data` 32 | - Downloads and extracts all databases and video datasets 33 | - Automatically restores MongoDB dump to your local MongoDB instance 34 | 35 | ### push_to_hub.py 36 | 37 | Uploads the ARES data and databases to the Hugging Face Hub. 38 | 39 | ```bash 40 | python scripts/release/push_to_hub.py 41 | ``` 42 | 43 | - Creates tar archives for folder-based data 44 | - Handles video datasets separately, creating individual archives per dataset 45 | - Supports both direct file uploads and directory uploads 46 | - Creates the Hugging Face repository if it doesn't exist -------------------------------------------------------------------------------- /scripts/release/hf_hub_readme.md: -------------------------------------------------------------------------------- 1 | --- 2 | license: apache-2.0 3 | language: 4 | - en 5 | size_categories: 6 | - 1k /dev/null; then 46 | echo "Error: mongorestore could not be found" 47 | exit 1 48 | fi 49 | mongorestore --uri="mongodb://localhost:27017" "$OUTDIR/annotation_mongodump" 50 | 51 | # Get list of video dataset tars from HF hub and download, unpack, and remove each tar file 52 | echo "downloading videos..." 53 | mkdir -p "$OUTDIR/videos" 54 | echo "fetching video datasets..." 55 | API_URL="https://huggingface.co/api/datasets/$HF_REPO/tree/main/videos" 56 | echo "Fetching from: $API_URL" 57 | for tar_file in $(curl -s -H "Authorization: Bearer $HUGGINGFACE_API_KEY" "$API_URL" | grep -o '"path":"videos/[^"]*\.tar\.gz"' | cut -d'"' -f4 | cut -d'/' -f2); do 58 | echo "Found tar file: $tar_file" 59 | echo "downloading and extracting $tar_file..." 60 | curl -L -H "Authorization: Bearer $HUGGINGFACE_API_KEY" "$HF_DOWNLOAD/videos/$tar_file" -o "$OUTDIR/videos/$tar_file" 61 | tar -xzf "$OUTDIR/videos/$tar_file" -C "$OUTDIR" 62 | rm "$OUTDIR/videos/$tar_file" 63 | done 64 | 65 | echo "done." -------------------------------------------------------------------------------- /scripts/release/push_to_hub.py: -------------------------------------------------------------------------------- 1 | """ 2 | Helper script to push relevant data and databases to the HF Hub for distribution. 3 | All folders are uploaded as tar.gz files. 4 | """ 5 | 6 | import os 7 | import tarfile 8 | from huggingface_hub import HfApi, upload_file 9 | from ares.constants import ARES_DATA_DIR 10 | from ares.databases.structured_database import ( 11 | ROBOT_DB_NAME, 12 | ROBOT_DB_PATH, 13 | RolloutSQLModel, 14 | setup_database, 15 | write_db_to_parquet, 16 | ) 17 | from ares.databases.embedding_database import EMBEDDING_DB_NAME 18 | 19 | # insert your repo name here! 20 | HF_REPO = "jacobphillips99/ares-data" 21 | ANNOTATION_DB_BACKUP_NAME = "annotation_mongodump" 22 | SQL_PARQUET_NAME = "robot_data.parquet" 23 | HF_HUB_README_PATH = "hf_hub_readme.md" 24 | 25 | UPLOAD_CONFIGS = [ 26 | {"type": "file", "source": ROBOT_DB_NAME, "dest": ROBOT_DB_NAME}, 27 | { 28 | "type": "folder", 29 | "source": ANNOTATION_DB_BACKUP_NAME, 30 | "dest": ANNOTATION_DB_BACKUP_NAME, 31 | }, 32 | {"type": "folder", "source": EMBEDDING_DB_NAME, "dest": EMBEDDING_DB_NAME}, 33 | {"type": "folder", "source": "videos", "dest": "videos"}, 34 | {"type": "file", "source": SQL_PARQUET_NAME, "dest": SQL_PARQUET_NAME}, 35 | { 36 | "type": "file", 37 | "source": HF_HUB_README_PATH, 38 | "dest": "README.md", 39 | "dir": os.path.dirname(__file__), 40 | }, 41 | ] 42 | 43 | 44 | def backup_mongodb() -> None: 45 | """Create a MongoDB backup in the data directory.""" 46 | backup_path = os.path.join(ARES_DATA_DIR, ANNOTATION_DB_BACKUP_NAME).replace( 47 | "/workspaces/", "" 48 | ) 49 | print( 50 | f"Please run this command in your shell from root directory outside the container (or press c to skip): \n\n" 51 | f"mongodump --uri=mongodb://localhost:27017 --out={backup_path}\n\n" 52 | f"Press c to continue." 53 | ) 54 | print("MongoDB backup complete") 55 | 56 | 57 | def backup_sqldb_parquet() -> None: 58 | """Create a SQL database parquet file in the data directory.""" 59 | engine = setup_database(RolloutSQLModel, path=ROBOT_DB_PATH) 60 | write_db_to_parquet(engine, os.path.join(ARES_DATA_DIR, SQL_PARQUET_NAME)) 61 | 62 | 63 | def create_tarfile(source_dir: str, output_filename: str) -> str: 64 | """Create a tar archive of the given directory.""" 65 | output_path = os.path.join(ARES_DATA_DIR, output_filename) 66 | with tarfile.open(output_path, "w:gz") as tar: 67 | tar.add( 68 | os.path.join(ARES_DATA_DIR, source_dir), 69 | arcname=os.path.basename(source_dir), 70 | ) 71 | return output_path 72 | 73 | 74 | def upload_to_hf(config: dict[str, str], token: str) -> None: 75 | source_path = os.path.join(config.get("dir", ARES_DATA_DIR), config["source"]) 76 | 77 | if config["type"] == "file": 78 | # Direct file upload 79 | upload_file( 80 | path_or_fileobj=source_path, 81 | path_in_repo=config["dest"], 82 | repo_id=HF_REPO, 83 | repo_type="dataset", 84 | token=token, 85 | ) 86 | else: 87 | # For folders, create tar.gz and upload 88 | if config["source"] == "videos": 89 | # Handle videos directory specially - tar each subdirectory 90 | video_datasets = [ 91 | d 92 | for d in os.listdir(source_path) 93 | if os.path.isdir(os.path.join(source_path, d)) 94 | ] 95 | 96 | print(f"Found {len(video_datasets)} video datasets to upload") 97 | for idx, dataset in enumerate(video_datasets, 1): 98 | print(f"\nProcessing dataset {idx}/{len(video_datasets)}: {dataset}") 99 | dataset_path = os.path.join("videos", dataset) 100 | tar_filename = f"videos_{dataset}.tar.gz" 101 | tar_path = create_tarfile(dataset_path, tar_filename) 102 | 103 | upload_file( 104 | path_or_fileobj=tar_path, 105 | path_in_repo=f"videos/{tar_filename}", 106 | repo_id=HF_REPO, 107 | repo_type="dataset", 108 | token=token, 109 | ) 110 | os.remove(tar_path) 111 | else: 112 | # For other folders, create single tar.gz 113 | tar_filename = f"{config['source']}.tar.gz" 114 | tar_path = create_tarfile(config["source"], tar_filename) 115 | 116 | upload_file( 117 | path_or_fileobj=tar_path, 118 | path_in_repo=tar_filename, 119 | repo_id=HF_REPO, 120 | repo_type="dataset", 121 | token=token, 122 | ) 123 | os.remove(tar_path) 124 | 125 | print(f"Uploaded {config['source']}") 126 | 127 | 128 | if __name__ == "__main__": 129 | token = os.environ.get("HUGGINGFACE_API_KEY") 130 | if not token: 131 | raise ValueError("Please set HUGGINGFACE_API_KEY environment variable") 132 | 133 | api = HfApi(token=token) 134 | 135 | # Check if HF repo exists 136 | matching_datasets = api.list_datasets(search=HF_REPO) 137 | if not any(d.id == HF_REPO for d in matching_datasets): 138 | print(f"Creating repo {HF_REPO}") 139 | api.create_repo(repo_id=HF_REPO, repo_type="dataset") 140 | else: 141 | print(f"Repo {HF_REPO} already exists") 142 | 143 | # Create MongoDB backup 144 | backup_mongodb() 145 | 146 | # Create SQL parquet 147 | backup_sqldb_parquet() 148 | 149 | # Upload each item in the upload config 150 | for item in UPLOAD_CONFIGS: 151 | upload_to_hf(item, token) 152 | 153 | print("Upload complete.") 154 | -------------------------------------------------------------------------------- /scripts/run_trajectory_embedding_ingestion.py: -------------------------------------------------------------------------------- 1 | """ 2 | Helper script to ingest a rollout into the embedding database. We ingest the trajectory matrices (such a states and actions) as well as description and task_language_instruction embeddings. 3 | This enables us to perform efficient nearest neighbor search on the embeddings in order to find similar rollouts in the physical or language space. 4 | """ 5 | 6 | import time 7 | import typing as t 8 | from collections import defaultdict 9 | 10 | import click 11 | import numpy as np 12 | from tqdm import tqdm 13 | 14 | from ares.configs.base import Rollout 15 | from ares.databases.embedding_database import ( 16 | EMBEDDING_DB_PATH, 17 | META_INDEX_NAMES, 18 | STANDARDIZED_TIME_STEPS, 19 | FaissIndex, 20 | IndexManager, 21 | rollout_to_embedding_pack, 22 | ) 23 | from ares.databases.structured_database import ( 24 | RolloutSQLModel, 25 | get_rollouts_by_ids, 26 | setup_database, 27 | setup_rollouts, 28 | ) 29 | from ares.models.base import Embedder 30 | from ares.models.shortcuts import get_nomic_embedder 31 | 32 | 33 | def ingest_trajectory_matrices_from_rollouts_per_dataset( 34 | rollouts: list[Rollout], index_manager: IndexManager 35 | ) -> None: 36 | # collect all embedding packs for states, actions (pre-existing) 37 | embedding_packs = [] 38 | for rollout in rollouts: 39 | embedding_packs.append(rollout_to_embedding_pack(rollout)) 40 | if len(embedding_packs) == 0: 41 | return 42 | 43 | # collect all the embeddings and get normalizing constants 44 | for k in embedding_packs[0].keys(): 45 | try: 46 | packs = [ 47 | pack[k] 48 | for pack in embedding_packs 49 | if not all(x is None for x in pack[k]) 50 | ] 51 | embeddings = np.concatenate(packs) if packs else None 52 | except Exception as e: 53 | raise ValueError(f"Error concatenating embeddings for {k}: {e}") 54 | 55 | # Check if embeddings array contains all None values 56 | if embeddings is None or all(x is None for x in embeddings.flatten()): 57 | print(f"Skipping {k} - embeddings array contains all None values") 58 | continue 59 | print(f"found {embeddings.shape} for {k}; (N,K)") 60 | 61 | # find normalizing constants 62 | try: 63 | means = np.mean(embeddings, axis=0) 64 | stds = np.std(embeddings, axis=0) 65 | except Exception as e: 66 | raise ValueError(f"Error finding normalizing constants for {k}: {e}") 67 | feature_dim = embeddings.shape[1] 68 | print(f"found means {means.shape} and stds {stds.shape}") 69 | # setup index if not already existing 70 | if k not in index_manager.indices.keys(): 71 | index_manager.init_index( 72 | k, 73 | feature_dim, 74 | STANDARDIZED_TIME_STEPS, 75 | norm_means=means, # normalize with dimension-specific means 76 | norm_stds=stds, # normalize with dimension-specific stds 77 | ) 78 | 79 | # add the embeddings to the index! these will be normalized 80 | for rollout, pack in tqdm( 81 | zip(rollouts, embedding_packs), desc=f"Ingesting {k} embeddings" 82 | ): 83 | # some datasets do not provide this information 84 | if isinstance(pack.get(k), np.ndarray) and not all( 85 | x is None for x in pack[k] 86 | ): 87 | index_manager.add_matrix(k, pack[k], str(rollout.id)) 88 | 89 | 90 | def ingest_language_embeddings_from_rollouts_per_dataset( 91 | rollouts: list[Rollout], index_manager: IndexManager, embedder: Embedder 92 | ) -> None: 93 | feature_dim = embedder.embed("test").shape[0] 94 | for name in META_INDEX_NAMES: 95 | for rollout in tqdm(rollouts, desc=f"Ingesting {name} embeddings"): 96 | if name not in index_manager.indices.keys(): 97 | index_manager.init_index( 98 | name, 99 | feature_dim, 100 | time_steps=1, # lang embeddings dont get time dimension 101 | norm_means=None, # no need to normalize 102 | norm_stds=None, # no need to normalize 103 | extra_metadata={"model": embedder.name}, 104 | ) 105 | 106 | inp = rollout.get_nested_attr(name) 107 | # some datasets do not provide this information 108 | if inp is None: 109 | continue 110 | embedding = embedder.embed(inp) 111 | index_manager.add_vector(name, embedding, str(rollout.id)) 112 | 113 | 114 | def run_embedding_database_ingestion_per_dataset( 115 | rollouts: list[Rollout], embedder: Embedder, index_path: str 116 | ) -> None: 117 | index_manager = IndexManager(index_path, index_class=FaissIndex) 118 | 119 | tic = time.time() 120 | # add the trajectory matrices to the index and get normalizing constants for this dataset 121 | # (states and actions) 122 | ingest_trajectory_matrices_from_rollouts_per_dataset(rollouts, index_manager) 123 | index_manager.save() 124 | 125 | # add task and description embeddings to the index 126 | # (task, description) 127 | ingest_language_embeddings_from_rollouts_per_dataset( 128 | rollouts, index_manager, embedder 129 | ) 130 | index_manager.save() 131 | 132 | print(f"Embedding database new rollouts: {len(rollouts)}") 133 | total_time = time.time() - tic 134 | print(f"Embedding database time: {total_time}") 135 | print(f"Embedding database mean time: {total_time / len(rollouts)}") 136 | relevant_metadata = { 137 | k: v 138 | for k, v in index_manager.metadata.items() 139 | if rollouts[0].dataset_formalname in k or k in META_INDEX_NAMES 140 | } 141 | print(f"Metadata: {relevant_metadata}") 142 | 143 | 144 | @click.command() 145 | @click.option( 146 | "--engine-url", 147 | type=str, 148 | required=True, 149 | help="SQLAlchemy database URL", 150 | ) 151 | @click.option( 152 | "--dataset-formalname", 153 | type=t.Union[str, None], 154 | required=False, 155 | help="Formal name of the dataset to process", 156 | default=None, 157 | ) 158 | @click.option( 159 | "--from-id-file", 160 | type=t.Union[str, None], 161 | required=False, 162 | help="File containing rollout ids to ingest", 163 | default=None, 164 | ) 165 | @click.option( 166 | "--index-path", 167 | type=str, 168 | required=False, 169 | help="Path to the index to ingest", 170 | default=EMBEDDING_DB_PATH, 171 | ) 172 | def main( 173 | engine_url: str, 174 | dataset_formalname: t.Union[str, None], 175 | from_id_file: t.Union[str, None], 176 | index_path: str, 177 | ) -> None: 178 | """Run embedding database ingestion for trajectory data.""" 179 | assert ( 180 | dataset_formalname is not None or from_id_file is not None 181 | ), "Either dataset_formalname or from_id_file must be provided" 182 | engine = setup_database(RolloutSQLModel, path=engine_url) 183 | embedder = get_nomic_embedder() 184 | if from_id_file is not None: 185 | with open(from_id_file, "r") as f: 186 | rollout_ids = [line.strip() for line in f.readlines()] 187 | 188 | rollouts = get_rollouts_by_ids(engine, rollout_ids) if rollout_ids else [] 189 | else: 190 | rollouts = setup_rollouts(engine, dataset_formalname) 191 | 192 | if len(rollouts) == 0: 193 | print( 194 | f"No rollouts found for dataset_formalname: {dataset_formalname}, from_id_file: {from_id_file}" 195 | ) 196 | return 197 | 198 | dataset_to_rollouts = defaultdict(list) 199 | for rollout in rollouts: 200 | dataset_to_rollouts[rollout.dataset_formalname].append(rollout) 201 | for dataset_formalname, dataset_rollouts in dataset_to_rollouts.items(): 202 | run_embedding_database_ingestion_per_dataset( 203 | dataset_rollouts, embedder, index_path 204 | ) 205 | 206 | 207 | if __name__ == "__main__": 208 | main() 209 | -------------------------------------------------------------------------------- /scripts/self_heal.py: -------------------------------------------------------------------------------- 1 | """ 2 | Errors happen! We want a service to self-heal -- that is, ensure that our databases are synced. 3 | This script is a two-step process: 4 | 1. Run `find-heal` to find which rollouts are missing from the embedding database and annotation database. This saves a list of ids to disk to be used in the next step. 5 | 2. Run `exec-heal` to ingest the missing rollouts into the embedding database and update the annotation database. 6 | This ensures our databases are in sync. 7 | """ 8 | 9 | import os 10 | from datetime import datetime 11 | from pathlib import Path 12 | 13 | import click 14 | import pandas as pd 15 | 16 | from ares.annotating.orchestration import orchestrate_annotating 17 | from ares.constants import ARES_DATA_DIR 18 | from ares.databases.annotation_database import ANNOTATION_DB_PATH, AnnotationDatabase 19 | from ares.databases.embedding_database import ( 20 | EMBEDDING_DB_PATH, 21 | META_INDEX_NAMES, 22 | TRAJECTORY_INDEX_NAMES, 23 | FaissIndex, 24 | IndexManager, 25 | rollout_to_index_name, 26 | ) 27 | from ares.databases.structured_database import ( 28 | ROBOT_DB_PATH, 29 | RolloutSQLModel, 30 | get_partial_df, 31 | get_rollout_by_name, 32 | setup_database, 33 | ) 34 | 35 | from .annotating.run_grounding import GroundingModalAnnotatingFn 36 | from .run_trajectory_embedding_ingestion import ( 37 | main as run_trajectory_embedding_ingestion, 38 | ) 39 | 40 | HEALING_EXCEPTIONS = { 41 | "utokyo_saytap_converted_externally_to_rlds": ["grounding"], 42 | "CMU Franka Exploration": ["CMU Franka Exploration-Franka-states"], 43 | "USC Jaco Play": ["USC Jaco Play-Jaco 2-states"], 44 | } 45 | HEAL_INFO_DIR = os.path.join(ARES_DATA_DIR, "heal_info") 46 | 47 | 48 | @click.command("find-heal") 49 | @click.option("--heal-info-dir", type=str, default=HEAL_INFO_DIR) 50 | def find_heal_opportunities(heal_info_dir: str) -> None: 51 | engine = setup_database(RolloutSQLModel, path=ROBOT_DB_PATH) 52 | ann_db = AnnotationDatabase(connection_string=ANNOTATION_DB_PATH) 53 | embedding_db = IndexManager(EMBEDDING_DB_PATH, FaissIndex) 54 | time_str = datetime.now().strftime("%Y-%m-%d_%H-%M-%S") 55 | heal_dir = os.path.join(heal_info_dir, time_str) 56 | os.makedirs(heal_dir, exist_ok=True) 57 | 58 | # collect all rollout IDs from structured database (engine) 59 | id_cols = ["id", "dataset_filename", "dataset_formalname", "filename"] 60 | rollout_df = get_partial_df(engine, id_cols) 61 | dataset_formalname_to_df = { 62 | k: v for k, v in rollout_df.groupby("dataset_formalname") 63 | } 64 | dataset_filename_to_df = {k: v for k, v in rollout_df.groupby("dataset_filename")} 65 | 66 | # check embedding database 67 | to_update_embedding_index_ids = [] 68 | for dataset_formalname, id_df in dataset_formalname_to_df.items(): 69 | if "embedding" in HEALING_EXCEPTIONS.get(dataset_formalname, []): 70 | continue 71 | example_rollout = get_rollout_by_name( 72 | engine, dataset_formalname, id_df["filename"].iloc[0] 73 | ) 74 | potential_index_names = [ 75 | rollout_to_index_name(example_rollout, suffix) 76 | for suffix in TRAJECTORY_INDEX_NAMES 77 | ] + META_INDEX_NAMES # description, task 78 | for index_name in potential_index_names: 79 | if index_name in HEALING_EXCEPTIONS.get(dataset_formalname, []): 80 | continue 81 | if index_name not in embedding_db.indices: 82 | missing_ids = id_df["id"].tolist() 83 | existing_index_ids = [] 84 | else: 85 | existing_index = embedding_db.indices[index_name] 86 | existing_index_ids = existing_index.get_all_ids() 87 | # add any missing ids to update list 88 | missing_ids = set(id_df["id"].astype(str).tolist()) - set( 89 | existing_index_ids.tolist() 90 | ) 91 | if len(missing_ids) > 0: 92 | n_existing = len(existing_index_ids) 93 | pct_missing = ( 94 | 100 95 | * len(missing_ids) 96 | / (n_existing if n_existing > 0 else len(missing_ids)) 97 | ) 98 | print( 99 | f"Found {len(missing_ids)} missing ids for index {index_name} out of {n_existing} existing ids; {pct_missing:.2f}% missing from dataset {dataset_formalname}" 100 | ) 101 | to_update_embedding_index_ids.extend(missing_ids) 102 | 103 | update_embedding_ids_path = os.path.join(heal_dir, "update_embedding_ids.txt") 104 | with open(update_embedding_ids_path, "w") as f: 105 | for id in to_update_embedding_index_ids: 106 | f.write(f"{id}\n") 107 | print( 108 | f"Found {len(to_update_embedding_index_ids)} ids to update in embedding database; saving to disk at {update_embedding_ids_path}" 109 | ) 110 | 111 | print("\n\n" + "=" * 100 + "\n\n") 112 | # to update grounding 113 | to_update_grounding_ids = [] 114 | existing_video_ids = pd.Series(ann_db.get_video_ids()) 115 | for dataset_filename, id_df in dataset_filename_to_df.items(): 116 | if "grounding" in HEALING_EXCEPTIONS.get(dataset_filename, []): 117 | to_update_grounding_ids.extend(id_df["id"].tolist()) 118 | # check if videos exists -- if not, add to list (will add video and grounding) 119 | found_video_ids = (id_df["dataset_filename"] + "/" + id_df["filename"]).apply( 120 | lambda x: str(Path(x).with_suffix(".mp4")) 121 | ) 122 | mask = ~found_video_ids.isin(existing_video_ids) 123 | if mask.any(): 124 | print(f"Found {mask.sum()} missing videos for dataset {dataset_filename}") 125 | to_update_grounding_ids.extend(id_df[mask]["id"].astype(str).tolist()) 126 | 127 | # Handle videos that exist but are missing annotations 128 | has_video_mask = found_video_ids.isin(existing_video_ids) 129 | videos_with_annotations = pd.Series(ann_db.get_annotation_ids()) 130 | missing_annotations_mask = ~found_video_ids[has_video_mask].isin( 131 | videos_with_annotations 132 | ) 133 | if missing_annotations_mask.any(): 134 | print( 135 | f"Found {missing_annotations_mask.sum()} videos missing annotations for dataset {dataset_filename}" 136 | ) 137 | to_update_grounding_ids.extend( 138 | id_df[has_video_mask][missing_annotations_mask]["id"] 139 | .astype(str) 140 | .tolist() 141 | ) 142 | 143 | update_grounding_ids_path = os.path.join(heal_dir, "update_grounding_ids.txt") 144 | to_update_grounding_ids = list(set(to_update_grounding_ids)) # remove duplicates 145 | with open(update_grounding_ids_path, "w") as f: 146 | for id in to_update_grounding_ids: 147 | f.write(f"{id}\n") 148 | print( 149 | f"Found {len(to_update_grounding_ids)} ids to update in grounding database; saving to disk at {update_grounding_ids_path}" 150 | ) 151 | print(f"TIME DIR: {time_str}") 152 | 153 | 154 | @click.command("exec-heal") 155 | @click.option("--time-dir", type=str, required=True) 156 | def execute_heal(time_dir: str) -> None: 157 | heal_dir = os.path.join(HEAL_INFO_DIR, time_dir) 158 | 159 | # run embedding ingestion via click's command from our embedding ingestion script 160 | update_embedding_ids_path = os.path.join(heal_dir, "update_embedding_ids.txt") 161 | run_trajectory_embedding_ingestion.callback( 162 | engine_url=ROBOT_DB_PATH, 163 | dataset_formalname=None, 164 | from_id_file=update_embedding_ids_path, 165 | index_path=EMBEDDING_DB_PATH, 166 | ) 167 | 168 | # update grounding database 169 | update_grounding_ids_path = os.path.join(heal_dir, "update_grounding_ids.txt") 170 | orchestrate_annotating( 171 | engine_path=ROBOT_DB_PATH, 172 | ann_db_path=ANNOTATION_DB_PATH, 173 | annotating_fn=GroundingModalAnnotatingFn(), 174 | ids_path=update_grounding_ids_path, 175 | failures_path=os.path.join( 176 | ARES_DATA_DIR, "annotating_failures", f"heal_failures_{time_dir}.pkl" 177 | ), 178 | ) 179 | 180 | print(f"Finished healing") 181 | 182 | 183 | @click.group() 184 | def cli(): 185 | """Self-healing utilities for database synchronization""" 186 | pass 187 | 188 | 189 | cli.add_command(find_heal_opportunities) 190 | cli.add_command(execute_heal) 191 | 192 | if __name__ == "__main__": 193 | cli() 194 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import find_packages, setup 2 | 3 | setup( 4 | name="ares", 5 | version="0.1.0", 6 | description="A system for automatically evaluating robot data", 7 | author="Jacob Phillips", 8 | author_email="jacob.phillips8905@gmail.com", 9 | url="https://github.com/jacobphillips99/ares", 10 | package_dir={"": "src"}, 11 | packages=find_packages(where="src"), 12 | ) 13 | -------------------------------------------------------------------------------- /src/ares/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jacobphillips99/ares/21850d69242f9106c186338b5883667562f4442f/src/ares/__init__.py -------------------------------------------------------------------------------- /src/ares/annotating/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jacobphillips99/ares/21850d69242f9106c186338b5883667562f4442f/src/ares/annotating/__init__.py -------------------------------------------------------------------------------- /src/ares/annotating/annotating_base.py: -------------------------------------------------------------------------------- 1 | """ 2 | Base classes and functions for annotating rollouts, either by API, Modal, or local. 3 | """ 4 | 5 | import pickle 6 | from dataclasses import dataclass, field 7 | 8 | from sqlalchemy import Engine 9 | 10 | from ares.configs.base import Rollout 11 | from ares.constants import get_dataset_info_by_key 12 | from ares.databases.structured_database import get_rollouts_by_ids, setup_rollouts 13 | 14 | 15 | @dataclass 16 | class ErrorResult: 17 | rollout_id: str 18 | error_pattern: str 19 | error: str 20 | exception: str | None = None 21 | 22 | 23 | @dataclass 24 | class ResultTracker: 25 | videos: int = 0 26 | frames: int = 0 27 | annotations: int = 0 28 | video_ids: list[str] = field(default_factory=list) 29 | 30 | def update_via_batch( 31 | self, n_videos: int, n_frames: int, n_annotations: int, video_ids: list[str] 32 | ) -> None: 33 | self.videos += n_videos 34 | self.frames += n_frames 35 | self.annotations += n_annotations 36 | self.video_ids.extend(video_ids) 37 | 38 | def update_tracker(self, tracker: "ResultTracker") -> None: 39 | self.videos += tracker.videos 40 | self.frames += tracker.frames 41 | self.annotations += tracker.annotations 42 | self.video_ids.extend(tracker.video_ids) 43 | 44 | def print_stats(self) -> None: 45 | print( 46 | f"Processed {self.videos} videos, {self.frames} frames, {len(self.video_ids)} videos with annotations" 47 | ) 48 | 49 | 50 | def setup_rollouts_from_sources( 51 | engine: Engine, 52 | rollout_ids: list[str] | None = None, 53 | ids_path: str | None = None, 54 | dataset_filename: str | None = None, 55 | split: str | None = None, 56 | ) -> list[Rollout]: 57 | """ 58 | Helper function to setup rollouts from a variety of sources, e.g. a list of rollout IDs, a dataset filename, or a file path to a list of failed rollout IDs. 59 | The file path can be a pickle or a txt file; the pickle file should contain a list of dictionaries with a `rollout_id` key whereas the txt file should contain a list of rollout IDs. 60 | """ 61 | assert ( 62 | ids_path or rollout_ids or dataset_filename 63 | ), f"Must provide either ids_path, rollout_ids, or dataset_filename. Received: ids_path={ids_path}, rollout_ids={rollout_ids}, dataset_filename={dataset_filename}" 64 | 65 | if ids_path: 66 | # load rollouts from a file path 67 | # - a pickle implies a list of failed rollout dictionaries 68 | # - a txt implies a list of rollout IDs directly 69 | if ids_path.endswith(".pkl"): 70 | with open(ids_path, "rb") as f: 71 | failures = pickle.load(f) 72 | failed_ids = [str(f["rollout_id"]) for f in failures] 73 | return get_rollouts_by_ids(engine, failed_ids) 74 | elif ids_path.endswith(".txt"): 75 | with open(ids_path, "r") as f: 76 | failed_ids = [line.strip() for line in f.readlines()] 77 | return get_rollouts_by_ids(engine, failed_ids) 78 | else: 79 | raise ValueError(f"Unknown file type: {ids_path}") 80 | elif rollout_ids: 81 | # load rollouts from a list of IDs 82 | return get_rollouts_by_ids(engine, rollout_ids) 83 | else: 84 | # load rollouts from a dataset filename 85 | dataset_info = get_dataset_info_by_key("dataset_filename", dataset_filename) 86 | if not dataset_info: 87 | raise ValueError(f"Dataset filename {dataset_filename} not found.") 88 | dataset_formalname = dataset_info["dataset_formalname"] 89 | rollouts = setup_rollouts(engine, dataset_formalname=dataset_formalname) 90 | if split: 91 | rollouts = [r for r in rollouts if r.split == split] 92 | return rollouts 93 | -------------------------------------------------------------------------------- /src/ares/annotating/annotating_fn.py: -------------------------------------------------------------------------------- 1 | import asyncio 2 | import traceback 3 | import typing as t 4 | 5 | from tqdm import tqdm 6 | 7 | from ares.annotating.annotating_base import ErrorResult, ResultTracker 8 | from ares.configs.annotations import Annotation 9 | from ares.configs.base import Rollout 10 | from ares.databases.annotation_database import AnnotationDatabase, get_video_id 11 | from ares.models.base import VLM 12 | from ares.models.shortcuts import get_vlm 13 | 14 | 15 | class AnnotatingFn: 16 | """ 17 | Compute primitive to annotate rollouts. AnnotatingFns should conduct their annotation in batches, 18 | whether via API, Modal, local, etc. 19 | """ 20 | 21 | def __call__( 22 | self, 23 | *args: t.Any, 24 | **kwargs: t.Any, 25 | ) -> tuple[ResultTracker, list[ErrorResult]]: 26 | raise NotImplementedError 27 | 28 | 29 | class APIAnnotatingFn(AnnotatingFn): 30 | """ 31 | Base class to create annotating functions that use an API. E.g success criteria, grounding phrases, etc. 32 | """ 33 | 34 | def __init__(self, annotation_key: str, annotation_type: str): 35 | self.annotation_key = annotation_key 36 | self.annotation_type = annotation_type 37 | 38 | async def run_query( 39 | self, 40 | vlm: VLM, 41 | rollout: Rollout, 42 | ann_db: AnnotationDatabase, 43 | ) -> t.Any: 44 | raise NotImplementedError 45 | 46 | async def run_batch( 47 | self, 48 | vlm: VLM, 49 | rollouts_batch: list[Rollout], 50 | ann_db: AnnotationDatabase, 51 | ) -> tuple[ResultTracker, list[ErrorResult]]: 52 | """ 53 | Default function to annotate a batch of rollouts and store annotations in the database with an API-driven annotating function. 54 | """ 55 | 56 | # Create futures with their corresponding rollouts 57 | futures = [] 58 | for rollout in rollouts_batch: 59 | future = asyncio.create_task(self.run_query(vlm, rollout, ann_db)) 60 | futures.append((future, rollout)) 61 | 62 | tracker = ResultTracker() 63 | failures = [] 64 | 65 | for future, rollout in futures: 66 | try: 67 | result = await future 68 | if isinstance(result, ErrorResult): 69 | # check for error result and record if so 70 | failures.append(result) 71 | else: 72 | # otherwise, get the video id and add the annotaiton to the database 73 | video_id = get_video_id(rollout.dataset_filename, rollout.filename) 74 | ann_db.add_annotation( 75 | video_id=video_id, 76 | key=self.annotation_key, 77 | value=Annotation( 78 | description=result, annotation_type=self.annotation_type 79 | ), 80 | annotation_type=self.annotation_type, 81 | frame=None, 82 | ) 83 | tracker.update_via_batch( 84 | n_videos=1, n_frames=1, n_annotations=1, video_ids=[video_id] 85 | ) 86 | except Exception as e: 87 | failures.append( 88 | ErrorResult( 89 | rollout_id=str(rollout.id), 90 | error_pattern="batch_processing_failure", 91 | error=traceback.format_exc(), 92 | exception=str(e), 93 | ) 94 | ) 95 | 96 | return tracker, failures 97 | 98 | def __call__( 99 | self, 100 | rollouts: list[Rollout], 101 | ann_db: AnnotationDatabase, 102 | outer_batch_size: int, 103 | vlm_name: str = "gpt-4o", 104 | ) -> tuple[ResultTracker, list[ErrorResult]]: 105 | """ 106 | Orchestrating function for this annotating function. The __call__ function instantiates the objects and 107 | create the "outer loop" for annotating batches of rollouts. 108 | """ 109 | overall_tracker = ResultTracker() 110 | overall_failures = [] 111 | 112 | for i in tqdm( 113 | range(0, len(rollouts), outer_batch_size), 114 | desc="Processing outer batches", 115 | ): 116 | print( 117 | f"Processing batch {i // outer_batch_size + 1} of {max(1, len(rollouts) // outer_batch_size)}" 118 | ) 119 | # create VLM outside async as the semaphore gets "bound" to async context 120 | vlm = get_vlm(vlm_name) 121 | 122 | # get batch results 123 | rollouts_batch = rollouts[i : i + outer_batch_size] 124 | tracker, failures = asyncio.run( 125 | self.run_batch( 126 | vlm=vlm, 127 | rollouts_batch=rollouts_batch, 128 | ann_db=ann_db, 129 | ) 130 | ) 131 | 132 | print( 133 | f"Completed batch {i // outer_batch_size + 1} of {max(1, len(rollouts) // outer_batch_size)}" 134 | ) 135 | overall_tracker.update_tracker(tracker) 136 | overall_failures.extend(failures) 137 | return overall_tracker, overall_failures 138 | -------------------------------------------------------------------------------- /src/ares/annotating/modal_base.py: -------------------------------------------------------------------------------- 1 | """ 2 | Base Modal infrastructure for running serverless compute tasks. In order to prevent copying over the ARES repository and dependencies, 3 | avoid importing any ARES modules in this file; instead, import the necessary modules in the specific Modal app classes. 4 | """ 5 | 6 | import asyncio 7 | import typing as t 8 | 9 | from modal import App, Image, enter, method 10 | 11 | # Base Modal image with common dependencies 12 | base_image = ( 13 | Image.debian_slim() 14 | .apt_install("python3-opencv") 15 | .pip_install("torch", "transformers", "numpy", "opencv-python", "tqdm", "pillow") 16 | ) 17 | 18 | 19 | class BaseWorker: 20 | """ 21 | Base worker class to be decorated by specific Modal apps. 22 | """ 23 | 24 | @enter() 25 | def setup(self) -> None: 26 | """Override in subclass to initialize resources.""" 27 | pass 28 | 29 | @method() 30 | async def process(self, *args, **kwargs): 31 | """Override in subclass with task-specific logic.""" 32 | pass 33 | 34 | 35 | class BaseModalWrapper: 36 | """ 37 | Base class for Modal task wrappers. 38 | """ 39 | 40 | def __init__( 41 | self, 42 | app_name: str, 43 | worker_cls: t.Type[BaseWorker] = BaseWorker, 44 | image: Image = base_image, 45 | ) -> None: 46 | self.app_name = app_name 47 | self.app = App(app_name) 48 | self.WorkerCls = self.app.cls( 49 | image=image, 50 | gpu="t4", 51 | concurrency_limit=10, 52 | timeout=600, 53 | )(worker_cls) 54 | print(f"Modal app {self.app_name} initialized") 55 | 56 | async def run_batch( 57 | self, items: list[t.Any], batch_size: int = 8 58 | ) -> list[tuple[str, list[list["Annotation"]]]]: 59 | """Run batch processing using Modal.""" 60 | tasks = [] 61 | for i in range(0, len(items), batch_size): 62 | batch = items[i : i + batch_size] 63 | tasks.append(self.WorkerCls().process.remote.aio(batch)) 64 | results = await asyncio.gather(*tasks) 65 | results = [item for batch in results for item in batch] 66 | return results 67 | -------------------------------------------------------------------------------- /src/ares/annotating/modal_grounding.py: -------------------------------------------------------------------------------- 1 | """ 2 | Modal wrapper for interacting with GroundingWorker. 3 | """ 4 | 5 | from modal import enter, method 6 | 7 | from ares.annotating.modal_base import BaseModalWrapper, BaseWorker 8 | from ares.models.grounding import GroundingAnnotator 9 | 10 | 11 | class GroundingWorker(BaseWorker): 12 | """Worker class for grounding annotations.""" 13 | 14 | @enter() 15 | def setup(self) -> None: 16 | """Initialize resources needed for grounding.""" 17 | self.worker = GroundingAnnotator(segmenter_id=None) 18 | 19 | @method() 20 | async def process( 21 | self, batch: list[tuple[str, list, str]] 22 | ) -> list[tuple[str, list]]: 23 | """ 24 | Process method to annotate multiple videos in a batch. 25 | 26 | Args: 27 | batch: List of tuples containing (rollout_id, frames, label_str) 28 | 29 | Returns: 30 | list[tuple[str, list]]: List of annotation results 31 | """ 32 | results = [] 33 | for rollout_id, frames, label_str in batch: 34 | try: 35 | result = self.worker.annotate_video(rollout_id, frames, label_str) 36 | results.append(result) 37 | except Exception as e: 38 | print(f"Error processing {rollout_id}: {e}") 39 | # Return empty annotations for failed items 40 | results.append((rollout_id, [])) 41 | return results 42 | 43 | 44 | class GroundingModalWrapper(BaseModalWrapper): 45 | """ 46 | Wrapper class to interact with GroundingWorker via Modal. 47 | """ 48 | 49 | def __init__(self, app_name: str = "grounding_app"): 50 | super().__init__(app_name, worker_cls=GroundingWorker) 51 | 52 | async def annotate_videos( 53 | self, tasks: list[tuple[str, list, str]] 54 | ) -> list[tuple[str, list[list["Annotation"]]]]: 55 | """ 56 | Submit a batch of annotation tasks to the GroundingWorker. 57 | 58 | Args: 59 | tasks (list[tuple[str, list, str]]): List of tuples containing rollout_id, frames, and label_str. 60 | 61 | Returns: 62 | list[tuple[str, list[list[Annotation]]]]: 63 | """ 64 | return await self.run_batch(tasks) 65 | -------------------------------------------------------------------------------- /src/ares/annotating/orchestration.py: -------------------------------------------------------------------------------- 1 | """ 2 | Helper function to orchestrate annotation over datasets in the structured database and send annotations into the annotation database. 3 | See `ares.annotating.annotating_fn.py` for the AnnotatingFn object that gets fulfilled for different annotation methods. 4 | """ 5 | 6 | import os 7 | import pickle 8 | import time 9 | 10 | from ares.annotating.annotating_base import ( 11 | ErrorResult, 12 | ResultTracker, 13 | setup_rollouts_from_sources, 14 | ) 15 | from ares.annotating.annotating_fn import AnnotatingFn 16 | from ares.constants import ANNOTATION_OUTER_BATCH_SIZE 17 | from ares.databases.annotation_database import AnnotationDatabase 18 | from ares.databases.structured_database import RolloutSQLModel, setup_database 19 | 20 | 21 | def orchestrate_annotating( 22 | engine_path: str, 23 | ann_db_path: str, 24 | annotating_fn: AnnotatingFn, 25 | dataset_filename: str | None = None, 26 | split: str | None = None, 27 | rollout_ids: list[str] | None = None, 28 | outer_batch_size: int = ANNOTATION_OUTER_BATCH_SIZE, # RAM limits number of concurrent rollouts formatted into requests 29 | ids_path: ( 30 | str | None 31 | ) = None, # Path to ids to load; may be failed IDs from previous run 32 | annotating_kwargs: dict | None = None, 33 | failures_path: str | None = None, 34 | ) -> tuple[ResultTracker, list[ErrorResult]]: 35 | """ 36 | Main function to run annotations, whether local, API-driven, or on Modal. 37 | 38 | Args: 39 | engine_path (str): Path to the database engine. 40 | ann_db_path (str): Path to the annotation database. 41 | annotating_fn (Callable): Function to run annotation. 42 | dataset_filename (str, optional): Dataset filename. 43 | split (str, optional): Data split. 44 | rollout_ids (list[str], optional): Specific rollout IDs to process. 45 | outer_batch_size (int, optional): Batch size for processing rollouts. 46 | ids_path (str, optional): Path to ids to load. 47 | annotating_kwargs (dict, optional): Additional keyword arguments for annotating_fn. 48 | """ 49 | assert ( 50 | dataset_filename is not None or ids_path is not None or rollout_ids is not None 51 | ), f"Must provide either dataset_filename, ids_path, or rollout_ids. Received: dataset_filename={dataset_filename}, ids_path={ids_path}, rollout_ids={rollout_ids}" 52 | 53 | annotating_kwargs = annotating_kwargs or {} 54 | # Initialize databases 55 | ann_db = AnnotationDatabase(connection_string=ann_db_path) 56 | engine = setup_database(RolloutSQLModel, path=engine_path) 57 | rollouts = setup_rollouts_from_sources( 58 | engine, rollout_ids, ids_path, dataset_filename, split 59 | ) 60 | print(f"\n\nFound {len(rollouts)} total rollouts\n\n") 61 | if not rollouts: 62 | print( 63 | f"No rollouts found for dataset filename {dataset_filename}, retry failed path {ids_path}" 64 | ) 65 | return 66 | 67 | tic = time.time() 68 | overall_tracker, overall_failures = annotating_fn( 69 | rollouts, ann_db, outer_batch_size, **annotating_kwargs 70 | ) 71 | print(f"\n\nFailures: {overall_failures}\n\n") 72 | 73 | # Write failures to file to retry 74 | if failures_path and len(overall_failures) > 0: 75 | os.makedirs(os.path.dirname(failures_path), exist_ok=True) 76 | print(f"Writing failures to {failures_path}") 77 | with open(failures_path, "wb") as f: 78 | pickle.dump(overall_failures, f) 79 | 80 | print("Time taken:", time.time() - tic) 81 | print(f"\n\n") 82 | overall_tracker.print_stats() 83 | print(f"\nNumber of failures: {len(overall_failures)}") 84 | return overall_tracker, overall_failures 85 | -------------------------------------------------------------------------------- /src/ares/app/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jacobphillips99/ares/21850d69242f9106c186338b5883667562f4442f/src/ares/app/__init__.py -------------------------------------------------------------------------------- /src/ares/app/annotation_viz_helpers.py: -------------------------------------------------------------------------------- 1 | import hashlib 2 | 3 | import cv2 4 | import numpy as np 5 | import streamlit as st 6 | 7 | from ares.configs.annotations import Annotation 8 | from ares.utils.image_utils import choose_and_preprocess_frames, get_video_frames 9 | 10 | 11 | def get_color_mapping(category_str: str) -> tuple[int, int, int]: 12 | """ 13 | Create a consistent color mapping based on hash of a string. 14 | This way, the same strings are mapped to the same colors. 15 | """ 16 | hash_str = hashlib.sha256(category_str.encode()).hexdigest()[:6] 17 | # Convert pairs of hex digits to RGB values (0-255) 18 | r = int(hash_str[0:2], 16) 19 | g = int(hash_str[2:4], 16) 20 | b = int(hash_str[4:6], 16) 21 | return (r, g, b) 22 | 23 | 24 | def draw_legend( 25 | canvas: np.ndarray, 26 | unique_categories: dict, 27 | start_y: int, 28 | legend_spacing: int = 25, 29 | legend_x: int = 10, 30 | ) -> None: 31 | """Draw category legend on the canvas. 32 | 33 | Args: 34 | canvas: Image to draw legend on 35 | unique_categories: Dictionary mapping category names to colors 36 | start_y: Y coordinate to start drawing legend 37 | legend_spacing: Vertical spacing between legend items 38 | legend_x: X coordinate to start drawing legend 39 | """ 40 | for idx, (category, color) in enumerate(unique_categories.items()): 41 | # Draw color box 42 | cv2.rectangle( 43 | canvas, 44 | (legend_x, start_y + idx * legend_spacing - 15), 45 | (legend_x + 20, start_y + idx * legend_spacing), 46 | color, 47 | -1, 48 | ) 49 | # Draw category text 50 | cv2.putText( 51 | canvas, 52 | category, 53 | (legend_x + 30, start_y + idx * legend_spacing - 3), 54 | cv2.FONT_HERSHEY_SIMPLEX, 55 | 0.5, 56 | (0, 0, 0), 57 | 1, 58 | cv2.LINE_AA, 59 | ) 60 | 61 | 62 | def draw_box( 63 | annotation: Annotation, 64 | annotated_image: np.ndarray, 65 | image: np.ndarray, 66 | overlay: np.ndarray, 67 | color: tuple[int, int, int], 68 | label: str, 69 | show_scores: bool, 70 | ) -> np.ndarray: 71 | if not annotation.bbox: 72 | return annotated_image 73 | x1, y1, x2, y2 = map(int, annotation.bbox) 74 | 75 | cv2.rectangle(annotated_image, (x1, y1), (x2, y2), color, 2) # Border 76 | # Draw rectangle with transparency 77 | if not annotation.segmentation: 78 | cv2.rectangle(overlay, (x1, y1), (x2, y2), color, -1) # Filled box for overlay 79 | 80 | # Prepare label text 81 | if show_scores and annotation.score is not None: 82 | label_text = f"{label} {annotation.score:.2f}" 83 | else: 84 | label_text = label 85 | 86 | # Get text size 87 | (text_w, text_h), _ = cv2.getTextSize(label_text, cv2.FONT_HERSHEY_SIMPLEX, 0.5, 1) 88 | 89 | # Adjust label position to ensure it's within image bounds 90 | text_x = min(max(x1, 0), image.shape[1] - text_w) 91 | text_y = max(y1 - 2, text_h + 4) # Ensure there's room for text 92 | 93 | # Draw label background and text 94 | cv2.rectangle( 95 | annotated_image, 96 | (text_x, text_y - text_h - 4), 97 | (text_x + text_w, text_y), 98 | color, 99 | -1, 100 | ) 101 | cv2.putText( 102 | annotated_image, 103 | label_text, 104 | (text_x, text_y - 2), 105 | cv2.FONT_HERSHEY_SIMPLEX, 106 | 0.5, 107 | (0, 0, 0), 108 | 1, 109 | cv2.LINE_AA, 110 | ) 111 | return annotated_image 112 | 113 | 114 | # Draw annotations 115 | def draw_annotations( 116 | image: np.ndarray, 117 | annotations: list[Annotation], 118 | show_scores: bool = True, 119 | alpha: float = 0.25, 120 | ) -> np.ndarray: 121 | """ 122 | Create quick-and-easy visualizations of annotations. 123 | """ 124 | annotated_image = image.copy() 125 | overlay = image.copy() 126 | 127 | # Track unique categories for legend 128 | unique_categories = {} 129 | 130 | for annotation in annotations: 131 | label = annotation.category_name or "unknown" 132 | color = get_color_mapping(label) 133 | unique_categories[label] = color 134 | 135 | # Draw bounding box 136 | if annotation.bbox: 137 | annotated_image = draw_box( 138 | annotation, annotated_image, image, overlay, color, label, show_scores 139 | ) 140 | 141 | # Draw segmentation mask 142 | if annotation.segmentation: 143 | colored_mask = np.zeros_like(image, dtype=np.uint8) 144 | colored_mask[annotation.mask == 1] = color 145 | overlay = cv2.addWeighted(overlay, 1, colored_mask, alpha, 0) 146 | 147 | # Calculate legend dimensions 148 | legend_spacing = 25 # Vertical spacing between legend items 149 | legend_height = len(unique_categories) * legend_spacing + 10 150 | legend_padding = 20 # Padding around legend 151 | 152 | # Create extended canvas for image + legend 153 | canvas = np.full( 154 | (image.shape[0] + legend_height + legend_padding, image.shape[1], 3), 155 | 255, # White background 156 | dtype=np.uint8, 157 | ) 158 | 159 | # Place the annotated image at the top 160 | canvas[: image.shape[0]] = cv2.addWeighted( 161 | overlay, alpha, annotated_image, 1 - alpha, 0 162 | ) 163 | 164 | # Draw legend using helper function 165 | legend_y = image.shape[0] + legend_padding 166 | draw_legend(canvas, unique_categories, legend_y) 167 | return canvas 168 | 169 | 170 | def draw_detection_data(detection_data: dict, dataset: str, fname: str) -> None: 171 | # given detection data, lets display the frames and annotations 172 | frame_inds = list(detection_data.keys()) 173 | all_frame_paths = get_video_frames(dataset, fname, n_frames=None, just_path=True) 174 | selected_frames = choose_and_preprocess_frames( 175 | all_frame_paths, 176 | specified_frames=frame_inds, 177 | ) 178 | annotated_frames = [ 179 | draw_annotations(frame, anns) 180 | for frame, anns in zip(selected_frames, detection_data.values()) 181 | ] 182 | # use an expander for visual clarity 183 | with st.expander("Annotated Frames", expanded=False): 184 | max_cols = 3 185 | cols = st.columns(max_cols) 186 | for i, (frame_ind, frame) in enumerate(zip(frame_inds, annotated_frames)): 187 | with cols[i % max_cols]: 188 | st.write(f"Frame {frame_ind}") 189 | st.image(frame) 190 | -------------------------------------------------------------------------------- /src/ares/app/data_analysis.py: -------------------------------------------------------------------------------- 1 | """ 2 | Simple helpers for inferring data types and generating visualizations. For example, a column of data containing a small number of unique strings can 3 | be inferred to be a bar chart, while a column of data containing a large number of unique floats can be inferred to be a histogram. We also choose 4 | to *not* display certain columns that have poor visualizations, such as a column of data containing a large number of unique strings such as IDs or file paths. 5 | """ 6 | 7 | import typing as t 8 | 9 | import pandas as pd 10 | 11 | from ares.app.plot_primitives import create_bar_plot, create_histogram 12 | from ares.constants import IGNORE_COLS 13 | 14 | 15 | def infer_visualization_type( 16 | column_name: str, 17 | data: pd.DataFrame, 18 | ignore_cols: list | None = None, 19 | max_str_length: int = 500, 20 | ) -> dict[str, t.Any]: 21 | """ 22 | Heuristic solution for transforming a column of data into a visualization type, 23 | focusing on numeric ranges or category counts. 24 | """ 25 | ignore_cols = ignore_cols or IGNORE_COLS 26 | 27 | dtype = str(data[column_name].dtype) 28 | nunique = data[column_name].nunique() 29 | 30 | result = {"viz_type": None, "dtype": dtype, "nunique": nunique} 31 | 32 | if column_name.lower() in ignore_cols: 33 | return result 34 | 35 | # Add special handling for boolean columns 36 | if pd.api.types.is_bool_dtype(data[column_name]): 37 | result["viz_type"] = "bar" 38 | return result 39 | 40 | if pd.api.types.is_string_dtype(data[column_name]): 41 | if data[column_name].str.len().max() > max_str_length: 42 | return result 43 | 44 | if pd.api.types.is_datetime64_any_dtype(data[column_name]): 45 | return result 46 | 47 | if pd.api.types.is_numeric_dtype(data[column_name]) or ( 48 | dtype == "object" 49 | and len(data[column_name].dropna()) > 0 50 | and pd.to_numeric(data[column_name].dropna(), errors="coerce").notna().all() 51 | ): 52 | # check if lots of unique values or if it's a float between 0 and 1 53 | if nunique > 20 or ( 54 | pd.api.types.is_float_dtype(data[column_name]) 55 | and data[column_name].min() >= 0 56 | and data[column_name].max() <= 1 57 | ): 58 | result["viz_type"] = "histogram" 59 | else: 60 | result["viz_type"] = "bar" 61 | return result 62 | 63 | if pd.api.types.is_string_dtype(data[column_name]) or nunique < 20: 64 | result["viz_type"] = "bar" 65 | return result 66 | 67 | return result 68 | 69 | 70 | def generate_automatic_visualizations( 71 | df: pd.DataFrame, 72 | time_column: str = "creation_time", 73 | ignore_cols: list[str] | None = None, 74 | max_x_bar_options: int = 100, 75 | ) -> list[dict]: 76 | """ 77 | After inferring the 'type' of a column, we can create automatic visualizations. 78 | """ 79 | ignore_cols = ignore_cols or IGNORE_COLS 80 | visualizations = [] 81 | 82 | # Pre-calculate visualization types for all columns at once 83 | viz_infos = { 84 | col: infer_visualization_type(col, df) 85 | for col in sorted(df.columns) 86 | if col != time_column and col.lower() not in ignore_cols 87 | } 88 | 89 | # Group columns by visualization type 90 | histogram_cols = [] 91 | bar_cols = [] 92 | for col, info in viz_infos.items(): 93 | if not info["nunique"] or ( 94 | info["viz_type"] == "bar" and info["nunique"] > max_x_bar_options 95 | ): 96 | continue 97 | if info["viz_type"] == "histogram": 98 | histogram_cols.append(col) 99 | elif info["viz_type"] == "bar": 100 | bar_cols.append(col) 101 | 102 | # Create histogram visualizations 103 | for col in histogram_cols: 104 | col_title = col.replace("_", " ").replace("-", " ").title() 105 | visualizations.append( 106 | { 107 | "figure": create_histogram( 108 | df, 109 | x=col, 110 | color="#1f77b4", 111 | title=f"Distribution of {col_title}", 112 | labels={col: col_title, "count": "Count"}, 113 | ), 114 | "title": f"{col_title} Distribution", 115 | } 116 | ) 117 | 118 | # Create bar visualizations - handle each column separately 119 | for col in bar_cols: 120 | col_title = col.replace("_", " ").replace("-", " ").title() 121 | 122 | # Create aggregation consistently for both boolean and non-boolean columns 123 | if pd.api.types.is_bool_dtype(df[col]): 124 | value_counts = df[col].astype(str).value_counts() 125 | else: 126 | value_counts = df[col].value_counts() 127 | 128 | agg_data = value_counts.reset_index() 129 | agg_data.columns = [col, "count"] 130 | 131 | visualizations.append( 132 | { 133 | "figure": create_bar_plot( 134 | agg_data, 135 | x=col, 136 | y="count", 137 | color="#1f77b4", 138 | title=f"Count by {col_title}", 139 | labels={col: col_title, "count": "Count"}, 140 | ), 141 | "title": f"{col_title} Distribution", 142 | } 143 | ) 144 | return visualizations 145 | -------------------------------------------------------------------------------- /src/ares/app/export_data.py: -------------------------------------------------------------------------------- 1 | """ 2 | Simple methods for exporting ARES dashboards for a report to external users. 3 | """ 4 | 5 | import os 6 | import traceback 7 | import typing as t 8 | from datetime import datetime 9 | 10 | import pandas as pd 11 | import pdfkit 12 | import plotly.graph_objects as go 13 | import streamlit as st 14 | 15 | from ares.constants import ARES_DATA_DIR 16 | 17 | 18 | def pdf_from_html(all_html_content: str, full_path: str) -> None: 19 | # Use absolute path for temporary HTML file 20 | html_path = os.path.abspath("/tmp/export.html") 21 | 22 | # Write the HTML content with proper encoding 23 | with open(html_path, "w", encoding="utf-8") as f: 24 | f.write(all_html_content) 25 | 26 | # Configure pdfkit options 27 | options = {"enable-local-file-access": None, "quiet": None} 28 | 29 | try: 30 | pdfkit.from_file(html_path, full_path, options=options) 31 | finally: 32 | # Clean up temp file even if conversion fails 33 | if os.path.exists(html_path): 34 | os.remove(html_path) 35 | 36 | 37 | def data_only_export(df: pd.DataFrame, export_path: str, format: str) -> str: 38 | # Simple data-only export 39 | full_path = f"{export_path}.{format}" 40 | print(f"Exporting data-only format to {full_path}") 41 | if format == "csv": 42 | df.to_csv(full_path, index=False) 43 | elif format == "parquet": 44 | # parquet doesn't like uuid 45 | df.id = df.id.astype(str) 46 | df.to_parquet(full_path, index=False) 47 | else: 48 | df.to_excel(full_path, index=False) 49 | return full_path 50 | 51 | 52 | def pretty_dashboard_export( 53 | df: pd.DataFrame, 54 | export_path: str, 55 | title: str, 56 | structured_filters: dict[str, t.Any], 57 | visualizations: list[dict], 58 | format: str, 59 | go_figs: dict[str, go.Figure], 60 | ) -> str: 61 | # Full dashboard export 62 | full_path = f"{export_path}.{format}" 63 | img_dir = f"{export_path}_files" 64 | os.makedirs(img_dir, exist_ok=True) 65 | 66 | # Generate HTML content 67 | html_content = [ 68 | "", 69 | "", 76 | "", 77 | f"

{title}

", 78 | f"

Generated on: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}

", 79 | ] 80 | 81 | # add any selected structured filters 82 | if structured_filters: 83 | html_content.extend( 84 | [ 85 | "

Selected Filters

", 86 | "", 92 | ] 93 | ) 94 | else: 95 | html_content.extend(["

Selected Filters

", "

No filters selected

"]) 96 | 97 | # Modified plotly graph objects visualization section 98 | if go_figs: 99 | for name, fig in go_figs.items(): 100 | im_path = os.path.join(img_dir, f"{name}_go_fig.png") 101 | fig.write_image(im_path) 102 | html_content.extend( 103 | [ 104 | f"

{name} Analysis

", 105 | f'', 106 | ] 107 | ) 108 | 109 | # Add all visualizations 110 | html_content.append("

Visualizations

") 111 | for i, viz in enumerate(visualizations): 112 | img_path = os.path.join(img_dir, f"plot_{i}.png") 113 | if "figure" not in viz: 114 | raise ValueError(f"No figure found for visualization {viz['title']}") 115 | viz["figure"].write_image(img_path) 116 | html_content.extend( 117 | [ 118 | f"
", 119 | f"

{viz['title']}

", 120 | f'', 121 | "
", 122 | ] 123 | ) 124 | 125 | # Add summary statistics 126 | html_content.extend( 127 | [ 128 | "

Summary Statistics

", 129 | "
", 130 | df.describe().to_html(), 131 | "
", 132 | ] 133 | ) 134 | 135 | # Add data table 136 | # Truncate any long text cells to 1000 chars 137 | df_truncated = df.copy() 138 | truncate_length = 200 139 | for col in df_truncated.select_dtypes(["object"]): 140 | df_truncated[col] = ( 141 | df_truncated[col] 142 | .astype(str) 143 | .apply( 144 | lambda x: ( 145 | x[:truncate_length] + "..." if len(x) > truncate_length else x 146 | ) 147 | ) 148 | ) 149 | html_content.extend( 150 | [ 151 | "

Data Sample

", 152 | "
", 153 | df_truncated.head(100).to_html(), # First 100 rows 154 | "
", 155 | "", 156 | ] 157 | ) 158 | 159 | all_html_content = "\n".join(html_content) 160 | if format == "html": 161 | with open(full_path, "w") as f: 162 | f.write(all_html_content) 163 | else: 164 | pdf_from_html(all_html_content, full_path) 165 | return full_path 166 | 167 | 168 | def export_dashboard( 169 | df: pd.DataFrame, 170 | structured_filters: dict[str, t.Any], 171 | visualizations: list[dict], 172 | base_path: str, 173 | title: str, 174 | go_figs: dict[str, go.Figure], 175 | format: str = "html", 176 | ) -> str: 177 | """ 178 | Export dashboard including data, visualizations, and analytics. 179 | 180 | Args: 181 | df: DataFrame containing all data 182 | visualizations: List of visualization dictionaries with figures and titles 183 | base_path: Base directory path where file should be saved 184 | title: Title of the dashboard 185 | format: Export format ("html", "pdf", "csv", "xlsx") 186 | go_figs: Optional dict of plotly graph objects to include in export as images 187 | 188 | Note on formats: 189 | - data-only export formats ('csv', 'xlsx') will just dump a dataframe with the selected rows, no visualizations 190 | - "pretty" export formats ('html', 'pdf') will render the entire dashboard and save as a single artifact 191 | 192 | Returns: 193 | Path where file was saved 194 | """ 195 | timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") 196 | export_dir = os.path.join(base_path, "exports") 197 | os.makedirs(export_dir, exist_ok=True) 198 | 199 | base_filename = f"dashboard_export_{timestamp}" 200 | export_path = os.path.join(export_dir, base_filename) 201 | 202 | if format in ["csv", "xlsx", "parquet"]: 203 | return data_only_export(df, export_path, format) 204 | else: 205 | return pretty_dashboard_export( 206 | df=df, 207 | export_path=export_path, 208 | title=title, 209 | structured_filters=structured_filters, 210 | visualizations=visualizations, 211 | format=format, 212 | go_figs=go_figs, 213 | ) 214 | 215 | 216 | def export_options( 217 | filtered_df: pd.DataFrame, 218 | structured_filters: dict[str, t.Any], 219 | visualizations: list[dict], 220 | title: str, 221 | go_figs: dict[str, go.Figure], 222 | ) -> None: 223 | """Display and handle export controls for the dashboard. 224 | 225 | Args: 226 | filtered_df: DataFrame to be exported 227 | visualizations: List of visualization dictionaries 228 | cluster_fig: Optional plotly figure for cluster visualization 229 | """ 230 | st.header("Export Options") 231 | export_col1, export_col2, export_col3, _ = st.columns([1, 1, 1, 1]) 232 | 233 | with export_col1: 234 | export_path = st.text_input( 235 | "Export Directory", 236 | value=ARES_DATA_DIR, 237 | help="Directory where exported files will be saved", 238 | ) 239 | 240 | with export_col2: 241 | export_format = st.selectbox( 242 | "Export Format", 243 | options=["html", "pdf", "csv", "xlsx", "parquet"], 244 | help="Choose the format for your export. HTML/PDF include visualizations. CSV/XLSX include filtered data only.", 245 | ) 246 | 247 | with export_col3: 248 | if st.button("Export Dashboard"): 249 | try: 250 | with st.spinner(f"Exporting dashboard as {export_format}..."): 251 | export_path = export_dashboard( 252 | df=filtered_df, 253 | structured_filters=structured_filters, 254 | visualizations=visualizations, 255 | base_path=export_path, 256 | title=title, 257 | go_figs=go_figs, 258 | format=export_format, 259 | ) 260 | st.success(f"Dashboard exported successfully to: {export_path}") 261 | except Exception as e: 262 | st.error( 263 | f"Failed to export dashboard: {str(e)}\n{traceback.format_exc()}" 264 | ) 265 | -------------------------------------------------------------------------------- /src/ares/app/hero_display.py: -------------------------------------------------------------------------------- 1 | import traceback 2 | 3 | import numpy as np 4 | import pandas as pd 5 | import streamlit as st 6 | 7 | from ares.app.annotation_viz_helpers import draw_annotations, draw_detection_data 8 | from ares.app.viz_helpers import ( 9 | create_embedding_similarity_visualization, 10 | create_similarity_tabs, 11 | create_text_similarity_visualization, 12 | generate_robot_array_plot_visualizations, 13 | get_video_annotation_data, 14 | ) 15 | from ares.databases.annotation_database import get_video_id 16 | from ares.databases.embedding_database import ( 17 | META_INDEX_NAMES, 18 | TRAJECTORY_INDEX_NAMES, 19 | IndexManager, 20 | rollout_to_index_name, 21 | ) 22 | from ares.utils.image_utils import ( 23 | choose_and_preprocess_frames, 24 | get_video_frames, 25 | get_video_mp4, 26 | ) 27 | 28 | 29 | def setup_zero_distance_checkbox_with_state() -> str: 30 | zero_distance_filter_key = "filter_zero_distance_matches" 31 | if zero_distance_filter_key not in st.session_state: 32 | st.session_state[zero_distance_filter_key] = False 33 | 34 | if st.checkbox( 35 | "Filter out zero-distance matches", 36 | value=st.session_state[zero_distance_filter_key], 37 | ): 38 | st.session_state[zero_distance_filter_key] = True 39 | else: 40 | st.session_state[zero_distance_filter_key] = False 41 | return zero_distance_filter_key 42 | 43 | 44 | def display_hero_annotations( 45 | db_data: dict, video_id: str, dataset: str, fname: str 46 | ) -> None: 47 | """ 48 | Create a nice display of all the annotations in the database. 49 | This include grounding annotations (e.g. detections) as well as composed datasets like embodied-chain-of-thought. 50 | """ 51 | # check for annotation data 52 | annotation_data = db_data.get("annotations") 53 | if not annotation_data: 54 | st.warning(f"No annotation data found for this video for {video_id}") 55 | else: 56 | detection_data = annotation_data.get("detection") 57 | if detection_data: 58 | draw_detection_data(detection_data, dataset, fname) 59 | 60 | # show other top-level annotations (not frame-based) 61 | other_keys = [k for k in annotation_data.keys() if k != "detection"] 62 | with st.expander("Annotation Description Data", expanded=False): 63 | for key in other_keys: 64 | if isinstance(annotation_data[key], list): 65 | this_data = [ 66 | ann.description 67 | for ann in annotation_data[key] 68 | if ann.description 69 | ] 70 | if this_data: 71 | st.write(f"**{key.replace('_', ' ').title()}**") 72 | st.write(this_data) 73 | 74 | # display all other annotation data, eg pseduo-ECoT, in-context-learning datasets, etc 75 | with st.expander("Raw Annotation Data (as JSON)", expanded=False): 76 | if "video_data" in db_data: 77 | st.write("Video Data:") 78 | st.json(db_data["video_data"], expanded=False) 79 | if "annotations" in db_data: 80 | st.write("Annotations:") 81 | st.json(db_data["annotations"], expanded=False) 82 | 83 | 84 | def create_similarity_viz_objects( 85 | row: pd.Series, 86 | df: pd.DataFrame, 87 | index_manager: IndexManager, 88 | retrieve_n_most_similar: int, 89 | ) -> tuple[list[str], list[dict]]: 90 | """ 91 | Use our embedding indices to retrieve similar examples based on meta indices (task, description) or trajectories (states, actions) per dataset. 92 | Right now we construct a pure-text comparison for only task as descriptions are too long for efficient calculation. 93 | """ 94 | # some robotics datasets have lots of overlap, e.g. the same task instruction. 95 | # we may want to filter out zero-distance matches, even if they aren't the same ID 96 | # and will need to persist selection in state 97 | zero_distance_filter_key = setup_zero_distance_checkbox_with_state() 98 | 99 | # text comparison of levenshtein distance over task 100 | text_distance_fn = "levenshtein" 101 | text_distance_data_key = "task_language_instruction" 102 | text_viz_data = create_text_similarity_visualization( 103 | row, 104 | df, 105 | n_most_similar=retrieve_n_most_similar, 106 | data_key=text_distance_data_key, 107 | distance_fn_name=text_distance_fn, 108 | filter_zero_distance_matches=st.session_state[zero_distance_filter_key], 109 | ) 110 | 111 | # embedding retrieval over META_INDEX_NAMES 112 | text_embedding_viz_data = [ 113 | create_embedding_similarity_visualization( 114 | row, 115 | name=key, 116 | index_manager=index_manager, 117 | n_most_similar=retrieve_n_most_similar, 118 | filter_zero_distance_matches=st.session_state[zero_distance_filter_key], 119 | ) 120 | for key in META_INDEX_NAMES 121 | ] 122 | 123 | # embedding retrieval over TRAJECTORY_INDEX_NAMES 124 | trajectory_viz_data = [ 125 | create_embedding_similarity_visualization( 126 | row, 127 | name=rollout_to_index_name(row, k), 128 | index_manager=index_manager, 129 | n_most_similar=retrieve_n_most_similar, 130 | filter_zero_distance_matches=st.session_state[zero_distance_filter_key], 131 | ) 132 | for k in TRAJECTORY_INDEX_NAMES 133 | ] 134 | 135 | # organize tab names and visualizations for tabs 136 | tab_names = [ 137 | f"{text_distance_data_key.replace('_',' ').title()} - {text_distance_fn.title()}", 138 | *[f"{t.replace('_', ' ')} - Embedding".title() for t in META_INDEX_NAMES], 139 | *[t.title() for t in TRAJECTORY_INDEX_NAMES], 140 | ] 141 | 142 | similarity_viz = [ 143 | text_viz_data, 144 | *text_embedding_viz_data, 145 | *trajectory_viz_data, 146 | ] 147 | return tab_names, similarity_viz 148 | 149 | 150 | def show_hero_display( 151 | df: pd.DataFrame, 152 | row: pd.Series, 153 | all_vecs: dict, 154 | index_manager: IndexManager, 155 | traj_array_show_n: int = 100, 156 | retrieve_n_most_similar: int = 5, 157 | lazy_load: bool = False, 158 | max_cols: int = 5, 159 | ) -> None: 160 | """ 161 | Row 1: text 162 | Row 2: video col, detail + robot array plots 163 | Row 3: n tabs covering most similar based on state, action, video, text (embedding), text (metric) 164 | Returns: visualization figures to be included in export 165 | """ 166 | 167 | dataset, fname = ( 168 | row["dataset_filename"], 169 | row["filename"], 170 | ) 171 | 172 | col1, col2 = st.columns(2) 173 | with col1: 174 | # display the video 175 | if lazy_load: 176 | frame = get_video_frames(dataset, fname, n_frames=1)[0] 177 | st.image(frame) 178 | if st.button("Load Video"): 179 | st.video(get_video_mp4(dataset, fname)) 180 | else: 181 | st.video(get_video_mp4(dataset, fname)) 182 | with col2: 183 | # display a few key pieces of information, e.g. task and success 184 | if row.task_language_instruction: 185 | st.write(f"**Task:** {row.task_language_instruction}") 186 | if not np.isnan(row.task_success) and row.task_success: 187 | st.write(f"**Success:** {row.task_success:.2f}") 188 | if isinstance(row.trajectory_reward_step, str): 189 | if int(row.trajectory_reward_step) >= 0: 190 | st.write( 191 | f"**Reward Step:** {row.trajectory_reward_step} ({100*int(row.trajectory_reward_step) / int(row.length):.2f}% through rollout)" 192 | ) 193 | else: 194 | st.write(f"Failure episode!") 195 | # optionally display the rest of the row details as a truncated json 196 | with st.expander("Row Details", expanded=False): 197 | json_repr = { 198 | k: (v if len(str(v)) < 1000 else str(v)[:1000] + "...") 199 | for k, v in sorted(row.to_dict().items(), key=lambda x: x[0]) 200 | } 201 | st.json(json_repr, expanded=False) 202 | 203 | if st.button("Generate Robot Array Plots", key="robot_array_plots_button_hero"): 204 | array_figs = generate_robot_array_plot_visualizations( 205 | row, all_vecs, traj_array_show_n, highlight_row=True 206 | ) 207 | else: 208 | array_figs = [] 209 | 210 | # Add annotation data retrieval button 211 | if st.button("Retrieve Annotation Data"): 212 | try: 213 | video_id = get_video_id(dataset, fname) 214 | db_data = get_video_annotation_data(video_id) 215 | if db_data is not None: 216 | display_hero_annotations(db_data, video_id, dataset, fname) 217 | else: 218 | st.warning(f"No video or annotation data found for {video_id}") 219 | except Exception as e: 220 | st.error(f"Error retrieving annotation data: {str(e)}") 221 | st.error(traceback.format_exc()) 222 | 223 | # Row 3: n tabs covering most similar based on state, action, text 224 | st.write(f"**Similar Examples**") 225 | st.write(f"Most similar examples to {row['id']}, based on:") 226 | 227 | tab_names, similarity_viz = create_similarity_viz_objects( 228 | row, df, index_manager, retrieve_n_most_similar 229 | ) 230 | 231 | # Create the tabs with the data 232 | create_similarity_tabs( 233 | similarity_viz, 234 | tab_names, 235 | df, 236 | max_cols_in_tab=max_cols, 237 | ) 238 | -------------------------------------------------------------------------------- /src/ares/app/init_data.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import numpy as np 4 | import pandas as pd 5 | import streamlit as st 6 | from sqlalchemy import select 7 | from sqlalchemy.orm import Session 8 | 9 | from ares.databases.annotation_database import ANNOTATION_DB_PATH, AnnotationDatabase 10 | from ares.databases.embedding_database import ( 11 | EMBEDDING_DB_PATH, 12 | META_INDEX_NAMES, 13 | FaissIndex, 14 | IndexManager, 15 | ) 16 | from ares.databases.structured_database import ( 17 | ROBOT_DB_PATH, 18 | RolloutSQLModel, 19 | setup_database, 20 | ) 21 | from ares.models.base import VLM 22 | from ares.utils.clustering import cluster_embeddings 23 | 24 | 25 | def load_cached_embeddings( 26 | tmp_dump_dir: str, index_name: str, stored_embeddings: np.ndarray 27 | ) -> tuple | None: 28 | """ 29 | Settig up embedding visualizations can be expensive, so we locally cache some of the generated arrays. 30 | """ 31 | embeddings_path = os.path.join(tmp_dump_dir, f"{index_name}_embeddings.npy") 32 | clusters_path = os.path.join(tmp_dump_dir, f"{index_name}_clusters.npz") 33 | ids_path = os.path.join(tmp_dump_dir, f"{index_name}_ids.npy") # New path for IDs 34 | 35 | if not ( 36 | os.path.exists(embeddings_path) 37 | and os.path.exists(clusters_path) 38 | and os.path.exists(ids_path) 39 | ): 40 | return None 41 | 42 | loaded_embeddings = np.load(embeddings_path) 43 | if not ( 44 | len(loaded_embeddings) == len(stored_embeddings) 45 | and np.allclose(loaded_embeddings, stored_embeddings) 46 | ): 47 | # this means we have new embeddings 48 | return None 49 | 50 | # Valid cached data found - load everything 51 | clusters_data = np.load(clusters_path) 52 | loaded_ids = np.load(ids_path) # Load IDs 53 | return ( 54 | loaded_embeddings, 55 | clusters_data["reduced"], 56 | clusters_data["labels"], 57 | loaded_ids, # Return IDs as well 58 | ) 59 | 60 | 61 | def save_embeddings( 62 | tmp_dump_dir: str, 63 | index_name: str, 64 | embeddings: np.ndarray, 65 | reduced: np.ndarray, 66 | labels: np.ndarray, 67 | ids: np.ndarray, # Add IDs parameter 68 | ) -> None: 69 | """ 70 | Save reduced embeddings, clusters, and IDs to disk 71 | """ 72 | embeddings_path = os.path.join(tmp_dump_dir, f"{index_name}_embeddings.npy") 73 | clusters_path = os.path.join(tmp_dump_dir, f"{index_name}_clusters.npz") 74 | ids_path = os.path.join(tmp_dump_dir, f"{index_name}_ids.npy") # New path for IDs 75 | 76 | np.save(embeddings_path, embeddings) 77 | np.savez(clusters_path, reduced=reduced, labels=labels) 78 | np.save(ids_path, ids) # Save IDs 79 | 80 | 81 | def store_in_session( 82 | index_name: str, 83 | embeddings: np.ndarray, 84 | reduced: np.ndarray, 85 | labels: np.ndarray, 86 | stored_ids: np.ndarray, 87 | ) -> None: 88 | """ 89 | Store embeddings, clusters, other info in session state 90 | """ 91 | st.session_state[f"{index_name}_embeddings"] = embeddings 92 | st.session_state[f"{index_name}_reduced"] = reduced 93 | st.session_state[f"{index_name}_labels"] = labels 94 | st.session_state[f"{index_name}_ids"] = stored_ids 95 | 96 | 97 | def initialize_data(tmp_dump_dir: str) -> None: 98 | """ 99 | Initialize database connection, load data and create embeddings with caching. 100 | """ 101 | # Skip if already initialized 102 | if all( 103 | key in st.session_state for key in ["ENGINE", "SESSION", "df", "INDEX_MANAGER"] 104 | ): 105 | print("Data already initialized") 106 | return 107 | 108 | # Initialize database and session 109 | print("Initializing database and session") 110 | engine = setup_database(RolloutSQLModel, path=ROBOT_DB_PATH) 111 | sess = Session(engine) 112 | st.session_state.ENGINE = engine 113 | st.session_state.SESSION = sess 114 | 115 | # Load dataframe 116 | print("Loading dataframe") 117 | query = select(RolloutSQLModel) 118 | df = pd.read_sql(query, engine) 119 | # Filter out unnamed columns 120 | df = df[[c for c in df.columns if "unnamed" not in c.lower()]] 121 | st.session_state.df = df 122 | 123 | # Initialize index manager 124 | print("Initializing index manager") 125 | index_manager = IndexManager(base_dir=EMBEDDING_DB_PATH, index_class=FaissIndex) 126 | st.session_state.INDEX_MANAGER = index_manager 127 | 128 | # Get all vectors and their IDs 129 | print("Getting all vectors and their IDs") 130 | all_data = index_manager.get_all_matrices() 131 | st.session_state.all_vecs = { 132 | name: data["arrays"] for name, data in all_data.items() 133 | } 134 | st.session_state.all_ids = {name: data["ids"] for name, data in all_data.items()} 135 | 136 | # Create tmp directory if it doesn't exist 137 | os.makedirs(tmp_dump_dir, exist_ok=True) 138 | 139 | # Process each index type 140 | for index_name in META_INDEX_NAMES: 141 | print(f"Processing {index_name} index") 142 | stored_embeddings = index_manager.indices[index_name].get_all_vectors() 143 | stored_ids = index_manager.indices[index_name].get_all_ids() 144 | # Try loading from cache first 145 | cached_data = load_cached_embeddings( 146 | tmp_dump_dir, index_name, stored_embeddings 147 | ) 148 | if cached_data is not None: 149 | embeddings, reduced, labels, ids = cached_data # Unpack IDs from cache 150 | else: 151 | # Create new embeddings and clusters 152 | embeddings = stored_embeddings 153 | reduced, labels, _ = cluster_embeddings(embeddings) 154 | ids = stored_ids # Use the IDs from the index 155 | save_embeddings(tmp_dump_dir, index_name, embeddings, reduced, labels, ids) 156 | 157 | # Store in session state 158 | store_in_session(index_name, embeddings, reduced, labels, stored_ids) 159 | 160 | print("Setting up models") 161 | st.session_state.models = dict() 162 | st.session_state.models["summarizer"] = VLM(provider="openai", name="gpt-4o-mini") 163 | print("Setting up annotations database") 164 | st.session_state.annotations_db = AnnotationDatabase( 165 | connection_string=ANNOTATION_DB_PATH 166 | ) 167 | print("Getting annotations database stats") 168 | st.session_state.annotation_db_stats = ( 169 | st.session_state.annotations_db.get_database_stats() 170 | ) 171 | 172 | 173 | def display_state_info() -> None: 174 | """ 175 | Helpful debugging state info displayed as the first rendered item in the streamlit object 176 | """ 177 | with st.expander("Session State Data Overview"): 178 | """Display information about data stored in streamlit session state.""" 179 | st.header("Session State Data Overview") 180 | 181 | # Database Info 182 | st.subheader("Database Connections") 183 | st.write( 184 | "- ENGINE:", 185 | "Connected" if "ENGINE" in st.session_state else "Not Connected", 186 | ) 187 | st.write( 188 | "- SESSION:", 189 | "Connected" if "SESSION" in st.session_state else "Not Connected", 190 | ) 191 | st.write( 192 | "- ANNOTATIONS DB", 193 | "Connected" if "annotations_db" in st.session_state else "Not Connected", 194 | ) 195 | # Index Manager Info 196 | st.subheader("Index Manager") 197 | if "INDEX_MANAGER" in st.session_state: 198 | index_manager = st.session_state.INDEX_MANAGER 199 | st.write( 200 | "Available indices: " + ", ".join(list(index_manager.indices.keys())) 201 | ) 202 | with st.popover("Index Details"): 203 | for name, index in index_manager.indices.items(): 204 | st.write(f"\n**{name} Index:**") 205 | st.write(f"- Feature dimension: {index.feature_dim}") 206 | st.write(f"- Time steps: {index.time_steps}") 207 | st.write(f"- Total entries: {index.n_entries}") 208 | 209 | # Vectors and IDs 210 | st.subheader("Stored Vectors and IDs") 211 | if "all_vecs" in st.session_state: 212 | with st.popover("Vector Details"): 213 | for name, vecs in st.session_state.all_vecs.items(): 214 | if vecs is not None: 215 | st.write(f"\n**{name}:**") 216 | st.write(f"- Vector shape: {vecs.shape}") 217 | st.write( 218 | f"- Number of IDs: {len(st.session_state.all_ids[name])}" 219 | ) 220 | 221 | st.subheader("Annotations DB") 222 | if "annotations_db" in st.session_state: 223 | if "annotation_db_stats" not in st.session_state: 224 | st.session_state.annotation_db_stats = ( 225 | st.session_state.annotations_db.get_database_stats() 226 | ) 227 | st.json(st.session_state.annotation_db_stats, expanded=False) 228 | 229 | # Embeddings Info 230 | st.subheader("Embedding Data") 231 | for index_name in META_INDEX_NAMES: 232 | st.write(f"\n**{index_name}:**") 233 | 234 | emb_key = f"{index_name}_embeddings" 235 | red_key = f"{index_name}_reduced" 236 | lab_key = f"{index_name}_labels" 237 | 238 | if emb_key in st.session_state: 239 | embeddings = st.session_state[emb_key] 240 | reduced = st.session_state[red_key] 241 | labels = st.session_state[lab_key] 242 | 243 | st.write(f"- Original embeddings shape: {embeddings.shape}") 244 | st.write(f"- Reduced embeddings shape: {reduced.shape}") 245 | st.write(f"- Number of labels: {len(labels)}") 246 | st.write(f"- Unique clusters: {len(np.unique(labels))}") 247 | -------------------------------------------------------------------------------- /src/ares/app/sections.py: -------------------------------------------------------------------------------- 1 | """ 2 | Helper functions for building the Streamlit app. This file exists to abstract out the code for building the app into smaller, more manageable components 3 | and leave the main `ares.app.webapp.py` file clean and readable, exclusively for display, metrics, and export functionality. 4 | """ 5 | 6 | import typing as t 7 | 8 | import pandas as pd 9 | import streamlit as st 10 | 11 | from ares.app.data_analysis import generate_automatic_visualizations 12 | from ares.app.filter_helpers import ( 13 | create_embedding_data_filter_display, 14 | select_row_from_df_user, 15 | structured_data_filters_display, 16 | ) 17 | from ares.app.hero_display import show_hero_display 18 | from ares.app.init_data import display_state_info, initialize_data 19 | from ares.app.plot_primitives import show_dataframe 20 | from ares.app.viz_helpers import ( 21 | annotation_statistics, 22 | create_tabbed_visualizations, 23 | display_video_grid, 24 | generate_robot_array_plot_visualizations, 25 | generate_success_rate_visualizations, 26 | generate_time_series_visualizations, 27 | total_statistics, 28 | ) 29 | from ares.databases.embedding_database import META_INDEX_NAMES 30 | 31 | 32 | def load_data(tmp_dump_dir: str) -> pd.DataFrame: 33 | """Load data from session state or initialize if needed.""" 34 | initialize_data(tmp_dump_dir) 35 | return st.session_state.df 36 | 37 | 38 | def loading_data_section(title: str, tmp_dump_dir: str) -> pd.DataFrame: 39 | st.set_page_config(page_title=title, page_icon="📊", layout="wide") 40 | st.title(title) 41 | return load_data(tmp_dump_dir) 42 | 43 | 44 | def state_info_section(df: pd.DataFrame) -> None: 45 | display_state_info() 46 | total_statistics(df) 47 | annotation_statistics(st.session_state.annotations_db) 48 | 49 | 50 | def structured_data_filters_section( 51 | df: pd.DataFrame, 52 | ) -> tuple[pd.DataFrame, dict[str, t.Any]]: 53 | st.header(f"Data Filters") 54 | structured_filtered_df, active_filters = structured_data_filters_display( 55 | df, debug=False 56 | ) 57 | st.write( 58 | f"Selected {len(structured_filtered_df)} rows out of {len(df)} total via structured data filters" 59 | ) 60 | if len(structured_filtered_df) == 0: 61 | st.warning("No data matches the structured filters!") 62 | return structured_filtered_df, active_filters 63 | 64 | 65 | def embedding_data_filters_section( 66 | df: pd.DataFrame, 67 | structured_filtered_df: pd.DataFrame, 68 | ) -> pd.DataFrame: 69 | st.subheader(f"Unstructured Data Filters") 70 | embedding_figs = dict() 71 | embedding_filtered_dfs = [] 72 | 73 | # Get filtered dataframes for each embedding 74 | for raw_data_key in META_INDEX_NAMES: 75 | st.write(f"**Filtering on {raw_data_key.replace('_', ' ').title()}**") 76 | filtered_df, cluster_fig = create_embedding_data_filter_display( 77 | df=df, # Pass original df each time 78 | id_key="id", 79 | raw_data_key=raw_data_key, 80 | kept_ids=structured_filtered_df["id"].apply(str).tolist(), 81 | ) 82 | embedding_filtered_dfs.append(filtered_df) 83 | embedding_figs[raw_data_key] = cluster_fig 84 | 85 | # Combine all filtered dataframes (AND operation) 86 | if embedding_filtered_dfs: 87 | all_filtered_ids = set(embedding_filtered_dfs[0]["id"]) 88 | for filtered_df in embedding_filtered_dfs[1:]: 89 | all_filtered_ids &= set(filtered_df["id"]) 90 | 91 | # Final filtered dataframe combines structured and embedding filters 92 | filtered_df = structured_filtered_df[ 93 | structured_filtered_df["id"].isin(all_filtered_ids) 94 | ] 95 | else: 96 | filtered_df = structured_filtered_df 97 | return filtered_df, embedding_figs 98 | 99 | 100 | def data_distributions_section(filtered_df: pd.DataFrame) -> list[dict]: 101 | max_x_bar_options = 100 102 | # Create overview of all data 103 | st.header("Distribution Analytics") 104 | general_visualizations = generate_automatic_visualizations( 105 | filtered_df, 106 | time_column="ingestion_time", 107 | max_x_bar_options=max_x_bar_options, 108 | ) 109 | general_visualizations = sorted(general_visualizations, key=lambda x: x["title"]) 110 | create_tabbed_visualizations( 111 | general_visualizations, [viz["title"] for viz in general_visualizations] 112 | ) 113 | return general_visualizations 114 | 115 | 116 | def success_rate_analytics_section(filtered_df: pd.DataFrame) -> list[dict]: 117 | st.header("Success Estimate Analytics") 118 | success_visualizations = generate_success_rate_visualizations(filtered_df) 119 | create_tabbed_visualizations( 120 | success_visualizations, [viz["title"] for viz in success_visualizations] 121 | ) 122 | return success_visualizations 123 | 124 | 125 | def time_series_analytics_section(filtered_df: pd.DataFrame) -> list[dict]: 126 | st.header("Time Series Trends") 127 | time_series_visualizations = generate_time_series_visualizations( 128 | filtered_df, time_column="ingestion_time" 129 | ) 130 | create_tabbed_visualizations( 131 | time_series_visualizations, 132 | [viz["title"] for viz in time_series_visualizations], 133 | ) 134 | return time_series_visualizations 135 | 136 | 137 | def video_grid_section(filtered_df: pd.DataFrame) -> None: 138 | # show video cards of first 5 rows in a horizontal layout 139 | st.header("Rollout Examples") 140 | n_videos = 5 141 | display_rows = pd.concat( 142 | {k: v.head(1) for k, v in filtered_df.groupby("dataset_name")} 143 | ) 144 | if len(display_rows) < n_videos: 145 | # get enough videos to fill n_videos that arent already in display_rows 146 | extra_rows = filtered_df.head(n_videos) 147 | # remove rows that are already in display_rows 148 | extra_rows = extra_rows[~extra_rows.id.isin(display_rows.id)] 149 | display_rows = pd.concat([display_rows, extra_rows]) 150 | display_video_grid(display_rows, lazy_load=True) 151 | 152 | 153 | def plot_hero_section(df: pd.DataFrame, filtered_df: pd.DataFrame) -> pd.Series: 154 | st.header("Rollout Display") 155 | # initialize or persist selected row in state 156 | select_row_from_df_user(filtered_df) 157 | selected_row = st.session_state.get("selected_row") 158 | 159 | if selected_row is not None: 160 | show_dataframe( 161 | pd.DataFrame([selected_row]), title="Selected Row", add_refresh_button=False 162 | ) 163 | st.write(f"Selected row ID: {selected_row.id}") 164 | show_hero_display( 165 | df, # compare selected row from filtered_df to all rows in df 166 | selected_row, 167 | st.session_state.all_vecs, 168 | index_manager=st.session_state.INDEX_MANAGER, 169 | lazy_load=False, 170 | retrieve_n_most_similar=10, 171 | ) 172 | else: 173 | st.info("Please select a row to display details") 174 | return selected_row 175 | 176 | 177 | def robot_array_section(selected_row: pd.Series) -> list[dict]: 178 | if st.button("Generate Robot Array Plots", key="robot_array_plots_button"): 179 | st.header("Robot Array Display") 180 | # Number of trajectories to display in plots 181 | robot_array_visualizations = generate_robot_array_plot_visualizations( 182 | selected_row, # need row to select dataset/robot embodiment of trajectories 183 | st.session_state.all_vecs, 184 | show_n=1000, 185 | ) 186 | else: 187 | st.write("No robot array plots generated") 188 | robot_array_visualizations = [] 189 | return robot_array_visualizations 190 | -------------------------------------------------------------------------------- /src/ares/app/webapp.py: -------------------------------------------------------------------------------- 1 | """ 2 | Main file for displaying the Streamlit app. This file contains the main function that defines the order of the sections in the app as well as 3 | state management, error handling, timing, and data export functionality. 4 | """ 5 | 6 | import os 7 | import time 8 | import traceback 9 | import typing as t 10 | from collections import defaultdict 11 | from contextlib import contextmanager 12 | 13 | import streamlit as st 14 | 15 | from ares.app.export_data import export_options 16 | from ares.app.plot_primitives import show_dataframe 17 | from ares.app.sections import ( 18 | data_distributions_section, 19 | embedding_data_filters_section, 20 | loading_data_section, 21 | plot_hero_section, 22 | robot_array_section, 23 | state_info_section, 24 | structured_data_filters_section, 25 | success_rate_analytics_section, 26 | time_series_analytics_section, 27 | video_grid_section, 28 | ) 29 | from ares.constants import ARES_DATA_DIR 30 | 31 | # top level global variables 32 | title = "ARES Dashboard" 33 | tmp_dump_dir = os.path.join(ARES_DATA_DIR, "webapp_tmp") 34 | section_times: dict[str, float] = defaultdict(float) 35 | 36 | 37 | ###################################################################### 38 | # Context managers for error handling and timing 39 | # - `error_context` is used to catch errors in computation and render the error in the app 40 | # - `timer_context` is used to time the execution of a section and print the timing to the console 41 | ###################################################################### 42 | @contextmanager 43 | def error_context(section_name: str) -> t.Any: 44 | """ 45 | Context manager for gracefully handling errors in computation of a section. 46 | Catch the error and render it in the app for easy debugging and readability. 47 | """ 48 | print(section_name) 49 | try: 50 | yield 51 | except Exception as e: 52 | st.error(f"Error in {section_name}: {str(e)}\n{traceback.format_exc()}") 53 | st.write("Stopping execution") 54 | st.stop() 55 | 56 | 57 | @contextmanager 58 | def timer_context(section_name: str) -> t.Any: 59 | """ 60 | Context manager for timing sections, helpful for debugging and performance analysis. 61 | """ 62 | start_time = time.time() 63 | try: 64 | yield 65 | finally: 66 | elapsed_time = time.time() - start_time 67 | section_times[section_name] += elapsed_time 68 | 69 | 70 | # Main function defining the order of the streamlit subsections 71 | # Note: streamlit displays standalone-strings like `"""..."""` as markdown! 72 | # Use `#` for comments in the streamlit context. 73 | def main() -> None: 74 | ###################################################################### 75 | # Load data and setup state info 76 | ###################################################################### 77 | section_loading = "loading data" 78 | with error_context(section_loading), timer_context(section_loading): 79 | df = loading_data_section(title, tmp_dump_dir) 80 | 81 | # simple expander for st.session_state information; helpful for debugging 82 | section_state_info = "state info" 83 | with error_context(section_state_info), timer_context(section_state_info): 84 | state_info_section(df) 85 | st.divider() 86 | 87 | ###################################################################### 88 | # Filter data using structured (selected via buttons, dropdowns, etc.) and embedding (selected via pointer, boxes) filters 89 | ###################################################################### 90 | section_filters = "structured data filters" 91 | with error_context(section_filters), timer_context(section_filters): 92 | structured_filtered_df, active_structured_filters = ( 93 | structured_data_filters_section(df) 94 | ) 95 | 96 | section_embedding_filters = "embedding data filters" 97 | with ( 98 | error_context(section_embedding_filters), 99 | timer_context(section_embedding_filters), 100 | ): 101 | filtered_df, embedding_figs = embedding_data_filters_section( 102 | df, structured_filtered_df 103 | ) 104 | if filtered_df.empty: 105 | st.warning( 106 | "No data available for the selected points! Try adjusting your selection to receive analytics." 107 | ) 108 | return 109 | st.divider() 110 | 111 | ###################################################################### 112 | # Display a section of the data and the distributions of the data, covering: 113 | # - general data distribution 114 | # - success rate 115 | # - time series trends 116 | # - video grid of examples 117 | ###################################################################### 118 | section_data_sample = "data sample" 119 | with error_context(section_data_sample), timer_context(section_data_sample): 120 | show_dataframe( 121 | filtered_df.sample(min(5, len(filtered_df))), title="Data Sample" 122 | ) 123 | st.divider() 124 | 125 | section_display = "data distributions" 126 | with error_context(section_display), timer_context(section_display): 127 | data_distributation_visualizations = data_distributions_section(filtered_df) 128 | 129 | section_success_rate = "success estimate analytics" 130 | with ( 131 | error_context(section_success_rate), 132 | timer_context(section_success_rate), 133 | ): 134 | success_rate_visualizations = success_rate_analytics_section(filtered_df) 135 | st.divider() 136 | 137 | section_time_series = "time series trends" 138 | with error_context(section_time_series), timer_context(section_time_series): 139 | time_series_visualizations = time_series_analytics_section(filtered_df) 140 | st.divider() 141 | 142 | section_video_grid = "video grid" 143 | with error_context(section_video_grid), timer_context(section_video_grid): 144 | video_grid_section(filtered_df) 145 | st.divider() 146 | 147 | ###################################################################### 148 | # Create a centralized focus on a single row of data with a 'hero' display 149 | # - Show the video, annotations, and other relevant data 150 | # - Create a tabbed interface for different views of the data 151 | # - Retrieve similar examples based on different metrics 152 | ###################################################################### 153 | section_plot_hero = "plot hero display" 154 | with error_context(section_plot_hero), timer_context(section_plot_hero): 155 | selected_row = plot_hero_section(df, filtered_df) 156 | st.divider() 157 | 158 | ###################################################################### 159 | # Plot robot arrays showing the distribution of robot actions and states relative to the rest 160 | # of the dataset. Useful for finding outliers and other interesting patterns. 161 | ###################################################################### 162 | section_plot_robots = "plot robot arrays" 163 | with error_context(section_plot_robots), timer_context(section_plot_robots): 164 | robot_array_visualizations = robot_array_section(selected_row) 165 | st.divider() 166 | 167 | ###################################################################### 168 | # Export the data and all visualizations to a file or training format. 169 | # Note: we don't export video grids due to file size. 170 | ###################################################################### 171 | section_export = "exporting data" 172 | with error_context(section_export), timer_context(section_export): 173 | all_visualizations = [ 174 | *data_distributation_visualizations, 175 | *success_rate_visualizations, 176 | *time_series_visualizations, 177 | *robot_array_visualizations, 178 | ] 179 | export_options( 180 | filtered_df, 181 | active_structured_filters, 182 | all_visualizations, 183 | title, 184 | go_figs=embedding_figs, 185 | ) 186 | 187 | ###################################################################### 188 | # Display the timing report found by the timer context manager 189 | ###################################################################### 190 | print("\n=== Timing Report ===") 191 | print(f"Total time: {sum(section_times.values()):.2f} seconds") 192 | for section, elapsed_time in section_times.items(): 193 | print(f"{section}: {elapsed_time:.2f} seconds") 194 | print("==================\n") 195 | 196 | 197 | if __name__ == "__main__": 198 | main() 199 | -------------------------------------------------------------------------------- /src/ares/configs/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jacobphillips99/ares/21850d69242f9106c186338b5883667562f4442f/src/ares/configs/__init__.py -------------------------------------------------------------------------------- /src/ares/configs/annotations.py: -------------------------------------------------------------------------------- 1 | import json 2 | import typing as t 3 | 4 | import cv2 5 | import numpy as np 6 | from pycocotools import mask as mask_utils 7 | from pydantic import BaseModel, Field, model_validator 8 | 9 | 10 | def rle_to_binary_mask(rle: dict) -> np.ndarray: 11 | """Convert RLE format to binary mask.""" 12 | rle = {"counts": rle["counts"].encode("utf-8"), "size": rle["size"]} 13 | return mask_utils.decode(rle) 14 | 15 | 16 | def binary_mask_to_rle(mask: np.ndarray) -> dict: 17 | """Convert binary mask to RLE format.""" 18 | rle = mask_utils.encode(np.asfortranarray(mask.astype(np.uint8))) 19 | return {"counts": rle["counts"].decode("utf-8"), "size": rle["size"]} 20 | 21 | 22 | class Annotation(BaseModel): 23 | """ 24 | Base object to hold annotation data. 25 | """ 26 | 27 | # Core detection attributes 28 | description: str | None = None 29 | bbox: list[float] | None = None # [x1, y1, x2, y2] / LTRB format 30 | category_id: int | None = None 31 | category_name: str | None = None 32 | # denotes confidence of the detection if float else None if ground truth 33 | score: float | None = None 34 | 35 | # Segmentation attributes 36 | segmentation: t.Optional[t.Union[dict, list[list[float]]]] = ( 37 | None # RLE or polygon format 38 | ) 39 | 40 | # Tracking and metadata 41 | track_id: t.Optional[int] = None 42 | attributes: dict[str, t.Any] = Field(default_factory=dict) 43 | annotation_type: str | None = None 44 | 45 | model_config = { 46 | "arbitrary_types_allowed": True, 47 | "json_encoders": { 48 | np.ndarray: lambda x: x.tolist(), 49 | np.integer: lambda x: int(x), 50 | np.floating: lambda x: float(x), 51 | }, 52 | } 53 | 54 | @model_validator(mode="after") 55 | def sanity_check(self) -> "Annotation": 56 | # some portion of the annotation must be present! 57 | if ( 58 | self.description is None 59 | and self.bbox is None 60 | and self.segmentation is None 61 | and self.attributes is None 62 | ): 63 | raise ValueError( 64 | "Annotation must have at least one attribute; description, bbox, segmentation, or attributes" 65 | ) 66 | return self 67 | 68 | @property 69 | def bbox_xyxy(self) -> tuple[float, float, float, float]: 70 | """Get bbox in xyxy format.""" 71 | return self.bbox 72 | 73 | @property 74 | def bbox_xywh(self) -> tuple[float, float, float, float]: 75 | """Get bbox in xywh format.""" 76 | x1, y1, x2, y2 = self.bbox 77 | return x1, y1, x2 - x1, y2 - y1 78 | 79 | # Add this validation method 80 | def model_post_init(self, __context: t.Any) -> None: 81 | """Validate bbox format after initialization.""" 82 | if self.bbox is not None: 83 | x1, y1, x2, y2 = self.bbox 84 | if x1 > x2: 85 | raise ValueError(f"Invalid bbox: x1 ({x1}) must be <= x2 ({x2})") 86 | if y1 > y2: 87 | raise ValueError(f"Invalid bbox: y1 ({y1}) must be <= y2 ({y2})") 88 | 89 | @property 90 | def mask(self) -> np.ndarray | None: 91 | """Convert RLE or polygon segmentation to binary mask.""" 92 | if self.segmentation is None: 93 | return None 94 | 95 | if isinstance(self.segmentation, dict): # RLE format 96 | return rle_to_binary_mask(self.segmentation) 97 | else: # Polygon format 98 | mask = np.zeros( 99 | ( 100 | int(max(p[1] for p in self.segmentation[0])) + 1, 101 | int(max(p[0] for p in self.segmentation[0])) + 1, 102 | ), 103 | dtype=np.uint8, 104 | ) 105 | points = np.array(self.segmentation[0]).reshape((-1, 2)) 106 | cv2.fillPoly(mask, [points.astype(np.int32)], 1) 107 | return mask 108 | 109 | @classmethod 110 | def from_mask( 111 | cls, 112 | mask: np.ndarray, 113 | bbox: list[float], 114 | category_id: int, 115 | category_name: str, 116 | score: float, 117 | **kwargs, 118 | ) -> "Annotation": 119 | """Create annotation from binary mask.""" 120 | return cls( 121 | bbox=bbox, 122 | category_id=category_id, 123 | category_name=category_name, 124 | score=score, 125 | segmentation=binary_mask_to_rle(mask), 126 | **kwargs, 127 | ) 128 | 129 | def compute_iou(self, other: "Annotation") -> float: 130 | """Compute IoU between this annotation and another.""" 131 | if self.mask is None or other.mask is None: 132 | # Fall back to bbox IoU if masks aren't available 133 | return self.compute_bbox_iou(other) 134 | 135 | intersection = np.logical_and(self.mask, other.mask).sum() 136 | t.Union = np.logical_or(self.mask, other.mask).sum() 137 | return float(intersection) / float(t.Union) if t.Union > 0 else 0.0 138 | 139 | def compute_bbox_iou(self, other: "Annotation") -> float: 140 | """Compute IoU between bounding boxes.""" 141 | # Extract coordinates 142 | x1, y1, x2, y2 = self.bbox 143 | x1_, y1_, x2_, y2_ = other.bbox 144 | 145 | # Compute intersection 146 | x_left = max(x1, x1_) 147 | y_top = max(y1, y1_) 148 | x_right = min(x2, x2_) 149 | y_bottom = min(y2, y2_) 150 | 151 | if x_right < x_left or y_bottom < y_top: 152 | return 0.0 153 | 154 | intersection = (x_right - x_left) * (y_bottom - y_top) 155 | 156 | # Compute areas 157 | area1 = (x2 - x1) * (y2 - y1) 158 | area2 = (x2_ - x1_) * (y2_ - y1_) 159 | 160 | # Compute IoU 161 | t.Union = area1 + area2 - intersection 162 | return intersection / t.Union if t.Union > 0 else 0.0 163 | 164 | def transform( 165 | self, 166 | scale_x: float = 1.0, 167 | scale_y: float = 1.0, 168 | flip_horizontal: bool = False, 169 | flip_vertical: bool = False, 170 | ) -> "Annotation": 171 | """Transform the annotation coordinates.""" 172 | # Transform bbox 173 | x1, y1, x2, y2 = self.bbox 174 | if flip_horizontal: 175 | x1, x2 = 1 - x2, 1 - x1 176 | if flip_vertical: 177 | y1, y2 = 1 - y2, 1 - y1 178 | 179 | transformed_bbox = [x1 * scale_x, y1 * scale_y, x2 * scale_x, y2 * scale_y] 180 | 181 | # Transform segmentation if it exists 182 | transformed_segmentation = None 183 | if isinstance(self.segmentation, list): # Polygon format 184 | transformed_segmentation = [] 185 | for polygon in self.segmentation: 186 | transformed_polygon = [] 187 | for i in range(0, len(polygon), 2): 188 | x, y = polygon[i], polygon[i + 1] 189 | if flip_horizontal: 190 | x = 1 - x 191 | if flip_vertical: 192 | y = 1 - y 193 | transformed_polygon.extend([x * scale_x, y * scale_y]) 194 | transformed_segmentation.append(transformed_polygon) 195 | 196 | # Create new instance with transformed coordinates 197 | return Annotation( 198 | bbox=transformed_bbox, 199 | category_id=self.category_id, 200 | category_name=self.category_name, 201 | score=self.score, 202 | segmentation=transformed_segmentation or self.segmentation, 203 | track_id=self.track_id, 204 | attributes=self.attributes, 205 | ) 206 | 207 | def to_dict(self) -> dict: 208 | """Convert annotation to dictionary format suitable for JSON serialization.""" 209 | base_dict = self.model_dump(exclude_none=True) 210 | return base_dict 211 | 212 | @classmethod 213 | def from_dict(cls, data: dict) -> "Annotation": 214 | """Create annotation from dictionary.""" 215 | return cls(**data) 216 | 217 | def save_json(self, filepath: str) -> None: 218 | """Save annotation to JSON file.""" 219 | with open(filepath, "w") as f: 220 | json.dump(self.to_dict(), f) 221 | 222 | @classmethod 223 | def load_json(cls, filepath: str) -> "Annotation": 224 | """Load annotation from JSON file.""" 225 | with open(filepath, "r") as f: 226 | data = json.load(f) 227 | return cls.from_dict(data) 228 | 229 | def __json__(self): 230 | """ 231 | Helper method for JSON serialization 232 | Used as `json.dumps(..., default=lambda x: x.__json__() if hasattr(x, "__json__") else x)` 233 | """ 234 | return self.model_dump() 235 | -------------------------------------------------------------------------------- /src/ares/configs/open_x_embodiment_configs.py: -------------------------------------------------------------------------------- 1 | import os 2 | import typing as t 3 | 4 | import numpy as np 5 | import pandas as pd 6 | import tensorflow as tf 7 | from pydantic import BaseModel, model_validator 8 | 9 | from ares.constants import ARES_OXE_DIR 10 | 11 | 12 | class TensorConverterMixin(BaseModel): 13 | """ 14 | TFDS returns tensors; we want everything in numpy arrays or 15 | base python types to work with other parts of the codebase. 16 | """ 17 | 18 | model_config = {"arbitrary_types_allowed": True} 19 | 20 | @model_validator(mode="before") 21 | @classmethod 22 | def convert_tensors_to_python(cls, data: dict) -> dict: 23 | def convert_value(value: t.Any) -> t.Any: 24 | if isinstance(value, tf.Tensor): 25 | # Convert to numpy first 26 | value = value.numpy() 27 | # Convert to base Python type if it's a scalar 28 | if np.isscalar(value): 29 | if isinstance(value, (np.bool_)): 30 | return bool(value) 31 | elif isinstance(value, np.floating): 32 | return float(value) 33 | elif isinstance(value, np.integer): 34 | return int(value) 35 | elif isinstance(value, dict): 36 | return {k: convert_value(v) for k, v in value.items()} 37 | elif isinstance(value, list): 38 | return [convert_value(v) for v in value] 39 | return value 40 | 41 | return {k: convert_value(v) for k, v in data.items()} 42 | 43 | 44 | class OpenXEmbodimentEpisodeMetadata(TensorConverterMixin, BaseModel): 45 | file_path: str 46 | success: bool | None = None 47 | 48 | 49 | class OpenXEmbodimentStepObservation(TensorConverterMixin, BaseModel): 50 | image: np.ndarray 51 | state: np.ndarray | None = None 52 | depth: np.ndarray | None = None 53 | wrist_image: np.ndarray | None = None 54 | end_effector_state: np.ndarray | None = None 55 | 56 | @model_validator(mode="before") 57 | def get_image(cls, data: dict) -> "dict": 58 | if "highres_image" in data: 59 | data["image"] = data.pop("highres_image") 60 | elif "hand_image" in data and "image" not in data: 61 | data["image"] = data.pop("hand_image") 62 | elif "agentview_rgb" in data: 63 | data["image"] = data.pop("agentview_rgb") 64 | 65 | if "eye_in_hand_rgb" in data: 66 | data["wrist_image"] = data.pop("eye_in_hand_rgb") 67 | return data 68 | 69 | @model_validator(mode="before") 70 | def get_state(cls, data: dict) -> dict: 71 | if "state" not in data: 72 | extra_state_keys = [ 73 | "gripper", 74 | "gripper_states", 75 | "end_effector_cartesian_pos", 76 | "end_effector_cartesian_velocity", 77 | "joint_pos", 78 | "joint_states", 79 | "pose", 80 | ] 81 | state_arrays = [] 82 | for k in extra_state_keys: 83 | if k in data: 84 | value = data[k] 85 | if isinstance(value, bool): 86 | state_arrays.append(np.array([float(value)])) 87 | elif hasattr(value, "shape"): 88 | if value.shape == (): 89 | state_arrays.append(value.numpy().reshape(1)) 90 | else: 91 | state_arrays.append(value) 92 | if state_arrays: 93 | data["state"] = np.concatenate(state_arrays) 94 | else: 95 | data["state"] = None 96 | if "end_effector_state" not in data: 97 | if "ee_state" in data: 98 | data["end_effector_state"] = data.pop("ee_state") 99 | return data 100 | 101 | 102 | class OpenXEmbodimentStep(TensorConverterMixin, BaseModel): 103 | action: np.ndarray | None 104 | discount: float | None = None 105 | is_first: bool 106 | is_last: bool 107 | is_terminal: bool 108 | language_embedding: np.ndarray | None = None 109 | language_instruction: str | None = None 110 | observation: OpenXEmbodimentStepObservation 111 | reward: float | None = None 112 | 113 | @model_validator(mode="before") 114 | @classmethod 115 | def remap_fields(cls, data: dict) -> dict: 116 | # Handle observation field remapping 117 | if "observation" in data and isinstance(data["observation"], dict): 118 | obs = data["observation"] 119 | 120 | # Move natural_language_instruction if it exists in observation 121 | if "natural_language_instruction" in obs: 122 | data["language_instruction"] = obs.pop("natural_language_instruction") 123 | if "natural_language_embedding" in obs: 124 | data["language_embedding"] = obs.pop("natural_language_embedding") 125 | 126 | # Add more field remapping here as needed 127 | action = data["action"] 128 | if isinstance(action, dict): 129 | extra_action_keys = [ 130 | "rotation_delta", 131 | "world_vector", 132 | "gripper_closedness_action", 133 | "terminate_episode", 134 | ] 135 | action_arrays = [] 136 | for k in extra_action_keys: 137 | if k in action: 138 | value = action[k] 139 | if isinstance(value, (int, float)): 140 | action_arrays.append(np.array([float(value)])) 141 | elif hasattr(value, "shape"): 142 | if value.shape == (): 143 | action_arrays.append(value.numpy().reshape(1)) 144 | else: 145 | action_arrays.append(value) 146 | 147 | if action_arrays: 148 | data["action"] = np.concatenate(action_arrays) 149 | else: 150 | data["action"] = None 151 | return data 152 | 153 | 154 | class OpenXEmbodimentEpisode(TensorConverterMixin, BaseModel): 155 | episode_metadata: OpenXEmbodimentEpisodeMetadata 156 | steps: list[OpenXEmbodimentStep] 157 | 158 | 159 | # hardcoded path to OXE spreadsheet 160 | # see original version at https://docs.google.com/spreadsheets/d/1rPBD77tk60AEIGZrGSODwyyzs5FgCU9Uz3h-3_t2A9g/edit?gid=0#gid=0 161 | PATH_TO_OXE_SPREADSHEET = "/workspaces/ares/src/ares/extras/oxe.csv" 162 | HEADER_ROW = 16 163 | 164 | 165 | def get_oxe_dataframe() -> pd.DataFrame: 166 | return pd.read_csv(PATH_TO_OXE_SPREADSHEET, header=HEADER_ROW) 167 | 168 | 169 | def get_dataset_information(dataset_filename: str) -> pd.DataFrame: 170 | df = get_oxe_dataframe() 171 | return dict(df[df["Registered Dataset Name"] == dataset_filename].iloc[0]) 172 | 173 | 174 | def construct_openxembodiment_episode(ep: dict, i: int) -> OpenXEmbodimentEpisode: 175 | if "episode_metadata" not in ep: 176 | ep["episode_metadata"] = dict(file_path=f"episode_{i}.npy") 177 | episode = OpenXEmbodimentEpisode(**ep) 178 | return episode 179 | -------------------------------------------------------------------------------- /src/ares/configs/pydantic_sql_helpers.py: -------------------------------------------------------------------------------- 1 | """ 2 | We use pydantic configs as the base unit of information in ARES. We use SQLModel classes to store these configs in a database, which requires 3 | flattening the pydantic model into a SQLModel. This file contains helpers to do this, as well as a helper to reconstruct a pydantic model from a 4 | flattened SQLModel. This allows us to use the same configs in both the frontend and backend of the application and easily convert between the two. 5 | """ 6 | 7 | import typing as t 8 | import uuid 9 | 10 | from pydantic import BaseModel 11 | from sqlmodel import Field, SQLModel 12 | 13 | 14 | def create_flattened_model( 15 | data_model: t.Type[BaseModel], non_nullable_fields: list[str] = ["id"] 16 | ) -> t.Type[SQLModel]: 17 | """ 18 | Create a flattened SQLModel class from a pydantic model. This allows us to store the pydantic model in a database. This requires recursively 19 | extracting all the fields from the pydantic model and adding them to the SQLModel class and inferring the type of the field in the SQLModel. 20 | 21 | For example, a config field like `rollout.robot.environment.lighting_estimate` will be flattened into `rollout_robot_environment_lighting_estimate` 22 | and the appropriate type will be inferred. 23 | """ 24 | fields: dict[str, t.Any] = { 25 | "__annotations__": {}, 26 | "__tablename__": "rollout", 27 | } 28 | 29 | # Add id field explicitly as primary key 30 | fields["__annotations__"]["id"] = uuid.UUID 31 | fields["id"] = Field(default_factory=uuid.uuid4, primary_key=True) 32 | 33 | # recursively extract fields 34 | def flatten_fields(prefix: str, model: t.Type[BaseModel]) -> None: 35 | for field_name, field in model.model_fields.items(): 36 | if field_name == "id": # Skip id field as we've handled it above 37 | continue 38 | 39 | field_type = field.annotation 40 | if field_type is None: 41 | continue 42 | 43 | # Handle list types by converting them to JSON strings 44 | origin_type = t.get_origin(field_type) 45 | if origin_type is not None and origin_type in (list, t.List): 46 | fields["__annotations__"][f"{prefix}{field_name}"] = str 47 | if f"{prefix}{field_name}" not in non_nullable_fields: 48 | fields[f"{prefix}{field_name}"] = Field(default=None, nullable=True) 49 | else: 50 | fields[f"{prefix}{field_name}"] = Field() 51 | continue 52 | elif isinstance(field_type, type) and issubclass(field_type, BaseModel): 53 | # Handle nested BaseModel 54 | flatten_fields(f"{prefix}{field_name}_", field_type) 55 | continue 56 | 57 | # Handle the field 58 | field_key = f"{prefix}{field_name}" 59 | fields["__annotations__"][field_key] = field_type 60 | if field_key not in non_nullable_fields: 61 | fields[field_key] = Field(nullable=True) 62 | else: 63 | fields[field_key] = Field() 64 | 65 | flatten_fields("", data_model) 66 | return type("RolloutSQLModel", (SQLModel,), fields, table=True) 67 | 68 | 69 | ModelCls = t.TypeVar("ModelCls", bound=BaseModel) 70 | 71 | 72 | def recreate_model(sql_model_instance: SQLModel, model_cls: type[ModelCls]) -> ModelCls: 73 | """Recreate a Pydantic model object from a flattened SQLModel instance. 74 | 75 | Args: 76 | sql_model_instance: Instance of the flattened SQLModel 77 | model_cls: The Pydantic model class to recreate 78 | 79 | Returns: 80 | BaseModel: Reconstructed Pydantic model object 81 | """ 82 | # Convert SQLModel instance to dict 83 | flat_dict = { 84 | k: v for k, v in sql_model_instance.__dict__.items() if not k.startswith("_") 85 | } 86 | 87 | # Build nested structure 88 | nested_dict = {} 89 | # Get the field types from the model class 90 | fields = model_cls.model_fields 91 | 92 | # Group fields by model structure 93 | for key, value in flat_dict.items(): 94 | # Handle non-nested fields 95 | if key in fields: 96 | nested_dict[key] = value 97 | continue 98 | 99 | # Handle nested fields by matching against model fields 100 | for field_name, field in fields.items(): 101 | # Check if field is a nested model 102 | if hasattr(field.annotation, "model_fields"): 103 | # If key starts with field_name + "_", it belongs to this nested model 104 | if key.startswith(f"{field_name}_"): 105 | # Initialize nested dict if needed 106 | if field_name not in nested_dict: 107 | nested_dict[field_name] = {} 108 | # Remove prefix to get the nested field name 109 | nested_field = key[len(field_name) + 1 :] 110 | nested_dict[field_name][nested_field] = value 111 | break 112 | # Build kwargs dict automatically 113 | kwargs = {} 114 | for field_name, field in fields.items(): 115 | if field_name in nested_dict: 116 | # If field has nested model, instantiate it 117 | if hasattr(field.annotation, "model_fields"): 118 | kwargs[field_name] = field.annotation(**nested_dict[field_name]) 119 | else: 120 | kwargs[field_name] = nested_dict[field_name] 121 | return model_cls(**kwargs) 122 | -------------------------------------------------------------------------------- /src/ares/constants.py: -------------------------------------------------------------------------------- 1 | import os 2 | from collections import defaultdict 3 | 4 | ARES_DATA_DIR = "/workspaces/ares/data" 5 | ARES_OXE_DIR = os.path.join(ARES_DATA_DIR, "oxe") 6 | ARES_VIDEO_DIR = os.path.join(ARES_DATA_DIR, "videos") 7 | 8 | # using oxe-downloader 9 | # oxe-download --dataset "name" --path $ARES_OXE_DIR 10 | DATASET_NAMES: list[dict[str, str]] = [ 11 | { 12 | "dataset_filename": "nyu_rot_dataset_converted_externally_to_rlds", 13 | "dataset_formalname": "NYU ROT", 14 | }, 15 | { 16 | "dataset_filename": "ucsd_kitchen_dataset_converted_externally_to_rlds", 17 | "dataset_formalname": "UCSD Kitchen", 18 | }, 19 | { 20 | "dataset_filename": "cmu_franka_exploration_dataset_converted_externally_to_rlds", 21 | "dataset_formalname": "CMU Franka Exploration", 22 | }, 23 | { 24 | "dataset_filename": "berkeley_fanuc_manipulation", 25 | "dataset_formalname": "Berkeley Fanuc Manipulation", 26 | }, 27 | { 28 | "dataset_filename": "cmu_stretch", 29 | "dataset_formalname": "CMU Stretch", 30 | }, 31 | # {"dataset_filename": "cmu_play_fusion", "dataset_formalname": "CMU Play Fusion"}, 32 | # { 33 | # "dataset_filename": "jaco_play", 34 | # "dataset_formalname": "USC Jaco Play", 35 | # }, 36 | # { 37 | # "dataset_filename": "dlr_edan_shared_control_converted_externally_to_rlds", 38 | # "dataset_formalname": "DLR Wheelchair Shared Control", 39 | # }, 40 | # { 41 | # "dataset_filename": "imperialcollege_sawyer_wrist_cam", 42 | # "dataset_formalname": "Imperial Wrist Cam", 43 | # }, 44 | # { 45 | # "dataset_filename": "tokyo_u_lsmo_converted_externally_to_rlds", 46 | # "dataset_formalname": "LSMO Dataset", 47 | # }, 48 | # { 49 | # "dataset_filename": "ucsd_pick_and_place_dataset_converted_externally_to_rlds", 50 | # "dataset_formalname": "UCSD Pick Place", 51 | # }, 52 | # { 53 | # "dataset_filename": "asu_table_top_converted_externally_to_rlds", 54 | # "dataset_formalname": "ASU TableTop Manipulation", 55 | # }, 56 | # { 57 | # "dataset_filename": "viola", 58 | # "dataset_formalname": "Austin VIOLA", 59 | # }, 60 | # { 61 | # "dataset_filename": "kaist_nonprehensile_converted_externally_to_rlds", 62 | # "dataset_formalname": "KAIST Nonprehensile Objects", 63 | # }, 64 | # { 65 | # "dataset_filename": "berkeley_mvp_converted_externally_to_rlds", 66 | # "dataset_formalname": "Berkeley MVP Data", 67 | # }, 68 | # { 69 | # "dataset_filename": "pi_demos", 70 | # "dataset_formalname": "Physical Intelligence Demos", 71 | # }, 72 | ] 73 | # Saytap does not have pixel data, so we exclude it 74 | # { 75 | # "dataset_filename": "utokyo_saytap_converted_externally_to_rlds", 76 | # "dataset_formalname": "Saytap", 77 | # }, 78 | 79 | DATASET_KEY_TO_DATASET_INFO: dict[str, dict[str, dict[str, str]]] = defaultdict(dict) 80 | keys = ["dataset_filename", "dataset_formalname"] 81 | for dataset_info in DATASET_NAMES: 82 | for key in keys: 83 | DATASET_KEY_TO_DATASET_INFO[key][dataset_info[key]] = dataset_info 84 | 85 | 86 | def get_dataset_info_by_key(key_type: str, key: str) -> dict[str, str]: 87 | # allows us to get dataset info by key (like filename or formalname) 88 | if key_type not in DATASET_KEY_TO_DATASET_INFO: 89 | raise ValueError(f"Invalid key type: {key_type}") 90 | if key not in DATASET_KEY_TO_DATASET_INFO[key_type]: 91 | raise ValueError(f"Invalid key: {key}") 92 | return DATASET_KEY_TO_DATASET_INFO[key_type][key] 93 | 94 | 95 | # for ingestion operations, we're loading large amounts of data into memory at once. 96 | # this is a hard limit on the number of rollouts/requests to avoid memory issues. 97 | OUTER_BATCH_SIZE = 20 98 | 99 | # for annotation operations, the objects in memory are smaller (eg no point clouds), 100 | # so we can load more into memory at once. 101 | ANNOTATION_OUTER_BATCH_SIZE = 100 102 | 103 | # for grounding, annotate the frames at set FPS 104 | ANNOTATION_GROUNDING_FPS = 5 105 | 106 | # for displays, we want to ignore some columns consistently 107 | IGNORE_COLS = [ 108 | "dataset_filename", 109 | "dataset_formalname", 110 | "id", 111 | "path", 112 | "filename", 113 | ] 114 | -------------------------------------------------------------------------------- /src/ares/databases/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jacobphillips99/ares/21850d69242f9106c186338b5883667562f4442f/src/ares/databases/__init__.py -------------------------------------------------------------------------------- /src/ares/extras/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jacobphillips99/ares/21850d69242f9106c186338b5883667562f4442f/src/ares/extras/__init__.py -------------------------------------------------------------------------------- /src/ares/extras/pi_demo_utils.py: -------------------------------------------------------------------------------- 1 | PI_DEMO_TASKS = { 2 | "Eggs in carton": { 3 | "task": "The robot must place every egg into the carton and then close and secure the lid.", 4 | "filename_prefix": "processed_eggs", 5 | }, 6 | "Grocery Bagging": { 7 | "task": "The robot must place every item into the bag.", 8 | "filename_prefix": "processed_grocery_bagging", 9 | }, 10 | "Toast out of toaster": { 11 | "task": "The robot must remove the both pieces of toast from the toaster and place them both on the plate.", 12 | "filename_prefix": "processed_toast", 13 | }, 14 | "Towel fold": { 15 | "task": "The robot must fold the towel.", 16 | "filename_prefix": "processed_towel_fold", 17 | }, 18 | "Stack bowls": { 19 | "task": "The robot must stack all of the bowls.", 20 | "filename_prefix": "processed_stack", 21 | }, 22 | "Tupperware in microwave": { 23 | "task": "The robot must place the tupperware in the microwave and close the door.", 24 | "filename_prefix": "processed_tupperware", 25 | }, 26 | "Items in drawer": { 27 | "task": "The robot must place all of the items in the drawer and close the drawer.", 28 | "filename_prefix": "processed_drawer", 29 | }, 30 | "Laundry fold (shirts)": { 31 | "task": "The robot must fold the shirt.", 32 | "filename_prefix": "processed_folding_single_shirt", 33 | }, 34 | "Laundry fold (shorts)": { 35 | "task": "The robot must fold the shorts.", 36 | "filename_prefix": "processed_fold_single_shorts", 37 | }, 38 | "Paper towel in holder": { 39 | "task": "The robot must discard the old paper towel roll and place the new white paper towel roll completely onto the holder.", 40 | "filename_prefix": "processed_towel", 41 | }, 42 | "Food in to go box": { 43 | "task": "The robot must place the food in the to go box and properly close the lid.", 44 | "filename_prefix": "processed_togo", 45 | }, 46 | } 47 | -------------------------------------------------------------------------------- /src/ares/models/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jacobphillips99/ares/21850d69242f9106c186338b5883667562f4442f/src/ares/models/__init__.py -------------------------------------------------------------------------------- /src/ares/models/grounding.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import numpy as np 4 | import torch 5 | from PIL import Image 6 | from transformers import ( 7 | AutoModelForMaskGeneration, 8 | AutoModelForZeroShotObjectDetection, 9 | AutoProcessor, 10 | ) 11 | 12 | 13 | class GroundingAnnotator: 14 | def __init__( 15 | self, 16 | detector_id: str = "IDEA-Research/grounding-dino-tiny", 17 | segmenter_id: str | None = "facebook/sam-vit-base", 18 | detector_thresholds: dict[str, float] | None = None, 19 | device: str = "cuda" if torch.cuda.is_available() else "cpu", 20 | ): 21 | self.device = device 22 | self.detector_processor, self.detector_model = self.setup_detector(detector_id) 23 | self.segmentor_processor, self.segmentor_model = self.setup_segmenter( 24 | segmenter_id 25 | ) 26 | self.detector_thresholds = detector_thresholds or { 27 | "box_threshold": 0.4, 28 | "text_threshold": 0.3, 29 | } 30 | print( 31 | f"Loaded detector {detector_id}" 32 | + ( 33 | f"and segmenter {segmenter_id} on device {device}" 34 | if segmenter_id 35 | else "" 36 | ) 37 | ) 38 | 39 | def setup_detector( 40 | self, model_id: str 41 | ) -> tuple[AutoProcessor, AutoModelForZeroShotObjectDetection]: 42 | processor = AutoProcessor.from_pretrained(model_id) 43 | print(f"Downloading model {model_id}...") 44 | model = AutoModelForZeroShotObjectDetection.from_pretrained(model_id).to( 45 | self.device 46 | ) 47 | return processor, model 48 | 49 | def setup_segmenter( 50 | self, model_id: str | None 51 | ) -> tuple[AutoProcessor, AutoModelForMaskGeneration]: 52 | if model_id is None: 53 | return None, None 54 | processor = AutoProcessor.from_pretrained(model_id) 55 | print(f"Downloading model {model_id}...") 56 | model = AutoModelForMaskGeneration.from_pretrained( 57 | model_id, token=os.environ.get("HUGGINGFACE_API_KEY") 58 | ).to(self.device) 59 | return processor, model 60 | 61 | def run_detector( 62 | self, 63 | images: list[Image.Image], 64 | labels_str: str, 65 | ) -> list[list[dict]]: 66 | # Process all images in a single batch 67 | inputs = self.detector_processor( 68 | images=images, text=[labels_str] * len(images), return_tensors="pt" 69 | ).to(self.device) 70 | 71 | with torch.no_grad(): 72 | outputs = self.detector_model(**inputs) 73 | 74 | target_sizes = [[img.size[1], img.size[0]] for img in images] # [height, width] 75 | 76 | results = self.detector_processor.post_process_grounded_object_detection( 77 | outputs, 78 | inputs.input_ids, 79 | box_threshold=self.detector_thresholds["box_threshold"], 80 | text_threshold=self.detector_thresholds["text_threshold"], 81 | target_sizes=target_sizes, 82 | ) 83 | 84 | all_annotations = [] 85 | for image_idx, result in enumerate(results): 86 | frame_annotations = [] 87 | if "boxes" in result: 88 | for box_idx in range(len(result["boxes"])): 89 | ann_dict = { 90 | "bbox": result["boxes"][box_idx].tolist(), 91 | "category_name": result["labels"][box_idx], 92 | "score": result["scores"][box_idx].item(), 93 | } 94 | frame_annotations.append(ann_dict) 95 | all_annotations.append(frame_annotations) 96 | 97 | return all_annotations 98 | 99 | def run_segmenter( 100 | self, 101 | images: list[Image.Image], 102 | annotations: list[list[dict]], 103 | ) -> list[list[dict]]: 104 | # Process each image's annotations 105 | all_points = [] 106 | all_labels = [] 107 | max_points = max(len(frame_anns) for frame_anns in annotations) 108 | 109 | for frame_anns in annotations: 110 | frame_points = [ 111 | [ 112 | (box["bbox"][0] + box["bbox"][2]) / 2, 113 | (box["bbox"][1] + box["bbox"][3]) / 2, 114 | ] 115 | for box in frame_anns 116 | ] 117 | # Pad points and labels to ensure consistent shape 118 | while len(frame_points) < max_points: 119 | frame_points.append([0.0, 0.0]) # Add dummy points 120 | 121 | frame_labels = [1] * len(frame_anns) 122 | frame_labels.extend([0] * (max_points - len(frame_anns))) # Pad with zeros 123 | 124 | all_points.append(frame_points) 125 | all_labels.append(frame_labels) 126 | 127 | if not any(all_points): # Handle case with no detections 128 | return annotations 129 | 130 | inputs = self.segmentor_processor( 131 | images=images, 132 | input_points=all_points, 133 | input_labels=all_labels, 134 | return_tensors="pt", 135 | ).to(self.device) 136 | 137 | with torch.no_grad(): 138 | outputs = self.segmentor_model(**inputs) 139 | 140 | scores = outputs["iou_scores"] 141 | masks = self.segmentor_processor.post_process_masks( 142 | masks=outputs.pred_masks, 143 | original_sizes=inputs.original_sizes, 144 | reshaped_input_sizes=inputs.reshaped_input_sizes, 145 | ) 146 | 147 | # Process results for each frame 148 | for frame_idx, (frame_masks, frame_scores, frame_anns) in enumerate( 149 | zip(masks, scores, annotations) 150 | ): 151 | for obj_idx, (mask, score, ann) in enumerate( 152 | zip(frame_masks, frame_scores, frame_anns) 153 | ): 154 | best_mask = mask[score.argmax()] 155 | ann["segmentation"] = best_mask.numpy() 156 | 157 | return annotations 158 | 159 | def process_batch( 160 | self, 161 | images: list[Image.Image], 162 | labels_str: str, 163 | ) -> list[list[dict]]: 164 | """Process a batch of images with detection and segmentation.""" 165 | box_annotations = self.run_detector(images, labels_str) 166 | if not any(box_annotations): 167 | return box_annotations 168 | 169 | if self.segmentor_model is not None: 170 | segment_annotations = self.run_segmenter(images, box_annotations) 171 | return segment_annotations 172 | else: 173 | return box_annotations 174 | 175 | def annotate_video( 176 | self, 177 | rollout_id: str, 178 | frames: list, 179 | labels_str: str, 180 | batch_size: int = 8, 181 | ) -> tuple[str, list[list[dict]]]: 182 | """Annotate video frames in batches.""" 183 | all_annotations = [] 184 | 185 | # Convert frames to PIL Images using PIL instead of cv2 186 | frames = [np.array(f) for f in frames] 187 | pil_frames = [Image.fromarray(f).convert("RGB") for f in frames] 188 | 189 | # Process in batches 190 | for i in range(0, len(pil_frames), batch_size): 191 | batch_frames = pil_frames[i : i + batch_size] 192 | batch_annotations = self.process_batch(batch_frames, labels_str) 193 | all_annotations.extend(batch_annotations) 194 | return rollout_id, all_annotations 195 | -------------------------------------------------------------------------------- /src/ares/models/grounding_utils.py: -------------------------------------------------------------------------------- 1 | import asyncio 2 | 3 | import numpy as np 4 | 5 | from ares.configs.annotations import Annotation, binary_mask_to_rle 6 | from ares.models.base import VLM, parse_response 7 | 8 | 9 | async def get_grounding_nouns_async( 10 | vlm: VLM, 11 | image: np.ndarray, 12 | task_instructions: str, 13 | prompt_filename: str = "grounding_description.jinja2", 14 | ) -> str: 15 | """Get object labels from VLM asynchronously.""" 16 | if task_instructions is None: 17 | task_instructions = "" 18 | _, response = await vlm.ask_async( 19 | info=dict(task_instructions=task_instructions), 20 | prompt_filename=prompt_filename, 21 | images=[image], 22 | ) 23 | label_str = parse_response(response.choices[0], load_json=False) 24 | label_str = label_str.replace("a ", "").replace("an ", "") 25 | return label_str 26 | 27 | 28 | def get_grounding_nouns( 29 | vlm: VLM, 30 | image: np.ndarray, 31 | task_instructions: str, 32 | prompt_filename: str = "grounding_description.jinja2", 33 | ) -> str: 34 | return asyncio.run( 35 | get_grounding_nouns_async(vlm, image, task_instructions, prompt_filename) 36 | ) 37 | 38 | 39 | def convert_to_annotations( 40 | detection_results: list[list[dict]], 41 | ) -> list[list[Annotation]]: 42 | """Convert detection results from dictionaries to Annotation objects.""" 43 | # convert masks to rle 44 | outputs = [] 45 | for frame_anns in detection_results: 46 | frame_anns = [] 47 | for ann in frame_anns: 48 | if ann.get("segmentation") is not None: 49 | ann["segmentation"] = binary_mask_to_rle(ann["segmentation"]) 50 | frame_anns.append(ann) 51 | outputs.append(frame_anns) 52 | return outputs 53 | -------------------------------------------------------------------------------- /src/ares/models/prompts/extractor_prompt.jinja2: -------------------------------------------------------------------------------- 1 | You are an AI assistant helping to analyze robot task execution data. Please examine the provided images and task information to extract relevant details. 2 | 3 | Task Instruction: {{ task }} 4 | 5 | Required Fields: 6 | {{field_instructions}} 7 | 8 | Example Response Format: 9 | {{ example_response_format | tojson(indent=2) }} where "..." represents the actual information. 10 | 11 | Please provide a structured response following the example format above. 12 | Make sure to adhere to any additional instructions for each field, such as the following the pattern for the field (including case sensitivity). 13 | 14 | Your response should be valid JSON matching the example format. Ignore any text superimposed on the images. 15 | Begin the response with ```json and end with ```. 16 | -------------------------------------------------------------------------------- /src/ares/models/prompts/grounding_description.jinja2: -------------------------------------------------------------------------------- 1 | The task is to look at an image and list the important objects in the image which will later be used to query bounding boxes. 2 | The image will show a robot performing a task. Pay special attention to the robot and the objects it interacts with. 3 | 4 | Use general purpose nouns for objects, such as "a cup" or "a bowl"; however, don't overgeneralize to the point that the label is not specific. For example, say "a red bowl", not "a red object". 5 | Do not provide background information like "floor" or "wall"; focus on the objects that are important to the task and scene. Do provide information about the robot, the objects getting manipulated in the scene, and anything that is relevant to the task. 6 | The output format is a simple period-separated list of simple descriptions of objects, such as "a cup. a bowl. a robot." 7 | 8 | You may also be provided with task instructions. If so, try to return the objects that are relevant to the task. 9 | 10 | For example, imagine an image of a robot picking up a square red cup and pouring it into a round blue bowl, and task instructions saying "pour the cup into the bowl". The important objects are the robot, the cup, and the bowl. 11 | The output should be "a red cup. a blue bowl. a robot." 12 | 13 | Do not just say "a red object"; always use the actual name of the object! If an object is important, try to disambiguate it from other objects. 14 | Look at the following image and list the important objects in the image. Please be extensive and include all objects that are relevant to the task. 15 | 16 | Do not include any pleasantries or other text in the output, just the period-separated list of objects; no need to include quotation marks. 17 | Again, do not include any comments like "I'm unable to see the image" or "I'm not sure what's in the image", just return the period-separated list of objects. 18 | 19 | 20 | {% if task_instructions %}Task instructions: {{ task_instructions }}. Most nouns in the instructions are important objects.{% endif %} 21 | 22 | Images: 23 | -------------------------------------------------------------------------------- /src/ares/models/prompts/icl.jinja2: -------------------------------------------------------------------------------- 1 | You objective is to create a plan for a robot to complete a task given an image and some examples of similar tasks. 2 | Each example is a description of how a robot completed a similar task, where "similar" is defined by the "key" accompanying the example. 3 | Task: {{ task }} 4 | 5 | Examples: 6 | {% for key, example_list in examples.items() %} 7 | Similarity key: {{ key }} 8 | {% for example_item in example_list %} 9 | - {{ example_item }} 10 | {% endfor %} 11 | {% endfor %} 12 | 13 | Following the image below, output a plan for the robot to complete the task. 14 | You should first summarize the examples above and then create a plan for the robot to complete the task. 15 | Think carefully and look carefully at the image and the information above. 16 | Your response should be a well-reasoned and detailed plan for the robot to complete the task. 17 | -------------------------------------------------------------------------------- /src/ares/models/prompts/pseudo_ecot.jinja2: -------------------------------------------------------------------------------- 1 | Your goal is to create a plan for a robot to complete a task given an image and information about the task. 2 | 3 | Information about the task: 4 | - The overarching task is: {{ task }}, which is a {{ complexity_category_estimate }} task. 5 | - The environment contains the following objects that relate to the task: {{ grounding_string }} 6 | {% if detections %}- These objects have the following positions in the environment: {{ detections }}{% endif %} 7 | - The success criteria for the task are: {{ success_criteria }} 8 | 9 | Following the image below, output a plan for the robot to complete the task. 10 | Think carefully and look carefully at the image and the information above. 11 | Your response should be a well-reasoned and detailed plan for the robot to complete the task. 12 | -------------------------------------------------------------------------------- /src/ares/models/prompts/simple_video_eval.jinja2: -------------------------------------------------------------------------------- 1 | You will be provided with a task description, a list of frames from a video, and a set of success constraints. 2 | Your objective is to describe the images and determine the robot's success at the end of the images according to the success constraints. 3 | 4 | First, describe the images. Make sure to include the end state of the scene from the last images. 5 | Second, analyze the robot's success at the task according to the success constraints. Make sure to consider the end state of the images! 6 | Third, determine the float score representing the robot's success at the task according to the success constraints. 7 | Output the answer as a JSON object with the following fields: 8 | {{output_format}} 9 | 10 | Here is the task: 11 | {{task}} 12 | 13 | Here are the success constraints: 14 | {{success_constraints}} 15 | 16 | Remember, since you are returning a JSON object, you must return a valid JSON object. Begin with ```json and end with ```. 17 | Throughout your response, reference images by their index which is provided between images. 18 | 19 | Images: -------------------------------------------------------------------------------- /src/ares/models/prompts/success_constraint_generation.jinja2: -------------------------------------------------------------------------------- 1 | The task is to generate a set of success constraints for a robot to perform a task, including the goal end state of the task. 2 | Given the task description and an image describing the scene, output a string listing the criteria for successful completion of the task. 3 | The succcess constraints should be specific to the task and image, but should represent what a person would consider for the task to be generally successful. 4 | Focus on if the critical steps of the task are successful. 5 | 6 | For example, consider the task description "Pour the coffee into the cup" and an image showing a robot holding a pot of coffee and a cup on a table. 7 | The example output should be "The robot must pour the coffee from the pot into the cup on the table. The end state should be a cup with coffee in it on the table." 8 | 9 | Please output the success constraint as a simple string; do not include any other text or formatting. 10 | Make sure to be very specific and unambiguous when describing the success constraints and objects in the scene. 11 | 12 | Here is the task description: 13 | {{task}} 14 | -------------------------------------------------------------------------------- /src/ares/models/prompts/summarization_frame_eval.jinja2: -------------------------------------------------------------------------------- 1 | You will be provided with a task description, a set of success constraints, and a list of paragraphs describing each frame from a video. 2 | Your objective is to summarize the descriptions and determine the robot's success at the task according to the success constraints. 3 | 4 | First, summarize the descriptions. 5 | Second, analyze the robot's success at the task according to the success constraints. Make sure to consider the end state of the descriptions! 6 | Third, determine the float score representing the robot's success at the task according to the success constraints. 7 | Output the answer as a JSON object with the following fields: 8 | {{output_format}} 9 | 10 | Here is the task: 11 | {{task}} 12 | 13 | Here are the success constraints: 14 | {{success_constraints}} 15 | 16 | Remember, since you are returning a JSON object, you must return a valid JSON object. Begin with ```json and end with ```. 17 | Throughout your response, reference descriptions by their index which is provided between descriptions. 18 | 19 | Descriptions: 20 | {{descriptions}} -------------------------------------------------------------------------------- /src/ares/models/prompts/summarizing.jinja2: -------------------------------------------------------------------------------- 1 | Your task is to summarize the following data. The data represents {{description}}. 2 | 3 | Data: 4 | {{data}} 5 | 6 | Please provide a concise summary of the data and do not include any other information or pleasantries. 7 | -------------------------------------------------------------------------------- /src/ares/models/prompts/task_frame_description.jinja2: -------------------------------------------------------------------------------- 1 | Your task is to extensively describe a frame from a video, paying special attention to the provided task. 2 | Make sure to include all relevant information about the robot, the task, any objects, and the scene. 3 | Here are the task details: {{task}} and success constraints: {{success_constraints}}. 4 | 5 | Return a paragraph describing the frame; do not include any pleasantries, formatting, or other text. 6 | -------------------------------------------------------------------------------- /src/ares/models/refusal.py: -------------------------------------------------------------------------------- 1 | import re 2 | 3 | REFUSAL_PHRASES = ["I'm sorry", "I'm unable"] 4 | 5 | 6 | def check_refusal(text: str, refusal_phrases: list[str] | None = None) -> bool: 7 | if refusal_phrases is None: 8 | refusal_phrases = REFUSAL_PHRASES 9 | pattern = "|".join(map(re.escape, refusal_phrases)) 10 | return bool(re.search(pattern, text)) 11 | -------------------------------------------------------------------------------- /src/ares/models/sampling_bias.py: -------------------------------------------------------------------------------- 1 | """ 2 | We are interested in sampling frames from a given episode. 3 | We have a thesis that frames towards the end of an episode are more interesting than those in the beginning. 4 | Below, we implement a few strategies to present this strategy with a focus on an upper bound of F total frames. 5 | 6 | Note: individual sampling strategies are not guaranteed to return F frames. Use the `sampling_bias` function to ensure F frames are returned. 7 | The `sampling_bias` function will sample extra frames if necessary to reach the target number of frames or uniformly subsample if more frames are sampled than necessary. 8 | """ 9 | 10 | import typing as t 11 | 12 | import numpy as np 13 | 14 | 15 | # linear sampling bias 16 | def linear_sampling_bias( 17 | input_n_frames: int, total_desired_frames: int, **kwargs: t.Any 18 | ) -> t.Sequence[int]: 19 | """ 20 | Sample frames with a probability linearly proportional to their index. 21 | """ 22 | return [i for i in range(input_n_frames) if np.random.random() < i / input_n_frames] 23 | 24 | 25 | def exponential_sampling_bias( 26 | input_n_frames: int, total_desired_frames: int, rate: float = 5.0, **kwargs: t.Any 27 | ) -> t.Sequence[int]: 28 | """ 29 | Sample frames with exponentially increasing probability. 30 | 31 | The probability of selecting a frame increases exponentially with its index, 32 | controlled by the rate parameter. Higher rates lead to stronger bias towards 33 | later frames. 34 | """ 35 | return [ 36 | i 37 | for i in range(input_n_frames) 38 | if np.random.rand() < (1 - np.exp(-rate * i / total_desired_frames)) 39 | ] 40 | 41 | 42 | def threshold_sampling_bias( 43 | input_n_frames: int, 44 | total_desired_frames: int, 45 | frame_threshold: float = 0.75, 46 | bias_rate: float = 0.5, 47 | ) -> t.Sequence[int]: 48 | """ 49 | Uniformly sample frames with a proportion coming from before the frame_threshold and the rest coming from after. 50 | Control the bias rate to control the relative number of frames sampled from before and after. 51 | """ 52 | desired_n_frames_before = int(bias_rate * total_desired_frames) 53 | desired_n_frames_after = total_desired_frames - desired_n_frames_before 54 | 55 | threshold_n = int(input_n_frames * frame_threshold) 56 | frames_before = np.random.choice( 57 | range(threshold_n), desired_n_frames_before, replace=False 58 | ) 59 | frames_after = np.random.choice( 60 | range(threshold_n, input_n_frames), desired_n_frames_after, replace=False 61 | ) 62 | return sorted(list(frames_before) + list(frames_after)) 63 | 64 | 65 | def sampling_bias( 66 | input_n_frames: int, 67 | total_desired_frames: int, 68 | strategy: str = "linear", 69 | **kwargs: t.Any, 70 | ) -> t.Sequence[int]: 71 | if input_n_frames < total_desired_frames: 72 | raise ValueError( 73 | f"Input number of frames ({input_n_frames}) is less than the desired number of frames ({total_desired_frames})." 74 | ) 75 | 76 | if strategy == "linear": 77 | sampled = linear_sampling_bias(input_n_frames, total_desired_frames) 78 | elif strategy == "exponential": 79 | sampled = exponential_sampling_bias( 80 | input_n_frames, total_desired_frames, **kwargs 81 | ) 82 | elif strategy == "threshold": 83 | sampled = threshold_sampling_bias( 84 | input_n_frames, total_desired_frames, **kwargs 85 | ) 86 | if len(sampled) < total_desired_frames: 87 | remaining_indices = list(set(range(input_n_frames)) - set(sampled)) 88 | extra_samples = np.random.choice( 89 | remaining_indices, 90 | size=total_desired_frames - len(sampled), 91 | replace=False, 92 | ) 93 | sampled = sorted(sampled + extra_samples) 94 | elif len(sampled) > total_desired_frames: 95 | sampled = np.random.choice( 96 | np.array(sampled), total_desired_frames, replace=False 97 | ) 98 | return sorted(sampled) 99 | -------------------------------------------------------------------------------- /src/ares/models/shortcuts.py: -------------------------------------------------------------------------------- 1 | import typing as t 2 | 3 | from ares.models.base import VLM, Embedder, SentenceTransformerEmbedder 4 | 5 | 6 | def get_siglip_embedder() -> Embedder: 7 | return Embedder(provider="google", name="siglip-base-patch16-224") 8 | 9 | 10 | def get_nomic_embedder() -> SentenceTransformerEmbedder: 11 | return SentenceTransformerEmbedder(provider="nomic-ai", name="nomic-embed-text-v1") 12 | 13 | 14 | def get_all_embedders() -> dict[str, Embedder]: 15 | return { 16 | "siglip": get_siglip_embedder(), 17 | "nomic": get_nomic_embedder(), 18 | } 19 | 20 | 21 | def get_gemini_15_flash() -> VLM: 22 | return VLM(provider="gemini", name="gemini-1.5-flash") 23 | 24 | 25 | def get_gemini_15_pro() -> VLM: 26 | return VLM(provider="gemini", name="gemini-1.5-pro") 27 | 28 | 29 | def get_gemini_2_flash() -> VLM: 30 | return VLM(provider="gemini", name="gemini-2.0-flash") 31 | 32 | 33 | def get_gemini_2_pro() -> VLM: 34 | return VLM(provider="gemini", name="gemini-2.0-pro-exp-02-05") 35 | 36 | 37 | def get_gpt_4o_mini() -> VLM: 38 | return VLM(provider="openai", name="gpt-4o-mini") 39 | 40 | 41 | def get_gpt_4o() -> VLM: 42 | return VLM(provider="openai", name="gpt-4o") 43 | 44 | 45 | def get_gpt_o1_mini() -> VLM: 46 | return VLM(provider="openai", name="o1-preview") 47 | 48 | 49 | def get_gpt_4_turbo() -> VLM: 50 | return VLM(provider="openai", name="gpt-4-turbo") 51 | 52 | 53 | def get_claude_3_5_sonnet() -> VLM: 54 | return VLM(provider="anthropic", name="claude-3-5-sonnet-20240620") 55 | 56 | 57 | name_to_vlm_fn_mapping = { 58 | "gemini-1.5-pro": get_gemini_15_pro, 59 | "gemini-2-flash": get_gemini_2_flash, 60 | "gemini-1.5-flash": get_gemini_15_flash, 61 | "gpt-4o-mini": get_gpt_4o_mini, 62 | "gpt-4o": get_gpt_4o, 63 | "gpt-o1-mini": get_gpt_o1_mini, 64 | "claude-3-5-sonnet": get_claude_3_5_sonnet, 65 | "gpt-4-turbo": get_gpt_4_turbo, 66 | } 67 | 68 | 69 | def get_all_vlm_fns() -> dict[str, t.Callable[[], VLM]]: 70 | return name_to_vlm_fn_mapping 71 | 72 | 73 | def get_vlm(name: str) -> VLM: 74 | if name not in name_to_vlm_fn_mapping: 75 | raise ValueError( 76 | f"VLM {name} not found from name_to_vlm_fn_mapping: {name_to_vlm_fn_mapping.keys()}" 77 | ) 78 | return name_to_vlm_fn_mapping[name]() 79 | 80 | 81 | def summarize(vlm: VLM, data: list[str], description: str) -> str: 82 | info = {"data": "\n".join(data), "description": description} 83 | messages, response = vlm.ask( 84 | info, 85 | prompt_filename="summarizing.jinja2", 86 | ) 87 | return response.choices[0].message.content 88 | -------------------------------------------------------------------------------- /src/ares/training/README.md: -------------------------------------------------------------------------------- 1 | # ARES Training Setup 2 | 3 | This directory contains the (MOCK) training pipeline for using ARES. The process is split into two main steps: 4 | Note: this is a mock pipeline demonstrating how the ARES platform could be used for training. 5 | ## 1. Data Preprocessing 6 | 7 | The first step uses `preprocess.py` to: 8 | - Load a dataframe of desired rollout IDs to train on 9 | - Query the database engine to fetch full rollout data 10 | - Preload annotations from annotations_db for the specified key 11 | - Save everything as a parquet file for efficient loading 12 | 13 | ```bash 14 | python preprocess.py --ids-path path/to/ids.csv --output-path data/processed.parquet --annotation-key detection 15 | ``` 16 | 17 | ## 2. Training 18 | 19 | The second step uses `train.py` which provides: 20 | - Custom PyTorch Dataset that loads the parquet file 21 | - Efficient DataLoader for batching 22 | - Constructs Rollout objects and annotation dictionaries 23 | 24 | ```bash 25 | python train.py --data-path data/processed.parquet 26 | ``` 27 | 28 | ## Data Structure 29 | 30 | The preprocessed parquet file contains: 31 | - All fields from the original rollouts 32 | - Preloaded annotations under the 'annotations' column 33 | - Frame indices and metadata needed for training 34 | 35 | The PyTorch Dataset returns tuples of: 36 | - Rollout object 37 | - Dictionary of annotations 38 | - Additional metadata needed for training 39 | -------------------------------------------------------------------------------- /src/ares/training/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jacobphillips99/ares/21850d69242f9106c186338b5883667562f4442f/src/ares/training/__init__.py -------------------------------------------------------------------------------- /src/ares/training/preprocess.py: -------------------------------------------------------------------------------- 1 | """ 2 | Mock simple script to preprocess rollouts and annotations into the train format. 3 | This is a mock to show how to construct a preprocessed artifact to conduct training using the ARES platform. 4 | We construct a training artifact in order to derisk loading errors and avoid massive database queries during training. 5 | 6 | See ares/train/README.md for more details and ares/train/train.py for how to use the preprocessed artifact. 7 | 8 | Preprocess assumes a list of IDs have been selected via curation on the ARES platform. 9 | Extra info cols could be "grounding_string", "detections", "embodied_cot", etc. 10 | 11 | Example usage: 12 | python preprocess.py --ids_df_path path/to/ids.csv --extra_info_cols col1 --extra_info_cols col2 --output_path output.parquet 13 | 14 | # Or with shorter syntax: 15 | python preprocess.py --ids_df_path path/to/ids.csv -extra_info_cols col1 -extra_info_cols col2 --output_path output.parquet 16 | """ 17 | 18 | import json 19 | import os 20 | from collections import defaultdict 21 | 22 | import click 23 | import numpy as np 24 | import pandas as pd 25 | from tqdm import tqdm 26 | 27 | from ares.databases.annotation_database import ( 28 | ANNOTATION_DB_PATH, 29 | AnnotationDatabase, 30 | get_video_id, 31 | ) 32 | from ares.databases.structured_database import ( 33 | ROBOT_DB_PATH, 34 | RolloutSQLModel, 35 | get_partial_df, 36 | get_rollouts_by_ids, 37 | setup_database, 38 | ) 39 | 40 | 41 | def setup_extra_info_col( 42 | df: pd.DataFrame, col: str, ann_db: AnnotationDatabase 43 | ) -> list[str | None]: 44 | raw_anns: list[str | None] = [] 45 | for _, row in tqdm(df.iterrows(), desc=f"Collecting annotations for {col}"): 46 | video_id = get_video_id(row["dataset_filename"], row["filename"]) 47 | anns = ann_db.get_annotations(video_id, annotation_type=col) 48 | if anns and col in anns: 49 | raw_anns.append( 50 | json.dumps( 51 | anns[col], 52 | default=lambda x: x.__json__() if hasattr(x, "__json__") else x, 53 | ) 54 | ) 55 | else: 56 | raw_anns.append(None) 57 | 58 | # Check if all annotations are None -- probably an error 59 | if all(ann is None for ann in raw_anns): 60 | raise ValueError(f"No annotations found for column {col}") 61 | return raw_anns 62 | 63 | 64 | @click.command() 65 | @click.option("--output-path", type=str, help="Path to save the preprocessed artifact") 66 | @click.option( 67 | "--ids-df-path", 68 | type=str, 69 | help="Path to CSV file containing rollout IDs", 70 | required=False, 71 | default=None, 72 | ) 73 | @click.option( 74 | "--extra-info-cols", 75 | type=str, 76 | multiple=True, 77 | help="Extra info columns to collect annotations for. Can be specified multiple times for different columns", 78 | required=False, 79 | default=None, 80 | ) 81 | @click.option( 82 | "--drop-nones", 83 | is_flag=True, 84 | default=False, 85 | help="If True, drop rows where any extra_info_col contains None values", 86 | ) 87 | def preprocess( 88 | output_path: str, 89 | ids_df_path: str | None, 90 | extra_info_cols: list[str] | None, 91 | drop_nones: bool, 92 | ) -> None: 93 | engine = setup_database(RolloutSQLModel, path=ROBOT_DB_PATH) 94 | 95 | if ids_df_path: 96 | ids_df = pd.read_csv(ids_df_path) 97 | else: 98 | ids_df = get_partial_df(engine, ["id"]) 99 | 100 | if extra_info_cols is None: 101 | extra_info_cols = [] 102 | 103 | # collect rollouts 104 | rollout_df: pd.DataFrame = get_rollouts_by_ids( 105 | engine, ids_df.id.tolist(), return_df=True 106 | ) 107 | 108 | # collect annotations from annotation database via extra_info_cols 109 | if extra_info_cols: 110 | ann_db = AnnotationDatabase(connection_string=ANNOTATION_DB_PATH) 111 | extra_info_cols_to_anns = defaultdict(list) 112 | for col in extra_info_cols: 113 | raw_anns = setup_extra_info_col(rollout_df, col, ann_db) 114 | extra_info_cols_to_anns[col] = raw_anns 115 | 116 | # construct train df 117 | train_df = rollout_df.copy() 118 | # parquet doesnt like uuids, so we convert to str 119 | train_df.id = train_df.id.astype(str) 120 | for col in extra_info_cols: 121 | train_df[col] = extra_info_cols_to_anns[col] 122 | 123 | # Drop rows with None values in extra info columns if requested 124 | if drop_nones and extra_info_cols: 125 | initial_len = len(train_df) 126 | train_df = train_df.dropna(subset=extra_info_cols) 127 | train_df = train_df.reset_index(drop=True) 128 | dropped_count = initial_len - len(train_df) 129 | print( 130 | f"Dropped {dropped_count} rows with None values in {extra_info_cols}: {dropped_count/initial_len*100:.2f}%" 131 | ) 132 | 133 | # save to parquet 134 | print(f"Saving to {output_path}") 135 | os.makedirs(os.path.dirname(output_path), exist_ok=True) 136 | train_df.to_parquet(output_path) 137 | 138 | 139 | if __name__ == "__main__": 140 | preprocess() 141 | -------------------------------------------------------------------------------- /src/ares/utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jacobphillips99/ares/21850d69242f9106c186338b5883667562f4442f/src/ares/utils/__init__.py --------------------------------------------------------------------------------