├── .github └── workflows │ ├── lint.yml │ ├── local.yaml │ ├── main.yml │ └── unit.yaml ├── .gitignore ├── .pre-commit-config.yaml ├── CHANGES ├── CODE_OF_CONDUCT.md ├── CONTRIBUTING.md ├── LICENSE.md ├── Makefile ├── README.md ├── astrapy ├── __init__.py ├── admin │ ├── __init__.py │ ├── admin.py │ └── endpoints.py ├── api_options.py ├── authentication.py ├── client.py ├── collection.py ├── constants.py ├── cursors.py ├── data │ ├── __init__.py │ ├── collection.py │ ├── cursors │ │ ├── __init__.py │ │ ├── cursor.py │ │ ├── farr_cursor.py │ │ ├── find_cursor.py │ │ ├── query_engine.py │ │ └── reranked_result.py │ ├── database.py │ ├── info │ │ ├── collection_descriptor.py │ │ ├── database_info.py │ │ ├── reranking.py │ │ ├── table_descriptor │ │ │ ├── table_altering.py │ │ │ ├── table_columns.py │ │ │ ├── table_creation.py │ │ │ ├── table_indexes.py │ │ │ └── table_listing.py │ │ └── vectorize.py │ ├── table.py │ └── utils │ │ ├── __init__.py │ │ ├── collection_converters.py │ │ ├── distinct_extractors.py │ │ ├── extended_json_converters.py │ │ ├── table_converters.py │ │ ├── table_types.py │ │ └── vector_coercion.py ├── data_types │ ├── __init__.py │ ├── data_api_date.py │ ├── data_api_duration.py │ ├── data_api_map.py │ ├── data_api_set.py │ ├── data_api_time.py │ ├── data_api_timestamp.py │ └── data_api_vector.py ├── database.py ├── exceptions │ ├── __init__.py │ ├── collection_exceptions.py │ ├── data_api_exceptions.py │ ├── devops_api_exceptions.py │ └── table_exceptions.py ├── ids.py ├── info.py ├── py.typed ├── results.py ├── settings │ ├── __init__.py │ ├── defaults.py │ └── error_messages.py ├── table.py └── utils │ ├── __init__.py │ ├── api_commander.py │ ├── api_options.py │ ├── date_utils.py │ ├── document_paths.py │ ├── duration_c_utils.py │ ├── duration_std_utils.py │ ├── meta.py │ ├── parsing.py │ ├── request_tools.py │ ├── str_enum.py │ ├── unset.py │ └── user_agents.py ├── pictures ├── astrapy_abstractions.png ├── astrapy_datetime_serdes_options.png └── astrapy_exceptions.png ├── pyproject.toml ├── tests ├── __init__.py ├── admin │ ├── __init__.py │ ├── conftest.py │ └── integration │ │ ├── __init__.py │ │ ├── test_admin.py │ │ └── test_nonastra_admin.py ├── base │ ├── __init__.py │ ├── collection_decimal_support_assets.py │ ├── conftest.py │ ├── integration │ │ ├── __init__.py │ │ ├── collection_decimal_support_assets.py │ │ ├── collections │ │ │ ├── __init__.py │ │ │ ├── test_collection_cursor_async.py │ │ │ ├── test_collection_cursor_sync.py │ │ │ ├── test_collection_ddl_async.py │ │ │ ├── test_collection_ddl_sync.py │ │ │ ├── test_collection_decimal_support.py │ │ │ ├── test_collection_dml_async.py │ │ │ ├── test_collection_dml_sync.py │ │ │ ├── test_collection_exceptions_async.py │ │ │ ├── test_collection_exceptions_sync.py │ │ │ ├── test_collection_farr_async.py │ │ │ ├── test_collection_farr_sync.py │ │ │ ├── test_collection_farrcursor_async.py │ │ │ ├── test_collection_farrcursor_sync.py │ │ │ ├── test_collection_timeout_async.py │ │ │ ├── test_collection_timeout_sync.py │ │ │ ├── test_collection_typing.py │ │ │ ├── test_collection_vectorize_methods_async.py │ │ │ └── test_collection_vectorize_methods_sync.py │ │ ├── conftest.py │ │ ├── misc │ │ │ ├── __init__.py │ │ │ ├── test_admin_ops_async.py │ │ │ ├── test_admin_ops_sync.py │ │ │ ├── test_reranking_ops_async.py │ │ │ ├── test_reranking_ops_sync.py │ │ │ ├── test_vectorize_ops_async.py │ │ │ └── test_vectorize_ops_sync.py │ │ └── tables │ │ │ ├── __init__.py │ │ │ ├── table_cql_assets.py │ │ │ ├── table_row_assets.py │ │ │ ├── test_table_column_types_async.py │ │ │ ├── test_table_column_types_sync.py │ │ │ ├── test_table_cqldriven_dml_sync.py │ │ │ ├── test_table_cursor_async.py │ │ │ ├── test_table_cursor_sync.py │ │ │ ├── test_table_dml_async.py │ │ │ ├── test_table_dml_sync.py │ │ │ ├── test_table_lifecycle_async.py │ │ │ ├── test_table_lifecycle_sync.py │ │ │ ├── test_table_mapsastuples_async.py │ │ │ ├── test_table_mapsastuples_sync.py │ │ │ ├── test_table_typing.py │ │ │ ├── test_table_vectorize_async.py │ │ │ └── test_table_vectorize_sync.py │ ├── table_decimal_support_assets.py │ ├── table_structure_assets.py │ └── unit │ │ ├── __init__.py │ │ ├── test_admin_conversions.py │ │ ├── test_apicommander.py │ │ ├── test_apioptions.py │ │ ├── test_collection_decimal_support.py │ │ ├── test_collection_options.py │ │ ├── test_collection_timeouts.py │ │ ├── test_collections_async.py │ │ ├── test_collections_sync.py │ │ ├── test_collectionvectorserviceoptions.py │ │ ├── test_dataapidate.py │ │ ├── test_dataapiduration.py │ │ ├── test_dataapimap.py │ │ ├── test_dataapiset.py │ │ ├── test_dataapitime.py │ │ ├── test_dataapitimestamp.py │ │ ├── test_dataapivector.py │ │ ├── test_databases_async.py │ │ ├── test_databases_sync.py │ │ ├── test_datetime_serdes_options.py │ │ ├── test_document_extractors.py │ │ ├── test_document_paths.py │ │ ├── test_embeddingheadersprovider.py │ │ ├── test_exceptions.py │ │ ├── test_findandrerank_collectiondefinition.py │ │ ├── test_findembeddingproviderresult.py │ │ ├── test_findrerankingproviderresult.py │ │ ├── test_ids.py │ │ ├── test_imports.py │ │ ├── test_info.py │ │ ├── test_multicalltimeoutmanager.py │ │ ├── test_regionname_deprecation.py │ │ ├── test_rerankingheadersprovider.py │ │ ├── test_strenum.py │ │ ├── test_table_decimal_support.py │ │ ├── test_table_dry_methods.py │ │ ├── test_tableconverteragent.py │ │ ├── test_tabledescriptors.py │ │ ├── test_tableindexdescriptor_parsing.py │ │ ├── test_timeouts.py │ │ ├── test_token_providers.py │ │ ├── test_tpostprocessors.py │ │ └── test_tpreprocessors_mapastuples.py ├── conftest.py ├── dse_compose │ ├── README │ └── docker-compose.yml ├── env_templates │ ├── env.astra.admin.template │ ├── env.astra.template │ ├── env.local.template │ ├── env.testcontainers.template │ ├── env.vectorize-minimal.template │ └── env.vectorize.template ├── hcd_compose │ ├── cassandra-hcd.yaml │ └── docker-compose.yml ├── preprocess_env.py └── vectorize │ ├── __init__.py │ ├── conftest.py │ ├── integration │ ├── __init__.py │ └── test_vectorize_providers.py │ ├── live_provider_info.py │ ├── query_providers.py │ └── vectorize_models.py └── uv.lock /.github/workflows/lint.yml: -------------------------------------------------------------------------------- 1 | name: ruff and mypy checks 2 | 3 | on: 4 | push: 5 | branches: 6 | - main 7 | pull_request: 8 | branches: 9 | - main 10 | 11 | jobs: 12 | mypy: 13 | runs-on: ubuntu-latest 14 | 15 | steps: 16 | - uses: actions/checkout@v2 17 | - name: Set up Python 18 | uses: actions/setup-python@v2 19 | with: 20 | python-version: '3.11' # Or any version you prefer 21 | 22 | - name: Install dependencies 23 | run: | 24 | python -m pip install --upgrade pip 25 | pipx install uv 26 | make venv 27 | 28 | - name: Ruff Linting AstraPy 29 | run: | 30 | uv run ruff check astrapy 31 | 32 | - name: Ruff Linting Tests 33 | run: | 34 | uv run ruff check tests 35 | 36 | - name: Ruff formatting astrapy 37 | run: | 38 | uv run ruff format --check astrapy 39 | 40 | - name: Ruff formatting tests 41 | run: | 42 | uv run ruff format --check tests 43 | 44 | - name: Run MyPy AstraPy 45 | run: | 46 | uv run mypy astrapy 47 | 48 | - name: Run MyPy Tests 49 | run: | 50 | uv run mypy tests 51 | -------------------------------------------------------------------------------- /.github/workflows/local.yaml: -------------------------------------------------------------------------------- 1 | name: Run base integration tests on a local Data API 2 | 3 | on: 4 | push: 5 | branches: 6 | - main 7 | pull_request: 8 | branches: 9 | - main 10 | 11 | jobs: 12 | test: 13 | env: 14 | HEADER_EMBEDDING_API_KEY_OPENAI: ${{ secrets.HEADER_EMBEDDING_API_KEY_OPENAI }} 15 | # hardcoding the target DB 16 | DOCKER_COMPOSE_LOCAL_DATA_API: "yes" 17 | # turn on header-based reranker auth 18 | ASTRAPY_FINDANDRERANK_USE_RERANKER_HEADER: "yes" 19 | HEADER_RERANKING_API_KEY_NVIDIA: ${{ secrets.HEADER_RERANKING_API_KEY_NVIDIA }} 20 | runs-on: ubuntu-latest 21 | 22 | steps: 23 | - name: Checkout code 24 | uses: actions/checkout@v2 25 | 26 | - name: Set up Python 27 | uses: actions/setup-python@v2 28 | with: 29 | python-version: 3.11 30 | 31 | - name: Install dependencies 32 | run: | 33 | python -m pip install --upgrade pip 34 | pipx install uv 35 | make venv 36 | 37 | # Prepare to login to ECR: 38 | - name: Configure AWS credentials 39 | uses: aws-actions/configure-aws-credentials@v4 40 | with: 41 | aws-access-key-id: ${{ secrets.HCD_ECR_ACCESS_KEY }} 42 | aws-secret-access-key: ${{ secrets.HCD_ECR_SECRET_KEY }} 43 | aws-region: us-west-2 44 | 45 | # Login to ECR so we can pull HCD image: 46 | - name: Login to Amazon ECR 47 | id: login-ecr 48 | uses: aws-actions/amazon-ecr-login@v2 49 | with: 50 | mask-password: 'true' 51 | 52 | - name: Run pytest 53 | run: | 54 | uv run pytest tests/base/integration 55 | -------------------------------------------------------------------------------- /.github/workflows/main.yml: -------------------------------------------------------------------------------- 1 | name: Run base integration tests on Astra DB 2 | 3 | on: 4 | push: 5 | branches: 6 | - main 7 | pull_request: 8 | branches: 9 | - main 10 | 11 | jobs: 12 | test: 13 | env: 14 | # basic secrets 15 | ASTRA_DB_APPLICATION_TOKEN: ${{ secrets.ASTRA_DB_APPLICATION_TOKEN }} 16 | ASTRA_DB_API_ENDPOINT: ${{ secrets.ASTRA_DB_API_ENDPOINT }} 17 | ASTRA_DB_KEYSPACE: ${{ secrets.ASTRA_DB_KEYSPACE }} 18 | ASTRA_DB_SECONDARY_KEYSPACE: ${{ secrets.ASTRA_DB_SECONDARY_KEYSPACE }} 19 | HEADER_EMBEDDING_API_KEY_OPENAI: ${{ secrets.HEADER_EMBEDDING_API_KEY_OPENAI }} 20 | runs-on: ubuntu-latest 21 | 22 | steps: 23 | - name: Checkout code 24 | uses: actions/checkout@v2 25 | 26 | - name: Set up Python 27 | uses: actions/setup-python@v2 28 | with: 29 | python-version: 3.11 30 | 31 | - name: Install dependencies 32 | run: | 33 | python -m pip install --upgrade pip 34 | pipx install uv 35 | make venv 36 | 37 | - name: Run pytest 38 | run: | 39 | uv run pytest tests/base/integration 40 | -------------------------------------------------------------------------------- /.github/workflows/unit.yaml: -------------------------------------------------------------------------------- 1 | name: Run unit tests 2 | 3 | on: 4 | push: 5 | branches: 6 | - main 7 | pull_request: 8 | branches: 9 | - main 10 | 11 | jobs: 12 | test: 13 | env: 14 | # basic secrets 15 | ASTRA_DB_APPLICATION_TOKEN: ${{ secrets.ASTRA_DB_APPLICATION_TOKEN }} 16 | ASTRA_DB_API_ENDPOINT: ${{ secrets.ASTRA_DB_API_ENDPOINT }} 17 | ASTRA_DB_KEYSPACE: ${{ secrets.ASTRA_DB_KEYSPACE }} 18 | ASTRA_DB_SECONDARY_KEYSPACE: ${{ secrets.ASTRA_DB_SECONDARY_KEYSPACE }} 19 | runs-on: ubuntu-latest 20 | strategy: 21 | matrix: 22 | python-version: 23 | - "3.8" 24 | - "3.9" 25 | - "3.10" 26 | - "3.11" 27 | - "3.12" 28 | name: "unit test on #${{ matrix.python-version }}" 29 | steps: 30 | - name: Checkout code 31 | uses: actions/checkout@v2 32 | 33 | - name: Set up Python 34 | uses: actions/setup-python@v2 35 | with: 36 | python-version: ${{ matrix.python-version }} 37 | 38 | - name: Install dependencies 39 | run: | 40 | python -m pip install --upgrade pip 41 | pipx install uv 42 | make venv 43 | 44 | - name: Run pytest 45 | run: | 46 | uv run pytest tests/base/unit 47 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | .logs 2 | .idea 3 | .DS_Store 4 | .project 5 | .pydevproject 6 | .settings 7 | .env 8 | node_modules 9 | venv 10 | .venv* 11 | dump.rdb 12 | .tmp 13 | npm-debug.log 14 | .ssh 15 | .csscomb.json 16 | tests/screenshots 17 | .bash_history 18 | .ipynb_checkpoints 19 | .scratch.py 20 | .pytest_cache 21 | .mypy_cache 22 | .vscode 23 | manifest 24 | dist 25 | 26 | # generated by findEmbeddingProviders dump: 27 | _providers.json 28 | 29 | # premade .gitignore from GitHub 30 | 31 | # Byte-compiled / optimized / DLL files 32 | __pycache__/ 33 | *.py[cod] 34 | 35 | # C extensions 36 | *.so 37 | 38 | # Distribution / packaging 39 | .Python 40 | env/ 41 | build/ 42 | !frontend/vendor/**/build 43 | develop-eggs/ 44 | downloads/ 45 | !pd_pfe/static/downloads 46 | eggs/ 47 | .eggs/ 48 | parts/ 49 | sdist/ 50 | var/ 51 | *.egg-info/ 52 | .installed.cfg 53 | *.egg 54 | 55 | # PyInstaller 56 | # Usually these files are written by a python script from a template 57 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 58 | *.manifest 59 | *.spec 60 | 61 | # Installer logs 62 | pip-log.txt 63 | pip-delete-this-directory.txt 64 | 65 | # Unit test / coverage reports 66 | htmlcov/ 67 | .tox/ 68 | .coverage 69 | .coverage.* 70 | .cache 71 | nosetests.xml 72 | coverage.xml 73 | *,cover 74 | 75 | # Translations 76 | *.mo 77 | *.pot 78 | 79 | # Django stuff: 80 | *.log 81 | 82 | # Sphinx documentation 83 | docs/_build/ 84 | 85 | # PyBuilder 86 | target/ 87 | 88 | # Browser test output 89 | junitresults.* 90 | -------------------------------------------------------------------------------- /.pre-commit-config.yaml: -------------------------------------------------------------------------------- 1 | repos: 2 | - repo: local 3 | hooks: 4 | - id: check 5 | name: check 6 | language: system 7 | entry: make format-fix 8 | -------------------------------------------------------------------------------- /CODE_OF_CONDUCT.md: -------------------------------------------------------------------------------- 1 | # Contributor Covenant Code of Conduct 2 | 3 | ## Our Pledge 4 | 5 | In the interest of fostering an open and welcoming environment, we as 6 | contributors and maintainers pledge to making participation in our project and 7 | our community a harassment-free experience for everyone, regardless of age, body 8 | size, disability, ethnicity, sex characteristics, gender identity and expression, 9 | level of experience, education, socio-economic status, nationality, personal 10 | appearance, race, religion, or sexual identity and orientation. 11 | 12 | ## Our Standards 13 | 14 | Examples of behavior that contributes to creating a positive environment 15 | include: 16 | 17 | * Using welcoming and inclusive language 18 | * Being respectful of differing viewpoints and experiences 19 | * Gracefully accepting constructive criticism 20 | * Focusing on what is best for the community 21 | * Showing empathy towards other community members 22 | 23 | Examples of unacceptable behavior by participants include: 24 | 25 | * The use of sexualized language or imagery and unwelcome sexual attention or 26 | advances 27 | * Trolling, insulting/derogatory comments, and personal or political attacks 28 | * Public or private harassment 29 | * Publishing others' private information, such as a physical or electronic 30 | address, without explicit permission 31 | * Other conduct which could reasonably be considered inappropriate in a 32 | professional setting 33 | 34 | ## Our Responsibilities 35 | 36 | Project maintainers are responsible for clarifying the standards of acceptable 37 | behavior and are expected to take appropriate and fair corrective action in 38 | response to any instances of unacceptable behavior. 39 | 40 | Project maintainers have the right and responsibility to remove, edit, or 41 | reject comments, commits, code, wiki edits, issues, and other contributions 42 | that are not aligned to this Code of Conduct, or to ban temporarily or 43 | permanently any contributor for other behaviors that they deem inappropriate, 44 | threatening, offensive, or harmful. 45 | 46 | ## Scope 47 | 48 | This Code of Conduct applies both within project spaces and in public spaces 49 | when an individual is representing the project or its community. Examples of 50 | representing a project or community include using an official project e-mail 51 | address, posting via an official social media account, or acting as an appointed 52 | representative at an online or offline event. Representation of a project may be 53 | further defined and clarified by project maintainers. 54 | 55 | ## Enforcement 56 | 57 | Instances of abusive, harassing, or otherwise unacceptable behavior may be 58 | reported by contacting the project team at . All 59 | complaints will be reviewed and investigated and will result in a response that 60 | is deemed necessary and appropriate to the circumstances. The project team is 61 | obligated to maintain confidentiality with regard to the reporter of an incident. 62 | Further details of specific enforcement policies may be posted separately. 63 | 64 | Project maintainers who do not follow or enforce the Code of Conduct in good 65 | faith may face temporary or permanent repercussions as determined by other 66 | members of the project's leadership. 67 | 68 | ## Attribution 69 | 70 | This Code of Conduct is adapted from the [Contributor Covenant][homepage], version 1.4, 71 | available at https://www.contributor-covenant.org/version/1/4/code-of-conduct.html 72 | 73 | [homepage]: https://www.contributor-covenant.org 74 | 75 | For answers to common questions about this code of conduct, see 76 | https://www.contributor-covenant.org/faq 77 | -------------------------------------------------------------------------------- /CONTRIBUTING.md: -------------------------------------------------------------------------------- 1 | # Contributing 2 | 3 | When contributing to this repository, please first discuss the change you wish to make via issue, 4 | email, or any other method with the owners of this repository before making a change. 5 | 6 | Please note we have a [Code of Conduct](CODE_OF_CONDUCT.md), please follow it in all your interactions with the project. 7 | 8 | ## Found an Issue? 9 | If you find a bug in the source code or a mistake in the documentation, you can help us by 10 | [submitting an issue](#submit-issue) to the GitHub Repository. Even better, you can 11 | [submit a Pull Request](#submit-pr) with a fix. 12 | 13 | ## Want a Feature? 14 | You can *request* a new feature by [submitting an issue](#submit-issue) to the GitHub 15 | Repository. If you would like to *implement* a new feature, please submit an issue with 16 | a proposal for your work first, to be sure that we can use it. 17 | 18 | * **Small Features** can be crafted and directly [submitted as a Pull Request](#submit-pr). 19 | 20 | ## Contribution Guidelines 21 | 22 | ### Contribution License Agreement 23 | To protect the community, all contributors are required to [sign the DataStax Contribution License Agreement](https://cla.datastax.com/). The process is completely electronic and should only take a few minutes. 24 | 25 | ### Submitting an Issue 26 | Before you submit an issue, search the archive, maybe your question was already answered. 27 | 28 | If your issue appears to be a bug, and hasn't been reported, open a new issue. 29 | Help us to maximize the effort we can spend fixing issues and adding new 30 | features, by not reporting duplicate issues. Providing the following information will increase the 31 | chances of your issue being dealt with quickly: 32 | 33 | * **Overview of the Issue** - if an error is being thrown a non-minified stack trace helps 34 | * **Motivation for or Use Case** - explain what are you trying to do and why the current behavior is a bug for you 35 | * **Reproduce the Error** - provide a live example or a unambiguous set of steps 36 | * **Suggest a Fix** - if you can't fix the bug yourself, perhaps you can point to what might be 37 | causing the problem (line of code or commit) 38 | 39 | ### Submitting a Pull Request (PR) 40 | Before you submit your Pull Request (PR) consider the following guidelines: 41 | 42 | * Search the repository (https://github.com/bechbd/[repository-name]/pulls) for an open or closed PR that relates to your submission. You don't want to duplicate effort. 43 | 44 | * Create a fork of the repo 45 | * Navigate to the repo you want to fork 46 | * In the top right corner of the page click **Fork**: 47 | ![](https://help.github.com/assets/images/help/repository/fork_button.jpg) 48 | 49 | * Make your changes in the forked repo 50 | * Commit your changes using a descriptive commit message 51 | * In GitHub, create a pull request: https://help.github.com/en/articles/creating-a-pull-request-from-a-fork 52 | * If we suggest changes then: 53 | * Make the required updates. 54 | * Rebase your fork and force push to your GitHub repository (this will update your Pull Request): 55 | 56 | ```shell 57 | git rebase main -i 58 | git push -f 59 | ``` 60 | 61 | That's it! Thank you for your contribution! -------------------------------------------------------------------------------- /Makefile: -------------------------------------------------------------------------------- 1 | SHELL := /bin/bash 2 | 3 | .PHONY: all venv format format-fix format-tests format-src test-integration test build help 4 | 5 | all: help 6 | 7 | FMT_FLAGS ?= --check 8 | VENV ?= false 9 | 10 | ifeq ($(VENV), true) 11 | VENV_FLAGS := --active 12 | else 13 | VENV_FLAGS := 14 | endif 15 | 16 | venv: 17 | uv venv 18 | uv sync --dev 19 | 20 | format: format-src format-tests 21 | 22 | format-tests: 23 | uv run $(VENV_FLAGS) ruff check tests 24 | uv run $(VENV_FLAGS) ruff format tests $(FMT_FLAGS) 25 | uv run $(VENV_FLAGS) mypy tests 26 | 27 | format-src: 28 | uv run $(VENV_FLAGS) ruff check astrapy 29 | uv run $(VENV_FLAGS) ruff format astrapy $(FMT_FLAGS) 30 | uv run $(VENV_FLAGS) mypy astrapy 31 | 32 | format-fix: format-fix-src format-fix-tests 33 | 34 | format-fix-src: FMT_FLAGS= 35 | format-fix-src: format-src 36 | 37 | format-fix-tests: FMT_FLAGS= 38 | format-fix-tests: format-tests 39 | 40 | test-integration: 41 | uv run $(VENV_FLAGS) pytest tests/base -vv 42 | 43 | test: 44 | uv run $(VENV_FLAGS) pytest tests/base/unit -vv 45 | 46 | docker-test-integration: 47 | DOCKER_COMPOSE_LOCAL_DATA_API="yes" uv run pytest tests/base -vv 48 | 49 | build: 50 | rm -f dist/astrapy* 51 | uv build 52 | 53 | help: 54 | @echo "======================================================================" 55 | @echo "AstraPy make command purpose" 56 | @echo "----------------------------------------------------------------------" 57 | @echo "venv create a virtual env (needs uv)" 58 | @echo "format full lint and format checks" 59 | @echo " format-src limited to source" 60 | @echo " format-tests limited to tests" 61 | @echo " format-fix fixing imports and style" 62 | @echo " format-fix-src limited to source" 63 | @echo " format-fix-tests limited to tests" 64 | @echo "test run unit tests" 65 | @echo "test-integration run integration tests" 66 | @echo "docker-test-integration run int.tests on dockerized local" 67 | @echo "build build package ready for PyPI" 68 | @echo "======================================================================" 69 | -------------------------------------------------------------------------------- /astrapy/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright DataStax, Inc. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | from __future__ import annotations 16 | 17 | import importlib.metadata 18 | import os 19 | 20 | import toml 21 | 22 | 23 | def get_version() -> str: 24 | try: 25 | # Expect a __version__ attribute in the package's __init__.py file 26 | return importlib.metadata.version(__package__) 27 | 28 | # If the package is not installed, we can still get the version from the pyproject.toml file 29 | except importlib.metadata.PackageNotFoundError: 30 | # Get the path to the pyproject.toml file 31 | dir_path = os.path.dirname(os.path.realpath(__file__)) 32 | pyproject_path = os.path.join(dir_path, "..", "pyproject.toml") 33 | 34 | # Read the pyproject.toml file and get the version from the appropriate section 35 | try: 36 | with open(pyproject_path, encoding="utf-8") as pyproject: 37 | # Load the pyproject.toml file as a dictionary 38 | file_contents = pyproject.read() 39 | pyproject_data = toml.loads(file_contents) 40 | 41 | # Return the version from the 'project' section 42 | return str(pyproject_data["tool"]["project"]["version"]) 43 | 44 | # If the pyproject.toml file does not exist or the version is not found, return unknown 45 | except (FileNotFoundError, KeyError): 46 | return "unknown" 47 | 48 | 49 | __version__: str = get_version() 50 | 51 | 52 | from astrapy import api_options # noqa: E402, F401 53 | from astrapy.admin import ( # noqa: E402 54 | AstraDBAdmin, 55 | AstraDBDatabaseAdmin, 56 | DataAPIDatabaseAdmin, 57 | ) 58 | from astrapy.client import DataAPIClient # noqa: E402 59 | from astrapy.collection import AsyncCollection, Collection # noqa: E402 60 | 61 | # A circular-import issue requires this to happen at the end of this module: 62 | from astrapy.database import AsyncDatabase, Database # noqa: E402 63 | from astrapy.table import AsyncTable, Table # noqa: E402 64 | 65 | __all__ = [ 66 | "AstraDBAdmin", 67 | "AstraDBDatabaseAdmin", 68 | "AsyncCollection", 69 | "AsyncDatabase", 70 | "AsyncTable", 71 | "Collection", 72 | "Database", 73 | "DataAPIClient", 74 | "DataAPIDatabaseAdmin", 75 | "Table", 76 | "__version__", 77 | ] 78 | 79 | 80 | __pdoc__ = { 81 | "ids": False, 82 | "settings": False, 83 | } 84 | -------------------------------------------------------------------------------- /astrapy/admin/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright DataStax, Inc. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | from __future__ import annotations 16 | 17 | from astrapy.admin.admin import ( 18 | AstraDBAdmin, 19 | AstraDBDatabaseAdmin, 20 | DataAPIDatabaseAdmin, 21 | DatabaseAdmin, 22 | ParsedAPIEndpoint, 23 | async_fetch_database_info, 24 | fetch_database_info, 25 | parse_api_endpoint, 26 | ) 27 | 28 | __all__ = [ 29 | "AstraDBAdmin", 30 | "AstraDBDatabaseAdmin", 31 | "DataAPIDatabaseAdmin", 32 | "DatabaseAdmin", 33 | "ParsedAPIEndpoint", 34 | "async_fetch_database_info", 35 | "fetch_database_info", 36 | "parse_api_endpoint", 37 | ] 38 | -------------------------------------------------------------------------------- /astrapy/api_options.py: -------------------------------------------------------------------------------- 1 | # Copyright DataStax, Inc. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | from __future__ import annotations 16 | 17 | from astrapy.utils.api_options import ( 18 | APIOptions, 19 | DataAPIURLOptions, 20 | DevOpsAPIURLOptions, 21 | SerdesOptions, 22 | TimeoutOptions, 23 | ) 24 | 25 | __all__ = [ 26 | "APIOptions", 27 | "DataAPIURLOptions", 28 | "DevOpsAPIURLOptions", 29 | "SerdesOptions", 30 | "TimeoutOptions", 31 | ] 32 | -------------------------------------------------------------------------------- /astrapy/collection.py: -------------------------------------------------------------------------------- 1 | # Copyright DataStax, Inc. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | from astrapy.data.collection import ( 16 | AsyncCollection, 17 | Collection, 18 | ) 19 | 20 | __all__ = [ 21 | "AsyncCollection", 22 | "Collection", 23 | ] 24 | -------------------------------------------------------------------------------- /astrapy/constants.py: -------------------------------------------------------------------------------- 1 | # Copyright DataStax, Inc. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | from __future__ import annotations 16 | 17 | from typing import Any, Dict, Iterable, List, Optional, Tuple, TypeVar, Union 18 | 19 | from astrapy.data_types import DataAPIVector 20 | from astrapy.settings.defaults import ( 21 | DATA_API_ENVIRONMENT_CASSANDRA, 22 | DATA_API_ENVIRONMENT_DEV, 23 | DATA_API_ENVIRONMENT_DSE, 24 | DATA_API_ENVIRONMENT_HCD, 25 | DATA_API_ENVIRONMENT_OTHER, 26 | DATA_API_ENVIRONMENT_PROD, 27 | DATA_API_ENVIRONMENT_TEST, 28 | ) 29 | from astrapy.utils.str_enum import StrEnum 30 | 31 | DefaultDocumentType = Dict[str, Any] 32 | DefaultRowType = Dict[str, Any] 33 | ProjectionType = Union[ 34 | Iterable[str], Dict[str, Union[bool, Dict[str, Union[int, Iterable[int]]]]] 35 | ] 36 | SortType = Dict[str, Any] 37 | HybridSortType = Dict[ 38 | str, Union[str, Dict[str, Union[str, List[float], DataAPIVector]]] 39 | ] 40 | FilterType = Dict[str, Any] 41 | CallerType = Tuple[Optional[str], Optional[str]] 42 | 43 | 44 | ROW = TypeVar("ROW") 45 | ROW2 = TypeVar("ROW2") 46 | DOC = TypeVar("DOC") 47 | DOC2 = TypeVar("DOC2") 48 | 49 | 50 | def normalize_optional_projection( 51 | projection: ProjectionType | None, 52 | ) -> dict[str, bool | dict[str, int | Iterable[int]]] | None: 53 | if projection: 54 | if isinstance(projection, dict): 55 | # already a dictionary 56 | return projection 57 | else: 58 | # an iterable over strings: coerce to allow-list projection 59 | return {field: True for field in projection} 60 | else: 61 | return None 62 | 63 | 64 | class ReturnDocument: 65 | """ 66 | Admitted values for the `return_document` parameter in 67 | `find_one_and_replace` and `find_one_and_update` collection 68 | methods. 69 | """ 70 | 71 | def __init__(self) -> None: 72 | raise NotImplementedError 73 | 74 | BEFORE = "before" 75 | AFTER = "after" 76 | 77 | 78 | class SortMode: 79 | """ 80 | Admitted values for the `sort` parameter in the find collection methods, 81 | e.g. `sort={"field": SortMode.ASCENDING}`. 82 | """ 83 | 84 | def __init__(self) -> None: 85 | raise NotImplementedError 86 | 87 | ASCENDING = 1 88 | DESCENDING = -1 89 | 90 | 91 | class VectorMetric: 92 | """ 93 | Admitted values for the "metric" parameter to use in CollectionVectorOptions 94 | object, needed when creating vector collections through the database 95 | `create_collection` method. 96 | """ 97 | 98 | def __init__(self) -> None: 99 | raise NotImplementedError 100 | 101 | DOT_PRODUCT = "dot_product" 102 | EUCLIDEAN = "euclidean" 103 | COSINE = "cosine" 104 | 105 | 106 | class DefaultIdType: 107 | """ 108 | Admitted values for the "default_id_type" parameter to use in 109 | CollectionDefaultIDOptions object, needed when creating collections 110 | through the database `create_collection` method. 111 | """ 112 | 113 | def __init__(self) -> None: 114 | raise NotImplementedError 115 | 116 | UUID = "uuid" 117 | OBJECTID = "objectId" 118 | UUIDV6 = "uuidv6" 119 | UUIDV7 = "uuidv7" 120 | DEFAULT = "uuid" 121 | 122 | 123 | class Environment: 124 | """ 125 | Admitted values for `environment` property, 126 | denoting the targeted API deployment type. 127 | """ 128 | 129 | def __init__(self) -> None: 130 | raise NotImplementedError 131 | 132 | PROD = DATA_API_ENVIRONMENT_PROD 133 | DEV = DATA_API_ENVIRONMENT_DEV 134 | TEST = DATA_API_ENVIRONMENT_TEST 135 | DSE = DATA_API_ENVIRONMENT_DSE 136 | HCD = DATA_API_ENVIRONMENT_HCD 137 | CASSANDRA = DATA_API_ENVIRONMENT_CASSANDRA 138 | OTHER = DATA_API_ENVIRONMENT_OTHER 139 | 140 | values = {PROD, DEV, TEST, DSE, HCD, CASSANDRA, OTHER} 141 | astra_db_values = {PROD, DEV, TEST} 142 | 143 | 144 | class MapEncodingMode(StrEnum): 145 | """ 146 | Enum for the possible values of the setting controlling whether to encode 147 | dicts/DataAPIMaps as lists of pairs ("association lists") in table payloads. 148 | """ 149 | 150 | NEVER = "NEVER" 151 | DATAAPIMAPS = "DATAAPIMAPS" 152 | ALWAYS = "ALWAYS" 153 | 154 | 155 | __all__ = [ 156 | "DefaultIdType", 157 | "Environment", 158 | "MapEncodingMode", 159 | "ReturnDocument", 160 | "SortMode", 161 | "VectorMetric", 162 | ] 163 | -------------------------------------------------------------------------------- /astrapy/cursors.py: -------------------------------------------------------------------------------- 1 | # Copyright DataStax, Inc. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | from astrapy.data.cursors.cursor import ( 16 | AbstractCursor, 17 | CursorState, 18 | ) 19 | from astrapy.data.cursors.farr_cursor import ( 20 | AsyncCollectionFindAndRerankCursor, 21 | CollectionFindAndRerankCursor, 22 | ) 23 | from astrapy.data.cursors.find_cursor import ( 24 | AsyncCollectionFindCursor, 25 | AsyncTableFindCursor, 26 | CollectionFindCursor, 27 | TableFindCursor, 28 | ) 29 | from astrapy.data.cursors.reranked_result import RerankedResult 30 | 31 | __all__ = [ 32 | "AsyncCollectionFindAndRerankCursor", 33 | "AsyncCollectionFindCursor", 34 | "AsyncTableFindCursor", 35 | "CollectionFindAndRerankCursor", 36 | "CollectionFindCursor", 37 | "AbstractCursor", 38 | "CursorState", 39 | "RerankedResult", 40 | "TableFindCursor", 41 | ] 42 | -------------------------------------------------------------------------------- /astrapy/data/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright DataStax, Inc. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | -------------------------------------------------------------------------------- /astrapy/data/cursors/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright DataStax, Inc. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | from __future__ import annotations 16 | -------------------------------------------------------------------------------- /astrapy/data/cursors/reranked_result.py: -------------------------------------------------------------------------------- 1 | # Copyright DataStax, Inc. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | from __future__ import annotations 16 | 17 | from dataclasses import dataclass 18 | from typing import Generic 19 | 20 | from astrapy.data.cursors.cursor import TRAW 21 | 22 | 23 | @dataclass 24 | class RerankedResult(Generic[TRAW]): 25 | """ 26 | A single result coming `find_and_rerank` command, i.e. an item from DB with scores. 27 | 28 | Attributes: 29 | document: a collection/row as returned by `find_and_rerank` API command. 30 | scores: a dictionary of score labels to score float values, such as 31 | `{"$rerank": 0.87, "$vector" : 0.65, "$lexical" : 0.91}`. 32 | """ 33 | 34 | document: TRAW 35 | scores: dict[str, float | int | None] 36 | -------------------------------------------------------------------------------- /astrapy/data/utils/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright DataStax, Inc. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | -------------------------------------------------------------------------------- /astrapy/data/utils/extended_json_converters.py: -------------------------------------------------------------------------------- 1 | # Copyright DataStax, Inc. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | from __future__ import annotations 16 | 17 | import base64 18 | import datetime 19 | import time 20 | 21 | from astrapy.data_types import DataAPITimestamp 22 | from astrapy.ids import UUID, ObjectId 23 | 24 | 25 | def convert_to_ejson_date_object( 26 | date_value: datetime.date | datetime.datetime, 27 | ) -> dict[str, int]: 28 | if isinstance(date_value, datetime.datetime): 29 | return {"$date": int(date_value.timestamp() * 1000)} 30 | return {"$date": int(time.mktime(date_value.timetuple()) * 1000)} 31 | 32 | 33 | def convert_to_ejson_apitimestamp_object( 34 | date_value: DataAPITimestamp, 35 | ) -> dict[str, int]: 36 | return {"$date": date_value.timestamp_ms} 37 | 38 | 39 | def convert_to_ejson_bytes(bytes_value: bytes) -> dict[str, str]: 40 | return {"$binary": base64.b64encode(bytes_value).decode()} 41 | 42 | 43 | def convert_to_ejson_uuid_object(uuid_value: UUID) -> dict[str, str]: 44 | return {"$uuid": str(uuid_value)} 45 | 46 | 47 | def convert_to_ejson_objectid_object(objectid_value: ObjectId) -> dict[str, str]: 48 | return {"$objectId": str(objectid_value)} 49 | 50 | 51 | def convert_ejson_date_object_to_datetime( 52 | date_object: dict[str, int], tz: datetime.timezone | None 53 | ) -> datetime.datetime: 54 | return datetime.datetime.fromtimestamp(date_object["$date"] / 1000.0, tz=tz) 55 | 56 | 57 | def convert_ejson_date_object_to_apitimestamp( 58 | date_object: dict[str, int], 59 | ) -> DataAPITimestamp: 60 | return DataAPITimestamp(date_object["$date"]) 61 | 62 | 63 | def convert_ejson_binary_object_to_bytes( 64 | binary_object: dict[str, str], 65 | ) -> bytes: 66 | return base64.b64decode(binary_object["$binary"]) 67 | 68 | 69 | def convert_ejson_uuid_object_to_uuid(uuid_object: dict[str, str]) -> UUID: 70 | return UUID(uuid_object["$uuid"]) 71 | 72 | 73 | def convert_ejson_objectid_object_to_objectid( 74 | objectid_object: dict[str, str], 75 | ) -> ObjectId: 76 | return ObjectId(objectid_object["$objectId"]) 77 | -------------------------------------------------------------------------------- /astrapy/data/utils/table_types.py: -------------------------------------------------------------------------------- 1 | # Copyright DataStax, Inc. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | from __future__ import annotations 16 | 17 | from astrapy.utils.str_enum import StrEnum 18 | 19 | 20 | class ColumnType(StrEnum): 21 | """ 22 | Enum to describe the scalar column types for Tables. 23 | 24 | A 'scalar' type is a non-composite type: that means, no sets, lists, maps 25 | and other non-primitive data types. 26 | """ 27 | 28 | ASCII = "ascii" 29 | BIGINT = "bigint" 30 | BLOB = "blob" 31 | BOOLEAN = "boolean" 32 | COUNTER = "counter" 33 | DATE = "date" 34 | DECIMAL = "decimal" 35 | DOUBLE = "double" 36 | DURATION = "duration" 37 | FLOAT = "float" 38 | INET = "inet" 39 | INT = "int" 40 | SMALLINT = "smallint" 41 | TEXT = "text" 42 | TIME = "time" 43 | TIMESTAMP = "timestamp" 44 | TIMEUUID = "timeuuid" 45 | TINYINT = "tinyint" 46 | UUID = "uuid" 47 | VARINT = "varint" 48 | 49 | 50 | class TableValuedColumnType(StrEnum): 51 | """ 52 | An enum to describe the types of column with "values". 53 | """ 54 | 55 | LIST = "list" 56 | SET = "set" 57 | 58 | 59 | class TableKeyValuedColumnType(StrEnum): 60 | """ 61 | An enum to describe the types of column with "keys and values". 62 | """ 63 | 64 | MAP = "map" 65 | 66 | 67 | class TableVectorColumnType(StrEnum): 68 | """ 69 | An enum to describe the types of 'vector-like' column. 70 | """ 71 | 72 | VECTOR = "vector" 73 | 74 | 75 | class TableUnsupportedColumnType(StrEnum): 76 | """ 77 | An enum to describe the types of column falling into the 'unsupported' group 78 | (read/describe path). 79 | """ 80 | 81 | UNSUPPORTED = "UNSUPPORTED" 82 | 83 | 84 | class TablePassthroughColumnType(StrEnum): 85 | """ 86 | An enum to describe the types for 'passthrough' columns (read/describe path). 87 | """ 88 | 89 | PASSTHROUGH = "PASSTHROUGH" 90 | -------------------------------------------------------------------------------- /astrapy/data/utils/vector_coercion.py: -------------------------------------------------------------------------------- 1 | # Copyright DataStax, Inc. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | from __future__ import annotations 16 | 17 | from typing import Any, Iterable 18 | 19 | from astrapy.data_types import DataAPIMap, DataAPISet, DataAPIVector 20 | 21 | ITERABLES_TO_NOT_UNROLL = list, str, bytes, dict, DataAPIVector, DataAPIMap, DataAPISet 22 | 23 | 24 | def ensure_unrolled_if_iterable(value: Any) -> Any: 25 | if isinstance(value, Iterable) and not isinstance(value, (ITERABLES_TO_NOT_UNROLL)): 26 | return list(value) 27 | return value 28 | 29 | 30 | def convert_vector_to_floats(vector: Iterable[Any]) -> list[float]: 31 | """ 32 | Convert a vector of strings to a vector of floats. 33 | 34 | Args: 35 | vector (list): A vector of objects. 36 | 37 | Returns: 38 | list: A vector of floats. 39 | """ 40 | return [float(value) for value in vector] 41 | 42 | 43 | def is_list_of_floats(vector: Iterable[Any]) -> bool: 44 | """ 45 | Safely determine if it's a list of floats. 46 | Assumption: if list, and first item is float, then all items are. 47 | """ 48 | return isinstance(vector, list) and ( 49 | len(vector) == 0 or isinstance(vector[0], float) or isinstance(vector[0], int) 50 | ) 51 | -------------------------------------------------------------------------------- /astrapy/data_types/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright DataStax, Inc. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | from __future__ import annotations 16 | 17 | from astrapy.data_types.data_api_date import DataAPIDate 18 | from astrapy.data_types.data_api_duration import DataAPIDuration 19 | from astrapy.data_types.data_api_map import DataAPIMap 20 | from astrapy.data_types.data_api_set import DataAPISet 21 | from astrapy.data_types.data_api_time import DataAPITime 22 | from astrapy.data_types.data_api_timestamp import DataAPITimestamp 23 | from astrapy.data_types.data_api_vector import DataAPIVector 24 | 25 | __all__ = [ 26 | "DataAPITimestamp", 27 | "DataAPIVector", 28 | "DataAPIDate", 29 | "DataAPIDuration", 30 | "DataAPIMap", 31 | "DataAPISet", 32 | "DataAPITime", 33 | ] 34 | -------------------------------------------------------------------------------- /astrapy/data_types/data_api_map.py: -------------------------------------------------------------------------------- 1 | # Copyright DataStax, Inc. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | from __future__ import annotations 16 | 17 | import math 18 | from typing import Generic, Iterable, Iterator, Mapping, TypeVar 19 | 20 | T = TypeVar("T") 21 | U = TypeVar("U") 22 | 23 | 24 | def _accumulate_pairs( 25 | destination: tuple[list[T], list[U]], source: Iterable[tuple[T, U]] 26 | ) -> tuple[list[T], list[U]]: 27 | _new_ks = list(destination[0]) 28 | _new_vs = list(destination[1]) 29 | for k, v in source: 30 | if k not in _new_ks: 31 | _new_ks.append(k) 32 | _new_vs.append(v) 33 | return (_new_ks, _new_vs) 34 | 35 | 36 | class DataAPIMap(Generic[T, U], Mapping[T, U]): 37 | """ 38 | An immutable 'map-like' class that preserves the order and can employ 39 | non-hashable keys (which must support __eq__). Not designed for performance. 40 | 41 | Despite internally preserving the order, equality between DataAPIMap instances 42 | (and with regular dicts) is independent of the order. 43 | """ 44 | 45 | _keys: list[T] 46 | _values: list[U] 47 | 48 | def __init__(self, source: Iterable[tuple[T, U]] | dict[T, U] = []) -> None: 49 | if isinstance(source, dict): 50 | self._keys, self._values = _accumulate_pairs( 51 | ([], []), 52 | source.items(), 53 | ) 54 | else: 55 | self._keys, self._values = _accumulate_pairs( 56 | ([], []), 57 | source, 58 | ) 59 | 60 | def __getitem__(self, key: T) -> U: 61 | if isinstance(key, float) and math.isnan(key): 62 | for idx, k in enumerate(self._keys): 63 | if isinstance(k, float) and math.isnan(k): 64 | return self._values[idx] 65 | raise KeyError(str(key)) 66 | else: 67 | for idx, k in enumerate(self._keys): 68 | if k == key: 69 | return self._values[idx] 70 | raise KeyError(str(key) + "//" + str(self._keys) + "//" + str(self._values)) 71 | 72 | def __iter__(self) -> Iterator[T]: 73 | return iter(self._keys) 74 | 75 | def __len__(self) -> int: 76 | return len(self._keys) 77 | 78 | def __eq__(self, other: object) -> bool: 79 | if isinstance(other, DataAPIMap): 80 | if len(self) == len(other): 81 | if all(o_k in self for o_k in other): 82 | return all(other[k] == self[k] for k in self) 83 | return False 84 | try: 85 | dother = dict(other) # type: ignore[call-overload] 86 | return all( 87 | [ 88 | len(dother) == len(self), 89 | all(o_k in self for o_k in dother), 90 | all(dother[k] == self[k] for k in self), 91 | ] 92 | ) 93 | except KeyError: 94 | return False 95 | except TypeError: 96 | pass 97 | return NotImplemented 98 | 99 | def __repr__(self) -> str: 100 | _map_repr = ", ".join( 101 | f"({repr(k)}, {repr(v)})" for k, v in zip(self._keys, self._values) 102 | ) 103 | return f"{self.__class__.__name__}([{_map_repr}])" 104 | 105 | def __str__(self) -> str: 106 | _map_repr = ", ".join(f"({k}, {v})" for k, v in zip(self._keys, self._values)) 107 | return f"{_map_repr}" 108 | 109 | def __reduce__(self) -> tuple[type, tuple[Iterable[tuple[T, U]]]]: 110 | return self.__class__, (list(zip(self._keys, self._values)),) 111 | -------------------------------------------------------------------------------- /astrapy/database.py: -------------------------------------------------------------------------------- 1 | # Copyright DataStax, Inc. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | from astrapy.data.database import ( 16 | AsyncDatabase, 17 | Database, 18 | ) 19 | 20 | __all__ = [ 21 | "AsyncDatabase", 22 | "Database", 23 | ] 24 | -------------------------------------------------------------------------------- /astrapy/exceptions/collection_exceptions.py: -------------------------------------------------------------------------------- 1 | # Copyright DataStax, Inc. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | from __future__ import annotations 16 | 17 | from dataclasses import dataclass 18 | from typing import TYPE_CHECKING, Any, Sequence 19 | 20 | from astrapy.exceptions.data_api_exceptions import DataAPIException 21 | 22 | if TYPE_CHECKING: 23 | from astrapy.results import ( 24 | CollectionDeleteResult, 25 | CollectionUpdateResult, 26 | ) 27 | 28 | 29 | @dataclass 30 | class TooManyDocumentsToCountException(DataAPIException): 31 | """ 32 | A `count_documents()` operation on a collection failed because the resulting 33 | number of documents exceeded either the upper bound set by the caller or the 34 | hard limit imposed by the Data API. 35 | 36 | Attributes: 37 | text: a text message about the exception. 38 | server_max_count_exceeded: True if the count limit imposed by the API 39 | is reached. In that case, increasing the upper bound in the method 40 | invocation is of no help. 41 | """ 42 | 43 | text: str 44 | server_max_count_exceeded: bool 45 | 46 | def __init__( 47 | self, 48 | text: str, 49 | *, 50 | server_max_count_exceeded: bool, 51 | ) -> None: 52 | super().__init__(text) 53 | self.text = text 54 | self.server_max_count_exceeded = server_max_count_exceeded 55 | 56 | 57 | @dataclass 58 | class CollectionInsertManyException(DataAPIException): 59 | """ 60 | An exception occurring within an insert_many (an operation that can span 61 | several requests). As such, it represents both the root error(s) that happened 62 | and information on the portion of the documents that were successfully inserted. 63 | 64 | The behaviour of insert_many (concurrency and the `ordered` setting) make it 65 | possible that more than one "root errors" are collected. 66 | 67 | Attributes: 68 | inserted_ids: a list of the document IDs that have been successfully inserted. 69 | exceptions: a list of the root exceptions leading to this error. The list, 70 | under normal circumstances, is not empty. 71 | """ 72 | 73 | inserted_ids: list[Any] 74 | exceptions: Sequence[Exception] 75 | 76 | def __str__(self) -> str: 77 | num_ids = len(self.inserted_ids) 78 | if self.exceptions: 79 | exc_desc: str 80 | excs_strs = [exc.__str__() for exc in self.exceptions[:8]] 81 | if len(self.exceptions) > 8: 82 | exc_desc = ", ".join(excs_strs) + " ... (more exceptions)" 83 | else: 84 | exc_desc = ", ".join(excs_strs) 85 | return ( 86 | f"{self.__class__.__name__}({exc_desc} [with {num_ids} inserted ids])" 87 | ) 88 | else: 89 | return f"{self.__class__.__name__}()" 90 | 91 | 92 | @dataclass 93 | class CollectionDeleteManyException(DataAPIException): 94 | """ 95 | An exception occurring during a delete_many (an operation that can span 96 | several requests). As such, besides information on the root-cause error, 97 | there may be a partial result about the part that succeeded. 98 | 99 | Attributes: 100 | partial_result: a CollectionDeleteResult object, just like the one that would 101 | be the return value of the operation, had it succeeded completely. 102 | cause: a root exception that happened during the delete_many, causing 103 | the method call to stop and raise this error. 104 | """ 105 | 106 | partial_result: CollectionDeleteResult 107 | cause: Exception 108 | 109 | def __str__(self) -> str: 110 | return f"{self.__class__.__name__}({self.cause.__str__()})" 111 | 112 | 113 | @dataclass 114 | class CollectionUpdateManyException(DataAPIException): 115 | """ 116 | An exception occurring during an update_many (an operation that can span 117 | several requests). As such, besides information on the root-cause error, 118 | there may be a partial result about the part that succeeded. 119 | 120 | Attributes: 121 | partial_result: a CollectionUpdateResult object, just like the one that would 122 | be the return value of the operation, had it succeeded completely. 123 | cause: a root exception that happened during the update_many, causing 124 | the method call to stop and raise this error. 125 | """ 126 | 127 | partial_result: CollectionUpdateResult 128 | cause: Exception 129 | 130 | def __str__(self) -> str: 131 | return f"{self.__class__.__name__}({self.cause.__str__()})" 132 | -------------------------------------------------------------------------------- /astrapy/exceptions/table_exceptions.py: -------------------------------------------------------------------------------- 1 | # Copyright DataStax, Inc. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | from __future__ import annotations 16 | 17 | from dataclasses import dataclass 18 | from typing import Any, Sequence 19 | 20 | from astrapy.exceptions.data_api_exceptions import DataAPIException 21 | 22 | 23 | @dataclass 24 | class TooManyRowsToCountException(DataAPIException): 25 | """ 26 | A `count_documents()` operation on a table failed because of the excessive amount 27 | of rows to count. 28 | 29 | Attributes: 30 | text: a text message about the exception. 31 | """ 32 | 33 | text: str 34 | server_max_count_exceeded: bool 35 | 36 | def __init__( 37 | self, 38 | text: str, 39 | *, 40 | server_max_count_exceeded: bool, 41 | ) -> None: 42 | super().__init__(text) 43 | self.text = text 44 | self.server_max_count_exceeded = server_max_count_exceeded 45 | 46 | 47 | @dataclass 48 | class TableInsertManyException(DataAPIException): 49 | """ 50 | An exception occurring within an insert_many (an operation that can span 51 | several requests). As such, it represents both the root error(s) that happened 52 | and information on the portion of the row that were successfully inserted. 53 | 54 | The behaviour of insert_many (concurrency and the `ordered` setting) make it 55 | possible that more than one "root errors" are collected. 56 | 57 | Attributes: 58 | inserted_ids: a list of the row IDs that have been successfully inserted, 59 | in the form of a dictionary matching the table primary key). 60 | inserted_id_tuples: the same information as for `inserted_ids` (in the same 61 | order), but in form of a tuples for each ID. 62 | exceptions: a list of the root exceptions leading to this error. The list, 63 | under normal circumstances, is not empty. 64 | """ 65 | 66 | inserted_ids: list[Any] 67 | inserted_id_tuples: list[tuple[Any, ...]] 68 | exceptions: Sequence[Exception] 69 | 70 | def __str__(self) -> str: 71 | num_ids = len(self.inserted_ids) 72 | if self.exceptions: 73 | exc_desc: str 74 | excs_strs = [exc.__str__() for exc in self.exceptions[:8]] 75 | if len(self.exceptions) > 8: 76 | exc_desc = ", ".join(excs_strs) + " ... (more exceptions)" 77 | else: 78 | exc_desc = ", ".join(excs_strs) 79 | return ( 80 | f"{self.__class__.__name__}({exc_desc} [with {num_ids} inserted ids])" 81 | ) 82 | else: 83 | return f"{self.__class__.__name__}()" 84 | -------------------------------------------------------------------------------- /astrapy/ids.py: -------------------------------------------------------------------------------- 1 | # Copyright DataStax, Inc. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | from __future__ import annotations 16 | 17 | from uuid import UUID, uuid1, uuid3, uuid4, uuid5 # noqa: F401 18 | 19 | from bson.objectid import ObjectId # noqa: F401 20 | from uuid6 import uuid6, uuid7, uuid8 # noqa: F401 21 | 22 | __all__ = [ 23 | "ObjectId", 24 | "uuid1", 25 | "uuid3", 26 | "uuid4", 27 | "uuid5", 28 | "uuid6", 29 | "uuid7", 30 | "uuid8", 31 | "UUID", 32 | ] 33 | -------------------------------------------------------------------------------- /astrapy/info.py: -------------------------------------------------------------------------------- 1 | # Copyright DataStax, Inc. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | from __future__ import annotations 16 | 17 | from astrapy.data.info.collection_descriptor import ( 18 | CollectionDefaultIDOptions, 19 | CollectionDefinition, 20 | CollectionDescriptor, 21 | CollectionInfo, 22 | CollectionLexicalOptions, 23 | CollectionRerankOptions, 24 | CollectionVectorOptions, 25 | ) 26 | from astrapy.data.info.database_info import ( 27 | AstraDBAdminDatabaseInfo, 28 | AstraDBAvailableRegionInfo, 29 | AstraDBDatabaseInfo, 30 | ) 31 | from astrapy.data.info.reranking import ( 32 | FindRerankingProvidersResult, 33 | RerankingAPIModelSupport, 34 | RerankingProvider, 35 | RerankingProviderAuthentication, 36 | RerankingProviderModel, 37 | RerankingProviderParameter, 38 | RerankingProviderToken, 39 | RerankServiceOptions, 40 | ) 41 | from astrapy.data.info.table_descriptor.table_altering import ( 42 | AlterTableAddColumns, 43 | AlterTableAddVectorize, 44 | AlterTableDropColumns, 45 | AlterTableDropVectorize, 46 | ) 47 | from astrapy.data.info.table_descriptor.table_columns import ( 48 | TableAPISupportDescriptor, 49 | TableKeyValuedColumnTypeDescriptor, 50 | TablePassthroughColumnTypeDescriptor, 51 | TablePrimaryKeyDescriptor, 52 | TableScalarColumnTypeDescriptor, 53 | TableUnsupportedColumnTypeDescriptor, 54 | TableValuedColumnTypeDescriptor, 55 | TableVectorColumnTypeDescriptor, 56 | ) 57 | from astrapy.data.info.table_descriptor.table_creation import ( 58 | CreateTableDefinition, 59 | ) 60 | from astrapy.data.info.table_descriptor.table_indexes import ( 61 | TableAPIIndexSupportDescriptor, 62 | TableBaseIndexDefinition, 63 | TableIndexDefinition, 64 | TableIndexDescriptor, 65 | TableIndexOptions, 66 | TableIndexType, 67 | TableUnsupportedIndexDefinition, 68 | TableVectorIndexDefinition, 69 | TableVectorIndexOptions, 70 | ) 71 | from astrapy.data.info.table_descriptor.table_listing import ( 72 | ListTableDefinition, 73 | ListTableDescriptor, 74 | TableInfo, 75 | ) 76 | from astrapy.data.info.vectorize import ( 77 | EmbeddingAPIModelSupport, 78 | EmbeddingProvider, 79 | EmbeddingProviderAuthentication, 80 | EmbeddingProviderModel, 81 | EmbeddingProviderParameter, 82 | EmbeddingProviderToken, 83 | FindEmbeddingProvidersResult, 84 | VectorServiceOptions, 85 | ) 86 | from astrapy.data.utils.table_types import ( 87 | ColumnType, 88 | TableKeyValuedColumnType, 89 | TableValuedColumnType, 90 | ) 91 | 92 | __all__ = [ 93 | "AlterTableAddColumns", 94 | "AlterTableAddVectorize", 95 | "AlterTableDropColumns", 96 | "AlterTableDropVectorize", 97 | "AstraDBAdminDatabaseInfo", 98 | "AstraDBAvailableRegionInfo", 99 | "AstraDBDatabaseInfo", 100 | "CollectionDefaultIDOptions", 101 | "CollectionDefinition", 102 | "CollectionDescriptor", 103 | "CollectionInfo", 104 | "CollectionLexicalOptions", 105 | "CollectionRerankOptions", 106 | "CollectionVectorOptions", 107 | "ColumnType", 108 | "CreateTableDefinition", 109 | "EmbeddingAPIModelSupport", 110 | "EmbeddingProvider", 111 | "EmbeddingProviderAuthentication", 112 | "EmbeddingProviderModel", 113 | "EmbeddingProviderParameter", 114 | "EmbeddingProviderToken", 115 | "FindEmbeddingProvidersResult", 116 | "FindRerankingProvidersResult", 117 | "ListTableDefinition", 118 | "ListTableDescriptor", 119 | "RerankingAPIModelSupport", 120 | "RerankingProvider", 121 | "RerankingProviderAuthentication", 122 | "RerankingProviderModel", 123 | "RerankingProviderParameter", 124 | "RerankingProviderToken", 125 | "RerankServiceOptions", 126 | "TableAPIIndexSupportDescriptor", 127 | "TableAPISupportDescriptor", 128 | "TableBaseIndexDefinition", 129 | "TableIndexDefinition", 130 | "TableIndexDescriptor", 131 | "TableIndexOptions", 132 | "TableIndexType", 133 | "TableInfo", 134 | "TableKeyValuedColumnType", 135 | "TableKeyValuedColumnTypeDescriptor", 136 | "TablePassthroughColumnTypeDescriptor", 137 | "TablePrimaryKeyDescriptor", 138 | "TableScalarColumnTypeDescriptor", 139 | "TableUnsupportedColumnTypeDescriptor", 140 | "TableUnsupportedIndexDefinition", 141 | "TableValuedColumnType", 142 | "TableValuedColumnTypeDescriptor", 143 | "TableVectorColumnTypeDescriptor", 144 | "TableVectorIndexDefinition", 145 | "TableVectorIndexOptions", 146 | "VectorServiceOptions", 147 | ] 148 | -------------------------------------------------------------------------------- /astrapy/py.typed: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/datastax/astrapy/ffc693317399673632b72baf3ccbeb52a20f1eec/astrapy/py.typed -------------------------------------------------------------------------------- /astrapy/settings/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright DataStax, Inc. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | -------------------------------------------------------------------------------- /astrapy/settings/defaults.py: -------------------------------------------------------------------------------- 1 | # Copyright DataStax, Inc. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | from __future__ import annotations 16 | 17 | import datetime 18 | 19 | # whether to go the extra mile and ensure the "decimal escape trick" is 20 | # not colliding with legitimate user-provided (extremely rare) strings. 21 | # A system setting, off at this time 22 | CHECK_DECIMAL_ESCAPING_CONSISTENCY = False 23 | 24 | # Environment names for management internal to astrapy 25 | DATA_API_ENVIRONMENT_PROD = "prod" 26 | DATA_API_ENVIRONMENT_DEV = "dev" 27 | DATA_API_ENVIRONMENT_TEST = "test" 28 | DATA_API_ENVIRONMENT_DSE = "dse" 29 | DATA_API_ENVIRONMENT_HCD = "hcd" 30 | DATA_API_ENVIRONMENT_CASSANDRA = "cassandra" 31 | DATA_API_ENVIRONMENT_OTHER = "other" 32 | 33 | # Defaults/settings for Database management 34 | DEFAULT_ASTRA_DB_KEYSPACE = "default_keyspace" 35 | API_ENDPOINT_TEMPLATE_ENV_MAP = { 36 | DATA_API_ENVIRONMENT_PROD: "https://{database_id}-{region}.apps.astra.datastax.com", 37 | DATA_API_ENVIRONMENT_DEV: "https://{database_id}-{region}.apps.astra-dev.datastax.com", 38 | DATA_API_ENVIRONMENT_TEST: "https://{database_id}-{region}.apps.astra-test.datastax.com", 39 | } 40 | API_PATH_ENV_MAP = { 41 | DATA_API_ENVIRONMENT_PROD: "/api/json", 42 | DATA_API_ENVIRONMENT_DEV: "/api/json", 43 | DATA_API_ENVIRONMENT_TEST: "/api/json", 44 | # 45 | DATA_API_ENVIRONMENT_DSE: "", 46 | DATA_API_ENVIRONMENT_HCD: "", 47 | DATA_API_ENVIRONMENT_CASSANDRA: "", 48 | DATA_API_ENVIRONMENT_OTHER: "", 49 | } 50 | API_VERSION_ENV_MAP = { 51 | DATA_API_ENVIRONMENT_PROD: "/v1", 52 | DATA_API_ENVIRONMENT_DEV: "/v1", 53 | DATA_API_ENVIRONMENT_TEST: "/v1", 54 | # 55 | DATA_API_ENVIRONMENT_DSE: "v1", 56 | DATA_API_ENVIRONMENT_HCD: "v1", 57 | DATA_API_ENVIRONMENT_CASSANDRA: "v1", 58 | DATA_API_ENVIRONMENT_OTHER: "v1", 59 | } 60 | 61 | # Defaults/settings for Data API requests 62 | DEFAULT_USE_DECIMALS_IN_COLLECTIONS = False 63 | DEFAULT_BINARY_ENCODE_VECTORS = True 64 | DEFAULT_CUSTOM_DATATYPES_IN_READING = True 65 | DEFAULT_UNROLL_ITERABLES_TO_LISTS = False 66 | DEFAULT_ENCODE_MAPS_AS_LISTS_IN_TABLES = "NEVER" # later coerced as MapEncodingMode 67 | 68 | DEFAULT_ACCEPT_NAIVE_DATETIMES = False 69 | DEFAULT_DATETIME_TZINFO = datetime.timezone.utc 70 | 71 | DEFAULT_INSERT_MANY_CHUNK_SIZE = 50 72 | DEFAULT_INSERT_MANY_CONCURRENCY = 20 73 | DEFAULT_REQUEST_TIMEOUT_MS = 10000 74 | DEFAULT_GENERAL_METHOD_TIMEOUT_MS = 30000 75 | DEFAULT_COLLECTION_ADMIN_TIMEOUT_MS = 60000 76 | DEFAULT_TABLE_ADMIN_TIMEOUT_MS = 30000 77 | DEFAULT_DATA_API_AUTH_HEADER = "Token" 78 | EMBEDDING_HEADER_AWS_ACCESS_ID = "X-Embedding-Access-Id" 79 | EMBEDDING_HEADER_AWS_SECRET_ID = "X-Embedding-Secret-Id" 80 | EMBEDDING_HEADER_API_KEY = "X-Embedding-Api-Key" 81 | RERANKING_HEADER_API_KEY = "Reranking-Api-Key" 82 | 83 | # Defaults/settings for DevOps API requests and admin operations 84 | DEFAULT_DEV_OPS_AUTH_HEADER = "Authorization" 85 | DEFAULT_DEV_OPS_AUTH_PREFIX = "Bearer " 86 | DEV_OPS_KEYSPACE_POLL_INTERVAL_S = 2 87 | DEV_OPS_DATABASE_POLL_INTERVAL_S = 15 88 | DEFAULT_DATABASE_ADMIN_TIMEOUT_MS = 600000 89 | DEFAULT_KEYSPACE_ADMIN_TIMEOUT_MS = 30000 90 | DEV_OPS_DATABASE_STATUS_MAINTENANCE = "MAINTENANCE" 91 | DEV_OPS_DATABASE_STATUS_ACTIVE = "ACTIVE" 92 | DEV_OPS_DATABASE_STATUS_PENDING = "PENDING" 93 | DEV_OPS_DATABASE_STATUS_INITIALIZING = "INITIALIZING" 94 | DEV_OPS_DATABASE_STATUS_ERROR = "ERROR" 95 | DEV_OPS_DATABASE_STATUS_TERMINATING = "TERMINATING" 96 | DEV_OPS_URL_ENV_MAP = { 97 | DATA_API_ENVIRONMENT_PROD: "https://api.astra.datastax.com", 98 | DATA_API_ENVIRONMENT_DEV: "https://api.dev.cloud.datastax.com", 99 | DATA_API_ENVIRONMENT_TEST: "https://api.test.cloud.datastax.com", 100 | } 101 | DEV_OPS_VERSION_ENV_MAP = { 102 | DATA_API_ENVIRONMENT_PROD: "v2", 103 | DATA_API_ENVIRONMENT_DEV: "v2", 104 | DATA_API_ENVIRONMENT_TEST: "v2", 105 | } 106 | DEV_OPS_RESPONSE_HTTP_ACCEPTED = 202 107 | DEV_OPS_RESPONSE_HTTP_CREATED = 201 108 | DEV_OPS_DEFAULT_DATABASES_PAGE_SIZE = 50 109 | 110 | # Settings for redacting secrets in string representations and logging 111 | SECRETS_REDACT_ENDING = "..." 112 | SECRETS_REDACT_CHAR = "*" 113 | SECRETS_REDACT_ENDING_LENGTH = 3 114 | FIXED_SECRET_PLACEHOLDER = "***" 115 | DEFAULT_REDACTED_HEADER_NAMES = { 116 | DEFAULT_DATA_API_AUTH_HEADER, 117 | DEFAULT_DEV_OPS_AUTH_HEADER, 118 | EMBEDDING_HEADER_AWS_ACCESS_ID, 119 | EMBEDDING_HEADER_AWS_SECRET_ID, 120 | EMBEDDING_HEADER_API_KEY, 121 | RERANKING_HEADER_API_KEY, 122 | } 123 | -------------------------------------------------------------------------------- /astrapy/settings/error_messages.py: -------------------------------------------------------------------------------- 1 | # Copyright DataStax, Inc. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | from __future__ import annotations 16 | 17 | CANNOT_ENCODE_NAIVE_DATETIME_ERROR_MESSAGE = ( 18 | "Cannot encode a datetime without timezone information ('tzinfo'). " 19 | "See the APIOptions.SerdesOptions.accept_naive_datetimes setting " 20 | "if you want to relax this write-time safeguard." 21 | ) 22 | -------------------------------------------------------------------------------- /astrapy/table.py: -------------------------------------------------------------------------------- 1 | # Copyright DataStax, Inc. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | from astrapy.data.table import ( 16 | AsyncTable, 17 | Table, 18 | ) 19 | 20 | __all__ = [ 21 | "AsyncTable", 22 | "Table", 23 | ] 24 | -------------------------------------------------------------------------------- /astrapy/utils/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright DataStax, Inc. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | -------------------------------------------------------------------------------- /astrapy/utils/meta.py: -------------------------------------------------------------------------------- 1 | # Copyright DataStax, Inc. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | from __future__ import annotations 16 | 17 | import warnings 18 | from functools import wraps 19 | from typing import Callable, TypeVar 20 | 21 | from deprecation import DeprecatedWarning 22 | from typing_extensions import ParamSpec # compatible with pre-3.10 Python 23 | 24 | P = ParamSpec("P") 25 | R = TypeVar("R") 26 | 27 | BETA_WARNING_TEMPLATE = ( 28 | "Method '{method_name}' is in beta and might undergo signature " 29 | "or behaviour changes in the future." 30 | ) 31 | 32 | 33 | class BetaFeatureWarning(UserWarning): 34 | pass 35 | 36 | 37 | def check_deprecated_alias( 38 | new_name: str | None, 39 | deprecated_name: str | None, 40 | ) -> str | None: 41 | """Generic blueprint utility for deprecating parameters through an alias. 42 | 43 | Normalize the two aliased parameter names, raising deprecation 44 | when needed and an error if both parameter supplied. 45 | The returned value is the final one for the parameter. 46 | """ 47 | 48 | if deprecated_name is None: 49 | # no need for deprecation nor exceptions 50 | return new_name 51 | else: 52 | # issue a deprecation warning 53 | the_warning = DeprecatedWarning( 54 | "Parameter 'deprecated_name'", 55 | deprecated_in="2.0.0", 56 | removed_in="3.0.0", 57 | details="Please use 'new_name' instead.", 58 | ) 59 | warnings.warn( 60 | the_warning, 61 | stacklevel=3, 62 | ) 63 | 64 | if new_name is None: 65 | return deprecated_name 66 | else: 67 | msg = ( 68 | "Parameters `new_name` and `deprecated_name` " 69 | "(a deprecated alias for the former) cannot be passed at the same time." 70 | ) 71 | raise ValueError(msg) 72 | 73 | 74 | def deprecated_property( 75 | new_name: str, deprecated_in: str, removed_in: str 76 | ) -> Callable[[Callable[P, R]], Callable[P, R]]: 77 | """ 78 | Decorator for a @property that is a deprecated alias for attribute 'new_name'. 79 | """ 80 | 81 | def _deprecator(method: Callable[P, R]) -> Callable[P, R]: 82 | @wraps(method) 83 | def wrapper(*args: P.args, **kwargs: P.kwargs) -> R: 84 | the_warning = DeprecatedWarning( 85 | f"Property '{method.__name__}'", 86 | deprecated_in=deprecated_in, 87 | removed_in=removed_in, 88 | details=f"Please use '{new_name}' instead.", 89 | ) 90 | warnings.warn( 91 | the_warning, 92 | stacklevel=2, 93 | ) 94 | return method(*args, **kwargs) 95 | 96 | return wrapper 97 | 98 | return _deprecator 99 | 100 | 101 | def beta_method(method: Callable[P, R]) -> Callable[P, R]: 102 | @wraps(method) 103 | def wrapper(*args: P.args, **kwargs: P.kwargs) -> R: 104 | beta_warning_message = BETA_WARNING_TEMPLATE.format( 105 | method_name=method.__qualname__ 106 | ) 107 | warnings.warn( 108 | beta_warning_message, 109 | BetaFeatureWarning, 110 | stacklevel=2, 111 | ) 112 | return method(*args, **kwargs) 113 | 114 | return wrapper 115 | -------------------------------------------------------------------------------- /astrapy/utils/parsing.py: -------------------------------------------------------------------------------- 1 | # Copyright DataStax, Inc. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | from __future__ import annotations 16 | 17 | import warnings 18 | from typing import Any 19 | 20 | 21 | def _warn_residual_keys( 22 | klass: type, raw_dict: dict[str, Any], known_keys: set[str] 23 | ) -> None: 24 | residual_keys = raw_dict.keys() - known_keys 25 | if residual_keys: 26 | warnings.warn( 27 | "Unexpected key(s) encountered parsing a dictionary into " 28 | f"a `{klass.__name__}`: '{','.join(sorted(residual_keys))}'" 29 | ) 30 | -------------------------------------------------------------------------------- /astrapy/utils/request_tools.py: -------------------------------------------------------------------------------- 1 | # Copyright DataStax, Inc. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | from __future__ import annotations 16 | 17 | import logging 18 | from typing import Any 19 | 20 | import httpx 21 | 22 | from astrapy.exceptions import _TimeoutContext 23 | 24 | logger = logging.getLogger(__name__) 25 | 26 | 27 | def log_httpx_request( 28 | http_method: str, 29 | full_url: str, 30 | request_params: dict[str, Any] | None, 31 | redacted_request_headers: dict[str, str], 32 | encoded_payload: str | None, 33 | timeout_context: _TimeoutContext, 34 | ) -> None: 35 | """ 36 | Log the details of an HTTP request for debugging purposes. 37 | 38 | Args: 39 | http_method: the HTTP verb of the request (e.g. "POST"). 40 | full_url: the URL of the request (e.g. "https://domain.com/full/path"). 41 | request_params: parameters of the request. 42 | redacted_request_headers: caution, as these will be logged as they are. 43 | encoded_payload: the payload (in bytes) sent with the request, if any. 44 | timeout_ms: the timeout in milliseconds, if any is set. 45 | """ 46 | logger.debug(f"Request URL: {http_method} {full_url}") 47 | if request_params: 48 | logger.debug(f"Request params: '{request_params}'") 49 | if redacted_request_headers: 50 | logger.debug(f"Request headers: '{redacted_request_headers}'") 51 | if encoded_payload is not None: 52 | logger.debug(f"Request payload: '{encoded_payload}'") 53 | if timeout_context: 54 | logger.debug( 55 | f"Timeout (ms): for request {timeout_context.request_ms or '(unset)'} ms" 56 | f", overall operation {timeout_context.nominal_ms or '(unset)'} ms" 57 | ) 58 | 59 | 60 | def log_httpx_response(response: httpx.Response) -> None: 61 | """ 62 | Log the details of an httpx.Response. 63 | 64 | Args: 65 | response: the httpx.Response object to log. 66 | """ 67 | logger.debug(f"Response status code: {response.status_code}") 68 | logger.debug(f"Response headers: '{response.headers}'") 69 | logger.debug(f"Response text: '{response.text}'") 70 | 71 | 72 | class HttpMethod: 73 | GET = "GET" 74 | POST = "POST" 75 | PUT = "PUT" 76 | PATCH = "PATCH" 77 | DELETE = "DELETE" 78 | 79 | 80 | def to_httpx_timeout(timeout_context: _TimeoutContext) -> httpx.Timeout | None: 81 | if timeout_context.request_ms is None or timeout_context.request_ms == 0: 82 | return None 83 | else: 84 | return httpx.Timeout(timeout_context.request_ms / 1000) 85 | -------------------------------------------------------------------------------- /astrapy/utils/str_enum.py: -------------------------------------------------------------------------------- 1 | # Copyright DataStax, Inc. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | from __future__ import annotations 16 | 17 | from enum import Enum, EnumMeta 18 | from typing import TypeVar 19 | 20 | T = TypeVar("T", bound="StrEnum") 21 | 22 | 23 | class StrEnumMeta(EnumMeta): 24 | def _name_lookup(cls, value: str) -> str | None: 25 | """Return a proper key in the enum if some matching logic works, or None.""" 26 | mmap = {k: v.value for k, v in cls._member_map_.items()} 27 | # try exact key match 28 | if value in mmap: 29 | return value 30 | # try case-insensitive key match 31 | u_value = value.upper() 32 | u_mmap = {k.upper(): k for k in mmap.keys()} 33 | if u_value in u_mmap: 34 | return u_mmap[u_value] 35 | # try case-insensitive *value* match 36 | v_mmap = {v.upper(): k for k, v in mmap.items()} 37 | if u_value in v_mmap: 38 | return v_mmap[u_value] 39 | return None 40 | 41 | def __contains__(cls, value: object) -> bool: 42 | """Return True if the provided string belongs to the enum.""" 43 | if isinstance(value, str): 44 | return cls._name_lookup(value) is not None 45 | return False 46 | 47 | 48 | class StrEnum(Enum, metaclass=StrEnumMeta): 49 | @classmethod 50 | def coerce(cls: type[T], value: str | T) -> T: 51 | """ 52 | Accepts either a string or an instance of the Enum itself. 53 | If a string is passed, it converts it to the corresponding 54 | Enum value (case-insensitive). 55 | If an Enum instance is passed, it returns it as-is. 56 | Raises ValueError if the string does not match any enum member. 57 | """ 58 | 59 | if isinstance(value, cls): 60 | return value 61 | elif isinstance(value, str): 62 | norm_value = cls._name_lookup(value) 63 | if norm_value is not None: 64 | return cls[norm_value] 65 | # no matches 66 | raise ValueError( 67 | f"Invalid value '{value}' for {cls.__name__}. " 68 | f"Allowed values are: {[e.value for e in cls]}" 69 | ) 70 | raise ValueError( 71 | f"Invalid value '{value}' for {cls.__name__}. " 72 | f"Allowed values are: {[e.value for e in cls]}" 73 | ) 74 | -------------------------------------------------------------------------------- /astrapy/utils/unset.py: -------------------------------------------------------------------------------- 1 | # Copyright DataStax, Inc. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | from __future__ import annotations 16 | 17 | 18 | class UnsetType: 19 | _instance = None 20 | 21 | def __new__(cls) -> UnsetType: 22 | if cls._instance is None: 23 | cls._instance = super().__new__(cls) 24 | return cls._instance 25 | 26 | def __repr__(self) -> str: 27 | return "(unset)" 28 | 29 | 30 | _UNSET = UnsetType() 31 | -------------------------------------------------------------------------------- /astrapy/utils/user_agents.py: -------------------------------------------------------------------------------- 1 | # Copyright DataStax, Inc. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | from __future__ import annotations 16 | 17 | from typing import Sequence 18 | 19 | from astrapy import __version__ 20 | from astrapy.constants import CallerType 21 | 22 | 23 | def detect_astrapy_user_agent() -> CallerType: 24 | package_name = __name__.split(".")[0] 25 | return (package_name, __version__) 26 | 27 | 28 | def compose_user_agent_string( 29 | caller_name: str | None, caller_version: str | None 30 | ) -> str | None: 31 | if caller_name: 32 | if caller_version: 33 | return f"{caller_name}/{caller_version}" 34 | else: 35 | return f"{caller_name}" 36 | else: 37 | return None 38 | 39 | 40 | def compose_full_user_agent(callers: Sequence[CallerType]) -> str | None: 41 | user_agent_strings = [ 42 | ua_string 43 | for ua_string in ( 44 | compose_user_agent_string(caller[0], caller[1]) for caller in callers 45 | ) 46 | if ua_string 47 | ] 48 | if user_agent_strings: 49 | return " ".join(user_agent_strings) 50 | else: 51 | return None 52 | -------------------------------------------------------------------------------- /pictures/astrapy_abstractions.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/datastax/astrapy/ffc693317399673632b72baf3ccbeb52a20f1eec/pictures/astrapy_abstractions.png -------------------------------------------------------------------------------- /pictures/astrapy_datetime_serdes_options.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/datastax/astrapy/ffc693317399673632b72baf3ccbeb52a20f1eec/pictures/astrapy_datetime_serdes_options.png -------------------------------------------------------------------------------- /pictures/astrapy_exceptions.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/datastax/astrapy/ffc693317399673632b72baf3ccbeb52a20f1eec/pictures/astrapy_exceptions.png -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [project] 2 | requires-python = ">=3.8,<4.0" 3 | name = "astrapy" 4 | version = "2.0.1" 5 | description = "A Python client for the Data API on DataStax Astra DB" 6 | authors = [ 7 | {"name" = "Stefano Lottini", "email" = "stefano.lottini@datastax.com"}, 8 | {"name" = "Eric Hare", "email" = "eric.hare@datastax.com"}, 9 | ] 10 | readme = "README.md" 11 | 12 | keywords = ["DataStax", "Astra DB", "Astra"] 13 | dependencies = [ 14 | "deprecation ~= 2.1.0", 15 | "httpx[http2]>=0.25.2,<1", 16 | "h11 >= 0.16.0", 17 | "pymongo >= 3", 18 | "toml >= 0.10.2,<1.0.0", 19 | "typing-extensions >= 4.0", 20 | "uuid6 >= 2024.1.12" 21 | ] 22 | classifiers = [ 23 | "Development Status :: 5 - Production/Stable", 24 | "Intended Audience :: Developers", 25 | "License :: OSI Approved :: Apache Software License", 26 | "Programming Language :: Python :: 3.8", 27 | "Programming Language :: Python :: 3.9", 28 | "Programming Language :: Python :: 3.10", 29 | "Programming Language :: Python :: 3.11", 30 | "Programming Language :: Python :: 3.12", 31 | "Topic :: Software Development :: Build Tools" 32 | ] 33 | 34 | [project.urls] 35 | Homepage = "https://github.com/datastax/astrapy" 36 | Documentation = "https://docs.datastax.com/en/astra-db-serverless/api-reference/dataapiclient.html" 37 | Repository = "https://github.com/datastax/astrapy" 38 | Issues = "https://github.com/datastax/astrapy/issues" 39 | Changelog = "https://github.com/datastax/astrapy/blob/main/CHANGES" 40 | 41 | [dependency-groups] 42 | dev = [ 43 | "blockbuster ~= 1.5.5", 44 | "build >= 1.0.0", 45 | "cassio ~= 0.1.10; python_version >= '3.9'", 46 | "faker ~= 23.1.0", 47 | "mypy ~= 1.9.0", 48 | "ruff >= 0.11.9,<0.12", 49 | "pre-commit ~= 3.5.0", 50 | "pytest ~= 8.0.0", 51 | "pytest-asyncio ~= 0.23.5", 52 | "pytest-cov ~= 4.1.0", 53 | "pytest-testdox ~= 3.1.0", 54 | "python-dotenv ~= 1.0.1", 55 | "pytest-httpserver ~= 1.0.8", 56 | "setuptools >= 61.0", 57 | "testcontainers ~= 3.7.1", 58 | "types-toml >= 0.10.8.7,<1.0.0" 59 | ] 60 | 61 | [tool.hatch.build.targets.wheel] 62 | packages = ["astrapy"] 63 | 64 | [build-system] 65 | requires = ["hatchling"] 66 | build-backend = "hatchling.build" 67 | 68 | [tool.setuptools.packages.find] 69 | include = ["astrapy*"] 70 | 71 | [tool.ruff.lint] 72 | select = ["E4", "E7", "E9", "F", "FA", "I", "UP"] 73 | 74 | [tool.mypy] 75 | disallow_any_generics = true 76 | disallow_incomplete_defs = true 77 | disallow_untyped_calls = true 78 | disallow_untyped_decorators = true 79 | disallow_untyped_defs = true 80 | follow_imports = "normal" 81 | ignore_missing_imports = true 82 | no_implicit_reexport = true 83 | show_error_codes = true 84 | show_error_context = true 85 | strict_equality = true 86 | strict_optional = true 87 | warn_redundant_casts = true 88 | warn_return_any = true 89 | warn_unused_ignores = true 90 | 91 | [tool.pytest.ini_options] 92 | filterwarnings = "ignore::DeprecationWarning" 93 | addopts = "-v --cov=astrapy --testdox --cov-report term-missing" 94 | asyncio_mode = "auto" 95 | log_cli = 1 96 | log_cli_level = "INFO" 97 | -------------------------------------------------------------------------------- /tests/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright DataStax, Inc. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | from __future__ import annotations 16 | -------------------------------------------------------------------------------- /tests/admin/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright DataStax, Inc. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | from __future__ import annotations 16 | -------------------------------------------------------------------------------- /tests/admin/conftest.py: -------------------------------------------------------------------------------- 1 | # Copyright DataStax, Inc. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | from __future__ import annotations 16 | 17 | from ..conftest import ( 18 | ADMIN_ENV_LIST, 19 | ADMIN_ENV_VARIABLE_MAP, 20 | IS_ASTRA_DB, 21 | ) 22 | 23 | __all__ = [ 24 | "ADMIN_ENV_LIST", 25 | "ADMIN_ENV_VARIABLE_MAP", 26 | "IS_ASTRA_DB", 27 | ] 28 | -------------------------------------------------------------------------------- /tests/admin/integration/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright DataStax, Inc. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | from __future__ import annotations 16 | -------------------------------------------------------------------------------- /tests/base/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright DataStax, Inc. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | from __future__ import annotations 16 | -------------------------------------------------------------------------------- /tests/base/collection_decimal_support_assets.py: -------------------------------------------------------------------------------- 1 | # Copyright DataStax, Inc. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | from __future__ import annotations 16 | 17 | from decimal import Decimal 18 | from typing import Any 19 | 20 | from astrapy.utils.api_options import FullSerdesOptions 21 | 22 | S_OPTS_NO_DECS = FullSerdesOptions( 23 | binary_encode_vectors=False, 24 | custom_datatypes_in_reading=True, 25 | unroll_iterables_to_lists=True, 26 | use_decimals_in_collections=False, 27 | encode_maps_as_lists_in_tables="NEVER", 28 | accept_naive_datetimes=False, 29 | datetime_tzinfo=None, 30 | ) 31 | S_OPTS_OK_DECS = FullSerdesOptions( 32 | binary_encode_vectors=False, 33 | custom_datatypes_in_reading=True, 34 | unroll_iterables_to_lists=True, 35 | use_decimals_in_collections=True, 36 | encode_maps_as_lists_in_tables="NEVER", 37 | accept_naive_datetimes=False, 38 | datetime_tzinfo=None, 39 | ) 40 | _BASELINE_SCALAR_CASES = { 41 | "_id": "baseline", 42 | "f": 1.23, 43 | "i": 123, 44 | "t": "T", 45 | } 46 | _W_DECIMALS_SCALAR_CASES = { 47 | "_id": "decimals", 48 | "f": 1.23, 49 | "i": 123, 50 | "t": "T", 51 | "de": Decimal("1.23"), 52 | "dfA": Decimal(1.23), 53 | "dfB": Decimal("1.229999999999999982236431605997495353221893310546875"), 54 | } 55 | BASELINE_OBJ = { 56 | **_BASELINE_SCALAR_CASES, 57 | "subdict": _BASELINE_SCALAR_CASES, 58 | "sublist": list(_BASELINE_SCALAR_CASES.values()), 59 | } 60 | WDECS_OBJ = { 61 | **_W_DECIMALS_SCALAR_CASES, 62 | "subdict": _W_DECIMALS_SCALAR_CASES, 63 | "sublist": list(_W_DECIMALS_SCALAR_CASES.values()), 64 | } 65 | 66 | 67 | def is_decimal_super(more_decs: dict[str, Any], less_decs: dict[str, Any]) -> bool: 68 | """ 69 | Return True if the first item is "the same values, possibly made Decimal 70 | where the corresponding second item can be another number (int/float)". 71 | """ 72 | if isinstance(more_decs, list): 73 | if not isinstance(less_decs, list): 74 | return False 75 | if len(more_decs) != len(less_decs): 76 | return False 77 | return all( 78 | [ 79 | is_decimal_super(v_more, v_less) 80 | for v_more, v_less in zip(more_decs, less_decs) 81 | ] 82 | ) 83 | elif isinstance(more_decs, dict): 84 | if not isinstance(less_decs, dict): 85 | return False 86 | if more_decs.keys() != less_decs.keys(): 87 | return False 88 | return all( 89 | [is_decimal_super(v_more, less_decs[k]) for k, v_more in more_decs.items()] 90 | ) 91 | else: 92 | # other scalars 93 | if isinstance(more_decs, Decimal): 94 | if isinstance(less_decs, Decimal): 95 | return more_decs == less_decs 96 | else: 97 | return float(more_decs) == float(less_decs) 98 | else: 99 | if isinstance(less_decs, Decimal): 100 | return False 101 | return more_decs == less_decs 102 | -------------------------------------------------------------------------------- /tests/base/integration/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright DataStax, Inc. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | from __future__ import annotations 16 | -------------------------------------------------------------------------------- /tests/base/integration/collection_decimal_support_assets.py: -------------------------------------------------------------------------------- 1 | # Copyright DataStax, Inc. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | from __future__ import annotations 16 | 17 | from ..collection_decimal_support_assets import ( 18 | BASELINE_OBJ, 19 | S_OPTS_NO_DECS, 20 | S_OPTS_OK_DECS, 21 | WDECS_OBJ, 22 | is_decimal_super, 23 | ) 24 | 25 | __all__ = [ 26 | "BASELINE_OBJ", 27 | "S_OPTS_NO_DECS", 28 | "S_OPTS_OK_DECS", 29 | "WDECS_OBJ", 30 | "is_decimal_super", 31 | ] 32 | -------------------------------------------------------------------------------- /tests/base/integration/collections/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright DataStax, Inc. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | from __future__ import annotations 16 | -------------------------------------------------------------------------------- /tests/base/integration/collections/test_collection_decimal_support.py: -------------------------------------------------------------------------------- 1 | # Copyright DataStax, Inc. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | from __future__ import annotations 16 | 17 | import pytest 18 | 19 | from astrapy.api_options import APIOptions 20 | 21 | from ..collection_decimal_support_assets import ( 22 | BASELINE_OBJ, 23 | S_OPTS_NO_DECS, 24 | S_OPTS_OK_DECS, 25 | WDECS_OBJ, 26 | is_decimal_super, 27 | ) 28 | from ..conftest import DefaultAsyncCollection, DefaultCollection 29 | 30 | 31 | class TestCollectionDecimalSupportIntegration: 32 | @pytest.mark.describe( 33 | "test of decimals not supported by default in collections, sync" 34 | ) 35 | def test_decimalsupport_collections_defaultsettings_sync( 36 | self, 37 | sync_empty_collection: DefaultCollection, 38 | ) -> None: 39 | # write-and-read, baseline 40 | no_decimal_coll = sync_empty_collection.with_options( 41 | api_options=APIOptions( 42 | serdes_options=S_OPTS_NO_DECS, 43 | ), 44 | ) 45 | no_decimal_coll.insert_one(BASELINE_OBJ) 46 | baseline_obj_2 = no_decimal_coll.find_one({"_id": BASELINE_OBJ["_id"]}) 47 | assert baseline_obj_2 is not None 48 | assert BASELINE_OBJ == baseline_obj_2 49 | 50 | # the write should error for object with decimals 51 | with pytest.raises(TypeError): 52 | no_decimal_coll.insert_one(WDECS_OBJ) 53 | 54 | @pytest.mark.describe( 55 | "test of decimals not supported by default in collections, async" 56 | ) 57 | async def test_decimalsupport_collections_defaultsettings_async( 58 | self, 59 | async_empty_collection: DefaultAsyncCollection, 60 | ) -> None: 61 | # write-and-read, baseline 62 | no_decimal_acoll = async_empty_collection.with_options( 63 | api_options=APIOptions( 64 | serdes_options=S_OPTS_NO_DECS, 65 | ), 66 | ) 67 | await no_decimal_acoll.insert_one(BASELINE_OBJ) 68 | baseline_obj_2 = await no_decimal_acoll.find_one({"_id": BASELINE_OBJ["_id"]}) 69 | assert baseline_obj_2 is not None 70 | assert BASELINE_OBJ == baseline_obj_2 71 | 72 | # the write should error for object with decimals 73 | with pytest.raises(TypeError): 74 | await no_decimal_acoll.insert_one(WDECS_OBJ) 75 | 76 | @pytest.mark.describe( 77 | "test of decimals supported in collections if set to do so, sync" 78 | ) 79 | def test_decimalsupport_collections_decimalsettings_sync( 80 | self, 81 | sync_empty_collection: DefaultCollection, 82 | ) -> None: 83 | # write-and-read, baseline 84 | ok_decimal_coll = sync_empty_collection.with_options( 85 | api_options=APIOptions( 86 | serdes_options=S_OPTS_OK_DECS, 87 | ), 88 | ) 89 | ok_decimal_coll.insert_one(BASELINE_OBJ) 90 | baseline_obj_2 = ok_decimal_coll.find_one({"_id": BASELINE_OBJ["_id"]}) 91 | assert baseline_obj_2 is not None 92 | assert is_decimal_super(baseline_obj_2, BASELINE_OBJ) 93 | 94 | # write-and-read, with decimals 95 | ok_decimal_coll.insert_one(WDECS_OBJ) 96 | wdecs_2 = ok_decimal_coll.find_one({"_id": WDECS_OBJ["_id"]}) 97 | assert wdecs_2 is not None 98 | assert is_decimal_super(wdecs_2, WDECS_OBJ) 99 | 100 | @pytest.mark.describe( 101 | "test of decimals supported in collections if set to do so, async" 102 | ) 103 | async def test_decimalsupport_collections_decimalsettings_async( 104 | self, 105 | async_empty_collection: DefaultAsyncCollection, 106 | ) -> None: 107 | # write-and-read, baseline 108 | ok_decimal_acoll = async_empty_collection.with_options( 109 | api_options=APIOptions( 110 | serdes_options=S_OPTS_OK_DECS, 111 | ), 112 | ) 113 | await ok_decimal_acoll.insert_one(BASELINE_OBJ) 114 | baseline_obj_2 = await ok_decimal_acoll.find_one({"_id": BASELINE_OBJ["_id"]}) 115 | assert baseline_obj_2 is not None 116 | assert is_decimal_super(baseline_obj_2, BASELINE_OBJ) 117 | 118 | # write-and-read, with decimals 119 | await ok_decimal_acoll.insert_one(WDECS_OBJ) 120 | wdecs_2 = await ok_decimal_acoll.find_one({"_id": WDECS_OBJ["_id"]}) 121 | assert wdecs_2 is not None 122 | assert is_decimal_super(wdecs_2, WDECS_OBJ) 123 | -------------------------------------------------------------------------------- /tests/base/integration/conftest.py: -------------------------------------------------------------------------------- 1 | # Copyright DataStax, Inc. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | from __future__ import annotations 16 | 17 | from ..conftest import ( 18 | ADMIN_ENV_LIST, 19 | ADMIN_ENV_VARIABLE_MAP, 20 | CQL_AVAILABLE, 21 | HEADER_EMBEDDING_API_KEY_OPENAI, 22 | IS_ASTRA_DB, 23 | SECONDARY_KEYSPACE, 24 | TEST_COLLECTION_NAME, 25 | VECTORIZE_TEXTS, 26 | DataAPICredentials, 27 | DataAPICredentialsInfo, 28 | DefaultAsyncCollection, 29 | DefaultAsyncTable, 30 | DefaultCollection, 31 | DefaultTable, 32 | _repaint_NaNs, 33 | _typify_tuple, 34 | async_fail_if_not_removed, 35 | clean_nulls_from_dict, 36 | sync_fail_if_not_removed, 37 | ) 38 | 39 | __all__ = [ 40 | "DataAPICredentials", 41 | "DataAPICredentialsInfo", 42 | "async_fail_if_not_removed", 43 | "clean_nulls_from_dict", 44 | "sync_fail_if_not_removed", 45 | "HEADER_EMBEDDING_API_KEY_OPENAI", 46 | "IS_ASTRA_DB", 47 | "ADMIN_ENV_LIST", 48 | "ADMIN_ENV_VARIABLE_MAP", 49 | "CQL_AVAILABLE", 50 | "SECONDARY_KEYSPACE", 51 | "TEST_COLLECTION_NAME", 52 | "VECTORIZE_TEXTS", 53 | "_repaint_NaNs", 54 | "_typify_tuple", 55 | "DefaultCollection", 56 | "DefaultAsyncCollection", 57 | "DefaultAsyncTable", 58 | "DefaultTable", 59 | ] 60 | -------------------------------------------------------------------------------- /tests/base/integration/misc/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright DataStax, Inc. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | -------------------------------------------------------------------------------- /tests/base/integration/misc/test_admin_ops_async.py: -------------------------------------------------------------------------------- 1 | # Copyright DataStax, Inc. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | from __future__ import annotations 16 | 17 | import pytest 18 | 19 | from astrapy import DataAPIClient 20 | 21 | from ..conftest import IS_ASTRA_DB, DataAPICredentials 22 | 23 | 24 | class TestFindRegionsAsync: 25 | @pytest.mark.skipif(not IS_ASTRA_DB, reason="This test requires Astra DB") 26 | @pytest.mark.describe("test of find_available_regions, async") 27 | async def test_findavailableregions_async( 28 | self, 29 | client: DataAPIClient, 30 | data_api_credentials_kwargs: DataAPICredentials, 31 | ) -> None: 32 | admin = client.get_admin(token=data_api_credentials_kwargs["token"]) 33 | 34 | ar0 = await admin.async_find_available_regions() 35 | art = await admin.async_find_available_regions(only_org_enabled_regions=True) 36 | arf = await admin.async_find_available_regions(only_org_enabled_regions=False) 37 | 38 | assert ar0 == art 39 | assert len(arf) >= len(art) 40 | assert all(reg in arf for reg in art) 41 | -------------------------------------------------------------------------------- /tests/base/integration/misc/test_admin_ops_sync.py: -------------------------------------------------------------------------------- 1 | # Copyright DataStax, Inc. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | from __future__ import annotations 16 | 17 | import pytest 18 | 19 | from astrapy import DataAPIClient 20 | 21 | from ..conftest import IS_ASTRA_DB, DataAPICredentials 22 | 23 | 24 | class TestFindRegionsSync: 25 | @pytest.mark.skipif(not IS_ASTRA_DB, reason="This test requires Astra DB") 26 | @pytest.mark.describe("test of find_available_regions, sync") 27 | def test_findavailableregions_sync( 28 | self, 29 | client: DataAPIClient, 30 | data_api_credentials_kwargs: DataAPICredentials, 31 | ) -> None: 32 | admin = client.get_admin(token=data_api_credentials_kwargs["token"]) 33 | 34 | ar0 = admin.find_available_regions() 35 | art = admin.find_available_regions(only_org_enabled_regions=True) 36 | arf = admin.find_available_regions(only_org_enabled_regions=False) 37 | 38 | assert ar0 == art 39 | assert len(arf) >= len(art) 40 | assert all(reg in arf for reg in art) 41 | -------------------------------------------------------------------------------- /tests/base/integration/misc/test_reranking_ops_async.py: -------------------------------------------------------------------------------- 1 | # Copyright DataStax, Inc. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | from __future__ import annotations 16 | 17 | import os 18 | 19 | import pytest 20 | 21 | from astrapy import AsyncDatabase 22 | from astrapy.info import FindRerankingProvidersResult, RerankingProvider 23 | 24 | from ..conftest import IS_ASTRA_DB 25 | 26 | 27 | def _count_models(frp_result: FindRerankingProvidersResult) -> int: 28 | return len( 29 | [ 30 | model 31 | for prov_v in frp_result.reranking_providers.values() 32 | for model in prov_v.models 33 | ] 34 | ) 35 | 36 | 37 | class TestRerankingOpsAsync: 38 | @pytest.mark.describe("test of find_reranking_providers, async") 39 | async def test_findrerankingproviders_async( 40 | self, 41 | async_database: AsyncDatabase, 42 | ) -> None: 43 | database_admin = async_database.get_database_admin() 44 | rp_result = await database_admin.async_find_reranking_providers() 45 | 46 | assert isinstance(rp_result, FindRerankingProvidersResult) 47 | 48 | assert all( 49 | isinstance(rer_prov, RerankingProvider) 50 | for rer_prov in rp_result.reranking_providers.values() 51 | ) 52 | 53 | # 'raw_info' not compared, for resiliency against newly-introduced fields 54 | assert ( 55 | FindRerankingProvidersResult._from_dict( 56 | rp_result.as_dict() 57 | ).reranking_providers 58 | == rp_result.reranking_providers 59 | ) 60 | 61 | @pytest.mark.skipif( 62 | "ASTRAPY_TEST_LATEST_MAIN" not in os.environ, 63 | reason="No 'latest main' tests required.", 64 | ) 65 | @pytest.mark.skipif(IS_ASTRA_DB, reason="Filtering models not yet on Astra DB") 66 | @pytest.mark.describe("test of find_reranking_providers filtering, async") 67 | async def test_filtered_findrerankingproviders_async( 68 | self, 69 | async_database: AsyncDatabase, 70 | ) -> None: 71 | database_admin = async_database.get_database_admin() 72 | default_count = _count_models( 73 | await database_admin.async_find_reranking_providers() 74 | ) 75 | 76 | all_count = _count_models( 77 | await database_admin.async_find_reranking_providers(filter_model_status="") 78 | ) 79 | 80 | sup_count = _count_models( 81 | await database_admin.async_find_reranking_providers( 82 | filter_model_status="SUPPORTED" 83 | ) 84 | ) 85 | dep_count = _count_models( 86 | await database_admin.async_find_reranking_providers( 87 | filter_model_status="DEPRECATED" 88 | ) 89 | ) 90 | eol_count = _count_models( 91 | await database_admin.async_find_reranking_providers( 92 | filter_model_status="END_OF_LIFE" 93 | ) 94 | ) 95 | 96 | assert sup_count + dep_count + eol_count == all_count 97 | assert sup_count == default_count 98 | -------------------------------------------------------------------------------- /tests/base/integration/misc/test_reranking_ops_sync.py: -------------------------------------------------------------------------------- 1 | # Copyright DataStax, Inc. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | from __future__ import annotations 16 | 17 | import os 18 | 19 | import pytest 20 | 21 | from astrapy import Database 22 | from astrapy.info import FindRerankingProvidersResult, RerankingProvider 23 | 24 | from ..conftest import IS_ASTRA_DB 25 | 26 | 27 | def _count_models(frp_result: FindRerankingProvidersResult) -> int: 28 | return len( 29 | [ 30 | model 31 | for prov_v in frp_result.reranking_providers.values() 32 | for model in prov_v.models 33 | ] 34 | ) 35 | 36 | 37 | class TestRerankingOpsSync: 38 | @pytest.mark.describe("test of find_reranking_providers, sync") 39 | def test_findrerankingproviders_sync( 40 | self, 41 | sync_database: Database, 42 | ) -> None: 43 | database_admin = sync_database.get_database_admin() 44 | rp_result = database_admin.find_reranking_providers() 45 | 46 | assert isinstance(rp_result, FindRerankingProvidersResult) 47 | 48 | assert all( 49 | isinstance(rer_prov, RerankingProvider) 50 | for rer_prov in rp_result.reranking_providers.values() 51 | ) 52 | 53 | # 'raw_info' not compared, for resiliency against newly-introduced fields 54 | assert ( 55 | FindRerankingProvidersResult._from_dict( 56 | rp_result.as_dict() 57 | ).reranking_providers 58 | == rp_result.reranking_providers 59 | ) 60 | 61 | @pytest.mark.skipif( 62 | "ASTRAPY_TEST_LATEST_MAIN" not in os.environ, 63 | reason="No 'latest main' tests required.", 64 | ) 65 | @pytest.mark.skipif(IS_ASTRA_DB, reason="Filtering models not yet on Astra DB") 66 | @pytest.mark.describe("test of find_reranking_providers filtering, sync") 67 | def test_filtered_findrerankingproviders_sync( 68 | self, 69 | sync_database: Database, 70 | ) -> None: 71 | database_admin = sync_database.get_database_admin() 72 | default_count = _count_models(database_admin.find_reranking_providers()) 73 | 74 | all_count = _count_models( 75 | database_admin.find_reranking_providers(filter_model_status="") 76 | ) 77 | 78 | sup_count = _count_models( 79 | database_admin.find_reranking_providers(filter_model_status="SUPPORTED") 80 | ) 81 | dep_count = _count_models( 82 | database_admin.find_reranking_providers(filter_model_status="DEPRECATED") 83 | ) 84 | eol_count = _count_models( 85 | database_admin.find_reranking_providers(filter_model_status="END_OF_LIFE") 86 | ) 87 | 88 | assert sup_count + dep_count + eol_count == all_count 89 | assert sup_count == default_count 90 | -------------------------------------------------------------------------------- /tests/base/integration/misc/test_vectorize_ops_async.py: -------------------------------------------------------------------------------- 1 | # Copyright DataStax, Inc. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | from __future__ import annotations 16 | 17 | import os 18 | 19 | import pytest 20 | 21 | from astrapy import AsyncDatabase 22 | from astrapy.info import EmbeddingProvider, FindEmbeddingProvidersResult 23 | 24 | from ..conftest import IS_ASTRA_DB, clean_nulls_from_dict 25 | 26 | 27 | def _count_models(fep_result: FindEmbeddingProvidersResult) -> int: 28 | return len( 29 | [ 30 | model 31 | for prov_v in fep_result.embedding_providers.values() 32 | for model in prov_v.models 33 | ] 34 | ) 35 | 36 | 37 | class TestVectorizeOpsAsync: 38 | @pytest.mark.describe("test of find_embedding_providers, async") 39 | async def test_findembeddingproviders_async( 40 | self, 41 | async_database: AsyncDatabase, 42 | ) -> None: 43 | database_admin = async_database.get_database_admin() 44 | ep_result = await database_admin.async_find_embedding_providers() 45 | 46 | assert isinstance(ep_result, FindEmbeddingProvidersResult) 47 | 48 | assert all( 49 | isinstance(emb_prov, EmbeddingProvider) 50 | for emb_prov in ep_result.embedding_providers.values() 51 | ) 52 | 53 | reconstructed = { 54 | ep_name: EmbeddingProvider._from_dict(emb_prov.as_dict()) 55 | for ep_name, emb_prov in ep_result.embedding_providers.items() 56 | } 57 | assert reconstructed == ep_result.embedding_providers 58 | dict_mapping = { 59 | ep_name: emb_prov.as_dict() 60 | for ep_name, emb_prov in ep_result.embedding_providers.items() 61 | } 62 | cleaned_dict_mapping = clean_nulls_from_dict(dict_mapping) 63 | cleaned_raw_info = clean_nulls_from_dict( 64 | ep_result.raw_info["embeddingProviders"] # type: ignore[index] 65 | ) 66 | # TODO remove this cleanup once Astra prod and HCD get the support status flags 67 | raw_info_has_ams = any( 68 | "apiModelSupport" in model_dict 69 | for prov_v in cleaned_raw_info.values() 70 | for model_dict in prov_v["models"] 71 | ) 72 | if not raw_info_has_ams: 73 | for emb_prov in cleaned_dict_mapping.values(): 74 | for model in emb_prov["models"]: 75 | if "apiModelSupport" in model: 76 | del model["apiModelSupport"] 77 | assert cleaned_dict_mapping == cleaned_raw_info 78 | 79 | @pytest.mark.skipif( 80 | "ASTRAPY_TEST_LATEST_MAIN" not in os.environ, 81 | reason="No 'latest main' tests required.", 82 | ) 83 | @pytest.mark.skipif(IS_ASTRA_DB, reason="Filtering models not yet on Astra DB") 84 | @pytest.mark.describe("test of find_embedding_providers filtering, async") 85 | async def test_filtered_findembeddingproviders_async( 86 | self, 87 | async_database: AsyncDatabase, 88 | ) -> None: 89 | database_admin = async_database.get_database_admin() 90 | default_count = _count_models( 91 | await database_admin.async_find_embedding_providers() 92 | ) 93 | 94 | all_count = _count_models( 95 | await database_admin.async_find_embedding_providers(filter_model_status="") 96 | ) 97 | 98 | sup_count = _count_models( 99 | await database_admin.async_find_embedding_providers( 100 | filter_model_status="SUPPORTED" 101 | ) 102 | ) 103 | dep_count = _count_models( 104 | await database_admin.async_find_embedding_providers( 105 | filter_model_status="DEPRECATED" 106 | ) 107 | ) 108 | eol_count = _count_models( 109 | await database_admin.async_find_embedding_providers( 110 | filter_model_status="END_OF_LIFE" 111 | ) 112 | ) 113 | 114 | assert sup_count + dep_count + eol_count == all_count 115 | assert sup_count == default_count 116 | -------------------------------------------------------------------------------- /tests/base/integration/misc/test_vectorize_ops_sync.py: -------------------------------------------------------------------------------- 1 | # Copyright DataStax, Inc. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | from __future__ import annotations 16 | 17 | import os 18 | 19 | import pytest 20 | 21 | from astrapy import Database 22 | from astrapy.info import EmbeddingProvider, FindEmbeddingProvidersResult 23 | 24 | from ..conftest import IS_ASTRA_DB, clean_nulls_from_dict 25 | 26 | 27 | def _count_models(fep_result: FindEmbeddingProvidersResult) -> int: 28 | return len( 29 | [ 30 | model 31 | for prov_v in fep_result.embedding_providers.values() 32 | for model in prov_v.models 33 | ] 34 | ) 35 | 36 | 37 | class TestVectorizeOpsSync: 38 | @pytest.mark.describe("test of find_embedding_providers, sync") 39 | def test_findembeddingproviders_sync( 40 | self, 41 | sync_database: Database, 42 | ) -> None: 43 | database_admin = sync_database.get_database_admin() 44 | ep_result = database_admin.find_embedding_providers() 45 | 46 | assert isinstance(ep_result, FindEmbeddingProvidersResult) 47 | 48 | assert all( 49 | isinstance(emb_prov, EmbeddingProvider) 50 | for emb_prov in ep_result.embedding_providers.values() 51 | ) 52 | 53 | reconstructed = { 54 | ep_name: EmbeddingProvider._from_dict(emb_prov.as_dict()) 55 | for ep_name, emb_prov in ep_result.embedding_providers.items() 56 | } 57 | assert reconstructed == ep_result.embedding_providers 58 | dict_mapping = { 59 | ep_name: emb_prov.as_dict() 60 | for ep_name, emb_prov in ep_result.embedding_providers.items() 61 | } 62 | cleaned_dict_mapping = clean_nulls_from_dict(dict_mapping) 63 | cleaned_raw_info = clean_nulls_from_dict( 64 | ep_result.raw_info["embeddingProviders"] # type: ignore[index] 65 | ) 66 | # TODO remove this cleanup once Astra prod and HCD get the support status flags 67 | raw_info_has_ams = any( 68 | "apiModelSupport" in model_dict 69 | for prov_v in cleaned_raw_info.values() 70 | for model_dict in prov_v["models"] 71 | ) 72 | if not raw_info_has_ams: 73 | for emb_prov in cleaned_dict_mapping.values(): 74 | for model in emb_prov["models"]: 75 | if "apiModelSupport" in model: 76 | del model["apiModelSupport"] 77 | assert cleaned_dict_mapping == cleaned_raw_info 78 | 79 | @pytest.mark.skipif( 80 | "ASTRAPY_TEST_LATEST_MAIN" not in os.environ, 81 | reason="No 'latest main' tests required.", 82 | ) 83 | @pytest.mark.skipif(IS_ASTRA_DB, reason="Filtering models not yet on Astra DB") 84 | @pytest.mark.describe("test of find_embedding_providers filtering, sync") 85 | def test_filtered_findembeddingproviders_sync( 86 | self, 87 | sync_database: Database, 88 | ) -> None: 89 | database_admin = sync_database.get_database_admin() 90 | default_count = _count_models(database_admin.find_embedding_providers()) 91 | 92 | all_count = _count_models( 93 | database_admin.find_embedding_providers(filter_model_status="") 94 | ) 95 | 96 | sup_count = _count_models( 97 | database_admin.find_embedding_providers(filter_model_status="SUPPORTED") 98 | ) 99 | dep_count = _count_models( 100 | database_admin.find_embedding_providers(filter_model_status="DEPRECATED") 101 | ) 102 | eol_count = _count_models( 103 | database_admin.find_embedding_providers(filter_model_status="END_OF_LIFE") 104 | ) 105 | 106 | assert sup_count + dep_count + eol_count == all_count 107 | assert sup_count == default_count 108 | -------------------------------------------------------------------------------- /tests/base/integration/tables/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright DataStax, Inc. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | -------------------------------------------------------------------------------- /tests/base/integration/tables/table_cql_assets.py: -------------------------------------------------------------------------------- 1 | # Copyright DataStax, Inc. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | from __future__ import annotations 16 | 17 | from astrapy.data_types import DataAPITimestamp 18 | 19 | TABLE_NAME_COUNTER = "test_table_counter" 20 | CREATE_TABLE_COUNTER = ( 21 | f"CREATE TABLE {TABLE_NAME_COUNTER} (a TEXT PRIMARY KEY, col_counter COUNTER);" 22 | ) 23 | DROP_TABLE_COUNTER = f"DROP TABLE {TABLE_NAME_COUNTER};" 24 | INSERTS_TABLE_COUNTER = [ 25 | f"UPDATE {TABLE_NAME_COUNTER} SET col_counter=col_counter+137 WHERE a='a';", 26 | f"UPDATE {TABLE_NAME_COUNTER} SET col_counter=col_counter+1 WHERE a='z';", 27 | ] 28 | FILTER_COUNTER = {"a": "a"} 29 | EXPECTED_ROW_COUNTER = {"a": "a", "col_counter": 137} 30 | 31 | TABLE_NAME_LOWSUPPORT = "test_table_lowsupport" 32 | TYPE_NAME_LOWSUPPORT = "test_type_lowsupport" 33 | CREATE_TYPE_LOWSUPPORT = ( 34 | f"CREATE TYPE {TYPE_NAME_LOWSUPPORT} (genus TEXT, species TEXT, size FLOAT);" 35 | ) 36 | CREATE_TABLE_LOWSUPPORT = f"""CREATE TABLE {TABLE_NAME_LOWSUPPORT} ( 37 | a TEXT, 38 | b TEXT, 39 | col_static_timestamp TIMESTAMP STATIC, 40 | col_static_list LIST STATIC, 41 | col_static_list_exotic LIST STATIC, 42 | col_static_set SET STATIC, 43 | col_static_set_exotic SET STATIC, 44 | col_static_map MAP STATIC, 45 | col_static_map_exotic MAP STATIC, 46 | col_unsupported FROZEN>, SMALLINT>>>>, 47 | col_udt {TYPE_NAME_LOWSUPPORT}, 48 | PRIMARY KEY ((A),B) 49 | );""" 50 | DROP_TABLE_LOWSUPPORT = f"DROP TABLE {TABLE_NAME_LOWSUPPORT};" 51 | DROP_TYPE_LOWSUPPORT = f"DROP TYPE {TYPE_NAME_LOWSUPPORT};" 52 | INSERTS_TABLE_LOWSUPPORT = [ 53 | ( 54 | f"INSERT INTO {TABLE_NAME_LOWSUPPORT}" 55 | " (" 56 | " a," 57 | " b," 58 | " col_static_timestamp," 59 | " col_static_list," 60 | " col_static_list_exotic," 61 | " col_static_set," 62 | " col_static_set_exotic," 63 | " col_static_map," 64 | " col_static_map_exotic," 65 | " col_unsupported," 66 | " col_udt" 67 | ") VALUES (" 68 | " 'a'," 69 | " 'b'," 70 | " '2022-01-01T12:34:56.000'," 71 | " [1, 2, 3]," 72 | " [0xff, 0xff]," 73 | " {1, 2, 3}," 74 | " {0xff, 0xff}," 75 | " {1: 'one'}," 76 | " {0xff: 0xff}," 77 | " [{{0.1, 0.2}: 3}, {{0.3, 0.4}: 6}]," 78 | " {genus: 'Eratigena', species: 'atrica', size: 1.8}" 79 | ");" 80 | ), 81 | ] 82 | FILTER_LOWSUPPORT = {"a": "a", "b": "b"} 83 | ILLEGAL_PROJECTIONS_LOWSUPPORT = [ 84 | {"col_unsupported": True}, 85 | {"col_unsupported_udt": True}, 86 | ] 87 | PROJECTION_LOWSUPPORT = {"col_unsupported": False, "col_udt": False} 88 | EXPECTED_ROW_LOWSUPPORT = { 89 | "a": "a", 90 | "b": "b", 91 | "col_static_list_exotic": [ 92 | {"$binary": "/w=="}, 93 | {"$binary": "/w=="}, 94 | ], 95 | "col_static_set": [1, 2, 3], 96 | "col_static_map_exotic": [ 97 | [ 98 | {"$binary": "/w=="}, 99 | {"$binary": "/w=="}, 100 | ] 101 | ], 102 | "col_static_timestamp": DataAPITimestamp.from_string("2022-01-01T12:34:56Z"), 103 | "col_static_map": [[1, "one"]], 104 | "col_static_list": [1, 2, 3], 105 | "col_static_set_exotic": [{"$binary": "/w=="}], 106 | } 107 | -------------------------------------------------------------------------------- /tests/base/integration/tables/test_table_column_types_async.py: -------------------------------------------------------------------------------- 1 | # Copyright DataStax, Inc. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | from __future__ import annotations 16 | 17 | import pytest 18 | 19 | from astrapy.api_options import APIOptions, SerdesOptions 20 | 21 | from ..conftest import DefaultAsyncTable, _repaint_NaNs 22 | from .table_row_assets import ( 23 | FULL_AR_ROW_CUSTOMTYPED, 24 | FULL_AR_ROW_NONCUSTOMTYPED, 25 | ) 26 | 27 | 28 | class TestTableColumnTypesSync: 29 | @pytest.mark.describe("test of table write-and-read, with custom types, async") 30 | async def test_table_write_then_read_customtypes_async( 31 | self, 32 | async_empty_table_all_returns: DefaultAsyncTable, 33 | ) -> None: 34 | cdtypes_atable = async_empty_table_all_returns.with_options( 35 | api_options=APIOptions( 36 | serdes_options=SerdesOptions( 37 | custom_datatypes_in_reading=True, 38 | ), 39 | ), 40 | ) 41 | await cdtypes_atable.insert_one(FULL_AR_ROW_CUSTOMTYPED) 42 | retrieved = await cdtypes_atable.find_one({}) 43 | assert _repaint_NaNs(retrieved) == _repaint_NaNs(FULL_AR_ROW_CUSTOMTYPED) 44 | 45 | @pytest.mark.describe("test of table write-and-read, with noncustom types, async") 46 | async def test_table_write_then_read_noncustomtypes_async( 47 | self, 48 | async_empty_table_all_returns: DefaultAsyncTable, 49 | ) -> None: 50 | rdtypes_atable = async_empty_table_all_returns.with_options( 51 | api_options=APIOptions( 52 | serdes_options=SerdesOptions( 53 | custom_datatypes_in_reading=False, 54 | ), 55 | ), 56 | ) 57 | await rdtypes_atable.insert_one(FULL_AR_ROW_NONCUSTOMTYPED) 58 | retrieved = await rdtypes_atable.find_one({}) 59 | assert _repaint_NaNs(retrieved) == _repaint_NaNs(FULL_AR_ROW_NONCUSTOMTYPED) 60 | -------------------------------------------------------------------------------- /tests/base/integration/tables/test_table_column_types_sync.py: -------------------------------------------------------------------------------- 1 | # Copyright DataStax, Inc. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | from __future__ import annotations 16 | 17 | import pytest 18 | 19 | from astrapy.api_options import APIOptions, SerdesOptions 20 | 21 | from ..conftest import DefaultTable, _repaint_NaNs 22 | from .table_row_assets import ( 23 | FULL_AR_ROW_CUSTOMTYPED, 24 | FULL_AR_ROW_NONCUSTOMTYPED, 25 | ) 26 | 27 | 28 | class TestTableColumnTypesSync: 29 | @pytest.mark.describe("test of table write-and-read, with custom types, sync") 30 | def test_table_write_then_read_customtypes_sync( 31 | self, 32 | sync_empty_table_all_returns: DefaultTable, 33 | ) -> None: 34 | cdtypes_table = sync_empty_table_all_returns.with_options( 35 | api_options=APIOptions( 36 | serdes_options=SerdesOptions( 37 | custom_datatypes_in_reading=True, 38 | ), 39 | ), 40 | ) 41 | cdtypes_table.insert_one(FULL_AR_ROW_CUSTOMTYPED) 42 | retrieved = cdtypes_table.find_one({}) 43 | assert _repaint_NaNs(retrieved) == _repaint_NaNs(FULL_AR_ROW_CUSTOMTYPED) 44 | 45 | @pytest.mark.describe("test of table write-and-read, with noncustom types, sync") 46 | def test_table_write_then_read_noncustomtypes_sync( 47 | self, 48 | sync_empty_table_all_returns: DefaultTable, 49 | ) -> None: 50 | rdtypes_table = sync_empty_table_all_returns.with_options( 51 | api_options=APIOptions( 52 | serdes_options=SerdesOptions( 53 | custom_datatypes_in_reading=False, 54 | ), 55 | ), 56 | ) 57 | rdtypes_table.insert_one(FULL_AR_ROW_NONCUSTOMTYPED) 58 | retrieved = rdtypes_table.find_one({}) 59 | assert _repaint_NaNs(retrieved) == _repaint_NaNs(FULL_AR_ROW_NONCUSTOMTYPED) 60 | -------------------------------------------------------------------------------- /tests/base/integration/tables/test_table_cqldriven_dml_sync.py: -------------------------------------------------------------------------------- 1 | # Copyright DataStax, Inc. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | from __future__ import annotations 16 | 17 | import time 18 | from typing import TYPE_CHECKING 19 | 20 | import pytest 21 | 22 | from astrapy import Database 23 | from astrapy.exceptions import DataAPIResponseException 24 | 25 | from ..conftest import CQL_AVAILABLE 26 | from .table_cql_assets import ( 27 | CREATE_TABLE_COUNTER, 28 | CREATE_TABLE_LOWSUPPORT, 29 | CREATE_TYPE_LOWSUPPORT, 30 | DROP_TABLE_COUNTER, 31 | DROP_TABLE_LOWSUPPORT, 32 | DROP_TYPE_LOWSUPPORT, 33 | EXPECTED_ROW_COUNTER, 34 | EXPECTED_ROW_LOWSUPPORT, 35 | FILTER_COUNTER, 36 | FILTER_LOWSUPPORT, 37 | ILLEGAL_PROJECTIONS_LOWSUPPORT, 38 | INSERTS_TABLE_COUNTER, 39 | INSERTS_TABLE_LOWSUPPORT, 40 | PROJECTION_LOWSUPPORT, 41 | TABLE_NAME_COUNTER, 42 | TABLE_NAME_LOWSUPPORT, 43 | ) 44 | 45 | if TYPE_CHECKING: 46 | from cassandra.cluster import Session 47 | 48 | try: 49 | from cassandra.cluster import Session 50 | except ImportError: 51 | pass 52 | 53 | 54 | @pytest.mark.skipif(not CQL_AVAILABLE, reason="Not CQL session available") 55 | class TestTableCQLDrivenDMLSync: 56 | @pytest.mark.describe( 57 | "test of reading from a CQL-driven table with a Counter, sync" 58 | ) 59 | def test_table_cqldriven_counter_sync( 60 | self, 61 | cql_session: Session, 62 | sync_database: Database, 63 | ) -> None: 64 | try: 65 | cql_session.execute(CREATE_TABLE_COUNTER) 66 | for insert_statement in INSERTS_TABLE_COUNTER: 67 | cql_session.execute(insert_statement) 68 | time.sleep(1.5) # delay for schema propagation 69 | 70 | table = sync_database.get_table(TABLE_NAME_COUNTER) 71 | table.definition() 72 | row = table.find_one(filter=FILTER_COUNTER) 73 | assert row == EXPECTED_ROW_COUNTER 74 | table.delete_one(filter=FILTER_COUNTER) 75 | row = table.find_one(filter=FILTER_COUNTER) 76 | assert row is None 77 | finally: 78 | cql_session.execute(DROP_TABLE_COUNTER) 79 | 80 | @pytest.mark.describe( 81 | "test of reading from a CQL-driven table with limited-support columns, sync" 82 | ) 83 | def test_table_cqldriven_lowsupport_sync( 84 | self, 85 | cql_session: Session, 86 | sync_database: Database, 87 | ) -> None: 88 | try: 89 | cql_session.execute(CREATE_TYPE_LOWSUPPORT) 90 | cql_session.execute(CREATE_TABLE_LOWSUPPORT) 91 | for insert_statement in INSERTS_TABLE_LOWSUPPORT: 92 | cql_session.execute(insert_statement) 93 | time.sleep(1.5) # delay for schema propagation 94 | 95 | table = sync_database.get_table(TABLE_NAME_LOWSUPPORT) 96 | table.definition() 97 | for ill_proj in ILLEGAL_PROJECTIONS_LOWSUPPORT: 98 | with pytest.raises(DataAPIResponseException): 99 | table.find_one(filter=FILTER_LOWSUPPORT) 100 | row = table.find_one( 101 | filter=FILTER_LOWSUPPORT, projection=PROJECTION_LOWSUPPORT 102 | ) 103 | assert row == EXPECTED_ROW_LOWSUPPORT 104 | table.delete_one(filter=FILTER_LOWSUPPORT) 105 | row = table.find_one( 106 | filter=FILTER_LOWSUPPORT, projection=PROJECTION_LOWSUPPORT 107 | ) 108 | assert row is None 109 | finally: 110 | cql_session.execute(DROP_TABLE_LOWSUPPORT) 111 | cql_session.execute(DROP_TYPE_LOWSUPPORT) 112 | -------------------------------------------------------------------------------- /tests/base/unit/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright DataStax, Inc. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | from __future__ import annotations 16 | -------------------------------------------------------------------------------- /tests/base/unit/test_apioptions.py: -------------------------------------------------------------------------------- 1 | # Copyright DataStax, Inc. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | from __future__ import annotations 16 | 17 | import pytest 18 | 19 | from astrapy.utils.api_options import APIOptions, defaultAPIOptions 20 | 21 | 22 | class TestAPIOptions: 23 | @pytest.mark.describe("test of header inheritance in APIOptions") 24 | def test_apioptions_headers(self) -> None: 25 | opts_d = defaultAPIOptions(environment="dev") 26 | opts_1 = opts_d.with_override( 27 | APIOptions( 28 | database_additional_headers={"d": "y", "D": None}, 29 | admin_additional_headers={"a": "y", "A": None}, 30 | redacted_header_names={"x", "y"}, 31 | ) 32 | ) 33 | opts_2 = opts_d.with_override( 34 | APIOptions( 35 | database_additional_headers={"D": "y"}, 36 | admin_additional_headers={"A": "y"}, 37 | redacted_header_names={"x"}, 38 | ) 39 | ).with_override( 40 | APIOptions( 41 | database_additional_headers={"d": "y", "D": None}, 42 | admin_additional_headers={"a": "y", "A": None}, 43 | redacted_header_names={"y"}, 44 | ) 45 | ) 46 | 47 | assert opts_1 == opts_2 48 | -------------------------------------------------------------------------------- /tests/base/unit/test_collection_decimal_support.py: -------------------------------------------------------------------------------- 1 | # Copyright DataStax, Inc. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | from __future__ import annotations 16 | 17 | import pytest 18 | 19 | from astrapy.data.utils.collection_converters import ( 20 | postprocess_collection_response, 21 | preprocess_collection_payload, 22 | ) 23 | from astrapy.utils.api_commander import APICommander 24 | 25 | from ..collection_decimal_support_assets import ( 26 | BASELINE_OBJ, 27 | S_OPTS_NO_DECS, 28 | S_OPTS_OK_DECS, 29 | WDECS_OBJ, 30 | is_decimal_super, 31 | ) 32 | 33 | 34 | class TestCollectionDecimalSupportUnit: 35 | @pytest.mark.describe( 36 | "test of decimals not supported by default in collection codec paths" 37 | ) 38 | def test_decimalsupport_collection_codecpath_defaultsettings(self) -> None: 39 | # write path with baseline 40 | baseline_fully_encoded = APICommander._decimal_unaware_encode_payload( 41 | preprocess_collection_payload( 42 | BASELINE_OBJ, 43 | options=S_OPTS_NO_DECS, 44 | ) 45 | ) 46 | # read path back to it 47 | assert baseline_fully_encoded is not None 48 | baseline_obj_2 = postprocess_collection_response( 49 | APICommander._decimal_unaware_parse_json_response( 50 | baseline_fully_encoded, 51 | ), 52 | options=S_OPTS_NO_DECS, 53 | ) 54 | # this must match exactly (baseline) 55 | assert BASELINE_OBJ == baseline_obj_2 56 | 57 | # write path with decimals should error instead 58 | with pytest.raises(TypeError): 59 | APICommander._decimal_unaware_encode_payload( 60 | preprocess_collection_payload( 61 | WDECS_OBJ, 62 | options=S_OPTS_NO_DECS, 63 | ) 64 | ) 65 | 66 | @pytest.mark.describe( 67 | "test of decimals supported in collection codec paths if set to do so" 68 | ) 69 | def test_decimalsupport_collection_codecpath_decimalsettings(self) -> None: 70 | # write path with baseline 71 | baseline_fully_encoded = APICommander._decimal_aware_encode_payload( 72 | preprocess_collection_payload( 73 | BASELINE_OBJ, 74 | options=S_OPTS_OK_DECS, 75 | ) 76 | ) 77 | # read path back to it 78 | assert baseline_fully_encoded is not None 79 | baseline_obj_2 = postprocess_collection_response( 80 | APICommander._decimal_aware_parse_json_response( 81 | baseline_fully_encoded, 82 | ), 83 | options=S_OPTS_OK_DECS, 84 | ) 85 | # the re-read object must "be more-or-equally Decimal" than the source 86 | # but otherwise coincide 87 | assert is_decimal_super(baseline_obj_2, BASELINE_OBJ) 88 | 89 | # write path with decimals 90 | wdecs_fully_encoded = APICommander._decimal_aware_encode_payload( 91 | preprocess_collection_payload( 92 | WDECS_OBJ, 93 | options=S_OPTS_OK_DECS, 94 | ) 95 | ) 96 | # read path back to it 97 | assert wdecs_fully_encoded is not None 98 | wdecs_2 = postprocess_collection_response( 99 | APICommander._decimal_aware_parse_json_response( 100 | wdecs_fully_encoded, 101 | ), 102 | options=S_OPTS_OK_DECS, 103 | ) 104 | # the re-read object must "be more-or-equally Decimal" than the source 105 | # but otherwise coincide 106 | assert is_decimal_super(wdecs_2, WDECS_OBJ) 107 | -------------------------------------------------------------------------------- /tests/base/unit/test_collectionvectorserviceoptions.py: -------------------------------------------------------------------------------- 1 | # Copyright DataStax, Inc. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | from __future__ import annotations 16 | 17 | import pytest 18 | 19 | from astrapy.info import VectorServiceOptions 20 | 21 | 22 | class TestVectorServiceOptions: 23 | @pytest.mark.describe("test of VectorServiceOptions conversions, base") 24 | def test_VectorServiceOptions_conversions_base(self) -> None: 25 | dict0 = { 26 | "provider": "PRO", 27 | "modelName": "MOD", 28 | } 29 | from_dict0 = VectorServiceOptions._from_dict(dict0) 30 | assert from_dict0 == VectorServiceOptions( 31 | provider="PRO", 32 | model_name="MOD", 33 | ) 34 | assert from_dict0.as_dict() == dict0 35 | 36 | @pytest.mark.describe("test of VectorServiceOptions conversions, with auth") 37 | def test_VectorServiceOptions_conversions_auth(self) -> None: 38 | dict_a = { 39 | "provider": "PRO", 40 | "modelName": "MOD", 41 | "authentication": { 42 | "type": ["A_T"], 43 | "field": "value", 44 | }, 45 | } 46 | from_dict_a = VectorServiceOptions._from_dict(dict_a) 47 | assert from_dict_a == VectorServiceOptions( 48 | provider="PRO", 49 | model_name="MOD", 50 | authentication={ 51 | "type": ["A_T"], 52 | "field": "value", 53 | }, 54 | ) 55 | assert from_dict_a.as_dict() == dict_a 56 | 57 | @pytest.mark.describe("test of VectorServiceOptions conversions, with params") 58 | def test_VectorServiceOptions_conversions_params(self) -> None: 59 | dict_p = { 60 | "provider": "PRO", 61 | "modelName": "MOD", 62 | "parameters": { 63 | "field1": "value1", 64 | "field2": 12.3, 65 | "field3": 123, 66 | "field4": True, 67 | "field5": None, 68 | "field6": {"a": 1}, 69 | }, 70 | } 71 | from_dict_p = VectorServiceOptions._from_dict(dict_p) 72 | assert from_dict_p == VectorServiceOptions( 73 | provider="PRO", 74 | model_name="MOD", 75 | parameters={ 76 | "field1": "value1", 77 | "field2": 12.3, 78 | "field3": 123, 79 | "field4": True, 80 | "field5": None, 81 | "field6": {"a": 1}, 82 | }, 83 | ) 84 | assert from_dict_p.as_dict() == dict_p 85 | 86 | @pytest.mark.describe( 87 | "test of VectorServiceOptions conversions, with params and auth" 88 | ) 89 | def test_VectorServiceOptions_conversions_params_auth(self) -> None: 90 | dict_ap = { 91 | "provider": "PRO", 92 | "modelName": "MOD", 93 | "authentication": { 94 | "type": ["A_T"], 95 | "field": "value", 96 | }, 97 | "parameters": { 98 | "field1": "value1", 99 | "field2": 12.3, 100 | "field3": 123, 101 | "field4": True, 102 | "field5": None, 103 | "field6": {"a": 1}, 104 | }, 105 | } 106 | from_dict_ap = VectorServiceOptions._from_dict(dict_ap) 107 | assert from_dict_ap == VectorServiceOptions( 108 | provider="PRO", 109 | model_name="MOD", 110 | authentication={ 111 | "type": ["A_T"], 112 | "field": "value", 113 | }, 114 | parameters={ 115 | "field1": "value1", 116 | "field2": 12.3, 117 | "field3": 123, 118 | "field4": True, 119 | "field5": None, 120 | "field6": {"a": 1}, 121 | }, 122 | ) 123 | assert from_dict_ap.as_dict() == dict_ap 124 | -------------------------------------------------------------------------------- /tests/base/unit/test_dataapimap.py: -------------------------------------------------------------------------------- 1 | # Copyright DataStax, Inc. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | from __future__ import annotations 16 | 17 | import pickle 18 | 19 | import pytest 20 | 21 | from astrapy.data_types import DataAPIMap 22 | 23 | 24 | class TestDataAPIMap: 25 | @pytest.mark.describe("test of table map usage with hashables") 26 | def test_dataapimap_hashables(self) -> None: 27 | mp0: DataAPIMap[int, int] = DataAPIMap() 28 | assert mp0 == DataAPIMap() 29 | assert dict(mp0) == {} 30 | assert mp0 == {} 31 | # identity/equality 32 | items1 = [(1, "a"), (2, "b"), (3, "c")] 33 | mp1 = DataAPIMap(items1) 34 | assert mp1 == DataAPIMap(items1) 35 | assert mp1 == DataAPIMap(items1 + [(2, "b")]) 36 | assert mp1 == DataAPIMap(items1[2:] + items1[:2]) 37 | assert mp1 != DataAPIMap(items1[1:] + [(1, "z")]) 38 | assert mp1 != DataAPIMap(items1[1:] + [(9, "z")]) 39 | assert dict(items1) == mp1 40 | assert mp1 == dict(items1) 41 | assert mp1 == dict(items1 + [(2, "b")]) 42 | assert mp1 == dict(items1[2:] + items1[:2]) 43 | assert mp1 != dict(items1[1:] + [(1, "z")]) 44 | assert mp1 != dict(items1[1:] + [(9, "z")]) 45 | # map operations 46 | assert list(mp1.keys()) == [1, 2, 3] 47 | assert list(mp1.values()) == ["a", "b", "c"] 48 | assert mp1[2] == "b" 49 | with pytest.raises(KeyError): 50 | mp1[4] 51 | assert list(mp1.items()) == items1 52 | assert len(mp1) == len(items1) 53 | assert list(mp1) == [1, 2, 3] 54 | 55 | @pytest.mark.describe("test of table map usage with non-hashables") 56 | def test_dataapimap_nonhashables(self) -> None: 57 | mp0: DataAPIMap[list[int], int] = DataAPIMap() 58 | assert mp0 == DataAPIMap() 59 | assert dict(mp0) == {} 60 | # identity/equality 61 | items1 = [([1], "a"), ([2], "b"), ([3], "c")] 62 | mp1 = DataAPIMap(items1) 63 | assert mp1 == DataAPIMap(items1) 64 | assert mp1 == DataAPIMap(items1 + [([2], "b")]) 65 | assert mp1 == DataAPIMap(items1[2:] + items1[:2]) 66 | assert mp1 != DataAPIMap(items1[1:] + [([1], "z")]) 67 | assert mp1 != DataAPIMap(items1[1:] + [([9], "z")]) 68 | # map operations 69 | assert list(mp1.keys()) == [[1], [2], [3]] 70 | assert list(mp1.values()) == ["a", "b", "c"] 71 | assert mp1[[2]] == "b" 72 | with pytest.raises(KeyError): 73 | mp1[[4]] 74 | assert list(mp1.items()) == items1 75 | assert len(mp1) == len(items1) 76 | assert list(mp1) == [[1], [2], [3]] 77 | 78 | @pytest.mark.describe("test pickling of DataAPIMap") 79 | def test_dataapimap_pickle(self) -> None: 80 | the_map = DataAPIMap([("key1", 1), (None, 2), ("key3", {"a": 1})]) 81 | assert pickle.loads(pickle.dumps(the_map)) == the_map 82 | -------------------------------------------------------------------------------- /tests/base/unit/test_dataapiset.py: -------------------------------------------------------------------------------- 1 | # Copyright DataStax, Inc. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | from __future__ import annotations 16 | 17 | import pickle 18 | 19 | import pytest 20 | 21 | from astrapy.data_types import DataAPISet 22 | 23 | 24 | class TestDataAPISet: 25 | @pytest.mark.describe("test of table set usage with hashables") 26 | def test_dataapiset_hashables(self) -> None: 27 | ts0: DataAPISet[int] = DataAPISet() 28 | assert ts0 == DataAPISet() 29 | assert set(ts0) == set() 30 | # identity/equality 31 | ts1 = DataAPISet([1, 2, 3]) 32 | assert ts1 == DataAPISet([1, 2, 3]) 33 | assert ts1 == DataAPISet([1, 2, 3, 2]) 34 | assert ts1 != DataAPISet([1, 3, 2]) 35 | assert set(ts1) == {1, 2, 3} 36 | # set operations 37 | assert ts1 - {2} == DataAPISet([1, 3]) 38 | assert ts1 | {2} == DataAPISet([1, 2, 3]) 39 | assert ts1 | {4} == DataAPISet([1, 2, 3, 4]) 40 | assert ts1 | DataAPISet([4, 2, 5]) == DataAPISet([1, 2, 3, 4, 5]) 41 | assert ts1 ^ DataAPISet([4, 2, 5]) == DataAPISet([1, 3, 4, 5]) 42 | assert ts1 & DataAPISet([4, 2, 5]) == DataAPISet([2]) 43 | 44 | @pytest.mark.describe("test of table set usage with non-hashables") 45 | def test_dataapiset_nonhashables(self) -> None: 46 | ts0: DataAPISet[list[int]] = DataAPISet() 47 | assert ts0 == DataAPISet() 48 | # identity/equality 49 | ts1 = DataAPISet([[1], [2], [3]]) 50 | assert ts1 == DataAPISet([[1], [2], [3]]) 51 | assert ts1 == DataAPISet([[1], [2], [3], [2]]) 52 | assert ts1 != DataAPISet([[1], [3], [2]]) 53 | # set operations 54 | assert ts1 - DataAPISet([[2]]) == DataAPISet([[1], [3]]) 55 | assert ts1 | DataAPISet([[2]]) == DataAPISet([[1], [2], [3]]) 56 | assert ts1 | DataAPISet([[4]]) == DataAPISet([[1], [2], [3], [4]]) 57 | assert ts1 | DataAPISet([[4], [2], [5]]) == DataAPISet( 58 | [[1], [2], [3], [4], [5]] 59 | ) 60 | assert ts1 ^ DataAPISet([[4], [2], [5]]) == DataAPISet([[1], [3], [4], [5]]) 61 | assert ts1 & DataAPISet([[4], [2], [5]]) == DataAPISet([[2]]) 62 | 63 | @pytest.mark.describe("test pickling of DataAPISet") 64 | def test_dataapiset_pickle(self) -> None: 65 | the_set = DataAPISet([1, 2, 3]) 66 | assert pickle.loads(pickle.dumps(the_set)) == the_set 67 | -------------------------------------------------------------------------------- /tests/base/unit/test_dataapitime.py: -------------------------------------------------------------------------------- 1 | # Copyright DataStax, Inc. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | from __future__ import annotations 16 | 17 | import datetime 18 | import pickle 19 | 20 | import pytest 21 | 22 | from astrapy.data_types import DataAPITime 23 | 24 | 25 | class TestDataAPITime: 26 | @pytest.mark.describe("test of time type, errors in parsing from string") 27 | def test_dataapitime_parse_errors(self) -> None: 28 | # empty, faulty, misformatted 29 | with pytest.raises(ValueError): 30 | DataAPITime.from_string("") 31 | with pytest.raises(ValueError): 32 | DataAPITime.from_string("boom") 33 | 34 | with pytest.raises(ValueError): 35 | DataAPITime.from_string("12:34:56:21") 36 | with pytest.raises(ValueError): 37 | DataAPITime.from_string("+12:34:56") 38 | with pytest.raises(ValueError): 39 | DataAPITime.from_string("12:+34:56") 40 | with pytest.raises(ValueError): 41 | DataAPITime.from_string("12:34:56.") 42 | with pytest.raises(ValueError): 43 | DataAPITime.from_string("12:34:56X") 44 | with pytest.raises(ValueError): 45 | DataAPITime.from_string("X12:34:56") 46 | with pytest.raises(ValueError): 47 | DataAPITime.from_string("12:34X:56") 48 | 49 | with pytest.raises(ValueError): 50 | DataAPITime.from_string("24:00:00") 51 | with pytest.raises(ValueError): 52 | DataAPITime.from_string("00:60:00") 53 | with pytest.raises(ValueError): 54 | DataAPITime.from_string("00:00:60") 55 | 56 | DataAPITime.from_string("12:23") 57 | DataAPITime.from_string("00:00:00") 58 | DataAPITime.from_string("23:00:00") 59 | DataAPITime.from_string("00:59:00") 60 | DataAPITime.from_string("00:00:00.123456789") 61 | DataAPITime.from_string("00:00:00.123") 62 | DataAPITime.from_string("01:02:03.123456789") 63 | DataAPITime.from_string("01:02:03.123") 64 | 65 | @pytest.mark.describe("test of time type, lifecycle") 66 | def test_dataapitime_lifecycle(self) -> None: 67 | tint = DataAPITime.from_string("02:03:04") 68 | tint_exp = DataAPITime(hour=2, minute=3, second=4) 69 | assert tint == tint_exp 70 | assert tint == DataAPITime.from_string("2:3:4") 71 | assert tint == DataAPITime.from_string("0002:003:000004") 72 | py_tint = datetime.time(2, 3, 4) 73 | assert DataAPITime.from_time(py_tint) == tint 74 | assert tint.to_time() == py_tint 75 | repr(tint) 76 | str(tint) 77 | tint.to_string() 78 | 79 | tfra = DataAPITime.from_string("02:03:04.9876") 80 | tfra_exp = DataAPITime(hour=2, minute=3, second=4, nanosecond=987600000) 81 | assert tfra == tfra_exp 82 | assert tfra == DataAPITime.from_string("2:3:4.9876") 83 | assert tfra == DataAPITime.from_string("0002:003:000004.987600000") 84 | py_tfra = datetime.time(2, 3, 4, 987600) 85 | assert DataAPITime.from_time(py_tfra) == tfra 86 | assert tfra.to_time() == py_tfra 87 | repr(tfra) 88 | str(tfra) 89 | tfra.to_string() 90 | 91 | tfra1 = DataAPITime(1, 2, 3, 123) 92 | tfra1p = DataAPITime(1, 2, 3, 12) 93 | tfra2p = DataAPITime(1, 2, 3, 12345) 94 | tfra2 = DataAPITime(1, 2, 3, 123456) 95 | tfra3p = DataAPITime(1, 2, 3, 12345678) 96 | tfra3 = DataAPITime(1, 2, 3, 123456789) 97 | assert DataAPITime.from_string(tfra1.to_string()) == tfra1 98 | assert DataAPITime.from_string(tfra1p.to_string()) == tfra1p 99 | assert DataAPITime.from_string(tfra2p.to_string()) == tfra2p 100 | assert DataAPITime.from_string(tfra2.to_string()) == tfra2 101 | assert DataAPITime.from_string(tfra3p.to_string()) == tfra3p 102 | assert DataAPITime.from_string(tfra3.to_string()) == tfra3 103 | 104 | t1 = DataAPITime(1, 2, 3, 30000) 105 | t2 = DataAPITime(1, 2, 3, 45000) 106 | py_t1 = t1.to_time() 107 | py_t2 = t2.to_time() 108 | assert t1 < t2 109 | assert t1 <= t2 110 | assert t2 > t1 111 | assert t2 >= t1 112 | assert py_t1 < t2 113 | assert py_t1 <= t2 114 | assert py_t2 > t1 115 | assert py_t2 >= t1 116 | assert t1 < py_t2 117 | assert t1 <= py_t2 118 | assert t2 > py_t1 119 | assert t2 >= py_t1 120 | 121 | @pytest.mark.describe("test pickling of DataAPITime") 122 | def test_dataapitime_pickle(self) -> None: 123 | the_time = DataAPITime(12, 34, 56, 789) 124 | assert pickle.loads(pickle.dumps(the_time)) == the_time 125 | -------------------------------------------------------------------------------- /tests/base/unit/test_dataapivector.py: -------------------------------------------------------------------------------- 1 | # Copyright DataStax, Inc. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | from __future__ import annotations 16 | 17 | import pytest 18 | 19 | from astrapy.data.utils.extended_json_converters import ( 20 | convert_ejson_binary_object_to_bytes, 21 | convert_to_ejson_bytes, 22 | ) 23 | from astrapy.data_types import DataAPIVector 24 | from astrapy.data_types.data_api_vector import bytes_to_floats, floats_to_bytes 25 | 26 | COMPARE_EPSILON = 0.00001 27 | 28 | SAMPLE_FLOAT_LISTS: list[list[float]] = [ 29 | [10, 100, 1000, 0.1], 30 | [0.0043, 0.123, 1.332], 31 | [], 32 | [103.1, 104.5, 105.6], 33 | [(-1 if i % 2 == 0 else +1) * i / 5000 for i in range(4096)], 34 | ] 35 | 36 | 37 | def _nearly_equal_lists(list1: list[float], list2: list[float]) -> bool: 38 | if len(list1) != len(list2): 39 | return False 40 | if len(list1) == 0: 41 | return True 42 | return max(abs(x - y) for x, y in zip(list1, list2)) < COMPARE_EPSILON 43 | 44 | 45 | def _nearly_equal_vectors(vec1: DataAPIVector, vec2: DataAPIVector) -> bool: 46 | return _nearly_equal_lists(vec1.data, vec2.data) 47 | 48 | 49 | class TestDataAPIVector: 50 | @pytest.mark.describe("test of float-binary conversions") 51 | def test_dataapivector_byteconversions(self) -> None: 52 | for test_list0 in SAMPLE_FLOAT_LISTS: 53 | test_bytes0 = floats_to_bytes(test_list0) 54 | test_list1 = bytes_to_floats(test_bytes0) 55 | assert _nearly_equal_lists(test_list0, test_list1) 56 | 57 | @pytest.mark.describe("test of float-string conversions") 58 | def test_dataapivector_stringconversions(self) -> None: 59 | # known expectation 60 | obj_1 = {"$binary": "PczMzT5MzM0+mZma"} 61 | vec = DataAPIVector.from_bytes(convert_ejson_binary_object_to_bytes(obj_1)) 62 | vec_exp = DataAPIVector([0.1, 0.2, 0.3]) 63 | assert _nearly_equal_vectors(vec, vec_exp) 64 | # some full-round conversions 65 | for test_list0 in SAMPLE_FLOAT_LISTS: 66 | test_vec0 = DataAPIVector(test_list0) 67 | test_ejson = convert_to_ejson_bytes(test_vec0.to_bytes()) 68 | test_vec1 = DataAPIVector.from_bytes( 69 | convert_ejson_binary_object_to_bytes(test_ejson) 70 | ) 71 | assert _nearly_equal_vectors(test_vec0, test_vec1) 72 | 73 | @pytest.mark.describe("test of DataAPIVector lifecycle") 74 | def test_dataapivector_lifecycle(self) -> None: 75 | v0 = DataAPIVector([]) 76 | v1 = DataAPIVector([1.1, 2.2, 3.3]) 77 | for i in v0: 78 | pass 79 | for i in v1: 80 | pass 81 | assert list(v0) == [] 82 | assert list(v1) == [1.1, 2.2, 3.3] 83 | assert v0 != v1 84 | assert v1 == DataAPIVector([1.1, 2.2, 3.3]) 85 | 86 | # list-likeness 87 | assert v1[1:2] == DataAPIVector([2.2]) 88 | assert len(v1) == 3 89 | -------------------------------------------------------------------------------- /tests/base/unit/test_document_paths.py: -------------------------------------------------------------------------------- 1 | # Copyright DataStax, Inc. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | from __future__ import annotations 16 | 17 | import pytest 18 | 19 | from astrapy.utils.document_paths import ( 20 | ILLEGAL_ESCAPE_ERROR_MESSAGE_TEMPLATE, 21 | UNTERMINATED_ESCAPE_ERROR_MESSAGE_TEMPLATE, 22 | escape_field_names, 23 | unescape_field_path, 24 | ) 25 | 26 | ESCAPE_TEST_FIELDS = { 27 | "a": "a", 28 | "a.b": "a&.b", 29 | "a&c": "a&&c", 30 | "a..": "a&.&.", 31 | "": "", 32 | "xyz": "xyz", 33 | ".": "&.", 34 | "...": "&.&.&.", 35 | "&": "&&", 36 | "&&": "&&&&", 37 | "&.&.&..": "&&&.&&&.&&&.&.", 38 | "🇦🇨": "🇦🇨", 39 | "q🇦🇨w": "q🇦🇨w", 40 | ".🇦🇨&": "&.🇦🇨&&", 41 | "a&b.c.&.d": "a&&b&.c&.&&&.d", 42 | ".env": "&.env", 43 | """: "&"", 44 | "ݤ": "ݤ", 45 | "Aݤ": "Aݤ", 46 | "ݤZ": "ݤZ", 47 | "🚾 🆒 🆓 🆕 🆖 🆗 🆙 🏧": "🚾 🆒 🆓 🆕 🆖 🆗 🆙 🏧", 48 | "✋🏿 💪🏿 👐🏿 🙌🏿 👏🏿 🙏🏿": "✋🏿 💪🏿 👐🏿 🙌🏿 👏🏿 🙏🏿", 49 | "👨‍👩‍👦 👨‍👩‍👧‍👦 👨‍👨‍👦 👩‍👩‍👧 👨‍👦 👨‍👧‍👦 👩‍👦 👩‍👧‍👦": "👨‍👩‍👦 👨‍👩‍👧‍👦 👨‍👨‍👦 👩‍👩‍👧 👨‍👦 👨‍👧‍👦 👩‍👦 👩‍👧‍👦", 50 | } 51 | 52 | 53 | class TestDocumentPaths: 54 | @pytest.mark.describe("test of escape_field_names, one arg") 55 | def test_escape_field_names_onearg(self) -> None: 56 | for lit_fn, esc_fn in ESCAPE_TEST_FIELDS.items(): 57 | assert escape_field_names(lit_fn) == esc_fn 58 | assert escape_field_names([lit_fn]) == esc_fn 59 | for num in [0, 12, 130099]: 60 | assert escape_field_names(num) == str(num) 61 | assert escape_field_names([num]) == str(num) 62 | 63 | @pytest.mark.describe("test of escape_field_names, multiple args") 64 | def test_escape_field_names_multiargs(self) -> None: 65 | all_lits, all_escs = list(zip(*ESCAPE_TEST_FIELDS.items())) 66 | assert escape_field_names(all_lits) == ".".join(all_escs) 67 | assert escape_field_names(*all_lits) == ".".join(all_escs) 68 | assert escape_field_names(all_lits[:0]) == ".".join(all_escs[:0]) 69 | assert escape_field_names(*all_lits[:0]) == ".".join(all_escs[:0]) 70 | assert escape_field_names(all_lits[:3]) == ".".join(all_escs[:3]) 71 | assert escape_field_names(*all_lits[:3]) == ".".join(all_escs[:3]) 72 | assert escape_field_names(all_lits[3:6]) == ".".join(all_escs[3:6]) 73 | assert escape_field_names(*all_lits[3:6]) == ".".join(all_escs[3:6]) 74 | 75 | assert escape_field_names(["first", 12, "last&!."]) == "first.12.last&&!&." 76 | assert escape_field_names("first", 12, "last&!.") == "first.12.last&&!&." 77 | 78 | @pytest.mark.describe("test of unescape_field_path") 79 | def test_unescape_field_path(self) -> None: 80 | assert unescape_field_path("a&.b") == ["a.b"] 81 | 82 | all_lits, all_escs = list(zip(*ESCAPE_TEST_FIELDS.items())) 83 | assert unescape_field_path(".".join(all_escs)) == list(all_lits) 84 | assert unescape_field_path(".".join(all_escs[:3])) == list(all_lits[:3]) 85 | assert unescape_field_path(".".join(all_escs[3:6])) == list(all_lits[3:6]) 86 | 87 | assert unescape_field_path("") == [] 88 | 89 | with pytest.raises( 90 | ValueError, match=ILLEGAL_ESCAPE_ERROR_MESSAGE_TEMPLATE[:24] 91 | ): 92 | unescape_field_path("a.b&?c.d") 93 | 94 | with pytest.raises( 95 | ValueError, match=UNTERMINATED_ESCAPE_ERROR_MESSAGE_TEMPLATE[:24] 96 | ): 97 | unescape_field_path("a.b.c&") 98 | -------------------------------------------------------------------------------- /tests/base/unit/test_embeddingheadersprovider.py: -------------------------------------------------------------------------------- 1 | # Copyright DataStax, Inc. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | from __future__ import annotations 16 | 17 | import pytest 18 | 19 | from astrapy.authentication import ( 20 | AWSEmbeddingHeadersProvider, 21 | EmbeddingAPIKeyHeaderProvider, 22 | coerce_embedding_headers_provider, 23 | ) 24 | from astrapy.settings.defaults import ( 25 | EMBEDDING_HEADER_API_KEY, 26 | EMBEDDING_HEADER_AWS_ACCESS_ID, 27 | EMBEDDING_HEADER_AWS_SECRET_ID, 28 | ) 29 | 30 | 31 | class TestEmbeddingHeadersProvider: 32 | @pytest.mark.describe("test of headers from EmbeddingAPIKeyHeaderProvider") 33 | def test_embeddingheadersprovider_static(self) -> None: 34 | ehp = EmbeddingAPIKeyHeaderProvider("x") 35 | assert {k.lower(): v for k, v in ehp.get_headers().items()} == { 36 | EMBEDDING_HEADER_API_KEY.lower(): "x" 37 | } 38 | 39 | @pytest.mark.describe("test of headers from empty EmbeddingAPIKeyHeaderProvider") 40 | def test_embeddingheadersprovider_null(self) -> None: 41 | ehp = EmbeddingAPIKeyHeaderProvider(None) 42 | assert ehp.get_headers() == {} 43 | 44 | @pytest.mark.describe("test of headers from AWSEmbeddingHeadersProvider") 45 | def test_embeddingheadersprovider_aws(self) -> None: 46 | ehp = AWSEmbeddingHeadersProvider( 47 | embedding_access_id="x", 48 | embedding_secret_id="y", 49 | ) 50 | gen_headers_lower = {k.lower(): v for k, v in ehp.get_headers().items()} 51 | exp_headers_lower = { 52 | EMBEDDING_HEADER_AWS_ACCESS_ID.lower(): "x", 53 | EMBEDDING_HEADER_AWS_SECRET_ID.lower(): "y", 54 | } 55 | assert gen_headers_lower == exp_headers_lower 56 | 57 | @pytest.mark.describe("test of embedding headers provider coercion") 58 | def test_embeddingheadersprovider_coercion(self) -> None: 59 | """This doubles as equality test.""" 60 | ehp_s = EmbeddingAPIKeyHeaderProvider("x") 61 | ehp_n = EmbeddingAPIKeyHeaderProvider(None) 62 | ehp_a = AWSEmbeddingHeadersProvider( 63 | embedding_access_id="x", 64 | embedding_secret_id="y", 65 | ) 66 | assert coerce_embedding_headers_provider(ehp_s) == ehp_s 67 | assert coerce_embedding_headers_provider(ehp_n) == ehp_n 68 | assert coerce_embedding_headers_provider(ehp_a) == ehp_a 69 | 70 | assert coerce_embedding_headers_provider("x") == ehp_s 71 | -------------------------------------------------------------------------------- /tests/base/unit/test_findrerankingproviderresult.py: -------------------------------------------------------------------------------- 1 | # Copyright DataStax, Inc. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | from __future__ import annotations 16 | 17 | import pytest 18 | 19 | from astrapy.info import FindRerankingProvidersResult 20 | 21 | RESPONSE_DICT_0 = { 22 | "rerankingProviders": { 23 | "provider": { 24 | "isDefault": True, 25 | "displayName": "TheProvider", 26 | "supportedAuthentication": { 27 | "NONE": { 28 | "tokens": [], 29 | "enabled": True, 30 | }, 31 | }, 32 | "models": [ 33 | { 34 | "name": "provider/", 35 | "isDefault": True, 36 | "url": "https:///ranking", 37 | "properties": None, 38 | }, 39 | ], 40 | }, 41 | }, 42 | } 43 | 44 | RESPONSE_DICT_1 = { 45 | "rerankingProviders": { 46 | "provider": { 47 | "isDefault": True, 48 | "displayName": "TheProvider", 49 | "supportedAuthentication": { 50 | "NONE": { 51 | "tokens": [], 52 | "enabled": True, 53 | }, 54 | }, 55 | "models": [ 56 | { 57 | "name": "provider/", 58 | "apiModelSupport": { 59 | "status": "SUPPORTED", 60 | }, 61 | "isDefault": True, 62 | "url": "https:///ranking", 63 | "properties": None, 64 | }, 65 | ], 66 | }, 67 | }, 68 | } 69 | 70 | 71 | class TestFindRerankingProvidersResult: 72 | @pytest.mark.describe("test of FindRerankingProvidersResult parsing and back") 73 | def test_reranking_providers_result_parsing_and_back_base(self) -> None: 74 | parsed = FindRerankingProvidersResult._from_dict(RESPONSE_DICT_0) 75 | providers = parsed.reranking_providers 76 | assert len(providers) == 1 77 | assert providers["provider"].display_name is not None 78 | 79 | models = providers["provider"].models 80 | assert len(models) == 1 81 | assert models[0].is_default 82 | 83 | dumped = parsed.as_dict() 84 | # clean out apiModelSupport from generated dict before checking 85 | for pro_v in dumped["rerankingProviders"].values(): 86 | for mod_v in pro_v["models"]: 87 | del mod_v["apiModelSupport"] 88 | assert dumped == RESPONSE_DICT_0 89 | 90 | @pytest.mark.describe( 91 | "test of FindRerankingProvidersResult parsing and back (with apiModelSupport)" 92 | ) 93 | def test_reranking_providers_result_parsing_and_back_rich(self) -> None: 94 | parsed = FindRerankingProvidersResult._from_dict(RESPONSE_DICT_1) 95 | providers = parsed.reranking_providers 96 | assert len(providers) == 1 97 | assert providers["provider"].display_name is not None 98 | 99 | models = providers["provider"].models 100 | assert len(models) == 1 101 | assert models[0].is_default 102 | 103 | dumped = parsed.as_dict() 104 | assert dumped == RESPONSE_DICT_1 105 | -------------------------------------------------------------------------------- /tests/base/unit/test_ids.py: -------------------------------------------------------------------------------- 1 | # Copyright DataStax, Inc. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """ 16 | Unit tests for the ObjectIds and UUIDn conversions, 'idiomatic' imports 17 | """ 18 | 19 | from __future__ import annotations 20 | 21 | import json 22 | 23 | import pytest 24 | 25 | from astrapy.data.utils.collection_converters import ( 26 | postprocess_collection_response, 27 | preprocess_collection_payload, 28 | ) 29 | from astrapy.ids import UUID, ObjectId 30 | from astrapy.utils.api_options import FullSerdesOptions 31 | 32 | 33 | @pytest.mark.describe("test of serdes for ids") 34 | def test_ids_serdes() -> None: 35 | f_u1 = UUID("8ccd6ff8-e61b-11ee-a2fc-7df4a8c4164b") # uuid1 36 | f_u3 = UUID("6fa459ea-ee8a-3ca4-894e-db77e160355e") # uuid3 37 | f_u4 = UUID("4f16cba8-1115-43ab-aa39-3a9c29f37db5") # uuid4 38 | f_u5 = UUID("886313e1-3b8a-5372-9b90-0c9aee199e5d") # uuid5 39 | f_u6 = UUID("1eee61b9-8f2d-69ad-8ebb-5054d2a1a2c0") # uuid6 40 | f_u7 = UUID("018e57e5-f586-7ed6-be55-6b0de3041116") # uuid7 41 | f_u8 = UUID("018e57e5-fbcd-8bd4-b794-be914f2c4c85") # uuid8 42 | f_oi = ObjectId("65f9cfa0d7fabb3f255c25a1") 43 | 44 | full_structure = { 45 | "f_u1": f_u1, 46 | "f_u3": f_u3, 47 | "f_u4": f_u4, 48 | "f_u5": f_u5, 49 | "f_u6": f_u6, 50 | "f_u7": f_u7, 51 | "f_u8": f_u8, 52 | "f_oi": f_oi, 53 | } 54 | 55 | normalized = preprocess_collection_payload( 56 | full_structure, 57 | options=FullSerdesOptions( 58 | binary_encode_vectors=True, 59 | custom_datatypes_in_reading=True, 60 | unroll_iterables_to_lists=True, 61 | use_decimals_in_collections=False, 62 | encode_maps_as_lists_in_tables="NEVER", 63 | accept_naive_datetimes=False, 64 | datetime_tzinfo=None, 65 | ), 66 | ) 67 | json.dumps(normalized) 68 | assert normalized is not None 69 | restored = postprocess_collection_response( 70 | normalized, 71 | options=FullSerdesOptions( 72 | binary_encode_vectors=True, 73 | custom_datatypes_in_reading=True, 74 | unroll_iterables_to_lists=True, 75 | use_decimals_in_collections=False, 76 | encode_maps_as_lists_in_tables="NEVER", 77 | accept_naive_datetimes=False, 78 | datetime_tzinfo=None, 79 | ), 80 | ) 81 | assert restored == full_structure 82 | -------------------------------------------------------------------------------- /tests/base/unit/test_info.py: -------------------------------------------------------------------------------- 1 | # Copyright DataStax, Inc. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """ 16 | Unit tests for the parsing of API endpoints and related 17 | """ 18 | 19 | from __future__ import annotations 20 | 21 | import pytest 22 | 23 | from astrapy.admin import ParsedAPIEndpoint, parse_api_endpoint 24 | from astrapy.info import AstraDBAvailableRegionInfo 25 | 26 | 27 | @pytest.mark.describe("test of parsing API endpoints") 28 | def test_parse_api_endpoint() -> None: 29 | parsed_prod = parse_api_endpoint( 30 | "https://01234567-89ab-cdef-0123-456789abcdef-eu-west-1.apps.astra.datastax.com" 31 | ) 32 | assert isinstance(parsed_prod, ParsedAPIEndpoint) 33 | assert parsed_prod == ParsedAPIEndpoint( 34 | database_id="01234567-89ab-cdef-0123-456789abcdef", 35 | region="eu-west-1", 36 | environment="prod", 37 | ) 38 | 39 | parsed_prod = parse_api_endpoint( 40 | "https://a1234567-89ab-cdef-0123-456789abcdef-us-central1.apps.astra-dev.datastax.com" 41 | ) 42 | assert isinstance(parsed_prod, ParsedAPIEndpoint) 43 | assert parsed_prod == ParsedAPIEndpoint( 44 | database_id="a1234567-89ab-cdef-0123-456789abcdef", 45 | region="us-central1", 46 | environment="dev", 47 | ) 48 | 49 | parsed_prod = parse_api_endpoint( 50 | "https://b1234567-89ab-cdef-0123-456789abcdef-eu-southwest-4.apps.astra-test.datastax.com/subpath?a=1" 51 | ) 52 | assert isinstance(parsed_prod, ParsedAPIEndpoint) 53 | assert parsed_prod == ParsedAPIEndpoint( 54 | database_id="b1234567-89ab-cdef-0123-456789abcdef", 55 | region="eu-southwest-4", 56 | environment="test", 57 | ) 58 | 59 | malformed_endpoints = [ 60 | "http://01234567-89ab-cdef-0123-456789abcdef-us-central1.apps.astra-dev.datastax.com", 61 | "https://a909bdbf-q9ba-4e5e-893c-a859ed701407-us-central1.apps.astra-dev.datastax.com", 62 | "https://01234567-89ab-cdef-0123-456789abcdef-us-c_entral1.apps.astra-dev.datastax.com", 63 | "https://01234567-89ab-cdef-0123-456789abcdef-us-central1.apps.astra-fake.datastax.com", 64 | "https://01234567-89ab-cdef-0123-456789abcdef-us-central1.apps.astra-dev.datastax-staging.com", 65 | "ahttps://01234567-89ab-cdef-0123-456789abcdef-us-central1.apps.astra-dev.datastax.com", 66 | ] 67 | 68 | for m_ep in malformed_endpoints: 69 | assert parse_api_endpoint(m_ep) is None 70 | 71 | 72 | @pytest.mark.describe("test of marshaling of available region info") 73 | def test_parse_availableregioninfo() -> None: 74 | region_dict = { 75 | "classification": "standard", 76 | "cloudProvider": "AWS", 77 | "displayName": "US East (Ohio)", 78 | "enabled": True, 79 | "name": "us-east-2", 80 | "region_type": "vector", 81 | "reservedForQualifiedUsers": False, 82 | "zone": "na", 83 | } 84 | assert AstraDBAvailableRegionInfo._from_dict(region_dict).as_dict() == region_dict 85 | -------------------------------------------------------------------------------- /tests/base/unit/test_multicalltimeoutmanager.py: -------------------------------------------------------------------------------- 1 | # Copyright DataStax, Inc. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | from __future__ import annotations 16 | 17 | import time 18 | 19 | import pytest 20 | 21 | from astrapy.exceptions import ( 22 | DataAPITimeoutException, 23 | DevOpsAPITimeoutException, 24 | MultiCallTimeoutManager, 25 | ) 26 | 27 | 28 | class TestTimeouts: 29 | @pytest.mark.describe("test MultiCallTimeoutManager") 30 | def test_multicalltimeoutmanager(self) -> None: 31 | mgr_n = MultiCallTimeoutManager(overall_timeout_ms=None) 32 | assert mgr_n.remaining_timeout().request_ms is None 33 | time.sleep(0.5) 34 | assert mgr_n.remaining_timeout().request_ms is None 35 | 36 | mgr_1 = MultiCallTimeoutManager(overall_timeout_ms=1000) 37 | crt_1 = mgr_1.remaining_timeout().request_ms 38 | assert crt_1 is not None 39 | time.sleep(0.6) 40 | crt_2 = mgr_1.remaining_timeout().request_ms 41 | assert crt_2 is not None 42 | time.sleep(0.6) 43 | with pytest.raises(DataAPITimeoutException): 44 | mgr_1.remaining_timeout().request_ms 45 | 46 | @pytest.mark.describe("test MultiCallTimeoutManager DevOps") 47 | def test_multicalltimeoutmanager_devops(self) -> None: 48 | mgr_n = MultiCallTimeoutManager(overall_timeout_ms=None, dev_ops_api=True) 49 | assert mgr_n.remaining_timeout().request_ms is None 50 | time.sleep(0.5) 51 | assert mgr_n.remaining_timeout().request_ms is None 52 | 53 | mgr_1 = MultiCallTimeoutManager(overall_timeout_ms=1000, dev_ops_api=True) 54 | crt_1 = mgr_1.remaining_timeout().request_ms 55 | assert crt_1 is not None 56 | time.sleep(0.6) 57 | crt_2 = mgr_1.remaining_timeout().request_ms 58 | assert crt_2 is not None 59 | time.sleep(0.6) 60 | with pytest.raises(DevOpsAPITimeoutException): 61 | mgr_1.remaining_timeout().request_ms 62 | -------------------------------------------------------------------------------- /tests/base/unit/test_regionname_deprecation.py: -------------------------------------------------------------------------------- 1 | # Copyright DataStax, Inc. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | from __future__ import annotations 16 | 17 | import pytest 18 | from deprecation import DeprecatedWarning 19 | 20 | from astrapy.data.info.database_info import ( 21 | AstraDBAdminDatabaseRegionInfo, 22 | AstraDBAvailableRegionInfo, 23 | ) 24 | 25 | from ..conftest import is_future_version 26 | 27 | 28 | class TestRegionNameDeprecation: 29 | @pytest.mark.describe("test of region_name in AstraDBAdminDatabaseRegionInfo") 30 | def test_regionname_dbregioninfo(self) -> None: 31 | dc0 = { 32 | "capacityUnits": 1, 33 | "cloudProvider": "AWS", 34 | "dateCreated": "2025-04-10T19:14:44Z", 35 | "id": "...", 36 | "isPrimary": True, 37 | "name": "dc-1", 38 | "region": "us-east-2", 39 | "regionClassification": "standard", 40 | "regionZone": "na", 41 | "requestedNodeCount": 3, 42 | "secureBundleInternalUrl": "...", 43 | "secureBundleMigrationProxyInternalUrl": "...", 44 | "secureBundleMigrationProxyUrl": "...", 45 | "secureBundleUrl": "...", 46 | "status": "", 47 | "streamingTenant": {}, 48 | "targetAccount": "abcd0123", 49 | "tier": "serverless", 50 | } 51 | region_info = AstraDBAdminDatabaseRegionInfo( 52 | raw_datacenter_dict=dc0, 53 | environment="dev", 54 | database_id="D", 55 | ) 56 | 57 | r0 = region_info.name 58 | with pytest.warns(DeprecationWarning) as w_checker: 59 | r1 = region_info.region_name 60 | assert len(w_checker.list) == 1 61 | warning0 = w_checker.list[0].message 62 | assert isinstance(warning0, DeprecatedWarning) 63 | assert is_future_version(warning0.removed_in) 64 | 65 | assert r0 == r1 66 | 67 | @pytest.mark.describe("test of region_name in AstraDBAvailableRegionInfo") 68 | def test_regionname_availableregion(self) -> None: 69 | ar0 = { 70 | "classification": "premium", 71 | "cloudProvider": "GCP", 72 | "displayName": "West Europe3 (Frankfurt, Germany)", 73 | "enabled": True, 74 | "name": "europe-west3", 75 | "region_type": "serverless", 76 | "reservedForQualifiedUsers": True, 77 | "zone": "emea", 78 | } 79 | available_region_info = AstraDBAvailableRegionInfo._from_dict(ar0) 80 | 81 | r0 = available_region_info.name 82 | with pytest.warns(DeprecationWarning) as w_checker: 83 | r1 = available_region_info.region_name 84 | 85 | assert len(w_checker.list) == 1 86 | warning0 = w_checker.list[0].message 87 | assert isinstance(warning0, DeprecatedWarning) 88 | assert is_future_version(warning0.removed_in) 89 | 90 | assert r0 == r1 91 | -------------------------------------------------------------------------------- /tests/base/unit/test_rerankingheadersprovider.py: -------------------------------------------------------------------------------- 1 | # Copyright DataStax, Inc. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | from __future__ import annotations 16 | 17 | import pytest 18 | 19 | from astrapy.authentication import ( 20 | RerankingAPIKeyHeaderProvider, 21 | coerce_reranking_headers_provider, 22 | ) 23 | from astrapy.settings.defaults import RERANKING_HEADER_API_KEY 24 | 25 | 26 | class TestRerankingHeadersProvider: 27 | @pytest.mark.describe("test of headers from RerankingAPIKeyHeaderProvider") 28 | def test_rerankingheadersprovider_static(self) -> None: 29 | ehp = RerankingAPIKeyHeaderProvider("x") 30 | assert {k.lower(): v for k, v in ehp.get_headers().items()} == { 31 | RERANKING_HEADER_API_KEY.lower(): "x" 32 | } 33 | 34 | @pytest.mark.describe("test of headers from empty RerankingAPIKeyHeaderProvider") 35 | def test_rerankingheadersprovider_null(self) -> None: 36 | ehp = RerankingAPIKeyHeaderProvider(None) 37 | assert ehp.get_headers() == {} 38 | 39 | @pytest.mark.describe("test of reranking headers provider coercion") 40 | def test_rerankingheadersprovider_coercion(self) -> None: 41 | """This doubles as equality test.""" 42 | ehp_s = RerankingAPIKeyHeaderProvider("x") 43 | ehp_n = RerankingAPIKeyHeaderProvider(None) 44 | assert coerce_reranking_headers_provider(ehp_s) == ehp_s 45 | assert coerce_reranking_headers_provider(ehp_n) == ehp_n 46 | 47 | assert coerce_reranking_headers_provider("x") == ehp_s 48 | -------------------------------------------------------------------------------- /tests/base/unit/test_strenum.py: -------------------------------------------------------------------------------- 1 | # Copyright DataStax, Inc. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | from __future__ import annotations 16 | 17 | import pytest 18 | 19 | from astrapy.utils.str_enum import StrEnum 20 | 21 | 22 | class TestEnum(StrEnum): 23 | VALUE = "value" 24 | VALUE_DASH = "value-dash" 25 | 26 | 27 | class TestStrEnum: 28 | def test_strenum_contains(self) -> None: 29 | assert "value" in TestEnum 30 | assert "value_dash" in TestEnum 31 | assert "value-dash" in TestEnum 32 | assert "VALUE-DASH" in TestEnum 33 | assert "pippo" not in TestEnum 34 | assert {6: 12} not in TestEnum 35 | 36 | def test_strenum_coerce(self) -> None: 37 | TestEnum.coerce("value") 38 | TestEnum.coerce("value_dash") 39 | TestEnum.coerce("value-dash") 40 | TestEnum.coerce("VALUE-DASH") 41 | TestEnum.coerce(TestEnum.VALUE) 42 | with pytest.raises(ValueError): 43 | TestEnum.coerce("pippo") 44 | -------------------------------------------------------------------------------- /tests/base/unit/test_table_decimal_support.py: -------------------------------------------------------------------------------- 1 | # Copyright DataStax, Inc. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | from __future__ import annotations 16 | 17 | from typing import Any 18 | 19 | import pytest 20 | 21 | from astrapy.constants import DefaultRowType 22 | from astrapy.data.utils.table_converters import _TableConverterAgent 23 | from astrapy.utils.api_commander import APICommander 24 | from astrapy.utils.api_options import FullSerdesOptions, defaultAPIOptions 25 | 26 | from ..conftest import _repaint_NaNs 27 | from ..table_decimal_support_assets import ( 28 | BASELINE_COLUMNS, 29 | BASELINE_KEY_STR, 30 | BASELINE_OBJ, 31 | COLLTYPES_CUSTOM_OBJ_SCHEMA_TRIPLES, 32 | COLLTYPES_TRIPLE_IDS, 33 | WDECS_KEY_STR, 34 | WDECS_OBJ, 35 | WDECS_OBJ_COLUMNS, 36 | ) 37 | 38 | 39 | class TestTableDecimalSupportUnit: 40 | @pytest.mark.describe("test of decimal-related conversions in table codec paths") 41 | def test_decimalsupport_table_codecpath(self) -> None: 42 | t_agent: _TableConverterAgent[DefaultRowType] = _TableConverterAgent( 43 | options=defaultAPIOptions(environment="prod").serdes_options, 44 | ) 45 | # baseline, encode then decode and check 46 | baseline_fully_encoded = APICommander._decimal_aware_encode_payload( 47 | t_agent.preprocess_payload(BASELINE_OBJ, map2tuple_checker=None) 48 | ) 49 | baseline_obj_2 = t_agent.postprocess_row( 50 | APICommander._decimal_aware_parse_json_response(baseline_fully_encoded), # type: ignore[arg-type] 51 | columns_dict=BASELINE_COLUMNS, 52 | similarity_pseudocolumn=None, 53 | ) 54 | assert _repaint_NaNs(baseline_obj_2) == _repaint_NaNs(BASELINE_OBJ) 55 | # with-decimals, encode then decode and check 56 | wdecs_fully_encoded = APICommander._decimal_aware_encode_payload( 57 | t_agent.preprocess_payload(WDECS_OBJ, map2tuple_checker=None) 58 | ) 59 | wdecs_2 = t_agent.postprocess_row( 60 | APICommander._decimal_aware_parse_json_response(wdecs_fully_encoded), # type: ignore[arg-type] 61 | columns_dict=WDECS_OBJ_COLUMNS, 62 | similarity_pseudocolumn=None, 63 | ) 64 | assert _repaint_NaNs(wdecs_2) == _repaint_NaNs(WDECS_OBJ) 65 | 66 | # baseline, keys (decode only) 67 | baseline_kobj_2 = t_agent.postprocess_key( 68 | APICommander._decimal_aware_parse_json_response(BASELINE_KEY_STR), # type: ignore[arg-type] 69 | primary_key_schema_dict=BASELINE_COLUMNS, 70 | )[1] 71 | assert _repaint_NaNs(baseline_kobj_2) == _repaint_NaNs(BASELINE_OBJ) 72 | # with-decimals, keys (decode only) 73 | wdecs_kobj_2 = t_agent.postprocess_key( 74 | APICommander._decimal_aware_parse_json_response(WDECS_KEY_STR), # type: ignore[arg-type] 75 | primary_key_schema_dict=WDECS_OBJ_COLUMNS, 76 | )[1] 77 | assert _repaint_NaNs(wdecs_kobj_2) == _repaint_NaNs(WDECS_OBJ) 78 | 79 | @pytest.mark.parametrize( 80 | ("colltype_custom", "colltype_obj", "colltype_columns"), 81 | COLLTYPES_CUSTOM_OBJ_SCHEMA_TRIPLES, 82 | ids=COLLTYPES_TRIPLE_IDS, 83 | ) 84 | @pytest.mark.describe( 85 | "test of decimal-related conversions in table codec paths, collection types" 86 | ) 87 | def test_decimalsupport_table_collectiontypes_codecpath( 88 | self, 89 | colltype_custom: bool, 90 | colltype_obj: dict[str, Any], 91 | colltype_columns: dict[str, Any], 92 | ) -> None: 93 | t_agent: _TableConverterAgent[DefaultRowType] = _TableConverterAgent( 94 | options=FullSerdesOptions( 95 | binary_encode_vectors=True, 96 | custom_datatypes_in_reading=colltype_custom, 97 | unroll_iterables_to_lists=True, 98 | use_decimals_in_collections=True, 99 | encode_maps_as_lists_in_tables="never", 100 | accept_naive_datetimes=False, 101 | datetime_tzinfo=None, 102 | ), 103 | ) 104 | 105 | fully_encoded = APICommander._decimal_aware_encode_payload( 106 | t_agent.preprocess_payload(colltype_obj, map2tuple_checker=None) 107 | ) 108 | colltype_obj_2 = t_agent.postprocess_row( 109 | APICommander._decimal_aware_parse_json_response(fully_encoded), # type: ignore[arg-type] 110 | columns_dict=colltype_columns, 111 | similarity_pseudocolumn=None, 112 | ) 113 | assert _repaint_NaNs(colltype_obj_2) == _repaint_NaNs(colltype_obj) 114 | -------------------------------------------------------------------------------- /tests/base/unit/test_token_providers.py: -------------------------------------------------------------------------------- 1 | # Copyright DataStax, Inc. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """ 16 | Unit tests for the token providers 17 | """ 18 | 19 | from __future__ import annotations 20 | 21 | import pytest 22 | 23 | from astrapy.authentication import ( 24 | StaticTokenProvider, 25 | UsernamePasswordTokenProvider, 26 | coerce_token_provider, 27 | ) 28 | 29 | 30 | @pytest.mark.describe("test of static token provider") 31 | def test_static_token_provider() -> None: 32 | literal_t = "AstraCS:1" 33 | static_tp = StaticTokenProvider(literal_t) 34 | 35 | assert static_tp.get_token() == literal_t 36 | 37 | 38 | @pytest.mark.describe("test of username-password token provider") 39 | def test_username_password_token_provider() -> None: 40 | up_tp = UsernamePasswordTokenProvider("cassandraA", "cassandraB") 41 | 42 | assert up_tp.get_token() == "Cassandra:Y2Fzc2FuZHJhQQ==:Y2Fzc2FuZHJhQg==" 43 | 44 | 45 | @pytest.mark.describe("test of null token provider") 46 | def test_null_token_provider() -> None: 47 | null_tp = StaticTokenProvider(None) 48 | 49 | assert null_tp.get_token() is None 50 | 51 | 52 | @pytest.mark.describe("test of token providers coercion") 53 | def test_coerce_token_provider() -> None: 54 | literal_t = "AstraCS:1" 55 | static_tp = StaticTokenProvider(literal_t) 56 | null_tp = StaticTokenProvider(None) 57 | up_tp = UsernamePasswordTokenProvider("cassandraA", "cassandraB") 58 | 59 | assert coerce_token_provider(literal_t).get_token() == literal_t 60 | assert coerce_token_provider(static_tp).get_token() == static_tp.get_token() 61 | assert coerce_token_provider(up_tp).get_token() == up_tp.get_token() 62 | assert coerce_token_provider(null_tp).get_token() is None 63 | 64 | 65 | @pytest.mark.describe("test of token providers equality") 66 | def test_token_provider_equality() -> None: 67 | literal_t = "AstraCS:1" 68 | static_tp_1 = StaticTokenProvider(literal_t) 69 | null_tp_1 = StaticTokenProvider(None) 70 | up_tp_1 = UsernamePasswordTokenProvider("cassandraA", "cassandraB") 71 | static_tp_2 = StaticTokenProvider(literal_t) 72 | null_tp_2 = StaticTokenProvider(None) 73 | up_tp_2 = UsernamePasswordTokenProvider("cassandraA", "cassandraB") 74 | 75 | assert static_tp_1 == static_tp_2 76 | assert null_tp_1 == null_tp_2 77 | assert up_tp_1 == up_tp_2 78 | 79 | assert static_tp_1 != null_tp_1 80 | assert static_tp_1 != up_tp_1 81 | assert null_tp_1 != static_tp_1 82 | assert null_tp_1 != up_tp_1 83 | assert up_tp_1 != static_tp_1 84 | assert up_tp_1 != null_tp_1 85 | 86 | # effective equality 87 | assert up_tp_1 == StaticTokenProvider("Cassandra:Y2Fzc2FuZHJhQQ==:Y2Fzc2FuZHJhQg==") 88 | 89 | 90 | @pytest.mark.describe("test of token provider inheritance yield") 91 | def test_token_provider_inheritance_yield() -> None: 92 | static_tp = StaticTokenProvider("AstraCS:xyz") 93 | null_tp = StaticTokenProvider(None) 94 | up_tp = UsernamePasswordTokenProvider("cassandraA", "cassandraB") 95 | 96 | assert static_tp | static_tp == static_tp 97 | assert static_tp | null_tp == static_tp 98 | assert static_tp | up_tp == static_tp 99 | 100 | assert null_tp | static_tp == static_tp 101 | assert null_tp | null_tp == null_tp 102 | assert null_tp | up_tp == up_tp 103 | 104 | assert up_tp | static_tp == up_tp 105 | assert up_tp | null_tp == up_tp 106 | assert up_tp | up_tp == up_tp 107 | 108 | assert static_tp or static_tp == static_tp 109 | assert static_tp or null_tp == static_tp 110 | assert static_tp or up_tp == static_tp 111 | 112 | assert null_tp or static_tp == static_tp 113 | assert null_tp or null_tp == null_tp 114 | assert null_tp or up_tp == up_tp 115 | 116 | assert up_tp or static_tp == up_tp 117 | assert up_tp or null_tp == up_tp 118 | assert up_tp or up_tp == up_tp 119 | -------------------------------------------------------------------------------- /tests/dse_compose/README: -------------------------------------------------------------------------------- 1 | Caution: DSE does not support all features required for a full Data API usage. 2 | (for example BM25 indexing, necessary for the Data API reranker). 3 | 4 | As such, do not expect all of the features (/testing) to work if running this compose. 5 | -------------------------------------------------------------------------------- /tests/dse_compose/docker-compose.yml: -------------------------------------------------------------------------------- 1 | services: 2 | 3 | # complement with something like: 4 | # docker run -it --rm --network container:dse_compose-coordinator-1 cassandra:latest cqlsh -u cassandra -p cassandra 5 | # create keyspace default_keyspace WITH replication = {'class': 'SimpleStrategy', 'replication_factor': 1}; 6 | 7 | # Development mode with latest tag 8 | coordinator: 9 | image: stargateio/coordinator-dse-next:v2.1 10 | networks: 11 | - stargate 12 | ports: 13 | - "9042:9042" 14 | - "8081:8081" 15 | - "8090:8090" 16 | mem_limit: 2G 17 | environment: 18 | - JAVA_OPTS="-Xmx2G" 19 | - CLUSTER_NAME=sgv2-cluster 20 | - RACK_NAME=rack1 21 | - DATACENTER_NAME=datacenter1 22 | - ENABLE_AUTH=true 23 | - DEVELOPER_MODE=true 24 | healthcheck: 25 | test: curl -f http://localhost:8084/checker/readiness || exit 1 26 | interval: 15s 27 | timeout: 10s 28 | retries: 10 29 | 30 | data-api: 31 | image: stargateio/data-api:v1.0.25 32 | depends_on: 33 | coordinator: 34 | condition: service_healthy 35 | networks: 36 | - stargate 37 | ports: 38 | - "8181:8181" 39 | mem_limit: 2G 40 | environment: 41 | #- QUARKUS_GRPC_CLIENTS_BRIDGE_HOST=coordinator 42 | #- QUARKUS_GRPC_CLIENTS_BRIDGE_PORT=8091 43 | - STARGATE_DATA_STORE_SAI_ENABLED=true 44 | - STARGATE_DATA_STORE_VECTOR_SEARCH_ENABLED=true 45 | - STARGATE_JSONAPI_OPERATIONS_VECTORIZE_ENABLED=true 46 | - STARGATE_FEATURE_FLAGS_TABLES=true 47 | - STARGATE_DATA_STORE_IGNORE_BRIDGE=true 48 | - STARGATE_JSONAPI_OPERATIONS_DATABASE_CONFIG_CASSANDRA_END_POINTS=coordinator 49 | - QUARKUS_HTTP_ACCESS_LOG_ENABLED=FALSE 50 | - QUARKUS_LOG_LEVEL=INFO 51 | - JAVA_MAX_MEM_RATIO=75 52 | - JAVA_INITIAL_MEM_RATIO=50 53 | - GC_CONTAINER_OPTIONS=-XX:+UseG1GC 54 | - JAVA_OPTS_APPEND=-Dquarkus.http.host=0.0.0.0 -Djava.util.logging.manager=org.jboss.logmanager.LogManager 55 | healthcheck: 56 | test: curl -f http://localhost:8181/stargate/health || exit 1 57 | interval: 5s 58 | timeout: 10s 59 | retries: 10 60 | 61 | networks: 62 | stargate: -------------------------------------------------------------------------------- /tests/env_templates/env.astra.admin.template: -------------------------------------------------------------------------------- 1 | ######################## 2 | # FOR THE ADMIN TESTS: # 3 | ######################## 4 | 5 | # PROD settings 6 | export PROD_ADMIN_TEST_ASTRA_DB_APPLICATION_TOKEN="AstraCS:..." 7 | export PROD_ADMIN_TEST_ASTRA_DB_PROVIDER="aws" 8 | export PROD_ADMIN_TEST_ASTRA_DB_REGION="eu-west-1" 9 | 10 | # DEV settings (optional) 11 | export DEV_ADMIN_TEST_ASTRA_DB_APPLICATION_TOKEN="AstraCS:..." 12 | export DEV_ADMIN_TEST_ASTRA_DB_PROVIDER="aws" 13 | export DEV_ADMIN_TEST_ASTRA_DB_REGION="us-west-2" 14 | 15 | # TEST settings (optional) 16 | export TEST_ADMIN_TEST_ASTRA_DB_APPLICATION_TOKEN="AstraCS:..." 17 | export TEST_ADMIN_TEST_ASTRA_DB_PROVIDER="aws" 18 | export TEST_ADMIN_TEST_ASTRA_DB_REGION="us-west-2" 19 | -------------------------------------------------------------------------------- /tests/env_templates/env.astra.template: -------------------------------------------------------------------------------- 1 | ############################## 2 | # FOR THE TESTS ON ASTRA DB: # 3 | ############################## 4 | 5 | export ASTRA_DB_APPLICATION_TOKEN="AstraCS:..." 6 | 7 | export ASTRA_DB_API_ENDPOINT="https://-.apps.astra.datastax.com" 8 | 9 | # OPTIONAL (the first has a default; a few tests are skipped if the second is missing): 10 | # export ASTRA_DB_KEYSPACE="..." 11 | # export ASTRA_DB_SECONDARY_KEYSPACE="..." 12 | -------------------------------------------------------------------------------- /tests/env_templates/env.local.template: -------------------------------------------------------------------------------- 1 | ######################################## 2 | # FOR THE TESTS WITH A LOCAL DATA API: # 3 | ######################################## 4 | 5 | # Authentication: 6 | export LOCAL_DATA_API_USERNAME="cassandra" 7 | export LOCAL_DATA_API_PASSWORD="cassandra" 8 | 9 | export LOCAL_DATA_API_ENDPOINT="http://localhost:8181" 10 | export LOCAL_CASSANDRA_CONTACT_POINT="127.0.0.1" 11 | 12 | # OPTIONAL: (if defined here, they will be created as needed) 13 | export LOCAL_DATA_API_KEYSPACE="default_keyspace" 14 | export LOCAL_DATA_API_SECONDARY_KEYSPACE="alternate_keyspace" 15 | 16 | # RERANKER SETTINGS 17 | export ASTRAPY_FINDANDRERANK_USE_RERANKER_HEADER="y" 18 | # Local runs require a DEV token as reranker API Key: 19 | export HEADER_RERANKING_API_KEY_NVIDIA="AstraCS:[...]" 20 | -------------------------------------------------------------------------------- /tests/env_templates/env.testcontainers.template: -------------------------------------------------------------------------------- 1 | ###################################### 2 | # FOR THE TESTS WITH TESTCONTAINERS: # 3 | ###################################### 4 | 5 | export DOCKER_COMPOSE_LOCAL_DATA_API="yes" 6 | -------------------------------------------------------------------------------- /tests/env_templates/env.vectorize-minimal.template: -------------------------------------------------------------------------------- 1 | #################################### 2 | # FOR THE VECTORIZE TESTS IN BASE: # 3 | #################################### 4 | 5 | export HEADER_EMBEDDING_API_KEY_OPENAI="..." 6 | -------------------------------------------------------------------------------- /tests/env_templates/env.vectorize.template: -------------------------------------------------------------------------------- 1 | ##################################### 2 | # FOR THE EXTENDED VECTORIZE TESTS: # 3 | ##################################### 4 | 5 | 6 | export HEADER_EMBEDDING_API_KEY_HUGGINGFACE="..." 7 | 8 | export HEADER_EMBEDDING_API_KEY_COHERE="..." 9 | 10 | export HEADER_EMBEDDING_API_KEY_VOYAGEAI="..." 11 | 12 | export HEADER_EMBEDDING_API_KEY_MISTRAL="..." 13 | 14 | export HEADER_EMBEDDING_API_KEY_UPSTAGE="..." 15 | 16 | export HEADER_EMBEDDING_API_KEY_OPENAI="..." 17 | export OPENAI_ORGANIZATION_ID="..." 18 | export OPENAI_PROJECT_ID="..." 19 | 20 | export HEADER_EMBEDDING_API_KEY_AZURE_OPENAI="..." 21 | export AZURE_OPENAI_DEPLOY_ID_EMB3LARGE="..." 22 | export AZURE_OPENAI_RESNAME_EMB3LARGE="..." 23 | export AZURE_OPENAI_DEPLOY_ID_EMB3SMALL="..." 24 | export AZURE_OPENAI_RESNAME_EMB3SMALL="..." 25 | export AZURE_OPENAI_DEPLOY_ID_ADA2="..." 26 | export AZURE_OPENAI_RESNAME_ADA2="..." 27 | 28 | export HEADER_EMBEDDING_API_KEY_JINAAI="..." 29 | 30 | export HEADER_EMBEDDING_API_KEY_VERTEXAI="..." 31 | 32 | export HEADER_EMBEDDING_VERTEXAI_PROJECT_ID="..." 33 | 34 | export HEADER_EMBEDDING_API_KEY_HUGGINGFACEDED="..." 35 | export HUGGINGFACEDED_DIMENSION="..." 36 | export HUGGINGFACEDED_ENDPOINTNAME="..." 37 | export HUGGINGFACEDED_REGIONNAME="..." 38 | export HUGGINGFACEDED_CLOUDNAME="..." 39 | 40 | # Additional SHARED_SECRET testing information 41 | # (Preparation to be done in the UI at the moment) 42 | # Scope these secrets, with these names, to the targeted DB(s): 43 | # SHARED_SECRET_EMBEDDING_API_KEY_AZURE_OPENAI 44 | # SHARED_SECRET_EMBEDDING_API_KEY_HUGGINGFACE 45 | # SHARED_SECRET_EMBEDDING_API_KEY_HUGGINGFACEDED 46 | # SHARED_SECRET_EMBEDDING_API_KEY_JINAAI 47 | # SHARED_SECRET_EMBEDDING_API_KEY_MISTRAL 48 | # SHARED_SECRET_EMBEDDING_API_KEY_OPENAI 49 | # SHARED_SECRET_EMBEDDING_API_KEY_UPSTAGE 50 | # SHARED_SECRET_EMBEDDING_API_KEY_VOYAGEAI 51 | -------------------------------------------------------------------------------- /tests/hcd_compose/docker-compose.yml: -------------------------------------------------------------------------------- 1 | services: 2 | hcd: 3 | image: 559669398656.dkr.ecr.us-west-2.amazonaws.com/engops-shared/hcd/staging/hcd:1.2.1-early-preview 4 | networks: 5 | - stargate 6 | mem_limit: 2G 7 | environment: 8 | - MAX_HEAP_SIZE=1536M 9 | - CLUSTER_NAME=hcd-1.2.1-early-preview-cluster 10 | - DS_LICENSE=accept 11 | - HCD_AUTO_CONF_OFF=cassandra.yaml 12 | volumes: 13 | - ./cassandra-hcd.yaml:/opt/hcd/resources/cassandra/conf/cassandra.yaml:rw 14 | ports: 15 | - "9042:9042" 16 | healthcheck: 17 | test: [ "CMD-SHELL", "cqlsh -u cassandra -p cassandra -e 'describe keyspaces'" ] 18 | interval: 15s 19 | timeout: 10s 20 | retries: 20 21 | 22 | data-api: 23 | image: stargateio/data-api:v1.0.25 24 | depends_on: 25 | hcd: 26 | condition: service_healthy 27 | networks: 28 | - stargate 29 | ports: 30 | - "8181:8181" 31 | mem_limit: 2G 32 | environment: 33 | - JAVA_MAX_MEM_RATIO=75 34 | - JAVA_INITIAL_MEM_RATIO=50 35 | - STARGATE_DATA_STORE_IGNORE_BRIDGE=true 36 | - GC_CONTAINER_OPTIONS=-XX:+UseG1GC 37 | - STARGATE_JSONAPI_OPERATIONS_DATABASE_CONFIG_CASSANDRA_END_POINTS=hcd 38 | - STARGATE_JSONAPI_OPERATIONS_DATABASE_CONFIG_LOCAL_DATACENTER=dc1 39 | - QUARKUS_HTTP_ACCESS_LOG_ENABLED=FALSE 40 | - QUARKUS_LOG_LEVEL=INFO 41 | - STARGATE_JSONAPI_OPERATIONS_VECTORIZE_ENABLED=true 42 | - STARGATE_FEATURE_FLAGS_TABLES=true 43 | - STARGATE_FEATURE_FLAGS_RERANKING=true 44 | - JAVA_OPTS_APPEND=-Dquarkus.http.host=0.0.0.0 -Djava.util.logging.manager=org.jboss.logmanager.LogManager 45 | healthcheck: 46 | test: curl -f http://localhost:8181/stargate/health || exit 1 47 | interval: 5s 48 | timeout: 10s 49 | retries: 10 50 | networks: 51 | stargate: 52 | -------------------------------------------------------------------------------- /tests/vectorize/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright DataStax, Inc. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | from __future__ import annotations 16 | -------------------------------------------------------------------------------- /tests/vectorize/conftest.py: -------------------------------------------------------------------------------- 1 | # Copyright DataStax, Inc. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | from __future__ import annotations 16 | 17 | from ..conftest import ( 18 | IS_ASTRA_DB, 19 | ) 20 | 21 | __all__ = [ 22 | "IS_ASTRA_DB", 23 | ] 24 | -------------------------------------------------------------------------------- /tests/vectorize/integration/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright DataStax, Inc. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | from __future__ import annotations 16 | -------------------------------------------------------------------------------- /tests/vectorize/live_provider_info.py: -------------------------------------------------------------------------------- 1 | # Copyright DataStax, Inc. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | from __future__ import annotations 16 | 17 | from preprocess_env import ( 18 | ASTRA_DB_API_ENDPOINT, 19 | ASTRA_DB_KEYSPACE, 20 | ASTRA_DB_TOKEN_PROVIDER, 21 | IS_ASTRA_DB, 22 | LOCAL_DATA_API_ENDPOINT, 23 | LOCAL_DATA_API_KEYSPACE, 24 | LOCAL_DATA_API_TOKEN_PROVIDER, 25 | ) 26 | 27 | from astrapy import DataAPIClient, Database 28 | from astrapy.admin import parse_api_endpoint 29 | from astrapy.constants import Environment 30 | from astrapy.info import FindEmbeddingProvidersResult 31 | 32 | 33 | def live_provider_info() -> FindEmbeddingProvidersResult: 34 | """ 35 | Query the API endpoint `findEmbeddingProviders` endpoint 36 | for the latest information. 37 | 38 | This utility function uses the environment variables it can find 39 | to establish a target database to query. 40 | """ 41 | 42 | database: Database 43 | if IS_ASTRA_DB: 44 | parsed = parse_api_endpoint(ASTRA_DB_API_ENDPOINT) 45 | if parsed is None: 46 | raise ValueError( 47 | "Cannot parse the Astra DB API Endpoint '{ASTRA_DB_API_ENDPOINT}'" 48 | ) 49 | client = DataAPIClient(environment=parsed.environment) 50 | database = client.get_database( 51 | ASTRA_DB_API_ENDPOINT, 52 | token=ASTRA_DB_TOKEN_PROVIDER, 53 | keyspace=ASTRA_DB_KEYSPACE, 54 | ) 55 | else: 56 | client = DataAPIClient(environment=Environment.OTHER) 57 | database = client.get_database( 58 | LOCAL_DATA_API_ENDPOINT, 59 | token=LOCAL_DATA_API_TOKEN_PROVIDER, 60 | keyspace=LOCAL_DATA_API_KEYSPACE, 61 | ) 62 | 63 | database_admin = database.get_database_admin() 64 | response = database_admin.find_embedding_providers() 65 | return response 66 | -------------------------------------------------------------------------------- /tests/vectorize/query_providers.py: -------------------------------------------------------------------------------- 1 | # Copyright DataStax, Inc. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | from __future__ import annotations 16 | 17 | import json 18 | import os 19 | import sys 20 | 21 | from astrapy.info import EmbeddingProviderParameter, FindEmbeddingProvidersResult 22 | 23 | sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))) 24 | 25 | from live_provider_info import live_provider_info 26 | from vectorize_models import live_test_models 27 | 28 | 29 | def desc_param(param_data: EmbeddingProviderParameter) -> str: 30 | if param_data.parameter_type.lower() == "string": 31 | return "str" 32 | elif param_data.parameter_type.lower() == "number": 33 | validation = param_data.validation 34 | if "numericRange" in validation: 35 | validation_nr = validation["numericRange"] 36 | assert isinstance(validation_nr, list) and len(validation_nr) == 2 37 | range_desc = f"[{validation_nr[0]} : {validation_nr[1]}]" 38 | if param_data.default_value is not None: 39 | range_desc2 = f"{range_desc} (default={param_data.default_value})" 40 | else: 41 | range_desc2 = range_desc 42 | return f"number, {range_desc2}" 43 | elif "options" in validation: 44 | validation_op = validation["options"] 45 | assert isinstance(validation_op, list) and len(validation_op) > 1 46 | return f"number, {' / '.join(str(v) for v in validation_op)}" 47 | else: 48 | raise ValueError( 49 | f"Unknown number validation spec: '{json.dumps(validation)}'" 50 | ) 51 | elif param_data.parameter_type.lower() == "boolean": 52 | return "bool" 53 | else: 54 | raise NotImplementedError 55 | 56 | 57 | if __name__ == "__main__": 58 | provider_info: FindEmbeddingProvidersResult = live_provider_info() 59 | providers_json = (provider_info.raw_info or {}).get("embeddingProviders") 60 | if not providers_json: 61 | raise ValueError( 62 | "raw info from embedding providers lacks `embeddingProviders` content." 63 | ) 64 | json.dump(providers_json, open("_providers.json", "w"), indent=2, sort_keys=True) 65 | 66 | for provider, provider_data in sorted(provider_info.embedding_providers.items()): 67 | print(f"{provider} ({len(provider_data.models)} models)") 68 | print(" auth:") 69 | for auth_type, auth_data in sorted( 70 | provider_data.supported_authentication.items() 71 | ): 72 | if auth_data.enabled: 73 | tokens = ", ".join(f"'{tok.accepted}'" for tok in auth_data.tokens) 74 | print(f" {auth_type} ({tokens})") 75 | if provider_data.parameters: 76 | print(" parameters") 77 | for param_data in provider_data.parameters: 78 | param_name = param_data.name 79 | if param_data.required: 80 | param_display_name = param_name 81 | else: 82 | param_display_name = f"({param_name})" 83 | param_desc = desc_param(param_data) 84 | print(f" - {param_display_name}: {param_desc}") 85 | print(" models:") 86 | for model_data in sorted(provider_data.models, key=lambda pro: pro.name): 87 | model_name = model_data.name 88 | if model_data.vector_dimension is not None: 89 | assert model_data.vector_dimension > 0 90 | model_dim_desc = f" (D = {model_data.vector_dimension})" 91 | else: 92 | model_dim_desc = "" 93 | if True: 94 | print(f" {model_name}{model_dim_desc}") 95 | if model_data.parameters: 96 | for param_data in model_data.parameters: 97 | param_name = param_data.name 98 | if param_data.required: 99 | param_display_name = param_name 100 | else: 101 | param_display_name = f"({param_name})" 102 | param_desc = desc_param(param_data) 103 | print(f" - {param_display_name}: {param_desc}") 104 | 105 | print("\n" * 2) 106 | all_test_models = list(live_test_models()) 107 | for auth_type in ["HEADER", "NONE", "SHARED_SECRET"]: 108 | print(f"Tags for auth type {auth_type}:", end="") 109 | # 110 | at_test_models = [ 111 | test_model 112 | for test_model in all_test_models 113 | if test_model["auth_type_name"] == auth_type 114 | ] 115 | at_model_ids: list[str] = sorted( 116 | [str(model_desc["model_tag"]) for model_desc in at_test_models] 117 | ) 118 | if at_model_ids: 119 | print("") 120 | print("\n".join(f" {ami}" for ami in at_model_ids)) 121 | else: 122 | print(" (no tags)") 123 | --------------------------------------------------------------------------------