├── aana ├── __init__.py ├── core │ ├── __init__.py │ ├── models │ │ ├── __init__.py │ │ ├── file.py │ │ ├── captions.py │ │ ├── time.py │ │ ├── exception.py │ │ ├── custom_config.py │ │ ├── api_service.py │ │ ├── types.py │ │ ├── speaker.py │ │ ├── task.py │ │ ├── vad.py │ │ └── api.py │ ├── chat │ │ ├── templates │ │ │ ├── __init__.py │ │ │ └── llama2.jinja │ │ └── chat_template.py │ └── libraries │ │ ├── audio.py │ │ └── image.py ├── tests │ ├── __init__.py │ ├── units │ │ ├── __init__.py │ │ ├── test_question.py │ │ ├── test_typed_dict.py │ │ ├── test_media_id.py │ │ ├── test_rate_limiter.py │ │ ├── test_deployment_retry.py │ │ ├── test_app_deploy.py │ │ ├── test_merge_options.py │ │ ├── test_whisper_params.py │ │ ├── test_sampling_params.py │ │ ├── test_deployment_restart.py │ │ ├── test_settings.py │ │ ├── test_event_manager.py │ │ ├── test_speaker.py │ │ ├── test_app_upload.py │ │ └── test_app.py │ ├── projects │ │ ├── __init__.py │ │ └── lowercase │ │ │ ├── __init__.py │ │ │ └── app.py │ ├── integration │ │ └── __init__.py │ ├── const.py │ ├── files │ │ ├── expected │ │ │ ├── vad │ │ │ │ └── squirrel.json │ │ │ ├── hf_blip2 │ │ │ │ └── blip2_deployment_Starry_Night.jpeg.json │ │ │ ├── idefics │ │ │ │ └── idefics_2_8b_deployment_Starry_Night.jpeg_394ead9926577785413d0c748ebf9878.json │ │ │ ├── image_text_generation │ │ │ │ ├── phi-3.5-vision-instruct_vllm_deployment_Starry_Night.jpeg_394ead9926577785413d0c748ebf9878.json │ │ │ │ ├── pixtral_12b_2409_vllm_deployment_Starry_Night.jpeg_394ead9926577785413d0c748ebf9878.json │ │ │ │ ├── qwen2_vl_3b_gemlite_vllm_deployment_Starry_Night.jpeg_394ead9926577785413d0c748ebf9878.json │ │ │ │ ├── qwen2_vl_7b_instruct_vllm_deployment_Starry_Night.jpeg_394ead9926577785413d0c748ebf9878.json │ │ │ │ ├── qwen2_vl_3b_gemlite_vllm_deployment_video_Starry_Night.jpeg_394ead9926577785413d0c748ebf9878.json │ │ │ │ ├── qwen2_vl_3b_gemlite_onthefly_vllm_deployment_Starry_Night.jpeg_394ead9926577785413d0c748ebf9878.json │ │ │ │ └── qwen2_vl_3b_gemlite_onthefly_vllm_deployment_video_Starry_Night.jpeg_394ead9926577785413d0c748ebf9878.json │ │ │ ├── hf_pipeline │ │ │ │ └── hf_pipeline_blip2_deployment_Starry_Night.jpeg.json │ │ │ ├── text_generation │ │ │ │ ├── phi-3.5-vision-instruct_vllm_deployment_2f8c1d10dab7f75e25cc2d4f29c54469.json │ │ │ │ ├── phi3_mini_4k_instruct_vllm_deployment_2f8c1d10dab7f75e25cc2d4f29c54469.json │ │ │ │ ├── phi3_mini_4k_instruct_hf_text_generation_deployment_2f8c1d10dab7f75e25cc2d4f29c54469.json │ │ │ │ ├── qwen2_vl_7b_instruct_vllm_deployment_2f8c1d10dab7f75e25cc2d4f29c54469.json │ │ │ │ ├── pixtral_12b_2409_vllm_deployment_2f8c1d10dab7f75e25cc2d4f29c54469.json │ │ │ │ ├── qwen1.5_1.8b_chat_vllm_deployment_with_lora_2f8c1d10dab7f75e25cc2d4f29c54469.json │ │ │ │ ├── gemma_3_1b_it_hf_text_generation_deployment_2f8c1d10dab7f75e25cc2d4f29c54469.json │ │ │ │ ├── qwen2_vl_3b_gemlite_vllm_deployment_2f8c1d10dab7f75e25cc2d4f29c54469.json │ │ │ │ └── qwen2_vl_3b_gemlite_onthefly_vllm_deployment_2f8c1d10dab7f75e25cc2d4f29c54469.json │ │ │ ├── whisper │ │ │ │ ├── whisper_medium_squirrel.wav.json │ │ │ │ ├── whisper_tiny_squirrel.wav.json │ │ │ │ ├── whisper_turbo_squirrel.wav.json │ │ │ │ ├── whisper_medium_squirrel.wav_batched.json │ │ │ │ ├── whisper_tiny_squirrel.wav_batched.json │ │ │ │ └── whisper_turbo_squirrel.wav_batched.json │ │ │ ├── hqq_generation │ │ │ │ ├── meta-llama │ │ │ │ │ └── Meta-Llama-3.1-8B-Instruct_2f8c1d10dab7f75e25cc2d4f29c54469.json │ │ │ │ └── mobiuslabsgmbh │ │ │ │ │ └── Llama-3.1-8b-instruct_4bitgs64_hqq_calib_2f8c1d10dab7f75e25cc2d4f29c54469.json │ │ │ └── sd │ │ │ │ └── sd_sample.json │ │ ├── audios │ │ │ ├── squirrel.wav │ │ │ ├── physicsworks.wav │ │ │ └── ATTRIBUTION.md │ │ ├── videos │ │ │ ├── squirrel.mp4 │ │ │ ├── physicsworks.webm │ │ │ ├── squirrel_no_audio.mp4 │ │ │ ├── physicsworks_audio.webm │ │ │ └── ATTRIBUTION.md │ │ └── images │ │ │ ├── Starry_Night.jpeg │ │ │ └── ATTRIBUTION.md │ ├── db │ │ └── datastore │ │ │ ├── test_config.py │ │ │ ├── test_video_repo.py │ │ │ ├── test_transcript_repo.py │ │ │ └── test_caption_repo.py │ └── deployments │ │ ├── test_vad_deployment.py │ │ ├── test_pyannote_speaker_diarization_deployment.py │ │ ├── test_hf_blip2_deployment.py │ │ └── test_hf_pipeline_deployment.py ├── utils │ ├── __init__.py │ ├── openapi_code_templates │ │ ├── __init__.py │ │ ├── curl_form.j2 │ │ ├── python_form.j2 │ │ └── python_form_streaming.j2 │ ├── gpu.py │ ├── streamer.py │ ├── file.py │ ├── typing.py │ ├── asyncio.py │ ├── lazy_import.py │ └── json.py ├── alembic │ ├── __init__.py │ ├── README │ ├── script.py.mako │ └── versions │ │ ├── b9860676dd49_set_server_default_for_task_completed_.py │ │ ├── d40eba8ebc4c_added_user_id_to_tasks.py │ │ └── acb40dabc2c0_added_webhooks.py ├── processors │ ├── __init__.py │ ├── remote.py │ └── video.py ├── routers │ └── __init__.py ├── storage │ ├── __init__.py │ ├── repository │ │ ├── __init__.py │ │ ├── video.py │ │ ├── transcript.py │ │ ├── caption.py │ │ └── media.py │ ├── session.py │ ├── models │ │ ├── media.py │ │ ├── video.py │ │ ├── __init__.py │ │ ├── webhook.py │ │ ├── caption.py │ │ ├── api_key.py │ │ ├── task.py │ │ ├── base.py │ │ └── transcript.py │ └── types.py ├── integrations │ ├── __init__.py │ ├── external │ │ └── __init__.py │ └── haystack │ │ ├── __init__.py │ │ └── remote_haystack_component.py ├── api │ ├── __init__.py │ ├── event_handlers │ │ ├── event_handler.py │ │ ├── rate_limit_handler.py │ │ └── event_manager.py │ ├── responses.py │ ├── app.py │ └── security.py ├── exceptions │ ├── __init__.py │ ├── core.py │ ├── db.py │ └── api_service.py ├── deployments │ ├── __init__.py │ └── haystack_component_deployment.py └── configs │ ├── __init__.py │ └── db.py ├── docs ├── reference │ ├── sdk.md │ ├── settings.md │ ├── endpoint.md │ ├── storage │ │ ├── models.md │ │ └── repositories.md │ ├── models │ │ ├── asr.md │ │ ├── chat.md │ │ ├── time.md │ │ ├── vad.md │ │ ├── speaker.md │ │ ├── types.md │ │ ├── video.md │ │ ├── captions.md │ │ ├── custom_config.md │ │ ├── sampling.md │ │ ├── whisper.md │ │ ├── image_chat.md │ │ ├── multimodal_chat.md │ │ └── media.md │ ├── utils.md │ ├── exceptions.md │ ├── processors.md │ ├── integrations.md │ └── index.md ├── images │ ├── favicon.ico │ ├── white_logo.png │ ├── AanaSDK_logo_dark_theme.png │ └── AanaSDK_logo_light_theme.png ├── pages │ ├── model_hub │ │ ├── vad.md │ │ ├── hf_pipeline.md │ │ ├── index.md │ │ └── speaker_recognition.md │ ├── dev_environment.md │ ├── code_standards.md │ ├── settings.md │ ├── serve_config_files.md │ ├── docker.md │ ├── code_overview.md │ └── openai_api.md └── stylesheets │ └── extra.css ├── .env ├── .vscode ├── extensions.json └── settings.json ├── mypy.ini ├── test_volume.dstack.yml ├── .devcontainer ├── Dockerfile └── devcontainer.json ├── .github ├── ISSUE_TEMPLATE │ ├── enhancement.md │ ├── question.md │ ├── testing.md │ ├── feature_request.md │ ├── documentation.md │ └── bug_report.md └── workflows │ ├── publish_docs.yml │ ├── run_tests_with_gpu.yml │ ├── tests.yml │ └── publish.yml ├── pg_ctl.sh └── tests.dstack.yml /aana/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /aana/core/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /aana/tests/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /aana/utils/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /aana/alembic/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /aana/core/models/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /aana/processors/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /aana/routers/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /aana/storage/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /aana/tests/units/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /aana/integrations/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /aana/tests/projects/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /aana/core/chat/templates/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /aana/integrations/external/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /aana/storage/repository/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /aana/tests/integration/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /aana/tests/projects/lowercase/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /aana/utils/openapi_code_templates/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /aana/alembic/README: -------------------------------------------------------------------------------- 1 | Generic single-database configuration. -------------------------------------------------------------------------------- /docs/reference/sdk.md: -------------------------------------------------------------------------------- 1 | # SDK 2 | 3 | ::: aana.sdk.AanaSDK -------------------------------------------------------------------------------- /aana/tests/const.py: -------------------------------------------------------------------------------- 1 | ALLOWED_LEVENSTEIN_ERROR_RATE = 0.15 2 | -------------------------------------------------------------------------------- /docs/reference/settings.md: -------------------------------------------------------------------------------- 1 | # Settings 2 | 3 | ::: aana.configs -------------------------------------------------------------------------------- /docs/reference/endpoint.md: -------------------------------------------------------------------------------- 1 | # Endpoint 2 | 3 | ::: aana.api.Endpoint -------------------------------------------------------------------------------- /aana/tests/files/expected/vad/squirrel.json: -------------------------------------------------------------------------------- 1 | { 2 | "segments": [] 3 | } -------------------------------------------------------------------------------- /docs/reference/storage/models.md: -------------------------------------------------------------------------------- 1 | # Models 2 | 3 | ::: aana.storage.models -------------------------------------------------------------------------------- /docs/reference/models/asr.md: -------------------------------------------------------------------------------- 1 | # ASR Models 2 | 3 | ::: aana.core.models.asr -------------------------------------------------------------------------------- /.env: -------------------------------------------------------------------------------- 1 | CUDA_VISIBLE_DEVICES="" 2 | HF_HUB_ENABLE_HF_TRANSFER = 1 3 | HF_TOKEN="" 4 | -------------------------------------------------------------------------------- /docs/reference/models/chat.md: -------------------------------------------------------------------------------- 1 | # Chat Models 2 | 3 | ::: aana.core.models.chat 4 | -------------------------------------------------------------------------------- /docs/reference/models/time.md: -------------------------------------------------------------------------------- 1 | # Time Models 2 | 3 | ::: aana.core.models.time 4 | -------------------------------------------------------------------------------- /docs/reference/models/vad.md: -------------------------------------------------------------------------------- 1 | # VAD Models 2 | 3 | ::: aana.core.models.vad 4 | -------------------------------------------------------------------------------- /docs/reference/models/speaker.md: -------------------------------------------------------------------------------- 1 | # Speaker Models 2 | 3 | ::: aana.core.models.speaker -------------------------------------------------------------------------------- /docs/reference/models/types.md: -------------------------------------------------------------------------------- 1 | # Types Models 2 | 3 | ::: aana.core.models.types 4 | -------------------------------------------------------------------------------- /docs/reference/models/video.md: -------------------------------------------------------------------------------- 1 | # Video Models 2 | 3 | ::: aana.core.models.video 4 | -------------------------------------------------------------------------------- /docs/reference/models/captions.md: -------------------------------------------------------------------------------- 1 | # Caption Models 2 | 3 | ::: aana.core.models.captions 4 | -------------------------------------------------------------------------------- /docs/images/favicon.ico: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dropbox/aana_sdk/main/docs/images/favicon.ico -------------------------------------------------------------------------------- /docs/reference/models/custom_config.md: -------------------------------------------------------------------------------- 1 | # Custom Config 2 | 3 | ::: aana.core.models.custom_config -------------------------------------------------------------------------------- /docs/reference/models/sampling.md: -------------------------------------------------------------------------------- 1 | # Sampling Models 2 | 3 | ::: aana.core.models.sampling 4 | -------------------------------------------------------------------------------- /docs/reference/models/whisper.md: -------------------------------------------------------------------------------- 1 | # Whisper Models 2 | 3 | ::: aana.core.models.whisper 4 | 5 | -------------------------------------------------------------------------------- /.vscode/extensions.json: -------------------------------------------------------------------------------- 1 | { 2 | "recommendations": [ 3 | "charliermarsh.ruff" 4 | ] 5 | } -------------------------------------------------------------------------------- /aana/api/__init__.py: -------------------------------------------------------------------------------- 1 | from aana.api.api_generation import Endpoint 2 | 3 | __all__ = ["Endpoint"] 4 | -------------------------------------------------------------------------------- /docs/reference/models/image_chat.md: -------------------------------------------------------------------------------- 1 | # Image Chat Models 2 | 3 | ::: aana.core.models.image_chat 4 | -------------------------------------------------------------------------------- /docs/images/white_logo.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dropbox/aana_sdk/main/docs/images/white_logo.png -------------------------------------------------------------------------------- /aana/tests/files/expected/hf_blip2/blip2_deployment_Starry_Night.jpeg.json: -------------------------------------------------------------------------------- 1 | "the starry night by vincent van gogh" -------------------------------------------------------------------------------- /aana/exceptions/__init__.py: -------------------------------------------------------------------------------- 1 | from aana.exceptions.core import BaseException 2 | 3 | __all__ = ["BaseException"] 4 | -------------------------------------------------------------------------------- /docs/pages/model_hub/vad.md: -------------------------------------------------------------------------------- 1 | # Voice Activity Detection (VAD) Models 2 | 3 | #TODO: Make VAD deployment more generic -------------------------------------------------------------------------------- /docs/reference/models/multimodal_chat.md: -------------------------------------------------------------------------------- 1 | # Multimodal Chat Models 2 | 3 | ::: aana.core.models.multimodal_chat 4 | -------------------------------------------------------------------------------- /aana/tests/files/audios/squirrel.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dropbox/aana_sdk/main/aana/tests/files/audios/squirrel.wav -------------------------------------------------------------------------------- /aana/tests/files/expected/idefics/idefics_2_8b_deployment_Starry_Night.jpeg_394ead9926577785413d0c748ebf9878.json: -------------------------------------------------------------------------------- 1 | "Van gogh." -------------------------------------------------------------------------------- /aana/tests/files/videos/squirrel.mp4: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dropbox/aana_sdk/main/aana/tests/files/videos/squirrel.mp4 -------------------------------------------------------------------------------- /mypy.ini: -------------------------------------------------------------------------------- 1 | [mypy] 2 | ignore_missing_imports = False 3 | allow_redefinition = True 4 | mypy_path = $MYPY_CONFIG_FILE_DIR/aana 5 | -------------------------------------------------------------------------------- /aana/tests/files/audios/physicsworks.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dropbox/aana_sdk/main/aana/tests/files/audios/physicsworks.wav -------------------------------------------------------------------------------- /docs/images/AanaSDK_logo_dark_theme.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dropbox/aana_sdk/main/docs/images/AanaSDK_logo_dark_theme.png -------------------------------------------------------------------------------- /docs/images/AanaSDK_logo_light_theme.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dropbox/aana_sdk/main/docs/images/AanaSDK_logo_light_theme.png -------------------------------------------------------------------------------- /test_volume.dstack.yml: -------------------------------------------------------------------------------- 1 | type: volume 2 | 3 | name: test-models-cache 4 | 5 | backend: runpod 6 | region: EU-SE-1 7 | 8 | size: 150GB -------------------------------------------------------------------------------- /aana/tests/files/images/Starry_Night.jpeg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dropbox/aana_sdk/main/aana/tests/files/images/Starry_Night.jpeg -------------------------------------------------------------------------------- /aana/tests/files/videos/physicsworks.webm: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dropbox/aana_sdk/main/aana/tests/files/videos/physicsworks.webm -------------------------------------------------------------------------------- /aana/tests/files/videos/squirrel_no_audio.mp4: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dropbox/aana_sdk/main/aana/tests/files/videos/squirrel_no_audio.mp4 -------------------------------------------------------------------------------- /aana/tests/files/videos/physicsworks_audio.webm: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dropbox/aana_sdk/main/aana/tests/files/videos/physicsworks_audio.webm -------------------------------------------------------------------------------- /docs/reference/utils.md: -------------------------------------------------------------------------------- 1 | # Utility functions 2 | 3 | ::: aana.utils.asyncio 4 | ::: aana.utils.json 5 | ::: aana.utils.download 6 | ::: aana.utils.gpu -------------------------------------------------------------------------------- /docs/reference/exceptions.md: -------------------------------------------------------------------------------- 1 | # Exceptions 2 | 3 | ::: aana.exceptions 4 | ::: aana.exceptions.runtime 5 | ::: aana.exceptions.io 6 | ::: aana.exceptions.db -------------------------------------------------------------------------------- /aana/tests/files/expected/image_text_generation/phi-3.5-vision-instruct_vllm_deployment_Starry_Night.jpeg_394ead9926577785413d0c748ebf9878.json: -------------------------------------------------------------------------------- 1 | " Vincent van Gogh" -------------------------------------------------------------------------------- /aana/integrations/haystack/__init__.py: -------------------------------------------------------------------------------- 1 | from aana.integrations.haystack.deployment_component import AanaDeploymentComponent 2 | 3 | __all__ = ["AanaDeploymentComponent"] 4 | -------------------------------------------------------------------------------- /aana/tests/files/expected/hf_pipeline/hf_pipeline_blip2_deployment_Starry_Night.jpeg.json: -------------------------------------------------------------------------------- 1 | [ 2 | { 3 | "generated_text": "the starry night by van gogh\n" 4 | } 5 | ] -------------------------------------------------------------------------------- /docs/reference/models/media.md: -------------------------------------------------------------------------------- 1 | # Media Models 2 | 3 | The `aana.core.models` provides models for such media types as audio, video, and images. 4 | 5 | ::: aana.core.models -------------------------------------------------------------------------------- /aana/tests/files/expected/image_text_generation/pixtral_12b_2409_vllm_deployment_Starry_Night.jpeg_394ead9926577785413d0c748ebf9878.json: -------------------------------------------------------------------------------- 1 | "The painter of the image is Vincent Van Gogh." -------------------------------------------------------------------------------- /aana/tests/files/expected/image_text_generation/qwen2_vl_3b_gemlite_vllm_deployment_Starry_Night.jpeg_394ead9926577785413d0c748ebf9878.json: -------------------------------------------------------------------------------- 1 | "The painter of the image is Vincent van Gogh." -------------------------------------------------------------------------------- /aana/tests/files/expected/image_text_generation/qwen2_vl_7b_instruct_vllm_deployment_Starry_Night.jpeg_394ead9926577785413d0c748ebf9878.json: -------------------------------------------------------------------------------- 1 | "The painter of the image is Vincent van Gogh." -------------------------------------------------------------------------------- /aana/tests/files/expected/image_text_generation/qwen2_vl_3b_gemlite_vllm_deployment_video_Starry_Night.jpeg_394ead9926577785413d0c748ebf9878.json: -------------------------------------------------------------------------------- 1 | "The painter of the image is Vincent van Gogh." -------------------------------------------------------------------------------- /aana/tests/files/expected/image_text_generation/qwen2_vl_3b_gemlite_onthefly_vllm_deployment_Starry_Night.jpeg_394ead9926577785413d0c748ebf9878.json: -------------------------------------------------------------------------------- 1 | "The painter of the image is Vincent van Gogh." -------------------------------------------------------------------------------- /docs/reference/processors.md: -------------------------------------------------------------------------------- 1 | # Processors 2 | 3 | ::: aana.processors.remote 4 | ::: aana.processors.video 5 | ::: aana.processors.batch 6 | ::: aana.processors.speaker.PostProcessingForDiarizedAsr -------------------------------------------------------------------------------- /aana/tests/files/expected/image_text_generation/qwen2_vl_3b_gemlite_onthefly_vllm_deployment_video_Starry_Night.jpeg_394ead9926577785413d0c748ebf9878.json: -------------------------------------------------------------------------------- 1 | "The painter of the image is Vincent van Gogh." -------------------------------------------------------------------------------- /aana/core/models/file.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | from typing import TypedDict 3 | 4 | 5 | class PathResult(TypedDict): 6 | """Represents a path result describing a file on disk.""" 7 | 8 | path: Path 9 | -------------------------------------------------------------------------------- /aana/utils/openapi_code_templates/curl_form.j2: -------------------------------------------------------------------------------- 1 | curl {{ base_url }}{{ path }} \ 2 | --request POST \ 3 | --header 'Content-Type: application/x-www-form-urlencoded' \ 4 | --header 'x-api-key: YOUR_API_KEY_HERE' \ 5 | --data 'body={{ body }}' -------------------------------------------------------------------------------- /aana/tests/files/expected/text_generation/phi-3.5-vision-instruct_vllm_deployment_2f8c1d10dab7f75e25cc2d4f29c54469.json: -------------------------------------------------------------------------------- 1 | " Elon Musk is a business magnate, industrial designer, and engineer. He is the CEO and lead designer of SpaceX, CEO and product" -------------------------------------------------------------------------------- /aana/tests/files/expected/text_generation/phi3_mini_4k_instruct_vllm_deployment_2f8c1d10dab7f75e25cc2d4f29c54469.json: -------------------------------------------------------------------------------- 1 | " Elon Musk is a business magnate, industrial designer, and engineer. He is the founder, CEO, CTO, and chief designer of Space" -------------------------------------------------------------------------------- /docs/reference/integrations.md: -------------------------------------------------------------------------------- 1 | # Integrations 2 | 3 | ::: aana.integrations.haystack 4 | ::: aana.integrations.external.av 5 | ::: aana.integrations.external.decord 6 | ::: aana.integrations.external.opencv 7 | ::: aana.integrations.external.yt_dlp -------------------------------------------------------------------------------- /aana/deployments/__init__.py: -------------------------------------------------------------------------------- 1 | from aana.deployments.aana_deployment_handle import AanaDeploymentHandle 2 | from aana.deployments.base_deployment import BaseDeployment 3 | 4 | __all__ = [ 5 | "AanaDeploymentHandle", 6 | "BaseDeployment", 7 | ] 8 | -------------------------------------------------------------------------------- /aana/tests/files/expected/text_generation/phi3_mini_4k_instruct_hf_text_generation_deployment_2f8c1d10dab7f75e25cc2d4f29c54469.json: -------------------------------------------------------------------------------- 1 | "Elon Musk is a business magnate, industrial designer, and engineer. He is the founder, CEO, CTO, and chief designer of Space" -------------------------------------------------------------------------------- /aana/tests/files/expected/text_generation/qwen2_vl_7b_instruct_vllm_deployment_2f8c1d10dab7f75e25cc2d4f29c54469.json: -------------------------------------------------------------------------------- 1 | "Elon Musk is a South African-born entrepreneur, inventor, and business magnate. He is the CEO of SpaceX, the founder of Tesla, Inc., and" -------------------------------------------------------------------------------- /aana/tests/files/expected/whisper/whisper_medium_squirrel.wav.json: -------------------------------------------------------------------------------- 1 | { 2 | "segments": [], 3 | "transcription_info": { 4 | "language": "silence", 5 | "language_confidence": 1.0 6 | }, 7 | "transcription": { 8 | "text": "" 9 | } 10 | } -------------------------------------------------------------------------------- /aana/tests/files/expected/whisper/whisper_tiny_squirrel.wav.json: -------------------------------------------------------------------------------- 1 | { 2 | "segments": [], 3 | "transcription_info": { 4 | "language": "silence", 5 | "language_confidence": 1.0 6 | }, 7 | "transcription": { 8 | "text": "" 9 | } 10 | } -------------------------------------------------------------------------------- /aana/tests/files/expected/hqq_generation/meta-llama/Meta-Llama-3.1-8B-Instruct_2f8c1d10dab7f75e25cc2d4f29c54469.json: -------------------------------------------------------------------------------- 1 | "Elon Musk is a South African-born entrepreneur, inventor, and business magnate. He is best known for his ambitious ventures in the fields of space exploration," -------------------------------------------------------------------------------- /aana/tests/files/expected/text_generation/pixtral_12b_2409_vllm_deployment_2f8c1d10dab7f75e25cc2d4f29c54469.json: -------------------------------------------------------------------------------- 1 | "Elon Musk is a South African-born Canadian-American business magnate, industrial designer, and engineer. He is known for his work in the fields of electric vehicles" -------------------------------------------------------------------------------- /aana/tests/files/expected/text_generation/qwen1.5_1.8b_chat_vllm_deployment_with_lora_2f8c1d10dab7f75e25cc2d4f29c54469.json: -------------------------------------------------------------------------------- 1 | "Elon Musk is a South African-American entrepreneur, investor, engineer, and inventor who co-founded SpaceX, Tesla, Neuralink, The Boring Company, and" -------------------------------------------------------------------------------- /.devcontainer/Dockerfile: -------------------------------------------------------------------------------- 1 | FROM nvidia/cuda:12.3.2-cudnn9-devel-ubuntu22.04 2 | RUN apt-get update && apt-get install -y libgl1 libglib2.0-0 ffmpeg locales 3 | 4 | # Set the locale 5 | RUN locale-gen en_US.UTF-8 6 | ENV LANG="en_US.UTF-8" LANGUAGE="en_US:en" LC_ALL="en_US.UTF-8" 7 | -------------------------------------------------------------------------------- /aana/tests/files/expected/text_generation/gemma_3_1b_it_hf_text_generation_deployment_2f8c1d10dab7f75e25cc2d4f29c54469.json: -------------------------------------------------------------------------------- 1 | "Okay, let's break down who Elon Musk is. He's a truly fascinating and incredibly influential figure, often described as a disruptive innovator and a visionary" -------------------------------------------------------------------------------- /aana/tests/files/expected/whisper/whisper_turbo_squirrel.wav.json: -------------------------------------------------------------------------------- 1 | { 2 | "segments": [], 3 | "transcription_info": { 4 | "language": "silence", 5 | "language_confidence": 1.0 6 | }, 7 | "transcription": { 8 | "text": "" 9 | } 10 | } -------------------------------------------------------------------------------- /aana/tests/files/expected/text_generation/qwen2_vl_3b_gemlite_vllm_deployment_2f8c1d10dab7f75e25cc2d4f29c54469.json: -------------------------------------------------------------------------------- 1 | "Elon Musk is an American entrepreneur, investor, and inventor. He is the founder and CEO of SpaceX, a private aerospace manufacturer and space transportation company. He" -------------------------------------------------------------------------------- /aana/tests/files/expected/whisper/whisper_medium_squirrel.wav_batched.json: -------------------------------------------------------------------------------- 1 | { 2 | "segments": [], 3 | "transcription": { 4 | "text": "" 5 | }, 6 | "transcription_info": { 7 | "language": "silence", 8 | "language_confidence": 1.0 9 | } 10 | } -------------------------------------------------------------------------------- /aana/tests/files/expected/whisper/whisper_tiny_squirrel.wav_batched.json: -------------------------------------------------------------------------------- 1 | { 2 | "segments": [], 3 | "transcription": { 4 | "text": "" 5 | }, 6 | "transcription_info": { 7 | "language": "silence", 8 | "language_confidence": 1.0 9 | } 10 | } -------------------------------------------------------------------------------- /aana/tests/files/expected/whisper/whisper_turbo_squirrel.wav_batched.json: -------------------------------------------------------------------------------- 1 | { 2 | "segments": [], 3 | "transcription": { 4 | "text": "" 5 | }, 6 | "transcription_info": { 7 | "language": "silence", 8 | "language_confidence": 1.0 9 | } 10 | } -------------------------------------------------------------------------------- /aana/tests/files/expected/hqq_generation/mobiuslabsgmbh/Llama-3.1-8b-instruct_4bitgs64_hqq_calib_2f8c1d10dab7f75e25cc2d4f29c54469.json: -------------------------------------------------------------------------------- 1 | "Elon Musk is a South African-born entrepreneur, inventor, and business magnate. He is best known for his ambitious ventures in the fields of space exploration," -------------------------------------------------------------------------------- /aana/tests/files/expected/text_generation/qwen2_vl_3b_gemlite_onthefly_vllm_deployment_2f8c1d10dab7f75e25cc2d4f29c54469.json: -------------------------------------------------------------------------------- 1 | "Elon Musk is an American entrepreneur, investor, and engineer who has made significant contributions to the fields of space exploration, electric vehicles, and renewable energy. He" -------------------------------------------------------------------------------- /aana/tests/units/test_question.py: -------------------------------------------------------------------------------- 1 | # ruff: noqa: S101 2 | 3 | from aana.core.models.chat import Question 4 | 5 | 6 | def test_question_creation(): 7 | """Test that a question can be created.""" 8 | question = Question("What is the capital of France?") 9 | assert question == "What is the capital of France?" 10 | -------------------------------------------------------------------------------- /docs/stylesheets/extra.css: -------------------------------------------------------------------------------- 1 | :root { 2 | --md-primary-fg-color: #A66CFF; 3 | --md-primary-fg-color--light: #CFF500; 4 | --md-primary-fg-color--dark: #3E3E3E; 5 | --md-primary-bg-color: #F0F0F0; 6 | --md-primary-bg-color--light: #FFFFFF; 7 | } 8 | 9 | /* .md-header__topic { 10 | display: none; 11 | } */ -------------------------------------------------------------------------------- /docs/pages/dev_environment.md: -------------------------------------------------------------------------------- 1 | # Dev Environment 2 | 3 | If you are using Visual Studio Code, you can run this repository in a 4 | [dev container](https://code.visualstudio.com/docs/devcontainers/containers). This lets you install and 5 | run everything you need for the repo in an isolated environment via docker on a host system. 6 | 7 | -------------------------------------------------------------------------------- /docs/reference/storage/repositories.md: -------------------------------------------------------------------------------- 1 | # Repositories 2 | 3 | ::: aana.storage.repository.base 4 | ::: aana.storage.repository.media 5 | ::: aana.storage.repository.video 6 | ::: aana.storage.repository.caption 7 | ::: aana.storage.repository.transcript 8 | ::: aana.storage.repository.task 9 | ::: aana.storage.repository.webhook 10 | -------------------------------------------------------------------------------- /aana/configs/__init__.py: -------------------------------------------------------------------------------- 1 | from aana.configs.db import DbSettings, DbType, PostgreSQLConfig, SQLiteConfig 2 | from aana.configs.settings import Settings, TaskQueueSettings, TestSettings 3 | 4 | __all__ = [ 5 | "DbSettings", 6 | "DbType", 7 | "PostgreSQLConfig", 8 | "SQLiteConfig", 9 | "Settings", 10 | "TaskQueueSettings", 11 | "TestSettings", 12 | ] 13 | -------------------------------------------------------------------------------- /aana/core/models/captions.py: -------------------------------------------------------------------------------- 1 | from typing import Annotated 2 | 3 | from pydantic import Field 4 | 5 | __all__ = ["Caption", "CaptionsList"] 6 | 7 | Caption = Annotated[str, Field(description="A caption.")] 8 | """ 9 | A caption. 10 | """ 11 | 12 | CaptionsList = Annotated[list[Caption], Field(description="A list of captions.")] 13 | """ 14 | A list of captions. 15 | """ 16 | -------------------------------------------------------------------------------- /aana/utils/gpu.py: -------------------------------------------------------------------------------- 1 | __all__ = ["get_gpu_memory"] 2 | 3 | 4 | def get_gpu_memory(gpu: int = 0) -> int: 5 | """Get the total memory of a GPU in bytes. 6 | 7 | Args: 8 | gpu (int): the GPU index. Defaults to 0. 9 | 10 | Returns: 11 | int: the total memory of the GPU in bytes 12 | """ 13 | import torch 14 | 15 | return torch.cuda.get_device_properties(gpu).total_memory 16 | -------------------------------------------------------------------------------- /aana/utils/openapi_code_templates/python_form.j2: -------------------------------------------------------------------------------- 1 | import json, requests 2 | 3 | payload = {{ body }} 4 | 5 | headers = { 6 | "Content-Type": "application/x-www-form-urlencoded", 7 | "x-api-key": "YOUR_API_KEY_HERE" 8 | } 9 | 10 | response = requests.post("{{ base_url }}{{ path }}", 11 | headers=headers, 12 | data={"body": json.dumps(payload)} 13 | ) 14 | 15 | data = response.json() 16 | print(data) -------------------------------------------------------------------------------- /aana/utils/streamer.py: -------------------------------------------------------------------------------- 1 | import asyncio 2 | from queue import Empty 3 | 4 | 5 | async def async_streamer_adapter(streamer): 6 | """Adapt the TextIteratorStreamer to an async generator.""" 7 | while True: 8 | try: 9 | for item in streamer: 10 | yield item 11 | break 12 | except Empty: 13 | # wait for the next item 14 | await asyncio.sleep(0.01) 15 | -------------------------------------------------------------------------------- /.github/ISSUE_TEMPLATE/enhancement.md: -------------------------------------------------------------------------------- 1 | --- 2 | name: Enhancement 3 | about: Recommend improvement to existing features and suggest code quality improvement. 4 | title: "[ENHANCEMENT]" 5 | labels: enhancement 6 | assignees: '' 7 | 8 | --- 9 | 10 | ### Enhancement Description 11 | - Overview of the enhancement 12 | 13 | ### Advantages 14 | - Benefits of implementing this enhancement 15 | 16 | ### Possible Implementation 17 | - Suggested methods for implementing the enhancement 18 | -------------------------------------------------------------------------------- /.github/ISSUE_TEMPLATE/question.md: -------------------------------------------------------------------------------- 1 | --- 2 | name: Question 3 | about: Ask questions to clarify project-related queries, seek further information, 4 | or understand functionalities better. 5 | title: "[QUESTION]" 6 | labels: question 7 | assignees: '' 8 | 9 | --- 10 | 11 | ### Context 12 | - Background or context of the question 13 | 14 | ### Question 15 | - Specific question being asked 16 | 17 | ### What You've Tried 18 | - List any solutions or research already conducted 19 | -------------------------------------------------------------------------------- /.github/ISSUE_TEMPLATE/testing.md: -------------------------------------------------------------------------------- 1 | --- 2 | name: Testing 3 | about: Address needs for creating new tests, enhancing existing tests, or reporting 4 | test failures. 5 | title: "[TESTS]" 6 | labels: tests 7 | assignees: '' 8 | 9 | --- 10 | 11 | ### Testing Requirement 12 | - Describe the testing need or issue 13 | 14 | ### Test Scenarios 15 | - Detail specific test scenarios to be addressed 16 | 17 | ### Acceptance Criteria 18 | - What are the criteria for the test to be considered successful? 19 | -------------------------------------------------------------------------------- /.vscode/settings.json: -------------------------------------------------------------------------------- 1 | { 2 | "notebook.formatOnSave.enabled": true, 3 | "[python]": { 4 | "editor.defaultFormatter": "charliermarsh.ruff", 5 | "editor.formatOnSave": true, 6 | }, 7 | "python.testing.pytestArgs": [ 8 | "aana" 9 | ], 10 | "python.testing.unittestEnabled": false, 11 | "python.testing.pytestEnabled": true, 12 | "python.testing.pytestPath": "poetry run pytest", 13 | "ruff.fixAll": true, 14 | "ruff.organizeImports": true, 15 | } -------------------------------------------------------------------------------- /.github/ISSUE_TEMPLATE/feature_request.md: -------------------------------------------------------------------------------- 1 | --- 2 | name: Feature request 3 | about: Suggest new functionalities or modifications to enhance the application's capabilities. 4 | title: "[FEATURE REQUEST]" 5 | labels: feature request 6 | assignees: '' 7 | 8 | --- 9 | 10 | ### Feature Summary 11 | - Concise description of the feature 12 | 13 | ### Justification/Rationale 14 | - Why is the feature beneficial? 15 | 16 | ### Proposed Implementation (if any) 17 | - How do you envision this feature's implementation? 18 | -------------------------------------------------------------------------------- /aana/utils/openapi_code_templates/python_form_streaming.j2: -------------------------------------------------------------------------------- 1 | import json, requests 2 | import ijson # pip install ijson 3 | 4 | payload = {{ body }} 5 | 6 | headers = { 7 | "Content-Type": "application/x-www-form-urlencoded", 8 | "x-api-key": "YOUR_API_KEY_HERE" 9 | } 10 | 11 | response = requests.post("{{ base_url }}{{ path }}", 12 | headers=headers, 13 | data={"body": json.dumps(payload)}, 14 | stream=True 15 | ) 16 | 17 | for obj in ijson.items(response.raw, '', multiple_values=True, buf_size=32): 18 | print(obj) -------------------------------------------------------------------------------- /pg_ctl.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # Get the PostgreSQL major version (e.g., 14) 4 | PG_VERSION=$(psql --version | awk '{print $3}' | cut -d '.' -f 1) 5 | 6 | if [ "$2" = "--pgdata" ]; then 7 | # Recursively change parent directories permissions so that pg_ctl can access the necessary files as non-root. 8 | f=$(dirname "$3") 9 | while [[ $f != / ]]; do chmod o+rwx "$f"; f=$(dirname "$f"); done; 10 | fi 11 | 12 | # Execute pg_ctl as postgres instead of root. 13 | sudo -u postgres "/usr/lib/postgresql/$PG_VERSION/bin/pg_ctl" "$@" 14 | -------------------------------------------------------------------------------- /aana/api/event_handlers/event_handler.py: -------------------------------------------------------------------------------- 1 | from abc import ABC, abstractmethod 2 | 3 | 4 | class EventHandler(ABC): 5 | """Base class for event handlers. Not guaranteed to be thread safe.""" 6 | 7 | @abstractmethod 8 | def handle(self, event_name: str, *args, **kwargs): 9 | """Handles an event of the given name. 10 | 11 | Arguments: 12 | event_name (str): name of the event to handle 13 | *args (list): specific, context-dependent args 14 | **kwargs (dict): specific, context-dependent args 15 | """ 16 | pass 17 | -------------------------------------------------------------------------------- /aana/tests/units/test_typed_dict.py: -------------------------------------------------------------------------------- 1 | # ruff: noqa: S101 2 | 3 | from typing import TypedDict 4 | 5 | from aana.utils.typing import is_typed_dict 6 | 7 | 8 | def test_is_typed_dict() -> None: 9 | """Test the is_typed_dict function.""" 10 | 11 | class Foo(TypedDict): 12 | foo: str 13 | bar: int 14 | 15 | assert is_typed_dict(Foo) == True 16 | 17 | 18 | def test_is_not_typed_dict(): 19 | """Test the is_typed_dict function.""" 20 | 21 | class Foo: 22 | pass 23 | 24 | assert is_typed_dict(Foo) == False 25 | assert is_typed_dict(dict) == False 26 | -------------------------------------------------------------------------------- /.github/ISSUE_TEMPLATE/documentation.md: -------------------------------------------------------------------------------- 1 | --- 2 | name: Documentation 3 | about: Propose updates or corrections to both internal development documentation and 4 | client-facing documentation. 5 | title: "[DOCS]" 6 | labels: documentation 7 | assignees: '' 8 | 9 | --- 10 | 11 | ### Documentation Area (Development/Client) 12 | - Specify the area of documentation 13 | 14 | ### Current Content 15 | - Quote the current content or describe the issue 16 | 17 | ### Proposed Changes 18 | - Detail the proposed changes 19 | 20 | ### Reasons for Changes 21 | - Why these changes will improve the documentation 22 | -------------------------------------------------------------------------------- /aana/core/models/time.py: -------------------------------------------------------------------------------- 1 | from pydantic import BaseModel, ConfigDict, Field 2 | 3 | __all__ = ["TimeInterval"] 4 | 5 | 6 | class TimeInterval(BaseModel): 7 | """Pydantic schema for TimeInterval. 8 | 9 | Attributes: 10 | start (float): Start time in seconds 11 | end (float): End time in seconds 12 | """ 13 | 14 | start: float = Field(ge=0.0, description="Start time in seconds") 15 | end: float = Field(ge=0.0, description="End time in seconds") 16 | model_config = ConfigDict( 17 | json_schema_extra={ 18 | "description": "Time interval in seconds", 19 | }, 20 | extra="forbid", 21 | ) 22 | -------------------------------------------------------------------------------- /aana/tests/files/images/ATTRIBUTION.md: -------------------------------------------------------------------------------- 1 | # Third-Party Media Attribution 2 | 3 | This folder contains third-party media used for testing. No endorsement implied. 4 | 5 | _Last updated: 2025-08-22_ 6 | 7 | | File | Source | License/Status | Author/Owner | Changes Made | 8 | |---|---|---|---|---| 9 | | `Starry_Night.jpeg` | Wikimedia Commons — https://commons.wikimedia.org/wiki/File:Van_Gogh_-_Starry_Night_-_Google_Art_Project.jpg (faithful reproduction; Google Arts & Culture — bgEuwDxel93-Pg) | Public domain (faithful reproduction of a 2D public-domain work) | Vincent van Gogh (original work); reproduction per Commons record | Downloaded and saved as JPEG for testing. | 10 | -------------------------------------------------------------------------------- /.github/ISSUE_TEMPLATE/bug_report.md: -------------------------------------------------------------------------------- 1 | --- 2 | name: Bug report 3 | about: Report a malfunction, glitch, or error in the application. This includes any 4 | performance-related issues that may arise. 5 | title: "[BUG]" 6 | labels: bug 7 | assignees: '' 8 | 9 | --- 10 | 11 | ### Bug Description 12 | - Brief summary of the issue 13 | 14 | ### Steps to Reproduce 15 | 1. 16 | 2. 17 | 3. 18 | 19 | ### Expected Behavior 20 | - What should have happened? 21 | 22 | ### Actual Behavior 23 | - What actually happened? 24 | 25 | ### Performance Details (if applicable) 26 | - Specifics of the performance issue encountered 27 | 28 | ### Environment 29 | - Version (commit hash, tag, branch) 30 | -------------------------------------------------------------------------------- /aana/core/models/exception.py: -------------------------------------------------------------------------------- 1 | from typing import Any 2 | 3 | from pydantic import BaseModel, ConfigDict 4 | 5 | 6 | class ExceptionResponseModel(BaseModel): 7 | """This class is used to represent an exception response for 400 errors. 8 | 9 | Attributes: 10 | error (str): The error that occurred. 11 | message (str): The message of the error. 12 | data (Optional[Any]): The extra data that helps to debug the error. 13 | stacktrace (Optional[str]): The stacktrace of the error. 14 | """ 15 | 16 | error: str 17 | message: str 18 | data: Any | None = None 19 | stacktrace: str | None = None 20 | model_config = ConfigDict(extra="forbid") 21 | -------------------------------------------------------------------------------- /.github/workflows/publish_docs.yml: -------------------------------------------------------------------------------- 1 | name: Publish Docs 2 | 3 | on: 4 | release: 5 | types: [published] 6 | workflow_dispatch: 7 | 8 | jobs: 9 | publish: 10 | runs-on: ubuntu-latest 11 | 12 | steps: 13 | - name: Checkout code 14 | uses: actions/checkout@v3 15 | 16 | - name: Set up Python 3.10 17 | uses: actions/setup-python@v5 18 | with: 19 | python-version: "3.10" 20 | 21 | - name: Bootstrap uv 22 | uses: astral-sh/setup-uv@v6 23 | with: 24 | version: "latest" 25 | 26 | - name: Install dependencies 27 | run: uv sync --only-group docs 28 | 29 | - name: Deploy docs 30 | run: uv run mkdocs gh-deploy --force 31 | -------------------------------------------------------------------------------- /aana/processors/remote.py: -------------------------------------------------------------------------------- 1 | import inspect 2 | from collections.abc import Callable 3 | 4 | import ray 5 | 6 | __all__ = ["run_remote"] 7 | 8 | 9 | def run_remote(func: Callable) -> Callable: 10 | """Wrap a function to run it remotely on Ray. 11 | 12 | Args: 13 | func (Callable): the function to wrap 14 | 15 | Returns: 16 | Callable: the wrapped function 17 | """ 18 | 19 | async def generator_wrapper(*args, **kwargs): 20 | async for item in ray.remote(func).remote(*args, **kwargs): 21 | yield await item 22 | 23 | if inspect.isgeneratorfunction(func): 24 | return generator_wrapper 25 | else: 26 | return ray.remote(func).remote 27 | -------------------------------------------------------------------------------- /docs/pages/code_standards.md: -------------------------------------------------------------------------------- 1 | # Code Standards 2 | 3 | This project uses Ruff for linting and formatting. If you want to 4 | manually run Ruff on the codebase, using uv it's 5 | 6 | ```sh 7 | uv run ruff check aana 8 | ``` 9 | 10 | You can automatically fix some issues with the `--fix` 11 | and `--unsafe-fixes` options. (Be sure to install the dev 12 | dependencies: `uv sync --group dev`. ) 13 | 14 | To run the auto-formatter, it's 15 | 16 | ```sh 17 | uv run ruff format aana 18 | ``` 19 | 20 | (If you are running code in a non-uv environment, just leave off `uv run`.) 21 | 22 | For users of VS Code, the included `settings.json` should ensure 23 | that Ruff problems appear while you edit, and formatting is applied 24 | automatically on save. 25 | -------------------------------------------------------------------------------- /aana/utils/file.py: -------------------------------------------------------------------------------- 1 | import hashlib 2 | from pathlib import Path 3 | 4 | 5 | def get_sha256_hash_file(filename: Path) -> str: 6 | """Compute SHA-256 hash of a file without loading it entirely in memory. 7 | 8 | Args: 9 | filename (Path): Path to the file to be hashed. 10 | 11 | Returns: 12 | str: SHA-256 hash of the file in hexadecimal format. 13 | """ 14 | # Create a sha256 hash object 15 | sha256 = hashlib.sha256() 16 | 17 | # Open the file in binary mode 18 | with Path.open(filename, "rb") as f: 19 | # Read and update hash in chunks of 4K 20 | for chunk in iter(lambda: f.read(4096), b""): 21 | sha256.update(chunk) 22 | 23 | # Return the hexadecimal representation of the digest 24 | return sha256.hexdigest() 25 | -------------------------------------------------------------------------------- /aana/core/chat/templates/llama2.jinja: -------------------------------------------------------------------------------- 1 | {% if messages[0]['role'] == 'system' %}{% set loop_messages = messages[1:] %}{% set system_message = messages[0]['content'] %}{% else %}{% set loop_messages = messages %}{% set system_message = false %}{% endif %}{% for message in loop_messages %}{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}{% endif %}{% if loop.index0 == 0 and system_message != false %}{% set content = '<>\\n' + system_message + '\\n<>\\n\\n' + message['content'] %}{% else %}{% set content = message['content'] %}{% endif %}{% if message['role'] == 'user' %}{{ bos_token + '[INST] ' + content.strip() + ' [/INST]' }}{% elif message['role'] == 'assistant' %}{{ ' ' + content.strip() + ' ' + eos_token }}{% endif %}{% endfor %} -------------------------------------------------------------------------------- /aana/tests/files/audios/ATTRIBUTION.md: -------------------------------------------------------------------------------- 1 | # Third-Party Media Attribution 2 | 3 | This folder contains third-party media used for testing. No endorsement implied. 4 | 5 | _Last updated: 2025-08-22_ 6 | 7 | | File | Source | License | Author/Owner | Changes Made | 8 | |---|---|---|---|---| 9 | | `squirrel.wav` | Video by **Nathan J Hilton** on Pexels — https://www.pexels.com/video/a-squirrel-is-sitting-on-top-of-a-table-17977045/ | Pexels License — https://www.pexels.com/license/ | Nathan J Hilton | Converted from source video to WAV for testing. | 10 | | `physicsworks.wav` | Wikimedia Commons — https://commons.wikimedia.org/wiki/File:Physicsworks.ogv (clip from MIT OCW “Work, Energy, and Universal Gravitation”, Lecture 11) | CC BY 3.0 — https://creativecommons.org/licenses/by/3.0/ | Walter Lewin | Extracted audio and transcoded to WAV for testing; trimmed for length. | 11 | -------------------------------------------------------------------------------- /aana/alembic/script.py.mako: -------------------------------------------------------------------------------- 1 | """${message} 2 | 3 | Revision ID: ${up_revision} 4 | Revises: ${down_revision | comma,n} 5 | Create Date: ${create_date} 6 | 7 | """ 8 | from typing import Sequence 9 | 10 | from alembic import op 11 | import sqlalchemy as sa 12 | ${imports if imports else ""} 13 | 14 | # revision identifiers, used by Alembic. 15 | revision: str = ${repr(up_revision)} 16 | down_revision: str | None = ${repr(down_revision)} 17 | branch_labels: str | Sequence[str] | None = ${repr(branch_labels)} 18 | depends_on: str | Sequence[str] | None = ${repr(depends_on)} 19 | 20 | 21 | def upgrade() -> None: 22 | """Upgrade database to this revision from previous.""" 23 | ${upgrades if upgrades else "pass"} 24 | 25 | 26 | def downgrade() -> None: 27 | """Downgrade database from this revision to previous.""" 28 | ${downgrades if downgrades else "pass"} 29 | -------------------------------------------------------------------------------- /aana/processors/video.py: -------------------------------------------------------------------------------- 1 | from aana.core.models.audio import Audio 2 | from aana.core.models.video import Video 3 | from aana.integrations.external.av import load_audio 4 | 5 | __all__ = ["extract_audio"] 6 | 7 | 8 | def extract_audio(video: Video) -> Audio: 9 | """Extract the audio file from a Video and return an Audio object. 10 | 11 | Args: 12 | video (Video): The video file to extract audio. 13 | 14 | Returns: 15 | Audio: an Audio object containing the extracted audio. 16 | """ 17 | audio_bytes = load_audio(video.path) 18 | 19 | # Only difference will be in path where WAV file will be stored 20 | # and in content but has same media_id 21 | return Audio( 22 | url=video.url, 23 | media_id=f"audio_{video.media_id}", 24 | content=audio_bytes, 25 | title=video.title, 26 | description=video.description, 27 | ) 28 | -------------------------------------------------------------------------------- /tests.dstack.yml: -------------------------------------------------------------------------------- 1 | type: task 2 | 3 | name: aana-tests 4 | 5 | backends: [runpod] 6 | 7 | image: nvidia/cuda:12.3.2-cudnn9-devel-ubuntu22.04 8 | 9 | env: 10 | - HF_TOKEN 11 | 12 | commands: 13 | - apt-get update 14 | - DEBIAN_FRONTEND=noninteractive apt-get install -y --no-install-recommends tzdata 15 | - apt-get install -y libgl1 libglib2.0-0 ffmpeg python3 python3-dev postgresql sudo 16 | - locale-gen en_US.UTF-8 17 | - export LANG="en_US.UTF-8" LANGUAGE="en_US:en" LC_ALL="en_US.UTF-8" 18 | - curl -LsSf https://astral.sh/uv/install.sh | sh 19 | - export PATH=$PATH:/root/.cargo/bin 20 | - uv sync --group tests --extra all 21 | - HF_HUB_CACHE="/models_cache" CUDA_VISIBLE_DEVICES="0" uv run pytest -vv -s 22 | 23 | volumes: 24 | - name: test-models-cache 25 | path: /models_cache 26 | 27 | max_price: 1.0 28 | 29 | resources: 30 | cpu: 9.. 31 | memory: 32GB.. 32 | gpu: 40GB.. -------------------------------------------------------------------------------- /aana/utils/typing.py: -------------------------------------------------------------------------------- 1 | import typing 2 | from collections.abc import AsyncGenerator 3 | 4 | 5 | def is_typed_dict(argument: type) -> bool: 6 | """Checks if a argument is a TypedDict. 7 | 8 | Arguments: 9 | argument (type): the type to check 10 | 11 | Returns: 12 | bool: True if the argument type is a TypedDict anf False if it is not. 13 | """ 14 | return bool( 15 | argument and getattr(argument, "__orig_bases__", None) == (typing.TypedDict,) 16 | ) 17 | 18 | 19 | def is_async_generator(argument: type) -> bool: 20 | """Checks if a argument is an AsyncGenerator. 21 | 22 | Arguments: 23 | argument (type): the type to check 24 | 25 | Returns: 26 | bool: True if the argument type is an AsyncGenerator and False if it is not. 27 | """ 28 | return hasattr(argument, "__origin__") and issubclass( 29 | argument.__origin__, AsyncGenerator 30 | ) 31 | -------------------------------------------------------------------------------- /.devcontainer/devcontainer.json: -------------------------------------------------------------------------------- 1 | { 2 | "name": "Ubuntu", 3 | "build": { 4 | "dockerfile": "Dockerfile" 5 | }, 6 | "capAdd": [ 7 | "SYS_PTRACE" 8 | ], 9 | "features": { 10 | "ghcr.io/devcontainers/features/python:1": { 11 | "installTools": true, 12 | "version": "3.10" 13 | }, 14 | "ghcr.io/itsmechlark/features/postgresql:1": { 15 | "version": "16" 16 | }, 17 | "ghcr.io/va-h/devcontainers-features/uv:1": {} 18 | }, 19 | "hostRequirements": { 20 | "gpu": "optional" 21 | }, 22 | "securityOpt": [ 23 | "seccomp=unconfined" 24 | ], 25 | "postCreateCommand": "uv sync --all-groups --extra all --extra api_service", 26 | "postStartCommand": "git config --global --add safe.directory ${containerWorkspaceFolder}", 27 | "customizations": { 28 | "vscode": { 29 | "extensions": [ 30 | "charliermarsh.ruff", 31 | "ms-python.python", 32 | "ms-python.mypy-type-checker", 33 | "ms-toolsai.jupyter" 34 | ] 35 | } 36 | } 37 | } -------------------------------------------------------------------------------- /aana/storage/session.py: -------------------------------------------------------------------------------- 1 | import logging 2 | from typing import Annotated 3 | 4 | from fastapi import Depends 5 | from sqlalchemy.ext.asyncio import AsyncSession 6 | 7 | from aana.configs.settings import settings 8 | from aana.storage.op import DatabaseSessionManager 9 | 10 | logger = logging.getLogger(__name__) 11 | 12 | __all__ = ["GetDbDependency", "get_db", "get_session"] 13 | 14 | session_manager = DatabaseSessionManager(settings) 15 | 16 | 17 | def get_session() -> AsyncSession: 18 | """Get a new SQLAlchemy Session object. 19 | 20 | Returns: 21 | AsyncSession: SQLAlchemy async session 22 | """ 23 | return session_manager.session() 24 | 25 | 26 | async def get_db(): 27 | """Get a database session. 28 | 29 | Returns: 30 | AsyncSession: SQLAlchemy async session 31 | """ 32 | async with session_manager.session() as session: 33 | yield session 34 | 35 | 36 | GetDbDependency = Annotated[AsyncSession, Depends(get_db)] 37 | """ Dependency to get a database session. """ 38 | -------------------------------------------------------------------------------- /aana/storage/models/media.py: -------------------------------------------------------------------------------- 1 | from uuid import uuid4 2 | 3 | from sqlalchemy.orm import Mapped, mapped_column 4 | 5 | from aana.core.models.media import MediaId 6 | from aana.storage.models.base import BaseEntity, TimeStampEntity 7 | 8 | 9 | class MediaEntity(BaseEntity, TimeStampEntity): 10 | """Base ORM class for media (e.g. videos, images, etc.). 11 | 12 | This class is meant to be subclassed by other media types. 13 | 14 | Attributes: 15 | id (MediaId): Unique identifier for the media. 16 | media_type (str): The type of media (populated automatically by ORM based on `polymorphic_identity` of subclass). 17 | """ 18 | 19 | __tablename__ = "media" 20 | id: Mapped[MediaId] = mapped_column( 21 | primary_key=True, 22 | default=lambda: str(uuid4()), 23 | comment="Unique identifier for the media", 24 | ) 25 | media_type: Mapped[str] = mapped_column(comment="The type of media") 26 | 27 | __mapper_args__ = { # noqa: RUF012 28 | "polymorphic_identity": "media", 29 | "polymorphic_on": "media_type", 30 | } 31 | -------------------------------------------------------------------------------- /aana/storage/models/video.py: -------------------------------------------------------------------------------- 1 | from sqlalchemy import ForeignKey 2 | from sqlalchemy.orm import Mapped, mapped_column 3 | 4 | from aana.core.models.media import MediaId 5 | from aana.storage.models.media import MediaEntity 6 | 7 | 8 | class VideoEntity(MediaEntity): 9 | """Base ORM class for videos. 10 | 11 | Attributes: 12 | id (MediaId): Unique identifier for the video. 13 | path (str): Path to the video file. 14 | url (str): URL to the video file. 15 | title (str): Title of the video. 16 | description (str): Description of the video. 17 | """ 18 | 19 | __tablename__ = "video" 20 | 21 | id: Mapped[MediaId] = mapped_column(ForeignKey("media.id"), primary_key=True) 22 | path: Mapped[str] = mapped_column(comment="Path", nullable=True) 23 | url: Mapped[str] = mapped_column(comment="URL", nullable=True) 24 | title: Mapped[str] = mapped_column(comment="Title", nullable=True) 25 | description: Mapped[str] = mapped_column(comment="Description", nullable=True) 26 | 27 | __mapper_args__ = { # noqa: RUF012 28 | "polymorphic_identity": "video", 29 | } 30 | -------------------------------------------------------------------------------- /aana/api/responses.py: -------------------------------------------------------------------------------- 1 | from typing import Any 2 | 3 | import orjson 4 | from fastapi.responses import JSONResponse 5 | 6 | from aana.utils.json import jsonify 7 | 8 | 9 | class AanaJSONResponse(JSONResponse): 10 | """Response class that uses orjson to serialize data. 11 | 12 | It has additional support for numpy arrays. 13 | """ 14 | 15 | media_type = "application/json" 16 | option = None 17 | 18 | def __init__( 19 | self, 20 | content: Any, 21 | status_code: int = 200, 22 | option: int | None = orjson.OPT_SERIALIZE_NUMPY, 23 | **kwargs, 24 | ): 25 | """Initialize the response class with the orjson option.""" 26 | self.option = option 27 | super().__init__(content, status_code, **kwargs) 28 | 29 | def render(self, content: Any) -> bytes: 30 | """Override the render method to use orjson.dumps instead of json.dumps.""" 31 | return jsonify(content, option=self.option, as_bytes=True) 32 | 33 | 34 | class AanaConcatenatedJSONResponse(AanaJSONResponse): 35 | """Response class that uses orjson to serialize data to Concatenated JSON format.""" 36 | 37 | media_type = "application/json-seq" 38 | -------------------------------------------------------------------------------- /aana/storage/models/__init__.py: -------------------------------------------------------------------------------- 1 | # ruff: noqa: F401 2 | # We need to import all db models here and, other than in the class definitions 3 | # themselves, only import them from aana.models.db directly. The reason for 4 | # this is the way SQLAlchemy's declarative base works. You can use forward 5 | # references like `parent = reference("Parent", backreferences="child")`, but the 6 | # forward reference needs to have been resolved before the first constructor 7 | # is called so that SqlAlchemy "knows" about it. 8 | # See: 9 | # https://docs.pylonsproject.org/projects/pyramid_cookbook/en/latest/database/sqlalchemy.html#importing-all-sqlalchemy-models 10 | # (even if not using Pyramid) 11 | 12 | from aana.storage.models.base import BaseEntity 13 | from aana.storage.models.caption import CaptionEntity 14 | from aana.storage.models.media import MediaEntity 15 | from aana.storage.models.task import TaskEntity 16 | from aana.storage.models.transcript import TranscriptEntity 17 | from aana.storage.models.video import VideoEntity 18 | from aana.storage.models.webhook import WebhookEntity 19 | 20 | __all__ = [ 21 | "BaseEntity", 22 | "CaptionEntity", 23 | "MediaEntity", 24 | "TranscriptEntity", 25 | "VideoEntity", 26 | ] 27 | -------------------------------------------------------------------------------- /aana/core/models/custom_config.py: -------------------------------------------------------------------------------- 1 | import pickle 2 | from typing import Annotated 3 | 4 | from pydantic import BeforeValidator, PlainSerializer 5 | 6 | __all__ = ["CustomConfig"] 7 | 8 | CustomConfig = Annotated[ 9 | dict, 10 | PlainSerializer(lambda x: pickle.dumps(x).decode("latin1"), return_type=str), 11 | BeforeValidator( 12 | lambda x: x if isinstance(x, dict) else pickle.loads(x.encode("latin1")) # noqa: S301 13 | ), 14 | ] 15 | """ 16 | A custom configuration field that can be used to pass arbitrary configuration to the deployment. 17 | 18 | For example, you can define a custom configuration field in a deployment configuration like this: 19 | 20 | ```python 21 | class HfPipelineConfig(BaseModel): 22 | model_id: str 23 | task: str | None = None 24 | model_kwargs: CustomConfig = {} 25 | pipeline_kwargs: CustomConfig = {} 26 | generation_kwargs: CustomConfig = {} 27 | ``` 28 | 29 | Then you can use the custom configuration field to pass a configuration to the deployment: 30 | 31 | ```python 32 | HfPipelineConfig( 33 | model_id="Salesforce/blip2-opt-2.7b", 34 | model_kwargs={ 35 | "quantization_config": BitsAndBytesConfig( 36 | load_in_8bit=False, load_in_4bit=True 37 | ), 38 | }, 39 | ) 40 | ``` 41 | """ 42 | -------------------------------------------------------------------------------- /aana/core/models/api_service.py: -------------------------------------------------------------------------------- 1 | from typing import Annotated 2 | 3 | from pydantic import BaseModel, Field 4 | from pydantic.json_schema import SkipJsonSchema 5 | 6 | 7 | class ApiKey(BaseModel): 8 | """Pydantic model for API key entity. 9 | 10 | Attributes: 11 | api_key (str): The API key. 12 | user_id (str): ID of the user who owns this API key. 13 | subscription_id (str): ID of the associated subscription. 14 | is_subscription_active (bool): Whether the subscription is active (credits are available). 15 | is_admin (bool): Whether the user is an admin. 16 | hmac_secret (str | None): The secret key for HMAC signature generation. 17 | """ 18 | 19 | api_key: str 20 | user_id: str 21 | subscription_id: str 22 | is_subscription_active: bool 23 | is_admin: bool 24 | hmac_secret: str | None 25 | 26 | 27 | ApiKeyType = SkipJsonSchema[Annotated[ApiKey, Field(default=None)]] 28 | """ 29 | Type with optional API key information. 30 | 31 | Can be None if API service is disabled. Otherwise, it will be an instance of `ApiKey`. 32 | 33 | Attributes: 34 | api_key (str): The API key. 35 | user_id (str): The user ID. 36 | subscription_id (str): The subscription ID. 37 | is_subscription_active (bool): Flag indicating if the subscription is active. 38 | """ 39 | -------------------------------------------------------------------------------- /.github/workflows/run_tests_with_gpu.yml: -------------------------------------------------------------------------------- 1 | name: Run Tests with GPU 2 | 3 | on: 4 | workflow_dispatch: # Allows for manual triggering 5 | 6 | concurrency: 7 | group: run-tests-gpu # Fixed group name to ensure only one instance runs 8 | cancel-in-progress: false 9 | 10 | jobs: 11 | test: 12 | runs-on: ubuntu-latest 13 | 14 | steps: 15 | - name: Set SSH permissions 16 | run: | 17 | mkdir -p ~/.ssh 18 | chmod 700 ~/.ssh 19 | sudo chown $USER:$USER ~/.ssh 20 | 21 | - name: Checkout code 22 | uses: actions/checkout@v3 23 | 24 | - name: Set up Python 3.10 25 | uses: actions/setup-python@v5 26 | with: 27 | python-version: "3.10" 28 | 29 | - name: Install and configure dstack 30 | run: | 31 | pip install dstack 32 | dstack config --url https://sky.dstack.ai --project ${{ secrets.DSTACK_PROJECT }} --token ${{ secrets.DSTACK_TOKEN }} 33 | dstack init 34 | 35 | - name: Run tests with GPU 36 | run: | 37 | DSTACK_CLI_LOG_LEVEL=DEBUG HF_TOKEN=${{ secrets.HF_TOKEN }} dstack apply -f tests.dstack.yml --force -y 38 | 39 | - name: Extract pytest logs 40 | if: ${{ always() }} 41 | run: | 42 | dstack logs aana-tests | sed -n '/============================= test session starts ==============================/,$p' 43 | -------------------------------------------------------------------------------- /docs/pages/settings.md: -------------------------------------------------------------------------------- 1 | --- 2 | hide: 3 | - navigation 4 | --- 5 | 6 | 12 | 13 | 14 | # Settings 15 | 16 | Here are the environment variables that can be used to configure the Aaana SDK: 17 | 18 | - TMP_DATA_DIR: The directory to store temporary data. Default: `/tmp/aana`. 19 | - NUM_WORKERS: The number of request workers. Default: `2`. 20 | - DB_CONFIG: The database configuration in the format `{"datastore_type": "sqlite", "datastore_config": {"path": "/path/to/sqlite.db"}}`. Currently only SQLite and PostgreSQL are supported. Default: `{"datastore_type": "sqlite", "datastore_config": {"path": "/var/lib/aana_data"}}`. 21 | - HF_HUB_ENABLE_HF_TRANSFER: If set to `1`, the HuggingFace Transformers will use the HF Transfer library to download the models from HuggingFace Hub to speed up the process. Recommended to always set to it `1`. Default: `0`. 22 | - HF_TOKEN: The HuggingFace API token to download the models from HuggingFace Hub, required for private or gated models. 23 | - TEST__SAVE_EXPECTED_OUTPUT: If set to `True`, the expected output will be saved when running the tests. Useful for creating new development tests. Default: `False`. 24 | 25 | 26 | See [reference documentation](./../reference/settings.md#aana.configs.Settings) for more advanced settings. 27 | -------------------------------------------------------------------------------- /aana/core/models/types.py: -------------------------------------------------------------------------------- 1 | from enum import Enum 2 | 3 | import torch 4 | 5 | __all__ = ["Dtype"] 6 | 7 | 8 | class Dtype(str, Enum): 9 | """Data types. 10 | 11 | Possible values are "auto", "float32", "float16", "bfloat16" and "int8". 12 | 13 | Attributes: 14 | AUTO (str): auto 15 | FLOAT32 (str): float32 16 | FLOAT16 (str): float16 17 | BFLOAT16 (str): bfloat16 18 | INT8 (str): int8 19 | """ 20 | 21 | AUTO = "auto" 22 | FLOAT32 = "float32" 23 | FLOAT16 = "float16" 24 | BFLOAT16 = "bfloat16" 25 | INT8 = "int8" 26 | 27 | def to_torch(self) -> torch.dtype | str: 28 | """Convert the instance's dtype to a torch dtype. 29 | 30 | Returns: 31 | Union[torch.dtype, str]: the torch dtype or "auto" 32 | 33 | Raises: 34 | ValueError: if the dtype is unknown 35 | """ 36 | match self.value: 37 | case self.AUTO: 38 | return "auto" 39 | case self.FLOAT32: 40 | return torch.float32 41 | case self.FLOAT16: 42 | return torch.float16 43 | case self.BFLOAT16: 44 | return torch.bfloat16 45 | case self.INT8: 46 | return torch.int8 47 | case _: 48 | raise ValueError(f"Unknown dtype: {self}") # noqa: TRY003 49 | -------------------------------------------------------------------------------- /.github/workflows/tests.yml: -------------------------------------------------------------------------------- 1 | name: Tests and Linting 2 | 3 | on: 4 | push: 5 | branches: 6 | - '**' # Runs on push to any branch 7 | pull_request: 8 | branches: 9 | - '**' # Runs on pull requests to any branch 10 | workflow_dispatch: # Allows for manual triggering 11 | 12 | jobs: 13 | build: 14 | 15 | runs-on: ubuntu-latest 16 | strategy: 17 | fail-fast: false 18 | matrix: 19 | python-version: ["3.10", "3.11", "3.12"] 20 | 21 | steps: 22 | - name: Set up Python ${{ matrix.python-version }} 23 | uses: actions/setup-python@v5 24 | with: 25 | python-version: ${{ matrix.python-version }} 26 | - name: Display Python version 27 | run: python -c "import sys; print(sys.version)" 28 | - name: Checkout code 29 | uses: actions/checkout@v3 30 | - name: Bootstrap uv 31 | uses: astral-sh/setup-uv@v6 32 | with: 33 | version: "latest" 34 | - name: Install dependencies 35 | run: | 36 | uv sync --group dev --group tests --extra all 37 | sudo apt-get update 38 | sudo apt-get install ffmpeg 39 | - name: Install postgres 40 | uses: ikalnytskyi/action-setup-postgres@v6 41 | - name: Run Ruff Check 42 | run: uv run ruff check 43 | - name: Test with pytest 44 | if: always() 45 | env: 46 | HF_TOKEN: ${{ secrets.HF_TOKEN }} 47 | run: uv run pytest -vv 48 | -------------------------------------------------------------------------------- /aana/storage/types.py: -------------------------------------------------------------------------------- 1 | import datetime 2 | from typing import TypeAlias 3 | 4 | from sqlalchemy import String 5 | from sqlalchemy.types import DateTime, TypeDecorator 6 | 7 | MediaIdSqlType: TypeAlias = String(36) 8 | 9 | 10 | class TimezoneAwareDateTime(TypeDecorator): 11 | """A custom SQLAlchemy type decorator for timezone-aware datetime objects. 12 | 13 | This implementation converts naive datetime objects to UTC-aware datetime 14 | objects when binding parameters and when loading from the database. 15 | """ 16 | 17 | impl = DateTime 18 | cache_ok = True 19 | 20 | def process_bind_param(self, value, dialect): 21 | """Convert naive datetime objects to UTC-aware datetime before storing.""" 22 | if value is None: 23 | return value 24 | # If the datetime is naive, assume it is UTC and set the tzinfo. 25 | if value.tzinfo is None: 26 | value = value.replace(tzinfo=datetime.timezone.utc) 27 | else: 28 | # Otherwise, normalize to UTC. 29 | value = value.astimezone(datetime.timezone.utc) 30 | return value 31 | 32 | def process_result_value(self, value, dialect): 33 | """Ensure that datetime objects loaded from the database are UTC-aware.""" 34 | if value is None: 35 | return value 36 | # If value is naive (as may happen with SQLite), assume it's in UTC. 37 | if value.tzinfo is None: 38 | value = value.replace(tzinfo=datetime.timezone.utc) 39 | return value 40 | -------------------------------------------------------------------------------- /aana/tests/db/datastore/test_config.py: -------------------------------------------------------------------------------- 1 | # ruff: noqa: S101 2 | import pytest 3 | 4 | from aana.configs.db import DbSettings, DbType 5 | from aana.configs.settings import settings as aana_settings 6 | from aana.storage.op import DatabaseSessionManager 7 | 8 | 9 | def test_datastore_config(db_session_manager): 10 | """Tests datastore config for PostgreSQL and SQLite.""" 11 | engine = db_session_manager._engine 12 | if db_session_manager._db_config.datastore_type == DbType.POSTGRESQL: 13 | assert engine.name == "postgresql" 14 | assert str(engine.url).startswith("postgresql+asyncpg://") 15 | elif db_session_manager._db_config.datastore_type == DbType.SQLITE: 16 | assert engine.name == "sqlite" 17 | assert str(engine.url).startswith("sqlite+aiosqlite://") 18 | else: 19 | raise AssertionError("Unsupported database type") # noqa: TRY003 20 | 21 | 22 | def test_nonexistent_datastore_config(): 23 | """Tests that datastore config errors on unsupported DB types.""" 24 | db_settings = DbSettings( 25 | **{ 26 | "datastore_type": "oracle", 27 | "datastore_config": { 28 | "host": "0.0.0.0", # noqa: S104 29 | "port": "5432", 30 | "database": "oracle", 31 | "user": "oracle", 32 | "password": "bogus", 33 | }, 34 | } 35 | ) 36 | aana_settings.db_config = db_settings 37 | with pytest.raises(ValueError): 38 | DatabaseSessionManager(aana_settings) 39 | -------------------------------------------------------------------------------- /.github/workflows/publish.yml: -------------------------------------------------------------------------------- 1 | name: Publish Python Package 2 | 3 | on: 4 | release: 5 | types: [published] 6 | workflow_dispatch: 7 | inputs: 8 | publish_target: 9 | description: 'Select the target PyPI repository' 10 | required: true 11 | default: 'testpypi' 12 | type: choice 13 | options: 14 | - pypi 15 | - testpypi 16 | 17 | jobs: 18 | publish: 19 | runs-on: ubuntu-latest 20 | 21 | steps: 22 | - name: Checkout code 23 | uses: actions/checkout@v3 24 | 25 | - name: Set up Python 3.10 26 | uses: actions/setup-python@v5 27 | with: 28 | python-version: "3.10" 29 | 30 | - name: Bootstrap uv 31 | uses: astral-sh/setup-uv@v6 32 | with: 33 | version: "latest" 34 | 35 | - name: Install dependencies 36 | run: uv sync 37 | 38 | - name: Build the package 39 | run: uv build 40 | 41 | - name: Publish to PyPI 42 | if: github.event_name == 'release' || (github.event_name == 'workflow_dispatch' && github.event.inputs.publish_target == 'pypi') 43 | env: 44 | UV_PUBLISH_USERNAME: "__token__" 45 | UV_PUBLISH_PASSWORD: ${{ secrets.PYPI_API_TOKEN }} 46 | run: | 47 | uv publish 48 | 49 | - name: Publish to Test PyPI 50 | if: github.event_name == 'workflow_dispatch' && github.event.inputs.publish_target == 'testpypi' 51 | env: 52 | UV_PUBLISH_USERNAME: "__token__" 53 | UV_PUBLISH_PASSWORD: ${{ secrets.TEST_PYPI_API_TOKEN }} 54 | run: | 55 | uv publish --publish-url https://test.pypi.org/legacy/ 56 | -------------------------------------------------------------------------------- /aana/alembic/versions/b9860676dd49_set_server_default_for_task_completed_.py: -------------------------------------------------------------------------------- 1 | """Set server default for task.completed_at and task.assigned_at to none and add num_retries. 2 | 3 | Revision ID: b9860676dd49 4 | Revises: 5ad873484aa3 5 | Create Date: 2024-08-22 07:54:55.921710 6 | 7 | """ 8 | from collections.abc import Sequence 9 | 10 | import sqlalchemy as sa 11 | from alembic import op 12 | 13 | # revision identifiers, used by Alembic. 14 | revision: str = "b9860676dd49" 15 | down_revision: str | None = "5ad873484aa3" 16 | branch_labels: str | Sequence[str] | None = None 17 | depends_on: str | Sequence[str] | None = None 18 | 19 | 20 | def upgrade() -> None: 21 | """Upgrade database to this revision from previous.""" 22 | with op.batch_alter_table("tasks", schema=None) as batch_op: 23 | batch_op.alter_column( 24 | "completed_at", 25 | server_default=None, 26 | ) 27 | batch_op.alter_column( 28 | "assigned_at", 29 | server_default=None, 30 | ) 31 | batch_op.add_column( 32 | sa.Column( 33 | "num_retries", 34 | sa.Integer(), 35 | nullable=False, 36 | comment="Number of retries", 37 | server_default=sa.text("0"), 38 | ) 39 | ) 40 | 41 | # ### end Alembic commands ### 42 | 43 | 44 | def downgrade() -> None: 45 | """Downgrade database from this revision to previous.""" 46 | with op.batch_alter_table("tasks", schema=None) as batch_op: 47 | batch_op.drop_column("num_retries") 48 | 49 | # ### end Alembic commands ### 50 | -------------------------------------------------------------------------------- /aana/tests/files/videos/ATTRIBUTION.md: -------------------------------------------------------------------------------- 1 | # Third-Party Media Attribution 2 | 3 | This folder contains third-party media used for testing. No endorsement implied. 4 | 5 | _Last updated: 2025-08-22_ 6 | 7 | | File | Source | License | Author/Owner | Changes Made | 8 | |---|---|---|---|---| 9 | | `squirrel.mp4` | Video by **Nathan J Hilton** on Pexels — https://www.pexels.com/video/a-squirrel-is-sitting-on-top-of-a-table-17977045/ | Pexels License — https://www.pexels.com/license/ | Nathan J Hilton | Downloaded for testing (may have been re-encoded; no substantive changes). | 10 | | `squirrel_no_audio.mp4` | Video by **Nathan J Hilton** on Pexels — https://www.pexels.com/video/a-squirrel-is-sitting-on-top-of-a-table-17977045/ | Pexels License — https://www.pexels.com/license/ | Nathan J Hilton | Audio track removed for testing; re-encoded. | 11 | | `physicsworks.webm` | Wikimedia Commons — https://commons.wikimedia.org/wiki/File:Physicsworks.ogv (clip from MIT OCW “Work, Energy, and Universal Gravitation”, Lecture 11) | CC BY 3.0 — https://creativecommons.org/licenses/by/3.0/ | Walter Lewin | Transcoded to WEBM for testing; may be trimmed for length. | 12 | | `physicsworks_audio.webm` | Wikimedia Commons — https://commons.wikimedia.org/wiki/File:Physicsworks.ogv (clip from MIT OCW “Work, Energy, and Universal Gravitation”, Lecture 11) | CC BY 3.0 — https://creativecommons.org/licenses/by/3.0/ | Walter Lewin | Extracted audio only and transcoded to WEBM/Opus for testing; trimmed for length. | 13 | 14 | **Additional note (Commons VRT):** The Commons file page indicates a VRT permission ticket (#2011051010013473) confirming publication under the stated terms. 15 | -------------------------------------------------------------------------------- /aana/integrations/haystack/remote_haystack_component.py: -------------------------------------------------------------------------------- 1 | 2 | from haystack import component 3 | 4 | from aana.deployments.aana_deployment_handle import AanaDeploymentHandle 5 | from aana.utils.asyncio import run_async 6 | 7 | 8 | @component 9 | class RemoteHaystackComponent: 10 | """A component that connects to a remote Haystack component created by HaystackComponentDeployment. 11 | 12 | Attributes: 13 | deployment_name (str): The name of the deployment to use. 14 | """ 15 | 16 | def __init__( 17 | self, 18 | deployment_name: str, 19 | ): 20 | """Initialize the component. 21 | 22 | Args: 23 | deployment_name (str): The name of the deployment to use. 24 | """ 25 | self.deployment_name = deployment_name 26 | 27 | def warm_up(self): 28 | """Warm up the component. 29 | 30 | This will properly initialize the component by creating a handle to the deployment 31 | and setting the input and output types. 32 | """ 33 | if hasattr(self, "handle"): 34 | return 35 | self.handle = run_async(AanaDeploymentHandle.create(self.deployment_name)) 36 | sockets = run_async(self.handle.get_sockets()) 37 | component.set_input_types( 38 | self, **{socket.name: socket.type for socket in sockets["input"].values()} 39 | ) 40 | component.set_output_types( 41 | self, **{socket.name: socket.type for socket in sockets["output"].values()} 42 | ) 43 | 44 | def run(self, **data): 45 | """Run the component on the input data.""" 46 | return run_async(self.handle.run(**data)) 47 | -------------------------------------------------------------------------------- /aana/storage/models/webhook.py: -------------------------------------------------------------------------------- 1 | import uuid 2 | from enum import Enum 3 | 4 | from sqlalchemy import JSON, UUID 5 | from sqlalchemy.dialects.postgresql import JSONB 6 | from sqlalchemy.orm import Mapped, mapped_column 7 | 8 | from aana.storage.models.base import BaseEntity, TimeStampEntity 9 | 10 | 11 | class WebhookEventType(str, Enum): 12 | """Enum for webhook event types.""" 13 | 14 | TASK_COMPLETED = "task.completed" 15 | TASK_FAILED = "task.failed" 16 | TASK_STARTED = "task.started" 17 | 18 | 19 | class WebhookEntity(BaseEntity, TimeStampEntity): 20 | """Table for webhook items.""" 21 | 22 | __tablename__ = "webhooks" 23 | 24 | id: Mapped[uuid.UUID] = mapped_column( 25 | UUID, primary_key=True, default=uuid.uuid4, comment="Webhook ID" 26 | ) 27 | user_id: Mapped[str | None] = mapped_column( 28 | nullable=True, index=True, comment="The user ID associated with the webhook" 29 | ) 30 | url: Mapped[str] = mapped_column( 31 | nullable=False, comment="The URL to which the webhook will send requests" 32 | ) 33 | events: Mapped[list[str]] = mapped_column( 34 | JSON().with_variant(JSONB, "postgresql"), 35 | nullable=False, 36 | comment="List of events the webhook is subscribed to. If the list is empty, the webhook is subscribed to all events.", 37 | ) 38 | 39 | def __repr__(self) -> str: 40 | """String representation of the webhook.""" 41 | return ( 42 | f"" 45 | ) 46 | -------------------------------------------------------------------------------- /aana/alembic/versions/d40eba8ebc4c_added_user_id_to_tasks.py: -------------------------------------------------------------------------------- 1 | """Added user_id to tasks. 2 | 3 | Revision ID: d40eba8ebc4c 4 | Revises: b9860676dd49 5 | Create Date: 2025-01-23 14:17:04.394863 6 | 7 | """ 8 | from collections.abc import Sequence 9 | 10 | import sqlalchemy as sa 11 | from alembic import op 12 | 13 | # revision identifiers, used by Alembic. 14 | revision: str = "d40eba8ebc4c" 15 | down_revision: str | None = "b9860676dd49" 16 | branch_labels: str | Sequence[str] | None = None 17 | depends_on: str | Sequence[str] | None = None 18 | 19 | 20 | def upgrade() -> None: 21 | """Upgrade database to this revision from previous.""" 22 | # ### commands auto generated by Alembic - please adjust! ### 23 | with op.batch_alter_table("tasks", schema=None) as batch_op: 24 | batch_op.add_column( 25 | sa.Column( 26 | "user_id", 27 | sa.String(), 28 | nullable=True, 29 | comment="ID of the user who launched the task", 30 | ) 31 | ) 32 | batch_op.create_index(batch_op.f("ix_tasks_status"), ["status"], unique=False) 33 | batch_op.create_index(batch_op.f("ix_tasks_user_id"), ["user_id"], unique=False) 34 | 35 | # ### end Alembic commands ### 36 | 37 | 38 | def downgrade() -> None: 39 | """Downgrade database from this revision to previous.""" 40 | # ### commands auto generated by Alembic - please adjust! ### 41 | with op.batch_alter_table("tasks", schema=None) as batch_op: 42 | batch_op.drop_index(batch_op.f("ix_tasks_user_id")) 43 | batch_op.drop_index(batch_op.f("ix_tasks_status")) 44 | batch_op.drop_column("user_id") 45 | 46 | # ### end Alembic commands ### 47 | -------------------------------------------------------------------------------- /aana/storage/repository/video.py: -------------------------------------------------------------------------------- 1 | from typing import TypeVar 2 | 3 | from sqlalchemy.ext.asyncio import AsyncSession 4 | 5 | from aana.core.models.media import MediaId 6 | from aana.core.models.video import Video, VideoMetadata 7 | from aana.storage.models import VideoEntity 8 | from aana.storage.repository.media import MediaRepository 9 | 10 | V = TypeVar("V", bound=VideoEntity) 11 | 12 | 13 | class VideoRepository(MediaRepository[V]): 14 | """Repository for videos.""" 15 | 16 | def __init__(self, session: AsyncSession, model_class: type[V] = VideoEntity): 17 | """Constructor.""" 18 | super().__init__(session, model_class) 19 | 20 | async def save(self, video: Video) -> dict: 21 | """Saves a video to datastore. 22 | 23 | Args: 24 | video (Video): The video object. 25 | 26 | Returns: 27 | dict: The dictionary with media ID. 28 | """ 29 | video_entity = VideoEntity( 30 | id=video.media_id, 31 | path=str(video.path), 32 | url=video.url, 33 | title=video.title, 34 | description=video.description, 35 | ) 36 | 37 | await self.create(video_entity) 38 | return { 39 | "media_id": video_entity.id, 40 | } 41 | 42 | async def get_metadata(self, media_id: MediaId) -> VideoMetadata: 43 | """Get the metadata of a video. 44 | 45 | Args: 46 | media_id (MediaId): The media ID. 47 | 48 | Returns: 49 | VideoMetadata: The video metadata. 50 | """ 51 | entity: VideoEntity = await self.read(media_id) 52 | return VideoMetadata(title=entity.title, description=entity.description) 53 | -------------------------------------------------------------------------------- /aana/utils/asyncio.py: -------------------------------------------------------------------------------- 1 | import asyncio 2 | import threading 3 | from collections.abc import Coroutine 4 | from typing import Any, TypeVar 5 | 6 | __all__ = ["run_async"] 7 | 8 | T = TypeVar("T") 9 | 10 | 11 | def run_async(coro: Coroutine[Any, Any, T]) -> T: 12 | """Run a coroutine in a thread if the current thread is running an event loop. 13 | 14 | Otherwise, run the coroutine in the current asyncio loop. 15 | 16 | Useful when you want to run an async function in a non-async context. 17 | 18 | From: https://stackoverflow.com/a/75094151 19 | 20 | Args: 21 | coro (Coroutine): The coroutine to run. 22 | 23 | Returns: 24 | T: The result of the coroutine. 25 | """ 26 | 27 | class RunThread(threading.Thread): 28 | """Run a coroutine in a thread.""" 29 | 30 | def __init__(self, coro: Coroutine[Any, Any, T]): 31 | """Initialize the thread.""" 32 | self.coro = coro 33 | self.result: T | None = None 34 | self.exception: Exception | None = None 35 | super().__init__() 36 | 37 | def run(self): 38 | """Run the coroutine.""" 39 | try: 40 | self.result = asyncio.run(self.coro) 41 | except Exception as e: 42 | self.exception = e 43 | 44 | try: 45 | loop = asyncio.get_running_loop() 46 | except RuntimeError: 47 | loop = None 48 | 49 | if loop and loop.is_running(): 50 | thread = RunThread(coro) 51 | thread.start() 52 | thread.join() 53 | if thread.exception: 54 | raise thread.exception 55 | return thread.result 56 | else: 57 | return asyncio.run(coro) 58 | -------------------------------------------------------------------------------- /aana/tests/units/test_media_id.py: -------------------------------------------------------------------------------- 1 | # ruff: noqa: S101 2 | 3 | import pytest 4 | from pydantic import BaseModel, ValidationError 5 | 6 | from aana.core.models.media import MediaId 7 | 8 | 9 | class TestModel(BaseModel): 10 | """Test model for media id.""" 11 | 12 | media_id: MediaId 13 | 14 | 15 | def test_media_id_creation(): 16 | """Test that a media id can be created.""" 17 | media_id = TestModel(media_id="foo").media_id 18 | assert media_id == "foo" 19 | 20 | # Validation only happens when the model is created 21 | # because MediaId is just an annotated string 22 | with pytest.raises(ValueError): 23 | TestModel().media_id # noqa: B018 24 | 25 | with pytest.raises(ValidationError): 26 | TestModel(media_id="").media_id # noqa: B018 27 | 28 | # MediaId is a string with a maximum length of 36 29 | with pytest.raises(ValidationError): 30 | TestModel(media_id="a" * 37).media_id # noqa: B018 31 | 32 | 33 | def test_valid_media_ids(): 34 | """Test that valid media ids are accepted.""" 35 | valid_ids = ["abc123", "abc-123", "abc_123", "A1B2_C3", "123456", "a_b-c"] 36 | for media_id in valid_ids: 37 | model = TestModel(media_id=media_id) 38 | assert model.media_id == media_id 39 | 40 | 41 | def test_invalid_media_ids(): 42 | """Test that invalid media ids are rejected.""" 43 | invalid_ids = [ 44 | "abc 123", # contains a space 45 | "abc@123", # contains an invalid character (@) 46 | "abc#123", # contains an invalid character (#) 47 | "abc.123", # contains an invalid character (.) 48 | ] 49 | for media_id in invalid_ids: 50 | with pytest.raises(ValidationError): 51 | TestModel(media_id=media_id) 52 | -------------------------------------------------------------------------------- /aana/storage/repository/transcript.py: -------------------------------------------------------------------------------- 1 | from typing import TypeVar 2 | 3 | from sqlalchemy.ext.asyncio import AsyncSession 4 | 5 | from aana.core.models.asr import ( 6 | AsrSegments, 7 | AsrTranscription, 8 | AsrTranscriptionInfo, 9 | ) 10 | from aana.storage.models.transcript import TranscriptEntity 11 | from aana.storage.repository.base import BaseRepository 12 | 13 | T = TypeVar("T", bound=TranscriptEntity) 14 | 15 | 16 | class TranscriptRepository(BaseRepository[T]): 17 | """Repository for Transcripts.""" 18 | 19 | def __init__(self, session: AsyncSession, model_class: type[T] = TranscriptEntity): 20 | """Constructor.""" 21 | super().__init__(session, model_class) 22 | 23 | async def save( 24 | self, 25 | model_name: str, 26 | transcription_info: AsrTranscriptionInfo, 27 | transcription: AsrTranscription, 28 | segments: AsrSegments, 29 | ) -> TranscriptEntity: 30 | """Save transcripts. 31 | 32 | Args: 33 | model_name (str): The name of the model used to generate the transcript. 34 | transcription_info (AsrTranscriptionInfo): The ASR transcription info. 35 | transcription (AsrTranscription): The ASR transcription. 36 | segments (AsrSegments): The ASR segments. 37 | 38 | Returns: 39 | TranscriptEntity: The transcript entity. 40 | """ 41 | transcript_entity = TranscriptEntity.from_asr_output( 42 | model_name=model_name, 43 | transcription=transcription, 44 | segments=segments, 45 | info=transcription_info, 46 | ) 47 | self.session.add(transcript_entity) 48 | await self.session.commit() 49 | return transcript_entity 50 | -------------------------------------------------------------------------------- /aana/exceptions/core.py: -------------------------------------------------------------------------------- 1 | from typing import Any 2 | 3 | 4 | class BaseException(Exception): # noqa: A001 5 | """Base class for SDK exceptions.""" 6 | 7 | def __init__(self, **kwargs: Any) -> None: 8 | """Initialise Exception.""" 9 | super().__init__() 10 | self.extra = kwargs 11 | 12 | def __str__(self) -> str: 13 | """Return a string representation of the exception. 14 | 15 | String is defined as follows: 16 | ``` 17 | (extra_key1=extra_value1, extra_key2=extra_value2, ...) 18 | ``` 19 | """ 20 | class_name = self.__class__.__name__ 21 | extra_str_list = [] 22 | for key, value in self.extra.items(): 23 | extra_str_list.append(f"{key}={value}") 24 | extra_str = ", ".join(extra_str_list) 25 | return f"{class_name}({extra_str})" 26 | 27 | def get_data(self) -> dict[str, Any]: 28 | """Get the data to be returned to the client. 29 | 30 | Returns: 31 | Dict[str, Any]: data to be returned to the client 32 | """ 33 | data = self.extra.copy() 34 | return data 35 | 36 | def add_extra(self, data: dict[str, Any]) -> None: 37 | """Add extra data to the exception. 38 | 39 | This data will be returned to the user as part of the response. 40 | 41 | How to use: in the exception handler, add the extra data to the exception and raise it again. 42 | 43 | Example: 44 | ``` 45 | try: 46 | ... 47 | except BaseException as e: 48 | e.add_extra({'extra_key': 'extra_value'}) 49 | raise e 50 | ``` 51 | 52 | Args: 53 | data (dict[str, Any]): dictionary containing the extra data 54 | """ 55 | self.extra.update(data) 56 | -------------------------------------------------------------------------------- /docs/pages/serve_config_files.md: -------------------------------------------------------------------------------- 1 | # Serve Config Files 2 | 3 | The [Serve Config Files](https://docs.ray.io/en/latest/serve/production-guide/config.html#serve-config-files) is the recommended way to deploy and update your applications in production. Aana SDK provides a way to build the Serve Config Files for the Aana applications. 4 | 5 | ## Building Serve Config Files 6 | 7 | To build the Serve config file, run the following command: 8 | 9 | ```bash 10 | aana build : 11 | ``` 12 | 13 | For example: 14 | 15 | ```bash 16 | aana build aana_chat_with_video.app:aana_app 17 | ``` 18 | 19 | The command will generate the Serve Config file and App Config file and save them in the project directory. You can then use these files to deploy the application using the Ray Serve CLI. 20 | 21 | ## Deploying with Serve Config Files 22 | 23 | When you are running the Aana application using the Serve config files, you need to run the migrations to create the database tables for the application. To run the migrations, use the following command: 24 | 25 | ```bash 26 | aana migrate : 27 | ``` 28 | 29 | For example: 30 | 31 | ```bash 32 | aana migrate aana_chat_with_video.app:aana_app 33 | ``` 34 | 35 | Before deploying the application, make sure you have the Ray cluster running. If you want to start a new Ray cluster on a single machine, you can use the following command: 36 | 37 | ```bash 38 | ray start --head 39 | ``` 40 | 41 | For more info on how to start a Ray cluster, see the [Ray documentation](https://docs.ray.io/en/latest/ray-core/starting-ray.html#starting-ray-via-the-cli-ray-start). 42 | 43 | To deploy the application using the Serve config files, use [`serve deploy`](https://docs.ray.io/en/latest/serve/advanced-guides/deploy-vm.html#serve-in-production-deploying) command provided by Ray Serve. For example: 44 | 45 | ```bash 46 | serve deploy config.yaml 47 | ``` 48 | -------------------------------------------------------------------------------- /docs/reference/index.md: -------------------------------------------------------------------------------- 1 | # Reference Documentation (Code API) 2 | 3 | This section contains the reference documentation for the public API of the project. Quick links to the most important classes and functions are provided below. 4 | 5 | ## SDK 6 | 7 | [`aana.AanaSDK`](./sdk.md#aana.AanaSDK) - The main class for interacting with the Aana SDK. Use it to register endpoints and deployments and to start the server. 8 | 9 | ## Endpoint 10 | 11 | [`aana.api.Endpoint`](./endpoint.md#aana.api.Endpoint) - The base class for defining endpoints in the Aana SDK. 12 | 13 | ## Deployments 14 | 15 | [Deployments](./deployments.md) contains information about how to deploy models with a number of predefined deployments for such models as Whisper, LLMs, Hugging Face models, and more. 16 | 17 | ## Models 18 | 19 | - [Media Models](./models/media.md) - Models for working with media types like audio, video, and images. 20 | - [Automatic Speech Recognition (ASR) Models](./models/asr.md) - Models for working with automatic speech recognition (ASR) models. 21 | - [Caption Models](./models/captions.md) - Models for working with captions. 22 | - [Chat Models](./models/chat.md) - Models for working with chat models. 23 | - [Image Chat Models](./models/image_chat.md) - Models for working with visual-text content for visual-language models. 24 | - [Custom Config](./models/custom_config.md) - Custom Config model can be used to pass arbitrary configuration to the deployment. 25 | - [Sampling Models](./models/sampling.md) - Contains Sampling Parameters model which can be used to pass sampling parameters to the LLM models. 26 | - [Time Models](./models/time.md) - Contains time models like TimeInterval. 27 | - [Types Models](./models/types.md) - Contains types models like Dtype. 28 | - [Video Models](./models/video.md) - Models for working with video files. 29 | - [Whisper Models](./models/whisper.md) - Models for working with whispers. 30 | 31 | -------------------------------------------------------------------------------- /aana/tests/units/test_rate_limiter.py: -------------------------------------------------------------------------------- 1 | import sys 2 | 3 | import pytest 4 | 5 | from aana.api.event_handlers.event_manager import EventManager 6 | from aana.api.event_handlers.rate_limit_handler import RateLimitHandler 7 | from aana.exceptions.runtime import TooManyRequestsException 8 | 9 | 10 | def test_rate_limiter_single(): 11 | """Tests that the rate limiter raises if the rate limit is exceeded.""" 12 | event_manager = EventManager() 13 | rate_limiter = RateLimitHandler(1, 1.0) 14 | 15 | event_manager.register_handler_for_events(rate_limiter, ["foo"]) 16 | event_manager.handle("foo") 17 | with pytest.raises(TooManyRequestsException): 18 | event_manager.handle("foo") 19 | 20 | 21 | def test_rate_limiter_multiple(): 22 | """Tests that the rate limiter raises if the rate limit is exceeded.""" 23 | event_manager = EventManager() 24 | rate_limiter = RateLimitHandler(1, 1.0) 25 | 26 | event_manager.register_handler_for_events(rate_limiter, ["foo", "bar"]) 27 | event_manager.handle("foo") 28 | with pytest.raises(TooManyRequestsException): 29 | event_manager.handle("bar") 30 | 31 | 32 | def test_rate_limiter_noraise(): 33 | """Tests that the rate limiter raises if the rate limit is exceeded.""" 34 | event_manager = EventManager() 35 | # Smallest possible value such that 1+x>1, should never run into rate limit 36 | rate_limiter = RateLimitHandler(1, sys.float_info.epsilon) 37 | event_manager.register_handler_for_events(rate_limiter, ["foo"]) 38 | for _ in range(10): 39 | event_manager.handle("foo") 40 | 41 | 42 | def test_event_manager_discriminates(): 43 | """Tests that the rate limiter raises if the rate limit is exceeded.""" 44 | event_manager = EventManager() 45 | rate_limiter = RateLimitHandler(1, sys.float_info.max) 46 | event_manager.register_handler_for_events(rate_limiter, ["bar"]) 47 | # Should not raise 48 | event_manager.handle("foo") 49 | -------------------------------------------------------------------------------- /aana/deployments/haystack_component_deployment.py: -------------------------------------------------------------------------------- 1 | from typing import Any 2 | 3 | from pydantic import BaseModel 4 | from ray import serve 5 | 6 | from aana.deployments.base_deployment import BaseDeployment, exception_handler 7 | from aana.utils.core import import_from_path 8 | 9 | 10 | class HaystackComponentDeploymentConfig(BaseModel): 11 | """Configuration for the HaystackComponentDeployment. 12 | 13 | Attributes: 14 | component (str): The path to the Haystack component to deploy. 15 | params (dict): The parameters to pass to the component on initialization (model etc). 16 | """ 17 | 18 | component: str 19 | params: dict[str, Any] 20 | 21 | 22 | @serve.deployment 23 | class HaystackComponentDeployment(BaseDeployment): 24 | """Deployment to deploy a Haystack component.""" 25 | 26 | async def apply_config(self, config: dict[str, Any]): 27 | """Apply the configuration. 28 | 29 | The method is called when the deployment is created or updated. 30 | 31 | It creates the Haystack component and warms it up. 32 | 33 | The configuration should conform to the HaystackComponentDeploymentConfig schema. 34 | """ 35 | config_obj = HaystackComponentDeploymentConfig(**config) 36 | 37 | self.params = config_obj.params 38 | self.component_path = config_obj.component 39 | self.component = import_from_path(config_obj.component)(**config_obj.params) 40 | 41 | self.component.warm_up() 42 | 43 | @exception_handler 44 | async def run(self, **data: dict[str, Any]) -> dict[str, Any]: 45 | """Run the model on the input data.""" 46 | return self.component.run(**data) 47 | 48 | async def get_sockets(self): 49 | """Get the input and output sockets of the component.""" 50 | return { 51 | "output": self.component.__haystack_output__._sockets_dict, 52 | "input": self.component.__haystack_input__._sockets_dict, 53 | } 54 | -------------------------------------------------------------------------------- /docs/pages/docker.md: -------------------------------------------------------------------------------- 1 | # Run with Docker 2 | 3 | We provide a docker-compose configuration to run the application in a Docker container in [Aana App Template](https://github.com/mobiusml/aana_app_template/blob/main?tab=readme-ov-file#running-with-docker). 4 | 5 | Requirements: 6 | 7 | - Docker Engine >= 26.1.0 8 | - Docker Compose >= 1.29.2 9 | - NVIDIA Driver >= 525.60.13 10 | 11 | You can edit the [Dockerfile](https://github.com/mobiusml/aana_app_template/blob/main/Dockerfile) to assemble the image as you desire and 12 | and [docker-compose file](https://github.com/mobiusml/aana_app_template/blob/main/docker-compose.yaml) for container instances and their environment variables. 13 | 14 | To run the application, simply run the following command: 15 | 16 | ```bash 17 | docker-compose up 18 | ``` 19 | 20 | The application will be accessible at `http://localhost:8000` on the host server. 21 | 22 | 23 | !!! warning 24 | 25 | If your applications requires GPU to run, you need to specify which GPU to use. 26 | 27 | The applications will detect the available GPU automatically but you need to make sure that `CUDA_VISIBLE_DEVICES` is set correctly. 28 | 29 | Sometimes `CUDA_VISIBLE_DEVICES` is set to an empty string and the application will not be able to detect the GPU. Use `unset CUDA_VISIBLE_DEVICES` to unset the variable. 30 | 31 | You can also set the `CUDA_VISIBLE_DEVICES` environment variable to the GPU index you want to use: `CUDA_VISIBLE_DEVICES=0 docker-compose up`. 32 | 33 | 34 | !!! Tip 35 | 36 | Some models use Flash Attention for better performance. You can set the build argument `INSTALL_FLASH_ATTENTION` to `true` to install Flash Attention. 37 | 38 | ```bash 39 | INSTALL_FLASH_ATTENTION=true docker-compose build 40 | ``` 41 | 42 | After building the image, you can use `docker-compose up` command to run the application. 43 | 44 | You can also set the `INSTALL_FLASH_ATTENTION` environment variable to `true` in the `docker-compose.yaml` file. 45 | 46 | -------------------------------------------------------------------------------- /aana/utils/lazy_import.py: -------------------------------------------------------------------------------- 1 | # Adapted from: https://github.com/deepset-ai/haystack 2 | # 3 | # SPDX-FileCopyrightText: 2022-present deepset GmbH 4 | # 5 | # SPDX-License-Identifier: Apache-2.0 6 | 7 | from types import TracebackType 8 | 9 | from lazy_imports.try_import import _DeferredImportExceptionContextManager 10 | 11 | DEFAULT_IMPORT_ERROR_MSG = "Try 'pip install {}'" 12 | 13 | 14 | class LazyImport(_DeferredImportExceptionContextManager): 15 | """Wrapper on top of lazy_import's _DeferredImportExceptionContextManager. 16 | 17 | It adds the possibility to customize the error messages. 18 | """ 19 | 20 | def __init__(self, message: str = DEFAULT_IMPORT_ERROR_MSG) -> None: 21 | """Initialize the context manager.""" 22 | super().__init__() 23 | self.import_error_msg = message 24 | 25 | def __exit__( 26 | self, 27 | exc_type: type[Exception] | None = None, 28 | exc_value: Exception | None = None, 29 | traceback: TracebackType | None = None, 30 | ) -> bool | None: 31 | """Exit the context manager. 32 | 33 | Args: 34 | exc_type: 35 | Raised exception type. :obj:`None` if nothing is raised. 36 | exc_value: 37 | Raised exception object. :obj:`None` if nothing is raised. 38 | traceback: 39 | Associated traceback. :obj:`None` if nothing is raised. 40 | 41 | Returns: 42 | :obj:`None` if nothing is deferred, otherwise :obj:`True`. 43 | :obj:`True` will suppress any exceptions avoiding them from propagating. 44 | 45 | """ 46 | if isinstance(exc_value, ImportError): 47 | message = ( 48 | f"Failed to import '{exc_value.name}'. {self.import_error_msg.format(exc_value.name)}. " 49 | f"Original error: {exc_value}" 50 | ) 51 | self._deferred = (exc_value, message) 52 | return True 53 | return None 54 | -------------------------------------------------------------------------------- /aana/exceptions/db.py: -------------------------------------------------------------------------------- 1 | from aana.core.models.media import MediaId 2 | from aana.exceptions.core import BaseException 3 | 4 | __all__ = [ 5 | "MediaIdAlreadyExistsException", 6 | "NotFoundException", 7 | ] 8 | 9 | 10 | class DatabaseException(BaseException): 11 | def __init__(self, message: str): 12 | """Constructor. 13 | 14 | Args: 15 | message: (str): the error message. 16 | """ 17 | super().__init__(message=message) 18 | self.message = message 19 | 20 | def __reduce__(self): 21 | """Used for pickling.""" 22 | return (self.__class__, (self.message,)) 23 | 24 | 25 | class NotFoundException(BaseException): 26 | """Raised when an item searched by id is not found.""" 27 | 28 | def __init__(self, table_name: str, id: int | MediaId): # noqa: A002 29 | """Constructor. 30 | 31 | Args: 32 | table_name (str): the name of the table being queried. 33 | id (int | MediaId): the id of the item to be retrieved. 34 | """ 35 | super().__init__(table=table_name, id=id) 36 | self.table_name = table_name 37 | self.id = id 38 | self.http_status_code = 404 39 | 40 | def __reduce__(self): 41 | """Used for pickling.""" 42 | return (self.__class__, (self.table_name, self.id)) 43 | 44 | 45 | class MediaIdAlreadyExistsException(BaseException): 46 | """Raised when a media_id already exists.""" 47 | 48 | def __init__(self, table_name: str, media_id: MediaId): 49 | """Constructor. 50 | 51 | Args: 52 | table_name (str): the name of the table being queried. 53 | media_id (MediaId): the id of the item to be retrieved. 54 | """ 55 | super().__init__(table=table_name, id=media_id) 56 | self.table_name = table_name 57 | self.media_id = media_id 58 | 59 | def __reduce__(self): 60 | """Used for pickling.""" 61 | return (self.__class__, (self.table_name, self.media_id)) 62 | -------------------------------------------------------------------------------- /aana/core/libraries/audio.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | 3 | import numpy as np 4 | 5 | 6 | class AbstractAudioLibrary: 7 | """Abstract class for audio libraries.""" 8 | 9 | @classmethod 10 | def read_file(cls, path: Path) -> np.ndarray: 11 | """Read an audio file from path and return as numpy audio array. 12 | 13 | Args: 14 | path (Path): The path of the file to read. 15 | 16 | Returns: 17 | np.ndarray: The file as a numpy array. 18 | """ 19 | raise NotImplementedError 20 | 21 | @classmethod 22 | def read_from_bytes(cls, content: bytes) -> np.ndarray: 23 | """Read bytes using the image library. 24 | 25 | Args: 26 | content (bytes): The content of the file to read. 27 | 28 | Returns: 29 | np.ndarray: The file as a numpy array. 30 | """ 31 | raise NotImplementedError 32 | 33 | @classmethod 34 | def write_file(cls, path: Path, audio: np.ndarray): 35 | """Write a file using the audio library. 36 | 37 | Args: 38 | path (Path): The path of the file to write. 39 | audio (np.ndarray): The audio to write. 40 | """ 41 | raise NotImplementedError 42 | 43 | @classmethod 44 | def write_to_bytes(cls, audio: np.ndarray) -> bytes: 45 | """Write bytes using the audio library. 46 | 47 | Args: 48 | audio (np.ndarray): The audio to write. 49 | 50 | Returns: 51 | bytes: The audio as bytes. 52 | """ 53 | raise NotImplementedError 54 | 55 | @classmethod 56 | def write_audio_bytes(cls, path: Path, audio: bytes, sample_rate: int = 16000): 57 | """Write a file to wav from the normalized audio bytes. 58 | 59 | Args: 60 | path (Path): The path of the file to write. 61 | audio (bytes): The audio to in 16-bit PCM byte write. 62 | sample_rate (int): The sample rate of the audio. 63 | """ 64 | raise NotImplementedError 65 | -------------------------------------------------------------------------------- /aana/tests/deployments/test_vad_deployment.py: -------------------------------------------------------------------------------- 1 | # ruff: noqa: S101 2 | from importlib import resources 3 | from pathlib import Path 4 | 5 | import pytest 6 | 7 | from aana.core.models.audio import Audio 8 | from aana.core.models.base import pydantic_to_dict 9 | from aana.core.models.vad import VadParams 10 | from aana.deployments.aana_deployment_handle import AanaDeploymentHandle 11 | from aana.deployments.vad_deployment import VadConfig, VadDeployment 12 | from aana.tests.utils import verify_deployment_results 13 | 14 | deployments = [ 15 | ( 16 | "vad_deployment", 17 | VadDeployment.options( 18 | num_replicas=1, 19 | max_ongoing_requests=1000, 20 | ray_actor_options={"num_gpus": 0}, 21 | user_config=VadConfig( 22 | model_id="pyannote/segmentation", 23 | onset=0.5, 24 | sample_rate=16000, 25 | ).model_dump(mode="json"), 26 | ), 27 | ) 28 | ] 29 | 30 | 31 | @pytest.mark.parametrize("setup_deployment", deployments, indirect=True) 32 | class TestVadDeployment: 33 | """Test VAD deployment.""" 34 | 35 | @pytest.mark.asyncio 36 | @pytest.mark.parametrize("audio_file", ["physicsworks.wav", "squirrel.wav"]) 37 | async def test_vad(self, setup_deployment, audio_file): 38 | """Test VAD.""" 39 | deployment_name, handle_name, _ = setup_deployment 40 | 41 | handle = await AanaDeploymentHandle.create(handle_name) 42 | 43 | audio_file_name = Path(audio_file).stem 44 | expected_output_path = ( 45 | resources.files("aana.tests.files.expected") 46 | / "vad" 47 | / f"{audio_file_name}.json" 48 | ) 49 | 50 | path = resources.files("aana.tests.files.audios") / audio_file 51 | assert path.exists(), f"Audio not found: {path}" 52 | 53 | audio = Audio(path=path) 54 | 55 | output = await handle.asr_preprocess_vad(audio=audio, params=VadParams()) 56 | output = pydantic_to_dict(output) 57 | verify_deployment_results(expected_output_path, output) 58 | -------------------------------------------------------------------------------- /aana/alembic/versions/acb40dabc2c0_added_webhooks.py: -------------------------------------------------------------------------------- 1 | """Added webhooks. 2 | 3 | Revision ID: acb40dabc2c0 4 | Revises: d40eba8ebc4c 5 | Create Date: 2025-01-30 14:32:16.596842 6 | 7 | """ 8 | from collections.abc import Sequence 9 | 10 | import sqlalchemy as sa 11 | from alembic import op 12 | from sqlalchemy.dialects import postgresql 13 | 14 | # revision identifiers, used by Alembic. 15 | revision: str = "acb40dabc2c0" 16 | down_revision: str | None = "d40eba8ebc4c" 17 | branch_labels: str | Sequence[str] | None = None 18 | depends_on: str | Sequence[str] | None = None 19 | 20 | 21 | def upgrade() -> None: 22 | """Upgrade database to this revision from previous.""" 23 | # fmt: off 24 | op.create_table('webhooks', 25 | sa.Column('id', sa.UUID(), nullable=False, comment='Webhook ID'), 26 | sa.Column('user_id', sa.String(), nullable=True, comment='The user ID associated with the webhook'), 27 | sa.Column('url', sa.String(), nullable=False, comment='The URL to which the webhook will send requests'), 28 | sa.Column('events', sa.JSON().with_variant(postgresql.JSONB(astext_type=sa.Text()), 'postgresql'), nullable=False, comment='List of events the webhook is subscribed to. If None, the webhook is subscribed to all events.'), 29 | sa.Column('created_at', sa.DateTime(timezone=True), server_default=sa.text('(CURRENT_TIMESTAMP)'), nullable=False, comment='Timestamp when row is inserted'), 30 | sa.Column('updated_at', sa.DateTime(timezone=True), server_default=sa.text('(CURRENT_TIMESTAMP)'), nullable=False, comment='Timestamp when row is updated'), 31 | sa.PrimaryKeyConstraint('id', name=op.f('pk_webhooks')) 32 | ) 33 | with op.batch_alter_table('webhooks', schema=None) as batch_op: 34 | batch_op.create_index(batch_op.f('ix_webhooks_user_id'), ['user_id'], unique=False) 35 | # fmt: on 36 | 37 | 38 | def downgrade() -> None: 39 | """Downgrade database from this revision to previous.""" 40 | with op.batch_alter_table("webhooks", schema=None) as batch_op: 41 | batch_op.drop_index(batch_op.f("ix_webhooks_user_id")) 42 | 43 | op.drop_table("webhooks") 44 | -------------------------------------------------------------------------------- /aana/tests/units/test_deployment_retry.py: -------------------------------------------------------------------------------- 1 | # ruff: noqa: S101 2 | 3 | import pytest 4 | from ray import serve 5 | 6 | from aana.deployments.aana_deployment_handle import AanaDeploymentHandle 7 | from aana.deployments.base_deployment import BaseDeployment, exception_handler 8 | 9 | 10 | @serve.deployment(health_check_period_s=1, health_check_timeout_s=30) 11 | class Lowercase(BaseDeployment): 12 | """Ray deployment that returns the lowercase version of a text.""" 13 | 14 | def __init__(self): 15 | """Initialize the deployment.""" 16 | super().__init__() 17 | self.num_requests = 0 18 | 19 | @exception_handler 20 | async def lower(self, text: str) -> dict: 21 | """Lowercase the text. 22 | 23 | Args: 24 | text (str): The text to lowercase 25 | 26 | Returns: 27 | dict: The lowercase text 28 | """ 29 | # Only every 3rd request should be successful 30 | self.num_requests += 1 31 | if self.num_requests % 3 != 0: 32 | raise Exception("Random exception") # noqa: TRY002, TRY003 33 | 34 | return {"text": text.lower()} 35 | 36 | 37 | deployments = [ 38 | { 39 | "name": "lowercase_deployment", 40 | "instance": Lowercase, 41 | } 42 | ] 43 | 44 | 45 | @pytest.mark.asyncio 46 | async def test_deployment_retry(create_app): 47 | """Test the Ray Serve app.""" 48 | create_app(deployments, []) 49 | 50 | text = "Hello, World!" 51 | 52 | # Get deployment handle without retries 53 | handle = await AanaDeploymentHandle.create( 54 | "lowercase_deployment", retry_exceptions=False 55 | ) 56 | 57 | # test the lowercase deployment fails 58 | with pytest.raises(Exception): # noqa: B017 59 | await handle.lower(text=text) 60 | 61 | # Get deployment handle with retries 62 | handle = await AanaDeploymentHandle.create( 63 | "lowercase_deployment", retry_exceptions=True 64 | ) 65 | 66 | # test the lowercase deployment works 67 | response = await handle.lower(text=text) 68 | assert response == {"text": text.lower()} 69 | -------------------------------------------------------------------------------- /aana/utils/json.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | from typing import Any 3 | 4 | import orjson 5 | from pydantic import BaseModel 6 | from sqlalchemy import Engine 7 | 8 | __all__ = ["json_serializer_default", "jsonify"] 9 | 10 | 11 | def json_serializer_default(obj: object) -> object: 12 | """Default function for json serializer to handle custom objects. 13 | 14 | If json serializer does not know how to serialize an object, it calls the default function. 15 | 16 | For example, if we see that the object is a pydantic model, 17 | we call the dict method to get the dictionary representation of the model 18 | that json serializer can deal with. 19 | 20 | If the object is not supported, we raise a TypeError. 21 | 22 | Args: 23 | obj (object): The object to serialize. 24 | 25 | Returns: 26 | object: The serializable object. 27 | 28 | Raises: 29 | TypeError: If the object is not a pydantic model, Path, or Media object. 30 | """ 31 | if isinstance(obj, Engine): 32 | return None 33 | if isinstance(obj, BaseModel): 34 | return obj.model_dump() 35 | if isinstance(obj, Path): 36 | return str(obj) 37 | if isinstance(obj, type): 38 | return str(type) 39 | if isinstance(obj, bytes): 40 | return obj.decode() 41 | 42 | from aana.core.models.media import Media 43 | 44 | if isinstance(obj, Media): 45 | return str(obj) 46 | 47 | raise TypeError(type(obj)) 48 | 49 | 50 | def jsonify( 51 | data: Any, 52 | option: int | None = orjson.OPT_SERIALIZE_NUMPY | orjson.OPT_SORT_KEYS, 53 | as_bytes: bool = False, 54 | ) -> str | bytes: 55 | """Serialize content using orjson. 56 | 57 | Args: 58 | data (Any): The content to serialize. 59 | option (int | None): The option for orjson.dumps. 60 | as_bytes (bool): Return output as bytes instead of string 61 | 62 | Returns: 63 | bytes | str: The serialized data as desired format. 64 | """ 65 | output = orjson.dumps(data, option=option, default=json_serializer_default) 66 | return output if as_bytes else output.decode() 67 | -------------------------------------------------------------------------------- /aana/tests/units/test_app_deploy.py: -------------------------------------------------------------------------------- 1 | from typing import Any 2 | 3 | import pytest 4 | from ray import serve 5 | 6 | from aana.deployments.base_deployment import BaseDeployment 7 | from aana.exceptions.runtime import FailedDeployment, InsufficientResources 8 | 9 | 10 | @serve.deployment 11 | class DummyFailingDeployment(BaseDeployment): 12 | """Simple deployment that fails on initialization.""" 13 | 14 | async def apply_config(self, config: dict[str, Any]): 15 | """Apply the configuration to the deployment and initialize it.""" 16 | raise Exception("Dummy exception") # noqa: TRY002, TRY003 17 | 18 | 19 | @serve.deployment 20 | class Lowercase(BaseDeployment): 21 | """Simple deployment that lowercases the text.""" 22 | 23 | async def apply_config(self, config: dict[str, Any]): 24 | """Apply the configuration to the deployment and initialize it.""" 25 | pass 26 | 27 | async def lower(self, text: str) -> dict: 28 | """Lowercase the text. 29 | 30 | Args: 31 | text (str): The text to lowercase 32 | 33 | Returns: 34 | dict: The lowercase text 35 | """ 36 | return {"text": [t.lower() for t in text]} 37 | 38 | 39 | def test_failed_deployment(create_app): 40 | """Test that a failed deployment raises a FailedDeployment exception.""" 41 | deployments = [ 42 | { 43 | "name": "deployment", 44 | "instance": DummyFailingDeployment.options(num_replicas=1, user_config={}), 45 | } 46 | ] 47 | with pytest.raises(FailedDeployment): 48 | create_app(deployments, []) 49 | 50 | 51 | def test_insufficient_resources(create_app): 52 | """Test that deployment fails when there are insufficient resources to deploy.""" 53 | deployments = [ 54 | { 55 | "name": "deployment", 56 | "instance": Lowercase.options( 57 | num_replicas=1, 58 | ray_actor_options={"num_gpus": 100}, # requires 100 GPUs 59 | user_config={}, 60 | ), 61 | } 62 | ] 63 | with pytest.raises(InsufficientResources): 64 | create_app(deployments, []) 65 | -------------------------------------------------------------------------------- /aana/tests/files/expected/sd/sd_sample.json: -------------------------------------------------------------------------------- 1 | { 2 | "segments": [ 3 | { 4 | "time_interval": { 5 | "start": 6.730343750000001, 6 | "end": 7.16909375 7 | }, 8 | "speaker": "SPEAKER_01" 9 | }, 10 | { 11 | "time_interval": { 12 | "start": 7.16909375, 13 | "end": 7.185968750000001 14 | }, 15 | "speaker": "SPEAKER_02" 16 | }, 17 | { 18 | "time_interval": { 19 | "start": 7.59096875, 20 | "end": 8.316593750000003 21 | }, 22 | "speaker": "SPEAKER_01" 23 | }, 24 | { 25 | "time_interval": { 26 | "start": 8.316593750000003, 27 | "end": 9.919718750000001 28 | }, 29 | "speaker": "SPEAKER_02" 30 | }, 31 | { 32 | "time_interval": { 33 | "start": 9.919718750000001, 34 | "end": 10.93221875 35 | }, 36 | "speaker": "SPEAKER_01" 37 | }, 38 | { 39 | "time_interval": { 40 | "start": 10.93221875, 41 | "end": 14.745968750000003 42 | }, 43 | "speaker": "SPEAKER_02" 44 | }, 45 | { 46 | "time_interval": { 47 | "start": 14.745968750000003, 48 | "end": 17.88471875 49 | }, 50 | "speaker": "SPEAKER_00" 51 | }, 52 | { 53 | "time_interval": { 54 | "start": 18.01971875, 55 | "end": 21.512843750000002 56 | }, 57 | "speaker": "SPEAKER_02" 58 | }, 59 | { 60 | "time_interval": { 61 | "start": 21.512843750000002, 62 | "end": 28.49909375 63 | }, 64 | "speaker": "SPEAKER_00" 65 | }, 66 | { 67 | "time_interval": { 68 | "start": 28.49909375, 69 | "end": 29.96721875 70 | }, 71 | "speaker": "SPEAKER_02" 72 | } 73 | ] 74 | } -------------------------------------------------------------------------------- /aana/tests/units/test_merge_options.py: -------------------------------------------------------------------------------- 1 | # ruff: noqa: S101 2 | 3 | import pytest 4 | from pydantic import BaseModel 5 | 6 | from aana.core.models.base import merged_options 7 | 8 | 9 | class MyOptions(BaseModel): 10 | """Test option class.""" 11 | 12 | field1: str 13 | field2: int | None = None 14 | field3: bool 15 | field4: str = "default" 16 | 17 | 18 | def test_merged_options_same_type(): 19 | """Test merged_options with options of the same type as default_options.""" 20 | default = MyOptions(field1="default1", field2=2, field3=True) 21 | to_merge = MyOptions(field1="merge1", field2=None, field3=False) 22 | merged = merged_options(default, to_merge) 23 | 24 | assert merged.field1 == "merge1" 25 | assert ( 26 | merged.field2 == 2 27 | ) # Should retain value from default_options as it's None in options 28 | assert merged.field3 == False 29 | 30 | 31 | def test_merged_options_none(): 32 | """Test merged_options with options=None.""" 33 | default = MyOptions(field1="default1", field2=2, field3=True) 34 | merged = merged_options(default, None) 35 | 36 | assert merged.model_dump() == default.model_dump() 37 | 38 | 39 | def test_merged_options_type_mismatch(): 40 | """Test merged_options with options of a different type from default_options.""" 41 | 42 | class AnotherOptions(BaseModel): 43 | another_field: str 44 | 45 | default = MyOptions(field1="default1", field2=2, field3=True) 46 | to_merge = AnotherOptions(another_field="test") 47 | 48 | with pytest.raises(ValueError): 49 | merged_options(default, to_merge) 50 | 51 | 52 | def test_merged_options_unset(): 53 | """Test merged_options with unset fields.""" 54 | default = MyOptions(field1="default1", field2=2, field3=True, field4="new_default") 55 | to_merge = MyOptions(field1="merge1", field3=False) # field4 is not set 56 | merged = merged_options(default, to_merge) 57 | 58 | assert merged.field1 == "merge1" 59 | assert merged.field2 == 2 60 | assert merged.field3 == False 61 | assert ( 62 | merged.field4 == "new_default" 63 | ) # Should retain value from default_options as it's not set in options 64 | -------------------------------------------------------------------------------- /aana/tests/units/test_whisper_params.py: -------------------------------------------------------------------------------- 1 | # ruff: noqa: S101 2 | import pytest 3 | 4 | from aana.core.models.whisper import WhisperParams 5 | 6 | 7 | def test_whisper_params_default(): 8 | """Test the default values of WhisperParams object. 9 | 10 | Keeping the default parameters of a function or object is important 11 | in case other code relies on them. 12 | 13 | If you need to change the default parameters, think twice before doing so. 14 | """ 15 | params = WhisperParams() 16 | 17 | assert params.language is None 18 | assert params.beam_size == 5 19 | assert params.best_of == 5 20 | assert params.temperature == (0.0, 0.2, 0.4, 0.6, 0.8, 1.0) 21 | assert params.word_timestamps is False 22 | assert params.vad_filter is True 23 | 24 | 25 | @pytest.mark.parametrize( 26 | "language, beam_size, best_of, temperature, word_timestamps, vad_filter", 27 | [ 28 | ("en", 5, 5, 0.5, True, True), 29 | ("fr", 3, 3, 0.2, False, False), 30 | (None, 1, 1, [0.8, 0.9], True, False), 31 | ], 32 | ) 33 | def test_whisper_params( 34 | language, beam_size, best_of, temperature, word_timestamps, vad_filter 35 | ): 36 | """Test function for the WhisperParams class with valid parameters.""" 37 | params = WhisperParams( 38 | language=language, 39 | beam_size=beam_size, 40 | best_of=best_of, 41 | temperature=temperature, 42 | word_timestamps=word_timestamps, 43 | vad_filter=vad_filter, 44 | ) 45 | 46 | assert params.language == language 47 | assert params.beam_size == beam_size 48 | assert params.best_of == best_of 49 | assert params.temperature == temperature 50 | assert params.word_timestamps == word_timestamps 51 | assert params.vad_filter == vad_filter 52 | 53 | 54 | @pytest.mark.parametrize( 55 | "temperature", 56 | [ 57 | [-1.0, 0.5, 1.5], 58 | [0.0, 0.2, 0.4, 0.6, 0.8, 1.0, 2.0], 59 | "invalid_temperature", 60 | 2, 61 | ], 62 | ) 63 | def test_whisper_params_invalid_temperature(temperature): 64 | """Check ValueError raised if temperature is invalid.""" 65 | with pytest.raises(ValueError): 66 | WhisperParams(temperature=temperature) 67 | -------------------------------------------------------------------------------- /aana/tests/db/datastore/test_video_repo.py: -------------------------------------------------------------------------------- 1 | # ruff: noqa: S101 2 | 3 | import uuid 4 | from importlib import resources 5 | 6 | import pytest 7 | 8 | from aana.core.models.video import Video, VideoMetadata 9 | from aana.exceptions.db import MediaIdAlreadyExistsException, NotFoundException 10 | from aana.storage.repository.video import VideoRepository 11 | 12 | 13 | @pytest.fixture(scope="function") 14 | def dummy_video(): 15 | """Creates a dummy video for testing.""" 16 | media_id = str(uuid.uuid4()) 17 | path = resources.files("aana.tests.files.videos") / "squirrel.mp4" 18 | video = Video(path=path, media_id=media_id) 19 | return video 20 | 21 | 22 | @pytest.mark.asyncio 23 | async def test_save_video(db_session_manager, dummy_video): 24 | """Tests saving a video.""" 25 | async with db_session_manager.session() as session: 26 | video_repo = VideoRepository(session) 27 | await video_repo.save(dummy_video) 28 | 29 | video_entity = await video_repo.read(dummy_video.media_id) 30 | assert video_entity 31 | assert video_entity.id == dummy_video.media_id 32 | 33 | # Try to save the same video again 34 | with pytest.raises(MediaIdAlreadyExistsException): 35 | await video_repo.save(dummy_video) 36 | 37 | await video_repo.delete(dummy_video.media_id) 38 | with pytest.raises(NotFoundException): 39 | await video_repo.read(dummy_video.media_id) 40 | 41 | 42 | @pytest.mark.asyncio 43 | async def test_get_metadata(db_session_manager, dummy_video): 44 | """Tests getting video metadata.""" 45 | async with db_session_manager.session() as session: 46 | video_repo = VideoRepository(session) 47 | await video_repo.save(dummy_video) 48 | 49 | metadata = await video_repo.get_metadata(dummy_video.media_id) 50 | assert isinstance(metadata, VideoMetadata) 51 | assert metadata.title == dummy_video.title 52 | assert metadata.description == dummy_video.description 53 | assert metadata.duration == None 54 | 55 | await video_repo.delete(dummy_video.media_id) 56 | with pytest.raises(NotFoundException): 57 | await video_repo.get_metadata(dummy_video.media_id) 58 | -------------------------------------------------------------------------------- /aana/tests/deployments/test_pyannote_speaker_diarization_deployment.py: -------------------------------------------------------------------------------- 1 | # ruff: noqa: S101 2 | from importlib import resources 3 | from pathlib import Path 4 | 5 | import pytest 6 | 7 | from aana.core.models.audio import Audio 8 | from aana.core.models.base import pydantic_to_dict 9 | from aana.deployments.aana_deployment_handle import AanaDeploymentHandle 10 | from aana.deployments.pyannote_speaker_diarization_deployment import ( 11 | PyannoteSpeakerDiarizationConfig, 12 | PyannoteSpeakerDiarizationDeployment, 13 | ) 14 | from aana.tests.utils import verify_deployment_results 15 | 16 | deployments = [ 17 | ( 18 | "sd_deployment", 19 | PyannoteSpeakerDiarizationDeployment.options( 20 | num_replicas=1, 21 | max_ongoing_requests=1000, 22 | ray_actor_options={"num_gpus": 0.05}, 23 | user_config=PyannoteSpeakerDiarizationConfig( 24 | model_id=("pyannote/speaker-diarization-3.1"), 25 | sample_rate=16000, 26 | ).model_dump(mode="json"), 27 | ), 28 | ) 29 | ] 30 | 31 | 32 | @pytest.mark.skip(reason="The test is temporary disabled") 33 | @pytest.mark.parametrize("setup_deployment", deployments, indirect=True) 34 | class TestPyannoteSpeakerDiarizationDeployment: 35 | """Test pyannote Speaker Diarization deployment.""" 36 | 37 | @pytest.mark.asyncio 38 | @pytest.mark.parametrize("audio_file", ["sd_sample.wav"]) 39 | async def test_speaker_diarization(self, setup_deployment, audio_file): 40 | """Test pyannote Speaker Diarization.""" 41 | deployment_name, handle_name, _ = setup_deployment 42 | 43 | handle = await AanaDeploymentHandle.create(handle_name) 44 | 45 | audio_file_name = Path(audio_file).stem 46 | expected_output_path = ( 47 | resources.files("aana.tests.files.expected") 48 | / "sd" 49 | / f"{audio_file_name}.json" 50 | ) 51 | 52 | path = resources.files("aana.tests.files.audios") / audio_file 53 | assert path.exists(), f"Audio not found: {path}" 54 | 55 | audio = Audio(path=path) 56 | 57 | output = await handle.diarize(audio=audio) 58 | output = pydantic_to_dict(output) 59 | 60 | verify_deployment_results(expected_output_path, output) 61 | -------------------------------------------------------------------------------- /aana/tests/deployments/test_hf_blip2_deployment.py: -------------------------------------------------------------------------------- 1 | # ruff: noqa: S101 2 | from importlib import resources 3 | 4 | import pytest 5 | 6 | from aana.core.models.image import Image 7 | from aana.core.models.types import Dtype 8 | from aana.deployments.aana_deployment_handle import AanaDeploymentHandle 9 | from aana.deployments.hf_blip2_deployment import HFBlip2Config, HFBlip2Deployment 10 | from aana.tests.utils import verify_deployment_results 11 | 12 | deployments = [ 13 | ( 14 | "blip2_deployment", 15 | HFBlip2Deployment.options( 16 | num_replicas=1, 17 | max_ongoing_requests=1000, 18 | ray_actor_options={"num_gpus": 0.25}, 19 | user_config=HFBlip2Config( 20 | model="Salesforce/blip2-opt-2.7b", 21 | dtype=Dtype.FLOAT16, 22 | batch_size=2, 23 | num_processing_threads=2, 24 | ).model_dump(mode="json"), 25 | ), 26 | ) 27 | ] 28 | 29 | 30 | @pytest.mark.parametrize("setup_deployment", deployments, indirect=True) 31 | class TestHFBlip2Deployment: 32 | """Test HuggingFace BLIP2 deployment.""" 33 | 34 | @pytest.mark.asyncio 35 | @pytest.mark.parametrize("image_name", ["Starry_Night.jpeg"]) 36 | async def test_methods(self, setup_deployment, image_name): 37 | """Test HuggingFace BLIP2 methods.""" 38 | deployment_name, handle_name, _ = setup_deployment 39 | 40 | handle = await AanaDeploymentHandle.create(handle_name) 41 | 42 | expected_output_path = ( 43 | resources.files("aana.tests.files.expected") 44 | / "hf_blip2" 45 | / f"{deployment_name}_{image_name}.json" 46 | ) 47 | 48 | path = resources.files("aana.tests.files.images") / image_name 49 | image = Image(path=path, save_on_disk=False, media_id=image_name) 50 | 51 | output = await handle.generate(image=image) 52 | caption = output["caption"] 53 | verify_deployment_results(expected_output_path, caption) 54 | 55 | images = [image] * 8 56 | 57 | output = await handle.generate_batch(images=images) 58 | captions = output["captions"] 59 | 60 | assert len(captions) == 8 61 | for caption in captions: 62 | verify_deployment_results(expected_output_path, caption) 63 | -------------------------------------------------------------------------------- /aana/api/event_handlers/rate_limit_handler.py: -------------------------------------------------------------------------------- 1 | from time import monotonic 2 | 3 | from typing_extensions import override 4 | 5 | from aana.api.event_handlers.event_handler import EventHandler 6 | from aana.exceptions.runtime import TooManyRequestsException 7 | 8 | 9 | class RateLimitHandler(EventHandler): 10 | """Event handler that raises TooManyRequestsException if the rate limit is exceeded. 11 | 12 | Attributes: 13 | capacity (int): number of resources (requests) per interval 14 | rate (float): the interval for the limit in seconds 15 | """ 16 | 17 | capacity: int 18 | interval: float 19 | 20 | _calls: list 21 | 22 | def __init__(self, capacity: int, interval: float): 23 | """Constructor.""" 24 | self.capacity = capacity 25 | self.interval = interval 26 | self._calls = [] 27 | 28 | def _clear_expired(self, expired: float): 29 | """Removes expired items from list of resources. 30 | 31 | Arguments: 32 | expired: timestamp before which to clear, as output from time.monotonic() 33 | """ 34 | while self._calls and self._calls[0] < expired: 35 | self._calls.pop(0) 36 | 37 | def _acquire(self): 38 | """Checks if we can acquire (process) a resource. 39 | 40 | Raises: 41 | TooManyRequestsException: if the rate limit has been reached 42 | """ 43 | now = monotonic() 44 | expired = now - self.interval 45 | self._clear_expired(expired) 46 | if len(self._calls) < self.capacity: 47 | self._calls.append(now) 48 | else: 49 | raise TooManyRequestsException(self.capacity, self.interval) 50 | 51 | @override 52 | def handle(self, event_name: str, *args, **kwargs): 53 | """Handle the event by checking against rate limiting parameters. 54 | 55 | Arguments: 56 | event_name (str): the name of the event to handle 57 | *args (list): args for the event 58 | **kwargs (dict): keyword args for the event 59 | 60 | Raises: 61 | TooManyRequestsException: if the rate limit has been reached 62 | """ 63 | # if the endpoint execution is deferred, we don't want to rate limit it 64 | defer = kwargs.get("defer", False) 65 | if not defer: 66 | self._acquire() 67 | -------------------------------------------------------------------------------- /docs/pages/model_hub/hf_pipeline.md: -------------------------------------------------------------------------------- 1 | # Hugging Face Pipeline Models 2 | 3 | [Hugging Face Pipeline deployment](./../../reference/deployments.md#aana.deployments.hf_pipeline_deployment.HfPipelineDeployment) allows you to serve *almost* any model from the [Hugging Face Hub](https://huggingface.co/models). It is a wrapper for [Hugging Face Pipelines](https://huggingface.co/transformers/main_classes/pipelines.html) so you can deploy and scale *almost* any model from the Hugging Face Hub with a few lines of code. 4 | 5 | !!! Tip 6 | To use HF Pipeline deployment, install required libraries with `pip install transformers` or include extra dependencies using `pip install aana[transformers]`. 7 | 8 | 9 | [HfPipelineConfig](./../../reference/deployments.md#aana.deployments.hf_pipeline_deployment.HfPipelineConfig) is used to configure the Hugging Face Pipeline deployment. 10 | 11 | ::: aana.deployments.hf_pipeline_deployment.HfPipelineConfig 12 | options: 13 | show_bases: false 14 | heading_level: 4 15 | show_docstring_description: false 16 | docstring_section_style: list 17 | 18 | ### Example Configurations 19 | 20 | As an example, let's see how to configure the Hugging Face Pipeline deployment to serve [Salesforce BLIP-2 OPT-2.7b model](https://huggingface.co/Salesforce/blip2-opt-2.7b). 21 | 22 | !!! example "BLIP-2 OPT-2.7b" 23 | 24 | ```python 25 | from transformers import BitsAndBytesConfig 26 | from aana.deployments.hf_pipeline_deployment import HfPipelineConfig, HfPipelineDeployment 27 | 28 | HfPipelineDeployment.options( 29 | num_replicas=1, 30 | ray_actor_options={"num_gpus": 0.25}, 31 | user_config=HfPipelineConfig( 32 | model_id="Salesforce/blip2-opt-2.7b", 33 | task="image-to-text", 34 | model_kwargs={ 35 | "quantization_config": BitsAndBytesConfig(load_in_8bit=False, load_in_4bit=True), 36 | }, 37 | ).model_dump(mode="json"), 38 | ) 39 | ``` 40 | 41 | Model ID is the Hugging Face model ID. `task` is one of the [Hugging Face Pipelines tasks](https://huggingface.co/transformers/main_classes/pipelines.html) that the model can perform. We deploy the model with 4-bit quantization by setting `quantization_config` in the `model_kwargs` dictionary. You can pass extra arguments to the model in the `model_kwargs` dictionary. 42 | -------------------------------------------------------------------------------- /aana/tests/db/datastore/test_transcript_repo.py: -------------------------------------------------------------------------------- 1 | # ruff: noqa: S101 2 | 3 | import pytest 4 | 5 | from aana.core.models.asr import AsrSegment, AsrTranscription, AsrTranscriptionInfo 6 | from aana.core.models.time import TimeInterval 7 | from aana.exceptions.db import NotFoundException 8 | from aana.storage.models.transcript import TranscriptEntity 9 | from aana.storage.repository.transcript import TranscriptRepository 10 | 11 | transcript_entity = TranscriptEntity.from_asr_output( 12 | model_name="whisper", 13 | transcription=AsrTranscription(text="This is a transcript"), 14 | segments=[], 15 | info=AsrTranscriptionInfo(), 16 | ) 17 | 18 | 19 | @pytest.fixture(scope="function") 20 | def dummy_transcript(): 21 | """Creates a dummy transcript for testing.""" 22 | transcript = AsrTranscription(text="This is a transcript") 23 | segments = [ 24 | AsrSegment(text="This is a segment", time_interval=TimeInterval(start=0, end=1)) 25 | ] 26 | info = AsrTranscriptionInfo(language="en", language_confidence=0.9) 27 | return transcript, segments, info 28 | 29 | 30 | @pytest.mark.asyncio 31 | async def test_save_transcript(db_session_manager, dummy_transcript): 32 | """Tests saving a transcript.""" 33 | transcript, segments, info = dummy_transcript 34 | model_name = "whisper" 35 | 36 | async with db_session_manager.session() as session: 37 | transcript_repo = TranscriptRepository(session) 38 | transcript_entity = await transcript_repo.save( 39 | model_name=model_name, 40 | transcription_info=info, 41 | transcription=transcript, 42 | segments=segments, 43 | ) 44 | 45 | transcript_id = transcript_entity.id 46 | 47 | transcript_entity = await transcript_repo.read(transcript_id) 48 | assert transcript_entity 49 | assert transcript_entity.id == transcript_id 50 | assert transcript_entity.model == model_name 51 | assert transcript_entity.transcript == transcript.text 52 | assert len(transcript_entity.segments) == len(segments) 53 | assert transcript_entity.language == info.language 54 | assert transcript_entity.language_confidence == info.language_confidence 55 | 56 | await transcript_repo.delete(transcript_id) 57 | with pytest.raises(NotFoundException): 58 | await transcript_repo.read(transcript_id) 59 | -------------------------------------------------------------------------------- /aana/core/libraries/image.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | 3 | import numpy as np 4 | 5 | 6 | class AbstractImageLibrary: 7 | """Abstract class for image libraries.""" 8 | 9 | @classmethod 10 | def read_file(cls, path: Path) -> np.ndarray: 11 | """Read a file using the image library. 12 | 13 | Args: 14 | path (Path): The path of the file to read. 15 | 16 | Returns: 17 | np.ndarray: The file as a numpy array. 18 | """ 19 | raise NotImplementedError 20 | 21 | @classmethod 22 | def read_from_bytes(cls, content: bytes) -> np.ndarray: 23 | """Read bytes using the image library. 24 | 25 | Args: 26 | content (bytes): The content of the file to read. 27 | 28 | Returns: 29 | np.ndarray: The file as a numpy array. 30 | """ 31 | raise NotImplementedError 32 | 33 | @classmethod 34 | def write_file( 35 | cls, 36 | path: Path, 37 | img: np.ndarray, 38 | format: str = "bmp", # noqa: A002 39 | quality: int = 95, 40 | compression: int = 3, 41 | ): 42 | """Write an image to disk in BMP, PNG or JPEG format. 43 | 44 | Args: 45 | path (Path): Base path (extension will be enforced). 46 | img (np.ndarray): RGB image array. 47 | format (str): One of "bmp", "png", "jpeg" (or "jpg"). 48 | quality (int): JPEG quality (0-100; higher is better). Only used if format is JPEG. 49 | compression (int): PNG compression level (0-9; higher is smaller). Only used if format is PNG. 50 | """ 51 | raise NotImplementedError 52 | 53 | @classmethod 54 | def write_to_bytes( 55 | cls, 56 | img: np.ndarray, 57 | format: str = "jpeg", # noqa: A002 58 | quality: int = 95, 59 | compression: int = 3, 60 | ) -> bytes: 61 | """Write image to bytes in a specified format. 62 | 63 | Args: 64 | img (np.ndarray): The image to write. 65 | format (str): The format to use for encoding. Default is "jpeg". 66 | quality (int): The quality to use for encoding. Default is 95. 67 | compression (int): The compression level to use for encoding. Default is 3. 68 | 69 | Returns: 70 | bytes: The image as bytes. 71 | """ 72 | raise NotImplementedError 73 | -------------------------------------------------------------------------------- /aana/tests/units/test_sampling_params.py: -------------------------------------------------------------------------------- 1 | # ruff: noqa: S101 2 | import pytest 3 | from pydantic import ValidationError 4 | 5 | from aana.core.models.sampling import SamplingParams 6 | 7 | 8 | def test_valid_sampling_params(): 9 | """Test valid sampling parameters.""" 10 | params = SamplingParams( 11 | temperature=0.5, top_p=0.9, top_k=10, max_tokens=50, repetition_penalty=1.5 12 | ) 13 | assert params.temperature == 0.5 14 | assert params.top_p == 0.9 15 | assert params.top_k == 10 16 | assert params.max_tokens == 50 17 | assert params.repetition_penalty == 1.5 18 | 19 | # Test valid params with default values 20 | params = SamplingParams() 21 | assert params.temperature == 1.0 22 | assert params.top_p == 1.0 23 | assert params.top_k is None 24 | assert params.max_tokens is None 25 | assert params.repetition_penalty == 1.0 26 | 27 | 28 | def test_invalid_temperature(): 29 | """Test invalid temperature values.""" 30 | with pytest.raises(ValueError): 31 | SamplingParams(temperature=-1.0) 32 | 33 | 34 | def test_invalid_top_p(): 35 | """Test invalid top_p values.""" 36 | with pytest.raises(ValueError): 37 | SamplingParams(top_p=0.0) 38 | with pytest.raises(ValueError): 39 | SamplingParams(top_p=1.1) 40 | 41 | 42 | def test_invalid_top_k(): 43 | """Test invalid top_k values.""" 44 | with pytest.raises(ValueError): 45 | SamplingParams(top_k=0) 46 | with pytest.raises(ValueError): 47 | SamplingParams(top_k=-2) 48 | 49 | 50 | def test_invalid_max_tokens(): 51 | """Test invalid max_tokens values.""" 52 | with pytest.raises(ValueError): 53 | SamplingParams(max_tokens=0) 54 | 55 | 56 | def test_kwargs(): 57 | """Test extra keyword arguments.""" 58 | params = SamplingParams( 59 | temperature=0.5, kwargs={"presence_penalty": 2.0, "frequency_penalty": 1.0} 60 | ) 61 | assert params.kwargs == {"presence_penalty": 2.0, "frequency_penalty": 1.0} 62 | assert params.temperature == 0.5 63 | assert params.top_p == 1.0 64 | assert params.top_k is None 65 | assert params.max_tokens is None 66 | assert params.repetition_penalty == 1.0 67 | 68 | 69 | def test_disallowed_extra_fields(): 70 | """Test that extra fields are not allowed.""" 71 | with pytest.raises(ValidationError): 72 | SamplingParams(temperature=0.5, extra_field="extra_value") 73 | -------------------------------------------------------------------------------- /aana/storage/repository/caption.py: -------------------------------------------------------------------------------- 1 | from typing import TypeVar 2 | 3 | from sqlalchemy.ext.asyncio import AsyncSession 4 | 5 | from aana.core.models.captions import Caption, CaptionsList 6 | from aana.storage.models.caption import CaptionEntity 7 | from aana.storage.repository.base import BaseRepository 8 | 9 | T = TypeVar("T", bound=CaptionEntity) 10 | 11 | 12 | class CaptionRepository(BaseRepository[T]): 13 | """Repository for Captions.""" 14 | 15 | def __init__(self, session: AsyncSession, model_class: type[T] = CaptionEntity): 16 | """Constructor.""" 17 | super().__init__(session, model_class) 18 | 19 | async def save( 20 | self, model_name: str, caption: Caption, timestamp: float, frame_id: int 21 | ): 22 | """Save a caption. 23 | 24 | Args: 25 | model_name (str): The name of the model used to generate the caption. 26 | caption (Caption): The caption. 27 | timestamp (float): The timestamp. 28 | frame_id (int): The frame ID. 29 | """ 30 | entity = CaptionEntity.from_caption_output( 31 | model_name=model_name, 32 | frame_id=frame_id, 33 | timestamp=timestamp, 34 | caption=caption, 35 | ) 36 | await self.create(entity) 37 | return entity 38 | 39 | async def save_all( 40 | self, 41 | model_name: str, 42 | captions: CaptionsList, 43 | timestamps: list[float], 44 | frame_ids: list[int], 45 | ) -> list[CaptionEntity]: 46 | """Save captions. 47 | 48 | Args: 49 | model_name (str): The name of the model used to generate the captions. 50 | captions (CaptionsList): The captions. 51 | timestamps (list[float]): The timestamps. 52 | frame_ids (list[int]): The frame IDs. 53 | 54 | Returns: 55 | list[CaptionEntity]: The list of caption entities. 56 | """ 57 | entities = [ 58 | CaptionEntity.from_caption_output( 59 | model_name=model_name, 60 | frame_id=frame_id, 61 | timestamp=timestamp, 62 | caption=caption, 63 | ) 64 | for caption, timestamp, frame_id in zip( 65 | captions, timestamps, frame_ids, strict=True 66 | ) 67 | ] 68 | results = await self.create_multiple(entities) 69 | return results 70 | -------------------------------------------------------------------------------- /aana/tests/units/test_deployment_restart.py: -------------------------------------------------------------------------------- 1 | # ruff: noqa: S101 2 | import asyncio 3 | 4 | import pytest 5 | from ray import serve 6 | 7 | from aana.deployments.aana_deployment_handle import AanaDeploymentHandle 8 | from aana.deployments.base_deployment import BaseDeployment, exception_handler 9 | from aana.exceptions.runtime import InferenceException 10 | 11 | 12 | @serve.deployment(health_check_period_s=1, health_check_timeout_s=30) 13 | class Lowercase(BaseDeployment): 14 | """Ray deployment that returns the lowercase version of a text.""" 15 | 16 | def __init__(self): 17 | """Initialize the deployment.""" 18 | super().__init__() 19 | self.active = True 20 | 21 | @exception_handler 22 | async def lower(self, text: str) -> dict: 23 | """Lowercase the text. 24 | 25 | Args: 26 | text (str): The text to lowercase 27 | 28 | Returns: 29 | dict: The lowercase text 30 | """ 31 | if text == "inference_exception" or not self.active: 32 | self.active = False 33 | raise InferenceException(model_name="lowercase_deployment") 34 | 35 | return {"text": text.lower()} 36 | 37 | 38 | deployments = [ 39 | { 40 | "name": "lowercase_deployment", 41 | "instance": Lowercase, 42 | } 43 | ] 44 | 45 | 46 | @pytest.mark.asyncio 47 | async def test_deployment_restart(create_app): 48 | """Test the Ray Serve app.""" 49 | create_app(deployments, []) 50 | 51 | handle = await AanaDeploymentHandle.create("lowercase_deployment") 52 | 53 | text = "Hello, World!" 54 | 55 | # test the lowercase deployment works 56 | response = await handle.lower(text=text) 57 | assert response == {"text": text.lower()} 58 | 59 | # Cause an InferenceException in the deployment and make it inactive. 60 | # After the deployment is inactive, the deployment should always raise an InferenceException. 61 | with pytest.raises(InferenceException): 62 | await handle.lower(text="inference_exception") 63 | 64 | # The deployment should restart and work again, wait for around 60 seconds for the deployment to restart. 65 | for _ in range(60): 66 | await asyncio.sleep(1) 67 | try: 68 | response = await handle.lower(text=text) 69 | if response == {"text": text.lower()}: 70 | break 71 | except: # noqa: S110 72 | pass 73 | 74 | assert response == {"text": text.lower()} 75 | -------------------------------------------------------------------------------- /aana/configs/db.py: -------------------------------------------------------------------------------- 1 | from enum import Enum 2 | from os import PathLike 3 | 4 | from pydantic_settings import BaseSettings 5 | from typing_extensions import TypedDict 6 | 7 | 8 | class DbType(str, Enum): 9 | """Engine types for relational database. 10 | 11 | Attributes: 12 | POSTGRESQL: PostgreSQL database. 13 | SQLITE: SQLite database. 14 | """ 15 | 16 | POSTGRESQL = "postgresql" 17 | SQLITE = "sqlite" 18 | 19 | 20 | class SQLiteConfig(TypedDict): 21 | """Config values for SQLite. 22 | 23 | Attributes: 24 | path (PathLike): The path to the SQLite database file. 25 | """ 26 | 27 | path: PathLike | str 28 | 29 | 30 | class PostgreSQLConfig(TypedDict): 31 | """Config values for PostgreSQL. 32 | 33 | Attributes: 34 | host (str): The host of the PostgreSQL server. 35 | port (int): The port of the PostgreSQL server. 36 | user (str): The user to connect to the PostgreSQL server. 37 | password (str): The password to connect to the PostgreSQL server. 38 | database (str): The database name. 39 | """ 40 | 41 | host: str 42 | port: int 43 | user: str 44 | password: str 45 | database: str 46 | 47 | 48 | class DbSettings(BaseSettings): 49 | """Database configuration. 50 | 51 | Attributes: 52 | datastore_type (DbType | str): The type of the datastore. Default is DbType.SQLITE. 53 | datastore_config (SQLiteConfig | PostgreSQLConfig): The configuration for the datastore. 54 | Default is SQLiteConfig(path="/var/lib/aana_data"). 55 | pool_size (int): The number of connections to keep in the pool. Default is 5. 56 | max_overflow (int): The number of connections that can be created when the pool is exhausted. 57 | Default is 10. 58 | pool_recycle (int): The number of seconds a connection can be idle in the pool before it is invalidated. 59 | Default is 3600. 60 | query_timeout (int): The timeout for database queries in seconds. Default is 30. Set to 0 for no timeout. 61 | connection_timeout (int): The timeout for database connections in seconds. Default is 10. 62 | """ 63 | 64 | datastore_type: DbType | str = DbType.SQLITE 65 | datastore_config: SQLiteConfig | PostgreSQLConfig = SQLiteConfig( 66 | path="/var/lib/aana_data" 67 | ) 68 | pool_size: int = 5 69 | max_overflow: int = 10 70 | pool_recycle: int = 3600 71 | query_timeout: int = 30 72 | connection_timeout: int = 10 73 | -------------------------------------------------------------------------------- /aana/tests/units/test_settings.py: -------------------------------------------------------------------------------- 1 | # ruff: noqa: S101, S108 2 | from pathlib import Path 3 | 4 | from aana.configs.settings import Settings 5 | 6 | 7 | def test_default_tmp_data_dir(): 8 | """Test that the default temporary data directory is set correctly.""" 9 | settings = Settings() 10 | assert settings.tmp_data_dir == Path("/tmp/aana_data") 11 | 12 | 13 | def test_custom_tmp_data_dir(monkeypatch): 14 | """Test that the custom temporary data directory with environment variable is set correctly.""" 15 | test_path = "/tmp/override/path" 16 | monkeypatch.setenv("TMP_DATA_DIR", test_path) 17 | settings = Settings() 18 | assert settings.tmp_data_dir == Path(test_path) 19 | 20 | 21 | def test_changing_tmp_data_dir(): 22 | """Test that changing the temporary data directory is reflected in the other directories.""" 23 | new_tmp_data_dir = Path("/tmp/new_tmp_data_dir") 24 | settings = Settings(tmp_data_dir=new_tmp_data_dir) 25 | 26 | assert settings.tmp_data_dir == new_tmp_data_dir 27 | assert settings.image_dir == new_tmp_data_dir / "images" 28 | assert settings.video_dir == new_tmp_data_dir / "videos" 29 | assert settings.audio_dir == new_tmp_data_dir / "audios" 30 | assert settings.model_dir == new_tmp_data_dir / "models" 31 | 32 | # Check that we can change the image directory independently 33 | new_image_dir = Path("/tmp/new_image_dir") 34 | settings = Settings(tmp_data_dir=new_tmp_data_dir, image_dir=new_image_dir) 35 | assert settings.tmp_data_dir == new_tmp_data_dir 36 | assert settings.image_dir == new_image_dir 37 | 38 | # Check that we can change the video directory independently 39 | new_video_dir = Path("/tmp/new_video_dir") 40 | settings = Settings(tmp_data_dir=new_tmp_data_dir, video_dir=new_video_dir) 41 | assert settings.tmp_data_dir == new_tmp_data_dir 42 | assert settings.video_dir == new_video_dir 43 | 44 | # Check that we can change the audio directory independently 45 | new_audio_dir = Path("/tmp/new_audio_dir") 46 | settings = Settings(tmp_data_dir=new_tmp_data_dir, audio_dir=new_audio_dir) 47 | assert settings.tmp_data_dir == new_tmp_data_dir 48 | assert settings.audio_dir == new_audio_dir 49 | 50 | # Check that we can change the model directory independently 51 | new_model_dir = Path("/tmp/new_model_dir") 52 | settings = Settings(tmp_data_dir=new_tmp_data_dir, model_dir=new_model_dir) 53 | assert settings.tmp_data_dir == new_tmp_data_dir 54 | assert settings.model_dir == new_model_dir 55 | -------------------------------------------------------------------------------- /aana/tests/units/test_event_manager.py: -------------------------------------------------------------------------------- 1 | # ruff: noqa: S101 2 | from collections.abc import Callable 3 | 4 | import pytest 5 | from typing_extensions import override 6 | 7 | from aana.api.event_handlers.event_handler import EventHandler 8 | from aana.api.event_handlers.event_manager import EventManager 9 | from aana.exceptions.runtime import ( 10 | HandlerNotRegisteredException, 11 | ) 12 | 13 | 14 | class CallbackHandler(EventHandler): 15 | """A test event handler that just invokes a callback function.""" 16 | 17 | def __init__(self, callback: Callable): 18 | """Constructor.""" 19 | super().__init__() 20 | self.callback = callback 21 | 22 | @override 23 | def handle(self, event_name: str, *args, **kwargs): 24 | self.callback(event_name, *args, **kwargs) 25 | 26 | 27 | def test_event_dispatch(): 28 | """Tests that event dispatch works correctly.""" 29 | event_manager = EventManager() 30 | expected_event_name = "foo" 31 | expected_args = (1, 2, 3, 4, 5) 32 | expected_kwargs = {"a": "A", "b": "B"} 33 | 34 | def callback(event_name, *args, **kwargs): 35 | assert event_name == expected_event_name 36 | assert args == expected_args 37 | assert kwargs == expected_kwargs 38 | 39 | event_manager.register_handler_for_events(CallbackHandler(callback), ["foo"]) 40 | 41 | event_manager.handle("foo", 1, 2, 3, 4, 5, a="A", b="B") 42 | 43 | 44 | def test_remove_all_raises(): 45 | """Tests that removing handler not added from all events raises an error.""" 46 | event_manager = EventManager() 47 | handler = CallbackHandler(lambda _, *_args, **_kwargs: None) 48 | with pytest.raises(HandlerNotRegisteredException): 49 | event_manager.deregister_handler_from_all_events(handler) 50 | 51 | 52 | def test_remove_works(): 53 | """Tests that removing a handler works.""" 54 | event_manager = EventManager() 55 | handler = CallbackHandler(lambda _, *_args, **_kwargs: None) 56 | event_manager.register_handler_for_events(handler, ["foo"]) 57 | event_manager.deregister_handler_from_event(handler, "foo") 58 | assert len(event_manager._handlers["foo"]) == 0 59 | 60 | 61 | def test_remove_all_works(): 62 | """Tests that removing all handlers works.""" 63 | event_manager = EventManager() 64 | handler = CallbackHandler(lambda _, *_args, **_kwargs: None) 65 | event_manager.register_handler_for_events(handler, ["foo"]) 66 | event_manager.deregister_handler_from_all_events(handler) 67 | assert len(event_manager._handlers["foo"]) == 0 68 | -------------------------------------------------------------------------------- /aana/tests/units/test_speaker.py: -------------------------------------------------------------------------------- 1 | import json 2 | from importlib import resources 3 | from pathlib import Path 4 | from typing import Literal 5 | 6 | import pytest 7 | 8 | from aana.core.models.asr import AsrSegment 9 | from aana.core.models.speaker import SpeakerDiarizationSegment 10 | from aana.processors.speaker import PostProcessingForDiarizedAsr 11 | from aana.tests.utils import verify_deployment_results 12 | 13 | 14 | @pytest.mark.skip(reason="The test is temporary disabled") 15 | @pytest.mark.parametrize("audio_file", ["sd_sample.wav"]) 16 | def test_asr_diarization_post_process(audio_file: Literal["sd_sample.wav"]): 17 | """Test that the ASR output can be processed to generate diarized transcription and an invalid ASR output leads to ValueError.""" 18 | # Load precomputed ASR and Diarization outputs 19 | asr_path = ( 20 | resources.files("aana.tests.files.expected.whisper") 21 | / f"whisper_medium_{audio_file}.json" 22 | ) 23 | diar_path = ( 24 | resources.files("aana.tests.files.expected.sd") 25 | / f"{Path(audio_file).stem}.json" 26 | ) 27 | expected_results_path = ( 28 | resources.files("aana.tests.files.expected.whisper") 29 | / f"whisper_medium_{audio_file}_diar.json" 30 | ) 31 | 32 | # convert to WhisperOutput and SpeakerDiarizationOutput 33 | with Path.open(asr_path, "r") as json_file: 34 | asr_op = json.load(json_file) 35 | 36 | asr_segments = [ 37 | AsrSegment.model_validate(segment) for segment in asr_op["segments"] 38 | ] 39 | 40 | with Path.open(diar_path, "r") as json_file: 41 | diar_op = json.load(json_file) 42 | 43 | diarized_segments = [ 44 | SpeakerDiarizationSegment.model_validate(segment) 45 | for segment in diar_op["segments"] 46 | ] 47 | asr_op["segments"] = PostProcessingForDiarizedAsr.process( 48 | diarized_segments=diarized_segments, transcription_segments=asr_segments 49 | ) 50 | verify_deployment_results(expected_results_path, asr_op) 51 | 52 | # Raise error if the ASR output is a invalid input for combining with diarization. 53 | 54 | # setting words to empty list 55 | for segment in asr_segments: 56 | segment.words = [] 57 | 58 | # Expect ValueError with the specific error message 59 | with pytest.raises( 60 | ValueError, match="Word-level timestamps are required for diarized ASR." 61 | ): 62 | PostProcessingForDiarizedAsr.process( 63 | diarized_segments=diarized_segments, transcription_segments=asr_segments 64 | ) 65 | -------------------------------------------------------------------------------- /aana/core/chat/chat_template.py: -------------------------------------------------------------------------------- 1 | import typing 2 | from functools import lru_cache 3 | from importlib import resources 4 | 5 | from aana.core.models.chat import ChatDialog 6 | 7 | if typing.TYPE_CHECKING: 8 | from transformers.tokenization_utils_base import PreTrainedTokenizerBase 9 | 10 | 11 | @lru_cache(maxsize=128) 12 | def load_chat_template(chat_template_name: str) -> str: 13 | """Loads a chat template from the chat templates directory. 14 | 15 | Args: 16 | chat_template_name (str): The name of the chat template to load. 17 | 18 | Returns: 19 | str: The loaded chat template. 20 | 21 | Raises: 22 | ValueError: If the chat template does not exist. 23 | """ 24 | path = resources.files("aana.core.chat.templates") / f"{chat_template_name}.jinja" 25 | if not path.exists(): 26 | raise ValueError(f"Chat template {chat_template_name} does not exist.") # noqa: TRY003 27 | 28 | return path.read_text() 29 | 30 | 31 | def apply_chat_template( 32 | tokenizer: "PreTrainedTokenizerBase", 33 | dialog: ChatDialog | list[dict], 34 | chat_template_name: str | None = None, 35 | ) -> str: 36 | """Applies a chat template to a list of messages to generate a prompt for the model. 37 | 38 | If the chat template is not specified, the tokenizer's default chat template is used. 39 | If the chat template is specified, the template with given name loaded from the chat templates directory. 40 | 41 | Args: 42 | tokenizer (PreTrainedTokenizerBase): The tokenizer to use. 43 | dialog (ChatDialog | list[dict]): The dialog to generate a prompt for. 44 | chat_template_name (str, optional): The name of the chat template to use. Defaults to None, which uses the tokenizer's default chat template. 45 | 46 | Returns: 47 | str: The generated prompt. 48 | 49 | Raises: 50 | ValueError: If the tokenizer does not have a chat template. 51 | ValueError: If the chat template does not exist. 52 | """ 53 | if isinstance(dialog, ChatDialog): 54 | messages = dialog.model_dump()["messages"] 55 | else: 56 | messages = dialog 57 | 58 | if chat_template_name is not None: 59 | chat_template = load_chat_template(chat_template_name) 60 | tokenizer.chat_template = chat_template 61 | 62 | if tokenizer.chat_template is None: 63 | raise ValueError("Tokenizer does not have a chat template.") # noqa: TRY003 64 | 65 | return tokenizer.apply_chat_template( 66 | messages, tokenize=False, add_generation_prompt=True 67 | ) 68 | -------------------------------------------------------------------------------- /aana/core/models/speaker.py: -------------------------------------------------------------------------------- 1 | from typing import Annotated 2 | 3 | from pydantic import BaseModel, ConfigDict, Field 4 | 5 | from aana.core.models.time import TimeInterval 6 | 7 | __all__ = [ 8 | "PyannoteSpeakerDiarizationParams", 9 | "SpeakerDiarizationSegment", 10 | "SpeakerDiarizationSegments", 11 | ] 12 | 13 | 14 | class PyannoteSpeakerDiarizationParams(BaseModel): 15 | """A model for the pyannote Speaker Diarization model parameters. 16 | 17 | Attributes: 18 | min_speakers (int | None): The minimum number of speakers present in the audio. 19 | max_speakers (int | None): The maximum number of speakers present in the audio. 20 | """ 21 | 22 | min_speakers: int | None = Field( 23 | default=None, 24 | description="The minimum number of speakers present in the audio.", 25 | ) 26 | 27 | max_speakers: int | None = Field( 28 | default=None, 29 | description="The maximum number of speakers present in the audio.", 30 | ) 31 | 32 | model_config = ConfigDict( 33 | json_schema_extra={ 34 | "description": "Parameters for the pyannote speaker diarization model.", 35 | }, 36 | extra="forbid", 37 | ) 38 | 39 | 40 | class SpeakerDiarizationSegment(BaseModel): 41 | """Pydantic schema for Segment from Speaker Diarization model. 42 | 43 | Attributes: 44 | time_interval (TimeInterval): The start and end time of the segment 45 | speaker (str): speaker assignment of the model in the format "SPEAKER_XX" 46 | """ 47 | 48 | time_interval: TimeInterval = Field(description="Time interval of the segment") 49 | speaker: str = Field(description="speaker assignment from the model") 50 | 51 | def to_dict(self) -> dict: 52 | """Generate dictionary with start, end and speaker keys from SpeakerDiarizationSegment. 53 | 54 | Returns: 55 | dict: Dictionary with start, end and speaker keys 56 | """ 57 | return { 58 | "start": self.time_interval.start, 59 | "end": self.time_interval.end, 60 | "speaker": self.speaker, 61 | } 62 | 63 | model_config = ConfigDict( 64 | json_schema_extra={ 65 | "description": "Speaker Diarization Segment", 66 | }, 67 | extra="forbid", 68 | ) 69 | 70 | 71 | SpeakerDiarizationSegments = Annotated[ 72 | list[SpeakerDiarizationSegment], 73 | Field(description="List of Speaker Diarization segments", default_factory=list), 74 | ] 75 | """ 76 | List of SpeakerDiarizationSegment objects. 77 | """ 78 | -------------------------------------------------------------------------------- /aana/tests/deployments/test_hf_pipeline_deployment.py: -------------------------------------------------------------------------------- 1 | from importlib import resources 2 | 3 | import pytest 4 | from transformers import BitsAndBytesConfig 5 | 6 | from aana.core.models.image import Image 7 | from aana.deployments.aana_deployment_handle import AanaDeploymentHandle 8 | from aana.deployments.hf_pipeline_deployment import ( 9 | HfPipelineConfig, 10 | HfPipelineDeployment, 11 | ) 12 | from aana.tests.utils import verify_deployment_results 13 | 14 | deployments = [ 15 | ( 16 | "hf_pipeline_blip2_deployment", 17 | HfPipelineDeployment.options( 18 | num_replicas=1, 19 | ray_actor_options={"num_gpus": 1}, 20 | user_config=HfPipelineConfig( 21 | model_id="Salesforce/blip2-opt-2.7b", 22 | task = "image-to-text", 23 | model_kwargs={ 24 | "quantization_config": BitsAndBytesConfig( 25 | load_in_8bit=False, load_in_4bit=True 26 | ), 27 | }, 28 | ).model_dump(mode="json"), 29 | ), 30 | ) 31 | ] 32 | 33 | 34 | @pytest.mark.parametrize("setup_deployment", deployments, indirect=True) 35 | class TestHFPipelineDeployment: 36 | """Test HuggingFace Pipeline deployment.""" 37 | 38 | @pytest.mark.asyncio 39 | @pytest.mark.parametrize("image_name", ["Starry_Night.jpeg"]) 40 | async def test_call(self, setup_deployment, image_name): 41 | """Test call method.""" 42 | deployment_name, handle_name, _ = setup_deployment 43 | 44 | handle = await AanaDeploymentHandle.create(handle_name) 45 | 46 | expected_output_path = ( 47 | resources.files("aana.tests.files.expected") 48 | / "hf_pipeline" 49 | / f"{deployment_name}_{image_name}.json" 50 | ) 51 | path = resources.files("aana.tests.files.images") / image_name 52 | image = Image(path=path, save_on_disk=False, media_id=image_name) 53 | 54 | output = await handle.call(images=image) 55 | verify_deployment_results(expected_output_path, output) 56 | 57 | output = await handle.call(image) 58 | verify_deployment_results(expected_output_path, output) 59 | 60 | output = await handle.call(images=[str(path)]) 61 | verify_deployment_results(expected_output_path, [output]) 62 | 63 | output = await handle.call(images=[image]) 64 | verify_deployment_results(expected_output_path, [output]) 65 | 66 | output = await handle.call([image]) 67 | verify_deployment_results(expected_output_path, [output]) 68 | -------------------------------------------------------------------------------- /aana/tests/projects/lowercase/app.py: -------------------------------------------------------------------------------- 1 | from typing import Annotated, TypedDict 2 | 3 | from pydantic import Field 4 | from ray import serve 5 | 6 | from aana.api.api_generation import Endpoint 7 | from aana.deployments.aana_deployment_handle import AanaDeploymentHandle 8 | from aana.deployments.base_deployment import BaseDeployment 9 | from aana.sdk import AanaSDK 10 | 11 | 12 | @serve.deployment 13 | class Lowercase(BaseDeployment): 14 | """Ray deployment that returns the lowercase version of a text.""" 15 | 16 | async def lower(self, text: str) -> dict: 17 | """Lowercase the text. 18 | 19 | Args: 20 | text (str): The text to lowercase 21 | 22 | Returns: 23 | dict: The lowercase text 24 | """ 25 | return {"text": [t.lower() for t in text]} 26 | 27 | 28 | TextList = Annotated[list[str], Field(description="List of text to lowercase.")] 29 | 30 | 31 | class LowercaseEndpointOutput(TypedDict): 32 | """The output of the lowercase endpoint.""" 33 | 34 | text: list[str] 35 | 36 | 37 | class LowercaseEndpoint(Endpoint): 38 | """Lowercase endpoint.""" 39 | 40 | async def initialize(self): 41 | """Initialize the endpoint.""" 42 | self.lowercase_handle = await AanaDeploymentHandle.create( 43 | "lowercase_deployment" 44 | ) 45 | await super().initialize() 46 | 47 | async def run(self, text: TextList) -> LowercaseEndpointOutput: 48 | """Lowercase the text. 49 | 50 | Args: 51 | text (TextList): The list of text to lowercase 52 | 53 | Returns: 54 | LowercaseEndpointOutput: The lowercase texts 55 | """ 56 | lowercase_output = await self.lowercase_handle.lower(text=text) 57 | return {"text": lowercase_output["text"]} 58 | 59 | 60 | deployments = [ 61 | { 62 | "name": "lowercase_deployment", 63 | "instance": Lowercase, 64 | } 65 | ] 66 | 67 | endpoints = [ 68 | { 69 | "name": "lowercase", 70 | "path": "/lowercase", 71 | "summary": "Lowercase text", 72 | "endpoint_cls": LowercaseEndpoint, 73 | } 74 | ] 75 | 76 | aana_app = AanaSDK(name="lowercase_app") 77 | 78 | for deployment in deployments: 79 | aana_app.register_deployment( 80 | name=deployment["name"], 81 | instance=deployment["instance"], 82 | ) 83 | 84 | for endpoint in endpoints: 85 | aana_app.register_endpoint( 86 | name=endpoint["name"], 87 | path=endpoint["path"], 88 | summary=endpoint["summary"], 89 | endpoint_cls=endpoint["endpoint_cls"], 90 | ) 91 | -------------------------------------------------------------------------------- /aana/core/models/task.py: -------------------------------------------------------------------------------- 1 | from typing import Annotated, Any 2 | 3 | from pydantic import BaseModel, ConfigDict, Field 4 | 5 | from aana.core.models.api_service import ApiKey 6 | from aana.storage.models.task import Status as TaskStatus 7 | from aana.storage.models.task import TaskEntity 8 | 9 | TaskId = Annotated[ 10 | str, 11 | Field(description="The task ID.", example="11111111-1111-1111-1111-111111111111"), 12 | ] 13 | 14 | 15 | class TaskInfo(BaseModel): 16 | """Task information. 17 | 18 | Attributes: 19 | id (str): The task ID. 20 | endpoint (str): The endpoint to which the task is assigned. 21 | data (Any): The task data. 22 | status (TaskStatus): The task status. 23 | result (Any): The task result. 24 | """ 25 | 26 | id: TaskId 27 | endpoint: str = Field( 28 | ..., description="The endpoint to which the task is assigned." 29 | ) 30 | data: Any = Field(..., description="The task data.") 31 | status: TaskStatus = Field(..., description="The task status.") 32 | result: Any = Field(None, description="The task result.") 33 | 34 | model_config = ConfigDict( 35 | json_schema_extra={ 36 | "examples": [ 37 | { 38 | "id": "11111111-1111-1111-1111-111111111111", 39 | "endpoint": "/index", 40 | "data": { 41 | "image": { 42 | "url": "https://example.com/image.jpg", 43 | "media_id": "abc123", 44 | } 45 | }, 46 | "status": "running", 47 | "result": None, 48 | } 49 | ] 50 | }, 51 | extra="forbid", 52 | ) 53 | 54 | @classmethod 55 | def from_entity(cls, task: TaskEntity, is_admin: bool = False) -> "TaskInfo": 56 | """Create a TaskInfo from a TaskEntity.""" 57 | # Prepare data (remove ApiKey, None values, etc.) 58 | task_data = {} 59 | for key, value in task.data.items(): 60 | if value is None or isinstance(value, ApiKey): 61 | continue 62 | task_data[key] = value 63 | 64 | # Remove stacktrace from result if not admin 65 | if not is_admin and task.result and "stacktrace" in task.result: 66 | task.result.pop("stacktrace", None) 67 | 68 | return TaskInfo( 69 | id=str(task.id), 70 | endpoint=task.endpoint, 71 | data=task_data, 72 | status=task.status, 73 | result=task.result, 74 | ) 75 | -------------------------------------------------------------------------------- /aana/core/models/vad.py: -------------------------------------------------------------------------------- 1 | from typing import Annotated 2 | 3 | from pydantic import BaseModel, ConfigDict, Field 4 | 5 | from aana.core.models.time import TimeInterval 6 | 7 | __all__ = ["VadParams", "VadSegment", "VadSegments"] 8 | 9 | 10 | class VadParams(BaseModel): 11 | """A model for the Voice Activity Detection model parameters. 12 | 13 | Attributes: 14 | chunk_size (float): The maximum length of each vad output chunk. 15 | merge_onset (float): Onset to be used for the merging operation. 16 | merge_offset (float): "Optional offset to be used for the merging operation. 17 | """ 18 | 19 | chunk_size: float = Field( 20 | default=30, ge=10.0, description="The maximum length of each vad output chunk." 21 | ) 22 | 23 | merge_onset: float = Field( 24 | default=0.5, ge=0.0, description="Onset to be used for the merging operation." 25 | ) 26 | 27 | merge_offset: float | None = Field( 28 | default=None, 29 | description="Optional offset to be used for the merging operation.", 30 | ) 31 | 32 | model_config = ConfigDict( 33 | json_schema_extra={ 34 | "description": "Parameters for the voice activity detection model." 35 | }, 36 | extra="forbid", 37 | ) 38 | 39 | 40 | class VadSegment(BaseModel): 41 | """Pydantic schema for Segment from Voice Activity Detection model. 42 | 43 | Attributes: 44 | time_interval (TimeInterval): The start and end time of the segment 45 | segments (list[tuple[float, float]]): smaller voiced segments within a merged vad segment 46 | """ 47 | 48 | time_interval: TimeInterval = Field(description="Time interval of the segment") 49 | segments: list[tuple[float, float]] = Field( 50 | description="List of voiced segments within a Segment for ASR" 51 | ) 52 | 53 | def to_whisper_dict(self) -> dict: 54 | """Generate dictionary with start, end and segments keys from VADSegment for faster whisper. 55 | 56 | Returns: 57 | dict: Dictionary with start, end and segments keys 58 | """ 59 | return { 60 | "start": self.time_interval.start, 61 | "end": self.time_interval.end, 62 | "segments": self.segments, 63 | } 64 | 65 | model_config = ConfigDict( 66 | json_schema_extra={ 67 | "description": "VAD Segment for ASR", 68 | }, 69 | extra="forbid", 70 | ) 71 | 72 | 73 | VadSegments = Annotated[ 74 | list[VadSegment], Field(description="List of VAD segments", default_factory=list) 75 | ] 76 | """ 77 | List of VadSegment objects. 78 | """ 79 | -------------------------------------------------------------------------------- /aana/api/app.py: -------------------------------------------------------------------------------- 1 | from datetime import datetime, timezone 2 | 3 | from fastapi import FastAPI, Request 4 | from fastapi.exceptions import RequestValidationError 5 | from fastapi.middleware.cors import CORSMiddleware 6 | from pydantic import ValidationError 7 | from sqlalchemy import select 8 | 9 | from aana.api.exception_handler import ( 10 | aana_exception_handler, 11 | validation_exception_handler, 12 | ) 13 | from aana.configs.settings import settings as aana_settings 14 | from aana.exceptions.api_service import ( 15 | ApiKeyExpired, 16 | ApiKeyNotFound, 17 | ApiKeyNotProvided, 18 | ApiKeyValidationFailed, 19 | ) 20 | from aana.storage.models.api_key import ApiKeyEntity 21 | from aana.storage.session import get_session 22 | 23 | app = FastAPI() 24 | 25 | app.add_middleware( 26 | CORSMiddleware, 27 | allow_origins=aana_settings.cors.allow_origins, 28 | allow_origin_regex=aana_settings.cors.allow_origin_regex, 29 | allow_credentials=aana_settings.cors.allow_credentials, 30 | allow_methods=aana_settings.cors.allow_methods, 31 | allow_headers=aana_settings.cors.allow_headers, 32 | ) 33 | 34 | app.add_exception_handler(ValidationError, validation_exception_handler) 35 | app.add_exception_handler(RequestValidationError, validation_exception_handler) 36 | app.add_exception_handler(Exception, aana_exception_handler) 37 | 38 | 39 | @app.middleware("http") 40 | async def api_key_check(request: Request, call_next): 41 | """Middleware to check the API key and subscription status.""" 42 | excluded_paths = ["/openapi.json", "/docs", "/redoc", "/api/ready"] 43 | if request.url.path in excluded_paths or request.method == "OPTIONS": 44 | return await call_next(request) 45 | 46 | if aana_settings.api_service.enabled: 47 | api_key = request.headers.get("x-api-key") 48 | 49 | if not api_key: 50 | raise ApiKeyNotProvided() 51 | 52 | async with get_session() as session: 53 | try: 54 | result = await session.execute( 55 | select(ApiKeyEntity).where(ApiKeyEntity.api_key == api_key) 56 | ) 57 | api_key_info = result.scalars().first() 58 | except Exception as e: 59 | raise ApiKeyValidationFailed() from e 60 | 61 | if not api_key_info: 62 | raise ApiKeyNotFound(key=api_key) 63 | 64 | if api_key_info.expired_at < datetime.now(timezone.utc): 65 | raise ApiKeyExpired(key=api_key) 66 | 67 | request.state.api_key_info = api_key_info.to_model() 68 | 69 | response = await call_next(request) 70 | return response 71 | -------------------------------------------------------------------------------- /aana/storage/repository/media.py: -------------------------------------------------------------------------------- 1 | from typing import TypeVar 2 | 3 | from sqlalchemy.ext.asyncio import AsyncSession 4 | 5 | from aana.core.models.media import MediaId 6 | from aana.exceptions.db import MediaIdAlreadyExistsException, NotFoundException 7 | from aana.storage.models.media import MediaEntity 8 | from aana.storage.repository.base import BaseRepository 9 | 10 | M = TypeVar("M", bound=MediaEntity) 11 | 12 | 13 | class MediaRepository(BaseRepository[M]): 14 | """Repository for media files.""" 15 | 16 | def __init__(self, session: AsyncSession, model_class: type[M] = MediaEntity): 17 | """Constructor.""" 18 | super().__init__(session, model_class) 19 | 20 | async def check_media_exists(self, media_id: MediaId) -> bool: 21 | """Checks if a media file exists in the database. 22 | 23 | Args: 24 | media_id (MediaId): The media ID. 25 | 26 | Returns: 27 | bool: True if the media exists, False otherwise. 28 | """ 29 | try: 30 | await self.read(media_id) 31 | except NotFoundException: 32 | return False 33 | 34 | return True 35 | 36 | async def create(self, entity: MediaEntity) -> MediaEntity: 37 | """Inserts a single new entity. 38 | 39 | Args: 40 | entity (MediaEntity): The entity to insert. 41 | 42 | Returns: 43 | MediaEntity: The inserted entity. 44 | """ 45 | # TODO: throw MediaIdAlreadyExistsException without checking if media exists. 46 | # The following code is a better way to check if media already exists because it does only one query and 47 | # prevents race condition where two processes try to create the same media. 48 | # But it has an issue with the exception handling. 49 | # Unique constraint violation raises IntegrityError, but IntegrityError much more broader than 50 | # Unique constraint violation. So we cannot catch IntegrityError and raise MediaIdAlreadyExistsException, 51 | # we need to check the exception if it is Unique constraint violation or not and if it's on the media_id column. 52 | # Also different DBMS have different exception messages. 53 | # 54 | # try: 55 | # return super().create(entity) 56 | # except IntegrityError as e: 57 | # self.session.rollback() 58 | # raise MediaIdAlreadyExistsException(self.table_name, entity.id) from e 59 | 60 | if await self.check_media_exists(entity.id): 61 | raise MediaIdAlreadyExistsException(self.table_name, entity.id) 62 | 63 | return await super().create(entity) 64 | -------------------------------------------------------------------------------- /docs/pages/code_overview.md: -------------------------------------------------------------------------------- 1 | # Code overview 2 | 3 | ``` 4 | aana/ | top level source code directory for the project 5 | ├── alembic/ | directory for database migrations 6 | │ └── versions/ | individual migrations 7 | ├── api/ | API functionality 8 | │ ├── api_generation.py | API generation code, defines Endpoint class 9 | │ ├── request_handler.py | request handler routes requests to endpoints 10 | │ ├── exception_handler.py | exception handler to process exceptions and return them as JSON 11 | │ ├── responses.py | custom responses for the API 12 | │ └── app.py | defines the FastAPI app and connects exception handlers 13 | ├── config/ | various configuration objects, including settings, but preconfigured deployments 14 | │ ├── db.py | config for the database 15 | │ ├── deployments.py | preconfigured for deployments 16 | │ └── settings.py | app settings 17 | ├── core/ | core models and functionality 18 | │ ├── models/ | core data models 19 | │ ├── libraries/ | base libraries for audio, images etc. 20 | │ └── chat/ | LLM chat templates 21 | ├── deployments/ | classes for predefined deployments (e.g. Hugging Face Transformers, Whisper, vLLM) 22 | ├── exceptions/ | custom exception classes 23 | ├── integrations/ | integrations with 3rd party libraries 24 | │ ├── external/ | integrations with 3rd party libraries for example image, video, audio processing, download youtube videos, etc. 25 | │ └── haystack/ | integrations with Deepset Haystack 26 | ├── processors/ | utility functions for processing data 27 | ├── storage/ | storage functionality 28 | │ ├── models/ | database models 29 | │ ├── repository/ | repository classes for storage 30 | │ └── services/ | utility functions for storage 31 | ├── tests/ | automated tests for the SDK 32 | │ ├── db/ | tests for database functions 33 | │ ├── deployments/ | tests for model deployments 34 | │ ├── files/ | assets for testing 35 | │ ├── integrations/ | tests for integrations 36 | │ ├── projects/ | test projects 37 | │ └── units/ | unit tests 38 | ├── utils/ | various utility functionality 39 | ├── cli.py | command-line interface to build and deploy the SDK 40 | └── sdk.py | base class to create an SDK instance 41 | ``` 42 | 43 | -------------------------------------------------------------------------------- /aana/tests/units/test_app_upload.py: -------------------------------------------------------------------------------- 1 | # ruff: noqa: S101, S113 2 | import io 3 | import json 4 | from typing import TypedDict 5 | 6 | import requests 7 | from pydantic import BaseModel, ConfigDict, Field 8 | 9 | from aana.api.api_generation import Endpoint 10 | from aana.exceptions.runtime import UploadedFileNotFound 11 | 12 | 13 | class FileUploadModel(BaseModel): 14 | """Model for a file upload input.""" 15 | 16 | content: str | None = Field( 17 | None, 18 | description="The name of the file to upload.", 19 | ) 20 | _file: bytes | None = None 21 | 22 | def set_files(self, files: dict[str, bytes]): 23 | """Set files.""" 24 | if self.content: 25 | if self.content not in files: 26 | raise UploadedFileNotFound(self.content) 27 | self._file = files[self.content] 28 | 29 | model_config = ConfigDict(extra="forbid") 30 | 31 | 32 | class FileUploadEndpointOutput(TypedDict): 33 | """The output of the file upload endpoint.""" 34 | 35 | text: str 36 | 37 | 38 | class FileUploadEndpoint(Endpoint): 39 | """File upload endpoint.""" 40 | 41 | async def run(self, file: FileUploadModel) -> FileUploadEndpointOutput: 42 | """Upload a file. 43 | 44 | Args: 45 | file (FileUploadModel): The file to upload 46 | 47 | Returns: 48 | FileUploadEndpointOutput: The uploaded file 49 | """ 50 | file = file._file 51 | return {"text": file.decode()} 52 | 53 | 54 | deployments = [] 55 | 56 | endpoints = [ 57 | { 58 | "name": "file_upload", 59 | "path": "/file_upload", 60 | "summary": "Upload a file", 61 | "endpoint_cls": FileUploadEndpoint, 62 | } 63 | ] 64 | 65 | 66 | def test_file_upload_app(create_app): 67 | """Test the app with a file upload endpoint.""" 68 | aana_app = create_app(deployments, endpoints) 69 | 70 | port = aana_app.port 71 | route_prefix = "" 72 | 73 | # Check that the server is ready 74 | response = requests.get(f"http://localhost:{port}{route_prefix}/api/ready") 75 | assert response.status_code == 200 76 | assert response.json() == {"ready": True} 77 | 78 | # Test lowercase endpoint 79 | # data = {"content": "file.txt"} 80 | data = {"file": {"content": "file.txt"}} 81 | file = b"Hello world! This is a test." 82 | files = {"file.txt": io.BytesIO(file)} 83 | response = requests.post( 84 | f"http://localhost:{port}{route_prefix}/file_upload", 85 | data={"body": json.dumps(data)}, 86 | files=files, 87 | ) 88 | assert response.status_code == 200, response.text 89 | text = response.json().get("text") 90 | assert text == file.decode() 91 | -------------------------------------------------------------------------------- /aana/storage/models/caption.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations # Let classes use themselves in type annotations 2 | 3 | import typing 4 | 5 | from sqlalchemy import CheckConstraint 6 | from sqlalchemy.orm import Mapped, mapped_column 7 | 8 | from aana.storage.models.base import BaseEntity, TimeStampEntity 9 | 10 | if typing.TYPE_CHECKING: 11 | from aana.core.models.captions import Caption 12 | 13 | 14 | class CaptionEntity(BaseEntity, TimeStampEntity): 15 | """ORM model for video captions. 16 | 17 | Attributes: 18 | id (int): Unique identifier for the caption. 19 | model (str): Name of the model used to generate the caption. 20 | frame_id (int): The 0-based frame id of video for caption. 21 | caption (str): Frame caption. 22 | timestamp (float): Frame timestamp in seconds. 23 | caption_type (str): The type of caption (populated automatically by ORM based on `polymorphic_identity` of subclass). 24 | """ 25 | 26 | __tablename__ = "caption" 27 | 28 | id: Mapped[int] = mapped_column(autoincrement=True, primary_key=True) 29 | model: Mapped[str] = mapped_column( 30 | nullable=False, comment="Name of model used to generate the caption" 31 | ) 32 | frame_id: Mapped[int] = mapped_column( 33 | CheckConstraint("frame_id >= 0", "frame_id_positive"), 34 | comment="The 0-based frame id of video for caption", 35 | ) 36 | caption: Mapped[str] = mapped_column(comment="Frame caption") 37 | timestamp: Mapped[float] = mapped_column( 38 | CheckConstraint("timestamp >= 0", name="timestamp_positive"), 39 | comment="Frame timestamp in seconds", 40 | ) 41 | caption_type: Mapped[str] = mapped_column(comment="The type of caption") 42 | 43 | __mapper_args__ = { # noqa: RUF012 44 | "polymorphic_identity": "caption", 45 | "polymorphic_on": "caption_type", 46 | } 47 | 48 | @classmethod 49 | def from_caption_output( 50 | cls, 51 | model_name: str, 52 | caption: Caption, 53 | frame_id: int, 54 | timestamp: float, 55 | ) -> CaptionEntity: 56 | """Converts a Caption pydantic model to a CaptionEntity. 57 | 58 | Args: 59 | model_name (str): Name of the model used to generate the caption. 60 | caption (Caption): Caption pydantic model. 61 | frame_id (int): The 0-based frame id of video for caption. 62 | timestamp (float): Frame timestamp in seconds. 63 | 64 | Returns: 65 | CaptionEntity: ORM model for video captions. 66 | """ 67 | return CaptionEntity( 68 | model=model_name, 69 | frame_id=frame_id, 70 | caption=str(caption), 71 | timestamp=timestamp, 72 | ) 73 | -------------------------------------------------------------------------------- /docs/pages/model_hub/index.md: -------------------------------------------------------------------------------- 1 | # Model Hub 2 | 3 | 4 | 5 | Model deployment is a crucial part of the machine learning workflow. Aana SDK uses concept of deployments to serve models. 6 | 7 | The deployments are "recipes" that can be used to deploy models. With the same deployment, you can deploy multiple different models by providing specific configurations. 8 | 9 | Aana SDK comes with a set of predefined deployments, like [VLLMDeployment](./../../reference/deployments.md#aana.deployments.VLLMDeployment) for serving Large Language Models (LLMs) with [vLLM](https://github.com/vllm-project/vllm/) library or [WhisperDeployment](./../../reference/deployments.md#aana.deployments.WhisperDeployment) for automatic Speech Recognition (ASR) based on the [faster-whisper](https://github.com/SYSTRAN/faster-whisper) library. 10 | 11 | Each deployment has its own configuration class that specifies which model to deploy and with which parameters. 12 | 13 | The model hub provides a collection of configurations for different models that can be used with the predefined deployments. 14 | 15 | The full list of predefined deployments can be found in the [Deployments](./../integrations.md). 16 | 17 | !!! tip 18 | 19 | The Model Hub provides only a subset of the available models. You can deploy a lot more models using predefined deployments. For example, [Hugging Face Pipeline Deployment](./../../reference/deployments.md#aana.deployments.HfPipelineDeployment) is a generic deployment that can be used to deploy any model from the [Hugging Face Model Hub](https://huggingface.co/models) that can be used with [Hugging Face Pipelines](https://huggingface.co/transformers/main_classes/pipelines.html). It would be impossible to list all the models that can be deployed with this deployment. 20 | 21 | !!! tip 22 | 23 | The SDK is not limited to the predefined deployments. You can create your own deployment. 24 | 25 | ## How to Use the Model Hub 26 | 27 | There are a few ways to use the Model Hub (from the simplest to the most advanced): 28 | 29 | - Find the model configuration you are interested in and copy the configuration code to your project. 30 | 31 | - Use the provided examples as a starting point to create your own configurations for existing deployments. 32 | 33 | - Create a new deployment with your own configuration. 34 | 35 | See [Tutorial](./../tutorial.md#deployments) for more information on how to use the deployments. 36 | 37 | ## Models by Category 38 | 39 | - [Text Generation Models (LLMs)](./text_generation.md) 40 | - [Image-to-Text Models](./image_to_text.md) 41 | - [Half-Quadratic Quantization Models](./hqq.md) 42 | - [Automatic Speech Recognition (ASR) Models](./asr.md) 43 | - [Hugging Face Pipeline Models](./hf_pipeline.md) 44 | -------------------------------------------------------------------------------- /aana/api/event_handlers/event_manager.py: -------------------------------------------------------------------------------- 1 | from collections import defaultdict 2 | from collections.abc import MutableMapping 3 | 4 | from aana.api.event_handlers.event_handler import EventHandler 5 | from aana.exceptions.runtime import ( 6 | HandlerNotRegisteredException, 7 | ) 8 | 9 | 10 | class EventManager: 11 | """Class for event manager. Not guaranteed to be thread safe.""" 12 | 13 | _handlers: MutableMapping[str, EventHandler] 14 | 15 | def __init__(self): 16 | """Constructor.""" 17 | self._handlers = defaultdict(list) 18 | 19 | def handle(self, event_name: str, *args, **kwargs): 20 | """Trigger event handlers for `event_name`. 21 | 22 | Arguments: 23 | event_name (str): name of event 24 | *args (list): specific args 25 | **kwargs (dict): specific args 26 | """ 27 | for handler in self._handlers[event_name]: 28 | handler.handle(event_name, *args, **kwargs) 29 | 30 | def register_handler_for_events( 31 | self, handler: EventHandler, event_names: list[str] 32 | ): 33 | """Adds a handler to the event handler list. 34 | 35 | Arguments: 36 | handler (EventHandler): the handler to deregister 37 | event_names (list[str]): the events from which this handler is to be deregistered 38 | """ 39 | for event_name in event_names: 40 | if handler not in self._handlers[event_name]: 41 | self._handlers[event_name].append(handler) 42 | 43 | def deregister_handler_from_event(self, handler: EventHandler, event_name: str): 44 | """Removes a handler from the event handler list. 45 | 46 | Arguments: 47 | handler (EventHandler): the handler to remove 48 | event_name (str): the name of the event from which the handler should be removed 49 | Raises: 50 | HandlerNotRegisteredException: if the handler isn't registered. (embed in try-except to suppress) 51 | """ 52 | try: 53 | self._handlers[event_name].remove(handler) 54 | except ValueError as e: 55 | raise HandlerNotRegisteredException() from e 56 | 57 | def deregister_handler_from_all_events(self, handler: EventHandler): 58 | """Removes a handler from all event handlers. 59 | 60 | Arguments: 61 | handler (EventHandler): the exact instance of the handler to remove. 62 | 63 | Raises: 64 | HandlerNotRegisteredException: if the handler isn't registered. (embed in try-except to suppress) 65 | """ 66 | has_removed = False 67 | for handler_list in self._handlers.values(): 68 | if handler in handler_list: 69 | handler_list.remove(handler) 70 | has_removed = True 71 | if not has_removed: 72 | raise HandlerNotRegisteredException() 73 | -------------------------------------------------------------------------------- /aana/storage/models/api_key.py: -------------------------------------------------------------------------------- 1 | from sqlalchemy import Boolean 2 | from sqlalchemy.orm import DeclarativeBase, Mapped, mapped_column 3 | 4 | from aana.core.models.api_service import ApiKey 5 | from aana.storage.models.base import TimeStampEntity, timestamp 6 | 7 | 8 | class ApiServiceBase(DeclarativeBase): 9 | """Base class.""" 10 | 11 | pass 12 | 13 | 14 | class ApiKeyEntity(ApiServiceBase, TimeStampEntity): 15 | """Table for API keys.""" 16 | 17 | __tablename__ = "api_keys" 18 | __entity_name__ = "api key" 19 | 20 | id: Mapped[int] = mapped_column(autoincrement=True, primary_key=True) 21 | key_id: Mapped[str] = mapped_column( 22 | nullable=False, unique=True, comment="The API key id in api gateway" 23 | ) 24 | user_id: Mapped[str] = mapped_column( 25 | nullable=False, index=True, comment="ID of the user who owns this API key" 26 | ) 27 | api_key: Mapped[str] = mapped_column( 28 | nullable=False, index=True, unique=True, comment="The API key" 29 | ) 30 | is_admin: Mapped[bool] = mapped_column( 31 | Boolean, nullable=False, default=False, comment="Whether the user is an admin" 32 | ) 33 | subscription_id: Mapped[str] = mapped_column( 34 | nullable=False, comment="ID of the associated subscription" 35 | ) 36 | is_subscription_active: Mapped[bool] = mapped_column( 37 | Boolean, 38 | nullable=False, 39 | default=True, 40 | comment="Whether the subscription is active (credits are available)", 41 | ) 42 | hmac_secret: Mapped[str] = mapped_column( 43 | nullable=True, comment="The secret key for HMAC signature generation" 44 | ) 45 | expired_at: Mapped[timestamp] = mapped_column( 46 | nullable=False, comment="The expiration date of the API key" 47 | ) 48 | 49 | def __repr__(self) -> str: 50 | """String representation of the API key.""" 51 | return ( 52 | f"" 63 | ) 64 | 65 | def to_model(self) -> ApiKey: 66 | """Convert the object to a dictionary.""" 67 | return ApiKey( 68 | api_key=self.api_key, 69 | user_id=self.user_id, 70 | is_admin=self.is_admin, 71 | subscription_id=self.subscription_id, 72 | is_subscription_active=self.is_subscription_active, 73 | hmac_secret=self.hmac_secret, 74 | ) 75 | -------------------------------------------------------------------------------- /aana/api/security.py: -------------------------------------------------------------------------------- 1 | from typing import Annotated 2 | 3 | from fastapi import Depends, Request 4 | 5 | from aana.configs.settings import settings as aana_settings 6 | from aana.core.models.api_service import ApiKey 7 | from aana.exceptions.api_service import AdminOnlyAccess, InactiveSubscription 8 | 9 | 10 | def is_admin(request: Request) -> bool: 11 | """Check if the user is an admin. 12 | 13 | Args: 14 | request (Request): The request object 15 | 16 | Returns: 17 | bool: True if the user is an admin, False otherwise 18 | """ 19 | if aana_settings.api_service.enabled: 20 | api_key_info: ApiKey = request.state.api_key_info 21 | return api_key_info.is_admin if api_key_info else False 22 | return True 23 | 24 | 25 | def require_admin_access(request: Request) -> bool: 26 | """Check if the user is an admin. If not, raise an exception. 27 | 28 | Args: 29 | request (Request): The request object 30 | 31 | Raises: 32 | AdminOnlyAccess: If the user is not an admin 33 | """ 34 | _is_admin = is_admin(request) 35 | if not _is_admin: 36 | raise AdminOnlyAccess() 37 | return True 38 | 39 | 40 | def extract_api_key_info(request: Request) -> ApiKey | None: 41 | """Get the API key info dependency.""" 42 | return getattr(request.state, "api_key_info", None) 43 | 44 | 45 | def extract_user_id(request: Request) -> str | None: 46 | """Get the user ID dependency.""" 47 | api_key_info = extract_api_key_info(request) 48 | return api_key_info.user_id if api_key_info else None 49 | 50 | 51 | def require_active_subscription(request: Request) -> bool: 52 | """Check if the user has an active subscription. If not, raise an exception. 53 | 54 | Args: 55 | request (Request): The request object 56 | 57 | Raises: 58 | InactiveSubscription: If the user does not have an active subscription 59 | """ 60 | if aana_settings.api_service.enabled: 61 | api_key_info: ApiKey = request.state.api_key_info 62 | if not api_key_info.is_subscription_active: 63 | raise InactiveSubscription(key=api_key_info.api_key) 64 | return True 65 | 66 | 67 | AdminAccessDependency = Annotated[bool, Depends(require_admin_access)] 68 | """ Dependency to check if the user is an admin. If not, it will raise an exception. """ 69 | 70 | IsAdminDependency = Annotated[bool, Depends(is_admin)] 71 | """ Dependency to check if the user is an admin. """ 72 | 73 | UserIdDependency = Annotated[str | None, Depends(extract_user_id)] 74 | """ Dependency to get the user ID. """ 75 | 76 | ApiKeyInfoDependency = Annotated[ApiKey | None, Depends(extract_api_key_info)] 77 | """ Dependency to get the API key info. """ 78 | 79 | ActiveSubscriptionRequiredDependency = Annotated[ 80 | bool, Depends(require_active_subscription) 81 | ] 82 | """ Dependency to check if the user has an active subscription. If not, it will raise an exception. """ 83 | -------------------------------------------------------------------------------- /aana/storage/models/task.py: -------------------------------------------------------------------------------- 1 | import uuid 2 | from enum import Enum 3 | 4 | from sqlalchemy import ( 5 | JSON, 6 | UUID, 7 | PickleType, 8 | ) 9 | from sqlalchemy.orm import Mapped, mapped_column 10 | 11 | from aana.storage.models.base import BaseEntity, TimeStampEntity, timestamp 12 | 13 | 14 | class Status(str, Enum): 15 | """Enum for task status. 16 | 17 | Attributes: 18 | CREATED: The task is created. 19 | ASSIGNED: The task is assigned to a worker. 20 | COMPLETED: The task is completed. 21 | RUNNING: The task is running. 22 | FAILED: The task has failed. 23 | NOT_FINISHED: The task is not finished. 24 | """ 25 | 26 | CREATED = "created" 27 | ASSIGNED = "assigned" 28 | COMPLETED = "completed" 29 | RUNNING = "running" 30 | FAILED = "failed" 31 | NOT_FINISHED = "not_finished" 32 | 33 | 34 | class TaskEntity(BaseEntity, TimeStampEntity): 35 | """Table for task items.""" 36 | 37 | __tablename__ = "tasks" 38 | 39 | id: Mapped[uuid.UUID] = mapped_column( 40 | UUID, primary_key=True, default=uuid.uuid4, comment="Task ID" 41 | ) 42 | endpoint: Mapped[str] = mapped_column( 43 | nullable=False, comment="The endpoint to which the task is assigned" 44 | ) 45 | data = mapped_column(PickleType, nullable=False, comment="Data for the task") 46 | status: Mapped[Status] = mapped_column( 47 | default=Status.CREATED, 48 | comment="Status of the task", 49 | index=True, 50 | ) 51 | priority: Mapped[int] = mapped_column( 52 | nullable=False, default=0, comment="Priority of the task (0 is the lowest)" 53 | ) 54 | assigned_at: Mapped[timestamp | None] = mapped_column( 55 | comment="Timestamp when the task was assigned", 56 | ) 57 | completed_at: Mapped[timestamp | None] = mapped_column( 58 | server_default=None, 59 | comment="Timestamp when the task was completed", 60 | ) 61 | progress: Mapped[float] = mapped_column( 62 | nullable=False, default=0.0, comment="Progress of the task in percentage" 63 | ) 64 | result: Mapped[dict | None] = mapped_column( 65 | JSON, comment="Result of the task in JSON format" 66 | ) 67 | num_retries: Mapped[int] = mapped_column( 68 | nullable=False, default=0, comment="Number of retries" 69 | ) 70 | user_id: Mapped[str] = mapped_column( 71 | nullable=True, 72 | comment="ID of the user who launched the task", 73 | index=True, 74 | ) 75 | 76 | def __repr__(self) -> str: 77 | """String representation of the task.""" 78 | return ( 79 | f"" 85 | ) 86 | -------------------------------------------------------------------------------- /aana/core/models/api.py: -------------------------------------------------------------------------------- 1 | from enum import Enum 2 | 3 | from pydantic import BaseModel, ConfigDict, Field 4 | from ray.serve.schema import ApplicationStatus 5 | 6 | 7 | class SDKStatus(str, Enum): 8 | """The status of the SDK.""" 9 | 10 | UNHEALTHY = "UNHEALTHY" 11 | RUNNING = "RUNNING" 12 | DEPLOYING = "DEPLOYING" 13 | 14 | 15 | class DeploymentStatus(BaseModel): 16 | """The status of a deployment.""" 17 | 18 | status: ApplicationStatus = Field(description="The status of the deployment.") 19 | message: str = Field( 20 | description="The message for more information like error message." 21 | ) 22 | 23 | model_config = ConfigDict(extra="forbid") 24 | 25 | 26 | class SDKStatusResponse(BaseModel): 27 | """The response for the SDK status endpoint. 28 | 29 | Attributes: 30 | status (SDKStatus): The status of the SDK. 31 | message (str): The message for more information like error message. 32 | deployments (dict[str, DeploymentStatus]): The status of each deployment in the Aana app. 33 | """ 34 | 35 | status: SDKStatus = Field(description="The status of the SDK.") 36 | message: str = Field( 37 | description="The message for more information like error message." 38 | ) 39 | deployments: dict[str, DeploymentStatus] = Field( 40 | description="The status of each deployment in the Aana app." 41 | ) 42 | 43 | model_config = ConfigDict( 44 | json_schema_extra={ 45 | "description": "The response for the SDK status endpoint.", 46 | "examples": [ 47 | { 48 | "status": "RUNNING", 49 | "message": "", 50 | "deployments": { 51 | "app": { 52 | "status": "RUNNING", 53 | "message": "", 54 | }, 55 | "lowercase_deployment": { 56 | "status": "RUNNING", 57 | "message": "", 58 | }, 59 | }, 60 | }, 61 | { 62 | "status": "UNHEALTHY", 63 | "message": "Error: Lowercase (lowercase_deployment): A replica's health check failed. " 64 | "This deployment will be UNHEALTHY until the replica recovers or a new deploy happens.", 65 | "deployments": { 66 | "app": { 67 | "status": "RUNNING", 68 | "message": "", 69 | }, 70 | "lowercase_deployment": { 71 | "status": "UNHEALTHY", 72 | "message": "A replica's health check failed. This deployment will be UNHEALTHY " 73 | "until the replica recovers or a new deploy happens.", 74 | }, 75 | }, 76 | }, 77 | ], 78 | }, 79 | extra="forbid", 80 | ) 81 | -------------------------------------------------------------------------------- /aana/storage/models/base.py: -------------------------------------------------------------------------------- 1 | import datetime 2 | from typing import Annotated, Any, TypeVar 3 | 4 | from sqlalchemy import MetaData, String, func 5 | from sqlalchemy.orm import ( 6 | DeclarativeBase, 7 | Mapped, 8 | mapped_column, 9 | object_mapper, 10 | registry, 11 | ) 12 | 13 | from aana.core.models.media import MediaId 14 | from aana.storage.types import TimezoneAwareDateTime 15 | 16 | timestamp = Annotated[ 17 | datetime.datetime, 18 | mapped_column(TimezoneAwareDateTime(timezone=True)), 19 | ] 20 | 21 | T = TypeVar("T", bound="InheritanceReuseMixin") 22 | 23 | 24 | class InheritanceReuseMixin: 25 | """Mixin for instantiating child classes from parent instances.""" 26 | 27 | @classmethod 28 | def from_parent(cls: type[T], parent_instance: Any, **kwargs: Any) -> T: 29 | """Create a new instance of the child class, reusing attributes from the parent instance. 30 | 31 | Args: 32 | parent_instance (Any): An instance of the parent class 33 | kwargs (Any): Additional keyword arguments to set on the new instance 34 | 35 | Returns: 36 | T: A new instance of the child class 37 | """ 38 | # Get the mapped attributes of the parent class 39 | mapper = object_mapper(parent_instance) 40 | attributes = { 41 | prop.key: getattr(parent_instance, prop.key) 42 | for prop in mapper.iterate_properties 43 | if hasattr(parent_instance, prop.key) 44 | and prop.key 45 | != mapper.polymorphic_on.name # don't copy the polymorphic_on attribute from the parent 46 | } 47 | 48 | # Update attributes with any additional kwargs 49 | attributes.update(kwargs) 50 | 51 | # Create and return a new instance of the child class 52 | return cls(**attributes) 53 | 54 | 55 | class BaseEntity(DeclarativeBase, InheritanceReuseMixin): 56 | """Base for all ORM classes.""" 57 | 58 | metadata = MetaData( 59 | naming_convention={ 60 | "ix": "ix_%(column_0_label)s", 61 | "uq": "uq_%(table_name)s_%(column_0_name)s", 62 | "ck": "ck_%(table_name)s_`%(constraint_name)s`", 63 | "fk": "fk_%(table_name)s_%(column_0_name)s_%(referred_table_name)s", 64 | "pk": "pk_%(table_name)s", 65 | } 66 | ) 67 | 68 | registry = registry( 69 | type_annotation_map={ 70 | MediaId: String(36), 71 | } 72 | ) 73 | 74 | def __repr__(self) -> str: 75 | """Get the representation of the entity.""" 76 | return f"{self.__class__.__name__}(id={self.id})" 77 | 78 | 79 | class TimeStampEntity: 80 | """Mixin for database entities that will have create/update timestamps.""" 81 | 82 | created_at: Mapped[timestamp] = mapped_column( 83 | server_default=func.now(), 84 | comment="Timestamp when row is inserted", 85 | ) 86 | updated_at: Mapped[timestamp] = mapped_column( 87 | onupdate=func.now(), 88 | server_default=func.now(), 89 | comment="Timestamp when row is updated", 90 | ) 91 | -------------------------------------------------------------------------------- /docs/pages/model_hub/speaker_recognition.md: -------------------------------------------------------------------------------- 1 | # Speaker Recognition 2 | 3 | ## Speaker Diarization (SD) Models 4 | 5 | [PyannoteSpeakerDiarizationDeployment](./../../reference/deployments.md#aana.deployments.pyannote_speaker_diarization_deployment.PyannoteSpeakerDiarizationDeployment) allows you to diarize the audio for speakers audio with pyannote models. The deployment is based on the [pyannote.audio](https://github.com/pyannote/pyannote-audio) library. 6 | 7 | !!! Tip 8 | To use Pyannotate Speaker Diarization deployment, install required libraries with `pip install pyannote-audio` or include extra dependencies using `pip install aana[asr]`. 9 | 10 | [PyannoteSpeakerDiarizationConfig](./../../reference/deployments.md#aana.deployments.pyannote_speaker_diarization_deployment.PyannoteSpeakerDiarizationConfig) is used to configure the Speaker Diarization deployment. 11 | 12 | ::: aana.deployments.pyannote_speaker_diarization_deployment.PyannoteSpeakerDiarizationConfig 13 | options: 14 | show_bases: false 15 | heading_level: 4 16 | show_docstring_description: false 17 | docstring_section_style: list 18 | 19 | 20 | ## Accessing Gated Models 21 | 22 | The PyAnnote speaker diarization models are gated, requiring special access. To use these models: 23 | 24 | 1. **Request Access**: 25 | Visit the [PyAnnote Speaker Diarization 3.1 model page](https://huggingface.co/pyannote/speaker-diarization-3.1) and [Pyannote Speaker Segmentation 3.0 model page](https://huggingface.co/pyannote/segmentation-3.0) on Hugging Face. Log in, fil out the forms, and request access. 26 | 27 | 2. **Approval**: 28 | - If automatic, access is granted immediately. 29 | - If manual, wait for the model authors to approve your request. 30 | 31 | 3. **Set Up the SDK**: 32 | After approval, add your Hugging Face access token to your `.env` file by setting the `HF_TOKEN` variable: 33 | 34 | ```plaintext 35 | HF_TOKEN=your_huggingface_access_token 36 | ``` 37 | 38 | To get your Hugging Face access token, visit the [Hugging Face Settings - Tokens](https://huggingface.co/settings/tokens). 39 | 40 | 41 | ## Example Configurations 42 | 43 | As an example, let's see how to configure the Pyannote Speaker Diarization deployment for the [Speaker Diarization-3.1 model](https://huggingface.co/pyannote/speaker-diarization-3.1). 44 | 45 | !!! example "Speaker diarization-3.1" 46 | 47 | ```python 48 | from aana.deployments.pyannote_speaker_diarization_deployment import PyannoteSpeakerDiarizationDeployment, PyannoteSpeakerDiarizationConfig 49 | 50 | PyannoteSpeakerDiarizationDeployment.options( 51 | num_replicas=1, 52 | max_ongoing_requests=1000, 53 | ray_actor_options={"num_gpus": 0.05}, 54 | user_config=PyannoteSpeakerDiarizationConfig( 55 | model_name=("pyannote/speaker-diarization-3.1"), 56 | sample_rate=16000, 57 | ).model_dump(mode="json"), 58 | ) 59 | ``` 60 | 61 | ## Diarized ASR 62 | 63 | Speaker Diarization output can be combined with ASR to generate transcription with speaker information. Further details and code snippet are available in [ASR model hub](./asr.md/#diarized-asr). 64 | 65 | -------------------------------------------------------------------------------- /docs/pages/openai_api.md: -------------------------------------------------------------------------------- 1 | # OpenAI-compatible API 2 | 3 | Aana SDK provides an OpenAI-compatible Chat Completions API that allows you to integrate Aana with any OpenAI-compatible application. 4 | 5 | Chat Completions API is available at the `/chat/completions` endpoint. 6 | 7 | !!! Tip 8 | The endpoint is enabled by default but can be disabled by setting the environment variable: `OPENAI_ENDPOINT_ENABLED=False`. 9 | 10 | It is compatible with the OpenAI client libraries and can be used as a drop-in replacement for OpenAI API. 11 | 12 | ```python 13 | from openai import OpenAI 14 | 15 | client = OpenAI( 16 | api_key="token", # Any non empty string will work, we don't require an API key 17 | base_url="http://localhost:8000", 18 | ) 19 | 20 | messages = [ 21 | {"role": "user", "content": "What is the capital of France?"} 22 | ] 23 | 24 | completion = client.chat.completions.create( 25 | messages=messages, 26 | model="llm_deployment", 27 | ) 28 | 29 | print(completion.choices[0].message.content) 30 | ``` 31 | 32 | The API also supports streaming: 33 | 34 | ```python 35 | from openai import OpenAI 36 | 37 | client = OpenAI( 38 | api_key="token", # Any non empty string will work, we don't require an API key 39 | base_url="http://localhost:8000", 40 | ) 41 | 42 | messages = [ 43 | {"role": "user", "content": "What is the capital of France?"} 44 | ] 45 | 46 | stream = client.chat.completions.create( 47 | messages=messages, 48 | model="llm_deployment", 49 | stream=True, 50 | ) 51 | for chunk in stream: 52 | print(chunk.choices[0].delta.content or "", end="") 53 | ``` 54 | 55 | The API requires an LLM deployment. Aana SDK provides support for [vLLM](integrations.md#vllm) and [Hugging Face Transformers](integrations.md#hugging-face-transformers). 56 | 57 | The name of the model matches the name of the deployment. For example, if you registered a vLLM deployment with the name `llm_deployment`, you can use it with the OpenAI API as `model="llm_deployment"`. 58 | 59 | ```python 60 | import os 61 | 62 | os.environ["CUDA_VISIBLE_DEVICES"] = "0" 63 | 64 | from aana.core.models.sampling import SamplingParams 65 | from aana.core.models.types import Dtype 66 | from aana.deployments.vllm_deployment import VLLMConfig, VLLMDeployment 67 | from aana.sdk import AanaSDK 68 | 69 | llm_deployment = VLLMDeployment.options( 70 | num_replicas=1, 71 | ray_actor_options={"num_gpus": 1}, 72 | user_config=VLLMConfig( 73 | model="TheBloke/Llama-2-7b-Chat-AWQ", 74 | dtype=Dtype.AUTO, 75 | quantization="awq", 76 | gpu_memory_reserved=13000, 77 | enforce_eager=True, 78 | default_sampling_params=SamplingParams( 79 | temperature=0.0, top_p=1.0, top_k=-1, max_tokens=1024 80 | ), 81 | chat_template="llama2", 82 | ).model_dump(mode="json"), 83 | ) 84 | 85 | aana_app = AanaSDK(name="llm_app") 86 | aana_app.register_deployment(name="llm_deployment", instance=llm_deployment) 87 | 88 | if __name__ == "__main__": 89 | aana_app.connect() 90 | aana_app.migrate() 91 | aana_app.deploy() 92 | ``` 93 | 94 | You can also use the example project `llama2` to deploy Llama-2-7b Chat model. 95 | 96 | ```bash 97 | CUDA_VISIBLE_DEVICES=0 aana deploy aana.projects.llama2.app:aana_app 98 | ``` 99 | -------------------------------------------------------------------------------- /aana/tests/units/test_app.py: -------------------------------------------------------------------------------- 1 | # ruff: noqa: S101, S113 2 | import json 3 | from typing import TypedDict 4 | 5 | import requests 6 | from ray import serve 7 | 8 | from aana.api.api_generation import Endpoint 9 | from aana.deployments.aana_deployment_handle import AanaDeploymentHandle 10 | from aana.deployments.base_deployment import BaseDeployment 11 | 12 | 13 | @serve.deployment 14 | class Lowercase(BaseDeployment): 15 | """Ray deployment that returns the lowercase version of a text.""" 16 | 17 | async def lower(self, text: str) -> dict: 18 | """Lowercase the text. 19 | 20 | Args: 21 | text (str): The text to lowercase 22 | 23 | Returns: 24 | dict: The lowercase text 25 | """ 26 | return {"text": text.lower()} 27 | 28 | 29 | class LowercaseEndpointOutput(TypedDict): 30 | """The output of the lowercase endpoint.""" 31 | 32 | text: str 33 | 34 | 35 | class LowercaseEndpoint(Endpoint): 36 | """Lowercase endpoint.""" 37 | 38 | async def initialize(self): 39 | """Initialize the endpoint.""" 40 | self.lowercase_handle = await AanaDeploymentHandle.create( 41 | "lowercase_deployment" 42 | ) 43 | await super().initialize() 44 | 45 | async def run(self, text: str) -> LowercaseEndpointOutput: 46 | """Lowercase the text. 47 | 48 | Args: 49 | text (str): The list of text to lowercase 50 | 51 | Returns: 52 | LowercaseEndpointOutput: The lowercase texts 53 | """ 54 | lowercase_output = await self.lowercase_handle.lower(text=text) 55 | return {"text": lowercase_output["text"]} 56 | 57 | 58 | deployments = [ 59 | { 60 | "name": "lowercase_deployment", 61 | "instance": Lowercase, 62 | } 63 | ] 64 | 65 | endpoints = [ 66 | { 67 | "name": "lowercase", 68 | "path": "/lowercase", 69 | "summary": "Lowercase text", 70 | "endpoint_cls": LowercaseEndpoint, 71 | } 72 | ] 73 | 74 | 75 | def test_app(create_app): 76 | """Test the Ray Serve app.""" 77 | aana_app = create_app(deployments, endpoints) 78 | 79 | port = aana_app.port 80 | route_prefix = "" 81 | 82 | # Check that the server is ready 83 | response = requests.get(f"http://localhost:{port}{route_prefix}/api/ready") 84 | assert response.status_code == 200 85 | assert response.json() == {"ready": True} 86 | 87 | # Test lowercase endpoint 88 | data = {"text": "Hello World! This is a test."} 89 | response = requests.post( 90 | f"http://localhost:{port}{route_prefix}/lowercase", 91 | data={"body": json.dumps(data)}, 92 | ) 93 | assert response.status_code == 200 94 | lowercase_text = response.json().get("text") 95 | assert lowercase_text == "hello world! this is a test." 96 | 97 | # Test that extra fields are not allowed 98 | data = {"text": "Hello World! This is a test.", "extra_field": "extra_value"} 99 | response = requests.post( 100 | f"http://localhost:{port}{route_prefix}/lowercase", 101 | data={"body": json.dumps(data)}, 102 | ) 103 | assert response.status_code == 422, response.text 104 | -------------------------------------------------------------------------------- /aana/tests/db/datastore/test_caption_repo.py: -------------------------------------------------------------------------------- 1 | # ruff: noqa: S101 2 | 3 | import random 4 | import uuid 5 | 6 | import pytest 7 | 8 | from aana.core.models.captions import Caption 9 | from aana.exceptions.db import NotFoundException 10 | from aana.storage.repository.caption import CaptionRepository 11 | 12 | 13 | @pytest.fixture(scope="function") 14 | def dummy_caption(): 15 | """Creates a dummy caption for testing.""" 16 | caption = Caption(f"This is a caption {uuid.uuid4()}") 17 | frame_id = random.randint(0, 100) # noqa: S311 18 | timestamp = random.random() # noqa: S311 19 | return caption, frame_id, timestamp 20 | 21 | 22 | @pytest.mark.asyncio 23 | async def test_save_caption(db_session_manager, dummy_caption): 24 | """Tests saving a caption.""" 25 | async with db_session_manager.session() as session: 26 | caption, frame_id, timestamp = dummy_caption 27 | model_name = "blip2" 28 | 29 | caption_repo = CaptionRepository(session) 30 | caption_entity = await caption_repo.save( 31 | model_name=model_name, 32 | caption=caption, 33 | frame_id=frame_id, 34 | timestamp=timestamp, 35 | ) 36 | caption_id = caption_entity.id 37 | 38 | caption_entity = await caption_repo.read(caption_id) 39 | assert caption_entity.model == model_name 40 | assert caption_entity.frame_id == frame_id 41 | assert caption_entity.timestamp == timestamp 42 | assert caption_entity.caption == caption 43 | 44 | await caption_repo.delete(caption_id) 45 | with pytest.raises(NotFoundException): 46 | await caption_repo.read(caption_id) 47 | 48 | 49 | @pytest.mark.asyncio 50 | async def test_save_all_captions(db_session_manager, dummy_caption): 51 | """Tests saving all captions.""" 52 | async with db_session_manager.session() as session: 53 | captions, frame_ids, timestamps = [], [], [] 54 | for _ in range(3): 55 | caption, frame_id, timestamp = dummy_caption 56 | captions.append(caption) 57 | frame_ids.append(frame_id) 58 | timestamps.append(timestamp) 59 | model_name = "blip2" 60 | 61 | caption_repo = CaptionRepository(session) 62 | caption_entities = await caption_repo.save_all( 63 | model_name=model_name, 64 | captions=captions, 65 | timestamps=timestamps, 66 | frame_ids=frame_ids, 67 | ) 68 | assert len(caption_entities) == len(captions) 69 | 70 | caption_ids = [caption_entity.id for caption_entity in caption_entities] 71 | for caption_id, caption, frame_id, timestamp in zip( 72 | caption_ids, captions, frame_ids, timestamps, strict=True 73 | ): 74 | caption_entity = await caption_repo.read(caption_id) 75 | 76 | assert caption_entity.model == model_name 77 | assert caption_entity.frame_id == frame_id 78 | assert caption_entity.timestamp == timestamp 79 | assert caption_entity.caption == caption 80 | 81 | # delete all captions 82 | for caption_id in caption_ids: 83 | await caption_repo.delete(caption_id) 84 | with pytest.raises(NotFoundException): 85 | await caption_repo.read(caption_id) 86 | -------------------------------------------------------------------------------- /aana/exceptions/api_service.py: -------------------------------------------------------------------------------- 1 | from aana.exceptions.core import BaseException 2 | 3 | 4 | class ApiKeyNotProvided(BaseException): 5 | """Exception raised when the API key is not provided.""" 6 | 7 | def __init__(self): 8 | """Initialize the exception.""" 9 | self.message = "API key not provided" 10 | super().__init__(message=self.message) 11 | 12 | def __reduce__(self): 13 | """Used for pickling.""" 14 | return (self.__class__, ()) 15 | 16 | 17 | class ApiKeyNotFound(BaseException): 18 | """Exception raised when the API key is not found. 19 | 20 | Attributes: 21 | key (str): the API key that was not found 22 | """ 23 | 24 | def __init__(self, key: str): 25 | """Initialize the exception. 26 | 27 | Args: 28 | key (str): the API key that was not found 29 | """ 30 | self.key = key 31 | self.message = f"API key {key} not found" 32 | super().__init__(key=key, message=self.message) 33 | 34 | def __reduce__(self): 35 | """Used for pickling.""" 36 | return (self.__class__, (self.key,)) 37 | 38 | 39 | class InactiveSubscription(BaseException): 40 | """Exception raised when the subscription is inactive (e.g. credits are not available). 41 | 42 | Attributes: 43 | key (str): the API key with inactive subscription 44 | """ 45 | 46 | def __init__(self, key: str): 47 | """Initialize the exception. 48 | 49 | Args: 50 | key (str): the API key with inactive subscription 51 | """ 52 | self.key = key 53 | self.message = ( 54 | f"API key {key} has an inactive subscription. Check your credits." 55 | ) 56 | super().__init__(key=key, message=self.message) 57 | 58 | def __reduce__(self): 59 | """Used for pickling.""" 60 | return (self.__class__, (self.key,)) 61 | 62 | 63 | class AdminOnlyAccess(BaseException): 64 | """Exception raised when the user does not have enough permissions.""" 65 | 66 | def __init__(self): 67 | """Initialize the exception.""" 68 | self.message = "Admin only access" 69 | super().__init__(message=self.message) 70 | 71 | def __reduce__(self): 72 | """Used for pickling.""" 73 | return (self.__class__, ()) 74 | 75 | 76 | class ApiKeyValidationFailed(BaseException): 77 | """Exception raised when the API key validation fails.""" 78 | 79 | def __init__(self): 80 | """Initialize the exception.""" 81 | self.message = "API key validation failed" 82 | super().__init__(message=self.message) 83 | 84 | def __reduce__(self): 85 | """Used for pickling.""" 86 | return (self.__class__, ()) 87 | 88 | 89 | class ApiKeyExpired(BaseException): 90 | """Exception raised when the API key is expired.""" 91 | 92 | def __init__(self, key: str): 93 | """Initialize the exception. 94 | 95 | Args: 96 | key (str): the expired API key 97 | """ 98 | self.key = key 99 | self.message = f"API key {key} is expired." 100 | super().__init__(key=key, message=self.message) 101 | 102 | def __reduce__(self): 103 | """Used for pickling.""" 104 | return (self.__class__, (self.key,)) 105 | -------------------------------------------------------------------------------- /aana/storage/models/transcript.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations # Let classes use themselves in type annotations 2 | 3 | from typing import TYPE_CHECKING 4 | 5 | from sqlalchemy import JSON, CheckConstraint 6 | from sqlalchemy.orm import Mapped, mapped_column 7 | 8 | from aana.storage.models.base import BaseEntity, TimeStampEntity 9 | 10 | if TYPE_CHECKING: 11 | from aana.core.models.asr import ( 12 | AsrSegments, 13 | AsrTranscription, 14 | AsrTranscriptionInfo, 15 | ) 16 | 17 | 18 | class TranscriptEntity(BaseEntity, TimeStampEntity): 19 | """ORM class for media transcripts generated by a model. 20 | 21 | Attributes: 22 | id (int): Unique identifier for the transcript. 23 | model (str): Name of the model used to generate the transcript. 24 | transcript (str): Full text transcript of the media. 25 | segments (dict): Segments of the transcript. 26 | language (str): Language of the transcript as predicted by the model. 27 | language_confidence (float): Confidence score of language prediction. 28 | transcript_type (str): The type of transcript (populated automatically by ORM based on `polymorphic_identity` of subclass). 29 | """ 30 | 31 | __tablename__ = "transcript" 32 | 33 | id: Mapped[int] = mapped_column(autoincrement=True, primary_key=True) 34 | model: Mapped[str] = mapped_column( 35 | nullable=False, comment="Name of model used to generate transcript" 36 | ) 37 | transcript: Mapped[str] = mapped_column(comment="Full text transcript of media") 38 | segments: Mapped[dict] = mapped_column(JSON, comment="Segments of the transcript") 39 | language: Mapped[str] = mapped_column( 40 | comment="Language of the transcript as predicted by model" 41 | ) 42 | language_confidence: Mapped[float] = mapped_column( 43 | CheckConstraint( 44 | "0 <= language_confidence <= 1", name="language_confidence_value_range" 45 | ), 46 | comment="Confidence score of language prediction", 47 | ) 48 | transcript_type: Mapped[str] = mapped_column(comment="The type of transcript") 49 | 50 | __mapper_args__ = { # noqa: RUF012 51 | "polymorphic_identity": "transcript", 52 | "polymorphic_on": "transcript_type", 53 | } 54 | 55 | @classmethod 56 | def from_asr_output( 57 | cls, 58 | model_name: str, 59 | info: AsrTranscriptionInfo, 60 | transcription: AsrTranscription, 61 | segments: AsrSegments, 62 | ) -> TranscriptEntity: 63 | """Converts an AsrTranscriptionInfo and AsrTranscription to a single Transcript entity. 64 | 65 | Args: 66 | model_name (str): Name of the model used to generate the transcript. 67 | info (AsrTranscriptionInfo): Information about the transcription. 68 | transcription (AsrTranscription): The full transcription. 69 | segments (AsrSegments): Segments of the transcription. 70 | 71 | Returns: 72 | TranscriptEntity: A new instance of the TranscriptEntity class. 73 | """ 74 | return TranscriptEntity( 75 | model=model_name, 76 | language=info.language, 77 | language_confidence=info.language_confidence, 78 | transcript=transcription.text, 79 | segments=[s.model_dump() for s in segments], 80 | ) 81 | --------------------------------------------------------------------------------