├── .gitchangelog.rc
├── .github
├── ISSUE_TEMPLATE.md
├── PULL_REQUEST_TEMPLATE.md
├── dependabot.yml
└── workflows
│ └── python-package.yml
├── .gitignore
├── .pre-commit-config.yaml
├── CHANGELOG.md
├── CODE_OF_CONDUCT.md
├── CONTRIBUTING.md
├── LICENSE
├── MANIFEST.in
├── NOTICE
├── README.rst
├── THIRD_PARTY_LICENSES
├── build.sh
├── redshift_connector
├── __init__.py
├── auth
│ ├── __init__.py
│ └── aws_credentials_provider.py
├── config.py
├── core.py
├── credentials_holder.py
├── cursor.py
├── error.py
├── files
│ └── redshift-ca-bundle.crt
├── iam_helper.py
├── idp_auth_helper.py
├── interval.py
├── metadataAPIHelper.py
├── metadataAPIPostProcessing.py
├── metadataServerAPIHelper.py
├── native_plugin_helper.py
├── objects.py
├── pg_types.py
├── plugin
│ ├── __init__.py
│ ├── adfs_credentials_provider.py
│ ├── azure_credentials_provider.py
│ ├── browser_azure_credentials_provider.py
│ ├── browser_azure_oauth2_credentials_provider.py
│ ├── browser_idc_auth_plugin.py
│ ├── browser_saml_credentials_provider.py
│ ├── common_credentials_provider.py
│ ├── credential_provider_constants.py
│ ├── i_native_plugin.py
│ ├── i_plugin.py
│ ├── idp_credentials_provider.py
│ ├── idp_token_auth_plugin.py
│ ├── jwt_credentials_provider.py
│ ├── native_token_holder.py
│ ├── okta_credentials_provider.py
│ ├── ping_credentials_provider.py
│ └── saml_credentials_provider.py
├── py.typed
├── redshift_property.py
├── utils
│ ├── __init__.py
│ ├── array_util.py
│ ├── driver_info.py
│ ├── extensible_digest.py
│ ├── logging_utils.py
│ ├── oids.py
│ ├── sql_types.py
│ └── type_utils.py
└── version.py
├── requirements-dev.txt
├── requirements.txt
├── setup.cfg
├── setup.py
├── test
├── __init__.py
├── conftest.py
├── integration
│ ├── __init__.py
│ ├── datatype
│ │ ├── _generate_test_datatype_tables.py
│ │ ├── test_datatypes.py
│ │ └── test_system_table_queries.py
│ ├── metadata
│ │ ├── test_list_catalog.py
│ │ ├── test_metadataAPI.py
│ │ ├── test_metadataAPIHelperServer.py
│ │ ├── test_metadataAPI_pattern_matching.py
│ │ ├── test_metadataAPI_special_character_handling.py
│ │ ├── test_metadataAPI_special_character_handling_case_sensitive.py
│ │ ├── test_metadataAPI_special_character_handling_standard_delimited_identifier_case_insensitive.py
│ │ ├── test_metadataAPI_special_character_handling_standard_delimited_identifier_case_insensitive_with_pipe.py
│ │ ├── test_metadataAPI_special_character_handling_standard_delimited_identifier_case_sensitive.py
│ │ ├── test_metadataAPI_special_character_handling_standard_identifier.py
│ │ └── test_metadataAPI_sql_injection.py
│ ├── plugin
│ │ ├── conftest.py
│ │ ├── test_azure_credentials_provider.py
│ │ ├── test_credentials_providers.py
│ │ └── test_okta_credentials_provider.py
│ ├── test_connection.py
│ ├── test_cursor.py
│ ├── test_dbapi20.py
│ ├── test_pandas.py
│ ├── test_paramstyle.py
│ ├── test_query.py
│ ├── test_redshift_property.py
│ └── test_unsupported_datatypes.py
├── manual
│ ├── __init__.py
│ ├── auth
│ │ ├── test_aws_credentials.py
│ │ └── test_redshift_auth_profile.py
│ ├── plugin
│ │ ├── __init__.py
│ │ └── test_browser_credentials_provider.py
│ ├── test_redshift_custom_domain.py
│ └── test_redshift_serverless.py
├── performance
│ ├── bulk_insert_data.csv
│ ├── bulk_insert_performance.py
│ ├── performance.py
│ ├── protocol_perf_test.sql
│ ├── protocol_performance.py
│ └── test.sql
├── unit
│ ├── __init__.py
│ ├── auth
│ │ ├── __init__.py
│ │ └── test_aws_credentials_provider.py
│ ├── datatype
│ │ ├── __init__.py
│ │ ├── test_data_in.py
│ │ ├── test_oids.py
│ │ └── test_sql_types.py
│ ├── helpers
│ │ ├── __init__.py
│ │ └── idp_helpers.py
│ ├── mocks
│ │ ├── __init__.py
│ │ ├── mock_external_credential_provider.py
│ │ └── mock_socket.py
│ ├── plugin
│ │ ├── __init__.py
│ │ ├── data
│ │ │ ├── __init__.py
│ │ │ ├── browser_azure_data.py
│ │ │ ├── mock_adfs_saml_response.html
│ │ │ ├── mock_adfs_sign_in.html
│ │ │ ├── saml_response.xml
│ │ │ └── saml_response_data.py
│ │ ├── test_adfs_credentials_provider.py
│ │ ├── test_azure_credentials_provider.py
│ │ ├── test_azure_oauth2_credentials_provider.py
│ │ ├── test_browser_azure_credentials_provider.py
│ │ ├── test_browser_idc_auth_plugin.py
│ │ ├── test_browser_saml_credentials_provider.py
│ │ ├── test_credentials_providers.py
│ │ ├── test_idp_token_auth_plugin.py
│ │ ├── test_jwt_credentials_provider.py
│ │ ├── test_okta_credentials_provider.py
│ │ ├── test_plugin_inheritance.py
│ │ └── test_saml_credentials_provider.py
│ ├── test_array_util.py
│ ├── test_connection.py
│ ├── test_core.py
│ ├── test_credentials_holder.py
│ ├── test_cursor.py
│ ├── test_dbapi20.py
│ ├── test_driver_info.py
│ ├── test_iam_helper.py
│ ├── test_idp_auth_helper.py
│ ├── test_import.py
│ ├── test_logging_utils.py
│ ├── test_metadataAPIHelper.py
│ ├── test_metadataAPIPostProcessing.py
│ ├── test_metadataServerAPIHelper.py
│ ├── test_module.py
│ ├── test_native_plugin_helper.py
│ ├── test_paramstyle.py
│ ├── test_redshift_property.py
│ ├── test_type_utils.py
│ └── test_typeobjects.py
└── utils
│ ├── __init__.py
│ └── decorators.py
└── tutorials
├── .ipynb_checkpoints
└── 001 - Connecting to Amazon Redshift-checkpoint.ipynb
├── 001 - Connecting to Amazon Redshift.ipynb
├── 002 - Data Science Library Integrations.ipynb
├── 003 - Amazon Redshift Feature Support.ipynb
└── 004 - Amazon Redshift Datatypes.ipynb
/.github/ISSUE_TEMPLATE.md:
--------------------------------------------------------------------------------
1 | ---
2 | name: Issue report
3 | about: Report an issue
4 | title: ''
5 | assignees: ''
6 |
7 | ---
8 |
9 | ## Driver version
10 |
11 |
12 | ## Redshift version
13 |
14 |
15 | ## Client Operating System
16 |
17 |
18 | ## Python version
19 |
20 |
21 | ## Table schema
22 |
23 |
24 | ## Problem description
25 |
26 | 1. Expected behaviour:
27 | 2. Actual behaviour:
28 | 3. Error message/stack trace:
29 | 4. Any other details that can be helpful:
30 |
31 | ## Python Driver trace logs
32 |
33 |
34 | ## Reproduction code
35 |
36 |
--------------------------------------------------------------------------------
/.github/PULL_REQUEST_TEMPLATE.md:
--------------------------------------------------------------------------------
1 |
2 |
3 | ## Description
4 |
5 |
6 | ## Motivation and Context
7 |
8 |
9 |
10 | ## Testing
11 |
12 |
13 |
14 |
15 | ## Screenshots (if appropriate)
16 |
17 | ## Types of changes
18 |
19 | - [ ] Bug fix (non-breaking change which fixes an issue)
20 | - [ ] New feature (non-breaking change which adds functionality)
21 |
22 | ## Checklist
23 |
24 |
25 |
26 | - [ ] Local run of `./build.sh` succeeds
27 | - [ ] Code changes have been run against the repository's pre-commit hooks
28 | - [ ] Commit messages follow [Conventional Commit Specification](https://www.conventionalcommits.org/en/v1.0.0/)
29 | - [ ] I have read the **README** document
30 | - [ ] I have added tests to cover my changes
31 | - [ ] I have run all unit tests using `pytest test/unit` and they are passing.
32 |
36 |
37 |
38 | - By submitting this pull request, I confirm that you can use, modify, copy, and redistribute this contribution, under the terms of your choice.
39 |
--------------------------------------------------------------------------------
/.github/dependabot.yml:
--------------------------------------------------------------------------------
1 | version: 2
2 | updates:
3 | - package-ecosystem: "pip"
4 | directory: "/"
5 | schedule:
6 | interval: "weekly"
7 | time: "09:00"
8 | timezone: "Europe/London"
9 | commit-message:
10 | prefix: "chore"
11 | prefix-development: "chore"
12 | include: "scope"
13 | open-pull-requests-limit: 8
14 |
--------------------------------------------------------------------------------
/.github/workflows/python-package.yml:
--------------------------------------------------------------------------------
1 | # This workflow will install Python dependencies, run tests and lint with a variety of Python versions
2 | # For more information see: https://help.github.com/actions/language-and-framework-guides/using-python-with-github-actions
3 |
4 | name: Python package
5 |
6 | on: workflow_dispatch
7 |
8 | jobs:
9 | build:
10 |
11 | runs-on: ubuntu-latest
12 | timeout-minutes: 15
13 |
14 | strategy:
15 | fail-fast: true
16 | matrix:
17 | python-version: ["3.6", "3.9"]
18 |
19 | steps:
20 | - uses: actions/checkout@v2
21 | - name: Set up Python ${{ matrix.python-version }}
22 | uses: actions/setup-python@v2
23 | with:
24 | python-version: ${{ matrix.python-version }}
25 | - name: Install dependencies
26 | run: |
27 | python -m pip install --upgrade pip
28 | python -m pip install wheel setuptools
29 | if [ -f requirements.txt ]; then pip install -r requirements.txt; fi
30 | if [ -f requirements-dev.txt ]; then pip install -r requirements-dev.txt; fi
31 | - name: Run Unit Tests
32 | run: |
33 | pytest test/unit/
34 |
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 |
2 |
--------------------------------------------------------------------------------
/.pre-commit-config.yaml:
--------------------------------------------------------------------------------
1 | repos:
2 | - repo: https://github.com/pre-commit/pre-commit-hooks
3 | rev: v4.4.0
4 | hooks:
5 | - id: check-yaml
6 | - id: end-of-file-fixer
7 | - id: trailing-whitespace
8 | - repo: https://github.com/pre-commit/mirrors-isort
9 | rev: v5.10.1 # must be >5.0.0 for black compatibility
10 | hooks:
11 | - id: isort
12 | args: ["--profile", "black", "."]
13 | - repo: https://github.com/ambv/black
14 | rev: 23.7.0
15 | hooks:
16 | - id: black
17 | - repo: https://github.com/pre-commit/mirrors-mypy
18 | rev: v1.4.1
19 | hooks:
20 | - id: mypy
21 | additional_dependencies: [types-setuptools, types-requests, types-python-dateutil]
22 | args: [--ignore-missing-imports, --disable-error-code, "annotation-unchecked"]
23 | verbose: true
24 |
--------------------------------------------------------------------------------
/CODE_OF_CONDUCT.md:
--------------------------------------------------------------------------------
1 | ## Code of Conduct
2 | This project has adopted the [Amazon Open Source Code of Conduct](https://aws.github.io/code-of-conduct).
3 | For more information see the [Code of Conduct FAQ](https://aws.github.io/code-of-conduct-faq) or contact
4 | opensource-codeofconduct@amazon.com with any additional questions or comments.
5 |
--------------------------------------------------------------------------------
/CONTRIBUTING.md:
--------------------------------------------------------------------------------
1 | # Contributing Guidelines
2 |
3 | Thank you for your interest in contributing to our project. Whether it's a bug report, new feature, correction, or additional
4 | documentation, we greatly value feedback and contributions from our community.
5 |
6 | Please read through this document before submitting any issues or pull requests to ensure we have all the necessary
7 | information to effectively respond to your bug report or contribution.
8 |
9 |
10 | ## Reporting Bugs/Feature Requests
11 |
12 | We welcome you to use the GitHub issue tracker to report bugs or suggest features.
13 |
14 | When filing an issue, please check [existing open](https://github.com/aws/amazon-redshift-python-driver/issues), or [recently closed](https://github.com/aws/amazon-redshift-python-driver/issues?utf8=%E2%9C%93&q=is%3Aissue%20is%3Aclosed%20), issues to make sure somebody else hasn't already
15 | reported the issue. Please try to include as much information as you can. Details like these are incredibly useful:
16 |
17 | * A reproducible test case or series of steps
18 | * The version of our code being used
19 | * Any modifications you've made relevant to the bug
20 | * Anything unusual about your environment or deployment
21 |
22 |
23 | ## Contributing via Pull Requests
24 | Contributions via pull requests are much appreciated. Before sending us a pull request, please ensure that:
25 |
26 | 1. You are working against the latest source on the *master* branch.
27 | 2. You check existing open, and recently merged, pull requests to make sure someone else hasn't addressed the problem already.
28 | 3. You open an issue to discuss any significant work - we would hate for your time to be wasted.
29 |
30 | To send us a pull request, please:
31 |
32 | 1. Fork the repository.
33 | 2. Modify the source; please focus on the specific change you are contributing. If you also reformat all the code, it will be hard for us to focus on your change.
34 | 3. Pre-commit hooks have been run.
35 | 4. Ensure unit tests are passing by running `pytest test/unit`
36 | 5. Commit to your fork using clear commit messages that follow [Conventional Commit](https://www.conventionalcommits.org/en/v1.0.0/) specification.
37 | 6. Send us a pull request, answering any default questions in the pull request interface.
38 | 7. Pay attention to any automated CI failures reported in the pull request, and stay involved in the conversation.
39 |
40 | GitHub provides additional document on [forking a repository](https://help.github.com/articles/fork-a-repo/) and
41 | [creating a pull request](https://help.github.com/articles/creating-a-pull-request/).
42 |
43 |
44 | ## Finding contributions to work on
45 | Looking at the existing issues is a great way to find something to contribute on. As our projects, by default, use the default GitHub issue labels (enhancement/bug/duplicate/help wanted/invalid/question/wontfix), looking at any 'help wanted' issues is a great place to start.
46 |
47 |
48 | ## Code of Conduct
49 | This project has adopted the [Amazon Open Source Code of Conduct](https://aws.github.io/code-of-conduct).
50 | For more information see the [Code of Conduct FAQ](https://aws.github.io/code-of-conduct-faq) or contact
51 | opensource-codeofconduct@amazon.com with any additional questions or comments.
52 |
53 |
54 | ## Security issue notifications
55 | If you discover a potential security issue in this project we ask that you notify AWS/Amazon Security via our [vulnerability reporting page](http://aws.amazon.com/security/vulnerability-reporting/). Please do **not** create a public github issue.
56 |
57 |
58 | ## Licensing
59 |
60 | See the [LICENSE](LICENSE) file for our project's licensing. We will ask you to confirm the licensing of your contribution.
61 |
62 | We may ask you to sign a [Contributor License Agreement (CLA)](http://en.wikipedia.org/wiki/Contributor_License_Agreement) for larger changes.
63 |
--------------------------------------------------------------------------------
/MANIFEST.in:
--------------------------------------------------------------------------------
1 | include redshift_connector/files/*
2 | recursive-include files *.crt
--------------------------------------------------------------------------------
/NOTICE:
--------------------------------------------------------------------------------
1 | Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
2 |
--------------------------------------------------------------------------------
/build.sh:
--------------------------------------------------------------------------------
1 | python3 setup.py sdist bdist_wheel
2 |
--------------------------------------------------------------------------------
/redshift_connector/auth/__init__.py:
--------------------------------------------------------------------------------
1 | from .aws_credentials_provider import AWSCredentialsProvider
2 |
--------------------------------------------------------------------------------
/redshift_connector/auth/aws_credentials_provider.py:
--------------------------------------------------------------------------------
1 | import logging
2 | import typing
3 |
4 | from redshift_connector.credentials_holder import (
5 | ABCCredentialsHolder,
6 | AWSDirectCredentialsHolder,
7 | AWSProfileCredentialsHolder,
8 | )
9 | from redshift_connector.error import InterfaceError
10 |
11 | _logger: logging.Logger = logging.getLogger(__name__)
12 |
13 | if typing.TYPE_CHECKING:
14 | import boto3 # type: ignore
15 |
16 | from redshift_connector.redshift_property import RedshiftProperty
17 |
18 |
19 | class AWSCredentialsProvider:
20 | """
21 | A credential provider class for AWS credentials specified via :func:`~redshift_connector.connect` using `profile` or AWS access keys.
22 | """
23 |
24 | def __init__(self: "AWSCredentialsProvider") -> None:
25 | self.cache: typing.Dict[int, typing.Union[AWSDirectCredentialsHolder, AWSProfileCredentialsHolder]] = {}
26 |
27 | self.access_key_id: typing.Optional[str] = None
28 | self.secret_access_key: typing.Optional[str] = None
29 | self.session_token: typing.Optional[str] = None
30 | self.profile: typing.Optional["boto3.Session"] = None
31 |
32 | def get_cache_key(self: "AWSCredentialsProvider") -> int:
33 | """
34 | Creates a cache key using the hash of either the end-user provided AWS credential information.
35 |
36 | Returns
37 | -------
38 | An `int` hash representation of the non-secret portion of credential information: `int`
39 | """
40 | if self.profile:
41 | return hash(self.profile)
42 | else:
43 | return hash(self.access_key_id)
44 |
45 | def get_credentials(
46 | self: "AWSCredentialsProvider",
47 | ) -> typing.Union[AWSDirectCredentialsHolder, AWSProfileCredentialsHolder]:
48 | """
49 | Retrieves a :class`ABCCredentialsHolder` from cache or builds one.
50 |
51 | Returns
52 | -------
53 | An `AWSCredentialsHolder` object containing end-user specified AWS credential information: :class`ABCAWSCredentialsHolder`
54 | """
55 | key: int = self.get_cache_key()
56 | if key not in self.cache:
57 | try:
58 | self.refresh()
59 | except Exception as e:
60 | exec_msg: str = "Refreshing IdP credentials failed"
61 | _logger.debug(exec_msg)
62 | raise InterfaceError(e)
63 |
64 | credentials: typing.Union[AWSDirectCredentialsHolder, AWSProfileCredentialsHolder] = self.cache[key]
65 |
66 | if credentials is None:
67 | exec_msg = "Unable to load AWS credentials from cache"
68 | _logger.debug(exec_msg)
69 | raise InterfaceError(exec_msg)
70 |
71 | return credentials
72 |
73 | def add_parameter(self: "AWSCredentialsProvider", info: "RedshiftProperty") -> None:
74 | """
75 | Defines instance variables used for creating a :class`ABCCredentialsHolder` object and associated :class:`boto3.Session`
76 |
77 | Parameters
78 | ----------
79 | info : :class:`RedshiftProperty`
80 | The :class:`RedshiftProperty` object created using end-user specified values passed to :func:`~redshift_connector.connect`
81 | """
82 | self.access_key_id = info.access_key_id
83 | self.secret_access_key = info.secret_access_key
84 | self.session_token = info.session_token
85 | self.profile = info.profile
86 |
87 | def refresh(self: "AWSCredentialsProvider") -> None:
88 | """
89 | Establishes a :class:`boto3.Session` using end-user specified AWS credential information
90 | """
91 | import boto3 # type: ignore
92 |
93 | args: typing.Dict[str, str] = {}
94 |
95 | if self.profile is not None:
96 | args["profile_name"] = self.profile
97 | elif self.access_key_id is not None:
98 | args["aws_access_key_id"] = self.access_key_id
99 | args["aws_secret_access_key"] = typing.cast(str, self.secret_access_key)
100 | if self.session_token is not None:
101 | args["aws_session_token"] = self.session_token
102 |
103 | session: boto3.Session = boto3.Session(**args)
104 | credentials: typing.Optional[typing.Union[AWSProfileCredentialsHolder, AWSDirectCredentialsHolder]] = None
105 |
106 | if self.profile is not None:
107 | credentials = AWSProfileCredentialsHolder(profile=self.profile, session=session)
108 | else:
109 | credentials = AWSDirectCredentialsHolder(
110 | access_key_id=typing.cast(str, self.access_key_id),
111 | secret_access_key=typing.cast(str, self.secret_access_key),
112 | session_token=self.session_token,
113 | session=session,
114 | )
115 |
116 | key = self.get_cache_key()
117 | self.cache[key] = credentials
118 |
--------------------------------------------------------------------------------
/redshift_connector/error.py:
--------------------------------------------------------------------------------
1 | class Warning(Exception):
2 | """Generic exception raised for important database warnings like data
3 | truncations. This exception is not currently used.
4 |
5 | This exception is part of the `DBAPI 2.0 specification
6 | `_.
7 | """
8 |
9 | pass
10 |
11 |
12 | class Error(Exception):
13 | """Generic exception that is the base exception of all other error
14 | exceptions.
15 |
16 | This exception is part of the `DBAPI 2.0 specification
17 | `_.
18 | """
19 |
20 | pass
21 |
22 |
23 | class InterfaceError(Error):
24 | """Generic exception raised for errors that are related to the database
25 | interface rather than the database itself. For example, if the interface
26 | attempts to use an SSL connection but the server refuses, an InterfaceError
27 | will be raised.
28 |
29 | This exception is part of the `DBAPI 2.0 specification
30 | `_.
31 | """
32 |
33 | pass
34 |
35 |
36 | class DatabaseError(Error):
37 | """Generic exception raised for errors that are related to the database.
38 | This exception is currently never raised.
39 |
40 | This exception is part of the `DBAPI 2.0 specification
41 | `_.
42 | """
43 |
44 | pass
45 |
46 |
47 | class DataError(DatabaseError):
48 | """Generic exception raised for errors that are due to problems with the
49 | processed data. This exception is not currently raised.
50 |
51 | This exception is part of the `DBAPI 2.0 specification
52 | `_.
53 | """
54 |
55 | pass
56 |
57 |
58 | class OperationalError(DatabaseError):
59 | """
60 | Generic exception raised for errors that are related to the database's
61 | operation and not necessarily under the control of the programmer. This
62 | exception is currently never raised.
63 |
64 | This exception is part of the `DBAPI 2.0 specification
65 | `_.
66 | """
67 |
68 | pass
69 |
70 |
71 | class IntegrityError(DatabaseError):
72 | """
73 | Generic exception raised when the relational integrity of the database is
74 | affected. This exception is not currently raised.
75 |
76 | This exception is part of the `DBAPI 2.0 specification
77 | `_.
78 | """
79 |
80 | pass
81 |
82 |
83 | class InternalError(DatabaseError):
84 | """Generic exception raised when the database encounters an internal error.
85 | This is currently only raised when unexpected state occurs in the
86 | interface itself, and is typically the result of a interface bug.
87 |
88 | This exception is part of the `DBAPI 2.0 specification
89 | `_.
90 | """
91 |
92 | pass
93 |
94 |
95 | class ProgrammingError(DatabaseError):
96 | """Generic exception raised for programming errors. For example, this
97 | exception is raised if more parameter fields are in a query string than
98 | there are available parameters.
99 |
100 | This exception is part of the `DBAPI 2.0 specification
101 | `_.
102 | """
103 |
104 | pass
105 |
106 |
107 | class NotSupportedError(DatabaseError):
108 | """Generic exception raised in case a method or database API was used which
109 | is not supported by the database.
110 |
111 | This exception is part of the `DBAPI 2.0 specification
112 | `_.
113 | """
114 |
115 | pass
116 |
117 |
118 | class ArrayContentNotSupportedError(NotSupportedError):
119 | """
120 | Raised when attempting to transmit an array where the base type is not
121 | supported for binary data transfer by the interface.
122 | """
123 |
124 | pass
125 |
126 |
127 | class ArrayContentNotHomogenousError(ProgrammingError):
128 | """
129 | Raised when attempting to transmit an array that doesn't contain only a
130 | single type of object.
131 | """
132 |
133 | pass
134 |
135 |
136 | class ArrayDimensionsNotConsistentError(ProgrammingError):
137 | """
138 | Raised when attempting to transmit an array that has inconsistent
139 | multi-dimension sizes.
140 | """
141 |
142 | pass
143 |
144 |
145 | MISSING_MODULE_ERROR_MSG: str = (
146 | "redshift_connector requires {module} support for this functionality. "
147 | "Please install redshift_connector[full] for {module} support"
148 | )
149 |
--------------------------------------------------------------------------------
/redshift_connector/native_plugin_helper.py:
--------------------------------------------------------------------------------
1 | import logging
2 | import typing
3 |
4 | from redshift_connector.error import InterfaceError
5 | from redshift_connector.idp_auth_helper import IdpAuthHelper
6 | from redshift_connector.plugin.i_native_plugin import INativePlugin
7 |
8 | if typing.TYPE_CHECKING:
9 | from redshift_connector.plugin.native_token_holder import NativeTokenHolder
10 | from redshift_connector.redshift_property import RedshiftProperty
11 |
12 | logging.getLogger(__name__).addHandler(logging.NullHandler())
13 | _logger: logging.Logger = logging.getLogger(__name__)
14 |
15 |
16 | class NativeAuthPluginHelper:
17 | @staticmethod
18 | def set_native_auth_plugin_properties(info: "RedshiftProperty") -> None:
19 | """
20 | Modifies ``info`` to prepare for authentication with Amazon Redshift
21 |
22 | Parameters
23 | ----------
24 | info: RedshiftProperty
25 | RedshiftProperty object storing user defined and derived attributes used for authentication
26 |
27 | Returns
28 | -------
29 | None:None
30 | """
31 | if info.credentials_provider:
32 | # include the authentication token which will be used for authentication via
33 | # Redshift Native IDP Integration
34 | _logger.debug("Attempting to get native auth plugin credentials")
35 | idp_token: str = NativeAuthPluginHelper.get_native_auth_plugin_credentials(info)
36 | if idp_token:
37 | _logger.debug("setting info.web_identity_token")
38 | info.put("web_identity_token", idp_token)
39 |
40 | @staticmethod
41 | def get_native_auth_plugin_credentials(info: "RedshiftProperty") -> str:
42 | """
43 | Retrieves credentials for Amazon Redshift native authentication.
44 |
45 | Parameters
46 | ----------
47 | info: RedshiftProperty
48 | RedshiftProperty object storing user defined and derived attributes used for authentication
49 |
50 | Returns
51 | -------
52 | str: An authentication token compatible with Redshift Native IDP Integration (code 14)
53 | """
54 | idp_token: typing.Optional[str] = None
55 | provider = None
56 |
57 | if info.credentials_provider:
58 | provider = IdpAuthHelper.load_credentials_provider(info)
59 |
60 | if not isinstance(provider, INativePlugin):
61 | _logger.debug("Native auth will not be used, no credentials provider specified")
62 | return ""
63 | else:
64 | raise InterfaceError(
65 | "Required credentials_provider parameter is null or empty: {}".format(info.credentials_provider)
66 | )
67 |
68 | _logger.debug("Native IDP Credential Provider %s:%s", provider, info.credentials_provider)
69 | _logger.debug("Calling provider.getCredentials()")
70 |
71 | # Provider will cache the credentials, it's OK to call get_credentials() here
72 | credentials: "NativeTokenHolder" = typing.cast("NativeTokenHolder", provider.get_credentials())
73 |
74 | _logger.debug("credentials is None = %s", credentials is None)
75 | _logger.debug("credentials.is_expired() = %s", credentials.is_expired())
76 |
77 | if credentials is None or (credentials.expiration is not None and credentials.is_expired()):
78 | # get idp token
79 | plugin: INativePlugin = provider
80 | _logger.debug("Unable to get IdP token from cache. Calling plugin.get_idp_token()")
81 |
82 | idp_token = plugin.get_idp_token()
83 | _logger.debug("IdP token retrieved")
84 | info.put("idp_token", idp_token)
85 | else:
86 | _logger.debug("Cached idp_token will be used")
87 | idp_token = credentials.access_token
88 |
89 | return idp_token
90 |
--------------------------------------------------------------------------------
/redshift_connector/objects.py:
--------------------------------------------------------------------------------
1 | from datetime import date
2 | from datetime import datetime as Datetime
3 | from datetime import time
4 | from time import localtime
5 |
6 |
7 | def Date(year: int, month: int, day: int) -> date:
8 | """Constuct an object holding a date value.
9 |
10 | This function is part of the `DBAPI 2.0 specification
11 | `_.
12 |
13 | :rtype: :class:`datetime.date`
14 | """
15 | return date(year, month, day)
16 |
17 |
18 | def Time(hour: int, minute: int, second: int) -> time:
19 | """Construct an object holding a time value.
20 |
21 | This function is part of the `DBAPI 2.0 specification
22 | `_.
23 |
24 | :rtype: :class:`datetime.time`
25 | """
26 | return time(hour, minute, second)
27 |
28 |
29 | def Timestamp(year: int, month: int, day: int, hour: int, minute: int, second: int) -> Datetime:
30 | """Construct an object holding a timestamp value.
31 |
32 | This function is part of the `DBAPI 2.0 specification
33 | `_.
34 |
35 | :rtype: :class:`datetime.datetime`
36 | """
37 | return Datetime(year, month, day, hour, minute, second)
38 |
39 |
40 | def DateFromTicks(ticks: float) -> date:
41 | """Construct an object holding a date value from the given ticks value
42 | (number of seconds since the epoch).
43 |
44 | This function is part of the `DBAPI 2.0 specification
45 | `_.
46 |
47 | :rtype: :class:`datetime.date`
48 | """
49 | return Date(*localtime(ticks)[:3])
50 |
51 |
52 | def TimeFromTicks(ticks: float) -> time:
53 | """Construct an objet holding a time value from the given ticks value
54 | (number of seconds since the epoch).
55 |
56 | This function is part of the `DBAPI 2.0 specification
57 | `_.
58 |
59 | :rtype: :class:`datetime.time`
60 | """
61 | return Time(*localtime(ticks)[3:6])
62 |
63 |
64 | def TimestampFromTicks(ticks: float) -> Datetime:
65 | """Construct an object holding a timestamp value from the given ticks value
66 | (number of seconds since the epoch).
67 |
68 | This function is part of the `DBAPI 2.0 specification
69 | `_.
70 |
71 | :rtype: :class:`datetime.datetime`
72 | """
73 | return Timestamp(*localtime(ticks)[:6])
74 |
75 |
76 | def Binary(value: bytes):
77 | """Construct an object holding binary data.
78 |
79 | This function is part of the `DBAPI 2.0 specification
80 | `_.
81 |
82 | """
83 | return value
84 |
--------------------------------------------------------------------------------
/redshift_connector/pg_types.py:
--------------------------------------------------------------------------------
1 | from json import dumps
2 |
3 |
4 | class PGType:
5 | def __init__(self: "PGType", value) -> None:
6 | self.value: str = value
7 |
8 | def encode(self, encoding) -> bytes:
9 | return str(self.value).encode(encoding)
10 |
11 |
12 | class PGEnum(PGType):
13 | def __init__(self: "PGEnum", value) -> None:
14 | if isinstance(value, str):
15 | self.value = value
16 | else:
17 | self.value = value.value
18 |
19 |
20 | class PGJson(PGType):
21 | def encode(self: "PGJson", encoding: str) -> bytes:
22 | return dumps(self.value).encode(encoding)
23 |
24 |
25 | class PGJsonb(PGType):
26 | def encode(self: "PGJsonb", encoding: str) -> bytes:
27 | return dumps(self.value).encode(encoding)
28 |
29 |
30 | class PGTsvector(PGType):
31 | pass
32 |
33 |
34 | class PGVarchar(str):
35 | pass
36 |
37 |
38 | class PGText(str):
39 | pass
40 |
--------------------------------------------------------------------------------
/redshift_connector/plugin/__init__.py:
--------------------------------------------------------------------------------
1 | from .adfs_credentials_provider import AdfsCredentialsProvider
2 | from .azure_credentials_provider import AzureCredentialsProvider
3 | from .browser_azure_credentials_provider import BrowserAzureCredentialsProvider
4 | from .browser_azure_oauth2_credentials_provider import (
5 | BrowserAzureOAuth2CredentialsProvider,
6 | )
7 | from .browser_idc_auth_plugin import BrowserIdcAuthPlugin
8 | from .browser_saml_credentials_provider import BrowserSamlCredentialsProvider
9 | from .common_credentials_provider import CommonCredentialsProvider
10 | from .idp_credentials_provider import IdpCredentialsProvider
11 | from .idp_token_auth_plugin import IdpTokenAuthPlugin
12 | from .jwt_credentials_provider import (
13 | BasicJwtCredentialsProvider,
14 | JwtCredentialsProvider,
15 | )
16 | from .okta_credentials_provider import OktaCredentialsProvider
17 | from .ping_credentials_provider import PingCredentialsProvider
18 | from .saml_credentials_provider import SamlCredentialsProvider
19 |
--------------------------------------------------------------------------------
/redshift_connector/plugin/browser_saml_credentials_provider.py:
--------------------------------------------------------------------------------
1 | import concurrent.futures
2 | import logging
3 | import re
4 | import socket
5 | import typing
6 | import urllib.parse
7 |
8 | from redshift_connector.error import InterfaceError
9 | from redshift_connector.plugin.saml_credentials_provider import SamlCredentialsProvider
10 | from redshift_connector.redshift_property import RedshiftProperty
11 |
12 | _logger: logging.Logger = logging.getLogger(__name__)
13 |
14 |
15 | # Class to get SAML Response
16 | class BrowserSamlCredentialsProvider(SamlCredentialsProvider):
17 | """
18 | Generic Identity Provider Browser Plugin providing multi-factor authentication access to an Amazon Redshift cluster using an identity provider of your choice.
19 | """
20 |
21 | def __init__(self: "BrowserSamlCredentialsProvider") -> None:
22 | super().__init__()
23 | self.login_url: typing.Optional[str] = None
24 |
25 | self.idp_response_timeout: int = 120
26 | self.listen_port: int = 7890
27 |
28 | # method to grab the field parameters specified by end user.
29 | # This method adds to it specific parameters.
30 | def add_parameter(self: "BrowserSamlCredentialsProvider", info: RedshiftProperty) -> None:
31 | super().add_parameter(info)
32 | self.login_url = info.login_url
33 |
34 | self.idp_response_timeout = info.idp_response_timeout
35 | self.listen_port = info.listen_port
36 |
37 | # Required method to grab the SAML Response. Used in base class to refresh temporary credentials.
38 | def get_saml_assertion(self: "BrowserSamlCredentialsProvider") -> str:
39 | _logger.debug("BrowserSamlCredentialsProvider.get_saml_assertion")
40 |
41 | if self.login_url == "" or self.login_url is None:
42 | BrowserSamlCredentialsProvider.handle_missing_required_property("login_url")
43 |
44 | if self.idp_response_timeout < 10:
45 | BrowserSamlCredentialsProvider.handle_invalid_property_value(
46 | "idp_response_timeout", "Must be 10 seconds or greater"
47 | )
48 | if self.listen_port < 1 or self.listen_port > 65535:
49 | BrowserSamlCredentialsProvider.handle_invalid_property_value("listen_port", "Must be in range [1,65535]")
50 |
51 | return self.authenticate()
52 |
53 | # Authentication consists of:
54 | # Start the Socket Server on the port {@link BrowserSamlCredentialsProvider#m_listen_port}.
55 | # Open the default browser with the link asking a User to enter the credentials.
56 | # Retrieve the SAML Assertion string from the response.
57 | def authenticate(self: "BrowserSamlCredentialsProvider") -> str:
58 | _logger.debug("BrowserSamlCredentialsProvider.authenticate")
59 |
60 | try:
61 | with concurrent.futures.ThreadPoolExecutor() as executor:
62 | _logger.debug("Listening for connection on port %s", self.listen_port)
63 | future = executor.submit(self.run_server, self.listen_port, self.idp_response_timeout)
64 | self.open_browser()
65 | return_value: str = future.result()
66 |
67 | samlresponse = urllib.parse.unquote(return_value)
68 | return str(samlresponse)
69 | except socket.error as e:
70 | _logger.debug("Socket error: %s", e)
71 | raise e
72 | except Exception as e:
73 | _logger.debug("Other Exception: %s", e)
74 | raise e
75 |
76 | def run_server(self: "BrowserSamlCredentialsProvider", listen_port: int, idp_response_timeout: int) -> str:
77 | _logger.debug("BrowserSamlCredentialsProvider.run_server")
78 | HOST: str = "127.0.0.1"
79 | PORT: int = listen_port
80 |
81 | with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
82 | s.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
83 | _logger.debug("attempting socket bind on host %s port %s", HOST, PORT)
84 | s.bind((HOST, PORT))
85 | s.listen()
86 | conn, addr = s.accept() # typing.Tuple[Socket, Any]
87 | _logger.debug("Localhost socket connection established for Browser SAML IdP")
88 | conn.settimeout(float(idp_response_timeout))
89 | size: int = 102400
90 | with conn:
91 | while True:
92 | part: bytes = conn.recv(size)
93 | decoded_part: str = part.decode()
94 | result: typing.Optional[typing.Match] = re.search(
95 | pattern="SAMLResponse[:=]+[\\n\\r]*", string=decoded_part, flags=re.MULTILINE
96 | )
97 | _logger.debug("Data received contained SAML Response: %s", result is not None)
98 |
99 | if result is not None:
100 | conn.send(self.close_window_http_resp())
101 | saml_resp_block: str = decoded_part[result.regs[0][1] :]
102 | end_idx: int = saml_resp_block.find("&RelayState=")
103 | if end_idx > -1:
104 | saml_resp_block = saml_resp_block[:end_idx]
105 | return saml_resp_block
106 |
107 | # Opens the default browser with the authorization request to the web service.
108 | def open_browser(self: "BrowserSamlCredentialsProvider") -> None:
109 | _logger.debug("BrowserSamlCredentialsProvider.open_browser")
110 | import webbrowser
111 |
112 | url: typing.Optional[str] = self.login_url
113 |
114 | if url is None:
115 | BrowserSamlCredentialsProvider.handle_missing_required_property("login_url")
116 | self.validate_url(typing.cast(str, url))
117 | webbrowser.open(typing.cast(str, url))
118 |
--------------------------------------------------------------------------------
/redshift_connector/plugin/common_credentials_provider.py:
--------------------------------------------------------------------------------
1 | import logging
2 | import typing
3 | from abc import abstractmethod
4 |
5 | from redshift_connector.error import InterfaceError
6 | from redshift_connector.iam_helper import IamHelper
7 | from redshift_connector.plugin.i_native_plugin import INativePlugin
8 | from redshift_connector.plugin.idp_credentials_provider import IdpCredentialsProvider
9 | from redshift_connector.plugin.native_token_holder import NativeTokenHolder
10 | from redshift_connector.redshift_property import RedshiftProperty
11 |
12 | _logger: logging.Logger = logging.getLogger(__name__)
13 |
14 |
15 | class CommonCredentialsProvider(INativePlugin, IdpCredentialsProvider):
16 | """
17 | Abstract base class for authentication plugins using IdC authentication.
18 | """
19 |
20 | def __init__(self: "CommonCredentialsProvider") -> None:
21 | super().__init__()
22 | self.last_refreshed_credentials: typing.Optional[NativeTokenHolder] = None
23 |
24 | @abstractmethod
25 | def get_auth_token(self: "CommonCredentialsProvider") -> str:
26 | """
27 | Returns the auth token retrieved from corresponding plugin
28 | """
29 | pass # pragma: no cover
30 |
31 | def add_parameter(
32 | self: "CommonCredentialsProvider",
33 | info: RedshiftProperty,
34 | ) -> None:
35 | self.disable_cache = True
36 |
37 | def get_credentials(self: "CommonCredentialsProvider") -> NativeTokenHolder:
38 | credentials: typing.Optional[NativeTokenHolder] = None
39 |
40 | if not self.disable_cache:
41 | key = self.get_cache_key()
42 | credentials = typing.cast(NativeTokenHolder, self.cache.get(key))
43 |
44 | if not credentials or credentials.is_expired():
45 | if self.disable_cache:
46 | _logger.debug("Auth token Cache disabled : fetching new token")
47 | else:
48 | _logger.debug("Auth token Cache enabled - No auth token found from cache : fetching new token")
49 |
50 | self.refresh()
51 |
52 | if self.disable_cache:
53 | credentials = self.last_refreshed_credentials
54 | self.last_refreshed_credentials = None
55 | else:
56 | credentials.refresh = False
57 | _logger.debug("Auth token found from cache")
58 |
59 | if not self.disable_cache:
60 | credentials = typing.cast(NativeTokenHolder, self.cache[key])
61 | return typing.cast(NativeTokenHolder, credentials)
62 |
63 | def refresh(self: "CommonCredentialsProvider") -> None:
64 | auth_token: str = self.get_auth_token()
65 | _logger.debug("auth token: {}".format(auth_token))
66 |
67 | if auth_token is None:
68 | raise InterfaceError("IdC authentication failed : An error occurred during the request.")
69 |
70 | credentials: NativeTokenHolder = NativeTokenHolder(access_token=auth_token, expiration=None)
71 | credentials.refresh = True
72 |
73 | _logger.debug("disable_cache={}".format(str(self.disable_cache)))
74 | if not self.disable_cache:
75 | self.cache[self.get_cache_key()] = credentials
76 | else:
77 | self.last_refreshed_credentials = credentials
78 |
79 | def get_idp_token(self: "CommonCredentialsProvider") -> str:
80 | auth_token: str = self.get_auth_token()
81 | return auth_token
82 |
83 | def set_group_federation(self: "CommonCredentialsProvider", group_federation: bool):
84 | pass
85 |
86 | def get_sub_type(self: "CommonCredentialsProvider") -> int:
87 | return IamHelper.IDC_PLUGIN
88 |
89 | def get_cache_key(self: "CommonCredentialsProvider") -> str:
90 | return ""
91 |
--------------------------------------------------------------------------------
/redshift_connector/plugin/credential_provider_constants.py:
--------------------------------------------------------------------------------
1 | import typing
2 |
3 | azure_headers: typing.Dict[str, str] = {
4 | "Content-Type": "application/x-www-form-urlencoded",
5 | "Accept": "application/json",
6 | }
7 | okta_headers: typing.Dict[str, str] = {
8 | "Accept": "application/json",
9 | "Content-Type": "application/json",
10 | "Cache-Control": "no-cache",
11 | }
12 | # order of preference when searching for attributes in SAML response
13 | SAML_RESP_NAMESPACES: typing.List[str] = ["saml2:", "saml:", ""]
14 |
--------------------------------------------------------------------------------
/redshift_connector/plugin/i_native_plugin.py:
--------------------------------------------------------------------------------
1 | from abc import abstractmethod
2 |
3 | from redshift_connector.plugin.i_plugin import IPlugin
4 |
5 |
6 | class INativePlugin(IPlugin):
7 | """
8 | Abstract base class for authentication plugins using Redshift Native authentication
9 | """
10 |
11 | @abstractmethod
12 | def get_idp_token(self: "INativePlugin") -> str:
13 | """
14 | Returns the IdP token retrieved after authenticating with the plugin.
15 | """
16 | pass # pragma: no cover
17 |
--------------------------------------------------------------------------------
/redshift_connector/plugin/i_plugin.py:
--------------------------------------------------------------------------------
1 | import logging
2 | import typing
3 | from abc import ABC, abstractmethod
4 |
5 | if typing.TYPE_CHECKING:
6 | from redshift_connector.credentials_holder import CredentialsHolder
7 | from redshift_connector.plugin.native_token_holder import NativeTokenHolder
8 | from redshift_connector.redshift_property import RedshiftProperty
9 |
10 | _logger: logging.Logger = logging.getLogger(__name__)
11 |
12 |
13 | class IPlugin(ABC):
14 | @abstractmethod
15 | def add_parameter(self: "IPlugin", info: "RedshiftProperty"):
16 | """
17 | Defines instance attributes from the :class:RedshiftProperty object associated with the current authentication session.
18 | """
19 | pass
20 |
21 | @abstractmethod
22 | def get_cache_key(self: "IPlugin") -> str:
23 | pass
24 |
25 | @abstractmethod
26 | def get_credentials(self: "IPlugin") -> typing.Union["NativeTokenHolder", "CredentialsHolder"]:
27 | """
28 | Returns a :class:NativeTokenHolder object associated with the current plugin.
29 | """
30 | pass # pragma: no cover
31 |
32 | @abstractmethod
33 | def get_sub_type(self: "IPlugin") -> int:
34 | """
35 | Returns a type code indicating the type of the current plugin.
36 |
37 | See :class:IdpAuthHelper for possible return values
38 |
39 | """
40 | pass # pragma: no cover
41 |
42 | @abstractmethod
43 | def refresh(self: "IPlugin") -> None:
44 | """
45 | Refreshes the credentials, stored in :class:NativeTokenHolder, for the current plugin.
46 | """
47 | pass # pragma: no cover
48 |
49 | @abstractmethod
50 | def set_group_federation(self: "IPlugin", group_federation: bool):
51 | pass
52 |
53 | @staticmethod
54 | def handle_missing_required_property(property_name: str) -> None:
55 | from redshift_connector import InterfaceError
56 |
57 | error_msg: str = "Missing required connection property: {}".format(property_name)
58 | _logger.debug(error_msg)
59 | raise InterfaceError(error_msg)
60 |
61 | @staticmethod
62 | def handle_invalid_property_value(property_name: str, reason: str) -> None:
63 | from redshift_connector import InterfaceError
64 |
65 | error_msg: str = "Invalid value specified for connection property: {}. {}".format(property_name, reason)
66 | _logger.debug(error_msg)
67 | raise InterfaceError(error_msg)
68 |
--------------------------------------------------------------------------------
/redshift_connector/plugin/idp_credentials_provider.py:
--------------------------------------------------------------------------------
1 | import typing
2 | from abc import abstractmethod
3 |
4 | from redshift_connector.error import InterfaceError
5 | from redshift_connector.plugin.i_plugin import IPlugin
6 | from redshift_connector.redshift_property import IAM_URL_PATTERN, RedshiftProperty
7 |
8 | if typing.TYPE_CHECKING:
9 | from redshift_connector.credentials_holder import ABCCredentialsHolder
10 | from redshift_connector.plugin.native_token_holder import NativeTokenHolder
11 |
12 |
13 | class IdpCredentialsProvider(IPlugin):
14 | """
15 | Abstract base class for authentication plugins.
16 | """
17 |
18 | def __init__(self: "IdpCredentialsProvider") -> None:
19 | self.cache: typing.Dict[str, typing.Union[NativeTokenHolder, ABCCredentialsHolder]] = {}
20 |
21 | @staticmethod
22 | def close_window_http_resp() -> bytes:
23 | """
24 | Builds the client facing HTML contents notifying that the authentication window may be closed.
25 | """
26 | return str.encode(
27 | "HTTP/1.1 200 OK\nContent-Type: text/html\n\n"
28 | + "
Thank you for using Amazon Redshift! You can now close this window.\n"
29 | )
30 |
31 | @abstractmethod
32 | def check_required_parameters(self: "IdpCredentialsProvider") -> None:
33 | """
34 | Performs validation on client provided parameters used by the IdP.
35 | """
36 | pass # pragma: no cover
37 |
38 | @classmethod
39 | def validate_url(cls, uri: str) -> None:
40 | import re
41 |
42 | if not re.fullmatch(pattern=IAM_URL_PATTERN, string=str(uri)):
43 | raise InterfaceError("URI: {} is an invalid web address".format(uri))
44 |
--------------------------------------------------------------------------------
/redshift_connector/plugin/idp_token_auth_plugin.py:
--------------------------------------------------------------------------------
1 | import logging
2 | import typing
3 |
4 | from redshift_connector.error import InterfaceError
5 | from redshift_connector.plugin.common_credentials_provider import (
6 | CommonCredentialsProvider,
7 | )
8 | from redshift_connector.redshift_property import RedshiftProperty
9 |
10 | logging.getLogger(__name__).addHandler(logging.NullHandler())
11 | _logger: logging.Logger = logging.getLogger(__name__)
12 |
13 |
14 | class IdpTokenAuthPlugin(CommonCredentialsProvider):
15 | """
16 | A basic IdP Token auth plugin class. This plugin class allows clients to directly provide any auth token that is handled by Redshift.
17 | """
18 |
19 | def __init__(self: "IdpTokenAuthPlugin") -> None:
20 | super().__init__()
21 | self.token: typing.Optional[str] = None
22 | self.token_type: typing.Optional[str] = None
23 |
24 | def add_parameter(
25 | self: "IdpTokenAuthPlugin",
26 | info: RedshiftProperty,
27 | ) -> None:
28 | super().add_parameter(info)
29 | self.token = info.token
30 | self.token_type = info.token_type
31 | _logger.debug("Setting token_type = {}".format(self.token_type))
32 |
33 | def check_required_parameters(self: "IdpTokenAuthPlugin") -> None:
34 | super().check_required_parameters()
35 | if not self.token:
36 | _logger.error("IdC authentication failed: token needs to be provided in connection params")
37 | raise InterfaceError("IdC authentication failed: The token must be included in the connection parameters.")
38 | if not self.token_type:
39 | _logger.error("IdC authentication failed: token_type needs to be provided in connection params")
40 | raise InterfaceError(
41 | "IdC authentication failed: The token type must be included in the connection parameters."
42 | )
43 |
44 | def get_cache_key(self: "IdpTokenAuthPlugin") -> str: # type: ignore
45 | pass
46 |
47 | def get_auth_token(self: "IdpTokenAuthPlugin") -> str:
48 | self.check_required_parameters()
49 | return typing.cast(str, self.token)
50 |
--------------------------------------------------------------------------------
/redshift_connector/plugin/jwt_credentials_provider.py:
--------------------------------------------------------------------------------
1 | import logging
2 | import typing
3 | from abc import abstractmethod
4 |
5 | from redshift_connector.error import InterfaceError
6 | from redshift_connector.iam_helper import IamHelper
7 | from redshift_connector.plugin.i_native_plugin import INativePlugin
8 | from redshift_connector.plugin.idp_credentials_provider import IdpCredentialsProvider
9 | from redshift_connector.plugin.native_token_holder import NativeTokenHolder
10 | from redshift_connector.redshift_property import RedshiftProperty
11 |
12 | _logger: logging.Logger = logging.getLogger(__name__)
13 |
14 |
15 | class JwtCredentialsProvider(INativePlugin, IdpCredentialsProvider):
16 | """
17 | Abstract base class for authentication plugins using JWT for Redshift native authentication.
18 | """
19 |
20 | KEY_ROLE_ARN: str = "role_arn"
21 | KEY_WEB_IDENTITY_TOKEN: str = "web_identity_token"
22 | KEY_DURATION: str = "duration"
23 | KEY_ROLE_SESSION_NAME: str = "role_session_name"
24 | DEFAULT_ROLE_SESSION_NAME: str = "jwt_redshift_session"
25 |
26 | def __init__(self: "JwtCredentialsProvider") -> None:
27 | super().__init__()
28 | self.last_refreshed_credentials: typing.Optional[NativeTokenHolder] = None
29 |
30 | @abstractmethod
31 | def get_jwt_assertion(self: "JwtCredentialsProvider") -> str:
32 | """
33 | Returns the jwt assertion retrieved following IdP authentication
34 | """
35 | pass # pragma: no cover
36 |
37 | def add_parameter(
38 | self: "JwtCredentialsProvider",
39 | info: RedshiftProperty,
40 | ) -> None:
41 | self.provider_name = info.provider_name
42 | self.ssl_insecure = info.ssl_insecure
43 | self.disable_cache = info.iam_disable_cache
44 | self.group_federation = False
45 |
46 | if info.role_session_name is not None:
47 | self.role_session_name = info.role_session_name
48 |
49 | def set_group_federation(self: "JwtCredentialsProvider", group_federation: bool):
50 | self.group_federation = group_federation
51 |
52 | def get_credentials(self: "JwtCredentialsProvider") -> NativeTokenHolder:
53 | _logger.debug("JwtCredentialsProvider.get_credentials")
54 | credentials: typing.Optional[NativeTokenHolder] = None
55 |
56 | if not self.disable_cache:
57 | _logger.debug("checking cache for credentials")
58 | key = self.get_cache_key()
59 | credentials = typing.cast(NativeTokenHolder, self.cache.get(key))
60 |
61 | if not credentials or credentials.is_expired():
62 | _logger.debug("JWT get_credentials NOT from cache")
63 | self.refresh()
64 |
65 | if self.disable_cache:
66 | credentials = self.last_refreshed_credentials
67 | self.last_refreshed_credentials = None
68 | else:
69 | credentials.refresh = False
70 | _logger.debug("JWT get_credentials from cache")
71 |
72 | if not self.disable_cache:
73 | credentials = typing.cast(NativeTokenHolder, self.cache[key])
74 | return typing.cast(NativeTokenHolder, credentials)
75 |
76 | def refresh(self: "JwtCredentialsProvider") -> None:
77 | _logger.debug("JwtCredentialsProvider.refresh")
78 | jwt: str = self.get_jwt_assertion()
79 |
80 | if jwt is None:
81 | exec_msg = "Unable to refresh, no jwt provided"
82 | _logger.debug(exec_msg)
83 | raise InterfaceError(exec_msg)
84 |
85 | credentials: NativeTokenHolder = NativeTokenHolder(access_token=jwt, expiration=None)
86 | credentials.refresh = True
87 |
88 | _logger.debug("disable_cache=%s", self.disable_cache)
89 | if not self.disable_cache:
90 | self.cache[self.get_cache_key()] = credentials
91 |
92 | else:
93 | self.last_refreshed_credentials = credentials
94 |
95 | def do_verify_ssl_cert(self: "JwtCredentialsProvider") -> bool:
96 | return not self.ssl_insecure
97 |
98 | def get_idp_token(self: "JwtCredentialsProvider") -> str:
99 | jwt: str = self.get_jwt_assertion()
100 |
101 | return jwt
102 |
103 | def get_sub_type(self: "JwtCredentialsProvider") -> int:
104 | return IamHelper.JWT_PLUGIN
105 |
106 |
107 | class BasicJwtCredentialsProvider(JwtCredentialsProvider):
108 | """
109 | A basic JWT Credential provider class that can be changed and implemented to work with any desired JWT service provider.
110 | """
111 |
112 | def __init__(self: "BasicJwtCredentialsProvider") -> None:
113 | super().__init__()
114 | self.jwt: typing.Optional[str] = None
115 |
116 | def add_parameter(
117 | self: "BasicJwtCredentialsProvider",
118 | info: RedshiftProperty,
119 | ) -> None:
120 | super().add_parameter(info)
121 | self.jwt = info.web_identity_token
122 |
123 | if info.role_session_name is not None:
124 | self.role_session_name = info.role_session_name
125 |
126 | def check_required_parameters(self: "BasicJwtCredentialsProvider") -> None:
127 | super().check_required_parameters()
128 | if not self.jwt:
129 | BasicJwtCredentialsProvider.handle_missing_required_property("jwt")
130 |
131 | def get_cache_key(self: "BasicJwtCredentialsProvider") -> str:
132 | return typing.cast(str, self.jwt)
133 |
134 | def get_jwt_assertion(self: "BasicJwtCredentialsProvider") -> str:
135 | self.check_required_parameters()
136 | return self.jwt # type: ignore
137 |
--------------------------------------------------------------------------------
/redshift_connector/plugin/native_token_holder.py:
--------------------------------------------------------------------------------
1 | import datetime
2 | import typing
3 |
4 | from dateutil.tz import tzutc
5 |
6 |
7 | class NativeTokenHolder:
8 | """
9 | Holds Redshift Native authentication credentials.
10 | """
11 |
12 | def __init__(self: "NativeTokenHolder", access_token: str, expiration: typing.Optional[str]):
13 | self.access_token: str = access_token
14 | self.expiration = expiration
15 | self.refresh: bool = False # True means newly added, false means from cache
16 |
17 | def is_expired(self: "NativeTokenHolder") -> bool:
18 | """
19 | Returns boolean value indicating if the Redshift native authentication credentials have expired.
20 | """
21 | return self.expiration is None or typing.cast(datetime.datetime, self.expiration) < datetime.datetime.now(
22 | tz=tzutc()
23 | )
24 |
--------------------------------------------------------------------------------
/redshift_connector/py.typed:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/aws/amazon-redshift-python-driver/5bbd061b03cbe555dbc783f44f716bcf8b4864c4/redshift_connector/py.typed
--------------------------------------------------------------------------------
/redshift_connector/utils/__init__.py:
--------------------------------------------------------------------------------
1 | from .array_util import (
2 | array_check_dimensions,
3 | array_dim_lengths,
4 | array_find_first_element,
5 | array_flatten,
6 | array_has_null,
7 | walk_array,
8 | )
9 | from .driver_info import DriverInfo
10 | from .logging_utils import make_divider_block, mask_secure_info_in_props
11 | from .type_utils import (
12 | FC_BINARY,
13 | FC_TEXT,
14 | NULL,
15 | NULL_BYTE,
16 | array_recv_binary,
17 | array_recv_text,
18 | bh_unpack,
19 | cccc_unpack,
20 | ci_unpack,
21 | date_in,
22 | date_recv_binary,
23 | float_array_recv,
24 | geographyhex_recv,
25 | h_pack,
26 | h_unpack,
27 | i_pack,
28 | i_unpack,
29 | ihihih_unpack,
30 | ii_pack,
31 | iii_pack,
32 | int_array_recv,
33 | numeric_in,
34 | numeric_in_binary,
35 | numeric_to_float_binary,
36 | numeric_to_float_in,
37 | py_types,
38 | q_pack,
39 | redshift_types,
40 | text_recv,
41 | time_in,
42 | time_recv_binary,
43 | timetz_in,
44 | timetz_recv_binary,
45 | varbytehex_recv,
46 | )
47 |
--------------------------------------------------------------------------------
/redshift_connector/utils/array_util.py:
--------------------------------------------------------------------------------
1 | import typing
2 |
3 | from redshift_connector.error import ArrayDimensionsNotConsistentError
4 |
5 |
6 | def walk_array(arr: typing.List) -> typing.Generator:
7 | for i, v in enumerate(arr):
8 | if isinstance(v, list):
9 | for a, i2, v2 in walk_array(v):
10 | yield a, i2, v2
11 | else:
12 | yield arr, i, v
13 |
14 |
15 | def array_find_first_element(arr: typing.List) -> typing.Any:
16 | for v in array_flatten(arr):
17 | if v is not None:
18 | return v
19 | return None
20 |
21 |
22 | def array_flatten(arr: typing.List) -> typing.Generator:
23 | for v in arr:
24 | if isinstance(v, list):
25 | for v2 in array_flatten(v):
26 | yield v2
27 | else:
28 | yield v
29 |
30 |
31 | def array_check_dimensions(arr: typing.List) -> typing.List:
32 | if len(arr) > 0:
33 | v0 = arr[0]
34 | if isinstance(v0, list):
35 | req_len = len(v0)
36 | req_inner_lengths = array_check_dimensions(v0)
37 | for v in arr:
38 | inner_lengths = array_check_dimensions(v)
39 | if len(v) != req_len or inner_lengths != req_inner_lengths:
40 | raise ArrayDimensionsNotConsistentError("array dimensions not consistent")
41 | retval = [req_len]
42 | retval.extend(req_inner_lengths)
43 | return retval
44 | else:
45 | # make sure nothing else at this level is a list
46 | for v in arr:
47 | if isinstance(v, list):
48 | raise ArrayDimensionsNotConsistentError("array dimensions not consistent")
49 | return []
50 |
51 |
52 | def array_has_null(arr: typing.List) -> bool:
53 | for v in array_flatten(arr):
54 | if v is None:
55 | return True
56 | return False
57 |
58 |
59 | def array_dim_lengths(arr: typing.List) -> typing.List:
60 | len_arr = len(arr)
61 | retval = [len_arr]
62 | if len_arr > 0:
63 | v0 = arr[0]
64 | if isinstance(v0, list):
65 | retval.extend(array_dim_lengths(v0))
66 | return retval
67 |
--------------------------------------------------------------------------------
/redshift_connector/utils/driver_info.py:
--------------------------------------------------------------------------------
1 | class DriverInfo:
2 | """
3 | No-op informative class containing Amazon Redshift Python driver specifications.
4 | """
5 |
6 | @staticmethod
7 | def version() -> str:
8 | """
9 | The version of redshift_connector
10 | Returns
11 | -------
12 | The redshift_connector package version: str
13 | """
14 | from redshift_connector import __version__ as DRIVER_VERSION
15 |
16 | return str(DRIVER_VERSION)
17 |
18 | @staticmethod
19 | def driver_name() -> str:
20 | """
21 | The name of the Amazon Redshift Python driver, redshift_connector
22 | Returns
23 | -------
24 | The human readable name of the redshift_connector package: str
25 | """
26 | return "Redshift Python Driver"
27 |
28 | @staticmethod
29 | def driver_short_name() -> str:
30 | """
31 | The shortened name of the Amazon Redshift Python driver, redshift_connector
32 | Returns
33 | -------
34 | The shortened human readable name of the Amazon Redshift Python driver: str
35 | """
36 | return "RsPython"
37 |
38 | @staticmethod
39 | def driver_full_name() -> str:
40 | """
41 | The fully qualified name of the Amazon Redshift Python driver, redshift_connector
42 | Returns
43 | -------
44 | The fully qualified name of the Amazon Redshift Python driver: str
45 | """
46 | return "{driver_name} {driver_version}".format(
47 | driver_name=DriverInfo.driver_name(), driver_version=DriverInfo.version()
48 | )
49 |
--------------------------------------------------------------------------------
/redshift_connector/utils/extensible_digest.py:
--------------------------------------------------------------------------------
1 | import typing
2 | from hashlib import new as hashlib_new
3 |
4 | from redshift_connector import InterfaceError
5 |
6 | if typing.TYPE_CHECKING:
7 | from hashlib import _Hash
8 |
9 |
10 | class ExtensibleDigest:
11 | """
12 | Encodes user/password/salt information in the following way: SHA2(SHA2(password + user) + salt).
13 | """
14 |
15 | @staticmethod
16 | def encode(client_nonce: bytes, password: bytes, salt: bytes, algo_name: str, server_nonce: bytes) -> bytes:
17 | """
18 | Encodes user/password/salt information in the following way: SHA2(SHA2(password + user) + salt).
19 | :param client_nonce: The client nonce
20 | :type client_nonce: bytes
21 | :param password: The connecting user's password
22 | :type password: bytes
23 | :param salt: salt sent by the server
24 | :type salt: bytes
25 | :param algo_name: Algorithm name such as "SHA256" etc.
26 | :type algo_name: str
27 | :param server_nonce: random number generated by server
28 | :type server_nonce: bytes
29 | :return: the digest
30 | :rtype: bytes
31 | """
32 | try:
33 | hl1: "_Hash" = hashlib_new(name=algo_name)
34 | except ImportError:
35 | raise InterfaceError("Unable to encode password with extensible hashing: {}".format(algo_name))
36 | hl1.update(password)
37 | hl1.update(salt)
38 | pass_digest1: bytes = hl1.digest() # SHA2(user + password)
39 |
40 | hl2: "_Hash" = hashlib_new(name=algo_name)
41 | hl2.update(pass_digest1)
42 | hl2.update(server_nonce)
43 | hl2.update(client_nonce)
44 | pass_digest2 = hl2.digest() # SHA2(SHA2(password + user) + salt)
45 |
46 | return pass_digest2
47 |
--------------------------------------------------------------------------------
/redshift_connector/utils/logging_utils.py:
--------------------------------------------------------------------------------
1 | import copy
2 | import socket
3 | import typing
4 |
5 | if typing.TYPE_CHECKING:
6 | from redshift_connector import RedshiftProperty
7 |
8 |
9 | def make_divider_block() -> str:
10 | return "=" * 35
11 |
12 |
13 | def mask_secure_info_in_props(info: "RedshiftProperty") -> "RedshiftProperty":
14 | from redshift_connector import RedshiftProperty
15 |
16 | logging_allow_list: typing.Tuple[str, ...] = (
17 | # "access_key_id",
18 | "allow_db_user_override",
19 | "app_id",
20 | "app_name",
21 | "application_name",
22 | "auth_profile",
23 | "auto_create",
24 | # "client_id",
25 | "client_protocol_version",
26 | # "client_secret",
27 | "cluster_identifier",
28 | "credentials_provider",
29 | "database_metadata_current_db_only",
30 | "db_groups",
31 | "db_name",
32 | "db_user",
33 | "duration",
34 | "endpoint_url",
35 | "force_lowercase",
36 | "group_federation",
37 | "host",
38 | "iam",
39 | "iam_disable_cache",
40 | "idc_client_display_name",
41 | "idc_region",
42 | "identity_namespace",
43 | "idp_host",
44 | "idpPort",
45 | "idp_response_timeout",
46 | "idp_tenant",
47 | "issuer_url",
48 | "is_serverless",
49 | "listen_port",
50 | "login_url",
51 | "max_prepared_statements",
52 | "numeric_to_float",
53 | "partner_sp_id",
54 | # "password",
55 | "port",
56 | "preferred_role",
57 | "principal",
58 | "profile",
59 | "provider_name",
60 | "region",
61 | "replication",
62 | "role_arn",
63 | "role_session_name",
64 | "scope",
65 | # "secret_access_key",
66 | "serverless_acct_id",
67 | "serverless_work_group",
68 | # "session_token",
69 | "source_address",
70 | "ssl",
71 | "ssl_insecure",
72 | "sslmode",
73 | "tcp_keepalive",
74 | "token_type",
75 | "timeout",
76 | "unix_sock",
77 | "user_name",
78 | # "web_identity_token",
79 | )
80 |
81 | if info is None:
82 | return info
83 |
84 | temp: RedshiftProperty = RedshiftProperty()
85 |
86 | def is_populated(field: typing.Optional[str]):
87 | return field is not None and field != ""
88 |
89 | for parameter, value in info.__dict__.items():
90 | if parameter in logging_allow_list:
91 | temp.put(parameter, value)
92 | elif is_populated(value):
93 | try:
94 | temp.put(parameter, "***")
95 | except AttributeError:
96 | pass
97 |
98 | return temp
99 |
--------------------------------------------------------------------------------
/redshift_connector/utils/oids.py:
--------------------------------------------------------------------------------
1 | from enum import IntEnum, EnumMeta
2 |
3 | class RedshiftOIDMeta(EnumMeta):
4 | def __setattr__(self, name, value):
5 | if name in self.__dict__:
6 | raise AttributeError(f"Cannot modify OID constant '{name}'")
7 | # throw error if any constant value defined in RedshiftOID was modified
8 | # e.g. "Cannot modify OID constant 'VARCHAR'"
9 | super().__setattr__(name, value)
10 |
11 |
12 | class RedshiftOID(IntEnum, metaclass=RedshiftOIDMeta):
13 | ACLITEM = 1033
14 | ACLITEM_ARRAY = 1034
15 | ANY_ARRAY = 2277
16 | ABSTIME = 702
17 | BIGINT = 20
18 | BIGINT_ARRAY = 1016
19 | BOOLEAN = 16
20 | BOOLEAN_ARRAY = 1000
21 | BPCHAR = 1042
22 | BPCHAR_ARRAY = 1014
23 | BYTES = 17
24 | BYTES_ARRAY = 1001
25 | CHAR = 18
26 | CHAR_ARRAY = 1002
27 | CIDR = 650
28 | CIDR_ARRAY = 651
29 | CSTRING = 2275
30 | CSTRING_ARRAY = 1263
31 | DATE = 1082
32 | DATE_ARRAY = 1182
33 | FLOAT = 701
34 | FLOAT_ARRAY = 1022
35 | GEOGRAPHY = 3001
36 | GEOMETRY = 3000
37 | GEOMETRYHEX = 3999
38 | INET = 869
39 | INET_ARRAY = 1041
40 | INT2VECTOR = 22
41 | INTEGER = 23
42 | INTEGER_ARRAY = 1007
43 | INTERVAL = 1186
44 | INTERVAL_ARRAY = 1187
45 | INTERVALY2M = 1188
46 | INTERVALY2M_ARRAY = 1189
47 | INTERVALD2S = 1190
48 | INTERVALD2S_ARRAY = 1191
49 | JSON = 114
50 | JSON_ARRAY = 199
51 | JSONB = 3802
52 | JSONB_ARRAY = 3807
53 | MACADDR = 829
54 | MONEY = 790
55 | MONEY_ARRAY = 791
56 | NAME = 19
57 | NAME_ARRAY = 1003
58 | NUMERIC = 1700
59 | NUMERIC_ARRAY = 1231
60 | NULLTYPE = -1
61 | OID = 26
62 | OID_ARRAY = 1028
63 | POINT = 600
64 | REAL = 700
65 | REAL_ARRAY = 1021
66 | REGPROC = 24
67 | SMALLINT = 21
68 | SMALLINT_ARRAY = 1005
69 | SMALLINT_VECTOR = 22
70 | STRING = 1043
71 | SUPER = 4000
72 | TEXT = 25
73 | TEXT_ARRAY = 1009
74 | TIME = 1083
75 | TIME_ARRAY = 1183
76 | TIMESTAMP = 1114
77 | TIMESTAMP_ARRAY = 1115
78 | TIMESTAMPTZ = 1184
79 | TIMESTAMPTZ_ARRAY = 1185
80 | TIMETZ = 1266
81 | UNKNOWN = 705
82 | UUID_TYPE = 2950
83 | UUID_ARRAY = 2951
84 | VARCHAR = 1043
85 | VARBYTE = 6551
86 | VARCHAR_ARRAY = 1015
87 | XID = 28
88 |
89 | BIGINTEGER = BIGINT
90 | DATETIME = TIMESTAMP
91 | NUMBER = DECIMAL = NUMERIC
92 | DECIMAL_ARRAY = NUMERIC_ARRAY
93 | ROWID = OID
94 |
95 |
96 | def get_datatype_name(oid: int) -> str:
97 | return RedshiftOID(oid).name
98 |
--------------------------------------------------------------------------------
/redshift_connector/utils/sql_types.py:
--------------------------------------------------------------------------------
1 | from enum import IntEnum, EnumMeta
2 |
3 | class SQLTypeMeta(EnumMeta):
4 | def __setattr__(self, name, value):
5 | if name in self.__dict__:
6 | raise AttributeError(f"Cannot modify SQL type constant '{name}'")
7 | # throw error if any constant value defined in SQLType was modified
8 | # e.g. "Cannot modify SQL type constant 'SQL_VARCHAR'"
9 | super().__setattr__(name, value)
10 |
11 | class SQLType(IntEnum, metaclass=SQLTypeMeta):
12 | SQL_VARCHAR = 12
13 | SQL_BIT = -7
14 | SQL_TINYINT = -6
15 | SQL_SMALLINT = 5
16 | SQL_INTEGER = 4
17 | SQL_BIGINT = -5
18 | SQL_FLOAT = 6
19 | SQL_REAL = 7
20 | SQL_DOUBLE = 8
21 | SQL_NUMERIC = 2
22 | SQL_DECIMAL = 3
23 | SQL_CHAR = 1
24 | SQL_LONGVARCHAR = -1
25 | SQL_DATE = 91
26 | SQL_TIME = 92
27 | SQL_TIMESTAMP = 93
28 | SQL_BINARY = -2
29 | SQL_VARBINARY = -3
30 | SQL_LONGVARBINARY = -4
31 | SQL_NULL = 0
32 | SQL_OTHER = 1111
33 | SQL_BOOLEAN = 16
34 | SQL_LONGNVARCHAR = -16
35 | SQL_TIME_WITH_TIMEZONE = 2013
36 | SQL_TIMESTAMP_WITH_TIMEZONE = 2014
37 |
38 |
39 | def get_sql_type_name(sql_type: int) -> str:
40 | return SQLType(sql_type).name
--------------------------------------------------------------------------------
/redshift_connector/version.py:
--------------------------------------------------------------------------------
1 | # Store the version here so:
2 | # 1) we don't load dependencies by storing it in __init__.py
3 | # 2) we can import it in setup.py for the same reason
4 | # 3) we can import it into your module module
5 | __version__ = "2.1.7"
6 |
--------------------------------------------------------------------------------
/requirements-dev.txt:
--------------------------------------------------------------------------------
1 | pytest>=5.4.0,<7.4.0
2 | pytest-xdist[psutil]
3 | mypy>=0.782
4 | pre-commit>=2.6.0
5 | pytest-cov>=2.10.0
6 | pytest-mock>=1.11.1,<=3.2.0
7 | wheel>=0.33
8 | docutils>=0.14
9 | selenium>=4.25
10 | -e .
11 |
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | scramp>=1.2.0,<1.5.0
2 | pytz>=2020.1
3 | beautifulsoup4>=4.7.0,<5.0.0
4 | boto3>=1.9.201,<2.0.0
5 | requests>=2.23.0,<3.0.0
6 | lxml>=4.6.5
7 | botocore>=1.12.201,<2.0.0
8 | packaging
9 | setuptools
10 |
--------------------------------------------------------------------------------
/setup.cfg:
--------------------------------------------------------------------------------
1 |
2 | [coverage:run]
3 | branch = true
4 | parallel = true
5 |
6 | [coverage:paths]
7 | source =
8 | ./
9 | build/lib/*/site-packages/
10 |
11 | [coverage:html]
12 | directory = build/coverage
13 |
14 | [coverage:xml]
15 | output = build/coverage/coverage.xml
16 |
17 | [tool:pytest]
18 | addopts =
19 | --verbose
20 | --ignore=build/private
21 | --doctest-modules
22 | --cov redshift_connector
23 | --cov-report term-missing
24 | --cov-report html:build/coverage
25 | --cov-report xml:build/coverage/coverage.xml
26 | test
27 |
--------------------------------------------------------------------------------
/setup.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 | import os
3 | import sys
4 | import typing
5 |
6 | from setuptools import find_packages, setup
7 | from setuptools.command.install import install as InstallCommandBase
8 | from setuptools.command.test import test as TestCommand
9 | from setuptools.dist import Distribution
10 | from wheel.bdist_wheel import bdist_wheel as BDistWheelCommandBase
11 |
12 |
13 | class BasePytestCommand(TestCommand):
14 | user_options: typing.List = []
15 | test_dir: typing.Optional[str] = None
16 |
17 | def initialize_options(self):
18 | TestCommand.initialize_options(self)
19 |
20 | def finalize_options(self):
21 | TestCommand.finalize_options(self)
22 | self.test_args = []
23 | self.test_suite = True
24 |
25 | def run_tests(self):
26 | import pytest
27 |
28 | src_dir = os.getenv("SRC_DIR", "")
29 | if src_dir:
30 | src_dir += "/"
31 | args = [
32 | self.test_dir,
33 | "--cov=redshift_connector",
34 | "--cov-report=xml",
35 | "--cov-report=html",
36 | ]
37 |
38 | errno = pytest.main(args)
39 | sys.exit(errno)
40 |
41 |
42 | class UnitTestCommand(BasePytestCommand):
43 | test_dir: str = "test/unit"
44 |
45 |
46 | class IntegrationTestCommand(BasePytestCommand):
47 | test_dir = "test/integration"
48 |
49 |
50 | class BinaryDistribution(Distribution):
51 | def has_ext_modules(self):
52 | return True
53 |
54 |
55 | class InstallCommand(InstallCommandBase):
56 | """Override the installation dir."""
57 |
58 | def finalize_options(self):
59 | ret = InstallCommandBase.finalize_options(self)
60 | self.install_lib = self.install_platlib
61 | return ret
62 |
63 |
64 | class BDistWheelCommand(BDistWheelCommandBase):
65 | def finalize_options(self):
66 | super().finalize_options()
67 | self.root_is_pure = False
68 | self.universal = True
69 |
70 | def get_tag(self):
71 | python, abi, plat = "py3", "none", "any"
72 | return python, abi, plat
73 |
74 |
75 | custom_cmds = {
76 | "bdist_wheel": BDistWheelCommand,
77 | "unit_test": UnitTestCommand,
78 | "integration_test": IntegrationTestCommand,
79 | }
80 |
81 | if os.getenv("CUSTOMINSTALL", False):
82 | custom_cmds["install"] = InstallCommand
83 | elif "install" in custom_cmds:
84 | del custom_cmds["install"]
85 |
86 | # read the contents of your README file
87 | this_directory = os.path.abspath(os.path.dirname(__file__))
88 | with open(os.path.join(this_directory, "README.rst"), encoding="utf-8") as f:
89 | long_description = f.read()
90 | exec(open("redshift_connector/version.py").read())
91 |
92 | optional_deps = {
93 | "full": ["numpy", "pandas"],
94 | }
95 |
96 | setup(
97 | name="redshift_connector",
98 | version=__version__, # type: ignore
99 | description="Redshift interface library",
100 | long_description=long_description,
101 | long_description_content_type="text/x-rst",
102 | author="Amazon Web Services",
103 | author_email="redshift-drivers@amazon.com",
104 | url="https://github.com/aws/amazon-redshift-python-driver",
105 | license="Apache License 2.0",
106 | python_requires=">=3.6",
107 | install_requires=open("requirements.txt").read().strip().split("\n"),
108 | extras_require=optional_deps,
109 | classifiers=[
110 | "Development Status :: 5 - Production/Stable",
111 | "Intended Audience :: Developers",
112 | "License :: OSI Approved :: BSD License",
113 | "Programming Language :: Python",
114 | "Programming Language :: Python :: 3",
115 | "Programming Language :: Python :: 3.6",
116 | "Programming Language :: Python :: 3.7",
117 | "Programming Language :: Python :: 3.8",
118 | "Programming Language :: Python :: 3.9",
119 | "Programming Language :: Python :: 3.10",
120 | "Programming Language :: Python :: 3.11",
121 | "Programming Language :: Python :: Implementation",
122 | "Programming Language :: Python :: Implementation :: CPython",
123 | "Programming Language :: Python :: Implementation :: Jython",
124 | "Programming Language :: Python :: Implementation :: PyPy",
125 | "Operating System :: OS Independent",
126 | "Topic :: Database :: Front-Ends",
127 | "Topic :: Software Development :: Libraries :: Python Modules",
128 | ],
129 | keywords="redshift dbapi",
130 | include_package_data=True,
131 | package_data={"redshift-connector": ["*.py", "*.crt", "LICENSE", "NOTICE", "py.typed"]},
132 | packages=find_packages(exclude=["test*"]),
133 | cmdclass=custom_cmds,
134 | )
135 |
--------------------------------------------------------------------------------
/test/__init__.py:
--------------------------------------------------------------------------------
1 | from test.integration import (
2 | adfs_idp,
3 | azure_browser_idp,
4 | azure_idp,
5 | db_kwargs,
6 | idp_arg,
7 | jumpcloud_browser_idp,
8 | jwt_azure_v2_idp,
9 | jwt_google_idp,
10 | okta_browser_idp,
11 | okta_idp,
12 | ping_browser_idp,
13 | redshift_browser_idc,
14 | redshift_idp_token_auth_plugin,
15 | )
16 |
--------------------------------------------------------------------------------
/test/integration/__init__.py:
--------------------------------------------------------------------------------
1 | from test.conftest import (
2 | adfs_idp,
3 | azure_browser_idp,
4 | azure_idp,
5 | db_kwargs,
6 | ds_consumer_dsdb_kwargs,
7 | idp_arg,
8 | jumpcloud_browser_idp,
9 | jwt_azure_v2_idp,
10 | jwt_google_idp,
11 | okta_browser_idp,
12 | okta_idp,
13 | ping_browser_idp,
14 | redshift_browser_idc,
15 | redshift_idp_token_auth_plugin,
16 | redshift_native_browser_azure_oauth2_idp,
17 | )
18 |
--------------------------------------------------------------------------------
/test/integration/datatype/test_system_table_queries.py:
--------------------------------------------------------------------------------
1 | import typing
2 | from warnings import warn
3 |
4 | import pytest # type: ignore
5 |
6 | import redshift_connector
7 | from redshift_connector.config import ClientProtocolVersion
8 |
9 | system_tables: typing.List[str] = [
10 | "pg_aggregate",
11 | "pg_am",
12 | "pg_amop",
13 | "pg_amproc",
14 | "pg_attrdef",
15 | "pg_attribute",
16 | "pg_cast",
17 | "pg_class",
18 | "pg_constraint",
19 | "pg_conversion",
20 | "pg_database",
21 | "pg_default_acl",
22 | "pg_depend",
23 | "pg_description",
24 | # "pg_index", # has unsupported type oids 22, 30
25 | "pg_inherits",
26 | "pg_language",
27 | "pg_largeobject",
28 | "pg_namespace",
29 | "pg_opclass",
30 | "pg_operator",
31 | "pg_proc",
32 | "pg_rewrite",
33 | "pg_shdepend",
34 | "pg_statistic",
35 | "pg_tablespace",
36 | # "pg_trigger", # has unsupported type oid 22
37 | "pg_type",
38 | "pg_group",
39 | "pg_indexes",
40 | "pg_locks",
41 | "pg_rules",
42 | "pg_settings",
43 | "pg_stats",
44 | "pg_tables",
45 | "pg_user",
46 | "pg_views",
47 | # "pg_authid",
48 | # "pg_auth_members",
49 | # "pg_collation",
50 | # "pg_db_role_setting",
51 | # "pg_enum",
52 | # "pg_extension",
53 | # "pg_foreign_data_wrapper",
54 | # "pg_foreign_server",
55 | # "pg_foreign_table",
56 | # "pg_largeobject_metadata",
57 | # "pg_opfamily",
58 | # "pg_pltemplate",
59 | # "pg_seclabel",
60 | # "pg_shdescription",
61 | # "pg_ts_config",
62 | # "pg_ts_config_map",
63 | # "pg_ts_dict",
64 | # "pg_ts_parser",
65 | # "pg_ts_template",
66 | # "pg_user_mapping",
67 | # "pg_available_extensions",
68 | # "pg_available_extension_versions",
69 | # "pg_cursors",
70 | # "pg_prepared_statements",
71 | # "pg_prepared_xacts",
72 | # "pg_roles",
73 | # "pg_seclabels",
74 | # "pg_shadow",
75 | # "pg_timezone_abbrevs",
76 | # "pg_timezone_names",
77 | # "pg_user_mappings",
78 | ]
79 |
80 | # this test ensures system tables can be queried without datatype
81 | # conversion issue. no validation of result set occurs.
82 |
83 |
84 | @pytest.mark.parametrize("client_protocol", ClientProtocolVersion.list())
85 | @pytest.mark.parametrize("table_name", system_tables)
86 | def test_process_system_table_datatypes(db_kwargs, client_protocol, table_name):
87 | db_kwargs["client_protocol_version"] = client_protocol
88 |
89 | with redshift_connector.connect(**db_kwargs) as conn:
90 | if conn._client_protocol_version != client_protocol:
91 | warn(
92 | "Requested client_protocol_version was not satisfied. Requested {} Got {}".format(
93 | client_protocol, conn._client_protocol_version
94 | )
95 | )
96 | with conn.cursor() as cursor:
97 | cursor.execute("select * from {}".format(table_name))
98 |
--------------------------------------------------------------------------------
/test/integration/metadata/test_metadataAPIHelperServer.py:
--------------------------------------------------------------------------------
1 | import configparser
2 | import os
3 | import typing
4 | import pytest # type: ignore
5 |
6 | import redshift_connector
7 | from redshift_connector.metadataAPIHelper import MetadataAPIHelper
8 |
9 | conf: configparser.ConfigParser = configparser.ConfigParser()
10 | root_path: str = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
11 | conf.read(root_path + "/config.ini")
12 |
13 |
14 | def test_SHOW_DATABASES_col(db_kwargs) -> None:
15 | with redshift_connector.connect(**db_kwargs) as conn:
16 | with conn.cursor() as cursor:
17 | if cursor.supportSHOWDiscovery() >= 2:
18 | cursor.execute("SHOW DATABASES;")
19 |
20 | column = cursor.description
21 |
22 | col_set: typing.Set = set()
23 |
24 | for col in column:
25 | col_set.add(col[0])
26 |
27 | mock_metadataAPIHelper: MetadataAPIHelper = MetadataAPIHelper()
28 |
29 | assert mock_metadataAPIHelper._SHOW_DATABASES_database_name in col_set
30 |
31 |
32 | def test_SHOW_SCHEMAS_col(db_kwargs) -> None:
33 | with redshift_connector.connect(**db_kwargs) as conn:
34 | with conn.cursor() as cursor:
35 | if cursor.supportSHOWDiscovery() >= 2:
36 | cursor.execute("SHOW SCHEMAS FROM DATABASE test_catalog;")
37 |
38 | column = cursor.description
39 |
40 | col_set: typing.Set = set()
41 |
42 | for col in column:
43 | col_set.add(col[0])
44 |
45 | mock_metadataAPIHelper: MetadataAPIHelper = MetadataAPIHelper()
46 |
47 | assert mock_metadataAPIHelper._SHOW_SCHEMA_database_name in col_set
48 | assert mock_metadataAPIHelper._SHOW_SCHEMA_schema_name in col_set
49 |
50 |
51 | def test_SHOW_TABLES_col(db_kwargs) -> None:
52 | with redshift_connector.connect(**db_kwargs) as conn:
53 | with conn.cursor() as cursor:
54 | if cursor.supportSHOWDiscovery() >= 2:
55 | cursor.execute("SHOW TABLES FROM SCHEMA test_catalog.test_schema;")
56 |
57 | column = cursor.description
58 |
59 | col_set: typing.Set = set()
60 |
61 | for col in column:
62 | col_set.add(col[0])
63 |
64 | mock_metadataAPIHelper: MetadataAPIHelper = MetadataAPIHelper()
65 |
66 | assert mock_metadataAPIHelper._SHOW_TABLES_database_name in col_set
67 | assert mock_metadataAPIHelper._SHOW_TABLES_schema_name in col_set
68 | assert mock_metadataAPIHelper._SHOW_TABLES_table_name in col_set
69 | assert mock_metadataAPIHelper._SHOW_TABLES_table_type in col_set
70 | assert mock_metadataAPIHelper._SHOW_TABLES_remarks in col_set
71 |
72 |
73 | def test_SHOW_COLUMNS_col(db_kwargs) -> None:
74 | with redshift_connector.connect(**db_kwargs) as conn:
75 | with conn.cursor() as cursor:
76 | if cursor.supportSHOWDiscovery() >= 2:
77 | cursor.execute("SHOW COLUMNS FROM TABLE test_catalog.test_schema.test_table;")
78 |
79 | column = cursor.description
80 |
81 | col_set: typing.Set = set()
82 |
83 | for col in column:
84 | col_set.add(col[0])
85 |
86 | mock_metadataAPIHelper: MetadataAPIHelper = MetadataAPIHelper()
87 |
88 | assert mock_metadataAPIHelper._SHOW_COLUMNS_database_name in col_set
89 | assert mock_metadataAPIHelper._SHOW_COLUMNS_schema_name in col_set
90 | assert mock_metadataAPIHelper._SHOW_COLUMNS_table_name in col_set
91 | assert mock_metadataAPIHelper._SHOW_COLUMNS_column_name in col_set
92 | assert mock_metadataAPIHelper._SHOW_COLUMNS_ordinal_position in col_set
93 | assert mock_metadataAPIHelper._SHOW_COLUMNS_column_default in col_set
94 | assert mock_metadataAPIHelper._SHOW_COLUMNS_is_nullable in col_set
95 | assert mock_metadataAPIHelper._SHOW_COLUMNS_data_type in col_set
96 | assert mock_metadataAPIHelper._SHOW_COLUMNS_character_maximum_length in col_set
97 | assert mock_metadataAPIHelper._SHOW_COLUMNS_numeric_precision in col_set
98 | assert mock_metadataAPIHelper._SHOW_COLUMNS_numeric_scale in col_set
99 | assert mock_metadataAPIHelper._SHOW_COLUMNS_remarks in col_set
100 |
101 |
102 |
--------------------------------------------------------------------------------
/test/integration/metadata/test_metadataAPI_special_character_handling_standard_identifier.py:
--------------------------------------------------------------------------------
1 | import configparser
2 | import os
3 | import typing
4 | import pytest # type: ignore
5 |
6 | import redshift_connector
7 | from redshift_connector.metadataAPIHelper import MetadataAPIHelper
8 |
9 | conf: configparser.ConfigParser = configparser.ConfigParser()
10 | root_path: str = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
11 | conf.read(root_path + "/config.ini")
12 |
13 | # Manually set the following flag to False if we want to test cross-database metadata api call
14 | disable_cross_database_testing: bool = True
15 |
16 | current_catalog: str = "təst_cata好log$123standard"
17 |
18 | object_name: str = "təst_n𝒜m好e$123standard"
19 |
20 | startup_stmts: typing.Tuple[str, ...] = (
21 | "DROP SCHEMA IF EXISTS {} CASCADE;".format(object_name),
22 | "CREATE SCHEMA {};".format(object_name),
23 | "create table {}.{} ({} INT);".format(object_name, object_name, object_name),
24 | )
25 |
26 | test_cases = [
27 | # Standard identifier with lower case
28 | ([current_catalog, object_name], 1,
29 | [object_name]),
30 |
31 | # Standard identifier with mixed case
32 | ([current_catalog, "təst_N𝒜m好e$123Standard"], 0,
33 | []),
34 |
35 | # Standard identifier with lower case + illegal special character
36 | ([current_catalog, "təst_n𝒜m好e$123!standard"], 0,
37 | []),
38 | ]
39 |
40 | @pytest.fixture(scope="class", autouse=True)
41 | def test_metadataAPI_config(request, db_kwargs):
42 | global cur_db_kwargs
43 | cur_db_kwargs = dict(db_kwargs)
44 | print(cur_db_kwargs)
45 |
46 | with redshift_connector.connect(**cur_db_kwargs) as con:
47 | con.paramstyle = "format"
48 | con.autocommit = True
49 | with con.cursor() as cursor:
50 | try:
51 | cursor.execute("drop database {};".format(current_catalog))
52 | except redshift_connector.ProgrammingError:
53 | pass
54 | cursor.execute("create database {};".format(current_catalog))
55 | cur_db_kwargs["database"] = current_catalog
56 | with redshift_connector.connect(**cur_db_kwargs) as con:
57 | con.paramstyle = "format"
58 | with con.cursor() as cursor:
59 | for stmt in startup_stmts:
60 | cursor.execute(stmt)
61 |
62 | con.commit()
63 | def fin():
64 | try:
65 | with redshift_connector.connect(**db_kwargs) as con:
66 | con.autocommit = True
67 | with con.cursor() as cursor:
68 | cursor.execute("drop database {};".format(current_catalog))
69 | cursor.execute("select 1;")
70 | except redshift_connector.ProgrammingError:
71 | pass
72 |
73 | request.addfinalizer(fin)
74 |
75 | @pytest.mark.parametrize("test_case, expected_row_count, expected_result", test_cases)
76 | def test_get_schemas_special_character(db_kwargs, test_case, expected_row_count, expected_result) -> None:
77 | global cur_db_kwargs
78 | if disable_cross_database_testing:
79 | test_db_kwargs = dict(cur_db_kwargs)
80 | else:
81 | test_db_kwargs = dict(db_kwargs)
82 | test_db_kwargs["database_metadata_current_db_only"] = False
83 |
84 | with redshift_connector.connect(**test_db_kwargs) as conn:
85 | with conn.cursor() as cursor:
86 | result: tuple = cursor.get_schemas(test_case[0], test_case[1])
87 | assert len(result) == expected_row_count
88 | for actual_row, expected_schema in zip(result, expected_result):
89 | assert actual_row[0] == expected_schema
90 | assert actual_row[1] == current_catalog
91 |
92 | @pytest.mark.parametrize("test_case, expected_row_count, expected_result", test_cases)
93 | def test_get_tables_special_character(db_kwargs, test_case, expected_row_count, expected_result) -> None:
94 | global cur_db_kwargs
95 | if disable_cross_database_testing:
96 | test_db_kwargs = dict(cur_db_kwargs)
97 | else:
98 | test_db_kwargs = dict(db_kwargs)
99 | test_db_kwargs["database_metadata_current_db_only"] = False
100 |
101 | with redshift_connector.connect(**test_db_kwargs) as conn:
102 | with conn.cursor() as cursor:
103 | result: tuple = cursor.get_tables(test_case[0], test_case[1], test_case[1], None)
104 | assert len(result) == expected_row_count
105 | for actual_row, expected_name in zip(result, expected_result):
106 | assert actual_row[0] == current_catalog
107 | assert actual_row[1] == expected_name
108 | assert actual_row[2] == expected_name
109 |
110 | @pytest.mark.parametrize("test_case, expected_row_count, expected_result", test_cases)
111 | def test_get_columns_special_character(db_kwargs, test_case, expected_row_count, expected_result) -> None:
112 | global cur_db_kwargs
113 | if disable_cross_database_testing:
114 | test_db_kwargs = dict(cur_db_kwargs)
115 | else:
116 | test_db_kwargs = dict(db_kwargs)
117 | test_db_kwargs["database_metadata_current_db_only"] = False
118 |
119 | with redshift_connector.connect(**test_db_kwargs) as conn:
120 | with conn.cursor() as cursor:
121 | result: tuple = cursor.get_columns(test_case[0], test_case[1], test_case[1], test_case[1])
122 | assert len(result) == expected_row_count
123 | for actual_row, expected_name in zip(result, expected_result):
124 | assert actual_row[0] == current_catalog
125 | assert actual_row[1] == expected_name
126 | assert actual_row[2] == expected_name
127 | assert actual_row[3] == expected_name
128 |
129 |
--------------------------------------------------------------------------------
/test/integration/metadata/test_metadataAPI_sql_injection.py:
--------------------------------------------------------------------------------
1 | import configparser
2 | import os
3 | import typing
4 | import pytest # type: ignore
5 |
6 | import redshift_connector
7 | from redshift_connector.metadataAPIHelper import MetadataAPIHelper
8 |
9 | conf: configparser.ConfigParser = configparser.ConfigParser()
10 | root_path: str = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
11 | conf.read(root_path + "/config.ini")
12 |
13 | # Manually set the following flag to False if we want to test cross-database metadata api call
14 | disable_cross_database_testing: bool = True
15 |
16 | catalog_name: str = "test_sql_injection -- comment"
17 |
18 | schema_name: str = "test_schema; create table pwn1(i int);--"
19 | table_name: str = "test_table; create table pwn2(i int);--"
20 | col_name: str = "col"
21 |
22 | startup_stmts: typing.Tuple[str, ...] = (
23 | "DROP SCHEMA IF EXISTS \"{}\" CASCADE;".format(schema_name),
24 | "CREATE SCHEMA \"{}\";".format(schema_name),
25 | "create table \"{}\".\"{}\" (col INT);".format(schema_name, table_name),
26 | )
27 |
28 | test_cases = [
29 | "' OR '1'='1",
30 | "' OR 1=1 --",
31 | "' OR '1'='1' --",
32 | "') OR 1=1 --",
33 | "') OR '1'='1' --",
34 | "' OR 1=1;--",
35 | "' OR 1=1 LIMIT 1;--",
36 | "' UNION SELECT null --",
37 | "' UNION SELECT null, null --",
38 | "' UNION SELECT 1, 'username', 'password' FROM users --",
39 | "' UNION SELECT * FROM users --",
40 | "') AND 1=CAST((SELECT current_database()) AS INT) --",
41 | "' AND 1=1 --",
42 | "' AND 1=2 --",
43 | "'; DROP TABLE test_table_100; --",
44 | "\"; DROP TABLE test_table_100; --\"",
45 | "; DROP TABLE test_table_100; --"
46 | ]
47 |
48 | @pytest.fixture(scope="class", autouse=True)
49 | def test_metadataAPI_config(request, db_kwargs):
50 | global cur_db_kwargs
51 | cur_db_kwargs = dict(db_kwargs)
52 | print(cur_db_kwargs)
53 |
54 | with redshift_connector.connect(**cur_db_kwargs) as con:
55 | con.paramstyle = "format"
56 | con.autocommit = True
57 | with con.cursor() as cursor:
58 | try:
59 | cursor.execute("drop database \"{}\";".format(catalog_name))
60 | except redshift_connector.ProgrammingError:
61 | pass
62 | cursor.execute("create database \"{}\";".format(catalog_name))
63 | cur_db_kwargs["database"] = catalog_name
64 | with redshift_connector.connect(**cur_db_kwargs) as con:
65 | con.paramstyle = "format"
66 | with con.cursor() as cursor:
67 | for stmt in startup_stmts:
68 | cursor.execute(stmt)
69 |
70 | con.commit()
71 | def fin():
72 | try:
73 | with redshift_connector.connect(**db_kwargs) as con:
74 | con.autocommit = True
75 | with con.cursor() as cursor:
76 | cursor.execute("drop database \"{}\";".format(catalog_name))
77 | cursor.execute("select 1;")
78 | except redshift_connector.ProgrammingError:
79 | pass
80 |
81 | request.addfinalizer(fin)
82 |
83 | @pytest.mark.parametrize("test_input", test_cases)
84 | def test_input_parameter_sql_injection(db_kwargs, test_input) -> None:
85 | with redshift_connector.connect(**db_kwargs) as conn:
86 | with conn.cursor() as cursor:
87 | try:
88 | result: tuple = cursor.get_schemas(test_input, None)
89 | except Exception as e:
90 | pytest.fail(f"Unexpected exception raised: {e}")
91 |
92 | def test_get_schemas_sql_injection(db_kwargs) -> None:
93 | global cur_db_kwargs
94 | if disable_cross_database_testing:
95 | test_db_kwargs = dict(cur_db_kwargs)
96 | else:
97 | test_db_kwargs = dict(db_kwargs)
98 | test_db_kwargs["database_metadata_current_db_only"] = False
99 |
100 | with redshift_connector.connect(**test_db_kwargs) as conn:
101 | with conn.cursor() as cursor:
102 | result: tuple = cursor.get_schemas(catalog_name, None)
103 |
104 | assert len(result) > 0
105 |
106 | found_expected_row: bool = False
107 | for actual_row in result:
108 | if actual_row[0] == schema_name and actual_row[1] == catalog_name:
109 | found_expected_row = True
110 |
111 | assert found_expected_row
112 |
113 | def test_get_tables_sql_injection(db_kwargs) -> None:
114 | global cur_db_kwargs
115 | if disable_cross_database_testing:
116 | test_db_kwargs = dict(cur_db_kwargs)
117 | else:
118 | test_db_kwargs = dict(db_kwargs)
119 | test_db_kwargs["database_metadata_current_db_only"] = False
120 |
121 | with redshift_connector.connect(**test_db_kwargs) as conn:
122 | with conn.cursor() as cursor:
123 | result: tuple = cursor.get_tables(catalog_name, None, None)
124 |
125 | assert len(result) > 0
126 |
127 | found_expected_row: bool = False
128 | for actual_row in result:
129 | if actual_row[0] == catalog_name and actual_row[1] == schema_name and actual_row[2] == table_name:
130 | found_expected_row = True
131 |
132 | assert found_expected_row
133 |
134 | def test_get_columns_sql_injection(db_kwargs) -> None:
135 | global cur_db_kwargs
136 | if disable_cross_database_testing:
137 | test_db_kwargs = dict(cur_db_kwargs)
138 | else:
139 | test_db_kwargs = dict(db_kwargs)
140 | test_db_kwargs["database_metadata_current_db_only"] = False
141 |
142 | with redshift_connector.connect(**test_db_kwargs) as conn:
143 | with conn.cursor() as cursor:
144 | result: tuple = cursor.get_columns(catalog_name, None, None, None)
145 |
146 | assert len(result) > 0
147 |
148 | found_expected_row: bool = False
149 | for actual_row in result:
150 | if actual_row[0] == catalog_name and actual_row[1] == schema_name and actual_row[2] == table_name and actual_row[3] == col_name:
151 | found_expected_row = True
152 |
153 | assert found_expected_row
154 |
155 |
--------------------------------------------------------------------------------
/test/integration/plugin/conftest.py:
--------------------------------------------------------------------------------
1 | import typing
2 | from test.conftest import _get_default_connection_args, conf
3 |
4 | import pytest # type: ignore
5 |
6 | import redshift_connector
7 |
8 |
9 | @pytest.fixture(scope="session", autouse=True)
10 | def startup_db_stmts() -> None:
11 | """
12 | Executes a defined set of statements to configure a fresh Amazon Redshift cluster for IdP integration tests.
13 | """
14 | groups: typing.List[str] = conf.get("cluster-setup", "groups").split(sep=",")
15 |
16 | with redshift_connector.connect(**_get_default_connection_args()) as conn: # type: ignore
17 | conn.autocommit = True
18 | with conn.cursor() as cursor:
19 | for grp in groups:
20 | try:
21 | cursor.execute("DROP GROUP {}".format(grp))
22 | except:
23 | pass # we can't use IF EXISTS here, so ignore any error
24 | cursor.execute("CREATE GROUP {}".format(grp))
25 |
--------------------------------------------------------------------------------
/test/integration/plugin/test_azure_credentials_provider.py:
--------------------------------------------------------------------------------
1 | import configparser
2 | import os
3 | import typing
4 |
5 | import pytest # type: ignore
6 |
7 | import redshift_connector
8 |
9 | conf: configparser.ConfigParser = configparser.ConfigParser()
10 | root_path: str = os.path.dirname(os.path.dirname(os.path.abspath(os.path.join(__file__, os.pardir))))
11 | conf.read(root_path + "/config.ini")
12 |
13 | PROVIDER: typing.List[str] = ["azure_idp"]
14 |
15 |
16 | @pytest.mark.parametrize("idp_arg", PROVIDER, indirect=True)
17 | def test_preferred_role_should_use(idp_arg):
18 | idp_arg["preferred_role"] = conf.get("azure-idp", "preferred_role")
19 | with redshift_connector.connect(**idp_arg):
20 | pass
21 |
--------------------------------------------------------------------------------
/test/integration/plugin/test_okta_credentials_provider.py:
--------------------------------------------------------------------------------
1 | import configparser
2 | import os
3 | import typing
4 |
5 | import pytest # type: ignore
6 |
7 | import redshift_connector
8 |
9 | conf: configparser.ConfigParser = configparser.ConfigParser()
10 | root_path: str = os.path.dirname(os.path.dirname(os.path.abspath(os.path.join(__file__, os.pardir))))
11 | conf.read(root_path + "/config.ini")
12 |
13 | PROVIDER: typing.List[str] = ["okta_idp"]
14 |
15 |
16 | @pytest.mark.parametrize("idp_arg", PROVIDER, indirect=True)
17 | def test_idp_host_invalid_should_fail(idp_arg) -> None:
18 | wrong_idp_host: str = "andrew.okta.com"
19 | idp_arg["idp_host"] = wrong_idp_host
20 |
21 | with pytest.raises(redshift_connector.InterfaceError, match="Failed to get SAML assertion"):
22 | redshift_connector.connect(**idp_arg)
23 |
24 |
25 | @pytest.mark.parametrize("idp_arg", PROVIDER, indirect=True)
26 | def test_preferred_role_should_use(idp_arg) -> None:
27 | idp_arg["preferred_role"] = conf.get("okta-idp", "preferred_role")
28 | with redshift_connector.connect(**idp_arg):
29 | pass
30 |
--------------------------------------------------------------------------------
/test/integration/test_pandas.py:
--------------------------------------------------------------------------------
1 | from test.utils import numpy_only, pandas_only
2 | from warnings import filterwarnings
3 |
4 | import pytest # type: ignore
5 |
6 | import redshift_connector
7 | from redshift_connector.config import DbApiParamstyle
8 |
9 | # Tests relating to the pandas and numpy operation of the database driver
10 | # redshift_connector custom interface.
11 |
12 |
13 | @pytest.fixture
14 | def db_table(request, con: redshift_connector.Connection) -> redshift_connector.Connection:
15 | filterwarnings("ignore", "DB-API extension cursor.next()")
16 | filterwarnings("ignore", "DB-API extension cursor.__iter__()")
17 | con.paramstyle = "format" # type: ignore
18 | with con.cursor() as cursor:
19 | cursor.execute("drop table if exists book")
20 | cursor.execute("create Temp table book(bookname varchar,author varchar)")
21 |
22 | def fin() -> None:
23 | try:
24 | with con.cursor() as cursor:
25 | cursor.execute("drop table if exists book")
26 | except redshift_connector.ProgrammingError:
27 | pass
28 |
29 | request.addfinalizer(fin)
30 | return con
31 |
32 |
33 | @pandas_only
34 | def test_fetch_dataframe(db_table) -> None:
35 | import numpy as np # type: ignore
36 | import pandas as pd # type: ignore
37 |
38 | df = pd.DataFrame(
39 | np.array(
40 | [
41 | ["One Hundred Years of Solitude", "Gabriel García Márquez"],
42 | ["A Brief History of Time", "Stephen Hawking"],
43 | ]
44 | ),
45 | columns=["bookname", "author"],
46 | )
47 | with db_table.cursor() as cursor:
48 | cursor.executemany(
49 | "insert into book (bookname, author) values (%s, %s)",
50 | [
51 | ("One Hundred Years of Solitude", "Gabriel García Márquez"),
52 | ("A Brief History of Time", "Stephen Hawking"),
53 | ],
54 | )
55 | cursor.execute("select * from book; ")
56 | result = cursor.fetch_dataframe()
57 | assert result.columns[0] == "bookname"
58 | assert result.columns[1] == "author\u200e"
59 |
60 |
61 | @pandas_only
62 | @pytest.mark.parametrize("paramstyle", DbApiParamstyle.list())
63 | def test_write_dataframe(db_table, paramstyle) -> None:
64 | import numpy as np
65 | import pandas as pd
66 |
67 | df = pd.DataFrame(
68 | np.array(
69 | [
70 | ["One Hundred Years of Solitude", "Gabriel García Márquez"],
71 | ["A Brief History of Time", "Stephen Hawking"],
72 | ]
73 | ),
74 | columns=["bookname", "author"],
75 | )
76 | db_table.paramstyle = paramstyle
77 |
78 | with db_table.cursor() as cursor:
79 | cursor.write_dataframe(df, "book")
80 | cursor.execute("select * from book; ")
81 | result = cursor.fetchall()
82 | assert len(np.array(result)) == 2
83 |
84 | assert db_table.paramstyle == paramstyle
85 |
86 |
87 | @numpy_only
88 | def test_fetch_numpyarray(db_table) -> None:
89 | with db_table.cursor() as cursor:
90 | cursor.executemany(
91 | "insert into book (bookname, author) values (%s, %s)",
92 | [
93 | ("One Hundred Years of Solitude", "Gabriel García Márquez"),
94 | ("A Brief History of Time", "Stephen Hawking"),
95 | ],
96 | )
97 | cursor.execute("select * from book; ")
98 | result = cursor.fetch_numpy_array()
99 | assert len(result) == 2
100 |
--------------------------------------------------------------------------------
/test/integration/test_paramstyle.py:
--------------------------------------------------------------------------------
1 | import typing
2 |
3 | import pytest # type: ignore
4 |
5 |
6 | @pytest.mark.parametrize(
7 | "parameters",
8 | [
9 | ({"c1": "abc", "c2": "defg", "c3": "hijkl"}, ["abc", "defg", "hijkl"]),
10 | ({"c1": "a", "c2": "b", "c3": "c"}, ["a", "b", "c"]),
11 | ],
12 | )
13 | def test_pyformat(cursor, parameters) -> None:
14 | cursor.paramstyle = "pyformat"
15 | data, exp_result = parameters
16 | cursor.execute("create temporary table test_pyformat(c1 varchar, c2 varchar, c3 varchar)")
17 | cursor.execute("insert into test_pyformat(c1, c2, c3) values(%(c1)s, %(c2)s, %(c3)s)", data)
18 | cursor.execute("select * from test_pyformat")
19 | res: typing.Tuple[typing.List[str], ...] = cursor.fetchall()
20 | assert len(res) == 1
21 | assert len(res[0]) == len(data)
22 | assert res[0] == exp_result
23 |
24 |
25 | @pytest.mark.parametrize(
26 | "parameters",
27 | [
28 | (
29 | ({"c1": "abc", "c2": "defg", "c3": "hijkl"}, {"c1": "a", "c2": "b", "c3": "c"}),
30 | [["a", "b", "c"], ["abc", "defg", "hijkl"]],
31 | ),
32 | ],
33 | )
34 | def test_pyformat_multiple_insert(cursor, parameters) -> None:
35 | cursor.paramstyle = "pyformat"
36 | data, exp_result = parameters
37 | cursor.execute("create temporary table test_pyformat(c1 varchar, c2 varchar, c3 varchar)")
38 | cursor.executemany("insert into test_pyformat(c1, c2, c3) values(%(c1)s, %(c2)s, %(c3)s)", data)
39 | cursor.execute("select * from test_pyformat order by c1")
40 | res: typing.Tuple[typing.List[str], ...] = cursor.fetchall()
41 | assert len(res) == len(exp_result)
42 | for idx, row in enumerate(res):
43 | assert len(row) == len(exp_result[idx])
44 | assert row == exp_result[idx]
45 |
46 |
47 | @pytest.mark.parametrize(
48 | "parameters", [(["abc", "defg", "hijkl"], ["abc", "defg", "hijkl"]), (["a", "b", "c"], ["a", "b", "c"])]
49 | )
50 | def test_qmark(cursor, parameters) -> None:
51 | cursor.paramstyle = "qmark"
52 | data, exp_result = parameters
53 | cursor.execute("create temporary table test_qmark(c1 varchar, c2 varchar, c3 varchar)")
54 | cursor.execute("insert into test_qmark(c1, c2, c3) values(?, ?, ?)", data)
55 | cursor.execute("select * from test_qmark")
56 | res: typing.Tuple[typing.List[str], ...] = cursor.fetchall()
57 | assert len(res) == 1
58 | assert len(res[0]) == len(data)
59 | assert res[0] == exp_result
60 |
61 |
62 | @pytest.mark.parametrize(
63 | "parameters", [(["abc", "defg", "hijkl"], ["abc", "defg", "hijkl"]), (["a", "b", "c"], ["a", "b", "c"])]
64 | )
65 | def test_numeric(cursor, parameters) -> None:
66 | cursor.paramstyle = "numeric"
67 | data, exp_result = parameters
68 | cursor.execute("create temporary table test_numeric(c1 varchar, c2 varchar, c3 varchar)")
69 | cursor.execute("insert into test_numeric(c1, c2, c3) values(:1, :2, :3)", data)
70 | cursor.execute("select * from test_numeric")
71 | res: typing.Tuple[typing.List[str], ...] = cursor.fetchall()
72 | assert len(res) == 1
73 | assert len(res[0]) == len(data)
74 | assert res[0] == exp_result
75 |
76 |
77 | @pytest.mark.parametrize(
78 | "parameters",
79 | [
80 | ({"parameter1": "abc", "parameter2": "defg", "parameter3": "hijkl"}, ["abc", "defg", "hijkl"]),
81 | ({"parameter1": "a", "parameter2": "b", "parameter3": "c"}, ["a", "b", "c"]),
82 | ],
83 | )
84 | def test_named(cursor, parameters) -> None:
85 | cursor.paramstyle = "named"
86 | data, exp_result = parameters
87 | cursor.execute("create temporary table test_named(c1 varchar, c2 varchar, c3 varchar)")
88 | cursor.execute("insert into test_named(c1, c2, c3) values(:parameter1, :parameter2, :parameter3)", data)
89 | cursor.execute("select * from test_named")
90 | res: typing.Tuple[typing.List[str], ...] = cursor.fetchall()
91 | assert len(res) == 1
92 | assert len(res[0]) == len(data)
93 | assert res[0] == exp_result
94 |
95 |
96 | @pytest.mark.parametrize(
97 | "parameters", [(["abc", "defg", "hijkl"], ["abc", "defg", "hijkl"]), (["a", "b", "c"], ["a", "b", "c"])]
98 | )
99 | def test_format(cursor, parameters) -> None:
100 | cursor.paramstyle = "format"
101 | data, exp_result = parameters
102 | cursor.execute("create temporary table test_format(c1 varchar, c2 varchar, c3 varchar)")
103 | cursor.execute("insert into test_format(c1, c2, c3) values(%s, %s, %s)", data)
104 | cursor.execute("select * from test_format")
105 | res: typing.Tuple[typing.List[str], ...] = cursor.fetchall()
106 | assert len(res) == 1
107 | assert len(res[0]) == len(data)
108 | assert res[0] == exp_result
109 |
110 |
111 | @pytest.mark.parametrize(
112 | "parameters",
113 | [
114 | ([["abc", "defg", "hijkl"], ["a", "b", "c"]], [["a", "b", "c"], ["abc", "defg", "hijkl"]]),
115 | ],
116 | )
117 | def test_format_multiple(cursor, parameters) -> None:
118 | cursor.paramstyle = "format"
119 | data, exp_result = parameters
120 | cursor.execute("create temporary table test_format(c1 varchar, c2 varchar, c3 varchar)")
121 | cursor.executemany("insert into test_format(c1, c2, c3) values(%s, %s, %s)", data)
122 | cursor.execute("select * from test_format order by c1")
123 | res: typing.Tuple[typing.List[str], ...] = cursor.fetchall()
124 | assert len(res) == len(exp_result)
125 | for idx, row in enumerate(res):
126 | assert len(row) == len(exp_result[idx])
127 | assert row == exp_result[idx]
128 |
--------------------------------------------------------------------------------
/test/integration/test_redshift_property.py:
--------------------------------------------------------------------------------
1 | import logging
2 |
3 | import pytest
4 |
5 | from redshift_connector import InterfaceError, RedshiftProperty
6 |
7 | LOGGER = logging.getLogger(__name__)
8 | LOGGER.propagate = True
9 |
10 |
11 | def test_set_region_from_endpoint_lookup(db_kwargs) -> None:
12 | rp: RedshiftProperty = RedshiftProperty()
13 | rp.put(key="host", value=db_kwargs["host"])
14 | rp.put(key="port", value=db_kwargs["port"])
15 | rp.set_region_from_host()
16 |
17 | expected_region = rp.region
18 | rp.region = None
19 |
20 | rp.set_region_from_endpoint_lookup()
21 | assert rp.region == expected_region
22 |
23 |
24 | @pytest.mark.parametrize("host, port", [("x", 1000), ("amazon.com", -1), ("-o", 5439)])
25 | def test_set_region_from_endpoint_lookup_raises(host, port, caplog) -> None:
26 | import logging
27 |
28 | rp: RedshiftProperty = RedshiftProperty()
29 | rp.put(key="host", value=host)
30 | rp.put(key="port", value=port)
31 | expected_msg: str = "Unable to automatically determine AWS region from host {} port {}. Please check host and port connection parameters are correct.".format(
32 | host, port
33 | )
34 |
35 | with caplog.at_level(logging.DEBUG):
36 | rp.set_region_from_endpoint_lookup()
37 | assert expected_msg in caplog.text
38 |
--------------------------------------------------------------------------------
/test/integration/test_unsupported_datatypes.py:
--------------------------------------------------------------------------------
1 | import configparser
2 | import os
3 | import typing
4 | from warnings import filterwarnings
5 |
6 | import pytest # type: ignore
7 |
8 | import redshift_connector
9 |
10 | if typing.TYPE_CHECKING:
11 | from redshift_connector import Connection
12 |
13 | conf: configparser.ConfigParser = configparser.ConfigParser()
14 | root_path: str = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
15 | conf.read(root_path + "/config.ini")
16 |
17 |
18 | @pytest.fixture
19 | def db_table(request, con: redshift_connector.Connection) -> redshift_connector.Connection:
20 | filterwarnings("ignore", "DB-API extension cursor.next()")
21 | filterwarnings("ignore", "DB-API extension cursor.__iter__()")
22 | con.paramstyle = "format" # type: ignore
23 |
24 | def fin() -> None:
25 | try:
26 | with con.cursor() as cursor:
27 | cursor.execute("drop table if exists t1")
28 | except redshift_connector.ProgrammingError:
29 | pass
30 |
31 | request.addfinalizer(fin)
32 | return con
33 |
34 |
35 | # https://docs.aws.amazon.com/redshift/latest/dg/c_unsupported-postgresql-datatypes.html
36 | unsupported_datatypes: typing.List[str] = [
37 | "bytea",
38 | "interval",
39 | "bit",
40 | "bit varying",
41 | "hstore",
42 | "json",
43 | "serial",
44 | "bigserial",
45 | "smallserial",
46 | "money",
47 | "txid_snapshot",
48 | "uuid",
49 | "xml",
50 | "inet",
51 | "cidr",
52 | "macaddr",
53 | "oid",
54 | "regproc",
55 | "regprocedure",
56 | "regoper",
57 | "regoperator",
58 | "regclass",
59 | "regtype",
60 | "regrole",
61 | "regnamespace",
62 | "regconfig",
63 | "regdictionary",
64 | "any",
65 | "anyelement",
66 | "anyarray",
67 | "anynonarray",
68 | "anyenum",
69 | "anyrange",
70 | "cstring",
71 | "internal",
72 | "language_handler",
73 | "fdw_handler",
74 | "tsm_handler",
75 | "record",
76 | "trigger",
77 | "event_trigger",
78 | "pg_ddl_command",
79 | "void",
80 | "opaque",
81 | "int4range",
82 | "int8range",
83 | "numrange",
84 | "tsrange",
85 | "tstzrange",
86 | "daterange",
87 | "tsvector",
88 | "tsquery",
89 | ]
90 |
91 |
92 | @pytest.mark.unsupported_datatype
93 | class TestUnsupportedDataTypes:
94 | @pytest.mark.parametrize("datatype", unsupported_datatypes)
95 | def test_create_table_with_unsupported_datatype_fails(self, db_table: "Connection", datatype: str) -> None:
96 | with db_table.cursor() as cursor:
97 | with pytest.raises(Exception) as exception:
98 | cursor.execute("CREATE TEMPORARY TABLE t1 (a {})".format(datatype))
99 | assert exception.type == redshift_connector.ProgrammingError
100 | assert (
101 | 'Column "t1.a" has unsupported type' in exception.__str__()
102 | or 'type "{}" does not exist'.format(datatype) in exception.__str__()
103 | or 'syntax error at or near "{}"'.format(datatype) in exception.__str__()
104 | )
105 |
106 | @pytest.mark.parametrize("datatype", ["int[]", "int[][]"])
107 | def test_create_table_with_array_datatype_fails(self, db_table: "Connection", datatype: str) -> None:
108 | with db_table.cursor() as cursor:
109 | with pytest.raises(Exception) as exception:
110 | cursor.execute("CREATE TEMPORARY TABLE t1 (a {})".format(datatype))
111 | assert exception.type == redshift_connector.ProgrammingError
112 | assert (
113 | 'Column "t1.a" has unsupported type' in exception.__str__()
114 | or 'type "{}" does not exist'.format(datatype) in exception.__str__()
115 | )
116 |
--------------------------------------------------------------------------------
/test/manual/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/aws/amazon-redshift-python-driver/5bbd061b03cbe555dbc783f44f716bcf8b4864c4/test/manual/__init__.py
--------------------------------------------------------------------------------
/test/manual/auth/test_aws_credentials.py:
--------------------------------------------------------------------------------
1 | import pytest # type: ignore
2 |
3 | import redshift_connector
4 |
5 | aws_secret_access_key: str = ""
6 | aws_access_key: str = ""
7 | aws_session_token: str = ""
8 |
9 | """
10 | How to use:
11 | 0) If necessary, create a Redshift cluster
12 | 1) In the connect method below, specify the connection parameters
13 | 3) Specify the AWS IAM credentials in the variables above
14 | 4) Manually execute this test
15 | """
16 |
17 |
18 | @pytest.mark.skip(reason="manual")
19 | def test_use_aws_credentials_default_profile() -> None:
20 | with redshift_connector.connect(
21 | iam=True,
22 | database="my_database",
23 | db_user="my_db_user",
24 | password="",
25 | user="",
26 | cluster_identifier="my_cluster_identifier",
27 | region="my_region",
28 | access_key_id=aws_access_key,
29 | secret_access_key=aws_secret_access_key,
30 | session_token=aws_session_token,
31 | ) as con:
32 | with con.cursor() as cursor:
33 | cursor.execute("select 1")
34 |
35 |
36 | """
37 | How to use:
38 | 0) Generate credentials using instructions: https://docs.aws.amazon.com/sdk-for-javascript/v2/developer-guide/getting-your-credentials.html
39 | 1) In the connect method below, specify the connection parameters
40 | 3) Specify the AWS IAM credentials in the variables above
41 | 4) Update iam_helper.py to include correct min version.
42 | 5) Manually execute this test
43 | """
44 |
45 |
46 | @pytest.mark.skip(reason="manual")
47 | def test_use_get_cluster_credentials_with_iam(db_kwargs):
48 | role_name = "groupFederationTest"
49 | with redshift_connector.connect(**db_kwargs) as conn:
50 | with conn.cursor() as cursor:
51 | # https://docs.aws.amazon.com/redshift/latest/dg/r_CREATE_USER.html
52 | cursor.execute('create user "IAMR:{}" with password disable;'.format(role_name))
53 | with redshift_connector.connect(
54 | iam=True,
55 | database="replace_me",
56 | cluster_identifier="replace_me",
57 | region="replace_me",
58 | profile="replace_me", # contains credentials for AssumeRole groupFederationTest
59 | group_federation=True,
60 | ) as con:
61 | with con.cursor() as cursor:
62 | cursor.execute("select 1")
63 | cursor.execute("select current_user")
64 | assert cursor.fetchone()[0] == role_name
65 |
--------------------------------------------------------------------------------
/test/manual/auth/test_redshift_auth_profile.py:
--------------------------------------------------------------------------------
1 | import typing
2 |
3 | import pytest # type: ignore
4 |
5 | import redshift_connector
6 |
7 | access_key_id: str = "replace_me"
8 | secret_access_key: str = "replace_me"
9 | session_token: str = "replace_me"
10 |
11 | """
12 | This is a manually executable test. It requires valid AWS credentials.
13 |
14 | A Redshift authentication profile will be created prior to the test.
15 | Following the test, this profile will be deleted.
16 |
17 | Please modify the fixture, handle_redshift_auth_profile, to include any additional arguments to boto3 client required
18 | for your testing configuration. This may include attributes such as endpoint_url and region. The contents of the auth
19 | profile can be modified as needed.
20 | """
21 |
22 |
23 | creds: typing.Dict[str, str] = {
24 | "aws_access_key_id": "replace_me",
25 | "aws_session_token": "replace_me",
26 | "aws_secret_access_key": "replace_me",
27 | }
28 |
29 | auth_profile_name: str = "PythonManualTest"
30 |
31 |
32 | @pytest.fixture(autouse=True)
33 | def handle_redshift_auth_profile(request, db_kwargs: typing.Dict[str, typing.Union[str, bool, int]]) -> None:
34 | import json
35 |
36 | import boto3 # type: ignore
37 | from botocore.exceptions import ClientError # type: ignore
38 |
39 | payload: str = json.dumps(
40 | {
41 | "host": db_kwargs["host"],
42 | "db_user": db_kwargs["user"],
43 | "max_prepared_statements": 5,
44 | "region": db_kwargs["region"],
45 | "cluster_identifier": db_kwargs["cluster_identifier"],
46 | "db_name": db_kwargs["database"],
47 | }
48 | )
49 |
50 | try:
51 | client = boto3.client(
52 | "redshift",
53 | **{**creds, **{"region_name": typing.cast(str, db_kwargs["region"])}},
54 | verify=False,
55 | )
56 | client.create_authentication_profile(
57 | AuthenticationProfileName=auth_profile_name, AuthenticationProfileContent=payload
58 | )
59 | except ClientError:
60 | raise
61 |
62 | def fin() -> None:
63 | import boto3
64 | from botocore.exceptions import ClientError
65 |
66 | try:
67 | client = boto3.client(
68 | "redshift",
69 | **{**creds, **{"region_name": typing.cast(str, db_kwargs["region"])}},
70 | verify=False,
71 | )
72 | client.delete_authentication_profile(
73 | AuthenticationProfileName=auth_profile_name,
74 | )
75 | except ClientError:
76 | raise
77 |
78 | request.addfinalizer(fin)
79 |
80 |
81 | @pytest.mark.skip(reason="manual")
82 | def test_redshift_auth_profile_can_connect(db_kwargs):
83 | with redshift_connector.connect(
84 | region=db_kwargs["region"],
85 | access_key_id=creds["aws_access_key_id"],
86 | secret_access_key=creds["aws_secret_access_key"],
87 | session_token=creds["aws_session_token"],
88 | auth_profile=auth_profile_name,
89 | iam=True,
90 | ) as conn:
91 | assert conn.user == "IAM:{}".format(db_kwargs["user"]).encode()
92 | assert conn.max_prepared_statements == 5
93 |
--------------------------------------------------------------------------------
/test/manual/plugin/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/aws/amazon-redshift-python-driver/5bbd061b03cbe555dbc783f44f716bcf8b4864c4/test/manual/plugin/__init__.py
--------------------------------------------------------------------------------
/test/manual/plugin/test_browser_credentials_provider.py:
--------------------------------------------------------------------------------
1 | import configparser
2 | import os
3 | import typing
4 |
5 | import pytest # type: ignore
6 |
7 | import redshift_connector
8 |
9 | conf: configparser.ConfigParser = configparser.ConfigParser()
10 | root_path: str = os.path.dirname(os.path.dirname(os.path.abspath(os.path.join(__file__, os.pardir))))
11 | conf.read(root_path + "/config.ini")
12 |
13 | BROWSER_CREDENTIAL_PROVIDERS: typing.List[str] = [
14 | "jumpcloud_browser_idp",
15 | "okta_browser_idp",
16 | "azure_browser_idp",
17 | # "jwt_azure_v2_idp",
18 | # "jwt_google_idp",
19 | "ping_browser_idp",
20 | "redshift_native_browser_azure_oauth2_idp",
21 | "redshift_browser_idc",
22 | "redshift_idp_token_auth_plugin",
23 | ]
24 |
25 | """
26 | How to use:
27 | 0) If necessary, create a Redshift cluster and configure it for use with the desired IdP
28 | 1) In config.ini specify the connection parameters required by the desired IdP fixture in test/integration/conftest.py
29 | 2) Ensure browser cookies have been cleared
30 | 3) Manually execute the tests in this file, providing the necessary login information in the web browser
31 | """
32 |
33 |
34 | @pytest.mark.skip(reason="manual")
35 | @pytest.mark.parametrize("idp_arg", BROWSER_CREDENTIAL_PROVIDERS, indirect=True)
36 | def test_browser_credentials_provider_can_auth(idp_arg):
37 | with redshift_connector.connect(**idp_arg) as conn:
38 | with conn.cursor() as cursor:
39 | cursor.execute("select 1;")
40 |
--------------------------------------------------------------------------------
/test/manual/test_redshift_custom_domain.py:
--------------------------------------------------------------------------------
1 | import pytest
2 |
3 | import redshift_connector
4 | from redshift_connector.idp_auth_helper import SupportedSSLMode
5 |
6 | """
7 | These functional tests ensure connections to Redshift provisioned customer with custom domain name can be established
8 | when using various authentication methods.
9 |
10 | Pre-requisites:
11 | 1) Redshift provisioned configuration
12 | 2) Existing custom domain association with instance created in step 1
13 | """
14 |
15 |
16 | @pytest.mark.skip(reason="manual")
17 | @pytest.mark.parametrize("sslmode", (SupportedSSLMode.VERIFY_CA, SupportedSSLMode.VERIFY_FULL))
18 | def test_native_connect(provisioned_cname_db_kwargs, sslmode) -> None:
19 | # this test requires aws default profile contains valid credentials that provide permissions for
20 | # redshift:GetClusterCredentials ( Only called from this test method)
21 | import boto3
22 |
23 | profile = "default"
24 | client = boto3.client(
25 | service_name="redshift",
26 | region_name="eu-north-1",
27 | )
28 | # fetch cluster credentials and pass them as driver connect parameters
29 | response = client.get_cluster_credentials(
30 | CustomDomainName=provisioned_cname_db_kwargs["host"], DbUser=provisioned_cname_db_kwargs["db_user"]
31 | )
32 |
33 | provisioned_cname_db_kwargs["password"] = response["DbPassword"]
34 | provisioned_cname_db_kwargs["user"] = response["DbUser"]
35 | provisioned_cname_db_kwargs["profile"] = profile
36 | provisioned_cname_db_kwargs["ssl"] = True
37 | provisioned_cname_db_kwargs["sslmode"] = sslmode.value
38 |
39 | with redshift_connector.connect(**provisioned_cname_db_kwargs):
40 | pass
41 |
42 |
43 | @pytest.mark.skip(reason="manual")
44 | @pytest.mark.parametrize("sslmode", (SupportedSSLMode.VERIFY_CA, SupportedSSLMode.VERIFY_FULL))
45 | def test_iam_connect(provisioned_cname_db_kwargs, sslmode) -> None:
46 | # this test requires aws default profile contains valid credentials that provide permissions for
47 | # redshift:GetClusterCredentials (called from driver)
48 | # redshift:DescribeClusters (called from driver)
49 | # redshift:DescribeCustomDomainAssociations (called from driver)
50 | provisioned_cname_db_kwargs["iam"] = True
51 | provisioned_cname_db_kwargs["profile"] = "default"
52 | provisioned_cname_db_kwargs["auto_create"] = True
53 | provisioned_cname_db_kwargs["ssl"] = True
54 | provisioned_cname_db_kwargs["sslmode"] = sslmode.value
55 | with redshift_connector.connect(**provisioned_cname_db_kwargs):
56 | pass
57 |
58 |
59 | def test_idp_connect(okta_idp, provisioned_cname_db_kwargs) -> None:
60 | # todo
61 | pass
62 |
63 |
64 | @pytest.mark.skip(reason="manual")
65 | def test_nlb_connect() -> None:
66 | args = {
67 | "iam": True,
68 | # "access_key_id": "xxx",
69 | # "secret_access_key": "xxx",
70 | "cluster_identifier": "replace-me",
71 | "region": "us-east-1",
72 | "host": "replace-me",
73 | "port": 5439,
74 | "database": "dev",
75 | "db_user": "replace-me",
76 | }
77 | with redshift_connector.connect(**args): # type: ignore
78 | pass
79 |
80 |
81 | @pytest.mark.skip(reason="manual")
82 | @pytest.mark.parametrize("sslmode", (SupportedSSLMode.VERIFY_CA, SupportedSSLMode.VERIFY_FULL))
83 | def test_serverless_iam_cname_connect(sslmode, serverless_cname_db_kwargs):
84 | serverless_cname_db_kwargs["iam"] = True
85 | serverless_cname_db_kwargs["profile"] = "default"
86 | serverless_cname_db_kwargs["auto_create"] = True
87 | serverless_cname_db_kwargs["ssl"] = True
88 | serverless_cname_db_kwargs["sslmode"] = sslmode.value
89 |
90 | with redshift_connector.connect(**serverless_cname_db_kwargs) as conn:
91 | with conn.cursor() as cursor:
92 | cursor.execute("select current_user")
93 | print(cursor.fetchone())
94 |
95 |
96 | @pytest.mark.skip(reason="manual")
97 | @pytest.mark.parametrize("sslmode", (SupportedSSLMode.VERIFY_CA, SupportedSSLMode.VERIFY_FULL))
98 | def test_serverless_cname_connect(sslmode, serverless_cname_db_kwargs):
99 | # this test requires aws default profile contains valid credentials that provide permissions for
100 | # redshift-serverless:GetCredentials ( Only called from this test method)
101 | import boto3
102 |
103 | profile = "default"
104 | client = boto3.client(
105 | service_name="redshift-serverless",
106 | region_name="us-east-1",
107 | )
108 | # fetch cluster credentials and pass them as driver connect parameters
109 | response = client.get_credentials(
110 | customDomainName=serverless_cname_db_kwargs["host"], dbName=serverless_cname_db_kwargs["database"]
111 | )
112 |
113 | serverless_cname_db_kwargs["sslmode"] = sslmode.value
114 | serverless_cname_db_kwargs["ssl"] = True
115 | serverless_cname_db_kwargs["user"] = response["dbUser"]
116 | serverless_cname_db_kwargs["password"] = response["dbPassword"]
117 | serverless_cname_db_kwargs["profile"] = profile
118 |
119 | with redshift_connector.connect(**serverless_cname_db_kwargs) as conn:
120 | with conn.cursor() as cursor:
121 | cursor.execute("select current_user")
122 | print(cursor.fetchone())
123 |
--------------------------------------------------------------------------------
/test/manual/test_redshift_serverless.py:
--------------------------------------------------------------------------------
1 | import pytest
2 |
3 | import redshift_connector
4 |
5 | """
6 | These functional tests ensure connections to Redshift serverless can be established when
7 | using various authentication methods.
8 |
9 | Please note the pre-requisites were documented while this feature is under public preview,
10 | and are subject to change.
11 |
12 | Pre-requisites:
13 | 1) Redshift serverless configuration
14 | 2) EC2 instance configured for accessing Redshift serverless (i.e. in compatible VPC, subnet)
15 | 3) Perform a sanity check using psql to ensure Redshift serverless connection can be established
16 | 3) EC2 instance has Python installed
17 | 4) Clone redshift_connector on EC2 instance and install
18 |
19 | How to use:
20 | 1) Populate config.ini with the Redshift serverless endpoint and user authentication information
21 | 2) Run this file with pytest
22 | """
23 |
24 |
25 | @pytest.mark.skip(reason="manual")
26 | def test_native_auth(serverless_native_db_kwargs) -> None:
27 | with redshift_connector.connect(**serverless_native_db_kwargs):
28 | pass
29 |
30 |
31 | @pytest.mark.skip(reason="manual")
32 | def test_iam_auth(serverless_iam_db_kwargs) -> None:
33 | with redshift_connector.connect(**serverless_iam_db_kwargs):
34 | pass
35 |
36 |
37 | @pytest.mark.skip(reason="manual")
38 | def test_idp_auth(okta_idp) -> None:
39 | okta_idp["host"] = "my_redshift_serverless_endpoint"
40 |
41 | with redshift_connector.connect(**okta_idp):
42 | pass
43 |
44 |
45 | @pytest.mark.skip()
46 | def test_connection_without_host(serverless_iam_db_kwargs) -> None:
47 | serverless_iam_db_kwargs["is_serverless"] = True
48 | serverless_iam_db_kwargs["host"] = None
49 | serverless_iam_db_kwargs["serverless_work_group"] = "default"
50 | with redshift_connector.connect(**serverless_iam_db_kwargs) as conn:
51 | with conn.cursor() as cursor:
52 | cursor.execute("select 1")
53 |
54 |
55 | @pytest.mark.skip()
56 | def test_nlb_connection(serverless_iam_db_kwargs) -> None:
57 | serverless_iam_db_kwargs["is_serverless"] = True
58 | serverless_iam_db_kwargs["host"] = "my_nlb_endpoint"
59 | serverless_iam_db_kwargs["serverless_work_group"] = "default"
60 | with redshift_connector.connect(**serverless_iam_db_kwargs) as conn:
61 | with conn.cursor() as cursor:
62 | cursor.execute("select 1")
63 |
64 |
65 | @pytest.mark.skip()
66 | def test_vpc_endpoint_connection(serverless_iam_db_kwargs) -> None:
67 | serverless_iam_db_kwargs["is_serverless"] = True
68 | serverless_iam_db_kwargs["host"] = "my_vpc_endpoint"
69 | serverless_iam_db_kwargs["serverless_work_group"] = "default"
70 | with redshift_connector.connect(**serverless_iam_db_kwargs) as conn:
71 | with conn.cursor() as cursor:
72 | cursor.execute("select 1")
73 |
--------------------------------------------------------------------------------
/test/performance/bulk_insert_performance.py:
--------------------------------------------------------------------------------
1 | import configparser
2 | import csv
3 | import os
4 | import time
5 | import typing
6 |
7 | import redshift_connector
8 |
9 | conf: configparser.ConfigParser = configparser.ConfigParser()
10 | root_path: str = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
11 | conf.read(root_path + "/config.ini")
12 |
13 | root_path = os.path.dirname(os.path.abspath(__file__))
14 |
15 |
16 | def perf_conn():
17 | return redshift_connector.connect(
18 | database=conf.get("ci-cluster", "database"),
19 | host=conf.get("ci-cluster", "host"),
20 | port=conf.getint("default-test", "port"),
21 | user=conf.get("ci-cluster", "test_user"),
22 | password=conf.get("ci-cluster", "test_password"),
23 | ssl=True,
24 | sslmode=conf.get("default-test", "sslmode"),
25 | iam=False,
26 | )
27 |
28 |
29 | print("Reading data from csv file")
30 | with open(root_path + "/bulk_insert_data.csv", "r", encoding="utf8") as csv_data:
31 | reader = csv.reader(csv_data, delimiter="\t")
32 | next(reader)
33 | data = []
34 | for row in reader:
35 | data.append(row)
36 | print("Inserting {} rows having {} columns".format(len(data), len(data[0])))
37 |
38 | print("\nCursor.insert_data_bulk()..")
39 | for batch_size in [1e2, 1e3, 1e4]:
40 | with perf_conn() as conn:
41 | with conn.cursor() as cursor:
42 | cursor.execute("drop table if exists bulk_insert_perf;")
43 | cursor.execute("create table bulk_insert_perf (c1 int, c2 int, c3 int, c4 int, c5 int);")
44 | start_time: float = time.time()
45 | cursor.insert_data_bulk(
46 | filename=root_path + "/bulk_insert_data.csv",
47 | table_name="bulk_insert_perf",
48 | parameter_indices=[0, 1, 2, 3, 4],
49 | column_names=["c1", "c2", "c3", "c4", "c5"],
50 | delimiter="\t",
51 | batch_size=batch_size,
52 | )
53 |
54 | print("batch_size={0} {1} seconds.".format(batch_size, time.time() - start_time))
55 |
56 | print("Cursor.executemany()")
57 | with perf_conn() as conn:
58 | with conn.cursor() as cursor:
59 | cursor.execute("drop table if exists bulk_insert_perf;")
60 | cursor.execute("create table bulk_insert_perf (c1 int, c2 int, c3 int, c4 int, c5 int);")
61 | start_time = time.time()
62 | cursor.executemany("insert into bulk_insert_perf(c1, c2, c3, c4, c5) values(%s, %s, %s, %s, %s)", data)
63 | print("{0} seconds.".format(time.time() - start_time))
64 |
--------------------------------------------------------------------------------
/test/performance/performance.py:
--------------------------------------------------------------------------------
1 | import configparser
2 | import os
3 | import time
4 | import typing
5 |
6 | import redshift_connector
7 |
8 | conf: configparser.ConfigParser = configparser.ConfigParser()
9 | root_path: str = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
10 | conf.read(root_path + "/config.ini")
11 |
12 | root_path = os.path.dirname(os.path.abspath(__file__))
13 | sql: typing.TextIO = open(root_path + "/test.sql", "r", encoding="utf8")
14 | sqls: typing.List[str] = sql.readlines()
15 | sqls = [_sql.replace("\n", "") for _sql in sqls]
16 | sql.close()
17 |
18 | conn: redshift_connector.Connection = redshift_connector.connect(
19 | database=conf.get("ci-cluster", "database"),
20 | host=conf.get("ci-cluster", "host"),
21 | port=conf.getint("default-test", "port"),
22 | user=conf.get("ci-cluster", "test_user"),
23 | password=conf.get("ci-cluster", "test_password"),
24 | ssl=True,
25 | sslmode=conf.get("default-test", "sslmode"),
26 | iam=False,
27 | )
28 |
29 | cursor: redshift_connector.Cursor = conn.cursor()
30 | for _sql in sqls:
31 | cursor.execute(_sql)
32 |
33 | result: typing.Tuple[typing.List[int], ...] = cursor.fetchall()
34 | print("fetch {result} rows".format(result=result))
35 |
36 | print("start calculate fetch time")
37 | for val in [True, False]:
38 | print("merge_socket_read={val}".format(val=val))
39 | start_time: float = time.time()
40 | cursor.execute("select * from performance", merge_socket_read=val)
41 | results: typing.Tuple[typing.List[int], ...] = cursor.fetchall()
42 | print("Took {0} seconds.".format(time.time() - start_time))
43 | print("fetch {result} rows".format(result=len(results)))
44 |
45 | cursor.close()
46 | conn.commit()
47 |
--------------------------------------------------------------------------------
/test/performance/protocol_perf_test.sql:
--------------------------------------------------------------------------------
1 | /* To use: PGPASSWORD={password} psql --host={host} --port 5439 --user={user} --dbname={db} -f protocol_perf_test.sql */
2 | /* All used tables have 5 columns, and all columns within the same table have the same datatype */
3 |
4 | drop table if exists perf_varchar;
5 | create table perf_varchar (val1 varchar, val2 varchar, val3 varchar, val4 varchar, val5 varchar);
6 |
7 | drop table if exists perf_time;
8 | create table perf_time (val1 time, val2 time, val3 time, val4 time, val5 time);
9 |
10 | drop table if exists perf_timetz;
11 | create table perf_timetz (val1 timetz, val2 timetz, val3 timetz, val4 timetz, val5 timetz);
12 |
13 | drop table if exists perf_timestamptz;
14 | create table perf_timestamptz (val1 timestamptz, val2 timestamptz, val3 timestamptz, val4 timestamptz, val5 timestamptz);
15 |
16 | insert into perf_varchar values('abcd¬µ3kt¿abcdÆgda123~Øasd', 'abcd¬µ3kt¿abcdÆgda123~Øasd', 'abcd¬µ3kt¿abcdÆgda123~Øasd', 'abcd¬µ3kt¿abcdÆgda123~Øasd', 'abcd¬µ3kt¿abcdÆgda123~Øasd');
17 | insert into perf_varchar (select * from perf_varchar);
18 | insert into perf_varchar (select * from perf_varchar);
19 | insert into perf_varchar (select * from perf_varchar);
20 | insert into perf_varchar (select * from perf_varchar);
21 | insert into perf_varchar (select * from perf_varchar);
22 | insert into perf_varchar (select * from perf_varchar);
23 | insert into perf_varchar (select * from perf_varchar);
24 | insert into perf_varchar (select * from perf_varchar);
25 | insert into perf_varchar (select * from perf_varchar);
26 | insert into perf_varchar (select * from perf_varchar);
27 | insert into perf_varchar (select * from perf_varchar);
28 | insert into perf_varchar (select * from perf_varchar);
29 | insert into perf_varchar (select * from perf_varchar);
30 | insert into perf_varchar (select * from perf_varchar);
31 | insert into perf_varchar (select * from perf_varchar);
32 | insert into perf_varchar (select * from perf_varchar);
33 | insert into perf_varchar (select * from perf_varchar);
34 | insert into perf_varchar (select * from perf_varchar);
35 | insert into perf_varchar (select * from perf_varchar);
36 | insert into perf_varchar (select * from perf_varchar);
37 | insert into perf_varchar (select * from perf_varchar);
38 | insert into perf_varchar (select * from perf_varchar);
39 |
40 | insert into perf_time values('12:13:14', '12:13:14', '12:13:14', '12:13:14', '12:13:14');
41 | insert into perf_time (select * from perf_time);
42 | insert into perf_time (select * from perf_time);
43 | insert into perf_time (select * from perf_time);
44 | insert into perf_time (select * from perf_time);
45 | insert into perf_time (select * from perf_time);
46 | insert into perf_time (select * from perf_time);
47 | insert into perf_time (select * from perf_time);
48 | insert into perf_time (select * from perf_time);
49 | insert into perf_time (select * from perf_time);
50 | insert into perf_time (select * from perf_time);
51 | insert into perf_time (select * from perf_time);
52 | insert into perf_time (select * from perf_time);
53 | insert into perf_time (select * from perf_time);
54 | insert into perf_time (select * from perf_time);
55 | insert into perf_time (select * from perf_time);
56 | insert into perf_time (select * from perf_time);
57 | insert into perf_time (select * from perf_time);
58 | insert into perf_time (select * from perf_time);
59 | insert into perf_time (select * from perf_time);
60 | insert into perf_time (select * from perf_time);
61 | insert into perf_time (select * from perf_time);
62 | insert into perf_time (select * from perf_time);
63 |
64 | insert into perf_timetz values('20:13:14.123456', '20:13:14.123456', '20:13:14.123456', '20:13:14.123456', '20:13:14.123456');
65 | insert into perf_timetz (select * from perf_timetz);
66 | insert into perf_timetz (select * from perf_timetz);
67 | insert into perf_timetz (select * from perf_timetz);
68 | insert into perf_timetz (select * from perf_timetz);
69 | insert into perf_timetz (select * from perf_timetz);
70 | insert into perf_timetz (select * from perf_timetz);
71 | insert into perf_timetz (select * from perf_timetz);
72 | insert into perf_timetz (select * from perf_timetz);
73 | insert into perf_timetz (select * from perf_timetz);
74 | insert into perf_timetz (select * from perf_timetz);
75 | insert into perf_timetz (select * from perf_timetz);
76 | insert into perf_timetz (select * from perf_timetz);
77 | insert into perf_timetz (select * from perf_timetz);
78 | insert into perf_timetz (select * from perf_timetz);
79 | insert into perf_timetz (select * from perf_timetz);
80 | insert into perf_timetz (select * from perf_timetz);
81 | insert into perf_timetz (select * from perf_timetz);
82 | insert into perf_timetz (select * from perf_timetz);
83 |
84 | insert into perf_timestamptz values('1997-10-11 07:37:16', '1997-10-11 07:37:16', '1997-10-11 07:37:16', '1997-10-11 07:37:16', '1997-10-11 07:37:16');
85 | insert into perf_timestamptz (select * from perf_timestamptz);
86 | insert into perf_timestamptz (select * from perf_timestamptz);
87 | insert into perf_timestamptz (select * from perf_timestamptz);
88 | insert into perf_timestamptz (select * from perf_timestamptz);
89 | insert into perf_timestamptz (select * from perf_timestamptz);
90 | insert into perf_timestamptz (select * from perf_timestamptz);
91 | insert into perf_timestamptz (select * from perf_timestamptz);
92 | insert into perf_timestamptz (select * from perf_timestamptz);
93 | insert into perf_timestamptz (select * from perf_timestamptz);
94 | insert into perf_timestamptz (select * from perf_timestamptz);
95 | insert into perf_timestamptz (select * from perf_timestamptz);
96 | insert into perf_timestamptz (select * from perf_timestamptz);
97 | insert into perf_timestamptz (select * from perf_timestamptz);
98 | insert into perf_timestamptz (select * from perf_timestamptz);
99 | insert into perf_timestamptz (select * from perf_timestamptz);
100 | insert into perf_timestamptz (select * from perf_timestamptz);
101 | insert into perf_timestamptz (select * from perf_timestamptz);
102 | insert into perf_timestamptz (select * from perf_timestamptz);
103 |
--------------------------------------------------------------------------------
/test/performance/test.sql:
--------------------------------------------------------------------------------
1 | drop table if exists performance CASCADE;
2 | create table performance(c1 integer);
3 | insert into performance values(1);
4 | insert into performance values(2);
5 | insert into performance values(3);
6 | insert into performance values(4);
7 | insert into performance values(5);
8 | insert into performance values(6);
9 | insert into performance values(7);
10 | insert into performance values(8);
11 | insert into performance values(9);
12 | insert into performance select * from performance;
13 | insert into performance select * from performance;
14 | insert into performance select * from performance;
15 | insert into performance select * from performance;
16 | insert into performance select * from performance;
17 | insert into performance select * from performance;
18 | insert into performance select * from performance;
19 | insert into performance select * from performance;
20 | insert into performance select * from performance;
21 | insert into performance select * from performance;
22 | insert into performance select * from performance;
23 | insert into performance select * from performance;
24 | insert into performance select * from performance;
25 | insert into performance select * from performance;
26 | insert into performance select * from performance;
27 | insert into performance select * from performance;
28 | insert into performance select * from performance;
29 | insert into performance select * from performance;
30 | insert into performance select * from performance;
31 | insert into performance select * from performance;
32 | insert into performance select * from performance LIMIT 10000000;
33 | insert into performance select * from performance LIMIT 10000000;
34 | insert into performance select * from performance LIMIT 10000000;
35 | insert into performance select * from performance LIMIT 10000000;
36 | insert into performance select * from performance LIMIT 1457280;
37 | select count(*) from performance;
38 |
--------------------------------------------------------------------------------
/test/unit/__init__.py:
--------------------------------------------------------------------------------
1 | from .mocks import MockCredentialsProvider
2 |
--------------------------------------------------------------------------------
/test/unit/auth/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/aws/amazon-redshift-python-driver/5bbd061b03cbe555dbc783f44f716bcf8b4864c4/test/unit/auth/__init__.py
--------------------------------------------------------------------------------
/test/unit/auth/test_aws_credentials_provider.py:
--------------------------------------------------------------------------------
1 | import typing
2 | from unittest.mock import MagicMock, patch
3 |
4 | import pytest # type: ignore
5 |
6 | from redshift_connector import InterfaceError
7 | from redshift_connector.auth.aws_credentials_provider import (
8 | AWSCredentialsProvider,
9 | AWSDirectCredentialsHolder,
10 | )
11 | from redshift_connector.credentials_holder import AWSProfileCredentialsHolder
12 | from redshift_connector.redshift_property import RedshiftProperty
13 |
14 |
15 | def _make_aws_credentials_obj_with_profile() -> AWSCredentialsProvider:
16 | cred_provider: AWSCredentialsProvider = AWSCredentialsProvider()
17 | rp: RedshiftProperty = RedshiftProperty()
18 | profile_name: str = "myProfile"
19 |
20 | rp.profile = profile_name
21 |
22 | cred_provider.add_parameter(rp)
23 | return cred_provider
24 |
25 |
26 | def _make_aws_credentials_obj_with_credentials() -> AWSCredentialsProvider:
27 | cred_provider: AWSCredentialsProvider = AWSCredentialsProvider()
28 | rp: RedshiftProperty = RedshiftProperty()
29 | access_key_id: str = "my_access"
30 | secret_key: str = "my_secret"
31 | session_token: str = "my_session"
32 |
33 | rp.access_key_id = access_key_id
34 | rp.secret_access_key = secret_key
35 | rp.session_token = session_token
36 |
37 | cred_provider.add_parameter(rp)
38 | return cred_provider
39 |
40 |
41 | def test_create_aws_credentials_provider_with_profile() -> None:
42 | cred_provider: AWSCredentialsProvider = _make_aws_credentials_obj_with_profile()
43 | assert cred_provider.profile == "myProfile"
44 | assert cred_provider.access_key_id is None
45 | assert cred_provider.secret_access_key is None
46 | assert cred_provider.session_token is None
47 |
48 |
49 | def test_create_aws_credentials_provider_with_credentials() -> None:
50 | cred_provider: AWSCredentialsProvider = _make_aws_credentials_obj_with_credentials()
51 | assert cred_provider.profile is None
52 | assert cred_provider.access_key_id == "my_access"
53 | assert cred_provider.secret_access_key == "my_secret"
54 | assert cred_provider.session_token == "my_session"
55 |
56 |
57 | def test_get_cache_key_with_profile() -> None:
58 | cred_provider: AWSCredentialsProvider = _make_aws_credentials_obj_with_profile()
59 | assert cred_provider.get_cache_key() == hash(cred_provider.profile)
60 |
61 |
62 | def test_get_cache_key_with_credentials() -> None:
63 | cred_provider: AWSCredentialsProvider = _make_aws_credentials_obj_with_credentials()
64 | assert cred_provider.get_cache_key() == hash("my_access")
65 |
66 |
67 | def test_get_credentials_checks_cache_first(mocker) -> None:
68 | mocked_credential_holder = MagicMock()
69 |
70 | def mock_set_cache(cp: AWSCredentialsProvider, key: str = "tomato") -> None:
71 | cp.cache[key] = mocked_credential_holder # type: ignore
72 |
73 | cred_provider: AWSCredentialsProvider = AWSCredentialsProvider()
74 | mocker.patch("redshift_connector.auth.AWSCredentialsProvider.get_cache_key", return_value="tomato")
75 |
76 | with patch("redshift_connector.auth.AWSCredentialsProvider.refresh") as mocked_refresh:
77 | mocked_refresh.side_effect = mock_set_cache(cred_provider) # type: ignore
78 | get_cache_key_spy = mocker.spy(cred_provider, "get_cache_key")
79 |
80 | assert cred_provider.get_credentials() == mocked_credential_holder
81 |
82 | assert get_cache_key_spy.called is True
83 | assert get_cache_key_spy.call_count == 1
84 |
85 |
86 | def test_get_credentials_refresh_error_is_raised(mocker) -> None:
87 | cred_provider: AWSCredentialsProvider = AWSCredentialsProvider()
88 | mocker.patch("redshift_connector.auth.AWSCredentialsProvider.get_cache_key", return_value="tomato")
89 | expected_exception = "Refreshing IdP credentials failed"
90 |
91 | with patch("redshift_connector.auth.AWSCredentialsProvider.refresh") as mocked_refresh:
92 | mocked_refresh.side_effect = Exception(expected_exception)
93 |
94 | with pytest.raises(InterfaceError, match=expected_exception):
95 | cred_provider.get_credentials()
96 |
97 |
98 | def test_refresh_uses_profile_if_present(mocker) -> None:
99 | cred_provider: AWSCredentialsProvider = _make_aws_credentials_obj_with_profile()
100 | mocked_boto_session: MagicMock = MagicMock()
101 |
102 | with patch("boto3.Session", return_value=mocked_boto_session):
103 | cred_provider.refresh()
104 |
105 | assert hash("myProfile") in cred_provider.cache
106 | assert isinstance(cred_provider.cache[hash("myProfile")], AWSProfileCredentialsHolder)
107 | assert typing.cast(AWSProfileCredentialsHolder, cred_provider.cache[hash("myProfile")]).profile == "myProfile"
108 |
109 |
110 | def test_refresh_uses_credentials_if_present(mocker) -> None:
111 | cred_provider: AWSCredentialsProvider = _make_aws_credentials_obj_with_credentials()
112 | mocked_boto_session: MagicMock = MagicMock()
113 |
114 | with patch("boto3.Session", return_value=mocked_boto_session):
115 | cred_provider.refresh()
116 |
117 | assert hash("my_access") in cred_provider.cache
118 | assert isinstance(cred_provider.cache[hash("my_access")], AWSDirectCredentialsHolder)
119 | assert (
120 | typing.cast(AWSDirectCredentialsHolder, cred_provider.cache[hash("my_access")]).access_key_id == "my_access"
121 | )
122 | assert (
123 | typing.cast(AWSDirectCredentialsHolder, cred_provider.cache[hash("my_access")]).secret_access_key
124 | == "my_secret"
125 | )
126 | assert (
127 | typing.cast(AWSDirectCredentialsHolder, cred_provider.cache[hash("my_access")]).session_token
128 | == "my_session"
129 | )
130 |
--------------------------------------------------------------------------------
/test/unit/datatype/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/aws/amazon-redshift-python-driver/5bbd061b03cbe555dbc783f44f716bcf8b4864c4/test/unit/datatype/__init__.py
--------------------------------------------------------------------------------
/test/unit/datatype/test_oids.py:
--------------------------------------------------------------------------------
1 | import typing
2 |
3 | import pytest
4 |
5 | from redshift_connector.utils.oids import RedshiftOID, get_datatype_name
6 |
7 | all_oids: typing.List[int] = [d for d in RedshiftOID]
8 | all_datatypes: typing.List[typing.Tuple[int, str]] = [(d, d.name) for d in RedshiftOID]
9 |
10 |
11 | @pytest.mark.parametrize("oid", all_oids)
12 | def test_RedshiftOID_has_type_int(oid):
13 | assert isinstance(oid, int)
14 |
15 |
16 | @pytest.mark.parametrize("oid, datatype", all_datatypes)
17 | def test_get_datatype_name(oid, datatype):
18 | assert get_datatype_name(oid) == datatype
19 |
20 |
21 | def test_get_datatype_name_invalid_oid_raises() -> None:
22 | with pytest.raises(ValueError, match="not a valid RedshiftOID"):
23 | get_datatype_name(-9)
24 |
25 | def test_modify_oid() -> None:
26 | with pytest.raises(AttributeError, match="Cannot modify OID constant 'VARCHAR'"):
27 | RedshiftOID.VARCHAR = 0
--------------------------------------------------------------------------------
/test/unit/datatype/test_sql_types.py:
--------------------------------------------------------------------------------
1 | import typing
2 |
3 | import pytest
4 |
5 | from redshift_connector.utils.sql_types import SQLType, get_sql_type_name
6 |
7 | all_oids: typing.List[int] = [d for d in SQLType]
8 | all_datatypes: typing.List[typing.Tuple[int, str]] = [(d, d.name) for d in SQLType]
9 |
10 |
11 | @pytest.mark.parametrize("oid", all_oids)
12 | def test_sql_type_has_type_int(oid):
13 | assert isinstance(oid, int)
14 |
15 |
16 | @pytest.mark.parametrize("oid, datatype", all_datatypes)
17 | def test_get_sql_type_name(oid, datatype):
18 | assert get_sql_type_name(oid) == datatype
19 |
20 |
21 | def test_get_datatype_name_invalid_oid_raises() -> None:
22 | with pytest.raises(ValueError, match="not a valid SQLType"):
23 | get_sql_type_name(3000)
24 |
25 | def test_modify_sql_type() -> None:
26 | with pytest.raises(AttributeError, match="Cannot modify SQL type constant 'SQL_VARCHAR'"):
27 | SQLType.SQL_VARCHAR = 0
--------------------------------------------------------------------------------
/test/unit/helpers/__init__.py:
--------------------------------------------------------------------------------
1 | from .idp_helpers import make_redshift_property
2 |
--------------------------------------------------------------------------------
/test/unit/helpers/idp_helpers.py:
--------------------------------------------------------------------------------
1 | from redshift_connector import RedshiftProperty
2 | from redshift_connector.config import ClientProtocolVersion
3 |
4 |
5 | def make_redshift_property() -> RedshiftProperty:
6 | rp: RedshiftProperty = RedshiftProperty()
7 | rp.user_name = "mario@luigi.com"
8 | rp.password = "bowser"
9 | rp.db_name = "dev"
10 | rp.cluster_identifier = "something"
11 | rp.idp_host = "8000"
12 | rp.duration = 100
13 | rp.preferred_role = "analyst"
14 | rp.ssl_insecure = False
15 | rp.db_user = "primary"
16 | rp.db_groups = ["employees"]
17 | rp.force_lowercase = True
18 | rp.auto_create = False
19 | rp.region = "us-west-1"
20 | rp.principal = "arn:aws:iam::123456789012:user/Development/product_1234/*"
21 | rp.client_protocol_version = ClientProtocolVersion.BASE_SERVER
22 | return rp
23 |
--------------------------------------------------------------------------------
/test/unit/mocks/__init__.py:
--------------------------------------------------------------------------------
1 | from .mock_external_credential_provider import MockCredentialsProvider
2 |
--------------------------------------------------------------------------------
/test/unit/mocks/mock_external_credential_provider.py:
--------------------------------------------------------------------------------
1 | from redshift_connector.plugin import SamlCredentialsProvider
2 |
3 |
4 | class MockCredentialsProvider(SamlCredentialsProvider):
5 | def get_saml_assertion(self: "SamlCredentialsProvider"):
6 | return "mocked"
7 |
--------------------------------------------------------------------------------
/test/unit/mocks/mock_socket.py:
--------------------------------------------------------------------------------
1 | import socket
2 | import typing
3 |
4 |
5 | class MockSocket(socket.socket):
6 | mocked_data: typing.Optional[bytes] = None
7 |
8 | def __init__(self, family=-1, type=-1, proto=-1, fileno=None):
9 | pass
10 |
11 | def __exit__(self, exc_type, exc_val, exc_tb):
12 | pass
13 |
14 | def setsockopt(self, level, optname, value):
15 | pass
16 |
17 | def bind(self, address):
18 | pass
19 |
20 | def listen(self, __backlog=...):
21 | pass
22 |
23 | def settimeout(self, value):
24 | pass
25 |
26 | def accept(self):
27 | return (MockSocket(), "127.0.0.1")
28 |
29 | def recv(self, bufsize, flags=...):
30 | return self.mocked_data
31 |
32 | def close(self) -> None:
33 | pass
34 |
35 | def send(self, *args) -> None: # type: ignore
36 | pass
37 |
--------------------------------------------------------------------------------
/test/unit/plugin/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/aws/amazon-redshift-python-driver/5bbd061b03cbe555dbc783f44f716bcf8b4864c4/test/unit/plugin/__init__.py
--------------------------------------------------------------------------------
/test/unit/plugin/data/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/aws/amazon-redshift-python-driver/5bbd061b03cbe555dbc783f44f716bcf8b4864c4/test/unit/plugin/data/__init__.py
--------------------------------------------------------------------------------
/test/unit/plugin/data/browser_azure_data.py:
--------------------------------------------------------------------------------
1 | code: str = "helloworld"
2 | state: str = "abcdefghij"
3 | valid_response: bytes = (
4 | b"POST /redshift/ HTTP/1.1\r\nHost: localhost:51187\r\nConnection: keep-alive\r\nContent-Length: 695\r\n"
5 | b"Cache-Control: max-age=0\r\nUpgrade-Insecure-Requests: 1\r\nOrigin: null\r\n"
6 | b"Content-Type: application/x-www-form-urlencoded\r\nUser-Agent: Mozilla/5.0 (Macintosh; Intel Mac OS X 10_15_7) "
7 | b"AppleWebKit/537.36 (KHTML, like Gecko) Chrome/86.0.4240.75 Safari/537.36\r\nAccept: text/html,"
8 | b"application/xhtml+xml,application/xml;q=0.9,image/avif,image/webp,image/apng,*/*;q=0.8,"
9 | b"application/signed-exchange;v=b3;q=0.9\r\nSec-Fetch-Site: cross-site\r\nSec-Fetch-Mode: navigate\r\n"
10 | b"Sec-Fetch-Dest: document\r\nAccept-Encoding: gzip, deflate, br\r\nAccept-Language: en-US,en;q=0.9\r\n"
11 | b"\r\ncode=" + code.encode("utf-8") + b"&state=" + state.encode("utf-8") + b"&session_state=hooplah"
12 | )
13 |
14 | missing_state_response: bytes = (
15 | b"POST /redshift/ HTTP/1.1\r\nHost: localhost:51187\r\nConnection: keep-alive\r\nContent-Length: 695\r\n"
16 | b"Cache-Control: max-age=0\r\nUpgrade-Insecure-Requests: 1\r\nOrigin: null\r\n"
17 | b"Content-Type: application/x-www-form-urlencoded\r\nUser-Agent: Mozilla/5.0 (Macintosh; Intel Mac OS X 10_15_7) "
18 | b"AppleWebKit/537.36 (KHTML, like Gecko) Chrome/86.0.4240.75 Safari/537.36\r\nAccept: text/html,"
19 | b"application/xhtml+xml,application/xml;q=0.9,image/avif,image/webp,image/apng,*/*;q=0.8,"
20 | b"application/signed-exchange;v=b3;q=0.9\r\nSec-Fetch-Site: cross-site\r\nSec-Fetch-Mode: navigate\r\n"
21 | b"Sec-Fetch-Dest: document\r\nAccept-Encoding: gzip, deflate, br\r\nAccept-Language: en-US,en;q=0.9\r\n"
22 | b"\r\ncode=" + code.encode("utf-8")
23 | )
24 |
25 | mismatched_state_response: bytes = (
26 | b"POST /redshift/ HTTP/1.1\r\nHost: localhost:51187\r\nConnection: keep-alive\r\nContent-Length: 695\r\n"
27 | b"Cache-Control: max-age=0\r\nUpgrade-Insecure-Requests: 1\r\nOrigin: null\r\n"
28 | b"Content-Type: application/x-www-form-urlencoded\r\nUser-Agent: Mozilla/5.0 (Macintosh; Intel Mac OS X 10_15_7) "
29 | b"AppleWebKit/537.36 (KHTML, like Gecko) Chrome/86.0.4240.75 Safari/537.36\r\nAccept: text/html,"
30 | b"application/xhtml+xml,application/xml;q=0.9,image/avif,image/webp,image/apng,*/*;q=0.8,"
31 | b"application/signed-exchange;v=b3;q=0.9\r\nSec-Fetch-Site: cross-site\r\nSec-Fetch-Mode: navigate\r\n"
32 | b"Sec-Fetch-Dest: document\r\nAccept-Encoding: gzip, deflate, br\r\nAccept-Language: en-US,en;q=0.9\r\n"
33 | b"\r\ncode=" + code.encode("utf-8") + b"&state=" + state[::-1].encode("utf-8") + b"&session_state=hooplah"
34 | )
35 |
36 | missing_code_response: bytes = (
37 | b"POST /redshift/ HTTP/1.1\r\nHost: localhost:51187\r\nConnection: keep-alive\r\nContent-Length: 695\r\n"
38 | b"Cache-Control: max-age=0\r\nUpgrade-Insecure-Requests: 1\r\nOrigin: null\r\n"
39 | b"Content-Type: application/x-www-form-urlencoded\r\nUser-Agent: Mozilla/5.0 (Macintosh; Intel Mac OS X 10_15_7) "
40 | b"AppleWebKit/537.36 (KHTML, like Gecko) Chrome/86.0.4240.75 Safari/537.36\r\nAccept: text/html,"
41 | b"application/xhtml+xml,application/xml;q=0.9,image/avif,image/webp,image/apng,*/*;q=0.8,"
42 | b"application/signed-exchange;v=b3;q=0.9\r\nSec-Fetch-Site: cross-site\r\nSec-Fetch-Mode: navigate\r\n"
43 | b"Sec-Fetch-Dest: document\r\nAccept-Encoding: gzip, deflate, br\r\nAccept-Language: en-US,en;q=0.9\r\n"
44 | b"&state=" + state.encode("utf-8") + b"&session_state=hooplah"
45 | )
46 |
47 | empty_code_response: bytes = (
48 | b"POST /redshift/ HTTP/1.1\r\nHost: localhost:51187\r\nConnection: keep-alive\r\nContent-Length: 695\r\n"
49 | b"Cache-Control: max-age=0\r\nUpgrade-Insecure-Requests: 1\r\nOrigin: null\r\n"
50 | b"Content-Type: application/x-www-form-urlencoded\r\nUser-Agent: Mozilla/5.0 (Macintosh; Intel Mac OS X 10_15_7) "
51 | b"AppleWebKit/537.36 (KHTML, like Gecko) Chrome/86.0.4240.75 Safari/537.36\r\nAccept: text/html,"
52 | b"application/xhtml+xml,application/xml;q=0.9,image/avif,image/webp,image/apng,*/*;q=0.8,"
53 | b"application/signed-exchange;v=b3;q=0.9\r\nSec-Fetch-Site: cross-site\r\nSec-Fetch-Mode: navigate\r\n"
54 | b"Sec-Fetch-Dest: document\r\nAccept-Encoding: gzip, deflate, br\r\nAccept-Language: en-US,en;q=0.9\r\n"
55 | b"\r\ncode=" + b"&state=" + state.encode("utf-8") + b"&session_state=hooplah"
56 | )
57 |
58 |
59 | saml_response = b"my_access_token"
60 |
61 | valid_json_response: dict = {
62 | "token_type": "Bearer",
63 | "expires_in": "3599",
64 | "ext_expires_in": "3599",
65 | "expires_on": "1602782647",
66 | "resource": "spn:1234567891011121314151617181920",
67 | "access_token": "bXlfYWNjZXNzX3Rva2Vu", # base64.urlsafe_64encode(saml_response)
68 | "issued_token_type": "urn:ietf:params:oauth:token-type:saml2",
69 | "refresh_token": "my_refresh_token",
70 | "id_token": "my_id_token",
71 | }
72 |
73 | json_response_no_access_token: dict = {
74 | "token_type": "Bearer",
75 | "expires_in": "3599",
76 | "ext_expires_in": "3599",
77 | "expires_on": "1602782647",
78 | "resource": "spn:1234567891011121314151617181920",
79 | "issued_token_type": "urn:ietf:params:oauth:token-type:saml2",
80 | "refresh_token": "my_refresh_token",
81 | "id_token": "my_id_token",
82 | }
83 |
84 | json_response_empty_access_token: dict = {
85 | "token_type": "Bearer",
86 | "expires_in": "3599",
87 | "ext_expires_in": "3599",
88 | "expires_on": "1602782647",
89 | "resource": "spn:1234567891011121314151617181920",
90 | "access_token": "",
91 | "issued_token_type": "urn:ietf:params:oauth:token-type:saml2",
92 | "refresh_token": "my_refresh_token",
93 | "id_token": "my_id_token",
94 | }
95 |
--------------------------------------------------------------------------------
/test/unit/plugin/data/mock_adfs_saml_response.html:
--------------------------------------------------------------------------------
1 |
2 | Working...
3 |
4 |
9 |
10 |
11 |
12 |
--------------------------------------------------------------------------------
/test/unit/plugin/data/saml_response.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 | http://www.okta.com/exki777jkSgkPNu9L4x6
4 |
5 |
6 |
7 |
8 |
9 |
10 |
11 |
12 |
13 |
14 |
15 |
16 | testDigestValue1
17 |
18 |
19 | testSignature1
20 |
21 |
22 | testCertificate1
23 |
24 |
25 |
26 |
27 |
28 |
29 |
30 | http://www.okta.com/exki777jkSgkPNu9L4x6
31 |
32 |
33 |
34 |
35 |
36 |
37 |
38 |
39 |
40 |
41 |
42 |
43 | testDigestValue2=
44 |
45 |
46 | testSignature2
47 |
48 |
49 | testCertificate2
50 |
51 |
52 |
53 |
54 | example@example.com
55 |
56 |
57 |
58 |
59 |
60 |
61 | urn:amazon:webservices
62 |
63 |
64 |
65 |
66 | urn:oasis:names:tc:SAML:2.0:ac:classes:PasswordProtectedTransport
67 |
68 |
69 |
70 |
71 | arn:aws:iam::123456789012:role/myRole,arn:aws:iam::123456789012:saml-provider/myProvider
72 |
73 |
74 | example@example.com
75 |
76 |
77 | true
78 |
79 |
80 | example@example.com
81 |
82 |
83 |
84 |
85 |
--------------------------------------------------------------------------------
/test/unit/plugin/data/saml_response_data.py:
--------------------------------------------------------------------------------
1 | import base64
2 | import zlib
3 |
4 | # uses sample (valid) SAML response from data folder
5 | saml_response: bytes = open("test/unit/plugin/data/saml_response.xml").read().encode("utf-8")
6 |
7 | # the SAML response as received in HTTP response
8 | encoded_saml_response: bytes = base64.b64encode(zlib.compress(saml_response)[2:-4])
9 |
10 |
11 | # HTTP response containing SAML response can be formatted two ways.
12 | # 1. HTTP response containing things other than the SAML response
13 | valid_http_response_with_header_equal_delim: bytes = (
14 | b"POST /redshift/ HTTP/1.1\r\nHost: localhost:7890\r\nConnection: keep-alive\r\nContent-Length: "
15 | b"11639\r\nCache-Control: max-age=0\r\nUpgrade-Insecure-Requests: 1\r\nOrigin: null\r\nContent-Type: "
16 | b"application/x-www-form-urlencoded\r\nUser-Agent: Mozilla/5.0 (Macintosh; Intel Mac OS X 10_15_6) "
17 | b"AppleWebKit/537.36 (KHTML, like Gecko) Chrome/85.0.4183.102 Safari/537.36\r\nAccept: text/html,"
18 | b"application/xhtml+xml,application/xml;q=0.9,image/avif,image/webp,image/apng,/;q=0.8,"
19 | b"application/signed-exchange;v=b3;q=0.9\r\nSec-Fetch-Site: cross-site\r\nSec-Fetch-Mode: "
20 | b"navigate\r\nSec-Fetch-Dest: document\r\nAccept-Encoding: gzip, deflate, br\r\nAccept-Language: en-US,"
21 | b"en;q=0.9\r\n\r\nSAMLResponse=" + encoded_saml_response + b"&RelayState="
22 | )
23 |
24 | valid_http_response_with_header_colon_delim: bytes = (
25 | b"POST\r\nRelayState: \r\nSAMLResponse:\r\n" + encoded_saml_response
26 | )
27 |
28 | # 2. HTTP response containing *only* the SAML response
29 | valid_http_response_no_header: bytes = b"SAMLResponse=" + encoded_saml_response + b"&RelayState="
30 |
--------------------------------------------------------------------------------
/test/unit/plugin/test_azure_oauth2_credentials_provider.py:
--------------------------------------------------------------------------------
1 | import typing
2 |
3 | import pytest
4 | from pytest_mock import mocker # type: ignore
5 |
6 | from redshift_connector.error import InterfaceError
7 | from redshift_connector.plugin.browser_azure_oauth2_credentials_provider import (
8 | BrowserAzureOAuth2CredentialsProvider,
9 | )
10 | from redshift_connector.redshift_property import RedshiftProperty
11 |
12 |
13 | def make_valid_azure_oauth2_provider() -> typing.Tuple[BrowserAzureOAuth2CredentialsProvider, RedshiftProperty]:
14 | rp: RedshiftProperty = RedshiftProperty()
15 | rp.idp_tenant = "my_idp_tenant"
16 | rp.client_id = "my_client_id"
17 | rp.scope = "my_scope"
18 | rp.idp_response_timeout = 900
19 | rp.listen_port = 1099
20 | cp: BrowserAzureOAuth2CredentialsProvider = BrowserAzureOAuth2CredentialsProvider()
21 | cp.add_parameter(rp)
22 | return cp, rp
23 |
24 | def test_default_parameters_azure_oauth2_specific() -> None:
25 | acp, _ = make_valid_azure_oauth2_provider()
26 | assert acp.ssl_insecure == False
27 | assert acp.do_verify_ssl_cert() == True
28 |
29 |
30 | def test_add_parameter_sets_azure_oauth2_specific() -> None:
31 | acp, rp = make_valid_azure_oauth2_provider()
32 | assert acp.idp_tenant == rp.idp_tenant
33 | assert acp.client_id == rp.client_id
34 | assert acp.scope == rp.scope
35 | assert acp.idp_response_timeout == rp.idp_response_timeout
36 | assert acp.listen_port == rp.listen_port
37 |
38 |
39 | @pytest.mark.parametrize("value", [None, ""])
40 | def test_check_required_parameters_raises_if_idp_tenant_missing_or_too_small(value) -> None:
41 | acp, _ = make_valid_azure_oauth2_provider()
42 | acp.idp_tenant = value
43 |
44 | with pytest.raises(InterfaceError, match="Missing required connection property: idp_tenant"):
45 | acp.get_jwt_assertion()
46 |
47 |
48 | @pytest.mark.parametrize("value", [None, ""])
49 | def test_check_required_parameters_raises_if_client_id_missing(value) -> None:
50 | acp, _ = make_valid_azure_oauth2_provider()
51 | acp.client_id = value
52 |
53 | with pytest.raises(InterfaceError, match="Missing required connection property: client_id"):
54 | acp.get_jwt_assertion()
55 |
56 |
57 | @pytest.mark.parametrize("value", [None, ""])
58 | def test_check_required_parameters_raises_if_idp_response_timeout_missing(value) -> None:
59 | acp, _ = make_valid_azure_oauth2_provider()
60 | acp.idp_response_timeout = value
61 |
62 | with pytest.raises(
63 | InterfaceError,
64 | match="Invalid value specified for connection property: idp_response_timeout. Must be 10 seconds or greater",
65 | ):
66 | acp.get_jwt_assertion()
67 |
68 |
69 | def test_get_jwt_assertion_fetches_and_extracts(mocker) -> None:
70 | mock_token: str = "mock_token"
71 | mock_content: str = "mock_content"
72 | mock_jwt_assertion: str = "mock_jwt_assertion"
73 | mocker.patch(
74 | "redshift_connector.plugin.browser_azure_oauth2_credentials_provider."
75 | "BrowserAzureOAuth2CredentialsProvider.fetch_authorization_token",
76 | return_value=mock_token,
77 | )
78 | mocker.patch(
79 | "redshift_connector.plugin.browser_azure_oauth2_credentials_provider."
80 | "BrowserAzureOAuth2CredentialsProvider.fetch_jwt_response",
81 | return_value=mock_content,
82 | )
83 | mocker.patch(
84 | "redshift_connector.plugin.browser_azure_oauth2_credentials_provider."
85 | "BrowserAzureOAuth2CredentialsProvider.extract_jwt_assertion",
86 | return_value=mock_jwt_assertion,
87 | )
88 | acp, rp = make_valid_azure_oauth2_provider()
89 |
90 | fetch_token_spy = mocker.spy(acp, "fetch_authorization_token")
91 | fetch_jwt_spy = mocker.spy(acp, "fetch_jwt_response")
92 | extract_jwt_spy = mocker.spy(acp, "extract_jwt_assertion")
93 |
94 | jwt_assertion: str = acp.get_jwt_assertion()
95 |
96 | assert fetch_token_spy.called is True
97 | assert fetch_token_spy.call_count == 1
98 |
99 | assert fetch_jwt_spy.called is True
100 | assert fetch_jwt_spy.call_count == 1
101 | assert fetch_jwt_spy.call_args[0][0] == mock_token
102 |
103 | assert extract_jwt_spy.called is True
104 | assert extract_jwt_spy.call_count == 1
105 | assert extract_jwt_spy.call_args[0][0] == mock_content
106 |
107 | assert jwt_assertion == mock_jwt_assertion
108 |
--------------------------------------------------------------------------------
/test/unit/plugin/test_browser_saml_credentials_provider.py:
--------------------------------------------------------------------------------
1 | import typing
2 | from test.unit.mocks.mock_socket import MockSocket
3 | from test.unit.plugin.data import saml_response_data
4 | from unittest.mock import Mock, patch
5 |
6 | import pytest # type: ignore
7 |
8 | from redshift_connector.error import InterfaceError
9 | from redshift_connector.plugin.browser_saml_credentials_provider import (
10 | BrowserSamlCredentialsProvider,
11 | )
12 |
13 | http_response_datas: typing.List[bytes] = [
14 | saml_response_data.valid_http_response_with_header_equal_delim,
15 | saml_response_data.valid_http_response_with_header_colon_delim,
16 | saml_response_data.valid_http_response_no_header,
17 | ]
18 |
19 |
20 | @pytest.fixture(autouse=True)
21 | def cleanup_mock_socket() -> None:
22 | # cleans up class attribute that mocks data the socket receives
23 | MockSocket.mocked_data = None
24 |
25 |
26 | @pytest.mark.parametrize("http_response", http_response_datas)
27 | def test_run_server_parses_saml_response(http_response) -> None:
28 | MockSocket.mocked_data = http_response
29 | with patch("socket.socket", return_value=MockSocket()):
30 | browser_saml_credentials: BrowserSamlCredentialsProvider = BrowserSamlCredentialsProvider()
31 | parsed_saml_response: str = browser_saml_credentials.run_server(listen_port=0, idp_response_timeout=5)
32 | assert parsed_saml_response == saml_response_data.encoded_saml_response.decode("utf-8")
33 |
34 |
35 | invalid_login_url_vals: typing.List[typing.Optional[str]] = ["", None]
36 |
37 |
38 | @pytest.mark.parametrize("login_url", invalid_login_url_vals)
39 | def test_get_saml_assertion_no_login_url_should_fail(login_url) -> None:
40 | browser_saml_credentials: BrowserSamlCredentialsProvider = BrowserSamlCredentialsProvider()
41 | browser_saml_credentials.login_url = login_url
42 |
43 | with pytest.raises(InterfaceError) as ex:
44 | browser_saml_credentials.get_saml_assertion()
45 | assert "Missing required connection property: login_url" in str(ex.value)
46 |
47 |
48 | def test_get_saml_assertion_low_idp_response_timeout_should_fail() -> None:
49 | browser_saml_credentials: BrowserSamlCredentialsProvider = BrowserSamlCredentialsProvider()
50 | browser_saml_credentials.login_url = "https://www.example.com"
51 | browser_saml_credentials.idp_response_timeout = -1
52 |
53 | with pytest.raises(InterfaceError) as ex:
54 | browser_saml_credentials.get_saml_assertion()
55 | assert (
56 | "Invalid value specified for connection property: idp_response_timeout. Must be 10 seconds or greater"
57 | in str(ex.value)
58 | )
59 |
60 |
61 | invalid_listen_port_vals: typing.List[int] = [-1, 0, 65536]
62 |
63 |
64 | @pytest.mark.parametrize("listen_port", invalid_listen_port_vals)
65 | def test_get_saml_assertion_invalid_listen_port_should_fail(listen_port) -> None:
66 | browser_saml_credentials: BrowserSamlCredentialsProvider = BrowserSamlCredentialsProvider()
67 | browser_saml_credentials.login_url = "https://www.example.com"
68 | browser_saml_credentials.idp_response_timeout = 11
69 | browser_saml_credentials.listen_port = listen_port
70 |
71 | with pytest.raises(InterfaceError) as ex:
72 | browser_saml_credentials.get_saml_assertion()
73 | assert "Invalid value specified for connection property: listen_port. Must be in range [1,65535]" in str(ex.value)
74 |
75 |
76 | def test_authenticate_returns_authorization_token(mocker) -> None:
77 | browser_saml_credentials: BrowserSamlCredentialsProvider = BrowserSamlCredentialsProvider()
78 | mock_authorization_token: str = "my_authorization_token"
79 |
80 | mocker.patch("redshift_connector.plugin.BrowserSamlCredentialsProvider.open_browser", return_value=None)
81 | mocker.patch(
82 | "redshift_connector.plugin.BrowserSamlCredentialsProvider.run_server", return_value=mock_authorization_token
83 | )
84 |
85 | assert browser_saml_credentials.authenticate() == mock_authorization_token
86 |
87 |
88 | def test_authenticate_errors_should_fail(mocker) -> None:
89 | browser_saml_credentials: BrowserSamlCredentialsProvider = BrowserSamlCredentialsProvider()
90 |
91 | mocker.patch("redshift_connector.plugin.BrowserSamlCredentialsProvider.open_browser", return_value=None)
92 | with patch("redshift_connector.plugin.BrowserSamlCredentialsProvider.run_server") as mocked_server:
93 | mocked_server.side_effect = Exception("bad mistake")
94 |
95 | with pytest.raises(Exception, match="bad mistake"):
96 | browser_saml_credentials.authenticate()
97 |
98 |
99 | def test_open_browser_no_url_should_fail() -> None:
100 | browser_saml_credentials: BrowserSamlCredentialsProvider = BrowserSamlCredentialsProvider()
101 |
102 | with pytest.raises(InterfaceError) as ex:
103 | browser_saml_credentials.open_browser()
104 | assert "Missing required connection property: login_url" in str(ex.value)
105 |
--------------------------------------------------------------------------------
/test/unit/plugin/test_credentials_providers.py:
--------------------------------------------------------------------------------
1 | import configparser
2 | import os
3 | import typing
4 | from test import (
5 | adfs_idp,
6 | azure_browser_idp,
7 | azure_idp,
8 | idp_arg,
9 | jumpcloud_browser_idp,
10 | jwt_azure_v2_idp,
11 | jwt_google_idp,
12 | okta_browser_idp,
13 | okta_idp,
14 | ping_browser_idp,
15 | redshift_browser_idc,
16 | redshift_idp_token_auth_plugin,
17 | )
18 |
19 | import pytest # type: ignore
20 |
21 | import redshift_connector
22 |
23 | conf: configparser.ConfigParser = configparser.ConfigParser()
24 | root_path: str = os.path.dirname(os.path.dirname(os.path.abspath(os.path.join(__file__, os.pardir))))
25 | conf.read(root_path + "/config.ini")
26 |
27 |
28 | NON_BROWSER_IDP: typing.List[str] = ["okta_idp", "azure_idp", "adfs_idp"]
29 | NON_BROWSER_IDC: typing.List[str] = ["redshift_idp_token_auth_plugin"]
30 |
31 | ALL_IDP: typing.List[str] = [
32 | "okta_browser_idp",
33 | "azure_browser_idp",
34 | "jumpcloud_browser_idp",
35 | "ping_browser_idp",
36 | "jwt_google_idp",
37 | "jwt_azure_v2_idp",
38 | ] + NON_BROWSER_IDP
39 |
40 |
41 | @pytest.mark.parametrize("idp_arg", ALL_IDP, indirect=True)
42 | def test_credential_provider_dne_should_fail(idp_arg) -> None:
43 | idp_arg["credentials_provider"] = "WrongProvider"
44 | with pytest.raises(
45 | redshift_connector.InterfaceError, match="Invalid IdP specified in credential_provider connection parameter"
46 | ):
47 | redshift_connector.connect(**idp_arg)
48 |
49 |
50 | @pytest.mark.parametrize("idp_arg", ALL_IDP, indirect=True)
51 | def test_ssl_and_iam_invalid_should_fail(idp_arg) -> None:
52 | idp_arg["ssl"] = False
53 | idp_arg["iam"] = True
54 | with pytest.raises(
55 | redshift_connector.InterfaceError,
56 | match="Invalid connection property setting. SSL must be enabled when using IAM",
57 | ):
58 | redshift_connector.connect(**idp_arg)
59 |
60 | idp_arg["ssl"] = True
61 | idp_arg["credentials_provider"] = "OktacredentialSProvider"
62 | with pytest.raises(
63 | redshift_connector.InterfaceError,
64 | match="Invalid IdP specified in credential_provider connection parameter",
65 | ):
66 | redshift_connector.connect(**idp_arg)
67 |
68 |
69 | @pytest.mark.parametrize("idp_arg", NON_BROWSER_IDC, indirect=True)
70 | def test_using_iam_should_fail(idp_arg) -> None:
71 | idp_arg["iam"] = True
72 | with pytest.raises(
73 | redshift_connector.InterfaceError,
74 | match="You can not use this authentication plugin with IAM enabled.",
75 | ):
76 | redshift_connector.connect(**idp_arg)
77 |
--------------------------------------------------------------------------------
/test/unit/plugin/test_idp_token_auth_plugin.py:
--------------------------------------------------------------------------------
1 | import pytest
2 |
3 | from redshift_connector import IamHelper, InterfaceError, RedshiftProperty
4 | from redshift_connector.plugin import IdpTokenAuthPlugin
5 | from redshift_connector.plugin.native_token_holder import NativeTokenHolder
6 |
7 |
8 | def test_should_fail_without_token():
9 | itap: IdpTokenAuthPlugin = IdpTokenAuthPlugin()
10 | itap.token_type = "blah"
11 |
12 | with pytest.raises(
13 | InterfaceError, match="IdC authentication failed: The token must be included in the connection parameters."
14 | ):
15 | itap.check_required_parameters()
16 |
17 |
18 | def test_should_fail_without_token_type():
19 | itap: IdpTokenAuthPlugin = IdpTokenAuthPlugin()
20 | itap.token = "blah"
21 |
22 | with pytest.raises(
23 | InterfaceError, match="IdC authentication failed: The token type must be included in the connection parameters."
24 | ):
25 | itap.check_required_parameters()
26 |
27 |
28 | def test_get_auth_token_calls_check_required_parameters(mocker):
29 | spy = mocker.spy(IdpTokenAuthPlugin, "check_required_parameters")
30 | itap: IdpTokenAuthPlugin = IdpTokenAuthPlugin()
31 | itap.token = "my_token"
32 | itap.token_type = "testing_token"
33 |
34 | itap.get_auth_token()
35 | assert spy.called
36 | assert spy.call_count == 1
37 |
38 |
39 | def test_get_auth_token_returns_token():
40 | itap: IdpTokenAuthPlugin = IdpTokenAuthPlugin()
41 | itap.token = "my_token"
42 | itap.token_type = "testing_token"
43 |
44 | result = itap.get_auth_token()
45 | assert result == "my_token"
46 |
47 |
48 | def test_add_parameter_sets_token():
49 | itap: IdpTokenAuthPlugin = IdpTokenAuthPlugin()
50 | rp: RedshiftProperty = RedshiftProperty()
51 | token_value: str = "a token of appreciation"
52 | rp.token = token_value
53 | itap.add_parameter(rp)
54 | assert itap.token == token_value
55 |
56 |
57 | def test_add_parameter_sets_token_type():
58 | itap: IdpTokenAuthPlugin = IdpTokenAuthPlugin()
59 | rp: RedshiftProperty = RedshiftProperty()
60 | token_type_value: str = "appreciative token"
61 | rp.token_type = token_type_value
62 | itap.add_parameter(rp)
63 | assert itap.token_type == token_type_value
64 |
65 |
66 | def test_get_sub_type_is_idc():
67 | itap: IdpTokenAuthPlugin = IdpTokenAuthPlugin()
68 | assert itap.get_sub_type() == IamHelper.IDC_PLUGIN
69 |
70 |
71 | def test_cache_disabled_by_default():
72 | itap: IdpTokenAuthPlugin = IdpTokenAuthPlugin()
73 | rp: RedshiftProperty = RedshiftProperty()
74 | rp.token_type = "happy token"
75 | rp.token = "hello world"
76 | itap.add_parameter(rp)
77 | assert itap.disable_cache == True
78 |
79 |
80 | def test_get_credentials_calls_refresh(mocker):
81 | itap: IdpTokenAuthPlugin = IdpTokenAuthPlugin()
82 | rp: RedshiftProperty = RedshiftProperty()
83 | rp.token_type = "happy token"
84 | rp.token = "hello world"
85 | itap.add_parameter(rp)
86 | mocker.patch("redshift_connector.plugin.IdpTokenAuthPlugin.refresh", return_value=None)
87 | spy = mocker.spy(IdpTokenAuthPlugin, "refresh")
88 | itap.get_credentials()
89 | assert spy.called
90 | assert spy.call_count == 1
91 |
92 |
93 | def test_refresh_sets_credential():
94 | itap: IdpTokenAuthPlugin = IdpTokenAuthPlugin()
95 | rp: RedshiftProperty = RedshiftProperty()
96 | rp.token_type = "happy token"
97 | rp.token = "hello world"
98 | itap.add_parameter(rp)
99 |
100 | itap.refresh()
101 | result: NativeTokenHolder = itap.last_refreshed_credentials
102 | assert result is not None
103 | assert isinstance(result, NativeTokenHolder)
104 | assert result.access_token == rp.token
105 |
106 |
107 | def test_get_credentials_returns_credential():
108 | itap: IdpTokenAuthPlugin = IdpTokenAuthPlugin()
109 | rp: RedshiftProperty = RedshiftProperty()
110 | rp.token_type = "happy token"
111 | rp.token = "hello world"
112 | itap.add_parameter(rp)
113 |
114 | result: NativeTokenHolder = itap.get_credentials()
115 | assert result is not None
116 | assert isinstance(result, NativeTokenHolder)
117 | assert result.access_token == rp.token
118 |
--------------------------------------------------------------------------------
/test/unit/plugin/test_plugin_inheritance.py:
--------------------------------------------------------------------------------
1 | import pytest
2 |
3 | from redshift_connector import plugin
4 | from redshift_connector.plugin.i_native_plugin import INativePlugin
5 | from redshift_connector.plugin.i_plugin import IPlugin
6 | from redshift_connector.plugin.idp_credentials_provider import IdpCredentialsProvider
7 |
8 |
9 | def test_i_native_plugin_inherits_from_i_plugin() -> None:
10 | assert issubclass(INativePlugin, IPlugin)
11 |
12 |
13 | def test_idp_credentials_provider_inherits_from_i_plugin() -> None:
14 | assert issubclass(IdpCredentialsProvider, IPlugin)
15 |
16 |
17 | def test_saml_provider_plugin_inherit_from_idp_credentials_provider() -> None:
18 | assert issubclass(plugin.SamlCredentialsProvider, IdpCredentialsProvider)
19 |
20 |
21 | def test_jwt_abc_inherit_from_idp_credentials_provider() -> None:
22 | assert issubclass(plugin.JwtCredentialsProvider, IdpCredentialsProvider)
23 |
24 |
25 | saml_provider_plugins = (
26 | plugin.BrowserSamlCredentialsProvider,
27 | plugin.OktaCredentialsProvider,
28 | plugin.AdfsCredentialsProvider,
29 | plugin.AzureCredentialsProvider,
30 | plugin.BrowserSamlCredentialsProvider,
31 | plugin.BrowserAzureCredentialsProvider,
32 | )
33 |
34 |
35 | @pytest.mark.parametrize("saml_plugin", saml_provider_plugins)
36 | def test_saml_provider_plugins_inherit_from_saml_credentials_provider(saml_plugin):
37 | assert issubclass(saml_plugin, plugin.SamlCredentialsProvider)
38 |
39 |
40 | jwt_plugins = (plugin.BrowserAzureOAuth2CredentialsProvider, plugin.BasicJwtCredentialsProvider)
41 |
42 |
43 | @pytest.mark.parametrize("jwt_plugin", jwt_plugins)
44 | def test_jwt_plugins_inherit_from_jwt_abc(jwt_plugin):
45 | assert issubclass(jwt_plugin, plugin.JwtCredentialsProvider)
46 |
--------------------------------------------------------------------------------
/test/unit/plugin/test_saml_credentials_provider.py:
--------------------------------------------------------------------------------
1 | import base64
2 | import typing
3 | from test.unit.plugin.data import saml_response_data
4 | from unittest.mock import MagicMock, patch
5 |
6 | import pytest # type: ignore
7 |
8 | from redshift_connector import InterfaceError, RedshiftProperty
9 | from redshift_connector.credentials_holder import CredentialsHolder
10 | from redshift_connector.plugin import SamlCredentialsProvider
11 |
12 |
13 | @patch.multiple(SamlCredentialsProvider, __abstractmethods__=set())
14 | def make_valid_saml_credentials_provider() -> typing.Tuple[SamlCredentialsProvider, RedshiftProperty]:
15 | rp: RedshiftProperty = RedshiftProperty()
16 | rp.user_name = "AzureDiamond"
17 | rp.password = "hunter2"
18 | scp: SamlCredentialsProvider = SamlCredentialsProvider() # type: ignore
19 | scp.add_parameter(rp)
20 | return scp, rp
21 |
22 |
23 | def test_default_parameters_saml_credentials_provider() -> None:
24 | acp, _ = make_valid_saml_credentials_provider()
25 | assert acp.ssl_insecure == False
26 | assert acp.do_verify_ssl_cert() == True
27 |
28 |
29 | def test_get_cache_key_format_as_expected() -> None:
30 | scp, _ = make_valid_saml_credentials_provider()
31 | expected_cache_key: str = "{username}{password}{idp_host}{idp_port}{duration}{preferred_role}".format(
32 | username=scp.user_name,
33 | password=scp.password,
34 | idp_host=scp.idp_host,
35 | idp_port=scp.idpPort,
36 | duration=scp.duration,
37 | preferred_role=scp.preferred_role,
38 | )
39 | assert scp.get_cache_key() == expected_cache_key
40 |
41 |
42 | def test_get_credentials_uses_cache_when_exists(mocker) -> None:
43 | scp, _ = make_valid_saml_credentials_provider()
44 | mock_credentials = MagicMock()
45 | mock_credentials.is_expired.return_value = False
46 | scp.cache[scp.get_cache_key()] = mock_credentials
47 |
48 | spy = mocker.spy(SamlCredentialsProvider, "refresh")
49 |
50 | assert scp.get_credentials() == mock_credentials
51 | assert spy.called is False
52 |
53 |
54 | def test_get_credentials_calls_refresh_when_cache_expired(mocker) -> None:
55 | scp, _ = make_valid_saml_credentials_provider()
56 | mock_credentials = MagicMock()
57 | mock_credentials.is_expired.return_value = True
58 | scp.cache[scp.get_cache_key()] = mock_credentials
59 |
60 | mocker.patch("redshift_connector.plugin.SamlCredentialsProvider.refresh", return_value=None)
61 | spy = mocker.spy(SamlCredentialsProvider, "refresh")
62 |
63 | scp.get_credentials()
64 |
65 | assert spy.called
66 | assert spy.call_count == 1
67 |
68 |
69 | def test_get_credentials_sets_db_user_when_present(mocker) -> None:
70 | scp, _ = make_valid_saml_credentials_provider()
71 | mocked_db_user: str = "test_db_user"
72 | scp.db_user = mocked_db_user
73 | mock_credentials = MagicMock()
74 | mock_credentials.is_expired.return_value = True
75 | mock_credentials.metadata = MagicMock()
76 | scp.cache[scp.get_cache_key()] = mock_credentials
77 |
78 | mocker.patch("redshift_connector.plugin.SamlCredentialsProvider.refresh", return_value=None)
79 | spy = mocker.spy(mock_credentials.metadata, "set_db_user")
80 |
81 | scp.get_credentials()
82 |
83 | assert spy.called
84 | assert spy.call_count == 1
85 | assert spy.call_args[0][0] == mocked_db_user
86 |
87 |
88 | def test_refresh_get_saml_assertion_fails(mocker) -> None:
89 | scp, _ = make_valid_saml_credentials_provider()
90 |
91 | with patch("redshift_connector.plugin.SamlCredentialsProvider.get_saml_assertion") as mocked_get_saml_assertion:
92 | mocked_get_saml_assertion.side_effect = Exception("bad robot")
93 |
94 | with pytest.raises(InterfaceError, match="Failed to get SAML assertion"):
95 | scp.refresh()
96 |
97 |
98 | def test_refresh_saml_assertion_missing_role_should_fail(mocker) -> None:
99 | scp, _ = make_valid_saml_credentials_provider()
100 | mocked_data: str = "test"
101 | mocker.patch("redshift_connector.plugin.SamlCredentialsProvider.get_saml_assertion", return_value=mocked_data)
102 |
103 | with pytest.raises(
104 | InterfaceError,
105 | match="No roles were found in SAML assertion. Please verify IdP configuration provides ARNs in the SAML https://aws.amazon.com/SAML/Attributes/Role Attribute.",
106 | ):
107 | scp.refresh()
108 |
109 |
110 | def test_refresh_saml_assertion_passed_to_boto(mocker) -> None:
111 | scp, _ = make_valid_saml_credentials_provider()
112 | mocker.patch(
113 | "redshift_connector.plugin.SamlCredentialsProvider.get_saml_assertion",
114 | return_value=base64.b64encode(saml_response_data.saml_response),
115 | )
116 |
117 | mocked_response: typing.Dict[str, typing.Any] = {
118 | "Credentials": {"Expiration": "test_expiry", "Things": "much data", "Other Things": "more data"}
119 | }
120 |
121 | mocked_boto = MagicMock()
122 | mocked_boto_response = MagicMock()
123 | mocked_boto_response.return_value = mocked_response
124 | mocked_boto.assume_role_with_saml = mocked_boto_response
125 |
126 | credential_holder_spy = mocker.spy(CredentialsHolder, "__init__")
127 |
128 | mocker.patch("boto3.client", return_value=mocked_boto)
129 |
130 | scp.refresh()
131 | assert credential_holder_spy.called
132 | assert credential_holder_spy.call_count == 1
133 | assert credential_holder_spy.call_args[0][1] == mocked_response["Credentials"]
134 |
--------------------------------------------------------------------------------
/test/unit/test_array_util.py:
--------------------------------------------------------------------------------
1 | import typing
2 |
3 | import pytest # type: ignore
4 |
5 | from redshift_connector.utils import array_util
6 |
7 | walk_array_data: typing.List = [
8 | (
9 | [10, 9, 8, 7, 6],
10 | [
11 | ([10, 9, 8, 7, 6], 0, 10),
12 | ([10, 9, 8, 7, 6], 1, 9),
13 | ([10, 9, 8, 7, 6], 2, 8),
14 | ([10, 9, 8, 7, 6], 3, 7),
15 | ([10, 9, 8, 7, 6], 4, 6),
16 | ],
17 | ),
18 | (
19 | [1, 2, [3, 4, 5], 6],
20 | [
21 | ([1, 2, [3, 4, 5], 6], 0, 1),
22 | ([1, 2, [3, 4, 5], 6], 1, 2),
23 | ([3, 4, 5], 0, 3),
24 | ([3, 4, 5], 1, 4),
25 | ([3, 4, 5], 2, 5),
26 | ([1, 2, [3, 4, 5], 6], 3, 6),
27 | ],
28 | ),
29 | ]
30 |
31 |
32 | @pytest.mark.parametrize("_input", walk_array_data)
33 | def test_walk_array(_input) -> None:
34 | in_val, exp_vals = _input
35 | x: typing.Generator = array_util.walk_array(in_val)
36 | idx: int = 0
37 | for a, b, c in x:
38 | assert a == exp_vals[idx][0]
39 | assert b == exp_vals[idx][1]
40 | assert c == exp_vals[idx][2]
41 | idx += 1
42 |
43 |
44 | array_flatten_data: typing.List = [
45 | ([1, 2, 3, 4], [1, 2, 3, 4]),
46 | ([1, [2], 3, 4], [1, 2, 3, 4]),
47 | ([1, [2, [3]], 4, [5, [6]]], [1, 2, 3, 4, 5, 6]),
48 | ([[1]], [1]),
49 | ]
50 |
51 |
52 | @pytest.mark.parametrize("_input", array_flatten_data)
53 | def test_array_flatten(_input) -> None:
54 | in_val, exp_val = _input
55 | assert 1 == 1
56 | assert list(array_util.array_flatten(in_val)) == exp_val
57 |
58 |
59 | array_find_first_element_data: typing.List = [
60 | ([1], 1),
61 | ([None, None, [None, None, 1], 2], 1),
62 | ([[[1]]], 1),
63 | ([None, None, [None]], None),
64 | ]
65 |
66 |
67 | @pytest.mark.parametrize("_input", array_find_first_element_data)
68 | def test_array_find_first_element(_input) -> None:
69 | in_val, exp_val = _input
70 | assert array_util.array_find_first_element(in_val) == exp_val
71 |
72 |
73 | array_has_null_data: typing.List = [
74 | ([None], True),
75 | ([1, 2, 3, 4, [[[None]]]], True),
76 | ([1, 2, 3, 4], False),
77 | ]
78 |
79 |
80 | @pytest.mark.parametrize("_input", array_has_null_data)
81 | def test_array_has_null(_input) -> None:
82 | in_val, exp_val = _input
83 | assert array_util.array_has_null(in_val) is exp_val
84 |
--------------------------------------------------------------------------------
/test/unit/test_credentials_holder.py:
--------------------------------------------------------------------------------
1 | import datetime
2 | import typing
3 | from unittest.mock import MagicMock
4 |
5 | import pytest # type: ignore
6 |
7 | from redshift_connector.credentials_holder import (
8 | ABCAWSCredentialsHolder,
9 | ABCCredentialsHolder,
10 | AWSDirectCredentialsHolder,
11 | AWSProfileCredentialsHolder,
12 | CredentialsHolder,
13 | )
14 |
15 |
16 | @pytest.mark.parametrize("cred_holder", (AWSDirectCredentialsHolder, AWSProfileCredentialsHolder))
17 | def test_aws_credentials_holder_inherit_from_abc(cred_holder) -> None:
18 | assert issubclass(cred_holder, ABCAWSCredentialsHolder)
19 |
20 |
21 | def test_credentials_holder_inherits_from_abc() -> None:
22 | assert issubclass(CredentialsHolder, ABCCredentialsHolder)
23 |
24 |
25 | def test_aws_direct_credentials_holder_should_have_session() -> None:
26 | mocked_session: MagicMock = MagicMock()
27 | obj: AWSDirectCredentialsHolder = AWSDirectCredentialsHolder(
28 | access_key_id="something",
29 | secret_access_key="secret",
30 | session_token="fornow",
31 | session=mocked_session,
32 | )
33 |
34 | assert isinstance(obj, ABCAWSCredentialsHolder)
35 | assert hasattr(obj, "get_boto_session")
36 | assert obj.has_associated_session == True
37 | assert obj.get_boto_session() == mocked_session
38 |
39 |
40 | valid_aws_direct_credential_params: typing.List[typing.Dict[str, typing.Optional[str]]] = [
41 | {
42 | "access_key_id": "something",
43 | "secret_access_key": "secret",
44 | "session_token": "fornow",
45 | },
46 | {
47 | "access_key_id": "something",
48 | "secret_access_key": "secret",
49 | "session_token": None,
50 | },
51 | ]
52 |
53 |
54 | @pytest.mark.parametrize("input", valid_aws_direct_credential_params)
55 | def test_aws_direct_credentials_holder_get_session_credentials(input) -> None:
56 | input["session"] = MagicMock()
57 | obj: AWSDirectCredentialsHolder = AWSDirectCredentialsHolder(**input)
58 |
59 | ret_value: typing.Dict[str, str] = obj.get_session_credentials()
60 |
61 | assert len(ret_value) == 3 if input["session_token"] is not None else 2
62 |
63 | assert ret_value["aws_access_key_id"] == input["access_key_id"]
64 | assert ret_value["aws_secret_access_key"] == input["secret_access_key"]
65 |
66 | if input["session_token"] is not None:
67 | assert ret_value["aws_session_token"] == input["session_token"]
68 |
69 |
70 | def test_aws_profile_credentials_holder_should_have_session() -> None:
71 | mocked_session: MagicMock = MagicMock()
72 | obj: AWSProfileCredentialsHolder = AWSProfileCredentialsHolder(profile="myprofile", session=mocked_session)
73 |
74 | assert isinstance(obj, ABCAWSCredentialsHolder)
75 | assert hasattr(obj, "get_boto_session")
76 | assert obj.has_associated_session == True
77 | assert obj.get_boto_session() == mocked_session
78 |
79 |
80 | def test_aws_profile_credentials_holder_get_session_credentials() -> None:
81 | profile_val: str = "myprofile"
82 | obj: AWSProfileCredentialsHolder = AWSProfileCredentialsHolder(profile=profile_val, session=MagicMock())
83 |
84 | ret_value = obj.get_session_credentials()
85 | assert len(ret_value) == 1
86 |
87 | assert ret_value["profile"] == profile_val
88 |
89 |
90 | @pytest.mark.parametrize(
91 | "expiration_delta",
92 | [
93 | datetime.timedelta(hours=3), # expired 3 hrs ago
94 | datetime.timedelta(days=1), # expired 1 day ago
95 | datetime.timedelta(weeks=1), # expired 1 week ago
96 | ],
97 | )
98 | def test_is_expired_true(expiration_delta) -> None:
99 | credentials: typing.Dict[str, typing.Any] = {
100 | "AccessKeyId": "something",
101 | "SecretAccessKey": "secret",
102 | "SessionToken": "fornow",
103 | "Expiration": datetime.datetime.now(datetime.timezone.utc) - expiration_delta,
104 | }
105 |
106 | obj: CredentialsHolder = CredentialsHolder(credentials=credentials)
107 |
108 | assert obj.is_expired() == True
109 |
110 |
111 | @pytest.mark.parametrize(
112 | "expiration_delta",
113 | [
114 | datetime.timedelta(minutes=2), # expired 1 minute ago
115 | datetime.timedelta(hours=3), # expired 3 hrs ago
116 | datetime.timedelta(days=1), # expired 1 day ago
117 | datetime.timedelta(weeks=1), # expired 1 week ago
118 | ],
119 | )
120 | def test_is_expired_false(expiration_delta) -> None:
121 | credentials: typing.Dict[str, typing.Any] = {
122 | "AccessKeyId": "something",
123 | "SecretAccessKey": "secret",
124 | "SessionToken": "fornow",
125 | "Expiration": datetime.datetime.now(datetime.timezone.utc) + expiration_delta,
126 | }
127 |
128 | obj: CredentialsHolder = CredentialsHolder(credentials=credentials)
129 |
130 | assert obj.is_expired() == False
131 |
--------------------------------------------------------------------------------
/test/unit/test_dbapi20.py:
--------------------------------------------------------------------------------
1 | import time
2 |
3 | import redshift_connector
4 |
5 | driver = redshift_connector
6 |
7 |
8 | def test_apilevel() -> None:
9 | # Must exist
10 | apilevel: str = driver.apilevel
11 |
12 | # Must equal 2.0
13 | assert apilevel == "2.0"
14 |
15 |
16 | def test_threadsafety() -> None:
17 | try:
18 | # Must exist
19 | threadsafety: int = driver.threadsafety
20 | # Must be a valid value
21 | assert threadsafety in (0, 1, 2, 3)
22 | except AttributeError:
23 | assert False, "Driver doesn't define threadsafety"
24 |
25 |
26 | def test_paramstyle() -> None:
27 | from redshift_connector.config import DbApiParamstyle
28 |
29 | try:
30 | # Must exist
31 | paramstyle: str = driver.paramstyle
32 | # Must be a valid value
33 | assert paramstyle in DbApiParamstyle.list()
34 | except AttributeError:
35 | assert False, "Driver doesn't define paramstyle"
36 |
37 |
38 | def test_Exceptions() -> None:
39 | # Make sure required exceptions exist, and are in the
40 | # defined heirarchy.
41 | assert issubclass(driver.Warning, Exception)
42 | assert issubclass(driver.Error, Exception)
43 | assert issubclass(driver.InterfaceError, driver.Error)
44 | assert issubclass(driver.DatabaseError, driver.Error)
45 | assert issubclass(driver.OperationalError, driver.Error)
46 | assert issubclass(driver.IntegrityError, driver.Error)
47 | assert issubclass(driver.InternalError, driver.Error)
48 | assert issubclass(driver.ProgrammingError, driver.Error)
49 | assert issubclass(driver.NotSupportedError, driver.Error)
50 |
51 |
52 | def test_Date() -> None:
53 | driver.Date(2002, 12, 25)
54 | driver.DateFromTicks(time.mktime((2002, 12, 25, 0, 0, 0, 0, 0, 0)))
55 |
56 |
57 | def test_Time() -> None:
58 | driver.Time(13, 45, 30)
59 | driver.TimeFromTicks(time.mktime((2001, 1, 1, 13, 45, 30, 0, 0, 0)))
60 |
61 |
62 | def test_Timestamp() -> None:
63 | driver.Timestamp(2002, 12, 25, 13, 45, 30)
64 | driver.TimestampFromTicks(time.mktime((2002, 12, 25, 13, 45, 30, 0, 0, 0)))
65 |
66 |
67 | def test_Binary() -> None:
68 | driver.Binary(b"Something")
69 | driver.Binary(b"")
70 |
71 |
72 | def test_STRING() -> None:
73 | assert hasattr(driver, "STRING"), "module.STRING must be defined"
74 |
75 |
76 | def test_BINARY() -> None:
77 | assert hasattr(driver, "BINARY"), "module.BINARY must be defined."
78 |
79 |
80 | def test_NUMBER() -> None:
81 | assert hasattr(driver, "NUMBER"), "module.NUMBER must be defined."
82 |
83 |
84 | def test_DATETIME() -> None:
85 | assert hasattr(driver, "DATETIME"), "module.DATETIME must be defined."
86 |
87 |
88 | def test_ROWID() -> None:
89 | assert hasattr(driver, "ROWID"), "module.ROWID must be defined."
90 |
--------------------------------------------------------------------------------
/test/unit/test_driver_info.py:
--------------------------------------------------------------------------------
1 | import re
2 |
3 | from redshift_connector.utils import DriverInfo
4 |
5 |
6 | def test_version_is_not_none() -> None:
7 | assert DriverInfo.version() is not None
8 |
9 |
10 | def test_version_is_str() -> None:
11 | assert isinstance(DriverInfo.version(), str)
12 |
13 |
14 | def test_version_proper_format() -> None:
15 | version_regex: re.Pattern = re.compile(r"^\d+(\.\d+){2,3}$")
16 | assert version_regex.match(DriverInfo.version())
17 |
18 |
19 | def test_driver_name_is_not_none() -> None:
20 | assert DriverInfo.driver_name() is not None
21 |
22 |
23 | def test_driver_short_name_is_not_none() -> None:
24 | assert DriverInfo.driver_short_name() is not None
25 |
26 |
27 | def test_driver_full_name_is_not_none() -> None:
28 | assert DriverInfo.driver_full_name() is not None
29 |
30 |
31 | def test_driver_full_name_contains_name() -> None:
32 | assert DriverInfo.driver_name() in DriverInfo.driver_full_name()
33 |
34 |
35 | def test_driver_full_name_contains_version() -> None:
36 | assert DriverInfo.version() in DriverInfo.driver_full_name()
37 |
--------------------------------------------------------------------------------
/test/unit/test_idp_auth_helper.py:
--------------------------------------------------------------------------------
1 | from unittest.mock import MagicMock
2 |
3 | from packaging.version import Version
4 | from redshift_connector.idp_auth_helper import IdpAuthHelper
5 |
6 |
7 | def test_get_pkg_version(mocker) -> None:
8 | mocker.patch("importlib.metadata.version", return_value=None)
9 |
10 | module_mock = MagicMock()
11 | module_mock.__version__ = "9.8.7"
12 | mocker.patch("importlib.import_module", return_value=module_mock)
13 |
14 | actual_version: Version = IdpAuthHelper.get_pkg_version("test_module")
15 |
16 | assert actual_version == Version("9.8.7")
17 |
--------------------------------------------------------------------------------
/test/unit/test_import.py:
--------------------------------------------------------------------------------
1 | def test_import_redshift_connector() -> None:
2 | import redshift_connector
3 |
--------------------------------------------------------------------------------
/test/unit/test_logging_utils.py:
--------------------------------------------------------------------------------
1 | import pytest # type: ignore
2 |
3 | from redshift_connector import RedshiftProperty
4 | from redshift_connector.utils.logging_utils import mask_secure_info_in_props
5 |
6 | secret_rp_values = (
7 | "password",
8 | "access_key_id",
9 | "session_token",
10 | "secret_access_key",
11 | "client_id",
12 | "client_secret",
13 | "web_identity_token",
14 | )
15 |
16 |
17 | @pytest.mark.parametrize("rp_arg", secret_rp_values)
18 | def test_mask_secure_info_in_props_obscures_secret_value(rp_arg) -> None:
19 | rp: RedshiftProperty = RedshiftProperty()
20 | secret_value: str = "SECRET_VALUE"
21 | rp.put(rp_arg, secret_value)
22 | result = mask_secure_info_in_props(rp)
23 | assert result.__getattribute__(rp_arg) != secret_value
24 |
25 |
26 | def test_mask_secure_info_in_props_no_info() -> None:
27 | assert mask_secure_info_in_props(None) is None # type: ignore
28 |
--------------------------------------------------------------------------------
/test/unit/test_module.py:
--------------------------------------------------------------------------------
1 | import pytest
2 |
3 | import redshift_connector
4 | from redshift_connector.utils.oids import RedshiftOID
5 |
6 |
7 | @pytest.mark.parametrize(
8 | "error_class",
9 | (
10 | "Warning",
11 | "Error",
12 | "InterfaceError",
13 | "DatabaseError",
14 | "OperationalError",
15 | "IntegrityError",
16 | "InternalError",
17 | "ProgrammingError",
18 | "NotSupportedError",
19 | "ArrayContentNotSupportedError",
20 | "ArrayContentNotHomogenousError",
21 | "ArrayDimensionsNotConsistentError",
22 | ),
23 | )
24 | def test_errors_available_on_module(error_class) -> None:
25 | import importlib
26 |
27 | getattr(importlib.import_module("redshift_connector"), error_class)
28 |
29 |
30 | def test_cursor_on_module() -> None:
31 | import importlib
32 |
33 | getattr(importlib.import_module("redshift_connector"), "Cursor")
34 |
35 |
36 | def test_connection_on_module() -> None:
37 | import importlib
38 |
39 | getattr(importlib.import_module("redshift_connector"), "Connection")
40 |
41 |
42 | def test_version_on_module() -> None:
43 | import importlib
44 |
45 | getattr(importlib.import_module("redshift_connector"), "__version__")
46 |
47 |
48 | @pytest.mark.parametrize("datatype", [d.name for d in RedshiftOID])
49 | def test_datatypes_on_module(datatype) -> None:
50 | assert datatype in redshift_connector.__all__
51 |
--------------------------------------------------------------------------------
/test/unit/test_native_plugin_helper.py:
--------------------------------------------------------------------------------
1 | from unittest.mock import MagicMock
2 |
3 | from pytest_mock import mocker # type: ignore
4 |
5 | from redshift_connector.iam_helper import IamHelper, IdpAuthHelper
6 | from redshift_connector.native_plugin_helper import NativeAuthPluginHelper
7 |
8 |
9 | def test_set_native_auth_plugin_properties_gets_idp_token_when_credentials_provider(mocker) -> None:
10 | mocked_idp_token: str = "my_idp_token"
11 | mocker.patch("redshift_connector.iam_helper.IdpAuthHelper.set_auth_properties", return_value=None)
12 | mocker.patch(
13 | "redshift_connector.native_plugin_helper.NativeAuthPluginHelper.get_native_auth_plugin_credentials",
14 | return_value=mocked_idp_token,
15 | )
16 | spy = mocker.spy(NativeAuthPluginHelper, "get_native_auth_plugin_credentials")
17 | mock_rp: MagicMock = MagicMock()
18 |
19 | NativeAuthPluginHelper.set_native_auth_plugin_properties(mock_rp)
20 |
21 | assert spy.called is True
22 | assert spy.call_count == 1
23 | assert spy.call_args[0][0] == mock_rp
24 | assert mock_rp.method_calls[0][0] == "put"
25 | assert mock_rp.method_calls[0][1] == ("web_identity_token", mocked_idp_token)
26 |
--------------------------------------------------------------------------------
/test/unit/test_paramstyle.py:
--------------------------------------------------------------------------------
1 | import typing
2 |
3 | import pytest
4 |
5 | from redshift_connector.config import DbApiParamstyle
6 | from redshift_connector.core import convert_paramstyle as convert
7 |
8 | # Tests of the convert_paramstyle function.
9 |
10 |
11 | @pytest.mark.parametrize(
12 | "in_statement,out_statement,args",
13 | [
14 | (
15 | 'SELECT ?, ?, "field_?" FROM t ' "WHERE a='say ''what?''' AND b=? AND c=E'?\\'test\\'?'",
16 | 'SELECT $1, $2, "field_?" FROM t WHERE ' "a='say ''what?''' AND b=$3 AND c=E'?\\'test\\'?'",
17 | (1, 2, 3),
18 | ),
19 | (
20 | "SELECT ?, ?, * FROM t WHERE a=? AND b='are you ''sure?'",
21 | "SELECT $1, $2, * FROM t WHERE a=$3 AND b='are you ''sure?'",
22 | (1, 2, 3),
23 | ),
24 | ],
25 | )
26 | def test_qmark(in_statement, out_statement, args) -> None:
27 | new_query, make_args = convert(DbApiParamstyle.QMARK.value, in_statement)
28 | assert new_query == out_statement
29 | assert make_args(args) == args
30 |
31 |
32 | def test_numeric() -> None:
33 | new_query, make_args = convert(
34 | DbApiParamstyle.NUMERIC.value, "SELECT sum(x)::decimal(5, 2) :2, :1, * FROM t WHERE a=:3"
35 | )
36 | expected: str = "SELECT sum(x)::decimal(5, 2) $2, $1, * FROM t WHERE a=$3"
37 | assert new_query == expected
38 | assert make_args((1, 2, 3)) == (1, 2, 3)
39 |
40 |
41 | def test_numeric_default_parameter() -> None:
42 | new_query, make_args = convert(DbApiParamstyle.NUMERIC.value, "make_interval(days := 10)")
43 |
44 | assert new_query == "make_interval(days := 10)"
45 | assert make_args((1, 2, 3)) == (1, 2, 3)
46 |
47 |
48 | def test_named() -> None:
49 | new_query, make_args = convert(
50 | DbApiParamstyle.NAMED.value, "SELECT sum(x)::decimal(5, 2) :f_2, :f1 FROM t WHERE a=:f_2"
51 | )
52 | expected: str = "SELECT sum(x)::decimal(5, 2) $1, $2 FROM t WHERE a=$1"
53 | assert new_query == expected
54 | assert make_args({"f_2": 1, "f1": 2}) == (1, 2)
55 |
56 |
57 | def test_format() -> None:
58 | new_query, make_args = convert(
59 | DbApiParamstyle.FORMAT.value,
60 | "SELECT %s, %s, \"f1_%%\", E'txt_%%' FROM t WHERE a=%s AND b='75%%' AND c = '%' -- Comment with %",
61 | )
62 | expected: str = (
63 | "SELECT $1, $2, \"f1_%%\", E'txt_%%' FROM t WHERE a=$3 AND " "b='75%%' AND c = '%' -- Comment with %"
64 | )
65 | assert new_query == expected
66 | assert make_args((1, 2, 3)) == (1, 2, 3)
67 |
68 |
69 | def test_format_multiline() -> None:
70 | new_query, make_args = convert(DbApiParamstyle.FORMAT.value, "SELECT -- Comment\n%s FROM t")
71 | assert new_query == "SELECT -- Comment\n$1 FROM t"
72 |
73 |
74 | @pytest.mark.parametrize("paramstyle", DbApiParamstyle.list())
75 | @pytest.mark.parametrize(
76 | "statement",
77 | (
78 | """
79 | EXPLAIN
80 | /* blabla
81 | something 100% with percent
82 | */
83 | SELECT {}
84 | """,
85 | """
86 | EXPLAIN
87 | /* blabla
88 | %% %s :blah %sbooze $1 %%s
89 | */
90 | SELECT {}
91 | """,
92 | """/* multiple line comment here */""",
93 | """
94 | /* this is my multi-line sql comment */
95 | select
96 | pk_id,
97 | {},
98 | -- shared_id, disabled until 12/12/2020
99 | order_date
100 | from my_table
101 | """,
102 | """/**/select {}""",
103 | """select {}
104 | /*\n
105 | some comments about the logic
106 | */ -- redo later""",
107 | r"""COMMENT ON TABLE test_schema.comment_test """ r"""IS 'the test % '' " \ table comment'""",
108 | ),
109 | )
110 | def test_multiline_single_parameter(paramstyle, statement) -> None:
111 | in_statement = statement
112 | format_char = None
113 | expected = statement.format("$1")
114 |
115 | if paramstyle == DbApiParamstyle.FORMAT.value:
116 | format_char = "%s"
117 | elif paramstyle == DbApiParamstyle.PYFORMAT.value:
118 | format_char = "%(f1)s"
119 | elif paramstyle == DbApiParamstyle.NAMED.value:
120 | format_char = ":beer"
121 | elif paramstyle == DbApiParamstyle.NUMERIC.value:
122 | format_char = ":1"
123 | elif paramstyle == DbApiParamstyle.QMARK.value:
124 | format_char = "?"
125 | in_statement = in_statement.format(format_char)
126 |
127 | new_query, make_args = convert(paramstyle, in_statement)
128 | assert new_query == expected
129 |
130 |
131 | def test_py_format() -> None:
132 | new_query, make_args = convert(
133 | DbApiParamstyle.PYFORMAT.value,
134 | "SELECT %(f2)s, %(f1)s, \"f1_%%\", E'txt_%%' " "FROM t WHERE a=%(f2)s AND b='75%%'",
135 | )
136 | expected: str = "SELECT $1, $2, \"f1_%%\", E'txt_%%' FROM t WHERE a=$1 AND " "b='75%%'"
137 | assert new_query == expected
138 | assert make_args({"f2": 1, "f1": 2, "f3": 3}) == (1, 2)
139 |
140 | # pyformat should support %s and an array, too:
141 | new_query, make_args = convert(
142 | DbApiParamstyle.PYFORMAT.value, "SELECT %s, %s, \"f1_%%\", E'txt_%%' " "FROM t WHERE a=%s AND b='75%%'"
143 | )
144 | expected = "SELECT $1, $2, \"f1_%%\", E'txt_%%' FROM t WHERE a=$3 AND " "b='75%%'"
145 | assert new_query, expected
146 | assert make_args((1, 2, 3)) == (1, 2, 3)
147 |
--------------------------------------------------------------------------------
/test/unit/test_type_utils.py:
--------------------------------------------------------------------------------
1 | import typing
2 | from datetime import date, datetime, time
3 | from decimal import Decimal
4 | from enum import Enum
5 |
6 | import pytest # type: ignore
7 |
8 | from redshift_connector.config import (
9 | EPOCH,
10 | INFINITY_MICROSECONDS,
11 | MINUS_INFINITY_MICROSECONDS,
12 | )
13 | from redshift_connector.utils import type_utils
14 |
15 |
16 | @pytest.mark.parametrize("_input", [(True, b"\x01"), (False, b"\x00")])
17 | def test_bool_send(_input) -> None:
18 | in_val, exp_val = _input
19 | assert type_utils.bool_send(in_val) == exp_val
20 |
21 |
22 | @pytest.mark.parametrize("_input", [None, 1])
23 | def test_null_send(_input) -> None:
24 | assert type_utils.null_send(_input) == type_utils.NULL
25 |
26 |
27 | class Apple(Enum):
28 | macintosh: int = 1
29 | granny_smith: int = 2
30 | ambrosia: int = 3
31 |
32 |
33 | class Orange(Enum):
34 | navel: int = 1
35 | blood: int = 2
36 | cara_cara: int = 3
37 |
38 |
39 | @pytest.mark.parametrize("_input", [(Apple.macintosh, b"1"), (Orange.cara_cara, b"3")])
40 | def test_enum_out(_input) -> None:
41 | in_val, exp_val = _input
42 | assert type_utils.enum_out(in_val) == exp_val
43 |
44 |
45 | @pytest.mark.parametrize(
46 | "_input", [(time(hour=0, minute=0, second=0), b"00:00:00"), (time(hour=12, minute=34, second=56), b"12:34:56")]
47 | )
48 | def test_time_out(_input) -> None:
49 | in_val, exp_val = _input
50 | assert type_utils.time_out(in_val) == exp_val
51 |
52 |
53 | @pytest.mark.parametrize(
54 | "_input", [(date(month=1, day=1, year=1), b"0001-01-01"), (date(month=1, day=31, year=2020), b"2020-01-31")]
55 | )
56 | def test_date_out(_input) -> None:
57 | in_val, exp_val = _input
58 | assert type_utils.date_out(in_val) == exp_val
59 |
60 |
61 | @pytest.mark.parametrize(
62 | "_input",
63 | [
64 | (Decimal(123.45678), b"123.4567799999999948568074614740908145904541015625"),
65 | (Decimal(123456789.012345), b"123456789.01234500110149383544921875"),
66 | ],
67 | )
68 | def test_numeric_out(_input) -> None:
69 | in_val, exp_val = _input
70 | assert type_utils.numeric_out(in_val) == exp_val
71 |
72 |
73 | timestamp_send_integer_data: typing.List[typing.Tuple[bytes, datetime]] = [
74 | (b"00000000", datetime.max),
75 | (b"12345678", datetime.max),
76 | (INFINITY_MICROSECONDS.to_bytes(length=8, byteorder="big"), datetime.max),
77 | (MINUS_INFINITY_MICROSECONDS.to_bytes(signed=True, length=8, byteorder="big"), datetime.min),
78 | ]
79 |
80 |
81 | @pytest.mark.parametrize("_input", timestamp_send_integer_data)
82 | def test_timestamp_recv_integer(_input) -> None:
83 | in_val, exp_val = _input
84 | print(type_utils.timestamp_recv_integer(in_val, 0, 0))
85 | print(EPOCH.timestamp() * 1000)
86 | assert type_utils.timestamp_recv_integer(in_val, 0, 0) == exp_val
87 |
--------------------------------------------------------------------------------
/test/unit/test_typeobjects.py:
--------------------------------------------------------------------------------
1 | import typing
2 | import unittest
3 |
4 | import pytest
5 |
6 | from redshift_connector.interval import Interval
7 |
8 |
9 | @pytest.mark.parametrize(
10 | "kwargs, exp_months, exp_days, exp_microseconds",
11 | (
12 | (
13 | {"months": 1},
14 | 1,
15 | 0,
16 | 0,
17 | ),
18 | ({"days": 1}, 0, 1, 0),
19 | ({"microseconds": 1}, 0, 0, 1),
20 | ({"months": 1, "days": 2, "microseconds": 3}, 1, 2, 3),
21 | ),
22 | )
23 | def test_interval_constructor(kwargs, exp_months, exp_days, exp_microseconds) -> None:
24 | i = Interval(**kwargs)
25 | assert i.months == exp_months
26 | assert i.days == exp_days
27 | assert i.microseconds == exp_microseconds
28 |
29 |
30 | def test_default_constructor() -> None:
31 | i: Interval = Interval()
32 | assert i.months == 0
33 | assert i.days == 0
34 | assert i.microseconds == 0
35 |
36 |
37 | def interval_range_test(parameter, in_range, out_of_range):
38 | for v in out_of_range:
39 | try:
40 | Interval(**{parameter: v})
41 | pytest.fail("expected OverflowError")
42 | except OverflowError:
43 | pass
44 | for v in in_range:
45 | Interval(**{parameter: v})
46 |
47 |
48 | def test_interval_days_range() -> None:
49 | out_of_range_days = (
50 | -2147483649,
51 | +2147483648,
52 | )
53 | in_range_days = (
54 | -2147483648,
55 | +2147483647,
56 | )
57 | interval_range_test("days", in_range_days, out_of_range_days)
58 |
59 |
60 | def test_interval_months_range() -> None:
61 | out_of_range_months = (
62 | -2147483649,
63 | +2147483648,
64 | )
65 | in_range_months = (
66 | -2147483648,
67 | +2147483647,
68 | )
69 | interval_range_test("months", in_range_months, out_of_range_months)
70 |
71 |
72 | def test_interval_microseconds_range() -> None:
73 | out_of_range_microseconds = (
74 | -9223372036854775809,
75 | +9223372036854775808,
76 | )
77 | in_range_microseconds = (
78 | -9223372036854775808,
79 | +9223372036854775807,
80 | )
81 | interval_range_test("microseconds", in_range_microseconds, out_of_range_microseconds)
82 |
83 |
84 | @pytest.mark.parametrize(
85 | "kwargs, exp_total_seconds",
86 | (
87 | ({"months": 1}, 0),
88 | ({"days": 1}, 86400),
89 | ({"microseconds": 1}, 1e-6),
90 | ({"months": 1, "days": 2, "microseconds": 3}, 172800.000003),
91 | ),
92 | )
93 | def test_total_seconds(kwargs, exp_total_seconds) -> None:
94 | i: Interval = Interval(**kwargs)
95 | assert i.total_seconds() == exp_total_seconds
96 |
97 |
98 | def test_set_months_raises_type_error() -> None:
99 | with pytest.raises(TypeError):
100 | Interval(months="foobar") # type: ignore
101 |
102 |
103 | def test_set_days_raises_type_error() -> None:
104 | with pytest.raises(TypeError):
105 | Interval(days="foobar") # type: ignore
106 |
107 |
108 | def test_set_microseconds_raises_type_error() -> None:
109 | with pytest.raises(TypeError):
110 | Interval(microseconds="foobar") # type: ignore
111 |
112 |
113 | interval_equality_test_vals: typing.Tuple[
114 | typing.Tuple[typing.Optional[Interval], typing.Optional[Interval], bool], ...
115 | ] = (
116 | (Interval(months=1), Interval(months=1), True),
117 | (Interval(months=1), Interval(), False),
118 | (Interval(months=1), Interval(months=2), False),
119 | (Interval(days=1), Interval(days=1), True),
120 | (Interval(days=1), Interval(), False),
121 | (Interval(days=1), Interval(days=2), False),
122 | (Interval(microseconds=1), Interval(microseconds=1), True),
123 | (Interval(microseconds=1), Interval(), False),
124 | (Interval(microseconds=1), Interval(microseconds=2), False),
125 | (Interval(), Interval(), True),
126 | )
127 |
128 |
129 | @pytest.mark.parametrize("i1, i2, exp_eq", interval_equality_test_vals)
130 | def test__eq__(i1, i2, exp_eq) -> None:
131 | actual_eq = i1.__eq__(i2)
132 | assert actual_eq == exp_eq
133 |
134 |
135 | @pytest.mark.parametrize("i1, i2, exp_eq", interval_equality_test_vals)
136 | def test__neq__(i1, i2, exp_eq) -> None:
137 | exp_neq = not exp_eq
138 | actual_neq = i1.__neq__(i2)
139 | assert actual_neq == exp_neq
140 |
--------------------------------------------------------------------------------
/test/utils/__init__.py:
--------------------------------------------------------------------------------
1 | from .decorators import numpy_only, pandas_only
2 |
--------------------------------------------------------------------------------
/test/utils/decorators.py:
--------------------------------------------------------------------------------
1 | import pytest # type: ignore
2 |
3 |
4 | def is_numpy_installed() -> bool:
5 | try:
6 | import numpy # type: ignore
7 |
8 | return True
9 | except ModuleNotFoundError:
10 | return False
11 |
12 |
13 | def is_pandas_installed() -> bool:
14 | try:
15 | import pandas # type: ignore
16 |
17 | return True
18 | except ModuleNotFoundError:
19 | return False
20 |
21 |
22 | numpy_only = pytest.mark.skipif(not is_numpy_installed(), reason="requires numpy")
23 |
24 | pandas_only = pytest.mark.skipif(not is_pandas_installed(), reason="requires pandas")
25 |
--------------------------------------------------------------------------------
/tutorials/003 - Amazon Redshift Feature Support.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "code",
5 | "execution_count": null,
6 | "metadata": {
7 | "collapsed": true
8 | },
9 | "outputs": [],
10 | "source": [
11 | "# Amazon Redshift Feature Support"
12 | ]
13 | },
14 | {
15 | "cell_type": "markdown",
16 | "metadata": {},
17 | "source": [
18 | "# Overview\n",
19 | "`redshift_connector` aims to support the latest and greatest features provided by Amazon Redshift so you can get the most out of your data."
20 | ]
21 | },
22 | {
23 | "cell_type": "markdown",
24 | "metadata": {},
25 | "source": [
26 | "## COPY and UNLOAD Support - Amazon S3\n",
27 | "`redshift_connector` provides the ability to `COPY` and `UNLOAD` data from an Amazon S3 bucket. Shown below is a sample workflow which copies and unloads data from an Amazon S3 bucket"
28 | ]
29 | },
30 | {
31 | "cell_type": "markdown",
32 | "metadata": {
33 | "pycharm": {
34 | "name": "#%% md\n"
35 | }
36 | },
37 | "source": [
38 | "1. Upload the following text file to an Amazon S3 bucket and name it `category_csv.txt`\n",
39 | "\n",
40 | "```text\n",
41 | " 12,Shows,Musicals,Musical theatre\n",
42 | " 13,Shows,Plays,\"All \"\"non-musical\"\" theatre\"\n",
43 | " 14,Shows,Opera,\"All opera, light, and \"\"rock\"\" opera\"\n",
44 | " 15,Concerts,Classical,\"All symphony, concerto, and choir concerts\"\n",
45 | "```"
46 | ]
47 | },
48 | {
49 | "cell_type": "code",
50 | "execution_count": null,
51 | "metadata": {
52 | "pycharm": {
53 | "name": "#%%\n"
54 | }
55 | },
56 | "outputs": [],
57 | "source": [
58 | "import redshift_connector\n",
59 | "\n",
60 | "with redshift_connector.connect(\n",
61 | " host='examplecluster.abc123xyz789.us-west-1.redshift.amazonaws.com',\n",
62 | " database='dev',\n",
63 | " user='awsuser',\n",
64 | " password='my_password'\n",
65 | ") as conn:\n",
66 | " with conn.cursor() as cursor:\n",
67 | " cursor.execute(\"create table category (catid int, cargroup varchar, catname varchar, catdesc varchar)\")\n",
68 | " cursor.execute(\"copy category from 's3://testing/category_csv.txt' iam_role 'arn:aws:iam::123:role/RedshiftCopyUnload' csv;\")\n",
69 | " cursor.execute(\"select * from category\")\n",
70 | " print(cursor.fetchall())\n",
71 | " cursor.execute(\"unload ('select * from category') to 's3://testing/unloaded_category_csv.txt' iam_role 'arn:aws:iam::123:role/RedshiftCopyUnload' csv;\")\n",
72 | " print('done')\n"
73 | ]
74 | },
75 | {
76 | "cell_type": "markdown",
77 | "metadata": {},
78 | "source": [
79 | "After executing the above code block, we can see the requested data was unloaded into the following file, `unloaded_category_csv.text0000_part00`, in the specified Amazon s3 bucket\n"
80 | ]
81 | }
82 | ],
83 | "metadata": {
84 | "kernelspec": {
85 | "display_name": "Python 3 (ipykernel)",
86 | "language": "python",
87 | "name": "python3"
88 | },
89 | "language_info": {
90 | "codemirror_mode": {
91 | "name": "ipython",
92 | "version": 3
93 | },
94 | "file_extension": ".py",
95 | "mimetype": "text/x-python",
96 | "name": "python",
97 | "nbconvert_exporter": "python",
98 | "pygments_lexer": "ipython3",
99 | "version": "3.9.7"
100 | }
101 | },
102 | "nbformat": 4,
103 | "nbformat_minor": 1
104 | }
105 |
--------------------------------------------------------------------------------