├── aana
├── __init__.py
├── core
│ ├── __init__.py
│ ├── models
│ │ ├── __init__.py
│ │ ├── file.py
│ │ ├── captions.py
│ │ ├── time.py
│ │ ├── exception.py
│ │ ├── custom_config.py
│ │ ├── api_service.py
│ │ ├── types.py
│ │ ├── speaker.py
│ │ ├── task.py
│ │ ├── vad.py
│ │ └── api.py
│ ├── chat
│ │ ├── templates
│ │ │ ├── __init__.py
│ │ │ └── llama2.jinja
│ │ └── chat_template.py
│ └── libraries
│ │ ├── audio.py
│ │ └── image.py
├── tests
│ ├── __init__.py
│ ├── units
│ │ ├── __init__.py
│ │ ├── test_question.py
│ │ ├── test_typed_dict.py
│ │ ├── test_media_id.py
│ │ ├── test_rate_limiter.py
│ │ ├── test_deployment_retry.py
│ │ ├── test_app_deploy.py
│ │ ├── test_merge_options.py
│ │ ├── test_whisper_params.py
│ │ ├── test_sampling_params.py
│ │ ├── test_deployment_restart.py
│ │ ├── test_settings.py
│ │ ├── test_event_manager.py
│ │ ├── test_speaker.py
│ │ ├── test_app_upload.py
│ │ └── test_app.py
│ ├── projects
│ │ ├── __init__.py
│ │ └── lowercase
│ │ │ ├── __init__.py
│ │ │ └── app.py
│ ├── integration
│ │ └── __init__.py
│ ├── const.py
│ ├── files
│ │ ├── expected
│ │ │ ├── vad
│ │ │ │ └── squirrel.json
│ │ │ ├── hf_blip2
│ │ │ │ └── blip2_deployment_Starry_Night.jpeg.json
│ │ │ ├── idefics
│ │ │ │ └── idefics_2_8b_deployment_Starry_Night.jpeg_394ead9926577785413d0c748ebf9878.json
│ │ │ ├── image_text_generation
│ │ │ │ ├── phi-3.5-vision-instruct_vllm_deployment_Starry_Night.jpeg_394ead9926577785413d0c748ebf9878.json
│ │ │ │ ├── pixtral_12b_2409_vllm_deployment_Starry_Night.jpeg_394ead9926577785413d0c748ebf9878.json
│ │ │ │ ├── qwen2_vl_3b_gemlite_vllm_deployment_Starry_Night.jpeg_394ead9926577785413d0c748ebf9878.json
│ │ │ │ ├── qwen2_vl_7b_instruct_vllm_deployment_Starry_Night.jpeg_394ead9926577785413d0c748ebf9878.json
│ │ │ │ ├── qwen2_vl_3b_gemlite_vllm_deployment_video_Starry_Night.jpeg_394ead9926577785413d0c748ebf9878.json
│ │ │ │ ├── qwen2_vl_3b_gemlite_onthefly_vllm_deployment_Starry_Night.jpeg_394ead9926577785413d0c748ebf9878.json
│ │ │ │ └── qwen2_vl_3b_gemlite_onthefly_vllm_deployment_video_Starry_Night.jpeg_394ead9926577785413d0c748ebf9878.json
│ │ │ ├── hf_pipeline
│ │ │ │ └── hf_pipeline_blip2_deployment_Starry_Night.jpeg.json
│ │ │ ├── text_generation
│ │ │ │ ├── phi-3.5-vision-instruct_vllm_deployment_2f8c1d10dab7f75e25cc2d4f29c54469.json
│ │ │ │ ├── phi3_mini_4k_instruct_vllm_deployment_2f8c1d10dab7f75e25cc2d4f29c54469.json
│ │ │ │ ├── phi3_mini_4k_instruct_hf_text_generation_deployment_2f8c1d10dab7f75e25cc2d4f29c54469.json
│ │ │ │ ├── qwen2_vl_7b_instruct_vllm_deployment_2f8c1d10dab7f75e25cc2d4f29c54469.json
│ │ │ │ ├── pixtral_12b_2409_vllm_deployment_2f8c1d10dab7f75e25cc2d4f29c54469.json
│ │ │ │ ├── qwen1.5_1.8b_chat_vllm_deployment_with_lora_2f8c1d10dab7f75e25cc2d4f29c54469.json
│ │ │ │ ├── gemma_3_1b_it_hf_text_generation_deployment_2f8c1d10dab7f75e25cc2d4f29c54469.json
│ │ │ │ ├── qwen2_vl_3b_gemlite_vllm_deployment_2f8c1d10dab7f75e25cc2d4f29c54469.json
│ │ │ │ └── qwen2_vl_3b_gemlite_onthefly_vllm_deployment_2f8c1d10dab7f75e25cc2d4f29c54469.json
│ │ │ ├── whisper
│ │ │ │ ├── whisper_medium_squirrel.wav.json
│ │ │ │ ├── whisper_tiny_squirrel.wav.json
│ │ │ │ ├── whisper_turbo_squirrel.wav.json
│ │ │ │ ├── whisper_medium_squirrel.wav_batched.json
│ │ │ │ ├── whisper_tiny_squirrel.wav_batched.json
│ │ │ │ └── whisper_turbo_squirrel.wav_batched.json
│ │ │ ├── hqq_generation
│ │ │ │ ├── meta-llama
│ │ │ │ │ └── Meta-Llama-3.1-8B-Instruct_2f8c1d10dab7f75e25cc2d4f29c54469.json
│ │ │ │ └── mobiuslabsgmbh
│ │ │ │ │ └── Llama-3.1-8b-instruct_4bitgs64_hqq_calib_2f8c1d10dab7f75e25cc2d4f29c54469.json
│ │ │ └── sd
│ │ │ │ └── sd_sample.json
│ │ ├── audios
│ │ │ ├── squirrel.wav
│ │ │ ├── physicsworks.wav
│ │ │ └── ATTRIBUTION.md
│ │ ├── videos
│ │ │ ├── squirrel.mp4
│ │ │ ├── physicsworks.webm
│ │ │ ├── squirrel_no_audio.mp4
│ │ │ ├── physicsworks_audio.webm
│ │ │ └── ATTRIBUTION.md
│ │ └── images
│ │ │ ├── Starry_Night.jpeg
│ │ │ └── ATTRIBUTION.md
│ ├── db
│ │ └── datastore
│ │ │ ├── test_config.py
│ │ │ ├── test_video_repo.py
│ │ │ ├── test_transcript_repo.py
│ │ │ └── test_caption_repo.py
│ └── deployments
│ │ ├── test_vad_deployment.py
│ │ ├── test_pyannote_speaker_diarization_deployment.py
│ │ ├── test_hf_blip2_deployment.py
│ │ └── test_hf_pipeline_deployment.py
├── utils
│ ├── __init__.py
│ ├── openapi_code_templates
│ │ ├── __init__.py
│ │ ├── curl_form.j2
│ │ ├── python_form.j2
│ │ └── python_form_streaming.j2
│ ├── gpu.py
│ ├── streamer.py
│ ├── file.py
│ ├── typing.py
│ ├── asyncio.py
│ ├── lazy_import.py
│ └── json.py
├── alembic
│ ├── __init__.py
│ ├── README
│ ├── script.py.mako
│ └── versions
│ │ ├── b9860676dd49_set_server_default_for_task_completed_.py
│ │ ├── d40eba8ebc4c_added_user_id_to_tasks.py
│ │ └── acb40dabc2c0_added_webhooks.py
├── processors
│ ├── __init__.py
│ ├── remote.py
│ └── video.py
├── routers
│ └── __init__.py
├── storage
│ ├── __init__.py
│ ├── repository
│ │ ├── __init__.py
│ │ ├── video.py
│ │ ├── transcript.py
│ │ ├── caption.py
│ │ └── media.py
│ ├── session.py
│ ├── models
│ │ ├── media.py
│ │ ├── video.py
│ │ ├── __init__.py
│ │ ├── webhook.py
│ │ ├── caption.py
│ │ ├── api_key.py
│ │ ├── task.py
│ │ ├── base.py
│ │ └── transcript.py
│ └── types.py
├── integrations
│ ├── __init__.py
│ ├── external
│ │ └── __init__.py
│ └── haystack
│ │ ├── __init__.py
│ │ └── remote_haystack_component.py
├── api
│ ├── __init__.py
│ ├── event_handlers
│ │ ├── event_handler.py
│ │ ├── rate_limit_handler.py
│ │ └── event_manager.py
│ ├── responses.py
│ ├── app.py
│ └── security.py
├── exceptions
│ ├── __init__.py
│ ├── core.py
│ ├── db.py
│ └── api_service.py
├── deployments
│ ├── __init__.py
│ └── haystack_component_deployment.py
└── configs
│ ├── __init__.py
│ └── db.py
├── docs
├── reference
│ ├── sdk.md
│ ├── settings.md
│ ├── endpoint.md
│ ├── storage
│ │ ├── models.md
│ │ └── repositories.md
│ ├── models
│ │ ├── asr.md
│ │ ├── chat.md
│ │ ├── time.md
│ │ ├── vad.md
│ │ ├── speaker.md
│ │ ├── types.md
│ │ ├── video.md
│ │ ├── captions.md
│ │ ├── custom_config.md
│ │ ├── sampling.md
│ │ ├── whisper.md
│ │ ├── image_chat.md
│ │ ├── multimodal_chat.md
│ │ └── media.md
│ ├── utils.md
│ ├── exceptions.md
│ ├── processors.md
│ ├── integrations.md
│ └── index.md
├── images
│ ├── favicon.ico
│ ├── white_logo.png
│ ├── AanaSDK_logo_dark_theme.png
│ └── AanaSDK_logo_light_theme.png
├── pages
│ ├── model_hub
│ │ ├── vad.md
│ │ ├── hf_pipeline.md
│ │ ├── index.md
│ │ └── speaker_recognition.md
│ ├── dev_environment.md
│ ├── code_standards.md
│ ├── settings.md
│ ├── serve_config_files.md
│ ├── docker.md
│ ├── code_overview.md
│ └── openai_api.md
└── stylesheets
│ └── extra.css
├── .env
├── .vscode
├── extensions.json
└── settings.json
├── mypy.ini
├── test_volume.dstack.yml
├── .devcontainer
├── Dockerfile
└── devcontainer.json
├── .github
├── ISSUE_TEMPLATE
│ ├── enhancement.md
│ ├── question.md
│ ├── testing.md
│ ├── feature_request.md
│ ├── documentation.md
│ └── bug_report.md
└── workflows
│ ├── publish_docs.yml
│ ├── run_tests_with_gpu.yml
│ ├── tests.yml
│ └── publish.yml
├── pg_ctl.sh
└── tests.dstack.yml
/aana/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/aana/core/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/aana/tests/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/aana/utils/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/aana/alembic/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/aana/core/models/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/aana/processors/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/aana/routers/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/aana/storage/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/aana/tests/units/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/aana/integrations/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/aana/tests/projects/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/aana/core/chat/templates/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/aana/integrations/external/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/aana/storage/repository/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/aana/tests/integration/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/aana/tests/projects/lowercase/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/aana/utils/openapi_code_templates/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/aana/alembic/README:
--------------------------------------------------------------------------------
1 | Generic single-database configuration.
--------------------------------------------------------------------------------
/docs/reference/sdk.md:
--------------------------------------------------------------------------------
1 | # SDK
2 |
3 | ::: aana.sdk.AanaSDK
--------------------------------------------------------------------------------
/aana/tests/const.py:
--------------------------------------------------------------------------------
1 | ALLOWED_LEVENSTEIN_ERROR_RATE = 0.15
2 |
--------------------------------------------------------------------------------
/docs/reference/settings.md:
--------------------------------------------------------------------------------
1 | # Settings
2 |
3 | ::: aana.configs
--------------------------------------------------------------------------------
/docs/reference/endpoint.md:
--------------------------------------------------------------------------------
1 | # Endpoint
2 |
3 | ::: aana.api.Endpoint
--------------------------------------------------------------------------------
/aana/tests/files/expected/vad/squirrel.json:
--------------------------------------------------------------------------------
1 | {
2 | "segments": []
3 | }
--------------------------------------------------------------------------------
/docs/reference/storage/models.md:
--------------------------------------------------------------------------------
1 | # Models
2 |
3 | ::: aana.storage.models
--------------------------------------------------------------------------------
/docs/reference/models/asr.md:
--------------------------------------------------------------------------------
1 | # ASR Models
2 |
3 | ::: aana.core.models.asr
--------------------------------------------------------------------------------
/.env:
--------------------------------------------------------------------------------
1 | CUDA_VISIBLE_DEVICES=""
2 | HF_HUB_ENABLE_HF_TRANSFER = 1
3 | HF_TOKEN=""
4 |
--------------------------------------------------------------------------------
/docs/reference/models/chat.md:
--------------------------------------------------------------------------------
1 | # Chat Models
2 |
3 | ::: aana.core.models.chat
4 |
--------------------------------------------------------------------------------
/docs/reference/models/time.md:
--------------------------------------------------------------------------------
1 | # Time Models
2 |
3 | ::: aana.core.models.time
4 |
--------------------------------------------------------------------------------
/docs/reference/models/vad.md:
--------------------------------------------------------------------------------
1 | # VAD Models
2 |
3 | ::: aana.core.models.vad
4 |
--------------------------------------------------------------------------------
/docs/reference/models/speaker.md:
--------------------------------------------------------------------------------
1 | # Speaker Models
2 |
3 | ::: aana.core.models.speaker
--------------------------------------------------------------------------------
/docs/reference/models/types.md:
--------------------------------------------------------------------------------
1 | # Types Models
2 |
3 | ::: aana.core.models.types
4 |
--------------------------------------------------------------------------------
/docs/reference/models/video.md:
--------------------------------------------------------------------------------
1 | # Video Models
2 |
3 | ::: aana.core.models.video
4 |
--------------------------------------------------------------------------------
/docs/reference/models/captions.md:
--------------------------------------------------------------------------------
1 | # Caption Models
2 |
3 | ::: aana.core.models.captions
4 |
--------------------------------------------------------------------------------
/docs/images/favicon.ico:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/dropbox/aana_sdk/main/docs/images/favicon.ico
--------------------------------------------------------------------------------
/docs/reference/models/custom_config.md:
--------------------------------------------------------------------------------
1 | # Custom Config
2 |
3 | ::: aana.core.models.custom_config
--------------------------------------------------------------------------------
/docs/reference/models/sampling.md:
--------------------------------------------------------------------------------
1 | # Sampling Models
2 |
3 | ::: aana.core.models.sampling
4 |
--------------------------------------------------------------------------------
/docs/reference/models/whisper.md:
--------------------------------------------------------------------------------
1 | # Whisper Models
2 |
3 | ::: aana.core.models.whisper
4 |
5 |
--------------------------------------------------------------------------------
/.vscode/extensions.json:
--------------------------------------------------------------------------------
1 | {
2 | "recommendations": [
3 | "charliermarsh.ruff"
4 | ]
5 | }
--------------------------------------------------------------------------------
/aana/api/__init__.py:
--------------------------------------------------------------------------------
1 | from aana.api.api_generation import Endpoint
2 |
3 | __all__ = ["Endpoint"]
4 |
--------------------------------------------------------------------------------
/docs/reference/models/image_chat.md:
--------------------------------------------------------------------------------
1 | # Image Chat Models
2 |
3 | ::: aana.core.models.image_chat
4 |
--------------------------------------------------------------------------------
/docs/images/white_logo.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/dropbox/aana_sdk/main/docs/images/white_logo.png
--------------------------------------------------------------------------------
/aana/tests/files/expected/hf_blip2/blip2_deployment_Starry_Night.jpeg.json:
--------------------------------------------------------------------------------
1 | "the starry night by vincent van gogh"
--------------------------------------------------------------------------------
/aana/exceptions/__init__.py:
--------------------------------------------------------------------------------
1 | from aana.exceptions.core import BaseException
2 |
3 | __all__ = ["BaseException"]
4 |
--------------------------------------------------------------------------------
/docs/pages/model_hub/vad.md:
--------------------------------------------------------------------------------
1 | # Voice Activity Detection (VAD) Models
2 |
3 | #TODO: Make VAD deployment more generic
--------------------------------------------------------------------------------
/docs/reference/models/multimodal_chat.md:
--------------------------------------------------------------------------------
1 | # Multimodal Chat Models
2 |
3 | ::: aana.core.models.multimodal_chat
4 |
--------------------------------------------------------------------------------
/aana/tests/files/audios/squirrel.wav:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/dropbox/aana_sdk/main/aana/tests/files/audios/squirrel.wav
--------------------------------------------------------------------------------
/aana/tests/files/expected/idefics/idefics_2_8b_deployment_Starry_Night.jpeg_394ead9926577785413d0c748ebf9878.json:
--------------------------------------------------------------------------------
1 | "Van gogh."
--------------------------------------------------------------------------------
/aana/tests/files/videos/squirrel.mp4:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/dropbox/aana_sdk/main/aana/tests/files/videos/squirrel.mp4
--------------------------------------------------------------------------------
/mypy.ini:
--------------------------------------------------------------------------------
1 | [mypy]
2 | ignore_missing_imports = False
3 | allow_redefinition = True
4 | mypy_path = $MYPY_CONFIG_FILE_DIR/aana
5 |
--------------------------------------------------------------------------------
/aana/tests/files/audios/physicsworks.wav:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/dropbox/aana_sdk/main/aana/tests/files/audios/physicsworks.wav
--------------------------------------------------------------------------------
/docs/images/AanaSDK_logo_dark_theme.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/dropbox/aana_sdk/main/docs/images/AanaSDK_logo_dark_theme.png
--------------------------------------------------------------------------------
/docs/images/AanaSDK_logo_light_theme.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/dropbox/aana_sdk/main/docs/images/AanaSDK_logo_light_theme.png
--------------------------------------------------------------------------------
/test_volume.dstack.yml:
--------------------------------------------------------------------------------
1 | type: volume
2 |
3 | name: test-models-cache
4 |
5 | backend: runpod
6 | region: EU-SE-1
7 |
8 | size: 150GB
--------------------------------------------------------------------------------
/aana/tests/files/images/Starry_Night.jpeg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/dropbox/aana_sdk/main/aana/tests/files/images/Starry_Night.jpeg
--------------------------------------------------------------------------------
/aana/tests/files/videos/physicsworks.webm:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/dropbox/aana_sdk/main/aana/tests/files/videos/physicsworks.webm
--------------------------------------------------------------------------------
/aana/tests/files/videos/squirrel_no_audio.mp4:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/dropbox/aana_sdk/main/aana/tests/files/videos/squirrel_no_audio.mp4
--------------------------------------------------------------------------------
/aana/tests/files/videos/physicsworks_audio.webm:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/dropbox/aana_sdk/main/aana/tests/files/videos/physicsworks_audio.webm
--------------------------------------------------------------------------------
/docs/reference/utils.md:
--------------------------------------------------------------------------------
1 | # Utility functions
2 |
3 | ::: aana.utils.asyncio
4 | ::: aana.utils.json
5 | ::: aana.utils.download
6 | ::: aana.utils.gpu
--------------------------------------------------------------------------------
/docs/reference/exceptions.md:
--------------------------------------------------------------------------------
1 | # Exceptions
2 |
3 | ::: aana.exceptions
4 | ::: aana.exceptions.runtime
5 | ::: aana.exceptions.io
6 | ::: aana.exceptions.db
--------------------------------------------------------------------------------
/aana/tests/files/expected/image_text_generation/phi-3.5-vision-instruct_vllm_deployment_Starry_Night.jpeg_394ead9926577785413d0c748ebf9878.json:
--------------------------------------------------------------------------------
1 | " Vincent van Gogh"
--------------------------------------------------------------------------------
/aana/integrations/haystack/__init__.py:
--------------------------------------------------------------------------------
1 | from aana.integrations.haystack.deployment_component import AanaDeploymentComponent
2 |
3 | __all__ = ["AanaDeploymentComponent"]
4 |
--------------------------------------------------------------------------------
/aana/tests/files/expected/hf_pipeline/hf_pipeline_blip2_deployment_Starry_Night.jpeg.json:
--------------------------------------------------------------------------------
1 | [
2 | {
3 | "generated_text": "the starry night by van gogh\n"
4 | }
5 | ]
--------------------------------------------------------------------------------
/docs/reference/models/media.md:
--------------------------------------------------------------------------------
1 | # Media Models
2 |
3 | The `aana.core.models` provides models for such media types as audio, video, and images.
4 |
5 | ::: aana.core.models
--------------------------------------------------------------------------------
/aana/tests/files/expected/image_text_generation/pixtral_12b_2409_vllm_deployment_Starry_Night.jpeg_394ead9926577785413d0c748ebf9878.json:
--------------------------------------------------------------------------------
1 | "The painter of the image is Vincent Van Gogh."
--------------------------------------------------------------------------------
/aana/tests/files/expected/image_text_generation/qwen2_vl_3b_gemlite_vllm_deployment_Starry_Night.jpeg_394ead9926577785413d0c748ebf9878.json:
--------------------------------------------------------------------------------
1 | "The painter of the image is Vincent van Gogh."
--------------------------------------------------------------------------------
/aana/tests/files/expected/image_text_generation/qwen2_vl_7b_instruct_vllm_deployment_Starry_Night.jpeg_394ead9926577785413d0c748ebf9878.json:
--------------------------------------------------------------------------------
1 | "The painter of the image is Vincent van Gogh."
--------------------------------------------------------------------------------
/aana/tests/files/expected/image_text_generation/qwen2_vl_3b_gemlite_vllm_deployment_video_Starry_Night.jpeg_394ead9926577785413d0c748ebf9878.json:
--------------------------------------------------------------------------------
1 | "The painter of the image is Vincent van Gogh."
--------------------------------------------------------------------------------
/aana/tests/files/expected/image_text_generation/qwen2_vl_3b_gemlite_onthefly_vllm_deployment_Starry_Night.jpeg_394ead9926577785413d0c748ebf9878.json:
--------------------------------------------------------------------------------
1 | "The painter of the image is Vincent van Gogh."
--------------------------------------------------------------------------------
/docs/reference/processors.md:
--------------------------------------------------------------------------------
1 | # Processors
2 |
3 | ::: aana.processors.remote
4 | ::: aana.processors.video
5 | ::: aana.processors.batch
6 | ::: aana.processors.speaker.PostProcessingForDiarizedAsr
--------------------------------------------------------------------------------
/aana/tests/files/expected/image_text_generation/qwen2_vl_3b_gemlite_onthefly_vllm_deployment_video_Starry_Night.jpeg_394ead9926577785413d0c748ebf9878.json:
--------------------------------------------------------------------------------
1 | "The painter of the image is Vincent van Gogh."
--------------------------------------------------------------------------------
/aana/core/models/file.py:
--------------------------------------------------------------------------------
1 | from pathlib import Path
2 | from typing import TypedDict
3 |
4 |
5 | class PathResult(TypedDict):
6 | """Represents a path result describing a file on disk."""
7 |
8 | path: Path
9 |
--------------------------------------------------------------------------------
/aana/utils/openapi_code_templates/curl_form.j2:
--------------------------------------------------------------------------------
1 | curl {{ base_url }}{{ path }} \
2 | --request POST \
3 | --header 'Content-Type: application/x-www-form-urlencoded' \
4 | --header 'x-api-key: YOUR_API_KEY_HERE' \
5 | --data 'body={{ body }}'
--------------------------------------------------------------------------------
/aana/tests/files/expected/text_generation/phi-3.5-vision-instruct_vllm_deployment_2f8c1d10dab7f75e25cc2d4f29c54469.json:
--------------------------------------------------------------------------------
1 | " Elon Musk is a business magnate, industrial designer, and engineer. He is the CEO and lead designer of SpaceX, CEO and product"
--------------------------------------------------------------------------------
/aana/tests/files/expected/text_generation/phi3_mini_4k_instruct_vllm_deployment_2f8c1d10dab7f75e25cc2d4f29c54469.json:
--------------------------------------------------------------------------------
1 | " Elon Musk is a business magnate, industrial designer, and engineer. He is the founder, CEO, CTO, and chief designer of Space"
--------------------------------------------------------------------------------
/docs/reference/integrations.md:
--------------------------------------------------------------------------------
1 | # Integrations
2 |
3 | ::: aana.integrations.haystack
4 | ::: aana.integrations.external.av
5 | ::: aana.integrations.external.decord
6 | ::: aana.integrations.external.opencv
7 | ::: aana.integrations.external.yt_dlp
--------------------------------------------------------------------------------
/aana/deployments/__init__.py:
--------------------------------------------------------------------------------
1 | from aana.deployments.aana_deployment_handle import AanaDeploymentHandle
2 | from aana.deployments.base_deployment import BaseDeployment
3 |
4 | __all__ = [
5 | "AanaDeploymentHandle",
6 | "BaseDeployment",
7 | ]
8 |
--------------------------------------------------------------------------------
/aana/tests/files/expected/text_generation/phi3_mini_4k_instruct_hf_text_generation_deployment_2f8c1d10dab7f75e25cc2d4f29c54469.json:
--------------------------------------------------------------------------------
1 | "Elon Musk is a business magnate, industrial designer, and engineer. He is the founder, CEO, CTO, and chief designer of Space"
--------------------------------------------------------------------------------
/aana/tests/files/expected/text_generation/qwen2_vl_7b_instruct_vllm_deployment_2f8c1d10dab7f75e25cc2d4f29c54469.json:
--------------------------------------------------------------------------------
1 | "Elon Musk is a South African-born entrepreneur, inventor, and business magnate. He is the CEO of SpaceX, the founder of Tesla, Inc., and"
--------------------------------------------------------------------------------
/aana/tests/files/expected/whisper/whisper_medium_squirrel.wav.json:
--------------------------------------------------------------------------------
1 | {
2 | "segments": [],
3 | "transcription_info": {
4 | "language": "silence",
5 | "language_confidence": 1.0
6 | },
7 | "transcription": {
8 | "text": ""
9 | }
10 | }
--------------------------------------------------------------------------------
/aana/tests/files/expected/whisper/whisper_tiny_squirrel.wav.json:
--------------------------------------------------------------------------------
1 | {
2 | "segments": [],
3 | "transcription_info": {
4 | "language": "silence",
5 | "language_confidence": 1.0
6 | },
7 | "transcription": {
8 | "text": ""
9 | }
10 | }
--------------------------------------------------------------------------------
/aana/tests/files/expected/hqq_generation/meta-llama/Meta-Llama-3.1-8B-Instruct_2f8c1d10dab7f75e25cc2d4f29c54469.json:
--------------------------------------------------------------------------------
1 | "Elon Musk is a South African-born entrepreneur, inventor, and business magnate. He is best known for his ambitious ventures in the fields of space exploration,"
--------------------------------------------------------------------------------
/aana/tests/files/expected/text_generation/pixtral_12b_2409_vllm_deployment_2f8c1d10dab7f75e25cc2d4f29c54469.json:
--------------------------------------------------------------------------------
1 | "Elon Musk is a South African-born Canadian-American business magnate, industrial designer, and engineer. He is known for his work in the fields of electric vehicles"
--------------------------------------------------------------------------------
/aana/tests/files/expected/text_generation/qwen1.5_1.8b_chat_vllm_deployment_with_lora_2f8c1d10dab7f75e25cc2d4f29c54469.json:
--------------------------------------------------------------------------------
1 | "Elon Musk is a South African-American entrepreneur, investor, engineer, and inventor who co-founded SpaceX, Tesla, Neuralink, The Boring Company, and"
--------------------------------------------------------------------------------
/.devcontainer/Dockerfile:
--------------------------------------------------------------------------------
1 | FROM nvidia/cuda:12.3.2-cudnn9-devel-ubuntu22.04
2 | RUN apt-get update && apt-get install -y libgl1 libglib2.0-0 ffmpeg locales
3 |
4 | # Set the locale
5 | RUN locale-gen en_US.UTF-8
6 | ENV LANG="en_US.UTF-8" LANGUAGE="en_US:en" LC_ALL="en_US.UTF-8"
7 |
--------------------------------------------------------------------------------
/aana/tests/files/expected/text_generation/gemma_3_1b_it_hf_text_generation_deployment_2f8c1d10dab7f75e25cc2d4f29c54469.json:
--------------------------------------------------------------------------------
1 | "Okay, let's break down who Elon Musk is. He's a truly fascinating and incredibly influential figure, often described as a disruptive innovator and a visionary"
--------------------------------------------------------------------------------
/aana/tests/files/expected/whisper/whisper_turbo_squirrel.wav.json:
--------------------------------------------------------------------------------
1 | {
2 | "segments": [],
3 | "transcription_info": {
4 | "language": "silence",
5 | "language_confidence": 1.0
6 | },
7 | "transcription": {
8 | "text": ""
9 | }
10 | }
--------------------------------------------------------------------------------
/aana/tests/files/expected/text_generation/qwen2_vl_3b_gemlite_vllm_deployment_2f8c1d10dab7f75e25cc2d4f29c54469.json:
--------------------------------------------------------------------------------
1 | "Elon Musk is an American entrepreneur, investor, and inventor. He is the founder and CEO of SpaceX, a private aerospace manufacturer and space transportation company. He"
--------------------------------------------------------------------------------
/aana/tests/files/expected/whisper/whisper_medium_squirrel.wav_batched.json:
--------------------------------------------------------------------------------
1 | {
2 | "segments": [],
3 | "transcription": {
4 | "text": ""
5 | },
6 | "transcription_info": {
7 | "language": "silence",
8 | "language_confidence": 1.0
9 | }
10 | }
--------------------------------------------------------------------------------
/aana/tests/files/expected/whisper/whisper_tiny_squirrel.wav_batched.json:
--------------------------------------------------------------------------------
1 | {
2 | "segments": [],
3 | "transcription": {
4 | "text": ""
5 | },
6 | "transcription_info": {
7 | "language": "silence",
8 | "language_confidence": 1.0
9 | }
10 | }
--------------------------------------------------------------------------------
/aana/tests/files/expected/whisper/whisper_turbo_squirrel.wav_batched.json:
--------------------------------------------------------------------------------
1 | {
2 | "segments": [],
3 | "transcription": {
4 | "text": ""
5 | },
6 | "transcription_info": {
7 | "language": "silence",
8 | "language_confidence": 1.0
9 | }
10 | }
--------------------------------------------------------------------------------
/aana/tests/files/expected/hqq_generation/mobiuslabsgmbh/Llama-3.1-8b-instruct_4bitgs64_hqq_calib_2f8c1d10dab7f75e25cc2d4f29c54469.json:
--------------------------------------------------------------------------------
1 | "Elon Musk is a South African-born entrepreneur, inventor, and business magnate. He is best known for his ambitious ventures in the fields of space exploration,"
--------------------------------------------------------------------------------
/aana/tests/files/expected/text_generation/qwen2_vl_3b_gemlite_onthefly_vllm_deployment_2f8c1d10dab7f75e25cc2d4f29c54469.json:
--------------------------------------------------------------------------------
1 | "Elon Musk is an American entrepreneur, investor, and engineer who has made significant contributions to the fields of space exploration, electric vehicles, and renewable energy. He"
--------------------------------------------------------------------------------
/aana/tests/units/test_question.py:
--------------------------------------------------------------------------------
1 | # ruff: noqa: S101
2 |
3 | from aana.core.models.chat import Question
4 |
5 |
6 | def test_question_creation():
7 | """Test that a question can be created."""
8 | question = Question("What is the capital of France?")
9 | assert question == "What is the capital of France?"
10 |
--------------------------------------------------------------------------------
/docs/stylesheets/extra.css:
--------------------------------------------------------------------------------
1 | :root {
2 | --md-primary-fg-color: #A66CFF;
3 | --md-primary-fg-color--light: #CFF500;
4 | --md-primary-fg-color--dark: #3E3E3E;
5 | --md-primary-bg-color: #F0F0F0;
6 | --md-primary-bg-color--light: #FFFFFF;
7 | }
8 |
9 | /* .md-header__topic {
10 | display: none;
11 | } */
--------------------------------------------------------------------------------
/docs/pages/dev_environment.md:
--------------------------------------------------------------------------------
1 | # Dev Environment
2 |
3 | If you are using Visual Studio Code, you can run this repository in a
4 | [dev container](https://code.visualstudio.com/docs/devcontainers/containers). This lets you install and
5 | run everything you need for the repo in an isolated environment via docker on a host system.
6 |
7 |
--------------------------------------------------------------------------------
/docs/reference/storage/repositories.md:
--------------------------------------------------------------------------------
1 | # Repositories
2 |
3 | ::: aana.storage.repository.base
4 | ::: aana.storage.repository.media
5 | ::: aana.storage.repository.video
6 | ::: aana.storage.repository.caption
7 | ::: aana.storage.repository.transcript
8 | ::: aana.storage.repository.task
9 | ::: aana.storage.repository.webhook
10 |
--------------------------------------------------------------------------------
/aana/configs/__init__.py:
--------------------------------------------------------------------------------
1 | from aana.configs.db import DbSettings, DbType, PostgreSQLConfig, SQLiteConfig
2 | from aana.configs.settings import Settings, TaskQueueSettings, TestSettings
3 |
4 | __all__ = [
5 | "DbSettings",
6 | "DbType",
7 | "PostgreSQLConfig",
8 | "SQLiteConfig",
9 | "Settings",
10 | "TaskQueueSettings",
11 | "TestSettings",
12 | ]
13 |
--------------------------------------------------------------------------------
/aana/core/models/captions.py:
--------------------------------------------------------------------------------
1 | from typing import Annotated
2 |
3 | from pydantic import Field
4 |
5 | __all__ = ["Caption", "CaptionsList"]
6 |
7 | Caption = Annotated[str, Field(description="A caption.")]
8 | """
9 | A caption.
10 | """
11 |
12 | CaptionsList = Annotated[list[Caption], Field(description="A list of captions.")]
13 | """
14 | A list of captions.
15 | """
16 |
--------------------------------------------------------------------------------
/aana/utils/gpu.py:
--------------------------------------------------------------------------------
1 | __all__ = ["get_gpu_memory"]
2 |
3 |
4 | def get_gpu_memory(gpu: int = 0) -> int:
5 | """Get the total memory of a GPU in bytes.
6 |
7 | Args:
8 | gpu (int): the GPU index. Defaults to 0.
9 |
10 | Returns:
11 | int: the total memory of the GPU in bytes
12 | """
13 | import torch
14 |
15 | return torch.cuda.get_device_properties(gpu).total_memory
16 |
--------------------------------------------------------------------------------
/aana/utils/openapi_code_templates/python_form.j2:
--------------------------------------------------------------------------------
1 | import json, requests
2 |
3 | payload = {{ body }}
4 |
5 | headers = {
6 | "Content-Type": "application/x-www-form-urlencoded",
7 | "x-api-key": "YOUR_API_KEY_HERE"
8 | }
9 |
10 | response = requests.post("{{ base_url }}{{ path }}",
11 | headers=headers,
12 | data={"body": json.dumps(payload)}
13 | )
14 |
15 | data = response.json()
16 | print(data)
--------------------------------------------------------------------------------
/aana/utils/streamer.py:
--------------------------------------------------------------------------------
1 | import asyncio
2 | from queue import Empty
3 |
4 |
5 | async def async_streamer_adapter(streamer):
6 | """Adapt the TextIteratorStreamer to an async generator."""
7 | while True:
8 | try:
9 | for item in streamer:
10 | yield item
11 | break
12 | except Empty:
13 | # wait for the next item
14 | await asyncio.sleep(0.01)
15 |
--------------------------------------------------------------------------------
/.github/ISSUE_TEMPLATE/enhancement.md:
--------------------------------------------------------------------------------
1 | ---
2 | name: Enhancement
3 | about: Recommend improvement to existing features and suggest code quality improvement.
4 | title: "[ENHANCEMENT]"
5 | labels: enhancement
6 | assignees: ''
7 |
8 | ---
9 |
10 | ### Enhancement Description
11 | - Overview of the enhancement
12 |
13 | ### Advantages
14 | - Benefits of implementing this enhancement
15 |
16 | ### Possible Implementation
17 | - Suggested methods for implementing the enhancement
18 |
--------------------------------------------------------------------------------
/.github/ISSUE_TEMPLATE/question.md:
--------------------------------------------------------------------------------
1 | ---
2 | name: Question
3 | about: Ask questions to clarify project-related queries, seek further information,
4 | or understand functionalities better.
5 | title: "[QUESTION]"
6 | labels: question
7 | assignees: ''
8 |
9 | ---
10 |
11 | ### Context
12 | - Background or context of the question
13 |
14 | ### Question
15 | - Specific question being asked
16 |
17 | ### What You've Tried
18 | - List any solutions or research already conducted
19 |
--------------------------------------------------------------------------------
/.github/ISSUE_TEMPLATE/testing.md:
--------------------------------------------------------------------------------
1 | ---
2 | name: Testing
3 | about: Address needs for creating new tests, enhancing existing tests, or reporting
4 | test failures.
5 | title: "[TESTS]"
6 | labels: tests
7 | assignees: ''
8 |
9 | ---
10 |
11 | ### Testing Requirement
12 | - Describe the testing need or issue
13 |
14 | ### Test Scenarios
15 | - Detail specific test scenarios to be addressed
16 |
17 | ### Acceptance Criteria
18 | - What are the criteria for the test to be considered successful?
19 |
--------------------------------------------------------------------------------
/.vscode/settings.json:
--------------------------------------------------------------------------------
1 | {
2 | "notebook.formatOnSave.enabled": true,
3 | "[python]": {
4 | "editor.defaultFormatter": "charliermarsh.ruff",
5 | "editor.formatOnSave": true,
6 | },
7 | "python.testing.pytestArgs": [
8 | "aana"
9 | ],
10 | "python.testing.unittestEnabled": false,
11 | "python.testing.pytestEnabled": true,
12 | "python.testing.pytestPath": "poetry run pytest",
13 | "ruff.fixAll": true,
14 | "ruff.organizeImports": true,
15 | }
--------------------------------------------------------------------------------
/.github/ISSUE_TEMPLATE/feature_request.md:
--------------------------------------------------------------------------------
1 | ---
2 | name: Feature request
3 | about: Suggest new functionalities or modifications to enhance the application's capabilities.
4 | title: "[FEATURE REQUEST]"
5 | labels: feature request
6 | assignees: ''
7 |
8 | ---
9 |
10 | ### Feature Summary
11 | - Concise description of the feature
12 |
13 | ### Justification/Rationale
14 | - Why is the feature beneficial?
15 |
16 | ### Proposed Implementation (if any)
17 | - How do you envision this feature's implementation?
18 |
--------------------------------------------------------------------------------
/aana/utils/openapi_code_templates/python_form_streaming.j2:
--------------------------------------------------------------------------------
1 | import json, requests
2 | import ijson # pip install ijson
3 |
4 | payload = {{ body }}
5 |
6 | headers = {
7 | "Content-Type": "application/x-www-form-urlencoded",
8 | "x-api-key": "YOUR_API_KEY_HERE"
9 | }
10 |
11 | response = requests.post("{{ base_url }}{{ path }}",
12 | headers=headers,
13 | data={"body": json.dumps(payload)},
14 | stream=True
15 | )
16 |
17 | for obj in ijson.items(response.raw, '', multiple_values=True, buf_size=32):
18 | print(obj)
--------------------------------------------------------------------------------
/pg_ctl.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | # Get the PostgreSQL major version (e.g., 14)
4 | PG_VERSION=$(psql --version | awk '{print $3}' | cut -d '.' -f 1)
5 |
6 | if [ "$2" = "--pgdata" ]; then
7 | # Recursively change parent directories permissions so that pg_ctl can access the necessary files as non-root.
8 | f=$(dirname "$3")
9 | while [[ $f != / ]]; do chmod o+rwx "$f"; f=$(dirname "$f"); done;
10 | fi
11 |
12 | # Execute pg_ctl as postgres instead of root.
13 | sudo -u postgres "/usr/lib/postgresql/$PG_VERSION/bin/pg_ctl" "$@"
14 |
--------------------------------------------------------------------------------
/aana/api/event_handlers/event_handler.py:
--------------------------------------------------------------------------------
1 | from abc import ABC, abstractmethod
2 |
3 |
4 | class EventHandler(ABC):
5 | """Base class for event handlers. Not guaranteed to be thread safe."""
6 |
7 | @abstractmethod
8 | def handle(self, event_name: str, *args, **kwargs):
9 | """Handles an event of the given name.
10 |
11 | Arguments:
12 | event_name (str): name of the event to handle
13 | *args (list): specific, context-dependent args
14 | **kwargs (dict): specific, context-dependent args
15 | """
16 | pass
17 |
--------------------------------------------------------------------------------
/aana/tests/units/test_typed_dict.py:
--------------------------------------------------------------------------------
1 | # ruff: noqa: S101
2 |
3 | from typing import TypedDict
4 |
5 | from aana.utils.typing import is_typed_dict
6 |
7 |
8 | def test_is_typed_dict() -> None:
9 | """Test the is_typed_dict function."""
10 |
11 | class Foo(TypedDict):
12 | foo: str
13 | bar: int
14 |
15 | assert is_typed_dict(Foo) == True
16 |
17 |
18 | def test_is_not_typed_dict():
19 | """Test the is_typed_dict function."""
20 |
21 | class Foo:
22 | pass
23 |
24 | assert is_typed_dict(Foo) == False
25 | assert is_typed_dict(dict) == False
26 |
--------------------------------------------------------------------------------
/.github/ISSUE_TEMPLATE/documentation.md:
--------------------------------------------------------------------------------
1 | ---
2 | name: Documentation
3 | about: Propose updates or corrections to both internal development documentation and
4 | client-facing documentation.
5 | title: "[DOCS]"
6 | labels: documentation
7 | assignees: ''
8 |
9 | ---
10 |
11 | ### Documentation Area (Development/Client)
12 | - Specify the area of documentation
13 |
14 | ### Current Content
15 | - Quote the current content or describe the issue
16 |
17 | ### Proposed Changes
18 | - Detail the proposed changes
19 |
20 | ### Reasons for Changes
21 | - Why these changes will improve the documentation
22 |
--------------------------------------------------------------------------------
/aana/core/models/time.py:
--------------------------------------------------------------------------------
1 | from pydantic import BaseModel, ConfigDict, Field
2 |
3 | __all__ = ["TimeInterval"]
4 |
5 |
6 | class TimeInterval(BaseModel):
7 | """Pydantic schema for TimeInterval.
8 |
9 | Attributes:
10 | start (float): Start time in seconds
11 | end (float): End time in seconds
12 | """
13 |
14 | start: float = Field(ge=0.0, description="Start time in seconds")
15 | end: float = Field(ge=0.0, description="End time in seconds")
16 | model_config = ConfigDict(
17 | json_schema_extra={
18 | "description": "Time interval in seconds",
19 | },
20 | extra="forbid",
21 | )
22 |
--------------------------------------------------------------------------------
/aana/tests/files/images/ATTRIBUTION.md:
--------------------------------------------------------------------------------
1 | # Third-Party Media Attribution
2 |
3 | This folder contains third-party media used for testing. No endorsement implied.
4 |
5 | _Last updated: 2025-08-22_
6 |
7 | | File | Source | License/Status | Author/Owner | Changes Made |
8 | |---|---|---|---|---|
9 | | `Starry_Night.jpeg` | Wikimedia Commons — https://commons.wikimedia.org/wiki/File:Van_Gogh_-_Starry_Night_-_Google_Art_Project.jpg (faithful reproduction; Google Arts & Culture — bgEuwDxel93-Pg) | Public domain (faithful reproduction of a 2D public-domain work) | Vincent van Gogh (original work); reproduction per Commons record | Downloaded and saved as JPEG for testing. |
10 |
--------------------------------------------------------------------------------
/.github/ISSUE_TEMPLATE/bug_report.md:
--------------------------------------------------------------------------------
1 | ---
2 | name: Bug report
3 | about: Report a malfunction, glitch, or error in the application. This includes any
4 | performance-related issues that may arise.
5 | title: "[BUG]"
6 | labels: bug
7 | assignees: ''
8 |
9 | ---
10 |
11 | ### Bug Description
12 | - Brief summary of the issue
13 |
14 | ### Steps to Reproduce
15 | 1.
16 | 2.
17 | 3.
18 |
19 | ### Expected Behavior
20 | - What should have happened?
21 |
22 | ### Actual Behavior
23 | - What actually happened?
24 |
25 | ### Performance Details (if applicable)
26 | - Specifics of the performance issue encountered
27 |
28 | ### Environment
29 | - Version (commit hash, tag, branch)
30 |
--------------------------------------------------------------------------------
/aana/core/models/exception.py:
--------------------------------------------------------------------------------
1 | from typing import Any
2 |
3 | from pydantic import BaseModel, ConfigDict
4 |
5 |
6 | class ExceptionResponseModel(BaseModel):
7 | """This class is used to represent an exception response for 400 errors.
8 |
9 | Attributes:
10 | error (str): The error that occurred.
11 | message (str): The message of the error.
12 | data (Optional[Any]): The extra data that helps to debug the error.
13 | stacktrace (Optional[str]): The stacktrace of the error.
14 | """
15 |
16 | error: str
17 | message: str
18 | data: Any | None = None
19 | stacktrace: str | None = None
20 | model_config = ConfigDict(extra="forbid")
21 |
--------------------------------------------------------------------------------
/.github/workflows/publish_docs.yml:
--------------------------------------------------------------------------------
1 | name: Publish Docs
2 |
3 | on:
4 | release:
5 | types: [published]
6 | workflow_dispatch:
7 |
8 | jobs:
9 | publish:
10 | runs-on: ubuntu-latest
11 |
12 | steps:
13 | - name: Checkout code
14 | uses: actions/checkout@v3
15 |
16 | - name: Set up Python 3.10
17 | uses: actions/setup-python@v5
18 | with:
19 | python-version: "3.10"
20 |
21 | - name: Bootstrap uv
22 | uses: astral-sh/setup-uv@v6
23 | with:
24 | version: "latest"
25 |
26 | - name: Install dependencies
27 | run: uv sync --only-group docs
28 |
29 | - name: Deploy docs
30 | run: uv run mkdocs gh-deploy --force
31 |
--------------------------------------------------------------------------------
/aana/processors/remote.py:
--------------------------------------------------------------------------------
1 | import inspect
2 | from collections.abc import Callable
3 |
4 | import ray
5 |
6 | __all__ = ["run_remote"]
7 |
8 |
9 | def run_remote(func: Callable) -> Callable:
10 | """Wrap a function to run it remotely on Ray.
11 |
12 | Args:
13 | func (Callable): the function to wrap
14 |
15 | Returns:
16 | Callable: the wrapped function
17 | """
18 |
19 | async def generator_wrapper(*args, **kwargs):
20 | async for item in ray.remote(func).remote(*args, **kwargs):
21 | yield await item
22 |
23 | if inspect.isgeneratorfunction(func):
24 | return generator_wrapper
25 | else:
26 | return ray.remote(func).remote
27 |
--------------------------------------------------------------------------------
/docs/pages/code_standards.md:
--------------------------------------------------------------------------------
1 | # Code Standards
2 |
3 | This project uses Ruff for linting and formatting. If you want to
4 | manually run Ruff on the codebase, using uv it's
5 |
6 | ```sh
7 | uv run ruff check aana
8 | ```
9 |
10 | You can automatically fix some issues with the `--fix`
11 | and `--unsafe-fixes` options. (Be sure to install the dev
12 | dependencies: `uv sync --group dev`. )
13 |
14 | To run the auto-formatter, it's
15 |
16 | ```sh
17 | uv run ruff format aana
18 | ```
19 |
20 | (If you are running code in a non-uv environment, just leave off `uv run`.)
21 |
22 | For users of VS Code, the included `settings.json` should ensure
23 | that Ruff problems appear while you edit, and formatting is applied
24 | automatically on save.
25 |
--------------------------------------------------------------------------------
/aana/utils/file.py:
--------------------------------------------------------------------------------
1 | import hashlib
2 | from pathlib import Path
3 |
4 |
5 | def get_sha256_hash_file(filename: Path) -> str:
6 | """Compute SHA-256 hash of a file without loading it entirely in memory.
7 |
8 | Args:
9 | filename (Path): Path to the file to be hashed.
10 |
11 | Returns:
12 | str: SHA-256 hash of the file in hexadecimal format.
13 | """
14 | # Create a sha256 hash object
15 | sha256 = hashlib.sha256()
16 |
17 | # Open the file in binary mode
18 | with Path.open(filename, "rb") as f:
19 | # Read and update hash in chunks of 4K
20 | for chunk in iter(lambda: f.read(4096), b""):
21 | sha256.update(chunk)
22 |
23 | # Return the hexadecimal representation of the digest
24 | return sha256.hexdigest()
25 |
--------------------------------------------------------------------------------
/aana/core/chat/templates/llama2.jinja:
--------------------------------------------------------------------------------
1 | {% if messages[0]['role'] == 'system' %}{% set loop_messages = messages[1:] %}{% set system_message = messages[0]['content'] %}{% else %}{% set loop_messages = messages %}{% set system_message = false %}{% endif %}{% for message in loop_messages %}{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}{% endif %}{% if loop.index0 == 0 and system_message != false %}{% set content = '<>\\n' + system_message + '\\n<>\\n\\n' + message['content'] %}{% else %}{% set content = message['content'] %}{% endif %}{% if message['role'] == 'user' %}{{ bos_token + '[INST] ' + content.strip() + ' [/INST]' }}{% elif message['role'] == 'assistant' %}{{ ' ' + content.strip() + ' ' + eos_token }}{% endif %}{% endfor %}
--------------------------------------------------------------------------------
/aana/tests/files/audios/ATTRIBUTION.md:
--------------------------------------------------------------------------------
1 | # Third-Party Media Attribution
2 |
3 | This folder contains third-party media used for testing. No endorsement implied.
4 |
5 | _Last updated: 2025-08-22_
6 |
7 | | File | Source | License | Author/Owner | Changes Made |
8 | |---|---|---|---|---|
9 | | `squirrel.wav` | Video by **Nathan J Hilton** on Pexels — https://www.pexels.com/video/a-squirrel-is-sitting-on-top-of-a-table-17977045/ | Pexels License — https://www.pexels.com/license/ | Nathan J Hilton | Converted from source video to WAV for testing. |
10 | | `physicsworks.wav` | Wikimedia Commons — https://commons.wikimedia.org/wiki/File:Physicsworks.ogv (clip from MIT OCW “Work, Energy, and Universal Gravitation”, Lecture 11) | CC BY 3.0 — https://creativecommons.org/licenses/by/3.0/ | Walter Lewin | Extracted audio and transcoded to WAV for testing; trimmed for length. |
11 |
--------------------------------------------------------------------------------
/aana/alembic/script.py.mako:
--------------------------------------------------------------------------------
1 | """${message}
2 |
3 | Revision ID: ${up_revision}
4 | Revises: ${down_revision | comma,n}
5 | Create Date: ${create_date}
6 |
7 | """
8 | from typing import Sequence
9 |
10 | from alembic import op
11 | import sqlalchemy as sa
12 | ${imports if imports else ""}
13 |
14 | # revision identifiers, used by Alembic.
15 | revision: str = ${repr(up_revision)}
16 | down_revision: str | None = ${repr(down_revision)}
17 | branch_labels: str | Sequence[str] | None = ${repr(branch_labels)}
18 | depends_on: str | Sequence[str] | None = ${repr(depends_on)}
19 |
20 |
21 | def upgrade() -> None:
22 | """Upgrade database to this revision from previous."""
23 | ${upgrades if upgrades else "pass"}
24 |
25 |
26 | def downgrade() -> None:
27 | """Downgrade database from this revision to previous."""
28 | ${downgrades if downgrades else "pass"}
29 |
--------------------------------------------------------------------------------
/aana/processors/video.py:
--------------------------------------------------------------------------------
1 | from aana.core.models.audio import Audio
2 | from aana.core.models.video import Video
3 | from aana.integrations.external.av import load_audio
4 |
5 | __all__ = ["extract_audio"]
6 |
7 |
8 | def extract_audio(video: Video) -> Audio:
9 | """Extract the audio file from a Video and return an Audio object.
10 |
11 | Args:
12 | video (Video): The video file to extract audio.
13 |
14 | Returns:
15 | Audio: an Audio object containing the extracted audio.
16 | """
17 | audio_bytes = load_audio(video.path)
18 |
19 | # Only difference will be in path where WAV file will be stored
20 | # and in content but has same media_id
21 | return Audio(
22 | url=video.url,
23 | media_id=f"audio_{video.media_id}",
24 | content=audio_bytes,
25 | title=video.title,
26 | description=video.description,
27 | )
28 |
--------------------------------------------------------------------------------
/tests.dstack.yml:
--------------------------------------------------------------------------------
1 | type: task
2 |
3 | name: aana-tests
4 |
5 | backends: [runpod]
6 |
7 | image: nvidia/cuda:12.3.2-cudnn9-devel-ubuntu22.04
8 |
9 | env:
10 | - HF_TOKEN
11 |
12 | commands:
13 | - apt-get update
14 | - DEBIAN_FRONTEND=noninteractive apt-get install -y --no-install-recommends tzdata
15 | - apt-get install -y libgl1 libglib2.0-0 ffmpeg python3 python3-dev postgresql sudo
16 | - locale-gen en_US.UTF-8
17 | - export LANG="en_US.UTF-8" LANGUAGE="en_US:en" LC_ALL="en_US.UTF-8"
18 | - curl -LsSf https://astral.sh/uv/install.sh | sh
19 | - export PATH=$PATH:/root/.cargo/bin
20 | - uv sync --group tests --extra all
21 | - HF_HUB_CACHE="/models_cache" CUDA_VISIBLE_DEVICES="0" uv run pytest -vv -s
22 |
23 | volumes:
24 | - name: test-models-cache
25 | path: /models_cache
26 |
27 | max_price: 1.0
28 |
29 | resources:
30 | cpu: 9..
31 | memory: 32GB..
32 | gpu: 40GB..
--------------------------------------------------------------------------------
/aana/utils/typing.py:
--------------------------------------------------------------------------------
1 | import typing
2 | from collections.abc import AsyncGenerator
3 |
4 |
5 | def is_typed_dict(argument: type) -> bool:
6 | """Checks if a argument is a TypedDict.
7 |
8 | Arguments:
9 | argument (type): the type to check
10 |
11 | Returns:
12 | bool: True if the argument type is a TypedDict anf False if it is not.
13 | """
14 | return bool(
15 | argument and getattr(argument, "__orig_bases__", None) == (typing.TypedDict,)
16 | )
17 |
18 |
19 | def is_async_generator(argument: type) -> bool:
20 | """Checks if a argument is an AsyncGenerator.
21 |
22 | Arguments:
23 | argument (type): the type to check
24 |
25 | Returns:
26 | bool: True if the argument type is an AsyncGenerator and False if it is not.
27 | """
28 | return hasattr(argument, "__origin__") and issubclass(
29 | argument.__origin__, AsyncGenerator
30 | )
31 |
--------------------------------------------------------------------------------
/.devcontainer/devcontainer.json:
--------------------------------------------------------------------------------
1 | {
2 | "name": "Ubuntu",
3 | "build": {
4 | "dockerfile": "Dockerfile"
5 | },
6 | "capAdd": [
7 | "SYS_PTRACE"
8 | ],
9 | "features": {
10 | "ghcr.io/devcontainers/features/python:1": {
11 | "installTools": true,
12 | "version": "3.10"
13 | },
14 | "ghcr.io/itsmechlark/features/postgresql:1": {
15 | "version": "16"
16 | },
17 | "ghcr.io/va-h/devcontainers-features/uv:1": {}
18 | },
19 | "hostRequirements": {
20 | "gpu": "optional"
21 | },
22 | "securityOpt": [
23 | "seccomp=unconfined"
24 | ],
25 | "postCreateCommand": "uv sync --all-groups --extra all --extra api_service",
26 | "postStartCommand": "git config --global --add safe.directory ${containerWorkspaceFolder}",
27 | "customizations": {
28 | "vscode": {
29 | "extensions": [
30 | "charliermarsh.ruff",
31 | "ms-python.python",
32 | "ms-python.mypy-type-checker",
33 | "ms-toolsai.jupyter"
34 | ]
35 | }
36 | }
37 | }
--------------------------------------------------------------------------------
/aana/storage/session.py:
--------------------------------------------------------------------------------
1 | import logging
2 | from typing import Annotated
3 |
4 | from fastapi import Depends
5 | from sqlalchemy.ext.asyncio import AsyncSession
6 |
7 | from aana.configs.settings import settings
8 | from aana.storage.op import DatabaseSessionManager
9 |
10 | logger = logging.getLogger(__name__)
11 |
12 | __all__ = ["GetDbDependency", "get_db", "get_session"]
13 |
14 | session_manager = DatabaseSessionManager(settings)
15 |
16 |
17 | def get_session() -> AsyncSession:
18 | """Get a new SQLAlchemy Session object.
19 |
20 | Returns:
21 | AsyncSession: SQLAlchemy async session
22 | """
23 | return session_manager.session()
24 |
25 |
26 | async def get_db():
27 | """Get a database session.
28 |
29 | Returns:
30 | AsyncSession: SQLAlchemy async session
31 | """
32 | async with session_manager.session() as session:
33 | yield session
34 |
35 |
36 | GetDbDependency = Annotated[AsyncSession, Depends(get_db)]
37 | """ Dependency to get a database session. """
38 |
--------------------------------------------------------------------------------
/aana/storage/models/media.py:
--------------------------------------------------------------------------------
1 | from uuid import uuid4
2 |
3 | from sqlalchemy.orm import Mapped, mapped_column
4 |
5 | from aana.core.models.media import MediaId
6 | from aana.storage.models.base import BaseEntity, TimeStampEntity
7 |
8 |
9 | class MediaEntity(BaseEntity, TimeStampEntity):
10 | """Base ORM class for media (e.g. videos, images, etc.).
11 |
12 | This class is meant to be subclassed by other media types.
13 |
14 | Attributes:
15 | id (MediaId): Unique identifier for the media.
16 | media_type (str): The type of media (populated automatically by ORM based on `polymorphic_identity` of subclass).
17 | """
18 |
19 | __tablename__ = "media"
20 | id: Mapped[MediaId] = mapped_column(
21 | primary_key=True,
22 | default=lambda: str(uuid4()),
23 | comment="Unique identifier for the media",
24 | )
25 | media_type: Mapped[str] = mapped_column(comment="The type of media")
26 |
27 | __mapper_args__ = { # noqa: RUF012
28 | "polymorphic_identity": "media",
29 | "polymorphic_on": "media_type",
30 | }
31 |
--------------------------------------------------------------------------------
/aana/storage/models/video.py:
--------------------------------------------------------------------------------
1 | from sqlalchemy import ForeignKey
2 | from sqlalchemy.orm import Mapped, mapped_column
3 |
4 | from aana.core.models.media import MediaId
5 | from aana.storage.models.media import MediaEntity
6 |
7 |
8 | class VideoEntity(MediaEntity):
9 | """Base ORM class for videos.
10 |
11 | Attributes:
12 | id (MediaId): Unique identifier for the video.
13 | path (str): Path to the video file.
14 | url (str): URL to the video file.
15 | title (str): Title of the video.
16 | description (str): Description of the video.
17 | """
18 |
19 | __tablename__ = "video"
20 |
21 | id: Mapped[MediaId] = mapped_column(ForeignKey("media.id"), primary_key=True)
22 | path: Mapped[str] = mapped_column(comment="Path", nullable=True)
23 | url: Mapped[str] = mapped_column(comment="URL", nullable=True)
24 | title: Mapped[str] = mapped_column(comment="Title", nullable=True)
25 | description: Mapped[str] = mapped_column(comment="Description", nullable=True)
26 |
27 | __mapper_args__ = { # noqa: RUF012
28 | "polymorphic_identity": "video",
29 | }
30 |
--------------------------------------------------------------------------------
/aana/api/responses.py:
--------------------------------------------------------------------------------
1 | from typing import Any
2 |
3 | import orjson
4 | from fastapi.responses import JSONResponse
5 |
6 | from aana.utils.json import jsonify
7 |
8 |
9 | class AanaJSONResponse(JSONResponse):
10 | """Response class that uses orjson to serialize data.
11 |
12 | It has additional support for numpy arrays.
13 | """
14 |
15 | media_type = "application/json"
16 | option = None
17 |
18 | def __init__(
19 | self,
20 | content: Any,
21 | status_code: int = 200,
22 | option: int | None = orjson.OPT_SERIALIZE_NUMPY,
23 | **kwargs,
24 | ):
25 | """Initialize the response class with the orjson option."""
26 | self.option = option
27 | super().__init__(content, status_code, **kwargs)
28 |
29 | def render(self, content: Any) -> bytes:
30 | """Override the render method to use orjson.dumps instead of json.dumps."""
31 | return jsonify(content, option=self.option, as_bytes=True)
32 |
33 |
34 | class AanaConcatenatedJSONResponse(AanaJSONResponse):
35 | """Response class that uses orjson to serialize data to Concatenated JSON format."""
36 |
37 | media_type = "application/json-seq"
38 |
--------------------------------------------------------------------------------
/aana/storage/models/__init__.py:
--------------------------------------------------------------------------------
1 | # ruff: noqa: F401
2 | # We need to import all db models here and, other than in the class definitions
3 | # themselves, only import them from aana.models.db directly. The reason for
4 | # this is the way SQLAlchemy's declarative base works. You can use forward
5 | # references like `parent = reference("Parent", backreferences="child")`, but the
6 | # forward reference needs to have been resolved before the first constructor
7 | # is called so that SqlAlchemy "knows" about it.
8 | # See:
9 | # https://docs.pylonsproject.org/projects/pyramid_cookbook/en/latest/database/sqlalchemy.html#importing-all-sqlalchemy-models
10 | # (even if not using Pyramid)
11 |
12 | from aana.storage.models.base import BaseEntity
13 | from aana.storage.models.caption import CaptionEntity
14 | from aana.storage.models.media import MediaEntity
15 | from aana.storage.models.task import TaskEntity
16 | from aana.storage.models.transcript import TranscriptEntity
17 | from aana.storage.models.video import VideoEntity
18 | from aana.storage.models.webhook import WebhookEntity
19 |
20 | __all__ = [
21 | "BaseEntity",
22 | "CaptionEntity",
23 | "MediaEntity",
24 | "TranscriptEntity",
25 | "VideoEntity",
26 | ]
27 |
--------------------------------------------------------------------------------
/aana/core/models/custom_config.py:
--------------------------------------------------------------------------------
1 | import pickle
2 | from typing import Annotated
3 |
4 | from pydantic import BeforeValidator, PlainSerializer
5 |
6 | __all__ = ["CustomConfig"]
7 |
8 | CustomConfig = Annotated[
9 | dict,
10 | PlainSerializer(lambda x: pickle.dumps(x).decode("latin1"), return_type=str),
11 | BeforeValidator(
12 | lambda x: x if isinstance(x, dict) else pickle.loads(x.encode("latin1")) # noqa: S301
13 | ),
14 | ]
15 | """
16 | A custom configuration field that can be used to pass arbitrary configuration to the deployment.
17 |
18 | For example, you can define a custom configuration field in a deployment configuration like this:
19 |
20 | ```python
21 | class HfPipelineConfig(BaseModel):
22 | model_id: str
23 | task: str | None = None
24 | model_kwargs: CustomConfig = {}
25 | pipeline_kwargs: CustomConfig = {}
26 | generation_kwargs: CustomConfig = {}
27 | ```
28 |
29 | Then you can use the custom configuration field to pass a configuration to the deployment:
30 |
31 | ```python
32 | HfPipelineConfig(
33 | model_id="Salesforce/blip2-opt-2.7b",
34 | model_kwargs={
35 | "quantization_config": BitsAndBytesConfig(
36 | load_in_8bit=False, load_in_4bit=True
37 | ),
38 | },
39 | )
40 | ```
41 | """
42 |
--------------------------------------------------------------------------------
/aana/core/models/api_service.py:
--------------------------------------------------------------------------------
1 | from typing import Annotated
2 |
3 | from pydantic import BaseModel, Field
4 | from pydantic.json_schema import SkipJsonSchema
5 |
6 |
7 | class ApiKey(BaseModel):
8 | """Pydantic model for API key entity.
9 |
10 | Attributes:
11 | api_key (str): The API key.
12 | user_id (str): ID of the user who owns this API key.
13 | subscription_id (str): ID of the associated subscription.
14 | is_subscription_active (bool): Whether the subscription is active (credits are available).
15 | is_admin (bool): Whether the user is an admin.
16 | hmac_secret (str | None): The secret key for HMAC signature generation.
17 | """
18 |
19 | api_key: str
20 | user_id: str
21 | subscription_id: str
22 | is_subscription_active: bool
23 | is_admin: bool
24 | hmac_secret: str | None
25 |
26 |
27 | ApiKeyType = SkipJsonSchema[Annotated[ApiKey, Field(default=None)]]
28 | """
29 | Type with optional API key information.
30 |
31 | Can be None if API service is disabled. Otherwise, it will be an instance of `ApiKey`.
32 |
33 | Attributes:
34 | api_key (str): The API key.
35 | user_id (str): The user ID.
36 | subscription_id (str): The subscription ID.
37 | is_subscription_active (bool): Flag indicating if the subscription is active.
38 | """
39 |
--------------------------------------------------------------------------------
/.github/workflows/run_tests_with_gpu.yml:
--------------------------------------------------------------------------------
1 | name: Run Tests with GPU
2 |
3 | on:
4 | workflow_dispatch: # Allows for manual triggering
5 |
6 | concurrency:
7 | group: run-tests-gpu # Fixed group name to ensure only one instance runs
8 | cancel-in-progress: false
9 |
10 | jobs:
11 | test:
12 | runs-on: ubuntu-latest
13 |
14 | steps:
15 | - name: Set SSH permissions
16 | run: |
17 | mkdir -p ~/.ssh
18 | chmod 700 ~/.ssh
19 | sudo chown $USER:$USER ~/.ssh
20 |
21 | - name: Checkout code
22 | uses: actions/checkout@v3
23 |
24 | - name: Set up Python 3.10
25 | uses: actions/setup-python@v5
26 | with:
27 | python-version: "3.10"
28 |
29 | - name: Install and configure dstack
30 | run: |
31 | pip install dstack
32 | dstack config --url https://sky.dstack.ai --project ${{ secrets.DSTACK_PROJECT }} --token ${{ secrets.DSTACK_TOKEN }}
33 | dstack init
34 |
35 | - name: Run tests with GPU
36 | run: |
37 | DSTACK_CLI_LOG_LEVEL=DEBUG HF_TOKEN=${{ secrets.HF_TOKEN }} dstack apply -f tests.dstack.yml --force -y
38 |
39 | - name: Extract pytest logs
40 | if: ${{ always() }}
41 | run: |
42 | dstack logs aana-tests | sed -n '/============================= test session starts ==============================/,$p'
43 |
--------------------------------------------------------------------------------
/docs/pages/settings.md:
--------------------------------------------------------------------------------
1 | ---
2 | hide:
3 | - navigation
4 | ---
5 |
6 |
12 |
13 |
14 | # Settings
15 |
16 | Here are the environment variables that can be used to configure the Aaana SDK:
17 |
18 | - TMP_DATA_DIR: The directory to store temporary data. Default: `/tmp/aana`.
19 | - NUM_WORKERS: The number of request workers. Default: `2`.
20 | - DB_CONFIG: The database configuration in the format `{"datastore_type": "sqlite", "datastore_config": {"path": "/path/to/sqlite.db"}}`. Currently only SQLite and PostgreSQL are supported. Default: `{"datastore_type": "sqlite", "datastore_config": {"path": "/var/lib/aana_data"}}`.
21 | - HF_HUB_ENABLE_HF_TRANSFER: If set to `1`, the HuggingFace Transformers will use the HF Transfer library to download the models from HuggingFace Hub to speed up the process. Recommended to always set to it `1`. Default: `0`.
22 | - HF_TOKEN: The HuggingFace API token to download the models from HuggingFace Hub, required for private or gated models.
23 | - TEST__SAVE_EXPECTED_OUTPUT: If set to `True`, the expected output will be saved when running the tests. Useful for creating new development tests. Default: `False`.
24 |
25 |
26 | See [reference documentation](./../reference/settings.md#aana.configs.Settings) for more advanced settings.
27 |
--------------------------------------------------------------------------------
/aana/core/models/types.py:
--------------------------------------------------------------------------------
1 | from enum import Enum
2 |
3 | import torch
4 |
5 | __all__ = ["Dtype"]
6 |
7 |
8 | class Dtype(str, Enum):
9 | """Data types.
10 |
11 | Possible values are "auto", "float32", "float16", "bfloat16" and "int8".
12 |
13 | Attributes:
14 | AUTO (str): auto
15 | FLOAT32 (str): float32
16 | FLOAT16 (str): float16
17 | BFLOAT16 (str): bfloat16
18 | INT8 (str): int8
19 | """
20 |
21 | AUTO = "auto"
22 | FLOAT32 = "float32"
23 | FLOAT16 = "float16"
24 | BFLOAT16 = "bfloat16"
25 | INT8 = "int8"
26 |
27 | def to_torch(self) -> torch.dtype | str:
28 | """Convert the instance's dtype to a torch dtype.
29 |
30 | Returns:
31 | Union[torch.dtype, str]: the torch dtype or "auto"
32 |
33 | Raises:
34 | ValueError: if the dtype is unknown
35 | """
36 | match self.value:
37 | case self.AUTO:
38 | return "auto"
39 | case self.FLOAT32:
40 | return torch.float32
41 | case self.FLOAT16:
42 | return torch.float16
43 | case self.BFLOAT16:
44 | return torch.bfloat16
45 | case self.INT8:
46 | return torch.int8
47 | case _:
48 | raise ValueError(f"Unknown dtype: {self}") # noqa: TRY003
49 |
--------------------------------------------------------------------------------
/.github/workflows/tests.yml:
--------------------------------------------------------------------------------
1 | name: Tests and Linting
2 |
3 | on:
4 | push:
5 | branches:
6 | - '**' # Runs on push to any branch
7 | pull_request:
8 | branches:
9 | - '**' # Runs on pull requests to any branch
10 | workflow_dispatch: # Allows for manual triggering
11 |
12 | jobs:
13 | build:
14 |
15 | runs-on: ubuntu-latest
16 | strategy:
17 | fail-fast: false
18 | matrix:
19 | python-version: ["3.10", "3.11", "3.12"]
20 |
21 | steps:
22 | - name: Set up Python ${{ matrix.python-version }}
23 | uses: actions/setup-python@v5
24 | with:
25 | python-version: ${{ matrix.python-version }}
26 | - name: Display Python version
27 | run: python -c "import sys; print(sys.version)"
28 | - name: Checkout code
29 | uses: actions/checkout@v3
30 | - name: Bootstrap uv
31 | uses: astral-sh/setup-uv@v6
32 | with:
33 | version: "latest"
34 | - name: Install dependencies
35 | run: |
36 | uv sync --group dev --group tests --extra all
37 | sudo apt-get update
38 | sudo apt-get install ffmpeg
39 | - name: Install postgres
40 | uses: ikalnytskyi/action-setup-postgres@v6
41 | - name: Run Ruff Check
42 | run: uv run ruff check
43 | - name: Test with pytest
44 | if: always()
45 | env:
46 | HF_TOKEN: ${{ secrets.HF_TOKEN }}
47 | run: uv run pytest -vv
48 |
--------------------------------------------------------------------------------
/aana/storage/types.py:
--------------------------------------------------------------------------------
1 | import datetime
2 | from typing import TypeAlias
3 |
4 | from sqlalchemy import String
5 | from sqlalchemy.types import DateTime, TypeDecorator
6 |
7 | MediaIdSqlType: TypeAlias = String(36)
8 |
9 |
10 | class TimezoneAwareDateTime(TypeDecorator):
11 | """A custom SQLAlchemy type decorator for timezone-aware datetime objects.
12 |
13 | This implementation converts naive datetime objects to UTC-aware datetime
14 | objects when binding parameters and when loading from the database.
15 | """
16 |
17 | impl = DateTime
18 | cache_ok = True
19 |
20 | def process_bind_param(self, value, dialect):
21 | """Convert naive datetime objects to UTC-aware datetime before storing."""
22 | if value is None:
23 | return value
24 | # If the datetime is naive, assume it is UTC and set the tzinfo.
25 | if value.tzinfo is None:
26 | value = value.replace(tzinfo=datetime.timezone.utc)
27 | else:
28 | # Otherwise, normalize to UTC.
29 | value = value.astimezone(datetime.timezone.utc)
30 | return value
31 |
32 | def process_result_value(self, value, dialect):
33 | """Ensure that datetime objects loaded from the database are UTC-aware."""
34 | if value is None:
35 | return value
36 | # If value is naive (as may happen with SQLite), assume it's in UTC.
37 | if value.tzinfo is None:
38 | value = value.replace(tzinfo=datetime.timezone.utc)
39 | return value
40 |
--------------------------------------------------------------------------------
/aana/tests/db/datastore/test_config.py:
--------------------------------------------------------------------------------
1 | # ruff: noqa: S101
2 | import pytest
3 |
4 | from aana.configs.db import DbSettings, DbType
5 | from aana.configs.settings import settings as aana_settings
6 | from aana.storage.op import DatabaseSessionManager
7 |
8 |
9 | def test_datastore_config(db_session_manager):
10 | """Tests datastore config for PostgreSQL and SQLite."""
11 | engine = db_session_manager._engine
12 | if db_session_manager._db_config.datastore_type == DbType.POSTGRESQL:
13 | assert engine.name == "postgresql"
14 | assert str(engine.url).startswith("postgresql+asyncpg://")
15 | elif db_session_manager._db_config.datastore_type == DbType.SQLITE:
16 | assert engine.name == "sqlite"
17 | assert str(engine.url).startswith("sqlite+aiosqlite://")
18 | else:
19 | raise AssertionError("Unsupported database type") # noqa: TRY003
20 |
21 |
22 | def test_nonexistent_datastore_config():
23 | """Tests that datastore config errors on unsupported DB types."""
24 | db_settings = DbSettings(
25 | **{
26 | "datastore_type": "oracle",
27 | "datastore_config": {
28 | "host": "0.0.0.0", # noqa: S104
29 | "port": "5432",
30 | "database": "oracle",
31 | "user": "oracle",
32 | "password": "bogus",
33 | },
34 | }
35 | )
36 | aana_settings.db_config = db_settings
37 | with pytest.raises(ValueError):
38 | DatabaseSessionManager(aana_settings)
39 |
--------------------------------------------------------------------------------
/.github/workflows/publish.yml:
--------------------------------------------------------------------------------
1 | name: Publish Python Package
2 |
3 | on:
4 | release:
5 | types: [published]
6 | workflow_dispatch:
7 | inputs:
8 | publish_target:
9 | description: 'Select the target PyPI repository'
10 | required: true
11 | default: 'testpypi'
12 | type: choice
13 | options:
14 | - pypi
15 | - testpypi
16 |
17 | jobs:
18 | publish:
19 | runs-on: ubuntu-latest
20 |
21 | steps:
22 | - name: Checkout code
23 | uses: actions/checkout@v3
24 |
25 | - name: Set up Python 3.10
26 | uses: actions/setup-python@v5
27 | with:
28 | python-version: "3.10"
29 |
30 | - name: Bootstrap uv
31 | uses: astral-sh/setup-uv@v6
32 | with:
33 | version: "latest"
34 |
35 | - name: Install dependencies
36 | run: uv sync
37 |
38 | - name: Build the package
39 | run: uv build
40 |
41 | - name: Publish to PyPI
42 | if: github.event_name == 'release' || (github.event_name == 'workflow_dispatch' && github.event.inputs.publish_target == 'pypi')
43 | env:
44 | UV_PUBLISH_USERNAME: "__token__"
45 | UV_PUBLISH_PASSWORD: ${{ secrets.PYPI_API_TOKEN }}
46 | run: |
47 | uv publish
48 |
49 | - name: Publish to Test PyPI
50 | if: github.event_name == 'workflow_dispatch' && github.event.inputs.publish_target == 'testpypi'
51 | env:
52 | UV_PUBLISH_USERNAME: "__token__"
53 | UV_PUBLISH_PASSWORD: ${{ secrets.TEST_PYPI_API_TOKEN }}
54 | run: |
55 | uv publish --publish-url https://test.pypi.org/legacy/
56 |
--------------------------------------------------------------------------------
/aana/alembic/versions/b9860676dd49_set_server_default_for_task_completed_.py:
--------------------------------------------------------------------------------
1 | """Set server default for task.completed_at and task.assigned_at to none and add num_retries.
2 |
3 | Revision ID: b9860676dd49
4 | Revises: 5ad873484aa3
5 | Create Date: 2024-08-22 07:54:55.921710
6 |
7 | """
8 | from collections.abc import Sequence
9 |
10 | import sqlalchemy as sa
11 | from alembic import op
12 |
13 | # revision identifiers, used by Alembic.
14 | revision: str = "b9860676dd49"
15 | down_revision: str | None = "5ad873484aa3"
16 | branch_labels: str | Sequence[str] | None = None
17 | depends_on: str | Sequence[str] | None = None
18 |
19 |
20 | def upgrade() -> None:
21 | """Upgrade database to this revision from previous."""
22 | with op.batch_alter_table("tasks", schema=None) as batch_op:
23 | batch_op.alter_column(
24 | "completed_at",
25 | server_default=None,
26 | )
27 | batch_op.alter_column(
28 | "assigned_at",
29 | server_default=None,
30 | )
31 | batch_op.add_column(
32 | sa.Column(
33 | "num_retries",
34 | sa.Integer(),
35 | nullable=False,
36 | comment="Number of retries",
37 | server_default=sa.text("0"),
38 | )
39 | )
40 |
41 | # ### end Alembic commands ###
42 |
43 |
44 | def downgrade() -> None:
45 | """Downgrade database from this revision to previous."""
46 | with op.batch_alter_table("tasks", schema=None) as batch_op:
47 | batch_op.drop_column("num_retries")
48 |
49 | # ### end Alembic commands ###
50 |
--------------------------------------------------------------------------------
/aana/tests/files/videos/ATTRIBUTION.md:
--------------------------------------------------------------------------------
1 | # Third-Party Media Attribution
2 |
3 | This folder contains third-party media used for testing. No endorsement implied.
4 |
5 | _Last updated: 2025-08-22_
6 |
7 | | File | Source | License | Author/Owner | Changes Made |
8 | |---|---|---|---|---|
9 | | `squirrel.mp4` | Video by **Nathan J Hilton** on Pexels — https://www.pexels.com/video/a-squirrel-is-sitting-on-top-of-a-table-17977045/ | Pexels License — https://www.pexels.com/license/ | Nathan J Hilton | Downloaded for testing (may have been re-encoded; no substantive changes). |
10 | | `squirrel_no_audio.mp4` | Video by **Nathan J Hilton** on Pexels — https://www.pexels.com/video/a-squirrel-is-sitting-on-top-of-a-table-17977045/ | Pexels License — https://www.pexels.com/license/ | Nathan J Hilton | Audio track removed for testing; re-encoded. |
11 | | `physicsworks.webm` | Wikimedia Commons — https://commons.wikimedia.org/wiki/File:Physicsworks.ogv (clip from MIT OCW “Work, Energy, and Universal Gravitation”, Lecture 11) | CC BY 3.0 — https://creativecommons.org/licenses/by/3.0/ | Walter Lewin | Transcoded to WEBM for testing; may be trimmed for length. |
12 | | `physicsworks_audio.webm` | Wikimedia Commons — https://commons.wikimedia.org/wiki/File:Physicsworks.ogv (clip from MIT OCW “Work, Energy, and Universal Gravitation”, Lecture 11) | CC BY 3.0 — https://creativecommons.org/licenses/by/3.0/ | Walter Lewin | Extracted audio only and transcoded to WEBM/Opus for testing; trimmed for length. |
13 |
14 | **Additional note (Commons VRT):** The Commons file page indicates a VRT permission ticket (#2011051010013473) confirming publication under the stated terms.
15 |
--------------------------------------------------------------------------------
/aana/integrations/haystack/remote_haystack_component.py:
--------------------------------------------------------------------------------
1 |
2 | from haystack import component
3 |
4 | from aana.deployments.aana_deployment_handle import AanaDeploymentHandle
5 | from aana.utils.asyncio import run_async
6 |
7 |
8 | @component
9 | class RemoteHaystackComponent:
10 | """A component that connects to a remote Haystack component created by HaystackComponentDeployment.
11 |
12 | Attributes:
13 | deployment_name (str): The name of the deployment to use.
14 | """
15 |
16 | def __init__(
17 | self,
18 | deployment_name: str,
19 | ):
20 | """Initialize the component.
21 |
22 | Args:
23 | deployment_name (str): The name of the deployment to use.
24 | """
25 | self.deployment_name = deployment_name
26 |
27 | def warm_up(self):
28 | """Warm up the component.
29 |
30 | This will properly initialize the component by creating a handle to the deployment
31 | and setting the input and output types.
32 | """
33 | if hasattr(self, "handle"):
34 | return
35 | self.handle = run_async(AanaDeploymentHandle.create(self.deployment_name))
36 | sockets = run_async(self.handle.get_sockets())
37 | component.set_input_types(
38 | self, **{socket.name: socket.type for socket in sockets["input"].values()}
39 | )
40 | component.set_output_types(
41 | self, **{socket.name: socket.type for socket in sockets["output"].values()}
42 | )
43 |
44 | def run(self, **data):
45 | """Run the component on the input data."""
46 | return run_async(self.handle.run(**data))
47 |
--------------------------------------------------------------------------------
/aana/storage/models/webhook.py:
--------------------------------------------------------------------------------
1 | import uuid
2 | from enum import Enum
3 |
4 | from sqlalchemy import JSON, UUID
5 | from sqlalchemy.dialects.postgresql import JSONB
6 | from sqlalchemy.orm import Mapped, mapped_column
7 |
8 | from aana.storage.models.base import BaseEntity, TimeStampEntity
9 |
10 |
11 | class WebhookEventType(str, Enum):
12 | """Enum for webhook event types."""
13 |
14 | TASK_COMPLETED = "task.completed"
15 | TASK_FAILED = "task.failed"
16 | TASK_STARTED = "task.started"
17 |
18 |
19 | class WebhookEntity(BaseEntity, TimeStampEntity):
20 | """Table for webhook items."""
21 |
22 | __tablename__ = "webhooks"
23 |
24 | id: Mapped[uuid.UUID] = mapped_column(
25 | UUID, primary_key=True, default=uuid.uuid4, comment="Webhook ID"
26 | )
27 | user_id: Mapped[str | None] = mapped_column(
28 | nullable=True, index=True, comment="The user ID associated with the webhook"
29 | )
30 | url: Mapped[str] = mapped_column(
31 | nullable=False, comment="The URL to which the webhook will send requests"
32 | )
33 | events: Mapped[list[str]] = mapped_column(
34 | JSON().with_variant(JSONB, "postgresql"),
35 | nullable=False,
36 | comment="List of events the webhook is subscribed to. If the list is empty, the webhook is subscribed to all events.",
37 | )
38 |
39 | def __repr__(self) -> str:
40 | """String representation of the webhook."""
41 | return (
42 | f""
45 | )
46 |
--------------------------------------------------------------------------------
/aana/alembic/versions/d40eba8ebc4c_added_user_id_to_tasks.py:
--------------------------------------------------------------------------------
1 | """Added user_id to tasks.
2 |
3 | Revision ID: d40eba8ebc4c
4 | Revises: b9860676dd49
5 | Create Date: 2025-01-23 14:17:04.394863
6 |
7 | """
8 | from collections.abc import Sequence
9 |
10 | import sqlalchemy as sa
11 | from alembic import op
12 |
13 | # revision identifiers, used by Alembic.
14 | revision: str = "d40eba8ebc4c"
15 | down_revision: str | None = "b9860676dd49"
16 | branch_labels: str | Sequence[str] | None = None
17 | depends_on: str | Sequence[str] | None = None
18 |
19 |
20 | def upgrade() -> None:
21 | """Upgrade database to this revision from previous."""
22 | # ### commands auto generated by Alembic - please adjust! ###
23 | with op.batch_alter_table("tasks", schema=None) as batch_op:
24 | batch_op.add_column(
25 | sa.Column(
26 | "user_id",
27 | sa.String(),
28 | nullable=True,
29 | comment="ID of the user who launched the task",
30 | )
31 | )
32 | batch_op.create_index(batch_op.f("ix_tasks_status"), ["status"], unique=False)
33 | batch_op.create_index(batch_op.f("ix_tasks_user_id"), ["user_id"], unique=False)
34 |
35 | # ### end Alembic commands ###
36 |
37 |
38 | def downgrade() -> None:
39 | """Downgrade database from this revision to previous."""
40 | # ### commands auto generated by Alembic - please adjust! ###
41 | with op.batch_alter_table("tasks", schema=None) as batch_op:
42 | batch_op.drop_index(batch_op.f("ix_tasks_user_id"))
43 | batch_op.drop_index(batch_op.f("ix_tasks_status"))
44 | batch_op.drop_column("user_id")
45 |
46 | # ### end Alembic commands ###
47 |
--------------------------------------------------------------------------------
/aana/storage/repository/video.py:
--------------------------------------------------------------------------------
1 | from typing import TypeVar
2 |
3 | from sqlalchemy.ext.asyncio import AsyncSession
4 |
5 | from aana.core.models.media import MediaId
6 | from aana.core.models.video import Video, VideoMetadata
7 | from aana.storage.models import VideoEntity
8 | from aana.storage.repository.media import MediaRepository
9 |
10 | V = TypeVar("V", bound=VideoEntity)
11 |
12 |
13 | class VideoRepository(MediaRepository[V]):
14 | """Repository for videos."""
15 |
16 | def __init__(self, session: AsyncSession, model_class: type[V] = VideoEntity):
17 | """Constructor."""
18 | super().__init__(session, model_class)
19 |
20 | async def save(self, video: Video) -> dict:
21 | """Saves a video to datastore.
22 |
23 | Args:
24 | video (Video): The video object.
25 |
26 | Returns:
27 | dict: The dictionary with media ID.
28 | """
29 | video_entity = VideoEntity(
30 | id=video.media_id,
31 | path=str(video.path),
32 | url=video.url,
33 | title=video.title,
34 | description=video.description,
35 | )
36 |
37 | await self.create(video_entity)
38 | return {
39 | "media_id": video_entity.id,
40 | }
41 |
42 | async def get_metadata(self, media_id: MediaId) -> VideoMetadata:
43 | """Get the metadata of a video.
44 |
45 | Args:
46 | media_id (MediaId): The media ID.
47 |
48 | Returns:
49 | VideoMetadata: The video metadata.
50 | """
51 | entity: VideoEntity = await self.read(media_id)
52 | return VideoMetadata(title=entity.title, description=entity.description)
53 |
--------------------------------------------------------------------------------
/aana/utils/asyncio.py:
--------------------------------------------------------------------------------
1 | import asyncio
2 | import threading
3 | from collections.abc import Coroutine
4 | from typing import Any, TypeVar
5 |
6 | __all__ = ["run_async"]
7 |
8 | T = TypeVar("T")
9 |
10 |
11 | def run_async(coro: Coroutine[Any, Any, T]) -> T:
12 | """Run a coroutine in a thread if the current thread is running an event loop.
13 |
14 | Otherwise, run the coroutine in the current asyncio loop.
15 |
16 | Useful when you want to run an async function in a non-async context.
17 |
18 | From: https://stackoverflow.com/a/75094151
19 |
20 | Args:
21 | coro (Coroutine): The coroutine to run.
22 |
23 | Returns:
24 | T: The result of the coroutine.
25 | """
26 |
27 | class RunThread(threading.Thread):
28 | """Run a coroutine in a thread."""
29 |
30 | def __init__(self, coro: Coroutine[Any, Any, T]):
31 | """Initialize the thread."""
32 | self.coro = coro
33 | self.result: T | None = None
34 | self.exception: Exception | None = None
35 | super().__init__()
36 |
37 | def run(self):
38 | """Run the coroutine."""
39 | try:
40 | self.result = asyncio.run(self.coro)
41 | except Exception as e:
42 | self.exception = e
43 |
44 | try:
45 | loop = asyncio.get_running_loop()
46 | except RuntimeError:
47 | loop = None
48 |
49 | if loop and loop.is_running():
50 | thread = RunThread(coro)
51 | thread.start()
52 | thread.join()
53 | if thread.exception:
54 | raise thread.exception
55 | return thread.result
56 | else:
57 | return asyncio.run(coro)
58 |
--------------------------------------------------------------------------------
/aana/tests/units/test_media_id.py:
--------------------------------------------------------------------------------
1 | # ruff: noqa: S101
2 |
3 | import pytest
4 | from pydantic import BaseModel, ValidationError
5 |
6 | from aana.core.models.media import MediaId
7 |
8 |
9 | class TestModel(BaseModel):
10 | """Test model for media id."""
11 |
12 | media_id: MediaId
13 |
14 |
15 | def test_media_id_creation():
16 | """Test that a media id can be created."""
17 | media_id = TestModel(media_id="foo").media_id
18 | assert media_id == "foo"
19 |
20 | # Validation only happens when the model is created
21 | # because MediaId is just an annotated string
22 | with pytest.raises(ValueError):
23 | TestModel().media_id # noqa: B018
24 |
25 | with pytest.raises(ValidationError):
26 | TestModel(media_id="").media_id # noqa: B018
27 |
28 | # MediaId is a string with a maximum length of 36
29 | with pytest.raises(ValidationError):
30 | TestModel(media_id="a" * 37).media_id # noqa: B018
31 |
32 |
33 | def test_valid_media_ids():
34 | """Test that valid media ids are accepted."""
35 | valid_ids = ["abc123", "abc-123", "abc_123", "A1B2_C3", "123456", "a_b-c"]
36 | for media_id in valid_ids:
37 | model = TestModel(media_id=media_id)
38 | assert model.media_id == media_id
39 |
40 |
41 | def test_invalid_media_ids():
42 | """Test that invalid media ids are rejected."""
43 | invalid_ids = [
44 | "abc 123", # contains a space
45 | "abc@123", # contains an invalid character (@)
46 | "abc#123", # contains an invalid character (#)
47 | "abc.123", # contains an invalid character (.)
48 | ]
49 | for media_id in invalid_ids:
50 | with pytest.raises(ValidationError):
51 | TestModel(media_id=media_id)
52 |
--------------------------------------------------------------------------------
/aana/storage/repository/transcript.py:
--------------------------------------------------------------------------------
1 | from typing import TypeVar
2 |
3 | from sqlalchemy.ext.asyncio import AsyncSession
4 |
5 | from aana.core.models.asr import (
6 | AsrSegments,
7 | AsrTranscription,
8 | AsrTranscriptionInfo,
9 | )
10 | from aana.storage.models.transcript import TranscriptEntity
11 | from aana.storage.repository.base import BaseRepository
12 |
13 | T = TypeVar("T", bound=TranscriptEntity)
14 |
15 |
16 | class TranscriptRepository(BaseRepository[T]):
17 | """Repository for Transcripts."""
18 |
19 | def __init__(self, session: AsyncSession, model_class: type[T] = TranscriptEntity):
20 | """Constructor."""
21 | super().__init__(session, model_class)
22 |
23 | async def save(
24 | self,
25 | model_name: str,
26 | transcription_info: AsrTranscriptionInfo,
27 | transcription: AsrTranscription,
28 | segments: AsrSegments,
29 | ) -> TranscriptEntity:
30 | """Save transcripts.
31 |
32 | Args:
33 | model_name (str): The name of the model used to generate the transcript.
34 | transcription_info (AsrTranscriptionInfo): The ASR transcription info.
35 | transcription (AsrTranscription): The ASR transcription.
36 | segments (AsrSegments): The ASR segments.
37 |
38 | Returns:
39 | TranscriptEntity: The transcript entity.
40 | """
41 | transcript_entity = TranscriptEntity.from_asr_output(
42 | model_name=model_name,
43 | transcription=transcription,
44 | segments=segments,
45 | info=transcription_info,
46 | )
47 | self.session.add(transcript_entity)
48 | await self.session.commit()
49 | return transcript_entity
50 |
--------------------------------------------------------------------------------
/aana/exceptions/core.py:
--------------------------------------------------------------------------------
1 | from typing import Any
2 |
3 |
4 | class BaseException(Exception): # noqa: A001
5 | """Base class for SDK exceptions."""
6 |
7 | def __init__(self, **kwargs: Any) -> None:
8 | """Initialise Exception."""
9 | super().__init__()
10 | self.extra = kwargs
11 |
12 | def __str__(self) -> str:
13 | """Return a string representation of the exception.
14 |
15 | String is defined as follows:
16 | ```
17 | (extra_key1=extra_value1, extra_key2=extra_value2, ...)
18 | ```
19 | """
20 | class_name = self.__class__.__name__
21 | extra_str_list = []
22 | for key, value in self.extra.items():
23 | extra_str_list.append(f"{key}={value}")
24 | extra_str = ", ".join(extra_str_list)
25 | return f"{class_name}({extra_str})"
26 |
27 | def get_data(self) -> dict[str, Any]:
28 | """Get the data to be returned to the client.
29 |
30 | Returns:
31 | Dict[str, Any]: data to be returned to the client
32 | """
33 | data = self.extra.copy()
34 | return data
35 |
36 | def add_extra(self, data: dict[str, Any]) -> None:
37 | """Add extra data to the exception.
38 |
39 | This data will be returned to the user as part of the response.
40 |
41 | How to use: in the exception handler, add the extra data to the exception and raise it again.
42 |
43 | Example:
44 | ```
45 | try:
46 | ...
47 | except BaseException as e:
48 | e.add_extra({'extra_key': 'extra_value'})
49 | raise e
50 | ```
51 |
52 | Args:
53 | data (dict[str, Any]): dictionary containing the extra data
54 | """
55 | self.extra.update(data)
56 |
--------------------------------------------------------------------------------
/docs/pages/serve_config_files.md:
--------------------------------------------------------------------------------
1 | # Serve Config Files
2 |
3 | The [Serve Config Files](https://docs.ray.io/en/latest/serve/production-guide/config.html#serve-config-files) is the recommended way to deploy and update your applications in production. Aana SDK provides a way to build the Serve Config Files for the Aana applications.
4 |
5 | ## Building Serve Config Files
6 |
7 | To build the Serve config file, run the following command:
8 |
9 | ```bash
10 | aana build :
11 | ```
12 |
13 | For example:
14 |
15 | ```bash
16 | aana build aana_chat_with_video.app:aana_app
17 | ```
18 |
19 | The command will generate the Serve Config file and App Config file and save them in the project directory. You can then use these files to deploy the application using the Ray Serve CLI.
20 |
21 | ## Deploying with Serve Config Files
22 |
23 | When you are running the Aana application using the Serve config files, you need to run the migrations to create the database tables for the application. To run the migrations, use the following command:
24 |
25 | ```bash
26 | aana migrate :
27 | ```
28 |
29 | For example:
30 |
31 | ```bash
32 | aana migrate aana_chat_with_video.app:aana_app
33 | ```
34 |
35 | Before deploying the application, make sure you have the Ray cluster running. If you want to start a new Ray cluster on a single machine, you can use the following command:
36 |
37 | ```bash
38 | ray start --head
39 | ```
40 |
41 | For more info on how to start a Ray cluster, see the [Ray documentation](https://docs.ray.io/en/latest/ray-core/starting-ray.html#starting-ray-via-the-cli-ray-start).
42 |
43 | To deploy the application using the Serve config files, use [`serve deploy`](https://docs.ray.io/en/latest/serve/advanced-guides/deploy-vm.html#serve-in-production-deploying) command provided by Ray Serve. For example:
44 |
45 | ```bash
46 | serve deploy config.yaml
47 | ```
48 |
--------------------------------------------------------------------------------
/docs/reference/index.md:
--------------------------------------------------------------------------------
1 | # Reference Documentation (Code API)
2 |
3 | This section contains the reference documentation for the public API of the project. Quick links to the most important classes and functions are provided below.
4 |
5 | ## SDK
6 |
7 | [`aana.AanaSDK`](./sdk.md#aana.AanaSDK) - The main class for interacting with the Aana SDK. Use it to register endpoints and deployments and to start the server.
8 |
9 | ## Endpoint
10 |
11 | [`aana.api.Endpoint`](./endpoint.md#aana.api.Endpoint) - The base class for defining endpoints in the Aana SDK.
12 |
13 | ## Deployments
14 |
15 | [Deployments](./deployments.md) contains information about how to deploy models with a number of predefined deployments for such models as Whisper, LLMs, Hugging Face models, and more.
16 |
17 | ## Models
18 |
19 | - [Media Models](./models/media.md) - Models for working with media types like audio, video, and images.
20 | - [Automatic Speech Recognition (ASR) Models](./models/asr.md) - Models for working with automatic speech recognition (ASR) models.
21 | - [Caption Models](./models/captions.md) - Models for working with captions.
22 | - [Chat Models](./models/chat.md) - Models for working with chat models.
23 | - [Image Chat Models](./models/image_chat.md) - Models for working with visual-text content for visual-language models.
24 | - [Custom Config](./models/custom_config.md) - Custom Config model can be used to pass arbitrary configuration to the deployment.
25 | - [Sampling Models](./models/sampling.md) - Contains Sampling Parameters model which can be used to pass sampling parameters to the LLM models.
26 | - [Time Models](./models/time.md) - Contains time models like TimeInterval.
27 | - [Types Models](./models/types.md) - Contains types models like Dtype.
28 | - [Video Models](./models/video.md) - Models for working with video files.
29 | - [Whisper Models](./models/whisper.md) - Models for working with whispers.
30 |
31 |
--------------------------------------------------------------------------------
/aana/tests/units/test_rate_limiter.py:
--------------------------------------------------------------------------------
1 | import sys
2 |
3 | import pytest
4 |
5 | from aana.api.event_handlers.event_manager import EventManager
6 | from aana.api.event_handlers.rate_limit_handler import RateLimitHandler
7 | from aana.exceptions.runtime import TooManyRequestsException
8 |
9 |
10 | def test_rate_limiter_single():
11 | """Tests that the rate limiter raises if the rate limit is exceeded."""
12 | event_manager = EventManager()
13 | rate_limiter = RateLimitHandler(1, 1.0)
14 |
15 | event_manager.register_handler_for_events(rate_limiter, ["foo"])
16 | event_manager.handle("foo")
17 | with pytest.raises(TooManyRequestsException):
18 | event_manager.handle("foo")
19 |
20 |
21 | def test_rate_limiter_multiple():
22 | """Tests that the rate limiter raises if the rate limit is exceeded."""
23 | event_manager = EventManager()
24 | rate_limiter = RateLimitHandler(1, 1.0)
25 |
26 | event_manager.register_handler_for_events(rate_limiter, ["foo", "bar"])
27 | event_manager.handle("foo")
28 | with pytest.raises(TooManyRequestsException):
29 | event_manager.handle("bar")
30 |
31 |
32 | def test_rate_limiter_noraise():
33 | """Tests that the rate limiter raises if the rate limit is exceeded."""
34 | event_manager = EventManager()
35 | # Smallest possible value such that 1+x>1, should never run into rate limit
36 | rate_limiter = RateLimitHandler(1, sys.float_info.epsilon)
37 | event_manager.register_handler_for_events(rate_limiter, ["foo"])
38 | for _ in range(10):
39 | event_manager.handle("foo")
40 |
41 |
42 | def test_event_manager_discriminates():
43 | """Tests that the rate limiter raises if the rate limit is exceeded."""
44 | event_manager = EventManager()
45 | rate_limiter = RateLimitHandler(1, sys.float_info.max)
46 | event_manager.register_handler_for_events(rate_limiter, ["bar"])
47 | # Should not raise
48 | event_manager.handle("foo")
49 |
--------------------------------------------------------------------------------
/aana/deployments/haystack_component_deployment.py:
--------------------------------------------------------------------------------
1 | from typing import Any
2 |
3 | from pydantic import BaseModel
4 | from ray import serve
5 |
6 | from aana.deployments.base_deployment import BaseDeployment, exception_handler
7 | from aana.utils.core import import_from_path
8 |
9 |
10 | class HaystackComponentDeploymentConfig(BaseModel):
11 | """Configuration for the HaystackComponentDeployment.
12 |
13 | Attributes:
14 | component (str): The path to the Haystack component to deploy.
15 | params (dict): The parameters to pass to the component on initialization (model etc).
16 | """
17 |
18 | component: str
19 | params: dict[str, Any]
20 |
21 |
22 | @serve.deployment
23 | class HaystackComponentDeployment(BaseDeployment):
24 | """Deployment to deploy a Haystack component."""
25 |
26 | async def apply_config(self, config: dict[str, Any]):
27 | """Apply the configuration.
28 |
29 | The method is called when the deployment is created or updated.
30 |
31 | It creates the Haystack component and warms it up.
32 |
33 | The configuration should conform to the HaystackComponentDeploymentConfig schema.
34 | """
35 | config_obj = HaystackComponentDeploymentConfig(**config)
36 |
37 | self.params = config_obj.params
38 | self.component_path = config_obj.component
39 | self.component = import_from_path(config_obj.component)(**config_obj.params)
40 |
41 | self.component.warm_up()
42 |
43 | @exception_handler
44 | async def run(self, **data: dict[str, Any]) -> dict[str, Any]:
45 | """Run the model on the input data."""
46 | return self.component.run(**data)
47 |
48 | async def get_sockets(self):
49 | """Get the input and output sockets of the component."""
50 | return {
51 | "output": self.component.__haystack_output__._sockets_dict,
52 | "input": self.component.__haystack_input__._sockets_dict,
53 | }
54 |
--------------------------------------------------------------------------------
/docs/pages/docker.md:
--------------------------------------------------------------------------------
1 | # Run with Docker
2 |
3 | We provide a docker-compose configuration to run the application in a Docker container in [Aana App Template](https://github.com/mobiusml/aana_app_template/blob/main?tab=readme-ov-file#running-with-docker).
4 |
5 | Requirements:
6 |
7 | - Docker Engine >= 26.1.0
8 | - Docker Compose >= 1.29.2
9 | - NVIDIA Driver >= 525.60.13
10 |
11 | You can edit the [Dockerfile](https://github.com/mobiusml/aana_app_template/blob/main/Dockerfile) to assemble the image as you desire and
12 | and [docker-compose file](https://github.com/mobiusml/aana_app_template/blob/main/docker-compose.yaml) for container instances and their environment variables.
13 |
14 | To run the application, simply run the following command:
15 |
16 | ```bash
17 | docker-compose up
18 | ```
19 |
20 | The application will be accessible at `http://localhost:8000` on the host server.
21 |
22 |
23 | !!! warning
24 |
25 | If your applications requires GPU to run, you need to specify which GPU to use.
26 |
27 | The applications will detect the available GPU automatically but you need to make sure that `CUDA_VISIBLE_DEVICES` is set correctly.
28 |
29 | Sometimes `CUDA_VISIBLE_DEVICES` is set to an empty string and the application will not be able to detect the GPU. Use `unset CUDA_VISIBLE_DEVICES` to unset the variable.
30 |
31 | You can also set the `CUDA_VISIBLE_DEVICES` environment variable to the GPU index you want to use: `CUDA_VISIBLE_DEVICES=0 docker-compose up`.
32 |
33 |
34 | !!! Tip
35 |
36 | Some models use Flash Attention for better performance. You can set the build argument `INSTALL_FLASH_ATTENTION` to `true` to install Flash Attention.
37 |
38 | ```bash
39 | INSTALL_FLASH_ATTENTION=true docker-compose build
40 | ```
41 |
42 | After building the image, you can use `docker-compose up` command to run the application.
43 |
44 | You can also set the `INSTALL_FLASH_ATTENTION` environment variable to `true` in the `docker-compose.yaml` file.
45 |
46 |
--------------------------------------------------------------------------------
/aana/utils/lazy_import.py:
--------------------------------------------------------------------------------
1 | # Adapted from: https://github.com/deepset-ai/haystack
2 | #
3 | # SPDX-FileCopyrightText: 2022-present deepset GmbH
4 | #
5 | # SPDX-License-Identifier: Apache-2.0
6 |
7 | from types import TracebackType
8 |
9 | from lazy_imports.try_import import _DeferredImportExceptionContextManager
10 |
11 | DEFAULT_IMPORT_ERROR_MSG = "Try 'pip install {}'"
12 |
13 |
14 | class LazyImport(_DeferredImportExceptionContextManager):
15 | """Wrapper on top of lazy_import's _DeferredImportExceptionContextManager.
16 |
17 | It adds the possibility to customize the error messages.
18 | """
19 |
20 | def __init__(self, message: str = DEFAULT_IMPORT_ERROR_MSG) -> None:
21 | """Initialize the context manager."""
22 | super().__init__()
23 | self.import_error_msg = message
24 |
25 | def __exit__(
26 | self,
27 | exc_type: type[Exception] | None = None,
28 | exc_value: Exception | None = None,
29 | traceback: TracebackType | None = None,
30 | ) -> bool | None:
31 | """Exit the context manager.
32 |
33 | Args:
34 | exc_type:
35 | Raised exception type. :obj:`None` if nothing is raised.
36 | exc_value:
37 | Raised exception object. :obj:`None` if nothing is raised.
38 | traceback:
39 | Associated traceback. :obj:`None` if nothing is raised.
40 |
41 | Returns:
42 | :obj:`None` if nothing is deferred, otherwise :obj:`True`.
43 | :obj:`True` will suppress any exceptions avoiding them from propagating.
44 |
45 | """
46 | if isinstance(exc_value, ImportError):
47 | message = (
48 | f"Failed to import '{exc_value.name}'. {self.import_error_msg.format(exc_value.name)}. "
49 | f"Original error: {exc_value}"
50 | )
51 | self._deferred = (exc_value, message)
52 | return True
53 | return None
54 |
--------------------------------------------------------------------------------
/aana/exceptions/db.py:
--------------------------------------------------------------------------------
1 | from aana.core.models.media import MediaId
2 | from aana.exceptions.core import BaseException
3 |
4 | __all__ = [
5 | "MediaIdAlreadyExistsException",
6 | "NotFoundException",
7 | ]
8 |
9 |
10 | class DatabaseException(BaseException):
11 | def __init__(self, message: str):
12 | """Constructor.
13 |
14 | Args:
15 | message: (str): the error message.
16 | """
17 | super().__init__(message=message)
18 | self.message = message
19 |
20 | def __reduce__(self):
21 | """Used for pickling."""
22 | return (self.__class__, (self.message,))
23 |
24 |
25 | class NotFoundException(BaseException):
26 | """Raised when an item searched by id is not found."""
27 |
28 | def __init__(self, table_name: str, id: int | MediaId): # noqa: A002
29 | """Constructor.
30 |
31 | Args:
32 | table_name (str): the name of the table being queried.
33 | id (int | MediaId): the id of the item to be retrieved.
34 | """
35 | super().__init__(table=table_name, id=id)
36 | self.table_name = table_name
37 | self.id = id
38 | self.http_status_code = 404
39 |
40 | def __reduce__(self):
41 | """Used for pickling."""
42 | return (self.__class__, (self.table_name, self.id))
43 |
44 |
45 | class MediaIdAlreadyExistsException(BaseException):
46 | """Raised when a media_id already exists."""
47 |
48 | def __init__(self, table_name: str, media_id: MediaId):
49 | """Constructor.
50 |
51 | Args:
52 | table_name (str): the name of the table being queried.
53 | media_id (MediaId): the id of the item to be retrieved.
54 | """
55 | super().__init__(table=table_name, id=media_id)
56 | self.table_name = table_name
57 | self.media_id = media_id
58 |
59 | def __reduce__(self):
60 | """Used for pickling."""
61 | return (self.__class__, (self.table_name, self.media_id))
62 |
--------------------------------------------------------------------------------
/aana/core/libraries/audio.py:
--------------------------------------------------------------------------------
1 | from pathlib import Path
2 |
3 | import numpy as np
4 |
5 |
6 | class AbstractAudioLibrary:
7 | """Abstract class for audio libraries."""
8 |
9 | @classmethod
10 | def read_file(cls, path: Path) -> np.ndarray:
11 | """Read an audio file from path and return as numpy audio array.
12 |
13 | Args:
14 | path (Path): The path of the file to read.
15 |
16 | Returns:
17 | np.ndarray: The file as a numpy array.
18 | """
19 | raise NotImplementedError
20 |
21 | @classmethod
22 | def read_from_bytes(cls, content: bytes) -> np.ndarray:
23 | """Read bytes using the image library.
24 |
25 | Args:
26 | content (bytes): The content of the file to read.
27 |
28 | Returns:
29 | np.ndarray: The file as a numpy array.
30 | """
31 | raise NotImplementedError
32 |
33 | @classmethod
34 | def write_file(cls, path: Path, audio: np.ndarray):
35 | """Write a file using the audio library.
36 |
37 | Args:
38 | path (Path): The path of the file to write.
39 | audio (np.ndarray): The audio to write.
40 | """
41 | raise NotImplementedError
42 |
43 | @classmethod
44 | def write_to_bytes(cls, audio: np.ndarray) -> bytes:
45 | """Write bytes using the audio library.
46 |
47 | Args:
48 | audio (np.ndarray): The audio to write.
49 |
50 | Returns:
51 | bytes: The audio as bytes.
52 | """
53 | raise NotImplementedError
54 |
55 | @classmethod
56 | def write_audio_bytes(cls, path: Path, audio: bytes, sample_rate: int = 16000):
57 | """Write a file to wav from the normalized audio bytes.
58 |
59 | Args:
60 | path (Path): The path of the file to write.
61 | audio (bytes): The audio to in 16-bit PCM byte write.
62 | sample_rate (int): The sample rate of the audio.
63 | """
64 | raise NotImplementedError
65 |
--------------------------------------------------------------------------------
/aana/tests/deployments/test_vad_deployment.py:
--------------------------------------------------------------------------------
1 | # ruff: noqa: S101
2 | from importlib import resources
3 | from pathlib import Path
4 |
5 | import pytest
6 |
7 | from aana.core.models.audio import Audio
8 | from aana.core.models.base import pydantic_to_dict
9 | from aana.core.models.vad import VadParams
10 | from aana.deployments.aana_deployment_handle import AanaDeploymentHandle
11 | from aana.deployments.vad_deployment import VadConfig, VadDeployment
12 | from aana.tests.utils import verify_deployment_results
13 |
14 | deployments = [
15 | (
16 | "vad_deployment",
17 | VadDeployment.options(
18 | num_replicas=1,
19 | max_ongoing_requests=1000,
20 | ray_actor_options={"num_gpus": 0},
21 | user_config=VadConfig(
22 | model_id="pyannote/segmentation",
23 | onset=0.5,
24 | sample_rate=16000,
25 | ).model_dump(mode="json"),
26 | ),
27 | )
28 | ]
29 |
30 |
31 | @pytest.mark.parametrize("setup_deployment", deployments, indirect=True)
32 | class TestVadDeployment:
33 | """Test VAD deployment."""
34 |
35 | @pytest.mark.asyncio
36 | @pytest.mark.parametrize("audio_file", ["physicsworks.wav", "squirrel.wav"])
37 | async def test_vad(self, setup_deployment, audio_file):
38 | """Test VAD."""
39 | deployment_name, handle_name, _ = setup_deployment
40 |
41 | handle = await AanaDeploymentHandle.create(handle_name)
42 |
43 | audio_file_name = Path(audio_file).stem
44 | expected_output_path = (
45 | resources.files("aana.tests.files.expected")
46 | / "vad"
47 | / f"{audio_file_name}.json"
48 | )
49 |
50 | path = resources.files("aana.tests.files.audios") / audio_file
51 | assert path.exists(), f"Audio not found: {path}"
52 |
53 | audio = Audio(path=path)
54 |
55 | output = await handle.asr_preprocess_vad(audio=audio, params=VadParams())
56 | output = pydantic_to_dict(output)
57 | verify_deployment_results(expected_output_path, output)
58 |
--------------------------------------------------------------------------------
/aana/alembic/versions/acb40dabc2c0_added_webhooks.py:
--------------------------------------------------------------------------------
1 | """Added webhooks.
2 |
3 | Revision ID: acb40dabc2c0
4 | Revises: d40eba8ebc4c
5 | Create Date: 2025-01-30 14:32:16.596842
6 |
7 | """
8 | from collections.abc import Sequence
9 |
10 | import sqlalchemy as sa
11 | from alembic import op
12 | from sqlalchemy.dialects import postgresql
13 |
14 | # revision identifiers, used by Alembic.
15 | revision: str = "acb40dabc2c0"
16 | down_revision: str | None = "d40eba8ebc4c"
17 | branch_labels: str | Sequence[str] | None = None
18 | depends_on: str | Sequence[str] | None = None
19 |
20 |
21 | def upgrade() -> None:
22 | """Upgrade database to this revision from previous."""
23 | # fmt: off
24 | op.create_table('webhooks',
25 | sa.Column('id', sa.UUID(), nullable=False, comment='Webhook ID'),
26 | sa.Column('user_id', sa.String(), nullable=True, comment='The user ID associated with the webhook'),
27 | sa.Column('url', sa.String(), nullable=False, comment='The URL to which the webhook will send requests'),
28 | sa.Column('events', sa.JSON().with_variant(postgresql.JSONB(astext_type=sa.Text()), 'postgresql'), nullable=False, comment='List of events the webhook is subscribed to. If None, the webhook is subscribed to all events.'),
29 | sa.Column('created_at', sa.DateTime(timezone=True), server_default=sa.text('(CURRENT_TIMESTAMP)'), nullable=False, comment='Timestamp when row is inserted'),
30 | sa.Column('updated_at', sa.DateTime(timezone=True), server_default=sa.text('(CURRENT_TIMESTAMP)'), nullable=False, comment='Timestamp when row is updated'),
31 | sa.PrimaryKeyConstraint('id', name=op.f('pk_webhooks'))
32 | )
33 | with op.batch_alter_table('webhooks', schema=None) as batch_op:
34 | batch_op.create_index(batch_op.f('ix_webhooks_user_id'), ['user_id'], unique=False)
35 | # fmt: on
36 |
37 |
38 | def downgrade() -> None:
39 | """Downgrade database from this revision to previous."""
40 | with op.batch_alter_table("webhooks", schema=None) as batch_op:
41 | batch_op.drop_index(batch_op.f("ix_webhooks_user_id"))
42 |
43 | op.drop_table("webhooks")
44 |
--------------------------------------------------------------------------------
/aana/tests/units/test_deployment_retry.py:
--------------------------------------------------------------------------------
1 | # ruff: noqa: S101
2 |
3 | import pytest
4 | from ray import serve
5 |
6 | from aana.deployments.aana_deployment_handle import AanaDeploymentHandle
7 | from aana.deployments.base_deployment import BaseDeployment, exception_handler
8 |
9 |
10 | @serve.deployment(health_check_period_s=1, health_check_timeout_s=30)
11 | class Lowercase(BaseDeployment):
12 | """Ray deployment that returns the lowercase version of a text."""
13 |
14 | def __init__(self):
15 | """Initialize the deployment."""
16 | super().__init__()
17 | self.num_requests = 0
18 |
19 | @exception_handler
20 | async def lower(self, text: str) -> dict:
21 | """Lowercase the text.
22 |
23 | Args:
24 | text (str): The text to lowercase
25 |
26 | Returns:
27 | dict: The lowercase text
28 | """
29 | # Only every 3rd request should be successful
30 | self.num_requests += 1
31 | if self.num_requests % 3 != 0:
32 | raise Exception("Random exception") # noqa: TRY002, TRY003
33 |
34 | return {"text": text.lower()}
35 |
36 |
37 | deployments = [
38 | {
39 | "name": "lowercase_deployment",
40 | "instance": Lowercase,
41 | }
42 | ]
43 |
44 |
45 | @pytest.mark.asyncio
46 | async def test_deployment_retry(create_app):
47 | """Test the Ray Serve app."""
48 | create_app(deployments, [])
49 |
50 | text = "Hello, World!"
51 |
52 | # Get deployment handle without retries
53 | handle = await AanaDeploymentHandle.create(
54 | "lowercase_deployment", retry_exceptions=False
55 | )
56 |
57 | # test the lowercase deployment fails
58 | with pytest.raises(Exception): # noqa: B017
59 | await handle.lower(text=text)
60 |
61 | # Get deployment handle with retries
62 | handle = await AanaDeploymentHandle.create(
63 | "lowercase_deployment", retry_exceptions=True
64 | )
65 |
66 | # test the lowercase deployment works
67 | response = await handle.lower(text=text)
68 | assert response == {"text": text.lower()}
69 |
--------------------------------------------------------------------------------
/aana/utils/json.py:
--------------------------------------------------------------------------------
1 | from pathlib import Path
2 | from typing import Any
3 |
4 | import orjson
5 | from pydantic import BaseModel
6 | from sqlalchemy import Engine
7 |
8 | __all__ = ["json_serializer_default", "jsonify"]
9 |
10 |
11 | def json_serializer_default(obj: object) -> object:
12 | """Default function for json serializer to handle custom objects.
13 |
14 | If json serializer does not know how to serialize an object, it calls the default function.
15 |
16 | For example, if we see that the object is a pydantic model,
17 | we call the dict method to get the dictionary representation of the model
18 | that json serializer can deal with.
19 |
20 | If the object is not supported, we raise a TypeError.
21 |
22 | Args:
23 | obj (object): The object to serialize.
24 |
25 | Returns:
26 | object: The serializable object.
27 |
28 | Raises:
29 | TypeError: If the object is not a pydantic model, Path, or Media object.
30 | """
31 | if isinstance(obj, Engine):
32 | return None
33 | if isinstance(obj, BaseModel):
34 | return obj.model_dump()
35 | if isinstance(obj, Path):
36 | return str(obj)
37 | if isinstance(obj, type):
38 | return str(type)
39 | if isinstance(obj, bytes):
40 | return obj.decode()
41 |
42 | from aana.core.models.media import Media
43 |
44 | if isinstance(obj, Media):
45 | return str(obj)
46 |
47 | raise TypeError(type(obj))
48 |
49 |
50 | def jsonify(
51 | data: Any,
52 | option: int | None = orjson.OPT_SERIALIZE_NUMPY | orjson.OPT_SORT_KEYS,
53 | as_bytes: bool = False,
54 | ) -> str | bytes:
55 | """Serialize content using orjson.
56 |
57 | Args:
58 | data (Any): The content to serialize.
59 | option (int | None): The option for orjson.dumps.
60 | as_bytes (bool): Return output as bytes instead of string
61 |
62 | Returns:
63 | bytes | str: The serialized data as desired format.
64 | """
65 | output = orjson.dumps(data, option=option, default=json_serializer_default)
66 | return output if as_bytes else output.decode()
67 |
--------------------------------------------------------------------------------
/aana/tests/units/test_app_deploy.py:
--------------------------------------------------------------------------------
1 | from typing import Any
2 |
3 | import pytest
4 | from ray import serve
5 |
6 | from aana.deployments.base_deployment import BaseDeployment
7 | from aana.exceptions.runtime import FailedDeployment, InsufficientResources
8 |
9 |
10 | @serve.deployment
11 | class DummyFailingDeployment(BaseDeployment):
12 | """Simple deployment that fails on initialization."""
13 |
14 | async def apply_config(self, config: dict[str, Any]):
15 | """Apply the configuration to the deployment and initialize it."""
16 | raise Exception("Dummy exception") # noqa: TRY002, TRY003
17 |
18 |
19 | @serve.deployment
20 | class Lowercase(BaseDeployment):
21 | """Simple deployment that lowercases the text."""
22 |
23 | async def apply_config(self, config: dict[str, Any]):
24 | """Apply the configuration to the deployment and initialize it."""
25 | pass
26 |
27 | async def lower(self, text: str) -> dict:
28 | """Lowercase the text.
29 |
30 | Args:
31 | text (str): The text to lowercase
32 |
33 | Returns:
34 | dict: The lowercase text
35 | """
36 | return {"text": [t.lower() for t in text]}
37 |
38 |
39 | def test_failed_deployment(create_app):
40 | """Test that a failed deployment raises a FailedDeployment exception."""
41 | deployments = [
42 | {
43 | "name": "deployment",
44 | "instance": DummyFailingDeployment.options(num_replicas=1, user_config={}),
45 | }
46 | ]
47 | with pytest.raises(FailedDeployment):
48 | create_app(deployments, [])
49 |
50 |
51 | def test_insufficient_resources(create_app):
52 | """Test that deployment fails when there are insufficient resources to deploy."""
53 | deployments = [
54 | {
55 | "name": "deployment",
56 | "instance": Lowercase.options(
57 | num_replicas=1,
58 | ray_actor_options={"num_gpus": 100}, # requires 100 GPUs
59 | user_config={},
60 | ),
61 | }
62 | ]
63 | with pytest.raises(InsufficientResources):
64 | create_app(deployments, [])
65 |
--------------------------------------------------------------------------------
/aana/tests/files/expected/sd/sd_sample.json:
--------------------------------------------------------------------------------
1 | {
2 | "segments": [
3 | {
4 | "time_interval": {
5 | "start": 6.730343750000001,
6 | "end": 7.16909375
7 | },
8 | "speaker": "SPEAKER_01"
9 | },
10 | {
11 | "time_interval": {
12 | "start": 7.16909375,
13 | "end": 7.185968750000001
14 | },
15 | "speaker": "SPEAKER_02"
16 | },
17 | {
18 | "time_interval": {
19 | "start": 7.59096875,
20 | "end": 8.316593750000003
21 | },
22 | "speaker": "SPEAKER_01"
23 | },
24 | {
25 | "time_interval": {
26 | "start": 8.316593750000003,
27 | "end": 9.919718750000001
28 | },
29 | "speaker": "SPEAKER_02"
30 | },
31 | {
32 | "time_interval": {
33 | "start": 9.919718750000001,
34 | "end": 10.93221875
35 | },
36 | "speaker": "SPEAKER_01"
37 | },
38 | {
39 | "time_interval": {
40 | "start": 10.93221875,
41 | "end": 14.745968750000003
42 | },
43 | "speaker": "SPEAKER_02"
44 | },
45 | {
46 | "time_interval": {
47 | "start": 14.745968750000003,
48 | "end": 17.88471875
49 | },
50 | "speaker": "SPEAKER_00"
51 | },
52 | {
53 | "time_interval": {
54 | "start": 18.01971875,
55 | "end": 21.512843750000002
56 | },
57 | "speaker": "SPEAKER_02"
58 | },
59 | {
60 | "time_interval": {
61 | "start": 21.512843750000002,
62 | "end": 28.49909375
63 | },
64 | "speaker": "SPEAKER_00"
65 | },
66 | {
67 | "time_interval": {
68 | "start": 28.49909375,
69 | "end": 29.96721875
70 | },
71 | "speaker": "SPEAKER_02"
72 | }
73 | ]
74 | }
--------------------------------------------------------------------------------
/aana/tests/units/test_merge_options.py:
--------------------------------------------------------------------------------
1 | # ruff: noqa: S101
2 |
3 | import pytest
4 | from pydantic import BaseModel
5 |
6 | from aana.core.models.base import merged_options
7 |
8 |
9 | class MyOptions(BaseModel):
10 | """Test option class."""
11 |
12 | field1: str
13 | field2: int | None = None
14 | field3: bool
15 | field4: str = "default"
16 |
17 |
18 | def test_merged_options_same_type():
19 | """Test merged_options with options of the same type as default_options."""
20 | default = MyOptions(field1="default1", field2=2, field3=True)
21 | to_merge = MyOptions(field1="merge1", field2=None, field3=False)
22 | merged = merged_options(default, to_merge)
23 |
24 | assert merged.field1 == "merge1"
25 | assert (
26 | merged.field2 == 2
27 | ) # Should retain value from default_options as it's None in options
28 | assert merged.field3 == False
29 |
30 |
31 | def test_merged_options_none():
32 | """Test merged_options with options=None."""
33 | default = MyOptions(field1="default1", field2=2, field3=True)
34 | merged = merged_options(default, None)
35 |
36 | assert merged.model_dump() == default.model_dump()
37 |
38 |
39 | def test_merged_options_type_mismatch():
40 | """Test merged_options with options of a different type from default_options."""
41 |
42 | class AnotherOptions(BaseModel):
43 | another_field: str
44 |
45 | default = MyOptions(field1="default1", field2=2, field3=True)
46 | to_merge = AnotherOptions(another_field="test")
47 |
48 | with pytest.raises(ValueError):
49 | merged_options(default, to_merge)
50 |
51 |
52 | def test_merged_options_unset():
53 | """Test merged_options with unset fields."""
54 | default = MyOptions(field1="default1", field2=2, field3=True, field4="new_default")
55 | to_merge = MyOptions(field1="merge1", field3=False) # field4 is not set
56 | merged = merged_options(default, to_merge)
57 |
58 | assert merged.field1 == "merge1"
59 | assert merged.field2 == 2
60 | assert merged.field3 == False
61 | assert (
62 | merged.field4 == "new_default"
63 | ) # Should retain value from default_options as it's not set in options
64 |
--------------------------------------------------------------------------------
/aana/tests/units/test_whisper_params.py:
--------------------------------------------------------------------------------
1 | # ruff: noqa: S101
2 | import pytest
3 |
4 | from aana.core.models.whisper import WhisperParams
5 |
6 |
7 | def test_whisper_params_default():
8 | """Test the default values of WhisperParams object.
9 |
10 | Keeping the default parameters of a function or object is important
11 | in case other code relies on them.
12 |
13 | If you need to change the default parameters, think twice before doing so.
14 | """
15 | params = WhisperParams()
16 |
17 | assert params.language is None
18 | assert params.beam_size == 5
19 | assert params.best_of == 5
20 | assert params.temperature == (0.0, 0.2, 0.4, 0.6, 0.8, 1.0)
21 | assert params.word_timestamps is False
22 | assert params.vad_filter is True
23 |
24 |
25 | @pytest.mark.parametrize(
26 | "language, beam_size, best_of, temperature, word_timestamps, vad_filter",
27 | [
28 | ("en", 5, 5, 0.5, True, True),
29 | ("fr", 3, 3, 0.2, False, False),
30 | (None, 1, 1, [0.8, 0.9], True, False),
31 | ],
32 | )
33 | def test_whisper_params(
34 | language, beam_size, best_of, temperature, word_timestamps, vad_filter
35 | ):
36 | """Test function for the WhisperParams class with valid parameters."""
37 | params = WhisperParams(
38 | language=language,
39 | beam_size=beam_size,
40 | best_of=best_of,
41 | temperature=temperature,
42 | word_timestamps=word_timestamps,
43 | vad_filter=vad_filter,
44 | )
45 |
46 | assert params.language == language
47 | assert params.beam_size == beam_size
48 | assert params.best_of == best_of
49 | assert params.temperature == temperature
50 | assert params.word_timestamps == word_timestamps
51 | assert params.vad_filter == vad_filter
52 |
53 |
54 | @pytest.mark.parametrize(
55 | "temperature",
56 | [
57 | [-1.0, 0.5, 1.5],
58 | [0.0, 0.2, 0.4, 0.6, 0.8, 1.0, 2.0],
59 | "invalid_temperature",
60 | 2,
61 | ],
62 | )
63 | def test_whisper_params_invalid_temperature(temperature):
64 | """Check ValueError raised if temperature is invalid."""
65 | with pytest.raises(ValueError):
66 | WhisperParams(temperature=temperature)
67 |
--------------------------------------------------------------------------------
/aana/tests/db/datastore/test_video_repo.py:
--------------------------------------------------------------------------------
1 | # ruff: noqa: S101
2 |
3 | import uuid
4 | from importlib import resources
5 |
6 | import pytest
7 |
8 | from aana.core.models.video import Video, VideoMetadata
9 | from aana.exceptions.db import MediaIdAlreadyExistsException, NotFoundException
10 | from aana.storage.repository.video import VideoRepository
11 |
12 |
13 | @pytest.fixture(scope="function")
14 | def dummy_video():
15 | """Creates a dummy video for testing."""
16 | media_id = str(uuid.uuid4())
17 | path = resources.files("aana.tests.files.videos") / "squirrel.mp4"
18 | video = Video(path=path, media_id=media_id)
19 | return video
20 |
21 |
22 | @pytest.mark.asyncio
23 | async def test_save_video(db_session_manager, dummy_video):
24 | """Tests saving a video."""
25 | async with db_session_manager.session() as session:
26 | video_repo = VideoRepository(session)
27 | await video_repo.save(dummy_video)
28 |
29 | video_entity = await video_repo.read(dummy_video.media_id)
30 | assert video_entity
31 | assert video_entity.id == dummy_video.media_id
32 |
33 | # Try to save the same video again
34 | with pytest.raises(MediaIdAlreadyExistsException):
35 | await video_repo.save(dummy_video)
36 |
37 | await video_repo.delete(dummy_video.media_id)
38 | with pytest.raises(NotFoundException):
39 | await video_repo.read(dummy_video.media_id)
40 |
41 |
42 | @pytest.mark.asyncio
43 | async def test_get_metadata(db_session_manager, dummy_video):
44 | """Tests getting video metadata."""
45 | async with db_session_manager.session() as session:
46 | video_repo = VideoRepository(session)
47 | await video_repo.save(dummy_video)
48 |
49 | metadata = await video_repo.get_metadata(dummy_video.media_id)
50 | assert isinstance(metadata, VideoMetadata)
51 | assert metadata.title == dummy_video.title
52 | assert metadata.description == dummy_video.description
53 | assert metadata.duration == None
54 |
55 | await video_repo.delete(dummy_video.media_id)
56 | with pytest.raises(NotFoundException):
57 | await video_repo.get_metadata(dummy_video.media_id)
58 |
--------------------------------------------------------------------------------
/aana/tests/deployments/test_pyannote_speaker_diarization_deployment.py:
--------------------------------------------------------------------------------
1 | # ruff: noqa: S101
2 | from importlib import resources
3 | from pathlib import Path
4 |
5 | import pytest
6 |
7 | from aana.core.models.audio import Audio
8 | from aana.core.models.base import pydantic_to_dict
9 | from aana.deployments.aana_deployment_handle import AanaDeploymentHandle
10 | from aana.deployments.pyannote_speaker_diarization_deployment import (
11 | PyannoteSpeakerDiarizationConfig,
12 | PyannoteSpeakerDiarizationDeployment,
13 | )
14 | from aana.tests.utils import verify_deployment_results
15 |
16 | deployments = [
17 | (
18 | "sd_deployment",
19 | PyannoteSpeakerDiarizationDeployment.options(
20 | num_replicas=1,
21 | max_ongoing_requests=1000,
22 | ray_actor_options={"num_gpus": 0.05},
23 | user_config=PyannoteSpeakerDiarizationConfig(
24 | model_id=("pyannote/speaker-diarization-3.1"),
25 | sample_rate=16000,
26 | ).model_dump(mode="json"),
27 | ),
28 | )
29 | ]
30 |
31 |
32 | @pytest.mark.skip(reason="The test is temporary disabled")
33 | @pytest.mark.parametrize("setup_deployment", deployments, indirect=True)
34 | class TestPyannoteSpeakerDiarizationDeployment:
35 | """Test pyannote Speaker Diarization deployment."""
36 |
37 | @pytest.mark.asyncio
38 | @pytest.mark.parametrize("audio_file", ["sd_sample.wav"])
39 | async def test_speaker_diarization(self, setup_deployment, audio_file):
40 | """Test pyannote Speaker Diarization."""
41 | deployment_name, handle_name, _ = setup_deployment
42 |
43 | handle = await AanaDeploymentHandle.create(handle_name)
44 |
45 | audio_file_name = Path(audio_file).stem
46 | expected_output_path = (
47 | resources.files("aana.tests.files.expected")
48 | / "sd"
49 | / f"{audio_file_name}.json"
50 | )
51 |
52 | path = resources.files("aana.tests.files.audios") / audio_file
53 | assert path.exists(), f"Audio not found: {path}"
54 |
55 | audio = Audio(path=path)
56 |
57 | output = await handle.diarize(audio=audio)
58 | output = pydantic_to_dict(output)
59 |
60 | verify_deployment_results(expected_output_path, output)
61 |
--------------------------------------------------------------------------------
/aana/tests/deployments/test_hf_blip2_deployment.py:
--------------------------------------------------------------------------------
1 | # ruff: noqa: S101
2 | from importlib import resources
3 |
4 | import pytest
5 |
6 | from aana.core.models.image import Image
7 | from aana.core.models.types import Dtype
8 | from aana.deployments.aana_deployment_handle import AanaDeploymentHandle
9 | from aana.deployments.hf_blip2_deployment import HFBlip2Config, HFBlip2Deployment
10 | from aana.tests.utils import verify_deployment_results
11 |
12 | deployments = [
13 | (
14 | "blip2_deployment",
15 | HFBlip2Deployment.options(
16 | num_replicas=1,
17 | max_ongoing_requests=1000,
18 | ray_actor_options={"num_gpus": 0.25},
19 | user_config=HFBlip2Config(
20 | model="Salesforce/blip2-opt-2.7b",
21 | dtype=Dtype.FLOAT16,
22 | batch_size=2,
23 | num_processing_threads=2,
24 | ).model_dump(mode="json"),
25 | ),
26 | )
27 | ]
28 |
29 |
30 | @pytest.mark.parametrize("setup_deployment", deployments, indirect=True)
31 | class TestHFBlip2Deployment:
32 | """Test HuggingFace BLIP2 deployment."""
33 |
34 | @pytest.mark.asyncio
35 | @pytest.mark.parametrize("image_name", ["Starry_Night.jpeg"])
36 | async def test_methods(self, setup_deployment, image_name):
37 | """Test HuggingFace BLIP2 methods."""
38 | deployment_name, handle_name, _ = setup_deployment
39 |
40 | handle = await AanaDeploymentHandle.create(handle_name)
41 |
42 | expected_output_path = (
43 | resources.files("aana.tests.files.expected")
44 | / "hf_blip2"
45 | / f"{deployment_name}_{image_name}.json"
46 | )
47 |
48 | path = resources.files("aana.tests.files.images") / image_name
49 | image = Image(path=path, save_on_disk=False, media_id=image_name)
50 |
51 | output = await handle.generate(image=image)
52 | caption = output["caption"]
53 | verify_deployment_results(expected_output_path, caption)
54 |
55 | images = [image] * 8
56 |
57 | output = await handle.generate_batch(images=images)
58 | captions = output["captions"]
59 |
60 | assert len(captions) == 8
61 | for caption in captions:
62 | verify_deployment_results(expected_output_path, caption)
63 |
--------------------------------------------------------------------------------
/aana/api/event_handlers/rate_limit_handler.py:
--------------------------------------------------------------------------------
1 | from time import monotonic
2 |
3 | from typing_extensions import override
4 |
5 | from aana.api.event_handlers.event_handler import EventHandler
6 | from aana.exceptions.runtime import TooManyRequestsException
7 |
8 |
9 | class RateLimitHandler(EventHandler):
10 | """Event handler that raises TooManyRequestsException if the rate limit is exceeded.
11 |
12 | Attributes:
13 | capacity (int): number of resources (requests) per interval
14 | rate (float): the interval for the limit in seconds
15 | """
16 |
17 | capacity: int
18 | interval: float
19 |
20 | _calls: list
21 |
22 | def __init__(self, capacity: int, interval: float):
23 | """Constructor."""
24 | self.capacity = capacity
25 | self.interval = interval
26 | self._calls = []
27 |
28 | def _clear_expired(self, expired: float):
29 | """Removes expired items from list of resources.
30 |
31 | Arguments:
32 | expired: timestamp before which to clear, as output from time.monotonic()
33 | """
34 | while self._calls and self._calls[0] < expired:
35 | self._calls.pop(0)
36 |
37 | def _acquire(self):
38 | """Checks if we can acquire (process) a resource.
39 |
40 | Raises:
41 | TooManyRequestsException: if the rate limit has been reached
42 | """
43 | now = monotonic()
44 | expired = now - self.interval
45 | self._clear_expired(expired)
46 | if len(self._calls) < self.capacity:
47 | self._calls.append(now)
48 | else:
49 | raise TooManyRequestsException(self.capacity, self.interval)
50 |
51 | @override
52 | def handle(self, event_name: str, *args, **kwargs):
53 | """Handle the event by checking against rate limiting parameters.
54 |
55 | Arguments:
56 | event_name (str): the name of the event to handle
57 | *args (list): args for the event
58 | **kwargs (dict): keyword args for the event
59 |
60 | Raises:
61 | TooManyRequestsException: if the rate limit has been reached
62 | """
63 | # if the endpoint execution is deferred, we don't want to rate limit it
64 | defer = kwargs.get("defer", False)
65 | if not defer:
66 | self._acquire()
67 |
--------------------------------------------------------------------------------
/docs/pages/model_hub/hf_pipeline.md:
--------------------------------------------------------------------------------
1 | # Hugging Face Pipeline Models
2 |
3 | [Hugging Face Pipeline deployment](./../../reference/deployments.md#aana.deployments.hf_pipeline_deployment.HfPipelineDeployment) allows you to serve *almost* any model from the [Hugging Face Hub](https://huggingface.co/models). It is a wrapper for [Hugging Face Pipelines](https://huggingface.co/transformers/main_classes/pipelines.html) so you can deploy and scale *almost* any model from the Hugging Face Hub with a few lines of code.
4 |
5 | !!! Tip
6 | To use HF Pipeline deployment, install required libraries with `pip install transformers` or include extra dependencies using `pip install aana[transformers]`.
7 |
8 |
9 | [HfPipelineConfig](./../../reference/deployments.md#aana.deployments.hf_pipeline_deployment.HfPipelineConfig) is used to configure the Hugging Face Pipeline deployment.
10 |
11 | ::: aana.deployments.hf_pipeline_deployment.HfPipelineConfig
12 | options:
13 | show_bases: false
14 | heading_level: 4
15 | show_docstring_description: false
16 | docstring_section_style: list
17 |
18 | ### Example Configurations
19 |
20 | As an example, let's see how to configure the Hugging Face Pipeline deployment to serve [Salesforce BLIP-2 OPT-2.7b model](https://huggingface.co/Salesforce/blip2-opt-2.7b).
21 |
22 | !!! example "BLIP-2 OPT-2.7b"
23 |
24 | ```python
25 | from transformers import BitsAndBytesConfig
26 | from aana.deployments.hf_pipeline_deployment import HfPipelineConfig, HfPipelineDeployment
27 |
28 | HfPipelineDeployment.options(
29 | num_replicas=1,
30 | ray_actor_options={"num_gpus": 0.25},
31 | user_config=HfPipelineConfig(
32 | model_id="Salesforce/blip2-opt-2.7b",
33 | task="image-to-text",
34 | model_kwargs={
35 | "quantization_config": BitsAndBytesConfig(load_in_8bit=False, load_in_4bit=True),
36 | },
37 | ).model_dump(mode="json"),
38 | )
39 | ```
40 |
41 | Model ID is the Hugging Face model ID. `task` is one of the [Hugging Face Pipelines tasks](https://huggingface.co/transformers/main_classes/pipelines.html) that the model can perform. We deploy the model with 4-bit quantization by setting `quantization_config` in the `model_kwargs` dictionary. You can pass extra arguments to the model in the `model_kwargs` dictionary.
42 |
--------------------------------------------------------------------------------
/aana/tests/db/datastore/test_transcript_repo.py:
--------------------------------------------------------------------------------
1 | # ruff: noqa: S101
2 |
3 | import pytest
4 |
5 | from aana.core.models.asr import AsrSegment, AsrTranscription, AsrTranscriptionInfo
6 | from aana.core.models.time import TimeInterval
7 | from aana.exceptions.db import NotFoundException
8 | from aana.storage.models.transcript import TranscriptEntity
9 | from aana.storage.repository.transcript import TranscriptRepository
10 |
11 | transcript_entity = TranscriptEntity.from_asr_output(
12 | model_name="whisper",
13 | transcription=AsrTranscription(text="This is a transcript"),
14 | segments=[],
15 | info=AsrTranscriptionInfo(),
16 | )
17 |
18 |
19 | @pytest.fixture(scope="function")
20 | def dummy_transcript():
21 | """Creates a dummy transcript for testing."""
22 | transcript = AsrTranscription(text="This is a transcript")
23 | segments = [
24 | AsrSegment(text="This is a segment", time_interval=TimeInterval(start=0, end=1))
25 | ]
26 | info = AsrTranscriptionInfo(language="en", language_confidence=0.9)
27 | return transcript, segments, info
28 |
29 |
30 | @pytest.mark.asyncio
31 | async def test_save_transcript(db_session_manager, dummy_transcript):
32 | """Tests saving a transcript."""
33 | transcript, segments, info = dummy_transcript
34 | model_name = "whisper"
35 |
36 | async with db_session_manager.session() as session:
37 | transcript_repo = TranscriptRepository(session)
38 | transcript_entity = await transcript_repo.save(
39 | model_name=model_name,
40 | transcription_info=info,
41 | transcription=transcript,
42 | segments=segments,
43 | )
44 |
45 | transcript_id = transcript_entity.id
46 |
47 | transcript_entity = await transcript_repo.read(transcript_id)
48 | assert transcript_entity
49 | assert transcript_entity.id == transcript_id
50 | assert transcript_entity.model == model_name
51 | assert transcript_entity.transcript == transcript.text
52 | assert len(transcript_entity.segments) == len(segments)
53 | assert transcript_entity.language == info.language
54 | assert transcript_entity.language_confidence == info.language_confidence
55 |
56 | await transcript_repo.delete(transcript_id)
57 | with pytest.raises(NotFoundException):
58 | await transcript_repo.read(transcript_id)
59 |
--------------------------------------------------------------------------------
/aana/core/libraries/image.py:
--------------------------------------------------------------------------------
1 | from pathlib import Path
2 |
3 | import numpy as np
4 |
5 |
6 | class AbstractImageLibrary:
7 | """Abstract class for image libraries."""
8 |
9 | @classmethod
10 | def read_file(cls, path: Path) -> np.ndarray:
11 | """Read a file using the image library.
12 |
13 | Args:
14 | path (Path): The path of the file to read.
15 |
16 | Returns:
17 | np.ndarray: The file as a numpy array.
18 | """
19 | raise NotImplementedError
20 |
21 | @classmethod
22 | def read_from_bytes(cls, content: bytes) -> np.ndarray:
23 | """Read bytes using the image library.
24 |
25 | Args:
26 | content (bytes): The content of the file to read.
27 |
28 | Returns:
29 | np.ndarray: The file as a numpy array.
30 | """
31 | raise NotImplementedError
32 |
33 | @classmethod
34 | def write_file(
35 | cls,
36 | path: Path,
37 | img: np.ndarray,
38 | format: str = "bmp", # noqa: A002
39 | quality: int = 95,
40 | compression: int = 3,
41 | ):
42 | """Write an image to disk in BMP, PNG or JPEG format.
43 |
44 | Args:
45 | path (Path): Base path (extension will be enforced).
46 | img (np.ndarray): RGB image array.
47 | format (str): One of "bmp", "png", "jpeg" (or "jpg").
48 | quality (int): JPEG quality (0-100; higher is better). Only used if format is JPEG.
49 | compression (int): PNG compression level (0-9; higher is smaller). Only used if format is PNG.
50 | """
51 | raise NotImplementedError
52 |
53 | @classmethod
54 | def write_to_bytes(
55 | cls,
56 | img: np.ndarray,
57 | format: str = "jpeg", # noqa: A002
58 | quality: int = 95,
59 | compression: int = 3,
60 | ) -> bytes:
61 | """Write image to bytes in a specified format.
62 |
63 | Args:
64 | img (np.ndarray): The image to write.
65 | format (str): The format to use for encoding. Default is "jpeg".
66 | quality (int): The quality to use for encoding. Default is 95.
67 | compression (int): The compression level to use for encoding. Default is 3.
68 |
69 | Returns:
70 | bytes: The image as bytes.
71 | """
72 | raise NotImplementedError
73 |
--------------------------------------------------------------------------------
/aana/tests/units/test_sampling_params.py:
--------------------------------------------------------------------------------
1 | # ruff: noqa: S101
2 | import pytest
3 | from pydantic import ValidationError
4 |
5 | from aana.core.models.sampling import SamplingParams
6 |
7 |
8 | def test_valid_sampling_params():
9 | """Test valid sampling parameters."""
10 | params = SamplingParams(
11 | temperature=0.5, top_p=0.9, top_k=10, max_tokens=50, repetition_penalty=1.5
12 | )
13 | assert params.temperature == 0.5
14 | assert params.top_p == 0.9
15 | assert params.top_k == 10
16 | assert params.max_tokens == 50
17 | assert params.repetition_penalty == 1.5
18 |
19 | # Test valid params with default values
20 | params = SamplingParams()
21 | assert params.temperature == 1.0
22 | assert params.top_p == 1.0
23 | assert params.top_k is None
24 | assert params.max_tokens is None
25 | assert params.repetition_penalty == 1.0
26 |
27 |
28 | def test_invalid_temperature():
29 | """Test invalid temperature values."""
30 | with pytest.raises(ValueError):
31 | SamplingParams(temperature=-1.0)
32 |
33 |
34 | def test_invalid_top_p():
35 | """Test invalid top_p values."""
36 | with pytest.raises(ValueError):
37 | SamplingParams(top_p=0.0)
38 | with pytest.raises(ValueError):
39 | SamplingParams(top_p=1.1)
40 |
41 |
42 | def test_invalid_top_k():
43 | """Test invalid top_k values."""
44 | with pytest.raises(ValueError):
45 | SamplingParams(top_k=0)
46 | with pytest.raises(ValueError):
47 | SamplingParams(top_k=-2)
48 |
49 |
50 | def test_invalid_max_tokens():
51 | """Test invalid max_tokens values."""
52 | with pytest.raises(ValueError):
53 | SamplingParams(max_tokens=0)
54 |
55 |
56 | def test_kwargs():
57 | """Test extra keyword arguments."""
58 | params = SamplingParams(
59 | temperature=0.5, kwargs={"presence_penalty": 2.0, "frequency_penalty": 1.0}
60 | )
61 | assert params.kwargs == {"presence_penalty": 2.0, "frequency_penalty": 1.0}
62 | assert params.temperature == 0.5
63 | assert params.top_p == 1.0
64 | assert params.top_k is None
65 | assert params.max_tokens is None
66 | assert params.repetition_penalty == 1.0
67 |
68 |
69 | def test_disallowed_extra_fields():
70 | """Test that extra fields are not allowed."""
71 | with pytest.raises(ValidationError):
72 | SamplingParams(temperature=0.5, extra_field="extra_value")
73 |
--------------------------------------------------------------------------------
/aana/storage/repository/caption.py:
--------------------------------------------------------------------------------
1 | from typing import TypeVar
2 |
3 | from sqlalchemy.ext.asyncio import AsyncSession
4 |
5 | from aana.core.models.captions import Caption, CaptionsList
6 | from aana.storage.models.caption import CaptionEntity
7 | from aana.storage.repository.base import BaseRepository
8 |
9 | T = TypeVar("T", bound=CaptionEntity)
10 |
11 |
12 | class CaptionRepository(BaseRepository[T]):
13 | """Repository for Captions."""
14 |
15 | def __init__(self, session: AsyncSession, model_class: type[T] = CaptionEntity):
16 | """Constructor."""
17 | super().__init__(session, model_class)
18 |
19 | async def save(
20 | self, model_name: str, caption: Caption, timestamp: float, frame_id: int
21 | ):
22 | """Save a caption.
23 |
24 | Args:
25 | model_name (str): The name of the model used to generate the caption.
26 | caption (Caption): The caption.
27 | timestamp (float): The timestamp.
28 | frame_id (int): The frame ID.
29 | """
30 | entity = CaptionEntity.from_caption_output(
31 | model_name=model_name,
32 | frame_id=frame_id,
33 | timestamp=timestamp,
34 | caption=caption,
35 | )
36 | await self.create(entity)
37 | return entity
38 |
39 | async def save_all(
40 | self,
41 | model_name: str,
42 | captions: CaptionsList,
43 | timestamps: list[float],
44 | frame_ids: list[int],
45 | ) -> list[CaptionEntity]:
46 | """Save captions.
47 |
48 | Args:
49 | model_name (str): The name of the model used to generate the captions.
50 | captions (CaptionsList): The captions.
51 | timestamps (list[float]): The timestamps.
52 | frame_ids (list[int]): The frame IDs.
53 |
54 | Returns:
55 | list[CaptionEntity]: The list of caption entities.
56 | """
57 | entities = [
58 | CaptionEntity.from_caption_output(
59 | model_name=model_name,
60 | frame_id=frame_id,
61 | timestamp=timestamp,
62 | caption=caption,
63 | )
64 | for caption, timestamp, frame_id in zip(
65 | captions, timestamps, frame_ids, strict=True
66 | )
67 | ]
68 | results = await self.create_multiple(entities)
69 | return results
70 |
--------------------------------------------------------------------------------
/aana/tests/units/test_deployment_restart.py:
--------------------------------------------------------------------------------
1 | # ruff: noqa: S101
2 | import asyncio
3 |
4 | import pytest
5 | from ray import serve
6 |
7 | from aana.deployments.aana_deployment_handle import AanaDeploymentHandle
8 | from aana.deployments.base_deployment import BaseDeployment, exception_handler
9 | from aana.exceptions.runtime import InferenceException
10 |
11 |
12 | @serve.deployment(health_check_period_s=1, health_check_timeout_s=30)
13 | class Lowercase(BaseDeployment):
14 | """Ray deployment that returns the lowercase version of a text."""
15 |
16 | def __init__(self):
17 | """Initialize the deployment."""
18 | super().__init__()
19 | self.active = True
20 |
21 | @exception_handler
22 | async def lower(self, text: str) -> dict:
23 | """Lowercase the text.
24 |
25 | Args:
26 | text (str): The text to lowercase
27 |
28 | Returns:
29 | dict: The lowercase text
30 | """
31 | if text == "inference_exception" or not self.active:
32 | self.active = False
33 | raise InferenceException(model_name="lowercase_deployment")
34 |
35 | return {"text": text.lower()}
36 |
37 |
38 | deployments = [
39 | {
40 | "name": "lowercase_deployment",
41 | "instance": Lowercase,
42 | }
43 | ]
44 |
45 |
46 | @pytest.mark.asyncio
47 | async def test_deployment_restart(create_app):
48 | """Test the Ray Serve app."""
49 | create_app(deployments, [])
50 |
51 | handle = await AanaDeploymentHandle.create("lowercase_deployment")
52 |
53 | text = "Hello, World!"
54 |
55 | # test the lowercase deployment works
56 | response = await handle.lower(text=text)
57 | assert response == {"text": text.lower()}
58 |
59 | # Cause an InferenceException in the deployment and make it inactive.
60 | # After the deployment is inactive, the deployment should always raise an InferenceException.
61 | with pytest.raises(InferenceException):
62 | await handle.lower(text="inference_exception")
63 |
64 | # The deployment should restart and work again, wait for around 60 seconds for the deployment to restart.
65 | for _ in range(60):
66 | await asyncio.sleep(1)
67 | try:
68 | response = await handle.lower(text=text)
69 | if response == {"text": text.lower()}:
70 | break
71 | except: # noqa: S110
72 | pass
73 |
74 | assert response == {"text": text.lower()}
75 |
--------------------------------------------------------------------------------
/aana/configs/db.py:
--------------------------------------------------------------------------------
1 | from enum import Enum
2 | from os import PathLike
3 |
4 | from pydantic_settings import BaseSettings
5 | from typing_extensions import TypedDict
6 |
7 |
8 | class DbType(str, Enum):
9 | """Engine types for relational database.
10 |
11 | Attributes:
12 | POSTGRESQL: PostgreSQL database.
13 | SQLITE: SQLite database.
14 | """
15 |
16 | POSTGRESQL = "postgresql"
17 | SQLITE = "sqlite"
18 |
19 |
20 | class SQLiteConfig(TypedDict):
21 | """Config values for SQLite.
22 |
23 | Attributes:
24 | path (PathLike): The path to the SQLite database file.
25 | """
26 |
27 | path: PathLike | str
28 |
29 |
30 | class PostgreSQLConfig(TypedDict):
31 | """Config values for PostgreSQL.
32 |
33 | Attributes:
34 | host (str): The host of the PostgreSQL server.
35 | port (int): The port of the PostgreSQL server.
36 | user (str): The user to connect to the PostgreSQL server.
37 | password (str): The password to connect to the PostgreSQL server.
38 | database (str): The database name.
39 | """
40 |
41 | host: str
42 | port: int
43 | user: str
44 | password: str
45 | database: str
46 |
47 |
48 | class DbSettings(BaseSettings):
49 | """Database configuration.
50 |
51 | Attributes:
52 | datastore_type (DbType | str): The type of the datastore. Default is DbType.SQLITE.
53 | datastore_config (SQLiteConfig | PostgreSQLConfig): The configuration for the datastore.
54 | Default is SQLiteConfig(path="/var/lib/aana_data").
55 | pool_size (int): The number of connections to keep in the pool. Default is 5.
56 | max_overflow (int): The number of connections that can be created when the pool is exhausted.
57 | Default is 10.
58 | pool_recycle (int): The number of seconds a connection can be idle in the pool before it is invalidated.
59 | Default is 3600.
60 | query_timeout (int): The timeout for database queries in seconds. Default is 30. Set to 0 for no timeout.
61 | connection_timeout (int): The timeout for database connections in seconds. Default is 10.
62 | """
63 |
64 | datastore_type: DbType | str = DbType.SQLITE
65 | datastore_config: SQLiteConfig | PostgreSQLConfig = SQLiteConfig(
66 | path="/var/lib/aana_data"
67 | )
68 | pool_size: int = 5
69 | max_overflow: int = 10
70 | pool_recycle: int = 3600
71 | query_timeout: int = 30
72 | connection_timeout: int = 10
73 |
--------------------------------------------------------------------------------
/aana/tests/units/test_settings.py:
--------------------------------------------------------------------------------
1 | # ruff: noqa: S101, S108
2 | from pathlib import Path
3 |
4 | from aana.configs.settings import Settings
5 |
6 |
7 | def test_default_tmp_data_dir():
8 | """Test that the default temporary data directory is set correctly."""
9 | settings = Settings()
10 | assert settings.tmp_data_dir == Path("/tmp/aana_data")
11 |
12 |
13 | def test_custom_tmp_data_dir(monkeypatch):
14 | """Test that the custom temporary data directory with environment variable is set correctly."""
15 | test_path = "/tmp/override/path"
16 | monkeypatch.setenv("TMP_DATA_DIR", test_path)
17 | settings = Settings()
18 | assert settings.tmp_data_dir == Path(test_path)
19 |
20 |
21 | def test_changing_tmp_data_dir():
22 | """Test that changing the temporary data directory is reflected in the other directories."""
23 | new_tmp_data_dir = Path("/tmp/new_tmp_data_dir")
24 | settings = Settings(tmp_data_dir=new_tmp_data_dir)
25 |
26 | assert settings.tmp_data_dir == new_tmp_data_dir
27 | assert settings.image_dir == new_tmp_data_dir / "images"
28 | assert settings.video_dir == new_tmp_data_dir / "videos"
29 | assert settings.audio_dir == new_tmp_data_dir / "audios"
30 | assert settings.model_dir == new_tmp_data_dir / "models"
31 |
32 | # Check that we can change the image directory independently
33 | new_image_dir = Path("/tmp/new_image_dir")
34 | settings = Settings(tmp_data_dir=new_tmp_data_dir, image_dir=new_image_dir)
35 | assert settings.tmp_data_dir == new_tmp_data_dir
36 | assert settings.image_dir == new_image_dir
37 |
38 | # Check that we can change the video directory independently
39 | new_video_dir = Path("/tmp/new_video_dir")
40 | settings = Settings(tmp_data_dir=new_tmp_data_dir, video_dir=new_video_dir)
41 | assert settings.tmp_data_dir == new_tmp_data_dir
42 | assert settings.video_dir == new_video_dir
43 |
44 | # Check that we can change the audio directory independently
45 | new_audio_dir = Path("/tmp/new_audio_dir")
46 | settings = Settings(tmp_data_dir=new_tmp_data_dir, audio_dir=new_audio_dir)
47 | assert settings.tmp_data_dir == new_tmp_data_dir
48 | assert settings.audio_dir == new_audio_dir
49 |
50 | # Check that we can change the model directory independently
51 | new_model_dir = Path("/tmp/new_model_dir")
52 | settings = Settings(tmp_data_dir=new_tmp_data_dir, model_dir=new_model_dir)
53 | assert settings.tmp_data_dir == new_tmp_data_dir
54 | assert settings.model_dir == new_model_dir
55 |
--------------------------------------------------------------------------------
/aana/tests/units/test_event_manager.py:
--------------------------------------------------------------------------------
1 | # ruff: noqa: S101
2 | from collections.abc import Callable
3 |
4 | import pytest
5 | from typing_extensions import override
6 |
7 | from aana.api.event_handlers.event_handler import EventHandler
8 | from aana.api.event_handlers.event_manager import EventManager
9 | from aana.exceptions.runtime import (
10 | HandlerNotRegisteredException,
11 | )
12 |
13 |
14 | class CallbackHandler(EventHandler):
15 | """A test event handler that just invokes a callback function."""
16 |
17 | def __init__(self, callback: Callable):
18 | """Constructor."""
19 | super().__init__()
20 | self.callback = callback
21 |
22 | @override
23 | def handle(self, event_name: str, *args, **kwargs):
24 | self.callback(event_name, *args, **kwargs)
25 |
26 |
27 | def test_event_dispatch():
28 | """Tests that event dispatch works correctly."""
29 | event_manager = EventManager()
30 | expected_event_name = "foo"
31 | expected_args = (1, 2, 3, 4, 5)
32 | expected_kwargs = {"a": "A", "b": "B"}
33 |
34 | def callback(event_name, *args, **kwargs):
35 | assert event_name == expected_event_name
36 | assert args == expected_args
37 | assert kwargs == expected_kwargs
38 |
39 | event_manager.register_handler_for_events(CallbackHandler(callback), ["foo"])
40 |
41 | event_manager.handle("foo", 1, 2, 3, 4, 5, a="A", b="B")
42 |
43 |
44 | def test_remove_all_raises():
45 | """Tests that removing handler not added from all events raises an error."""
46 | event_manager = EventManager()
47 | handler = CallbackHandler(lambda _, *_args, **_kwargs: None)
48 | with pytest.raises(HandlerNotRegisteredException):
49 | event_manager.deregister_handler_from_all_events(handler)
50 |
51 |
52 | def test_remove_works():
53 | """Tests that removing a handler works."""
54 | event_manager = EventManager()
55 | handler = CallbackHandler(lambda _, *_args, **_kwargs: None)
56 | event_manager.register_handler_for_events(handler, ["foo"])
57 | event_manager.deregister_handler_from_event(handler, "foo")
58 | assert len(event_manager._handlers["foo"]) == 0
59 |
60 |
61 | def test_remove_all_works():
62 | """Tests that removing all handlers works."""
63 | event_manager = EventManager()
64 | handler = CallbackHandler(lambda _, *_args, **_kwargs: None)
65 | event_manager.register_handler_for_events(handler, ["foo"])
66 | event_manager.deregister_handler_from_all_events(handler)
67 | assert len(event_manager._handlers["foo"]) == 0
68 |
--------------------------------------------------------------------------------
/aana/tests/units/test_speaker.py:
--------------------------------------------------------------------------------
1 | import json
2 | from importlib import resources
3 | from pathlib import Path
4 | from typing import Literal
5 |
6 | import pytest
7 |
8 | from aana.core.models.asr import AsrSegment
9 | from aana.core.models.speaker import SpeakerDiarizationSegment
10 | from aana.processors.speaker import PostProcessingForDiarizedAsr
11 | from aana.tests.utils import verify_deployment_results
12 |
13 |
14 | @pytest.mark.skip(reason="The test is temporary disabled")
15 | @pytest.mark.parametrize("audio_file", ["sd_sample.wav"])
16 | def test_asr_diarization_post_process(audio_file: Literal["sd_sample.wav"]):
17 | """Test that the ASR output can be processed to generate diarized transcription and an invalid ASR output leads to ValueError."""
18 | # Load precomputed ASR and Diarization outputs
19 | asr_path = (
20 | resources.files("aana.tests.files.expected.whisper")
21 | / f"whisper_medium_{audio_file}.json"
22 | )
23 | diar_path = (
24 | resources.files("aana.tests.files.expected.sd")
25 | / f"{Path(audio_file).stem}.json"
26 | )
27 | expected_results_path = (
28 | resources.files("aana.tests.files.expected.whisper")
29 | / f"whisper_medium_{audio_file}_diar.json"
30 | )
31 |
32 | # convert to WhisperOutput and SpeakerDiarizationOutput
33 | with Path.open(asr_path, "r") as json_file:
34 | asr_op = json.load(json_file)
35 |
36 | asr_segments = [
37 | AsrSegment.model_validate(segment) for segment in asr_op["segments"]
38 | ]
39 |
40 | with Path.open(diar_path, "r") as json_file:
41 | diar_op = json.load(json_file)
42 |
43 | diarized_segments = [
44 | SpeakerDiarizationSegment.model_validate(segment)
45 | for segment in diar_op["segments"]
46 | ]
47 | asr_op["segments"] = PostProcessingForDiarizedAsr.process(
48 | diarized_segments=diarized_segments, transcription_segments=asr_segments
49 | )
50 | verify_deployment_results(expected_results_path, asr_op)
51 |
52 | # Raise error if the ASR output is a invalid input for combining with diarization.
53 |
54 | # setting words to empty list
55 | for segment in asr_segments:
56 | segment.words = []
57 |
58 | # Expect ValueError with the specific error message
59 | with pytest.raises(
60 | ValueError, match="Word-level timestamps are required for diarized ASR."
61 | ):
62 | PostProcessingForDiarizedAsr.process(
63 | diarized_segments=diarized_segments, transcription_segments=asr_segments
64 | )
65 |
--------------------------------------------------------------------------------
/aana/core/chat/chat_template.py:
--------------------------------------------------------------------------------
1 | import typing
2 | from functools import lru_cache
3 | from importlib import resources
4 |
5 | from aana.core.models.chat import ChatDialog
6 |
7 | if typing.TYPE_CHECKING:
8 | from transformers.tokenization_utils_base import PreTrainedTokenizerBase
9 |
10 |
11 | @lru_cache(maxsize=128)
12 | def load_chat_template(chat_template_name: str) -> str:
13 | """Loads a chat template from the chat templates directory.
14 |
15 | Args:
16 | chat_template_name (str): The name of the chat template to load.
17 |
18 | Returns:
19 | str: The loaded chat template.
20 |
21 | Raises:
22 | ValueError: If the chat template does not exist.
23 | """
24 | path = resources.files("aana.core.chat.templates") / f"{chat_template_name}.jinja"
25 | if not path.exists():
26 | raise ValueError(f"Chat template {chat_template_name} does not exist.") # noqa: TRY003
27 |
28 | return path.read_text()
29 |
30 |
31 | def apply_chat_template(
32 | tokenizer: "PreTrainedTokenizerBase",
33 | dialog: ChatDialog | list[dict],
34 | chat_template_name: str | None = None,
35 | ) -> str:
36 | """Applies a chat template to a list of messages to generate a prompt for the model.
37 |
38 | If the chat template is not specified, the tokenizer's default chat template is used.
39 | If the chat template is specified, the template with given name loaded from the chat templates directory.
40 |
41 | Args:
42 | tokenizer (PreTrainedTokenizerBase): The tokenizer to use.
43 | dialog (ChatDialog | list[dict]): The dialog to generate a prompt for.
44 | chat_template_name (str, optional): The name of the chat template to use. Defaults to None, which uses the tokenizer's default chat template.
45 |
46 | Returns:
47 | str: The generated prompt.
48 |
49 | Raises:
50 | ValueError: If the tokenizer does not have a chat template.
51 | ValueError: If the chat template does not exist.
52 | """
53 | if isinstance(dialog, ChatDialog):
54 | messages = dialog.model_dump()["messages"]
55 | else:
56 | messages = dialog
57 |
58 | if chat_template_name is not None:
59 | chat_template = load_chat_template(chat_template_name)
60 | tokenizer.chat_template = chat_template
61 |
62 | if tokenizer.chat_template is None:
63 | raise ValueError("Tokenizer does not have a chat template.") # noqa: TRY003
64 |
65 | return tokenizer.apply_chat_template(
66 | messages, tokenize=False, add_generation_prompt=True
67 | )
68 |
--------------------------------------------------------------------------------
/aana/core/models/speaker.py:
--------------------------------------------------------------------------------
1 | from typing import Annotated
2 |
3 | from pydantic import BaseModel, ConfigDict, Field
4 |
5 | from aana.core.models.time import TimeInterval
6 |
7 | __all__ = [
8 | "PyannoteSpeakerDiarizationParams",
9 | "SpeakerDiarizationSegment",
10 | "SpeakerDiarizationSegments",
11 | ]
12 |
13 |
14 | class PyannoteSpeakerDiarizationParams(BaseModel):
15 | """A model for the pyannote Speaker Diarization model parameters.
16 |
17 | Attributes:
18 | min_speakers (int | None): The minimum number of speakers present in the audio.
19 | max_speakers (int | None): The maximum number of speakers present in the audio.
20 | """
21 |
22 | min_speakers: int | None = Field(
23 | default=None,
24 | description="The minimum number of speakers present in the audio.",
25 | )
26 |
27 | max_speakers: int | None = Field(
28 | default=None,
29 | description="The maximum number of speakers present in the audio.",
30 | )
31 |
32 | model_config = ConfigDict(
33 | json_schema_extra={
34 | "description": "Parameters for the pyannote speaker diarization model.",
35 | },
36 | extra="forbid",
37 | )
38 |
39 |
40 | class SpeakerDiarizationSegment(BaseModel):
41 | """Pydantic schema for Segment from Speaker Diarization model.
42 |
43 | Attributes:
44 | time_interval (TimeInterval): The start and end time of the segment
45 | speaker (str): speaker assignment of the model in the format "SPEAKER_XX"
46 | """
47 |
48 | time_interval: TimeInterval = Field(description="Time interval of the segment")
49 | speaker: str = Field(description="speaker assignment from the model")
50 |
51 | def to_dict(self) -> dict:
52 | """Generate dictionary with start, end and speaker keys from SpeakerDiarizationSegment.
53 |
54 | Returns:
55 | dict: Dictionary with start, end and speaker keys
56 | """
57 | return {
58 | "start": self.time_interval.start,
59 | "end": self.time_interval.end,
60 | "speaker": self.speaker,
61 | }
62 |
63 | model_config = ConfigDict(
64 | json_schema_extra={
65 | "description": "Speaker Diarization Segment",
66 | },
67 | extra="forbid",
68 | )
69 |
70 |
71 | SpeakerDiarizationSegments = Annotated[
72 | list[SpeakerDiarizationSegment],
73 | Field(description="List of Speaker Diarization segments", default_factory=list),
74 | ]
75 | """
76 | List of SpeakerDiarizationSegment objects.
77 | """
78 |
--------------------------------------------------------------------------------
/aana/tests/deployments/test_hf_pipeline_deployment.py:
--------------------------------------------------------------------------------
1 | from importlib import resources
2 |
3 | import pytest
4 | from transformers import BitsAndBytesConfig
5 |
6 | from aana.core.models.image import Image
7 | from aana.deployments.aana_deployment_handle import AanaDeploymentHandle
8 | from aana.deployments.hf_pipeline_deployment import (
9 | HfPipelineConfig,
10 | HfPipelineDeployment,
11 | )
12 | from aana.tests.utils import verify_deployment_results
13 |
14 | deployments = [
15 | (
16 | "hf_pipeline_blip2_deployment",
17 | HfPipelineDeployment.options(
18 | num_replicas=1,
19 | ray_actor_options={"num_gpus": 1},
20 | user_config=HfPipelineConfig(
21 | model_id="Salesforce/blip2-opt-2.7b",
22 | task = "image-to-text",
23 | model_kwargs={
24 | "quantization_config": BitsAndBytesConfig(
25 | load_in_8bit=False, load_in_4bit=True
26 | ),
27 | },
28 | ).model_dump(mode="json"),
29 | ),
30 | )
31 | ]
32 |
33 |
34 | @pytest.mark.parametrize("setup_deployment", deployments, indirect=True)
35 | class TestHFPipelineDeployment:
36 | """Test HuggingFace Pipeline deployment."""
37 |
38 | @pytest.mark.asyncio
39 | @pytest.mark.parametrize("image_name", ["Starry_Night.jpeg"])
40 | async def test_call(self, setup_deployment, image_name):
41 | """Test call method."""
42 | deployment_name, handle_name, _ = setup_deployment
43 |
44 | handle = await AanaDeploymentHandle.create(handle_name)
45 |
46 | expected_output_path = (
47 | resources.files("aana.tests.files.expected")
48 | / "hf_pipeline"
49 | / f"{deployment_name}_{image_name}.json"
50 | )
51 | path = resources.files("aana.tests.files.images") / image_name
52 | image = Image(path=path, save_on_disk=False, media_id=image_name)
53 |
54 | output = await handle.call(images=image)
55 | verify_deployment_results(expected_output_path, output)
56 |
57 | output = await handle.call(image)
58 | verify_deployment_results(expected_output_path, output)
59 |
60 | output = await handle.call(images=[str(path)])
61 | verify_deployment_results(expected_output_path, [output])
62 |
63 | output = await handle.call(images=[image])
64 | verify_deployment_results(expected_output_path, [output])
65 |
66 | output = await handle.call([image])
67 | verify_deployment_results(expected_output_path, [output])
68 |
--------------------------------------------------------------------------------
/aana/tests/projects/lowercase/app.py:
--------------------------------------------------------------------------------
1 | from typing import Annotated, TypedDict
2 |
3 | from pydantic import Field
4 | from ray import serve
5 |
6 | from aana.api.api_generation import Endpoint
7 | from aana.deployments.aana_deployment_handle import AanaDeploymentHandle
8 | from aana.deployments.base_deployment import BaseDeployment
9 | from aana.sdk import AanaSDK
10 |
11 |
12 | @serve.deployment
13 | class Lowercase(BaseDeployment):
14 | """Ray deployment that returns the lowercase version of a text."""
15 |
16 | async def lower(self, text: str) -> dict:
17 | """Lowercase the text.
18 |
19 | Args:
20 | text (str): The text to lowercase
21 |
22 | Returns:
23 | dict: The lowercase text
24 | """
25 | return {"text": [t.lower() for t in text]}
26 |
27 |
28 | TextList = Annotated[list[str], Field(description="List of text to lowercase.")]
29 |
30 |
31 | class LowercaseEndpointOutput(TypedDict):
32 | """The output of the lowercase endpoint."""
33 |
34 | text: list[str]
35 |
36 |
37 | class LowercaseEndpoint(Endpoint):
38 | """Lowercase endpoint."""
39 |
40 | async def initialize(self):
41 | """Initialize the endpoint."""
42 | self.lowercase_handle = await AanaDeploymentHandle.create(
43 | "lowercase_deployment"
44 | )
45 | await super().initialize()
46 |
47 | async def run(self, text: TextList) -> LowercaseEndpointOutput:
48 | """Lowercase the text.
49 |
50 | Args:
51 | text (TextList): The list of text to lowercase
52 |
53 | Returns:
54 | LowercaseEndpointOutput: The lowercase texts
55 | """
56 | lowercase_output = await self.lowercase_handle.lower(text=text)
57 | return {"text": lowercase_output["text"]}
58 |
59 |
60 | deployments = [
61 | {
62 | "name": "lowercase_deployment",
63 | "instance": Lowercase,
64 | }
65 | ]
66 |
67 | endpoints = [
68 | {
69 | "name": "lowercase",
70 | "path": "/lowercase",
71 | "summary": "Lowercase text",
72 | "endpoint_cls": LowercaseEndpoint,
73 | }
74 | ]
75 |
76 | aana_app = AanaSDK(name="lowercase_app")
77 |
78 | for deployment in deployments:
79 | aana_app.register_deployment(
80 | name=deployment["name"],
81 | instance=deployment["instance"],
82 | )
83 |
84 | for endpoint in endpoints:
85 | aana_app.register_endpoint(
86 | name=endpoint["name"],
87 | path=endpoint["path"],
88 | summary=endpoint["summary"],
89 | endpoint_cls=endpoint["endpoint_cls"],
90 | )
91 |
--------------------------------------------------------------------------------
/aana/core/models/task.py:
--------------------------------------------------------------------------------
1 | from typing import Annotated, Any
2 |
3 | from pydantic import BaseModel, ConfigDict, Field
4 |
5 | from aana.core.models.api_service import ApiKey
6 | from aana.storage.models.task import Status as TaskStatus
7 | from aana.storage.models.task import TaskEntity
8 |
9 | TaskId = Annotated[
10 | str,
11 | Field(description="The task ID.", example="11111111-1111-1111-1111-111111111111"),
12 | ]
13 |
14 |
15 | class TaskInfo(BaseModel):
16 | """Task information.
17 |
18 | Attributes:
19 | id (str): The task ID.
20 | endpoint (str): The endpoint to which the task is assigned.
21 | data (Any): The task data.
22 | status (TaskStatus): The task status.
23 | result (Any): The task result.
24 | """
25 |
26 | id: TaskId
27 | endpoint: str = Field(
28 | ..., description="The endpoint to which the task is assigned."
29 | )
30 | data: Any = Field(..., description="The task data.")
31 | status: TaskStatus = Field(..., description="The task status.")
32 | result: Any = Field(None, description="The task result.")
33 |
34 | model_config = ConfigDict(
35 | json_schema_extra={
36 | "examples": [
37 | {
38 | "id": "11111111-1111-1111-1111-111111111111",
39 | "endpoint": "/index",
40 | "data": {
41 | "image": {
42 | "url": "https://example.com/image.jpg",
43 | "media_id": "abc123",
44 | }
45 | },
46 | "status": "running",
47 | "result": None,
48 | }
49 | ]
50 | },
51 | extra="forbid",
52 | )
53 |
54 | @classmethod
55 | def from_entity(cls, task: TaskEntity, is_admin: bool = False) -> "TaskInfo":
56 | """Create a TaskInfo from a TaskEntity."""
57 | # Prepare data (remove ApiKey, None values, etc.)
58 | task_data = {}
59 | for key, value in task.data.items():
60 | if value is None or isinstance(value, ApiKey):
61 | continue
62 | task_data[key] = value
63 |
64 | # Remove stacktrace from result if not admin
65 | if not is_admin and task.result and "stacktrace" in task.result:
66 | task.result.pop("stacktrace", None)
67 |
68 | return TaskInfo(
69 | id=str(task.id),
70 | endpoint=task.endpoint,
71 | data=task_data,
72 | status=task.status,
73 | result=task.result,
74 | )
75 |
--------------------------------------------------------------------------------
/aana/core/models/vad.py:
--------------------------------------------------------------------------------
1 | from typing import Annotated
2 |
3 | from pydantic import BaseModel, ConfigDict, Field
4 |
5 | from aana.core.models.time import TimeInterval
6 |
7 | __all__ = ["VadParams", "VadSegment", "VadSegments"]
8 |
9 |
10 | class VadParams(BaseModel):
11 | """A model for the Voice Activity Detection model parameters.
12 |
13 | Attributes:
14 | chunk_size (float): The maximum length of each vad output chunk.
15 | merge_onset (float): Onset to be used for the merging operation.
16 | merge_offset (float): "Optional offset to be used for the merging operation.
17 | """
18 |
19 | chunk_size: float = Field(
20 | default=30, ge=10.0, description="The maximum length of each vad output chunk."
21 | )
22 |
23 | merge_onset: float = Field(
24 | default=0.5, ge=0.0, description="Onset to be used for the merging operation."
25 | )
26 |
27 | merge_offset: float | None = Field(
28 | default=None,
29 | description="Optional offset to be used for the merging operation.",
30 | )
31 |
32 | model_config = ConfigDict(
33 | json_schema_extra={
34 | "description": "Parameters for the voice activity detection model."
35 | },
36 | extra="forbid",
37 | )
38 |
39 |
40 | class VadSegment(BaseModel):
41 | """Pydantic schema for Segment from Voice Activity Detection model.
42 |
43 | Attributes:
44 | time_interval (TimeInterval): The start and end time of the segment
45 | segments (list[tuple[float, float]]): smaller voiced segments within a merged vad segment
46 | """
47 |
48 | time_interval: TimeInterval = Field(description="Time interval of the segment")
49 | segments: list[tuple[float, float]] = Field(
50 | description="List of voiced segments within a Segment for ASR"
51 | )
52 |
53 | def to_whisper_dict(self) -> dict:
54 | """Generate dictionary with start, end and segments keys from VADSegment for faster whisper.
55 |
56 | Returns:
57 | dict: Dictionary with start, end and segments keys
58 | """
59 | return {
60 | "start": self.time_interval.start,
61 | "end": self.time_interval.end,
62 | "segments": self.segments,
63 | }
64 |
65 | model_config = ConfigDict(
66 | json_schema_extra={
67 | "description": "VAD Segment for ASR",
68 | },
69 | extra="forbid",
70 | )
71 |
72 |
73 | VadSegments = Annotated[
74 | list[VadSegment], Field(description="List of VAD segments", default_factory=list)
75 | ]
76 | """
77 | List of VadSegment objects.
78 | """
79 |
--------------------------------------------------------------------------------
/aana/api/app.py:
--------------------------------------------------------------------------------
1 | from datetime import datetime, timezone
2 |
3 | from fastapi import FastAPI, Request
4 | from fastapi.exceptions import RequestValidationError
5 | from fastapi.middleware.cors import CORSMiddleware
6 | from pydantic import ValidationError
7 | from sqlalchemy import select
8 |
9 | from aana.api.exception_handler import (
10 | aana_exception_handler,
11 | validation_exception_handler,
12 | )
13 | from aana.configs.settings import settings as aana_settings
14 | from aana.exceptions.api_service import (
15 | ApiKeyExpired,
16 | ApiKeyNotFound,
17 | ApiKeyNotProvided,
18 | ApiKeyValidationFailed,
19 | )
20 | from aana.storage.models.api_key import ApiKeyEntity
21 | from aana.storage.session import get_session
22 |
23 | app = FastAPI()
24 |
25 | app.add_middleware(
26 | CORSMiddleware,
27 | allow_origins=aana_settings.cors.allow_origins,
28 | allow_origin_regex=aana_settings.cors.allow_origin_regex,
29 | allow_credentials=aana_settings.cors.allow_credentials,
30 | allow_methods=aana_settings.cors.allow_methods,
31 | allow_headers=aana_settings.cors.allow_headers,
32 | )
33 |
34 | app.add_exception_handler(ValidationError, validation_exception_handler)
35 | app.add_exception_handler(RequestValidationError, validation_exception_handler)
36 | app.add_exception_handler(Exception, aana_exception_handler)
37 |
38 |
39 | @app.middleware("http")
40 | async def api_key_check(request: Request, call_next):
41 | """Middleware to check the API key and subscription status."""
42 | excluded_paths = ["/openapi.json", "/docs", "/redoc", "/api/ready"]
43 | if request.url.path in excluded_paths or request.method == "OPTIONS":
44 | return await call_next(request)
45 |
46 | if aana_settings.api_service.enabled:
47 | api_key = request.headers.get("x-api-key")
48 |
49 | if not api_key:
50 | raise ApiKeyNotProvided()
51 |
52 | async with get_session() as session:
53 | try:
54 | result = await session.execute(
55 | select(ApiKeyEntity).where(ApiKeyEntity.api_key == api_key)
56 | )
57 | api_key_info = result.scalars().first()
58 | except Exception as e:
59 | raise ApiKeyValidationFailed() from e
60 |
61 | if not api_key_info:
62 | raise ApiKeyNotFound(key=api_key)
63 |
64 | if api_key_info.expired_at < datetime.now(timezone.utc):
65 | raise ApiKeyExpired(key=api_key)
66 |
67 | request.state.api_key_info = api_key_info.to_model()
68 |
69 | response = await call_next(request)
70 | return response
71 |
--------------------------------------------------------------------------------
/aana/storage/repository/media.py:
--------------------------------------------------------------------------------
1 | from typing import TypeVar
2 |
3 | from sqlalchemy.ext.asyncio import AsyncSession
4 |
5 | from aana.core.models.media import MediaId
6 | from aana.exceptions.db import MediaIdAlreadyExistsException, NotFoundException
7 | from aana.storage.models.media import MediaEntity
8 | from aana.storage.repository.base import BaseRepository
9 |
10 | M = TypeVar("M", bound=MediaEntity)
11 |
12 |
13 | class MediaRepository(BaseRepository[M]):
14 | """Repository for media files."""
15 |
16 | def __init__(self, session: AsyncSession, model_class: type[M] = MediaEntity):
17 | """Constructor."""
18 | super().__init__(session, model_class)
19 |
20 | async def check_media_exists(self, media_id: MediaId) -> bool:
21 | """Checks if a media file exists in the database.
22 |
23 | Args:
24 | media_id (MediaId): The media ID.
25 |
26 | Returns:
27 | bool: True if the media exists, False otherwise.
28 | """
29 | try:
30 | await self.read(media_id)
31 | except NotFoundException:
32 | return False
33 |
34 | return True
35 |
36 | async def create(self, entity: MediaEntity) -> MediaEntity:
37 | """Inserts a single new entity.
38 |
39 | Args:
40 | entity (MediaEntity): The entity to insert.
41 |
42 | Returns:
43 | MediaEntity: The inserted entity.
44 | """
45 | # TODO: throw MediaIdAlreadyExistsException without checking if media exists.
46 | # The following code is a better way to check if media already exists because it does only one query and
47 | # prevents race condition where two processes try to create the same media.
48 | # But it has an issue with the exception handling.
49 | # Unique constraint violation raises IntegrityError, but IntegrityError much more broader than
50 | # Unique constraint violation. So we cannot catch IntegrityError and raise MediaIdAlreadyExistsException,
51 | # we need to check the exception if it is Unique constraint violation or not and if it's on the media_id column.
52 | # Also different DBMS have different exception messages.
53 | #
54 | # try:
55 | # return super().create(entity)
56 | # except IntegrityError as e:
57 | # self.session.rollback()
58 | # raise MediaIdAlreadyExistsException(self.table_name, entity.id) from e
59 |
60 | if await self.check_media_exists(entity.id):
61 | raise MediaIdAlreadyExistsException(self.table_name, entity.id)
62 |
63 | return await super().create(entity)
64 |
--------------------------------------------------------------------------------
/docs/pages/code_overview.md:
--------------------------------------------------------------------------------
1 | # Code overview
2 |
3 | ```
4 | aana/ | top level source code directory for the project
5 | ├── alembic/ | directory for database migrations
6 | │ └── versions/ | individual migrations
7 | ├── api/ | API functionality
8 | │ ├── api_generation.py | API generation code, defines Endpoint class
9 | │ ├── request_handler.py | request handler routes requests to endpoints
10 | │ ├── exception_handler.py | exception handler to process exceptions and return them as JSON
11 | │ ├── responses.py | custom responses for the API
12 | │ └── app.py | defines the FastAPI app and connects exception handlers
13 | ├── config/ | various configuration objects, including settings, but preconfigured deployments
14 | │ ├── db.py | config for the database
15 | │ ├── deployments.py | preconfigured for deployments
16 | │ └── settings.py | app settings
17 | ├── core/ | core models and functionality
18 | │ ├── models/ | core data models
19 | │ ├── libraries/ | base libraries for audio, images etc.
20 | │ └── chat/ | LLM chat templates
21 | ├── deployments/ | classes for predefined deployments (e.g. Hugging Face Transformers, Whisper, vLLM)
22 | ├── exceptions/ | custom exception classes
23 | ├── integrations/ | integrations with 3rd party libraries
24 | │ ├── external/ | integrations with 3rd party libraries for example image, video, audio processing, download youtube videos, etc.
25 | │ └── haystack/ | integrations with Deepset Haystack
26 | ├── processors/ | utility functions for processing data
27 | ├── storage/ | storage functionality
28 | │ ├── models/ | database models
29 | │ ├── repository/ | repository classes for storage
30 | │ └── services/ | utility functions for storage
31 | ├── tests/ | automated tests for the SDK
32 | │ ├── db/ | tests for database functions
33 | │ ├── deployments/ | tests for model deployments
34 | │ ├── files/ | assets for testing
35 | │ ├── integrations/ | tests for integrations
36 | │ ├── projects/ | test projects
37 | │ └── units/ | unit tests
38 | ├── utils/ | various utility functionality
39 | ├── cli.py | command-line interface to build and deploy the SDK
40 | └── sdk.py | base class to create an SDK instance
41 | ```
42 |
43 |
--------------------------------------------------------------------------------
/aana/tests/units/test_app_upload.py:
--------------------------------------------------------------------------------
1 | # ruff: noqa: S101, S113
2 | import io
3 | import json
4 | from typing import TypedDict
5 |
6 | import requests
7 | from pydantic import BaseModel, ConfigDict, Field
8 |
9 | from aana.api.api_generation import Endpoint
10 | from aana.exceptions.runtime import UploadedFileNotFound
11 |
12 |
13 | class FileUploadModel(BaseModel):
14 | """Model for a file upload input."""
15 |
16 | content: str | None = Field(
17 | None,
18 | description="The name of the file to upload.",
19 | )
20 | _file: bytes | None = None
21 |
22 | def set_files(self, files: dict[str, bytes]):
23 | """Set files."""
24 | if self.content:
25 | if self.content not in files:
26 | raise UploadedFileNotFound(self.content)
27 | self._file = files[self.content]
28 |
29 | model_config = ConfigDict(extra="forbid")
30 |
31 |
32 | class FileUploadEndpointOutput(TypedDict):
33 | """The output of the file upload endpoint."""
34 |
35 | text: str
36 |
37 |
38 | class FileUploadEndpoint(Endpoint):
39 | """File upload endpoint."""
40 |
41 | async def run(self, file: FileUploadModel) -> FileUploadEndpointOutput:
42 | """Upload a file.
43 |
44 | Args:
45 | file (FileUploadModel): The file to upload
46 |
47 | Returns:
48 | FileUploadEndpointOutput: The uploaded file
49 | """
50 | file = file._file
51 | return {"text": file.decode()}
52 |
53 |
54 | deployments = []
55 |
56 | endpoints = [
57 | {
58 | "name": "file_upload",
59 | "path": "/file_upload",
60 | "summary": "Upload a file",
61 | "endpoint_cls": FileUploadEndpoint,
62 | }
63 | ]
64 |
65 |
66 | def test_file_upload_app(create_app):
67 | """Test the app with a file upload endpoint."""
68 | aana_app = create_app(deployments, endpoints)
69 |
70 | port = aana_app.port
71 | route_prefix = ""
72 |
73 | # Check that the server is ready
74 | response = requests.get(f"http://localhost:{port}{route_prefix}/api/ready")
75 | assert response.status_code == 200
76 | assert response.json() == {"ready": True}
77 |
78 | # Test lowercase endpoint
79 | # data = {"content": "file.txt"}
80 | data = {"file": {"content": "file.txt"}}
81 | file = b"Hello world! This is a test."
82 | files = {"file.txt": io.BytesIO(file)}
83 | response = requests.post(
84 | f"http://localhost:{port}{route_prefix}/file_upload",
85 | data={"body": json.dumps(data)},
86 | files=files,
87 | )
88 | assert response.status_code == 200, response.text
89 | text = response.json().get("text")
90 | assert text == file.decode()
91 |
--------------------------------------------------------------------------------
/aana/storage/models/caption.py:
--------------------------------------------------------------------------------
1 | from __future__ import annotations # Let classes use themselves in type annotations
2 |
3 | import typing
4 |
5 | from sqlalchemy import CheckConstraint
6 | from sqlalchemy.orm import Mapped, mapped_column
7 |
8 | from aana.storage.models.base import BaseEntity, TimeStampEntity
9 |
10 | if typing.TYPE_CHECKING:
11 | from aana.core.models.captions import Caption
12 |
13 |
14 | class CaptionEntity(BaseEntity, TimeStampEntity):
15 | """ORM model for video captions.
16 |
17 | Attributes:
18 | id (int): Unique identifier for the caption.
19 | model (str): Name of the model used to generate the caption.
20 | frame_id (int): The 0-based frame id of video for caption.
21 | caption (str): Frame caption.
22 | timestamp (float): Frame timestamp in seconds.
23 | caption_type (str): The type of caption (populated automatically by ORM based on `polymorphic_identity` of subclass).
24 | """
25 |
26 | __tablename__ = "caption"
27 |
28 | id: Mapped[int] = mapped_column(autoincrement=True, primary_key=True)
29 | model: Mapped[str] = mapped_column(
30 | nullable=False, comment="Name of model used to generate the caption"
31 | )
32 | frame_id: Mapped[int] = mapped_column(
33 | CheckConstraint("frame_id >= 0", "frame_id_positive"),
34 | comment="The 0-based frame id of video for caption",
35 | )
36 | caption: Mapped[str] = mapped_column(comment="Frame caption")
37 | timestamp: Mapped[float] = mapped_column(
38 | CheckConstraint("timestamp >= 0", name="timestamp_positive"),
39 | comment="Frame timestamp in seconds",
40 | )
41 | caption_type: Mapped[str] = mapped_column(comment="The type of caption")
42 |
43 | __mapper_args__ = { # noqa: RUF012
44 | "polymorphic_identity": "caption",
45 | "polymorphic_on": "caption_type",
46 | }
47 |
48 | @classmethod
49 | def from_caption_output(
50 | cls,
51 | model_name: str,
52 | caption: Caption,
53 | frame_id: int,
54 | timestamp: float,
55 | ) -> CaptionEntity:
56 | """Converts a Caption pydantic model to a CaptionEntity.
57 |
58 | Args:
59 | model_name (str): Name of the model used to generate the caption.
60 | caption (Caption): Caption pydantic model.
61 | frame_id (int): The 0-based frame id of video for caption.
62 | timestamp (float): Frame timestamp in seconds.
63 |
64 | Returns:
65 | CaptionEntity: ORM model for video captions.
66 | """
67 | return CaptionEntity(
68 | model=model_name,
69 | frame_id=frame_id,
70 | caption=str(caption),
71 | timestamp=timestamp,
72 | )
73 |
--------------------------------------------------------------------------------
/docs/pages/model_hub/index.md:
--------------------------------------------------------------------------------
1 | # Model Hub
2 |
3 |
4 |
5 | Model deployment is a crucial part of the machine learning workflow. Aana SDK uses concept of deployments to serve models.
6 |
7 | The deployments are "recipes" that can be used to deploy models. With the same deployment, you can deploy multiple different models by providing specific configurations.
8 |
9 | Aana SDK comes with a set of predefined deployments, like [VLLMDeployment](./../../reference/deployments.md#aana.deployments.VLLMDeployment) for serving Large Language Models (LLMs) with [vLLM](https://github.com/vllm-project/vllm/) library or [WhisperDeployment](./../../reference/deployments.md#aana.deployments.WhisperDeployment) for automatic Speech Recognition (ASR) based on the [faster-whisper](https://github.com/SYSTRAN/faster-whisper) library.
10 |
11 | Each deployment has its own configuration class that specifies which model to deploy and with which parameters.
12 |
13 | The model hub provides a collection of configurations for different models that can be used with the predefined deployments.
14 |
15 | The full list of predefined deployments can be found in the [Deployments](./../integrations.md).
16 |
17 | !!! tip
18 |
19 | The Model Hub provides only a subset of the available models. You can deploy a lot more models using predefined deployments. For example, [Hugging Face Pipeline Deployment](./../../reference/deployments.md#aana.deployments.HfPipelineDeployment) is a generic deployment that can be used to deploy any model from the [Hugging Face Model Hub](https://huggingface.co/models) that can be used with [Hugging Face Pipelines](https://huggingface.co/transformers/main_classes/pipelines.html). It would be impossible to list all the models that can be deployed with this deployment.
20 |
21 | !!! tip
22 |
23 | The SDK is not limited to the predefined deployments. You can create your own deployment.
24 |
25 | ## How to Use the Model Hub
26 |
27 | There are a few ways to use the Model Hub (from the simplest to the most advanced):
28 |
29 | - Find the model configuration you are interested in and copy the configuration code to your project.
30 |
31 | - Use the provided examples as a starting point to create your own configurations for existing deployments.
32 |
33 | - Create a new deployment with your own configuration.
34 |
35 | See [Tutorial](./../tutorial.md#deployments) for more information on how to use the deployments.
36 |
37 | ## Models by Category
38 |
39 | - [Text Generation Models (LLMs)](./text_generation.md)
40 | - [Image-to-Text Models](./image_to_text.md)
41 | - [Half-Quadratic Quantization Models](./hqq.md)
42 | - [Automatic Speech Recognition (ASR) Models](./asr.md)
43 | - [Hugging Face Pipeline Models](./hf_pipeline.md)
44 |
--------------------------------------------------------------------------------
/aana/api/event_handlers/event_manager.py:
--------------------------------------------------------------------------------
1 | from collections import defaultdict
2 | from collections.abc import MutableMapping
3 |
4 | from aana.api.event_handlers.event_handler import EventHandler
5 | from aana.exceptions.runtime import (
6 | HandlerNotRegisteredException,
7 | )
8 |
9 |
10 | class EventManager:
11 | """Class for event manager. Not guaranteed to be thread safe."""
12 |
13 | _handlers: MutableMapping[str, EventHandler]
14 |
15 | def __init__(self):
16 | """Constructor."""
17 | self._handlers = defaultdict(list)
18 |
19 | def handle(self, event_name: str, *args, **kwargs):
20 | """Trigger event handlers for `event_name`.
21 |
22 | Arguments:
23 | event_name (str): name of event
24 | *args (list): specific args
25 | **kwargs (dict): specific args
26 | """
27 | for handler in self._handlers[event_name]:
28 | handler.handle(event_name, *args, **kwargs)
29 |
30 | def register_handler_for_events(
31 | self, handler: EventHandler, event_names: list[str]
32 | ):
33 | """Adds a handler to the event handler list.
34 |
35 | Arguments:
36 | handler (EventHandler): the handler to deregister
37 | event_names (list[str]): the events from which this handler is to be deregistered
38 | """
39 | for event_name in event_names:
40 | if handler not in self._handlers[event_name]:
41 | self._handlers[event_name].append(handler)
42 |
43 | def deregister_handler_from_event(self, handler: EventHandler, event_name: str):
44 | """Removes a handler from the event handler list.
45 |
46 | Arguments:
47 | handler (EventHandler): the handler to remove
48 | event_name (str): the name of the event from which the handler should be removed
49 | Raises:
50 | HandlerNotRegisteredException: if the handler isn't registered. (embed in try-except to suppress)
51 | """
52 | try:
53 | self._handlers[event_name].remove(handler)
54 | except ValueError as e:
55 | raise HandlerNotRegisteredException() from e
56 |
57 | def deregister_handler_from_all_events(self, handler: EventHandler):
58 | """Removes a handler from all event handlers.
59 |
60 | Arguments:
61 | handler (EventHandler): the exact instance of the handler to remove.
62 |
63 | Raises:
64 | HandlerNotRegisteredException: if the handler isn't registered. (embed in try-except to suppress)
65 | """
66 | has_removed = False
67 | for handler_list in self._handlers.values():
68 | if handler in handler_list:
69 | handler_list.remove(handler)
70 | has_removed = True
71 | if not has_removed:
72 | raise HandlerNotRegisteredException()
73 |
--------------------------------------------------------------------------------
/aana/storage/models/api_key.py:
--------------------------------------------------------------------------------
1 | from sqlalchemy import Boolean
2 | from sqlalchemy.orm import DeclarativeBase, Mapped, mapped_column
3 |
4 | from aana.core.models.api_service import ApiKey
5 | from aana.storage.models.base import TimeStampEntity, timestamp
6 |
7 |
8 | class ApiServiceBase(DeclarativeBase):
9 | """Base class."""
10 |
11 | pass
12 |
13 |
14 | class ApiKeyEntity(ApiServiceBase, TimeStampEntity):
15 | """Table for API keys."""
16 |
17 | __tablename__ = "api_keys"
18 | __entity_name__ = "api key"
19 |
20 | id: Mapped[int] = mapped_column(autoincrement=True, primary_key=True)
21 | key_id: Mapped[str] = mapped_column(
22 | nullable=False, unique=True, comment="The API key id in api gateway"
23 | )
24 | user_id: Mapped[str] = mapped_column(
25 | nullable=False, index=True, comment="ID of the user who owns this API key"
26 | )
27 | api_key: Mapped[str] = mapped_column(
28 | nullable=False, index=True, unique=True, comment="The API key"
29 | )
30 | is_admin: Mapped[bool] = mapped_column(
31 | Boolean, nullable=False, default=False, comment="Whether the user is an admin"
32 | )
33 | subscription_id: Mapped[str] = mapped_column(
34 | nullable=False, comment="ID of the associated subscription"
35 | )
36 | is_subscription_active: Mapped[bool] = mapped_column(
37 | Boolean,
38 | nullable=False,
39 | default=True,
40 | comment="Whether the subscription is active (credits are available)",
41 | )
42 | hmac_secret: Mapped[str] = mapped_column(
43 | nullable=True, comment="The secret key for HMAC signature generation"
44 | )
45 | expired_at: Mapped[timestamp] = mapped_column(
46 | nullable=False, comment="The expiration date of the API key"
47 | )
48 |
49 | def __repr__(self) -> str:
50 | """String representation of the API key."""
51 | return (
52 | f""
63 | )
64 |
65 | def to_model(self) -> ApiKey:
66 | """Convert the object to a dictionary."""
67 | return ApiKey(
68 | api_key=self.api_key,
69 | user_id=self.user_id,
70 | is_admin=self.is_admin,
71 | subscription_id=self.subscription_id,
72 | is_subscription_active=self.is_subscription_active,
73 | hmac_secret=self.hmac_secret,
74 | )
75 |
--------------------------------------------------------------------------------
/aana/api/security.py:
--------------------------------------------------------------------------------
1 | from typing import Annotated
2 |
3 | from fastapi import Depends, Request
4 |
5 | from aana.configs.settings import settings as aana_settings
6 | from aana.core.models.api_service import ApiKey
7 | from aana.exceptions.api_service import AdminOnlyAccess, InactiveSubscription
8 |
9 |
10 | def is_admin(request: Request) -> bool:
11 | """Check if the user is an admin.
12 |
13 | Args:
14 | request (Request): The request object
15 |
16 | Returns:
17 | bool: True if the user is an admin, False otherwise
18 | """
19 | if aana_settings.api_service.enabled:
20 | api_key_info: ApiKey = request.state.api_key_info
21 | return api_key_info.is_admin if api_key_info else False
22 | return True
23 |
24 |
25 | def require_admin_access(request: Request) -> bool:
26 | """Check if the user is an admin. If not, raise an exception.
27 |
28 | Args:
29 | request (Request): The request object
30 |
31 | Raises:
32 | AdminOnlyAccess: If the user is not an admin
33 | """
34 | _is_admin = is_admin(request)
35 | if not _is_admin:
36 | raise AdminOnlyAccess()
37 | return True
38 |
39 |
40 | def extract_api_key_info(request: Request) -> ApiKey | None:
41 | """Get the API key info dependency."""
42 | return getattr(request.state, "api_key_info", None)
43 |
44 |
45 | def extract_user_id(request: Request) -> str | None:
46 | """Get the user ID dependency."""
47 | api_key_info = extract_api_key_info(request)
48 | return api_key_info.user_id if api_key_info else None
49 |
50 |
51 | def require_active_subscription(request: Request) -> bool:
52 | """Check if the user has an active subscription. If not, raise an exception.
53 |
54 | Args:
55 | request (Request): The request object
56 |
57 | Raises:
58 | InactiveSubscription: If the user does not have an active subscription
59 | """
60 | if aana_settings.api_service.enabled:
61 | api_key_info: ApiKey = request.state.api_key_info
62 | if not api_key_info.is_subscription_active:
63 | raise InactiveSubscription(key=api_key_info.api_key)
64 | return True
65 |
66 |
67 | AdminAccessDependency = Annotated[bool, Depends(require_admin_access)]
68 | """ Dependency to check if the user is an admin. If not, it will raise an exception. """
69 |
70 | IsAdminDependency = Annotated[bool, Depends(is_admin)]
71 | """ Dependency to check if the user is an admin. """
72 |
73 | UserIdDependency = Annotated[str | None, Depends(extract_user_id)]
74 | """ Dependency to get the user ID. """
75 |
76 | ApiKeyInfoDependency = Annotated[ApiKey | None, Depends(extract_api_key_info)]
77 | """ Dependency to get the API key info. """
78 |
79 | ActiveSubscriptionRequiredDependency = Annotated[
80 | bool, Depends(require_active_subscription)
81 | ]
82 | """ Dependency to check if the user has an active subscription. If not, it will raise an exception. """
83 |
--------------------------------------------------------------------------------
/aana/storage/models/task.py:
--------------------------------------------------------------------------------
1 | import uuid
2 | from enum import Enum
3 |
4 | from sqlalchemy import (
5 | JSON,
6 | UUID,
7 | PickleType,
8 | )
9 | from sqlalchemy.orm import Mapped, mapped_column
10 |
11 | from aana.storage.models.base import BaseEntity, TimeStampEntity, timestamp
12 |
13 |
14 | class Status(str, Enum):
15 | """Enum for task status.
16 |
17 | Attributes:
18 | CREATED: The task is created.
19 | ASSIGNED: The task is assigned to a worker.
20 | COMPLETED: The task is completed.
21 | RUNNING: The task is running.
22 | FAILED: The task has failed.
23 | NOT_FINISHED: The task is not finished.
24 | """
25 |
26 | CREATED = "created"
27 | ASSIGNED = "assigned"
28 | COMPLETED = "completed"
29 | RUNNING = "running"
30 | FAILED = "failed"
31 | NOT_FINISHED = "not_finished"
32 |
33 |
34 | class TaskEntity(BaseEntity, TimeStampEntity):
35 | """Table for task items."""
36 |
37 | __tablename__ = "tasks"
38 |
39 | id: Mapped[uuid.UUID] = mapped_column(
40 | UUID, primary_key=True, default=uuid.uuid4, comment="Task ID"
41 | )
42 | endpoint: Mapped[str] = mapped_column(
43 | nullable=False, comment="The endpoint to which the task is assigned"
44 | )
45 | data = mapped_column(PickleType, nullable=False, comment="Data for the task")
46 | status: Mapped[Status] = mapped_column(
47 | default=Status.CREATED,
48 | comment="Status of the task",
49 | index=True,
50 | )
51 | priority: Mapped[int] = mapped_column(
52 | nullable=False, default=0, comment="Priority of the task (0 is the lowest)"
53 | )
54 | assigned_at: Mapped[timestamp | None] = mapped_column(
55 | comment="Timestamp when the task was assigned",
56 | )
57 | completed_at: Mapped[timestamp | None] = mapped_column(
58 | server_default=None,
59 | comment="Timestamp when the task was completed",
60 | )
61 | progress: Mapped[float] = mapped_column(
62 | nullable=False, default=0.0, comment="Progress of the task in percentage"
63 | )
64 | result: Mapped[dict | None] = mapped_column(
65 | JSON, comment="Result of the task in JSON format"
66 | )
67 | num_retries: Mapped[int] = mapped_column(
68 | nullable=False, default=0, comment="Number of retries"
69 | )
70 | user_id: Mapped[str] = mapped_column(
71 | nullable=True,
72 | comment="ID of the user who launched the task",
73 | index=True,
74 | )
75 |
76 | def __repr__(self) -> str:
77 | """String representation of the task."""
78 | return (
79 | f""
85 | )
86 |
--------------------------------------------------------------------------------
/aana/core/models/api.py:
--------------------------------------------------------------------------------
1 | from enum import Enum
2 |
3 | from pydantic import BaseModel, ConfigDict, Field
4 | from ray.serve.schema import ApplicationStatus
5 |
6 |
7 | class SDKStatus(str, Enum):
8 | """The status of the SDK."""
9 |
10 | UNHEALTHY = "UNHEALTHY"
11 | RUNNING = "RUNNING"
12 | DEPLOYING = "DEPLOYING"
13 |
14 |
15 | class DeploymentStatus(BaseModel):
16 | """The status of a deployment."""
17 |
18 | status: ApplicationStatus = Field(description="The status of the deployment.")
19 | message: str = Field(
20 | description="The message for more information like error message."
21 | )
22 |
23 | model_config = ConfigDict(extra="forbid")
24 |
25 |
26 | class SDKStatusResponse(BaseModel):
27 | """The response for the SDK status endpoint.
28 |
29 | Attributes:
30 | status (SDKStatus): The status of the SDK.
31 | message (str): The message for more information like error message.
32 | deployments (dict[str, DeploymentStatus]): The status of each deployment in the Aana app.
33 | """
34 |
35 | status: SDKStatus = Field(description="The status of the SDK.")
36 | message: str = Field(
37 | description="The message for more information like error message."
38 | )
39 | deployments: dict[str, DeploymentStatus] = Field(
40 | description="The status of each deployment in the Aana app."
41 | )
42 |
43 | model_config = ConfigDict(
44 | json_schema_extra={
45 | "description": "The response for the SDK status endpoint.",
46 | "examples": [
47 | {
48 | "status": "RUNNING",
49 | "message": "",
50 | "deployments": {
51 | "app": {
52 | "status": "RUNNING",
53 | "message": "",
54 | },
55 | "lowercase_deployment": {
56 | "status": "RUNNING",
57 | "message": "",
58 | },
59 | },
60 | },
61 | {
62 | "status": "UNHEALTHY",
63 | "message": "Error: Lowercase (lowercase_deployment): A replica's health check failed. "
64 | "This deployment will be UNHEALTHY until the replica recovers or a new deploy happens.",
65 | "deployments": {
66 | "app": {
67 | "status": "RUNNING",
68 | "message": "",
69 | },
70 | "lowercase_deployment": {
71 | "status": "UNHEALTHY",
72 | "message": "A replica's health check failed. This deployment will be UNHEALTHY "
73 | "until the replica recovers or a new deploy happens.",
74 | },
75 | },
76 | },
77 | ],
78 | },
79 | extra="forbid",
80 | )
81 |
--------------------------------------------------------------------------------
/aana/storage/models/base.py:
--------------------------------------------------------------------------------
1 | import datetime
2 | from typing import Annotated, Any, TypeVar
3 |
4 | from sqlalchemy import MetaData, String, func
5 | from sqlalchemy.orm import (
6 | DeclarativeBase,
7 | Mapped,
8 | mapped_column,
9 | object_mapper,
10 | registry,
11 | )
12 |
13 | from aana.core.models.media import MediaId
14 | from aana.storage.types import TimezoneAwareDateTime
15 |
16 | timestamp = Annotated[
17 | datetime.datetime,
18 | mapped_column(TimezoneAwareDateTime(timezone=True)),
19 | ]
20 |
21 | T = TypeVar("T", bound="InheritanceReuseMixin")
22 |
23 |
24 | class InheritanceReuseMixin:
25 | """Mixin for instantiating child classes from parent instances."""
26 |
27 | @classmethod
28 | def from_parent(cls: type[T], parent_instance: Any, **kwargs: Any) -> T:
29 | """Create a new instance of the child class, reusing attributes from the parent instance.
30 |
31 | Args:
32 | parent_instance (Any): An instance of the parent class
33 | kwargs (Any): Additional keyword arguments to set on the new instance
34 |
35 | Returns:
36 | T: A new instance of the child class
37 | """
38 | # Get the mapped attributes of the parent class
39 | mapper = object_mapper(parent_instance)
40 | attributes = {
41 | prop.key: getattr(parent_instance, prop.key)
42 | for prop in mapper.iterate_properties
43 | if hasattr(parent_instance, prop.key)
44 | and prop.key
45 | != mapper.polymorphic_on.name # don't copy the polymorphic_on attribute from the parent
46 | }
47 |
48 | # Update attributes with any additional kwargs
49 | attributes.update(kwargs)
50 |
51 | # Create and return a new instance of the child class
52 | return cls(**attributes)
53 |
54 |
55 | class BaseEntity(DeclarativeBase, InheritanceReuseMixin):
56 | """Base for all ORM classes."""
57 |
58 | metadata = MetaData(
59 | naming_convention={
60 | "ix": "ix_%(column_0_label)s",
61 | "uq": "uq_%(table_name)s_%(column_0_name)s",
62 | "ck": "ck_%(table_name)s_`%(constraint_name)s`",
63 | "fk": "fk_%(table_name)s_%(column_0_name)s_%(referred_table_name)s",
64 | "pk": "pk_%(table_name)s",
65 | }
66 | )
67 |
68 | registry = registry(
69 | type_annotation_map={
70 | MediaId: String(36),
71 | }
72 | )
73 |
74 | def __repr__(self) -> str:
75 | """Get the representation of the entity."""
76 | return f"{self.__class__.__name__}(id={self.id})"
77 |
78 |
79 | class TimeStampEntity:
80 | """Mixin for database entities that will have create/update timestamps."""
81 |
82 | created_at: Mapped[timestamp] = mapped_column(
83 | server_default=func.now(),
84 | comment="Timestamp when row is inserted",
85 | )
86 | updated_at: Mapped[timestamp] = mapped_column(
87 | onupdate=func.now(),
88 | server_default=func.now(),
89 | comment="Timestamp when row is updated",
90 | )
91 |
--------------------------------------------------------------------------------
/docs/pages/model_hub/speaker_recognition.md:
--------------------------------------------------------------------------------
1 | # Speaker Recognition
2 |
3 | ## Speaker Diarization (SD) Models
4 |
5 | [PyannoteSpeakerDiarizationDeployment](./../../reference/deployments.md#aana.deployments.pyannote_speaker_diarization_deployment.PyannoteSpeakerDiarizationDeployment) allows you to diarize the audio for speakers audio with pyannote models. The deployment is based on the [pyannote.audio](https://github.com/pyannote/pyannote-audio) library.
6 |
7 | !!! Tip
8 | To use Pyannotate Speaker Diarization deployment, install required libraries with `pip install pyannote-audio` or include extra dependencies using `pip install aana[asr]`.
9 |
10 | [PyannoteSpeakerDiarizationConfig](./../../reference/deployments.md#aana.deployments.pyannote_speaker_diarization_deployment.PyannoteSpeakerDiarizationConfig) is used to configure the Speaker Diarization deployment.
11 |
12 | ::: aana.deployments.pyannote_speaker_diarization_deployment.PyannoteSpeakerDiarizationConfig
13 | options:
14 | show_bases: false
15 | heading_level: 4
16 | show_docstring_description: false
17 | docstring_section_style: list
18 |
19 |
20 | ## Accessing Gated Models
21 |
22 | The PyAnnote speaker diarization models are gated, requiring special access. To use these models:
23 |
24 | 1. **Request Access**:
25 | Visit the [PyAnnote Speaker Diarization 3.1 model page](https://huggingface.co/pyannote/speaker-diarization-3.1) and [Pyannote Speaker Segmentation 3.0 model page](https://huggingface.co/pyannote/segmentation-3.0) on Hugging Face. Log in, fil out the forms, and request access.
26 |
27 | 2. **Approval**:
28 | - If automatic, access is granted immediately.
29 | - If manual, wait for the model authors to approve your request.
30 |
31 | 3. **Set Up the SDK**:
32 | After approval, add your Hugging Face access token to your `.env` file by setting the `HF_TOKEN` variable:
33 |
34 | ```plaintext
35 | HF_TOKEN=your_huggingface_access_token
36 | ```
37 |
38 | To get your Hugging Face access token, visit the [Hugging Face Settings - Tokens](https://huggingface.co/settings/tokens).
39 |
40 |
41 | ## Example Configurations
42 |
43 | As an example, let's see how to configure the Pyannote Speaker Diarization deployment for the [Speaker Diarization-3.1 model](https://huggingface.co/pyannote/speaker-diarization-3.1).
44 |
45 | !!! example "Speaker diarization-3.1"
46 |
47 | ```python
48 | from aana.deployments.pyannote_speaker_diarization_deployment import PyannoteSpeakerDiarizationDeployment, PyannoteSpeakerDiarizationConfig
49 |
50 | PyannoteSpeakerDiarizationDeployment.options(
51 | num_replicas=1,
52 | max_ongoing_requests=1000,
53 | ray_actor_options={"num_gpus": 0.05},
54 | user_config=PyannoteSpeakerDiarizationConfig(
55 | model_name=("pyannote/speaker-diarization-3.1"),
56 | sample_rate=16000,
57 | ).model_dump(mode="json"),
58 | )
59 | ```
60 |
61 | ## Diarized ASR
62 |
63 | Speaker Diarization output can be combined with ASR to generate transcription with speaker information. Further details and code snippet are available in [ASR model hub](./asr.md/#diarized-asr).
64 |
65 |
--------------------------------------------------------------------------------
/docs/pages/openai_api.md:
--------------------------------------------------------------------------------
1 | # OpenAI-compatible API
2 |
3 | Aana SDK provides an OpenAI-compatible Chat Completions API that allows you to integrate Aana with any OpenAI-compatible application.
4 |
5 | Chat Completions API is available at the `/chat/completions` endpoint.
6 |
7 | !!! Tip
8 | The endpoint is enabled by default but can be disabled by setting the environment variable: `OPENAI_ENDPOINT_ENABLED=False`.
9 |
10 | It is compatible with the OpenAI client libraries and can be used as a drop-in replacement for OpenAI API.
11 |
12 | ```python
13 | from openai import OpenAI
14 |
15 | client = OpenAI(
16 | api_key="token", # Any non empty string will work, we don't require an API key
17 | base_url="http://localhost:8000",
18 | )
19 |
20 | messages = [
21 | {"role": "user", "content": "What is the capital of France?"}
22 | ]
23 |
24 | completion = client.chat.completions.create(
25 | messages=messages,
26 | model="llm_deployment",
27 | )
28 |
29 | print(completion.choices[0].message.content)
30 | ```
31 |
32 | The API also supports streaming:
33 |
34 | ```python
35 | from openai import OpenAI
36 |
37 | client = OpenAI(
38 | api_key="token", # Any non empty string will work, we don't require an API key
39 | base_url="http://localhost:8000",
40 | )
41 |
42 | messages = [
43 | {"role": "user", "content": "What is the capital of France?"}
44 | ]
45 |
46 | stream = client.chat.completions.create(
47 | messages=messages,
48 | model="llm_deployment",
49 | stream=True,
50 | )
51 | for chunk in stream:
52 | print(chunk.choices[0].delta.content or "", end="")
53 | ```
54 |
55 | The API requires an LLM deployment. Aana SDK provides support for [vLLM](integrations.md#vllm) and [Hugging Face Transformers](integrations.md#hugging-face-transformers).
56 |
57 | The name of the model matches the name of the deployment. For example, if you registered a vLLM deployment with the name `llm_deployment`, you can use it with the OpenAI API as `model="llm_deployment"`.
58 |
59 | ```python
60 | import os
61 |
62 | os.environ["CUDA_VISIBLE_DEVICES"] = "0"
63 |
64 | from aana.core.models.sampling import SamplingParams
65 | from aana.core.models.types import Dtype
66 | from aana.deployments.vllm_deployment import VLLMConfig, VLLMDeployment
67 | from aana.sdk import AanaSDK
68 |
69 | llm_deployment = VLLMDeployment.options(
70 | num_replicas=1,
71 | ray_actor_options={"num_gpus": 1},
72 | user_config=VLLMConfig(
73 | model="TheBloke/Llama-2-7b-Chat-AWQ",
74 | dtype=Dtype.AUTO,
75 | quantization="awq",
76 | gpu_memory_reserved=13000,
77 | enforce_eager=True,
78 | default_sampling_params=SamplingParams(
79 | temperature=0.0, top_p=1.0, top_k=-1, max_tokens=1024
80 | ),
81 | chat_template="llama2",
82 | ).model_dump(mode="json"),
83 | )
84 |
85 | aana_app = AanaSDK(name="llm_app")
86 | aana_app.register_deployment(name="llm_deployment", instance=llm_deployment)
87 |
88 | if __name__ == "__main__":
89 | aana_app.connect()
90 | aana_app.migrate()
91 | aana_app.deploy()
92 | ```
93 |
94 | You can also use the example project `llama2` to deploy Llama-2-7b Chat model.
95 |
96 | ```bash
97 | CUDA_VISIBLE_DEVICES=0 aana deploy aana.projects.llama2.app:aana_app
98 | ```
99 |
--------------------------------------------------------------------------------
/aana/tests/units/test_app.py:
--------------------------------------------------------------------------------
1 | # ruff: noqa: S101, S113
2 | import json
3 | from typing import TypedDict
4 |
5 | import requests
6 | from ray import serve
7 |
8 | from aana.api.api_generation import Endpoint
9 | from aana.deployments.aana_deployment_handle import AanaDeploymentHandle
10 | from aana.deployments.base_deployment import BaseDeployment
11 |
12 |
13 | @serve.deployment
14 | class Lowercase(BaseDeployment):
15 | """Ray deployment that returns the lowercase version of a text."""
16 |
17 | async def lower(self, text: str) -> dict:
18 | """Lowercase the text.
19 |
20 | Args:
21 | text (str): The text to lowercase
22 |
23 | Returns:
24 | dict: The lowercase text
25 | """
26 | return {"text": text.lower()}
27 |
28 |
29 | class LowercaseEndpointOutput(TypedDict):
30 | """The output of the lowercase endpoint."""
31 |
32 | text: str
33 |
34 |
35 | class LowercaseEndpoint(Endpoint):
36 | """Lowercase endpoint."""
37 |
38 | async def initialize(self):
39 | """Initialize the endpoint."""
40 | self.lowercase_handle = await AanaDeploymentHandle.create(
41 | "lowercase_deployment"
42 | )
43 | await super().initialize()
44 |
45 | async def run(self, text: str) -> LowercaseEndpointOutput:
46 | """Lowercase the text.
47 |
48 | Args:
49 | text (str): The list of text to lowercase
50 |
51 | Returns:
52 | LowercaseEndpointOutput: The lowercase texts
53 | """
54 | lowercase_output = await self.lowercase_handle.lower(text=text)
55 | return {"text": lowercase_output["text"]}
56 |
57 |
58 | deployments = [
59 | {
60 | "name": "lowercase_deployment",
61 | "instance": Lowercase,
62 | }
63 | ]
64 |
65 | endpoints = [
66 | {
67 | "name": "lowercase",
68 | "path": "/lowercase",
69 | "summary": "Lowercase text",
70 | "endpoint_cls": LowercaseEndpoint,
71 | }
72 | ]
73 |
74 |
75 | def test_app(create_app):
76 | """Test the Ray Serve app."""
77 | aana_app = create_app(deployments, endpoints)
78 |
79 | port = aana_app.port
80 | route_prefix = ""
81 |
82 | # Check that the server is ready
83 | response = requests.get(f"http://localhost:{port}{route_prefix}/api/ready")
84 | assert response.status_code == 200
85 | assert response.json() == {"ready": True}
86 |
87 | # Test lowercase endpoint
88 | data = {"text": "Hello World! This is a test."}
89 | response = requests.post(
90 | f"http://localhost:{port}{route_prefix}/lowercase",
91 | data={"body": json.dumps(data)},
92 | )
93 | assert response.status_code == 200
94 | lowercase_text = response.json().get("text")
95 | assert lowercase_text == "hello world! this is a test."
96 |
97 | # Test that extra fields are not allowed
98 | data = {"text": "Hello World! This is a test.", "extra_field": "extra_value"}
99 | response = requests.post(
100 | f"http://localhost:{port}{route_prefix}/lowercase",
101 | data={"body": json.dumps(data)},
102 | )
103 | assert response.status_code == 422, response.text
104 |
--------------------------------------------------------------------------------
/aana/tests/db/datastore/test_caption_repo.py:
--------------------------------------------------------------------------------
1 | # ruff: noqa: S101
2 |
3 | import random
4 | import uuid
5 |
6 | import pytest
7 |
8 | from aana.core.models.captions import Caption
9 | from aana.exceptions.db import NotFoundException
10 | from aana.storage.repository.caption import CaptionRepository
11 |
12 |
13 | @pytest.fixture(scope="function")
14 | def dummy_caption():
15 | """Creates a dummy caption for testing."""
16 | caption = Caption(f"This is a caption {uuid.uuid4()}")
17 | frame_id = random.randint(0, 100) # noqa: S311
18 | timestamp = random.random() # noqa: S311
19 | return caption, frame_id, timestamp
20 |
21 |
22 | @pytest.mark.asyncio
23 | async def test_save_caption(db_session_manager, dummy_caption):
24 | """Tests saving a caption."""
25 | async with db_session_manager.session() as session:
26 | caption, frame_id, timestamp = dummy_caption
27 | model_name = "blip2"
28 |
29 | caption_repo = CaptionRepository(session)
30 | caption_entity = await caption_repo.save(
31 | model_name=model_name,
32 | caption=caption,
33 | frame_id=frame_id,
34 | timestamp=timestamp,
35 | )
36 | caption_id = caption_entity.id
37 |
38 | caption_entity = await caption_repo.read(caption_id)
39 | assert caption_entity.model == model_name
40 | assert caption_entity.frame_id == frame_id
41 | assert caption_entity.timestamp == timestamp
42 | assert caption_entity.caption == caption
43 |
44 | await caption_repo.delete(caption_id)
45 | with pytest.raises(NotFoundException):
46 | await caption_repo.read(caption_id)
47 |
48 |
49 | @pytest.mark.asyncio
50 | async def test_save_all_captions(db_session_manager, dummy_caption):
51 | """Tests saving all captions."""
52 | async with db_session_manager.session() as session:
53 | captions, frame_ids, timestamps = [], [], []
54 | for _ in range(3):
55 | caption, frame_id, timestamp = dummy_caption
56 | captions.append(caption)
57 | frame_ids.append(frame_id)
58 | timestamps.append(timestamp)
59 | model_name = "blip2"
60 |
61 | caption_repo = CaptionRepository(session)
62 | caption_entities = await caption_repo.save_all(
63 | model_name=model_name,
64 | captions=captions,
65 | timestamps=timestamps,
66 | frame_ids=frame_ids,
67 | )
68 | assert len(caption_entities) == len(captions)
69 |
70 | caption_ids = [caption_entity.id for caption_entity in caption_entities]
71 | for caption_id, caption, frame_id, timestamp in zip(
72 | caption_ids, captions, frame_ids, timestamps, strict=True
73 | ):
74 | caption_entity = await caption_repo.read(caption_id)
75 |
76 | assert caption_entity.model == model_name
77 | assert caption_entity.frame_id == frame_id
78 | assert caption_entity.timestamp == timestamp
79 | assert caption_entity.caption == caption
80 |
81 | # delete all captions
82 | for caption_id in caption_ids:
83 | await caption_repo.delete(caption_id)
84 | with pytest.raises(NotFoundException):
85 | await caption_repo.read(caption_id)
86 |
--------------------------------------------------------------------------------
/aana/exceptions/api_service.py:
--------------------------------------------------------------------------------
1 | from aana.exceptions.core import BaseException
2 |
3 |
4 | class ApiKeyNotProvided(BaseException):
5 | """Exception raised when the API key is not provided."""
6 |
7 | def __init__(self):
8 | """Initialize the exception."""
9 | self.message = "API key not provided"
10 | super().__init__(message=self.message)
11 |
12 | def __reduce__(self):
13 | """Used for pickling."""
14 | return (self.__class__, ())
15 |
16 |
17 | class ApiKeyNotFound(BaseException):
18 | """Exception raised when the API key is not found.
19 |
20 | Attributes:
21 | key (str): the API key that was not found
22 | """
23 |
24 | def __init__(self, key: str):
25 | """Initialize the exception.
26 |
27 | Args:
28 | key (str): the API key that was not found
29 | """
30 | self.key = key
31 | self.message = f"API key {key} not found"
32 | super().__init__(key=key, message=self.message)
33 |
34 | def __reduce__(self):
35 | """Used for pickling."""
36 | return (self.__class__, (self.key,))
37 |
38 |
39 | class InactiveSubscription(BaseException):
40 | """Exception raised when the subscription is inactive (e.g. credits are not available).
41 |
42 | Attributes:
43 | key (str): the API key with inactive subscription
44 | """
45 |
46 | def __init__(self, key: str):
47 | """Initialize the exception.
48 |
49 | Args:
50 | key (str): the API key with inactive subscription
51 | """
52 | self.key = key
53 | self.message = (
54 | f"API key {key} has an inactive subscription. Check your credits."
55 | )
56 | super().__init__(key=key, message=self.message)
57 |
58 | def __reduce__(self):
59 | """Used for pickling."""
60 | return (self.__class__, (self.key,))
61 |
62 |
63 | class AdminOnlyAccess(BaseException):
64 | """Exception raised when the user does not have enough permissions."""
65 |
66 | def __init__(self):
67 | """Initialize the exception."""
68 | self.message = "Admin only access"
69 | super().__init__(message=self.message)
70 |
71 | def __reduce__(self):
72 | """Used for pickling."""
73 | return (self.__class__, ())
74 |
75 |
76 | class ApiKeyValidationFailed(BaseException):
77 | """Exception raised when the API key validation fails."""
78 |
79 | def __init__(self):
80 | """Initialize the exception."""
81 | self.message = "API key validation failed"
82 | super().__init__(message=self.message)
83 |
84 | def __reduce__(self):
85 | """Used for pickling."""
86 | return (self.__class__, ())
87 |
88 |
89 | class ApiKeyExpired(BaseException):
90 | """Exception raised when the API key is expired."""
91 |
92 | def __init__(self, key: str):
93 | """Initialize the exception.
94 |
95 | Args:
96 | key (str): the expired API key
97 | """
98 | self.key = key
99 | self.message = f"API key {key} is expired."
100 | super().__init__(key=key, message=self.message)
101 |
102 | def __reduce__(self):
103 | """Used for pickling."""
104 | return (self.__class__, (self.key,))
105 |
--------------------------------------------------------------------------------
/aana/storage/models/transcript.py:
--------------------------------------------------------------------------------
1 | from __future__ import annotations # Let classes use themselves in type annotations
2 |
3 | from typing import TYPE_CHECKING
4 |
5 | from sqlalchemy import JSON, CheckConstraint
6 | from sqlalchemy.orm import Mapped, mapped_column
7 |
8 | from aana.storage.models.base import BaseEntity, TimeStampEntity
9 |
10 | if TYPE_CHECKING:
11 | from aana.core.models.asr import (
12 | AsrSegments,
13 | AsrTranscription,
14 | AsrTranscriptionInfo,
15 | )
16 |
17 |
18 | class TranscriptEntity(BaseEntity, TimeStampEntity):
19 | """ORM class for media transcripts generated by a model.
20 |
21 | Attributes:
22 | id (int): Unique identifier for the transcript.
23 | model (str): Name of the model used to generate the transcript.
24 | transcript (str): Full text transcript of the media.
25 | segments (dict): Segments of the transcript.
26 | language (str): Language of the transcript as predicted by the model.
27 | language_confidence (float): Confidence score of language prediction.
28 | transcript_type (str): The type of transcript (populated automatically by ORM based on `polymorphic_identity` of subclass).
29 | """
30 |
31 | __tablename__ = "transcript"
32 |
33 | id: Mapped[int] = mapped_column(autoincrement=True, primary_key=True)
34 | model: Mapped[str] = mapped_column(
35 | nullable=False, comment="Name of model used to generate transcript"
36 | )
37 | transcript: Mapped[str] = mapped_column(comment="Full text transcript of media")
38 | segments: Mapped[dict] = mapped_column(JSON, comment="Segments of the transcript")
39 | language: Mapped[str] = mapped_column(
40 | comment="Language of the transcript as predicted by model"
41 | )
42 | language_confidence: Mapped[float] = mapped_column(
43 | CheckConstraint(
44 | "0 <= language_confidence <= 1", name="language_confidence_value_range"
45 | ),
46 | comment="Confidence score of language prediction",
47 | )
48 | transcript_type: Mapped[str] = mapped_column(comment="The type of transcript")
49 |
50 | __mapper_args__ = { # noqa: RUF012
51 | "polymorphic_identity": "transcript",
52 | "polymorphic_on": "transcript_type",
53 | }
54 |
55 | @classmethod
56 | def from_asr_output(
57 | cls,
58 | model_name: str,
59 | info: AsrTranscriptionInfo,
60 | transcription: AsrTranscription,
61 | segments: AsrSegments,
62 | ) -> TranscriptEntity:
63 | """Converts an AsrTranscriptionInfo and AsrTranscription to a single Transcript entity.
64 |
65 | Args:
66 | model_name (str): Name of the model used to generate the transcript.
67 | info (AsrTranscriptionInfo): Information about the transcription.
68 | transcription (AsrTranscription): The full transcription.
69 | segments (AsrSegments): Segments of the transcription.
70 |
71 | Returns:
72 | TranscriptEntity: A new instance of the TranscriptEntity class.
73 | """
74 | return TranscriptEntity(
75 | model=model_name,
76 | language=info.language,
77 | language_confidence=info.language_confidence,
78 | transcript=transcription.text,
79 | segments=[s.model_dump() for s in segments],
80 | )
81 |
--------------------------------------------------------------------------------