├── .flake8 ├── .github └── workflows │ ├── poetry-lint.yml │ ├── poetry-tests.yml │ └── publish.yml ├── .gitignore ├── CODE_OF_CONDUCT.md ├── CONTRIBUTING.md ├── Dockerfile ├── Dockerfile.prod ├── LICENSE ├── Makefile ├── README.md ├── cloudbuild.yaml ├── deploy └── base │ ├── cert.yaml │ ├── deployment_api.yaml │ ├── ingress.yaml │ ├── kustomization.yaml │ ├── pg.yaml │ └── service.yaml ├── logging.conf ├── poetry.lock ├── pyproject.toml ├── pytest.ini ├── scripts └── lint.py ├── taskara ├── __init__.py ├── agent.py ├── auth │ ├── default.py │ ├── key.py │ ├── provider.py │ └── transport.py ├── benchmark.py ├── cli │ └── main.py ├── config.py ├── db │ ├── conn.py │ ├── models.py │ └── redis_connection.py ├── env.py ├── flag.py ├── img.py ├── metrics.py ├── review.py ├── runtime │ ├── base.py │ ├── docker.py │ ├── kube.py │ ├── load.py │ └── process.py ├── server │ ├── app.py │ ├── models.py │ └── router │ │ ├── benchmarks.py │ │ └── tasks.py ├── task.py └── util.py └── tests ├── benchmarks └── airbnb.yaml ├── test_bench.py ├── test_eval.py ├── test_runtime.py ├── test_task.py └── test_tpl.py /.flake8: -------------------------------------------------------------------------------- 1 | [flake8] 2 | max-line-length = 120 3 | ignore = E203, E266, E501, W503 4 | max-complexity = 18 5 | select = B,C,E,F,W,T4,B9 6 | exclude = 7 | .git, 8 | __pycache__, 9 | build, 10 | dist, 11 | .venv, 12 | .tox, 13 | .mypy_cache, 14 | .pytest_cache, 15 | .vscode, 16 | -------------------------------------------------------------------------------- /.github/workflows/poetry-lint.yml: -------------------------------------------------------------------------------- 1 | name: Poetry Lint 2 | 3 | on: 4 | push: 5 | branches: [ '**' ] 6 | pull_request: 7 | branches: [ '**' ] 8 | 9 | jobs: 10 | lint: 11 | 12 | runs-on: ubuntu-latest 13 | 14 | steps: 15 | - uses: actions/checkout@v2 16 | 17 | - name: Set up Python 18 | uses: actions/setup-python@v2 19 | with: 20 | python-version: '3.10' 21 | 22 | - name: Install Poetry 23 | uses: snok/install-poetry@v1 24 | 25 | - name: Install dependencies 26 | run: | 27 | poetry install 28 | 29 | - name: Run lint 30 | run: | 31 | poetry run lint 32 | 33 | -------------------------------------------------------------------------------- /.github/workflows/poetry-tests.yml: -------------------------------------------------------------------------------- 1 | name: Poetry Tests 2 | 3 | on: 4 | push: 5 | branches: ["**"] 6 | pull_request: 7 | branches: ["**"] 8 | 9 | jobs: 10 | test: 11 | runs-on: ubuntu-latest 12 | 13 | steps: 14 | - uses: actions/checkout@v2 15 | 16 | - name: Set up Python 17 | uses: actions/setup-python@v2 18 | with: 19 | python-version: "3.10" 20 | 21 | - name: Install Poetry 22 | uses: snok/install-poetry@v1 23 | 24 | - name: Install dependencies 25 | run: | 26 | poetry install 27 | 28 | - name: Run tests 29 | run: poetry run pytest --ignore=tests/test_runtime.py # TODO: this doesn't run well in CI 30 | -------------------------------------------------------------------------------- /.github/workflows/publish.yml: -------------------------------------------------------------------------------- 1 | name: Publish Python package to PyPI 2 | 3 | on: 4 | push: 5 | tags: 6 | - 'v*' 7 | 8 | jobs: 9 | build-and-publish: 10 | runs-on: ubuntu-latest 11 | steps: 12 | - uses: actions/checkout@v2 13 | - name: Set up Python 14 | uses: actions/setup-python@v2 15 | with: 16 | python-version: '3.10' 17 | - name: Install dependencies 18 | run: | 19 | python -m pip install --upgrade pip 20 | pip install setuptools wheel twine poetry 21 | - name: Build and publish to PyPI 22 | env: 23 | PYPI_USERNAME: __token__ 24 | PYPI_PASSWORD: ${{ secrets.PYPI_PASSWORD }} 25 | run: | 26 | poetry build 27 | poetry publish -u $PYPI_USERNAME -p $PYPI_PASSWORD 28 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | share/python-wheels/ 24 | *.egg-info/ 25 | .installed.cfg 26 | *.egg 27 | MANIFEST 28 | 29 | # PyInstaller 30 | # Usually these files are written by a python script from a template 31 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 32 | *.manifest 33 | *.spec 34 | 35 | # Installer logs 36 | pip-log.txt 37 | pip-delete-this-directory.txt 38 | 39 | # Unit test / coverage reports 40 | htmlcov/ 41 | .tox/ 42 | .nox/ 43 | .coverage 44 | .coverage.* 45 | .cache 46 | nosetests.xml 47 | coverage.xml 48 | *.cover 49 | *.py,cover 50 | .hypothesis/ 51 | .pytest_cache/ 52 | cover/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | .pybuilder/ 76 | target/ 77 | 78 | # Jupyter Notebook 79 | .ipynb_checkpoints 80 | 81 | # IPython 82 | profile_default/ 83 | ipython_config.py 84 | 85 | # pyenv 86 | # For a library or package, you might want to ignore these files since the code is 87 | # intended to run in multiple environments; otherwise, check them in: 88 | # .python-version 89 | 90 | # pipenv 91 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 92 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 93 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 94 | # install all needed dependencies. 95 | #Pipfile.lock 96 | 97 | # poetry 98 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. 99 | # This is especially recommended for binary packages to ensure reproducibility, and is more 100 | # commonly ignored for libraries. 101 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control 102 | #poetry.lock 103 | 104 | # pdm 105 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. 106 | #pdm.lock 107 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it 108 | # in version control. 109 | # https://pdm.fming.dev/#use-with-ide 110 | .pdm.toml 111 | 112 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm 113 | __pypackages__/ 114 | 115 | # Celery stuff 116 | celerybeat-schedule 117 | celerybeat.pid 118 | 119 | # SageMath parsed files 120 | *.sage.py 121 | 122 | # Environments 123 | .env 124 | .venv 125 | env/ 126 | venv/ 127 | ENV/ 128 | env.bak/ 129 | venv.bak/ 130 | .envrc 131 | .vscode/ 132 | 133 | # Spyder project settings 134 | .spyderproject 135 | .spyproject 136 | 137 | # Rope project settings 138 | .ropeproject 139 | 140 | # mkdocs documentation 141 | /site 142 | 143 | # mypy 144 | .mypy_cache/ 145 | .dmypy.json 146 | dmypy.json 147 | 148 | # Pyre type checker 149 | .pyre/ 150 | 151 | # pytype static type analyzer 152 | .pytype/ 153 | 154 | # Cython debug symbols 155 | cython_debug/ 156 | 157 | # PyCharm 158 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can 159 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore 160 | # and can be added to the global gitignore or merged into this file. For a more nuclear 161 | # option (not recommended) you can uncomment the following to ignore the entire idea folder. 162 | #.idea/ 163 | .data 164 | data 165 | .agentsea 166 | scratch.ipynb 167 | -------------------------------------------------------------------------------- /CODE_OF_CONDUCT.md: -------------------------------------------------------------------------------- 1 | # Code of Conduct 2 | 3 | ## Our Pledge 4 | 5 | In the interest of fostering an open and welcoming environment, we as contributors and maintainers pledge to make participation in our project and our community a harassment-free experience for everyone, regardless of age, body size, disability, ethnicity, sex characteristics, gender identity and expression, level of experience, education, socio-economic status, nationality, personal appearance, race, religion, or sexual identity and orientation. 6 | 7 | ## Our Standards 8 | 9 | Examples of behavior that contributes to creating a positive environment include: 10 | 11 | - Using welcoming and inclusive language 12 | - Being respectful of differing viewpoints and experiences 13 | - Gracefully accepting constructive criticism 14 | - Focusing on what is best for the community 15 | - Showing empathy towards other community members 16 | 17 | Examples of unacceptable behavior by participants include: 18 | 19 | - The use of sexualized language or imagery and unwelcome sexual attention or advances 20 | - Trolling, insulting/derogatory comments, and personal or political attacks 21 | - Public or private harassment 22 | - Publishing others' private information, such as a physical or email address, without explicit permission 23 | - Other conduct which could reasonably be considered inappropriate in a professional setting 24 | 25 | ## Our Responsibilities 26 | 27 | Project maintainers are responsible for clarifying the standards of acceptable behavior and are expected to take appropriate and fair corrective action in response to any instances of unacceptable behavior. 28 | 29 | Project maintainers have the right and responsibility to remove, edit, or reject comments, commits, code, wiki edits, issues, and other contributions that are not aligned with this Code of Conduct, or to ban temporarily or permanently any contributor for other behaviors that they deem inappropriate, threatening, offensive, or harmful. 30 | 31 | ## Scope 32 | 33 | This Code of Conduct applies within all project spaces, including GitHub, and in public spaces when an individual is representing the project or its community. Examples of representing a project or community include using an official project e-mail address, posting via an official social media account, or acting as an appointed representative at an online or offline event. 34 | 35 | ## Enforcement 36 | 37 | Instances of abusive, harassing, or otherwise unacceptable behavior may be reported by contacting the project team at github@kentauros.ai. All complaints will be reviewed and investigated and will result in a response that is deemed necessary and appropriate to the circumstances. The project team is obligated to maintain confidentiality regarding the reporter of an incident. Further details of specific enforcement policies may be posted separately. 38 | 39 | Project maintainers who do not follow or enforce the Code of Conduct in good faith may face temporary or permanent repercussions as determined by other members of the project's leadership. 40 | 41 | ## Attribution 42 | 43 | This Code of Conduct is adapted from the [Contributor Covenant](https://www.contributor-covenant.org), version 2.0, available at https://www.contributor-covenant.org/version/2/0/code_of_conduct.html. 44 | 45 | Community Impact Guidelines were inspired by [Mozilla's code of conduct enforcement ladder](https://github.com/mozilla/diversity). 46 | 47 | For answers to common questions about this code of conduct, see https://www.contributor-covenant.org/faq. Translations are available at https://www.contributor-covenant.org/translations. 48 | -------------------------------------------------------------------------------- /CONTRIBUTING.md: -------------------------------------------------------------------------------- 1 | # Contributing 2 | 3 | First off, thank you for considering contributing to this project. It's people like you that make it such a great tool. 4 | 5 | ## Code of Conduct 6 | 7 | This project adheres to a Code of Conduct that we expect project participants to adhere to. Please read [the full text](CODE_OF_CONDUCT.md) so that you can understand what actions will and will not be tolerated. 8 | 9 | ## What we are looking for 10 | 11 | This is an open-source project, and we welcome contributions of all kinds: new features, bug fixes, documentation, examples, or enhancements to existing features. We are always thrilled to receive contributions from the community. 12 | 13 | ## How to contribute 14 | 15 | If you've never contributed to an open-source project before, here are a few steps to get you started: 16 | 17 | ### Reporting Issues 18 | 19 | Before submitting a bug report or feature request, check to make sure it hasn't already been submitted. You can search through existing issues and pull requests to see if someone has reported one similar to yours. 20 | 21 | When you are creating a bug report, please include as much detail as possible. 22 | 23 | ### Pull Requests 24 | 25 | - Fork the repository and create your branch from `main`. 26 | - If you've added code that should be tested, add tests. 27 | - If you've changed APIs, update the documentation. 28 | - Ensure the test suite passes. 29 | - Make sure your code lints. 30 | - Issue that pull request! 31 | 32 | ### Getting started 33 | 34 | For something that is bigger than a one or two-line fix: 35 | 36 | 1. Create your own fork of the code. 37 | 2. Do the changes in your fork. 38 | 3. If you like the change and think the project could use it: 39 | - Be sure you have followed the code style for the project. 40 | - Note the Code of Conduct. 41 | - Send a pull request. 42 | 43 | ## How to report a bug 44 | 45 | If you find a security vulnerability, do NOT open an issue. Email github@kentauros.ai instead. 46 | 47 | In order to help us understand and resolve your issue quickly, please include as much information as possible, including: 48 | 49 | - A quick summary and/or background 50 | - Steps to reproduce 51 | - Be specific! 52 | - Give a sample code if you can. 53 | - What you expected would happen 54 | - What actually happens 55 | - Notes (possibly including why you think this might be happening or stuff you tried that didn't work) 56 | 57 | People *love* thorough bug reports. I'm not even kidding. 58 | 59 | ## How to suggest a feature or enhancement 60 | 61 | If you find yourself wishing for a feature that doesn't exist in the project, you are probably not alone. There are bound to be others out there with similar needs. Open an issue on our issues list on GitHub, which describes the feature you would like to see, why you need it, and how it should work. 62 | 63 | ## Code review process 64 | 65 | The core team looks at Pull Requests on a regular basis in a bi-weekly triage meeting. After feedback has been given, we expect responses within two weeks. After two weeks, we may close the pull request if it isn't showing any activity. 66 | 67 | ## Community 68 | 69 | Discussions about the project take place in this repository's Issues and Pull Requests sections. Anybody is welcome to join these conversations. 70 | 71 | Wherever possible, we use GitHub to discuss changes and keep the decision-making process open. 72 | 73 | ## Thank you! 74 | 75 | Thank you for contributing! 76 | 77 | -------------------------------------------------------------------------------- /Dockerfile: -------------------------------------------------------------------------------- 1 | FROM python:3.10-slim-buster 2 | 3 | RUN apt-get update && apt-get install -y openssh-client ntp gcc python3-dev 4 | 5 | RUN pip install poetry 6 | 7 | COPY . /app 8 | WORKDIR /app 9 | 10 | RUN poetry install 11 | 12 | EXPOSE 9070 13 | CMD ["poetry", "run", "python", "-m", "taskara.server.app"] -------------------------------------------------------------------------------- /Dockerfile.prod: -------------------------------------------------------------------------------- 1 | FROM thehale/python-poetry:1.8.2-py3.10-slim 2 | 3 | COPY . /app 4 | WORKDIR /app 5 | 6 | RUN apt-get update && apt-get install -y openssh-client ntp 7 | RUN poetry install 8 | 9 | EXPOSE 9070 10 | 11 | CMD ["poetry", "run", "gunicorn", "-k", "uvicorn.workers.UvicornWorker", "threadmem.server.app:app", "--workers=4", "--bind", "0.0.0.0:8080", "--log-level", "debug", "--log-config", "logging.conf", "--timeout", "240"] 12 | 13 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2024 Kentauros AI 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /Makefile: -------------------------------------------------------------------------------- 1 | .PHONY: test 2 | test: 3 | rm -rf .agentsea 4 | poetry run pytest -v -s -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | 2 |
3 |

4 | 7 | 8 |

Taskara

9 | 10 |

11 | Task management for AI agents 12 |
13 | Explore the docs » 14 |
15 |
16 | View Demo 17 | · 18 | Report Bug 19 | · 20 | Request Feature 21 |

22 |
23 |

