├── .flake8
├── .github
└── workflows
│ ├── poetry-lint.yml
│ ├── poetry-tests.yml
│ └── publish.yml
├── .gitignore
├── CODE_OF_CONDUCT.md
├── CONTRIBUTING.md
├── Dockerfile
├── Dockerfile.prod
├── LICENSE
├── Makefile
├── README.md
├── cloudbuild.yaml
├── deploy
└── base
│ ├── cert.yaml
│ ├── deployment_api.yaml
│ ├── ingress.yaml
│ ├── kustomization.yaml
│ ├── pg.yaml
│ └── service.yaml
├── logging.conf
├── poetry.lock
├── pyproject.toml
├── pytest.ini
├── scripts
└── lint.py
├── taskara
├── __init__.py
├── agent.py
├── auth
│ ├── default.py
│ ├── key.py
│ ├── provider.py
│ └── transport.py
├── benchmark.py
├── cli
│ └── main.py
├── config.py
├── db
│ ├── conn.py
│ ├── models.py
│ └── redis_connection.py
├── env.py
├── flag.py
├── img.py
├── metrics.py
├── review.py
├── runtime
│ ├── base.py
│ ├── docker.py
│ ├── kube.py
│ ├── load.py
│ └── process.py
├── server
│ ├── app.py
│ ├── models.py
│ └── router
│ │ ├── benchmarks.py
│ │ └── tasks.py
├── task.py
└── util.py
└── tests
├── benchmarks
└── airbnb.yaml
├── test_bench.py
├── test_eval.py
├── test_runtime.py
├── test_task.py
└── test_tpl.py
/.flake8:
--------------------------------------------------------------------------------
1 | [flake8]
2 | max-line-length = 120
3 | ignore = E203, E266, E501, W503
4 | max-complexity = 18
5 | select = B,C,E,F,W,T4,B9
6 | exclude =
7 | .git,
8 | __pycache__,
9 | build,
10 | dist,
11 | .venv,
12 | .tox,
13 | .mypy_cache,
14 | .pytest_cache,
15 | .vscode,
16 |
--------------------------------------------------------------------------------
/.github/workflows/poetry-lint.yml:
--------------------------------------------------------------------------------
1 | name: Poetry Lint
2 |
3 | on:
4 | push:
5 | branches: [ '**' ]
6 | pull_request:
7 | branches: [ '**' ]
8 |
9 | jobs:
10 | lint:
11 |
12 | runs-on: ubuntu-latest
13 |
14 | steps:
15 | - uses: actions/checkout@v2
16 |
17 | - name: Set up Python
18 | uses: actions/setup-python@v2
19 | with:
20 | python-version: '3.10'
21 |
22 | - name: Install Poetry
23 | uses: snok/install-poetry@v1
24 |
25 | - name: Install dependencies
26 | run: |
27 | poetry install
28 |
29 | - name: Run lint
30 | run: |
31 | poetry run lint
32 |
33 |
--------------------------------------------------------------------------------
/.github/workflows/poetry-tests.yml:
--------------------------------------------------------------------------------
1 | name: Poetry Tests
2 |
3 | on:
4 | push:
5 | branches: ["**"]
6 | pull_request:
7 | branches: ["**"]
8 |
9 | jobs:
10 | test:
11 | runs-on: ubuntu-latest
12 |
13 | steps:
14 | - uses: actions/checkout@v2
15 |
16 | - name: Set up Python
17 | uses: actions/setup-python@v2
18 | with:
19 | python-version: "3.10"
20 |
21 | - name: Install Poetry
22 | uses: snok/install-poetry@v1
23 |
24 | - name: Install dependencies
25 | run: |
26 | poetry install
27 |
28 | - name: Run tests
29 | run: poetry run pytest --ignore=tests/test_runtime.py # TODO: this doesn't run well in CI
30 |
--------------------------------------------------------------------------------
/.github/workflows/publish.yml:
--------------------------------------------------------------------------------
1 | name: Publish Python package to PyPI
2 |
3 | on:
4 | push:
5 | tags:
6 | - 'v*'
7 |
8 | jobs:
9 | build-and-publish:
10 | runs-on: ubuntu-latest
11 | steps:
12 | - uses: actions/checkout@v2
13 | - name: Set up Python
14 | uses: actions/setup-python@v2
15 | with:
16 | python-version: '3.10'
17 | - name: Install dependencies
18 | run: |
19 | python -m pip install --upgrade pip
20 | pip install setuptools wheel twine poetry
21 | - name: Build and publish to PyPI
22 | env:
23 | PYPI_USERNAME: __token__
24 | PYPI_PASSWORD: ${{ secrets.PYPI_PASSWORD }}
25 | run: |
26 | poetry build
27 | poetry publish -u $PYPI_USERNAME -p $PYPI_PASSWORD
28 |
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | # Byte-compiled / optimized / DLL files
2 | __pycache__/
3 | *.py[cod]
4 | *$py.class
5 |
6 | # C extensions
7 | *.so
8 |
9 | # Distribution / packaging
10 | .Python
11 | build/
12 | develop-eggs/
13 | dist/
14 | downloads/
15 | eggs/
16 | .eggs/
17 | lib/
18 | lib64/
19 | parts/
20 | sdist/
21 | var/
22 | wheels/
23 | share/python-wheels/
24 | *.egg-info/
25 | .installed.cfg
26 | *.egg
27 | MANIFEST
28 |
29 | # PyInstaller
30 | # Usually these files are written by a python script from a template
31 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
32 | *.manifest
33 | *.spec
34 |
35 | # Installer logs
36 | pip-log.txt
37 | pip-delete-this-directory.txt
38 |
39 | # Unit test / coverage reports
40 | htmlcov/
41 | .tox/
42 | .nox/
43 | .coverage
44 | .coverage.*
45 | .cache
46 | nosetests.xml
47 | coverage.xml
48 | *.cover
49 | *.py,cover
50 | .hypothesis/
51 | .pytest_cache/
52 | cover/
53 |
54 | # Translations
55 | *.mo
56 | *.pot
57 |
58 | # Django stuff:
59 | *.log
60 | local_settings.py
61 | db.sqlite3
62 | db.sqlite3-journal
63 |
64 | # Flask stuff:
65 | instance/
66 | .webassets-cache
67 |
68 | # Scrapy stuff:
69 | .scrapy
70 |
71 | # Sphinx documentation
72 | docs/_build/
73 |
74 | # PyBuilder
75 | .pybuilder/
76 | target/
77 |
78 | # Jupyter Notebook
79 | .ipynb_checkpoints
80 |
81 | # IPython
82 | profile_default/
83 | ipython_config.py
84 |
85 | # pyenv
86 | # For a library or package, you might want to ignore these files since the code is
87 | # intended to run in multiple environments; otherwise, check them in:
88 | # .python-version
89 |
90 | # pipenv
91 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
92 | # However, in case of collaboration, if having platform-specific dependencies or dependencies
93 | # having no cross-platform support, pipenv may install dependencies that don't work, or not
94 | # install all needed dependencies.
95 | #Pipfile.lock
96 |
97 | # poetry
98 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
99 | # This is especially recommended for binary packages to ensure reproducibility, and is more
100 | # commonly ignored for libraries.
101 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
102 | #poetry.lock
103 |
104 | # pdm
105 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
106 | #pdm.lock
107 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
108 | # in version control.
109 | # https://pdm.fming.dev/#use-with-ide
110 | .pdm.toml
111 |
112 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
113 | __pypackages__/
114 |
115 | # Celery stuff
116 | celerybeat-schedule
117 | celerybeat.pid
118 |
119 | # SageMath parsed files
120 | *.sage.py
121 |
122 | # Environments
123 | .env
124 | .venv
125 | env/
126 | venv/
127 | ENV/
128 | env.bak/
129 | venv.bak/
130 | .envrc
131 | .vscode/
132 |
133 | # Spyder project settings
134 | .spyderproject
135 | .spyproject
136 |
137 | # Rope project settings
138 | .ropeproject
139 |
140 | # mkdocs documentation
141 | /site
142 |
143 | # mypy
144 | .mypy_cache/
145 | .dmypy.json
146 | dmypy.json
147 |
148 | # Pyre type checker
149 | .pyre/
150 |
151 | # pytype static type analyzer
152 | .pytype/
153 |
154 | # Cython debug symbols
155 | cython_debug/
156 |
157 | # PyCharm
158 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can
159 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
160 | # and can be added to the global gitignore or merged into this file. For a more nuclear
161 | # option (not recommended) you can uncomment the following to ignore the entire idea folder.
162 | #.idea/
163 | .data
164 | data
165 | .agentsea
166 | scratch.ipynb
167 |
--------------------------------------------------------------------------------
/CODE_OF_CONDUCT.md:
--------------------------------------------------------------------------------
1 | # Code of Conduct
2 |
3 | ## Our Pledge
4 |
5 | In the interest of fostering an open and welcoming environment, we as contributors and maintainers pledge to make participation in our project and our community a harassment-free experience for everyone, regardless of age, body size, disability, ethnicity, sex characteristics, gender identity and expression, level of experience, education, socio-economic status, nationality, personal appearance, race, religion, or sexual identity and orientation.
6 |
7 | ## Our Standards
8 |
9 | Examples of behavior that contributes to creating a positive environment include:
10 |
11 | - Using welcoming and inclusive language
12 | - Being respectful of differing viewpoints and experiences
13 | - Gracefully accepting constructive criticism
14 | - Focusing on what is best for the community
15 | - Showing empathy towards other community members
16 |
17 | Examples of unacceptable behavior by participants include:
18 |
19 | - The use of sexualized language or imagery and unwelcome sexual attention or advances
20 | - Trolling, insulting/derogatory comments, and personal or political attacks
21 | - Public or private harassment
22 | - Publishing others' private information, such as a physical or email address, without explicit permission
23 | - Other conduct which could reasonably be considered inappropriate in a professional setting
24 |
25 | ## Our Responsibilities
26 |
27 | Project maintainers are responsible for clarifying the standards of acceptable behavior and are expected to take appropriate and fair corrective action in response to any instances of unacceptable behavior.
28 |
29 | Project maintainers have the right and responsibility to remove, edit, or reject comments, commits, code, wiki edits, issues, and other contributions that are not aligned with this Code of Conduct, or to ban temporarily or permanently any contributor for other behaviors that they deem inappropriate, threatening, offensive, or harmful.
30 |
31 | ## Scope
32 |
33 | This Code of Conduct applies within all project spaces, including GitHub, and in public spaces when an individual is representing the project or its community. Examples of representing a project or community include using an official project e-mail address, posting via an official social media account, or acting as an appointed representative at an online or offline event.
34 |
35 | ## Enforcement
36 |
37 | Instances of abusive, harassing, or otherwise unacceptable behavior may be reported by contacting the project team at github@kentauros.ai. All complaints will be reviewed and investigated and will result in a response that is deemed necessary and appropriate to the circumstances. The project team is obligated to maintain confidentiality regarding the reporter of an incident. Further details of specific enforcement policies may be posted separately.
38 |
39 | Project maintainers who do not follow or enforce the Code of Conduct in good faith may face temporary or permanent repercussions as determined by other members of the project's leadership.
40 |
41 | ## Attribution
42 |
43 | This Code of Conduct is adapted from the [Contributor Covenant](https://www.contributor-covenant.org), version 2.0, available at https://www.contributor-covenant.org/version/2/0/code_of_conduct.html.
44 |
45 | Community Impact Guidelines were inspired by [Mozilla's code of conduct enforcement ladder](https://github.com/mozilla/diversity).
46 |
47 | For answers to common questions about this code of conduct, see https://www.contributor-covenant.org/faq. Translations are available at https://www.contributor-covenant.org/translations.
48 |
--------------------------------------------------------------------------------
/CONTRIBUTING.md:
--------------------------------------------------------------------------------
1 | # Contributing
2 |
3 | First off, thank you for considering contributing to this project. It's people like you that make it such a great tool.
4 |
5 | ## Code of Conduct
6 |
7 | This project adheres to a Code of Conduct that we expect project participants to adhere to. Please read [the full text](CODE_OF_CONDUCT.md) so that you can understand what actions will and will not be tolerated.
8 |
9 | ## What we are looking for
10 |
11 | This is an open-source project, and we welcome contributions of all kinds: new features, bug fixes, documentation, examples, or enhancements to existing features. We are always thrilled to receive contributions from the community.
12 |
13 | ## How to contribute
14 |
15 | If you've never contributed to an open-source project before, here are a few steps to get you started:
16 |
17 | ### Reporting Issues
18 |
19 | Before submitting a bug report or feature request, check to make sure it hasn't already been submitted. You can search through existing issues and pull requests to see if someone has reported one similar to yours.
20 |
21 | When you are creating a bug report, please include as much detail as possible.
22 |
23 | ### Pull Requests
24 |
25 | - Fork the repository and create your branch from `main`.
26 | - If you've added code that should be tested, add tests.
27 | - If you've changed APIs, update the documentation.
28 | - Ensure the test suite passes.
29 | - Make sure your code lints.
30 | - Issue that pull request!
31 |
32 | ### Getting started
33 |
34 | For something that is bigger than a one or two-line fix:
35 |
36 | 1. Create your own fork of the code.
37 | 2. Do the changes in your fork.
38 | 3. If you like the change and think the project could use it:
39 | - Be sure you have followed the code style for the project.
40 | - Note the Code of Conduct.
41 | - Send a pull request.
42 |
43 | ## How to report a bug
44 |
45 | If you find a security vulnerability, do NOT open an issue. Email github@kentauros.ai instead.
46 |
47 | In order to help us understand and resolve your issue quickly, please include as much information as possible, including:
48 |
49 | - A quick summary and/or background
50 | - Steps to reproduce
51 | - Be specific!
52 | - Give a sample code if you can.
53 | - What you expected would happen
54 | - What actually happens
55 | - Notes (possibly including why you think this might be happening or stuff you tried that didn't work)
56 |
57 | People *love* thorough bug reports. I'm not even kidding.
58 |
59 | ## How to suggest a feature or enhancement
60 |
61 | If you find yourself wishing for a feature that doesn't exist in the project, you are probably not alone. There are bound to be others out there with similar needs. Open an issue on our issues list on GitHub, which describes the feature you would like to see, why you need it, and how it should work.
62 |
63 | ## Code review process
64 |
65 | The core team looks at Pull Requests on a regular basis in a bi-weekly triage meeting. After feedback has been given, we expect responses within two weeks. After two weeks, we may close the pull request if it isn't showing any activity.
66 |
67 | ## Community
68 |
69 | Discussions about the project take place in this repository's Issues and Pull Requests sections. Anybody is welcome to join these conversations.
70 |
71 | Wherever possible, we use GitHub to discuss changes and keep the decision-making process open.
72 |
73 | ## Thank you!
74 |
75 | Thank you for contributing!
76 |
77 |
--------------------------------------------------------------------------------
/Dockerfile:
--------------------------------------------------------------------------------
1 | FROM python:3.10-slim-buster
2 |
3 | RUN apt-get update && apt-get install -y openssh-client ntp gcc python3-dev
4 |
5 | RUN pip install poetry
6 |
7 | COPY . /app
8 | WORKDIR /app
9 |
10 | RUN poetry install
11 |
12 | EXPOSE 9070
13 | CMD ["poetry", "run", "python", "-m", "taskara.server.app"]
--------------------------------------------------------------------------------
/Dockerfile.prod:
--------------------------------------------------------------------------------
1 | FROM thehale/python-poetry:1.8.2-py3.10-slim
2 |
3 | COPY . /app
4 | WORKDIR /app
5 |
6 | RUN apt-get update && apt-get install -y openssh-client ntp
7 | RUN poetry install
8 |
9 | EXPOSE 9070
10 |
11 | CMD ["poetry", "run", "gunicorn", "-k", "uvicorn.workers.UvicornWorker", "threadmem.server.app:app", "--workers=4", "--bind", "0.0.0.0:8080", "--log-level", "debug", "--log-config", "logging.conf", "--timeout", "240"]
12 |
13 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2024 Kentauros AI
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/Makefile:
--------------------------------------------------------------------------------
1 | .PHONY: test
2 | test:
3 | rm -rf .agentsea
4 | poetry run pytest -v -s
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
7 |
8 |
Taskara
9 |
10 |
11 | Task management for AI agents
12 |
13 | Explore the docs »
14 |
15 |
16 | View Demo
17 | ·
18 | Report Bug
19 | ·
20 | Request Feature
21 |
22 |
23 |
24 |
25 | ## Installation
26 |
27 | ```sh
28 | pip install taskara
29 | ```
30 |
31 | ## Usage
32 |
33 | Create a task
34 |
35 | ```python
36 | from taskara import Task
37 |
38 | task = Task(
39 | description="Search for the most common varieties of french ducks",
40 | owner_id="delores@agentsea.ai"
41 | )
42 | ```
43 |
44 | Assign the task to an agent
45 |
46 | ```python
47 | task.assigned_to = "roko@agentsea.ai"
48 | ```
49 |
50 | Post a message to the task thread
51 |
52 | ```python
53 | task.post_message("assistant", "Getting started working on this")
54 | task.status = "in progress"
55 | ```
56 |
57 | Create a custom thread for the task
58 |
59 | ```python
60 | task.create_thread("debug")
61 | task.post_message("assistant", "I'll post debug messages to this thread", thread="debug")
62 | task.post_message("assistant", 'My current screenshot', images=["b64img"], thread="debug")
63 | ```
64 |
65 | Store prompts used to accomplish the task
66 |
67 | ```python
68 | from mllm import RoleThread, RoleMessage
69 |
70 | thread = RoleThread()
71 | thread.post(role="system", msg="I am a helpful assistant")
72 |
73 | response = RoleMessage(
74 | role="assistant",
75 | text="How can I help?"
76 | )
77 | task.store_prompt(thread, response, namespace="actions")
78 | ```
79 |
80 | Store the result
81 |
82 | ```python
83 | task.output = "The most common type of french duck is the Rouen"
84 | task.status = "success"
85 | ```
86 |
87 | Save the task
88 |
89 | ```python
90 | task.save()
91 | ```
92 |
93 | ## Tracker
94 |
95 | Taskara comes with a task tracker server which can be run on docker or kubernetes.
96 |
97 | Install surfkit to create a tracker
98 |
99 | ```
100 | pip install surfkit
101 | ```
102 |
103 | Create a tracker
104 |
105 | ```
106 | surfkit create tracker
107 | ```
108 |
109 | List trackers
110 |
111 | ```
112 | surfkit list trackers
113 | ```
114 |
115 | Get tracker logs
116 |
117 | ```
118 | surfkit logs tracker
119 | ```
120 |
121 | Create a task
122 |
123 | ```
124 | surfkit create task --description "Search for french ducks"
125 | ```
126 |
127 | List tasks
128 |
129 | ```
130 | surfkit list tasks
131 | ```
132 |
133 | Get a task
134 |
135 | ```
136 | surfkit get task
137 | ```
138 |
139 | ## Integrations
140 |
141 | Taskara is integrated with:
142 |
143 | - [Surfkit](https://github.com/agentsea/surfkit) A platform for AI agents
144 | - [MLLM](https://github.com/agentsea/mllm) A prompt management, routing, and schema validation library for multimodal LLMs
145 | - [Skillpacks](https://github.com/agentsea/skillpacks) A library to fine tune AI agents on tasks.
146 | - [Threadmem](https://github.com/agentsea/threadmem) A thread management library for AI agents
147 |
148 | ## Community
149 |
150 | Come join us on [Discord](https://discord.gg/hhaq7XYPS6).
151 |
152 | ## Backends
153 |
154 | Thread and prompt storage can be backed by:
155 |
156 | - Sqlite
157 | - Postgresql
158 |
159 | Sqlite will be used by default. To use postgres simply configure the env vars:
160 |
161 | ```sh
162 | DB_TYPE=postgres
163 | DB_NAME=tasks
164 | DB_HOST=localhost
165 | DB_USER=postgres
166 | DB_PASS=abc123
167 | ```
168 |
169 | Thread image storage by default will utilize the db, to configure bucket storage using GCS:
170 |
171 | - Create a bucket with fine grained permissions
172 | - Create a GCP service account JSON with permissions to write to the bucket
173 |
174 | ```sh
175 | export THREAD_STORAGE_SA_JSON='{
176 | "type": "service_account",
177 | ...
178 | }'
179 | export THREAD_STORAGE_BUCKET=my-bucket
180 | ```
181 |
--------------------------------------------------------------------------------
/cloudbuild.yaml:
--------------------------------------------------------------------------------
1 | steps:
2 | - name: "gcr.io/cloud-builders/docker"
3 | entrypoint: "bash"
4 | args:
5 | - "-c"
6 | - |
7 | docker buildx create --name mybuilder --use
8 | docker buildx build \
9 | --platform linux/amd64,linux/arm64 \
10 | -t us-central1-docker.pkg.dev/agentsea-dev/taskara/api:$SHORT_SHA . \
11 | --push \
12 | --cache-from=type=registry,ref=us-central1-docker.pkg.dev/agentsea-dev/taskara/api:cache \
13 | --cache-to=type=registry,ref=us-central1-docker.pkg.dev/agentsea-dev/taskara/api:cache,mode=max
14 | if [ "$BRANCH_NAME" == "main" ]; then
15 | docker buildx build \
16 | --platform linux/amd64,linux/arm64 \
17 | -t us-central1-docker.pkg.dev/agentsea-dev/taskara/api:latest . \
18 | --push \
19 | --cache-from=type=registry,ref=us-central1-docker.pkg.dev/agentsea-dev/taskara/api:cache \
20 | --cache-to=type=registry,ref=us-central1-docker.pkg.dev/agentsea-dev/taskara/api:cache,mode=max
21 | fi
22 |
--------------------------------------------------------------------------------
/deploy/base/cert.yaml:
--------------------------------------------------------------------------------
1 | apiVersion: networking.gke.io/v1
2 | kind: ManagedCertificate
3 | metadata:
4 | name: tasks-cert
5 | spec:
6 | domains:
7 | - api.tasks.my.domain
8 |
--------------------------------------------------------------------------------
/deploy/base/deployment_api.yaml:
--------------------------------------------------------------------------------
1 | apiVersion: apps/v1
2 | kind: Deployment
3 | metadata:
4 | name: tasks-api
5 | spec:
6 | replicas: 1
7 | selector:
8 | matchLabels:
9 | app: tasks-api
10 | template:
11 | metadata:
12 | labels:
13 | app: tasks-api
14 | spec:
15 | containers:
16 | - name: tasks-api
17 | image: us-central1-docker.pkg.dev/agentsea-dev/taskara/api:latest
18 | imagePullPolicy: Always
19 | ports:
20 | - containerPort: 8080
21 | livenessProbe:
22 | httpGet:
23 | path: /
24 | port: 8080
25 | initialDelaySeconds: 20
26 | periodSeconds: 25
27 | readinessProbe:
28 | httpGet:
29 | path: /
30 | port: 8080
31 | initialDelaySeconds: 10
32 | periodSeconds: 15
33 | env:
34 | - name: AGENTSEA_HUB_URL
35 | value: https://api.hub.dev.agentlabs.xyz
36 | - name: DB_USER
37 | value: postgres
38 | - name: DB_PASS
39 | value: "abc12345"
40 | - name: DB_HOST
41 | value: postgres-tasks-service
42 | - name: DB_NAME
43 | value: tasks
44 | - name: DB_TYPE
45 | value: postgres
46 |
--------------------------------------------------------------------------------
/deploy/base/ingress.yaml:
--------------------------------------------------------------------------------
1 | apiVersion: networking.k8s.io/v1
2 | kind: Ingress
3 | metadata:
4 | name: tasks-ingress
5 | annotations:
6 | ingressClassName: "gce"
7 | networking.gke.io/managed-certificates: tasks-cert
8 | kubernetes.io/ingress.global-static-ip-name: tasks-develop
9 | spec:
10 | rules:
11 | - host: api.tasks.my.domain
12 | http:
13 | paths:
14 | - path: /
15 | pathType: Prefix
16 | backend:
17 | service:
18 | name: tasks-api-service
19 | port:
20 | number: 80
21 |
--------------------------------------------------------------------------------
/deploy/base/kustomization.yaml:
--------------------------------------------------------------------------------
1 | resources:
2 | - cert.yaml
3 | - deployment_api.yaml
4 | - ingress.yaml
5 | - pg.yaml
6 | - service.yaml
7 |
--------------------------------------------------------------------------------
/deploy/base/pg.yaml:
--------------------------------------------------------------------------------
1 | apiVersion: apps/v1
2 | kind: Deployment
3 | metadata:
4 | name: postgres-tasks
5 | spec:
6 | replicas: 1
7 | selector:
8 | matchLabels:
9 | app: postgres-tasks
10 | template:
11 | metadata:
12 | labels:
13 | app: postgres-tasks
14 | spec:
15 | containers:
16 | - name: postgres-tasks
17 | image: postgres:16.2
18 | ports:
19 | - containerPort: 5432
20 | env:
21 | - name: POSTGRES_USER
22 | value: "postgres"
23 | - name: POSTGRES_PASSWORD
24 | value: "abc12345"
25 | - name: POSTGRES_DB
26 | value: "tasks"
27 | - name: PGDATA
28 | value: "/var/lib/postgresql/data/pgdata"
29 | volumeMounts:
30 | - mountPath: /var/lib/postgresql/data
31 | name: postgres-storage
32 | volumes:
33 | - name: postgres-storage
34 | persistentVolumeClaim:
35 | claimName: postgres-pvc
36 | ---
37 | apiVersion: v1
38 | kind: Service
39 | metadata:
40 | name: postgres-tasks-service
41 | spec:
42 | type: ClusterIP
43 | selector:
44 | app: postgres-tasks
45 | ports:
46 | - port: 5432
47 | targetPort: 5432
48 | ---
49 | apiVersion: v1
50 | kind: PersistentVolumeClaim
51 | metadata:
52 | name: postgres-tasks-pvc
53 | spec:
54 | accessModes:
55 | - ReadWriteOnce
56 | resources:
57 | requests:
58 | storage: 1Gi
59 |
--------------------------------------------------------------------------------
/deploy/base/service.yaml:
--------------------------------------------------------------------------------
1 | apiVersion: v1
2 | kind: Service
3 | metadata:
4 | name: tasks-api-service
5 | annotations:
6 | beta.cloud.google.com/backend-config: '{"default": "tasks-api-config"}'
7 | cloud.google.com/neg: '{"ingress": true}'
8 | spec:
9 | type: NodePort
10 | selector:
11 | app: tasks-api
12 | ports:
13 | - protocol: TCP
14 | port: 80
15 | targetPort: 8080
16 | ---
17 | apiVersion: cloud.google.com/v1
18 | kind: BackendConfig
19 | metadata:
20 | name: tasks-api-config
21 | spec:
22 | timeoutSec: 21600
23 | healthCheck:
24 | checkIntervalSec: 30
25 | timeoutSec: 5
26 | healthyThreshold: 1
27 | unhealthyThreshold: 2
28 | requestPath: /
29 | port: 8080
30 |
--------------------------------------------------------------------------------
/logging.conf:
--------------------------------------------------------------------------------
1 | [loggers]
2 | keys=root,uvicorn,uvicorn.error,uvicorn.access,uvicorn.asgi
3 |
4 | [handlers]
5 | keys=console
6 |
7 | [formatters]
8 | keys=generic
9 |
10 | [logger_root]
11 | level=DEBUG
12 | handlers=console
13 |
14 | [logger_uvicorn]
15 | level=DEBUG
16 | handlers=console
17 | qualname=uvicorn
18 |
19 | [logger_uvicorn.error]
20 | level=DEBUG
21 | handlers=console
22 | qualname=uvicorn.error
23 | propagate=0
24 |
25 | [logger_uvicorn.asgi]
26 | level=DEBUG
27 | handlers=console
28 | qualname=uvicorn.asgi
29 | propagate=0
30 |
31 | [logger_uvicorn.access]
32 | level=DEBUG
33 | handlers=console
34 | qualname=uvicorn.access
35 | propagate=0
36 |
37 | [handler_console]
38 | class=StreamHandler
39 | level=DEBUG
40 | formatter=generic
41 |
42 | [formatter_generic]
43 | format=%(asctime)s [%(process)d] [%(levelname)s] %(message)s
44 | datefmt=%Y-%m-%d %H:%M:%S
45 | class=logging.Formatter
46 |
--------------------------------------------------------------------------------
/pyproject.toml:
--------------------------------------------------------------------------------
1 | [tool.poetry]
2 | name = "taskara"
3 | version = "0.1.246"
4 | description = "Task management for AI agents"
5 | authors = ["Patrick Barker "]
6 | license = "MIT"
7 | readme = "README.md"
8 |
9 | [tool.poetry.dependencies]
10 | python = "^3.10"
11 | sqlalchemy = "^2.0.29"
12 | pydantic = "^2.6.4"
13 | docker = {version = "^7.0.0", optional = true}
14 | kubernetes = {version = "^29.0.0", optional = true}
15 | google-auth = {version = "^2.29.0", optional = true}
16 | google-cloud-container = {version = "^2.45.0", optional = true}
17 | namesgenerator = "^0.3"
18 | typer = {version = "^0.12.3", optional = true}
19 | tabulate = {version = "^0.9.0", optional = true}
20 | shortuuid = "^1.0.13"
21 | tqdm = "^4.66.4"
22 | cryptography = "^43.0.1"
23 | redis = "^5.2.0"
24 | agentcore = "^0.1.3"
25 | skillpacks = "^0.1.128"
26 | openmeter = "^1.0.0b188"
27 |
28 | [tool.poetry.group.dev.dependencies]
29 | pytest = "^8.1.1"
30 | flake8 = "^7.0.0"
31 | black = "^24.2.0"
32 | pytest-env = "^1.1.3"
33 | ipykernel = "^6.29.4"
34 | ruff = "^0.6.5"
35 |
36 | [tool.pyright]
37 | reportUnknownParameterType = false
38 | reportMissingTypeArgument = false
39 | reportUnknownMemberType = false
40 | reportUnknownVariableType = false
41 | reportUnknownArgumentType = false
42 |
43 |
44 | [tool.poetry.extras]
45 | runtime = ["kubernetes", "docker", "google-auth", "google-cloud-container"]
46 | cli = ["typer", "tabulate"]
47 | all = ["kubernetes", "docker", "google-auth", "google-cloud-container", "typer", "tabulate"]
48 |
49 | [build-system]
50 | requires = ["poetry-core"]
51 | build-backend = "poetry.core.masonry.api"
52 |
53 | [tool.poetry.scripts]
54 | lint = "scripts.lint:main"
55 |
--------------------------------------------------------------------------------
/pytest.ini:
--------------------------------------------------------------------------------
1 | [pytest]
2 | pythonpath = .
3 |
4 | env =
5 | D:DB_TYPE=sqlite
6 | D:AGENTSEA_DB_TEST=true
7 | D:AGENTSEA_HOME=./.agentsea
8 | D:AGENTSEA_DB_DIR=./.agentsea/data/test
9 |
--------------------------------------------------------------------------------
/scripts/lint.py:
--------------------------------------------------------------------------------
1 | import subprocess
2 |
3 |
4 | def main():
5 | subprocess.run(["black", "."])
6 | subprocess.run(["flake8", "."])
7 |
--------------------------------------------------------------------------------
/taskara/__init__.py:
--------------------------------------------------------------------------------
1 | from .benchmark import (
2 | Benchmark,
3 | Eval,
4 | TaskTemplate,
5 | V1Benchmark,
6 | V1Eval,
7 | V1TaskTemplate,
8 | )
9 | from .task import Task, TaskClient, TaskStatus, V1Task, V1Tasks, ReviewRequirement
10 |
11 | __all__ = [
12 | "Task",
13 | "V1Task",
14 | "V1Tasks",
15 | "TaskStatus",
16 | "Benchmark",
17 | "V1Benchmark",
18 | "TaskTemplate",
19 | "V1TaskTemplate",
20 | "Eval",
21 | "V1Eval",
22 | "TaskClient",
23 | "ReviewRequirement",
24 | ]
25 |
--------------------------------------------------------------------------------
/taskara/agent.py:
--------------------------------------------------------------------------------
1 | from abc import ABC, abstractmethod
2 | from typing import List, TypeVar, Type, Generic
3 |
4 | from pydantic import BaseModel
5 |
6 | from devicebay import Device
7 | from taskara import Task
8 |
9 | C = TypeVar("C", bound="BaseModel")
10 | T = TypeVar("T", bound="TaskAgent")
11 |
12 |
13 | class TaskAgent(Generic[C, T], ABC):
14 | """An agent that works on tasks"""
15 |
16 | @abstractmethod
17 | def solve_task(
18 | self,
19 | task: Task,
20 | device: Device,
21 | max_steps: int = 30,
22 | ) -> Task:
23 | """Solve a task on a device
24 |
25 | Args:
26 | task (Task): The task
27 | device (Device): Device to perform the task on.
28 | max_steps (int, optional): Max steps allowed. Defaults to 30.
29 |
30 | Returns:
31 | Task: A task
32 | """
33 | pass
34 |
35 | @classmethod
36 | @abstractmethod
37 | def supported_devices(cls) -> List[Type[Device]]:
38 | """Devices this agent supports
39 |
40 | Returns:
41 | List[Type[Device]]: A list of supported devices
42 | """
43 | pass
44 |
45 | @classmethod
46 | @abstractmethod
47 | def config_type(cls) -> Type[C]:
48 | """Type to configure the agent
49 |
50 | Returns:
51 | Type[C]: A configuration type
52 | """
53 | pass
54 |
55 | @classmethod
56 | @abstractmethod
57 | def from_config(cls, config: C) -> T:
58 | """Create an agent from a config
59 |
60 | Args:
61 | config (C): Config to create the agent from
62 |
63 | Returns:
64 | T: The Agent
65 | """
66 | pass
67 |
68 | @classmethod
69 | @abstractmethod
70 | def default(cls) -> T:
71 | """Create a default agent with no params
72 |
73 | Returns:
74 | T: The Agent
75 | """
76 | pass
77 |
78 | @classmethod
79 | def init(cls) -> None:
80 | """Initialize the Agent type"""
81 | pass
82 |
--------------------------------------------------------------------------------
/taskara/auth/default.py:
--------------------------------------------------------------------------------
1 | COMMON_USER = "common"
2 |
--------------------------------------------------------------------------------
/taskara/auth/key.py:
--------------------------------------------------------------------------------
1 | import os
2 | from abc import ABC, abstractmethod
3 | from typing import Optional
4 |
5 | import requests
6 | from agentcore.models import V1UserProfile
7 | from requests.exceptions import RequestException
8 | from threadmem.db.conn import WithDB
9 |
10 | from taskara.config import AGENTSEA_AUTH_URL
11 |
12 |
13 | class KeyProvider(ABC):
14 | """API key provider"""
15 |
16 | @abstractmethod
17 | def create_key(self) -> str:
18 | pass
19 |
20 | @abstractmethod
21 | def is_key(self, token: str) -> bool:
22 | pass
23 |
24 | @abstractmethod
25 | def validate(self, token: str) -> V1UserProfile:
26 | pass
27 |
28 |
29 | class MockProvider(KeyProvider):
30 | """Mock key provider"""
31 |
32 | _key = "k.mock"
33 |
34 | def create_key(self) -> str:
35 | return self._key
36 |
37 | def is_key(self, token: str) -> bool:
38 | if token.startswith("k."):
39 | return True
40 | return False
41 |
42 | def validate(self, token: str) -> V1UserProfile:
43 | if self._key == token:
44 | return V1UserProfile(
45 | email="tom@myspace.com",
46 | display_name="tom",
47 | picture="https://i.insider.com/4efd9b8b69bedd682c000022?width=750&format=jpeg&auto=webp",
48 | )
49 | raise ValueError("Invalid token")
50 |
51 |
52 | class HubKeyProvider(KeyProvider, WithDB):
53 | """AgentSea Hub provider"""
54 |
55 | def __init__(self) -> None:
56 | self.hub_url = AGENTSEA_AUTH_URL
57 |
58 | def create_key(self) -> str:
59 | raise NotImplementedError("create_key is not implemented")
60 |
61 | def is_key(self, token: str) -> bool:
62 | if token.startswith("k."):
63 | return True
64 | return False
65 |
66 | def validate(self, token: str) -> V1UserProfile:
67 | headers = {"Authorization": f"Bearer {token}"}
68 | try:
69 | response = requests.get(f"{self.hub_url}/v1/users/me", headers=headers)
70 | response.raise_for_status() # Raise an HTTPError if one occurred.
71 |
72 | user_data = response.json()
73 | # print("key response user data: ", user_data)
74 | prof = V1UserProfile(**user_data)
75 | return prof
76 |
77 | except RequestException as e:
78 | raise ValueError(f"Failed to validate token. Error: {e}")
79 |
80 |
81 | def get_key() -> Optional[str]:
82 | return os.environ.get("AGENTSEA_KEY")
83 |
84 |
85 | def ensure_key() -> str:
86 | key = get_key()
87 | if not key:
88 | raise ValueError("$AGENTSEA_KEY must be provided to use hub components")
89 | return key
90 |
91 |
92 | def default_key_provider() -> KeyProvider:
93 | return HubKeyProvider()
94 |
--------------------------------------------------------------------------------
/taskara/auth/provider.py:
--------------------------------------------------------------------------------
1 | import logging
2 | import os
3 | from abc import ABC, abstractmethod
4 | from typing import Optional
5 | from venv import logger
6 |
7 | import requests
8 | from agentcore.models import V1UserProfile
9 |
10 | from .key import KeyProvider, MockProvider, default_key_provider
11 |
12 | logger = logging.getLogger(__name__)
13 |
14 |
15 | class AuthProvider(ABC):
16 | @abstractmethod
17 | def key_provider(self) -> KeyProvider:
18 | pass
19 |
20 | @abstractmethod
21 | def get_user_auth(self, token: str) -> V1UserProfile:
22 | pass
23 |
24 |
25 | class HubAuthProvider(AuthProvider):
26 | """Hub user auth"""
27 |
28 | _key_provider: KeyProvider
29 |
30 | def __init__(self, key_provider: Optional[KeyProvider] = None) -> None:
31 | if not key_provider:
32 | key_provider = default_key_provider()
33 | self.hub_url = os.environ.get("AGENTSEA_AUTH_URL")
34 | if not self.hub_url:
35 | raise ValueError(
36 | "$AGENTSEA_AUTH_URL must be set to user the Hub key provider"
37 | )
38 |
39 | self._key_provider = key_provider
40 |
41 | def key_provider(self) -> KeyProvider:
42 | return self._key_provider
43 |
44 | def get_user_auth(self, token: str) -> V1UserProfile:
45 | try:
46 | if self._key_provider.is_key(token):
47 | user = self._key_provider.validate(token)
48 | logger.debug(f"found user: {user}")
49 |
50 | return user
51 |
52 | else:
53 | headers = {"Authorization": f"Bearer {token}"}
54 | headers.update(
55 | {
56 | "User-Agent": "My User Agent 1.0",
57 | }
58 | )
59 | auth_url = f"{self.hub_url}/v1/users/me"
60 | logger.debug(f"authorizing token with: {auth_url}")
61 | response = requests.get(auth_url, headers=headers)
62 | response.raise_for_status()
63 |
64 | user_data = response.json()
65 | user_schema = V1UserProfile(**user_data)
66 | user_schema.token = token
67 | return user_schema
68 |
69 | except Exception as e:
70 | logging.error(f"Problem fetching user auth {e}")
71 | raise Exception(
72 | "ID token was unauthorized, please log in",
73 | )
74 |
75 |
76 | class MockAuthProvider(AuthProvider):
77 | """Mock user auth"""
78 |
79 | _key_provider: KeyProvider = MockProvider()
80 |
81 | def key_provider(self) -> KeyProvider:
82 | return self._key_provider
83 |
84 | def get_user_auth(self, token: str) -> V1UserProfile:
85 | try:
86 | if self._key_provider.is_key(token):
87 | user = self._key_provider.validate(token)
88 | return user
89 |
90 | else:
91 | return V1UserProfile(
92 | email="tom@myspace.com",
93 | display_name="tom",
94 | picture="https://i.insider.com/4efd9b8b69bedd682c000022?width=750&format=jpeg&auto=webp",
95 | )
96 |
97 | except Exception as e:
98 | logging.error(f"Problem fetching user auth {e}")
99 | raise Exception(
100 | "ID token was unauthorized, please log in",
101 | )
102 |
103 |
104 | def default_auth_provider() -> AuthProvider:
105 | return HubAuthProvider()
106 |
--------------------------------------------------------------------------------
/taskara/auth/transport.py:
--------------------------------------------------------------------------------
1 | import logging
2 | import os
3 | from typing import Annotated
4 |
5 | from agentcore.models import V1UserProfile
6 | from fastapi import Depends, HTTPException
7 | from fastapi.security import OAuth2PasswordBearer
8 |
9 | from .provider import default_auth_provider
10 |
11 | logger = logging.getLogger(__name__)
12 |
13 | if os.getenv("TASK_SERVER_NO_AUTH", "false").lower() == "true":
14 | user_auth = None
15 | else:
16 | user_auth = default_auth_provider()
17 |
18 | oauth2_scheme = OAuth2PasswordBearer(tokenUrl="token")
19 |
20 |
21 | async def get_current_user(
22 | token: Annotated[str, Depends(oauth2_scheme)],
23 | ) -> V1UserProfile:
24 | if not user_auth:
25 | raise SystemError("user auth is not configured")
26 | try:
27 | logger.debug(f"checking user token: {token}")
28 | user = user_auth.get_user_auth(token)
29 | except Exception as e:
30 | logging.error(e)
31 | raise HTTPException(
32 | status_code=401,
33 | detail=f"-ID token was unauthorized, please log in: {e}",
34 | )
35 |
36 | return user
37 |
38 |
39 | async def get_user_mock_auth() -> V1UserProfile:
40 | # Return a dummy user profile when authentication is disabled
41 | return V1UserProfile(
42 | email="tom@myspace.com",
43 | display_name="tom",
44 | picture="https://i.insider.com/4efd9b8b69bedd682c000022?width=750&format=jpeg&auto=webp",
45 | )
46 |
47 |
48 | def get_user_dependency():
49 | if os.getenv("TASK_SERVER_NO_AUTH", "false").lower() == "true":
50 | return get_user_mock_auth
51 | else:
52 | return get_current_user
53 |
--------------------------------------------------------------------------------
/taskara/benchmark.py:
--------------------------------------------------------------------------------
1 | import json
2 | import time
3 | from typing import Any, Dict, List, Optional, Type
4 |
5 | import shortuuid
6 | from devicebay import V1Device, V1DeviceType
7 | from pydantic import BaseModel
8 |
9 | from taskara.db.conn import WithDB
10 | from taskara.db.models import (
11 | BenchmarkRecord,
12 | EvalRecord,
13 | TaskRecord,
14 | TaskTemplateRecord,
15 | benchmark_task_association,
16 | eval_task_association,
17 | )
18 | from taskara.server.models import V1Benchmark, V1Eval, V1TaskTemplate
19 | from taskara.task import Task
20 |
21 |
22 | class TaskTemplate(WithDB):
23 | """A task template"""
24 |
25 | def __init__(
26 | self,
27 | description: Optional[str] = None,
28 | max_steps: int = 30,
29 | owner_id: Optional[str] = None,
30 | device: Optional[V1Device] = None,
31 | device_type: Optional[V1DeviceType] = None,
32 | expect: Optional[Type[BaseModel]] = None,
33 | parameters: Dict[str, Any] = {},
34 | labels: Dict[str, str] = {},
35 | tags: List[str] = [],
36 | ) -> None:
37 | self._id = shortuuid.uuid()
38 | self._description = description
39 | self._max_steps = max_steps
40 | self._owner_id = owner_id
41 | self._device = device
42 | self._device_type = device_type
43 | self._expect_schema = expect.model_json_schema() if expect else None
44 | self._parameters = parameters
45 | self._labels = labels
46 | self._tags = tags
47 | self._created = time.time()
48 |
49 | @property
50 | def id(self) -> str:
51 | return self._id
52 |
53 | @property
54 | def description(self) -> Optional[str]:
55 | return self._description
56 |
57 | @property
58 | def max_steps(self) -> int:
59 | return self._max_steps
60 |
61 | @property
62 | def owner_id(self) -> Optional[str]:
63 | return self._owner_id
64 |
65 | @property
66 | def device(self) -> Optional[V1Device]:
67 | return self._device
68 |
69 | @property
70 | def device_type(self) -> Optional[V1DeviceType]:
71 | return self._device_type
72 |
73 | @property
74 | def expect_schema(self) -> Optional[Dict[str, Any]]:
75 | return self._expect_schema
76 |
77 | @property
78 | def parameters(self) -> Optional[Dict[str, Any]]:
79 | return self._parameters
80 |
81 | @property
82 | def labels(self) -> Dict[str, str]:
83 | return self._labels
84 |
85 | @property
86 | def tags(self) -> List[str]:
87 | return self._tags
88 |
89 | @property
90 | def created(self) -> float:
91 | return self._created
92 |
93 | def to_task(
94 | self,
95 | assigned_to: Optional[str] = None,
96 | assigned_type: Optional[str] = None,
97 | remote: Optional[str] = None,
98 | owner_id: Optional[str] = None,
99 | ) -> Task:
100 | task = Task(
101 | description=self.description,
102 | max_steps=self.max_steps,
103 | device=self.device,
104 | device_type=self.device_type,
105 | assigned_to=assigned_to,
106 | assigned_type=assigned_type,
107 | remote=remote,
108 | labels=self.labels,
109 | tags=self.tags,
110 | owner_id=owner_id if owner_id else self.owner_id,
111 | )
112 | task._expect_schema = self.expect_schema
113 | return task
114 |
115 | @classmethod
116 | def from_task(cls, task: Task) -> "TaskTemplate":
117 | tpl = cls(
118 | description=task.description,
119 | max_steps=task.max_steps,
120 | device=task.device,
121 | device_type=task.device_type,
122 | parameters=task.parameters if task.parameters else {},
123 | labels=task.labels,
124 | tags=task.tags,
125 | owner_id=task.owner_id,
126 | )
127 | tpl._expect_schema = task._expect_schema
128 | return tpl
129 |
130 | def to_record(self) -> TaskTemplateRecord:
131 |
132 | device = None
133 | if self._device:
134 | device = self._device.model_dump_json()
135 |
136 | device_type = None
137 | if self._device_type:
138 | device_type = self._device_type.model_dump_json()
139 |
140 | expect = None
141 | if self._expect_schema:
142 | expect = json.dumps(self._expect_schema)
143 |
144 | return TaskTemplateRecord(
145 | id=self._id,
146 | owner_id=self._owner_id,
147 | description=self._description,
148 | max_steps=self._max_steps,
149 | device=device,
150 | device_type=device_type,
151 | expect=expect,
152 | parameters=json.dumps(self._parameters),
153 | tags=json.dumps(self.tags),
154 | labels=json.dumps(self.labels),
155 | created=self._created,
156 | )
157 |
158 | @classmethod
159 | def from_record(cls, record: TaskTemplateRecord) -> "TaskTemplate":
160 |
161 | parameters = json.loads(str(record.parameters))
162 |
163 | device = None
164 | if record.device: # type: ignore
165 | device = V1Device.model_validate_json(str(record.device))
166 |
167 | device_type = None
168 | if record.device_type: # type: ignore
169 | device_type = V1DeviceType.model_validate_json(str(record.device_type))
170 |
171 | expect = None
172 | if record.expect: # type: ignore
173 | expect = json.loads(str(record.expect))
174 |
175 | obj = cls.__new__(cls)
176 | obj._id = record.id
177 | obj._owner_id = record.owner_id
178 | obj._description = record.description
179 | obj._max_steps = record.max_steps
180 | obj._device = device
181 | obj._device_type = device_type
182 | obj._expect_schema = expect
183 | obj._parameters = parameters
184 | obj._tags = json.loads(str(record.tags))
185 | obj._labels = json.loads(str(record.labels))
186 | obj._created = record.created
187 | return obj
188 |
189 | def to_v1(self) -> V1TaskTemplate:
190 | return V1TaskTemplate(
191 | id=self._id,
192 | description=self._description if self._description else "",
193 | max_steps=self._max_steps,
194 | device=self._device,
195 | device_type=self.device_type,
196 | expect_schema=self._expect_schema,
197 | parameters=self._parameters,
198 | owner_id=self._owner_id,
199 | tags=self._tags,
200 | labels=self._labels,
201 | created=self._created,
202 | )
203 |
204 | @classmethod
205 | def from_v1(
206 | cls, v1: V1TaskTemplate, owner_id: Optional[str] = None
207 | ) -> "TaskTemplate":
208 | obj = cls.__new__(cls)
209 |
210 | owner_id = owner_id if owner_id else v1.owner_id
211 | if not owner_id:
212 | raise ValueError("Owner id is required in v1 or as parameter")
213 |
214 | obj._id = v1.id if v1.id else shortuuid.uuid()
215 | obj._owner_id = owner_id
216 | obj._description = v1.description
217 | obj._max_steps = v1.max_steps
218 | obj._device = v1.device
219 | obj._device_type = v1.device_type
220 | obj._expect_schema = v1.expect_schema
221 | obj._parameters = v1.parameters
222 | obj._owner_id = owner_id
223 | obj._tags = v1.tags
224 | obj._labels = v1.labels
225 | obj._created = v1.created
226 |
227 | return obj
228 |
229 | def save(self) -> None:
230 | for db in self.get_db():
231 | db.merge(self.to_record())
232 | db.commit()
233 |
234 |
235 | class Benchmark(WithDB):
236 | """An agent benchmark"""
237 |
238 | def __init__(
239 | self,
240 | name: str,
241 | description: str,
242 | tasks: List[TaskTemplate],
243 | owner_id: Optional[str] = None,
244 | labels: Dict[str, str] = {},
245 | tags: List[str] = [],
246 | public: bool = False,
247 | ):
248 | self._tasks = tasks
249 | self._id = shortuuid.uuid()
250 | self._name = name
251 | self._description = description
252 | self._owner_id = owner_id
253 | self._labels = labels
254 | self._tags = tags
255 | self._public = public
256 | self._created = time.time()
257 |
258 | for task in tasks:
259 | task.labels["benchmark"] = self.name
260 |
261 | @property
262 | def name(self) -> str:
263 | return self._name
264 |
265 | @property
266 | def tasks(self) -> List[TaskTemplate]:
267 | return self._tasks
268 |
269 | @property
270 | def description(self) -> str:
271 | return self._description
272 |
273 | @property
274 | def id(self) -> str:
275 | return self._id
276 |
277 | @property
278 | def owner_id(self) -> Optional[str]:
279 | return self._owner_id
280 |
281 | @property
282 | def labels(self) -> Dict[str, str]:
283 | return self._labels
284 |
285 | @property
286 | def tags(self) -> List[str]:
287 | return self._tags
288 |
289 | @property
290 | def public(self) -> bool:
291 | return self._public
292 |
293 | def eval(
294 | self,
295 | assigned_to: str | None = None,
296 | assigned_type: str | None = None,
297 | remote: str | None = None,
298 | owner_id: str | None = None,
299 | ) -> "Eval":
300 | return Eval(
301 | benchmark=self,
302 | assigned_to=assigned_to,
303 | assigned_type=assigned_type,
304 | remote=remote,
305 | owner_id=owner_id,
306 | )
307 |
308 | def to_record(self) -> BenchmarkRecord:
309 | record = BenchmarkRecord(
310 | id=self._id,
311 | owner_id=self._owner_id,
312 | name=self._name,
313 | description=self._description,
314 | public=self._public,
315 | tags=json.dumps(self._tags),
316 | labels=json.dumps(self._labels),
317 | created=self._created,
318 | )
319 | return record
320 |
321 | @classmethod
322 | def from_record(cls, record: BenchmarkRecord, db_session) -> "Benchmark":
323 | # Retrieve task templates associated with the benchmark
324 | task_records = (
325 | db_session.query(TaskTemplateRecord)
326 | .join(
327 | benchmark_task_association,
328 | TaskTemplateRecord.id == benchmark_task_association.c.task_template_id,
329 | )
330 | .filter(benchmark_task_association.c.benchmark_id == record.id)
331 | .all()
332 | )
333 | tasks = [TaskTemplate.from_record(task_record) for task_record in task_records]
334 |
335 | obj = cls.__new__(cls)
336 | obj._id = record.id
337 | obj._owner_id = record.owner_id
338 | obj._name = record.name
339 | obj._description = record.description
340 | obj._labels = json.loads(str(record.labels))
341 | obj._tags = json.loads(str(record.tags))
342 | obj._created = record.created
343 | obj._tasks = tasks
344 | obj._public = record.public
345 | return obj
346 |
347 | def to_v1(self) -> V1Benchmark:
348 | return V1Benchmark(
349 | id=self._id,
350 | name=self._name,
351 | description=self._description,
352 | tasks=[task.to_v1() for task in self._tasks],
353 | owner_id=self._owner_id,
354 | tags=self._tags,
355 | labels=self._labels,
356 | created=self._created,
357 | public=self._public,
358 | )
359 |
360 | @classmethod
361 | def from_v1(cls, v1: V1Benchmark, owner_id: Optional[str] = None) -> "Benchmark":
362 | tasks = [
363 | TaskTemplate.from_v1(task, owner_id=owner_id if owner_id else v1.owner_id)
364 | for task in v1.tasks
365 | ]
366 | for task in tasks:
367 | task.save()
368 |
369 | obj = cls.__new__(cls)
370 | owner_id = owner_id if owner_id else v1.owner_id
371 | if not owner_id:
372 | raise ValueError("Owner id is required in v1 or as parameter")
373 |
374 | obj._id = v1.id if v1.id else shortuuid.uuid()
375 | obj._owner_id = owner_id
376 | obj._name = v1.name
377 | obj._description = v1.description
378 | obj._tasks = tasks
379 | obj._labels = v1.labels
380 | obj._tags = v1.tags
381 | obj._created = v1.created
382 | obj._public = v1.public
383 |
384 | return obj
385 |
386 | @classmethod
387 | def find(cls, remote: Optional[str] = None, **kwargs) -> List["Benchmark"]:
388 | for db in cls.get_db():
389 | records = (
390 | db.query(BenchmarkRecord)
391 | .filter_by(**kwargs)
392 | .order_by(BenchmarkRecord.created.desc())
393 | .all()
394 | )
395 | return [cls.from_record(record, db) for record in records]
396 | raise ValueError("No session")
397 |
398 | def save(self) -> None:
399 | for db in self.get_db():
400 | # Save the benchmark record
401 | benchmark_record = self.to_record()
402 | db.merge(benchmark_record)
403 | db.commit()
404 |
405 | # Save the task records and associations
406 | for task in self._tasks:
407 | task_record = task.to_record()
408 | db.merge(task_record)
409 | db.commit()
410 |
411 | association = benchmark_task_association.insert().values(
412 | benchmark_id=self._id, task_template_id=task.id
413 | )
414 | db.execute(association)
415 | db.commit()
416 |
417 | def delete(self) -> None:
418 | for db in self.get_db():
419 | # Delete the benchmark record
420 | benchmark_record = db.query(BenchmarkRecord).filter_by(id=self._id).first()
421 | if benchmark_record:
422 | db.delete(benchmark_record)
423 | db.commit()
424 |
425 | # Delete the task records and associations
426 | db.execute(
427 | benchmark_task_association.delete().where(
428 | benchmark_task_association.c.benchmark_id == self._id
429 | )
430 | )
431 | db.commit()
432 |
433 | for task in self._tasks:
434 | task_record = db.query(TaskTemplateRecord).filter_by(id=task.id).first()
435 | if task_record:
436 | db.delete(task_record)
437 | db.commit()
438 |
439 |
440 | class Eval(WithDB):
441 | """An agent evaluation on a benchmark"""
442 |
443 | def __init__(
444 | self,
445 | benchmark: Benchmark,
446 | assigned_to: Optional[str] = None,
447 | assigned_type: Optional[str] = None,
448 | remote: Optional[str] = None,
449 | owner_id: Optional[str] = None,
450 | ) -> None:
451 | self._id = shortuuid.uuid()
452 | self._benchmark = benchmark
453 | self._tasks: List[Task] = []
454 | self._owner_id = owner_id
455 | self._assigned_to = assigned_to
456 | self._assigned_type = assigned_type
457 |
458 | for tpl in self._benchmark.tasks:
459 | task = tpl.to_task(
460 | assigned_to=assigned_to,
461 | assigned_type=assigned_type,
462 | remote=remote,
463 | owner_id=owner_id,
464 | )
465 | task.labels["benchmark"] = self._benchmark.name
466 | self._tasks.append(task)
467 |
468 | @property
469 | def tasks(self) -> List[Task]:
470 | return self._tasks
471 |
472 | @property
473 | def benchmark(self) -> Benchmark:
474 | return self._benchmark
475 |
476 | @property
477 | def id(self) -> str:
478 | return self._id
479 |
480 | @property
481 | def owner_id(self) -> Optional[str]:
482 | return self._owner_id
483 |
484 | def to_record(self) -> EvalRecord:
485 | return EvalRecord(
486 | id=self._id,
487 | benchmark_id=self._benchmark.id,
488 | assigned_to=self._assigned_to,
489 | assigned_type=self._assigned_type,
490 | owner_id=self._owner_id,
491 | created=time.time(),
492 | )
493 |
494 | @classmethod
495 | def from_record(cls, record: EvalRecord, db_session) -> "Eval":
496 | benchmark = Benchmark.from_record(
497 | db_session.query(BenchmarkRecord).filter_by(id=record.benchmark_id).first(),
498 | db_session,
499 | )
500 | # Correctly extract task_ids from the association table
501 | task_associations = (
502 | db_session.query(eval_task_association.c.task_id)
503 | .filter_by(eval_id=record.id)
504 | .all()
505 | )
506 | task_ids = [task_id for (task_id,) in task_associations]
507 | tasks = [
508 | Task.from_record(db_session.query(TaskRecord).filter_by(id=task_id).first())
509 | for task_id in task_ids
510 | ]
511 |
512 | obj = cls.__new__(cls)
513 | obj._id = record.id
514 | obj._benchmark = benchmark
515 | obj._tasks = tasks
516 | obj._owner_id = record.owner_id
517 | obj._assigned_to = record.assigned_to
518 | obj._assigned_type = record.assigned_type
519 |
520 | return obj
521 |
522 | def to_v1(self) -> V1Eval:
523 | return V1Eval(
524 | id=self._id,
525 | benchmark=self._benchmark.to_v1(),
526 | tasks=[task.to_v1() for task in self._tasks],
527 | assigned_to=self._assigned_to,
528 | assigned_type=self._assigned_type,
529 | owner_id=self._owner_id,
530 | )
531 |
532 | @classmethod
533 | def from_v1(cls, v1: V1Eval, owner_id: Optional[str] = None) -> "Eval":
534 | benchmark = Benchmark.from_v1(v1.benchmark, owner_id=owner_id)
535 | tasks = [
536 | Task.from_v1(task, owner_id=owner_id if owner_id else v1.owner_id)
537 | for task in v1.tasks
538 | ]
539 |
540 | obj = cls.__new__(cls)
541 | obj._id = v1.id if v1.id else shortuuid.uuid()
542 | obj._benchmark = benchmark
543 | obj._tasks = tasks
544 | obj._owner_id = owner_id if owner_id else v1.owner_id
545 | obj._assigned_to = v1.assigned_to
546 | obj._assigned_type = v1.assigned_type
547 |
548 | return obj
549 |
550 | @classmethod
551 | def find(cls, remote: Optional[str] = None, **kwargs) -> List["Eval"]:
552 | for db in cls.get_db():
553 | records = (
554 | db.query(EvalRecord)
555 | .filter_by(**kwargs)
556 | .order_by(EvalRecord.created.desc())
557 | .all()
558 | )
559 | return [cls.from_record(record, db) for record in records]
560 | raise ValueError("No session")
561 |
562 | def save(self) -> None:
563 | for db in self.get_db():
564 | # Save the evaluation record
565 | eval_record = self.to_record()
566 | db.merge(eval_record)
567 | db.commit()
568 |
569 | # Save the task records and associations
570 | for task in self._tasks:
571 | task_record = task.to_record()
572 | db.merge(task_record)
573 | db.commit()
574 |
575 | association = eval_task_association.insert().values(
576 | eval_id=self._id, task_id=task.id
577 | )
578 | db.execute(association)
579 | db.commit()
580 |
581 | def delete(self) -> None:
582 | for db in self.get_db():
583 | # Delete the evaluation record
584 | eval_record = db.query(EvalRecord).filter_by(id=self._id).first()
585 | if eval_record:
586 | db.delete(eval_record)
587 | db.commit()
588 |
589 | # Delete the task records and associations
590 | db.execute(
591 | eval_task_association.delete().where(
592 | eval_task_association.c.eval_id == self._id
593 | )
594 | )
595 | db.commit()
596 |
--------------------------------------------------------------------------------
/taskara/cli/main.py:
--------------------------------------------------------------------------------
1 | from typing import Optional
2 | from importlib.metadata import PackageNotFoundError
3 | from importlib.metadata import version as pkgversion
4 |
5 | import typer
6 | from namesgenerator import get_random_name
7 | from tabulate import tabulate
8 | import rich
9 |
10 |
11 | art = """
12 | _______ _
13 | (_______) | |
14 | _ _____ ___| | _ _____ ____ _____
15 | | (____ |/___) |_/ |____ |/ ___|____ |
16 | | / ___ |___ | _ (/ ___ | | / ___ |
17 | |_\_____(___/|_| \_)_____|_| \_____|
18 |
19 | """
20 |
21 | app = typer.Typer()
22 |
23 | # Sub-command groups
24 | create_group = typer.Typer(help="Create resources")
25 | list_group = typer.Typer(help="List resources")
26 | get_group = typer.Typer(help="Get resources")
27 | view_group = typer.Typer(help="View resources")
28 | delete_group = typer.Typer(help="Delete resources")
29 | clean_group = typer.Typer(help="Clean resources")
30 |
31 | app.add_typer(create_group, name="create")
32 | app.add_typer(list_group, name="list")
33 | app.add_typer(get_group, name="get")
34 | app.add_typer(view_group, name="view")
35 | app.add_typer(delete_group, name="delete")
36 | app.add_typer(clean_group, name="clean")
37 |
38 |
39 | # Callback for showing help
40 | def show_help(ctx: typer.Context, command_group: str):
41 | if ctx.invoked_subcommand is None:
42 | if command_group == "root":
43 | typer.echo(art)
44 | typer.echo(ctx.get_help())
45 | raise typer.Exit()
46 |
47 |
48 | try:
49 | __version__ = pkgversion("surfkit")
50 | except PackageNotFoundError:
51 | # Fallback version or error handling
52 | __version__ = "unknown"
53 |
54 |
55 | @app.command(help="Show the version of the CLI")
56 | def version():
57 | """Show the CLI version."""
58 | typer.echo(__version__)
59 |
60 |
61 | # Root command callback
62 | @app.callback(invoke_without_command=True)
63 | def default(ctx: typer.Context):
64 | show_help(ctx, "root")
65 |
66 |
67 | # 'create' command group callback
68 | @create_group.callback(invoke_without_command=True)
69 | def create_default(ctx: typer.Context):
70 | show_help(ctx, "create")
71 |
72 |
73 | # 'list' command group callback
74 | @list_group.callback(invoke_without_command=True)
75 | def list_default(ctx: typer.Context):
76 | show_help(ctx, "list")
77 |
78 |
79 | # 'get' command group callback
80 | @get_group.callback(invoke_without_command=True)
81 | def get_default(ctx: typer.Context):
82 | show_help(ctx, "get")
83 |
84 |
85 | # 'delete' command group callback
86 | @delete_group.callback(invoke_without_command=True)
87 | def delete_default(ctx: typer.Context):
88 | show_help(ctx, "delete")
89 |
90 |
91 | # 'view' command group callback
92 | @view_group.callback(invoke_without_command=True)
93 | def view_default(ctx: typer.Context):
94 | show_help(ctx, "view")
95 |
96 |
97 | # 'clean' command group callback
98 | @clean_group.callback(invoke_without_command=True)
99 | def clean_default(ctx: typer.Context):
100 | show_help(ctx, "clean")
101 |
102 |
103 | # 'create' sub-commands
104 | @create_group.command("task")
105 | def create_task(
106 | description: str = typer.Option(
107 | ...,
108 | "--description",
109 | "-d",
110 | help="Description of the task. Defaults to a generated name.",
111 | ),
112 | remote: bool = typer.Option(True, "--remote", "-r", help="List tasks from remote"),
113 | ):
114 | from taskara import Task
115 |
116 | typer.echo(f"Creating task '{description}'")
117 | try:
118 | task = Task(description=description)
119 | except KeyboardInterrupt:
120 | print("Keyboard interrupt received, exiting...")
121 | return
122 |
--------------------------------------------------------------------------------
/taskara/config.py:
--------------------------------------------------------------------------------
1 | """
2 | Configuration for taskara
3 | """
4 |
5 | import os
6 | from dataclasses import dataclass
7 | from typing import Optional
8 |
9 | import yaml
10 |
11 | from .env import AGENTSEA_AUTH_URL_ENV, AGENTSEA_HUB_API_URL_ENV, AGENTSEA_HUB_URL_ENV
12 |
13 | AGENTSEA_HOME = os.path.expanduser(os.environ.get("AGENTSEA_HOME", "~/.agentsea"))
14 | AGENTSEA_DB_DIR = os.path.expanduser(
15 | os.environ.get("AGENTSEA_DB_DIR", os.path.join(AGENTSEA_HOME, "data"))
16 | )
17 | AGENTSEA_LOG_DIR = os.path.expanduser(
18 | os.environ.get("AGENTSEA_LOG_DIR", os.path.join(AGENTSEA_HOME, "logs"))
19 | )
20 | AGENTSEA_AUTH_URL = os.getenv(AGENTSEA_AUTH_URL_ENV, "https://auth.hub.agentsea.ai")
21 | AGENTSEA_HUB_URL = os.getenv(AGENTSEA_HUB_URL_ENV, "https://hub.agentsea.ai")
22 | AGENTSEA_HUB_API_URL = os.getenv(
23 | AGENTSEA_HUB_API_URL_ENV, "https://api.hub.agentsea.ai"
24 | )
25 | DB_TEST = os.environ.get("AGENTSEA_DB_TEST", "false") == "true"
26 | DB_NAME = os.environ.get("TASKS_DB_NAME", "tasks.db")
27 | if DB_TEST:
28 | DB_NAME = "tasks_test.db"
29 |
30 |
31 | @dataclass
32 | class GlobalConfig:
33 | api_key: Optional[str] = None
34 | hub_address: str = AGENTSEA_HUB_URL
35 |
36 | def write(self) -> None:
37 | home = os.path.expanduser("~")
38 | dir = os.path.join(home, ".agentsea")
39 | os.makedirs(dir, exist_ok=True)
40 | path = os.path.join(dir, "config.yaml")
41 |
42 | with open(path, "w") as yaml_file:
43 | yaml.dump(self.__dict__, yaml_file)
44 | yaml_file.flush()
45 | yaml_file.close()
46 |
47 | @classmethod
48 | def read(cls) -> "GlobalConfig":
49 | home = os.path.expanduser("~")
50 | dir = os.path.join(home, ".agentsea")
51 | os.makedirs(dir, exist_ok=True)
52 | path = os.path.join(dir, "config.yaml")
53 |
54 | if not os.path.exists(path):
55 | return GlobalConfig()
56 |
57 | with open(path, "r") as yaml_file:
58 | config = yaml.safe_load(yaml_file)
59 | return GlobalConfig(**config)
60 |
--------------------------------------------------------------------------------
/taskara/db/conn.py:
--------------------------------------------------------------------------------
1 | import logging
2 | import os
3 |
4 | from sqlalchemy import Engine, create_engine
5 | from sqlalchemy.orm import sessionmaker
6 |
7 | from taskara.config import AGENTSEA_DB_DIR, DB_NAME
8 |
9 | from .models import Base
10 |
11 | logger = logging.getLogger(__name__)
12 |
13 | DB_TYPE = os.environ.get("DB_TYPE", "sqlite")
14 |
15 |
16 | def get_pg_conn() -> Engine:
17 | # Helper function to get environment variable with fallback
18 | def get_env_var(key: str) -> str:
19 | task_key = f"TASKS_{key}"
20 | value = os.environ.get(task_key)
21 | if value is None:
22 | value = os.environ.get(key)
23 | if value is None:
24 | raise ValueError(f"${key} must be set")
25 | return value
26 |
27 | # Retrieve environment variables with fallbacks
28 | db_user = get_env_var("DB_USER")
29 | db_password = get_env_var("DB_PASS")
30 | db_host = get_env_var("DB_HOST")
31 | db_name = get_env_var("DB_NAME")
32 |
33 | logger.debug(f"connecting to db on postgres host '{db_host}' with db '{db_name}'")
34 | engine = create_engine(
35 | f"postgresql+psycopg2://{db_user}:{db_password}@{db_host}/{db_name}",
36 | client_encoding="utf8",
37 | pool_pre_ping=True,
38 | pool_recycle=300
39 | )
40 |
41 | return engine
42 |
43 |
44 | def get_sqlite_conn() -> Engine:
45 | db_path = os.path.join(AGENTSEA_DB_DIR, DB_NAME)
46 | if os.getenv("TASKARA_DEBUG"):
47 | print(f"connecting to local sqlite db {db_path}", flush=True)
48 | os.makedirs(AGENTSEA_DB_DIR, exist_ok=True)
49 | try:
50 | engine = create_engine(f"sqlite:///{db_path}")
51 | except Exception as e:
52 | logger.error(f"error connecting to sqlite db {db_path}: {e}")
53 | raise e
54 | return engine
55 |
56 |
57 | if DB_TYPE == "postgres":
58 | engine = get_pg_conn()
59 | else:
60 | engine = get_sqlite_conn()
61 | SessionLocal = sessionmaker(bind=engine)
62 |
63 | try:
64 | Base.metadata.create_all(bind=engine)
65 | except Exception as e:
66 | logger.error(f"error creating tables: {e} for engine {engine.url}")
67 | raise e
68 |
69 |
70 | class WithDB:
71 | @staticmethod
72 | def get_db():
73 | """Get a database connection
74 |
75 | Example:
76 | ```
77 | for session in self.get_db():
78 | session.add(foo)
79 | ```
80 | """
81 | db = SessionLocal()
82 | try:
83 | yield db
84 | finally:
85 | db.close()
86 |
87 |
88 | def get_db():
89 | """Get a database connection
90 |
91 | Example:
92 | ```
93 | for session in get_db():
94 | session.add(foo)
95 | ```
96 | """
97 | db = SessionLocal()
98 | try:
99 | yield db
100 | finally:
101 | db.close()
102 |
--------------------------------------------------------------------------------
/taskara/db/models.py:
--------------------------------------------------------------------------------
1 | import time
2 |
3 | import shortuuid
4 | from sqlalchemy import (
5 | Boolean,
6 | Column,
7 | Float,
8 | ForeignKey,
9 | Index,
10 | Integer,
11 | String,
12 | Table,
13 | Text,
14 | )
15 | from sqlalchemy.orm import declarative_base, relationship
16 |
17 | Base = declarative_base()
18 |
19 | benchmark_task_association = Table(
20 | "benchmark_task_association",
21 | Base.metadata,
22 | Column("benchmark_id", String, ForeignKey("benchmarks.id"), primary_key=True),
23 | Column(
24 | "task_template_id", String, ForeignKey("task_templates.id"), primary_key=True
25 | ),
26 | )
27 |
28 | eval_task_association = Table(
29 | "eval_task_association",
30 | Base.metadata,
31 | Column("eval_id", String, ForeignKey("evals.id"), primary_key=True),
32 | Column("task_id", String, ForeignKey("tasks.id"), primary_key=True),
33 | )
34 |
35 | # Association table for many-to-many between tasks and tags
36 | task_tag_association = Table(
37 | "task_tag_association",
38 | Base.metadata,
39 | Column("task_id", String, ForeignKey("tasks.id"), primary_key=True),
40 | Column("tag_id", String, ForeignKey("tags.id"), primary_key=True),
41 | )
42 |
43 | # Association table for many-to-many between tasks and labels
44 | task_label_association = Table(
45 | "task_label_association",
46 | Base.metadata,
47 | Column("task_id", String, ForeignKey("tasks.id"), primary_key=True),
48 | Column("label_id", String, ForeignKey("labels.id"), primary_key=True),
49 | )
50 |
51 |
52 | class TagRecord(Base):
53 | __tablename__ = "tags"
54 | __table_args__ = (Index("idx_tags_tag", "tag"),)
55 | id = Column(String, primary_key=True, default=lambda: shortuuid.uuid())
56 | tag = Column(String, unique=True, nullable=False)
57 |
58 |
59 | class LabelRecord(Base):
60 | __tablename__ = "labels"
61 | __table_args__ = (Index("idx_labels_key_value", "key", "value"),)
62 | id = Column(String, primary_key=True, default=lambda: shortuuid.uuid())
63 | key = Column(String, nullable=False)
64 | value = Column(String, nullable=False)
65 |
66 |
67 | class TaskRecord(Base):
68 | __tablename__ = "tasks"
69 | __table_args__ = (
70 | Index("idx_tasks_owner_id", "owner_id"),
71 | Index("idx_tasks_status", "status"),
72 | Index("idx_tasks_owner_id_skill", "owner_id", "skill"),
73 | Index("idx_tasks_skill_assigned_type", "skill", "assigned_type"),
74 | Index("idx_tasks_skill_assigned_type_completed", "skill", "assigned_type", "completed"),
75 | )
76 | id = Column(String, primary_key=True)
77 | owner_id = Column(String, nullable=True)
78 | description = Column(String, nullable=False)
79 | max_steps = Column(Integer, nullable=False, default=30)
80 | device = Column(String, nullable=True)
81 | device_type = Column(String, nullable=True)
82 | project = Column(String, nullable=True)
83 | expect = Column(String, nullable=True)
84 | assigned_to = Column(String, nullable=True)
85 | assigned_type = Column(String, nullable=True)
86 | reviews = Column(Text, nullable=True)
87 | review_requirements = Column(Text, nullable=True)
88 | status = Column(String, nullable=False)
89 | created = Column(Float, nullable=False)
90 | started = Column(Float, nullable=False, default=0.0)
91 | completed = Column(Float, nullable=False, default=0.0)
92 | created_by = Column(String, nullable=False)
93 | error = Column(String, nullable=True)
94 | output = Column(String, nullable=True)
95 | threads = Column(String, nullable=False)
96 | prompts = Column(String, nullable=True)
97 | parent_id = Column(String, nullable=True)
98 | parameters = Column(String, nullable=True)
99 | version = Column(String, nullable=True)
100 | public = Column(Boolean, nullable=False, default=False)
101 | skill = Column(String, nullable=True)
102 | episode_id = Column(String, nullable=True)
103 |
104 | tags = relationship("TagRecord", secondary=task_tag_association, backref="tasks")
105 | labels = relationship(
106 | "LabelRecord", secondary=task_label_association, backref="tasks"
107 | )
108 |
109 |
110 | class ReviewRequirementRecord(Base):
111 | __tablename__ = "review_requirements"
112 | __table_args__ = (Index("idx_review_req_task_id", "task_id"),)
113 | id = Column(String, primary_key=True)
114 | task_id = Column(String, nullable=True)
115 | number_required = Column(Integer, nullable=False)
116 | users = Column(
117 | Text, nullable=True
118 | ) # We need to split these apart for better lookups
119 | agents = Column(Text, nullable=True)
120 | groups = Column(Text, nullable=True)
121 | types = Column(Text, nullable=True)
122 | created = Column(Float, default=time.time)
123 | updated = Column(Float, nullable=True)
124 |
125 |
126 | class TaskTemplateRecord(Base):
127 | __tablename__ = "task_templates"
128 |
129 | id = Column(String, primary_key=True)
130 | owner_id = Column(String, nullable=True)
131 | description = Column(String, nullable=False)
132 | max_steps = Column(Integer, nullable=False, default=30)
133 | device = Column(String, nullable=True)
134 | device_type = Column(String, nullable=True)
135 | expect = Column(String, nullable=True)
136 | parameters = Column(String, nullable=True)
137 | tags = Column(String, nullable=True)
138 | labels = Column(String, nullable=True)
139 | created = Column(Float, default=time.time)
140 |
141 | benchmarks = relationship(
142 | "BenchmarkRecord",
143 | secondary=benchmark_task_association,
144 | back_populates="task_templates",
145 | )
146 |
147 |
148 | class BenchmarkRecord(Base):
149 | __tablename__ = "benchmarks"
150 |
151 | id = Column(String, primary_key=True)
152 | owner_id = Column(String, nullable=True)
153 | name = Column(String, unique=True, index=True)
154 | description = Column(String, nullable=False)
155 | public = Column(Boolean, default=False)
156 | tags = Column(String, nullable=True)
157 | labels = Column(String, nullable=True)
158 | created = Column(Float, default=time.time)
159 |
160 | task_templates = relationship(
161 | "TaskTemplateRecord",
162 | secondary=benchmark_task_association,
163 | back_populates="benchmarks",
164 | )
165 |
166 |
167 | class EvalRecord(Base):
168 | __tablename__ = "evals"
169 |
170 | id = Column(String, primary_key=True)
171 | owner_id = Column(String, nullable=True)
172 | benchmark_id = Column(String, ForeignKey("benchmarks.id"))
173 | assigned_to = Column(String, nullable=True)
174 | assigned_type = Column(String, nullable=True)
175 | created = Column(Float, default=time.time)
176 |
177 | benchmark = relationship("BenchmarkRecord")
178 | tasks = relationship(
179 | "TaskRecord",
180 | secondary=eval_task_association,
181 | )
182 |
183 |
184 | class TrackerRecord(Base):
185 | __tablename__ = "trackers"
186 |
187 | id = Column(String, primary_key=True)
188 | name = Column(String, unique=True, index=True)
189 | runtime_name = Column(String)
190 | runtime_config = Column(String)
191 | status = Column(String)
192 | port = Column(Integer)
193 | owner_id = Column(String, nullable=True)
194 | labels = Column(String, nullable=True)
195 | created = Column(Float, default=time.time)
196 | updated = Column(Float, default=time.time)
197 |
198 |
199 | class FlagRecord(Base):
200 | __tablename__ = "flags"
201 |
202 | id = Column(String, primary_key=True)
203 | type = Column(String)
204 | flag = Column(Text)
205 | result = Column(Text, nullable=True)
206 | created = Column(Float, default=time.time)
207 |
208 |
209 | class PendingReviewersRecord(Base):
210 | __tablename__ = "pending_reviewers"
211 | __table_args__ = (Index("idx_pending_reviewers_task_id", "task_id"),)
212 | id = Column(String, primary_key=True)
213 | task_id = Column(String, nullable=True)
214 | user_id = Column(String, nullable=True)
215 | agent_id = Column(String, nullable=True)
216 | requirement_id = Column(String, nullable=True)
217 |
--------------------------------------------------------------------------------
/taskara/db/redis_connection.py:
--------------------------------------------------------------------------------
1 | import redis.asyncio as redis
2 | import os
3 |
4 | # Global variable to hold the Redis connection pool and client
5 | # TODO Need to mock redis streams and async functionality or check if Fake Redis supports it
6 | redis_pool: redis.ConnectionPool | None = None
7 | redis_client: redis.Redis | None = None
8 | redis_url = os.environ.get("REDIS_CACHE_STORAGE", None)
9 |
10 | # store stream names here
11 | stream_action_recorded = "events:action_recorded"
12 | annotation_failures_dlq = "dlq:annotation_failures"
13 | stream_training_completed = "events:training_completed"
14 |
15 | async def init_redis_pool():
16 | """Initialize the Redis connection pool."""
17 | global redis_pool, redis_client
18 |
19 | if redis_url:
20 | # Create a Redis connection pool
21 | redis_pool = redis.ConnectionPool.from_url(redis_url, max_connections=10)
22 | print("Redis connection pool initialized.", flush=True)
23 | else:
24 | print("No Redis URL", flush=True)
25 |
26 | def get_redis_client():
27 | """Get a Redis client from the connection pool."""
28 | global redis_pool
29 | if not redis_pool:
30 | # TODO need to have proper mocking of redis in order to do this
31 | # raise ValueError("Redis connection pool is not initialized or redis connection details don't exist. Call init_redis_pool() first.")
32 | return None
33 | return redis.Redis(connection_pool=redis_pool)
34 |
35 | async def close_redis_pool():
36 | global redis_pool, redis_client
37 | if redis_pool:
38 | await redis_pool.disconnect()
39 | redis_pool = None
40 | redis_client = None
--------------------------------------------------------------------------------
/taskara/env.py:
--------------------------------------------------------------------------------
1 | HUB_SERVER_ENV = "SURF_HUB_SERVER"
2 | AGENTD_ADDR_ENV = "AGENTD_ADDR"
3 | AGENTD_PRIVATE_SSH_KEY_ENV = "AGENTD_PRIVATE_SSH_KEY"
4 | HUB_API_KEY_ENV = "HUB_API_KEY"
5 | AGENTSEA_HUB_URL_ENV = "AGENTSEA_HUB_URL"
6 | AGENTSEA_HUB_API_URL_ENV = "AGENTSEA_HUB_API_URL"
7 | AGENTSEA_AUTH_URL_ENV = "AGENTSEA_AUTH_URL"
8 | STORAGE_BUCKET_ENV = "SKILL_STORAGE_BUCKET"
9 | STORAGE_SA_JSON_ENV = "SKILL_STORAGE_SA_JSON"
10 |
--------------------------------------------------------------------------------
/taskara/flag.py:
--------------------------------------------------------------------------------
1 | from typing import TypeVar, Optional, Generic, Type, Dict, List
2 | from abc import ABC, abstractmethod
3 | import time
4 | import json
5 |
6 | import shortuuid
7 | from pydantic import BaseModel
8 |
9 | from taskara.server.models import V1BoundingBox, V1BoundingBoxFlag, V1Flag
10 | from taskara.db.models import FlagRecord
11 | from taskara.db.conn import WithDB
12 |
13 | FlagResult = TypeVar("FlagResult", bound="BaseModel")
14 | FlagModel = TypeVar("FlagModel", bound="BaseModel")
15 | FlagType = TypeVar("FlagType", bound="Flag")
16 |
17 |
18 | class Flag(Generic[FlagResult, FlagModel, FlagType], ABC, WithDB):
19 | """A flag for human review"""
20 |
21 | def __init__(self) -> None:
22 | self.id = shortuuid.uuid()
23 | self.result: Optional[FlagResult] = None
24 | self.created = time.time()
25 | self.result = None
26 |
27 | def set_result(self, result: FlagResult):
28 | self.result = result
29 |
30 | @classmethod
31 | @abstractmethod
32 | def result_type(cls) -> Type[FlagResult]:
33 | pass
34 |
35 | @classmethod
36 | @abstractmethod
37 | def v1_type(cls) -> Type[FlagModel]:
38 | pass
39 |
40 | @abstractmethod
41 | def to_v1(cls) -> FlagModel:
42 | pass
43 |
44 | @classmethod
45 | @abstractmethod
46 | def from_v1(cls, v1: FlagModel) -> FlagType:
47 | pass
48 |
49 | def to_v1flag(self) -> V1Flag:
50 | return V1Flag(
51 | type=self.__class__.__name__,
52 | id=self.id,
53 | flag=self.to_v1().model_dump(),
54 | result=self.result.model_dump() if self.result else None,
55 | created=self.created,
56 | )
57 |
58 | def to_record(self) -> FlagRecord:
59 | return FlagRecord(
60 | id=self.id,
61 | type=self.__class__.__name__,
62 | flag=self.to_v1().model_dump_json(),
63 | result=self.result.model_dump_json() if self.result else None,
64 | created=self.created,
65 | )
66 |
67 | @classmethod
68 | def from_record(cls, record: FlagRecord) -> FlagType:
69 | # Deserialize the flag JSON to a FlagModel (like V1BoundingBoxFlag)
70 | flag_model = cls.v1_type().model_validate_json(str(record.flag))
71 |
72 | # Use the from_v1 method to create the specific FlagType (like BoundingBoxFlag)
73 | instance = cls.from_v1(flag_model)
74 |
75 | # Set additional fields from the record
76 | instance.id = record.id
77 | instance.result = ( # type: ignore
78 | cls.result_type().model_validate_json(str(record.result))
79 | if record.result # type: ignore
80 | else None
81 | )
82 | instance.created = record.created
83 |
84 | return instance
85 |
86 | def save(self) -> None:
87 | for db in self.get_db():
88 | db.add(self.to_record())
89 | db.commit()
90 |
91 | @classmethod
92 | def find(cls, **kwargs) -> List[FlagType]:
93 | for db in cls.get_db():
94 | records = (
95 | db.query(FlagRecord)
96 | .filter_by(type=cls.__name__, **kwargs)
97 | .order_by(FlagRecord.created.desc())
98 | .all()
99 | )
100 | return [cls.from_record(record) for record in records]
101 | return []
102 |
103 | @classmethod
104 | def find_v1(cls, **kwargs) -> List[V1Flag]:
105 | for db in cls.get_db():
106 | records = (
107 | db.query(FlagRecord)
108 | .filter_by(**kwargs)
109 | .order_by(FlagRecord.created.desc())
110 | .all()
111 | )
112 | return [
113 | V1Flag(
114 | type=str(record.type),
115 | id=str(record.id),
116 | flag=json.loads(str(record.flag)),
117 | result=json.loads(str(record.result)) if record.result else None, # type: ignore
118 | created=record.created, # type: ignore
119 | )
120 | for record in records
121 | ]
122 | return []
123 |
124 |
125 | class BoundingBoxFlag(Flag[V1BoundingBox, V1BoundingBoxFlag, "BoundingBoxFlag"]):
126 | """Bounding box flag"""
127 |
128 | def __init__(
129 | self,
130 | img: str,
131 | target: str,
132 | bbox: V1BoundingBox,
133 | metadata: Optional[Dict[str, str]] = None,
134 | ):
135 | super().__init__()
136 | self.img = img
137 | self.target = target
138 | self.bbox = bbox
139 | self.metadata = metadata
140 |
141 | @classmethod
142 | def result_type(cls) -> Type[V1BoundingBox]:
143 | return V1BoundingBox
144 |
145 | @classmethod
146 | def v1_type(cls) -> Type[V1BoundingBoxFlag]:
147 | return V1BoundingBoxFlag
148 |
149 | def to_v1(self) -> V1BoundingBoxFlag:
150 | return V1BoundingBoxFlag(
151 | img=self.img,
152 | target=self.target,
153 | bbox=self.bbox,
154 | )
155 |
156 | @classmethod
157 | def from_v1(cls, v1: V1BoundingBoxFlag) -> "BoundingBoxFlag":
158 | out = cls.__new__(cls)
159 | out.img = v1.img
160 | out.target = v1.target
161 | out.bbox = v1.bbox
162 | return out
163 |
--------------------------------------------------------------------------------
/taskara/img.py:
--------------------------------------------------------------------------------
1 | import asyncio
2 | import base64
3 | import re
4 | from io import BytesIO
5 | import mimetypes
6 | import os
7 | import secrets
8 | import string
9 | import tempfile
10 | from typing import List, Sequence
11 |
12 | from google.cloud import storage
13 | from PIL import Image
14 |
15 | from .env import STORAGE_BUCKET_ENV, STORAGE_SA_JSON_ENV
16 |
17 |
18 | def image_to_b64(img: Image.Image, image_format="PNG") -> str:
19 | """Converts a PIL Image to a base64-encoded string with MIME type included.
20 |
21 | Args:
22 | img (Image.Image): The PIL Image object to convert.
23 | image_format (str): The format to use when saving the image (e.g., 'PNG', 'JPEG').
24 |
25 | Returns:
26 | str: A base64-encoded string of the image with MIME type.
27 | """
28 | buffer = BytesIO()
29 | img.save(buffer, format=image_format)
30 | image_data = buffer.getvalue()
31 | buffer.close()
32 |
33 | mime_type = f"image/{image_format.lower()}"
34 | base64_encoded_data = base64.b64encode(image_data).decode("utf-8")
35 | return f"data:{mime_type};base64,{base64_encoded_data}"
36 |
37 |
38 | def b64_to_image(base64_str: str) -> Image.Image:
39 | """Converts a base64 string to a PIL Image object.
40 |
41 | Args:
42 | base64_str (str): The base64 string, potentially with MIME type as part of a data URI.
43 |
44 | Returns:
45 | Image.Image: The converted PIL Image object.
46 | """
47 | # Strip the MIME type prefix if present
48 | if "," in base64_str:
49 | base64_str = base64_str.split(",")[1]
50 |
51 | image_data = base64.b64decode(base64_str)
52 | image = Image.open(BytesIO(image_data))
53 | return image
54 |
55 |
56 | def parse_image_data(image_data_str: str):
57 | """Parses the image data URL to extract the MIME type and base64 data."""
58 | data_url_pattern = re.compile(
59 | r"data:(?P[^;]+);base64,(?P.+)"
60 | )
61 | match = data_url_pattern.match(image_data_str)
62 | if not match:
63 | raise ValueError("Invalid image data format")
64 | mime_type = match.group("mime_type")
65 | base64_data = match.group("base64_data")
66 | return mime_type, base64_data
67 |
68 |
69 | def generate_random_suffix(length: int = 24) -> str:
70 | """Generates a random suffix for the image file name."""
71 | return "".join(
72 | secrets.choice(string.ascii_letters + string.digits) for _ in range(length)
73 | )
74 |
75 |
76 | def upload_image_to_gcs(image_data: bytes, mime_type: str) -> str:
77 | """Uploads an image to Google Cloud Storage and returns the public URL."""
78 | sa_json = os.getenv(STORAGE_SA_JSON_ENV)
79 | if not sa_json:
80 | raise ValueError(f"Environment variable {STORAGE_SA_JSON_ENV} not set")
81 |
82 | # Check if the service account JSON is a path or a JSON string
83 | if sa_json.startswith("{"):
84 | # Assume it's a JSON string, write to a temporary file
85 | with tempfile.NamedTemporaryFile(delete=False, suffix=".json") as temp_file:
86 | temp_file.write(sa_json.encode())
87 | temp_file_name = temp_file.name
88 | else:
89 | # Assume it's a path to a JSON file
90 | temp_file_name = sa_json
91 |
92 | storage_client = storage.Client.from_service_account_json(temp_file_name)
93 |
94 | bucket_name = os.getenv(STORAGE_BUCKET_ENV)
95 | if not bucket_name:
96 | raise ValueError(f"Environment variable {STORAGE_BUCKET_ENV} not set")
97 |
98 | bucket = storage_client.bucket(bucket_name)
99 |
100 | random_suffix = generate_random_suffix()
101 | extension = mimetypes.guess_extension(mime_type)
102 | blob_name = f"images/{random_suffix}{extension}"
103 | blob = bucket.blob(blob_name)
104 |
105 | # Create a temporary file to write the image data
106 | with tempfile.NamedTemporaryFile(delete=False, suffix=extension) as temp_file:
107 | temp_file.write(image_data)
108 | temp_file_name = temp_file.name
109 |
110 | # Upload the temporary file to Google Cloud Storage
111 | blob.upload_from_filename(temp_file_name)
112 | blob.content_type = mime_type
113 | blob.make_public()
114 |
115 | # Delete the temporary file
116 | os.remove(temp_file_name)
117 |
118 | return blob.public_url
119 |
120 |
121 | def convert_images(images: Sequence[str | Image.Image]) -> List[str]:
122 | sa = os.getenv(STORAGE_SA_JSON_ENV)
123 | new_imgs: List[str] = []
124 | if sa:
125 | for img in images:
126 | if isinstance(img, Image.Image):
127 | new_imgs.append(image_to_b64(img))
128 | elif isinstance(img, str):
129 | if img.startswith("data:"):
130 | mime_type, base64_data = parse_image_data(img)
131 | image_data = base64.b64decode(base64_data)
132 | public_url = upload_image_to_gcs(image_data, mime_type)
133 | new_imgs.append(public_url)
134 | elif img.startswith("https://"):
135 | new_imgs.append(img)
136 | else:
137 | loaded_img = Image.open(img)
138 | b64_img = image_to_b64(loaded_img)
139 | mime_type, base64_data = parse_image_data(b64_img)
140 | image_data = base64.b64decode(base64_data)
141 | public_url = upload_image_to_gcs(image_data, mime_type)
142 | new_imgs.append(public_url)
143 | else:
144 | raise ValueError("unnknown image type")
145 | else:
146 | for img in images:
147 | if isinstance(img, Image.Image):
148 | new_imgs.append(image_to_b64(img))
149 | elif img.startswith("data:") or img.startswith("https://"):
150 | new_imgs.append(img)
151 | else:
152 | loaded_img = Image.open(img)
153 | b64_img = image_to_b64(loaded_img)
154 | new_imgs.append(b64_img)
155 |
156 | return new_imgs
157 |
158 |
159 | async def process_image_async(img: str | Image.Image) -> str:
160 | """Async function to process a single image when service account is set."""
161 | if isinstance(img, Image.Image):
162 | return image_to_b64(img)
163 | elif isinstance(img, str):
164 | if img.startswith("data:"):
165 | print("convert_images function, uploading img to gcs", flush=True)
166 | mime_type, base64_data = parse_image_data(img)
167 | image_data = base64.b64decode(base64_data)
168 | # Offload blocking upload to a thread
169 | return await asyncio.to_thread(upload_image_to_gcs, image_data, mime_type)
170 | elif img.startswith("https://"):
171 | return img
172 | else:
173 | print("convert_images function, uploading img to gcs in else", flush=True)
174 | # Load image in a thread
175 | loaded_img = await asyncio.to_thread(Image.open, img)
176 | b64_img = image_to_b64(loaded_img)
177 | mime_type, base64_data = parse_image_data(b64_img)
178 | image_data = base64.b64decode(base64_data)
179 | return await asyncio.to_thread(upload_image_to_gcs, image_data, mime_type)
180 | else:
181 | raise ValueError("unknown image type")
182 |
183 |
184 | async def convert_images_async(images: Sequence[str | Image.Image]) -> List[str]:
185 | sa = os.getenv(STORAGE_SA_JSON_ENV)
186 | if sa:
187 | tasks = [process_image_async(img) for img in images]
188 | results = await asyncio.gather(*tasks)
189 | return list(results)
190 | else:
191 | # No service account, run synchronously
192 | new_imgs: List[str] = []
193 | for img in images:
194 | print("convert_images function, proceeding with non gcs key path", flush=True)
195 | if isinstance(img, Image.Image):
196 | new_imgs.append(image_to_b64(img))
197 | elif isinstance(img, str):
198 | if img.startswith("data:") or img.startswith("https://"):
199 | new_imgs.append(img)
200 | else:
201 | loaded_img = Image.open(img)
202 | b64_img = image_to_b64(loaded_img)
203 | new_imgs.append(b64_img)
204 | else:
205 | raise ValueError("unknown image type")
206 | return new_imgs
--------------------------------------------------------------------------------
/taskara/metrics.py:
--------------------------------------------------------------------------------
1 | import time
2 |
3 |
4 | class MetricsAggregator:
5 | def __init__(self):
6 | self.timings = {}
7 | self.counts = {}
8 |
9 | def start_timer(self, key):
10 | self.timings[key] = self.timings.get(key, [])
11 | self.timings[key].append(-time.time())
12 |
13 | def stop_timer(self, key):
14 | self.timings[key][-1] += time.time()
15 |
16 | def increment_count(self, key, count=1):
17 | if key not in self.counts:
18 | self.counts[key] = 0
19 | self.counts[key] += count
20 |
21 | def get_timing_stats(self, key):
22 | times = self.timings.get(key, [])
23 | if not times:
24 | return None
25 | return {
26 | "count": len(times),
27 | "total": sum(times),
28 | "avg": sum(times) / len(times),
29 | "min": min(times),
30 | "max": max(times),
31 | }
32 |
33 | def get_count(self, key):
34 | return self.counts.get(key, 0)
35 |
36 | def report(self):
37 | report = {}
38 | for key, times in self.timings.items():
39 | report[key] = self.get_timing_stats(key)
40 | for key, count in self.counts.items():
41 | report[f"{key}_count"] = count
42 | return report
43 |
44 |
45 | # metrics = MetricsAggregator()
46 |
47 | # # Example of timing a code section
48 | # metrics.start_timer('process_data')
49 | # # Simulate data processing
50 | # time.sleep(0.5)
51 | # metrics.stop_timer('process_data')
52 |
53 | # # Counting occurrences
54 | # metrics.increment_count('api_calls')
55 | # metrics.increment_count('api_calls')
56 | # metrics.increment_count('api_errors')
57 |
58 | # # Getting aggregated results
59 | # print(metrics.report())
60 |
--------------------------------------------------------------------------------
/taskara/review.py:
--------------------------------------------------------------------------------
1 | from typing import Dict, Optional, List
2 | import shortuuid
3 | import time
4 | import json
5 |
6 | from skillpacks import Review
7 |
8 | from taskara.db.conn import WithDB
9 | from taskara.db.models import ReviewRequirementRecord, PendingReviewersRecord
10 | from taskara.server.models import (
11 | V1ReviewRequirement,
12 | V1PendingReviewers,
13 | V1PendingReviews,
14 | )
15 |
16 |
17 | class PendingReviewers(WithDB):
18 | """A pending review requirement for a task"""
19 |
20 | def pending_reviewers(
21 | self, task_id: str, requirement_id: Optional[str] = None
22 | ) -> V1PendingReviewers:
23 | """Get the pending reviewers for a task, optionally filtered by requirement_id"""
24 |
25 | for db in self.get_db():
26 | # Start by filtering by task_id
27 | query = db.query(PendingReviewersRecord).filter_by(task_id=task_id)
28 |
29 | # If requirement_id is provided, filter by it as well
30 | if requirement_id:
31 | query = query.filter_by(requirement_id=requirement_id)
32 |
33 | # Get all matching records
34 | records = query.all()
35 |
36 | # Extract users and agents
37 | users = set([str(record.user_id) for record in records if record.user_id]) # type: ignore
38 | agents = set(
39 | [str(record.agent_id) for record in records if record.agent_id] # type: ignore
40 | )
41 |
42 | # Return the V1PendingReviewers with the filtered data
43 | return V1PendingReviewers(
44 | task_id=task_id,
45 | users=list(users) if users else [],
46 | agents=list(agents) if agents else [],
47 | )
48 |
49 | raise SystemError("no session")
50 |
51 | def pending_reviews(
52 | self, user: Optional[str] = None, agent: Optional[str] = None
53 | ) -> V1PendingReviews:
54 | """Get the pending reviews for a user or agent"""
55 |
56 | for db in self.get_db():
57 | query = db.query(PendingReviewersRecord)
58 |
59 | if user:
60 | query = query.filter(PendingReviewersRecord.user_id == user)
61 | if agent:
62 | query = query.filter(PendingReviewersRecord.agent_id == agent)
63 |
64 | records = query.all()
65 |
66 | tasks = list(set([str(record.task_id) for record in records]))
67 |
68 | return V1PendingReviews(tasks=tasks)
69 |
70 | raise SystemError("no session")
71 |
72 | def ensure_pending_reviewer(
73 | self,
74 | task_id: str,
75 | user: Optional[str] = None,
76 | agent: Optional[str] = None,
77 | requirement_id: Optional[str] = None,
78 | ) -> None:
79 | """Add a pending reviewer for a task, avoiding duplicates"""
80 |
81 | if not user and not agent:
82 | raise ValueError("Either user or agent must be provided")
83 |
84 | for db in self.get_db():
85 | # Check if the record already exists
86 | query = db.query(PendingReviewersRecord).filter_by(task_id=task_id)
87 | if user:
88 | query = query.filter_by(user_id=user)
89 | if agent:
90 | query = query.filter_by(agent_id=agent)
91 | if requirement_id:
92 | query = query.filter_by(requirement_id=requirement_id)
93 |
94 | existing_record = query.first()
95 |
96 | if existing_record:
97 | # If the record exists, return early to avoid duplicates
98 | return
99 |
100 | # Add the new record if it doesn't already exist
101 | new_record = PendingReviewersRecord(
102 | id=shortuuid.uuid(),
103 | task_id=task_id,
104 | user_id=user,
105 | agent_id=agent,
106 | requirement_id=requirement_id,
107 | )
108 | db.add(new_record)
109 | db.commit()
110 |
111 | def remove_pending_reviewer(
112 | self,
113 | task_id: str,
114 | user: Optional[str] = None,
115 | agent: Optional[str] = None,
116 | requirement_id: Optional[str] = None,
117 | ) -> None:
118 | """Remove a pending reviewer for a task"""
119 |
120 | if not user and not agent:
121 | raise ValueError("Either user or agent must be provided")
122 |
123 | for db in self.get_db():
124 | query = db.query(PendingReviewersRecord).filter_by(task_id=task_id)
125 | if user:
126 | query = query.filter_by(user_id=user)
127 | if agent:
128 | query = query.filter_by(agent_id=agent)
129 | if requirement_id:
130 | query = query.filter_by(requirement_id=requirement_id)
131 |
132 | record = query.first()
133 | if record:
134 | db.delete(record)
135 | db.commit()
136 |
137 | def task_is_pending(self, task_id: str) -> bool:
138 | """Check if a task has pending reviewers"""
139 |
140 | for db in self.get_db():
141 | record = db.query(PendingReviewersRecord).filter_by(task_id=task_id).first()
142 | return bool(record)
143 | raise SystemError("no session")
144 |
145 | def pending_tasks(self, user: Optional[str] = None) -> List[str]:
146 | """Get the pending tasks for a user or agent"""
147 |
148 | for db in self.get_db():
149 | query = db.query(PendingReviewersRecord)
150 | if user:
151 | query = query.filter_by(user_id=user)
152 | tasks = [str(record.task_id) for record in query.all()]
153 | return tasks
154 |
155 | raise SystemError("no session")
156 |
157 |
158 | class ReviewRequirement(WithDB):
159 | """A review requirement for a task"""
160 |
161 | def __init__(
162 | self,
163 | task_id: Optional[str] = None,
164 | number_required: int = 2,
165 | users: Optional[List[str]] = None,
166 | agents: Optional[List[str]] = None,
167 | groups: Optional[List[str]] = None,
168 | types: Optional[List[str]] = None,
169 | created: Optional[float] = None,
170 | updated: Optional[float] = None,
171 | ) -> None:
172 | self.id = str(shortuuid.uuid())
173 | self.task_id = task_id
174 | self.number_required = number_required
175 | self.users = users or []
176 | self.agents = agents or []
177 | self.groups = groups or []
178 | self.types = types or []
179 | self.created = created or time.time()
180 | self.updated = updated
181 |
182 | def to_v1(self) -> V1ReviewRequirement:
183 | return V1ReviewRequirement(
184 | id=self.id,
185 | task_id=self.task_id,
186 | users=self.users,
187 | agents=self.agents,
188 | groups=self.groups,
189 | number_required=self.number_required,
190 | )
191 |
192 | @classmethod
193 | def from_v1(cls, v1: V1ReviewRequirement) -> "ReviewRequirement":
194 | out = cls.__new__(cls)
195 | out.id = v1.id
196 | out.task_id = v1.task_id
197 | out.number_required = v1.number_required
198 | out.users = v1.users
199 | out.agents = v1.agents
200 | out.groups = v1.groups
201 |
202 | return out
203 |
204 | def save(self) -> None:
205 | """Saves the review requirement to the database."""
206 |
207 | for db in self.get_db():
208 | record = self.to_record()
209 | db.merge(record)
210 | db.commit()
211 |
212 | def delete(self) -> None:
213 | """Deletes the review requirement from the database."""
214 |
215 | for db in self.get_db():
216 | record = (
217 | db.query(ReviewRequirementRecord)
218 | .filter(ReviewRequirementRecord.id == self.id)
219 | .first()
220 | )
221 | if record:
222 | db.delete(record)
223 | db.commit()
224 | else:
225 | raise ValueError("Review requirement not found")
226 |
227 | def to_record(self) -> ReviewRequirementRecord:
228 | """Converts the review requirement to a database record."""
229 |
230 | return ReviewRequirementRecord(
231 | id=self.id,
232 | task_id=self.task_id,
233 | number_required=self.number_required,
234 | users=json.dumps(self.users),
235 | agents=json.dumps(self.agents),
236 | groups=json.dumps(self.groups),
237 | types=json.dumps(self.types),
238 | created=self.created,
239 | updated=self.updated,
240 | )
241 |
242 | @classmethod
243 | def from_record(cls, record: ReviewRequirementRecord) -> "ReviewRequirement":
244 | """Creates a review requirement instance from a database record."""
245 |
246 | review_requirement = cls.__new__(cls)
247 | review_requirement.id = record.id
248 | review_requirement.task_id = record.task_id
249 | review_requirement.number_required = record.number_required
250 | review_requirement.users = json.loads(record.users) # type: ignore
251 | review_requirement.agents = json.loads(record.agents) # type: ignore
252 | review_requirement.groups = json.loads(record.groups) # type: ignore
253 | review_requirement.types = json.loads(record.types) # type: ignore
254 | review_requirement.created = record.created
255 | review_requirement.updated = record.updated
256 | return review_requirement
257 |
258 | @classmethod
259 | def find(cls, **kwargs) -> List["ReviewRequirement"]:
260 | """Finds review requirements in the database based on provided filters."""
261 |
262 | for db in cls.get_db():
263 | records = db.query(ReviewRequirementRecord).filter_by(**kwargs).all()
264 | return [cls.from_record(record) for record in records]
265 | raise ValueError("No database session available")
266 |
267 | @classmethod
268 | def find_many(cls,
269 | task_ids: Optional[List[str]] = None,
270 | requirement_ids: Optional[List[str]] = None,
271 | ) -> List["ReviewRequirement"]:
272 | """Finds review requirements in the database based on provided filters."""
273 | for db in cls.get_db():
274 | query = db.query(ReviewRequirementRecord)
275 | if task_ids:
276 | query = query.filter(ReviewRequirementRecord.task_id.in_(task_ids))
277 | if requirement_ids:
278 | query = query.filter(ReviewRequirementRecord.id.in_(requirement_ids))
279 | records = db.query(ReviewRequirementRecord).all()
280 | return [cls.from_record(record) for record in records]
281 | raise ValueError("No database session available")
--------------------------------------------------------------------------------
/taskara/runtime/base.py:
--------------------------------------------------------------------------------
1 | import json
2 | import time
3 | from abc import ABC, abstractmethod
4 | from typing import Dict, Generic, Iterator, List, Optional, Tuple, Type, TypeVar, Union
5 |
6 | import shortuuid
7 | from pydantic import BaseModel
8 |
9 | from taskara.db.conn import WithDB
10 | from taskara.db.models import TrackerRecord
11 | from taskara.server.models import (
12 | V1ResourceLimits,
13 | V1ResourceRequests,
14 | V1Tracker,
15 | V1TrackerRuntimeConnect,
16 | )
17 |
18 | R = TypeVar("R", bound="TrackerRuntime")
19 | C = TypeVar("C", bound="BaseModel")
20 |
21 |
22 | class Tracker(WithDB):
23 | """A task server"""
24 |
25 | def __init__(
26 | self,
27 | name: str,
28 | port: int,
29 | runtime: "TrackerRuntime",
30 | status: str = "running",
31 | owner_id: Optional[str] = None,
32 | labels: Optional[Dict[str, str]] = None,
33 | ) -> None:
34 | self._id = shortuuid.uuid()
35 | self._name = name
36 | self._port = port
37 | self._status = status
38 | self._runtime = runtime
39 | self._owner_id = owner_id
40 | self._created = time.time()
41 | self._updated = time.time()
42 | self._labels = labels
43 |
44 | self.save()
45 |
46 | @property
47 | def id(self) -> str:
48 | return self._id
49 |
50 | @property
51 | def status(self) -> str:
52 | return self._status
53 |
54 | @property
55 | def name(self) -> str:
56 | return self._name
57 |
58 | @property
59 | def runtime(self) -> "TrackerRuntime":
60 | return self._runtime
61 |
62 | @property
63 | def port(self) -> int:
64 | return self._port
65 |
66 | @property
67 | def owner_id(self) -> Optional[str]:
68 | return self._owner_id
69 |
70 | @property
71 | def created(self) -> float:
72 | return self._created
73 |
74 | @property
75 | def updated(self) -> float:
76 | return self._updated
77 |
78 | @property
79 | def labels(self) -> Optional[Dict[str, str]]:
80 | return self._labels
81 |
82 | def proxy(
83 | self,
84 | local_port: Optional[int] = None,
85 | background: bool = True,
86 | ) -> Optional[int]:
87 | return self._runtime.proxy(self._name, local_port, self.port, background)
88 |
89 | def delete(self, force: bool = False) -> None:
90 | """
91 | Deletes the server instance from the runtime and the database.
92 | """
93 | # First, delete the server instance from the runtime.
94 | try:
95 | self._runtime.delete(self._name)
96 | except Exception as e:
97 | if not force:
98 | raise e
99 |
100 | # After the runtime deletion, proceed to delete the record from the database.
101 | for db in self.get_db():
102 | record = db.query(TrackerRecord).filter_by(id=self._id).one()
103 | db.delete(record)
104 | db.commit()
105 |
106 | def logs(self, follow: bool = False) -> Union[str, Iterator[str]]:
107 | """
108 | Fetches the logs from the specified pod.
109 |
110 | Parameters:
111 | follow (bool): If True, stream logs until the connection
112 |
113 | Returns:
114 | str: The logs from the pod.
115 | """
116 | return self._runtime.logs(self._name, follow)
117 |
118 | def save(self) -> None:
119 | for db in self.get_db():
120 | record = self.to_record()
121 | db.merge(record)
122 | db.commit()
123 |
124 | @classmethod
125 | def find(cls, **kwargs) -> List["Tracker"]:
126 | for db in cls.get_db():
127 | records = (
128 | db.query(TrackerRecord)
129 | .filter_by(**kwargs)
130 | .order_by(TrackerRecord.created.desc())
131 | .all()
132 | )
133 | return [cls.from_record(record) for record in records]
134 | raise ValueError("No session")
135 |
136 | @classmethod
137 | def active_runtimes(cls) -> List["TrackerRuntime"]:
138 | """Get all runtimes currently being used by a tracker
139 |
140 | Returns:
141 | List[TrackerRuntime]: a list of tracker runtimes
142 | """
143 | trackers = cls.find()
144 | return [tracker.runtime for tracker in trackers]
145 |
146 | def to_v1(self) -> V1Tracker:
147 | """Convert to V1 API model"""
148 | return V1Tracker(
149 | name=self._name,
150 | runtime=V1TrackerRuntimeConnect(
151 | name=self._runtime.name(), connect_config=self.runtime.connect_config()
152 | ),
153 | port=self._port,
154 | status=self._status,
155 | owner_id=self._owner_id,
156 | created=self._created,
157 | updated=self._updated,
158 | labels=self._labels or {},
159 | )
160 |
161 | def to_record(self) -> TrackerRecord:
162 | """Convert to DB model"""
163 | runtime_cfg = self._runtime.connect_config().model_dump_json()
164 |
165 | return TrackerRecord(
166 | id=self._id,
167 | name=self._name,
168 | runtime_name=self._runtime.name(),
169 | runtime_config=runtime_cfg,
170 | port=self._port,
171 | status=self._status,
172 | owner_id=self._owner_id,
173 | created=self._created,
174 | updated=self._updated,
175 | labels=json.dumps(self._labels or {}),
176 | )
177 |
178 | @classmethod
179 | def from_record(cls, record: TrackerRecord) -> "Tracker":
180 | from taskara.runtime.load import runtime_from_name
181 |
182 | runtype = runtime_from_name(str(record.runtime_name))
183 | runcfg = runtype.connect_config_type().model_validate_json(
184 | str(record.runtime_config)
185 | )
186 | runtime = runtype.connect(runcfg)
187 |
188 | obj = cls.__new__(cls)
189 | obj._id = str(record.id)
190 | obj._name = str(record.name)
191 | obj._runtime = runtime
192 | obj._status = record.status
193 | obj._port = record.port
194 | obj._owner_id = record.owner_id
195 | obj._created = record.created
196 | obj._updated = record.updated
197 | obj._labels = json.loads(record.labels) if record.labels else None # type: ignore
198 |
199 | return obj
200 |
201 | def call(
202 | self,
203 | path: str,
204 | method: str,
205 | data: Optional[dict] = None,
206 | headers: Optional[dict] = None,
207 | ) -> Tuple[int, str]:
208 | """Call the task server
209 |
210 | Args:
211 | path (str): Path to call
212 | method (str): Method to use
213 | data (Optional[dict], optional): Body data. Defaults to None.
214 | headers (Optional[dict], optional): Headers. Defaults to None.
215 |
216 | Returns:
217 | Tuple[int, str]: Status code and response text
218 | """
219 | return self.runtime.call(
220 | name=self.name,
221 | path=path,
222 | method=method,
223 | port=self.port,
224 | data=data,
225 | headers=headers,
226 | )
227 |
228 |
229 | class TrackerRuntime(Generic[R, C], ABC):
230 |
231 | @classmethod
232 | def name(cls) -> str:
233 | return cls.__name__
234 |
235 | @classmethod
236 | @abstractmethod
237 | def connect_config_type(cls) -> Type[C]:
238 | """The pydantic model which defines the schema for connecting to this runtime
239 |
240 | Returns:
241 | Type[C]: The type
242 | """
243 | pass
244 |
245 | @abstractmethod
246 | def connect_config(cls) -> C:
247 | """The connect config for this runtime instance
248 |
249 | Returns:
250 | C: Connect config
251 | """
252 | pass
253 |
254 | @classmethod
255 | @abstractmethod
256 | def connect(cls, cfg: C) -> R:
257 | """Connect to the runtime using this configuration
258 |
259 | Args:
260 | cfg (C): Connect config
261 |
262 | Returns:
263 | R: A runtime
264 | """
265 | pass
266 |
267 | @abstractmethod
268 | def run(
269 | self,
270 | name: str,
271 | env_vars: Optional[dict] = None,
272 | owner_id: Optional[str] = None,
273 | labels: Optional[Dict[str, str]] = None,
274 | resource_requests: V1ResourceRequests = V1ResourceRequests(),
275 | resource_limits: V1ResourceLimits = V1ResourceLimits(),
276 | auth_enabled: bool = True,
277 | ) -> Tracker:
278 | """Run the task server
279 |
280 | Args:
281 | name (str): Name of the task server
282 | env_vars (Optional[dict], optional): Env vars to supply. Defaults to None.
283 | owner_id (Optional[str], optional): Owner ID. Defaults to None.
284 | labels (Optional[Dict[str, str]], optional): Labels for the task server. Defaults to None.
285 | resource_requests (V1ResourceRequests, optional): Resource requests. Defaults to V1ResourceRequests().
286 | resource_limits (V1ResourceLimits, optional): Resource limits. Defaults to V1ResourceLimits().
287 | auth_enabled (bool, optional): Whether to enable auth. Defaults to True.
288 |
289 | Returns:
290 | Tracker: An task server instance
291 | """
292 | pass
293 |
294 | @abstractmethod
295 | def list(
296 | self, owner_id: Optional[str] = None, source: bool = False
297 | ) -> List[Tracker]:
298 | """List task server instances
299 |
300 | Args:
301 | owner_id (Optional[str], optional): An optional owner id. Defaults to None.
302 | source (bool, optional): Whether to list directly from the source. Defaults to False.
303 |
304 | Returns:
305 | List[Tracker]: A list of task server instances
306 | """
307 | pass
308 |
309 | @abstractmethod
310 | def get(
311 | self, name: str, owner_id: Optional[str] = None, source: bool = False
312 | ) -> Tracker:
313 | """Get an task server instance
314 |
315 | Args:
316 | name (str): Name of the task server
317 | owner_id (Optional[str], optional): Optional owner ID. Defaults to None.
318 | source (bool, optional): Whether to fetch directly from the source. Defaults to False.
319 |
320 | Returns:
321 | Tracker: An task server instance
322 | """
323 | pass
324 |
325 | @abstractmethod
326 | def requires_proxy(self) -> bool:
327 | """Whether this runtime requires a proxy to be used"""
328 | pass
329 |
330 | @abstractmethod
331 | def proxy(
332 | self,
333 | name: str,
334 | local_port: Optional[int] = None,
335 | tracker_port: int = 9070,
336 | background: bool = True,
337 | owner_id: Optional[str] = None,
338 | ) -> Optional[int]:
339 | """Proxy a port to the task server
340 |
341 | Args:
342 | name (str): Name of the task server
343 | local_port (Optional[int], optional): Local port to proxy to. Defaults to None.
344 | tracker_port (int, optional): The task servers port. Defaults to 9070.
345 | background (bool, optional): Whether to run the proxy in the background. Defaults to True.
346 | owner_id (Optional[str], optional): An optional owner ID. Defaults to None.
347 |
348 | Returns:
349 | Optional[int]: The pid of the proxy
350 | """
351 | pass
352 |
353 | @abstractmethod
354 | def delete(self, name: str, owner_id: Optional[str] = None) -> None:
355 | """Delete an task server instance
356 |
357 | Args:
358 | name (str): Name of the task server
359 | owner_id (Optional[str], optional): An optional owner id. Defaults to None.
360 | """
361 | pass
362 |
363 | @abstractmethod
364 | def clean(self, owner_id: Optional[str] = None) -> None:
365 | """Delete all task server instances
366 |
367 | Args:
368 | owner_id (Optional[str], optional): An optional owner ID to scope it to. Defaults to None.
369 | """
370 | pass
371 |
372 | @abstractmethod
373 | def logs(
374 | self, name: str, follow: bool = False, owner_id: Optional[str] = None
375 | ) -> Union[str, Iterator[str]]:
376 | """
377 | Fetches the logs from the specified task server.
378 |
379 | Parameters:
380 | name (str): The name of the task server.
381 |
382 | Returns:
383 | str: The logs from the task server.
384 | """
385 | pass
386 |
387 | @abstractmethod
388 | def call(
389 | self,
390 | name: str,
391 | path: str,
392 | method: str,
393 | port: int = 9070,
394 | data: Optional[dict] = None,
395 | headers: Optional[dict] = None,
396 | ) -> Tuple[int, str]:
397 | """Call the task server
398 |
399 | Args:
400 | name (str): Name of the server
401 | path (str): Path to call
402 | method (str): Method to use
403 | port (int, optional): Port to use. Defaults to 9070.
404 | data (Optional[dict], optional): Body data. Defaults to None.
405 | headers (Optional[dict], optional): Headers. Defaults to None.
406 |
407 | Returns:
408 | Tuple[int, str]: Status code and response text
409 | """
410 | pass
411 |
412 | @abstractmethod
413 | def refresh(self, owner_id: Optional[str] = None) -> None:
414 | """Refresh the runtime
415 |
416 | Args:
417 | owner_id (Optional[str], optional): Owner id to scope it to. Defaults to None.
418 | """
419 | pass
420 |
421 | @abstractmethod
422 | def runtime_local_addr(self, name: str, owner_id: Optional[str] = None) -> str:
423 | """
424 | Returns the local address of the agent with respect to the runtime
425 | """
426 | pass
427 |
--------------------------------------------------------------------------------
/taskara/runtime/docker.py:
--------------------------------------------------------------------------------
1 | import json
2 | import logging
3 | import os
4 | import signal
5 | import sys
6 | import urllib.error
7 | import urllib.parse
8 | import urllib.request
9 | from typing import Dict, Iterator, List, Optional, Tuple, Type, Union
10 |
11 | import docker
12 | from docker.api.client import APIClient
13 | from docker.errors import NotFound
14 | from pydantic import BaseModel
15 | from tqdm import tqdm
16 |
17 | from taskara.server.models import (
18 | V1ResourceLimits,
19 | V1ResourceRequests,
20 | V1Tracker,
21 | V1TrackerRuntimeConnect,
22 | )
23 | from taskara.util import find_open_port
24 |
25 | from .base import Tracker, TrackerRuntime
26 |
27 | logger = logging.getLogger(__name__)
28 |
29 |
30 | class DockerConnectConfig(BaseModel):
31 | timeout: Optional[int] = None
32 | image: str = "us-central1-docker.pkg.dev/agentsea-dev/taskara/api:latest"
33 |
34 |
35 | class DockerTrackerRuntime(TrackerRuntime["DockerTrackerRuntime", DockerConnectConfig]):
36 |
37 | def __init__(self, cfg: Optional[DockerConnectConfig] = None) -> None:
38 | self.docker_socket = self._configure_docker_socket()
39 | if not cfg:
40 | cfg = DockerConnectConfig()
41 |
42 | self.img = cfg.image
43 |
44 | self._cfg = cfg
45 | if cfg.timeout:
46 | self.client = docker.DockerClient(base_url=self.docker_socket, timeout=cfg.timeout)
47 | else:
48 | self.client = docker.DockerClient(base_url=self.docker_socket)
49 |
50 | # Verify connection and version
51 | self._check_version()
52 |
53 |
54 | def _configure_docker_socket(self):
55 | if os.path.exists("/var/run/docker.sock"):
56 | docker_socket = "unix:///var/run/docker.sock"
57 | else:
58 | user = os.environ.get("USER")
59 | if os.path.exists(f"/Users/{user}/.docker/run/docker.sock"):
60 | docker_socket = f"unix:///Users/{user}/.docker/run/docker.sock"
61 | else:
62 | raise FileNotFoundError(
63 | (
64 | "Neither '/var/run/docker.sock' nor '/Users//.docker/run/docker.sock' are available."
65 | "Please make sure you have Docker installed and running."
66 | )
67 | )
68 | os.environ["DOCKER_HOST"] = docker_socket
69 | return docker_socket
70 |
71 | def _check_version(self):
72 | version_info = self.client.version()
73 | engine_version = next((component['Version'] for component in version_info.get('Components', [])
74 | if component['Name'] == 'Engine'), None)
75 | if not engine_version:
76 | raise SystemError("Unable to determine Docker Engine version")
77 | logger.debug(f"Connected to Docker Engine version: {engine_version}")
78 |
79 | @classmethod
80 | def name(cls) -> str:
81 | return "docker"
82 |
83 | @classmethod
84 | def connect_config_type(cls) -> Type[DockerConnectConfig]:
85 | return DockerConnectConfig
86 |
87 | def connect_config(self) -> DockerConnectConfig:
88 | return self._cfg
89 |
90 | @classmethod
91 | def connect(cls, cfg: DockerConnectConfig) -> "DockerTrackerRuntime":
92 | return cls(cfg)
93 |
94 | def call(
95 | self,
96 | name: str,
97 | path: str,
98 | method: str,
99 | port: int = 9070,
100 | data: Optional[dict] = None,
101 | headers: Optional[dict] = None,
102 | ) -> Tuple[int, str]:
103 | # Attempt to get the container by name
104 | try:
105 | container = self.client.containers.get(name)
106 | except NotFound:
107 | raise ValueError(f"Container '{name}' not found")
108 |
109 | # Construct the URL using the mapped port
110 | url = f"http://localhost:{port}{path}"
111 |
112 | # Create a request object based on the HTTP method
113 | if method.upper() == "GET":
114 | if data:
115 | query_params = urllib.parse.urlencode(data)
116 | url += f"?{query_params}"
117 | request = urllib.request.Request(url)
118 | else:
119 | request = urllib.request.Request(url, method=method.upper())
120 | if data:
121 | request.add_header("Content-Type", "application/json")
122 | if headers:
123 | for k, v in headers.items():
124 | request.add_header(k, v)
125 | request.data = json.dumps(data).encode("utf-8")
126 |
127 | # Send the request and handle the response
128 | try:
129 | response = urllib.request.urlopen(request)
130 | status_code = response.code
131 | response_text = response.read().decode("utf-8")
132 | return status_code, response_text
133 | except urllib.error.HTTPError as e:
134 | status_code = e.code
135 | error_message = e.read().decode("utf-8")
136 | raise SystemError(
137 | f"Error making HTTP request to Docker container: {status_code}: {error_message}"
138 | )
139 | finally:
140 | try:
141 | if response: # type: ignore
142 | response.close()
143 | except:
144 | pass
145 |
146 | def _ensure_network_exists(self, network_name: str):
147 | try:
148 | self.client.networks.get(network_name)
149 | logger.debug(f"Network '{network_name}' already exists.")
150 | except NotFound:
151 | logger.debug(f"Network '{network_name}' not found. Creating network.")
152 | self.client.networks.create(network_name)
153 | logger.debug(f"Network '{network_name}' created.")
154 |
155 | def run(
156 | self,
157 | name: str,
158 | env_vars: Optional[dict] = None,
159 | owner_id: Optional[str] = None,
160 | labels: Optional[Dict[str, str]] = None,
161 | resource_requests: V1ResourceRequests = V1ResourceRequests(),
162 | resource_limits: V1ResourceLimits = V1ResourceLimits(),
163 | auth_enabled: bool = True,
164 | ) -> Tracker:
165 |
166 | api_client = docker.APIClient(base_url=self.docker_socket)
167 |
168 | # Pull the image with progress tracking
169 | pull_image(self.img, api_client)
170 |
171 | _labels = {
172 | "provisioner": "taskara",
173 | "server_name": name,
174 | }
175 | if labels:
176 | _labels.update(labels)
177 |
178 | port = find_open_port(9070, 10090)
179 | if not port:
180 | raise ValueError("Could not find open port")
181 |
182 | if not env_vars:
183 | env_vars = {}
184 |
185 | if not auth_enabled:
186 | env_vars["TASK_SERVER_NO_AUTH"] = "true"
187 |
188 | if not self.img:
189 | raise ValueError("img not found")
190 |
191 | self._ensure_network_exists("agentsea")
192 | container = self.client.containers.run(
193 | self.img,
194 | network="agentsea",
195 | ports={9070: port},
196 | environment=env_vars,
197 | detach=True,
198 | labels=_labels,
199 | name=name,
200 | )
201 | if container and type(container) != bytes:
202 | logger.debug(f"ran container '{container.id}'") # type: ignore
203 |
204 | return Tracker(
205 | name=name,
206 | runtime=self,
207 | status="running",
208 | port=port,
209 | owner_id=owner_id,
210 | )
211 |
212 | def runtime_local_addr(self, name: str, owner_id: Optional[str] = None) -> str:
213 | """
214 | Returns the local address of the agent with respect to the runtime
215 | """
216 | instances = Tracker.find(name=name, owner_id=owner_id, runtime_name=self.name())
217 | if not instances:
218 | raise ValueError(f"Task server '{name}' not found")
219 | instance = instances[0]
220 |
221 | return f"http://{instance.name}:{instance.port}"
222 |
223 | def _handle_logs_with_attach(self, server_name: str, attach: bool):
224 | if attach:
225 | # Setup the signal handler to catch interrupt signals
226 | signal.signal(signal.SIGINT, self._signal_handler(server_name))
227 |
228 | try:
229 | for line in self.logs(server_name, follow=True):
230 | print(line)
231 | except KeyboardInterrupt:
232 | # This block will be executed if SIGINT is caught
233 | print(f"Interrupt received, stopping logs for '{server_name}'")
234 | self.delete(server_name)
235 | except Exception as e:
236 | print(f"Error while streaming logs: {e}")
237 |
238 | def _signal_handler(self, server_name: str):
239 | def handle_signal(signum, frame):
240 | print(f"Signal {signum} received, stopping container '{server_name}'")
241 | self.delete(server_name)
242 | sys.exit(1)
243 |
244 | return handle_signal
245 |
246 | def requires_proxy(self) -> bool:
247 | """Whether this runtime requires a proxy to be used"""
248 | return False
249 |
250 | def proxy(
251 | self,
252 | name: str,
253 | local_port: Optional[int] = None,
254 | tracker_port: int = 9070,
255 | background: bool = True,
256 | owner_id: Optional[str] = None,
257 | ) -> Optional[int]:
258 | return
259 |
260 | def list(
261 | self, owner_id: Optional[str] = None, source: bool = False
262 | ) -> List[Tracker]:
263 |
264 | instances = []
265 | if source:
266 | label_filter = {"label": "provisioner=taskara"}
267 | containers = self.client.containers.list(filters=label_filter)
268 |
269 | for container in containers:
270 | server_name = container.name
271 |
272 | # Extract the TASK_SERVER_PORT environment variable
273 | env_vars = container.attrs.get("Config", {}).get("Env", [])
274 | port = next(
275 | (
276 | int(var.split("=")[1])
277 | for var in env_vars
278 | if var.startswith("TASK_SERVER_PORT=")
279 | ),
280 | 9070,
281 | )
282 |
283 | instance = Tracker(
284 | name=server_name,
285 | runtime=self,
286 | port=port,
287 | status="running",
288 | owner_id=owner_id,
289 | )
290 | instances.append(instance)
291 | else:
292 | return Tracker.find(owner_id=owner_id, runtime_name=self.name())
293 |
294 | return instances
295 |
296 | def get(
297 | self, name: str, owner_id: Optional[str] = None, source: bool = False
298 | ) -> Tracker:
299 | if source:
300 | try:
301 | container = self.client.containers.get(name)
302 |
303 | # Extract the TASK_SERVER_PORT environment variable
304 | env_vars = container.attrs.get("Config", {}).get("Env", [])
305 | port = next(
306 | (
307 | int(var.split("=")[1])
308 | for var in env_vars
309 | if var.startswith("TASK_SERVER_PORT=")
310 | ),
311 | 9070,
312 | )
313 |
314 | return Tracker(
315 | name=name,
316 | runtime=self,
317 | status="running",
318 | port=port,
319 | owner_id=owner_id,
320 | )
321 | except NotFound:
322 | raise ValueError(f"Container '{name}' not found")
323 |
324 | else:
325 | instances = Tracker.find(
326 | name=name, owner_id=owner_id, runtime_name=self.name()
327 | )
328 | if not instances:
329 | raise ValueError(f"Task server '{name}' not found")
330 | return instances[0]
331 |
332 | def delete(self, name: str, owner_id: Optional[str] = None) -> None:
333 | try:
334 | # Attempt to get the container by name
335 | container = self.client.containers.get(name)
336 |
337 | # If found, remove the container
338 | container.remove(force=True) # type: ignore
339 | logger.debug(f"Successfully deleted container: {name}")
340 | except NotFound:
341 | # Handle the case where the container does not exist
342 | logger.debug(f"Container '{name}' does not exist.")
343 | raise
344 | except Exception as e:
345 | # Handle other potential errors
346 | logger.error(f"Failed to delete container '{name}': {e}")
347 | raise
348 |
349 | def clean(self, owner_id: Optional[str] = None) -> None:
350 | # Define the filter for containers with the specific label
351 | label_filter = {"label": ["provisioner=taskara"]}
352 |
353 | # Use the filter to list containers
354 | containers = self.client.containers.list(filters=label_filter, all=True)
355 |
356 | # Initialize a list to keep track of deleted container names or IDs
357 | deleted_containers = []
358 |
359 | for container in containers:
360 | try:
361 | container_name_or_id = (
362 | container.name # type: ignore
363 | ) # or container.id for container ID
364 | container.remove(force=True) # type: ignore
365 | logger.debug(f"Deleted container: {container_name_or_id}")
366 | deleted_containers.append(container_name_or_id)
367 | except Exception as e:
368 | logger.error(f"Failed to delete container: {e}")
369 |
370 | return None
371 |
372 | def logs(
373 | self, name: str, follow: bool = False, owner_id: Optional[str] = None
374 | ) -> Union[str, Iterator[str]]:
375 | """
376 | Fetches the logs from the specified container. Can return all logs as a single string,
377 | or stream the logs as a generator of strings.
378 |
379 | Parameters:
380 | name (str): The name of the container.
381 | follow (bool): Whether to continuously follow the logs.
382 |
383 | Returns:
384 | Union[str, Iterator[str]]: All logs as a single string, or a generator that yields log lines.
385 | """
386 | try:
387 | container = self.client.containers.get(name)
388 | if follow:
389 | log_stream = container.logs(stream=True, follow=True) # type: ignore
390 | return (line.decode("utf-8").strip() for line in log_stream) # type: ignore
391 | else:
392 | return container.logs().decode("utf-8") # type: ignore
393 | except NotFound:
394 | print(f"Container '{name}' does not exist.")
395 | raise
396 | except Exception as e:
397 | print(f"Failed to fetch logs for container '{name}': {e}")
398 | raise
399 |
400 | def refresh(self, owner_id: Optional[str] = None) -> None:
401 | """
402 | Reconciles the database against the Docker containers running.
403 |
404 | Parameters:
405 | owner_id (Optional[str]): The owner ID to filter the trackers. If None, refreshes for all owners.
406 | """
407 | # List all Docker containers with the specific label
408 | label_filter = {"label": "provisioner=taskara"}
409 | running_containers = self.client.containers.list(filters=label_filter)
410 | running_container_names = {container.name for container in running_containers} # type: ignore
411 |
412 | # List all trackers in the database
413 | if owner_id:
414 | db_trackers = Tracker.find(owner_id=owner_id, runtime_name=self.name())
415 | else:
416 | db_trackers = Tracker.find(runtime_name=self.name())
417 |
418 | db_tracker_names = {tracker.name for tracker in db_trackers}
419 |
420 | # Determine trackers to add or remove from the database
421 | containers_to_add = running_container_names - db_tracker_names
422 | containers_to_remove = db_tracker_names - running_container_names
423 |
424 | # Add new containers to the database
425 | for container_name in containers_to_add:
426 | container = self.client.containers.get(container_name)
427 | env_vars = container.attrs.get("Config", {}).get("Env", [])
428 | port = next(
429 | (
430 | int(var.split("=")[1])
431 | for var in env_vars
432 | if var.startswith("TASK_SERVER_PORT=")
433 | ),
434 | 9070,
435 | )
436 | new_tracker = Tracker(
437 | name=container_name,
438 | runtime=self,
439 | port=port,
440 | status="running",
441 | owner_id=owner_id,
442 | )
443 | new_tracker.save()
444 |
445 | # Remove containers from the database that are no longer running
446 | for tracker_name in containers_to_remove:
447 | trackers = Tracker.find(
448 | name=tracker_name, owner_id=owner_id, runtime_name=self.name()
449 | )
450 | if not trackers:
451 | continue
452 |
453 | tracker = trackers[0]
454 | tracker.delete()
455 |
456 | logger.debug(
457 | f"Refresh completed: added {len(containers_to_add)} trackers, removed {len(containers_to_remove)} trackers."
458 | )
459 |
460 |
461 | def pull_image(img: str, api_client: APIClient):
462 | """
463 | Pulls a Docker image with progress bars for each layer.
464 |
465 | Args:
466 | img (str): The Docker image to pull.
467 | api_client (APIClient): The Docker API client.
468 | """
469 |
470 | print(f"Pulling Docker image '{img}'...")
471 |
472 | progress_bars = {}
473 | layers = {}
474 |
475 | for line in api_client.pull(img, stream=True, decode=True):
476 | if "id" in line and "progressDetail" in line:
477 | layer_id = line["id"]
478 | progress_detail = line["progressDetail"]
479 | current = progress_detail.get("current", 0)
480 | total = progress_detail.get("total", 0)
481 |
482 | if total:
483 | if layer_id not in layers:
484 | progress_bars[layer_id] = tqdm(
485 | total=total,
486 | desc=f"Layer {layer_id}",
487 | leave=False,
488 | ncols=100,
489 | )
490 | layers[layer_id] = 0
491 |
492 | layers[layer_id] = current
493 | progress_bars[layer_id].n = current
494 | progress_bars[layer_id].refresh()
495 |
496 | # Close all progress bars
497 | for bar in progress_bars.values():
498 | bar.n = bar.total # Ensure the progress bar is full before closing
499 | bar.refresh()
500 | bar.close()
501 |
502 | print("")
503 |
--------------------------------------------------------------------------------
/taskara/runtime/load.py:
--------------------------------------------------------------------------------
1 | from typing import Optional, List, Type
2 |
3 | from pydantic import BaseModel
4 |
5 | from .docker import DockerTrackerRuntime, DockerConnectConfig
6 | from .kube import KubeTrackerRuntime, KubeConnectConfig
7 | from .process import ProcessTrackerRuntime, ProcessConnectConfig
8 | from .base import Tracker, TrackerRuntime
9 | from taskara.server.models import V1TrackerRuntimeConnect
10 |
11 |
12 | class AgentRuntimeConfig(BaseModel):
13 | provider: Optional[str] = None
14 | docker_config: Optional[DockerConnectConfig] = None
15 | kube_config: Optional[KubeConnectConfig] = None
16 | process_config: Optional[ProcessConnectConfig] = None
17 | preference: List[str] = ["kube", "docker", "process"]
18 |
19 |
20 | def runtime_from_name(name: str) -> Type[TrackerRuntime]:
21 | for runt in RUNTIMES:
22 | if runt.name() == name:
23 | return runt
24 | raise ValueError(f"Unknown runtime '{name}'")
25 |
26 |
27 | def load_tracker_runtime(cfg: AgentRuntimeConfig) -> TrackerRuntime:
28 | for pref in cfg.preference:
29 | if pref == KubeTrackerRuntime.name() and cfg.kube_config:
30 | return KubeTrackerRuntime.connect(cfg.kube_config)
31 | elif pref == DockerTrackerRuntime.name() and cfg.docker_config:
32 | return DockerTrackerRuntime.connect(cfg.docker_config)
33 | elif pref == ProcessTrackerRuntime.name() and cfg.process_config:
34 | return ProcessTrackerRuntime.connect(cfg.process_config)
35 | raise ValueError(f"Unknown provider: {cfg.provider}")
36 |
37 |
38 | RUNTIMES: List[Type[TrackerRuntime]] = [DockerTrackerRuntime, KubeTrackerRuntime, ProcessTrackerRuntime] # type: ignore
39 |
40 |
41 | def load_from_connect(connect: V1TrackerRuntimeConnect) -> TrackerRuntime:
42 | for runt in RUNTIMES:
43 | if connect.name == runt.name():
44 | print("connect config: ", connect.connect_config)
45 | print("type: ", type(connect.connect_config))
46 | cfg = runt.connect_config_type().model_validate(connect.connect_config)
47 | return runt.connect(cfg)
48 |
49 | raise ValueError(f"Unknown runtime: {connect.name}")
50 |
--------------------------------------------------------------------------------
/taskara/runtime/process.py:
--------------------------------------------------------------------------------
1 | import json
2 | import logging
3 | import os
4 | import signal
5 | import subprocess
6 | import sys
7 | import time
8 | import urllib.error
9 | import urllib.parse
10 | import urllib.request
11 | from typing import Dict, Iterator, List, Optional, Tuple, Type, Union
12 |
13 | import requests
14 | from pydantic import BaseModel
15 |
16 | from taskara.server.models import V1ResourceLimits, V1ResourceRequests
17 | from taskara.util import find_open_port
18 |
19 | from .base import Tracker, TrackerRuntime
20 |
21 | logger = logging.getLogger(__name__)
22 |
23 |
24 | class ProcessConnectConfig(BaseModel):
25 | pass
26 |
27 |
28 | class ProcessTrackerRuntime(
29 | TrackerRuntime["ProcessTrackerRuntime", ProcessConnectConfig]
30 | ):
31 |
32 | @classmethod
33 | def name(cls) -> str:
34 | return "process"
35 |
36 | @classmethod
37 | def connect_config_type(cls) -> Type[ProcessConnectConfig]:
38 | return ProcessConnectConfig
39 |
40 | def connect_config(self) -> ProcessConnectConfig:
41 | return ProcessConnectConfig()
42 |
43 | @classmethod
44 | def connect(cls, cfg: ProcessConnectConfig) -> "ProcessTrackerRuntime":
45 | return cls()
46 |
47 | def run(
48 | self,
49 | name: str,
50 | env_vars: Optional[dict] = None,
51 | owner_id: Optional[str] = None,
52 | labels: Optional[Dict[str, str]] = None,
53 | resource_requests: V1ResourceRequests = V1ResourceRequests(),
54 | resource_limits: V1ResourceLimits = V1ResourceLimits(),
55 | auth_enabled: bool = True,
56 | ) -> Tracker:
57 |
58 | port = find_open_port(9070, 10090)
59 | if not port:
60 | raise ValueError("Could not find open port")
61 | logger.debug("running process")
62 |
63 | metadata = {
64 | "name": name,
65 | "port": port,
66 | "env_vars": env_vars if env_vars else {},
67 | "owner_id": owner_id,
68 | }
69 |
70 | server_cmd = "poetry run python -m taskara.server.app"
71 | command = f"TASK_SERVER_PORT={port} nohup {server_cmd} TASK_SERVER={name} TASK_SERVER_PORT={port} > ./.data/logs/{name.lower()}.log 2>&1 &"
72 | if not auth_enabled:
73 | command = "TASK_SERVER_NO_AUTH=true " + command
74 | metadata["command"] = command
75 |
76 | # Create metadata directory if it does not exist
77 | os.makedirs(f".data/proc", exist_ok=True)
78 | # Write metadata to a file
79 | with open(f".data/proc/{name}.json", "w") as f:
80 | json.dump(metadata, f, indent=4)
81 |
82 | os.makedirs(f".data/logs", exist_ok=True)
83 | print(f"running server on port {port}")
84 |
85 | environment = os.environ.copy()
86 | process = subprocess.Popen(
87 | command,
88 | shell=True,
89 | preexec_fn=os.setsid,
90 | env=environment,
91 | text=True,
92 | )
93 |
94 | # Wait for the command to complete
95 | stdout, stderr = process.communicate()
96 |
97 | # Check if there were any errors
98 | if process.returncode != 0:
99 | logger.error("Error running command:")
100 | print(stderr)
101 | else:
102 | # Print the output from stdout
103 | if stdout:
104 | print(stdout)
105 |
106 | # Health check logic
107 | max_retries = 20
108 | retry_delay = 1
109 | health_url = f"http://localhost:{port}/health"
110 |
111 | for _ in range(max_retries):
112 | try:
113 | response = requests.get(health_url)
114 | if response.status_code == 200:
115 | logger.info("Task server is up and running.")
116 | break
117 | except requests.ConnectionError:
118 | logger.warning("Task server not yet available, retrying...")
119 | time.sleep(retry_delay)
120 | else:
121 | raise RuntimeError("Failed to start server, it did not pass health checks.")
122 |
123 | return Tracker(
124 | name=name,
125 | runtime=self,
126 | status="running",
127 | port=port,
128 | labels={"command": command},
129 | owner_id=owner_id,
130 | )
131 |
132 | def _signal_handler(self, server_name: str):
133 | def handle_signal(signum, frame):
134 | print(f"Signal {signum} received, stopping process '{server_name}'")
135 | self.delete(server_name)
136 | instances = Tracker.find(name=server_name)
137 | if instances:
138 | instances[0].delete()
139 | sys.exit(1)
140 |
141 | return handle_signal
142 |
143 | def _follow_logs(self, server_name: str):
144 | log_path = f"./.data/logs/{server_name.lower()}.log"
145 | if not os.path.exists(log_path):
146 | logger.error("No log file found.")
147 | return
148 |
149 | with open(log_path, "r") as log_file:
150 | # Go to the end of the file
151 | log_file.seek(0, 2)
152 | try:
153 | while True:
154 | line = log_file.readline()
155 | if not line:
156 | time.sleep(0.5) # Wait briefly for new log entries
157 | continue
158 | print(line.strip())
159 | except KeyboardInterrupt:
160 | # Handle Ctrl+C gracefully if we are attached to the logs
161 | print(f"Interrupt received, stopping logs for '{server_name}'")
162 | self.delete(server_name)
163 | raise
164 |
165 | def requires_proxy(self) -> bool:
166 | """Whether this runtime requires a proxy to be used"""
167 | return False
168 |
169 | def proxy(
170 | self,
171 | name: str,
172 | local_port: Optional[int] = None,
173 | tracker_port: int = 9070,
174 | background: bool = True,
175 | owner_id: Optional[str] = None,
176 | ) -> Optional[int]:
177 | logger.info("no proxy needed")
178 | return
179 |
180 | def get(
181 | self, name: str, owner_id: Optional[str] = None, source: bool = False
182 | ) -> Tracker:
183 | if source:
184 | try:
185 | # Read the metadata file
186 | with open(f".data/proc/{name}.json", "r") as f:
187 | metadata = json.load(f)
188 |
189 | return Tracker(
190 | name=metadata["name"],
191 | runtime=self,
192 | port=metadata["port"],
193 | )
194 | except FileNotFoundError:
195 | raise ValueError(f"No metadata found for server {name}")
196 |
197 | else:
198 | instances = Tracker.find(
199 | name=name, owner_id=owner_id, runtime_name=self.name()
200 | )
201 | if len(instances) == 0:
202 | raise ValueError(f"No running server found with the name {name}")
203 | return instances[0]
204 |
205 | def list(
206 | self,
207 | owner_id: Optional[str] = None,
208 | source: bool = False,
209 | ) -> List[Tracker]:
210 | instances = []
211 | if source:
212 | metadata_dir = ".data/proc"
213 | all_processes = subprocess.check_output(
214 | "ps ax -o pid,command", shell=True, text=True
215 | )
216 |
217 | for filename in os.listdir(metadata_dir):
218 | if filename.endswith(".json"):
219 | try:
220 | with open(os.path.join(metadata_dir, filename), "r") as file:
221 | metadata = json.load(file)
222 |
223 | # Check if process is still running
224 | process_info = f"TASK_SERVER={metadata['name']} "
225 | if process_info in all_processes:
226 | instance = Tracker(
227 | name=metadata["name"],
228 | runtime=self,
229 | status="running",
230 | port=metadata["port"],
231 | )
232 | instances.append(instance)
233 | else:
234 | # Process is not running, delete the metadata file
235 | os.remove(os.path.join(metadata_dir, filename))
236 | logger.info(
237 | f"Deleted metadata for non-existing process {metadata['name']}."
238 | )
239 |
240 | except Exception as e:
241 | logger.error(f"Error processing {filename}: {str(e)}")
242 | else:
243 | return Tracker.find(owner_id=owner_id, runtime_name=self.name())
244 |
245 | return instances
246 |
247 | def call(
248 | self,
249 | name: str,
250 | path: str,
251 | method: str,
252 | port: int = 9070,
253 | data: Optional[dict] = None,
254 | headers: Optional[dict] = None,
255 | ) -> Tuple[int, str]:
256 | # Construct the URL
257 | url = f"http://localhost:{port}{path}"
258 |
259 | # Create a request object based on the HTTP method
260 | if method.upper() == "GET":
261 | if data:
262 | query_params = urllib.parse.urlencode(data)
263 | url += f"?{query_params}"
264 | request = urllib.request.Request(url)
265 | else:
266 | request = urllib.request.Request(url, method=method.upper())
267 | if data:
268 | request.add_header("Content-Type", "application/json")
269 | if headers:
270 | for k, v in headers.items():
271 | request.add_header(k, v)
272 | request.data = json.dumps(data).encode("utf-8")
273 |
274 | # Send the request and handle the response
275 | try:
276 | response = urllib.request.urlopen(request)
277 | status_code = response.code
278 | response_text = response.read().decode("utf-8")
279 | return status_code, response_text
280 | except urllib.error.HTTPError as e:
281 | status_code = e.code
282 | error_message = e.read().decode("utf-8")
283 | raise SystemError(
284 | f"Error making HTTP request to local process: {status_code}: {error_message}"
285 | )
286 | finally:
287 | try:
288 | if response: # type: ignore
289 | response.close()
290 | except:
291 | pass
292 |
293 | def delete(
294 | self,
295 | name: str,
296 | owner_id: Optional[str] = None,
297 | ) -> None:
298 | try:
299 | process_list = subprocess.check_output(
300 | f"ps ax -o pid,command | grep -v grep | grep TASK_SERVER={name}",
301 | shell=True,
302 | text=True,
303 | )
304 | logger.debug(f"Found process list: {process_list}")
305 | if process_list.strip():
306 | # Process found, extract PID and kill it
307 | pid = process_list.strip().split()[0]
308 | os.killpg(os.getpgid(int(pid)), signal.SIGTERM)
309 | logger.info(f"Process {name} with PID {pid} has been terminated.")
310 | else:
311 | raise SystemError(f"No running process found with the name {name}.")
312 |
313 | # Delete the metadata file whether or not the process was found
314 | metadata_file = f".data/proc/{name}.json"
315 | if os.path.exists(metadata_file):
316 | os.remove(metadata_file)
317 | logger.info(f"Deleted metadata file for {name}.")
318 |
319 | except subprocess.CalledProcessError as e:
320 | raise SystemError(f"Error while attempting to delete the process: {str(e)}")
321 | except ValueError as e:
322 | raise SystemError(f"Error parsing process ID: {str(e)}")
323 | except Exception as e:
324 | raise SystemError(f"An unexpected error occurred: {str(e)}")
325 |
326 | def runtime_local_addr(self, name: str, owner_id: Optional[str] = None) -> str:
327 | """
328 | Returns the local address of the agent with respect to the runtime
329 | """
330 | instances = Tracker.find(name=name, owner_id=owner_id, runtime_name=self.name())
331 | if not instances:
332 | raise ValueError(f"Task server '{name}' not found")
333 | instance = instances[0]
334 |
335 | return f"http://localhost:{instance.port}"
336 |
337 | def clean(
338 | self,
339 | owner_id: Optional[str] = None,
340 | ) -> None:
341 | try:
342 | # Fetch the list of all processes that were started with the 'TASK_SERVER' environment variable
343 | process_list = subprocess.check_output(
344 | "ps ax -o pid,command | grep -v grep | grep TASK_SERVER",
345 | shell=True,
346 | text=True,
347 | )
348 | # Iterate through each process found and kill it
349 | for line in process_list.strip().split("\n"):
350 | pid = line.split()[0] # Extract the PID from the output
351 | try:
352 | os.kill(
353 | int(pid), signal.SIGTERM
354 | ) # Send SIGTERM signal to terminate the process
355 | logger.info(f"Terminated process with PID {pid}.")
356 | except OSError as e:
357 | logger.error(
358 | f"Failed to terminate process with PID {pid}: {str(e)}"
359 | )
360 | logger.info("All relevant processes have been terminated.")
361 | except subprocess.CalledProcessError as e:
362 | logger.error(
363 | "No relevant processes found or error executing the ps command:", str(e)
364 | )
365 | except Exception as e:
366 | logger.error(f"An unexpected error occurred during cleanup: {str(e)}")
367 |
368 | def logs(
369 | self,
370 | name: str,
371 | follow: bool = False,
372 | owner_id: Optional[str] = None,
373 | ) -> Union[str, Iterator[str]]:
374 | log_path = f"./.data/logs/{name.lower()}.log"
375 | if not os.path.exists(log_path):
376 | return "No logs available for this server."
377 |
378 | if follow:
379 | # If follow is True, implement a simple follow (like 'tail -f')
380 | def follow_logs():
381 | with open(log_path, "r") as log_file:
382 | # Go to the end of the file
383 | log_file.seek(0, 2)
384 | while True:
385 | line = log_file.readline()
386 | if not line:
387 | time.sleep(0.5) # Wait briefly for new log entries
388 | continue
389 | yield line
390 |
391 | return follow_logs()
392 | else:
393 | # If not following, return all logs as a single string
394 | with open(log_path, "r") as log_file:
395 | return log_file.read()
396 |
397 | def refresh(self, owner_id: Optional[str] = None) -> None:
398 | """
399 | Reconciles the database against the running processes.
400 |
401 | Parameters:
402 | owner_id (Optional[str]): The owner ID to filter the trackers. If None, refreshes for all owners.
403 | """
404 | # List all running processes with the specific environment variable
405 | all_processes = subprocess.check_output(
406 | "ps ax -o pid,command", shell=True, text=True
407 | )
408 | running_processes = {}
409 | for line in all_processes.splitlines():
410 | if "TASK_SERVER=" in line:
411 | pid, command = line.split(maxsplit=1)
412 | server_name = next(
413 | (
414 | part.split("=")[1]
415 | for part in command.split()
416 | if part.startswith("TASK_SERVER=")
417 | ),
418 | None,
419 | )
420 | if server_name:
421 | running_processes[server_name] = pid
422 |
423 | running_process_names = set(running_processes.keys())
424 |
425 | # List all trackers in the database
426 | if owner_id:
427 | db_trackers = Tracker.find(owner_id=owner_id, runtime_name=self.name())
428 | else:
429 | db_trackers = Tracker.find(runtime_name=self.name())
430 |
431 | db_tracker_names = {tracker.name for tracker in db_trackers}
432 |
433 | # Determine trackers to add or remove from the database
434 | processes_to_add = running_process_names - db_tracker_names
435 | processes_to_remove = db_tracker_names - running_process_names
436 |
437 | # Add new processes to the database
438 | for process_name in processes_to_add:
439 | try:
440 | with open(f".data/proc/{process_name}.json", "r") as f:
441 | metadata = json.load(f)
442 |
443 | new_tracker = Tracker(
444 | name=metadata["name"],
445 | runtime=self,
446 | port=metadata["port"],
447 | status="running",
448 | owner_id=owner_id,
449 | )
450 | new_tracker.save()
451 | except FileNotFoundError:
452 | logger.warning(f"No metadata found for process {process_name}")
453 |
454 | # Remove processes from the database that are no longer running
455 | for tracker_name in processes_to_remove:
456 | trackers = Tracker.find(
457 | name=tracker_name, owner_id=owner_id, runtime_name=self.name()
458 | )
459 | if not trackers:
460 | continue
461 | tracker = trackers[0]
462 | tracker.delete()
463 |
464 | logger.debug(
465 | f"Refresh completed: added {len(processes_to_add)} trackers, removed {len(processes_to_remove)} trackers."
466 | )
467 |
--------------------------------------------------------------------------------
/taskara/server/app.py:
--------------------------------------------------------------------------------
1 | import logging.config
2 | import os
3 | from contextlib import asynccontextmanager
4 | import time
5 | from fastapi import FastAPI, Request, status
6 | from fastapi.exceptions import RequestValidationError
7 | from fastapi.middleware.cors import CORSMiddleware
8 | from fastapi.responses import JSONResponse
9 | from ..db.redis_connection import init_redis_pool, close_redis_pool
10 |
11 |
12 | from .router.benchmarks import router as benchmarks_router
13 | from .router.tasks import router as tasks_router
14 |
15 | logging.config.fileConfig("logging.conf", disable_existing_loggers=False)
16 |
17 | @asynccontextmanager
18 | async def lifespan(app: FastAPI):
19 | print("initializing boot sequence...", flush=True)
20 | print("boot sequence initialized.", flush=True)
21 | await init_redis_pool()
22 | yield
23 | print("Shutting Down Fast API server Taskara", flush=True)
24 | await close_redis_pool()
25 |
26 | app = FastAPI(lifespan=lifespan)
27 |
28 | access_logger = logging.getLogger("access")
29 |
30 |
31 | @app.exception_handler(RequestValidationError)
32 | async def validation_exception_handler(request: Request, exc: RequestValidationError):
33 | errors = []
34 | for error in exc.errors():
35 | field = ".".join(str(loc) for loc in error["loc"])
36 | msg = {"field": field, "message": error["msg"], "type": error["type"]}
37 | body = await request.body()
38 | print("\n\n!error: ", msg, "\nrequest data: ", body.decode(), "\n")
39 | access_logger.error(
40 | f"Validation error for field {field}: {error['msg']} (type: {error['type']}, request data: {body.decode()})"
41 | )
42 | errors.append(msg)
43 | return JSONResponse(
44 | status_code=status.HTTP_422_UNPROCESSABLE_ENTITY, content={"detail": errors}
45 | )
46 |
47 |
48 | @app.middleware("http")
49 | async def log_requests(request: Request, call_next):
50 | start_time = time.time()
51 | access_logger.info(f"Received request: {request.method} {request.url}")
52 | response = await call_next(request)
53 | duration = time.time() - start_time
54 | access_logger.info(
55 | f"Returned response {request.method} {request.url}: {response.status_code} - Duration: {duration:.4f} seconds"
56 | )
57 | return response
58 |
59 |
60 | app.add_middleware(
61 | CORSMiddleware,
62 | allow_origins=["*"],
63 | allow_credentials=True,
64 | allow_methods=["*"],
65 | allow_headers=["*"],
66 | )
67 |
68 | app.include_router(tasks_router)
69 | app.include_router(benchmarks_router)
70 |
71 |
72 | @app.get("/")
73 | async def root():
74 | return {"message": "A Taskara tracker"}
75 |
76 |
77 | @app.get("/health")
78 | async def health():
79 | return {"status": "ok"}
80 |
81 |
82 | if __name__ == "__main__":
83 | import uvicorn
84 |
85 | port = int(os.getenv("TASK_SERVER_PORT", "9070"))
86 | reload = os.getenv("TASK_SERVER_RELOAD", "false") == "true"
87 |
88 | uvicorn.run(app="taskara.server.app:app", host="0.0.0.0", port=port, reload=reload)
89 |
--------------------------------------------------------------------------------
/taskara/server/models.py:
--------------------------------------------------------------------------------
1 | import time
2 | from typing import Any, Dict, List, Optional
3 |
4 | import shortuuid
5 | from devicebay import V1Device, V1DeviceType
6 | from mllm import V1Prompt
7 | from pydantic import BaseModel, Field
8 | from skillpacks.base import V1Action, V1ActionEvent
9 | from skillpacks.review import ReviewerType, V1Review
10 | from threadmem.server.models import V1RoleThread
11 |
12 |
13 | class V1ReviewRequirement(BaseModel):
14 | """Review requirement for a task"""
15 |
16 | id: Optional[str] = None
17 | task_id: Optional[str] = None
18 | users: Optional[List[str]] = None
19 | agents: Optional[List[str]] = None
20 | groups: Optional[List[str]] = None
21 | number_required: int = 2
22 |
23 |
24 | class V1PendingReviewers(BaseModel):
25 | """Pending reviewers for a task"""
26 |
27 | task_id: str
28 | users: Optional[List[str]] = None
29 | agents: Optional[List[str]] = None
30 |
31 |
32 | class V1PendingReviews(BaseModel):
33 | tasks: List[str]
34 |
35 |
36 | class V1CreateReview(BaseModel):
37 | approved: bool
38 | reviewer_type: str = ReviewerType.HUMAN.value
39 | reason: Optional[str] = None
40 | reviewer: Optional[str] = None
41 |
42 |
43 | class V1CreateReviewAction(BaseModel):
44 | approved: bool
45 | reviewer_type: str = ReviewerType.HUMAN.value
46 | reason: Optional[str] = None
47 | reviewer: Optional[str] = None
48 | correction: Optional[V1Action] = None
49 |
50 |
51 | class V1CreateAnnotationReview(BaseModel):
52 | approved: bool
53 | reviewer_type: str = ReviewerType.HUMAN.value
54 | reason: Optional[str] = None
55 | reviewer: Optional[str] = None
56 | correction: Optional[str] = None
57 |
58 |
59 | class V1CreateAnnotationResponse(BaseModel):
60 | id: str
61 |
62 |
63 | class V1ReviewMany(BaseModel):
64 | reviewer: Optional[str] = None
65 | reviewer_type: str = ReviewerType.HUMAN.value
66 | approve_hidden: bool = False
67 | fail_hidden: bool = False
68 |
69 |
70 | class V1TaskUpdate(BaseModel):
71 | status: Optional[str] = None
72 | org: Optional[str] = None
73 | description: Optional[str] = None
74 | max_steps: Optional[int] = None
75 | error: Optional[str] = None
76 | output: Optional[str] = None
77 | assigned_to: Optional[str] = None
78 | assigned_type: Optional[str] = None
79 | completed: Optional[float] = None
80 | started: Optional[float] = None
81 | version: Optional[str] = None
82 | set_labels: Optional[Dict[str, str]] = None
83 |
84 | class V1CreateTask(BaseModel):
85 | id: str = Field(default_factory=lambda: shortuuid.uuid())
86 | description: str
87 | max_steps: int = 30
88 | device: Optional[V1Device] = None
89 | device_type: Optional[V1DeviceType] = None
90 | expect_schema: Optional[Dict[str, Any]] = None
91 | status: Optional[str] = None
92 | threads: Optional[List[V1RoleThread]] = None
93 | prompts: Optional[List[str]] = None
94 | reviews: List[V1Review] = []
95 | review_requirements: List[V1ReviewRequirement] = []
96 | assigned_to: Optional[str] = None
97 | assigned_type: Optional[str] = None
98 | created: float = Field(default_factory=time.time)
99 | started: float = 0.0
100 | completed: float = 0.0
101 | created_by: Optional[str] = None
102 | error: Optional[str] = None
103 | output: Optional[str] = None
104 | parameters: Optional[Dict[str, Any]] = {}
105 | version: Optional[str] = None
106 | remote: Optional[str] = None
107 | owner_id: Optional[str] = None
108 | project: Optional[str] = None
109 | parent_id: Optional[str] = None
110 | tags: List[str] = []
111 | labels: Dict[str, str] = {}
112 | episode_id: Optional[str] = None
113 | public: bool = False
114 | skill: Optional[str] = None
115 | org: Optional[str] = None
116 | auth_token: Optional[str] = None
117 |
118 | class V1Task(BaseModel):
119 | id: str = Field(default_factory=lambda: shortuuid.uuid())
120 | description: str
121 | max_steps: int = 30
122 | device: Optional[V1Device] = None
123 | device_type: Optional[V1DeviceType] = None
124 | expect_schema: Optional[Dict[str, Any]] = None
125 | status: Optional[str] = None
126 | threads: Optional[List[V1RoleThread]] = None
127 | prompts: Optional[List[str]] = None
128 | reviews: List[V1Review] = []
129 | review_requirements: List[V1ReviewRequirement] = []
130 | assigned_to: Optional[str] = None
131 | assigned_type: Optional[str] = None
132 | created: float = Field(default_factory=time.time)
133 | started: float = 0.0
134 | completed: float = 0.0
135 | created_by: Optional[str] = ''
136 | error: Optional[str] = None
137 | output: Optional[str] = None
138 | parameters: Optional[Dict[str, Any]] = {}
139 | version: Optional[str] = None
140 | remote: Optional[str] = None
141 | owner_id: Optional[str] = None
142 | project: Optional[str] = None
143 | parent_id: Optional[str] = None
144 | tags: List[str] = []
145 | labels: Dict[str, str] = {}
146 | episode_id: Optional[str] = None
147 | public: bool = False
148 | skill: Optional[str] = None
149 | auth_token: Optional[str] = None
150 |
151 |
152 | class V1SearchTask(BaseModel):
153 | id: Optional[str] = None
154 | description: Optional[str] = None
155 | max_steps: Optional[int] = None
156 | # device: Optional[V1Device]
157 | # device_type: Optional[V1DeviceType]
158 | status: Optional[str] = None
159 | owners: Optional[List[str]] = None
160 | # threads: Optional[List[V1RoleThread]]
161 | # prompts: Optional[List[str]]
162 | # reviews: Optional[List[V1Review]]
163 | # review_requirements: Optional[List[V1ReviewRequirement]]
164 | assigned_to: Optional[str] = None
165 | assigned_type: Optional[str] = None
166 | created: Optional[float] = None
167 | created_by: Optional[str] = None
168 | started: Optional[float] = None
169 | completed: Optional[float] = None
170 | error: Optional[str] = None
171 | output: Optional[str] = None
172 | # parameters: Optional[Dict[str, Any]]
173 | version: Optional[str] = None
174 | remote: Optional[str] = None
175 | owner_id: Optional[str] = None
176 | project: Optional[str] = None
177 | parent_id: Optional[str] = None
178 | tags: Optional[List[str]] = None
179 | labels: Optional[Dict[str, str]] = None
180 | skill: Optional[str] = None
181 | public: Optional[bool] = None
182 | episode_id: Optional[str] = None
183 | auth_token: Optional[str] = None
184 |
185 |
186 | class V1Tasks(BaseModel):
187 | tasks: List[V1Task]
188 |
189 |
190 | class V1TaskIDs(BaseModel):
191 | task_ids: List[str]
192 |
193 |
194 | class V1TaskTemplate(BaseModel):
195 | id: str = Field(default_factory=lambda: shortuuid.uuid())
196 | description: str
197 | max_steps: int = 30
198 | device: Optional[V1Device] = None
199 | device_type: Optional[V1DeviceType] = None
200 | expect_schema: Optional[Dict[str, Any]] = None
201 | parameters: Optional[Dict[str, Any]] = {}
202 | owner_id: Optional[str] = None
203 | tags: List[str] = []
204 | labels: Dict[str, str] = {}
205 | created: float = Field(default_factory=lambda: time.time())
206 |
207 |
208 | class V1TaskTemplates(BaseModel):
209 | templates: List[V1TaskTemplate]
210 |
211 |
212 | class V1UserProfile(BaseModel):
213 | email: Optional[str] = None
214 | display_name: Optional[str] = None
215 | handle: Optional[str] = None
216 | organization: Optional[str] = None
217 | role: Optional[str] = None
218 | actor: Optional[str] = None
219 | picture: Optional[str] = None
220 | created: Optional[int] = None
221 | updated: Optional[int] = None
222 | token: Optional[str] = None
223 |
224 |
225 | class V1AddThread(BaseModel):
226 | public: bool
227 | name: Optional[str] = None
228 | metadata: Optional[dict] = None
229 | id: Optional[str] = None
230 |
231 |
232 | class V1RemoveThread(BaseModel):
233 | id: str
234 |
235 |
236 | class V1PostMessage(BaseModel):
237 | role: str
238 | msg: str
239 | images: List[str] = []
240 | thread: Optional[str] = None
241 |
242 |
243 | class V1TrackerRuntimeConnect(BaseModel):
244 | name: str
245 | connect_config: BaseModel
246 |
247 |
248 | class V1Tracker(BaseModel):
249 | name: str
250 | runtime: V1TrackerRuntimeConnect
251 | version: Optional[str] = None
252 | port: int = 9090
253 | labels: Dict[str, str] = {}
254 | tags: List[str] = []
255 | status: str
256 | owner_id: Optional[str] = None
257 | created: float
258 | updated: float
259 |
260 |
261 | class V1Runtime(BaseModel):
262 | type: str
263 | preference: List[str] = []
264 |
265 |
266 | class V1ResourceLimits(BaseModel):
267 | cpu: str = "2"
268 | memory: str = "2Gi"
269 |
270 |
271 | class V1ResourceRequests(BaseModel):
272 | cpu: str = "1"
273 | memory: str = "500m"
274 | gpu: Optional[str] = None
275 |
276 |
277 | class V1Prompts(BaseModel):
278 | prompts: List[V1Prompt]
279 |
280 |
281 | class V1Benchmark(BaseModel):
282 | id: str = Field(default_factory=lambda: shortuuid.uuid())
283 | name: str
284 | description: str
285 | tasks: List[V1TaskTemplate]
286 | owner_id: Optional[str] = None
287 | tags: List[str] = []
288 | labels: Dict[str, str] = {}
289 | created: float = Field(default_factory=lambda: time.time())
290 | public: bool = False
291 |
292 |
293 | class V1BenchmarkEval(BaseModel):
294 | assigned_to: str | None = None
295 | assigned_type: str | None = None
296 |
297 |
298 | class V1Benchmarks(BaseModel):
299 | benchmarks: List[V1Benchmark]
300 |
301 |
302 | class V1Eval(BaseModel):
303 | id: Optional[str] = None
304 | benchmark: V1Benchmark
305 | tasks: List[V1Task]
306 | assigned_to: Optional[str] = None
307 | assigned_type: Optional[str] = None
308 | owner_id: Optional[str] = None
309 |
310 |
311 | class V1Evals(BaseModel):
312 | evals: List[V1Eval]
313 |
314 |
315 | class V1Flag(BaseModel):
316 | type: str
317 | id: str
318 | flag: Dict[str, Any]
319 | result: Optional[Dict[str, Any]] = None
320 | created: float
321 |
322 |
323 | class V1BoundingBox(BaseModel):
324 | """A bounding box"""
325 |
326 | x0: int
327 | x1: int
328 | y0: int
329 | y1: int
330 |
331 |
332 | class V1BoundingBoxFlag(BaseModel):
333 | """A bounding box"""
334 |
335 | img: str
336 | target: str
337 | bbox: V1BoundingBox
338 |
339 |
340 | class V1ActionRecordedMessage(BaseModel):
341 | action: V1ActionEvent
342 | prevAction: Optional[V1ActionEvent] = None
343 | task: V1Task
344 | event_number: int = Field(
345 | ...,
346 | description="The index of the action event in order on the task. The first action event will be 0",
347 | )
348 |
349 | class V1TrainingCompletedMessage(BaseModel):
350 | task_id: str
351 | agent_type: Optional[str] = None
352 | skill_id: Optional[str] = None
--------------------------------------------------------------------------------
/taskara/server/router/benchmarks.py:
--------------------------------------------------------------------------------
1 | import logging
2 | from typing import Annotated
3 |
4 | from fastapi import APIRouter, Depends, HTTPException
5 |
6 | from taskara.auth.transport import get_user_dependency
7 | from taskara.benchmark import Benchmark, Eval
8 | from taskara.server.models import (
9 | V1Benchmark,
10 | V1BenchmarkEval,
11 | V1Benchmarks,
12 | V1Eval,
13 | V1Evals,
14 | V1UserProfile,
15 | )
16 |
17 | router = APIRouter()
18 | logger = logging.getLogger(__name__)
19 |
20 |
21 | @router.post("/v1/benchmarks", response_model=V1Benchmark)
22 | async def create_benchmark(
23 | current_user: Annotated[V1UserProfile, Depends(get_user_dependency())],
24 | data: V1Benchmark,
25 | ):
26 | logger.debug(f"Creating benchmark with model: {data}")
27 | benchmark = Benchmark.from_v1(data, owner_id=current_user.email)
28 | benchmark.save()
29 | return benchmark.to_v1()
30 |
31 |
32 | @router.get("/v1/benchmarks", response_model=V1Benchmarks)
33 | async def get_benchmarks(
34 | current_user: Annotated[V1UserProfile, Depends(get_user_dependency())]
35 | ):
36 | benchmarks = Benchmark.find(owner_id=current_user.email)
37 | return V1Benchmarks(benchmarks=[benchmark.to_v1() for benchmark in benchmarks])
38 |
39 |
40 | @router.get("/v1/benchmarks/{id}", response_model=V1Benchmark)
41 | async def get_benchmark(
42 | current_user: Annotated[V1UserProfile, Depends(get_user_dependency())],
43 | id: str,
44 | ):
45 | logger.debug(f"Finding benchmark by id: {id}")
46 | benchmarks = Benchmark.find(id=id, owner_id=current_user.email)
47 | if not benchmarks:
48 | raise HTTPException(status_code=404, detail="Benchmark not found")
49 | return benchmarks[0].to_v1()
50 |
51 |
52 | @router.delete("/v1/benchmarks/{id}")
53 | async def delete_benchmark(
54 | current_user: Annotated[V1UserProfile, Depends(get_user_dependency())],
55 | id: str,
56 | ):
57 | Benchmark.delete(id=id, owner_id=current_user.email) # type: ignore
58 | return {"message": "Benchmark deleted successfully"}
59 |
60 |
61 | @router.post("/v1/benchmarks/{id}/eval", response_model=V1Eval)
62 | async def create_eval_from_benchmark(
63 | current_user: Annotated[V1UserProfile, Depends(get_user_dependency())],
64 | id: str,
65 | data: V1BenchmarkEval,
66 | ):
67 | logger.debug(f"Finding benchmark by id: {id}")
68 | benchmarks = Benchmark.find(id=id, owner_id=current_user.email)
69 | if not benchmarks:
70 | raise HTTPException(status_code=404, detail="Benchmark not found")
71 | benchmark = benchmarks[0]
72 |
73 | eval = benchmark.eval(
74 | data.assigned_to, data.assigned_type, owner_id=current_user.email
75 | )
76 | eval.save()
77 | return eval.to_v1()
78 |
79 |
80 | @router.post("/v1/evals", response_model=V1Eval)
81 | async def create_eval(
82 | current_user: Annotated[V1UserProfile, Depends(get_user_dependency())],
83 | data: V1Eval,
84 | ):
85 | logger.debug(f"Creating eval with model: {data}")
86 | eval_instance = Eval.from_v1(data, owner_id=current_user.email)
87 | eval_instance.save()
88 | return eval_instance.to_v1()
89 |
90 |
91 | @router.get("/v1/evals", response_model=V1Evals)
92 | async def get_evals(
93 | current_user: Annotated[V1UserProfile, Depends(get_user_dependency())]
94 | ):
95 | evals = Eval.find(owner_id=current_user.email)
96 | return V1Evals(evals=[eval_instance.to_v1() for eval_instance in evals])
97 |
98 |
99 | @router.get("/v1/evals/{id}", response_model=V1Eval)
100 | async def get_eval(
101 | current_user: Annotated[V1UserProfile, Depends(get_user_dependency())],
102 | id: str,
103 | ):
104 | logger.debug(f"Finding eval by id: {id}")
105 | evals = Eval.find(id=id, owner_id=current_user.email)
106 | if not evals:
107 | raise HTTPException(status_code=404, detail="Eval not found")
108 | return evals[0].to_v1()
109 |
110 |
111 | @router.delete("/v1/evals/{id}")
112 | async def delete_eval(
113 | current_user: Annotated[V1UserProfile, Depends(get_user_dependency())],
114 | id: str,
115 | ):
116 | Eval.delete(id=id, owner_id=current_user.email) # type: ignore
117 | return {"message": "Eval deleted successfully"}
118 |
--------------------------------------------------------------------------------
/taskara/util.py:
--------------------------------------------------------------------------------
1 | import random
2 | import socket
3 | import string
4 | import subprocess
5 | from typing import Optional
6 | from openmeter import Client
7 | from azure.core.exceptions import ResourceNotFoundError
8 | import os
9 |
10 | openmeter_secret = os.getenv("OPENMETER_SECRET", False)
11 | openmeter_agent_task_feature = os.getenv("OPENMETER_AGENT_TASK_FEATURE")
12 | # TODO really figure out if this initiates a connection, I think not but should make sure somehow
13 | if openmeter_secret:
14 | openmeter_client = Client(
15 | endpoint="https://openmeter.cloud",
16 | headers={
17 | "Accept": "application/json",
18 | "Authorization": f"Bearer {openmeter_secret}",
19 | },
20 | )
21 |
22 |
23 | def generate_random_string(length: int = 8):
24 | """Generate a random string of fixed length."""
25 | letters = string.ascii_letters + string.digits
26 | return "".join(random.choices(letters, k=length))
27 |
28 |
29 | def get_docker_host() -> str:
30 | try:
31 | # Get the current Docker context
32 | current_context = (
33 | subprocess.check_output("docker context show", shell=True).decode().strip()
34 | )
35 |
36 | # Inspect the current Docker context and extract the host
37 | context_info = subprocess.check_output(
38 | f"docker context inspect {current_context}", shell=True
39 | ).decode()
40 | for line in context_info.split("\n"):
41 | if '"Host"' in line:
42 | return line.split('"')[3]
43 | return ""
44 | except subprocess.CalledProcessError as e:
45 | print(f"Error: {e.output.decode()}")
46 | return ""
47 |
48 |
49 | def check_port_in_use(port: int) -> bool:
50 | """
51 | Check if the specified port is currently in use on the local machine.
52 |
53 | Args:
54 | port (int): The port number to check.
55 |
56 | Returns:
57 | bool: True if the port is in use, False otherwise.
58 | """
59 | with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
60 | return s.connect_ex(("localhost", port)) == 0
61 |
62 |
63 | def find_open_port(start_port: int = 1024, end_port: int = 65535) -> Optional[int]:
64 | """Finds an open port on the machine"""
65 | for port in range(start_port, end_port + 1):
66 | with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
67 | try:
68 | s.bind(("", port))
69 | return port # Port is open
70 | except socket.error:
71 | continue # Port is in use, try the next one
72 | return None # No open port found
73 |
74 | def check_openmeter_agent_tasks(owner_id) -> bool:
75 | if openmeter_secret:
76 | if not openmeter_agent_task_feature or not openmeter_client:
77 | raise ValueError('Cannot create desktop no openmeter secret or client or openmeter_agent_task_feature to get entitlements from')
78 |
79 | entitlement_value = {}
80 | try:
81 | # Check openmeter for if user has access through an entitlement
82 | entitlement_value = openmeter_client.get_entitlement_value(
83 | subject_id_or_key=owner_id,
84 | entitlement_id_or_feature_key=openmeter_agent_task_feature
85 | )
86 |
87 | except ResourceNotFoundError as e:
88 | print(
89 | f"#slack-alert Feature {openmeter_agent_task_feature} not found for subject {owner_id}: {e}"
90 | )
91 | return False
92 | if not entitlement_value or not entitlement_value["hasAccess"]:
93 | print(f"entitlement access denied in assigning task to agent for feature {openmeter_agent_task_feature}, for subject {owner_id} it is likely that the entitlement is no longer valid or the user/org has reached their cap #slack-alert")
94 | return False
95 | print(f"user: {owner_id} agent task entitlement values are {entitlement_value}", flush=True)
96 | return True
--------------------------------------------------------------------------------
/tests/benchmarks/airbnb.yaml:
--------------------------------------------------------------------------------
1 | name: AirbnbLiteV1
2 | description: A simple Airbnb benchmark
3 | tags: ["GUI", "desktop", "airbnb", "booking"]
4 | public: true
5 | tasks:
6 | - description: |
7 | Find a highly rated one bedroom apartment available in Barcelona Spain
8 | from August 5th to August 8th
9 | max_steps: 50
10 | device_type:
11 | name: "desktop"
12 | expect_schema:
13 | properties:
14 | room_id:
15 | description: |
16 | The id of the room, can be found in the URL e.g. https://www.airbnb.com/rooms/
17 | type: string
18 | required:
19 | - room_id
20 | type: object
21 | parameters:
22 | site: https://airbnb.com
23 |
24 | - description: |
25 | Find a highly rated two bedroom apartment available in Boulder CO
26 | from December 9th to December 12th
27 | max_steps: 50
28 | device_type:
29 | name: "desktop"
30 | expect_schema:
31 | properties:
32 | room_id:
33 | description: |
34 | The id of the room, can be found in the URL e.g. https://www.airbnb.com/rooms/
35 | type: string
36 | required:
37 | - room_id
38 | type: object
39 | parameters:
40 | site: https://airbnb.com
41 |
--------------------------------------------------------------------------------
/tests/test_bench.py:
--------------------------------------------------------------------------------
1 | import pytest
2 |
3 | from taskara.benchmark import Benchmark, TaskTemplate
4 | from taskara.server.models import V1Benchmark, V1TaskTemplate
5 |
6 |
7 | def test_benchmark_creation():
8 | task_template = TaskTemplate(description="Test Task")
9 | benchmark = Benchmark(
10 | name="Test Benchmark",
11 | description="Test Benchmark Description",
12 | tasks=[task_template],
13 | owner_id="owner@example.com",
14 | )
15 |
16 | assert benchmark.name == "Test Benchmark"
17 | assert benchmark.description == "Test Benchmark Description"
18 | assert len(benchmark.tasks) == 1
19 | assert benchmark.tasks[0].description == "Test Task"
20 | assert benchmark.owner_id == "owner@example.com"
21 |
22 |
23 | def test_benchmark_to_v1():
24 | task_template = TaskTemplate(description="Test Task")
25 | benchmark = Benchmark(
26 | name="Test Benchmark",
27 | description="Test Benchmark Description",
28 | tasks=[task_template],
29 | owner_id="owner@example.com",
30 | )
31 |
32 | v1_benchmark = benchmark.to_v1()
33 |
34 | assert v1_benchmark.name == "Test Benchmark"
35 | assert v1_benchmark.description == "Test Benchmark Description"
36 | assert len(v1_benchmark.tasks) == 1
37 | assert v1_benchmark.tasks[0].description == "Test Task"
38 | assert v1_benchmark.owner_id == "owner@example.com"
39 |
40 |
41 | def test_benchmark_from_v1():
42 | v1_task_template = V1TaskTemplate(
43 | id="task1",
44 | description="Test Task",
45 | created=0.0,
46 | owner_id="owner@example.com",
47 | )
48 | v1_benchmark = V1Benchmark(
49 | id="benchmark1",
50 | name="Test Benchmark",
51 | description="Test Benchmark Description",
52 | tasks=[v1_task_template],
53 | owner_id="owner@example.com",
54 | created=0.0,
55 | )
56 |
57 | benchmark = Benchmark.from_v1(v1_benchmark)
58 |
59 | assert benchmark.name == "Test Benchmark"
60 | assert benchmark.description == "Test Benchmark Description"
61 | assert len(benchmark.tasks) == 1
62 | assert benchmark.tasks[0].description == "Test Task"
63 | assert benchmark.owner_id == "owner@example.com"
64 |
--------------------------------------------------------------------------------
/tests/test_eval.py:
--------------------------------------------------------------------------------
1 | import pytest
2 |
3 | from taskara.benchmark import Benchmark, Eval, TaskTemplate
4 | from taskara.server.models import V1Benchmark, V1Eval, V1Task, V1TaskTemplate
5 |
6 |
7 | def test_eval_creation():
8 | task_template = TaskTemplate(description="Test Task", owner_id="owner@example.com")
9 | benchmark = Benchmark(
10 | name="Test Benchmark",
11 | description="Test Benchmark Description",
12 | tasks=[task_template],
13 | owner_id="owner@example.com",
14 | )
15 | eval_instance = Eval(benchmark)
16 |
17 | assert eval_instance.benchmark.name == "Test Benchmark"
18 | assert len(eval_instance.tasks) == 1
19 | assert eval_instance.tasks[0].description == "Test Task"
20 | assert eval_instance.benchmark.owner_id == "owner@example.com"
21 |
22 |
23 | def test_eval_to_v1():
24 | task_template = TaskTemplate(description="Test Task", owner_id="owner@example.com")
25 | benchmark = Benchmark(
26 | name="Test Benchmark",
27 | description="Test Benchmark Description",
28 | tasks=[task_template],
29 | owner_id="owner@example.com",
30 | )
31 | eval_instance = Eval(benchmark, owner_id="owner@example.com")
32 |
33 | v1_eval = eval_instance.to_v1()
34 |
35 | assert v1_eval.benchmark.name == "Test Benchmark"
36 | assert len(v1_eval.tasks) == 1
37 | assert v1_eval.tasks[0].description == "Test Task"
38 | assert v1_eval.owner_id == "owner@example.com"
39 |
40 |
41 | def test_eval_from_v1():
42 | v1_task_template = V1TaskTemplate(
43 | id="task1", description="Test Task", created=0.0, owner_id="owner@example.com"
44 | )
45 | v1_benchmark = V1Benchmark(
46 | id="benchmark1",
47 | name="Test Benchmark",
48 | description="Test Benchmark Description",
49 | tasks=[v1_task_template],
50 | owner_id="owner@example.com",
51 | created=0.0,
52 | )
53 |
54 | v1_task = V1Task(
55 | id="123", description="Search for french ducks", owner_id="owner@example.com"
56 | )
57 | v1_eval = V1Eval(
58 | id="eval1",
59 | benchmark=v1_benchmark,
60 | tasks=[v1_task],
61 | owner_id="owner@example.com",
62 | )
63 |
64 | eval_instance = Eval.from_v1(v1_eval)
65 |
66 | assert eval_instance.benchmark.name == "Test Benchmark"
67 | assert len(eval_instance.tasks) == 1
68 | assert eval_instance.tasks[0].description == "Search for french ducks"
69 | assert eval_instance.benchmark.owner_id == "owner@example.com"
70 |
--------------------------------------------------------------------------------
/tests/test_runtime.py:
--------------------------------------------------------------------------------
1 | import json
2 | import time
3 | import urllib.parse
4 |
5 | from mllm import Prompt, RoleMessage, RoleThread
6 | from namesgenerator import get_random_name
7 | from openai import BaseModel
8 | from skillpacks import ActionEvent, AnnotationReviewable, EnvState, V1Action, V1EnvState
9 | from skillpacks.server.models import V1ActionEvents, V1AnnotationReviewable, V1Episode
10 | from toolfuse.models import V1ToolRef
11 |
12 | from taskara import (
13 | Benchmark,
14 | ReviewRequirement,
15 | Task,
16 | TaskStatus,
17 | TaskTemplate,
18 | V1Benchmark,
19 | V1Task,
20 | V1TaskTemplate,
21 | )
22 | from taskara.runtime.process import ProcessConnectConfig, ProcessTrackerRuntime
23 | from taskara.server.models import (
24 | V1Benchmark,
25 | V1BenchmarkEval,
26 | V1Benchmarks,
27 | V1CreateAnnotationReview,
28 | V1CreateReview,
29 | V1DeviceType,
30 | V1Eval,
31 | V1Evals,
32 | V1PendingReviewers,
33 | V1PendingReviews,
34 | V1Prompt,
35 | V1RoleThread,
36 | V1Tasks,
37 | V1TaskTemplate,
38 | )
39 |
40 |
41 | def test_process_tracker_runtime():
42 | runtime = ProcessTrackerRuntime()
43 | assert runtime.name() == "process"
44 | assert runtime.connect_config_type() == ProcessConnectConfig
45 | assert runtime.connect_config().model_dump() == {}
46 |
47 | runtime.refresh()
48 |
49 | name = get_random_name("-")
50 | assert name
51 |
52 | print("running task server ", name)
53 | server = runtime.run(name, auth_enabled=False)
54 | print("task server ", server.__dict__)
55 |
56 | try:
57 | # Create a task
58 | task_data = {
59 | "description": "Search for french ducks",
60 | "assigned_to": "tom@myspace.com",
61 | "labels": {"test": "true"},
62 | "review_requirements": [
63 | {
64 | "number_required": 2,
65 | "users": ["anonymous@agentsea.ai"],
66 | "agents": ["agent1", "agent2"],
67 | },
68 | {
69 | "number_required": 1,
70 | "users": ["tom@myspace.com", "anonymous@agentsea.ai"],
71 | "agents": ["agent3"],
72 | },
73 | ],
74 | }
75 | status, text = server.call(path="/v1/tasks", method="POST", data=task_data)
76 | print("status: ", status)
77 | print("task created: ", text)
78 | assert status == 200
79 |
80 | task = V1Task.model_validate(json.loads(text))
81 | assert task.description == "Search for french ducks"
82 | assert task.owner_id == "tom@myspace.com"
83 | task_id = task.id
84 | time.sleep(1)
85 |
86 | # Fetch the task with query parameters and labels passed as a JSON string
87 | labels_query = json.dumps({"test": "true"}) # Encode labels as JSON string
88 | encoded_labels = urllib.parse.quote(labels_query)
89 |
90 | status, text = server.call(
91 | path=f"/v1/tasks?labels={encoded_labels}", method="GET"
92 | )
93 | print("status: ", status)
94 | print("tasks fetched: ", text)
95 | assert status == 200
96 |
97 | tasks = V1Tasks.model_validate(json.loads(text))
98 | assert any(t.id == task_id for t in tasks.tasks)
99 |
100 | # Get a specific task
101 | status, text = server.call(path=f"/v1/tasks/{task_id}", method="GET")
102 | print("status: ", status)
103 | print("task fetched: ", text)
104 | assert status == 200
105 | task = V1Task.model_validate(json.loads(text))
106 | assert task.id == task_id
107 | assert task.assigned_to == "tom@myspace.com"
108 |
109 | # Check the review requirements
110 | status, text = server.call(
111 | path=f"/v1/tasks/{task_id}/pending_reviewers", method="GET"
112 | )
113 | assert status == 200
114 | pending_reviewers = V1PendingReviewers.model_validate_json(text)
115 | assert pending_reviewers.users is not None
116 | assert len(pending_reviewers.users) == 5
117 |
118 | status, text = server.call(path="/v1/pending_reviews", method="GET")
119 | assert status == 200
120 | pending_reviews = V1PendingReviews.model_validate_json(text)
121 | assert len(pending_reviews.tasks) == 1
122 |
123 | # Approve the task
124 | data = V1CreateReview(approved=True, reason="Approved")
125 | status, text = server.call(
126 | path=f"/v1/tasks/{task_id}/review", method="PUT", data=data.model_dump()
127 | )
128 |
129 | data = V1CreateReview(approved=True, reason="Approved", reviewer="agent1")
130 | status, text = server.call(
131 | path=f"/v1/tasks/{task_id}/review", method="PUT", data=data.model_dump()
132 | )
133 |
134 | status, text = server.call(
135 | path=f"/v1/tasks/{task_id}/pending_reviewers", method="GET"
136 | )
137 | assert status == 200
138 | pending_reviewers = V1PendingReviewers.model_validate_json(text)
139 | assert pending_reviewers.users is not None
140 | assert len(pending_reviewers.users) == 3
141 |
142 | status, text = server.call(path="/v1/pending_reviews", method="GET")
143 | assert status == 200
144 | pending_reviews = V1PendingReviews.model_validate_json(text)
145 | assert len(pending_reviews.tasks) == 0
146 |
147 | # Update the task
148 | update_data = {
149 | "description": "Search for german ducks",
150 | "status": "in progress",
151 | "set_labels": {"test_set": "true"},
152 | }
153 | status, text = server.call(
154 | path=f"/v1/tasks/{task_id}", method="PUT", data=update_data
155 | )
156 | print("status: ", status)
157 | print("task updated: ", text)
158 | assert status == 200
159 | task = V1Task.model_validate(json.loads(text))
160 | assert task.description == "Search for german ducks"
161 | assert task.status == "in progress"
162 |
163 | # Post a message to the task
164 | message_data = {
165 | "role": "user",
166 | "msg": "This is a test message.",
167 | "images": [],
168 | "thread": None,
169 | }
170 | status, _ = server.call(
171 | path=f"/v1/tasks/{task_id}/msg", method="POST", data=message_data
172 | )
173 | print("status: ", status)
174 | assert status == 200
175 |
176 | # Create a thread
177 | thread_data = {"name": "test-thread", "public": True, "metadata": {}}
178 | status, _ = server.call(
179 | path=f"/v1/tasks/{task_id}/threads", method="POST", data=thread_data
180 | )
181 | print("create thread status: ", status)
182 | assert status == 200
183 |
184 | # Remove a thread
185 | remove_thread_data = {"id": "test-thread"}
186 | status, _ = server.call(
187 | path=f"/v1/tasks/{task_id}/threads",
188 | method="DELETE",
189 | data=remove_thread_data,
190 | )
191 | print("remove thread status: ", status)
192 | assert status == 200
193 |
194 | # Store a prompt in the task
195 | prompt = Prompt(
196 | thread=RoleThread(
197 | name="test-thread",
198 | public=True,
199 | ),
200 | response=RoleMessage(
201 | role="assistant",
202 | text="This is a test response",
203 | images=[],
204 | ),
205 | )
206 | print("sending prompt id: ", prompt.id)
207 | status, resp = server.call(
208 | path=f"/v1/tasks/{task_id}/prompts",
209 | method="POST",
210 | data=prompt.to_v1().model_dump(),
211 | )
212 | print("store prompt status: ", status)
213 | assert status == 200
214 |
215 | status, resp_text = server.call(
216 | path=f"/v1/tasks/{task_id}",
217 | method="GET",
218 | )
219 | print("get task status: ", status)
220 | assert status == 200
221 | task = V1Task.model_validate(json.loads(resp_text))
222 | print("task: ", task)
223 |
224 | print("store prompt response: ", resp)
225 |
226 | # Approve a prompt
227 | prompt_id = json.loads(resp)["id"]
228 |
229 | print("prompt id: ", prompt_id)
230 | status, _ = server.call(
231 | path=f"/v1/tasks/{task_id}/prompts/{prompt_id}/approve", method="POST"
232 | )
233 | print("approve prompt status: ", status)
234 | assert status == 200
235 |
236 | # Write a review
237 | review = V1CreateReview(approved=True, reason="test")
238 | status, _ = server.call(
239 | path=f"/v1/tasks/{task_id}/review",
240 | method="PUT",
241 | data=review.model_dump(),
242 | )
243 | print("store action status: ", status)
244 | assert status == 200
245 |
246 | # Store an action event
247 | action_event = ActionEvent(
248 | state=EnvState(images=["https://test.img"]),
249 | action=V1Action(name="test", parameters={}),
250 | tool=V1ToolRef(module="test", type="test"),
251 | prompt=prompt,
252 | )
253 |
254 | status, _ = server.call(
255 | path=f"/v1/tasks/{task_id}/actions",
256 | method="POST",
257 | data=action_event.to_v1().model_dump(),
258 | )
259 | print("store action status: ", status)
260 | assert status == 200
261 |
262 | status, resp_text = server.call(
263 | path=f"/v1/tasks/{task_id}/actions",
264 | method="GET",
265 | )
266 | print("get task status: ", status)
267 | assert status == 200
268 | events = V1ActionEvents.model_validate(json.loads(resp_text))
269 | print("events: ", events)
270 | assert len(events.events) > 0
271 |
272 | # First, find the `action_id` from previously stored action event
273 | status, resp_text = server.call(
274 | path=f"/v1/tasks/{task_id}/actions",
275 | method="GET",
276 | )
277 | assert status == 200
278 | events = V1ActionEvents.model_validate_json(resp_text)
279 | assert len(events.events) > 0
280 | action_id = events.events[0].id
281 |
282 | # Create an annotation
283 | annotation = V1AnnotationReviewable(
284 | key="test",
285 | value="This is a test annotation",
286 | annotator="tester@example.com",
287 | )
288 | status, resp_text = server.call(
289 | path=f"/v1/tasks/{task_id}/actions/{action_id}/annotations",
290 | method="POST",
291 | data=annotation.model_dump(),
292 | )
293 | assert status == 200
294 |
295 | # Retrieve the `annotation_id`
296 | # Assuming the response contains the `annotation_id` or we can fetch it
297 | resp_json = json.loads(resp_text)
298 | annotation_id = resp_json.get("id") # Adjust based on actual response
299 |
300 | # Now, review the annotation
301 | review = V1CreateAnnotationReview(
302 | approved=True,
303 | reviewer="reviewer@example.com",
304 | reviewer_type="human",
305 | reason="Annotation looks good",
306 | )
307 | status, resp_text = server.call(
308 | path=f"/v1/tasks/{task_id}/actions/{action_id}/annotations/{annotation_id}/review",
309 | method="POST",
310 | data=review.model_dump(),
311 | )
312 | assert status == 200
313 |
314 | # Verify that the annotation review was recorded properly
315 | # Fetch the action event and check the annotation's review status
316 | status, resp_text = server.call(
317 | path=f"/v1/tasks/{task_id}/actions",
318 | method="GET",
319 | )
320 | assert status == 200
321 | events = V1ActionEvents.model_validate_json(resp_text)
322 | action_event = next((e for e in events.events if e.id == action_id), None)
323 | assert action_event is not None
324 |
325 | # Assuming `action_event` has a `reviewables` field which contains annotations
326 | annotation = next(
327 | (a for a in action_event.reviewables if a.id == annotation_id), None
328 | )
329 | assert annotation is not None
330 | assert annotation.reviews[0].approved
331 | assert annotation.reviews[0].reviewer == "reviewer@example.com"
332 | assert annotation.reviews[0].reason == "Annotation looks good"
333 |
334 | print("getting remote task")
335 | found_task = Task.get(id=task_id, remote=f"http://localhost:{server.port}")
336 |
337 | # Delete the task
338 | status, _ = server.call(path=f"/v1/tasks/{task_id}", method="DELETE")
339 | print("delete task status: ", status)
340 | assert status == 200
341 |
342 | print("creating a new task")
343 |
344 | class Expected(BaseModel):
345 | foo: str
346 | bar: int
347 |
348 | new_task = Task(
349 | description="a good test",
350 | remote=f"http://localhost:{server.port}",
351 | expect=Expected,
352 | review_requirements=[
353 | ReviewRequirement(number_required=1, users=["tom@myspace.com"])
354 | ],
355 | )
356 | print("created a new task: ", new_task.id)
357 |
358 | action_event = ActionEvent(
359 | state=EnvState(images=["https://test.img"]),
360 | action=V1Action(name="test", parameters={}),
361 | tool=V1ToolRef(module="test", type="test"),
362 | prompt=prompt,
363 | )
364 | new_task.record_action_event(action_event)
365 |
366 | status, resp_text = server.call(
367 | path=f"/v1/tasks/{new_task.id}/actions",
368 | method="GET",
369 | )
370 | print("get task status: ", status)
371 | assert status == 200
372 | events = V1ActionEvents.model_validate(json.loads(resp_text))
373 | print("events: ", events)
374 | assert len(events.events) > 0
375 |
376 | # Check the review requirements
377 | status, text = server.call(
378 | path=f"/v1/tasks/{new_task.id}/pending_reviewers", method="GET"
379 | )
380 | assert status == 200
381 | pending_reviewers = V1PendingReviewers.model_validate_json(text)
382 | assert pending_reviewers.users is not None
383 | assert len(pending_reviewers.users) == 1
384 |
385 | status, text = server.call(path="/v1/pending_reviews", method="GET")
386 | assert status == 200
387 | pending_reviews = V1PendingReviews.model_validate_json(text)
388 | assert len(pending_reviews.tasks) == 1
389 |
390 | # Check that we can update the task
391 | new_task.refresh()
392 | new_task.status = TaskStatus.CANCELED
393 | new_task.save()
394 |
395 | status, resp_text = server.call(
396 | path=f"/v1/tasks/{new_task.id}/actions",
397 | method="GET",
398 | )
399 | print("get task status: ", status)
400 | assert status == 200
401 | events = V1ActionEvents.model_validate(json.loads(resp_text))
402 | print("events: ", events)
403 | assert len(events.events) > 0
404 |
405 | # Delete all actions associated with the task
406 | status, _ = server.call(
407 | path=f"/v1/tasks/{new_task.id}/actions",
408 | method="DELETE",
409 | )
410 | print("delete all actions status: ", status)
411 | assert status == 200
412 |
413 | # Verify that no actions remain
414 | status, resp_text = server.call(
415 | path=f"/v1/tasks/{new_task.id}/actions",
416 | method="GET",
417 | )
418 | print("get actions after deletion status: ", status)
419 | assert status == 200
420 | events = V1ActionEvents.model_validate(json.loads(resp_text))
421 | print("events after deletion: ", events)
422 | assert len(events.events) == 0
423 |
424 | # Benchmarks
425 | tpl0 = TaskTemplate(
426 | description="A good test 0",
427 | device_type=V1DeviceType(name="desktop"),
428 | owner_id="tom@myspace.com",
429 | )
430 | tpl1 = TaskTemplate(
431 | description="A good test 1",
432 | device_type=V1DeviceType(name="mobile"),
433 | owner_id="tom@myspace.com",
434 | )
435 | bench = Benchmark(
436 | name="test-bench",
437 | description="A good benchmark",
438 | tasks=[tpl0, tpl1],
439 | owner_id="tom@myspace.com",
440 | )
441 | status, _ = server.call(
442 | path="/v1/benchmarks", method="POST", data=bench.to_v1().model_dump()
443 | )
444 | assert status == 200
445 |
446 | status, text = server.call(
447 | path="/v1/benchmarks",
448 | method="GET",
449 | )
450 | benchmarks = V1Benchmarks.model_validate_json(text)
451 | assert benchmarks.benchmarks[0].description == "A good benchmark"
452 |
453 | status, text = server.call(
454 | path=f"/v1/benchmarks/{benchmarks.benchmarks[0].id}/eval",
455 | method="POST",
456 | data=V1BenchmarkEval(
457 | assigned_to="test_agent", assigned_type="pizza"
458 | ).model_dump(),
459 | )
460 | assert status == 200
461 |
462 | v1eval = V1Eval.model_validate_json(text)
463 | assert v1eval.owner_id == "tom@myspace.com"
464 | assert v1eval.assigned_to == "test_agent"
465 | assert v1eval.assigned_type == "pizza"
466 |
467 | status, text = server.call(
468 | path="/v1/evals",
469 | method="GET",
470 | )
471 | evals = V1Evals.model_validate_json(text)
472 | assert evals.evals[0].owner_id == "tom@myspace.com"
473 |
474 | except:
475 | print(server.logs())
476 | raise
477 |
478 | finally:
479 | # Ensure the server is deleted
480 | try:
481 | server.delete()
482 | except:
483 | pass
484 |
--------------------------------------------------------------------------------
/tests/test_task.py:
--------------------------------------------------------------------------------
1 | from devicebay import V1Device
2 | from pydantic import BaseModel
3 | from threadmem import RoleThread
4 |
5 | from taskara import Task
6 |
7 |
8 | class TestConnectConfig(BaseModel):
9 | a: str
10 | b: int
11 |
12 |
13 | # Test the thread creation functionality within the Task class
14 | def test_create_thread():
15 | class Expected(BaseModel):
16 | foo: str
17 | bar: int
18 |
19 | task = Task(
20 | description="Test Task",
21 | owner_id="owner123",
22 | id="task123",
23 | expect=Expected,
24 | device=V1Device(type="desktop", config=TestConnectConfig(a="a", b=1)),
25 | )
26 | assert len(task.threads) == 1
27 |
28 | # Directly call the method that doesn't involve remote calls
29 | task.create_thread(name="New Local Thread", public=True)
30 |
31 | # Verify a new thread is added
32 | assert len(task.threads) == 2
33 | # Verify the properties of the newly created thread
34 | new_thread = task.threads[-1]
35 | assert new_thread.name == "New Local Thread"
36 | assert new_thread.public is True
37 |
38 | task.refresh()
39 |
40 | found = Task.find(id=task.id)
41 | assert len(found) == 1
42 | print("\nfound: ", found[0].__dict__)
43 | assert found[0].device.config["a"] == "a" # type: ignore
44 | assert found[0].device.config["b"] == 1 # type: ignore
45 |
46 |
47 | # Test posting a message to a thread within the Task class
48 | def test_post_message():
49 | class Expected(BaseModel):
50 | foo: str
51 | bar: int
52 |
53 | task = Task(
54 | description="Test Task 2", owner_id="owner1234", id="task1234", expect=Expected
55 | )
56 |
57 | # Directly call the method that doesn't involve remote calls
58 | task.create_thread(name="Prompt", public=True)
59 | messages = task.messages(thread="Prompt")
60 | print("initial messages: ", messages)
61 | assert len(messages) == 0
62 |
63 | # Act: Post a message to the thread
64 | task.post_message(role="user", msg="Test Message", thread="Prompt")
65 | messages = task.messages(thread="Prompt")
66 | assert len(messages) == 1
67 | message = messages[0]
68 | assert message.text == "Test Message"
69 | assert message.role == "user"
70 |
71 | threads = RoleThread.find(name="Prompt")
72 | assert len(threads) == 1
73 | thread = threads[0]
74 |
75 | messages = thread.messages()
76 | assert len(messages) == 1
77 | message = messages[0]
78 | assert message.text == "Test Message"
79 | assert message.role == "user"
80 |
81 | task.post_message(role="moderator", msg="Test Message 5")
82 | messages = task.messages()
83 | assert len(messages) == 1
84 | message = messages[0]
85 |
86 | assert message.text == "Test Message 5"
87 | assert message.role == "moderator"
88 |
89 | def test_find_many_lite():
90 | # Create three tasks
91 | task1 = Task(description="Task 1", owner_id="owner1", id="task1")
92 | task2 = Task(description="Task 2", owner_id="owner1", id="task2")
93 | task3 = Task(description="Task 3", owner_id="owner2", id="task3")
94 |
95 | # Manually save tasks to ensure they're committed to the database
96 | # Note: The Task class's save method should handle this, so just calling them is enough.
97 | task1.save()
98 | task2.save()
99 | task3.save()
100 |
101 | # Test retrieving tasks by IDs
102 | found_tasks = Task.find_many_lite(task_ids=["task1", "task2"])
103 | print(f"found tasks {found_tasks}", flush=True)
104 | # Verify that only task1 and task2 are returned
105 | assert len(found_tasks) == 2
106 | found_ids = [t.id for t in found_tasks]
107 | assert "task1" in found_ids
108 | assert "task2" in found_ids
109 | assert "task3" not in found_ids
110 |
111 | # Test retrieving a subset
112 | found_task = Task.find_many_lite(task_ids=["task1"])
113 | assert len(found_task) == 1
114 | assert found_task[0].id == "task1"
115 |
116 | # Test retrieving no tasks
117 | found_none = Task.find_many_lite(task_ids=["nonexistent"])
118 | assert len(found_none) == 0
119 |
120 | def test_find_many_lite_with_reviews_and_reqs():
121 | """
122 | Verifies that find_many_lite correctly returns tasks along with:
123 | - Their parameters
124 | - Their associated Reviews
125 | - Their associated ReviewRequirements
126 | in a single batched lookup.
127 | """
128 |
129 | # -----------------------------
130 | # 1) Create sample tasks
131 | # -----------------------------
132 | task4 = Task(description="Task 4", owner_id="owner4", id="task4")
133 | task5 = Task(description="Task 5", owner_id="owner4", id="task5")
134 | task6 = Task(description="Task 6", owner_id="owner5", id="task6")
135 |
136 | # Give each task some parameters
137 | task4.parameters = {"foo": "bar4", "alpha": 123}
138 | task5.parameters = {"foo": "bar5", "beta": 999}
139 | task6.parameters = {"foo": "bar6", "gamma": 42}
140 |
141 | # -----------------------------
142 | # 2) Create Reviews for tasks
143 | # -----------------------------
144 | from skillpacks.review import Review, Resource
145 |
146 | review_a = Review(
147 | reviewer="alice",
148 | approved=True,
149 | resource_type=Resource.TASK.value,
150 | resource_id="task4", # Associate with task4
151 | reason="Review A - All good",
152 | )
153 | review_b = Review(
154 | reviewer="bob",
155 | approved=False,
156 | resource_type=Resource.TASK.value,
157 | resource_id="task4", # Also task4
158 | reason="Review B - Some issues",
159 | )
160 | review_c = Review(
161 | reviewer="charlie",
162 | approved=True,
163 | resource_type=Resource.TASK.value,
164 | resource_id="task5", # For task5
165 | reason="Review C - LGTM",
166 | )
167 | review_d = Review(
168 | reviewer="david",
169 | approved=True,
170 | resource_type=Resource.TASK.value,
171 | resource_id="task6", # For task6
172 | reason="Review D - Quick check",
173 | )
174 |
175 | # Save Reviews to DB so they have real IDs
176 | review_a.save()
177 | review_b.save()
178 | review_c.save()
179 | review_d.save()
180 |
181 | # Link them to the tasks
182 | task4.reviews = [review_a, review_b]
183 | task5.reviews = [review_c]
184 | task6.reviews = [review_d]
185 |
186 | # -----------------------------
187 | # 3) Create ReviewRequirements
188 | # -----------------------------
189 | from taskara.review import ReviewRequirement
190 |
191 | req_a = ReviewRequirement(
192 | task_id="task4",
193 | number_required=1,
194 | users=["alice", "bob"],
195 | )
196 | req_b = ReviewRequirement(
197 | task_id="task5",
198 | number_required=2,
199 | users=["charlie", "someone_else"],
200 | )
201 | req_c = ReviewRequirement(
202 | task_id="task6",
203 | number_required=1,
204 | agents=["agent_1"],
205 | )
206 |
207 | req_a.save()
208 | req_b.save()
209 | req_c.save()
210 |
211 | # Link them
212 | task4.review_requirements = [req_a]
213 | task5.review_requirements = [req_b]
214 | task6.review_requirements = [req_c]
215 |
216 | # -----------------------------
217 | # 4) Save all tasks to DB
218 | # -----------------------------
219 | task4.save()
220 | task5.save()
221 | task6.save()
222 |
223 | # -----------------------------
224 | # 5) Exercise find_many_lite
225 | # -----------------------------
226 | found = Task.find_many_lite(task_ids=["task4", "task5", "task6"])
227 | assert len(found) == 3, "Should return all three tasks"
228 |
229 | # Turn the list into a dict for easy lookup by ID
230 | found_dict = {t.id: t for t in found}
231 |
232 | # -----------------------------
233 | # 6) Verify Task 4 data
234 | # -----------------------------
235 | t4 = found_dict["task4"]
236 | assert t4.parameters["foo"] == "bar4" if t4.parameters else None
237 | assert t4.parameters["alpha"] == 123 if t4.parameters else None
238 | assert len(t4.reviews) == 2, "Task 4 should have 2 reviews"
239 | reviewers_t4 = {r.reviewer for r in t4.reviews}
240 | assert reviewers_t4 == {"alice", "bob"}
241 | assert len(t4.review_requirements) == 1, "Task 4 should have 1 review requirement"
242 | assert t4.review_requirements[0].users == ["alice", "bob"]
243 |
244 | # -----------------------------
245 | # 7) Verify Task 5 data
246 | # -----------------------------
247 | t5 = found_dict["task5"]
248 | assert t5.parameters["foo"] == "bar5" if t5.parameters else None
249 | assert t5.parameters["beta"] == 999 if t5.parameters else None
250 | assert len(t5.reviews) == 1, "Task 5 should have 1 review"
251 | assert t5.reviews[0].reviewer == "charlie"
252 | assert len(t5.review_requirements) == 1, "Task 5 should have 1 review requirement"
253 | req5 = t5.review_requirements[0]
254 | assert req5.number_required == 2
255 | assert req5.users == ["charlie", "someone_else"]
256 |
257 | # -----------------------------
258 | # 8) Verify Task 6 data
259 | # -----------------------------
260 | t6 = found_dict["task6"]
261 | assert t6.parameters["foo"] == "bar6" if t6.parameters else None
262 | assert t6.parameters["gamma"] == 42 if t6.parameters else None
263 | assert len(t6.reviews) == 1, "Task 6 should have 1 review"
264 | assert t6.reviews[0].reviewer == "david"
265 | assert len(t6.review_requirements) == 1, "Task 6 should have 1 review requirement"
266 | req6 = t6.review_requirements[0]
267 | assert req6.agents == ["agent_1"]
268 | assert req6.number_required == 1
269 |
270 | print("test_find_many_lite_with_reviews_and_reqs passed!")
--------------------------------------------------------------------------------
/tests/test_tpl.py:
--------------------------------------------------------------------------------
1 | import pytest
2 | from devicebay import V1Device, V1DeviceType
3 | from pydantic import BaseModel
4 |
5 | from taskara.benchmark import TaskTemplate
6 | from taskara.server.models import V1TaskTemplate
7 | from taskara.task import Task
8 |
9 |
10 | def test_task_template_creation():
11 | task_template = TaskTemplate(
12 | description="Test Task",
13 | max_steps=10,
14 | owner_id="owner@example.com",
15 | device=V1Device(type="device1"),
16 | device_type=V1DeviceType(name="device_type1"),
17 | parameters={"param1": "value1"},
18 | labels={"label1": "value1"},
19 | tags=["tag1", "tag2"],
20 | )
21 |
22 | assert task_template.description == "Test Task"
23 | assert task_template.max_steps == 10
24 | assert task_template.owner_id == "owner@example.com"
25 | assert task_template.device.type == "device1" # type: ignore
26 | assert task_template.device_type.name == "device_type1" # type: ignore
27 | assert task_template.parameters == {"param1": "value1"}
28 | assert task_template.labels == {"label1": "value1"}
29 | assert task_template.tags == ["tag1", "tag2"]
30 |
31 |
32 | def test_task_template_to_task():
33 | task_template = TaskTemplate(
34 | description="Test Task",
35 | max_steps=10,
36 | owner_id="owner@example.com",
37 | )
38 |
39 | task = task_template.to_task()
40 |
41 | assert task.description == "Test Task"
42 | assert task.max_steps == 10
43 | assert task.owner_id == "owner@example.com"
44 |
45 |
46 | def test_task_template_from_task():
47 | task = Task(
48 | description="Test Task",
49 | max_steps=10,
50 | owner_id="owner@example.com",
51 | )
52 |
53 | task_template = TaskTemplate.from_task(task)
54 |
55 | assert task_template.description == "Test Task"
56 | assert task_template.max_steps == 10
57 | assert task_template.owner_id == "owner@example.com"
58 |
59 |
60 | def test_task_template_to_v1():
61 | task_template = TaskTemplate(
62 | description="Test Task",
63 | max_steps=10,
64 | owner_id="owner@example.com",
65 | )
66 |
67 | v1_task_template = task_template.to_v1()
68 |
69 | assert v1_task_template.description == "Test Task"
70 | assert v1_task_template.max_steps == 10
71 | assert v1_task_template.owner_id == "owner@example.com"
72 |
73 |
74 | def test_task_template_from_v1():
75 | v1_task_template = V1TaskTemplate(
76 | id="task1",
77 | description="Test Task",
78 | max_steps=10,
79 | owner_id="owner@example.com",
80 | created=0.0,
81 | )
82 |
83 | task_template = TaskTemplate.from_v1(v1_task_template)
84 |
85 | assert task_template.description == "Test Task"
86 | assert task_template.max_steps == 10
87 | assert task_template.owner_id == "owner@example.com"
88 |
--------------------------------------------------------------------------------