├── .github
└── workflows
│ ├── lint.yml
│ ├── local.yaml
│ ├── main.yml
│ └── unit.yaml
├── .gitignore
├── .pre-commit-config.yaml
├── CHANGES
├── CODE_OF_CONDUCT.md
├── CONTRIBUTING.md
├── LICENSE.md
├── Makefile
├── README.md
├── astrapy
├── __init__.py
├── admin
│ ├── __init__.py
│ ├── admin.py
│ └── endpoints.py
├── api_options.py
├── authentication.py
├── client.py
├── collection.py
├── constants.py
├── cursors.py
├── data
│ ├── __init__.py
│ ├── collection.py
│ ├── cursors
│ │ ├── __init__.py
│ │ ├── cursor.py
│ │ ├── farr_cursor.py
│ │ ├── find_cursor.py
│ │ ├── query_engine.py
│ │ └── reranked_result.py
│ ├── database.py
│ ├── info
│ │ ├── collection_descriptor.py
│ │ ├── database_info.py
│ │ ├── reranking.py
│ │ ├── table_descriptor
│ │ │ ├── table_altering.py
│ │ │ ├── table_columns.py
│ │ │ ├── table_creation.py
│ │ │ ├── table_indexes.py
│ │ │ └── table_listing.py
│ │ └── vectorize.py
│ ├── table.py
│ └── utils
│ │ ├── __init__.py
│ │ ├── collection_converters.py
│ │ ├── distinct_extractors.py
│ │ ├── extended_json_converters.py
│ │ ├── table_converters.py
│ │ ├── table_types.py
│ │ └── vector_coercion.py
├── data_types
│ ├── __init__.py
│ ├── data_api_date.py
│ ├── data_api_duration.py
│ ├── data_api_map.py
│ ├── data_api_set.py
│ ├── data_api_time.py
│ ├── data_api_timestamp.py
│ └── data_api_vector.py
├── database.py
├── exceptions
│ ├── __init__.py
│ ├── collection_exceptions.py
│ ├── data_api_exceptions.py
│ ├── devops_api_exceptions.py
│ └── table_exceptions.py
├── ids.py
├── info.py
├── py.typed
├── results.py
├── settings
│ ├── __init__.py
│ ├── defaults.py
│ └── error_messages.py
├── table.py
└── utils
│ ├── __init__.py
│ ├── api_commander.py
│ ├── api_options.py
│ ├── date_utils.py
│ ├── document_paths.py
│ ├── duration_c_utils.py
│ ├── duration_std_utils.py
│ ├── meta.py
│ ├── parsing.py
│ ├── request_tools.py
│ ├── str_enum.py
│ ├── unset.py
│ └── user_agents.py
├── pictures
├── astrapy_abstractions.png
├── astrapy_datetime_serdes_options.png
└── astrapy_exceptions.png
├── pyproject.toml
├── tests
├── __init__.py
├── admin
│ ├── __init__.py
│ ├── conftest.py
│ └── integration
│ │ ├── __init__.py
│ │ ├── test_admin.py
│ │ └── test_nonastra_admin.py
├── base
│ ├── __init__.py
│ ├── collection_decimal_support_assets.py
│ ├── conftest.py
│ ├── integration
│ │ ├── __init__.py
│ │ ├── collection_decimal_support_assets.py
│ │ ├── collections
│ │ │ ├── __init__.py
│ │ │ ├── test_collection_cursor_async.py
│ │ │ ├── test_collection_cursor_sync.py
│ │ │ ├── test_collection_ddl_async.py
│ │ │ ├── test_collection_ddl_sync.py
│ │ │ ├── test_collection_decimal_support.py
│ │ │ ├── test_collection_dml_async.py
│ │ │ ├── test_collection_dml_sync.py
│ │ │ ├── test_collection_exceptions_async.py
│ │ │ ├── test_collection_exceptions_sync.py
│ │ │ ├── test_collection_farr_async.py
│ │ │ ├── test_collection_farr_sync.py
│ │ │ ├── test_collection_farrcursor_async.py
│ │ │ ├── test_collection_farrcursor_sync.py
│ │ │ ├── test_collection_timeout_async.py
│ │ │ ├── test_collection_timeout_sync.py
│ │ │ ├── test_collection_typing.py
│ │ │ ├── test_collection_vectorize_methods_async.py
│ │ │ └── test_collection_vectorize_methods_sync.py
│ │ ├── conftest.py
│ │ ├── misc
│ │ │ ├── __init__.py
│ │ │ ├── test_admin_ops_async.py
│ │ │ ├── test_admin_ops_sync.py
│ │ │ ├── test_reranking_ops_async.py
│ │ │ ├── test_reranking_ops_sync.py
│ │ │ ├── test_vectorize_ops_async.py
│ │ │ └── test_vectorize_ops_sync.py
│ │ └── tables
│ │ │ ├── __init__.py
│ │ │ ├── table_cql_assets.py
│ │ │ ├── table_row_assets.py
│ │ │ ├── test_table_column_types_async.py
│ │ │ ├── test_table_column_types_sync.py
│ │ │ ├── test_table_cqldriven_dml_sync.py
│ │ │ ├── test_table_cursor_async.py
│ │ │ ├── test_table_cursor_sync.py
│ │ │ ├── test_table_dml_async.py
│ │ │ ├── test_table_dml_sync.py
│ │ │ ├── test_table_lifecycle_async.py
│ │ │ ├── test_table_lifecycle_sync.py
│ │ │ ├── test_table_mapsastuples_async.py
│ │ │ ├── test_table_mapsastuples_sync.py
│ │ │ ├── test_table_typing.py
│ │ │ ├── test_table_vectorize_async.py
│ │ │ └── test_table_vectorize_sync.py
│ ├── table_decimal_support_assets.py
│ ├── table_structure_assets.py
│ └── unit
│ │ ├── __init__.py
│ │ ├── test_admin_conversions.py
│ │ ├── test_apicommander.py
│ │ ├── test_apioptions.py
│ │ ├── test_collection_decimal_support.py
│ │ ├── test_collection_options.py
│ │ ├── test_collection_timeouts.py
│ │ ├── test_collections_async.py
│ │ ├── test_collections_sync.py
│ │ ├── test_collectionvectorserviceoptions.py
│ │ ├── test_dataapidate.py
│ │ ├── test_dataapiduration.py
│ │ ├── test_dataapimap.py
│ │ ├── test_dataapiset.py
│ │ ├── test_dataapitime.py
│ │ ├── test_dataapitimestamp.py
│ │ ├── test_dataapivector.py
│ │ ├── test_databases_async.py
│ │ ├── test_databases_sync.py
│ │ ├── test_datetime_serdes_options.py
│ │ ├── test_document_extractors.py
│ │ ├── test_document_paths.py
│ │ ├── test_embeddingheadersprovider.py
│ │ ├── test_exceptions.py
│ │ ├── test_findandrerank_collectiondefinition.py
│ │ ├── test_findembeddingproviderresult.py
│ │ ├── test_findrerankingproviderresult.py
│ │ ├── test_ids.py
│ │ ├── test_imports.py
│ │ ├── test_info.py
│ │ ├── test_multicalltimeoutmanager.py
│ │ ├── test_regionname_deprecation.py
│ │ ├── test_rerankingheadersprovider.py
│ │ ├── test_strenum.py
│ │ ├── test_table_decimal_support.py
│ │ ├── test_table_dry_methods.py
│ │ ├── test_tableconverteragent.py
│ │ ├── test_tabledescriptors.py
│ │ ├── test_tableindexdescriptor_parsing.py
│ │ ├── test_timeouts.py
│ │ ├── test_token_providers.py
│ │ ├── test_tpostprocessors.py
│ │ └── test_tpreprocessors_mapastuples.py
├── conftest.py
├── dse_compose
│ ├── README
│ └── docker-compose.yml
├── env_templates
│ ├── env.astra.admin.template
│ ├── env.astra.template
│ ├── env.local.template
│ ├── env.testcontainers.template
│ ├── env.vectorize-minimal.template
│ └── env.vectorize.template
├── hcd_compose
│ ├── cassandra-hcd.yaml
│ └── docker-compose.yml
├── preprocess_env.py
└── vectorize
│ ├── __init__.py
│ ├── conftest.py
│ ├── integration
│ ├── __init__.py
│ └── test_vectorize_providers.py
│ ├── live_provider_info.py
│ ├── query_providers.py
│ └── vectorize_models.py
└── uv.lock
/.github/workflows/lint.yml:
--------------------------------------------------------------------------------
1 | name: ruff and mypy checks
2 |
3 | on:
4 | push:
5 | branches:
6 | - main
7 | pull_request:
8 | branches:
9 | - main
10 |
11 | jobs:
12 | mypy:
13 | runs-on: ubuntu-latest
14 |
15 | steps:
16 | - uses: actions/checkout@v2
17 | - name: Set up Python
18 | uses: actions/setup-python@v2
19 | with:
20 | python-version: '3.11' # Or any version you prefer
21 |
22 | - name: Install dependencies
23 | run: |
24 | python -m pip install --upgrade pip
25 | pipx install uv
26 | make venv
27 |
28 | - name: Ruff Linting AstraPy
29 | run: |
30 | uv run ruff check astrapy
31 |
32 | - name: Ruff Linting Tests
33 | run: |
34 | uv run ruff check tests
35 |
36 | - name: Ruff formatting astrapy
37 | run: |
38 | uv run ruff format --check astrapy
39 |
40 | - name: Ruff formatting tests
41 | run: |
42 | uv run ruff format --check tests
43 |
44 | - name: Run MyPy AstraPy
45 | run: |
46 | uv run mypy astrapy
47 |
48 | - name: Run MyPy Tests
49 | run: |
50 | uv run mypy tests
51 |
--------------------------------------------------------------------------------
/.github/workflows/local.yaml:
--------------------------------------------------------------------------------
1 | name: Run base integration tests on a local Data API
2 |
3 | on:
4 | push:
5 | branches:
6 | - main
7 | pull_request:
8 | branches:
9 | - main
10 |
11 | jobs:
12 | test:
13 | env:
14 | HEADER_EMBEDDING_API_KEY_OPENAI: ${{ secrets.HEADER_EMBEDDING_API_KEY_OPENAI }}
15 | # hardcoding the target DB
16 | DOCKER_COMPOSE_LOCAL_DATA_API: "yes"
17 | # turn on header-based reranker auth
18 | ASTRAPY_FINDANDRERANK_USE_RERANKER_HEADER: "yes"
19 | HEADER_RERANKING_API_KEY_NVIDIA: ${{ secrets.HEADER_RERANKING_API_KEY_NVIDIA }}
20 | runs-on: ubuntu-latest
21 |
22 | steps:
23 | - name: Checkout code
24 | uses: actions/checkout@v2
25 |
26 | - name: Set up Python
27 | uses: actions/setup-python@v2
28 | with:
29 | python-version: 3.11
30 |
31 | - name: Install dependencies
32 | run: |
33 | python -m pip install --upgrade pip
34 | pipx install uv
35 | make venv
36 |
37 | # Prepare to login to ECR:
38 | - name: Configure AWS credentials
39 | uses: aws-actions/configure-aws-credentials@v4
40 | with:
41 | aws-access-key-id: ${{ secrets.HCD_ECR_ACCESS_KEY }}
42 | aws-secret-access-key: ${{ secrets.HCD_ECR_SECRET_KEY }}
43 | aws-region: us-west-2
44 |
45 | # Login to ECR so we can pull HCD image:
46 | - name: Login to Amazon ECR
47 | id: login-ecr
48 | uses: aws-actions/amazon-ecr-login@v2
49 | with:
50 | mask-password: 'true'
51 |
52 | - name: Run pytest
53 | run: |
54 | uv run pytest tests/base/integration
55 |
--------------------------------------------------------------------------------
/.github/workflows/main.yml:
--------------------------------------------------------------------------------
1 | name: Run base integration tests on Astra DB
2 |
3 | on:
4 | push:
5 | branches:
6 | - main
7 | pull_request:
8 | branches:
9 | - main
10 |
11 | jobs:
12 | test:
13 | env:
14 | # basic secrets
15 | ASTRA_DB_APPLICATION_TOKEN: ${{ secrets.ASTRA_DB_APPLICATION_TOKEN }}
16 | ASTRA_DB_API_ENDPOINT: ${{ secrets.ASTRA_DB_API_ENDPOINT }}
17 | ASTRA_DB_KEYSPACE: ${{ secrets.ASTRA_DB_KEYSPACE }}
18 | ASTRA_DB_SECONDARY_KEYSPACE: ${{ secrets.ASTRA_DB_SECONDARY_KEYSPACE }}
19 | HEADER_EMBEDDING_API_KEY_OPENAI: ${{ secrets.HEADER_EMBEDDING_API_KEY_OPENAI }}
20 | runs-on: ubuntu-latest
21 |
22 | steps:
23 | - name: Checkout code
24 | uses: actions/checkout@v2
25 |
26 | - name: Set up Python
27 | uses: actions/setup-python@v2
28 | with:
29 | python-version: 3.11
30 |
31 | - name: Install dependencies
32 | run: |
33 | python -m pip install --upgrade pip
34 | pipx install uv
35 | make venv
36 |
37 | - name: Run pytest
38 | run: |
39 | uv run pytest tests/base/integration
40 |
--------------------------------------------------------------------------------
/.github/workflows/unit.yaml:
--------------------------------------------------------------------------------
1 | name: Run unit tests
2 |
3 | on:
4 | push:
5 | branches:
6 | - main
7 | pull_request:
8 | branches:
9 | - main
10 |
11 | jobs:
12 | test:
13 | env:
14 | # basic secrets
15 | ASTRA_DB_APPLICATION_TOKEN: ${{ secrets.ASTRA_DB_APPLICATION_TOKEN }}
16 | ASTRA_DB_API_ENDPOINT: ${{ secrets.ASTRA_DB_API_ENDPOINT }}
17 | ASTRA_DB_KEYSPACE: ${{ secrets.ASTRA_DB_KEYSPACE }}
18 | ASTRA_DB_SECONDARY_KEYSPACE: ${{ secrets.ASTRA_DB_SECONDARY_KEYSPACE }}
19 | runs-on: ubuntu-latest
20 | strategy:
21 | matrix:
22 | python-version:
23 | - "3.8"
24 | - "3.9"
25 | - "3.10"
26 | - "3.11"
27 | - "3.12"
28 | name: "unit test on #${{ matrix.python-version }}"
29 | steps:
30 | - name: Checkout code
31 | uses: actions/checkout@v2
32 |
33 | - name: Set up Python
34 | uses: actions/setup-python@v2
35 | with:
36 | python-version: ${{ matrix.python-version }}
37 |
38 | - name: Install dependencies
39 | run: |
40 | python -m pip install --upgrade pip
41 | pipx install uv
42 | make venv
43 |
44 | - name: Run pytest
45 | run: |
46 | uv run pytest tests/base/unit
47 |
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | .logs
2 | .idea
3 | .DS_Store
4 | .project
5 | .pydevproject
6 | .settings
7 | .env
8 | node_modules
9 | venv
10 | .venv*
11 | dump.rdb
12 | .tmp
13 | npm-debug.log
14 | .ssh
15 | .csscomb.json
16 | tests/screenshots
17 | .bash_history
18 | .ipynb_checkpoints
19 | .scratch.py
20 | .pytest_cache
21 | .mypy_cache
22 | .vscode
23 | manifest
24 | dist
25 |
26 | # generated by findEmbeddingProviders dump:
27 | _providers.json
28 |
29 | # premade .gitignore from GitHub
30 |
31 | # Byte-compiled / optimized / DLL files
32 | __pycache__/
33 | *.py[cod]
34 |
35 | # C extensions
36 | *.so
37 |
38 | # Distribution / packaging
39 | .Python
40 | env/
41 | build/
42 | !frontend/vendor/**/build
43 | develop-eggs/
44 | downloads/
45 | !pd_pfe/static/downloads
46 | eggs/
47 | .eggs/
48 | parts/
49 | sdist/
50 | var/
51 | *.egg-info/
52 | .installed.cfg
53 | *.egg
54 |
55 | # PyInstaller
56 | # Usually these files are written by a python script from a template
57 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
58 | *.manifest
59 | *.spec
60 |
61 | # Installer logs
62 | pip-log.txt
63 | pip-delete-this-directory.txt
64 |
65 | # Unit test / coverage reports
66 | htmlcov/
67 | .tox/
68 | .coverage
69 | .coverage.*
70 | .cache
71 | nosetests.xml
72 | coverage.xml
73 | *,cover
74 |
75 | # Translations
76 | *.mo
77 | *.pot
78 |
79 | # Django stuff:
80 | *.log
81 |
82 | # Sphinx documentation
83 | docs/_build/
84 |
85 | # PyBuilder
86 | target/
87 |
88 | # Browser test output
89 | junitresults.*
90 |
--------------------------------------------------------------------------------
/.pre-commit-config.yaml:
--------------------------------------------------------------------------------
1 | repos:
2 | - repo: local
3 | hooks:
4 | - id: check
5 | name: check
6 | language: system
7 | entry: make format-fix
8 |
--------------------------------------------------------------------------------
/CODE_OF_CONDUCT.md:
--------------------------------------------------------------------------------
1 | # Contributor Covenant Code of Conduct
2 |
3 | ## Our Pledge
4 |
5 | In the interest of fostering an open and welcoming environment, we as
6 | contributors and maintainers pledge to making participation in our project and
7 | our community a harassment-free experience for everyone, regardless of age, body
8 | size, disability, ethnicity, sex characteristics, gender identity and expression,
9 | level of experience, education, socio-economic status, nationality, personal
10 | appearance, race, religion, or sexual identity and orientation.
11 |
12 | ## Our Standards
13 |
14 | Examples of behavior that contributes to creating a positive environment
15 | include:
16 |
17 | * Using welcoming and inclusive language
18 | * Being respectful of differing viewpoints and experiences
19 | * Gracefully accepting constructive criticism
20 | * Focusing on what is best for the community
21 | * Showing empathy towards other community members
22 |
23 | Examples of unacceptable behavior by participants include:
24 |
25 | * The use of sexualized language or imagery and unwelcome sexual attention or
26 | advances
27 | * Trolling, insulting/derogatory comments, and personal or political attacks
28 | * Public or private harassment
29 | * Publishing others' private information, such as a physical or electronic
30 | address, without explicit permission
31 | * Other conduct which could reasonably be considered inappropriate in a
32 | professional setting
33 |
34 | ## Our Responsibilities
35 |
36 | Project maintainers are responsible for clarifying the standards of acceptable
37 | behavior and are expected to take appropriate and fair corrective action in
38 | response to any instances of unacceptable behavior.
39 |
40 | Project maintainers have the right and responsibility to remove, edit, or
41 | reject comments, commits, code, wiki edits, issues, and other contributions
42 | that are not aligned to this Code of Conduct, or to ban temporarily or
43 | permanently any contributor for other behaviors that they deem inappropriate,
44 | threatening, offensive, or harmful.
45 |
46 | ## Scope
47 |
48 | This Code of Conduct applies both within project spaces and in public spaces
49 | when an individual is representing the project or its community. Examples of
50 | representing a project or community include using an official project e-mail
51 | address, posting via an official social media account, or acting as an appointed
52 | representative at an online or offline event. Representation of a project may be
53 | further defined and clarified by project maintainers.
54 |
55 | ## Enforcement
56 |
57 | Instances of abusive, harassing, or otherwise unacceptable behavior may be
58 | reported by contacting the project team at . All
59 | complaints will be reviewed and investigated and will result in a response that
60 | is deemed necessary and appropriate to the circumstances. The project team is
61 | obligated to maintain confidentiality with regard to the reporter of an incident.
62 | Further details of specific enforcement policies may be posted separately.
63 |
64 | Project maintainers who do not follow or enforce the Code of Conduct in good
65 | faith may face temporary or permanent repercussions as determined by other
66 | members of the project's leadership.
67 |
68 | ## Attribution
69 |
70 | This Code of Conduct is adapted from the [Contributor Covenant][homepage], version 1.4,
71 | available at https://www.contributor-covenant.org/version/1/4/code-of-conduct.html
72 |
73 | [homepage]: https://www.contributor-covenant.org
74 |
75 | For answers to common questions about this code of conduct, see
76 | https://www.contributor-covenant.org/faq
77 |
--------------------------------------------------------------------------------
/CONTRIBUTING.md:
--------------------------------------------------------------------------------
1 | # Contributing
2 |
3 | When contributing to this repository, please first discuss the change you wish to make via issue,
4 | email, or any other method with the owners of this repository before making a change.
5 |
6 | Please note we have a [Code of Conduct](CODE_OF_CONDUCT.md), please follow it in all your interactions with the project.
7 |
8 | ## Found an Issue?
9 | If you find a bug in the source code or a mistake in the documentation, you can help us by
10 | [submitting an issue](#submit-issue) to the GitHub Repository. Even better, you can
11 | [submit a Pull Request](#submit-pr) with a fix.
12 |
13 | ## Want a Feature?
14 | You can *request* a new feature by [submitting an issue](#submit-issue) to the GitHub
15 | Repository. If you would like to *implement* a new feature, please submit an issue with
16 | a proposal for your work first, to be sure that we can use it.
17 |
18 | * **Small Features** can be crafted and directly [submitted as a Pull Request](#submit-pr).
19 |
20 | ## Contribution Guidelines
21 |
22 | ### Contribution License Agreement
23 | To protect the community, all contributors are required to [sign the DataStax Contribution License Agreement](https://cla.datastax.com/). The process is completely electronic and should only take a few minutes.
24 |
25 | ### Submitting an Issue
26 | Before you submit an issue, search the archive, maybe your question was already answered.
27 |
28 | If your issue appears to be a bug, and hasn't been reported, open a new issue.
29 | Help us to maximize the effort we can spend fixing issues and adding new
30 | features, by not reporting duplicate issues. Providing the following information will increase the
31 | chances of your issue being dealt with quickly:
32 |
33 | * **Overview of the Issue** - if an error is being thrown a non-minified stack trace helps
34 | * **Motivation for or Use Case** - explain what are you trying to do and why the current behavior is a bug for you
35 | * **Reproduce the Error** - provide a live example or a unambiguous set of steps
36 | * **Suggest a Fix** - if you can't fix the bug yourself, perhaps you can point to what might be
37 | causing the problem (line of code or commit)
38 |
39 | ### Submitting a Pull Request (PR)
40 | Before you submit your Pull Request (PR) consider the following guidelines:
41 |
42 | * Search the repository (https://github.com/bechbd/[repository-name]/pulls) for an open or closed PR that relates to your submission. You don't want to duplicate effort.
43 |
44 | * Create a fork of the repo
45 | * Navigate to the repo you want to fork
46 | * In the top right corner of the page click **Fork**:
47 | 
48 |
49 | * Make your changes in the forked repo
50 | * Commit your changes using a descriptive commit message
51 | * In GitHub, create a pull request: https://help.github.com/en/articles/creating-a-pull-request-from-a-fork
52 | * If we suggest changes then:
53 | * Make the required updates.
54 | * Rebase your fork and force push to your GitHub repository (this will update your Pull Request):
55 |
56 | ```shell
57 | git rebase main -i
58 | git push -f
59 | ```
60 |
61 | That's it! Thank you for your contribution!
--------------------------------------------------------------------------------
/Makefile:
--------------------------------------------------------------------------------
1 | SHELL := /bin/bash
2 |
3 | .PHONY: all venv format format-fix format-tests format-src test-integration test build help
4 |
5 | all: help
6 |
7 | FMT_FLAGS ?= --check
8 | VENV ?= false
9 |
10 | ifeq ($(VENV), true)
11 | VENV_FLAGS := --active
12 | else
13 | VENV_FLAGS :=
14 | endif
15 |
16 | venv:
17 | uv venv
18 | uv sync --dev
19 |
20 | format: format-src format-tests
21 |
22 | format-tests:
23 | uv run $(VENV_FLAGS) ruff check tests
24 | uv run $(VENV_FLAGS) ruff format tests $(FMT_FLAGS)
25 | uv run $(VENV_FLAGS) mypy tests
26 |
27 | format-src:
28 | uv run $(VENV_FLAGS) ruff check astrapy
29 | uv run $(VENV_FLAGS) ruff format astrapy $(FMT_FLAGS)
30 | uv run $(VENV_FLAGS) mypy astrapy
31 |
32 | format-fix: format-fix-src format-fix-tests
33 |
34 | format-fix-src: FMT_FLAGS=
35 | format-fix-src: format-src
36 |
37 | format-fix-tests: FMT_FLAGS=
38 | format-fix-tests: format-tests
39 |
40 | test-integration:
41 | uv run $(VENV_FLAGS) pytest tests/base -vv
42 |
43 | test:
44 | uv run $(VENV_FLAGS) pytest tests/base/unit -vv
45 |
46 | docker-test-integration:
47 | DOCKER_COMPOSE_LOCAL_DATA_API="yes" uv run pytest tests/base -vv
48 |
49 | build:
50 | rm -f dist/astrapy*
51 | uv build
52 |
53 | help:
54 | @echo "======================================================================"
55 | @echo "AstraPy make command purpose"
56 | @echo "----------------------------------------------------------------------"
57 | @echo "venv create a virtual env (needs uv)"
58 | @echo "format full lint and format checks"
59 | @echo " format-src limited to source"
60 | @echo " format-tests limited to tests"
61 | @echo " format-fix fixing imports and style"
62 | @echo " format-fix-src limited to source"
63 | @echo " format-fix-tests limited to tests"
64 | @echo "test run unit tests"
65 | @echo "test-integration run integration tests"
66 | @echo "docker-test-integration run int.tests on dockerized local"
67 | @echo "build build package ready for PyPI"
68 | @echo "======================================================================"
69 |
--------------------------------------------------------------------------------
/astrapy/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright DataStax, Inc.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | from __future__ import annotations
16 |
17 | import importlib.metadata
18 | import os
19 |
20 | import toml
21 |
22 |
23 | def get_version() -> str:
24 | try:
25 | # Expect a __version__ attribute in the package's __init__.py file
26 | return importlib.metadata.version(__package__)
27 |
28 | # If the package is not installed, we can still get the version from the pyproject.toml file
29 | except importlib.metadata.PackageNotFoundError:
30 | # Get the path to the pyproject.toml file
31 | dir_path = os.path.dirname(os.path.realpath(__file__))
32 | pyproject_path = os.path.join(dir_path, "..", "pyproject.toml")
33 |
34 | # Read the pyproject.toml file and get the version from the appropriate section
35 | try:
36 | with open(pyproject_path, encoding="utf-8") as pyproject:
37 | # Load the pyproject.toml file as a dictionary
38 | file_contents = pyproject.read()
39 | pyproject_data = toml.loads(file_contents)
40 |
41 | # Return the version from the 'project' section
42 | return str(pyproject_data["tool"]["project"]["version"])
43 |
44 | # If the pyproject.toml file does not exist or the version is not found, return unknown
45 | except (FileNotFoundError, KeyError):
46 | return "unknown"
47 |
48 |
49 | __version__: str = get_version()
50 |
51 |
52 | from astrapy import api_options # noqa: E402, F401
53 | from astrapy.admin import ( # noqa: E402
54 | AstraDBAdmin,
55 | AstraDBDatabaseAdmin,
56 | DataAPIDatabaseAdmin,
57 | )
58 | from astrapy.client import DataAPIClient # noqa: E402
59 | from astrapy.collection import AsyncCollection, Collection # noqa: E402
60 |
61 | # A circular-import issue requires this to happen at the end of this module:
62 | from astrapy.database import AsyncDatabase, Database # noqa: E402
63 | from astrapy.table import AsyncTable, Table # noqa: E402
64 |
65 | __all__ = [
66 | "AstraDBAdmin",
67 | "AstraDBDatabaseAdmin",
68 | "AsyncCollection",
69 | "AsyncDatabase",
70 | "AsyncTable",
71 | "Collection",
72 | "Database",
73 | "DataAPIClient",
74 | "DataAPIDatabaseAdmin",
75 | "Table",
76 | "__version__",
77 | ]
78 |
79 |
80 | __pdoc__ = {
81 | "ids": False,
82 | "settings": False,
83 | }
84 |
--------------------------------------------------------------------------------
/astrapy/admin/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright DataStax, Inc.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | from __future__ import annotations
16 |
17 | from astrapy.admin.admin import (
18 | AstraDBAdmin,
19 | AstraDBDatabaseAdmin,
20 | DataAPIDatabaseAdmin,
21 | DatabaseAdmin,
22 | ParsedAPIEndpoint,
23 | async_fetch_database_info,
24 | fetch_database_info,
25 | parse_api_endpoint,
26 | )
27 |
28 | __all__ = [
29 | "AstraDBAdmin",
30 | "AstraDBDatabaseAdmin",
31 | "DataAPIDatabaseAdmin",
32 | "DatabaseAdmin",
33 | "ParsedAPIEndpoint",
34 | "async_fetch_database_info",
35 | "fetch_database_info",
36 | "parse_api_endpoint",
37 | ]
38 |
--------------------------------------------------------------------------------
/astrapy/api_options.py:
--------------------------------------------------------------------------------
1 | # Copyright DataStax, Inc.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | from __future__ import annotations
16 |
17 | from astrapy.utils.api_options import (
18 | APIOptions,
19 | DataAPIURLOptions,
20 | DevOpsAPIURLOptions,
21 | SerdesOptions,
22 | TimeoutOptions,
23 | )
24 |
25 | __all__ = [
26 | "APIOptions",
27 | "DataAPIURLOptions",
28 | "DevOpsAPIURLOptions",
29 | "SerdesOptions",
30 | "TimeoutOptions",
31 | ]
32 |
--------------------------------------------------------------------------------
/astrapy/collection.py:
--------------------------------------------------------------------------------
1 | # Copyright DataStax, Inc.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | from astrapy.data.collection import (
16 | AsyncCollection,
17 | Collection,
18 | )
19 |
20 | __all__ = [
21 | "AsyncCollection",
22 | "Collection",
23 | ]
24 |
--------------------------------------------------------------------------------
/astrapy/constants.py:
--------------------------------------------------------------------------------
1 | # Copyright DataStax, Inc.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | from __future__ import annotations
16 |
17 | from typing import Any, Dict, Iterable, List, Optional, Tuple, TypeVar, Union
18 |
19 | from astrapy.data_types import DataAPIVector
20 | from astrapy.settings.defaults import (
21 | DATA_API_ENVIRONMENT_CASSANDRA,
22 | DATA_API_ENVIRONMENT_DEV,
23 | DATA_API_ENVIRONMENT_DSE,
24 | DATA_API_ENVIRONMENT_HCD,
25 | DATA_API_ENVIRONMENT_OTHER,
26 | DATA_API_ENVIRONMENT_PROD,
27 | DATA_API_ENVIRONMENT_TEST,
28 | )
29 | from astrapy.utils.str_enum import StrEnum
30 |
31 | DefaultDocumentType = Dict[str, Any]
32 | DefaultRowType = Dict[str, Any]
33 | ProjectionType = Union[
34 | Iterable[str], Dict[str, Union[bool, Dict[str, Union[int, Iterable[int]]]]]
35 | ]
36 | SortType = Dict[str, Any]
37 | HybridSortType = Dict[
38 | str, Union[str, Dict[str, Union[str, List[float], DataAPIVector]]]
39 | ]
40 | FilterType = Dict[str, Any]
41 | CallerType = Tuple[Optional[str], Optional[str]]
42 |
43 |
44 | ROW = TypeVar("ROW")
45 | ROW2 = TypeVar("ROW2")
46 | DOC = TypeVar("DOC")
47 | DOC2 = TypeVar("DOC2")
48 |
49 |
50 | def normalize_optional_projection(
51 | projection: ProjectionType | None,
52 | ) -> dict[str, bool | dict[str, int | Iterable[int]]] | None:
53 | if projection:
54 | if isinstance(projection, dict):
55 | # already a dictionary
56 | return projection
57 | else:
58 | # an iterable over strings: coerce to allow-list projection
59 | return {field: True for field in projection}
60 | else:
61 | return None
62 |
63 |
64 | class ReturnDocument:
65 | """
66 | Admitted values for the `return_document` parameter in
67 | `find_one_and_replace` and `find_one_and_update` collection
68 | methods.
69 | """
70 |
71 | def __init__(self) -> None:
72 | raise NotImplementedError
73 |
74 | BEFORE = "before"
75 | AFTER = "after"
76 |
77 |
78 | class SortMode:
79 | """
80 | Admitted values for the `sort` parameter in the find collection methods,
81 | e.g. `sort={"field": SortMode.ASCENDING}`.
82 | """
83 |
84 | def __init__(self) -> None:
85 | raise NotImplementedError
86 |
87 | ASCENDING = 1
88 | DESCENDING = -1
89 |
90 |
91 | class VectorMetric:
92 | """
93 | Admitted values for the "metric" parameter to use in CollectionVectorOptions
94 | object, needed when creating vector collections through the database
95 | `create_collection` method.
96 | """
97 |
98 | def __init__(self) -> None:
99 | raise NotImplementedError
100 |
101 | DOT_PRODUCT = "dot_product"
102 | EUCLIDEAN = "euclidean"
103 | COSINE = "cosine"
104 |
105 |
106 | class DefaultIdType:
107 | """
108 | Admitted values for the "default_id_type" parameter to use in
109 | CollectionDefaultIDOptions object, needed when creating collections
110 | through the database `create_collection` method.
111 | """
112 |
113 | def __init__(self) -> None:
114 | raise NotImplementedError
115 |
116 | UUID = "uuid"
117 | OBJECTID = "objectId"
118 | UUIDV6 = "uuidv6"
119 | UUIDV7 = "uuidv7"
120 | DEFAULT = "uuid"
121 |
122 |
123 | class Environment:
124 | """
125 | Admitted values for `environment` property,
126 | denoting the targeted API deployment type.
127 | """
128 |
129 | def __init__(self) -> None:
130 | raise NotImplementedError
131 |
132 | PROD = DATA_API_ENVIRONMENT_PROD
133 | DEV = DATA_API_ENVIRONMENT_DEV
134 | TEST = DATA_API_ENVIRONMENT_TEST
135 | DSE = DATA_API_ENVIRONMENT_DSE
136 | HCD = DATA_API_ENVIRONMENT_HCD
137 | CASSANDRA = DATA_API_ENVIRONMENT_CASSANDRA
138 | OTHER = DATA_API_ENVIRONMENT_OTHER
139 |
140 | values = {PROD, DEV, TEST, DSE, HCD, CASSANDRA, OTHER}
141 | astra_db_values = {PROD, DEV, TEST}
142 |
143 |
144 | class MapEncodingMode(StrEnum):
145 | """
146 | Enum for the possible values of the setting controlling whether to encode
147 | dicts/DataAPIMaps as lists of pairs ("association lists") in table payloads.
148 | """
149 |
150 | NEVER = "NEVER"
151 | DATAAPIMAPS = "DATAAPIMAPS"
152 | ALWAYS = "ALWAYS"
153 |
154 |
155 | __all__ = [
156 | "DefaultIdType",
157 | "Environment",
158 | "MapEncodingMode",
159 | "ReturnDocument",
160 | "SortMode",
161 | "VectorMetric",
162 | ]
163 |
--------------------------------------------------------------------------------
/astrapy/cursors.py:
--------------------------------------------------------------------------------
1 | # Copyright DataStax, Inc.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | from astrapy.data.cursors.cursor import (
16 | AbstractCursor,
17 | CursorState,
18 | )
19 | from astrapy.data.cursors.farr_cursor import (
20 | AsyncCollectionFindAndRerankCursor,
21 | CollectionFindAndRerankCursor,
22 | )
23 | from astrapy.data.cursors.find_cursor import (
24 | AsyncCollectionFindCursor,
25 | AsyncTableFindCursor,
26 | CollectionFindCursor,
27 | TableFindCursor,
28 | )
29 | from astrapy.data.cursors.reranked_result import RerankedResult
30 |
31 | __all__ = [
32 | "AsyncCollectionFindAndRerankCursor",
33 | "AsyncCollectionFindCursor",
34 | "AsyncTableFindCursor",
35 | "CollectionFindAndRerankCursor",
36 | "CollectionFindCursor",
37 | "AbstractCursor",
38 | "CursorState",
39 | "RerankedResult",
40 | "TableFindCursor",
41 | ]
42 |
--------------------------------------------------------------------------------
/astrapy/data/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright DataStax, Inc.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
--------------------------------------------------------------------------------
/astrapy/data/cursors/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright DataStax, Inc.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | from __future__ import annotations
16 |
--------------------------------------------------------------------------------
/astrapy/data/cursors/reranked_result.py:
--------------------------------------------------------------------------------
1 | # Copyright DataStax, Inc.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | from __future__ import annotations
16 |
17 | from dataclasses import dataclass
18 | from typing import Generic
19 |
20 | from astrapy.data.cursors.cursor import TRAW
21 |
22 |
23 | @dataclass
24 | class RerankedResult(Generic[TRAW]):
25 | """
26 | A single result coming `find_and_rerank` command, i.e. an item from DB with scores.
27 |
28 | Attributes:
29 | document: a collection/row as returned by `find_and_rerank` API command.
30 | scores: a dictionary of score labels to score float values, such as
31 | `{"$rerank": 0.87, "$vector" : 0.65, "$lexical" : 0.91}`.
32 | """
33 |
34 | document: TRAW
35 | scores: dict[str, float | int | None]
36 |
--------------------------------------------------------------------------------
/astrapy/data/utils/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright DataStax, Inc.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
--------------------------------------------------------------------------------
/astrapy/data/utils/extended_json_converters.py:
--------------------------------------------------------------------------------
1 | # Copyright DataStax, Inc.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | from __future__ import annotations
16 |
17 | import base64
18 | import datetime
19 | import time
20 |
21 | from astrapy.data_types import DataAPITimestamp
22 | from astrapy.ids import UUID, ObjectId
23 |
24 |
25 | def convert_to_ejson_date_object(
26 | date_value: datetime.date | datetime.datetime,
27 | ) -> dict[str, int]:
28 | if isinstance(date_value, datetime.datetime):
29 | return {"$date": int(date_value.timestamp() * 1000)}
30 | return {"$date": int(time.mktime(date_value.timetuple()) * 1000)}
31 |
32 |
33 | def convert_to_ejson_apitimestamp_object(
34 | date_value: DataAPITimestamp,
35 | ) -> dict[str, int]:
36 | return {"$date": date_value.timestamp_ms}
37 |
38 |
39 | def convert_to_ejson_bytes(bytes_value: bytes) -> dict[str, str]:
40 | return {"$binary": base64.b64encode(bytes_value).decode()}
41 |
42 |
43 | def convert_to_ejson_uuid_object(uuid_value: UUID) -> dict[str, str]:
44 | return {"$uuid": str(uuid_value)}
45 |
46 |
47 | def convert_to_ejson_objectid_object(objectid_value: ObjectId) -> dict[str, str]:
48 | return {"$objectId": str(objectid_value)}
49 |
50 |
51 | def convert_ejson_date_object_to_datetime(
52 | date_object: dict[str, int], tz: datetime.timezone | None
53 | ) -> datetime.datetime:
54 | return datetime.datetime.fromtimestamp(date_object["$date"] / 1000.0, tz=tz)
55 |
56 |
57 | def convert_ejson_date_object_to_apitimestamp(
58 | date_object: dict[str, int],
59 | ) -> DataAPITimestamp:
60 | return DataAPITimestamp(date_object["$date"])
61 |
62 |
63 | def convert_ejson_binary_object_to_bytes(
64 | binary_object: dict[str, str],
65 | ) -> bytes:
66 | return base64.b64decode(binary_object["$binary"])
67 |
68 |
69 | def convert_ejson_uuid_object_to_uuid(uuid_object: dict[str, str]) -> UUID:
70 | return UUID(uuid_object["$uuid"])
71 |
72 |
73 | def convert_ejson_objectid_object_to_objectid(
74 | objectid_object: dict[str, str],
75 | ) -> ObjectId:
76 | return ObjectId(objectid_object["$objectId"])
77 |
--------------------------------------------------------------------------------
/astrapy/data/utils/table_types.py:
--------------------------------------------------------------------------------
1 | # Copyright DataStax, Inc.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | from __future__ import annotations
16 |
17 | from astrapy.utils.str_enum import StrEnum
18 |
19 |
20 | class ColumnType(StrEnum):
21 | """
22 | Enum to describe the scalar column types for Tables.
23 |
24 | A 'scalar' type is a non-composite type: that means, no sets, lists, maps
25 | and other non-primitive data types.
26 | """
27 |
28 | ASCII = "ascii"
29 | BIGINT = "bigint"
30 | BLOB = "blob"
31 | BOOLEAN = "boolean"
32 | COUNTER = "counter"
33 | DATE = "date"
34 | DECIMAL = "decimal"
35 | DOUBLE = "double"
36 | DURATION = "duration"
37 | FLOAT = "float"
38 | INET = "inet"
39 | INT = "int"
40 | SMALLINT = "smallint"
41 | TEXT = "text"
42 | TIME = "time"
43 | TIMESTAMP = "timestamp"
44 | TIMEUUID = "timeuuid"
45 | TINYINT = "tinyint"
46 | UUID = "uuid"
47 | VARINT = "varint"
48 |
49 |
50 | class TableValuedColumnType(StrEnum):
51 | """
52 | An enum to describe the types of column with "values".
53 | """
54 |
55 | LIST = "list"
56 | SET = "set"
57 |
58 |
59 | class TableKeyValuedColumnType(StrEnum):
60 | """
61 | An enum to describe the types of column with "keys and values".
62 | """
63 |
64 | MAP = "map"
65 |
66 |
67 | class TableVectorColumnType(StrEnum):
68 | """
69 | An enum to describe the types of 'vector-like' column.
70 | """
71 |
72 | VECTOR = "vector"
73 |
74 |
75 | class TableUnsupportedColumnType(StrEnum):
76 | """
77 | An enum to describe the types of column falling into the 'unsupported' group
78 | (read/describe path).
79 | """
80 |
81 | UNSUPPORTED = "UNSUPPORTED"
82 |
83 |
84 | class TablePassthroughColumnType(StrEnum):
85 | """
86 | An enum to describe the types for 'passthrough' columns (read/describe path).
87 | """
88 |
89 | PASSTHROUGH = "PASSTHROUGH"
90 |
--------------------------------------------------------------------------------
/astrapy/data/utils/vector_coercion.py:
--------------------------------------------------------------------------------
1 | # Copyright DataStax, Inc.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | from __future__ import annotations
16 |
17 | from typing import Any, Iterable
18 |
19 | from astrapy.data_types import DataAPIMap, DataAPISet, DataAPIVector
20 |
21 | ITERABLES_TO_NOT_UNROLL = list, str, bytes, dict, DataAPIVector, DataAPIMap, DataAPISet
22 |
23 |
24 | def ensure_unrolled_if_iterable(value: Any) -> Any:
25 | if isinstance(value, Iterable) and not isinstance(value, (ITERABLES_TO_NOT_UNROLL)):
26 | return list(value)
27 | return value
28 |
29 |
30 | def convert_vector_to_floats(vector: Iterable[Any]) -> list[float]:
31 | """
32 | Convert a vector of strings to a vector of floats.
33 |
34 | Args:
35 | vector (list): A vector of objects.
36 |
37 | Returns:
38 | list: A vector of floats.
39 | """
40 | return [float(value) for value in vector]
41 |
42 |
43 | def is_list_of_floats(vector: Iterable[Any]) -> bool:
44 | """
45 | Safely determine if it's a list of floats.
46 | Assumption: if list, and first item is float, then all items are.
47 | """
48 | return isinstance(vector, list) and (
49 | len(vector) == 0 or isinstance(vector[0], float) or isinstance(vector[0], int)
50 | )
51 |
--------------------------------------------------------------------------------
/astrapy/data_types/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright DataStax, Inc.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | from __future__ import annotations
16 |
17 | from astrapy.data_types.data_api_date import DataAPIDate
18 | from astrapy.data_types.data_api_duration import DataAPIDuration
19 | from astrapy.data_types.data_api_map import DataAPIMap
20 | from astrapy.data_types.data_api_set import DataAPISet
21 | from astrapy.data_types.data_api_time import DataAPITime
22 | from astrapy.data_types.data_api_timestamp import DataAPITimestamp
23 | from astrapy.data_types.data_api_vector import DataAPIVector
24 |
25 | __all__ = [
26 | "DataAPITimestamp",
27 | "DataAPIVector",
28 | "DataAPIDate",
29 | "DataAPIDuration",
30 | "DataAPIMap",
31 | "DataAPISet",
32 | "DataAPITime",
33 | ]
34 |
--------------------------------------------------------------------------------
/astrapy/data_types/data_api_map.py:
--------------------------------------------------------------------------------
1 | # Copyright DataStax, Inc.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | from __future__ import annotations
16 |
17 | import math
18 | from typing import Generic, Iterable, Iterator, Mapping, TypeVar
19 |
20 | T = TypeVar("T")
21 | U = TypeVar("U")
22 |
23 |
24 | def _accumulate_pairs(
25 | destination: tuple[list[T], list[U]], source: Iterable[tuple[T, U]]
26 | ) -> tuple[list[T], list[U]]:
27 | _new_ks = list(destination[0])
28 | _new_vs = list(destination[1])
29 | for k, v in source:
30 | if k not in _new_ks:
31 | _new_ks.append(k)
32 | _new_vs.append(v)
33 | return (_new_ks, _new_vs)
34 |
35 |
36 | class DataAPIMap(Generic[T, U], Mapping[T, U]):
37 | """
38 | An immutable 'map-like' class that preserves the order and can employ
39 | non-hashable keys (which must support __eq__). Not designed for performance.
40 |
41 | Despite internally preserving the order, equality between DataAPIMap instances
42 | (and with regular dicts) is independent of the order.
43 | """
44 |
45 | _keys: list[T]
46 | _values: list[U]
47 |
48 | def __init__(self, source: Iterable[tuple[T, U]] | dict[T, U] = []) -> None:
49 | if isinstance(source, dict):
50 | self._keys, self._values = _accumulate_pairs(
51 | ([], []),
52 | source.items(),
53 | )
54 | else:
55 | self._keys, self._values = _accumulate_pairs(
56 | ([], []),
57 | source,
58 | )
59 |
60 | def __getitem__(self, key: T) -> U:
61 | if isinstance(key, float) and math.isnan(key):
62 | for idx, k in enumerate(self._keys):
63 | if isinstance(k, float) and math.isnan(k):
64 | return self._values[idx]
65 | raise KeyError(str(key))
66 | else:
67 | for idx, k in enumerate(self._keys):
68 | if k == key:
69 | return self._values[idx]
70 | raise KeyError(str(key) + "//" + str(self._keys) + "//" + str(self._values))
71 |
72 | def __iter__(self) -> Iterator[T]:
73 | return iter(self._keys)
74 |
75 | def __len__(self) -> int:
76 | return len(self._keys)
77 |
78 | def __eq__(self, other: object) -> bool:
79 | if isinstance(other, DataAPIMap):
80 | if len(self) == len(other):
81 | if all(o_k in self for o_k in other):
82 | return all(other[k] == self[k] for k in self)
83 | return False
84 | try:
85 | dother = dict(other) # type: ignore[call-overload]
86 | return all(
87 | [
88 | len(dother) == len(self),
89 | all(o_k in self for o_k in dother),
90 | all(dother[k] == self[k] for k in self),
91 | ]
92 | )
93 | except KeyError:
94 | return False
95 | except TypeError:
96 | pass
97 | return NotImplemented
98 |
99 | def __repr__(self) -> str:
100 | _map_repr = ", ".join(
101 | f"({repr(k)}, {repr(v)})" for k, v in zip(self._keys, self._values)
102 | )
103 | return f"{self.__class__.__name__}([{_map_repr}])"
104 |
105 | def __str__(self) -> str:
106 | _map_repr = ", ".join(f"({k}, {v})" for k, v in zip(self._keys, self._values))
107 | return f"{_map_repr}"
108 |
109 | def __reduce__(self) -> tuple[type, tuple[Iterable[tuple[T, U]]]]:
110 | return self.__class__, (list(zip(self._keys, self._values)),)
111 |
--------------------------------------------------------------------------------
/astrapy/database.py:
--------------------------------------------------------------------------------
1 | # Copyright DataStax, Inc.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | from astrapy.data.database import (
16 | AsyncDatabase,
17 | Database,
18 | )
19 |
20 | __all__ = [
21 | "AsyncDatabase",
22 | "Database",
23 | ]
24 |
--------------------------------------------------------------------------------
/astrapy/exceptions/collection_exceptions.py:
--------------------------------------------------------------------------------
1 | # Copyright DataStax, Inc.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | from __future__ import annotations
16 |
17 | from dataclasses import dataclass
18 | from typing import TYPE_CHECKING, Any, Sequence
19 |
20 | from astrapy.exceptions.data_api_exceptions import DataAPIException
21 |
22 | if TYPE_CHECKING:
23 | from astrapy.results import (
24 | CollectionDeleteResult,
25 | CollectionUpdateResult,
26 | )
27 |
28 |
29 | @dataclass
30 | class TooManyDocumentsToCountException(DataAPIException):
31 | """
32 | A `count_documents()` operation on a collection failed because the resulting
33 | number of documents exceeded either the upper bound set by the caller or the
34 | hard limit imposed by the Data API.
35 |
36 | Attributes:
37 | text: a text message about the exception.
38 | server_max_count_exceeded: True if the count limit imposed by the API
39 | is reached. In that case, increasing the upper bound in the method
40 | invocation is of no help.
41 | """
42 |
43 | text: str
44 | server_max_count_exceeded: bool
45 |
46 | def __init__(
47 | self,
48 | text: str,
49 | *,
50 | server_max_count_exceeded: bool,
51 | ) -> None:
52 | super().__init__(text)
53 | self.text = text
54 | self.server_max_count_exceeded = server_max_count_exceeded
55 |
56 |
57 | @dataclass
58 | class CollectionInsertManyException(DataAPIException):
59 | """
60 | An exception occurring within an insert_many (an operation that can span
61 | several requests). As such, it represents both the root error(s) that happened
62 | and information on the portion of the documents that were successfully inserted.
63 |
64 | The behaviour of insert_many (concurrency and the `ordered` setting) make it
65 | possible that more than one "root errors" are collected.
66 |
67 | Attributes:
68 | inserted_ids: a list of the document IDs that have been successfully inserted.
69 | exceptions: a list of the root exceptions leading to this error. The list,
70 | under normal circumstances, is not empty.
71 | """
72 |
73 | inserted_ids: list[Any]
74 | exceptions: Sequence[Exception]
75 |
76 | def __str__(self) -> str:
77 | num_ids = len(self.inserted_ids)
78 | if self.exceptions:
79 | exc_desc: str
80 | excs_strs = [exc.__str__() for exc in self.exceptions[:8]]
81 | if len(self.exceptions) > 8:
82 | exc_desc = ", ".join(excs_strs) + " ... (more exceptions)"
83 | else:
84 | exc_desc = ", ".join(excs_strs)
85 | return (
86 | f"{self.__class__.__name__}({exc_desc} [with {num_ids} inserted ids])"
87 | )
88 | else:
89 | return f"{self.__class__.__name__}()"
90 |
91 |
92 | @dataclass
93 | class CollectionDeleteManyException(DataAPIException):
94 | """
95 | An exception occurring during a delete_many (an operation that can span
96 | several requests). As such, besides information on the root-cause error,
97 | there may be a partial result about the part that succeeded.
98 |
99 | Attributes:
100 | partial_result: a CollectionDeleteResult object, just like the one that would
101 | be the return value of the operation, had it succeeded completely.
102 | cause: a root exception that happened during the delete_many, causing
103 | the method call to stop and raise this error.
104 | """
105 |
106 | partial_result: CollectionDeleteResult
107 | cause: Exception
108 |
109 | def __str__(self) -> str:
110 | return f"{self.__class__.__name__}({self.cause.__str__()})"
111 |
112 |
113 | @dataclass
114 | class CollectionUpdateManyException(DataAPIException):
115 | """
116 | An exception occurring during an update_many (an operation that can span
117 | several requests). As such, besides information on the root-cause error,
118 | there may be a partial result about the part that succeeded.
119 |
120 | Attributes:
121 | partial_result: a CollectionUpdateResult object, just like the one that would
122 | be the return value of the operation, had it succeeded completely.
123 | cause: a root exception that happened during the update_many, causing
124 | the method call to stop and raise this error.
125 | """
126 |
127 | partial_result: CollectionUpdateResult
128 | cause: Exception
129 |
130 | def __str__(self) -> str:
131 | return f"{self.__class__.__name__}({self.cause.__str__()})"
132 |
--------------------------------------------------------------------------------
/astrapy/exceptions/table_exceptions.py:
--------------------------------------------------------------------------------
1 | # Copyright DataStax, Inc.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | from __future__ import annotations
16 |
17 | from dataclasses import dataclass
18 | from typing import Any, Sequence
19 |
20 | from astrapy.exceptions.data_api_exceptions import DataAPIException
21 |
22 |
23 | @dataclass
24 | class TooManyRowsToCountException(DataAPIException):
25 | """
26 | A `count_documents()` operation on a table failed because of the excessive amount
27 | of rows to count.
28 |
29 | Attributes:
30 | text: a text message about the exception.
31 | """
32 |
33 | text: str
34 | server_max_count_exceeded: bool
35 |
36 | def __init__(
37 | self,
38 | text: str,
39 | *,
40 | server_max_count_exceeded: bool,
41 | ) -> None:
42 | super().__init__(text)
43 | self.text = text
44 | self.server_max_count_exceeded = server_max_count_exceeded
45 |
46 |
47 | @dataclass
48 | class TableInsertManyException(DataAPIException):
49 | """
50 | An exception occurring within an insert_many (an operation that can span
51 | several requests). As such, it represents both the root error(s) that happened
52 | and information on the portion of the row that were successfully inserted.
53 |
54 | The behaviour of insert_many (concurrency and the `ordered` setting) make it
55 | possible that more than one "root errors" are collected.
56 |
57 | Attributes:
58 | inserted_ids: a list of the row IDs that have been successfully inserted,
59 | in the form of a dictionary matching the table primary key).
60 | inserted_id_tuples: the same information as for `inserted_ids` (in the same
61 | order), but in form of a tuples for each ID.
62 | exceptions: a list of the root exceptions leading to this error. The list,
63 | under normal circumstances, is not empty.
64 | """
65 |
66 | inserted_ids: list[Any]
67 | inserted_id_tuples: list[tuple[Any, ...]]
68 | exceptions: Sequence[Exception]
69 |
70 | def __str__(self) -> str:
71 | num_ids = len(self.inserted_ids)
72 | if self.exceptions:
73 | exc_desc: str
74 | excs_strs = [exc.__str__() for exc in self.exceptions[:8]]
75 | if len(self.exceptions) > 8:
76 | exc_desc = ", ".join(excs_strs) + " ... (more exceptions)"
77 | else:
78 | exc_desc = ", ".join(excs_strs)
79 | return (
80 | f"{self.__class__.__name__}({exc_desc} [with {num_ids} inserted ids])"
81 | )
82 | else:
83 | return f"{self.__class__.__name__}()"
84 |
--------------------------------------------------------------------------------
/astrapy/ids.py:
--------------------------------------------------------------------------------
1 | # Copyright DataStax, Inc.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | from __future__ import annotations
16 |
17 | from uuid import UUID, uuid1, uuid3, uuid4, uuid5 # noqa: F401
18 |
19 | from bson.objectid import ObjectId # noqa: F401
20 | from uuid6 import uuid6, uuid7, uuid8 # noqa: F401
21 |
22 | __all__ = [
23 | "ObjectId",
24 | "uuid1",
25 | "uuid3",
26 | "uuid4",
27 | "uuid5",
28 | "uuid6",
29 | "uuid7",
30 | "uuid8",
31 | "UUID",
32 | ]
33 |
--------------------------------------------------------------------------------
/astrapy/info.py:
--------------------------------------------------------------------------------
1 | # Copyright DataStax, Inc.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | from __future__ import annotations
16 |
17 | from astrapy.data.info.collection_descriptor import (
18 | CollectionDefaultIDOptions,
19 | CollectionDefinition,
20 | CollectionDescriptor,
21 | CollectionInfo,
22 | CollectionLexicalOptions,
23 | CollectionRerankOptions,
24 | CollectionVectorOptions,
25 | )
26 | from astrapy.data.info.database_info import (
27 | AstraDBAdminDatabaseInfo,
28 | AstraDBAvailableRegionInfo,
29 | AstraDBDatabaseInfo,
30 | )
31 | from astrapy.data.info.reranking import (
32 | FindRerankingProvidersResult,
33 | RerankingAPIModelSupport,
34 | RerankingProvider,
35 | RerankingProviderAuthentication,
36 | RerankingProviderModel,
37 | RerankingProviderParameter,
38 | RerankingProviderToken,
39 | RerankServiceOptions,
40 | )
41 | from astrapy.data.info.table_descriptor.table_altering import (
42 | AlterTableAddColumns,
43 | AlterTableAddVectorize,
44 | AlterTableDropColumns,
45 | AlterTableDropVectorize,
46 | )
47 | from astrapy.data.info.table_descriptor.table_columns import (
48 | TableAPISupportDescriptor,
49 | TableKeyValuedColumnTypeDescriptor,
50 | TablePassthroughColumnTypeDescriptor,
51 | TablePrimaryKeyDescriptor,
52 | TableScalarColumnTypeDescriptor,
53 | TableUnsupportedColumnTypeDescriptor,
54 | TableValuedColumnTypeDescriptor,
55 | TableVectorColumnTypeDescriptor,
56 | )
57 | from astrapy.data.info.table_descriptor.table_creation import (
58 | CreateTableDefinition,
59 | )
60 | from astrapy.data.info.table_descriptor.table_indexes import (
61 | TableAPIIndexSupportDescriptor,
62 | TableBaseIndexDefinition,
63 | TableIndexDefinition,
64 | TableIndexDescriptor,
65 | TableIndexOptions,
66 | TableIndexType,
67 | TableUnsupportedIndexDefinition,
68 | TableVectorIndexDefinition,
69 | TableVectorIndexOptions,
70 | )
71 | from astrapy.data.info.table_descriptor.table_listing import (
72 | ListTableDefinition,
73 | ListTableDescriptor,
74 | TableInfo,
75 | )
76 | from astrapy.data.info.vectorize import (
77 | EmbeddingAPIModelSupport,
78 | EmbeddingProvider,
79 | EmbeddingProviderAuthentication,
80 | EmbeddingProviderModel,
81 | EmbeddingProviderParameter,
82 | EmbeddingProviderToken,
83 | FindEmbeddingProvidersResult,
84 | VectorServiceOptions,
85 | )
86 | from astrapy.data.utils.table_types import (
87 | ColumnType,
88 | TableKeyValuedColumnType,
89 | TableValuedColumnType,
90 | )
91 |
92 | __all__ = [
93 | "AlterTableAddColumns",
94 | "AlterTableAddVectorize",
95 | "AlterTableDropColumns",
96 | "AlterTableDropVectorize",
97 | "AstraDBAdminDatabaseInfo",
98 | "AstraDBAvailableRegionInfo",
99 | "AstraDBDatabaseInfo",
100 | "CollectionDefaultIDOptions",
101 | "CollectionDefinition",
102 | "CollectionDescriptor",
103 | "CollectionInfo",
104 | "CollectionLexicalOptions",
105 | "CollectionRerankOptions",
106 | "CollectionVectorOptions",
107 | "ColumnType",
108 | "CreateTableDefinition",
109 | "EmbeddingAPIModelSupport",
110 | "EmbeddingProvider",
111 | "EmbeddingProviderAuthentication",
112 | "EmbeddingProviderModel",
113 | "EmbeddingProviderParameter",
114 | "EmbeddingProviderToken",
115 | "FindEmbeddingProvidersResult",
116 | "FindRerankingProvidersResult",
117 | "ListTableDefinition",
118 | "ListTableDescriptor",
119 | "RerankingAPIModelSupport",
120 | "RerankingProvider",
121 | "RerankingProviderAuthentication",
122 | "RerankingProviderModel",
123 | "RerankingProviderParameter",
124 | "RerankingProviderToken",
125 | "RerankServiceOptions",
126 | "TableAPIIndexSupportDescriptor",
127 | "TableAPISupportDescriptor",
128 | "TableBaseIndexDefinition",
129 | "TableIndexDefinition",
130 | "TableIndexDescriptor",
131 | "TableIndexOptions",
132 | "TableIndexType",
133 | "TableInfo",
134 | "TableKeyValuedColumnType",
135 | "TableKeyValuedColumnTypeDescriptor",
136 | "TablePassthroughColumnTypeDescriptor",
137 | "TablePrimaryKeyDescriptor",
138 | "TableScalarColumnTypeDescriptor",
139 | "TableUnsupportedColumnTypeDescriptor",
140 | "TableUnsupportedIndexDefinition",
141 | "TableValuedColumnType",
142 | "TableValuedColumnTypeDescriptor",
143 | "TableVectorColumnTypeDescriptor",
144 | "TableVectorIndexDefinition",
145 | "TableVectorIndexOptions",
146 | "VectorServiceOptions",
147 | ]
148 |
--------------------------------------------------------------------------------
/astrapy/py.typed:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/datastax/astrapy/ffc693317399673632b72baf3ccbeb52a20f1eec/astrapy/py.typed
--------------------------------------------------------------------------------
/astrapy/settings/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright DataStax, Inc.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
--------------------------------------------------------------------------------
/astrapy/settings/defaults.py:
--------------------------------------------------------------------------------
1 | # Copyright DataStax, Inc.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | from __future__ import annotations
16 |
17 | import datetime
18 |
19 | # whether to go the extra mile and ensure the "decimal escape trick" is
20 | # not colliding with legitimate user-provided (extremely rare) strings.
21 | # A system setting, off at this time
22 | CHECK_DECIMAL_ESCAPING_CONSISTENCY = False
23 |
24 | # Environment names for management internal to astrapy
25 | DATA_API_ENVIRONMENT_PROD = "prod"
26 | DATA_API_ENVIRONMENT_DEV = "dev"
27 | DATA_API_ENVIRONMENT_TEST = "test"
28 | DATA_API_ENVIRONMENT_DSE = "dse"
29 | DATA_API_ENVIRONMENT_HCD = "hcd"
30 | DATA_API_ENVIRONMENT_CASSANDRA = "cassandra"
31 | DATA_API_ENVIRONMENT_OTHER = "other"
32 |
33 | # Defaults/settings for Database management
34 | DEFAULT_ASTRA_DB_KEYSPACE = "default_keyspace"
35 | API_ENDPOINT_TEMPLATE_ENV_MAP = {
36 | DATA_API_ENVIRONMENT_PROD: "https://{database_id}-{region}.apps.astra.datastax.com",
37 | DATA_API_ENVIRONMENT_DEV: "https://{database_id}-{region}.apps.astra-dev.datastax.com",
38 | DATA_API_ENVIRONMENT_TEST: "https://{database_id}-{region}.apps.astra-test.datastax.com",
39 | }
40 | API_PATH_ENV_MAP = {
41 | DATA_API_ENVIRONMENT_PROD: "/api/json",
42 | DATA_API_ENVIRONMENT_DEV: "/api/json",
43 | DATA_API_ENVIRONMENT_TEST: "/api/json",
44 | #
45 | DATA_API_ENVIRONMENT_DSE: "",
46 | DATA_API_ENVIRONMENT_HCD: "",
47 | DATA_API_ENVIRONMENT_CASSANDRA: "",
48 | DATA_API_ENVIRONMENT_OTHER: "",
49 | }
50 | API_VERSION_ENV_MAP = {
51 | DATA_API_ENVIRONMENT_PROD: "/v1",
52 | DATA_API_ENVIRONMENT_DEV: "/v1",
53 | DATA_API_ENVIRONMENT_TEST: "/v1",
54 | #
55 | DATA_API_ENVIRONMENT_DSE: "v1",
56 | DATA_API_ENVIRONMENT_HCD: "v1",
57 | DATA_API_ENVIRONMENT_CASSANDRA: "v1",
58 | DATA_API_ENVIRONMENT_OTHER: "v1",
59 | }
60 |
61 | # Defaults/settings for Data API requests
62 | DEFAULT_USE_DECIMALS_IN_COLLECTIONS = False
63 | DEFAULT_BINARY_ENCODE_VECTORS = True
64 | DEFAULT_CUSTOM_DATATYPES_IN_READING = True
65 | DEFAULT_UNROLL_ITERABLES_TO_LISTS = False
66 | DEFAULT_ENCODE_MAPS_AS_LISTS_IN_TABLES = "NEVER" # later coerced as MapEncodingMode
67 |
68 | DEFAULT_ACCEPT_NAIVE_DATETIMES = False
69 | DEFAULT_DATETIME_TZINFO = datetime.timezone.utc
70 |
71 | DEFAULT_INSERT_MANY_CHUNK_SIZE = 50
72 | DEFAULT_INSERT_MANY_CONCURRENCY = 20
73 | DEFAULT_REQUEST_TIMEOUT_MS = 10000
74 | DEFAULT_GENERAL_METHOD_TIMEOUT_MS = 30000
75 | DEFAULT_COLLECTION_ADMIN_TIMEOUT_MS = 60000
76 | DEFAULT_TABLE_ADMIN_TIMEOUT_MS = 30000
77 | DEFAULT_DATA_API_AUTH_HEADER = "Token"
78 | EMBEDDING_HEADER_AWS_ACCESS_ID = "X-Embedding-Access-Id"
79 | EMBEDDING_HEADER_AWS_SECRET_ID = "X-Embedding-Secret-Id"
80 | EMBEDDING_HEADER_API_KEY = "X-Embedding-Api-Key"
81 | RERANKING_HEADER_API_KEY = "Reranking-Api-Key"
82 |
83 | # Defaults/settings for DevOps API requests and admin operations
84 | DEFAULT_DEV_OPS_AUTH_HEADER = "Authorization"
85 | DEFAULT_DEV_OPS_AUTH_PREFIX = "Bearer "
86 | DEV_OPS_KEYSPACE_POLL_INTERVAL_S = 2
87 | DEV_OPS_DATABASE_POLL_INTERVAL_S = 15
88 | DEFAULT_DATABASE_ADMIN_TIMEOUT_MS = 600000
89 | DEFAULT_KEYSPACE_ADMIN_TIMEOUT_MS = 30000
90 | DEV_OPS_DATABASE_STATUS_MAINTENANCE = "MAINTENANCE"
91 | DEV_OPS_DATABASE_STATUS_ACTIVE = "ACTIVE"
92 | DEV_OPS_DATABASE_STATUS_PENDING = "PENDING"
93 | DEV_OPS_DATABASE_STATUS_INITIALIZING = "INITIALIZING"
94 | DEV_OPS_DATABASE_STATUS_ERROR = "ERROR"
95 | DEV_OPS_DATABASE_STATUS_TERMINATING = "TERMINATING"
96 | DEV_OPS_URL_ENV_MAP = {
97 | DATA_API_ENVIRONMENT_PROD: "https://api.astra.datastax.com",
98 | DATA_API_ENVIRONMENT_DEV: "https://api.dev.cloud.datastax.com",
99 | DATA_API_ENVIRONMENT_TEST: "https://api.test.cloud.datastax.com",
100 | }
101 | DEV_OPS_VERSION_ENV_MAP = {
102 | DATA_API_ENVIRONMENT_PROD: "v2",
103 | DATA_API_ENVIRONMENT_DEV: "v2",
104 | DATA_API_ENVIRONMENT_TEST: "v2",
105 | }
106 | DEV_OPS_RESPONSE_HTTP_ACCEPTED = 202
107 | DEV_OPS_RESPONSE_HTTP_CREATED = 201
108 | DEV_OPS_DEFAULT_DATABASES_PAGE_SIZE = 50
109 |
110 | # Settings for redacting secrets in string representations and logging
111 | SECRETS_REDACT_ENDING = "..."
112 | SECRETS_REDACT_CHAR = "*"
113 | SECRETS_REDACT_ENDING_LENGTH = 3
114 | FIXED_SECRET_PLACEHOLDER = "***"
115 | DEFAULT_REDACTED_HEADER_NAMES = {
116 | DEFAULT_DATA_API_AUTH_HEADER,
117 | DEFAULT_DEV_OPS_AUTH_HEADER,
118 | EMBEDDING_HEADER_AWS_ACCESS_ID,
119 | EMBEDDING_HEADER_AWS_SECRET_ID,
120 | EMBEDDING_HEADER_API_KEY,
121 | RERANKING_HEADER_API_KEY,
122 | }
123 |
--------------------------------------------------------------------------------
/astrapy/settings/error_messages.py:
--------------------------------------------------------------------------------
1 | # Copyright DataStax, Inc.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | from __future__ import annotations
16 |
17 | CANNOT_ENCODE_NAIVE_DATETIME_ERROR_MESSAGE = (
18 | "Cannot encode a datetime without timezone information ('tzinfo'). "
19 | "See the APIOptions.SerdesOptions.accept_naive_datetimes setting "
20 | "if you want to relax this write-time safeguard."
21 | )
22 |
--------------------------------------------------------------------------------
/astrapy/table.py:
--------------------------------------------------------------------------------
1 | # Copyright DataStax, Inc.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | from astrapy.data.table import (
16 | AsyncTable,
17 | Table,
18 | )
19 |
20 | __all__ = [
21 | "AsyncTable",
22 | "Table",
23 | ]
24 |
--------------------------------------------------------------------------------
/astrapy/utils/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright DataStax, Inc.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
--------------------------------------------------------------------------------
/astrapy/utils/meta.py:
--------------------------------------------------------------------------------
1 | # Copyright DataStax, Inc.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | from __future__ import annotations
16 |
17 | import warnings
18 | from functools import wraps
19 | from typing import Callable, TypeVar
20 |
21 | from deprecation import DeprecatedWarning
22 | from typing_extensions import ParamSpec # compatible with pre-3.10 Python
23 |
24 | P = ParamSpec("P")
25 | R = TypeVar("R")
26 |
27 | BETA_WARNING_TEMPLATE = (
28 | "Method '{method_name}' is in beta and might undergo signature "
29 | "or behaviour changes in the future."
30 | )
31 |
32 |
33 | class BetaFeatureWarning(UserWarning):
34 | pass
35 |
36 |
37 | def check_deprecated_alias(
38 | new_name: str | None,
39 | deprecated_name: str | None,
40 | ) -> str | None:
41 | """Generic blueprint utility for deprecating parameters through an alias.
42 |
43 | Normalize the two aliased parameter names, raising deprecation
44 | when needed and an error if both parameter supplied.
45 | The returned value is the final one for the parameter.
46 | """
47 |
48 | if deprecated_name is None:
49 | # no need for deprecation nor exceptions
50 | return new_name
51 | else:
52 | # issue a deprecation warning
53 | the_warning = DeprecatedWarning(
54 | "Parameter 'deprecated_name'",
55 | deprecated_in="2.0.0",
56 | removed_in="3.0.0",
57 | details="Please use 'new_name' instead.",
58 | )
59 | warnings.warn(
60 | the_warning,
61 | stacklevel=3,
62 | )
63 |
64 | if new_name is None:
65 | return deprecated_name
66 | else:
67 | msg = (
68 | "Parameters `new_name` and `deprecated_name` "
69 | "(a deprecated alias for the former) cannot be passed at the same time."
70 | )
71 | raise ValueError(msg)
72 |
73 |
74 | def deprecated_property(
75 | new_name: str, deprecated_in: str, removed_in: str
76 | ) -> Callable[[Callable[P, R]], Callable[P, R]]:
77 | """
78 | Decorator for a @property that is a deprecated alias for attribute 'new_name'.
79 | """
80 |
81 | def _deprecator(method: Callable[P, R]) -> Callable[P, R]:
82 | @wraps(method)
83 | def wrapper(*args: P.args, **kwargs: P.kwargs) -> R:
84 | the_warning = DeprecatedWarning(
85 | f"Property '{method.__name__}'",
86 | deprecated_in=deprecated_in,
87 | removed_in=removed_in,
88 | details=f"Please use '{new_name}' instead.",
89 | )
90 | warnings.warn(
91 | the_warning,
92 | stacklevel=2,
93 | )
94 | return method(*args, **kwargs)
95 |
96 | return wrapper
97 |
98 | return _deprecator
99 |
100 |
101 | def beta_method(method: Callable[P, R]) -> Callable[P, R]:
102 | @wraps(method)
103 | def wrapper(*args: P.args, **kwargs: P.kwargs) -> R:
104 | beta_warning_message = BETA_WARNING_TEMPLATE.format(
105 | method_name=method.__qualname__
106 | )
107 | warnings.warn(
108 | beta_warning_message,
109 | BetaFeatureWarning,
110 | stacklevel=2,
111 | )
112 | return method(*args, **kwargs)
113 |
114 | return wrapper
115 |
--------------------------------------------------------------------------------
/astrapy/utils/parsing.py:
--------------------------------------------------------------------------------
1 | # Copyright DataStax, Inc.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | from __future__ import annotations
16 |
17 | import warnings
18 | from typing import Any
19 |
20 |
21 | def _warn_residual_keys(
22 | klass: type, raw_dict: dict[str, Any], known_keys: set[str]
23 | ) -> None:
24 | residual_keys = raw_dict.keys() - known_keys
25 | if residual_keys:
26 | warnings.warn(
27 | "Unexpected key(s) encountered parsing a dictionary into "
28 | f"a `{klass.__name__}`: '{','.join(sorted(residual_keys))}'"
29 | )
30 |
--------------------------------------------------------------------------------
/astrapy/utils/request_tools.py:
--------------------------------------------------------------------------------
1 | # Copyright DataStax, Inc.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | from __future__ import annotations
16 |
17 | import logging
18 | from typing import Any
19 |
20 | import httpx
21 |
22 | from astrapy.exceptions import _TimeoutContext
23 |
24 | logger = logging.getLogger(__name__)
25 |
26 |
27 | def log_httpx_request(
28 | http_method: str,
29 | full_url: str,
30 | request_params: dict[str, Any] | None,
31 | redacted_request_headers: dict[str, str],
32 | encoded_payload: str | None,
33 | timeout_context: _TimeoutContext,
34 | ) -> None:
35 | """
36 | Log the details of an HTTP request for debugging purposes.
37 |
38 | Args:
39 | http_method: the HTTP verb of the request (e.g. "POST").
40 | full_url: the URL of the request (e.g. "https://domain.com/full/path").
41 | request_params: parameters of the request.
42 | redacted_request_headers: caution, as these will be logged as they are.
43 | encoded_payload: the payload (in bytes) sent with the request, if any.
44 | timeout_ms: the timeout in milliseconds, if any is set.
45 | """
46 | logger.debug(f"Request URL: {http_method} {full_url}")
47 | if request_params:
48 | logger.debug(f"Request params: '{request_params}'")
49 | if redacted_request_headers:
50 | logger.debug(f"Request headers: '{redacted_request_headers}'")
51 | if encoded_payload is not None:
52 | logger.debug(f"Request payload: '{encoded_payload}'")
53 | if timeout_context:
54 | logger.debug(
55 | f"Timeout (ms): for request {timeout_context.request_ms or '(unset)'} ms"
56 | f", overall operation {timeout_context.nominal_ms or '(unset)'} ms"
57 | )
58 |
59 |
60 | def log_httpx_response(response: httpx.Response) -> None:
61 | """
62 | Log the details of an httpx.Response.
63 |
64 | Args:
65 | response: the httpx.Response object to log.
66 | """
67 | logger.debug(f"Response status code: {response.status_code}")
68 | logger.debug(f"Response headers: '{response.headers}'")
69 | logger.debug(f"Response text: '{response.text}'")
70 |
71 |
72 | class HttpMethod:
73 | GET = "GET"
74 | POST = "POST"
75 | PUT = "PUT"
76 | PATCH = "PATCH"
77 | DELETE = "DELETE"
78 |
79 |
80 | def to_httpx_timeout(timeout_context: _TimeoutContext) -> httpx.Timeout | None:
81 | if timeout_context.request_ms is None or timeout_context.request_ms == 0:
82 | return None
83 | else:
84 | return httpx.Timeout(timeout_context.request_ms / 1000)
85 |
--------------------------------------------------------------------------------
/astrapy/utils/str_enum.py:
--------------------------------------------------------------------------------
1 | # Copyright DataStax, Inc.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | from __future__ import annotations
16 |
17 | from enum import Enum, EnumMeta
18 | from typing import TypeVar
19 |
20 | T = TypeVar("T", bound="StrEnum")
21 |
22 |
23 | class StrEnumMeta(EnumMeta):
24 | def _name_lookup(cls, value: str) -> str | None:
25 | """Return a proper key in the enum if some matching logic works, or None."""
26 | mmap = {k: v.value for k, v in cls._member_map_.items()}
27 | # try exact key match
28 | if value in mmap:
29 | return value
30 | # try case-insensitive key match
31 | u_value = value.upper()
32 | u_mmap = {k.upper(): k for k in mmap.keys()}
33 | if u_value in u_mmap:
34 | return u_mmap[u_value]
35 | # try case-insensitive *value* match
36 | v_mmap = {v.upper(): k for k, v in mmap.items()}
37 | if u_value in v_mmap:
38 | return v_mmap[u_value]
39 | return None
40 |
41 | def __contains__(cls, value: object) -> bool:
42 | """Return True if the provided string belongs to the enum."""
43 | if isinstance(value, str):
44 | return cls._name_lookup(value) is not None
45 | return False
46 |
47 |
48 | class StrEnum(Enum, metaclass=StrEnumMeta):
49 | @classmethod
50 | def coerce(cls: type[T], value: str | T) -> T:
51 | """
52 | Accepts either a string or an instance of the Enum itself.
53 | If a string is passed, it converts it to the corresponding
54 | Enum value (case-insensitive).
55 | If an Enum instance is passed, it returns it as-is.
56 | Raises ValueError if the string does not match any enum member.
57 | """
58 |
59 | if isinstance(value, cls):
60 | return value
61 | elif isinstance(value, str):
62 | norm_value = cls._name_lookup(value)
63 | if norm_value is not None:
64 | return cls[norm_value]
65 | # no matches
66 | raise ValueError(
67 | f"Invalid value '{value}' for {cls.__name__}. "
68 | f"Allowed values are: {[e.value for e in cls]}"
69 | )
70 | raise ValueError(
71 | f"Invalid value '{value}' for {cls.__name__}. "
72 | f"Allowed values are: {[e.value for e in cls]}"
73 | )
74 |
--------------------------------------------------------------------------------
/astrapy/utils/unset.py:
--------------------------------------------------------------------------------
1 | # Copyright DataStax, Inc.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | from __future__ import annotations
16 |
17 |
18 | class UnsetType:
19 | _instance = None
20 |
21 | def __new__(cls) -> UnsetType:
22 | if cls._instance is None:
23 | cls._instance = super().__new__(cls)
24 | return cls._instance
25 |
26 | def __repr__(self) -> str:
27 | return "(unset)"
28 |
29 |
30 | _UNSET = UnsetType()
31 |
--------------------------------------------------------------------------------
/astrapy/utils/user_agents.py:
--------------------------------------------------------------------------------
1 | # Copyright DataStax, Inc.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | from __future__ import annotations
16 |
17 | from typing import Sequence
18 |
19 | from astrapy import __version__
20 | from astrapy.constants import CallerType
21 |
22 |
23 | def detect_astrapy_user_agent() -> CallerType:
24 | package_name = __name__.split(".")[0]
25 | return (package_name, __version__)
26 |
27 |
28 | def compose_user_agent_string(
29 | caller_name: str | None, caller_version: str | None
30 | ) -> str | None:
31 | if caller_name:
32 | if caller_version:
33 | return f"{caller_name}/{caller_version}"
34 | else:
35 | return f"{caller_name}"
36 | else:
37 | return None
38 |
39 |
40 | def compose_full_user_agent(callers: Sequence[CallerType]) -> str | None:
41 | user_agent_strings = [
42 | ua_string
43 | for ua_string in (
44 | compose_user_agent_string(caller[0], caller[1]) for caller in callers
45 | )
46 | if ua_string
47 | ]
48 | if user_agent_strings:
49 | return " ".join(user_agent_strings)
50 | else:
51 | return None
52 |
--------------------------------------------------------------------------------
/pictures/astrapy_abstractions.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/datastax/astrapy/ffc693317399673632b72baf3ccbeb52a20f1eec/pictures/astrapy_abstractions.png
--------------------------------------------------------------------------------
/pictures/astrapy_datetime_serdes_options.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/datastax/astrapy/ffc693317399673632b72baf3ccbeb52a20f1eec/pictures/astrapy_datetime_serdes_options.png
--------------------------------------------------------------------------------
/pictures/astrapy_exceptions.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/datastax/astrapy/ffc693317399673632b72baf3ccbeb52a20f1eec/pictures/astrapy_exceptions.png
--------------------------------------------------------------------------------
/pyproject.toml:
--------------------------------------------------------------------------------
1 | [project]
2 | requires-python = ">=3.8,<4.0"
3 | name = "astrapy"
4 | version = "2.0.1"
5 | description = "A Python client for the Data API on DataStax Astra DB"
6 | authors = [
7 | {"name" = "Stefano Lottini", "email" = "stefano.lottini@datastax.com"},
8 | {"name" = "Eric Hare", "email" = "eric.hare@datastax.com"},
9 | ]
10 | readme = "README.md"
11 |
12 | keywords = ["DataStax", "Astra DB", "Astra"]
13 | dependencies = [
14 | "deprecation ~= 2.1.0",
15 | "httpx[http2]>=0.25.2,<1",
16 | "h11 >= 0.16.0",
17 | "pymongo >= 3",
18 | "toml >= 0.10.2,<1.0.0",
19 | "typing-extensions >= 4.0",
20 | "uuid6 >= 2024.1.12"
21 | ]
22 | classifiers = [
23 | "Development Status :: 5 - Production/Stable",
24 | "Intended Audience :: Developers",
25 | "License :: OSI Approved :: Apache Software License",
26 | "Programming Language :: Python :: 3.8",
27 | "Programming Language :: Python :: 3.9",
28 | "Programming Language :: Python :: 3.10",
29 | "Programming Language :: Python :: 3.11",
30 | "Programming Language :: Python :: 3.12",
31 | "Topic :: Software Development :: Build Tools"
32 | ]
33 |
34 | [project.urls]
35 | Homepage = "https://github.com/datastax/astrapy"
36 | Documentation = "https://docs.datastax.com/en/astra-db-serverless/api-reference/dataapiclient.html"
37 | Repository = "https://github.com/datastax/astrapy"
38 | Issues = "https://github.com/datastax/astrapy/issues"
39 | Changelog = "https://github.com/datastax/astrapy/blob/main/CHANGES"
40 |
41 | [dependency-groups]
42 | dev = [
43 | "blockbuster ~= 1.5.5",
44 | "build >= 1.0.0",
45 | "cassio ~= 0.1.10; python_version >= '3.9'",
46 | "faker ~= 23.1.0",
47 | "mypy ~= 1.9.0",
48 | "ruff >= 0.11.9,<0.12",
49 | "pre-commit ~= 3.5.0",
50 | "pytest ~= 8.0.0",
51 | "pytest-asyncio ~= 0.23.5",
52 | "pytest-cov ~= 4.1.0",
53 | "pytest-testdox ~= 3.1.0",
54 | "python-dotenv ~= 1.0.1",
55 | "pytest-httpserver ~= 1.0.8",
56 | "setuptools >= 61.0",
57 | "testcontainers ~= 3.7.1",
58 | "types-toml >= 0.10.8.7,<1.0.0"
59 | ]
60 |
61 | [tool.hatch.build.targets.wheel]
62 | packages = ["astrapy"]
63 |
64 | [build-system]
65 | requires = ["hatchling"]
66 | build-backend = "hatchling.build"
67 |
68 | [tool.setuptools.packages.find]
69 | include = ["astrapy*"]
70 |
71 | [tool.ruff.lint]
72 | select = ["E4", "E7", "E9", "F", "FA", "I", "UP"]
73 |
74 | [tool.mypy]
75 | disallow_any_generics = true
76 | disallow_incomplete_defs = true
77 | disallow_untyped_calls = true
78 | disallow_untyped_decorators = true
79 | disallow_untyped_defs = true
80 | follow_imports = "normal"
81 | ignore_missing_imports = true
82 | no_implicit_reexport = true
83 | show_error_codes = true
84 | show_error_context = true
85 | strict_equality = true
86 | strict_optional = true
87 | warn_redundant_casts = true
88 | warn_return_any = true
89 | warn_unused_ignores = true
90 |
91 | [tool.pytest.ini_options]
92 | filterwarnings = "ignore::DeprecationWarning"
93 | addopts = "-v --cov=astrapy --testdox --cov-report term-missing"
94 | asyncio_mode = "auto"
95 | log_cli = 1
96 | log_cli_level = "INFO"
97 |
--------------------------------------------------------------------------------
/tests/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright DataStax, Inc.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | from __future__ import annotations
16 |
--------------------------------------------------------------------------------
/tests/admin/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright DataStax, Inc.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | from __future__ import annotations
16 |
--------------------------------------------------------------------------------
/tests/admin/conftest.py:
--------------------------------------------------------------------------------
1 | # Copyright DataStax, Inc.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | from __future__ import annotations
16 |
17 | from ..conftest import (
18 | ADMIN_ENV_LIST,
19 | ADMIN_ENV_VARIABLE_MAP,
20 | IS_ASTRA_DB,
21 | )
22 |
23 | __all__ = [
24 | "ADMIN_ENV_LIST",
25 | "ADMIN_ENV_VARIABLE_MAP",
26 | "IS_ASTRA_DB",
27 | ]
28 |
--------------------------------------------------------------------------------
/tests/admin/integration/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright DataStax, Inc.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | from __future__ import annotations
16 |
--------------------------------------------------------------------------------
/tests/base/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright DataStax, Inc.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | from __future__ import annotations
16 |
--------------------------------------------------------------------------------
/tests/base/collection_decimal_support_assets.py:
--------------------------------------------------------------------------------
1 | # Copyright DataStax, Inc.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | from __future__ import annotations
16 |
17 | from decimal import Decimal
18 | from typing import Any
19 |
20 | from astrapy.utils.api_options import FullSerdesOptions
21 |
22 | S_OPTS_NO_DECS = FullSerdesOptions(
23 | binary_encode_vectors=False,
24 | custom_datatypes_in_reading=True,
25 | unroll_iterables_to_lists=True,
26 | use_decimals_in_collections=False,
27 | encode_maps_as_lists_in_tables="NEVER",
28 | accept_naive_datetimes=False,
29 | datetime_tzinfo=None,
30 | )
31 | S_OPTS_OK_DECS = FullSerdesOptions(
32 | binary_encode_vectors=False,
33 | custom_datatypes_in_reading=True,
34 | unroll_iterables_to_lists=True,
35 | use_decimals_in_collections=True,
36 | encode_maps_as_lists_in_tables="NEVER",
37 | accept_naive_datetimes=False,
38 | datetime_tzinfo=None,
39 | )
40 | _BASELINE_SCALAR_CASES = {
41 | "_id": "baseline",
42 | "f": 1.23,
43 | "i": 123,
44 | "t": "T",
45 | }
46 | _W_DECIMALS_SCALAR_CASES = {
47 | "_id": "decimals",
48 | "f": 1.23,
49 | "i": 123,
50 | "t": "T",
51 | "de": Decimal("1.23"),
52 | "dfA": Decimal(1.23),
53 | "dfB": Decimal("1.229999999999999982236431605997495353221893310546875"),
54 | }
55 | BASELINE_OBJ = {
56 | **_BASELINE_SCALAR_CASES,
57 | "subdict": _BASELINE_SCALAR_CASES,
58 | "sublist": list(_BASELINE_SCALAR_CASES.values()),
59 | }
60 | WDECS_OBJ = {
61 | **_W_DECIMALS_SCALAR_CASES,
62 | "subdict": _W_DECIMALS_SCALAR_CASES,
63 | "sublist": list(_W_DECIMALS_SCALAR_CASES.values()),
64 | }
65 |
66 |
67 | def is_decimal_super(more_decs: dict[str, Any], less_decs: dict[str, Any]) -> bool:
68 | """
69 | Return True if the first item is "the same values, possibly made Decimal
70 | where the corresponding second item can be another number (int/float)".
71 | """
72 | if isinstance(more_decs, list):
73 | if not isinstance(less_decs, list):
74 | return False
75 | if len(more_decs) != len(less_decs):
76 | return False
77 | return all(
78 | [
79 | is_decimal_super(v_more, v_less)
80 | for v_more, v_less in zip(more_decs, less_decs)
81 | ]
82 | )
83 | elif isinstance(more_decs, dict):
84 | if not isinstance(less_decs, dict):
85 | return False
86 | if more_decs.keys() != less_decs.keys():
87 | return False
88 | return all(
89 | [is_decimal_super(v_more, less_decs[k]) for k, v_more in more_decs.items()]
90 | )
91 | else:
92 | # other scalars
93 | if isinstance(more_decs, Decimal):
94 | if isinstance(less_decs, Decimal):
95 | return more_decs == less_decs
96 | else:
97 | return float(more_decs) == float(less_decs)
98 | else:
99 | if isinstance(less_decs, Decimal):
100 | return False
101 | return more_decs == less_decs
102 |
--------------------------------------------------------------------------------
/tests/base/integration/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright DataStax, Inc.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | from __future__ import annotations
16 |
--------------------------------------------------------------------------------
/tests/base/integration/collection_decimal_support_assets.py:
--------------------------------------------------------------------------------
1 | # Copyright DataStax, Inc.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | from __future__ import annotations
16 |
17 | from ..collection_decimal_support_assets import (
18 | BASELINE_OBJ,
19 | S_OPTS_NO_DECS,
20 | S_OPTS_OK_DECS,
21 | WDECS_OBJ,
22 | is_decimal_super,
23 | )
24 |
25 | __all__ = [
26 | "BASELINE_OBJ",
27 | "S_OPTS_NO_DECS",
28 | "S_OPTS_OK_DECS",
29 | "WDECS_OBJ",
30 | "is_decimal_super",
31 | ]
32 |
--------------------------------------------------------------------------------
/tests/base/integration/collections/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright DataStax, Inc.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | from __future__ import annotations
16 |
--------------------------------------------------------------------------------
/tests/base/integration/collections/test_collection_decimal_support.py:
--------------------------------------------------------------------------------
1 | # Copyright DataStax, Inc.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | from __future__ import annotations
16 |
17 | import pytest
18 |
19 | from astrapy.api_options import APIOptions
20 |
21 | from ..collection_decimal_support_assets import (
22 | BASELINE_OBJ,
23 | S_OPTS_NO_DECS,
24 | S_OPTS_OK_DECS,
25 | WDECS_OBJ,
26 | is_decimal_super,
27 | )
28 | from ..conftest import DefaultAsyncCollection, DefaultCollection
29 |
30 |
31 | class TestCollectionDecimalSupportIntegration:
32 | @pytest.mark.describe(
33 | "test of decimals not supported by default in collections, sync"
34 | )
35 | def test_decimalsupport_collections_defaultsettings_sync(
36 | self,
37 | sync_empty_collection: DefaultCollection,
38 | ) -> None:
39 | # write-and-read, baseline
40 | no_decimal_coll = sync_empty_collection.with_options(
41 | api_options=APIOptions(
42 | serdes_options=S_OPTS_NO_DECS,
43 | ),
44 | )
45 | no_decimal_coll.insert_one(BASELINE_OBJ)
46 | baseline_obj_2 = no_decimal_coll.find_one({"_id": BASELINE_OBJ["_id"]})
47 | assert baseline_obj_2 is not None
48 | assert BASELINE_OBJ == baseline_obj_2
49 |
50 | # the write should error for object with decimals
51 | with pytest.raises(TypeError):
52 | no_decimal_coll.insert_one(WDECS_OBJ)
53 |
54 | @pytest.mark.describe(
55 | "test of decimals not supported by default in collections, async"
56 | )
57 | async def test_decimalsupport_collections_defaultsettings_async(
58 | self,
59 | async_empty_collection: DefaultAsyncCollection,
60 | ) -> None:
61 | # write-and-read, baseline
62 | no_decimal_acoll = async_empty_collection.with_options(
63 | api_options=APIOptions(
64 | serdes_options=S_OPTS_NO_DECS,
65 | ),
66 | )
67 | await no_decimal_acoll.insert_one(BASELINE_OBJ)
68 | baseline_obj_2 = await no_decimal_acoll.find_one({"_id": BASELINE_OBJ["_id"]})
69 | assert baseline_obj_2 is not None
70 | assert BASELINE_OBJ == baseline_obj_2
71 |
72 | # the write should error for object with decimals
73 | with pytest.raises(TypeError):
74 | await no_decimal_acoll.insert_one(WDECS_OBJ)
75 |
76 | @pytest.mark.describe(
77 | "test of decimals supported in collections if set to do so, sync"
78 | )
79 | def test_decimalsupport_collections_decimalsettings_sync(
80 | self,
81 | sync_empty_collection: DefaultCollection,
82 | ) -> None:
83 | # write-and-read, baseline
84 | ok_decimal_coll = sync_empty_collection.with_options(
85 | api_options=APIOptions(
86 | serdes_options=S_OPTS_OK_DECS,
87 | ),
88 | )
89 | ok_decimal_coll.insert_one(BASELINE_OBJ)
90 | baseline_obj_2 = ok_decimal_coll.find_one({"_id": BASELINE_OBJ["_id"]})
91 | assert baseline_obj_2 is not None
92 | assert is_decimal_super(baseline_obj_2, BASELINE_OBJ)
93 |
94 | # write-and-read, with decimals
95 | ok_decimal_coll.insert_one(WDECS_OBJ)
96 | wdecs_2 = ok_decimal_coll.find_one({"_id": WDECS_OBJ["_id"]})
97 | assert wdecs_2 is not None
98 | assert is_decimal_super(wdecs_2, WDECS_OBJ)
99 |
100 | @pytest.mark.describe(
101 | "test of decimals supported in collections if set to do so, async"
102 | )
103 | async def test_decimalsupport_collections_decimalsettings_async(
104 | self,
105 | async_empty_collection: DefaultAsyncCollection,
106 | ) -> None:
107 | # write-and-read, baseline
108 | ok_decimal_acoll = async_empty_collection.with_options(
109 | api_options=APIOptions(
110 | serdes_options=S_OPTS_OK_DECS,
111 | ),
112 | )
113 | await ok_decimal_acoll.insert_one(BASELINE_OBJ)
114 | baseline_obj_2 = await ok_decimal_acoll.find_one({"_id": BASELINE_OBJ["_id"]})
115 | assert baseline_obj_2 is not None
116 | assert is_decimal_super(baseline_obj_2, BASELINE_OBJ)
117 |
118 | # write-and-read, with decimals
119 | await ok_decimal_acoll.insert_one(WDECS_OBJ)
120 | wdecs_2 = await ok_decimal_acoll.find_one({"_id": WDECS_OBJ["_id"]})
121 | assert wdecs_2 is not None
122 | assert is_decimal_super(wdecs_2, WDECS_OBJ)
123 |
--------------------------------------------------------------------------------
/tests/base/integration/conftest.py:
--------------------------------------------------------------------------------
1 | # Copyright DataStax, Inc.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | from __future__ import annotations
16 |
17 | from ..conftest import (
18 | ADMIN_ENV_LIST,
19 | ADMIN_ENV_VARIABLE_MAP,
20 | CQL_AVAILABLE,
21 | HEADER_EMBEDDING_API_KEY_OPENAI,
22 | IS_ASTRA_DB,
23 | SECONDARY_KEYSPACE,
24 | TEST_COLLECTION_NAME,
25 | VECTORIZE_TEXTS,
26 | DataAPICredentials,
27 | DataAPICredentialsInfo,
28 | DefaultAsyncCollection,
29 | DefaultAsyncTable,
30 | DefaultCollection,
31 | DefaultTable,
32 | _repaint_NaNs,
33 | _typify_tuple,
34 | async_fail_if_not_removed,
35 | clean_nulls_from_dict,
36 | sync_fail_if_not_removed,
37 | )
38 |
39 | __all__ = [
40 | "DataAPICredentials",
41 | "DataAPICredentialsInfo",
42 | "async_fail_if_not_removed",
43 | "clean_nulls_from_dict",
44 | "sync_fail_if_not_removed",
45 | "HEADER_EMBEDDING_API_KEY_OPENAI",
46 | "IS_ASTRA_DB",
47 | "ADMIN_ENV_LIST",
48 | "ADMIN_ENV_VARIABLE_MAP",
49 | "CQL_AVAILABLE",
50 | "SECONDARY_KEYSPACE",
51 | "TEST_COLLECTION_NAME",
52 | "VECTORIZE_TEXTS",
53 | "_repaint_NaNs",
54 | "_typify_tuple",
55 | "DefaultCollection",
56 | "DefaultAsyncCollection",
57 | "DefaultAsyncTable",
58 | "DefaultTable",
59 | ]
60 |
--------------------------------------------------------------------------------
/tests/base/integration/misc/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright DataStax, Inc.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
--------------------------------------------------------------------------------
/tests/base/integration/misc/test_admin_ops_async.py:
--------------------------------------------------------------------------------
1 | # Copyright DataStax, Inc.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | from __future__ import annotations
16 |
17 | import pytest
18 |
19 | from astrapy import DataAPIClient
20 |
21 | from ..conftest import IS_ASTRA_DB, DataAPICredentials
22 |
23 |
24 | class TestFindRegionsAsync:
25 | @pytest.mark.skipif(not IS_ASTRA_DB, reason="This test requires Astra DB")
26 | @pytest.mark.describe("test of find_available_regions, async")
27 | async def test_findavailableregions_async(
28 | self,
29 | client: DataAPIClient,
30 | data_api_credentials_kwargs: DataAPICredentials,
31 | ) -> None:
32 | admin = client.get_admin(token=data_api_credentials_kwargs["token"])
33 |
34 | ar0 = await admin.async_find_available_regions()
35 | art = await admin.async_find_available_regions(only_org_enabled_regions=True)
36 | arf = await admin.async_find_available_regions(only_org_enabled_regions=False)
37 |
38 | assert ar0 == art
39 | assert len(arf) >= len(art)
40 | assert all(reg in arf for reg in art)
41 |
--------------------------------------------------------------------------------
/tests/base/integration/misc/test_admin_ops_sync.py:
--------------------------------------------------------------------------------
1 | # Copyright DataStax, Inc.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | from __future__ import annotations
16 |
17 | import pytest
18 |
19 | from astrapy import DataAPIClient
20 |
21 | from ..conftest import IS_ASTRA_DB, DataAPICredentials
22 |
23 |
24 | class TestFindRegionsSync:
25 | @pytest.mark.skipif(not IS_ASTRA_DB, reason="This test requires Astra DB")
26 | @pytest.mark.describe("test of find_available_regions, sync")
27 | def test_findavailableregions_sync(
28 | self,
29 | client: DataAPIClient,
30 | data_api_credentials_kwargs: DataAPICredentials,
31 | ) -> None:
32 | admin = client.get_admin(token=data_api_credentials_kwargs["token"])
33 |
34 | ar0 = admin.find_available_regions()
35 | art = admin.find_available_regions(only_org_enabled_regions=True)
36 | arf = admin.find_available_regions(only_org_enabled_regions=False)
37 |
38 | assert ar0 == art
39 | assert len(arf) >= len(art)
40 | assert all(reg in arf for reg in art)
41 |
--------------------------------------------------------------------------------
/tests/base/integration/misc/test_reranking_ops_async.py:
--------------------------------------------------------------------------------
1 | # Copyright DataStax, Inc.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | from __future__ import annotations
16 |
17 | import os
18 |
19 | import pytest
20 |
21 | from astrapy import AsyncDatabase
22 | from astrapy.info import FindRerankingProvidersResult, RerankingProvider
23 |
24 | from ..conftest import IS_ASTRA_DB
25 |
26 |
27 | def _count_models(frp_result: FindRerankingProvidersResult) -> int:
28 | return len(
29 | [
30 | model
31 | for prov_v in frp_result.reranking_providers.values()
32 | for model in prov_v.models
33 | ]
34 | )
35 |
36 |
37 | class TestRerankingOpsAsync:
38 | @pytest.mark.describe("test of find_reranking_providers, async")
39 | async def test_findrerankingproviders_async(
40 | self,
41 | async_database: AsyncDatabase,
42 | ) -> None:
43 | database_admin = async_database.get_database_admin()
44 | rp_result = await database_admin.async_find_reranking_providers()
45 |
46 | assert isinstance(rp_result, FindRerankingProvidersResult)
47 |
48 | assert all(
49 | isinstance(rer_prov, RerankingProvider)
50 | for rer_prov in rp_result.reranking_providers.values()
51 | )
52 |
53 | # 'raw_info' not compared, for resiliency against newly-introduced fields
54 | assert (
55 | FindRerankingProvidersResult._from_dict(
56 | rp_result.as_dict()
57 | ).reranking_providers
58 | == rp_result.reranking_providers
59 | )
60 |
61 | @pytest.mark.skipif(
62 | "ASTRAPY_TEST_LATEST_MAIN" not in os.environ,
63 | reason="No 'latest main' tests required.",
64 | )
65 | @pytest.mark.skipif(IS_ASTRA_DB, reason="Filtering models not yet on Astra DB")
66 | @pytest.mark.describe("test of find_reranking_providers filtering, async")
67 | async def test_filtered_findrerankingproviders_async(
68 | self,
69 | async_database: AsyncDatabase,
70 | ) -> None:
71 | database_admin = async_database.get_database_admin()
72 | default_count = _count_models(
73 | await database_admin.async_find_reranking_providers()
74 | )
75 |
76 | all_count = _count_models(
77 | await database_admin.async_find_reranking_providers(filter_model_status="")
78 | )
79 |
80 | sup_count = _count_models(
81 | await database_admin.async_find_reranking_providers(
82 | filter_model_status="SUPPORTED"
83 | )
84 | )
85 | dep_count = _count_models(
86 | await database_admin.async_find_reranking_providers(
87 | filter_model_status="DEPRECATED"
88 | )
89 | )
90 | eol_count = _count_models(
91 | await database_admin.async_find_reranking_providers(
92 | filter_model_status="END_OF_LIFE"
93 | )
94 | )
95 |
96 | assert sup_count + dep_count + eol_count == all_count
97 | assert sup_count == default_count
98 |
--------------------------------------------------------------------------------
/tests/base/integration/misc/test_reranking_ops_sync.py:
--------------------------------------------------------------------------------
1 | # Copyright DataStax, Inc.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | from __future__ import annotations
16 |
17 | import os
18 |
19 | import pytest
20 |
21 | from astrapy import Database
22 | from astrapy.info import FindRerankingProvidersResult, RerankingProvider
23 |
24 | from ..conftest import IS_ASTRA_DB
25 |
26 |
27 | def _count_models(frp_result: FindRerankingProvidersResult) -> int:
28 | return len(
29 | [
30 | model
31 | for prov_v in frp_result.reranking_providers.values()
32 | for model in prov_v.models
33 | ]
34 | )
35 |
36 |
37 | class TestRerankingOpsSync:
38 | @pytest.mark.describe("test of find_reranking_providers, sync")
39 | def test_findrerankingproviders_sync(
40 | self,
41 | sync_database: Database,
42 | ) -> None:
43 | database_admin = sync_database.get_database_admin()
44 | rp_result = database_admin.find_reranking_providers()
45 |
46 | assert isinstance(rp_result, FindRerankingProvidersResult)
47 |
48 | assert all(
49 | isinstance(rer_prov, RerankingProvider)
50 | for rer_prov in rp_result.reranking_providers.values()
51 | )
52 |
53 | # 'raw_info' not compared, for resiliency against newly-introduced fields
54 | assert (
55 | FindRerankingProvidersResult._from_dict(
56 | rp_result.as_dict()
57 | ).reranking_providers
58 | == rp_result.reranking_providers
59 | )
60 |
61 | @pytest.mark.skipif(
62 | "ASTRAPY_TEST_LATEST_MAIN" not in os.environ,
63 | reason="No 'latest main' tests required.",
64 | )
65 | @pytest.mark.skipif(IS_ASTRA_DB, reason="Filtering models not yet on Astra DB")
66 | @pytest.mark.describe("test of find_reranking_providers filtering, sync")
67 | def test_filtered_findrerankingproviders_sync(
68 | self,
69 | sync_database: Database,
70 | ) -> None:
71 | database_admin = sync_database.get_database_admin()
72 | default_count = _count_models(database_admin.find_reranking_providers())
73 |
74 | all_count = _count_models(
75 | database_admin.find_reranking_providers(filter_model_status="")
76 | )
77 |
78 | sup_count = _count_models(
79 | database_admin.find_reranking_providers(filter_model_status="SUPPORTED")
80 | )
81 | dep_count = _count_models(
82 | database_admin.find_reranking_providers(filter_model_status="DEPRECATED")
83 | )
84 | eol_count = _count_models(
85 | database_admin.find_reranking_providers(filter_model_status="END_OF_LIFE")
86 | )
87 |
88 | assert sup_count + dep_count + eol_count == all_count
89 | assert sup_count == default_count
90 |
--------------------------------------------------------------------------------
/tests/base/integration/misc/test_vectorize_ops_async.py:
--------------------------------------------------------------------------------
1 | # Copyright DataStax, Inc.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | from __future__ import annotations
16 |
17 | import os
18 |
19 | import pytest
20 |
21 | from astrapy import AsyncDatabase
22 | from astrapy.info import EmbeddingProvider, FindEmbeddingProvidersResult
23 |
24 | from ..conftest import IS_ASTRA_DB, clean_nulls_from_dict
25 |
26 |
27 | def _count_models(fep_result: FindEmbeddingProvidersResult) -> int:
28 | return len(
29 | [
30 | model
31 | for prov_v in fep_result.embedding_providers.values()
32 | for model in prov_v.models
33 | ]
34 | )
35 |
36 |
37 | class TestVectorizeOpsAsync:
38 | @pytest.mark.describe("test of find_embedding_providers, async")
39 | async def test_findembeddingproviders_async(
40 | self,
41 | async_database: AsyncDatabase,
42 | ) -> None:
43 | database_admin = async_database.get_database_admin()
44 | ep_result = await database_admin.async_find_embedding_providers()
45 |
46 | assert isinstance(ep_result, FindEmbeddingProvidersResult)
47 |
48 | assert all(
49 | isinstance(emb_prov, EmbeddingProvider)
50 | for emb_prov in ep_result.embedding_providers.values()
51 | )
52 |
53 | reconstructed = {
54 | ep_name: EmbeddingProvider._from_dict(emb_prov.as_dict())
55 | for ep_name, emb_prov in ep_result.embedding_providers.items()
56 | }
57 | assert reconstructed == ep_result.embedding_providers
58 | dict_mapping = {
59 | ep_name: emb_prov.as_dict()
60 | for ep_name, emb_prov in ep_result.embedding_providers.items()
61 | }
62 | cleaned_dict_mapping = clean_nulls_from_dict(dict_mapping)
63 | cleaned_raw_info = clean_nulls_from_dict(
64 | ep_result.raw_info["embeddingProviders"] # type: ignore[index]
65 | )
66 | # TODO remove this cleanup once Astra prod and HCD get the support status flags
67 | raw_info_has_ams = any(
68 | "apiModelSupport" in model_dict
69 | for prov_v in cleaned_raw_info.values()
70 | for model_dict in prov_v["models"]
71 | )
72 | if not raw_info_has_ams:
73 | for emb_prov in cleaned_dict_mapping.values():
74 | for model in emb_prov["models"]:
75 | if "apiModelSupport" in model:
76 | del model["apiModelSupport"]
77 | assert cleaned_dict_mapping == cleaned_raw_info
78 |
79 | @pytest.mark.skipif(
80 | "ASTRAPY_TEST_LATEST_MAIN" not in os.environ,
81 | reason="No 'latest main' tests required.",
82 | )
83 | @pytest.mark.skipif(IS_ASTRA_DB, reason="Filtering models not yet on Astra DB")
84 | @pytest.mark.describe("test of find_embedding_providers filtering, async")
85 | async def test_filtered_findembeddingproviders_async(
86 | self,
87 | async_database: AsyncDatabase,
88 | ) -> None:
89 | database_admin = async_database.get_database_admin()
90 | default_count = _count_models(
91 | await database_admin.async_find_embedding_providers()
92 | )
93 |
94 | all_count = _count_models(
95 | await database_admin.async_find_embedding_providers(filter_model_status="")
96 | )
97 |
98 | sup_count = _count_models(
99 | await database_admin.async_find_embedding_providers(
100 | filter_model_status="SUPPORTED"
101 | )
102 | )
103 | dep_count = _count_models(
104 | await database_admin.async_find_embedding_providers(
105 | filter_model_status="DEPRECATED"
106 | )
107 | )
108 | eol_count = _count_models(
109 | await database_admin.async_find_embedding_providers(
110 | filter_model_status="END_OF_LIFE"
111 | )
112 | )
113 |
114 | assert sup_count + dep_count + eol_count == all_count
115 | assert sup_count == default_count
116 |
--------------------------------------------------------------------------------
/tests/base/integration/misc/test_vectorize_ops_sync.py:
--------------------------------------------------------------------------------
1 | # Copyright DataStax, Inc.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | from __future__ import annotations
16 |
17 | import os
18 |
19 | import pytest
20 |
21 | from astrapy import Database
22 | from astrapy.info import EmbeddingProvider, FindEmbeddingProvidersResult
23 |
24 | from ..conftest import IS_ASTRA_DB, clean_nulls_from_dict
25 |
26 |
27 | def _count_models(fep_result: FindEmbeddingProvidersResult) -> int:
28 | return len(
29 | [
30 | model
31 | for prov_v in fep_result.embedding_providers.values()
32 | for model in prov_v.models
33 | ]
34 | )
35 |
36 |
37 | class TestVectorizeOpsSync:
38 | @pytest.mark.describe("test of find_embedding_providers, sync")
39 | def test_findembeddingproviders_sync(
40 | self,
41 | sync_database: Database,
42 | ) -> None:
43 | database_admin = sync_database.get_database_admin()
44 | ep_result = database_admin.find_embedding_providers()
45 |
46 | assert isinstance(ep_result, FindEmbeddingProvidersResult)
47 |
48 | assert all(
49 | isinstance(emb_prov, EmbeddingProvider)
50 | for emb_prov in ep_result.embedding_providers.values()
51 | )
52 |
53 | reconstructed = {
54 | ep_name: EmbeddingProvider._from_dict(emb_prov.as_dict())
55 | for ep_name, emb_prov in ep_result.embedding_providers.items()
56 | }
57 | assert reconstructed == ep_result.embedding_providers
58 | dict_mapping = {
59 | ep_name: emb_prov.as_dict()
60 | for ep_name, emb_prov in ep_result.embedding_providers.items()
61 | }
62 | cleaned_dict_mapping = clean_nulls_from_dict(dict_mapping)
63 | cleaned_raw_info = clean_nulls_from_dict(
64 | ep_result.raw_info["embeddingProviders"] # type: ignore[index]
65 | )
66 | # TODO remove this cleanup once Astra prod and HCD get the support status flags
67 | raw_info_has_ams = any(
68 | "apiModelSupport" in model_dict
69 | for prov_v in cleaned_raw_info.values()
70 | for model_dict in prov_v["models"]
71 | )
72 | if not raw_info_has_ams:
73 | for emb_prov in cleaned_dict_mapping.values():
74 | for model in emb_prov["models"]:
75 | if "apiModelSupport" in model:
76 | del model["apiModelSupport"]
77 | assert cleaned_dict_mapping == cleaned_raw_info
78 |
79 | @pytest.mark.skipif(
80 | "ASTRAPY_TEST_LATEST_MAIN" not in os.environ,
81 | reason="No 'latest main' tests required.",
82 | )
83 | @pytest.mark.skipif(IS_ASTRA_DB, reason="Filtering models not yet on Astra DB")
84 | @pytest.mark.describe("test of find_embedding_providers filtering, sync")
85 | def test_filtered_findembeddingproviders_sync(
86 | self,
87 | sync_database: Database,
88 | ) -> None:
89 | database_admin = sync_database.get_database_admin()
90 | default_count = _count_models(database_admin.find_embedding_providers())
91 |
92 | all_count = _count_models(
93 | database_admin.find_embedding_providers(filter_model_status="")
94 | )
95 |
96 | sup_count = _count_models(
97 | database_admin.find_embedding_providers(filter_model_status="SUPPORTED")
98 | )
99 | dep_count = _count_models(
100 | database_admin.find_embedding_providers(filter_model_status="DEPRECATED")
101 | )
102 | eol_count = _count_models(
103 | database_admin.find_embedding_providers(filter_model_status="END_OF_LIFE")
104 | )
105 |
106 | assert sup_count + dep_count + eol_count == all_count
107 | assert sup_count == default_count
108 |
--------------------------------------------------------------------------------
/tests/base/integration/tables/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright DataStax, Inc.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
--------------------------------------------------------------------------------
/tests/base/integration/tables/table_cql_assets.py:
--------------------------------------------------------------------------------
1 | # Copyright DataStax, Inc.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | from __future__ import annotations
16 |
17 | from astrapy.data_types import DataAPITimestamp
18 |
19 | TABLE_NAME_COUNTER = "test_table_counter"
20 | CREATE_TABLE_COUNTER = (
21 | f"CREATE TABLE {TABLE_NAME_COUNTER} (a TEXT PRIMARY KEY, col_counter COUNTER);"
22 | )
23 | DROP_TABLE_COUNTER = f"DROP TABLE {TABLE_NAME_COUNTER};"
24 | INSERTS_TABLE_COUNTER = [
25 | f"UPDATE {TABLE_NAME_COUNTER} SET col_counter=col_counter+137 WHERE a='a';",
26 | f"UPDATE {TABLE_NAME_COUNTER} SET col_counter=col_counter+1 WHERE a='z';",
27 | ]
28 | FILTER_COUNTER = {"a": "a"}
29 | EXPECTED_ROW_COUNTER = {"a": "a", "col_counter": 137}
30 |
31 | TABLE_NAME_LOWSUPPORT = "test_table_lowsupport"
32 | TYPE_NAME_LOWSUPPORT = "test_type_lowsupport"
33 | CREATE_TYPE_LOWSUPPORT = (
34 | f"CREATE TYPE {TYPE_NAME_LOWSUPPORT} (genus TEXT, species TEXT, size FLOAT);"
35 | )
36 | CREATE_TABLE_LOWSUPPORT = f"""CREATE TABLE {TABLE_NAME_LOWSUPPORT} (
37 | a TEXT,
38 | b TEXT,
39 | col_static_timestamp TIMESTAMP STATIC,
40 | col_static_list LIST STATIC,
41 | col_static_list_exotic LIST STATIC,
42 | col_static_set SET STATIC,
43 | col_static_set_exotic SET STATIC,
44 | col_static_map MAP STATIC,
45 | col_static_map_exotic MAP STATIC,
46 | col_unsupported FROZEN>, SMALLINT>>>>,
47 | col_udt {TYPE_NAME_LOWSUPPORT},
48 | PRIMARY KEY ((A),B)
49 | );"""
50 | DROP_TABLE_LOWSUPPORT = f"DROP TABLE {TABLE_NAME_LOWSUPPORT};"
51 | DROP_TYPE_LOWSUPPORT = f"DROP TYPE {TYPE_NAME_LOWSUPPORT};"
52 | INSERTS_TABLE_LOWSUPPORT = [
53 | (
54 | f"INSERT INTO {TABLE_NAME_LOWSUPPORT}"
55 | " ("
56 | " a,"
57 | " b,"
58 | " col_static_timestamp,"
59 | " col_static_list,"
60 | " col_static_list_exotic,"
61 | " col_static_set,"
62 | " col_static_set_exotic,"
63 | " col_static_map,"
64 | " col_static_map_exotic,"
65 | " col_unsupported,"
66 | " col_udt"
67 | ") VALUES ("
68 | " 'a',"
69 | " 'b',"
70 | " '2022-01-01T12:34:56.000',"
71 | " [1, 2, 3],"
72 | " [0xff, 0xff],"
73 | " {1, 2, 3},"
74 | " {0xff, 0xff},"
75 | " {1: 'one'},"
76 | " {0xff: 0xff},"
77 | " [{{0.1, 0.2}: 3}, {{0.3, 0.4}: 6}],"
78 | " {genus: 'Eratigena', species: 'atrica', size: 1.8}"
79 | ");"
80 | ),
81 | ]
82 | FILTER_LOWSUPPORT = {"a": "a", "b": "b"}
83 | ILLEGAL_PROJECTIONS_LOWSUPPORT = [
84 | {"col_unsupported": True},
85 | {"col_unsupported_udt": True},
86 | ]
87 | PROJECTION_LOWSUPPORT = {"col_unsupported": False, "col_udt": False}
88 | EXPECTED_ROW_LOWSUPPORT = {
89 | "a": "a",
90 | "b": "b",
91 | "col_static_list_exotic": [
92 | {"$binary": "/w=="},
93 | {"$binary": "/w=="},
94 | ],
95 | "col_static_set": [1, 2, 3],
96 | "col_static_map_exotic": [
97 | [
98 | {"$binary": "/w=="},
99 | {"$binary": "/w=="},
100 | ]
101 | ],
102 | "col_static_timestamp": DataAPITimestamp.from_string("2022-01-01T12:34:56Z"),
103 | "col_static_map": [[1, "one"]],
104 | "col_static_list": [1, 2, 3],
105 | "col_static_set_exotic": [{"$binary": "/w=="}],
106 | }
107 |
--------------------------------------------------------------------------------
/tests/base/integration/tables/test_table_column_types_async.py:
--------------------------------------------------------------------------------
1 | # Copyright DataStax, Inc.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | from __future__ import annotations
16 |
17 | import pytest
18 |
19 | from astrapy.api_options import APIOptions, SerdesOptions
20 |
21 | from ..conftest import DefaultAsyncTable, _repaint_NaNs
22 | from .table_row_assets import (
23 | FULL_AR_ROW_CUSTOMTYPED,
24 | FULL_AR_ROW_NONCUSTOMTYPED,
25 | )
26 |
27 |
28 | class TestTableColumnTypesSync:
29 | @pytest.mark.describe("test of table write-and-read, with custom types, async")
30 | async def test_table_write_then_read_customtypes_async(
31 | self,
32 | async_empty_table_all_returns: DefaultAsyncTable,
33 | ) -> None:
34 | cdtypes_atable = async_empty_table_all_returns.with_options(
35 | api_options=APIOptions(
36 | serdes_options=SerdesOptions(
37 | custom_datatypes_in_reading=True,
38 | ),
39 | ),
40 | )
41 | await cdtypes_atable.insert_one(FULL_AR_ROW_CUSTOMTYPED)
42 | retrieved = await cdtypes_atable.find_one({})
43 | assert _repaint_NaNs(retrieved) == _repaint_NaNs(FULL_AR_ROW_CUSTOMTYPED)
44 |
45 | @pytest.mark.describe("test of table write-and-read, with noncustom types, async")
46 | async def test_table_write_then_read_noncustomtypes_async(
47 | self,
48 | async_empty_table_all_returns: DefaultAsyncTable,
49 | ) -> None:
50 | rdtypes_atable = async_empty_table_all_returns.with_options(
51 | api_options=APIOptions(
52 | serdes_options=SerdesOptions(
53 | custom_datatypes_in_reading=False,
54 | ),
55 | ),
56 | )
57 | await rdtypes_atable.insert_one(FULL_AR_ROW_NONCUSTOMTYPED)
58 | retrieved = await rdtypes_atable.find_one({})
59 | assert _repaint_NaNs(retrieved) == _repaint_NaNs(FULL_AR_ROW_NONCUSTOMTYPED)
60 |
--------------------------------------------------------------------------------
/tests/base/integration/tables/test_table_column_types_sync.py:
--------------------------------------------------------------------------------
1 | # Copyright DataStax, Inc.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | from __future__ import annotations
16 |
17 | import pytest
18 |
19 | from astrapy.api_options import APIOptions, SerdesOptions
20 |
21 | from ..conftest import DefaultTable, _repaint_NaNs
22 | from .table_row_assets import (
23 | FULL_AR_ROW_CUSTOMTYPED,
24 | FULL_AR_ROW_NONCUSTOMTYPED,
25 | )
26 |
27 |
28 | class TestTableColumnTypesSync:
29 | @pytest.mark.describe("test of table write-and-read, with custom types, sync")
30 | def test_table_write_then_read_customtypes_sync(
31 | self,
32 | sync_empty_table_all_returns: DefaultTable,
33 | ) -> None:
34 | cdtypes_table = sync_empty_table_all_returns.with_options(
35 | api_options=APIOptions(
36 | serdes_options=SerdesOptions(
37 | custom_datatypes_in_reading=True,
38 | ),
39 | ),
40 | )
41 | cdtypes_table.insert_one(FULL_AR_ROW_CUSTOMTYPED)
42 | retrieved = cdtypes_table.find_one({})
43 | assert _repaint_NaNs(retrieved) == _repaint_NaNs(FULL_AR_ROW_CUSTOMTYPED)
44 |
45 | @pytest.mark.describe("test of table write-and-read, with noncustom types, sync")
46 | def test_table_write_then_read_noncustomtypes_sync(
47 | self,
48 | sync_empty_table_all_returns: DefaultTable,
49 | ) -> None:
50 | rdtypes_table = sync_empty_table_all_returns.with_options(
51 | api_options=APIOptions(
52 | serdes_options=SerdesOptions(
53 | custom_datatypes_in_reading=False,
54 | ),
55 | ),
56 | )
57 | rdtypes_table.insert_one(FULL_AR_ROW_NONCUSTOMTYPED)
58 | retrieved = rdtypes_table.find_one({})
59 | assert _repaint_NaNs(retrieved) == _repaint_NaNs(FULL_AR_ROW_NONCUSTOMTYPED)
60 |
--------------------------------------------------------------------------------
/tests/base/integration/tables/test_table_cqldriven_dml_sync.py:
--------------------------------------------------------------------------------
1 | # Copyright DataStax, Inc.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | from __future__ import annotations
16 |
17 | import time
18 | from typing import TYPE_CHECKING
19 |
20 | import pytest
21 |
22 | from astrapy import Database
23 | from astrapy.exceptions import DataAPIResponseException
24 |
25 | from ..conftest import CQL_AVAILABLE
26 | from .table_cql_assets import (
27 | CREATE_TABLE_COUNTER,
28 | CREATE_TABLE_LOWSUPPORT,
29 | CREATE_TYPE_LOWSUPPORT,
30 | DROP_TABLE_COUNTER,
31 | DROP_TABLE_LOWSUPPORT,
32 | DROP_TYPE_LOWSUPPORT,
33 | EXPECTED_ROW_COUNTER,
34 | EXPECTED_ROW_LOWSUPPORT,
35 | FILTER_COUNTER,
36 | FILTER_LOWSUPPORT,
37 | ILLEGAL_PROJECTIONS_LOWSUPPORT,
38 | INSERTS_TABLE_COUNTER,
39 | INSERTS_TABLE_LOWSUPPORT,
40 | PROJECTION_LOWSUPPORT,
41 | TABLE_NAME_COUNTER,
42 | TABLE_NAME_LOWSUPPORT,
43 | )
44 |
45 | if TYPE_CHECKING:
46 | from cassandra.cluster import Session
47 |
48 | try:
49 | from cassandra.cluster import Session
50 | except ImportError:
51 | pass
52 |
53 |
54 | @pytest.mark.skipif(not CQL_AVAILABLE, reason="Not CQL session available")
55 | class TestTableCQLDrivenDMLSync:
56 | @pytest.mark.describe(
57 | "test of reading from a CQL-driven table with a Counter, sync"
58 | )
59 | def test_table_cqldriven_counter_sync(
60 | self,
61 | cql_session: Session,
62 | sync_database: Database,
63 | ) -> None:
64 | try:
65 | cql_session.execute(CREATE_TABLE_COUNTER)
66 | for insert_statement in INSERTS_TABLE_COUNTER:
67 | cql_session.execute(insert_statement)
68 | time.sleep(1.5) # delay for schema propagation
69 |
70 | table = sync_database.get_table(TABLE_NAME_COUNTER)
71 | table.definition()
72 | row = table.find_one(filter=FILTER_COUNTER)
73 | assert row == EXPECTED_ROW_COUNTER
74 | table.delete_one(filter=FILTER_COUNTER)
75 | row = table.find_one(filter=FILTER_COUNTER)
76 | assert row is None
77 | finally:
78 | cql_session.execute(DROP_TABLE_COUNTER)
79 |
80 | @pytest.mark.describe(
81 | "test of reading from a CQL-driven table with limited-support columns, sync"
82 | )
83 | def test_table_cqldriven_lowsupport_sync(
84 | self,
85 | cql_session: Session,
86 | sync_database: Database,
87 | ) -> None:
88 | try:
89 | cql_session.execute(CREATE_TYPE_LOWSUPPORT)
90 | cql_session.execute(CREATE_TABLE_LOWSUPPORT)
91 | for insert_statement in INSERTS_TABLE_LOWSUPPORT:
92 | cql_session.execute(insert_statement)
93 | time.sleep(1.5) # delay for schema propagation
94 |
95 | table = sync_database.get_table(TABLE_NAME_LOWSUPPORT)
96 | table.definition()
97 | for ill_proj in ILLEGAL_PROJECTIONS_LOWSUPPORT:
98 | with pytest.raises(DataAPIResponseException):
99 | table.find_one(filter=FILTER_LOWSUPPORT)
100 | row = table.find_one(
101 | filter=FILTER_LOWSUPPORT, projection=PROJECTION_LOWSUPPORT
102 | )
103 | assert row == EXPECTED_ROW_LOWSUPPORT
104 | table.delete_one(filter=FILTER_LOWSUPPORT)
105 | row = table.find_one(
106 | filter=FILTER_LOWSUPPORT, projection=PROJECTION_LOWSUPPORT
107 | )
108 | assert row is None
109 | finally:
110 | cql_session.execute(DROP_TABLE_LOWSUPPORT)
111 | cql_session.execute(DROP_TYPE_LOWSUPPORT)
112 |
--------------------------------------------------------------------------------
/tests/base/unit/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright DataStax, Inc.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | from __future__ import annotations
16 |
--------------------------------------------------------------------------------
/tests/base/unit/test_apioptions.py:
--------------------------------------------------------------------------------
1 | # Copyright DataStax, Inc.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | from __future__ import annotations
16 |
17 | import pytest
18 |
19 | from astrapy.utils.api_options import APIOptions, defaultAPIOptions
20 |
21 |
22 | class TestAPIOptions:
23 | @pytest.mark.describe("test of header inheritance in APIOptions")
24 | def test_apioptions_headers(self) -> None:
25 | opts_d = defaultAPIOptions(environment="dev")
26 | opts_1 = opts_d.with_override(
27 | APIOptions(
28 | database_additional_headers={"d": "y", "D": None},
29 | admin_additional_headers={"a": "y", "A": None},
30 | redacted_header_names={"x", "y"},
31 | )
32 | )
33 | opts_2 = opts_d.with_override(
34 | APIOptions(
35 | database_additional_headers={"D": "y"},
36 | admin_additional_headers={"A": "y"},
37 | redacted_header_names={"x"},
38 | )
39 | ).with_override(
40 | APIOptions(
41 | database_additional_headers={"d": "y", "D": None},
42 | admin_additional_headers={"a": "y", "A": None},
43 | redacted_header_names={"y"},
44 | )
45 | )
46 |
47 | assert opts_1 == opts_2
48 |
--------------------------------------------------------------------------------
/tests/base/unit/test_collection_decimal_support.py:
--------------------------------------------------------------------------------
1 | # Copyright DataStax, Inc.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | from __future__ import annotations
16 |
17 | import pytest
18 |
19 | from astrapy.data.utils.collection_converters import (
20 | postprocess_collection_response,
21 | preprocess_collection_payload,
22 | )
23 | from astrapy.utils.api_commander import APICommander
24 |
25 | from ..collection_decimal_support_assets import (
26 | BASELINE_OBJ,
27 | S_OPTS_NO_DECS,
28 | S_OPTS_OK_DECS,
29 | WDECS_OBJ,
30 | is_decimal_super,
31 | )
32 |
33 |
34 | class TestCollectionDecimalSupportUnit:
35 | @pytest.mark.describe(
36 | "test of decimals not supported by default in collection codec paths"
37 | )
38 | def test_decimalsupport_collection_codecpath_defaultsettings(self) -> None:
39 | # write path with baseline
40 | baseline_fully_encoded = APICommander._decimal_unaware_encode_payload(
41 | preprocess_collection_payload(
42 | BASELINE_OBJ,
43 | options=S_OPTS_NO_DECS,
44 | )
45 | )
46 | # read path back to it
47 | assert baseline_fully_encoded is not None
48 | baseline_obj_2 = postprocess_collection_response(
49 | APICommander._decimal_unaware_parse_json_response(
50 | baseline_fully_encoded,
51 | ),
52 | options=S_OPTS_NO_DECS,
53 | )
54 | # this must match exactly (baseline)
55 | assert BASELINE_OBJ == baseline_obj_2
56 |
57 | # write path with decimals should error instead
58 | with pytest.raises(TypeError):
59 | APICommander._decimal_unaware_encode_payload(
60 | preprocess_collection_payload(
61 | WDECS_OBJ,
62 | options=S_OPTS_NO_DECS,
63 | )
64 | )
65 |
66 | @pytest.mark.describe(
67 | "test of decimals supported in collection codec paths if set to do so"
68 | )
69 | def test_decimalsupport_collection_codecpath_decimalsettings(self) -> None:
70 | # write path with baseline
71 | baseline_fully_encoded = APICommander._decimal_aware_encode_payload(
72 | preprocess_collection_payload(
73 | BASELINE_OBJ,
74 | options=S_OPTS_OK_DECS,
75 | )
76 | )
77 | # read path back to it
78 | assert baseline_fully_encoded is not None
79 | baseline_obj_2 = postprocess_collection_response(
80 | APICommander._decimal_aware_parse_json_response(
81 | baseline_fully_encoded,
82 | ),
83 | options=S_OPTS_OK_DECS,
84 | )
85 | # the re-read object must "be more-or-equally Decimal" than the source
86 | # but otherwise coincide
87 | assert is_decimal_super(baseline_obj_2, BASELINE_OBJ)
88 |
89 | # write path with decimals
90 | wdecs_fully_encoded = APICommander._decimal_aware_encode_payload(
91 | preprocess_collection_payload(
92 | WDECS_OBJ,
93 | options=S_OPTS_OK_DECS,
94 | )
95 | )
96 | # read path back to it
97 | assert wdecs_fully_encoded is not None
98 | wdecs_2 = postprocess_collection_response(
99 | APICommander._decimal_aware_parse_json_response(
100 | wdecs_fully_encoded,
101 | ),
102 | options=S_OPTS_OK_DECS,
103 | )
104 | # the re-read object must "be more-or-equally Decimal" than the source
105 | # but otherwise coincide
106 | assert is_decimal_super(wdecs_2, WDECS_OBJ)
107 |
--------------------------------------------------------------------------------
/tests/base/unit/test_collectionvectorserviceoptions.py:
--------------------------------------------------------------------------------
1 | # Copyright DataStax, Inc.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | from __future__ import annotations
16 |
17 | import pytest
18 |
19 | from astrapy.info import VectorServiceOptions
20 |
21 |
22 | class TestVectorServiceOptions:
23 | @pytest.mark.describe("test of VectorServiceOptions conversions, base")
24 | def test_VectorServiceOptions_conversions_base(self) -> None:
25 | dict0 = {
26 | "provider": "PRO",
27 | "modelName": "MOD",
28 | }
29 | from_dict0 = VectorServiceOptions._from_dict(dict0)
30 | assert from_dict0 == VectorServiceOptions(
31 | provider="PRO",
32 | model_name="MOD",
33 | )
34 | assert from_dict0.as_dict() == dict0
35 |
36 | @pytest.mark.describe("test of VectorServiceOptions conversions, with auth")
37 | def test_VectorServiceOptions_conversions_auth(self) -> None:
38 | dict_a = {
39 | "provider": "PRO",
40 | "modelName": "MOD",
41 | "authentication": {
42 | "type": ["A_T"],
43 | "field": "value",
44 | },
45 | }
46 | from_dict_a = VectorServiceOptions._from_dict(dict_a)
47 | assert from_dict_a == VectorServiceOptions(
48 | provider="PRO",
49 | model_name="MOD",
50 | authentication={
51 | "type": ["A_T"],
52 | "field": "value",
53 | },
54 | )
55 | assert from_dict_a.as_dict() == dict_a
56 |
57 | @pytest.mark.describe("test of VectorServiceOptions conversions, with params")
58 | def test_VectorServiceOptions_conversions_params(self) -> None:
59 | dict_p = {
60 | "provider": "PRO",
61 | "modelName": "MOD",
62 | "parameters": {
63 | "field1": "value1",
64 | "field2": 12.3,
65 | "field3": 123,
66 | "field4": True,
67 | "field5": None,
68 | "field6": {"a": 1},
69 | },
70 | }
71 | from_dict_p = VectorServiceOptions._from_dict(dict_p)
72 | assert from_dict_p == VectorServiceOptions(
73 | provider="PRO",
74 | model_name="MOD",
75 | parameters={
76 | "field1": "value1",
77 | "field2": 12.3,
78 | "field3": 123,
79 | "field4": True,
80 | "field5": None,
81 | "field6": {"a": 1},
82 | },
83 | )
84 | assert from_dict_p.as_dict() == dict_p
85 |
86 | @pytest.mark.describe(
87 | "test of VectorServiceOptions conversions, with params and auth"
88 | )
89 | def test_VectorServiceOptions_conversions_params_auth(self) -> None:
90 | dict_ap = {
91 | "provider": "PRO",
92 | "modelName": "MOD",
93 | "authentication": {
94 | "type": ["A_T"],
95 | "field": "value",
96 | },
97 | "parameters": {
98 | "field1": "value1",
99 | "field2": 12.3,
100 | "field3": 123,
101 | "field4": True,
102 | "field5": None,
103 | "field6": {"a": 1},
104 | },
105 | }
106 | from_dict_ap = VectorServiceOptions._from_dict(dict_ap)
107 | assert from_dict_ap == VectorServiceOptions(
108 | provider="PRO",
109 | model_name="MOD",
110 | authentication={
111 | "type": ["A_T"],
112 | "field": "value",
113 | },
114 | parameters={
115 | "field1": "value1",
116 | "field2": 12.3,
117 | "field3": 123,
118 | "field4": True,
119 | "field5": None,
120 | "field6": {"a": 1},
121 | },
122 | )
123 | assert from_dict_ap.as_dict() == dict_ap
124 |
--------------------------------------------------------------------------------
/tests/base/unit/test_dataapimap.py:
--------------------------------------------------------------------------------
1 | # Copyright DataStax, Inc.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | from __future__ import annotations
16 |
17 | import pickle
18 |
19 | import pytest
20 |
21 | from astrapy.data_types import DataAPIMap
22 |
23 |
24 | class TestDataAPIMap:
25 | @pytest.mark.describe("test of table map usage with hashables")
26 | def test_dataapimap_hashables(self) -> None:
27 | mp0: DataAPIMap[int, int] = DataAPIMap()
28 | assert mp0 == DataAPIMap()
29 | assert dict(mp0) == {}
30 | assert mp0 == {}
31 | # identity/equality
32 | items1 = [(1, "a"), (2, "b"), (3, "c")]
33 | mp1 = DataAPIMap(items1)
34 | assert mp1 == DataAPIMap(items1)
35 | assert mp1 == DataAPIMap(items1 + [(2, "b")])
36 | assert mp1 == DataAPIMap(items1[2:] + items1[:2])
37 | assert mp1 != DataAPIMap(items1[1:] + [(1, "z")])
38 | assert mp1 != DataAPIMap(items1[1:] + [(9, "z")])
39 | assert dict(items1) == mp1
40 | assert mp1 == dict(items1)
41 | assert mp1 == dict(items1 + [(2, "b")])
42 | assert mp1 == dict(items1[2:] + items1[:2])
43 | assert mp1 != dict(items1[1:] + [(1, "z")])
44 | assert mp1 != dict(items1[1:] + [(9, "z")])
45 | # map operations
46 | assert list(mp1.keys()) == [1, 2, 3]
47 | assert list(mp1.values()) == ["a", "b", "c"]
48 | assert mp1[2] == "b"
49 | with pytest.raises(KeyError):
50 | mp1[4]
51 | assert list(mp1.items()) == items1
52 | assert len(mp1) == len(items1)
53 | assert list(mp1) == [1, 2, 3]
54 |
55 | @pytest.mark.describe("test of table map usage with non-hashables")
56 | def test_dataapimap_nonhashables(self) -> None:
57 | mp0: DataAPIMap[list[int], int] = DataAPIMap()
58 | assert mp0 == DataAPIMap()
59 | assert dict(mp0) == {}
60 | # identity/equality
61 | items1 = [([1], "a"), ([2], "b"), ([3], "c")]
62 | mp1 = DataAPIMap(items1)
63 | assert mp1 == DataAPIMap(items1)
64 | assert mp1 == DataAPIMap(items1 + [([2], "b")])
65 | assert mp1 == DataAPIMap(items1[2:] + items1[:2])
66 | assert mp1 != DataAPIMap(items1[1:] + [([1], "z")])
67 | assert mp1 != DataAPIMap(items1[1:] + [([9], "z")])
68 | # map operations
69 | assert list(mp1.keys()) == [[1], [2], [3]]
70 | assert list(mp1.values()) == ["a", "b", "c"]
71 | assert mp1[[2]] == "b"
72 | with pytest.raises(KeyError):
73 | mp1[[4]]
74 | assert list(mp1.items()) == items1
75 | assert len(mp1) == len(items1)
76 | assert list(mp1) == [[1], [2], [3]]
77 |
78 | @pytest.mark.describe("test pickling of DataAPIMap")
79 | def test_dataapimap_pickle(self) -> None:
80 | the_map = DataAPIMap([("key1", 1), (None, 2), ("key3", {"a": 1})])
81 | assert pickle.loads(pickle.dumps(the_map)) == the_map
82 |
--------------------------------------------------------------------------------
/tests/base/unit/test_dataapiset.py:
--------------------------------------------------------------------------------
1 | # Copyright DataStax, Inc.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | from __future__ import annotations
16 |
17 | import pickle
18 |
19 | import pytest
20 |
21 | from astrapy.data_types import DataAPISet
22 |
23 |
24 | class TestDataAPISet:
25 | @pytest.mark.describe("test of table set usage with hashables")
26 | def test_dataapiset_hashables(self) -> None:
27 | ts0: DataAPISet[int] = DataAPISet()
28 | assert ts0 == DataAPISet()
29 | assert set(ts0) == set()
30 | # identity/equality
31 | ts1 = DataAPISet([1, 2, 3])
32 | assert ts1 == DataAPISet([1, 2, 3])
33 | assert ts1 == DataAPISet([1, 2, 3, 2])
34 | assert ts1 != DataAPISet([1, 3, 2])
35 | assert set(ts1) == {1, 2, 3}
36 | # set operations
37 | assert ts1 - {2} == DataAPISet([1, 3])
38 | assert ts1 | {2} == DataAPISet([1, 2, 3])
39 | assert ts1 | {4} == DataAPISet([1, 2, 3, 4])
40 | assert ts1 | DataAPISet([4, 2, 5]) == DataAPISet([1, 2, 3, 4, 5])
41 | assert ts1 ^ DataAPISet([4, 2, 5]) == DataAPISet([1, 3, 4, 5])
42 | assert ts1 & DataAPISet([4, 2, 5]) == DataAPISet([2])
43 |
44 | @pytest.mark.describe("test of table set usage with non-hashables")
45 | def test_dataapiset_nonhashables(self) -> None:
46 | ts0: DataAPISet[list[int]] = DataAPISet()
47 | assert ts0 == DataAPISet()
48 | # identity/equality
49 | ts1 = DataAPISet([[1], [2], [3]])
50 | assert ts1 == DataAPISet([[1], [2], [3]])
51 | assert ts1 == DataAPISet([[1], [2], [3], [2]])
52 | assert ts1 != DataAPISet([[1], [3], [2]])
53 | # set operations
54 | assert ts1 - DataAPISet([[2]]) == DataAPISet([[1], [3]])
55 | assert ts1 | DataAPISet([[2]]) == DataAPISet([[1], [2], [3]])
56 | assert ts1 | DataAPISet([[4]]) == DataAPISet([[1], [2], [3], [4]])
57 | assert ts1 | DataAPISet([[4], [2], [5]]) == DataAPISet(
58 | [[1], [2], [3], [4], [5]]
59 | )
60 | assert ts1 ^ DataAPISet([[4], [2], [5]]) == DataAPISet([[1], [3], [4], [5]])
61 | assert ts1 & DataAPISet([[4], [2], [5]]) == DataAPISet([[2]])
62 |
63 | @pytest.mark.describe("test pickling of DataAPISet")
64 | def test_dataapiset_pickle(self) -> None:
65 | the_set = DataAPISet([1, 2, 3])
66 | assert pickle.loads(pickle.dumps(the_set)) == the_set
67 |
--------------------------------------------------------------------------------
/tests/base/unit/test_dataapitime.py:
--------------------------------------------------------------------------------
1 | # Copyright DataStax, Inc.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | from __future__ import annotations
16 |
17 | import datetime
18 | import pickle
19 |
20 | import pytest
21 |
22 | from astrapy.data_types import DataAPITime
23 |
24 |
25 | class TestDataAPITime:
26 | @pytest.mark.describe("test of time type, errors in parsing from string")
27 | def test_dataapitime_parse_errors(self) -> None:
28 | # empty, faulty, misformatted
29 | with pytest.raises(ValueError):
30 | DataAPITime.from_string("")
31 | with pytest.raises(ValueError):
32 | DataAPITime.from_string("boom")
33 |
34 | with pytest.raises(ValueError):
35 | DataAPITime.from_string("12:34:56:21")
36 | with pytest.raises(ValueError):
37 | DataAPITime.from_string("+12:34:56")
38 | with pytest.raises(ValueError):
39 | DataAPITime.from_string("12:+34:56")
40 | with pytest.raises(ValueError):
41 | DataAPITime.from_string("12:34:56.")
42 | with pytest.raises(ValueError):
43 | DataAPITime.from_string("12:34:56X")
44 | with pytest.raises(ValueError):
45 | DataAPITime.from_string("X12:34:56")
46 | with pytest.raises(ValueError):
47 | DataAPITime.from_string("12:34X:56")
48 |
49 | with pytest.raises(ValueError):
50 | DataAPITime.from_string("24:00:00")
51 | with pytest.raises(ValueError):
52 | DataAPITime.from_string("00:60:00")
53 | with pytest.raises(ValueError):
54 | DataAPITime.from_string("00:00:60")
55 |
56 | DataAPITime.from_string("12:23")
57 | DataAPITime.from_string("00:00:00")
58 | DataAPITime.from_string("23:00:00")
59 | DataAPITime.from_string("00:59:00")
60 | DataAPITime.from_string("00:00:00.123456789")
61 | DataAPITime.from_string("00:00:00.123")
62 | DataAPITime.from_string("01:02:03.123456789")
63 | DataAPITime.from_string("01:02:03.123")
64 |
65 | @pytest.mark.describe("test of time type, lifecycle")
66 | def test_dataapitime_lifecycle(self) -> None:
67 | tint = DataAPITime.from_string("02:03:04")
68 | tint_exp = DataAPITime(hour=2, minute=3, second=4)
69 | assert tint == tint_exp
70 | assert tint == DataAPITime.from_string("2:3:4")
71 | assert tint == DataAPITime.from_string("0002:003:000004")
72 | py_tint = datetime.time(2, 3, 4)
73 | assert DataAPITime.from_time(py_tint) == tint
74 | assert tint.to_time() == py_tint
75 | repr(tint)
76 | str(tint)
77 | tint.to_string()
78 |
79 | tfra = DataAPITime.from_string("02:03:04.9876")
80 | tfra_exp = DataAPITime(hour=2, minute=3, second=4, nanosecond=987600000)
81 | assert tfra == tfra_exp
82 | assert tfra == DataAPITime.from_string("2:3:4.9876")
83 | assert tfra == DataAPITime.from_string("0002:003:000004.987600000")
84 | py_tfra = datetime.time(2, 3, 4, 987600)
85 | assert DataAPITime.from_time(py_tfra) == tfra
86 | assert tfra.to_time() == py_tfra
87 | repr(tfra)
88 | str(tfra)
89 | tfra.to_string()
90 |
91 | tfra1 = DataAPITime(1, 2, 3, 123)
92 | tfra1p = DataAPITime(1, 2, 3, 12)
93 | tfra2p = DataAPITime(1, 2, 3, 12345)
94 | tfra2 = DataAPITime(1, 2, 3, 123456)
95 | tfra3p = DataAPITime(1, 2, 3, 12345678)
96 | tfra3 = DataAPITime(1, 2, 3, 123456789)
97 | assert DataAPITime.from_string(tfra1.to_string()) == tfra1
98 | assert DataAPITime.from_string(tfra1p.to_string()) == tfra1p
99 | assert DataAPITime.from_string(tfra2p.to_string()) == tfra2p
100 | assert DataAPITime.from_string(tfra2.to_string()) == tfra2
101 | assert DataAPITime.from_string(tfra3p.to_string()) == tfra3p
102 | assert DataAPITime.from_string(tfra3.to_string()) == tfra3
103 |
104 | t1 = DataAPITime(1, 2, 3, 30000)
105 | t2 = DataAPITime(1, 2, 3, 45000)
106 | py_t1 = t1.to_time()
107 | py_t2 = t2.to_time()
108 | assert t1 < t2
109 | assert t1 <= t2
110 | assert t2 > t1
111 | assert t2 >= t1
112 | assert py_t1 < t2
113 | assert py_t1 <= t2
114 | assert py_t2 > t1
115 | assert py_t2 >= t1
116 | assert t1 < py_t2
117 | assert t1 <= py_t2
118 | assert t2 > py_t1
119 | assert t2 >= py_t1
120 |
121 | @pytest.mark.describe("test pickling of DataAPITime")
122 | def test_dataapitime_pickle(self) -> None:
123 | the_time = DataAPITime(12, 34, 56, 789)
124 | assert pickle.loads(pickle.dumps(the_time)) == the_time
125 |
--------------------------------------------------------------------------------
/tests/base/unit/test_dataapivector.py:
--------------------------------------------------------------------------------
1 | # Copyright DataStax, Inc.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | from __future__ import annotations
16 |
17 | import pytest
18 |
19 | from astrapy.data.utils.extended_json_converters import (
20 | convert_ejson_binary_object_to_bytes,
21 | convert_to_ejson_bytes,
22 | )
23 | from astrapy.data_types import DataAPIVector
24 | from astrapy.data_types.data_api_vector import bytes_to_floats, floats_to_bytes
25 |
26 | COMPARE_EPSILON = 0.00001
27 |
28 | SAMPLE_FLOAT_LISTS: list[list[float]] = [
29 | [10, 100, 1000, 0.1],
30 | [0.0043, 0.123, 1.332],
31 | [],
32 | [103.1, 104.5, 105.6],
33 | [(-1 if i % 2 == 0 else +1) * i / 5000 for i in range(4096)],
34 | ]
35 |
36 |
37 | def _nearly_equal_lists(list1: list[float], list2: list[float]) -> bool:
38 | if len(list1) != len(list2):
39 | return False
40 | if len(list1) == 0:
41 | return True
42 | return max(abs(x - y) for x, y in zip(list1, list2)) < COMPARE_EPSILON
43 |
44 |
45 | def _nearly_equal_vectors(vec1: DataAPIVector, vec2: DataAPIVector) -> bool:
46 | return _nearly_equal_lists(vec1.data, vec2.data)
47 |
48 |
49 | class TestDataAPIVector:
50 | @pytest.mark.describe("test of float-binary conversions")
51 | def test_dataapivector_byteconversions(self) -> None:
52 | for test_list0 in SAMPLE_FLOAT_LISTS:
53 | test_bytes0 = floats_to_bytes(test_list0)
54 | test_list1 = bytes_to_floats(test_bytes0)
55 | assert _nearly_equal_lists(test_list0, test_list1)
56 |
57 | @pytest.mark.describe("test of float-string conversions")
58 | def test_dataapivector_stringconversions(self) -> None:
59 | # known expectation
60 | obj_1 = {"$binary": "PczMzT5MzM0+mZma"}
61 | vec = DataAPIVector.from_bytes(convert_ejson_binary_object_to_bytes(obj_1))
62 | vec_exp = DataAPIVector([0.1, 0.2, 0.3])
63 | assert _nearly_equal_vectors(vec, vec_exp)
64 | # some full-round conversions
65 | for test_list0 in SAMPLE_FLOAT_LISTS:
66 | test_vec0 = DataAPIVector(test_list0)
67 | test_ejson = convert_to_ejson_bytes(test_vec0.to_bytes())
68 | test_vec1 = DataAPIVector.from_bytes(
69 | convert_ejson_binary_object_to_bytes(test_ejson)
70 | )
71 | assert _nearly_equal_vectors(test_vec0, test_vec1)
72 |
73 | @pytest.mark.describe("test of DataAPIVector lifecycle")
74 | def test_dataapivector_lifecycle(self) -> None:
75 | v0 = DataAPIVector([])
76 | v1 = DataAPIVector([1.1, 2.2, 3.3])
77 | for i in v0:
78 | pass
79 | for i in v1:
80 | pass
81 | assert list(v0) == []
82 | assert list(v1) == [1.1, 2.2, 3.3]
83 | assert v0 != v1
84 | assert v1 == DataAPIVector([1.1, 2.2, 3.3])
85 |
86 | # list-likeness
87 | assert v1[1:2] == DataAPIVector([2.2])
88 | assert len(v1) == 3
89 |
--------------------------------------------------------------------------------
/tests/base/unit/test_document_paths.py:
--------------------------------------------------------------------------------
1 | # Copyright DataStax, Inc.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | from __future__ import annotations
16 |
17 | import pytest
18 |
19 | from astrapy.utils.document_paths import (
20 | ILLEGAL_ESCAPE_ERROR_MESSAGE_TEMPLATE,
21 | UNTERMINATED_ESCAPE_ERROR_MESSAGE_TEMPLATE,
22 | escape_field_names,
23 | unescape_field_path,
24 | )
25 |
26 | ESCAPE_TEST_FIELDS = {
27 | "a": "a",
28 | "a.b": "a&.b",
29 | "a&c": "a&&c",
30 | "a..": "a&.&.",
31 | "": "",
32 | "xyz": "xyz",
33 | ".": "&.",
34 | "...": "&.&.&.",
35 | "&": "&&",
36 | "&&": "&&&&",
37 | "&.&.&..": "&&&.&&&.&&&.&.",
38 | "🇦🇨": "🇦🇨",
39 | "q🇦🇨w": "q🇦🇨w",
40 | ".🇦🇨&": "&.🇦🇨&&",
41 | "a&b.c.&.d": "a&&b&.c&.&&&.d",
42 | ".env": "&.env",
43 | """: "&"",
44 | "ݤ": "ݤ",
45 | "Aݤ": "Aݤ",
46 | "ݤZ": "ݤZ",
47 | "🚾 🆒 🆓 🆕 🆖 🆗 🆙 🏧": "🚾 🆒 🆓 🆕 🆖 🆗 🆙 🏧",
48 | "✋🏿 💪🏿 👐🏿 🙌🏿 👏🏿 🙏🏿": "✋🏿 💪🏿 👐🏿 🙌🏿 👏🏿 🙏🏿",
49 | "👨👩👦 👨👩👧👦 👨👨👦 👩👩👧 👨👦 👨👧👦 👩👦 👩👧👦": "👨👩👦 👨👩👧👦 👨👨👦 👩👩👧 👨👦 👨👧👦 👩👦 👩👧👦",
50 | }
51 |
52 |
53 | class TestDocumentPaths:
54 | @pytest.mark.describe("test of escape_field_names, one arg")
55 | def test_escape_field_names_onearg(self) -> None:
56 | for lit_fn, esc_fn in ESCAPE_TEST_FIELDS.items():
57 | assert escape_field_names(lit_fn) == esc_fn
58 | assert escape_field_names([lit_fn]) == esc_fn
59 | for num in [0, 12, 130099]:
60 | assert escape_field_names(num) == str(num)
61 | assert escape_field_names([num]) == str(num)
62 |
63 | @pytest.mark.describe("test of escape_field_names, multiple args")
64 | def test_escape_field_names_multiargs(self) -> None:
65 | all_lits, all_escs = list(zip(*ESCAPE_TEST_FIELDS.items()))
66 | assert escape_field_names(all_lits) == ".".join(all_escs)
67 | assert escape_field_names(*all_lits) == ".".join(all_escs)
68 | assert escape_field_names(all_lits[:0]) == ".".join(all_escs[:0])
69 | assert escape_field_names(*all_lits[:0]) == ".".join(all_escs[:0])
70 | assert escape_field_names(all_lits[:3]) == ".".join(all_escs[:3])
71 | assert escape_field_names(*all_lits[:3]) == ".".join(all_escs[:3])
72 | assert escape_field_names(all_lits[3:6]) == ".".join(all_escs[3:6])
73 | assert escape_field_names(*all_lits[3:6]) == ".".join(all_escs[3:6])
74 |
75 | assert escape_field_names(["first", 12, "last&!."]) == "first.12.last&&!&."
76 | assert escape_field_names("first", 12, "last&!.") == "first.12.last&&!&."
77 |
78 | @pytest.mark.describe("test of unescape_field_path")
79 | def test_unescape_field_path(self) -> None:
80 | assert unescape_field_path("a&.b") == ["a.b"]
81 |
82 | all_lits, all_escs = list(zip(*ESCAPE_TEST_FIELDS.items()))
83 | assert unescape_field_path(".".join(all_escs)) == list(all_lits)
84 | assert unescape_field_path(".".join(all_escs[:3])) == list(all_lits[:3])
85 | assert unescape_field_path(".".join(all_escs[3:6])) == list(all_lits[3:6])
86 |
87 | assert unescape_field_path("") == []
88 |
89 | with pytest.raises(
90 | ValueError, match=ILLEGAL_ESCAPE_ERROR_MESSAGE_TEMPLATE[:24]
91 | ):
92 | unescape_field_path("a.b&?c.d")
93 |
94 | with pytest.raises(
95 | ValueError, match=UNTERMINATED_ESCAPE_ERROR_MESSAGE_TEMPLATE[:24]
96 | ):
97 | unescape_field_path("a.b.c&")
98 |
--------------------------------------------------------------------------------
/tests/base/unit/test_embeddingheadersprovider.py:
--------------------------------------------------------------------------------
1 | # Copyright DataStax, Inc.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | from __future__ import annotations
16 |
17 | import pytest
18 |
19 | from astrapy.authentication import (
20 | AWSEmbeddingHeadersProvider,
21 | EmbeddingAPIKeyHeaderProvider,
22 | coerce_embedding_headers_provider,
23 | )
24 | from astrapy.settings.defaults import (
25 | EMBEDDING_HEADER_API_KEY,
26 | EMBEDDING_HEADER_AWS_ACCESS_ID,
27 | EMBEDDING_HEADER_AWS_SECRET_ID,
28 | )
29 |
30 |
31 | class TestEmbeddingHeadersProvider:
32 | @pytest.mark.describe("test of headers from EmbeddingAPIKeyHeaderProvider")
33 | def test_embeddingheadersprovider_static(self) -> None:
34 | ehp = EmbeddingAPIKeyHeaderProvider("x")
35 | assert {k.lower(): v for k, v in ehp.get_headers().items()} == {
36 | EMBEDDING_HEADER_API_KEY.lower(): "x"
37 | }
38 |
39 | @pytest.mark.describe("test of headers from empty EmbeddingAPIKeyHeaderProvider")
40 | def test_embeddingheadersprovider_null(self) -> None:
41 | ehp = EmbeddingAPIKeyHeaderProvider(None)
42 | assert ehp.get_headers() == {}
43 |
44 | @pytest.mark.describe("test of headers from AWSEmbeddingHeadersProvider")
45 | def test_embeddingheadersprovider_aws(self) -> None:
46 | ehp = AWSEmbeddingHeadersProvider(
47 | embedding_access_id="x",
48 | embedding_secret_id="y",
49 | )
50 | gen_headers_lower = {k.lower(): v for k, v in ehp.get_headers().items()}
51 | exp_headers_lower = {
52 | EMBEDDING_HEADER_AWS_ACCESS_ID.lower(): "x",
53 | EMBEDDING_HEADER_AWS_SECRET_ID.lower(): "y",
54 | }
55 | assert gen_headers_lower == exp_headers_lower
56 |
57 | @pytest.mark.describe("test of embedding headers provider coercion")
58 | def test_embeddingheadersprovider_coercion(self) -> None:
59 | """This doubles as equality test."""
60 | ehp_s = EmbeddingAPIKeyHeaderProvider("x")
61 | ehp_n = EmbeddingAPIKeyHeaderProvider(None)
62 | ehp_a = AWSEmbeddingHeadersProvider(
63 | embedding_access_id="x",
64 | embedding_secret_id="y",
65 | )
66 | assert coerce_embedding_headers_provider(ehp_s) == ehp_s
67 | assert coerce_embedding_headers_provider(ehp_n) == ehp_n
68 | assert coerce_embedding_headers_provider(ehp_a) == ehp_a
69 |
70 | assert coerce_embedding_headers_provider("x") == ehp_s
71 |
--------------------------------------------------------------------------------
/tests/base/unit/test_findrerankingproviderresult.py:
--------------------------------------------------------------------------------
1 | # Copyright DataStax, Inc.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | from __future__ import annotations
16 |
17 | import pytest
18 |
19 | from astrapy.info import FindRerankingProvidersResult
20 |
21 | RESPONSE_DICT_0 = {
22 | "rerankingProviders": {
23 | "provider": {
24 | "isDefault": True,
25 | "displayName": "TheProvider",
26 | "supportedAuthentication": {
27 | "NONE": {
28 | "tokens": [],
29 | "enabled": True,
30 | },
31 | },
32 | "models": [
33 | {
34 | "name": "provider/",
35 | "isDefault": True,
36 | "url": "https:///ranking",
37 | "properties": None,
38 | },
39 | ],
40 | },
41 | },
42 | }
43 |
44 | RESPONSE_DICT_1 = {
45 | "rerankingProviders": {
46 | "provider": {
47 | "isDefault": True,
48 | "displayName": "TheProvider",
49 | "supportedAuthentication": {
50 | "NONE": {
51 | "tokens": [],
52 | "enabled": True,
53 | },
54 | },
55 | "models": [
56 | {
57 | "name": "provider/",
58 | "apiModelSupport": {
59 | "status": "SUPPORTED",
60 | },
61 | "isDefault": True,
62 | "url": "https:///ranking",
63 | "properties": None,
64 | },
65 | ],
66 | },
67 | },
68 | }
69 |
70 |
71 | class TestFindRerankingProvidersResult:
72 | @pytest.mark.describe("test of FindRerankingProvidersResult parsing and back")
73 | def test_reranking_providers_result_parsing_and_back_base(self) -> None:
74 | parsed = FindRerankingProvidersResult._from_dict(RESPONSE_DICT_0)
75 | providers = parsed.reranking_providers
76 | assert len(providers) == 1
77 | assert providers["provider"].display_name is not None
78 |
79 | models = providers["provider"].models
80 | assert len(models) == 1
81 | assert models[0].is_default
82 |
83 | dumped = parsed.as_dict()
84 | # clean out apiModelSupport from generated dict before checking
85 | for pro_v in dumped["rerankingProviders"].values():
86 | for mod_v in pro_v["models"]:
87 | del mod_v["apiModelSupport"]
88 | assert dumped == RESPONSE_DICT_0
89 |
90 | @pytest.mark.describe(
91 | "test of FindRerankingProvidersResult parsing and back (with apiModelSupport)"
92 | )
93 | def test_reranking_providers_result_parsing_and_back_rich(self) -> None:
94 | parsed = FindRerankingProvidersResult._from_dict(RESPONSE_DICT_1)
95 | providers = parsed.reranking_providers
96 | assert len(providers) == 1
97 | assert providers["provider"].display_name is not None
98 |
99 | models = providers["provider"].models
100 | assert len(models) == 1
101 | assert models[0].is_default
102 |
103 | dumped = parsed.as_dict()
104 | assert dumped == RESPONSE_DICT_1
105 |
--------------------------------------------------------------------------------
/tests/base/unit/test_ids.py:
--------------------------------------------------------------------------------
1 | # Copyright DataStax, Inc.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | """
16 | Unit tests for the ObjectIds and UUIDn conversions, 'idiomatic' imports
17 | """
18 |
19 | from __future__ import annotations
20 |
21 | import json
22 |
23 | import pytest
24 |
25 | from astrapy.data.utils.collection_converters import (
26 | postprocess_collection_response,
27 | preprocess_collection_payload,
28 | )
29 | from astrapy.ids import UUID, ObjectId
30 | from astrapy.utils.api_options import FullSerdesOptions
31 |
32 |
33 | @pytest.mark.describe("test of serdes for ids")
34 | def test_ids_serdes() -> None:
35 | f_u1 = UUID("8ccd6ff8-e61b-11ee-a2fc-7df4a8c4164b") # uuid1
36 | f_u3 = UUID("6fa459ea-ee8a-3ca4-894e-db77e160355e") # uuid3
37 | f_u4 = UUID("4f16cba8-1115-43ab-aa39-3a9c29f37db5") # uuid4
38 | f_u5 = UUID("886313e1-3b8a-5372-9b90-0c9aee199e5d") # uuid5
39 | f_u6 = UUID("1eee61b9-8f2d-69ad-8ebb-5054d2a1a2c0") # uuid6
40 | f_u7 = UUID("018e57e5-f586-7ed6-be55-6b0de3041116") # uuid7
41 | f_u8 = UUID("018e57e5-fbcd-8bd4-b794-be914f2c4c85") # uuid8
42 | f_oi = ObjectId("65f9cfa0d7fabb3f255c25a1")
43 |
44 | full_structure = {
45 | "f_u1": f_u1,
46 | "f_u3": f_u3,
47 | "f_u4": f_u4,
48 | "f_u5": f_u5,
49 | "f_u6": f_u6,
50 | "f_u7": f_u7,
51 | "f_u8": f_u8,
52 | "f_oi": f_oi,
53 | }
54 |
55 | normalized = preprocess_collection_payload(
56 | full_structure,
57 | options=FullSerdesOptions(
58 | binary_encode_vectors=True,
59 | custom_datatypes_in_reading=True,
60 | unroll_iterables_to_lists=True,
61 | use_decimals_in_collections=False,
62 | encode_maps_as_lists_in_tables="NEVER",
63 | accept_naive_datetimes=False,
64 | datetime_tzinfo=None,
65 | ),
66 | )
67 | json.dumps(normalized)
68 | assert normalized is not None
69 | restored = postprocess_collection_response(
70 | normalized,
71 | options=FullSerdesOptions(
72 | binary_encode_vectors=True,
73 | custom_datatypes_in_reading=True,
74 | unroll_iterables_to_lists=True,
75 | use_decimals_in_collections=False,
76 | encode_maps_as_lists_in_tables="NEVER",
77 | accept_naive_datetimes=False,
78 | datetime_tzinfo=None,
79 | ),
80 | )
81 | assert restored == full_structure
82 |
--------------------------------------------------------------------------------
/tests/base/unit/test_info.py:
--------------------------------------------------------------------------------
1 | # Copyright DataStax, Inc.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | """
16 | Unit tests for the parsing of API endpoints and related
17 | """
18 |
19 | from __future__ import annotations
20 |
21 | import pytest
22 |
23 | from astrapy.admin import ParsedAPIEndpoint, parse_api_endpoint
24 | from astrapy.info import AstraDBAvailableRegionInfo
25 |
26 |
27 | @pytest.mark.describe("test of parsing API endpoints")
28 | def test_parse_api_endpoint() -> None:
29 | parsed_prod = parse_api_endpoint(
30 | "https://01234567-89ab-cdef-0123-456789abcdef-eu-west-1.apps.astra.datastax.com"
31 | )
32 | assert isinstance(parsed_prod, ParsedAPIEndpoint)
33 | assert parsed_prod == ParsedAPIEndpoint(
34 | database_id="01234567-89ab-cdef-0123-456789abcdef",
35 | region="eu-west-1",
36 | environment="prod",
37 | )
38 |
39 | parsed_prod = parse_api_endpoint(
40 | "https://a1234567-89ab-cdef-0123-456789abcdef-us-central1.apps.astra-dev.datastax.com"
41 | )
42 | assert isinstance(parsed_prod, ParsedAPIEndpoint)
43 | assert parsed_prod == ParsedAPIEndpoint(
44 | database_id="a1234567-89ab-cdef-0123-456789abcdef",
45 | region="us-central1",
46 | environment="dev",
47 | )
48 |
49 | parsed_prod = parse_api_endpoint(
50 | "https://b1234567-89ab-cdef-0123-456789abcdef-eu-southwest-4.apps.astra-test.datastax.com/subpath?a=1"
51 | )
52 | assert isinstance(parsed_prod, ParsedAPIEndpoint)
53 | assert parsed_prod == ParsedAPIEndpoint(
54 | database_id="b1234567-89ab-cdef-0123-456789abcdef",
55 | region="eu-southwest-4",
56 | environment="test",
57 | )
58 |
59 | malformed_endpoints = [
60 | "http://01234567-89ab-cdef-0123-456789abcdef-us-central1.apps.astra-dev.datastax.com",
61 | "https://a909bdbf-q9ba-4e5e-893c-a859ed701407-us-central1.apps.astra-dev.datastax.com",
62 | "https://01234567-89ab-cdef-0123-456789abcdef-us-c_entral1.apps.astra-dev.datastax.com",
63 | "https://01234567-89ab-cdef-0123-456789abcdef-us-central1.apps.astra-fake.datastax.com",
64 | "https://01234567-89ab-cdef-0123-456789abcdef-us-central1.apps.astra-dev.datastax-staging.com",
65 | "ahttps://01234567-89ab-cdef-0123-456789abcdef-us-central1.apps.astra-dev.datastax.com",
66 | ]
67 |
68 | for m_ep in malformed_endpoints:
69 | assert parse_api_endpoint(m_ep) is None
70 |
71 |
72 | @pytest.mark.describe("test of marshaling of available region info")
73 | def test_parse_availableregioninfo() -> None:
74 | region_dict = {
75 | "classification": "standard",
76 | "cloudProvider": "AWS",
77 | "displayName": "US East (Ohio)",
78 | "enabled": True,
79 | "name": "us-east-2",
80 | "region_type": "vector",
81 | "reservedForQualifiedUsers": False,
82 | "zone": "na",
83 | }
84 | assert AstraDBAvailableRegionInfo._from_dict(region_dict).as_dict() == region_dict
85 |
--------------------------------------------------------------------------------
/tests/base/unit/test_multicalltimeoutmanager.py:
--------------------------------------------------------------------------------
1 | # Copyright DataStax, Inc.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | from __future__ import annotations
16 |
17 | import time
18 |
19 | import pytest
20 |
21 | from astrapy.exceptions import (
22 | DataAPITimeoutException,
23 | DevOpsAPITimeoutException,
24 | MultiCallTimeoutManager,
25 | )
26 |
27 |
28 | class TestTimeouts:
29 | @pytest.mark.describe("test MultiCallTimeoutManager")
30 | def test_multicalltimeoutmanager(self) -> None:
31 | mgr_n = MultiCallTimeoutManager(overall_timeout_ms=None)
32 | assert mgr_n.remaining_timeout().request_ms is None
33 | time.sleep(0.5)
34 | assert mgr_n.remaining_timeout().request_ms is None
35 |
36 | mgr_1 = MultiCallTimeoutManager(overall_timeout_ms=1000)
37 | crt_1 = mgr_1.remaining_timeout().request_ms
38 | assert crt_1 is not None
39 | time.sleep(0.6)
40 | crt_2 = mgr_1.remaining_timeout().request_ms
41 | assert crt_2 is not None
42 | time.sleep(0.6)
43 | with pytest.raises(DataAPITimeoutException):
44 | mgr_1.remaining_timeout().request_ms
45 |
46 | @pytest.mark.describe("test MultiCallTimeoutManager DevOps")
47 | def test_multicalltimeoutmanager_devops(self) -> None:
48 | mgr_n = MultiCallTimeoutManager(overall_timeout_ms=None, dev_ops_api=True)
49 | assert mgr_n.remaining_timeout().request_ms is None
50 | time.sleep(0.5)
51 | assert mgr_n.remaining_timeout().request_ms is None
52 |
53 | mgr_1 = MultiCallTimeoutManager(overall_timeout_ms=1000, dev_ops_api=True)
54 | crt_1 = mgr_1.remaining_timeout().request_ms
55 | assert crt_1 is not None
56 | time.sleep(0.6)
57 | crt_2 = mgr_1.remaining_timeout().request_ms
58 | assert crt_2 is not None
59 | time.sleep(0.6)
60 | with pytest.raises(DevOpsAPITimeoutException):
61 | mgr_1.remaining_timeout().request_ms
62 |
--------------------------------------------------------------------------------
/tests/base/unit/test_regionname_deprecation.py:
--------------------------------------------------------------------------------
1 | # Copyright DataStax, Inc.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | from __future__ import annotations
16 |
17 | import pytest
18 | from deprecation import DeprecatedWarning
19 |
20 | from astrapy.data.info.database_info import (
21 | AstraDBAdminDatabaseRegionInfo,
22 | AstraDBAvailableRegionInfo,
23 | )
24 |
25 | from ..conftest import is_future_version
26 |
27 |
28 | class TestRegionNameDeprecation:
29 | @pytest.mark.describe("test of region_name in AstraDBAdminDatabaseRegionInfo")
30 | def test_regionname_dbregioninfo(self) -> None:
31 | dc0 = {
32 | "capacityUnits": 1,
33 | "cloudProvider": "AWS",
34 | "dateCreated": "2025-04-10T19:14:44Z",
35 | "id": "...",
36 | "isPrimary": True,
37 | "name": "dc-1",
38 | "region": "us-east-2",
39 | "regionClassification": "standard",
40 | "regionZone": "na",
41 | "requestedNodeCount": 3,
42 | "secureBundleInternalUrl": "...",
43 | "secureBundleMigrationProxyInternalUrl": "...",
44 | "secureBundleMigrationProxyUrl": "...",
45 | "secureBundleUrl": "...",
46 | "status": "",
47 | "streamingTenant": {},
48 | "targetAccount": "abcd0123",
49 | "tier": "serverless",
50 | }
51 | region_info = AstraDBAdminDatabaseRegionInfo(
52 | raw_datacenter_dict=dc0,
53 | environment="dev",
54 | database_id="D",
55 | )
56 |
57 | r0 = region_info.name
58 | with pytest.warns(DeprecationWarning) as w_checker:
59 | r1 = region_info.region_name
60 | assert len(w_checker.list) == 1
61 | warning0 = w_checker.list[0].message
62 | assert isinstance(warning0, DeprecatedWarning)
63 | assert is_future_version(warning0.removed_in)
64 |
65 | assert r0 == r1
66 |
67 | @pytest.mark.describe("test of region_name in AstraDBAvailableRegionInfo")
68 | def test_regionname_availableregion(self) -> None:
69 | ar0 = {
70 | "classification": "premium",
71 | "cloudProvider": "GCP",
72 | "displayName": "West Europe3 (Frankfurt, Germany)",
73 | "enabled": True,
74 | "name": "europe-west3",
75 | "region_type": "serverless",
76 | "reservedForQualifiedUsers": True,
77 | "zone": "emea",
78 | }
79 | available_region_info = AstraDBAvailableRegionInfo._from_dict(ar0)
80 |
81 | r0 = available_region_info.name
82 | with pytest.warns(DeprecationWarning) as w_checker:
83 | r1 = available_region_info.region_name
84 |
85 | assert len(w_checker.list) == 1
86 | warning0 = w_checker.list[0].message
87 | assert isinstance(warning0, DeprecatedWarning)
88 | assert is_future_version(warning0.removed_in)
89 |
90 | assert r0 == r1
91 |
--------------------------------------------------------------------------------
/tests/base/unit/test_rerankingheadersprovider.py:
--------------------------------------------------------------------------------
1 | # Copyright DataStax, Inc.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | from __future__ import annotations
16 |
17 | import pytest
18 |
19 | from astrapy.authentication import (
20 | RerankingAPIKeyHeaderProvider,
21 | coerce_reranking_headers_provider,
22 | )
23 | from astrapy.settings.defaults import RERANKING_HEADER_API_KEY
24 |
25 |
26 | class TestRerankingHeadersProvider:
27 | @pytest.mark.describe("test of headers from RerankingAPIKeyHeaderProvider")
28 | def test_rerankingheadersprovider_static(self) -> None:
29 | ehp = RerankingAPIKeyHeaderProvider("x")
30 | assert {k.lower(): v for k, v in ehp.get_headers().items()} == {
31 | RERANKING_HEADER_API_KEY.lower(): "x"
32 | }
33 |
34 | @pytest.mark.describe("test of headers from empty RerankingAPIKeyHeaderProvider")
35 | def test_rerankingheadersprovider_null(self) -> None:
36 | ehp = RerankingAPIKeyHeaderProvider(None)
37 | assert ehp.get_headers() == {}
38 |
39 | @pytest.mark.describe("test of reranking headers provider coercion")
40 | def test_rerankingheadersprovider_coercion(self) -> None:
41 | """This doubles as equality test."""
42 | ehp_s = RerankingAPIKeyHeaderProvider("x")
43 | ehp_n = RerankingAPIKeyHeaderProvider(None)
44 | assert coerce_reranking_headers_provider(ehp_s) == ehp_s
45 | assert coerce_reranking_headers_provider(ehp_n) == ehp_n
46 |
47 | assert coerce_reranking_headers_provider("x") == ehp_s
48 |
--------------------------------------------------------------------------------
/tests/base/unit/test_strenum.py:
--------------------------------------------------------------------------------
1 | # Copyright DataStax, Inc.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | from __future__ import annotations
16 |
17 | import pytest
18 |
19 | from astrapy.utils.str_enum import StrEnum
20 |
21 |
22 | class TestEnum(StrEnum):
23 | VALUE = "value"
24 | VALUE_DASH = "value-dash"
25 |
26 |
27 | class TestStrEnum:
28 | def test_strenum_contains(self) -> None:
29 | assert "value" in TestEnum
30 | assert "value_dash" in TestEnum
31 | assert "value-dash" in TestEnum
32 | assert "VALUE-DASH" in TestEnum
33 | assert "pippo" not in TestEnum
34 | assert {6: 12} not in TestEnum
35 |
36 | def test_strenum_coerce(self) -> None:
37 | TestEnum.coerce("value")
38 | TestEnum.coerce("value_dash")
39 | TestEnum.coerce("value-dash")
40 | TestEnum.coerce("VALUE-DASH")
41 | TestEnum.coerce(TestEnum.VALUE)
42 | with pytest.raises(ValueError):
43 | TestEnum.coerce("pippo")
44 |
--------------------------------------------------------------------------------
/tests/base/unit/test_table_decimal_support.py:
--------------------------------------------------------------------------------
1 | # Copyright DataStax, Inc.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | from __future__ import annotations
16 |
17 | from typing import Any
18 |
19 | import pytest
20 |
21 | from astrapy.constants import DefaultRowType
22 | from astrapy.data.utils.table_converters import _TableConverterAgent
23 | from astrapy.utils.api_commander import APICommander
24 | from astrapy.utils.api_options import FullSerdesOptions, defaultAPIOptions
25 |
26 | from ..conftest import _repaint_NaNs
27 | from ..table_decimal_support_assets import (
28 | BASELINE_COLUMNS,
29 | BASELINE_KEY_STR,
30 | BASELINE_OBJ,
31 | COLLTYPES_CUSTOM_OBJ_SCHEMA_TRIPLES,
32 | COLLTYPES_TRIPLE_IDS,
33 | WDECS_KEY_STR,
34 | WDECS_OBJ,
35 | WDECS_OBJ_COLUMNS,
36 | )
37 |
38 |
39 | class TestTableDecimalSupportUnit:
40 | @pytest.mark.describe("test of decimal-related conversions in table codec paths")
41 | def test_decimalsupport_table_codecpath(self) -> None:
42 | t_agent: _TableConverterAgent[DefaultRowType] = _TableConverterAgent(
43 | options=defaultAPIOptions(environment="prod").serdes_options,
44 | )
45 | # baseline, encode then decode and check
46 | baseline_fully_encoded = APICommander._decimal_aware_encode_payload(
47 | t_agent.preprocess_payload(BASELINE_OBJ, map2tuple_checker=None)
48 | )
49 | baseline_obj_2 = t_agent.postprocess_row(
50 | APICommander._decimal_aware_parse_json_response(baseline_fully_encoded), # type: ignore[arg-type]
51 | columns_dict=BASELINE_COLUMNS,
52 | similarity_pseudocolumn=None,
53 | )
54 | assert _repaint_NaNs(baseline_obj_2) == _repaint_NaNs(BASELINE_OBJ)
55 | # with-decimals, encode then decode and check
56 | wdecs_fully_encoded = APICommander._decimal_aware_encode_payload(
57 | t_agent.preprocess_payload(WDECS_OBJ, map2tuple_checker=None)
58 | )
59 | wdecs_2 = t_agent.postprocess_row(
60 | APICommander._decimal_aware_parse_json_response(wdecs_fully_encoded), # type: ignore[arg-type]
61 | columns_dict=WDECS_OBJ_COLUMNS,
62 | similarity_pseudocolumn=None,
63 | )
64 | assert _repaint_NaNs(wdecs_2) == _repaint_NaNs(WDECS_OBJ)
65 |
66 | # baseline, keys (decode only)
67 | baseline_kobj_2 = t_agent.postprocess_key(
68 | APICommander._decimal_aware_parse_json_response(BASELINE_KEY_STR), # type: ignore[arg-type]
69 | primary_key_schema_dict=BASELINE_COLUMNS,
70 | )[1]
71 | assert _repaint_NaNs(baseline_kobj_2) == _repaint_NaNs(BASELINE_OBJ)
72 | # with-decimals, keys (decode only)
73 | wdecs_kobj_2 = t_agent.postprocess_key(
74 | APICommander._decimal_aware_parse_json_response(WDECS_KEY_STR), # type: ignore[arg-type]
75 | primary_key_schema_dict=WDECS_OBJ_COLUMNS,
76 | )[1]
77 | assert _repaint_NaNs(wdecs_kobj_2) == _repaint_NaNs(WDECS_OBJ)
78 |
79 | @pytest.mark.parametrize(
80 | ("colltype_custom", "colltype_obj", "colltype_columns"),
81 | COLLTYPES_CUSTOM_OBJ_SCHEMA_TRIPLES,
82 | ids=COLLTYPES_TRIPLE_IDS,
83 | )
84 | @pytest.mark.describe(
85 | "test of decimal-related conversions in table codec paths, collection types"
86 | )
87 | def test_decimalsupport_table_collectiontypes_codecpath(
88 | self,
89 | colltype_custom: bool,
90 | colltype_obj: dict[str, Any],
91 | colltype_columns: dict[str, Any],
92 | ) -> None:
93 | t_agent: _TableConverterAgent[DefaultRowType] = _TableConverterAgent(
94 | options=FullSerdesOptions(
95 | binary_encode_vectors=True,
96 | custom_datatypes_in_reading=colltype_custom,
97 | unroll_iterables_to_lists=True,
98 | use_decimals_in_collections=True,
99 | encode_maps_as_lists_in_tables="never",
100 | accept_naive_datetimes=False,
101 | datetime_tzinfo=None,
102 | ),
103 | )
104 |
105 | fully_encoded = APICommander._decimal_aware_encode_payload(
106 | t_agent.preprocess_payload(colltype_obj, map2tuple_checker=None)
107 | )
108 | colltype_obj_2 = t_agent.postprocess_row(
109 | APICommander._decimal_aware_parse_json_response(fully_encoded), # type: ignore[arg-type]
110 | columns_dict=colltype_columns,
111 | similarity_pseudocolumn=None,
112 | )
113 | assert _repaint_NaNs(colltype_obj_2) == _repaint_NaNs(colltype_obj)
114 |
--------------------------------------------------------------------------------
/tests/base/unit/test_token_providers.py:
--------------------------------------------------------------------------------
1 | # Copyright DataStax, Inc.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | """
16 | Unit tests for the token providers
17 | """
18 |
19 | from __future__ import annotations
20 |
21 | import pytest
22 |
23 | from astrapy.authentication import (
24 | StaticTokenProvider,
25 | UsernamePasswordTokenProvider,
26 | coerce_token_provider,
27 | )
28 |
29 |
30 | @pytest.mark.describe("test of static token provider")
31 | def test_static_token_provider() -> None:
32 | literal_t = "AstraCS:1"
33 | static_tp = StaticTokenProvider(literal_t)
34 |
35 | assert static_tp.get_token() == literal_t
36 |
37 |
38 | @pytest.mark.describe("test of username-password token provider")
39 | def test_username_password_token_provider() -> None:
40 | up_tp = UsernamePasswordTokenProvider("cassandraA", "cassandraB")
41 |
42 | assert up_tp.get_token() == "Cassandra:Y2Fzc2FuZHJhQQ==:Y2Fzc2FuZHJhQg=="
43 |
44 |
45 | @pytest.mark.describe("test of null token provider")
46 | def test_null_token_provider() -> None:
47 | null_tp = StaticTokenProvider(None)
48 |
49 | assert null_tp.get_token() is None
50 |
51 |
52 | @pytest.mark.describe("test of token providers coercion")
53 | def test_coerce_token_provider() -> None:
54 | literal_t = "AstraCS:1"
55 | static_tp = StaticTokenProvider(literal_t)
56 | null_tp = StaticTokenProvider(None)
57 | up_tp = UsernamePasswordTokenProvider("cassandraA", "cassandraB")
58 |
59 | assert coerce_token_provider(literal_t).get_token() == literal_t
60 | assert coerce_token_provider(static_tp).get_token() == static_tp.get_token()
61 | assert coerce_token_provider(up_tp).get_token() == up_tp.get_token()
62 | assert coerce_token_provider(null_tp).get_token() is None
63 |
64 |
65 | @pytest.mark.describe("test of token providers equality")
66 | def test_token_provider_equality() -> None:
67 | literal_t = "AstraCS:1"
68 | static_tp_1 = StaticTokenProvider(literal_t)
69 | null_tp_1 = StaticTokenProvider(None)
70 | up_tp_1 = UsernamePasswordTokenProvider("cassandraA", "cassandraB")
71 | static_tp_2 = StaticTokenProvider(literal_t)
72 | null_tp_2 = StaticTokenProvider(None)
73 | up_tp_2 = UsernamePasswordTokenProvider("cassandraA", "cassandraB")
74 |
75 | assert static_tp_1 == static_tp_2
76 | assert null_tp_1 == null_tp_2
77 | assert up_tp_1 == up_tp_2
78 |
79 | assert static_tp_1 != null_tp_1
80 | assert static_tp_1 != up_tp_1
81 | assert null_tp_1 != static_tp_1
82 | assert null_tp_1 != up_tp_1
83 | assert up_tp_1 != static_tp_1
84 | assert up_tp_1 != null_tp_1
85 |
86 | # effective equality
87 | assert up_tp_1 == StaticTokenProvider("Cassandra:Y2Fzc2FuZHJhQQ==:Y2Fzc2FuZHJhQg==")
88 |
89 |
90 | @pytest.mark.describe("test of token provider inheritance yield")
91 | def test_token_provider_inheritance_yield() -> None:
92 | static_tp = StaticTokenProvider("AstraCS:xyz")
93 | null_tp = StaticTokenProvider(None)
94 | up_tp = UsernamePasswordTokenProvider("cassandraA", "cassandraB")
95 |
96 | assert static_tp | static_tp == static_tp
97 | assert static_tp | null_tp == static_tp
98 | assert static_tp | up_tp == static_tp
99 |
100 | assert null_tp | static_tp == static_tp
101 | assert null_tp | null_tp == null_tp
102 | assert null_tp | up_tp == up_tp
103 |
104 | assert up_tp | static_tp == up_tp
105 | assert up_tp | null_tp == up_tp
106 | assert up_tp | up_tp == up_tp
107 |
108 | assert static_tp or static_tp == static_tp
109 | assert static_tp or null_tp == static_tp
110 | assert static_tp or up_tp == static_tp
111 |
112 | assert null_tp or static_tp == static_tp
113 | assert null_tp or null_tp == null_tp
114 | assert null_tp or up_tp == up_tp
115 |
116 | assert up_tp or static_tp == up_tp
117 | assert up_tp or null_tp == up_tp
118 | assert up_tp or up_tp == up_tp
119 |
--------------------------------------------------------------------------------
/tests/dse_compose/README:
--------------------------------------------------------------------------------
1 | Caution: DSE does not support all features required for a full Data API usage.
2 | (for example BM25 indexing, necessary for the Data API reranker).
3 |
4 | As such, do not expect all of the features (/testing) to work if running this compose.
5 |
--------------------------------------------------------------------------------
/tests/dse_compose/docker-compose.yml:
--------------------------------------------------------------------------------
1 | services:
2 |
3 | # complement with something like:
4 | # docker run -it --rm --network container:dse_compose-coordinator-1 cassandra:latest cqlsh -u cassandra -p cassandra
5 | # create keyspace default_keyspace WITH replication = {'class': 'SimpleStrategy', 'replication_factor': 1};
6 |
7 | # Development mode with latest tag
8 | coordinator:
9 | image: stargateio/coordinator-dse-next:v2.1
10 | networks:
11 | - stargate
12 | ports:
13 | - "9042:9042"
14 | - "8081:8081"
15 | - "8090:8090"
16 | mem_limit: 2G
17 | environment:
18 | - JAVA_OPTS="-Xmx2G"
19 | - CLUSTER_NAME=sgv2-cluster
20 | - RACK_NAME=rack1
21 | - DATACENTER_NAME=datacenter1
22 | - ENABLE_AUTH=true
23 | - DEVELOPER_MODE=true
24 | healthcheck:
25 | test: curl -f http://localhost:8084/checker/readiness || exit 1
26 | interval: 15s
27 | timeout: 10s
28 | retries: 10
29 |
30 | data-api:
31 | image: stargateio/data-api:v1.0.25
32 | depends_on:
33 | coordinator:
34 | condition: service_healthy
35 | networks:
36 | - stargate
37 | ports:
38 | - "8181:8181"
39 | mem_limit: 2G
40 | environment:
41 | #- QUARKUS_GRPC_CLIENTS_BRIDGE_HOST=coordinator
42 | #- QUARKUS_GRPC_CLIENTS_BRIDGE_PORT=8091
43 | - STARGATE_DATA_STORE_SAI_ENABLED=true
44 | - STARGATE_DATA_STORE_VECTOR_SEARCH_ENABLED=true
45 | - STARGATE_JSONAPI_OPERATIONS_VECTORIZE_ENABLED=true
46 | - STARGATE_FEATURE_FLAGS_TABLES=true
47 | - STARGATE_DATA_STORE_IGNORE_BRIDGE=true
48 | - STARGATE_JSONAPI_OPERATIONS_DATABASE_CONFIG_CASSANDRA_END_POINTS=coordinator
49 | - QUARKUS_HTTP_ACCESS_LOG_ENABLED=FALSE
50 | - QUARKUS_LOG_LEVEL=INFO
51 | - JAVA_MAX_MEM_RATIO=75
52 | - JAVA_INITIAL_MEM_RATIO=50
53 | - GC_CONTAINER_OPTIONS=-XX:+UseG1GC
54 | - JAVA_OPTS_APPEND=-Dquarkus.http.host=0.0.0.0 -Djava.util.logging.manager=org.jboss.logmanager.LogManager
55 | healthcheck:
56 | test: curl -f http://localhost:8181/stargate/health || exit 1
57 | interval: 5s
58 | timeout: 10s
59 | retries: 10
60 |
61 | networks:
62 | stargate:
--------------------------------------------------------------------------------
/tests/env_templates/env.astra.admin.template:
--------------------------------------------------------------------------------
1 | ########################
2 | # FOR THE ADMIN TESTS: #
3 | ########################
4 |
5 | # PROD settings
6 | export PROD_ADMIN_TEST_ASTRA_DB_APPLICATION_TOKEN="AstraCS:..."
7 | export PROD_ADMIN_TEST_ASTRA_DB_PROVIDER="aws"
8 | export PROD_ADMIN_TEST_ASTRA_DB_REGION="eu-west-1"
9 |
10 | # DEV settings (optional)
11 | export DEV_ADMIN_TEST_ASTRA_DB_APPLICATION_TOKEN="AstraCS:..."
12 | export DEV_ADMIN_TEST_ASTRA_DB_PROVIDER="aws"
13 | export DEV_ADMIN_TEST_ASTRA_DB_REGION="us-west-2"
14 |
15 | # TEST settings (optional)
16 | export TEST_ADMIN_TEST_ASTRA_DB_APPLICATION_TOKEN="AstraCS:..."
17 | export TEST_ADMIN_TEST_ASTRA_DB_PROVIDER="aws"
18 | export TEST_ADMIN_TEST_ASTRA_DB_REGION="us-west-2"
19 |
--------------------------------------------------------------------------------
/tests/env_templates/env.astra.template:
--------------------------------------------------------------------------------
1 | ##############################
2 | # FOR THE TESTS ON ASTRA DB: #
3 | ##############################
4 |
5 | export ASTRA_DB_APPLICATION_TOKEN="AstraCS:..."
6 |
7 | export ASTRA_DB_API_ENDPOINT="https://-.apps.astra.datastax.com"
8 |
9 | # OPTIONAL (the first has a default; a few tests are skipped if the second is missing):
10 | # export ASTRA_DB_KEYSPACE="..."
11 | # export ASTRA_DB_SECONDARY_KEYSPACE="..."
12 |
--------------------------------------------------------------------------------
/tests/env_templates/env.local.template:
--------------------------------------------------------------------------------
1 | ########################################
2 | # FOR THE TESTS WITH A LOCAL DATA API: #
3 | ########################################
4 |
5 | # Authentication:
6 | export LOCAL_DATA_API_USERNAME="cassandra"
7 | export LOCAL_DATA_API_PASSWORD="cassandra"
8 |
9 | export LOCAL_DATA_API_ENDPOINT="http://localhost:8181"
10 | export LOCAL_CASSANDRA_CONTACT_POINT="127.0.0.1"
11 |
12 | # OPTIONAL: (if defined here, they will be created as needed)
13 | export LOCAL_DATA_API_KEYSPACE="default_keyspace"
14 | export LOCAL_DATA_API_SECONDARY_KEYSPACE="alternate_keyspace"
15 |
16 | # RERANKER SETTINGS
17 | export ASTRAPY_FINDANDRERANK_USE_RERANKER_HEADER="y"
18 | # Local runs require a DEV token as reranker API Key:
19 | export HEADER_RERANKING_API_KEY_NVIDIA="AstraCS:[...]"
20 |
--------------------------------------------------------------------------------
/tests/env_templates/env.testcontainers.template:
--------------------------------------------------------------------------------
1 | ######################################
2 | # FOR THE TESTS WITH TESTCONTAINERS: #
3 | ######################################
4 |
5 | export DOCKER_COMPOSE_LOCAL_DATA_API="yes"
6 |
--------------------------------------------------------------------------------
/tests/env_templates/env.vectorize-minimal.template:
--------------------------------------------------------------------------------
1 | ####################################
2 | # FOR THE VECTORIZE TESTS IN BASE: #
3 | ####################################
4 |
5 | export HEADER_EMBEDDING_API_KEY_OPENAI="..."
6 |
--------------------------------------------------------------------------------
/tests/env_templates/env.vectorize.template:
--------------------------------------------------------------------------------
1 | #####################################
2 | # FOR THE EXTENDED VECTORIZE TESTS: #
3 | #####################################
4 |
5 |
6 | export HEADER_EMBEDDING_API_KEY_HUGGINGFACE="..."
7 |
8 | export HEADER_EMBEDDING_API_KEY_COHERE="..."
9 |
10 | export HEADER_EMBEDDING_API_KEY_VOYAGEAI="..."
11 |
12 | export HEADER_EMBEDDING_API_KEY_MISTRAL="..."
13 |
14 | export HEADER_EMBEDDING_API_KEY_UPSTAGE="..."
15 |
16 | export HEADER_EMBEDDING_API_KEY_OPENAI="..."
17 | export OPENAI_ORGANIZATION_ID="..."
18 | export OPENAI_PROJECT_ID="..."
19 |
20 | export HEADER_EMBEDDING_API_KEY_AZURE_OPENAI="..."
21 | export AZURE_OPENAI_DEPLOY_ID_EMB3LARGE="..."
22 | export AZURE_OPENAI_RESNAME_EMB3LARGE="..."
23 | export AZURE_OPENAI_DEPLOY_ID_EMB3SMALL="..."
24 | export AZURE_OPENAI_RESNAME_EMB3SMALL="..."
25 | export AZURE_OPENAI_DEPLOY_ID_ADA2="..."
26 | export AZURE_OPENAI_RESNAME_ADA2="..."
27 |
28 | export HEADER_EMBEDDING_API_KEY_JINAAI="..."
29 |
30 | export HEADER_EMBEDDING_API_KEY_VERTEXAI="..."
31 |
32 | export HEADER_EMBEDDING_VERTEXAI_PROJECT_ID="..."
33 |
34 | export HEADER_EMBEDDING_API_KEY_HUGGINGFACEDED="..."
35 | export HUGGINGFACEDED_DIMENSION="..."
36 | export HUGGINGFACEDED_ENDPOINTNAME="..."
37 | export HUGGINGFACEDED_REGIONNAME="..."
38 | export HUGGINGFACEDED_CLOUDNAME="..."
39 |
40 | # Additional SHARED_SECRET testing information
41 | # (Preparation to be done in the UI at the moment)
42 | # Scope these secrets, with these names, to the targeted DB(s):
43 | # SHARED_SECRET_EMBEDDING_API_KEY_AZURE_OPENAI
44 | # SHARED_SECRET_EMBEDDING_API_KEY_HUGGINGFACE
45 | # SHARED_SECRET_EMBEDDING_API_KEY_HUGGINGFACEDED
46 | # SHARED_SECRET_EMBEDDING_API_KEY_JINAAI
47 | # SHARED_SECRET_EMBEDDING_API_KEY_MISTRAL
48 | # SHARED_SECRET_EMBEDDING_API_KEY_OPENAI
49 | # SHARED_SECRET_EMBEDDING_API_KEY_UPSTAGE
50 | # SHARED_SECRET_EMBEDDING_API_KEY_VOYAGEAI
51 |
--------------------------------------------------------------------------------
/tests/hcd_compose/docker-compose.yml:
--------------------------------------------------------------------------------
1 | services:
2 | hcd:
3 | image: 559669398656.dkr.ecr.us-west-2.amazonaws.com/engops-shared/hcd/staging/hcd:1.2.1-early-preview
4 | networks:
5 | - stargate
6 | mem_limit: 2G
7 | environment:
8 | - MAX_HEAP_SIZE=1536M
9 | - CLUSTER_NAME=hcd-1.2.1-early-preview-cluster
10 | - DS_LICENSE=accept
11 | - HCD_AUTO_CONF_OFF=cassandra.yaml
12 | volumes:
13 | - ./cassandra-hcd.yaml:/opt/hcd/resources/cassandra/conf/cassandra.yaml:rw
14 | ports:
15 | - "9042:9042"
16 | healthcheck:
17 | test: [ "CMD-SHELL", "cqlsh -u cassandra -p cassandra -e 'describe keyspaces'" ]
18 | interval: 15s
19 | timeout: 10s
20 | retries: 20
21 |
22 | data-api:
23 | image: stargateio/data-api:v1.0.25
24 | depends_on:
25 | hcd:
26 | condition: service_healthy
27 | networks:
28 | - stargate
29 | ports:
30 | - "8181:8181"
31 | mem_limit: 2G
32 | environment:
33 | - JAVA_MAX_MEM_RATIO=75
34 | - JAVA_INITIAL_MEM_RATIO=50
35 | - STARGATE_DATA_STORE_IGNORE_BRIDGE=true
36 | - GC_CONTAINER_OPTIONS=-XX:+UseG1GC
37 | - STARGATE_JSONAPI_OPERATIONS_DATABASE_CONFIG_CASSANDRA_END_POINTS=hcd
38 | - STARGATE_JSONAPI_OPERATIONS_DATABASE_CONFIG_LOCAL_DATACENTER=dc1
39 | - QUARKUS_HTTP_ACCESS_LOG_ENABLED=FALSE
40 | - QUARKUS_LOG_LEVEL=INFO
41 | - STARGATE_JSONAPI_OPERATIONS_VECTORIZE_ENABLED=true
42 | - STARGATE_FEATURE_FLAGS_TABLES=true
43 | - STARGATE_FEATURE_FLAGS_RERANKING=true
44 | - JAVA_OPTS_APPEND=-Dquarkus.http.host=0.0.0.0 -Djava.util.logging.manager=org.jboss.logmanager.LogManager
45 | healthcheck:
46 | test: curl -f http://localhost:8181/stargate/health || exit 1
47 | interval: 5s
48 | timeout: 10s
49 | retries: 10
50 | networks:
51 | stargate:
52 |
--------------------------------------------------------------------------------
/tests/vectorize/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright DataStax, Inc.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | from __future__ import annotations
16 |
--------------------------------------------------------------------------------
/tests/vectorize/conftest.py:
--------------------------------------------------------------------------------
1 | # Copyright DataStax, Inc.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | from __future__ import annotations
16 |
17 | from ..conftest import (
18 | IS_ASTRA_DB,
19 | )
20 |
21 | __all__ = [
22 | "IS_ASTRA_DB",
23 | ]
24 |
--------------------------------------------------------------------------------
/tests/vectorize/integration/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright DataStax, Inc.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | from __future__ import annotations
16 |
--------------------------------------------------------------------------------
/tests/vectorize/live_provider_info.py:
--------------------------------------------------------------------------------
1 | # Copyright DataStax, Inc.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | from __future__ import annotations
16 |
17 | from preprocess_env import (
18 | ASTRA_DB_API_ENDPOINT,
19 | ASTRA_DB_KEYSPACE,
20 | ASTRA_DB_TOKEN_PROVIDER,
21 | IS_ASTRA_DB,
22 | LOCAL_DATA_API_ENDPOINT,
23 | LOCAL_DATA_API_KEYSPACE,
24 | LOCAL_DATA_API_TOKEN_PROVIDER,
25 | )
26 |
27 | from astrapy import DataAPIClient, Database
28 | from astrapy.admin import parse_api_endpoint
29 | from astrapy.constants import Environment
30 | from astrapy.info import FindEmbeddingProvidersResult
31 |
32 |
33 | def live_provider_info() -> FindEmbeddingProvidersResult:
34 | """
35 | Query the API endpoint `findEmbeddingProviders` endpoint
36 | for the latest information.
37 |
38 | This utility function uses the environment variables it can find
39 | to establish a target database to query.
40 | """
41 |
42 | database: Database
43 | if IS_ASTRA_DB:
44 | parsed = parse_api_endpoint(ASTRA_DB_API_ENDPOINT)
45 | if parsed is None:
46 | raise ValueError(
47 | "Cannot parse the Astra DB API Endpoint '{ASTRA_DB_API_ENDPOINT}'"
48 | )
49 | client = DataAPIClient(environment=parsed.environment)
50 | database = client.get_database(
51 | ASTRA_DB_API_ENDPOINT,
52 | token=ASTRA_DB_TOKEN_PROVIDER,
53 | keyspace=ASTRA_DB_KEYSPACE,
54 | )
55 | else:
56 | client = DataAPIClient(environment=Environment.OTHER)
57 | database = client.get_database(
58 | LOCAL_DATA_API_ENDPOINT,
59 | token=LOCAL_DATA_API_TOKEN_PROVIDER,
60 | keyspace=LOCAL_DATA_API_KEYSPACE,
61 | )
62 |
63 | database_admin = database.get_database_admin()
64 | response = database_admin.find_embedding_providers()
65 | return response
66 |
--------------------------------------------------------------------------------
/tests/vectorize/query_providers.py:
--------------------------------------------------------------------------------
1 | # Copyright DataStax, Inc.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | from __future__ import annotations
16 |
17 | import json
18 | import os
19 | import sys
20 |
21 | from astrapy.info import EmbeddingProviderParameter, FindEmbeddingProvidersResult
22 |
23 | sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), "..")))
24 |
25 | from live_provider_info import live_provider_info
26 | from vectorize_models import live_test_models
27 |
28 |
29 | def desc_param(param_data: EmbeddingProviderParameter) -> str:
30 | if param_data.parameter_type.lower() == "string":
31 | return "str"
32 | elif param_data.parameter_type.lower() == "number":
33 | validation = param_data.validation
34 | if "numericRange" in validation:
35 | validation_nr = validation["numericRange"]
36 | assert isinstance(validation_nr, list) and len(validation_nr) == 2
37 | range_desc = f"[{validation_nr[0]} : {validation_nr[1]}]"
38 | if param_data.default_value is not None:
39 | range_desc2 = f"{range_desc} (default={param_data.default_value})"
40 | else:
41 | range_desc2 = range_desc
42 | return f"number, {range_desc2}"
43 | elif "options" in validation:
44 | validation_op = validation["options"]
45 | assert isinstance(validation_op, list) and len(validation_op) > 1
46 | return f"number, {' / '.join(str(v) for v in validation_op)}"
47 | else:
48 | raise ValueError(
49 | f"Unknown number validation spec: '{json.dumps(validation)}'"
50 | )
51 | elif param_data.parameter_type.lower() == "boolean":
52 | return "bool"
53 | else:
54 | raise NotImplementedError
55 |
56 |
57 | if __name__ == "__main__":
58 | provider_info: FindEmbeddingProvidersResult = live_provider_info()
59 | providers_json = (provider_info.raw_info or {}).get("embeddingProviders")
60 | if not providers_json:
61 | raise ValueError(
62 | "raw info from embedding providers lacks `embeddingProviders` content."
63 | )
64 | json.dump(providers_json, open("_providers.json", "w"), indent=2, sort_keys=True)
65 |
66 | for provider, provider_data in sorted(provider_info.embedding_providers.items()):
67 | print(f"{provider} ({len(provider_data.models)} models)")
68 | print(" auth:")
69 | for auth_type, auth_data in sorted(
70 | provider_data.supported_authentication.items()
71 | ):
72 | if auth_data.enabled:
73 | tokens = ", ".join(f"'{tok.accepted}'" for tok in auth_data.tokens)
74 | print(f" {auth_type} ({tokens})")
75 | if provider_data.parameters:
76 | print(" parameters")
77 | for param_data in provider_data.parameters:
78 | param_name = param_data.name
79 | if param_data.required:
80 | param_display_name = param_name
81 | else:
82 | param_display_name = f"({param_name})"
83 | param_desc = desc_param(param_data)
84 | print(f" - {param_display_name}: {param_desc}")
85 | print(" models:")
86 | for model_data in sorted(provider_data.models, key=lambda pro: pro.name):
87 | model_name = model_data.name
88 | if model_data.vector_dimension is not None:
89 | assert model_data.vector_dimension > 0
90 | model_dim_desc = f" (D = {model_data.vector_dimension})"
91 | else:
92 | model_dim_desc = ""
93 | if True:
94 | print(f" {model_name}{model_dim_desc}")
95 | if model_data.parameters:
96 | for param_data in model_data.parameters:
97 | param_name = param_data.name
98 | if param_data.required:
99 | param_display_name = param_name
100 | else:
101 | param_display_name = f"({param_name})"
102 | param_desc = desc_param(param_data)
103 | print(f" - {param_display_name}: {param_desc}")
104 |
105 | print("\n" * 2)
106 | all_test_models = list(live_test_models())
107 | for auth_type in ["HEADER", "NONE", "SHARED_SECRET"]:
108 | print(f"Tags for auth type {auth_type}:", end="")
109 | #
110 | at_test_models = [
111 | test_model
112 | for test_model in all_test_models
113 | if test_model["auth_type_name"] == auth_type
114 | ]
115 | at_model_ids: list[str] = sorted(
116 | [str(model_desc["model_tag"]) for model_desc in at_test_models]
117 | )
118 | if at_model_ids:
119 | print("")
120 | print("\n".join(f" {ami}" for ami in at_model_ids))
121 | else:
122 | print(" (no tags)")
123 |
--------------------------------------------------------------------------------