24 | 25 | ## Installation 26 | 27 | ```sh 28 | pip install taskara 29 | ``` 30 | 31 | ## Usage 32 | 33 | Create a task 34 | 35 | ```python 36 | from taskara import Task 37 | 38 | task = Task( 39 | description="Search for the most common varieties of french ducks", 40 | owner_id="delores@agentsea.ai" 41 | ) 42 | ``` 43 | 44 | Assign the task to an agent 45 | 46 | ```python 47 | task.assigned_to = "roko@agentsea.ai" 48 | ``` 49 | 50 | Post a message to the task thread 51 | 52 | ```python 53 | task.post_message("assistant", "Getting started working on this") 54 | task.status = "in progress" 55 | ``` 56 | 57 | Create a custom thread for the task 58 | 59 | ```python 60 | task.create_thread("debug") 61 | task.post_message("assistant", "I'll post debug messages to this thread", thread="debug") 62 | task.post_message("assistant", 'My current screenshot', images=["b64img"], thread="debug") 63 | ``` 64 | 65 | Store prompts used to accomplish the task 66 | 67 | ```python 68 | from mllm import RoleThread, RoleMessage 69 | 70 | thread = RoleThread() 71 | thread.post(role="system", msg="I am a helpful assistant") 72 | 73 | response = RoleMessage( 74 | role="assistant", 75 | text="How can I help?" 76 | ) 77 | task.store_prompt(thread, response, namespace="actions") 78 | ``` 79 | 80 | Store the result 81 | 82 | ```python 83 | task.output = "The most common type of french duck is the Rouen" 84 | task.status = "success" 85 | ``` 86 | 87 | Save the task 88 | 89 | ```python 90 | task.save() 91 | ``` 92 | 93 | ## Tracker 94 | 95 | Taskara comes with a task tracker server which can be run on docker or kubernetes. 96 | 97 | Install surfkit to create a tracker 98 | 99 | ``` 100 | pip install surfkit 101 | ``` 102 | 103 | Create a tracker 104 | 105 | ``` 106 | surfkit create tracker 107 | ``` 108 | 109 | List trackers 110 | 111 | ``` 112 | surfkit list trackers 113 | ``` 114 | 115 | Get tracker logs 116 | 117 | ``` 118 | surfkit logs tracker 119 | ``` 120 | 121 | Create a task 122 | 123 | ``` 124 | surfkit create task --description "Search for french ducks" 125 | ``` 126 | 127 | List tasks 128 | 129 | ``` 130 | surfkit list tasks 131 | ``` 132 | 133 | Get a task 134 | 135 | ``` 136 | surfkit get task 137 | ``` 138 | 139 | ## Integrations 140 | 141 | Taskara is integrated with: 142 | 143 | - [Surfkit](https://github.com/agentsea/surfkit) A platform for AI agents 144 | - [MLLM](https://github.com/agentsea/mllm) A prompt management, routing, and schema validation library for multimodal LLMs 145 | - [Skillpacks](https://github.com/agentsea/skillpacks) A library to fine tune AI agents on tasks. 146 | - [Threadmem](https://github.com/agentsea/threadmem) A thread management library for AI agents 147 | 148 | ## Community 149 | 150 | Come join us on [Discord](https://discord.gg/hhaq7XYPS6). 151 | 152 | ## Backends 153 | 154 | Thread and prompt storage can be backed by: 155 | 156 | - Sqlite 157 | - Postgresql 158 | 159 | Sqlite will be used by default. To use postgres simply configure the env vars: 160 | 161 | ```sh 162 | DB_TYPE=postgres 163 | DB_NAME=tasks 164 | DB_HOST=localhost 165 | DB_USER=postgres 166 | DB_PASS=abc123 167 | ``` 168 | 169 | Thread image storage by default will utilize the db, to configure bucket storage using GCS: 170 | 171 | - Create a bucket with fine grained permissions 172 | - Create a GCP service account JSON with permissions to write to the bucket 173 | 174 | ```sh 175 | export THREAD_STORAGE_SA_JSON='{ 176 | "type": "service_account", 177 | ... 178 | }' 179 | export THREAD_STORAGE_BUCKET=my-bucket 180 | ``` 181 | -------------------------------------------------------------------------------- /cloudbuild.yaml: -------------------------------------------------------------------------------- 1 | steps: 2 | - name: "gcr.io/cloud-builders/docker" 3 | entrypoint: "bash" 4 | args: 5 | - "-c" 6 | - | 7 | docker buildx create --name mybuilder --use 8 | docker buildx build \ 9 | --platform linux/amd64,linux/arm64 \ 10 | -t us-central1-docker.pkg.dev/agentsea-dev/taskara/api:$SHORT_SHA . \ 11 | --push \ 12 | --cache-from=type=registry,ref=us-central1-docker.pkg.dev/agentsea-dev/taskara/api:cache \ 13 | --cache-to=type=registry,ref=us-central1-docker.pkg.dev/agentsea-dev/taskara/api:cache,mode=max 14 | if [ "$BRANCH_NAME" == "main" ]; then 15 | docker buildx build \ 16 | --platform linux/amd64,linux/arm64 \ 17 | -t us-central1-docker.pkg.dev/agentsea-dev/taskara/api:latest . \ 18 | --push \ 19 | --cache-from=type=registry,ref=us-central1-docker.pkg.dev/agentsea-dev/taskara/api:cache \ 20 | --cache-to=type=registry,ref=us-central1-docker.pkg.dev/agentsea-dev/taskara/api:cache,mode=max 21 | fi 22 | -------------------------------------------------------------------------------- /deploy/base/cert.yaml: -------------------------------------------------------------------------------- 1 | apiVersion: networking.gke.io/v1 2 | kind: ManagedCertificate 3 | metadata: 4 | name: tasks-cert 5 | spec: 6 | domains: 7 | - api.tasks.my.domain 8 | -------------------------------------------------------------------------------- /deploy/base/deployment_api.yaml: -------------------------------------------------------------------------------- 1 | apiVersion: apps/v1 2 | kind: Deployment 3 | metadata: 4 | name: tasks-api 5 | spec: 6 | replicas: 1 7 | selector: 8 | matchLabels: 9 | app: tasks-api 10 | template: 11 | metadata: 12 | labels: 13 | app: tasks-api 14 | spec: 15 | containers: 16 | - name: tasks-api 17 | image: us-central1-docker.pkg.dev/agentsea-dev/taskara/api:latest 18 | imagePullPolicy: Always 19 | ports: 20 | - containerPort: 8080 21 | livenessProbe: 22 | httpGet: 23 | path: / 24 | port: 8080 25 | initialDelaySeconds: 20 26 | periodSeconds: 25 27 | readinessProbe: 28 | httpGet: 29 | path: / 30 | port: 8080 31 | initialDelaySeconds: 10 32 | periodSeconds: 15 33 | env: 34 | - name: AGENTSEA_HUB_URL 35 | value: https://api.hub.dev.agentlabs.xyz 36 | - name: DB_USER 37 | value: postgres 38 | - name: DB_PASS 39 | value: "abc12345" 40 | - name: DB_HOST 41 | value: postgres-tasks-service 42 | - name: DB_NAME 43 | value: tasks 44 | - name: DB_TYPE 45 | value: postgres 46 | -------------------------------------------------------------------------------- /deploy/base/ingress.yaml: -------------------------------------------------------------------------------- 1 | apiVersion: networking.k8s.io/v1 2 | kind: Ingress 3 | metadata: 4 | name: tasks-ingress 5 | annotations: 6 | ingressClassName: "gce" 7 | networking.gke.io/managed-certificates: tasks-cert 8 | kubernetes.io/ingress.global-static-ip-name: tasks-develop 9 | spec: 10 | rules: 11 | - host: api.tasks.my.domain 12 | http: 13 | paths: 14 | - path: / 15 | pathType: Prefix 16 | backend: 17 | service: 18 | name: tasks-api-service 19 | port: 20 | number: 80 21 | -------------------------------------------------------------------------------- /deploy/base/kustomization.yaml: -------------------------------------------------------------------------------- 1 | resources: 2 | - cert.yaml 3 | - deployment_api.yaml 4 | - ingress.yaml 5 | - pg.yaml 6 | - service.yaml 7 | -------------------------------------------------------------------------------- /deploy/base/pg.yaml: -------------------------------------------------------------------------------- 1 | apiVersion: apps/v1 2 | kind: Deployment 3 | metadata: 4 | name: postgres-tasks 5 | spec: 6 | replicas: 1 7 | selector: 8 | matchLabels: 9 | app: postgres-tasks 10 | template: 11 | metadata: 12 | labels: 13 | app: postgres-tasks 14 | spec: 15 | containers: 16 | - name: postgres-tasks 17 | image: postgres:16.2 18 | ports: 19 | - containerPort: 5432 20 | env: 21 | - name: POSTGRES_USER 22 | value: "postgres" 23 | - name: POSTGRES_PASSWORD 24 | value: "abc12345" 25 | - name: POSTGRES_DB 26 | value: "tasks" 27 | - name: PGDATA 28 | value: "/var/lib/postgresql/data/pgdata" 29 | volumeMounts: 30 | - mountPath: /var/lib/postgresql/data 31 | name: postgres-storage 32 | volumes: 33 | - name: postgres-storage 34 | persistentVolumeClaim: 35 | claimName: postgres-pvc 36 | --- 37 | apiVersion: v1 38 | kind: Service 39 | metadata: 40 | name: postgres-tasks-service 41 | spec: 42 | type: ClusterIP 43 | selector: 44 | app: postgres-tasks 45 | ports: 46 | - port: 5432 47 | targetPort: 5432 48 | --- 49 | apiVersion: v1 50 | kind: PersistentVolumeClaim 51 | metadata: 52 | name: postgres-tasks-pvc 53 | spec: 54 | accessModes: 55 | - ReadWriteOnce 56 | resources: 57 | requests: 58 | storage: 1Gi 59 | -------------------------------------------------------------------------------- /deploy/base/service.yaml: -------------------------------------------------------------------------------- 1 | apiVersion: v1 2 | kind: Service 3 | metadata: 4 | name: tasks-api-service 5 | annotations: 6 | beta.cloud.google.com/backend-config: '{"default": "tasks-api-config"}' 7 | cloud.google.com/neg: '{"ingress": true}' 8 | spec: 9 | type: NodePort 10 | selector: 11 | app: tasks-api 12 | ports: 13 | - protocol: TCP 14 | port: 80 15 | targetPort: 8080 16 | --- 17 | apiVersion: cloud.google.com/v1 18 | kind: BackendConfig 19 | metadata: 20 | name: tasks-api-config 21 | spec: 22 | timeoutSec: 21600 23 | healthCheck: 24 | checkIntervalSec: 30 25 | timeoutSec: 5 26 | healthyThreshold: 1 27 | unhealthyThreshold: 2 28 | requestPath: / 29 | port: 8080 30 | -------------------------------------------------------------------------------- /logging.conf: -------------------------------------------------------------------------------- 1 | [loggers] 2 | keys=root,uvicorn,uvicorn.error,uvicorn.access,uvicorn.asgi 3 | 4 | [handlers] 5 | keys=console 6 | 7 | [formatters] 8 | keys=generic 9 | 10 | [logger_root] 11 | level=DEBUG 12 | handlers=console 13 | 14 | [logger_uvicorn] 15 | level=DEBUG 16 | handlers=console 17 | qualname=uvicorn 18 | 19 | [logger_uvicorn.error] 20 | level=DEBUG 21 | handlers=console 22 | qualname=uvicorn.error 23 | propagate=0 24 | 25 | [logger_uvicorn.asgi] 26 | level=DEBUG 27 | handlers=console 28 | qualname=uvicorn.asgi 29 | propagate=0 30 | 31 | [logger_uvicorn.access] 32 | level=DEBUG 33 | handlers=console 34 | qualname=uvicorn.access 35 | propagate=0 36 | 37 | [handler_console] 38 | class=StreamHandler 39 | level=DEBUG 40 | formatter=generic 41 | 42 | [formatter_generic] 43 | format=%(asctime)s [%(process)d] [%(levelname)s] %(message)s 44 | datefmt=%Y-%m-%d %H:%M:%S 45 | class=logging.Formatter 46 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [tool.poetry] 2 | name = "taskara" 3 | version = "0.1.246" 4 | description = "Task management for AI agents" 5 | authors = ["Patrick Barker "] 6 | license = "MIT" 7 | readme = "README.md" 8 | 9 | [tool.poetry.dependencies] 10 | python = "^3.10" 11 | sqlalchemy = "^2.0.29" 12 | pydantic = "^2.6.4" 13 | docker = {version = "^7.0.0", optional = true} 14 | kubernetes = {version = "^29.0.0", optional = true} 15 | google-auth = {version = "^2.29.0", optional = true} 16 | google-cloud-container = {version = "^2.45.0", optional = true} 17 | namesgenerator = "^0.3" 18 | typer = {version = "^0.12.3", optional = true} 19 | tabulate = {version = "^0.9.0", optional = true} 20 | shortuuid = "^1.0.13" 21 | tqdm = "^4.66.4" 22 | cryptography = "^43.0.1" 23 | redis = "^5.2.0" 24 | agentcore = "^0.1.3" 25 | skillpacks = "^0.1.128" 26 | openmeter = "^1.0.0b188" 27 | 28 | [tool.poetry.group.dev.dependencies] 29 | pytest = "^8.1.1" 30 | flake8 = "^7.0.0" 31 | black = "^24.2.0" 32 | pytest-env = "^1.1.3" 33 | ipykernel = "^6.29.4" 34 | ruff = "^0.6.5" 35 | 36 | [tool.pyright] 37 | reportUnknownParameterType = false 38 | reportMissingTypeArgument = false 39 | reportUnknownMemberType = false 40 | reportUnknownVariableType = false 41 | reportUnknownArgumentType = false 42 | 43 | 44 | [tool.poetry.extras] 45 | runtime = ["kubernetes", "docker", "google-auth", "google-cloud-container"] 46 | cli = ["typer", "tabulate"] 47 | all = ["kubernetes", "docker", "google-auth", "google-cloud-container", "typer", "tabulate"] 48 | 49 | [build-system] 50 | requires = ["poetry-core"] 51 | build-backend = "poetry.core.masonry.api" 52 | 53 | [tool.poetry.scripts] 54 | lint = "scripts.lint:main" 55 | -------------------------------------------------------------------------------- /pytest.ini: -------------------------------------------------------------------------------- 1 | [pytest] 2 | pythonpath = . 3 | 4 | env = 5 | D:DB_TYPE=sqlite 6 | D:AGENTSEA_DB_TEST=true 7 | D:AGENTSEA_HOME=./.agentsea 8 | D:AGENTSEA_DB_DIR=./.agentsea/data/test 9 | -------------------------------------------------------------------------------- /scripts/lint.py: -------------------------------------------------------------------------------- 1 | import subprocess 2 | 3 | 4 | def main(): 5 | subprocess.run(["black", "."]) 6 | subprocess.run(["flake8", "."]) 7 | -------------------------------------------------------------------------------- /taskara/__init__.py: -------------------------------------------------------------------------------- 1 | from .benchmark import ( 2 | Benchmark, 3 | Eval, 4 | TaskTemplate, 5 | V1Benchmark, 6 | V1Eval, 7 | V1TaskTemplate, 8 | ) 9 | from .task import Task, TaskClient, TaskStatus, V1Task, V1Tasks, ReviewRequirement 10 | 11 | __all__ = [ 12 | "Task", 13 | "V1Task", 14 | "V1Tasks", 15 | "TaskStatus", 16 | "Benchmark", 17 | "V1Benchmark", 18 | "TaskTemplate", 19 | "V1TaskTemplate", 20 | "Eval", 21 | "V1Eval", 22 | "TaskClient", 23 | "ReviewRequirement", 24 | ] 25 | -------------------------------------------------------------------------------- /taskara/agent.py: -------------------------------------------------------------------------------- 1 | from abc import ABC, abstractmethod 2 | from typing import List, TypeVar, Type, Generic 3 | 4 | from pydantic import BaseModel 5 | 6 | from devicebay import Device 7 | from taskara import Task 8 | 9 | C = TypeVar("C", bound="BaseModel") 10 | T = TypeVar("T", bound="TaskAgent") 11 | 12 | 13 | class TaskAgent(Generic[C, T], ABC): 14 | """An agent that works on tasks""" 15 | 16 | @abstractmethod 17 | def solve_task( 18 | self, 19 | task: Task, 20 | device: Device, 21 | max_steps: int = 30, 22 | ) -> Task: 23 | """Solve a task on a device 24 | 25 | Args: 26 | task (Task): The task 27 | device (Device): Device to perform the task on. 28 | max_steps (int, optional): Max steps allowed. Defaults to 30. 29 | 30 | Returns: 31 | Task: A task 32 | """ 33 | pass 34 | 35 | @classmethod 36 | @abstractmethod 37 | def supported_devices(cls) -> List[Type[Device]]: 38 | """Devices this agent supports 39 | 40 | Returns: 41 | List[Type[Device]]: A list of supported devices 42 | """ 43 | pass 44 | 45 | @classmethod 46 | @abstractmethod 47 | def config_type(cls) -> Type[C]: 48 | """Type to configure the agent 49 | 50 | Returns: 51 | Type[C]: A configuration type 52 | """ 53 | pass 54 | 55 | @classmethod 56 | @abstractmethod 57 | def from_config(cls, config: C) -> T: 58 | """Create an agent from a config 59 | 60 | Args: 61 | config (C): Config to create the agent from 62 | 63 | Returns: 64 | T: The Agent 65 | """ 66 | pass 67 | 68 | @classmethod 69 | @abstractmethod 70 | def default(cls) -> T: 71 | """Create a default agent with no params 72 | 73 | Returns: 74 | T: The Agent 75 | """ 76 | pass 77 | 78 | @classmethod 79 | def init(cls) -> None: 80 | """Initialize the Agent type""" 81 | pass 82 | -------------------------------------------------------------------------------- /taskara/auth/default.py: -------------------------------------------------------------------------------- 1 | COMMON_USER = "common" 2 | -------------------------------------------------------------------------------- /taskara/auth/key.py: -------------------------------------------------------------------------------- 1 | import os 2 | from abc import ABC, abstractmethod 3 | from typing import Optional 4 | 5 | import requests 6 | from agentcore.models import V1UserProfile 7 | from requests.exceptions import RequestException 8 | from threadmem.db.conn import WithDB 9 | 10 | from taskara.config import AGENTSEA_AUTH_URL 11 | 12 | 13 | class KeyProvider(ABC): 14 | """API key provider""" 15 | 16 | @abstractmethod 17 | def create_key(self) -> str: 18 | pass 19 | 20 | @abstractmethod 21 | def is_key(self, token: str) -> bool: 22 | pass 23 | 24 | @abstractmethod 25 | def validate(self, token: str) -> V1UserProfile: 26 | pass 27 | 28 | 29 | class MockProvider(KeyProvider): 30 | """Mock key provider""" 31 | 32 | _key = "k.mock" 33 | 34 | def create_key(self) -> str: 35 | return self._key 36 | 37 | def is_key(self, token: str) -> bool: 38 | if token.startswith("k."): 39 | return True 40 | return False 41 | 42 | def validate(self, token: str) -> V1UserProfile: 43 | if self._key == token: 44 | return V1UserProfile( 45 | email="tom@myspace.com", 46 | display_name="tom", 47 | picture="https://i.insider.com/4efd9b8b69bedd682c000022?width=750&format=jpeg&auto=webp", 48 | ) 49 | raise ValueError("Invalid token") 50 | 51 | 52 | class HubKeyProvider(KeyProvider, WithDB): 53 | """AgentSea Hub provider""" 54 | 55 | def __init__(self) -> None: 56 | self.hub_url = AGENTSEA_AUTH_URL 57 | 58 | def create_key(self) -> str: 59 | raise NotImplementedError("create_key is not implemented") 60 | 61 | def is_key(self, token: str) -> bool: 62 | if token.startswith("k."): 63 | return True 64 | return False 65 | 66 | def validate(self, token: str) -> V1UserProfile: 67 | headers = {"Authorization": f"Bearer {token}"} 68 | try: 69 | response = requests.get(f"{self.hub_url}/v1/users/me", headers=headers) 70 | response.raise_for_status() # Raise an HTTPError if one occurred. 71 | 72 | user_data = response.json() 73 | # print("key response user data: ", user_data) 74 | prof = V1UserProfile(**user_data) 75 | return prof 76 | 77 | except RequestException as e: 78 | raise ValueError(f"Failed to validate token. Error: {e}") 79 | 80 | 81 | def get_key() -> Optional[str]: 82 | return os.environ.get("AGENTSEA_KEY") 83 | 84 | 85 | def ensure_key() -> str: 86 | key = get_key() 87 | if not key: 88 | raise ValueError("$AGENTSEA_KEY must be provided to use hub components") 89 | return key 90 | 91 | 92 | def default_key_provider() -> KeyProvider: 93 | return HubKeyProvider() 94 | -------------------------------------------------------------------------------- /taskara/auth/provider.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import os 3 | from abc import ABC, abstractmethod 4 | from typing import Optional 5 | from venv import logger 6 | 7 | import requests 8 | from agentcore.models import V1UserProfile 9 | 10 | from .key import KeyProvider, MockProvider, default_key_provider 11 | 12 | logger = logging.getLogger(__name__) 13 | 14 | 15 | class AuthProvider(ABC): 16 | @abstractmethod 17 | def key_provider(self) -> KeyProvider: 18 | pass 19 | 20 | @abstractmethod 21 | def get_user_auth(self, token: str) -> V1UserProfile: 22 | pass 23 | 24 | 25 | class HubAuthProvider(AuthProvider): 26 | """Hub user auth""" 27 | 28 | _key_provider: KeyProvider 29 | 30 | def __init__(self, key_provider: Optional[KeyProvider] = None) -> None: 31 | if not key_provider: 32 | key_provider = default_key_provider() 33 | self.hub_url = os.environ.get("AGENTSEA_AUTH_URL") 34 | if not self.hub_url: 35 | raise ValueError( 36 | "$AGENTSEA_AUTH_URL must be set to user the Hub key provider" 37 | ) 38 | 39 | self._key_provider = key_provider 40 | 41 | def key_provider(self) -> KeyProvider: 42 | return self._key_provider 43 | 44 | def get_user_auth(self, token: str) -> V1UserProfile: 45 | try: 46 | if self._key_provider.is_key(token): 47 | user = self._key_provider.validate(token) 48 | logger.debug(f"found user: {user}") 49 | 50 | return user 51 | 52 | else: 53 | headers = {"Authorization": f"Bearer {token}"} 54 | headers.update( 55 | { 56 | "User-Agent": "My User Agent 1.0", 57 | } 58 | ) 59 | auth_url = f"{self.hub_url}/v1/users/me" 60 | logger.debug(f"authorizing token with: {auth_url}") 61 | response = requests.get(auth_url, headers=headers) 62 | response.raise_for_status() 63 | 64 | user_data = response.json() 65 | user_schema = V1UserProfile(**user_data) 66 | user_schema.token = token 67 | return user_schema 68 | 69 | except Exception as e: 70 | logging.error(f"Problem fetching user auth {e}") 71 | raise Exception( 72 | "ID token was unauthorized, please log in", 73 | ) 74 | 75 | 76 | class MockAuthProvider(AuthProvider): 77 | """Mock user auth""" 78 | 79 | _key_provider: KeyProvider = MockProvider() 80 | 81 | def key_provider(self) -> KeyProvider: 82 | return self._key_provider 83 | 84 | def get_user_auth(self, token: str) -> V1UserProfile: 85 | try: 86 | if self._key_provider.is_key(token): 87 | user = self._key_provider.validate(token) 88 | return user 89 | 90 | else: 91 | return V1UserProfile( 92 | email="tom@myspace.com", 93 | display_name="tom", 94 | picture="https://i.insider.com/4efd9b8b69bedd682c000022?width=750&format=jpeg&auto=webp", 95 | ) 96 | 97 | except Exception as e: 98 | logging.error(f"Problem fetching user auth {e}") 99 | raise Exception( 100 | "ID token was unauthorized, please log in", 101 | ) 102 | 103 | 104 | def default_auth_provider() -> AuthProvider: 105 | return HubAuthProvider() 106 | -------------------------------------------------------------------------------- /taskara/auth/transport.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import os 3 | from typing import Annotated 4 | 5 | from agentcore.models import V1UserProfile 6 | from fastapi import Depends, HTTPException 7 | from fastapi.security import OAuth2PasswordBearer 8 | 9 | from .provider import default_auth_provider 10 | 11 | logger = logging.getLogger(__name__) 12 | 13 | if os.getenv("TASK_SERVER_NO_AUTH", "false").lower() == "true": 14 | user_auth = None 15 | else: 16 | user_auth = default_auth_provider() 17 | 18 | oauth2_scheme = OAuth2PasswordBearer(tokenUrl="token") 19 | 20 | 21 | async def get_current_user( 22 | token: Annotated[str, Depends(oauth2_scheme)], 23 | ) -> V1UserProfile: 24 | if not user_auth: 25 | raise SystemError("user auth is not configured") 26 | try: 27 | logger.debug(f"checking user token: {token}") 28 | user = user_auth.get_user_auth(token) 29 | except Exception as e: 30 | logging.error(e) 31 | raise HTTPException( 32 | status_code=401, 33 | detail=f"-ID token was unauthorized, please log in: {e}", 34 | ) 35 | 36 | return user 37 | 38 | 39 | async def get_user_mock_auth() -> V1UserProfile: 40 | # Return a dummy user profile when authentication is disabled 41 | return V1UserProfile( 42 | email="tom@myspace.com", 43 | display_name="tom", 44 | picture="https://i.insider.com/4efd9b8b69bedd682c000022?width=750&format=jpeg&auto=webp", 45 | ) 46 | 47 | 48 | def get_user_dependency(): 49 | if os.getenv("TASK_SERVER_NO_AUTH", "false").lower() == "true": 50 | return get_user_mock_auth 51 | else: 52 | return get_current_user 53 | -------------------------------------------------------------------------------- /taskara/benchmark.py: -------------------------------------------------------------------------------- 1 | import json 2 | import time 3 | from typing import Any, Dict, List, Optional, Type 4 | 5 | import shortuuid 6 | from devicebay import V1Device, V1DeviceType 7 | from pydantic import BaseModel 8 | 9 | from taskara.db.conn import WithDB 10 | from taskara.db.models import ( 11 | BenchmarkRecord, 12 | EvalRecord, 13 | TaskRecord, 14 | TaskTemplateRecord, 15 | benchmark_task_association, 16 | eval_task_association, 17 | ) 18 | from taskara.server.models import V1Benchmark, V1Eval, V1TaskTemplate 19 | from taskara.task import Task 20 | 21 | 22 | class TaskTemplate(WithDB): 23 | """A task template""" 24 | 25 | def __init__( 26 | self, 27 | description: Optional[str] = None, 28 | max_steps: int = 30, 29 | owner_id: Optional[str] = None, 30 | device: Optional[V1Device] = None, 31 | device_type: Optional[V1DeviceType] = None, 32 | expect: Optional[Type[BaseModel]] = None, 33 | parameters: Dict[str, Any] = {}, 34 | labels: Dict[str, str] = {}, 35 | tags: List[str] = [], 36 | ) -> None: 37 | self._id = shortuuid.uuid() 38 | self._description = description 39 | self._max_steps = max_steps 40 | self._owner_id = owner_id 41 | self._device = device 42 | self._device_type = device_type 43 | self._expect_schema = expect.model_json_schema() if expect else None 44 | self._parameters = parameters 45 | self._labels = labels 46 | self._tags = tags 47 | self._created = time.time() 48 | 49 | @property 50 | def id(self) -> str: 51 | return self._id 52 | 53 | @property 54 | def description(self) -> Optional[str]: 55 | return self._description 56 | 57 | @property 58 | def max_steps(self) -> int: 59 | return self._max_steps 60 | 61 | @property 62 | def owner_id(self) -> Optional[str]: 63 | return self._owner_id 64 | 65 | @property 66 | def device(self) -> Optional[V1Device]: 67 | return self._device 68 | 69 | @property 70 | def device_type(self) -> Optional[V1DeviceType]: 71 | return self._device_type 72 | 73 | @property 74 | def expect_schema(self) -> Optional[Dict[str, Any]]: 75 | return self._expect_schema 76 | 77 | @property 78 | def parameters(self) -> Optional[Dict[str, Any]]: 79 | return self._parameters 80 | 81 | @property 82 | def labels(self) -> Dict[str, str]: 83 | return self._labels 84 | 85 | @property 86 | def tags(self) -> List[str]: 87 | return self._tags 88 | 89 | @property 90 | def created(self) -> float: 91 | return self._created 92 | 93 | def to_task( 94 | self, 95 | assigned_to: Optional[str] = None, 96 | assigned_type: Optional[str] = None, 97 | remote: Optional[str] = None, 98 | owner_id: Optional[str] = None, 99 | ) -> Task: 100 | task = Task( 101 | description=self.description, 102 | max_steps=self.max_steps, 103 | device=self.device, 104 | device_type=self.device_type, 105 | assigned_to=assigned_to, 106 | assigned_type=assigned_type, 107 | remote=remote, 108 | labels=self.labels, 109 | tags=self.tags, 110 | owner_id=owner_id if owner_id else self.owner_id, 111 | ) 112 | task._expect_schema = self.expect_schema 113 | return task 114 | 115 | @classmethod 116 | def from_task(cls, task: Task) -> "TaskTemplate": 117 | tpl = cls( 118 | description=task.description, 119 | max_steps=task.max_steps, 120 | device=task.device, 121 | device_type=task.device_type, 122 | parameters=task.parameters if task.parameters else {}, 123 | labels=task.labels, 124 | tags=task.tags, 125 | owner_id=task.owner_id, 126 | ) 127 | tpl._expect_schema = task._expect_schema 128 | return tpl 129 | 130 | def to_record(self) -> TaskTemplateRecord: 131 | 132 | device = None 133 | if self._device: 134 | device = self._device.model_dump_json() 135 | 136 | device_type = None 137 | if self._device_type: 138 | device_type = self._device_type.model_dump_json() 139 | 140 | expect = None 141 | if self._expect_schema: 142 | expect = json.dumps(self._expect_schema) 143 | 144 | return TaskTemplateRecord( 145 | id=self._id, 146 | owner_id=self._owner_id, 147 | description=self._description, 148 | max_steps=self._max_steps, 149 | device=device, 150 | device_type=device_type, 151 | expect=expect, 152 | parameters=json.dumps(self._parameters), 153 | tags=json.dumps(self.tags), 154 | labels=json.dumps(self.labels), 155 | created=self._created, 156 | ) 157 | 158 | @classmethod 159 | def from_record(cls, record: TaskTemplateRecord) -> "TaskTemplate": 160 | 161 | parameters = json.loads(str(record.parameters)) 162 | 163 | device = None 164 | if record.device: # type: ignore 165 | device = V1Device.model_validate_json(str(record.device)) 166 | 167 | device_type = None 168 | if record.device_type: # type: ignore 169 | device_type = V1DeviceType.model_validate_json(str(record.device_type)) 170 | 171 | expect = None 172 | if record.expect: # type: ignore 173 | expect = json.loads(str(record.expect)) 174 | 175 | obj = cls.__new__(cls) 176 | obj._id = record.id 177 | obj._owner_id = record.owner_id 178 | obj._description = record.description 179 | obj._max_steps = record.max_steps 180 | obj._device = device 181 | obj._device_type = device_type 182 | obj._expect_schema = expect 183 | obj._parameters = parameters 184 | obj._tags = json.loads(str(record.tags)) 185 | obj._labels = json.loads(str(record.labels)) 186 | obj._created = record.created 187 | return obj 188 | 189 | def to_v1(self) -> V1TaskTemplate: 190 | return V1TaskTemplate( 191 | id=self._id, 192 | description=self._description if self._description else "", 193 | max_steps=self._max_steps, 194 | device=self._device, 195 | device_type=self.device_type, 196 | expect_schema=self._expect_schema, 197 | parameters=self._parameters, 198 | owner_id=self._owner_id, 199 | tags=self._tags, 200 | labels=self._labels, 201 | created=self._created, 202 | ) 203 | 204 | @classmethod 205 | def from_v1( 206 | cls, v1: V1TaskTemplate, owner_id: Optional[str] = None 207 | ) -> "TaskTemplate": 208 | obj = cls.__new__(cls) 209 | 210 | owner_id = owner_id if owner_id else v1.owner_id 211 | if not owner_id: 212 | raise ValueError("Owner id is required in v1 or as parameter") 213 | 214 | obj._id = v1.id if v1.id else shortuuid.uuid() 215 | obj._owner_id = owner_id 216 | obj._description = v1.description 217 | obj._max_steps = v1.max_steps 218 | obj._device = v1.device 219 | obj._device_type = v1.device_type 220 | obj._expect_schema = v1.expect_schema 221 | obj._parameters = v1.parameters 222 | obj._owner_id = owner_id 223 | obj._tags = v1.tags 224 | obj._labels = v1.labels 225 | obj._created = v1.created 226 | 227 | return obj 228 | 229 | def save(self) -> None: 230 | for db in self.get_db(): 231 | db.merge(self.to_record()) 232 | db.commit() 233 | 234 | 235 | class Benchmark(WithDB): 236 | """An agent benchmark""" 237 | 238 | def __init__( 239 | self, 240 | name: str, 241 | description: str, 242 | tasks: List[TaskTemplate], 243 | owner_id: Optional[str] = None, 244 | labels: Dict[str, str] = {}, 245 | tags: List[str] = [], 246 | public: bool = False, 247 | ): 248 | self._tasks = tasks 249 | self._id = shortuuid.uuid() 250 | self._name = name 251 | self._description = description 252 | self._owner_id = owner_id 253 | self._labels = labels 254 | self._tags = tags 255 | self._public = public 256 | self._created = time.time() 257 | 258 | for task in tasks: 259 | task.labels["benchmark"] = self.name 260 | 261 | @property 262 | def name(self) -> str: 263 | return self._name 264 | 265 | @property 266 | def tasks(self) -> List[TaskTemplate]: 267 | return self._tasks 268 | 269 | @property 270 | def description(self) -> str: 271 | return self._description 272 | 273 | @property 274 | def id(self) -> str: 275 | return self._id 276 | 277 | @property 278 | def owner_id(self) -> Optional[str]: 279 | return self._owner_id 280 | 281 | @property 282 | def labels(self) -> Dict[str, str]: 283 | return self._labels 284 | 285 | @property 286 | def tags(self) -> List[str]: 287 | return self._tags 288 | 289 | @property 290 | def public(self) -> bool: 291 | return self._public 292 | 293 | def eval( 294 | self, 295 | assigned_to: str | None = None, 296 | assigned_type: str | None = None, 297 | remote: str | None = None, 298 | owner_id: str | None = None, 299 | ) -> "Eval": 300 | return Eval( 301 | benchmark=self, 302 | assigned_to=assigned_to, 303 | assigned_type=assigned_type, 304 | remote=remote, 305 | owner_id=owner_id, 306 | ) 307 | 308 | def to_record(self) -> BenchmarkRecord: 309 | record = BenchmarkRecord( 310 | id=self._id, 311 | owner_id=self._owner_id, 312 | name=self._name, 313 | description=self._description, 314 | public=self._public, 315 | tags=json.dumps(self._tags), 316 | labels=json.dumps(self._labels), 317 | created=self._created, 318 | ) 319 | return record 320 | 321 | @classmethod 322 | def from_record(cls, record: BenchmarkRecord, db_session) -> "Benchmark": 323 | # Retrieve task templates associated with the benchmark 324 | task_records = ( 325 | db_session.query(TaskTemplateRecord) 326 | .join( 327 | benchmark_task_association, 328 | TaskTemplateRecord.id == benchmark_task_association.c.task_template_id, 329 | ) 330 | .filter(benchmark_task_association.c.benchmark_id == record.id) 331 | .all() 332 | ) 333 | tasks = [TaskTemplate.from_record(task_record) for task_record in task_records] 334 | 335 | obj = cls.__new__(cls) 336 | obj._id = record.id 337 | obj._owner_id = record.owner_id 338 | obj._name = record.name 339 | obj._description = record.description 340 | obj._labels = json.loads(str(record.labels)) 341 | obj._tags = json.loads(str(record.tags)) 342 | obj._created = record.created 343 | obj._tasks = tasks 344 | obj._public = record.public 345 | return obj 346 | 347 | def to_v1(self) -> V1Benchmark: 348 | return V1Benchmark( 349 | id=self._id, 350 | name=self._name, 351 | description=self._description, 352 | tasks=[task.to_v1() for task in self._tasks], 353 | owner_id=self._owner_id, 354 | tags=self._tags, 355 | labels=self._labels, 356 | created=self._created, 357 | public=self._public, 358 | ) 359 | 360 | @classmethod 361 | def from_v1(cls, v1: V1Benchmark, owner_id: Optional[str] = None) -> "Benchmark": 362 | tasks = [ 363 | TaskTemplate.from_v1(task, owner_id=owner_id if owner_id else v1.owner_id) 364 | for task in v1.tasks 365 | ] 366 | for task in tasks: 367 | task.save() 368 | 369 | obj = cls.__new__(cls) 370 | owner_id = owner_id if owner_id else v1.owner_id 371 | if not owner_id: 372 | raise ValueError("Owner id is required in v1 or as parameter") 373 | 374 | obj._id = v1.id if v1.id else shortuuid.uuid() 375 | obj._owner_id = owner_id 376 | obj._name = v1.name 377 | obj._description = v1.description 378 | obj._tasks = tasks 379 | obj._labels = v1.labels 380 | obj._tags = v1.tags 381 | obj._created = v1.created 382 | obj._public = v1.public 383 | 384 | return obj 385 | 386 | @classmethod 387 | def find(cls, remote: Optional[str] = None, **kwargs) -> List["Benchmark"]: 388 | for db in cls.get_db(): 389 | records = ( 390 | db.query(BenchmarkRecord) 391 | .filter_by(**kwargs) 392 | .order_by(BenchmarkRecord.created.desc()) 393 | .all() 394 | ) 395 | return [cls.from_record(record, db) for record in records] 396 | raise ValueError("No session") 397 | 398 | def save(self) -> None: 399 | for db in self.get_db(): 400 | # Save the benchmark record 401 | benchmark_record = self.to_record() 402 | db.merge(benchmark_record) 403 | db.commit() 404 | 405 | # Save the task records and associations 406 | for task in self._tasks: 407 | task_record = task.to_record() 408 | db.merge(task_record) 409 | db.commit() 410 | 411 | association = benchmark_task_association.insert().values( 412 | benchmark_id=self._id, task_template_id=task.id 413 | ) 414 | db.execute(association) 415 | db.commit() 416 | 417 | def delete(self) -> None: 418 | for db in self.get_db(): 419 | # Delete the benchmark record 420 | benchmark_record = db.query(BenchmarkRecord).filter_by(id=self._id).first() 421 | if benchmark_record: 422 | db.delete(benchmark_record) 423 | db.commit() 424 | 425 | # Delete the task records and associations 426 | db.execute( 427 | benchmark_task_association.delete().where( 428 | benchmark_task_association.c.benchmark_id == self._id 429 | ) 430 | ) 431 | db.commit() 432 | 433 | for task in self._tasks: 434 | task_record = db.query(TaskTemplateRecord).filter_by(id=task.id).first() 435 | if task_record: 436 | db.delete(task_record) 437 | db.commit() 438 | 439 | 440 | class Eval(WithDB): 441 | """An agent evaluation on a benchmark""" 442 | 443 | def __init__( 444 | self, 445 | benchmark: Benchmark, 446 | assigned_to: Optional[str] = None, 447 | assigned_type: Optional[str] = None, 448 | remote: Optional[str] = None, 449 | owner_id: Optional[str] = None, 450 | ) -> None: 451 | self._id = shortuuid.uuid() 452 | self._benchmark = benchmark 453 | self._tasks: List[Task] = [] 454 | self._owner_id = owner_id 455 | self._assigned_to = assigned_to 456 | self._assigned_type = assigned_type 457 | 458 | for tpl in self._benchmark.tasks: 459 | task = tpl.to_task( 460 | assigned_to=assigned_to, 461 | assigned_type=assigned_type, 462 | remote=remote, 463 | owner_id=owner_id, 464 | ) 465 | task.labels["benchmark"] = self._benchmark.name 466 | self._tasks.append(task) 467 | 468 | @property 469 | def tasks(self) -> List[Task]: 470 | return self._tasks 471 | 472 | @property 473 | def benchmark(self) -> Benchmark: 474 | return self._benchmark 475 | 476 | @property 477 | def id(self) -> str: 478 | return self._id 479 | 480 | @property 481 | def owner_id(self) -> Optional[str]: 482 | return self._owner_id 483 | 484 | def to_record(self) -> EvalRecord: 485 | return EvalRecord( 486 | id=self._id, 487 | benchmark_id=self._benchmark.id, 488 | assigned_to=self._assigned_to, 489 | assigned_type=self._assigned_type, 490 | owner_id=self._owner_id, 491 | created=time.time(), 492 | ) 493 | 494 | @classmethod 495 | def from_record(cls, record: EvalRecord, db_session) -> "Eval": 496 | benchmark = Benchmark.from_record( 497 | db_session.query(BenchmarkRecord).filter_by(id=record.benchmark_id).first(), 498 | db_session, 499 | ) 500 | # Correctly extract task_ids from the association table 501 | task_associations = ( 502 | db_session.query(eval_task_association.c.task_id) 503 | .filter_by(eval_id=record.id) 504 | .all() 505 | ) 506 | task_ids = [task_id for (task_id,) in task_associations] 507 | tasks = [ 508 | Task.from_record(db_session.query(TaskRecord).filter_by(id=task_id).first()) 509 | for task_id in task_ids 510 | ] 511 | 512 | obj = cls.__new__(cls) 513 | obj._id = record.id 514 | obj._benchmark = benchmark 515 | obj._tasks = tasks 516 | obj._owner_id = record.owner_id 517 | obj._assigned_to = record.assigned_to 518 | obj._assigned_type = record.assigned_type 519 | 520 | return obj 521 | 522 | def to_v1(self) -> V1Eval: 523 | return V1Eval( 524 | id=self._id, 525 | benchmark=self._benchmark.to_v1(), 526 | tasks=[task.to_v1() for task in self._tasks], 527 | assigned_to=self._assigned_to, 528 | assigned_type=self._assigned_type, 529 | owner_id=self._owner_id, 530 | ) 531 | 532 | @classmethod 533 | def from_v1(cls, v1: V1Eval, owner_id: Optional[str] = None) -> "Eval": 534 | benchmark = Benchmark.from_v1(v1.benchmark, owner_id=owner_id) 535 | tasks = [ 536 | Task.from_v1(task, owner_id=owner_id if owner_id else v1.owner_id) 537 | for task in v1.tasks 538 | ] 539 | 540 | obj = cls.__new__(cls) 541 | obj._id = v1.id if v1.id else shortuuid.uuid() 542 | obj._benchmark = benchmark 543 | obj._tasks = tasks 544 | obj._owner_id = owner_id if owner_id else v1.owner_id 545 | obj._assigned_to = v1.assigned_to 546 | obj._assigned_type = v1.assigned_type 547 | 548 | return obj 549 | 550 | @classmethod 551 | def find(cls, remote: Optional[str] = None, **kwargs) -> List["Eval"]: 552 | for db in cls.get_db(): 553 | records = ( 554 | db.query(EvalRecord) 555 | .filter_by(**kwargs) 556 | .order_by(EvalRecord.created.desc()) 557 | .all() 558 | ) 559 | return [cls.from_record(record, db) for record in records] 560 | raise ValueError("No session") 561 | 562 | def save(self) -> None: 563 | for db in self.get_db(): 564 | # Save the evaluation record 565 | eval_record = self.to_record() 566 | db.merge(eval_record) 567 | db.commit() 568 | 569 | # Save the task records and associations 570 | for task in self._tasks: 571 | task_record = task.to_record() 572 | db.merge(task_record) 573 | db.commit() 574 | 575 | association = eval_task_association.insert().values( 576 | eval_id=self._id, task_id=task.id 577 | ) 578 | db.execute(association) 579 | db.commit() 580 | 581 | def delete(self) -> None: 582 | for db in self.get_db(): 583 | # Delete the evaluation record 584 | eval_record = db.query(EvalRecord).filter_by(id=self._id).first() 585 | if eval_record: 586 | db.delete(eval_record) 587 | db.commit() 588 | 589 | # Delete the task records and associations 590 | db.execute( 591 | eval_task_association.delete().where( 592 | eval_task_association.c.eval_id == self._id 593 | ) 594 | ) 595 | db.commit() 596 | -------------------------------------------------------------------------------- /taskara/cli/main.py: -------------------------------------------------------------------------------- 1 | from typing import Optional 2 | from importlib.metadata import PackageNotFoundError 3 | from importlib.metadata import version as pkgversion 4 | 5 | import typer 6 | from namesgenerator import get_random_name 7 | from tabulate import tabulate 8 | import rich 9 | 10 | 11 | art = """ 12 | _______ _ 13 | (_______) | | 14 | _ _____ ___| | _ _____ ____ _____ 15 | | (____ |/___) |_/ |____ |/ ___|____ | 16 | | / ___ |___ | _ (/ ___ | | / ___ | 17 | |_\_____(___/|_| \_)_____|_| \_____| 18 | 19 | """ 20 | 21 | app = typer.Typer() 22 | 23 | # Sub-command groups 24 | create_group = typer.Typer(help="Create resources") 25 | list_group = typer.Typer(help="List resources") 26 | get_group = typer.Typer(help="Get resources") 27 | view_group = typer.Typer(help="View resources") 28 | delete_group = typer.Typer(help="Delete resources") 29 | clean_group = typer.Typer(help="Clean resources") 30 | 31 | app.add_typer(create_group, name="create") 32 | app.add_typer(list_group, name="list") 33 | app.add_typer(get_group, name="get") 34 | app.add_typer(view_group, name="view") 35 | app.add_typer(delete_group, name="delete") 36 | app.add_typer(clean_group, name="clean") 37 | 38 | 39 | # Callback for showing help 40 | def show_help(ctx: typer.Context, command_group: str): 41 | if ctx.invoked_subcommand is None: 42 | if command_group == "root": 43 | typer.echo(art) 44 | typer.echo(ctx.get_help()) 45 | raise typer.Exit() 46 | 47 | 48 | try: 49 | __version__ = pkgversion("surfkit") 50 | except PackageNotFoundError: 51 | # Fallback version or error handling 52 | __version__ = "unknown" 53 | 54 | 55 | @app.command(help="Show the version of the CLI") 56 | def version(): 57 | """Show the CLI version.""" 58 | typer.echo(__version__) 59 | 60 | 61 | # Root command callback 62 | @app.callback(invoke_without_command=True) 63 | def default(ctx: typer.Context): 64 | show_help(ctx, "root") 65 | 66 | 67 | # 'create' command group callback 68 | @create_group.callback(invoke_without_command=True) 69 | def create_default(ctx: typer.Context): 70 | show_help(ctx, "create") 71 | 72 | 73 | # 'list' command group callback 74 | @list_group.callback(invoke_without_command=True) 75 | def list_default(ctx: typer.Context): 76 | show_help(ctx, "list") 77 | 78 | 79 | # 'get' command group callback 80 | @get_group.callback(invoke_without_command=True) 81 | def get_default(ctx: typer.Context): 82 | show_help(ctx, "get") 83 | 84 | 85 | # 'delete' command group callback 86 | @delete_group.callback(invoke_without_command=True) 87 | def delete_default(ctx: typer.Context): 88 | show_help(ctx, "delete") 89 | 90 | 91 | # 'view' command group callback 92 | @view_group.callback(invoke_without_command=True) 93 | def view_default(ctx: typer.Context): 94 | show_help(ctx, "view") 95 | 96 | 97 | # 'clean' command group callback 98 | @clean_group.callback(invoke_without_command=True) 99 | def clean_default(ctx: typer.Context): 100 | show_help(ctx, "clean") 101 | 102 | 103 | # 'create' sub-commands 104 | @create_group.command("task") 105 | def create_task( 106 | description: str = typer.Option( 107 | ..., 108 | "--description", 109 | "-d", 110 | help="Description of the task. Defaults to a generated name.", 111 | ), 112 | remote: bool = typer.Option(True, "--remote", "-r", help="List tasks from remote"), 113 | ): 114 | from taskara import Task 115 | 116 | typer.echo(f"Creating task '{description}'") 117 | try: 118 | task = Task(description=description) 119 | except KeyboardInterrupt: 120 | print("Keyboard interrupt received, exiting...") 121 | return 122 | -------------------------------------------------------------------------------- /taskara/config.py: -------------------------------------------------------------------------------- 1 | """ 2 | Configuration for taskara 3 | """ 4 | 5 | import os 6 | from dataclasses import dataclass 7 | from typing import Optional 8 | 9 | import yaml 10 | 11 | from .env import AGENTSEA_AUTH_URL_ENV, AGENTSEA_HUB_API_URL_ENV, AGENTSEA_HUB_URL_ENV 12 | 13 | AGENTSEA_HOME = os.path.expanduser(os.environ.get("AGENTSEA_HOME", "~/.agentsea")) 14 | AGENTSEA_DB_DIR = os.path.expanduser( 15 | os.environ.get("AGENTSEA_DB_DIR", os.path.join(AGENTSEA_HOME, "data")) 16 | ) 17 | AGENTSEA_LOG_DIR = os.path.expanduser( 18 | os.environ.get("AGENTSEA_LOG_DIR", os.path.join(AGENTSEA_HOME, "logs")) 19 | ) 20 | AGENTSEA_AUTH_URL = os.getenv(AGENTSEA_AUTH_URL_ENV, "https://auth.hub.agentsea.ai") 21 | AGENTSEA_HUB_URL = os.getenv(AGENTSEA_HUB_URL_ENV, "https://hub.agentsea.ai") 22 | AGENTSEA_HUB_API_URL = os.getenv( 23 | AGENTSEA_HUB_API_URL_ENV, "https://api.hub.agentsea.ai" 24 | ) 25 | DB_TEST = os.environ.get("AGENTSEA_DB_TEST", "false") == "true" 26 | DB_NAME = os.environ.get("TASKS_DB_NAME", "tasks.db") 27 | if DB_TEST: 28 | DB_NAME = "tasks_test.db" 29 | 30 | 31 | @dataclass 32 | class GlobalConfig: 33 | api_key: Optional[str] = None 34 | hub_address: str = AGENTSEA_HUB_URL 35 | 36 | def write(self) -> None: 37 | home = os.path.expanduser("~") 38 | dir = os.path.join(home, ".agentsea") 39 | os.makedirs(dir, exist_ok=True) 40 | path = os.path.join(dir, "config.yaml") 41 | 42 | with open(path, "w") as yaml_file: 43 | yaml.dump(self.__dict__, yaml_file) 44 | yaml_file.flush() 45 | yaml_file.close() 46 | 47 | @classmethod 48 | def read(cls) -> "GlobalConfig": 49 | home = os.path.expanduser("~") 50 | dir = os.path.join(home, ".agentsea") 51 | os.makedirs(dir, exist_ok=True) 52 | path = os.path.join(dir, "config.yaml") 53 | 54 | if not os.path.exists(path): 55 | return GlobalConfig() 56 | 57 | with open(path, "r") as yaml_file: 58 | config = yaml.safe_load(yaml_file) 59 | return GlobalConfig(**config) 60 | -------------------------------------------------------------------------------- /taskara/db/conn.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import os 3 | 4 | from sqlalchemy import Engine, create_engine 5 | from sqlalchemy.orm import sessionmaker 6 | 7 | from taskara.config import AGENTSEA_DB_DIR, DB_NAME 8 | 9 | from .models import Base 10 | 11 | logger = logging.getLogger(__name__) 12 | 13 | DB_TYPE = os.environ.get("DB_TYPE", "sqlite") 14 | 15 | 16 | def get_pg_conn() -> Engine: 17 | # Helper function to get environment variable with fallback 18 | def get_env_var(key: str) -> str: 19 | task_key = f"TASKS_{key}" 20 | value = os.environ.get(task_key) 21 | if value is None: 22 | value = os.environ.get(key) 23 | if value is None: 24 | raise ValueError(f"${key} must be set") 25 | return value 26 | 27 | # Retrieve environment variables with fallbacks 28 | db_user = get_env_var("DB_USER") 29 | db_password = get_env_var("DB_PASS") 30 | db_host = get_env_var("DB_HOST") 31 | db_name = get_env_var("DB_NAME") 32 | 33 | logger.debug(f"connecting to db on postgres host '{db_host}' with db '{db_name}'") 34 | engine = create_engine( 35 | f"postgresql+psycopg2://{db_user}:{db_password}@{db_host}/{db_name}", 36 | client_encoding="utf8", 37 | pool_pre_ping=True, 38 | pool_recycle=300 39 | ) 40 | 41 | return engine 42 | 43 | 44 | def get_sqlite_conn() -> Engine: 45 | db_path = os.path.join(AGENTSEA_DB_DIR, DB_NAME) 46 | if os.getenv("TASKARA_DEBUG"): 47 | print(f"connecting to local sqlite db {db_path}", flush=True) 48 | os.makedirs(AGENTSEA_DB_DIR, exist_ok=True) 49 | try: 50 | engine = create_engine(f"sqlite:///{db_path}") 51 | except Exception as e: 52 | logger.error(f"error connecting to sqlite db {db_path}: {e}") 53 | raise e 54 | return engine 55 | 56 | 57 | if DB_TYPE == "postgres": 58 | engine = get_pg_conn() 59 | else: 60 | engine = get_sqlite_conn() 61 | SessionLocal = sessionmaker(bind=engine) 62 | 63 | try: 64 | Base.metadata.create_all(bind=engine) 65 | except Exception as e: 66 | logger.error(f"error creating tables: {e} for engine {engine.url}") 67 | raise e 68 | 69 | 70 | class WithDB: 71 | @staticmethod 72 | def get_db(): 73 | """Get a database connection 74 | 75 | Example: 76 | ``` 77 | for session in self.get_db(): 78 | session.add(foo) 79 | ``` 80 | """ 81 | db = SessionLocal() 82 | try: 83 | yield db 84 | finally: 85 | db.close() 86 | 87 | 88 | def get_db(): 89 | """Get a database connection 90 | 91 | Example: 92 | ``` 93 | for session in get_db(): 94 | session.add(foo) 95 | ``` 96 | """ 97 | db = SessionLocal() 98 | try: 99 | yield db 100 | finally: 101 | db.close() 102 | -------------------------------------------------------------------------------- /taskara/db/models.py: -------------------------------------------------------------------------------- 1 | import time 2 | 3 | import shortuuid 4 | from sqlalchemy import ( 5 | Boolean, 6 | Column, 7 | Float, 8 | ForeignKey, 9 | Index, 10 | Integer, 11 | String, 12 | Table, 13 | Text, 14 | ) 15 | from sqlalchemy.orm import declarative_base, relationship 16 | 17 | Base = declarative_base() 18 | 19 | benchmark_task_association = Table( 20 | "benchmark_task_association", 21 | Base.metadata, 22 | Column("benchmark_id", String, ForeignKey("benchmarks.id"), primary_key=True), 23 | Column( 24 | "task_template_id", String, ForeignKey("task_templates.id"), primary_key=True 25 | ), 26 | ) 27 | 28 | eval_task_association = Table( 29 | "eval_task_association", 30 | Base.metadata, 31 | Column("eval_id", String, ForeignKey("evals.id"), primary_key=True), 32 | Column("task_id", String, ForeignKey("tasks.id"), primary_key=True), 33 | ) 34 | 35 | # Association table for many-to-many between tasks and tags 36 | task_tag_association = Table( 37 | "task_tag_association", 38 | Base.metadata, 39 | Column("task_id", String, ForeignKey("tasks.id"), primary_key=True), 40 | Column("tag_id", String, ForeignKey("tags.id"), primary_key=True), 41 | ) 42 | 43 | # Association table for many-to-many between tasks and labels 44 | task_label_association = Table( 45 | "task_label_association", 46 | Base.metadata, 47 | Column("task_id", String, ForeignKey("tasks.id"), primary_key=True), 48 | Column("label_id", String, ForeignKey("labels.id"), primary_key=True), 49 | ) 50 | 51 | 52 | class TagRecord(Base): 53 | __tablename__ = "tags" 54 | __table_args__ = (Index("idx_tags_tag", "tag"),) 55 | id = Column(String, primary_key=True, default=lambda: shortuuid.uuid()) 56 | tag = Column(String, unique=True, nullable=False) 57 | 58 | 59 | class LabelRecord(Base): 60 | __tablename__ = "labels" 61 | __table_args__ = (Index("idx_labels_key_value", "key", "value"),) 62 | id = Column(String, primary_key=True, default=lambda: shortuuid.uuid()) 63 | key = Column(String, nullable=False) 64 | value = Column(String, nullable=False) 65 | 66 | 67 | class TaskRecord(Base): 68 | __tablename__ = "tasks" 69 | __table_args__ = ( 70 | Index("idx_tasks_owner_id", "owner_id"), 71 | Index("idx_tasks_status", "status"), 72 | Index("idx_tasks_owner_id_skill", "owner_id", "skill"), 73 | Index("idx_tasks_skill_assigned_type", "skill", "assigned_type"), 74 | Index("idx_tasks_skill_assigned_type_completed", "skill", "assigned_type", "completed"), 75 | ) 76 | id = Column(String, primary_key=True) 77 | owner_id = Column(String, nullable=True) 78 | description = Column(String, nullable=False) 79 | max_steps = Column(Integer, nullable=False, default=30) 80 | device = Column(String, nullable=True) 81 | device_type = Column(String, nullable=True) 82 | project = Column(String, nullable=True) 83 | expect = Column(String, nullable=True) 84 | assigned_to = Column(String, nullable=True) 85 | assigned_type = Column(String, nullable=True) 86 | reviews = Column(Text, nullable=True) 87 | review_requirements = Column(Text, nullable=True) 88 | status = Column(String, nullable=False) 89 | created = Column(Float, nullable=False) 90 | started = Column(Float, nullable=False, default=0.0) 91 | completed = Column(Float, nullable=False, default=0.0) 92 | created_by = Column(String, nullable=False) 93 | error = Column(String, nullable=True) 94 | output = Column(String, nullable=True) 95 | threads = Column(String, nullable=False) 96 | prompts = Column(String, nullable=True) 97 | parent_id = Column(String, nullable=True) 98 | parameters = Column(String, nullable=True) 99 | version = Column(String, nullable=True) 100 | public = Column(Boolean, nullable=False, default=False) 101 | skill = Column(String, nullable=True) 102 | episode_id = Column(String, nullable=True) 103 | 104 | tags = relationship("TagRecord", secondary=task_tag_association, backref="tasks") 105 | labels = relationship( 106 | "LabelRecord", secondary=task_label_association, backref="tasks" 107 | ) 108 | 109 | 110 | class ReviewRequirementRecord(Base): 111 | __tablename__ = "review_requirements" 112 | __table_args__ = (Index("idx_review_req_task_id", "task_id"),) 113 | id = Column(String, primary_key=True) 114 | task_id = Column(String, nullable=True) 115 | number_required = Column(Integer, nullable=False) 116 | users = Column( 117 | Text, nullable=True 118 | ) # We need to split these apart for better lookups 119 | agents = Column(Text, nullable=True) 120 | groups = Column(Text, nullable=True) 121 | types = Column(Text, nullable=True) 122 | created = Column(Float, default=time.time) 123 | updated = Column(Float, nullable=True) 124 | 125 | 126 | class TaskTemplateRecord(Base): 127 | __tablename__ = "task_templates" 128 | 129 | id = Column(String, primary_key=True) 130 | owner_id = Column(String, nullable=True) 131 | description = Column(String, nullable=False) 132 | max_steps = Column(Integer, nullable=False, default=30) 133 | device = Column(String, nullable=True) 134 | device_type = Column(String, nullable=True) 135 | expect = Column(String, nullable=True) 136 | parameters = Column(String, nullable=True) 137 | tags = Column(String, nullable=True) 138 | labels = Column(String, nullable=True) 139 | created = Column(Float, default=time.time) 140 | 141 | benchmarks = relationship( 142 | "BenchmarkRecord", 143 | secondary=benchmark_task_association, 144 | back_populates="task_templates", 145 | ) 146 | 147 | 148 | class BenchmarkRecord(Base): 149 | __tablename__ = "benchmarks" 150 | 151 | id = Column(String, primary_key=True) 152 | owner_id = Column(String, nullable=True) 153 | name = Column(String, unique=True, index=True) 154 | description = Column(String, nullable=False) 155 | public = Column(Boolean, default=False) 156 | tags = Column(String, nullable=True) 157 | labels = Column(String, nullable=True) 158 | created = Column(Float, default=time.time) 159 | 160 | task_templates = relationship( 161 | "TaskTemplateRecord", 162 | secondary=benchmark_task_association, 163 | back_populates="benchmarks", 164 | ) 165 | 166 | 167 | class EvalRecord(Base): 168 | __tablename__ = "evals" 169 | 170 | id = Column(String, primary_key=True) 171 | owner_id = Column(String, nullable=True) 172 | benchmark_id = Column(String, ForeignKey("benchmarks.id")) 173 | assigned_to = Column(String, nullable=True) 174 | assigned_type = Column(String, nullable=True) 175 | created = Column(Float, default=time.time) 176 | 177 | benchmark = relationship("BenchmarkRecord") 178 | tasks = relationship( 179 | "TaskRecord", 180 | secondary=eval_task_association, 181 | ) 182 | 183 | 184 | class TrackerRecord(Base): 185 | __tablename__ = "trackers" 186 | 187 | id = Column(String, primary_key=True) 188 | name = Column(String, unique=True, index=True) 189 | runtime_name = Column(String) 190 | runtime_config = Column(String) 191 | status = Column(String) 192 | port = Column(Integer) 193 | owner_id = Column(String, nullable=True) 194 | labels = Column(String, nullable=True) 195 | created = Column(Float, default=time.time) 196 | updated = Column(Float, default=time.time) 197 | 198 | 199 | class FlagRecord(Base): 200 | __tablename__ = "flags" 201 | 202 | id = Column(String, primary_key=True) 203 | type = Column(String) 204 | flag = Column(Text) 205 | result = Column(Text, nullable=True) 206 | created = Column(Float, default=time.time) 207 | 208 | 209 | class PendingReviewersRecord(Base): 210 | __tablename__ = "pending_reviewers" 211 | __table_args__ = (Index("idx_pending_reviewers_task_id", "task_id"),) 212 | id = Column(String, primary_key=True) 213 | task_id = Column(String, nullable=True) 214 | user_id = Column(String, nullable=True) 215 | agent_id = Column(String, nullable=True) 216 | requirement_id = Column(String, nullable=True) 217 | -------------------------------------------------------------------------------- /taskara/db/redis_connection.py: -------------------------------------------------------------------------------- 1 | import redis.asyncio as redis 2 | import os 3 | 4 | # Global variable to hold the Redis connection pool and client 5 | # TODO Need to mock redis streams and async functionality or check if Fake Redis supports it 6 | redis_pool: redis.ConnectionPool | None = None 7 | redis_client: redis.Redis | None = None 8 | redis_url = os.environ.get("REDIS_CACHE_STORAGE", None) 9 | 10 | # store stream names here 11 | stream_action_recorded = "events:action_recorded" 12 | annotation_failures_dlq = "dlq:annotation_failures" 13 | stream_training_completed = "events:training_completed" 14 | 15 | async def init_redis_pool(): 16 | """Initialize the Redis connection pool.""" 17 | global redis_pool, redis_client 18 | 19 | if redis_url: 20 | # Create a Redis connection pool 21 | redis_pool = redis.ConnectionPool.from_url(redis_url, max_connections=10) 22 | print("Redis connection pool initialized.", flush=True) 23 | else: 24 | print("No Redis URL", flush=True) 25 | 26 | def get_redis_client(): 27 | """Get a Redis client from the connection pool.""" 28 | global redis_pool 29 | if not redis_pool: 30 | # TODO need to have proper mocking of redis in order to do this 31 | # raise ValueError("Redis connection pool is not initialized or redis connection details don't exist. Call init_redis_pool() first.") 32 | return None 33 | return redis.Redis(connection_pool=redis_pool) 34 | 35 | async def close_redis_pool(): 36 | global redis_pool, redis_client 37 | if redis_pool: 38 | await redis_pool.disconnect() 39 | redis_pool = None 40 | redis_client = None -------------------------------------------------------------------------------- /taskara/env.py: -------------------------------------------------------------------------------- 1 | HUB_SERVER_ENV = "SURF_HUB_SERVER" 2 | AGENTD_ADDR_ENV = "AGENTD_ADDR" 3 | AGENTD_PRIVATE_SSH_KEY_ENV = "AGENTD_PRIVATE_SSH_KEY" 4 | HUB_API_KEY_ENV = "HUB_API_KEY" 5 | AGENTSEA_HUB_URL_ENV = "AGENTSEA_HUB_URL" 6 | AGENTSEA_HUB_API_URL_ENV = "AGENTSEA_HUB_API_URL" 7 | AGENTSEA_AUTH_URL_ENV = "AGENTSEA_AUTH_URL" 8 | STORAGE_BUCKET_ENV = "SKILL_STORAGE_BUCKET" 9 | STORAGE_SA_JSON_ENV = "SKILL_STORAGE_SA_JSON" 10 | -------------------------------------------------------------------------------- /taskara/flag.py: -------------------------------------------------------------------------------- 1 | from typing import TypeVar, Optional, Generic, Type, Dict, List 2 | from abc import ABC, abstractmethod 3 | import time 4 | import json 5 | 6 | import shortuuid 7 | from pydantic import BaseModel 8 | 9 | from taskara.server.models import V1BoundingBox, V1BoundingBoxFlag, V1Flag 10 | from taskara.db.models import FlagRecord 11 | from taskara.db.conn import WithDB 12 | 13 | FlagResult = TypeVar("FlagResult", bound="BaseModel") 14 | FlagModel = TypeVar("FlagModel", bound="BaseModel") 15 | FlagType = TypeVar("FlagType", bound="Flag") 16 | 17 | 18 | class Flag(Generic[FlagResult, FlagModel, FlagType], ABC, WithDB): 19 | """A flag for human review""" 20 | 21 | def __init__(self) -> None: 22 | self.id = shortuuid.uuid() 23 | self.result: Optional[FlagResult] = None 24 | self.created = time.time() 25 | self.result = None 26 | 27 | def set_result(self, result: FlagResult): 28 | self.result = result 29 | 30 | @classmethod 31 | @abstractmethod 32 | def result_type(cls) -> Type[FlagResult]: 33 | pass 34 | 35 | @classmethod 36 | @abstractmethod 37 | def v1_type(cls) -> Type[FlagModel]: 38 | pass 39 | 40 | @abstractmethod 41 | def to_v1(cls) -> FlagModel: 42 | pass 43 | 44 | @classmethod 45 | @abstractmethod 46 | def from_v1(cls, v1: FlagModel) -> FlagType: 47 | pass 48 | 49 | def to_v1flag(self) -> V1Flag: 50 | return V1Flag( 51 | type=self.__class__.__name__, 52 | id=self.id, 53 | flag=self.to_v1().model_dump(), 54 | result=self.result.model_dump() if self.result else None, 55 | created=self.created, 56 | ) 57 | 58 | def to_record(self) -> FlagRecord: 59 | return FlagRecord( 60 | id=self.id, 61 | type=self.__class__.__name__, 62 | flag=self.to_v1().model_dump_json(), 63 | result=self.result.model_dump_json() if self.result else None, 64 | created=self.created, 65 | ) 66 | 67 | @classmethod 68 | def from_record(cls, record: FlagRecord) -> FlagType: 69 | # Deserialize the flag JSON to a FlagModel (like V1BoundingBoxFlag) 70 | flag_model = cls.v1_type().model_validate_json(str(record.flag)) 71 | 72 | # Use the from_v1 method to create the specific FlagType (like BoundingBoxFlag) 73 | instance = cls.from_v1(flag_model) 74 | 75 | # Set additional fields from the record 76 | instance.id = record.id 77 | instance.result = ( # type: ignore 78 | cls.result_type().model_validate_json(str(record.result)) 79 | if record.result # type: ignore 80 | else None 81 | ) 82 | instance.created = record.created 83 | 84 | return instance 85 | 86 | def save(self) -> None: 87 | for db in self.get_db(): 88 | db.add(self.to_record()) 89 | db.commit() 90 | 91 | @classmethod 92 | def find(cls, **kwargs) -> List[FlagType]: 93 | for db in cls.get_db(): 94 | records = ( 95 | db.query(FlagRecord) 96 | .filter_by(type=cls.__name__, **kwargs) 97 | .order_by(FlagRecord.created.desc()) 98 | .all() 99 | ) 100 | return [cls.from_record(record) for record in records] 101 | return [] 102 | 103 | @classmethod 104 | def find_v1(cls, **kwargs) -> List[V1Flag]: 105 | for db in cls.get_db(): 106 | records = ( 107 | db.query(FlagRecord) 108 | .filter_by(**kwargs) 109 | .order_by(FlagRecord.created.desc()) 110 | .all() 111 | ) 112 | return [ 113 | V1Flag( 114 | type=str(record.type), 115 | id=str(record.id), 116 | flag=json.loads(str(record.flag)), 117 | result=json.loads(str(record.result)) if record.result else None, # type: ignore 118 | created=record.created, # type: ignore 119 | ) 120 | for record in records 121 | ] 122 | return [] 123 | 124 | 125 | class BoundingBoxFlag(Flag[V1BoundingBox, V1BoundingBoxFlag, "BoundingBoxFlag"]): 126 | """Bounding box flag""" 127 | 128 | def __init__( 129 | self, 130 | img: str, 131 | target: str, 132 | bbox: V1BoundingBox, 133 | metadata: Optional[Dict[str, str]] = None, 134 | ): 135 | super().__init__() 136 | self.img = img 137 | self.target = target 138 | self.bbox = bbox 139 | self.metadata = metadata 140 | 141 | @classmethod 142 | def result_type(cls) -> Type[V1BoundingBox]: 143 | return V1BoundingBox 144 | 145 | @classmethod 146 | def v1_type(cls) -> Type[V1BoundingBoxFlag]: 147 | return V1BoundingBoxFlag 148 | 149 | def to_v1(self) -> V1BoundingBoxFlag: 150 | return V1BoundingBoxFlag( 151 | img=self.img, 152 | target=self.target, 153 | bbox=self.bbox, 154 | ) 155 | 156 | @classmethod 157 | def from_v1(cls, v1: V1BoundingBoxFlag) -> "BoundingBoxFlag": 158 | out = cls.__new__(cls) 159 | out.img = v1.img 160 | out.target = v1.target 161 | out.bbox = v1.bbox 162 | return out 163 | -------------------------------------------------------------------------------- /taskara/img.py: -------------------------------------------------------------------------------- 1 | import asyncio 2 | import base64 3 | import re 4 | from io import BytesIO 5 | import mimetypes 6 | import os 7 | import secrets 8 | import string 9 | import tempfile 10 | from typing import List, Sequence 11 | 12 | from google.cloud import storage 13 | from PIL import Image 14 | 15 | from .env import STORAGE_BUCKET_ENV, STORAGE_SA_JSON_ENV 16 | 17 | 18 | def image_to_b64(img: Image.Image, image_format="PNG") -> str: 19 | """Converts a PIL Image to a base64-encoded string with MIME type included. 20 | 21 | Args: 22 | img (Image.Image): The PIL Image object to convert. 23 | image_format (str): The format to use when saving the image (e.g., 'PNG', 'JPEG'). 24 | 25 | Returns: 26 | str: A base64-encoded string of the image with MIME type. 27 | """ 28 | buffer = BytesIO() 29 | img.save(buffer, format=image_format) 30 | image_data = buffer.getvalue() 31 | buffer.close() 32 | 33 | mime_type = f"image/{image_format.lower()}" 34 | base64_encoded_data = base64.b64encode(image_data).decode("utf-8") 35 | return f"data:{mime_type};base64,{base64_encoded_data}" 36 | 37 | 38 | def b64_to_image(base64_str: str) -> Image.Image: 39 | """Converts a base64 string to a PIL Image object. 40 | 41 | Args: 42 | base64_str (str): The base64 string, potentially with MIME type as part of a data URI. 43 | 44 | Returns: 45 | Image.Image: The converted PIL Image object. 46 | """ 47 | # Strip the MIME type prefix if present 48 | if "," in base64_str: 49 | base64_str = base64_str.split(",")[1] 50 | 51 | image_data = base64.b64decode(base64_str) 52 | image = Image.open(BytesIO(image_data)) 53 | return image 54 | 55 | 56 | def parse_image_data(image_data_str: str): 57 | """Parses the image data URL to extract the MIME type and base64 data.""" 58 | data_url_pattern = re.compile( 59 | r"data:(?P[^;]+);base64,(?P.+)" 60 | ) 61 | match = data_url_pattern.match(image_data_str) 62 | if not match: 63 | raise ValueError("Invalid image data format") 64 | mime_type = match.group("mime_type") 65 | base64_data = match.group("base64_data") 66 | return mime_type, base64_data 67 | 68 | 69 | def generate_random_suffix(length: int = 24) -> str: 70 | """Generates a random suffix for the image file name.""" 71 | return "".join( 72 | secrets.choice(string.ascii_letters + string.digits) for _ in range(length) 73 | ) 74 | 75 | 76 | def upload_image_to_gcs(image_data: bytes, mime_type: str) -> str: 77 | """Uploads an image to Google Cloud Storage and returns the public URL.""" 78 | sa_json = os.getenv(STORAGE_SA_JSON_ENV) 79 | if not sa_json: 80 | raise ValueError(f"Environment variable {STORAGE_SA_JSON_ENV} not set") 81 | 82 | # Check if the service account JSON is a path or a JSON string 83 | if sa_json.startswith("{"): 84 | # Assume it's a JSON string, write to a temporary file 85 | with tempfile.NamedTemporaryFile(delete=False, suffix=".json") as temp_file: 86 | temp_file.write(sa_json.encode()) 87 | temp_file_name = temp_file.name 88 | else: 89 | # Assume it's a path to a JSON file 90 | temp_file_name = sa_json 91 | 92 | storage_client = storage.Client.from_service_account_json(temp_file_name) 93 | 94 | bucket_name = os.getenv(STORAGE_BUCKET_ENV) 95 | if not bucket_name: 96 | raise ValueError(f"Environment variable {STORAGE_BUCKET_ENV} not set") 97 | 98 | bucket = storage_client.bucket(bucket_name) 99 | 100 | random_suffix = generate_random_suffix() 101 | extension = mimetypes.guess_extension(mime_type) 102 | blob_name = f"images/{random_suffix}{extension}" 103 | blob = bucket.blob(blob_name) 104 | 105 | # Create a temporary file to write the image data 106 | with tempfile.NamedTemporaryFile(delete=False, suffix=extension) as temp_file: 107 | temp_file.write(image_data) 108 | temp_file_name = temp_file.name 109 | 110 | # Upload the temporary file to Google Cloud Storage 111 | blob.upload_from_filename(temp_file_name) 112 | blob.content_type = mime_type 113 | blob.make_public() 114 | 115 | # Delete the temporary file 116 | os.remove(temp_file_name) 117 | 118 | return blob.public_url 119 | 120 | 121 | def convert_images(images: Sequence[str | Image.Image]) -> List[str]: 122 | sa = os.getenv(STORAGE_SA_JSON_ENV) 123 | new_imgs: List[str] = [] 124 | if sa: 125 | for img in images: 126 | if isinstance(img, Image.Image): 127 | new_imgs.append(image_to_b64(img)) 128 | elif isinstance(img, str): 129 | if img.startswith("data:"): 130 | mime_type, base64_data = parse_image_data(img) 131 | image_data = base64.b64decode(base64_data) 132 | public_url = upload_image_to_gcs(image_data, mime_type) 133 | new_imgs.append(public_url) 134 | elif img.startswith("https://"): 135 | new_imgs.append(img) 136 | else: 137 | loaded_img = Image.open(img) 138 | b64_img = image_to_b64(loaded_img) 139 | mime_type, base64_data = parse_image_data(b64_img) 140 | image_data = base64.b64decode(base64_data) 141 | public_url = upload_image_to_gcs(image_data, mime_type) 142 | new_imgs.append(public_url) 143 | else: 144 | raise ValueError("unnknown image type") 145 | else: 146 | for img in images: 147 | if isinstance(img, Image.Image): 148 | new_imgs.append(image_to_b64(img)) 149 | elif img.startswith("data:") or img.startswith("https://"): 150 | new_imgs.append(img) 151 | else: 152 | loaded_img = Image.open(img) 153 | b64_img = image_to_b64(loaded_img) 154 | new_imgs.append(b64_img) 155 | 156 | return new_imgs 157 | 158 | 159 | async def process_image_async(img: str | Image.Image) -> str: 160 | """Async function to process a single image when service account is set.""" 161 | if isinstance(img, Image.Image): 162 | return image_to_b64(img) 163 | elif isinstance(img, str): 164 | if img.startswith("data:"): 165 | print("convert_images function, uploading img to gcs", flush=True) 166 | mime_type, base64_data = parse_image_data(img) 167 | image_data = base64.b64decode(base64_data) 168 | # Offload blocking upload to a thread 169 | return await asyncio.to_thread(upload_image_to_gcs, image_data, mime_type) 170 | elif img.startswith("https://"): 171 | return img 172 | else: 173 | print("convert_images function, uploading img to gcs in else", flush=True) 174 | # Load image in a thread 175 | loaded_img = await asyncio.to_thread(Image.open, img) 176 | b64_img = image_to_b64(loaded_img) 177 | mime_type, base64_data = parse_image_data(b64_img) 178 | image_data = base64.b64decode(base64_data) 179 | return await asyncio.to_thread(upload_image_to_gcs, image_data, mime_type) 180 | else: 181 | raise ValueError("unknown image type") 182 | 183 | 184 | async def convert_images_async(images: Sequence[str | Image.Image]) -> List[str]: 185 | sa = os.getenv(STORAGE_SA_JSON_ENV) 186 | if sa: 187 | tasks = [process_image_async(img) for img in images] 188 | results = await asyncio.gather(*tasks) 189 | return list(results) 190 | else: 191 | # No service account, run synchronously 192 | new_imgs: List[str] = [] 193 | for img in images: 194 | print("convert_images function, proceeding with non gcs key path", flush=True) 195 | if isinstance(img, Image.Image): 196 | new_imgs.append(image_to_b64(img)) 197 | elif isinstance(img, str): 198 | if img.startswith("data:") or img.startswith("https://"): 199 | new_imgs.append(img) 200 | else: 201 | loaded_img = Image.open(img) 202 | b64_img = image_to_b64(loaded_img) 203 | new_imgs.append(b64_img) 204 | else: 205 | raise ValueError("unknown image type") 206 | return new_imgs -------------------------------------------------------------------------------- /taskara/metrics.py: -------------------------------------------------------------------------------- 1 | import time 2 | 3 | 4 | class MetricsAggregator: 5 | def __init__(self): 6 | self.timings = {} 7 | self.counts = {} 8 | 9 | def start_timer(self, key): 10 | self.timings[key] = self.timings.get(key, []) 11 | self.timings[key].append(-time.time()) 12 | 13 | def stop_timer(self, key): 14 | self.timings[key][-1] += time.time() 15 | 16 | def increment_count(self, key, count=1): 17 | if key not in self.counts: 18 | self.counts[key] = 0 19 | self.counts[key] += count 20 | 21 | def get_timing_stats(self, key): 22 | times = self.timings.get(key, []) 23 | if not times: 24 | return None 25 | return { 26 | "count": len(times), 27 | "total": sum(times), 28 | "avg": sum(times) / len(times), 29 | "min": min(times), 30 | "max": max(times), 31 | } 32 | 33 | def get_count(self, key): 34 | return self.counts.get(key, 0) 35 | 36 | def report(self): 37 | report = {} 38 | for key, times in self.timings.items(): 39 | report[key] = self.get_timing_stats(key) 40 | for key, count in self.counts.items(): 41 | report[f"{key}_count"] = count 42 | return report 43 | 44 | 45 | # metrics = MetricsAggregator() 46 | 47 | # # Example of timing a code section 48 | # metrics.start_timer('process_data') 49 | # # Simulate data processing 50 | # time.sleep(0.5) 51 | # metrics.stop_timer('process_data') 52 | 53 | # # Counting occurrences 54 | # metrics.increment_count('api_calls') 55 | # metrics.increment_count('api_calls') 56 | # metrics.increment_count('api_errors') 57 | 58 | # # Getting aggregated results 59 | # print(metrics.report()) 60 | -------------------------------------------------------------------------------- /taskara/review.py: -------------------------------------------------------------------------------- 1 | from typing import Dict, Optional, List 2 | import shortuuid 3 | import time 4 | import json 5 | 6 | from skillpacks import Review 7 | 8 | from taskara.db.conn import WithDB 9 | from taskara.db.models import ReviewRequirementRecord, PendingReviewersRecord 10 | from taskara.server.models import ( 11 | V1ReviewRequirement, 12 | V1PendingReviewers, 13 | V1PendingReviews, 14 | ) 15 | 16 | 17 | class PendingReviewers(WithDB): 18 | """A pending review requirement for a task""" 19 | 20 | def pending_reviewers( 21 | self, task_id: str, requirement_id: Optional[str] = None 22 | ) -> V1PendingReviewers: 23 | """Get the pending reviewers for a task, optionally filtered by requirement_id""" 24 | 25 | for db in self.get_db(): 26 | # Start by filtering by task_id 27 | query = db.query(PendingReviewersRecord).filter_by(task_id=task_id) 28 | 29 | # If requirement_id is provided, filter by it as well 30 | if requirement_id: 31 | query = query.filter_by(requirement_id=requirement_id) 32 | 33 | # Get all matching records 34 | records = query.all() 35 | 36 | # Extract users and agents 37 | users = set([str(record.user_id) for record in records if record.user_id]) # type: ignore 38 | agents = set( 39 | [str(record.agent_id) for record in records if record.agent_id] # type: ignore 40 | ) 41 | 42 | # Return the V1PendingReviewers with the filtered data 43 | return V1PendingReviewers( 44 | task_id=task_id, 45 | users=list(users) if users else [], 46 | agents=list(agents) if agents else [], 47 | ) 48 | 49 | raise SystemError("no session") 50 | 51 | def pending_reviews( 52 | self, user: Optional[str] = None, agent: Optional[str] = None 53 | ) -> V1PendingReviews: 54 | """Get the pending reviews for a user or agent""" 55 | 56 | for db in self.get_db(): 57 | query = db.query(PendingReviewersRecord) 58 | 59 | if user: 60 | query = query.filter(PendingReviewersRecord.user_id == user) 61 | if agent: 62 | query = query.filter(PendingReviewersRecord.agent_id == agent) 63 | 64 | records = query.all() 65 | 66 | tasks = list(set([str(record.task_id) for record in records])) 67 | 68 | return V1PendingReviews(tasks=tasks) 69 | 70 | raise SystemError("no session") 71 | 72 | def ensure_pending_reviewer( 73 | self, 74 | task_id: str, 75 | user: Optional[str] = None, 76 | agent: Optional[str] = None, 77 | requirement_id: Optional[str] = None, 78 | ) -> None: 79 | """Add a pending reviewer for a task, avoiding duplicates""" 80 | 81 | if not user and not agent: 82 | raise ValueError("Either user or agent must be provided") 83 | 84 | for db in self.get_db(): 85 | # Check if the record already exists 86 | query = db.query(PendingReviewersRecord).filter_by(task_id=task_id) 87 | if user: 88 | query = query.filter_by(user_id=user) 89 | if agent: 90 | query = query.filter_by(agent_id=agent) 91 | if requirement_id: 92 | query = query.filter_by(requirement_id=requirement_id) 93 | 94 | existing_record = query.first() 95 | 96 | if existing_record: 97 | # If the record exists, return early to avoid duplicates 98 | return 99 | 100 | # Add the new record if it doesn't already exist 101 | new_record = PendingReviewersRecord( 102 | id=shortuuid.uuid(), 103 | task_id=task_id, 104 | user_id=user, 105 | agent_id=agent, 106 | requirement_id=requirement_id, 107 | ) 108 | db.add(new_record) 109 | db.commit() 110 | 111 | def remove_pending_reviewer( 112 | self, 113 | task_id: str, 114 | user: Optional[str] = None, 115 | agent: Optional[str] = None, 116 | requirement_id: Optional[str] = None, 117 | ) -> None: 118 | """Remove a pending reviewer for a task""" 119 | 120 | if not user and not agent: 121 | raise ValueError("Either user or agent must be provided") 122 | 123 | for db in self.get_db(): 124 | query = db.query(PendingReviewersRecord).filter_by(task_id=task_id) 125 | if user: 126 | query = query.filter_by(user_id=user) 127 | if agent: 128 | query = query.filter_by(agent_id=agent) 129 | if requirement_id: 130 | query = query.filter_by(requirement_id=requirement_id) 131 | 132 | record = query.first() 133 | if record: 134 | db.delete(record) 135 | db.commit() 136 | 137 | def task_is_pending(self, task_id: str) -> bool: 138 | """Check if a task has pending reviewers""" 139 | 140 | for db in self.get_db(): 141 | record = db.query(PendingReviewersRecord).filter_by(task_id=task_id).first() 142 | return bool(record) 143 | raise SystemError("no session") 144 | 145 | def pending_tasks(self, user: Optional[str] = None) -> List[str]: 146 | """Get the pending tasks for a user or agent""" 147 | 148 | for db in self.get_db(): 149 | query = db.query(PendingReviewersRecord) 150 | if user: 151 | query = query.filter_by(user_id=user) 152 | tasks = [str(record.task_id) for record in query.all()] 153 | return tasks 154 | 155 | raise SystemError("no session") 156 | 157 | 158 | class ReviewRequirement(WithDB): 159 | """A review requirement for a task""" 160 | 161 | def __init__( 162 | self, 163 | task_id: Optional[str] = None, 164 | number_required: int = 2, 165 | users: Optional[List[str]] = None, 166 | agents: Optional[List[str]] = None, 167 | groups: Optional[List[str]] = None, 168 | types: Optional[List[str]] = None, 169 | created: Optional[float] = None, 170 | updated: Optional[float] = None, 171 | ) -> None: 172 | self.id = str(shortuuid.uuid()) 173 | self.task_id = task_id 174 | self.number_required = number_required 175 | self.users = users or [] 176 | self.agents = agents or [] 177 | self.groups = groups or [] 178 | self.types = types or [] 179 | self.created = created or time.time() 180 | self.updated = updated 181 | 182 | def to_v1(self) -> V1ReviewRequirement: 183 | return V1ReviewRequirement( 184 | id=self.id, 185 | task_id=self.task_id, 186 | users=self.users, 187 | agents=self.agents, 188 | groups=self.groups, 189 | number_required=self.number_required, 190 | ) 191 | 192 | @classmethod 193 | def from_v1(cls, v1: V1ReviewRequirement) -> "ReviewRequirement": 194 | out = cls.__new__(cls) 195 | out.id = v1.id 196 | out.task_id = v1.task_id 197 | out.number_required = v1.number_required 198 | out.users = v1.users 199 | out.agents = v1.agents 200 | out.groups = v1.groups 201 | 202 | return out 203 | 204 | def save(self) -> None: 205 | """Saves the review requirement to the database.""" 206 | 207 | for db in self.get_db(): 208 | record = self.to_record() 209 | db.merge(record) 210 | db.commit() 211 | 212 | def delete(self) -> None: 213 | """Deletes the review requirement from the database.""" 214 | 215 | for db in self.get_db(): 216 | record = ( 217 | db.query(ReviewRequirementRecord) 218 | .filter(ReviewRequirementRecord.id == self.id) 219 | .first() 220 | ) 221 | if record: 222 | db.delete(record) 223 | db.commit() 224 | else: 225 | raise ValueError("Review requirement not found") 226 | 227 | def to_record(self) -> ReviewRequirementRecord: 228 | """Converts the review requirement to a database record.""" 229 | 230 | return ReviewRequirementRecord( 231 | id=self.id, 232 | task_id=self.task_id, 233 | number_required=self.number_required, 234 | users=json.dumps(self.users), 235 | agents=json.dumps(self.agents), 236 | groups=json.dumps(self.groups), 237 | types=json.dumps(self.types), 238 | created=self.created, 239 | updated=self.updated, 240 | ) 241 | 242 | @classmethod 243 | def from_record(cls, record: ReviewRequirementRecord) -> "ReviewRequirement": 244 | """Creates a review requirement instance from a database record.""" 245 | 246 | review_requirement = cls.__new__(cls) 247 | review_requirement.id = record.id 248 | review_requirement.task_id = record.task_id 249 | review_requirement.number_required = record.number_required 250 | review_requirement.users = json.loads(record.users) # type: ignore 251 | review_requirement.agents = json.loads(record.agents) # type: ignore 252 | review_requirement.groups = json.loads(record.groups) # type: ignore 253 | review_requirement.types = json.loads(record.types) # type: ignore 254 | review_requirement.created = record.created 255 | review_requirement.updated = record.updated 256 | return review_requirement 257 | 258 | @classmethod 259 | def find(cls, **kwargs) -> List["ReviewRequirement"]: 260 | """Finds review requirements in the database based on provided filters.""" 261 | 262 | for db in cls.get_db(): 263 | records = db.query(ReviewRequirementRecord).filter_by(**kwargs).all() 264 | return [cls.from_record(record) for record in records] 265 | raise ValueError("No database session available") 266 | 267 | @classmethod 268 | def find_many(cls, 269 | task_ids: Optional[List[str]] = None, 270 | requirement_ids: Optional[List[str]] = None, 271 | ) -> List["ReviewRequirement"]: 272 | """Finds review requirements in the database based on provided filters.""" 273 | for db in cls.get_db(): 274 | query = db.query(ReviewRequirementRecord) 275 | if task_ids: 276 | query = query.filter(ReviewRequirementRecord.task_id.in_(task_ids)) 277 | if requirement_ids: 278 | query = query.filter(ReviewRequirementRecord.id.in_(requirement_ids)) 279 | records = db.query(ReviewRequirementRecord).all() 280 | return [cls.from_record(record) for record in records] 281 | raise ValueError("No database session available") -------------------------------------------------------------------------------- /taskara/runtime/base.py: -------------------------------------------------------------------------------- 1 | import json 2 | import time 3 | from abc import ABC, abstractmethod 4 | from typing import Dict, Generic, Iterator, List, Optional, Tuple, Type, TypeVar, Union 5 | 6 | import shortuuid 7 | from pydantic import BaseModel 8 | 9 | from taskara.db.conn import WithDB 10 | from taskara.db.models import TrackerRecord 11 | from taskara.server.models import ( 12 | V1ResourceLimits, 13 | V1ResourceRequests, 14 | V1Tracker, 15 | V1TrackerRuntimeConnect, 16 | ) 17 | 18 | R = TypeVar("R", bound="TrackerRuntime") 19 | C = TypeVar("C", bound="BaseModel") 20 | 21 | 22 | class Tracker(WithDB): 23 | """A task server""" 24 | 25 | def __init__( 26 | self, 27 | name: str, 28 | port: int, 29 | runtime: "TrackerRuntime", 30 | status: str = "running", 31 | owner_id: Optional[str] = None, 32 | labels: Optional[Dict[str, str]] = None, 33 | ) -> None: 34 | self._id = shortuuid.uuid() 35 | self._name = name 36 | self._port = port 37 | self._status = status 38 | self._runtime = runtime 39 | self._owner_id = owner_id 40 | self._created = time.time() 41 | self._updated = time.time() 42 | self._labels = labels 43 | 44 | self.save() 45 | 46 | @property 47 | def id(self) -> str: 48 | return self._id 49 | 50 | @property 51 | def status(self) -> str: 52 | return self._status 53 | 54 | @property 55 | def name(self) -> str: 56 | return self._name 57 | 58 | @property 59 | def runtime(self) -> "TrackerRuntime": 60 | return self._runtime 61 | 62 | @property 63 | def port(self) -> int: 64 | return self._port 65 | 66 | @property 67 | def owner_id(self) -> Optional[str]: 68 | return self._owner_id 69 | 70 | @property 71 | def created(self) -> float: 72 | return self._created 73 | 74 | @property 75 | def updated(self) -> float: 76 | return self._updated 77 | 78 | @property 79 | def labels(self) -> Optional[Dict[str, str]]: 80 | return self._labels 81 | 82 | def proxy( 83 | self, 84 | local_port: Optional[int] = None, 85 | background: bool = True, 86 | ) -> Optional[int]: 87 | return self._runtime.proxy(self._name, local_port, self.port, background) 88 | 89 | def delete(self, force: bool = False) -> None: 90 | """ 91 | Deletes the server instance from the runtime and the database. 92 | """ 93 | # First, delete the server instance from the runtime. 94 | try: 95 | self._runtime.delete(self._name) 96 | except Exception as e: 97 | if not force: 98 | raise e 99 | 100 | # After the runtime deletion, proceed to delete the record from the database. 101 | for db in self.get_db(): 102 | record = db.query(TrackerRecord).filter_by(id=self._id).one() 103 | db.delete(record) 104 | db.commit() 105 | 106 | def logs(self, follow: bool = False) -> Union[str, Iterator[str]]: 107 | """ 108 | Fetches the logs from the specified pod. 109 | 110 | Parameters: 111 | follow (bool): If True, stream logs until the connection 112 | 113 | Returns: 114 | str: The logs from the pod. 115 | """ 116 | return self._runtime.logs(self._name, follow) 117 | 118 | def save(self) -> None: 119 | for db in self.get_db(): 120 | record = self.to_record() 121 | db.merge(record) 122 | db.commit() 123 | 124 | @classmethod 125 | def find(cls, **kwargs) -> List["Tracker"]: 126 | for db in cls.get_db(): 127 | records = ( 128 | db.query(TrackerRecord) 129 | .filter_by(**kwargs) 130 | .order_by(TrackerRecord.created.desc()) 131 | .all() 132 | ) 133 | return [cls.from_record(record) for record in records] 134 | raise ValueError("No session") 135 | 136 | @classmethod 137 | def active_runtimes(cls) -> List["TrackerRuntime"]: 138 | """Get all runtimes currently being used by a tracker 139 | 140 | Returns: 141 | List[TrackerRuntime]: a list of tracker runtimes 142 | """ 143 | trackers = cls.find() 144 | return [tracker.runtime for tracker in trackers] 145 | 146 | def to_v1(self) -> V1Tracker: 147 | """Convert to V1 API model""" 148 | return V1Tracker( 149 | name=self._name, 150 | runtime=V1TrackerRuntimeConnect( 151 | name=self._runtime.name(), connect_config=self.runtime.connect_config() 152 | ), 153 | port=self._port, 154 | status=self._status, 155 | owner_id=self._owner_id, 156 | created=self._created, 157 | updated=self._updated, 158 | labels=self._labels or {}, 159 | ) 160 | 161 | def to_record(self) -> TrackerRecord: 162 | """Convert to DB model""" 163 | runtime_cfg = self._runtime.connect_config().model_dump_json() 164 | 165 | return TrackerRecord( 166 | id=self._id, 167 | name=self._name, 168 | runtime_name=self._runtime.name(), 169 | runtime_config=runtime_cfg, 170 | port=self._port, 171 | status=self._status, 172 | owner_id=self._owner_id, 173 | created=self._created, 174 | updated=self._updated, 175 | labels=json.dumps(self._labels or {}), 176 | ) 177 | 178 | @classmethod 179 | def from_record(cls, record: TrackerRecord) -> "Tracker": 180 | from taskara.runtime.load import runtime_from_name 181 | 182 | runtype = runtime_from_name(str(record.runtime_name)) 183 | runcfg = runtype.connect_config_type().model_validate_json( 184 | str(record.runtime_config) 185 | ) 186 | runtime = runtype.connect(runcfg) 187 | 188 | obj = cls.__new__(cls) 189 | obj._id = str(record.id) 190 | obj._name = str(record.name) 191 | obj._runtime = runtime 192 | obj._status = record.status 193 | obj._port = record.port 194 | obj._owner_id = record.owner_id 195 | obj._created = record.created 196 | obj._updated = record.updated 197 | obj._labels = json.loads(record.labels) if record.labels else None # type: ignore 198 | 199 | return obj 200 | 201 | def call( 202 | self, 203 | path: str, 204 | method: str, 205 | data: Optional[dict] = None, 206 | headers: Optional[dict] = None, 207 | ) -> Tuple[int, str]: 208 | """Call the task server 209 | 210 | Args: 211 | path (str): Path to call 212 | method (str): Method to use 213 | data (Optional[dict], optional): Body data. Defaults to None. 214 | headers (Optional[dict], optional): Headers. Defaults to None. 215 | 216 | Returns: 217 | Tuple[int, str]: Status code and response text 218 | """ 219 | return self.runtime.call( 220 | name=self.name, 221 | path=path, 222 | method=method, 223 | port=self.port, 224 | data=data, 225 | headers=headers, 226 | ) 227 | 228 | 229 | class TrackerRuntime(Generic[R, C], ABC): 230 | 231 | @classmethod 232 | def name(cls) -> str: 233 | return cls.__name__ 234 | 235 | @classmethod 236 | @abstractmethod 237 | def connect_config_type(cls) -> Type[C]: 238 | """The pydantic model which defines the schema for connecting to this runtime 239 | 240 | Returns: 241 | Type[C]: The type 242 | """ 243 | pass 244 | 245 | @abstractmethod 246 | def connect_config(cls) -> C: 247 | """The connect config for this runtime instance 248 | 249 | Returns: 250 | C: Connect config 251 | """ 252 | pass 253 | 254 | @classmethod 255 | @abstractmethod 256 | def connect(cls, cfg: C) -> R: 257 | """Connect to the runtime using this configuration 258 | 259 | Args: 260 | cfg (C): Connect config 261 | 262 | Returns: 263 | R: A runtime 264 | """ 265 | pass 266 | 267 | @abstractmethod 268 | def run( 269 | self, 270 | name: str, 271 | env_vars: Optional[dict] = None, 272 | owner_id: Optional[str] = None, 273 | labels: Optional[Dict[str, str]] = None, 274 | resource_requests: V1ResourceRequests = V1ResourceRequests(), 275 | resource_limits: V1ResourceLimits = V1ResourceLimits(), 276 | auth_enabled: bool = True, 277 | ) -> Tracker: 278 | """Run the task server 279 | 280 | Args: 281 | name (str): Name of the task server 282 | env_vars (Optional[dict], optional): Env vars to supply. Defaults to None. 283 | owner_id (Optional[str], optional): Owner ID. Defaults to None. 284 | labels (Optional[Dict[str, str]], optional): Labels for the task server. Defaults to None. 285 | resource_requests (V1ResourceRequests, optional): Resource requests. Defaults to V1ResourceRequests(). 286 | resource_limits (V1ResourceLimits, optional): Resource limits. Defaults to V1ResourceLimits(). 287 | auth_enabled (bool, optional): Whether to enable auth. Defaults to True. 288 | 289 | Returns: 290 | Tracker: An task server instance 291 | """ 292 | pass 293 | 294 | @abstractmethod 295 | def list( 296 | self, owner_id: Optional[str] = None, source: bool = False 297 | ) -> List[Tracker]: 298 | """List task server instances 299 | 300 | Args: 301 | owner_id (Optional[str], optional): An optional owner id. Defaults to None. 302 | source (bool, optional): Whether to list directly from the source. Defaults to False. 303 | 304 | Returns: 305 | List[Tracker]: A list of task server instances 306 | """ 307 | pass 308 | 309 | @abstractmethod 310 | def get( 311 | self, name: str, owner_id: Optional[str] = None, source: bool = False 312 | ) -> Tracker: 313 | """Get an task server instance 314 | 315 | Args: 316 | name (str): Name of the task server 317 | owner_id (Optional[str], optional): Optional owner ID. Defaults to None. 318 | source (bool, optional): Whether to fetch directly from the source. Defaults to False. 319 | 320 | Returns: 321 | Tracker: An task server instance 322 | """ 323 | pass 324 | 325 | @abstractmethod 326 | def requires_proxy(self) -> bool: 327 | """Whether this runtime requires a proxy to be used""" 328 | pass 329 | 330 | @abstractmethod 331 | def proxy( 332 | self, 333 | name: str, 334 | local_port: Optional[int] = None, 335 | tracker_port: int = 9070, 336 | background: bool = True, 337 | owner_id: Optional[str] = None, 338 | ) -> Optional[int]: 339 | """Proxy a port to the task server 340 | 341 | Args: 342 | name (str): Name of the task server 343 | local_port (Optional[int], optional): Local port to proxy to. Defaults to None. 344 | tracker_port (int, optional): The task servers port. Defaults to 9070. 345 | background (bool, optional): Whether to run the proxy in the background. Defaults to True. 346 | owner_id (Optional[str], optional): An optional owner ID. Defaults to None. 347 | 348 | Returns: 349 | Optional[int]: The pid of the proxy 350 | """ 351 | pass 352 | 353 | @abstractmethod 354 | def delete(self, name: str, owner_id: Optional[str] = None) -> None: 355 | """Delete an task server instance 356 | 357 | Args: 358 | name (str): Name of the task server 359 | owner_id (Optional[str], optional): An optional owner id. Defaults to None. 360 | """ 361 | pass 362 | 363 | @abstractmethod 364 | def clean(self, owner_id: Optional[str] = None) -> None: 365 | """Delete all task server instances 366 | 367 | Args: 368 | owner_id (Optional[str], optional): An optional owner ID to scope it to. Defaults to None. 369 | """ 370 | pass 371 | 372 | @abstractmethod 373 | def logs( 374 | self, name: str, follow: bool = False, owner_id: Optional[str] = None 375 | ) -> Union[str, Iterator[str]]: 376 | """ 377 | Fetches the logs from the specified task server. 378 | 379 | Parameters: 380 | name (str): The name of the task server. 381 | 382 | Returns: 383 | str: The logs from the task server. 384 | """ 385 | pass 386 | 387 | @abstractmethod 388 | def call( 389 | self, 390 | name: str, 391 | path: str, 392 | method: str, 393 | port: int = 9070, 394 | data: Optional[dict] = None, 395 | headers: Optional[dict] = None, 396 | ) -> Tuple[int, str]: 397 | """Call the task server 398 | 399 | Args: 400 | name (str): Name of the server 401 | path (str): Path to call 402 | method (str): Method to use 403 | port (int, optional): Port to use. Defaults to 9070. 404 | data (Optional[dict], optional): Body data. Defaults to None. 405 | headers (Optional[dict], optional): Headers. Defaults to None. 406 | 407 | Returns: 408 | Tuple[int, str]: Status code and response text 409 | """ 410 | pass 411 | 412 | @abstractmethod 413 | def refresh(self, owner_id: Optional[str] = None) -> None: 414 | """Refresh the runtime 415 | 416 | Args: 417 | owner_id (Optional[str], optional): Owner id to scope it to. Defaults to None. 418 | """ 419 | pass 420 | 421 | @abstractmethod 422 | def runtime_local_addr(self, name: str, owner_id: Optional[str] = None) -> str: 423 | """ 424 | Returns the local address of the agent with respect to the runtime 425 | """ 426 | pass 427 | -------------------------------------------------------------------------------- /taskara/runtime/docker.py: -------------------------------------------------------------------------------- 1 | import json 2 | import logging 3 | import os 4 | import signal 5 | import sys 6 | import urllib.error 7 | import urllib.parse 8 | import urllib.request 9 | from typing import Dict, Iterator, List, Optional, Tuple, Type, Union 10 | 11 | import docker 12 | from docker.api.client import APIClient 13 | from docker.errors import NotFound 14 | from pydantic import BaseModel 15 | from tqdm import tqdm 16 | 17 | from taskara.server.models import ( 18 | V1ResourceLimits, 19 | V1ResourceRequests, 20 | V1Tracker, 21 | V1TrackerRuntimeConnect, 22 | ) 23 | from taskara.util import find_open_port 24 | 25 | from .base import Tracker, TrackerRuntime 26 | 27 | logger = logging.getLogger(__name__) 28 | 29 | 30 | class DockerConnectConfig(BaseModel): 31 | timeout: Optional[int] = None 32 | image: str = "us-central1-docker.pkg.dev/agentsea-dev/taskara/api:latest" 33 | 34 | 35 | class DockerTrackerRuntime(TrackerRuntime["DockerTrackerRuntime", DockerConnectConfig]): 36 | 37 | def __init__(self, cfg: Optional[DockerConnectConfig] = None) -> None: 38 | self.docker_socket = self._configure_docker_socket() 39 | if not cfg: 40 | cfg = DockerConnectConfig() 41 | 42 | self.img = cfg.image 43 | 44 | self._cfg = cfg 45 | if cfg.timeout: 46 | self.client = docker.DockerClient(base_url=self.docker_socket, timeout=cfg.timeout) 47 | else: 48 | self.client = docker.DockerClient(base_url=self.docker_socket) 49 | 50 | # Verify connection and version 51 | self._check_version() 52 | 53 | 54 | def _configure_docker_socket(self): 55 | if os.path.exists("/var/run/docker.sock"): 56 | docker_socket = "unix:///var/run/docker.sock" 57 | else: 58 | user = os.environ.get("USER") 59 | if os.path.exists(f"/Users/{user}/.docker/run/docker.sock"): 60 | docker_socket = f"unix:///Users/{user}/.docker/run/docker.sock" 61 | else: 62 | raise FileNotFoundError( 63 | ( 64 | "Neither '/var/run/docker.sock' nor '/Users//.docker/run/docker.sock' are available." 65 | "Please make sure you have Docker installed and running." 66 | ) 67 | ) 68 | os.environ["DOCKER_HOST"] = docker_socket 69 | return docker_socket 70 | 71 | def _check_version(self): 72 | version_info = self.client.version() 73 | engine_version = next((component['Version'] for component in version_info.get('Components', []) 74 | if component['Name'] == 'Engine'), None) 75 | if not engine_version: 76 | raise SystemError("Unable to determine Docker Engine version") 77 | logger.debug(f"Connected to Docker Engine version: {engine_version}") 78 | 79 | @classmethod 80 | def name(cls) -> str: 81 | return "docker" 82 | 83 | @classmethod 84 | def connect_config_type(cls) -> Type[DockerConnectConfig]: 85 | return DockerConnectConfig 86 | 87 | def connect_config(self) -> DockerConnectConfig: 88 | return self._cfg 89 | 90 | @classmethod 91 | def connect(cls, cfg: DockerConnectConfig) -> "DockerTrackerRuntime": 92 | return cls(cfg) 93 | 94 | def call( 95 | self, 96 | name: str, 97 | path: str, 98 | method: str, 99 | port: int = 9070, 100 | data: Optional[dict] = None, 101 | headers: Optional[dict] = None, 102 | ) -> Tuple[int, str]: 103 | # Attempt to get the container by name 104 | try: 105 | container = self.client.containers.get(name) 106 | except NotFound: 107 | raise ValueError(f"Container '{name}' not found") 108 | 109 | # Construct the URL using the mapped port 110 | url = f"http://localhost:{port}{path}" 111 | 112 | # Create a request object based on the HTTP method 113 | if method.upper() == "GET": 114 | if data: 115 | query_params = urllib.parse.urlencode(data) 116 | url += f"?{query_params}" 117 | request = urllib.request.Request(url) 118 | else: 119 | request = urllib.request.Request(url, method=method.upper()) 120 | if data: 121 | request.add_header("Content-Type", "application/json") 122 | if headers: 123 | for k, v in headers.items(): 124 | request.add_header(k, v) 125 | request.data = json.dumps(data).encode("utf-8") 126 | 127 | # Send the request and handle the response 128 | try: 129 | response = urllib.request.urlopen(request) 130 | status_code = response.code 131 | response_text = response.read().decode("utf-8") 132 | return status_code, response_text 133 | except urllib.error.HTTPError as e: 134 | status_code = e.code 135 | error_message = e.read().decode("utf-8") 136 | raise SystemError( 137 | f"Error making HTTP request to Docker container: {status_code}: {error_message}" 138 | ) 139 | finally: 140 | try: 141 | if response: # type: ignore 142 | response.close() 143 | except: 144 | pass 145 | 146 | def _ensure_network_exists(self, network_name: str): 147 | try: 148 | self.client.networks.get(network_name) 149 | logger.debug(f"Network '{network_name}' already exists.") 150 | except NotFound: 151 | logger.debug(f"Network '{network_name}' not found. Creating network.") 152 | self.client.networks.create(network_name) 153 | logger.debug(f"Network '{network_name}' created.") 154 | 155 | def run( 156 | self, 157 | name: str, 158 | env_vars: Optional[dict] = None, 159 | owner_id: Optional[str] = None, 160 | labels: Optional[Dict[str, str]] = None, 161 | resource_requests: V1ResourceRequests = V1ResourceRequests(), 162 | resource_limits: V1ResourceLimits = V1ResourceLimits(), 163 | auth_enabled: bool = True, 164 | ) -> Tracker: 165 | 166 | api_client = docker.APIClient(base_url=self.docker_socket) 167 | 168 | # Pull the image with progress tracking 169 | pull_image(self.img, api_client) 170 | 171 | _labels = { 172 | "provisioner": "taskara", 173 | "server_name": name, 174 | } 175 | if labels: 176 | _labels.update(labels) 177 | 178 | port = find_open_port(9070, 10090) 179 | if not port: 180 | raise ValueError("Could not find open port") 181 | 182 | if not env_vars: 183 | env_vars = {} 184 | 185 | if not auth_enabled: 186 | env_vars["TASK_SERVER_NO_AUTH"] = "true" 187 | 188 | if not self.img: 189 | raise ValueError("img not found") 190 | 191 | self._ensure_network_exists("agentsea") 192 | container = self.client.containers.run( 193 | self.img, 194 | network="agentsea", 195 | ports={9070: port}, 196 | environment=env_vars, 197 | detach=True, 198 | labels=_labels, 199 | name=name, 200 | ) 201 | if container and type(container) != bytes: 202 | logger.debug(f"ran container '{container.id}'") # type: ignore 203 | 204 | return Tracker( 205 | name=name, 206 | runtime=self, 207 | status="running", 208 | port=port, 209 | owner_id=owner_id, 210 | ) 211 | 212 | def runtime_local_addr(self, name: str, owner_id: Optional[str] = None) -> str: 213 | """ 214 | Returns the local address of the agent with respect to the runtime 215 | """ 216 | instances = Tracker.find(name=name, owner_id=owner_id, runtime_name=self.name()) 217 | if not instances: 218 | raise ValueError(f"Task server '{name}' not found") 219 | instance = instances[0] 220 | 221 | return f"http://{instance.name}:{instance.port}" 222 | 223 | def _handle_logs_with_attach(self, server_name: str, attach: bool): 224 | if attach: 225 | # Setup the signal handler to catch interrupt signals 226 | signal.signal(signal.SIGINT, self._signal_handler(server_name)) 227 | 228 | try: 229 | for line in self.logs(server_name, follow=True): 230 | print(line) 231 | except KeyboardInterrupt: 232 | # This block will be executed if SIGINT is caught 233 | print(f"Interrupt received, stopping logs for '{server_name}'") 234 | self.delete(server_name) 235 | except Exception as e: 236 | print(f"Error while streaming logs: {e}") 237 | 238 | def _signal_handler(self, server_name: str): 239 | def handle_signal(signum, frame): 240 | print(f"Signal {signum} received, stopping container '{server_name}'") 241 | self.delete(server_name) 242 | sys.exit(1) 243 | 244 | return handle_signal 245 | 246 | def requires_proxy(self) -> bool: 247 | """Whether this runtime requires a proxy to be used""" 248 | return False 249 | 250 | def proxy( 251 | self, 252 | name: str, 253 | local_port: Optional[int] = None, 254 | tracker_port: int = 9070, 255 | background: bool = True, 256 | owner_id: Optional[str] = None, 257 | ) -> Optional[int]: 258 | return 259 | 260 | def list( 261 | self, owner_id: Optional[str] = None, source: bool = False 262 | ) -> List[Tracker]: 263 | 264 | instances = [] 265 | if source: 266 | label_filter = {"label": "provisioner=taskara"} 267 | containers = self.client.containers.list(filters=label_filter) 268 | 269 | for container in containers: 270 | server_name = container.name 271 | 272 | # Extract the TASK_SERVER_PORT environment variable 273 | env_vars = container.attrs.get("Config", {}).get("Env", []) 274 | port = next( 275 | ( 276 | int(var.split("=")[1]) 277 | for var in env_vars 278 | if var.startswith("TASK_SERVER_PORT=") 279 | ), 280 | 9070, 281 | ) 282 | 283 | instance = Tracker( 284 | name=server_name, 285 | runtime=self, 286 | port=port, 287 | status="running", 288 | owner_id=owner_id, 289 | ) 290 | instances.append(instance) 291 | else: 292 | return Tracker.find(owner_id=owner_id, runtime_name=self.name()) 293 | 294 | return instances 295 | 296 | def get( 297 | self, name: str, owner_id: Optional[str] = None, source: bool = False 298 | ) -> Tracker: 299 | if source: 300 | try: 301 | container = self.client.containers.get(name) 302 | 303 | # Extract the TASK_SERVER_PORT environment variable 304 | env_vars = container.attrs.get("Config", {}).get("Env", []) 305 | port = next( 306 | ( 307 | int(var.split("=")[1]) 308 | for var in env_vars 309 | if var.startswith("TASK_SERVER_PORT=") 310 | ), 311 | 9070, 312 | ) 313 | 314 | return Tracker( 315 | name=name, 316 | runtime=self, 317 | status="running", 318 | port=port, 319 | owner_id=owner_id, 320 | ) 321 | except NotFound: 322 | raise ValueError(f"Container '{name}' not found") 323 | 324 | else: 325 | instances = Tracker.find( 326 | name=name, owner_id=owner_id, runtime_name=self.name() 327 | ) 328 | if not instances: 329 | raise ValueError(f"Task server '{name}' not found") 330 | return instances[0] 331 | 332 | def delete(self, name: str, owner_id: Optional[str] = None) -> None: 333 | try: 334 | # Attempt to get the container by name 335 | container = self.client.containers.get(name) 336 | 337 | # If found, remove the container 338 | container.remove(force=True) # type: ignore 339 | logger.debug(f"Successfully deleted container: {name}") 340 | except NotFound: 341 | # Handle the case where the container does not exist 342 | logger.debug(f"Container '{name}' does not exist.") 343 | raise 344 | except Exception as e: 345 | # Handle other potential errors 346 | logger.error(f"Failed to delete container '{name}': {e}") 347 | raise 348 | 349 | def clean(self, owner_id: Optional[str] = None) -> None: 350 | # Define the filter for containers with the specific label 351 | label_filter = {"label": ["provisioner=taskara"]} 352 | 353 | # Use the filter to list containers 354 | containers = self.client.containers.list(filters=label_filter, all=True) 355 | 356 | # Initialize a list to keep track of deleted container names or IDs 357 | deleted_containers = [] 358 | 359 | for container in containers: 360 | try: 361 | container_name_or_id = ( 362 | container.name # type: ignore 363 | ) # or container.id for container ID 364 | container.remove(force=True) # type: ignore 365 | logger.debug(f"Deleted container: {container_name_or_id}") 366 | deleted_containers.append(container_name_or_id) 367 | except Exception as e: 368 | logger.error(f"Failed to delete container: {e}") 369 | 370 | return None 371 | 372 | def logs( 373 | self, name: str, follow: bool = False, owner_id: Optional[str] = None 374 | ) -> Union[str, Iterator[str]]: 375 | """ 376 | Fetches the logs from the specified container. Can return all logs as a single string, 377 | or stream the logs as a generator of strings. 378 | 379 | Parameters: 380 | name (str): The name of the container. 381 | follow (bool): Whether to continuously follow the logs. 382 | 383 | Returns: 384 | Union[str, Iterator[str]]: All logs as a single string, or a generator that yields log lines. 385 | """ 386 | try: 387 | container = self.client.containers.get(name) 388 | if follow: 389 | log_stream = container.logs(stream=True, follow=True) # type: ignore 390 | return (line.decode("utf-8").strip() for line in log_stream) # type: ignore 391 | else: 392 | return container.logs().decode("utf-8") # type: ignore 393 | except NotFound: 394 | print(f"Container '{name}' does not exist.") 395 | raise 396 | except Exception as e: 397 | print(f"Failed to fetch logs for container '{name}': {e}") 398 | raise 399 | 400 | def refresh(self, owner_id: Optional[str] = None) -> None: 401 | """ 402 | Reconciles the database against the Docker containers running. 403 | 404 | Parameters: 405 | owner_id (Optional[str]): The owner ID to filter the trackers. If None, refreshes for all owners. 406 | """ 407 | # List all Docker containers with the specific label 408 | label_filter = {"label": "provisioner=taskara"} 409 | running_containers = self.client.containers.list(filters=label_filter) 410 | running_container_names = {container.name for container in running_containers} # type: ignore 411 | 412 | # List all trackers in the database 413 | if owner_id: 414 | db_trackers = Tracker.find(owner_id=owner_id, runtime_name=self.name()) 415 | else: 416 | db_trackers = Tracker.find(runtime_name=self.name()) 417 | 418 | db_tracker_names = {tracker.name for tracker in db_trackers} 419 | 420 | # Determine trackers to add or remove from the database 421 | containers_to_add = running_container_names - db_tracker_names 422 | containers_to_remove = db_tracker_names - running_container_names 423 | 424 | # Add new containers to the database 425 | for container_name in containers_to_add: 426 | container = self.client.containers.get(container_name) 427 | env_vars = container.attrs.get("Config", {}).get("Env", []) 428 | port = next( 429 | ( 430 | int(var.split("=")[1]) 431 | for var in env_vars 432 | if var.startswith("TASK_SERVER_PORT=") 433 | ), 434 | 9070, 435 | ) 436 | new_tracker = Tracker( 437 | name=container_name, 438 | runtime=self, 439 | port=port, 440 | status="running", 441 | owner_id=owner_id, 442 | ) 443 | new_tracker.save() 444 | 445 | # Remove containers from the database that are no longer running 446 | for tracker_name in containers_to_remove: 447 | trackers = Tracker.find( 448 | name=tracker_name, owner_id=owner_id, runtime_name=self.name() 449 | ) 450 | if not trackers: 451 | continue 452 | 453 | tracker = trackers[0] 454 | tracker.delete() 455 | 456 | logger.debug( 457 | f"Refresh completed: added {len(containers_to_add)} trackers, removed {len(containers_to_remove)} trackers." 458 | ) 459 | 460 | 461 | def pull_image(img: str, api_client: APIClient): 462 | """ 463 | Pulls a Docker image with progress bars for each layer. 464 | 465 | Args: 466 | img (str): The Docker image to pull. 467 | api_client (APIClient): The Docker API client. 468 | """ 469 | 470 | print(f"Pulling Docker image '{img}'...") 471 | 472 | progress_bars = {} 473 | layers = {} 474 | 475 | for line in api_client.pull(img, stream=True, decode=True): 476 | if "id" in line and "progressDetail" in line: 477 | layer_id = line["id"] 478 | progress_detail = line["progressDetail"] 479 | current = progress_detail.get("current", 0) 480 | total = progress_detail.get("total", 0) 481 | 482 | if total: 483 | if layer_id not in layers: 484 | progress_bars[layer_id] = tqdm( 485 | total=total, 486 | desc=f"Layer {layer_id}", 487 | leave=False, 488 | ncols=100, 489 | ) 490 | layers[layer_id] = 0 491 | 492 | layers[layer_id] = current 493 | progress_bars[layer_id].n = current 494 | progress_bars[layer_id].refresh() 495 | 496 | # Close all progress bars 497 | for bar in progress_bars.values(): 498 | bar.n = bar.total # Ensure the progress bar is full before closing 499 | bar.refresh() 500 | bar.close() 501 | 502 | print("") 503 | -------------------------------------------------------------------------------- /taskara/runtime/load.py: -------------------------------------------------------------------------------- 1 | from typing import Optional, List, Type 2 | 3 | from pydantic import BaseModel 4 | 5 | from .docker import DockerTrackerRuntime, DockerConnectConfig 6 | from .kube import KubeTrackerRuntime, KubeConnectConfig 7 | from .process import ProcessTrackerRuntime, ProcessConnectConfig 8 | from .base import Tracker, TrackerRuntime 9 | from taskara.server.models import V1TrackerRuntimeConnect 10 | 11 | 12 | class AgentRuntimeConfig(BaseModel): 13 | provider: Optional[str] = None 14 | docker_config: Optional[DockerConnectConfig] = None 15 | kube_config: Optional[KubeConnectConfig] = None 16 | process_config: Optional[ProcessConnectConfig] = None 17 | preference: List[str] = ["kube", "docker", "process"] 18 | 19 | 20 | def runtime_from_name(name: str) -> Type[TrackerRuntime]: 21 | for runt in RUNTIMES: 22 | if runt.name() == name: 23 | return runt 24 | raise ValueError(f"Unknown runtime '{name}'") 25 | 26 | 27 | def load_tracker_runtime(cfg: AgentRuntimeConfig) -> TrackerRuntime: 28 | for pref in cfg.preference: 29 | if pref == KubeTrackerRuntime.name() and cfg.kube_config: 30 | return KubeTrackerRuntime.connect(cfg.kube_config) 31 | elif pref == DockerTrackerRuntime.name() and cfg.docker_config: 32 | return DockerTrackerRuntime.connect(cfg.docker_config) 33 | elif pref == ProcessTrackerRuntime.name() and cfg.process_config: 34 | return ProcessTrackerRuntime.connect(cfg.process_config) 35 | raise ValueError(f"Unknown provider: {cfg.provider}") 36 | 37 | 38 | RUNTIMES: List[Type[TrackerRuntime]] = [DockerTrackerRuntime, KubeTrackerRuntime, ProcessTrackerRuntime] # type: ignore 39 | 40 | 41 | def load_from_connect(connect: V1TrackerRuntimeConnect) -> TrackerRuntime: 42 | for runt in RUNTIMES: 43 | if connect.name == runt.name(): 44 | print("connect config: ", connect.connect_config) 45 | print("type: ", type(connect.connect_config)) 46 | cfg = runt.connect_config_type().model_validate(connect.connect_config) 47 | return runt.connect(cfg) 48 | 49 | raise ValueError(f"Unknown runtime: {connect.name}") 50 | -------------------------------------------------------------------------------- /taskara/runtime/process.py: -------------------------------------------------------------------------------- 1 | import json 2 | import logging 3 | import os 4 | import signal 5 | import subprocess 6 | import sys 7 | import time 8 | import urllib.error 9 | import urllib.parse 10 | import urllib.request 11 | from typing import Dict, Iterator, List, Optional, Tuple, Type, Union 12 | 13 | import requests 14 | from pydantic import BaseModel 15 | 16 | from taskara.server.models import V1ResourceLimits, V1ResourceRequests 17 | from taskara.util import find_open_port 18 | 19 | from .base import Tracker, TrackerRuntime 20 | 21 | logger = logging.getLogger(__name__) 22 | 23 | 24 | class ProcessConnectConfig(BaseModel): 25 | pass 26 | 27 | 28 | class ProcessTrackerRuntime( 29 | TrackerRuntime["ProcessTrackerRuntime", ProcessConnectConfig] 30 | ): 31 | 32 | @classmethod 33 | def name(cls) -> str: 34 | return "process" 35 | 36 | @classmethod 37 | def connect_config_type(cls) -> Type[ProcessConnectConfig]: 38 | return ProcessConnectConfig 39 | 40 | def connect_config(self) -> ProcessConnectConfig: 41 | return ProcessConnectConfig() 42 | 43 | @classmethod 44 | def connect(cls, cfg: ProcessConnectConfig) -> "ProcessTrackerRuntime": 45 | return cls() 46 | 47 | def run( 48 | self, 49 | name: str, 50 | env_vars: Optional[dict] = None, 51 | owner_id: Optional[str] = None, 52 | labels: Optional[Dict[str, str]] = None, 53 | resource_requests: V1ResourceRequests = V1ResourceRequests(), 54 | resource_limits: V1ResourceLimits = V1ResourceLimits(), 55 | auth_enabled: bool = True, 56 | ) -> Tracker: 57 | 58 | port = find_open_port(9070, 10090) 59 | if not port: 60 | raise ValueError("Could not find open port") 61 | logger.debug("running process") 62 | 63 | metadata = { 64 | "name": name, 65 | "port": port, 66 | "env_vars": env_vars if env_vars else {}, 67 | "owner_id": owner_id, 68 | } 69 | 70 | server_cmd = "poetry run python -m taskara.server.app" 71 | command = f"TASK_SERVER_PORT={port} nohup {server_cmd} TASK_SERVER={name} TASK_SERVER_PORT={port} > ./.data/logs/{name.lower()}.log 2>&1 &" 72 | if not auth_enabled: 73 | command = "TASK_SERVER_NO_AUTH=true " + command 74 | metadata["command"] = command 75 | 76 | # Create metadata directory if it does not exist 77 | os.makedirs(f".data/proc", exist_ok=True) 78 | # Write metadata to a file 79 | with open(f".data/proc/{name}.json", "w") as f: 80 | json.dump(metadata, f, indent=4) 81 | 82 | os.makedirs(f".data/logs", exist_ok=True) 83 | print(f"running server on port {port}") 84 | 85 | environment = os.environ.copy() 86 | process = subprocess.Popen( 87 | command, 88 | shell=True, 89 | preexec_fn=os.setsid, 90 | env=environment, 91 | text=True, 92 | ) 93 | 94 | # Wait for the command to complete 95 | stdout, stderr = process.communicate() 96 | 97 | # Check if there were any errors 98 | if process.returncode != 0: 99 | logger.error("Error running command:") 100 | print(stderr) 101 | else: 102 | # Print the output from stdout 103 | if stdout: 104 | print(stdout) 105 | 106 | # Health check logic 107 | max_retries = 20 108 | retry_delay = 1 109 | health_url = f"http://localhost:{port}/health" 110 | 111 | for _ in range(max_retries): 112 | try: 113 | response = requests.get(health_url) 114 | if response.status_code == 200: 115 | logger.info("Task server is up and running.") 116 | break 117 | except requests.ConnectionError: 118 | logger.warning("Task server not yet available, retrying...") 119 | time.sleep(retry_delay) 120 | else: 121 | raise RuntimeError("Failed to start server, it did not pass health checks.") 122 | 123 | return Tracker( 124 | name=name, 125 | runtime=self, 126 | status="running", 127 | port=port, 128 | labels={"command": command}, 129 | owner_id=owner_id, 130 | ) 131 | 132 | def _signal_handler(self, server_name: str): 133 | def handle_signal(signum, frame): 134 | print(f"Signal {signum} received, stopping process '{server_name}'") 135 | self.delete(server_name) 136 | instances = Tracker.find(name=server_name) 137 | if instances: 138 | instances[0].delete() 139 | sys.exit(1) 140 | 141 | return handle_signal 142 | 143 | def _follow_logs(self, server_name: str): 144 | log_path = f"./.data/logs/{server_name.lower()}.log" 145 | if not os.path.exists(log_path): 146 | logger.error("No log file found.") 147 | return 148 | 149 | with open(log_path, "r") as log_file: 150 | # Go to the end of the file 151 | log_file.seek(0, 2) 152 | try: 153 | while True: 154 | line = log_file.readline() 155 | if not line: 156 | time.sleep(0.5) # Wait briefly for new log entries 157 | continue 158 | print(line.strip()) 159 | except KeyboardInterrupt: 160 | # Handle Ctrl+C gracefully if we are attached to the logs 161 | print(f"Interrupt received, stopping logs for '{server_name}'") 162 | self.delete(server_name) 163 | raise 164 | 165 | def requires_proxy(self) -> bool: 166 | """Whether this runtime requires a proxy to be used""" 167 | return False 168 | 169 | def proxy( 170 | self, 171 | name: str, 172 | local_port: Optional[int] = None, 173 | tracker_port: int = 9070, 174 | background: bool = True, 175 | owner_id: Optional[str] = None, 176 | ) -> Optional[int]: 177 | logger.info("no proxy needed") 178 | return 179 | 180 | def get( 181 | self, name: str, owner_id: Optional[str] = None, source: bool = False 182 | ) -> Tracker: 183 | if source: 184 | try: 185 | # Read the metadata file 186 | with open(f".data/proc/{name}.json", "r") as f: 187 | metadata = json.load(f) 188 | 189 | return Tracker( 190 | name=metadata["name"], 191 | runtime=self, 192 | port=metadata["port"], 193 | ) 194 | except FileNotFoundError: 195 | raise ValueError(f"No metadata found for server {name}") 196 | 197 | else: 198 | instances = Tracker.find( 199 | name=name, owner_id=owner_id, runtime_name=self.name() 200 | ) 201 | if len(instances) == 0: 202 | raise ValueError(f"No running server found with the name {name}") 203 | return instances[0] 204 | 205 | def list( 206 | self, 207 | owner_id: Optional[str] = None, 208 | source: bool = False, 209 | ) -> List[Tracker]: 210 | instances = [] 211 | if source: 212 | metadata_dir = ".data/proc" 213 | all_processes = subprocess.check_output( 214 | "ps ax -o pid,command", shell=True, text=True 215 | ) 216 | 217 | for filename in os.listdir(metadata_dir): 218 | if filename.endswith(".json"): 219 | try: 220 | with open(os.path.join(metadata_dir, filename), "r") as file: 221 | metadata = json.load(file) 222 | 223 | # Check if process is still running 224 | process_info = f"TASK_SERVER={metadata['name']} " 225 | if process_info in all_processes: 226 | instance = Tracker( 227 | name=metadata["name"], 228 | runtime=self, 229 | status="running", 230 | port=metadata["port"], 231 | ) 232 | instances.append(instance) 233 | else: 234 | # Process is not running, delete the metadata file 235 | os.remove(os.path.join(metadata_dir, filename)) 236 | logger.info( 237 | f"Deleted metadata for non-existing process {metadata['name']}." 238 | ) 239 | 240 | except Exception as e: 241 | logger.error(f"Error processing {filename}: {str(e)}") 242 | else: 243 | return Tracker.find(owner_id=owner_id, runtime_name=self.name()) 244 | 245 | return instances 246 | 247 | def call( 248 | self, 249 | name: str, 250 | path: str, 251 | method: str, 252 | port: int = 9070, 253 | data: Optional[dict] = None, 254 | headers: Optional[dict] = None, 255 | ) -> Tuple[int, str]: 256 | # Construct the URL 257 | url = f"http://localhost:{port}{path}" 258 | 259 | # Create a request object based on the HTTP method 260 | if method.upper() == "GET": 261 | if data: 262 | query_params = urllib.parse.urlencode(data) 263 | url += f"?{query_params}" 264 | request = urllib.request.Request(url) 265 | else: 266 | request = urllib.request.Request(url, method=method.upper()) 267 | if data: 268 | request.add_header("Content-Type", "application/json") 269 | if headers: 270 | for k, v in headers.items(): 271 | request.add_header(k, v) 272 | request.data = json.dumps(data).encode("utf-8") 273 | 274 | # Send the request and handle the response 275 | try: 276 | response = urllib.request.urlopen(request) 277 | status_code = response.code 278 | response_text = response.read().decode("utf-8") 279 | return status_code, response_text 280 | except urllib.error.HTTPError as e: 281 | status_code = e.code 282 | error_message = e.read().decode("utf-8") 283 | raise SystemError( 284 | f"Error making HTTP request to local process: {status_code}: {error_message}" 285 | ) 286 | finally: 287 | try: 288 | if response: # type: ignore 289 | response.close() 290 | except: 291 | pass 292 | 293 | def delete( 294 | self, 295 | name: str, 296 | owner_id: Optional[str] = None, 297 | ) -> None: 298 | try: 299 | process_list = subprocess.check_output( 300 | f"ps ax -o pid,command | grep -v grep | grep TASK_SERVER={name}", 301 | shell=True, 302 | text=True, 303 | ) 304 | logger.debug(f"Found process list: {process_list}") 305 | if process_list.strip(): 306 | # Process found, extract PID and kill it 307 | pid = process_list.strip().split()[0] 308 | os.killpg(os.getpgid(int(pid)), signal.SIGTERM) 309 | logger.info(f"Process {name} with PID {pid} has been terminated.") 310 | else: 311 | raise SystemError(f"No running process found with the name {name}.") 312 | 313 | # Delete the metadata file whether or not the process was found 314 | metadata_file = f".data/proc/{name}.json" 315 | if os.path.exists(metadata_file): 316 | os.remove(metadata_file) 317 | logger.info(f"Deleted metadata file for {name}.") 318 | 319 | except subprocess.CalledProcessError as e: 320 | raise SystemError(f"Error while attempting to delete the process: {str(e)}") 321 | except ValueError as e: 322 | raise SystemError(f"Error parsing process ID: {str(e)}") 323 | except Exception as e: 324 | raise SystemError(f"An unexpected error occurred: {str(e)}") 325 | 326 | def runtime_local_addr(self, name: str, owner_id: Optional[str] = None) -> str: 327 | """ 328 | Returns the local address of the agent with respect to the runtime 329 | """ 330 | instances = Tracker.find(name=name, owner_id=owner_id, runtime_name=self.name()) 331 | if not instances: 332 | raise ValueError(f"Task server '{name}' not found") 333 | instance = instances[0] 334 | 335 | return f"http://localhost:{instance.port}" 336 | 337 | def clean( 338 | self, 339 | owner_id: Optional[str] = None, 340 | ) -> None: 341 | try: 342 | # Fetch the list of all processes that were started with the 'TASK_SERVER' environment variable 343 | process_list = subprocess.check_output( 344 | "ps ax -o pid,command | grep -v grep | grep TASK_SERVER", 345 | shell=True, 346 | text=True, 347 | ) 348 | # Iterate through each process found and kill it 349 | for line in process_list.strip().split("\n"): 350 | pid = line.split()[0] # Extract the PID from the output 351 | try: 352 | os.kill( 353 | int(pid), signal.SIGTERM 354 | ) # Send SIGTERM signal to terminate the process 355 | logger.info(f"Terminated process with PID {pid}.") 356 | except OSError as e: 357 | logger.error( 358 | f"Failed to terminate process with PID {pid}: {str(e)}" 359 | ) 360 | logger.info("All relevant processes have been terminated.") 361 | except subprocess.CalledProcessError as e: 362 | logger.error( 363 | "No relevant processes found or error executing the ps command:", str(e) 364 | ) 365 | except Exception as e: 366 | logger.error(f"An unexpected error occurred during cleanup: {str(e)}") 367 | 368 | def logs( 369 | self, 370 | name: str, 371 | follow: bool = False, 372 | owner_id: Optional[str] = None, 373 | ) -> Union[str, Iterator[str]]: 374 | log_path = f"./.data/logs/{name.lower()}.log" 375 | if not os.path.exists(log_path): 376 | return "No logs available for this server." 377 | 378 | if follow: 379 | # If follow is True, implement a simple follow (like 'tail -f') 380 | def follow_logs(): 381 | with open(log_path, "r") as log_file: 382 | # Go to the end of the file 383 | log_file.seek(0, 2) 384 | while True: 385 | line = log_file.readline() 386 | if not line: 387 | time.sleep(0.5) # Wait briefly for new log entries 388 | continue 389 | yield line 390 | 391 | return follow_logs() 392 | else: 393 | # If not following, return all logs as a single string 394 | with open(log_path, "r") as log_file: 395 | return log_file.read() 396 | 397 | def refresh(self, owner_id: Optional[str] = None) -> None: 398 | """ 399 | Reconciles the database against the running processes. 400 | 401 | Parameters: 402 | owner_id (Optional[str]): The owner ID to filter the trackers. If None, refreshes for all owners. 403 | """ 404 | # List all running processes with the specific environment variable 405 | all_processes = subprocess.check_output( 406 | "ps ax -o pid,command", shell=True, text=True 407 | ) 408 | running_processes = {} 409 | for line in all_processes.splitlines(): 410 | if "TASK_SERVER=" in line: 411 | pid, command = line.split(maxsplit=1) 412 | server_name = next( 413 | ( 414 | part.split("=")[1] 415 | for part in command.split() 416 | if part.startswith("TASK_SERVER=") 417 | ), 418 | None, 419 | ) 420 | if server_name: 421 | running_processes[server_name] = pid 422 | 423 | running_process_names = set(running_processes.keys()) 424 | 425 | # List all trackers in the database 426 | if owner_id: 427 | db_trackers = Tracker.find(owner_id=owner_id, runtime_name=self.name()) 428 | else: 429 | db_trackers = Tracker.find(runtime_name=self.name()) 430 | 431 | db_tracker_names = {tracker.name for tracker in db_trackers} 432 | 433 | # Determine trackers to add or remove from the database 434 | processes_to_add = running_process_names - db_tracker_names 435 | processes_to_remove = db_tracker_names - running_process_names 436 | 437 | # Add new processes to the database 438 | for process_name in processes_to_add: 439 | try: 440 | with open(f".data/proc/{process_name}.json", "r") as f: 441 | metadata = json.load(f) 442 | 443 | new_tracker = Tracker( 444 | name=metadata["name"], 445 | runtime=self, 446 | port=metadata["port"], 447 | status="running", 448 | owner_id=owner_id, 449 | ) 450 | new_tracker.save() 451 | except FileNotFoundError: 452 | logger.warning(f"No metadata found for process {process_name}") 453 | 454 | # Remove processes from the database that are no longer running 455 | for tracker_name in processes_to_remove: 456 | trackers = Tracker.find( 457 | name=tracker_name, owner_id=owner_id, runtime_name=self.name() 458 | ) 459 | if not trackers: 460 | continue 461 | tracker = trackers[0] 462 | tracker.delete() 463 | 464 | logger.debug( 465 | f"Refresh completed: added {len(processes_to_add)} trackers, removed {len(processes_to_remove)} trackers." 466 | ) 467 | -------------------------------------------------------------------------------- /taskara/server/app.py: -------------------------------------------------------------------------------- 1 | import logging.config 2 | import os 3 | from contextlib import asynccontextmanager 4 | import time 5 | from fastapi import FastAPI, Request, status 6 | from fastapi.exceptions import RequestValidationError 7 | from fastapi.middleware.cors import CORSMiddleware 8 | from fastapi.responses import JSONResponse 9 | from ..db.redis_connection import init_redis_pool, close_redis_pool 10 | 11 | 12 | from .router.benchmarks import router as benchmarks_router 13 | from .router.tasks import router as tasks_router 14 | 15 | logging.config.fileConfig("logging.conf", disable_existing_loggers=False) 16 | 17 | @asynccontextmanager 18 | async def lifespan(app: FastAPI): 19 | print("initializing boot sequence...", flush=True) 20 | print("boot sequence initialized.", flush=True) 21 | await init_redis_pool() 22 | yield 23 | print("Shutting Down Fast API server Taskara", flush=True) 24 | await close_redis_pool() 25 | 26 | app = FastAPI(lifespan=lifespan) 27 | 28 | access_logger = logging.getLogger("access") 29 | 30 | 31 | @app.exception_handler(RequestValidationError) 32 | async def validation_exception_handler(request: Request, exc: RequestValidationError): 33 | errors = [] 34 | for error in exc.errors(): 35 | field = ".".join(str(loc) for loc in error["loc"]) 36 | msg = {"field": field, "message": error["msg"], "type": error["type"]} 37 | body = await request.body() 38 | print("\n\n!error: ", msg, "\nrequest data: ", body.decode(), "\n") 39 | access_logger.error( 40 | f"Validation error for field {field}: {error['msg']} (type: {error['type']}, request data: {body.decode()})" 41 | ) 42 | errors.append(msg) 43 | return JSONResponse( 44 | status_code=status.HTTP_422_UNPROCESSABLE_ENTITY, content={"detail": errors} 45 | ) 46 | 47 | 48 | @app.middleware("http") 49 | async def log_requests(request: Request, call_next): 50 | start_time = time.time() 51 | access_logger.info(f"Received request: {request.method} {request.url}") 52 | response = await call_next(request) 53 | duration = time.time() - start_time 54 | access_logger.info( 55 | f"Returned response {request.method} {request.url}: {response.status_code} - Duration: {duration:.4f} seconds" 56 | ) 57 | return response 58 | 59 | 60 | app.add_middleware( 61 | CORSMiddleware, 62 | allow_origins=["*"], 63 | allow_credentials=True, 64 | allow_methods=["*"], 65 | allow_headers=["*"], 66 | ) 67 | 68 | app.include_router(tasks_router) 69 | app.include_router(benchmarks_router) 70 | 71 | 72 | @app.get("/") 73 | async def root(): 74 | return {"message": "A Taskara tracker"} 75 | 76 | 77 | @app.get("/health") 78 | async def health(): 79 | return {"status": "ok"} 80 | 81 | 82 | if __name__ == "__main__": 83 | import uvicorn 84 | 85 | port = int(os.getenv("TASK_SERVER_PORT", "9070")) 86 | reload = os.getenv("TASK_SERVER_RELOAD", "false") == "true" 87 | 88 | uvicorn.run(app="taskara.server.app:app", host="0.0.0.0", port=port, reload=reload) 89 | -------------------------------------------------------------------------------- /taskara/server/models.py: -------------------------------------------------------------------------------- 1 | import time 2 | from typing import Any, Dict, List, Optional 3 | 4 | import shortuuid 5 | from devicebay import V1Device, V1DeviceType 6 | from mllm import V1Prompt 7 | from pydantic import BaseModel, Field 8 | from skillpacks.base import V1Action, V1ActionEvent 9 | from skillpacks.review import ReviewerType, V1Review 10 | from threadmem.server.models import V1RoleThread 11 | 12 | 13 | class V1ReviewRequirement(BaseModel): 14 | """Review requirement for a task""" 15 | 16 | id: Optional[str] = None 17 | task_id: Optional[str] = None 18 | users: Optional[List[str]] = None 19 | agents: Optional[List[str]] = None 20 | groups: Optional[List[str]] = None 21 | number_required: int = 2 22 | 23 | 24 | class V1PendingReviewers(BaseModel): 25 | """Pending reviewers for a task""" 26 | 27 | task_id: str 28 | users: Optional[List[str]] = None 29 | agents: Optional[List[str]] = None 30 | 31 | 32 | class V1PendingReviews(BaseModel): 33 | tasks: List[str] 34 | 35 | 36 | class V1CreateReview(BaseModel): 37 | approved: bool 38 | reviewer_type: str = ReviewerType.HUMAN.value 39 | reason: Optional[str] = None 40 | reviewer: Optional[str] = None 41 | 42 | 43 | class V1CreateReviewAction(BaseModel): 44 | approved: bool 45 | reviewer_type: str = ReviewerType.HUMAN.value 46 | reason: Optional[str] = None 47 | reviewer: Optional[str] = None 48 | correction: Optional[V1Action] = None 49 | 50 | 51 | class V1CreateAnnotationReview(BaseModel): 52 | approved: bool 53 | reviewer_type: str = ReviewerType.HUMAN.value 54 | reason: Optional[str] = None 55 | reviewer: Optional[str] = None 56 | correction: Optional[str] = None 57 | 58 | 59 | class V1CreateAnnotationResponse(BaseModel): 60 | id: str 61 | 62 | 63 | class V1ReviewMany(BaseModel): 64 | reviewer: Optional[str] = None 65 | reviewer_type: str = ReviewerType.HUMAN.value 66 | approve_hidden: bool = False 67 | fail_hidden: bool = False 68 | 69 | 70 | class V1TaskUpdate(BaseModel): 71 | status: Optional[str] = None 72 | org: Optional[str] = None 73 | description: Optional[str] = None 74 | max_steps: Optional[int] = None 75 | error: Optional[str] = None 76 | output: Optional[str] = None 77 | assigned_to: Optional[str] = None 78 | assigned_type: Optional[str] = None 79 | completed: Optional[float] = None 80 | started: Optional[float] = None 81 | version: Optional[str] = None 82 | set_labels: Optional[Dict[str, str]] = None 83 | 84 | class V1CreateTask(BaseModel): 85 | id: str = Field(default_factory=lambda: shortuuid.uuid()) 86 | description: str 87 | max_steps: int = 30 88 | device: Optional[V1Device] = None 89 | device_type: Optional[V1DeviceType] = None 90 | expect_schema: Optional[Dict[str, Any]] = None 91 | status: Optional[str] = None 92 | threads: Optional[List[V1RoleThread]] = None 93 | prompts: Optional[List[str]] = None 94 | reviews: List[V1Review] = [] 95 | review_requirements: List[V1ReviewRequirement] = [] 96 | assigned_to: Optional[str] = None 97 | assigned_type: Optional[str] = None 98 | created: float = Field(default_factory=time.time) 99 | started: float = 0.0 100 | completed: float = 0.0 101 | created_by: Optional[str] = None 102 | error: Optional[str] = None 103 | output: Optional[str] = None 104 | parameters: Optional[Dict[str, Any]] = {} 105 | version: Optional[str] = None 106 | remote: Optional[str] = None 107 | owner_id: Optional[str] = None 108 | project: Optional[str] = None 109 | parent_id: Optional[str] = None 110 | tags: List[str] = [] 111 | labels: Dict[str, str] = {} 112 | episode_id: Optional[str] = None 113 | public: bool = False 114 | skill: Optional[str] = None 115 | org: Optional[str] = None 116 | auth_token: Optional[str] = None 117 | 118 | class V1Task(BaseModel): 119 | id: str = Field(default_factory=lambda: shortuuid.uuid()) 120 | description: str 121 | max_steps: int = 30 122 | device: Optional[V1Device] = None 123 | device_type: Optional[V1DeviceType] = None 124 | expect_schema: Optional[Dict[str, Any]] = None 125 | status: Optional[str] = None 126 | threads: Optional[List[V1RoleThread]] = None 127 | prompts: Optional[List[str]] = None 128 | reviews: List[V1Review] = [] 129 | review_requirements: List[V1ReviewRequirement] = [] 130 | assigned_to: Optional[str] = None 131 | assigned_type: Optional[str] = None 132 | created: float = Field(default_factory=time.time) 133 | started: float = 0.0 134 | completed: float = 0.0 135 | created_by: Optional[str] = '' 136 | error: Optional[str] = None 137 | output: Optional[str] = None 138 | parameters: Optional[Dict[str, Any]] = {} 139 | version: Optional[str] = None 140 | remote: Optional[str] = None 141 | owner_id: Optional[str] = None 142 | project: Optional[str] = None 143 | parent_id: Optional[str] = None 144 | tags: List[str] = [] 145 | labels: Dict[str, str] = {} 146 | episode_id: Optional[str] = None 147 | public: bool = False 148 | skill: Optional[str] = None 149 | auth_token: Optional[str] = None 150 | 151 | 152 | class V1SearchTask(BaseModel): 153 | id: Optional[str] = None 154 | description: Optional[str] = None 155 | max_steps: Optional[int] = None 156 | # device: Optional[V1Device] 157 | # device_type: Optional[V1DeviceType] 158 | status: Optional[str] = None 159 | owners: Optional[List[str]] = None 160 | # threads: Optional[List[V1RoleThread]] 161 | # prompts: Optional[List[str]] 162 | # reviews: Optional[List[V1Review]] 163 | # review_requirements: Optional[List[V1ReviewRequirement]] 164 | assigned_to: Optional[str] = None 165 | assigned_type: Optional[str] = None 166 | created: Optional[float] = None 167 | created_by: Optional[str] = None 168 | started: Optional[float] = None 169 | completed: Optional[float] = None 170 | error: Optional[str] = None 171 | output: Optional[str] = None 172 | # parameters: Optional[Dict[str, Any]] 173 | version: Optional[str] = None 174 | remote: Optional[str] = None 175 | owner_id: Optional[str] = None 176 | project: Optional[str] = None 177 | parent_id: Optional[str] = None 178 | tags: Optional[List[str]] = None 179 | labels: Optional[Dict[str, str]] = None 180 | skill: Optional[str] = None 181 | public: Optional[bool] = None 182 | episode_id: Optional[str] = None 183 | auth_token: Optional[str] = None 184 | 185 | 186 | class V1Tasks(BaseModel): 187 | tasks: List[V1Task] 188 | 189 | 190 | class V1TaskIDs(BaseModel): 191 | task_ids: List[str] 192 | 193 | 194 | class V1TaskTemplate(BaseModel): 195 | id: str = Field(default_factory=lambda: shortuuid.uuid()) 196 | description: str 197 | max_steps: int = 30 198 | device: Optional[V1Device] = None 199 | device_type: Optional[V1DeviceType] = None 200 | expect_schema: Optional[Dict[str, Any]] = None 201 | parameters: Optional[Dict[str, Any]] = {} 202 | owner_id: Optional[str] = None 203 | tags: List[str] = [] 204 | labels: Dict[str, str] = {} 205 | created: float = Field(default_factory=lambda: time.time()) 206 | 207 | 208 | class V1TaskTemplates(BaseModel): 209 | templates: List[V1TaskTemplate] 210 | 211 | 212 | class V1UserProfile(BaseModel): 213 | email: Optional[str] = None 214 | display_name: Optional[str] = None 215 | handle: Optional[str] = None 216 | organization: Optional[str] = None 217 | role: Optional[str] = None 218 | actor: Optional[str] = None 219 | picture: Optional[str] = None 220 | created: Optional[int] = None 221 | updated: Optional[int] = None 222 | token: Optional[str] = None 223 | 224 | 225 | class V1AddThread(BaseModel): 226 | public: bool 227 | name: Optional[str] = None 228 | metadata: Optional[dict] = None 229 | id: Optional[str] = None 230 | 231 | 232 | class V1RemoveThread(BaseModel): 233 | id: str 234 | 235 | 236 | class V1PostMessage(BaseModel): 237 | role: str 238 | msg: str 239 | images: List[str] = [] 240 | thread: Optional[str] = None 241 | 242 | 243 | class V1TrackerRuntimeConnect(BaseModel): 244 | name: str 245 | connect_config: BaseModel 246 | 247 | 248 | class V1Tracker(BaseModel): 249 | name: str 250 | runtime: V1TrackerRuntimeConnect 251 | version: Optional[str] = None 252 | port: int = 9090 253 | labels: Dict[str, str] = {} 254 | tags: List[str] = [] 255 | status: str 256 | owner_id: Optional[str] = None 257 | created: float 258 | updated: float 259 | 260 | 261 | class V1Runtime(BaseModel): 262 | type: str 263 | preference: List[str] = [] 264 | 265 | 266 | class V1ResourceLimits(BaseModel): 267 | cpu: str = "2" 268 | memory: str = "2Gi" 269 | 270 | 271 | class V1ResourceRequests(BaseModel): 272 | cpu: str = "1" 273 | memory: str = "500m" 274 | gpu: Optional[str] = None 275 | 276 | 277 | class V1Prompts(BaseModel): 278 | prompts: List[V1Prompt] 279 | 280 | 281 | class V1Benchmark(BaseModel): 282 | id: str = Field(default_factory=lambda: shortuuid.uuid()) 283 | name: str 284 | description: str 285 | tasks: List[V1TaskTemplate] 286 | owner_id: Optional[str] = None 287 | tags: List[str] = [] 288 | labels: Dict[str, str] = {} 289 | created: float = Field(default_factory=lambda: time.time()) 290 | public: bool = False 291 | 292 | 293 | class V1BenchmarkEval(BaseModel): 294 | assigned_to: str | None = None 295 | assigned_type: str | None = None 296 | 297 | 298 | class V1Benchmarks(BaseModel): 299 | benchmarks: List[V1Benchmark] 300 | 301 | 302 | class V1Eval(BaseModel): 303 | id: Optional[str] = None 304 | benchmark: V1Benchmark 305 | tasks: List[V1Task] 306 | assigned_to: Optional[str] = None 307 | assigned_type: Optional[str] = None 308 | owner_id: Optional[str] = None 309 | 310 | 311 | class V1Evals(BaseModel): 312 | evals: List[V1Eval] 313 | 314 | 315 | class V1Flag(BaseModel): 316 | type: str 317 | id: str 318 | flag: Dict[str, Any] 319 | result: Optional[Dict[str, Any]] = None 320 | created: float 321 | 322 | 323 | class V1BoundingBox(BaseModel): 324 | """A bounding box""" 325 | 326 | x0: int 327 | x1: int 328 | y0: int 329 | y1: int 330 | 331 | 332 | class V1BoundingBoxFlag(BaseModel): 333 | """A bounding box""" 334 | 335 | img: str 336 | target: str 337 | bbox: V1BoundingBox 338 | 339 | 340 | class V1ActionRecordedMessage(BaseModel): 341 | action: V1ActionEvent 342 | prevAction: Optional[V1ActionEvent] = None 343 | task: V1Task 344 | event_number: int = Field( 345 | ..., 346 | description="The index of the action event in order on the task. The first action event will be 0", 347 | ) 348 | 349 | class V1TrainingCompletedMessage(BaseModel): 350 | task_id: str 351 | agent_type: Optional[str] = None 352 | skill_id: Optional[str] = None -------------------------------------------------------------------------------- /taskara/server/router/benchmarks.py: -------------------------------------------------------------------------------- 1 | import logging 2 | from typing import Annotated 3 | 4 | from fastapi import APIRouter, Depends, HTTPException 5 | 6 | from taskara.auth.transport import get_user_dependency 7 | from taskara.benchmark import Benchmark, Eval 8 | from taskara.server.models import ( 9 | V1Benchmark, 10 | V1BenchmarkEval, 11 | V1Benchmarks, 12 | V1Eval, 13 | V1Evals, 14 | V1UserProfile, 15 | ) 16 | 17 | router = APIRouter() 18 | logger = logging.getLogger(__name__) 19 | 20 | 21 | @router.post("/v1/benchmarks", response_model=V1Benchmark) 22 | async def create_benchmark( 23 | current_user: Annotated[V1UserProfile, Depends(get_user_dependency())], 24 | data: V1Benchmark, 25 | ): 26 | logger.debug(f"Creating benchmark with model: {data}") 27 | benchmark = Benchmark.from_v1(data, owner_id=current_user.email) 28 | benchmark.save() 29 | return benchmark.to_v1() 30 | 31 | 32 | @router.get("/v1/benchmarks", response_model=V1Benchmarks) 33 | async def get_benchmarks( 34 | current_user: Annotated[V1UserProfile, Depends(get_user_dependency())] 35 | ): 36 | benchmarks = Benchmark.find(owner_id=current_user.email) 37 | return V1Benchmarks(benchmarks=[benchmark.to_v1() for benchmark in benchmarks]) 38 | 39 | 40 | @router.get("/v1/benchmarks/{id}", response_model=V1Benchmark) 41 | async def get_benchmark( 42 | current_user: Annotated[V1UserProfile, Depends(get_user_dependency())], 43 | id: str, 44 | ): 45 | logger.debug(f"Finding benchmark by id: {id}") 46 | benchmarks = Benchmark.find(id=id, owner_id=current_user.email) 47 | if not benchmarks: 48 | raise HTTPException(status_code=404, detail="Benchmark not found") 49 | return benchmarks[0].to_v1() 50 | 51 | 52 | @router.delete("/v1/benchmarks/{id}") 53 | async def delete_benchmark( 54 | current_user: Annotated[V1UserProfile, Depends(get_user_dependency())], 55 | id: str, 56 | ): 57 | Benchmark.delete(id=id, owner_id=current_user.email) # type: ignore 58 | return {"message": "Benchmark deleted successfully"} 59 | 60 | 61 | @router.post("/v1/benchmarks/{id}/eval", response_model=V1Eval) 62 | async def create_eval_from_benchmark( 63 | current_user: Annotated[V1UserProfile, Depends(get_user_dependency())], 64 | id: str, 65 | data: V1BenchmarkEval, 66 | ): 67 | logger.debug(f"Finding benchmark by id: {id}") 68 | benchmarks = Benchmark.find(id=id, owner_id=current_user.email) 69 | if not benchmarks: 70 | raise HTTPException(status_code=404, detail="Benchmark not found") 71 | benchmark = benchmarks[0] 72 | 73 | eval = benchmark.eval( 74 | data.assigned_to, data.assigned_type, owner_id=current_user.email 75 | ) 76 | eval.save() 77 | return eval.to_v1() 78 | 79 | 80 | @router.post("/v1/evals", response_model=V1Eval) 81 | async def create_eval( 82 | current_user: Annotated[V1UserProfile, Depends(get_user_dependency())], 83 | data: V1Eval, 84 | ): 85 | logger.debug(f"Creating eval with model: {data}") 86 | eval_instance = Eval.from_v1(data, owner_id=current_user.email) 87 | eval_instance.save() 88 | return eval_instance.to_v1() 89 | 90 | 91 | @router.get("/v1/evals", response_model=V1Evals) 92 | async def get_evals( 93 | current_user: Annotated[V1UserProfile, Depends(get_user_dependency())] 94 | ): 95 | evals = Eval.find(owner_id=current_user.email) 96 | return V1Evals(evals=[eval_instance.to_v1() for eval_instance in evals]) 97 | 98 | 99 | @router.get("/v1/evals/{id}", response_model=V1Eval) 100 | async def get_eval( 101 | current_user: Annotated[V1UserProfile, Depends(get_user_dependency())], 102 | id: str, 103 | ): 104 | logger.debug(f"Finding eval by id: {id}") 105 | evals = Eval.find(id=id, owner_id=current_user.email) 106 | if not evals: 107 | raise HTTPException(status_code=404, detail="Eval not found") 108 | return evals[0].to_v1() 109 | 110 | 111 | @router.delete("/v1/evals/{id}") 112 | async def delete_eval( 113 | current_user: Annotated[V1UserProfile, Depends(get_user_dependency())], 114 | id: str, 115 | ): 116 | Eval.delete(id=id, owner_id=current_user.email) # type: ignore 117 | return {"message": "Eval deleted successfully"} 118 | -------------------------------------------------------------------------------- /taskara/util.py: -------------------------------------------------------------------------------- 1 | import random 2 | import socket 3 | import string 4 | import subprocess 5 | from typing import Optional 6 | from openmeter import Client 7 | from azure.core.exceptions import ResourceNotFoundError 8 | import os 9 | 10 | openmeter_secret = os.getenv("OPENMETER_SECRET", False) 11 | openmeter_agent_task_feature = os.getenv("OPENMETER_AGENT_TASK_FEATURE") 12 | # TODO really figure out if this initiates a connection, I think not but should make sure somehow 13 | if openmeter_secret: 14 | openmeter_client = Client( 15 | endpoint="https://openmeter.cloud", 16 | headers={ 17 | "Accept": "application/json", 18 | "Authorization": f"Bearer {openmeter_secret}", 19 | }, 20 | ) 21 | 22 | 23 | def generate_random_string(length: int = 8): 24 | """Generate a random string of fixed length.""" 25 | letters = string.ascii_letters + string.digits 26 | return "".join(random.choices(letters, k=length)) 27 | 28 | 29 | def get_docker_host() -> str: 30 | try: 31 | # Get the current Docker context 32 | current_context = ( 33 | subprocess.check_output("docker context show", shell=True).decode().strip() 34 | ) 35 | 36 | # Inspect the current Docker context and extract the host 37 | context_info = subprocess.check_output( 38 | f"docker context inspect {current_context}", shell=True 39 | ).decode() 40 | for line in context_info.split("\n"): 41 | if '"Host"' in line: 42 | return line.split('"')[3] 43 | return "" 44 | except subprocess.CalledProcessError as e: 45 | print(f"Error: {e.output.decode()}") 46 | return "" 47 | 48 | 49 | def check_port_in_use(port: int) -> bool: 50 | """ 51 | Check if the specified port is currently in use on the local machine. 52 | 53 | Args: 54 | port (int): The port number to check. 55 | 56 | Returns: 57 | bool: True if the port is in use, False otherwise. 58 | """ 59 | with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s: 60 | return s.connect_ex(("localhost", port)) == 0 61 | 62 | 63 | def find_open_port(start_port: int = 1024, end_port: int = 65535) -> Optional[int]: 64 | """Finds an open port on the machine""" 65 | for port in range(start_port, end_port + 1): 66 | with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s: 67 | try: 68 | s.bind(("", port)) 69 | return port # Port is open 70 | except socket.error: 71 | continue # Port is in use, try the next one 72 | return None # No open port found 73 | 74 | def check_openmeter_agent_tasks(owner_id) -> bool: 75 | if openmeter_secret: 76 | if not openmeter_agent_task_feature or not openmeter_client: 77 | raise ValueError('Cannot create desktop no openmeter secret or client or openmeter_agent_task_feature to get entitlements from') 78 | 79 | entitlement_value = {} 80 | try: 81 | # Check openmeter for if user has access through an entitlement 82 | entitlement_value = openmeter_client.get_entitlement_value( 83 | subject_id_or_key=owner_id, 84 | entitlement_id_or_feature_key=openmeter_agent_task_feature 85 | ) 86 | 87 | except ResourceNotFoundError as e: 88 | print( 89 | f"#slack-alert Feature {openmeter_agent_task_feature} not found for subject {owner_id}: {e}" 90 | ) 91 | return False 92 | if not entitlement_value or not entitlement_value["hasAccess"]: 93 | print(f"entitlement access denied in assigning task to agent for feature {openmeter_agent_task_feature}, for subject {owner_id} it is likely that the entitlement is no longer valid or the user/org has reached their cap #slack-alert") 94 | return False 95 | print(f"user: {owner_id} agent task entitlement values are {entitlement_value}", flush=True) 96 | return True -------------------------------------------------------------------------------- /tests/benchmarks/airbnb.yaml: -------------------------------------------------------------------------------- 1 | name: AirbnbLiteV1 2 | description: A simple Airbnb benchmark 3 | tags: ["GUI", "desktop", "airbnb", "booking"] 4 | public: true 5 | tasks: 6 | - description: | 7 | Find a highly rated one bedroom apartment available in Barcelona Spain 8 | from August 5th to August 8th 9 | max_steps: 50 10 | device_type: 11 | name: "desktop" 12 | expect_schema: 13 | properties: 14 | room_id: 15 | description: | 16 | The id of the room, can be found in the URL e.g. https://www.airbnb.com/rooms/ 17 | type: string 18 | required: 19 | - room_id 20 | type: object 21 | parameters: 22 | site: https://airbnb.com 23 | 24 | - description: | 25 | Find a highly rated two bedroom apartment available in Boulder CO 26 | from December 9th to December 12th 27 | max_steps: 50 28 | device_type: 29 | name: "desktop" 30 | expect_schema: 31 | properties: 32 | room_id: 33 | description: | 34 | The id of the room, can be found in the URL e.g. https://www.airbnb.com/rooms/ 35 | type: string 36 | required: 37 | - room_id 38 | type: object 39 | parameters: 40 | site: https://airbnb.com 41 | -------------------------------------------------------------------------------- /tests/test_bench.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | 3 | from taskara.benchmark import Benchmark, TaskTemplate 4 | from taskara.server.models import V1Benchmark, V1TaskTemplate 5 | 6 | 7 | def test_benchmark_creation(): 8 | task_template = TaskTemplate(description="Test Task") 9 | benchmark = Benchmark( 10 | name="Test Benchmark", 11 | description="Test Benchmark Description", 12 | tasks=[task_template], 13 | owner_id="owner@example.com", 14 | ) 15 | 16 | assert benchmark.name == "Test Benchmark" 17 | assert benchmark.description == "Test Benchmark Description" 18 | assert len(benchmark.tasks) == 1 19 | assert benchmark.tasks[0].description == "Test Task" 20 | assert benchmark.owner_id == "owner@example.com" 21 | 22 | 23 | def test_benchmark_to_v1(): 24 | task_template = TaskTemplate(description="Test Task") 25 | benchmark = Benchmark( 26 | name="Test Benchmark", 27 | description="Test Benchmark Description", 28 | tasks=[task_template], 29 | owner_id="owner@example.com", 30 | ) 31 | 32 | v1_benchmark = benchmark.to_v1() 33 | 34 | assert v1_benchmark.name == "Test Benchmark" 35 | assert v1_benchmark.description == "Test Benchmark Description" 36 | assert len(v1_benchmark.tasks) == 1 37 | assert v1_benchmark.tasks[0].description == "Test Task" 38 | assert v1_benchmark.owner_id == "owner@example.com" 39 | 40 | 41 | def test_benchmark_from_v1(): 42 | v1_task_template = V1TaskTemplate( 43 | id="task1", 44 | description="Test Task", 45 | created=0.0, 46 | owner_id="owner@example.com", 47 | ) 48 | v1_benchmark = V1Benchmark( 49 | id="benchmark1", 50 | name="Test Benchmark", 51 | description="Test Benchmark Description", 52 | tasks=[v1_task_template], 53 | owner_id="owner@example.com", 54 | created=0.0, 55 | ) 56 | 57 | benchmark = Benchmark.from_v1(v1_benchmark) 58 | 59 | assert benchmark.name == "Test Benchmark" 60 | assert benchmark.description == "Test Benchmark Description" 61 | assert len(benchmark.tasks) == 1 62 | assert benchmark.tasks[0].description == "Test Task" 63 | assert benchmark.owner_id == "owner@example.com" 64 | -------------------------------------------------------------------------------- /tests/test_eval.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | 3 | from taskara.benchmark import Benchmark, Eval, TaskTemplate 4 | from taskara.server.models import V1Benchmark, V1Eval, V1Task, V1TaskTemplate 5 | 6 | 7 | def test_eval_creation(): 8 | task_template = TaskTemplate(description="Test Task", owner_id="owner@example.com") 9 | benchmark = Benchmark( 10 | name="Test Benchmark", 11 | description="Test Benchmark Description", 12 | tasks=[task_template], 13 | owner_id="owner@example.com", 14 | ) 15 | eval_instance = Eval(benchmark) 16 | 17 | assert eval_instance.benchmark.name == "Test Benchmark" 18 | assert len(eval_instance.tasks) == 1 19 | assert eval_instance.tasks[0].description == "Test Task" 20 | assert eval_instance.benchmark.owner_id == "owner@example.com" 21 | 22 | 23 | def test_eval_to_v1(): 24 | task_template = TaskTemplate(description="Test Task", owner_id="owner@example.com") 25 | benchmark = Benchmark( 26 | name="Test Benchmark", 27 | description="Test Benchmark Description", 28 | tasks=[task_template], 29 | owner_id="owner@example.com", 30 | ) 31 | eval_instance = Eval(benchmark, owner_id="owner@example.com") 32 | 33 | v1_eval = eval_instance.to_v1() 34 | 35 | assert v1_eval.benchmark.name == "Test Benchmark" 36 | assert len(v1_eval.tasks) == 1 37 | assert v1_eval.tasks[0].description == "Test Task" 38 | assert v1_eval.owner_id == "owner@example.com" 39 | 40 | 41 | def test_eval_from_v1(): 42 | v1_task_template = V1TaskTemplate( 43 | id="task1", description="Test Task", created=0.0, owner_id="owner@example.com" 44 | ) 45 | v1_benchmark = V1Benchmark( 46 | id="benchmark1", 47 | name="Test Benchmark", 48 | description="Test Benchmark Description", 49 | tasks=[v1_task_template], 50 | owner_id="owner@example.com", 51 | created=0.0, 52 | ) 53 | 54 | v1_task = V1Task( 55 | id="123", description="Search for french ducks", owner_id="owner@example.com" 56 | ) 57 | v1_eval = V1Eval( 58 | id="eval1", 59 | benchmark=v1_benchmark, 60 | tasks=[v1_task], 61 | owner_id="owner@example.com", 62 | ) 63 | 64 | eval_instance = Eval.from_v1(v1_eval) 65 | 66 | assert eval_instance.benchmark.name == "Test Benchmark" 67 | assert len(eval_instance.tasks) == 1 68 | assert eval_instance.tasks[0].description == "Search for french ducks" 69 | assert eval_instance.benchmark.owner_id == "owner@example.com" 70 | -------------------------------------------------------------------------------- /tests/test_runtime.py: -------------------------------------------------------------------------------- 1 | import json 2 | import time 3 | import urllib.parse 4 | 5 | from mllm import Prompt, RoleMessage, RoleThread 6 | from namesgenerator import get_random_name 7 | from openai import BaseModel 8 | from skillpacks import ActionEvent, AnnotationReviewable, EnvState, V1Action, V1EnvState 9 | from skillpacks.server.models import V1ActionEvents, V1AnnotationReviewable, V1Episode 10 | from toolfuse.models import V1ToolRef 11 | 12 | from taskara import ( 13 | Benchmark, 14 | ReviewRequirement, 15 | Task, 16 | TaskStatus, 17 | TaskTemplate, 18 | V1Benchmark, 19 | V1Task, 20 | V1TaskTemplate, 21 | ) 22 | from taskara.runtime.process import ProcessConnectConfig, ProcessTrackerRuntime 23 | from taskara.server.models import ( 24 | V1Benchmark, 25 | V1BenchmarkEval, 26 | V1Benchmarks, 27 | V1CreateAnnotationReview, 28 | V1CreateReview, 29 | V1DeviceType, 30 | V1Eval, 31 | V1Evals, 32 | V1PendingReviewers, 33 | V1PendingReviews, 34 | V1Prompt, 35 | V1RoleThread, 36 | V1Tasks, 37 | V1TaskTemplate, 38 | ) 39 | 40 | 41 | def test_process_tracker_runtime(): 42 | runtime = ProcessTrackerRuntime() 43 | assert runtime.name() == "process" 44 | assert runtime.connect_config_type() == ProcessConnectConfig 45 | assert runtime.connect_config().model_dump() == {} 46 | 47 | runtime.refresh() 48 | 49 | name = get_random_name("-") 50 | assert name 51 | 52 | print("running task server ", name) 53 | server = runtime.run(name, auth_enabled=False) 54 | print("task server ", server.__dict__) 55 | 56 | try: 57 | # Create a task 58 | task_data = { 59 | "description": "Search for french ducks", 60 | "assigned_to": "tom@myspace.com", 61 | "labels": {"test": "true"}, 62 | "review_requirements": [ 63 | { 64 | "number_required": 2, 65 | "users": ["anonymous@agentsea.ai"], 66 | "agents": ["agent1", "agent2"], 67 | }, 68 | { 69 | "number_required": 1, 70 | "users": ["tom@myspace.com", "anonymous@agentsea.ai"], 71 | "agents": ["agent3"], 72 | }, 73 | ], 74 | } 75 | status, text = server.call(path="/v1/tasks", method="POST", data=task_data) 76 | print("status: ", status) 77 | print("task created: ", text) 78 | assert status == 200 79 | 80 | task = V1Task.model_validate(json.loads(text)) 81 | assert task.description == "Search for french ducks" 82 | assert task.owner_id == "tom@myspace.com" 83 | task_id = task.id 84 | time.sleep(1) 85 | 86 | # Fetch the task with query parameters and labels passed as a JSON string 87 | labels_query = json.dumps({"test": "true"}) # Encode labels as JSON string 88 | encoded_labels = urllib.parse.quote(labels_query) 89 | 90 | status, text = server.call( 91 | path=f"/v1/tasks?labels={encoded_labels}", method="GET" 92 | ) 93 | print("status: ", status) 94 | print("tasks fetched: ", text) 95 | assert status == 200 96 | 97 | tasks = V1Tasks.model_validate(json.loads(text)) 98 | assert any(t.id == task_id for t in tasks.tasks) 99 | 100 | # Get a specific task 101 | status, text = server.call(path=f"/v1/tasks/{task_id}", method="GET") 102 | print("status: ", status) 103 | print("task fetched: ", text) 104 | assert status == 200 105 | task = V1Task.model_validate(json.loads(text)) 106 | assert task.id == task_id 107 | assert task.assigned_to == "tom@myspace.com" 108 | 109 | # Check the review requirements 110 | status, text = server.call( 111 | path=f"/v1/tasks/{task_id}/pending_reviewers", method="GET" 112 | ) 113 | assert status == 200 114 | pending_reviewers = V1PendingReviewers.model_validate_json(text) 115 | assert pending_reviewers.users is not None 116 | assert len(pending_reviewers.users) == 5 117 | 118 | status, text = server.call(path="/v1/pending_reviews", method="GET") 119 | assert status == 200 120 | pending_reviews = V1PendingReviews.model_validate_json(text) 121 | assert len(pending_reviews.tasks) == 1 122 | 123 | # Approve the task 124 | data = V1CreateReview(approved=True, reason="Approved") 125 | status, text = server.call( 126 | path=f"/v1/tasks/{task_id}/review", method="PUT", data=data.model_dump() 127 | ) 128 | 129 | data = V1CreateReview(approved=True, reason="Approved", reviewer="agent1") 130 | status, text = server.call( 131 | path=f"/v1/tasks/{task_id}/review", method="PUT", data=data.model_dump() 132 | ) 133 | 134 | status, text = server.call( 135 | path=f"/v1/tasks/{task_id}/pending_reviewers", method="GET" 136 | ) 137 | assert status == 200 138 | pending_reviewers = V1PendingReviewers.model_validate_json(text) 139 | assert pending_reviewers.users is not None 140 | assert len(pending_reviewers.users) == 3 141 | 142 | status, text = server.call(path="/v1/pending_reviews", method="GET") 143 | assert status == 200 144 | pending_reviews = V1PendingReviews.model_validate_json(text) 145 | assert len(pending_reviews.tasks) == 0 146 | 147 | # Update the task 148 | update_data = { 149 | "description": "Search for german ducks", 150 | "status": "in progress", 151 | "set_labels": {"test_set": "true"}, 152 | } 153 | status, text = server.call( 154 | path=f"/v1/tasks/{task_id}", method="PUT", data=update_data 155 | ) 156 | print("status: ", status) 157 | print("task updated: ", text) 158 | assert status == 200 159 | task = V1Task.model_validate(json.loads(text)) 160 | assert task.description == "Search for german ducks" 161 | assert task.status == "in progress" 162 | 163 | # Post a message to the task 164 | message_data = { 165 | "role": "user", 166 | "msg": "This is a test message.", 167 | "images": [], 168 | "thread": None, 169 | } 170 | status, _ = server.call( 171 | path=f"/v1/tasks/{task_id}/msg", method="POST", data=message_data 172 | ) 173 | print("status: ", status) 174 | assert status == 200 175 | 176 | # Create a thread 177 | thread_data = {"name": "test-thread", "public": True, "metadata": {}} 178 | status, _ = server.call( 179 | path=f"/v1/tasks/{task_id}/threads", method="POST", data=thread_data 180 | ) 181 | print("create thread status: ", status) 182 | assert status == 200 183 | 184 | # Remove a thread 185 | remove_thread_data = {"id": "test-thread"} 186 | status, _ = server.call( 187 | path=f"/v1/tasks/{task_id}/threads", 188 | method="DELETE", 189 | data=remove_thread_data, 190 | ) 191 | print("remove thread status: ", status) 192 | assert status == 200 193 | 194 | # Store a prompt in the task 195 | prompt = Prompt( 196 | thread=RoleThread( 197 | name="test-thread", 198 | public=True, 199 | ), 200 | response=RoleMessage( 201 | role="assistant", 202 | text="This is a test response", 203 | images=[], 204 | ), 205 | ) 206 | print("sending prompt id: ", prompt.id) 207 | status, resp = server.call( 208 | path=f"/v1/tasks/{task_id}/prompts", 209 | method="POST", 210 | data=prompt.to_v1().model_dump(), 211 | ) 212 | print("store prompt status: ", status) 213 | assert status == 200 214 | 215 | status, resp_text = server.call( 216 | path=f"/v1/tasks/{task_id}", 217 | method="GET", 218 | ) 219 | print("get task status: ", status) 220 | assert status == 200 221 | task = V1Task.model_validate(json.loads(resp_text)) 222 | print("task: ", task) 223 | 224 | print("store prompt response: ", resp) 225 | 226 | # Approve a prompt 227 | prompt_id = json.loads(resp)["id"] 228 | 229 | print("prompt id: ", prompt_id) 230 | status, _ = server.call( 231 | path=f"/v1/tasks/{task_id}/prompts/{prompt_id}/approve", method="POST" 232 | ) 233 | print("approve prompt status: ", status) 234 | assert status == 200 235 | 236 | # Write a review 237 | review = V1CreateReview(approved=True, reason="test") 238 | status, _ = server.call( 239 | path=f"/v1/tasks/{task_id}/review", 240 | method="PUT", 241 | data=review.model_dump(), 242 | ) 243 | print("store action status: ", status) 244 | assert status == 200 245 | 246 | # Store an action event 247 | action_event = ActionEvent( 248 | state=EnvState(images=["https://test.img"]), 249 | action=V1Action(name="test", parameters={}), 250 | tool=V1ToolRef(module="test", type="test"), 251 | prompt=prompt, 252 | ) 253 | 254 | status, _ = server.call( 255 | path=f"/v1/tasks/{task_id}/actions", 256 | method="POST", 257 | data=action_event.to_v1().model_dump(), 258 | ) 259 | print("store action status: ", status) 260 | assert status == 200 261 | 262 | status, resp_text = server.call( 263 | path=f"/v1/tasks/{task_id}/actions", 264 | method="GET", 265 | ) 266 | print("get task status: ", status) 267 | assert status == 200 268 | events = V1ActionEvents.model_validate(json.loads(resp_text)) 269 | print("events: ", events) 270 | assert len(events.events) > 0 271 | 272 | # First, find the `action_id` from previously stored action event 273 | status, resp_text = server.call( 274 | path=f"/v1/tasks/{task_id}/actions", 275 | method="GET", 276 | ) 277 | assert status == 200 278 | events = V1ActionEvents.model_validate_json(resp_text) 279 | assert len(events.events) > 0 280 | action_id = events.events[0].id 281 | 282 | # Create an annotation 283 | annotation = V1AnnotationReviewable( 284 | key="test", 285 | value="This is a test annotation", 286 | annotator="tester@example.com", 287 | ) 288 | status, resp_text = server.call( 289 | path=f"/v1/tasks/{task_id}/actions/{action_id}/annotations", 290 | method="POST", 291 | data=annotation.model_dump(), 292 | ) 293 | assert status == 200 294 | 295 | # Retrieve the `annotation_id` 296 | # Assuming the response contains the `annotation_id` or we can fetch it 297 | resp_json = json.loads(resp_text) 298 | annotation_id = resp_json.get("id") # Adjust based on actual response 299 | 300 | # Now, review the annotation 301 | review = V1CreateAnnotationReview( 302 | approved=True, 303 | reviewer="reviewer@example.com", 304 | reviewer_type="human", 305 | reason="Annotation looks good", 306 | ) 307 | status, resp_text = server.call( 308 | path=f"/v1/tasks/{task_id}/actions/{action_id}/annotations/{annotation_id}/review", 309 | method="POST", 310 | data=review.model_dump(), 311 | ) 312 | assert status == 200 313 | 314 | # Verify that the annotation review was recorded properly 315 | # Fetch the action event and check the annotation's review status 316 | status, resp_text = server.call( 317 | path=f"/v1/tasks/{task_id}/actions", 318 | method="GET", 319 | ) 320 | assert status == 200 321 | events = V1ActionEvents.model_validate_json(resp_text) 322 | action_event = next((e for e in events.events if e.id == action_id), None) 323 | assert action_event is not None 324 | 325 | # Assuming `action_event` has a `reviewables` field which contains annotations 326 | annotation = next( 327 | (a for a in action_event.reviewables if a.id == annotation_id), None 328 | ) 329 | assert annotation is not None 330 | assert annotation.reviews[0].approved 331 | assert annotation.reviews[0].reviewer == "reviewer@example.com" 332 | assert annotation.reviews[0].reason == "Annotation looks good" 333 | 334 | print("getting remote task") 335 | found_task = Task.get(id=task_id, remote=f"http://localhost:{server.port}") 336 | 337 | # Delete the task 338 | status, _ = server.call(path=f"/v1/tasks/{task_id}", method="DELETE") 339 | print("delete task status: ", status) 340 | assert status == 200 341 | 342 | print("creating a new task") 343 | 344 | class Expected(BaseModel): 345 | foo: str 346 | bar: int 347 | 348 | new_task = Task( 349 | description="a good test", 350 | remote=f"http://localhost:{server.port}", 351 | expect=Expected, 352 | review_requirements=[ 353 | ReviewRequirement(number_required=1, users=["tom@myspace.com"]) 354 | ], 355 | ) 356 | print("created a new task: ", new_task.id) 357 | 358 | action_event = ActionEvent( 359 | state=EnvState(images=["https://test.img"]), 360 | action=V1Action(name="test", parameters={}), 361 | tool=V1ToolRef(module="test", type="test"), 362 | prompt=prompt, 363 | ) 364 | new_task.record_action_event(action_event) 365 | 366 | status, resp_text = server.call( 367 | path=f"/v1/tasks/{new_task.id}/actions", 368 | method="GET", 369 | ) 370 | print("get task status: ", status) 371 | assert status == 200 372 | events = V1ActionEvents.model_validate(json.loads(resp_text)) 373 | print("events: ", events) 374 | assert len(events.events) > 0 375 | 376 | # Check the review requirements 377 | status, text = server.call( 378 | path=f"/v1/tasks/{new_task.id}/pending_reviewers", method="GET" 379 | ) 380 | assert status == 200 381 | pending_reviewers = V1PendingReviewers.model_validate_json(text) 382 | assert pending_reviewers.users is not None 383 | assert len(pending_reviewers.users) == 1 384 | 385 | status, text = server.call(path="/v1/pending_reviews", method="GET") 386 | assert status == 200 387 | pending_reviews = V1PendingReviews.model_validate_json(text) 388 | assert len(pending_reviews.tasks) == 1 389 | 390 | # Check that we can update the task 391 | new_task.refresh() 392 | new_task.status = TaskStatus.CANCELED 393 | new_task.save() 394 | 395 | status, resp_text = server.call( 396 | path=f"/v1/tasks/{new_task.id}/actions", 397 | method="GET", 398 | ) 399 | print("get task status: ", status) 400 | assert status == 200 401 | events = V1ActionEvents.model_validate(json.loads(resp_text)) 402 | print("events: ", events) 403 | assert len(events.events) > 0 404 | 405 | # Delete all actions associated with the task 406 | status, _ = server.call( 407 | path=f"/v1/tasks/{new_task.id}/actions", 408 | method="DELETE", 409 | ) 410 | print("delete all actions status: ", status) 411 | assert status == 200 412 | 413 | # Verify that no actions remain 414 | status, resp_text = server.call( 415 | path=f"/v1/tasks/{new_task.id}/actions", 416 | method="GET", 417 | ) 418 | print("get actions after deletion status: ", status) 419 | assert status == 200 420 | events = V1ActionEvents.model_validate(json.loads(resp_text)) 421 | print("events after deletion: ", events) 422 | assert len(events.events) == 0 423 | 424 | # Benchmarks 425 | tpl0 = TaskTemplate( 426 | description="A good test 0", 427 | device_type=V1DeviceType(name="desktop"), 428 | owner_id="tom@myspace.com", 429 | ) 430 | tpl1 = TaskTemplate( 431 | description="A good test 1", 432 | device_type=V1DeviceType(name="mobile"), 433 | owner_id="tom@myspace.com", 434 | ) 435 | bench = Benchmark( 436 | name="test-bench", 437 | description="A good benchmark", 438 | tasks=[tpl0, tpl1], 439 | owner_id="tom@myspace.com", 440 | ) 441 | status, _ = server.call( 442 | path="/v1/benchmarks", method="POST", data=bench.to_v1().model_dump() 443 | ) 444 | assert status == 200 445 | 446 | status, text = server.call( 447 | path="/v1/benchmarks", 448 | method="GET", 449 | ) 450 | benchmarks = V1Benchmarks.model_validate_json(text) 451 | assert benchmarks.benchmarks[0].description == "A good benchmark" 452 | 453 | status, text = server.call( 454 | path=f"/v1/benchmarks/{benchmarks.benchmarks[0].id}/eval", 455 | method="POST", 456 | data=V1BenchmarkEval( 457 | assigned_to="test_agent", assigned_type="pizza" 458 | ).model_dump(), 459 | ) 460 | assert status == 200 461 | 462 | v1eval = V1Eval.model_validate_json(text) 463 | assert v1eval.owner_id == "tom@myspace.com" 464 | assert v1eval.assigned_to == "test_agent" 465 | assert v1eval.assigned_type == "pizza" 466 | 467 | status, text = server.call( 468 | path="/v1/evals", 469 | method="GET", 470 | ) 471 | evals = V1Evals.model_validate_json(text) 472 | assert evals.evals[0].owner_id == "tom@myspace.com" 473 | 474 | except: 475 | print(server.logs()) 476 | raise 477 | 478 | finally: 479 | # Ensure the server is deleted 480 | try: 481 | server.delete() 482 | except: 483 | pass 484 | -------------------------------------------------------------------------------- /tests/test_task.py: -------------------------------------------------------------------------------- 1 | from devicebay import V1Device 2 | from pydantic import BaseModel 3 | from threadmem import RoleThread 4 | 5 | from taskara import Task 6 | 7 | 8 | class TestConnectConfig(BaseModel): 9 | a: str 10 | b: int 11 | 12 | 13 | # Test the thread creation functionality within the Task class 14 | def test_create_thread(): 15 | class Expected(BaseModel): 16 | foo: str 17 | bar: int 18 | 19 | task = Task( 20 | description="Test Task", 21 | owner_id="owner123", 22 | id="task123", 23 | expect=Expected, 24 | device=V1Device(type="desktop", config=TestConnectConfig(a="a", b=1)), 25 | ) 26 | assert len(task.threads) == 1 27 | 28 | # Directly call the method that doesn't involve remote calls 29 | task.create_thread(name="New Local Thread", public=True) 30 | 31 | # Verify a new thread is added 32 | assert len(task.threads) == 2 33 | # Verify the properties of the newly created thread 34 | new_thread = task.threads[-1] 35 | assert new_thread.name == "New Local Thread" 36 | assert new_thread.public is True 37 | 38 | task.refresh() 39 | 40 | found = Task.find(id=task.id) 41 | assert len(found) == 1 42 | print("\nfound: ", found[0].__dict__) 43 | assert found[0].device.config["a"] == "a" # type: ignore 44 | assert found[0].device.config["b"] == 1 # type: ignore 45 | 46 | 47 | # Test posting a message to a thread within the Task class 48 | def test_post_message(): 49 | class Expected(BaseModel): 50 | foo: str 51 | bar: int 52 | 53 | task = Task( 54 | description="Test Task 2", owner_id="owner1234", id="task1234", expect=Expected 55 | ) 56 | 57 | # Directly call the method that doesn't involve remote calls 58 | task.create_thread(name="Prompt", public=True) 59 | messages = task.messages(thread="Prompt") 60 | print("initial messages: ", messages) 61 | assert len(messages) == 0 62 | 63 | # Act: Post a message to the thread 64 | task.post_message(role="user", msg="Test Message", thread="Prompt") 65 | messages = task.messages(thread="Prompt") 66 | assert len(messages) == 1 67 | message = messages[0] 68 | assert message.text == "Test Message" 69 | assert message.role == "user" 70 | 71 | threads = RoleThread.find(name="Prompt") 72 | assert len(threads) == 1 73 | thread = threads[0] 74 | 75 | messages = thread.messages() 76 | assert len(messages) == 1 77 | message = messages[0] 78 | assert message.text == "Test Message" 79 | assert message.role == "user" 80 | 81 | task.post_message(role="moderator", msg="Test Message 5") 82 | messages = task.messages() 83 | assert len(messages) == 1 84 | message = messages[0] 85 | 86 | assert message.text == "Test Message 5" 87 | assert message.role == "moderator" 88 | 89 | def test_find_many_lite(): 90 | # Create three tasks 91 | task1 = Task(description="Task 1", owner_id="owner1", id="task1") 92 | task2 = Task(description="Task 2", owner_id="owner1", id="task2") 93 | task3 = Task(description="Task 3", owner_id="owner2", id="task3") 94 | 95 | # Manually save tasks to ensure they're committed to the database 96 | # Note: The Task class's save method should handle this, so just calling them is enough. 97 | task1.save() 98 | task2.save() 99 | task3.save() 100 | 101 | # Test retrieving tasks by IDs 102 | found_tasks = Task.find_many_lite(task_ids=["task1", "task2"]) 103 | print(f"found tasks {found_tasks}", flush=True) 104 | # Verify that only task1 and task2 are returned 105 | assert len(found_tasks) == 2 106 | found_ids = [t.id for t in found_tasks] 107 | assert "task1" in found_ids 108 | assert "task2" in found_ids 109 | assert "task3" not in found_ids 110 | 111 | # Test retrieving a subset 112 | found_task = Task.find_many_lite(task_ids=["task1"]) 113 | assert len(found_task) == 1 114 | assert found_task[0].id == "task1" 115 | 116 | # Test retrieving no tasks 117 | found_none = Task.find_many_lite(task_ids=["nonexistent"]) 118 | assert len(found_none) == 0 119 | 120 | def test_find_many_lite_with_reviews_and_reqs(): 121 | """ 122 | Verifies that find_many_lite correctly returns tasks along with: 123 | - Their parameters 124 | - Their associated Reviews 125 | - Their associated ReviewRequirements 126 | in a single batched lookup. 127 | """ 128 | 129 | # ----------------------------- 130 | # 1) Create sample tasks 131 | # ----------------------------- 132 | task4 = Task(description="Task 4", owner_id="owner4", id="task4") 133 | task5 = Task(description="Task 5", owner_id="owner4", id="task5") 134 | task6 = Task(description="Task 6", owner_id="owner5", id="task6") 135 | 136 | # Give each task some parameters 137 | task4.parameters = {"foo": "bar4", "alpha": 123} 138 | task5.parameters = {"foo": "bar5", "beta": 999} 139 | task6.parameters = {"foo": "bar6", "gamma": 42} 140 | 141 | # ----------------------------- 142 | # 2) Create Reviews for tasks 143 | # ----------------------------- 144 | from skillpacks.review import Review, Resource 145 | 146 | review_a = Review( 147 | reviewer="alice", 148 | approved=True, 149 | resource_type=Resource.TASK.value, 150 | resource_id="task4", # Associate with task4 151 | reason="Review A - All good", 152 | ) 153 | review_b = Review( 154 | reviewer="bob", 155 | approved=False, 156 | resource_type=Resource.TASK.value, 157 | resource_id="task4", # Also task4 158 | reason="Review B - Some issues", 159 | ) 160 | review_c = Review( 161 | reviewer="charlie", 162 | approved=True, 163 | resource_type=Resource.TASK.value, 164 | resource_id="task5", # For task5 165 | reason="Review C - LGTM", 166 | ) 167 | review_d = Review( 168 | reviewer="david", 169 | approved=True, 170 | resource_type=Resource.TASK.value, 171 | resource_id="task6", # For task6 172 | reason="Review D - Quick check", 173 | ) 174 | 175 | # Save Reviews to DB so they have real IDs 176 | review_a.save() 177 | review_b.save() 178 | review_c.save() 179 | review_d.save() 180 | 181 | # Link them to the tasks 182 | task4.reviews = [review_a, review_b] 183 | task5.reviews = [review_c] 184 | task6.reviews = [review_d] 185 | 186 | # ----------------------------- 187 | # 3) Create ReviewRequirements 188 | # ----------------------------- 189 | from taskara.review import ReviewRequirement 190 | 191 | req_a = ReviewRequirement( 192 | task_id="task4", 193 | number_required=1, 194 | users=["alice", "bob"], 195 | ) 196 | req_b = ReviewRequirement( 197 | task_id="task5", 198 | number_required=2, 199 | users=["charlie", "someone_else"], 200 | ) 201 | req_c = ReviewRequirement( 202 | task_id="task6", 203 | number_required=1, 204 | agents=["agent_1"], 205 | ) 206 | 207 | req_a.save() 208 | req_b.save() 209 | req_c.save() 210 | 211 | # Link them 212 | task4.review_requirements = [req_a] 213 | task5.review_requirements = [req_b] 214 | task6.review_requirements = [req_c] 215 | 216 | # ----------------------------- 217 | # 4) Save all tasks to DB 218 | # ----------------------------- 219 | task4.save() 220 | task5.save() 221 | task6.save() 222 | 223 | # ----------------------------- 224 | # 5) Exercise find_many_lite 225 | # ----------------------------- 226 | found = Task.find_many_lite(task_ids=["task4", "task5", "task6"]) 227 | assert len(found) == 3, "Should return all three tasks" 228 | 229 | # Turn the list into a dict for easy lookup by ID 230 | found_dict = {t.id: t for t in found} 231 | 232 | # ----------------------------- 233 | # 6) Verify Task 4 data 234 | # ----------------------------- 235 | t4 = found_dict["task4"] 236 | assert t4.parameters["foo"] == "bar4" if t4.parameters else None 237 | assert t4.parameters["alpha"] == 123 if t4.parameters else None 238 | assert len(t4.reviews) == 2, "Task 4 should have 2 reviews" 239 | reviewers_t4 = {r.reviewer for r in t4.reviews} 240 | assert reviewers_t4 == {"alice", "bob"} 241 | assert len(t4.review_requirements) == 1, "Task 4 should have 1 review requirement" 242 | assert t4.review_requirements[0].users == ["alice", "bob"] 243 | 244 | # ----------------------------- 245 | # 7) Verify Task 5 data 246 | # ----------------------------- 247 | t5 = found_dict["task5"] 248 | assert t5.parameters["foo"] == "bar5" if t5.parameters else None 249 | assert t5.parameters["beta"] == 999 if t5.parameters else None 250 | assert len(t5.reviews) == 1, "Task 5 should have 1 review" 251 | assert t5.reviews[0].reviewer == "charlie" 252 | assert len(t5.review_requirements) == 1, "Task 5 should have 1 review requirement" 253 | req5 = t5.review_requirements[0] 254 | assert req5.number_required == 2 255 | assert req5.users == ["charlie", "someone_else"] 256 | 257 | # ----------------------------- 258 | # 8) Verify Task 6 data 259 | # ----------------------------- 260 | t6 = found_dict["task6"] 261 | assert t6.parameters["foo"] == "bar6" if t6.parameters else None 262 | assert t6.parameters["gamma"] == 42 if t6.parameters else None 263 | assert len(t6.reviews) == 1, "Task 6 should have 1 review" 264 | assert t6.reviews[0].reviewer == "david" 265 | assert len(t6.review_requirements) == 1, "Task 6 should have 1 review requirement" 266 | req6 = t6.review_requirements[0] 267 | assert req6.agents == ["agent_1"] 268 | assert req6.number_required == 1 269 | 270 | print("test_find_many_lite_with_reviews_and_reqs passed!") -------------------------------------------------------------------------------- /tests/test_tpl.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | from devicebay import V1Device, V1DeviceType 3 | from pydantic import BaseModel 4 | 5 | from taskara.benchmark import TaskTemplate 6 | from taskara.server.models import V1TaskTemplate 7 | from taskara.task import Task 8 | 9 | 10 | def test_task_template_creation(): 11 | task_template = TaskTemplate( 12 | description="Test Task", 13 | max_steps=10, 14 | owner_id="owner@example.com", 15 | device=V1Device(type="device1"), 16 | device_type=V1DeviceType(name="device_type1"), 17 | parameters={"param1": "value1"}, 18 | labels={"label1": "value1"}, 19 | tags=["tag1", "tag2"], 20 | ) 21 | 22 | assert task_template.description == "Test Task" 23 | assert task_template.max_steps == 10 24 | assert task_template.owner_id == "owner@example.com" 25 | assert task_template.device.type == "device1" # type: ignore 26 | assert task_template.device_type.name == "device_type1" # type: ignore 27 | assert task_template.parameters == {"param1": "value1"} 28 | assert task_template.labels == {"label1": "value1"} 29 | assert task_template.tags == ["tag1", "tag2"] 30 | 31 | 32 | def test_task_template_to_task(): 33 | task_template = TaskTemplate( 34 | description="Test Task", 35 | max_steps=10, 36 | owner_id="owner@example.com", 37 | ) 38 | 39 | task = task_template.to_task() 40 | 41 | assert task.description == "Test Task" 42 | assert task.max_steps == 10 43 | assert task.owner_id == "owner@example.com" 44 | 45 | 46 | def test_task_template_from_task(): 47 | task = Task( 48 | description="Test Task", 49 | max_steps=10, 50 | owner_id="owner@example.com", 51 | ) 52 | 53 | task_template = TaskTemplate.from_task(task) 54 | 55 | assert task_template.description == "Test Task" 56 | assert task_template.max_steps == 10 57 | assert task_template.owner_id == "owner@example.com" 58 | 59 | 60 | def test_task_template_to_v1(): 61 | task_template = TaskTemplate( 62 | description="Test Task", 63 | max_steps=10, 64 | owner_id="owner@example.com", 65 | ) 66 | 67 | v1_task_template = task_template.to_v1() 68 | 69 | assert v1_task_template.description == "Test Task" 70 | assert v1_task_template.max_steps == 10 71 | assert v1_task_template.owner_id == "owner@example.com" 72 | 73 | 74 | def test_task_template_from_v1(): 75 | v1_task_template = V1TaskTemplate( 76 | id="task1", 77 | description="Test Task", 78 | max_steps=10, 79 | owner_id="owner@example.com", 80 | created=0.0, 81 | ) 82 | 83 | task_template = TaskTemplate.from_v1(v1_task_template) 84 | 85 | assert task_template.description == "Test Task" 86 | assert task_template.max_steps == 10 87 | assert task_template.owner_id == "owner@example.com" 88 | --------------------------------------------------------------------------------