├── .gitchangelog.rc ├── .github ├── ISSUE_TEMPLATE.md ├── PULL_REQUEST_TEMPLATE.md ├── dependabot.yml └── workflows │ └── python-package.yml ├── .gitignore ├── .pre-commit-config.yaml ├── CHANGELOG.md ├── CODE_OF_CONDUCT.md ├── CONTRIBUTING.md ├── LICENSE ├── MANIFEST.in ├── NOTICE ├── README.rst ├── THIRD_PARTY_LICENSES ├── build.sh ├── redshift_connector ├── __init__.py ├── auth │ ├── __init__.py │ └── aws_credentials_provider.py ├── config.py ├── core.py ├── credentials_holder.py ├── cursor.py ├── error.py ├── files │ └── redshift-ca-bundle.crt ├── iam_helper.py ├── idp_auth_helper.py ├── interval.py ├── metadataAPIHelper.py ├── metadataAPIPostProcessing.py ├── metadataServerAPIHelper.py ├── native_plugin_helper.py ├── objects.py ├── pg_types.py ├── plugin │ ├── __init__.py │ ├── adfs_credentials_provider.py │ ├── azure_credentials_provider.py │ ├── browser_azure_credentials_provider.py │ ├── browser_azure_oauth2_credentials_provider.py │ ├── browser_idc_auth_plugin.py │ ├── browser_saml_credentials_provider.py │ ├── common_credentials_provider.py │ ├── credential_provider_constants.py │ ├── i_native_plugin.py │ ├── i_plugin.py │ ├── idp_credentials_provider.py │ ├── idp_token_auth_plugin.py │ ├── jwt_credentials_provider.py │ ├── native_token_holder.py │ ├── okta_credentials_provider.py │ ├── ping_credentials_provider.py │ └── saml_credentials_provider.py ├── py.typed ├── redshift_property.py ├── utils │ ├── __init__.py │ ├── array_util.py │ ├── driver_info.py │ ├── extensible_digest.py │ ├── logging_utils.py │ ├── oids.py │ ├── sql_types.py │ └── type_utils.py └── version.py ├── requirements-dev.txt ├── requirements.txt ├── setup.cfg ├── setup.py ├── test ├── __init__.py ├── conftest.py ├── integration │ ├── __init__.py │ ├── datatype │ │ ├── _generate_test_datatype_tables.py │ │ ├── test_datatypes.py │ │ └── test_system_table_queries.py │ ├── metadata │ │ ├── test_list_catalog.py │ │ ├── test_metadataAPI.py │ │ ├── test_metadataAPIHelperServer.py │ │ ├── test_metadataAPI_pattern_matching.py │ │ ├── test_metadataAPI_special_character_handling.py │ │ ├── test_metadataAPI_special_character_handling_case_sensitive.py │ │ ├── test_metadataAPI_special_character_handling_standard_delimited_identifier_case_insensitive.py │ │ ├── test_metadataAPI_special_character_handling_standard_delimited_identifier_case_insensitive_with_pipe.py │ │ ├── test_metadataAPI_special_character_handling_standard_delimited_identifier_case_sensitive.py │ │ ├── test_metadataAPI_special_character_handling_standard_identifier.py │ │ └── test_metadataAPI_sql_injection.py │ ├── plugin │ │ ├── conftest.py │ │ ├── test_azure_credentials_provider.py │ │ ├── test_credentials_providers.py │ │ └── test_okta_credentials_provider.py │ ├── test_connection.py │ ├── test_cursor.py │ ├── test_dbapi20.py │ ├── test_pandas.py │ ├── test_paramstyle.py │ ├── test_query.py │ ├── test_redshift_property.py │ └── test_unsupported_datatypes.py ├── manual │ ├── __init__.py │ ├── auth │ │ ├── test_aws_credentials.py │ │ └── test_redshift_auth_profile.py │ ├── plugin │ │ ├── __init__.py │ │ └── test_browser_credentials_provider.py │ ├── test_redshift_custom_domain.py │ └── test_redshift_serverless.py ├── performance │ ├── bulk_insert_data.csv │ ├── bulk_insert_performance.py │ ├── performance.py │ ├── protocol_perf_test.sql │ ├── protocol_performance.py │ └── test.sql ├── unit │ ├── __init__.py │ ├── auth │ │ ├── __init__.py │ │ └── test_aws_credentials_provider.py │ ├── datatype │ │ ├── __init__.py │ │ ├── test_data_in.py │ │ ├── test_oids.py │ │ └── test_sql_types.py │ ├── helpers │ │ ├── __init__.py │ │ └── idp_helpers.py │ ├── mocks │ │ ├── __init__.py │ │ ├── mock_external_credential_provider.py │ │ └── mock_socket.py │ ├── plugin │ │ ├── __init__.py │ │ ├── data │ │ │ ├── __init__.py │ │ │ ├── browser_azure_data.py │ │ │ ├── mock_adfs_saml_response.html │ │ │ ├── mock_adfs_sign_in.html │ │ │ ├── saml_response.xml │ │ │ └── saml_response_data.py │ │ ├── test_adfs_credentials_provider.py │ │ ├── test_azure_credentials_provider.py │ │ ├── test_azure_oauth2_credentials_provider.py │ │ ├── test_browser_azure_credentials_provider.py │ │ ├── test_browser_idc_auth_plugin.py │ │ ├── test_browser_saml_credentials_provider.py │ │ ├── test_credentials_providers.py │ │ ├── test_idp_token_auth_plugin.py │ │ ├── test_jwt_credentials_provider.py │ │ ├── test_okta_credentials_provider.py │ │ ├── test_plugin_inheritance.py │ │ └── test_saml_credentials_provider.py │ ├── test_array_util.py │ ├── test_connection.py │ ├── test_core.py │ ├── test_credentials_holder.py │ ├── test_cursor.py │ ├── test_dbapi20.py │ ├── test_driver_info.py │ ├── test_iam_helper.py │ ├── test_idp_auth_helper.py │ ├── test_import.py │ ├── test_logging_utils.py │ ├── test_metadataAPIHelper.py │ ├── test_metadataAPIPostProcessing.py │ ├── test_metadataServerAPIHelper.py │ ├── test_module.py │ ├── test_native_plugin_helper.py │ ├── test_paramstyle.py │ ├── test_redshift_property.py │ ├── test_type_utils.py │ └── test_typeobjects.py └── utils │ ├── __init__.py │ └── decorators.py └── tutorials ├── .ipynb_checkpoints └── 001 - Connecting to Amazon Redshift-checkpoint.ipynb ├── 001 - Connecting to Amazon Redshift.ipynb ├── 002 - Data Science Library Integrations.ipynb ├── 003 - Amazon Redshift Feature Support.ipynb └── 004 - Amazon Redshift Datatypes.ipynb /.github/ISSUE_TEMPLATE.md: -------------------------------------------------------------------------------- 1 | --- 2 | name: Issue report 3 | about: Report an issue 4 | title: '' 5 | assignees: '' 6 | 7 | --- 8 | 9 | ## Driver version 10 | 11 | 12 | ## Redshift version 13 | 14 | 15 | ## Client Operating System 16 | 17 | 18 | ## Python version 19 | 20 | 21 | ## Table schema 22 | 23 | 24 | ## Problem description 25 | 26 | 1. Expected behaviour: 27 | 2. Actual behaviour: 28 | 3. Error message/stack trace: 29 | 4. Any other details that can be helpful: 30 | 31 | ## Python Driver trace logs 32 | 33 | 34 | ## Reproduction code 35 | 36 | -------------------------------------------------------------------------------- /.github/PULL_REQUEST_TEMPLATE.md: -------------------------------------------------------------------------------- 1 | 2 | 3 | ## Description 4 | 5 | 6 | ## Motivation and Context 7 | 8 | 9 | 10 | ## Testing 11 | 12 | 13 | 14 | 15 | ## Screenshots (if appropriate) 16 | 17 | ## Types of changes 18 | 19 | - [ ] Bug fix (non-breaking change which fixes an issue) 20 | - [ ] New feature (non-breaking change which adds functionality) 21 | 22 | ## Checklist 23 | 24 | 25 | 26 | - [ ] Local run of `./build.sh` succeeds 27 | - [ ] Code changes have been run against the repository's pre-commit hooks 28 | - [ ] Commit messages follow [Conventional Commit Specification](https://www.conventionalcommits.org/en/v1.0.0/) 29 | - [ ] I have read the **README** document 30 | - [ ] I have added tests to cover my changes 31 | - [ ] I have run all unit tests using `pytest test/unit` and they are passing. 32 | 36 | 37 | 38 | - By submitting this pull request, I confirm that you can use, modify, copy, and redistribute this contribution, under the terms of your choice. 39 | -------------------------------------------------------------------------------- /.github/dependabot.yml: -------------------------------------------------------------------------------- 1 | version: 2 2 | updates: 3 | - package-ecosystem: "pip" 4 | directory: "/" 5 | schedule: 6 | interval: "weekly" 7 | time: "09:00" 8 | timezone: "Europe/London" 9 | commit-message: 10 | prefix: "chore" 11 | prefix-development: "chore" 12 | include: "scope" 13 | open-pull-requests-limit: 8 14 | -------------------------------------------------------------------------------- /.github/workflows/python-package.yml: -------------------------------------------------------------------------------- 1 | # This workflow will install Python dependencies, run tests and lint with a variety of Python versions 2 | # For more information see: https://help.github.com/actions/language-and-framework-guides/using-python-with-github-actions 3 | 4 | name: Python package 5 | 6 | on: workflow_dispatch 7 | 8 | jobs: 9 | build: 10 | 11 | runs-on: ubuntu-latest 12 | timeout-minutes: 15 13 | 14 | strategy: 15 | fail-fast: true 16 | matrix: 17 | python-version: ["3.6", "3.9"] 18 | 19 | steps: 20 | - uses: actions/checkout@v2 21 | - name: Set up Python ${{ matrix.python-version }} 22 | uses: actions/setup-python@v2 23 | with: 24 | python-version: ${{ matrix.python-version }} 25 | - name: Install dependencies 26 | run: | 27 | python -m pip install --upgrade pip 28 | python -m pip install wheel setuptools 29 | if [ -f requirements.txt ]; then pip install -r requirements.txt; fi 30 | if [ -f requirements-dev.txt ]; then pip install -r requirements-dev.txt; fi 31 | - name: Run Unit Tests 32 | run: | 33 | pytest test/unit/ 34 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /.pre-commit-config.yaml: -------------------------------------------------------------------------------- 1 | repos: 2 | - repo: https://github.com/pre-commit/pre-commit-hooks 3 | rev: v4.4.0 4 | hooks: 5 | - id: check-yaml 6 | - id: end-of-file-fixer 7 | - id: trailing-whitespace 8 | - repo: https://github.com/pre-commit/mirrors-isort 9 | rev: v5.10.1 # must be >5.0.0 for black compatibility 10 | hooks: 11 | - id: isort 12 | args: ["--profile", "black", "."] 13 | - repo: https://github.com/ambv/black 14 | rev: 23.7.0 15 | hooks: 16 | - id: black 17 | - repo: https://github.com/pre-commit/mirrors-mypy 18 | rev: v1.4.1 19 | hooks: 20 | - id: mypy 21 | additional_dependencies: [types-setuptools, types-requests, types-python-dateutil] 22 | args: [--ignore-missing-imports, --disable-error-code, "annotation-unchecked"] 23 | verbose: true 24 | -------------------------------------------------------------------------------- /CODE_OF_CONDUCT.md: -------------------------------------------------------------------------------- 1 | ## Code of Conduct 2 | This project has adopted the [Amazon Open Source Code of Conduct](https://aws.github.io/code-of-conduct). 3 | For more information see the [Code of Conduct FAQ](https://aws.github.io/code-of-conduct-faq) or contact 4 | opensource-codeofconduct@amazon.com with any additional questions or comments. 5 | -------------------------------------------------------------------------------- /CONTRIBUTING.md: -------------------------------------------------------------------------------- 1 | # Contributing Guidelines 2 | 3 | Thank you for your interest in contributing to our project. Whether it's a bug report, new feature, correction, or additional 4 | documentation, we greatly value feedback and contributions from our community. 5 | 6 | Please read through this document before submitting any issues or pull requests to ensure we have all the necessary 7 | information to effectively respond to your bug report or contribution. 8 | 9 | 10 | ## Reporting Bugs/Feature Requests 11 | 12 | We welcome you to use the GitHub issue tracker to report bugs or suggest features. 13 | 14 | When filing an issue, please check [existing open](https://github.com/aws/amazon-redshift-python-driver/issues), or [recently closed](https://github.com/aws/amazon-redshift-python-driver/issues?utf8=%E2%9C%93&q=is%3Aissue%20is%3Aclosed%20), issues to make sure somebody else hasn't already 15 | reported the issue. Please try to include as much information as you can. Details like these are incredibly useful: 16 | 17 | * A reproducible test case or series of steps 18 | * The version of our code being used 19 | * Any modifications you've made relevant to the bug 20 | * Anything unusual about your environment or deployment 21 | 22 | 23 | ## Contributing via Pull Requests 24 | Contributions via pull requests are much appreciated. Before sending us a pull request, please ensure that: 25 | 26 | 1. You are working against the latest source on the *master* branch. 27 | 2. You check existing open, and recently merged, pull requests to make sure someone else hasn't addressed the problem already. 28 | 3. You open an issue to discuss any significant work - we would hate for your time to be wasted. 29 | 30 | To send us a pull request, please: 31 | 32 | 1. Fork the repository. 33 | 2. Modify the source; please focus on the specific change you are contributing. If you also reformat all the code, it will be hard for us to focus on your change. 34 | 3. Pre-commit hooks have been run. 35 | 4. Ensure unit tests are passing by running `pytest test/unit` 36 | 5. Commit to your fork using clear commit messages that follow [Conventional Commit](https://www.conventionalcommits.org/en/v1.0.0/) specification. 37 | 6. Send us a pull request, answering any default questions in the pull request interface. 38 | 7. Pay attention to any automated CI failures reported in the pull request, and stay involved in the conversation. 39 | 40 | GitHub provides additional document on [forking a repository](https://help.github.com/articles/fork-a-repo/) and 41 | [creating a pull request](https://help.github.com/articles/creating-a-pull-request/). 42 | 43 | 44 | ## Finding contributions to work on 45 | Looking at the existing issues is a great way to find something to contribute on. As our projects, by default, use the default GitHub issue labels (enhancement/bug/duplicate/help wanted/invalid/question/wontfix), looking at any 'help wanted' issues is a great place to start. 46 | 47 | 48 | ## Code of Conduct 49 | This project has adopted the [Amazon Open Source Code of Conduct](https://aws.github.io/code-of-conduct). 50 | For more information see the [Code of Conduct FAQ](https://aws.github.io/code-of-conduct-faq) or contact 51 | opensource-codeofconduct@amazon.com with any additional questions or comments. 52 | 53 | 54 | ## Security issue notifications 55 | If you discover a potential security issue in this project we ask that you notify AWS/Amazon Security via our [vulnerability reporting page](http://aws.amazon.com/security/vulnerability-reporting/). Please do **not** create a public github issue. 56 | 57 | 58 | ## Licensing 59 | 60 | See the [LICENSE](LICENSE) file for our project's licensing. We will ask you to confirm the licensing of your contribution. 61 | 62 | We may ask you to sign a [Contributor License Agreement (CLA)](http://en.wikipedia.org/wiki/Contributor_License_Agreement) for larger changes. 63 | -------------------------------------------------------------------------------- /MANIFEST.in: -------------------------------------------------------------------------------- 1 | include redshift_connector/files/* 2 | recursive-include files *.crt -------------------------------------------------------------------------------- /NOTICE: -------------------------------------------------------------------------------- 1 | Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. 2 | -------------------------------------------------------------------------------- /build.sh: -------------------------------------------------------------------------------- 1 | python3 setup.py sdist bdist_wheel 2 | -------------------------------------------------------------------------------- /redshift_connector/auth/__init__.py: -------------------------------------------------------------------------------- 1 | from .aws_credentials_provider import AWSCredentialsProvider 2 | -------------------------------------------------------------------------------- /redshift_connector/auth/aws_credentials_provider.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import typing 3 | 4 | from redshift_connector.credentials_holder import ( 5 | ABCCredentialsHolder, 6 | AWSDirectCredentialsHolder, 7 | AWSProfileCredentialsHolder, 8 | ) 9 | from redshift_connector.error import InterfaceError 10 | 11 | _logger: logging.Logger = logging.getLogger(__name__) 12 | 13 | if typing.TYPE_CHECKING: 14 | import boto3 # type: ignore 15 | 16 | from redshift_connector.redshift_property import RedshiftProperty 17 | 18 | 19 | class AWSCredentialsProvider: 20 | """ 21 | A credential provider class for AWS credentials specified via :func:`~redshift_connector.connect` using `profile` or AWS access keys. 22 | """ 23 | 24 | def __init__(self: "AWSCredentialsProvider") -> None: 25 | self.cache: typing.Dict[int, typing.Union[AWSDirectCredentialsHolder, AWSProfileCredentialsHolder]] = {} 26 | 27 | self.access_key_id: typing.Optional[str] = None 28 | self.secret_access_key: typing.Optional[str] = None 29 | self.session_token: typing.Optional[str] = None 30 | self.profile: typing.Optional["boto3.Session"] = None 31 | 32 | def get_cache_key(self: "AWSCredentialsProvider") -> int: 33 | """ 34 | Creates a cache key using the hash of either the end-user provided AWS credential information. 35 | 36 | Returns 37 | ------- 38 | An `int` hash representation of the non-secret portion of credential information: `int` 39 | """ 40 | if self.profile: 41 | return hash(self.profile) 42 | else: 43 | return hash(self.access_key_id) 44 | 45 | def get_credentials( 46 | self: "AWSCredentialsProvider", 47 | ) -> typing.Union[AWSDirectCredentialsHolder, AWSProfileCredentialsHolder]: 48 | """ 49 | Retrieves a :class`ABCCredentialsHolder` from cache or builds one. 50 | 51 | Returns 52 | ------- 53 | An `AWSCredentialsHolder` object containing end-user specified AWS credential information: :class`ABCAWSCredentialsHolder` 54 | """ 55 | key: int = self.get_cache_key() 56 | if key not in self.cache: 57 | try: 58 | self.refresh() 59 | except Exception as e: 60 | exec_msg: str = "Refreshing IdP credentials failed" 61 | _logger.debug(exec_msg) 62 | raise InterfaceError(e) 63 | 64 | credentials: typing.Union[AWSDirectCredentialsHolder, AWSProfileCredentialsHolder] = self.cache[key] 65 | 66 | if credentials is None: 67 | exec_msg = "Unable to load AWS credentials from cache" 68 | _logger.debug(exec_msg) 69 | raise InterfaceError(exec_msg) 70 | 71 | return credentials 72 | 73 | def add_parameter(self: "AWSCredentialsProvider", info: "RedshiftProperty") -> None: 74 | """ 75 | Defines instance variables used for creating a :class`ABCCredentialsHolder` object and associated :class:`boto3.Session` 76 | 77 | Parameters 78 | ---------- 79 | info : :class:`RedshiftProperty` 80 | The :class:`RedshiftProperty` object created using end-user specified values passed to :func:`~redshift_connector.connect` 81 | """ 82 | self.access_key_id = info.access_key_id 83 | self.secret_access_key = info.secret_access_key 84 | self.session_token = info.session_token 85 | self.profile = info.profile 86 | 87 | def refresh(self: "AWSCredentialsProvider") -> None: 88 | """ 89 | Establishes a :class:`boto3.Session` using end-user specified AWS credential information 90 | """ 91 | import boto3 # type: ignore 92 | 93 | args: typing.Dict[str, str] = {} 94 | 95 | if self.profile is not None: 96 | args["profile_name"] = self.profile 97 | elif self.access_key_id is not None: 98 | args["aws_access_key_id"] = self.access_key_id 99 | args["aws_secret_access_key"] = typing.cast(str, self.secret_access_key) 100 | if self.session_token is not None: 101 | args["aws_session_token"] = self.session_token 102 | 103 | session: boto3.Session = boto3.Session(**args) 104 | credentials: typing.Optional[typing.Union[AWSProfileCredentialsHolder, AWSDirectCredentialsHolder]] = None 105 | 106 | if self.profile is not None: 107 | credentials = AWSProfileCredentialsHolder(profile=self.profile, session=session) 108 | else: 109 | credentials = AWSDirectCredentialsHolder( 110 | access_key_id=typing.cast(str, self.access_key_id), 111 | secret_access_key=typing.cast(str, self.secret_access_key), 112 | session_token=self.session_token, 113 | session=session, 114 | ) 115 | 116 | key = self.get_cache_key() 117 | self.cache[key] = credentials 118 | -------------------------------------------------------------------------------- /redshift_connector/error.py: -------------------------------------------------------------------------------- 1 | class Warning(Exception): 2 | """Generic exception raised for important database warnings like data 3 | truncations. This exception is not currently used. 4 | 5 | This exception is part of the `DBAPI 2.0 specification 6 | `_. 7 | """ 8 | 9 | pass 10 | 11 | 12 | class Error(Exception): 13 | """Generic exception that is the base exception of all other error 14 | exceptions. 15 | 16 | This exception is part of the `DBAPI 2.0 specification 17 | `_. 18 | """ 19 | 20 | pass 21 | 22 | 23 | class InterfaceError(Error): 24 | """Generic exception raised for errors that are related to the database 25 | interface rather than the database itself. For example, if the interface 26 | attempts to use an SSL connection but the server refuses, an InterfaceError 27 | will be raised. 28 | 29 | This exception is part of the `DBAPI 2.0 specification 30 | `_. 31 | """ 32 | 33 | pass 34 | 35 | 36 | class DatabaseError(Error): 37 | """Generic exception raised for errors that are related to the database. 38 | This exception is currently never raised. 39 | 40 | This exception is part of the `DBAPI 2.0 specification 41 | `_. 42 | """ 43 | 44 | pass 45 | 46 | 47 | class DataError(DatabaseError): 48 | """Generic exception raised for errors that are due to problems with the 49 | processed data. This exception is not currently raised. 50 | 51 | This exception is part of the `DBAPI 2.0 specification 52 | `_. 53 | """ 54 | 55 | pass 56 | 57 | 58 | class OperationalError(DatabaseError): 59 | """ 60 | Generic exception raised for errors that are related to the database's 61 | operation and not necessarily under the control of the programmer. This 62 | exception is currently never raised. 63 | 64 | This exception is part of the `DBAPI 2.0 specification 65 | `_. 66 | """ 67 | 68 | pass 69 | 70 | 71 | class IntegrityError(DatabaseError): 72 | """ 73 | Generic exception raised when the relational integrity of the database is 74 | affected. This exception is not currently raised. 75 | 76 | This exception is part of the `DBAPI 2.0 specification 77 | `_. 78 | """ 79 | 80 | pass 81 | 82 | 83 | class InternalError(DatabaseError): 84 | """Generic exception raised when the database encounters an internal error. 85 | This is currently only raised when unexpected state occurs in the 86 | interface itself, and is typically the result of a interface bug. 87 | 88 | This exception is part of the `DBAPI 2.0 specification 89 | `_. 90 | """ 91 | 92 | pass 93 | 94 | 95 | class ProgrammingError(DatabaseError): 96 | """Generic exception raised for programming errors. For example, this 97 | exception is raised if more parameter fields are in a query string than 98 | there are available parameters. 99 | 100 | This exception is part of the `DBAPI 2.0 specification 101 | `_. 102 | """ 103 | 104 | pass 105 | 106 | 107 | class NotSupportedError(DatabaseError): 108 | """Generic exception raised in case a method or database API was used which 109 | is not supported by the database. 110 | 111 | This exception is part of the `DBAPI 2.0 specification 112 | `_. 113 | """ 114 | 115 | pass 116 | 117 | 118 | class ArrayContentNotSupportedError(NotSupportedError): 119 | """ 120 | Raised when attempting to transmit an array where the base type is not 121 | supported for binary data transfer by the interface. 122 | """ 123 | 124 | pass 125 | 126 | 127 | class ArrayContentNotHomogenousError(ProgrammingError): 128 | """ 129 | Raised when attempting to transmit an array that doesn't contain only a 130 | single type of object. 131 | """ 132 | 133 | pass 134 | 135 | 136 | class ArrayDimensionsNotConsistentError(ProgrammingError): 137 | """ 138 | Raised when attempting to transmit an array that has inconsistent 139 | multi-dimension sizes. 140 | """ 141 | 142 | pass 143 | 144 | 145 | MISSING_MODULE_ERROR_MSG: str = ( 146 | "redshift_connector requires {module} support for this functionality. " 147 | "Please install redshift_connector[full] for {module} support" 148 | ) 149 | -------------------------------------------------------------------------------- /redshift_connector/native_plugin_helper.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import typing 3 | 4 | from redshift_connector.error import InterfaceError 5 | from redshift_connector.idp_auth_helper import IdpAuthHelper 6 | from redshift_connector.plugin.i_native_plugin import INativePlugin 7 | 8 | if typing.TYPE_CHECKING: 9 | from redshift_connector.plugin.native_token_holder import NativeTokenHolder 10 | from redshift_connector.redshift_property import RedshiftProperty 11 | 12 | logging.getLogger(__name__).addHandler(logging.NullHandler()) 13 | _logger: logging.Logger = logging.getLogger(__name__) 14 | 15 | 16 | class NativeAuthPluginHelper: 17 | @staticmethod 18 | def set_native_auth_plugin_properties(info: "RedshiftProperty") -> None: 19 | """ 20 | Modifies ``info`` to prepare for authentication with Amazon Redshift 21 | 22 | Parameters 23 | ---------- 24 | info: RedshiftProperty 25 | RedshiftProperty object storing user defined and derived attributes used for authentication 26 | 27 | Returns 28 | ------- 29 | None:None 30 | """ 31 | if info.credentials_provider: 32 | # include the authentication token which will be used for authentication via 33 | # Redshift Native IDP Integration 34 | _logger.debug("Attempting to get native auth plugin credentials") 35 | idp_token: str = NativeAuthPluginHelper.get_native_auth_plugin_credentials(info) 36 | if idp_token: 37 | _logger.debug("setting info.web_identity_token") 38 | info.put("web_identity_token", idp_token) 39 | 40 | @staticmethod 41 | def get_native_auth_plugin_credentials(info: "RedshiftProperty") -> str: 42 | """ 43 | Retrieves credentials for Amazon Redshift native authentication. 44 | 45 | Parameters 46 | ---------- 47 | info: RedshiftProperty 48 | RedshiftProperty object storing user defined and derived attributes used for authentication 49 | 50 | Returns 51 | ------- 52 | str: An authentication token compatible with Redshift Native IDP Integration (code 14) 53 | """ 54 | idp_token: typing.Optional[str] = None 55 | provider = None 56 | 57 | if info.credentials_provider: 58 | provider = IdpAuthHelper.load_credentials_provider(info) 59 | 60 | if not isinstance(provider, INativePlugin): 61 | _logger.debug("Native auth will not be used, no credentials provider specified") 62 | return "" 63 | else: 64 | raise InterfaceError( 65 | "Required credentials_provider parameter is null or empty: {}".format(info.credentials_provider) 66 | ) 67 | 68 | _logger.debug("Native IDP Credential Provider %s:%s", provider, info.credentials_provider) 69 | _logger.debug("Calling provider.getCredentials()") 70 | 71 | # Provider will cache the credentials, it's OK to call get_credentials() here 72 | credentials: "NativeTokenHolder" = typing.cast("NativeTokenHolder", provider.get_credentials()) 73 | 74 | _logger.debug("credentials is None = %s", credentials is None) 75 | _logger.debug("credentials.is_expired() = %s", credentials.is_expired()) 76 | 77 | if credentials is None or (credentials.expiration is not None and credentials.is_expired()): 78 | # get idp token 79 | plugin: INativePlugin = provider 80 | _logger.debug("Unable to get IdP token from cache. Calling plugin.get_idp_token()") 81 | 82 | idp_token = plugin.get_idp_token() 83 | _logger.debug("IdP token retrieved") 84 | info.put("idp_token", idp_token) 85 | else: 86 | _logger.debug("Cached idp_token will be used") 87 | idp_token = credentials.access_token 88 | 89 | return idp_token 90 | -------------------------------------------------------------------------------- /redshift_connector/objects.py: -------------------------------------------------------------------------------- 1 | from datetime import date 2 | from datetime import datetime as Datetime 3 | from datetime import time 4 | from time import localtime 5 | 6 | 7 | def Date(year: int, month: int, day: int) -> date: 8 | """Constuct an object holding a date value. 9 | 10 | This function is part of the `DBAPI 2.0 specification 11 | `_. 12 | 13 | :rtype: :class:`datetime.date` 14 | """ 15 | return date(year, month, day) 16 | 17 | 18 | def Time(hour: int, minute: int, second: int) -> time: 19 | """Construct an object holding a time value. 20 | 21 | This function is part of the `DBAPI 2.0 specification 22 | `_. 23 | 24 | :rtype: :class:`datetime.time` 25 | """ 26 | return time(hour, minute, second) 27 | 28 | 29 | def Timestamp(year: int, month: int, day: int, hour: int, minute: int, second: int) -> Datetime: 30 | """Construct an object holding a timestamp value. 31 | 32 | This function is part of the `DBAPI 2.0 specification 33 | `_. 34 | 35 | :rtype: :class:`datetime.datetime` 36 | """ 37 | return Datetime(year, month, day, hour, minute, second) 38 | 39 | 40 | def DateFromTicks(ticks: float) -> date: 41 | """Construct an object holding a date value from the given ticks value 42 | (number of seconds since the epoch). 43 | 44 | This function is part of the `DBAPI 2.0 specification 45 | `_. 46 | 47 | :rtype: :class:`datetime.date` 48 | """ 49 | return Date(*localtime(ticks)[:3]) 50 | 51 | 52 | def TimeFromTicks(ticks: float) -> time: 53 | """Construct an objet holding a time value from the given ticks value 54 | (number of seconds since the epoch). 55 | 56 | This function is part of the `DBAPI 2.0 specification 57 | `_. 58 | 59 | :rtype: :class:`datetime.time` 60 | """ 61 | return Time(*localtime(ticks)[3:6]) 62 | 63 | 64 | def TimestampFromTicks(ticks: float) -> Datetime: 65 | """Construct an object holding a timestamp value from the given ticks value 66 | (number of seconds since the epoch). 67 | 68 | This function is part of the `DBAPI 2.0 specification 69 | `_. 70 | 71 | :rtype: :class:`datetime.datetime` 72 | """ 73 | return Timestamp(*localtime(ticks)[:6]) 74 | 75 | 76 | def Binary(value: bytes): 77 | """Construct an object holding binary data. 78 | 79 | This function is part of the `DBAPI 2.0 specification 80 | `_. 81 | 82 | """ 83 | return value 84 | -------------------------------------------------------------------------------- /redshift_connector/pg_types.py: -------------------------------------------------------------------------------- 1 | from json import dumps 2 | 3 | 4 | class PGType: 5 | def __init__(self: "PGType", value) -> None: 6 | self.value: str = value 7 | 8 | def encode(self, encoding) -> bytes: 9 | return str(self.value).encode(encoding) 10 | 11 | 12 | class PGEnum(PGType): 13 | def __init__(self: "PGEnum", value) -> None: 14 | if isinstance(value, str): 15 | self.value = value 16 | else: 17 | self.value = value.value 18 | 19 | 20 | class PGJson(PGType): 21 | def encode(self: "PGJson", encoding: str) -> bytes: 22 | return dumps(self.value).encode(encoding) 23 | 24 | 25 | class PGJsonb(PGType): 26 | def encode(self: "PGJsonb", encoding: str) -> bytes: 27 | return dumps(self.value).encode(encoding) 28 | 29 | 30 | class PGTsvector(PGType): 31 | pass 32 | 33 | 34 | class PGVarchar(str): 35 | pass 36 | 37 | 38 | class PGText(str): 39 | pass 40 | -------------------------------------------------------------------------------- /redshift_connector/plugin/__init__.py: -------------------------------------------------------------------------------- 1 | from .adfs_credentials_provider import AdfsCredentialsProvider 2 | from .azure_credentials_provider import AzureCredentialsProvider 3 | from .browser_azure_credentials_provider import BrowserAzureCredentialsProvider 4 | from .browser_azure_oauth2_credentials_provider import ( 5 | BrowserAzureOAuth2CredentialsProvider, 6 | ) 7 | from .browser_idc_auth_plugin import BrowserIdcAuthPlugin 8 | from .browser_saml_credentials_provider import BrowserSamlCredentialsProvider 9 | from .common_credentials_provider import CommonCredentialsProvider 10 | from .idp_credentials_provider import IdpCredentialsProvider 11 | from .idp_token_auth_plugin import IdpTokenAuthPlugin 12 | from .jwt_credentials_provider import ( 13 | BasicJwtCredentialsProvider, 14 | JwtCredentialsProvider, 15 | ) 16 | from .okta_credentials_provider import OktaCredentialsProvider 17 | from .ping_credentials_provider import PingCredentialsProvider 18 | from .saml_credentials_provider import SamlCredentialsProvider 19 | -------------------------------------------------------------------------------- /redshift_connector/plugin/browser_saml_credentials_provider.py: -------------------------------------------------------------------------------- 1 | import concurrent.futures 2 | import logging 3 | import re 4 | import socket 5 | import typing 6 | import urllib.parse 7 | 8 | from redshift_connector.error import InterfaceError 9 | from redshift_connector.plugin.saml_credentials_provider import SamlCredentialsProvider 10 | from redshift_connector.redshift_property import RedshiftProperty 11 | 12 | _logger: logging.Logger = logging.getLogger(__name__) 13 | 14 | 15 | # Class to get SAML Response 16 | class BrowserSamlCredentialsProvider(SamlCredentialsProvider): 17 | """ 18 | Generic Identity Provider Browser Plugin providing multi-factor authentication access to an Amazon Redshift cluster using an identity provider of your choice. 19 | """ 20 | 21 | def __init__(self: "BrowserSamlCredentialsProvider") -> None: 22 | super().__init__() 23 | self.login_url: typing.Optional[str] = None 24 | 25 | self.idp_response_timeout: int = 120 26 | self.listen_port: int = 7890 27 | 28 | # method to grab the field parameters specified by end user. 29 | # This method adds to it specific parameters. 30 | def add_parameter(self: "BrowserSamlCredentialsProvider", info: RedshiftProperty) -> None: 31 | super().add_parameter(info) 32 | self.login_url = info.login_url 33 | 34 | self.idp_response_timeout = info.idp_response_timeout 35 | self.listen_port = info.listen_port 36 | 37 | # Required method to grab the SAML Response. Used in base class to refresh temporary credentials. 38 | def get_saml_assertion(self: "BrowserSamlCredentialsProvider") -> str: 39 | _logger.debug("BrowserSamlCredentialsProvider.get_saml_assertion") 40 | 41 | if self.login_url == "" or self.login_url is None: 42 | BrowserSamlCredentialsProvider.handle_missing_required_property("login_url") 43 | 44 | if self.idp_response_timeout < 10: 45 | BrowserSamlCredentialsProvider.handle_invalid_property_value( 46 | "idp_response_timeout", "Must be 10 seconds or greater" 47 | ) 48 | if self.listen_port < 1 or self.listen_port > 65535: 49 | BrowserSamlCredentialsProvider.handle_invalid_property_value("listen_port", "Must be in range [1,65535]") 50 | 51 | return self.authenticate() 52 | 53 | # Authentication consists of: 54 | # Start the Socket Server on the port {@link BrowserSamlCredentialsProvider#m_listen_port}. 55 | # Open the default browser with the link asking a User to enter the credentials. 56 | # Retrieve the SAML Assertion string from the response. 57 | def authenticate(self: "BrowserSamlCredentialsProvider") -> str: 58 | _logger.debug("BrowserSamlCredentialsProvider.authenticate") 59 | 60 | try: 61 | with concurrent.futures.ThreadPoolExecutor() as executor: 62 | _logger.debug("Listening for connection on port %s", self.listen_port) 63 | future = executor.submit(self.run_server, self.listen_port, self.idp_response_timeout) 64 | self.open_browser() 65 | return_value: str = future.result() 66 | 67 | samlresponse = urllib.parse.unquote(return_value) 68 | return str(samlresponse) 69 | except socket.error as e: 70 | _logger.debug("Socket error: %s", e) 71 | raise e 72 | except Exception as e: 73 | _logger.debug("Other Exception: %s", e) 74 | raise e 75 | 76 | def run_server(self: "BrowserSamlCredentialsProvider", listen_port: int, idp_response_timeout: int) -> str: 77 | _logger.debug("BrowserSamlCredentialsProvider.run_server") 78 | HOST: str = "127.0.0.1" 79 | PORT: int = listen_port 80 | 81 | with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s: 82 | s.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) 83 | _logger.debug("attempting socket bind on host %s port %s", HOST, PORT) 84 | s.bind((HOST, PORT)) 85 | s.listen() 86 | conn, addr = s.accept() # typing.Tuple[Socket, Any] 87 | _logger.debug("Localhost socket connection established for Browser SAML IdP") 88 | conn.settimeout(float(idp_response_timeout)) 89 | size: int = 102400 90 | with conn: 91 | while True: 92 | part: bytes = conn.recv(size) 93 | decoded_part: str = part.decode() 94 | result: typing.Optional[typing.Match] = re.search( 95 | pattern="SAMLResponse[:=]+[\\n\\r]*", string=decoded_part, flags=re.MULTILINE 96 | ) 97 | _logger.debug("Data received contained SAML Response: %s", result is not None) 98 | 99 | if result is not None: 100 | conn.send(self.close_window_http_resp()) 101 | saml_resp_block: str = decoded_part[result.regs[0][1] :] 102 | end_idx: int = saml_resp_block.find("&RelayState=") 103 | if end_idx > -1: 104 | saml_resp_block = saml_resp_block[:end_idx] 105 | return saml_resp_block 106 | 107 | # Opens the default browser with the authorization request to the web service. 108 | def open_browser(self: "BrowserSamlCredentialsProvider") -> None: 109 | _logger.debug("BrowserSamlCredentialsProvider.open_browser") 110 | import webbrowser 111 | 112 | url: typing.Optional[str] = self.login_url 113 | 114 | if url is None: 115 | BrowserSamlCredentialsProvider.handle_missing_required_property("login_url") 116 | self.validate_url(typing.cast(str, url)) 117 | webbrowser.open(typing.cast(str, url)) 118 | -------------------------------------------------------------------------------- /redshift_connector/plugin/common_credentials_provider.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import typing 3 | from abc import abstractmethod 4 | 5 | from redshift_connector.error import InterfaceError 6 | from redshift_connector.iam_helper import IamHelper 7 | from redshift_connector.plugin.i_native_plugin import INativePlugin 8 | from redshift_connector.plugin.idp_credentials_provider import IdpCredentialsProvider 9 | from redshift_connector.plugin.native_token_holder import NativeTokenHolder 10 | from redshift_connector.redshift_property import RedshiftProperty 11 | 12 | _logger: logging.Logger = logging.getLogger(__name__) 13 | 14 | 15 | class CommonCredentialsProvider(INativePlugin, IdpCredentialsProvider): 16 | """ 17 | Abstract base class for authentication plugins using IdC authentication. 18 | """ 19 | 20 | def __init__(self: "CommonCredentialsProvider") -> None: 21 | super().__init__() 22 | self.last_refreshed_credentials: typing.Optional[NativeTokenHolder] = None 23 | 24 | @abstractmethod 25 | def get_auth_token(self: "CommonCredentialsProvider") -> str: 26 | """ 27 | Returns the auth token retrieved from corresponding plugin 28 | """ 29 | pass # pragma: no cover 30 | 31 | def add_parameter( 32 | self: "CommonCredentialsProvider", 33 | info: RedshiftProperty, 34 | ) -> None: 35 | self.disable_cache = True 36 | 37 | def get_credentials(self: "CommonCredentialsProvider") -> NativeTokenHolder: 38 | credentials: typing.Optional[NativeTokenHolder] = None 39 | 40 | if not self.disable_cache: 41 | key = self.get_cache_key() 42 | credentials = typing.cast(NativeTokenHolder, self.cache.get(key)) 43 | 44 | if not credentials or credentials.is_expired(): 45 | if self.disable_cache: 46 | _logger.debug("Auth token Cache disabled : fetching new token") 47 | else: 48 | _logger.debug("Auth token Cache enabled - No auth token found from cache : fetching new token") 49 | 50 | self.refresh() 51 | 52 | if self.disable_cache: 53 | credentials = self.last_refreshed_credentials 54 | self.last_refreshed_credentials = None 55 | else: 56 | credentials.refresh = False 57 | _logger.debug("Auth token found from cache") 58 | 59 | if not self.disable_cache: 60 | credentials = typing.cast(NativeTokenHolder, self.cache[key]) 61 | return typing.cast(NativeTokenHolder, credentials) 62 | 63 | def refresh(self: "CommonCredentialsProvider") -> None: 64 | auth_token: str = self.get_auth_token() 65 | _logger.debug("auth token: {}".format(auth_token)) 66 | 67 | if auth_token is None: 68 | raise InterfaceError("IdC authentication failed : An error occurred during the request.") 69 | 70 | credentials: NativeTokenHolder = NativeTokenHolder(access_token=auth_token, expiration=None) 71 | credentials.refresh = True 72 | 73 | _logger.debug("disable_cache={}".format(str(self.disable_cache))) 74 | if not self.disable_cache: 75 | self.cache[self.get_cache_key()] = credentials 76 | else: 77 | self.last_refreshed_credentials = credentials 78 | 79 | def get_idp_token(self: "CommonCredentialsProvider") -> str: 80 | auth_token: str = self.get_auth_token() 81 | return auth_token 82 | 83 | def set_group_federation(self: "CommonCredentialsProvider", group_federation: bool): 84 | pass 85 | 86 | def get_sub_type(self: "CommonCredentialsProvider") -> int: 87 | return IamHelper.IDC_PLUGIN 88 | 89 | def get_cache_key(self: "CommonCredentialsProvider") -> str: 90 | return "" 91 | -------------------------------------------------------------------------------- /redshift_connector/plugin/credential_provider_constants.py: -------------------------------------------------------------------------------- 1 | import typing 2 | 3 | azure_headers: typing.Dict[str, str] = { 4 | "Content-Type": "application/x-www-form-urlencoded", 5 | "Accept": "application/json", 6 | } 7 | okta_headers: typing.Dict[str, str] = { 8 | "Accept": "application/json", 9 | "Content-Type": "application/json", 10 | "Cache-Control": "no-cache", 11 | } 12 | # order of preference when searching for attributes in SAML response 13 | SAML_RESP_NAMESPACES: typing.List[str] = ["saml2:", "saml:", ""] 14 | -------------------------------------------------------------------------------- /redshift_connector/plugin/i_native_plugin.py: -------------------------------------------------------------------------------- 1 | from abc import abstractmethod 2 | 3 | from redshift_connector.plugin.i_plugin import IPlugin 4 | 5 | 6 | class INativePlugin(IPlugin): 7 | """ 8 | Abstract base class for authentication plugins using Redshift Native authentication 9 | """ 10 | 11 | @abstractmethod 12 | def get_idp_token(self: "INativePlugin") -> str: 13 | """ 14 | Returns the IdP token retrieved after authenticating with the plugin. 15 | """ 16 | pass # pragma: no cover 17 | -------------------------------------------------------------------------------- /redshift_connector/plugin/i_plugin.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import typing 3 | from abc import ABC, abstractmethod 4 | 5 | if typing.TYPE_CHECKING: 6 | from redshift_connector.credentials_holder import CredentialsHolder 7 | from redshift_connector.plugin.native_token_holder import NativeTokenHolder 8 | from redshift_connector.redshift_property import RedshiftProperty 9 | 10 | _logger: logging.Logger = logging.getLogger(__name__) 11 | 12 | 13 | class IPlugin(ABC): 14 | @abstractmethod 15 | def add_parameter(self: "IPlugin", info: "RedshiftProperty"): 16 | """ 17 | Defines instance attributes from the :class:RedshiftProperty object associated with the current authentication session. 18 | """ 19 | pass 20 | 21 | @abstractmethod 22 | def get_cache_key(self: "IPlugin") -> str: 23 | pass 24 | 25 | @abstractmethod 26 | def get_credentials(self: "IPlugin") -> typing.Union["NativeTokenHolder", "CredentialsHolder"]: 27 | """ 28 | Returns a :class:NativeTokenHolder object associated with the current plugin. 29 | """ 30 | pass # pragma: no cover 31 | 32 | @abstractmethod 33 | def get_sub_type(self: "IPlugin") -> int: 34 | """ 35 | Returns a type code indicating the type of the current plugin. 36 | 37 | See :class:IdpAuthHelper for possible return values 38 | 39 | """ 40 | pass # pragma: no cover 41 | 42 | @abstractmethod 43 | def refresh(self: "IPlugin") -> None: 44 | """ 45 | Refreshes the credentials, stored in :class:NativeTokenHolder, for the current plugin. 46 | """ 47 | pass # pragma: no cover 48 | 49 | @abstractmethod 50 | def set_group_federation(self: "IPlugin", group_federation: bool): 51 | pass 52 | 53 | @staticmethod 54 | def handle_missing_required_property(property_name: str) -> None: 55 | from redshift_connector import InterfaceError 56 | 57 | error_msg: str = "Missing required connection property: {}".format(property_name) 58 | _logger.debug(error_msg) 59 | raise InterfaceError(error_msg) 60 | 61 | @staticmethod 62 | def handle_invalid_property_value(property_name: str, reason: str) -> None: 63 | from redshift_connector import InterfaceError 64 | 65 | error_msg: str = "Invalid value specified for connection property: {}. {}".format(property_name, reason) 66 | _logger.debug(error_msg) 67 | raise InterfaceError(error_msg) 68 | -------------------------------------------------------------------------------- /redshift_connector/plugin/idp_credentials_provider.py: -------------------------------------------------------------------------------- 1 | import typing 2 | from abc import abstractmethod 3 | 4 | from redshift_connector.error import InterfaceError 5 | from redshift_connector.plugin.i_plugin import IPlugin 6 | from redshift_connector.redshift_property import IAM_URL_PATTERN, RedshiftProperty 7 | 8 | if typing.TYPE_CHECKING: 9 | from redshift_connector.credentials_holder import ABCCredentialsHolder 10 | from redshift_connector.plugin.native_token_holder import NativeTokenHolder 11 | 12 | 13 | class IdpCredentialsProvider(IPlugin): 14 | """ 15 | Abstract base class for authentication plugins. 16 | """ 17 | 18 | def __init__(self: "IdpCredentialsProvider") -> None: 19 | self.cache: typing.Dict[str, typing.Union[NativeTokenHolder, ABCCredentialsHolder]] = {} 20 | 21 | @staticmethod 22 | def close_window_http_resp() -> bytes: 23 | """ 24 | Builds the client facing HTML contents notifying that the authentication window may be closed. 25 | """ 26 | return str.encode( 27 | "HTTP/1.1 200 OK\nContent-Type: text/html\n\n" 28 | + "Thank you for using Amazon Redshift! You can now close this window.\n" 29 | ) 30 | 31 | @abstractmethod 32 | def check_required_parameters(self: "IdpCredentialsProvider") -> None: 33 | """ 34 | Performs validation on client provided parameters used by the IdP. 35 | """ 36 | pass # pragma: no cover 37 | 38 | @classmethod 39 | def validate_url(cls, uri: str) -> None: 40 | import re 41 | 42 | if not re.fullmatch(pattern=IAM_URL_PATTERN, string=str(uri)): 43 | raise InterfaceError("URI: {} is an invalid web address".format(uri)) 44 | -------------------------------------------------------------------------------- /redshift_connector/plugin/idp_token_auth_plugin.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import typing 3 | 4 | from redshift_connector.error import InterfaceError 5 | from redshift_connector.plugin.common_credentials_provider import ( 6 | CommonCredentialsProvider, 7 | ) 8 | from redshift_connector.redshift_property import RedshiftProperty 9 | 10 | logging.getLogger(__name__).addHandler(logging.NullHandler()) 11 | _logger: logging.Logger = logging.getLogger(__name__) 12 | 13 | 14 | class IdpTokenAuthPlugin(CommonCredentialsProvider): 15 | """ 16 | A basic IdP Token auth plugin class. This plugin class allows clients to directly provide any auth token that is handled by Redshift. 17 | """ 18 | 19 | def __init__(self: "IdpTokenAuthPlugin") -> None: 20 | super().__init__() 21 | self.token: typing.Optional[str] = None 22 | self.token_type: typing.Optional[str] = None 23 | 24 | def add_parameter( 25 | self: "IdpTokenAuthPlugin", 26 | info: RedshiftProperty, 27 | ) -> None: 28 | super().add_parameter(info) 29 | self.token = info.token 30 | self.token_type = info.token_type 31 | _logger.debug("Setting token_type = {}".format(self.token_type)) 32 | 33 | def check_required_parameters(self: "IdpTokenAuthPlugin") -> None: 34 | super().check_required_parameters() 35 | if not self.token: 36 | _logger.error("IdC authentication failed: token needs to be provided in connection params") 37 | raise InterfaceError("IdC authentication failed: The token must be included in the connection parameters.") 38 | if not self.token_type: 39 | _logger.error("IdC authentication failed: token_type needs to be provided in connection params") 40 | raise InterfaceError( 41 | "IdC authentication failed: The token type must be included in the connection parameters." 42 | ) 43 | 44 | def get_cache_key(self: "IdpTokenAuthPlugin") -> str: # type: ignore 45 | pass 46 | 47 | def get_auth_token(self: "IdpTokenAuthPlugin") -> str: 48 | self.check_required_parameters() 49 | return typing.cast(str, self.token) 50 | -------------------------------------------------------------------------------- /redshift_connector/plugin/jwt_credentials_provider.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import typing 3 | from abc import abstractmethod 4 | 5 | from redshift_connector.error import InterfaceError 6 | from redshift_connector.iam_helper import IamHelper 7 | from redshift_connector.plugin.i_native_plugin import INativePlugin 8 | from redshift_connector.plugin.idp_credentials_provider import IdpCredentialsProvider 9 | from redshift_connector.plugin.native_token_holder import NativeTokenHolder 10 | from redshift_connector.redshift_property import RedshiftProperty 11 | 12 | _logger: logging.Logger = logging.getLogger(__name__) 13 | 14 | 15 | class JwtCredentialsProvider(INativePlugin, IdpCredentialsProvider): 16 | """ 17 | Abstract base class for authentication plugins using JWT for Redshift native authentication. 18 | """ 19 | 20 | KEY_ROLE_ARN: str = "role_arn" 21 | KEY_WEB_IDENTITY_TOKEN: str = "web_identity_token" 22 | KEY_DURATION: str = "duration" 23 | KEY_ROLE_SESSION_NAME: str = "role_session_name" 24 | DEFAULT_ROLE_SESSION_NAME: str = "jwt_redshift_session" 25 | 26 | def __init__(self: "JwtCredentialsProvider") -> None: 27 | super().__init__() 28 | self.last_refreshed_credentials: typing.Optional[NativeTokenHolder] = None 29 | 30 | @abstractmethod 31 | def get_jwt_assertion(self: "JwtCredentialsProvider") -> str: 32 | """ 33 | Returns the jwt assertion retrieved following IdP authentication 34 | """ 35 | pass # pragma: no cover 36 | 37 | def add_parameter( 38 | self: "JwtCredentialsProvider", 39 | info: RedshiftProperty, 40 | ) -> None: 41 | self.provider_name = info.provider_name 42 | self.ssl_insecure = info.ssl_insecure 43 | self.disable_cache = info.iam_disable_cache 44 | self.group_federation = False 45 | 46 | if info.role_session_name is not None: 47 | self.role_session_name = info.role_session_name 48 | 49 | def set_group_federation(self: "JwtCredentialsProvider", group_federation: bool): 50 | self.group_federation = group_federation 51 | 52 | def get_credentials(self: "JwtCredentialsProvider") -> NativeTokenHolder: 53 | _logger.debug("JwtCredentialsProvider.get_credentials") 54 | credentials: typing.Optional[NativeTokenHolder] = None 55 | 56 | if not self.disable_cache: 57 | _logger.debug("checking cache for credentials") 58 | key = self.get_cache_key() 59 | credentials = typing.cast(NativeTokenHolder, self.cache.get(key)) 60 | 61 | if not credentials or credentials.is_expired(): 62 | _logger.debug("JWT get_credentials NOT from cache") 63 | self.refresh() 64 | 65 | if self.disable_cache: 66 | credentials = self.last_refreshed_credentials 67 | self.last_refreshed_credentials = None 68 | else: 69 | credentials.refresh = False 70 | _logger.debug("JWT get_credentials from cache") 71 | 72 | if not self.disable_cache: 73 | credentials = typing.cast(NativeTokenHolder, self.cache[key]) 74 | return typing.cast(NativeTokenHolder, credentials) 75 | 76 | def refresh(self: "JwtCredentialsProvider") -> None: 77 | _logger.debug("JwtCredentialsProvider.refresh") 78 | jwt: str = self.get_jwt_assertion() 79 | 80 | if jwt is None: 81 | exec_msg = "Unable to refresh, no jwt provided" 82 | _logger.debug(exec_msg) 83 | raise InterfaceError(exec_msg) 84 | 85 | credentials: NativeTokenHolder = NativeTokenHolder(access_token=jwt, expiration=None) 86 | credentials.refresh = True 87 | 88 | _logger.debug("disable_cache=%s", self.disable_cache) 89 | if not self.disable_cache: 90 | self.cache[self.get_cache_key()] = credentials 91 | 92 | else: 93 | self.last_refreshed_credentials = credentials 94 | 95 | def do_verify_ssl_cert(self: "JwtCredentialsProvider") -> bool: 96 | return not self.ssl_insecure 97 | 98 | def get_idp_token(self: "JwtCredentialsProvider") -> str: 99 | jwt: str = self.get_jwt_assertion() 100 | 101 | return jwt 102 | 103 | def get_sub_type(self: "JwtCredentialsProvider") -> int: 104 | return IamHelper.JWT_PLUGIN 105 | 106 | 107 | class BasicJwtCredentialsProvider(JwtCredentialsProvider): 108 | """ 109 | A basic JWT Credential provider class that can be changed and implemented to work with any desired JWT service provider. 110 | """ 111 | 112 | def __init__(self: "BasicJwtCredentialsProvider") -> None: 113 | super().__init__() 114 | self.jwt: typing.Optional[str] = None 115 | 116 | def add_parameter( 117 | self: "BasicJwtCredentialsProvider", 118 | info: RedshiftProperty, 119 | ) -> None: 120 | super().add_parameter(info) 121 | self.jwt = info.web_identity_token 122 | 123 | if info.role_session_name is not None: 124 | self.role_session_name = info.role_session_name 125 | 126 | def check_required_parameters(self: "BasicJwtCredentialsProvider") -> None: 127 | super().check_required_parameters() 128 | if not self.jwt: 129 | BasicJwtCredentialsProvider.handle_missing_required_property("jwt") 130 | 131 | def get_cache_key(self: "BasicJwtCredentialsProvider") -> str: 132 | return typing.cast(str, self.jwt) 133 | 134 | def get_jwt_assertion(self: "BasicJwtCredentialsProvider") -> str: 135 | self.check_required_parameters() 136 | return self.jwt # type: ignore 137 | -------------------------------------------------------------------------------- /redshift_connector/plugin/native_token_holder.py: -------------------------------------------------------------------------------- 1 | import datetime 2 | import typing 3 | 4 | from dateutil.tz import tzutc 5 | 6 | 7 | class NativeTokenHolder: 8 | """ 9 | Holds Redshift Native authentication credentials. 10 | """ 11 | 12 | def __init__(self: "NativeTokenHolder", access_token: str, expiration: typing.Optional[str]): 13 | self.access_token: str = access_token 14 | self.expiration = expiration 15 | self.refresh: bool = False # True means newly added, false means from cache 16 | 17 | def is_expired(self: "NativeTokenHolder") -> bool: 18 | """ 19 | Returns boolean value indicating if the Redshift native authentication credentials have expired. 20 | """ 21 | return self.expiration is None or typing.cast(datetime.datetime, self.expiration) < datetime.datetime.now( 22 | tz=tzutc() 23 | ) 24 | -------------------------------------------------------------------------------- /redshift_connector/py.typed: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aws/amazon-redshift-python-driver/5bbd061b03cbe555dbc783f44f716bcf8b4864c4/redshift_connector/py.typed -------------------------------------------------------------------------------- /redshift_connector/utils/__init__.py: -------------------------------------------------------------------------------- 1 | from .array_util import ( 2 | array_check_dimensions, 3 | array_dim_lengths, 4 | array_find_first_element, 5 | array_flatten, 6 | array_has_null, 7 | walk_array, 8 | ) 9 | from .driver_info import DriverInfo 10 | from .logging_utils import make_divider_block, mask_secure_info_in_props 11 | from .type_utils import ( 12 | FC_BINARY, 13 | FC_TEXT, 14 | NULL, 15 | NULL_BYTE, 16 | array_recv_binary, 17 | array_recv_text, 18 | bh_unpack, 19 | cccc_unpack, 20 | ci_unpack, 21 | date_in, 22 | date_recv_binary, 23 | float_array_recv, 24 | geographyhex_recv, 25 | h_pack, 26 | h_unpack, 27 | i_pack, 28 | i_unpack, 29 | ihihih_unpack, 30 | ii_pack, 31 | iii_pack, 32 | int_array_recv, 33 | numeric_in, 34 | numeric_in_binary, 35 | numeric_to_float_binary, 36 | numeric_to_float_in, 37 | py_types, 38 | q_pack, 39 | redshift_types, 40 | text_recv, 41 | time_in, 42 | time_recv_binary, 43 | timetz_in, 44 | timetz_recv_binary, 45 | varbytehex_recv, 46 | ) 47 | -------------------------------------------------------------------------------- /redshift_connector/utils/array_util.py: -------------------------------------------------------------------------------- 1 | import typing 2 | 3 | from redshift_connector.error import ArrayDimensionsNotConsistentError 4 | 5 | 6 | def walk_array(arr: typing.List) -> typing.Generator: 7 | for i, v in enumerate(arr): 8 | if isinstance(v, list): 9 | for a, i2, v2 in walk_array(v): 10 | yield a, i2, v2 11 | else: 12 | yield arr, i, v 13 | 14 | 15 | def array_find_first_element(arr: typing.List) -> typing.Any: 16 | for v in array_flatten(arr): 17 | if v is not None: 18 | return v 19 | return None 20 | 21 | 22 | def array_flatten(arr: typing.List) -> typing.Generator: 23 | for v in arr: 24 | if isinstance(v, list): 25 | for v2 in array_flatten(v): 26 | yield v2 27 | else: 28 | yield v 29 | 30 | 31 | def array_check_dimensions(arr: typing.List) -> typing.List: 32 | if len(arr) > 0: 33 | v0 = arr[0] 34 | if isinstance(v0, list): 35 | req_len = len(v0) 36 | req_inner_lengths = array_check_dimensions(v0) 37 | for v in arr: 38 | inner_lengths = array_check_dimensions(v) 39 | if len(v) != req_len or inner_lengths != req_inner_lengths: 40 | raise ArrayDimensionsNotConsistentError("array dimensions not consistent") 41 | retval = [req_len] 42 | retval.extend(req_inner_lengths) 43 | return retval 44 | else: 45 | # make sure nothing else at this level is a list 46 | for v in arr: 47 | if isinstance(v, list): 48 | raise ArrayDimensionsNotConsistentError("array dimensions not consistent") 49 | return [] 50 | 51 | 52 | def array_has_null(arr: typing.List) -> bool: 53 | for v in array_flatten(arr): 54 | if v is None: 55 | return True 56 | return False 57 | 58 | 59 | def array_dim_lengths(arr: typing.List) -> typing.List: 60 | len_arr = len(arr) 61 | retval = [len_arr] 62 | if len_arr > 0: 63 | v0 = arr[0] 64 | if isinstance(v0, list): 65 | retval.extend(array_dim_lengths(v0)) 66 | return retval 67 | -------------------------------------------------------------------------------- /redshift_connector/utils/driver_info.py: -------------------------------------------------------------------------------- 1 | class DriverInfo: 2 | """ 3 | No-op informative class containing Amazon Redshift Python driver specifications. 4 | """ 5 | 6 | @staticmethod 7 | def version() -> str: 8 | """ 9 | The version of redshift_connector 10 | Returns 11 | ------- 12 | The redshift_connector package version: str 13 | """ 14 | from redshift_connector import __version__ as DRIVER_VERSION 15 | 16 | return str(DRIVER_VERSION) 17 | 18 | @staticmethod 19 | def driver_name() -> str: 20 | """ 21 | The name of the Amazon Redshift Python driver, redshift_connector 22 | Returns 23 | ------- 24 | The human readable name of the redshift_connector package: str 25 | """ 26 | return "Redshift Python Driver" 27 | 28 | @staticmethod 29 | def driver_short_name() -> str: 30 | """ 31 | The shortened name of the Amazon Redshift Python driver, redshift_connector 32 | Returns 33 | ------- 34 | The shortened human readable name of the Amazon Redshift Python driver: str 35 | """ 36 | return "RsPython" 37 | 38 | @staticmethod 39 | def driver_full_name() -> str: 40 | """ 41 | The fully qualified name of the Amazon Redshift Python driver, redshift_connector 42 | Returns 43 | ------- 44 | The fully qualified name of the Amazon Redshift Python driver: str 45 | """ 46 | return "{driver_name} {driver_version}".format( 47 | driver_name=DriverInfo.driver_name(), driver_version=DriverInfo.version() 48 | ) 49 | -------------------------------------------------------------------------------- /redshift_connector/utils/extensible_digest.py: -------------------------------------------------------------------------------- 1 | import typing 2 | from hashlib import new as hashlib_new 3 | 4 | from redshift_connector import InterfaceError 5 | 6 | if typing.TYPE_CHECKING: 7 | from hashlib import _Hash 8 | 9 | 10 | class ExtensibleDigest: 11 | """ 12 | Encodes user/password/salt information in the following way: SHA2(SHA2(password + user) + salt). 13 | """ 14 | 15 | @staticmethod 16 | def encode(client_nonce: bytes, password: bytes, salt: bytes, algo_name: str, server_nonce: bytes) -> bytes: 17 | """ 18 | Encodes user/password/salt information in the following way: SHA2(SHA2(password + user) + salt). 19 | :param client_nonce: The client nonce 20 | :type client_nonce: bytes 21 | :param password: The connecting user's password 22 | :type password: bytes 23 | :param salt: salt sent by the server 24 | :type salt: bytes 25 | :param algo_name: Algorithm name such as "SHA256" etc. 26 | :type algo_name: str 27 | :param server_nonce: random number generated by server 28 | :type server_nonce: bytes 29 | :return: the digest 30 | :rtype: bytes 31 | """ 32 | try: 33 | hl1: "_Hash" = hashlib_new(name=algo_name) 34 | except ImportError: 35 | raise InterfaceError("Unable to encode password with extensible hashing: {}".format(algo_name)) 36 | hl1.update(password) 37 | hl1.update(salt) 38 | pass_digest1: bytes = hl1.digest() # SHA2(user + password) 39 | 40 | hl2: "_Hash" = hashlib_new(name=algo_name) 41 | hl2.update(pass_digest1) 42 | hl2.update(server_nonce) 43 | hl2.update(client_nonce) 44 | pass_digest2 = hl2.digest() # SHA2(SHA2(password + user) + salt) 45 | 46 | return pass_digest2 47 | -------------------------------------------------------------------------------- /redshift_connector/utils/logging_utils.py: -------------------------------------------------------------------------------- 1 | import copy 2 | import socket 3 | import typing 4 | 5 | if typing.TYPE_CHECKING: 6 | from redshift_connector import RedshiftProperty 7 | 8 | 9 | def make_divider_block() -> str: 10 | return "=" * 35 11 | 12 | 13 | def mask_secure_info_in_props(info: "RedshiftProperty") -> "RedshiftProperty": 14 | from redshift_connector import RedshiftProperty 15 | 16 | logging_allow_list: typing.Tuple[str, ...] = ( 17 | # "access_key_id", 18 | "allow_db_user_override", 19 | "app_id", 20 | "app_name", 21 | "application_name", 22 | "auth_profile", 23 | "auto_create", 24 | # "client_id", 25 | "client_protocol_version", 26 | # "client_secret", 27 | "cluster_identifier", 28 | "credentials_provider", 29 | "database_metadata_current_db_only", 30 | "db_groups", 31 | "db_name", 32 | "db_user", 33 | "duration", 34 | "endpoint_url", 35 | "force_lowercase", 36 | "group_federation", 37 | "host", 38 | "iam", 39 | "iam_disable_cache", 40 | "idc_client_display_name", 41 | "idc_region", 42 | "identity_namespace", 43 | "idp_host", 44 | "idpPort", 45 | "idp_response_timeout", 46 | "idp_tenant", 47 | "issuer_url", 48 | "is_serverless", 49 | "listen_port", 50 | "login_url", 51 | "max_prepared_statements", 52 | "numeric_to_float", 53 | "partner_sp_id", 54 | # "password", 55 | "port", 56 | "preferred_role", 57 | "principal", 58 | "profile", 59 | "provider_name", 60 | "region", 61 | "replication", 62 | "role_arn", 63 | "role_session_name", 64 | "scope", 65 | # "secret_access_key", 66 | "serverless_acct_id", 67 | "serverless_work_group", 68 | # "session_token", 69 | "source_address", 70 | "ssl", 71 | "ssl_insecure", 72 | "sslmode", 73 | "tcp_keepalive", 74 | "token_type", 75 | "timeout", 76 | "unix_sock", 77 | "user_name", 78 | # "web_identity_token", 79 | ) 80 | 81 | if info is None: 82 | return info 83 | 84 | temp: RedshiftProperty = RedshiftProperty() 85 | 86 | def is_populated(field: typing.Optional[str]): 87 | return field is not None and field != "" 88 | 89 | for parameter, value in info.__dict__.items(): 90 | if parameter in logging_allow_list: 91 | temp.put(parameter, value) 92 | elif is_populated(value): 93 | try: 94 | temp.put(parameter, "***") 95 | except AttributeError: 96 | pass 97 | 98 | return temp 99 | -------------------------------------------------------------------------------- /redshift_connector/utils/oids.py: -------------------------------------------------------------------------------- 1 | from enum import IntEnum, EnumMeta 2 | 3 | class RedshiftOIDMeta(EnumMeta): 4 | def __setattr__(self, name, value): 5 | if name in self.__dict__: 6 | raise AttributeError(f"Cannot modify OID constant '{name}'") 7 | # throw error if any constant value defined in RedshiftOID was modified 8 | # e.g. "Cannot modify OID constant 'VARCHAR'" 9 | super().__setattr__(name, value) 10 | 11 | 12 | class RedshiftOID(IntEnum, metaclass=RedshiftOIDMeta): 13 | ACLITEM = 1033 14 | ACLITEM_ARRAY = 1034 15 | ANY_ARRAY = 2277 16 | ABSTIME = 702 17 | BIGINT = 20 18 | BIGINT_ARRAY = 1016 19 | BOOLEAN = 16 20 | BOOLEAN_ARRAY = 1000 21 | BPCHAR = 1042 22 | BPCHAR_ARRAY = 1014 23 | BYTES = 17 24 | BYTES_ARRAY = 1001 25 | CHAR = 18 26 | CHAR_ARRAY = 1002 27 | CIDR = 650 28 | CIDR_ARRAY = 651 29 | CSTRING = 2275 30 | CSTRING_ARRAY = 1263 31 | DATE = 1082 32 | DATE_ARRAY = 1182 33 | FLOAT = 701 34 | FLOAT_ARRAY = 1022 35 | GEOGRAPHY = 3001 36 | GEOMETRY = 3000 37 | GEOMETRYHEX = 3999 38 | INET = 869 39 | INET_ARRAY = 1041 40 | INT2VECTOR = 22 41 | INTEGER = 23 42 | INTEGER_ARRAY = 1007 43 | INTERVAL = 1186 44 | INTERVAL_ARRAY = 1187 45 | INTERVALY2M = 1188 46 | INTERVALY2M_ARRAY = 1189 47 | INTERVALD2S = 1190 48 | INTERVALD2S_ARRAY = 1191 49 | JSON = 114 50 | JSON_ARRAY = 199 51 | JSONB = 3802 52 | JSONB_ARRAY = 3807 53 | MACADDR = 829 54 | MONEY = 790 55 | MONEY_ARRAY = 791 56 | NAME = 19 57 | NAME_ARRAY = 1003 58 | NUMERIC = 1700 59 | NUMERIC_ARRAY = 1231 60 | NULLTYPE = -1 61 | OID = 26 62 | OID_ARRAY = 1028 63 | POINT = 600 64 | REAL = 700 65 | REAL_ARRAY = 1021 66 | REGPROC = 24 67 | SMALLINT = 21 68 | SMALLINT_ARRAY = 1005 69 | SMALLINT_VECTOR = 22 70 | STRING = 1043 71 | SUPER = 4000 72 | TEXT = 25 73 | TEXT_ARRAY = 1009 74 | TIME = 1083 75 | TIME_ARRAY = 1183 76 | TIMESTAMP = 1114 77 | TIMESTAMP_ARRAY = 1115 78 | TIMESTAMPTZ = 1184 79 | TIMESTAMPTZ_ARRAY = 1185 80 | TIMETZ = 1266 81 | UNKNOWN = 705 82 | UUID_TYPE = 2950 83 | UUID_ARRAY = 2951 84 | VARCHAR = 1043 85 | VARBYTE = 6551 86 | VARCHAR_ARRAY = 1015 87 | XID = 28 88 | 89 | BIGINTEGER = BIGINT 90 | DATETIME = TIMESTAMP 91 | NUMBER = DECIMAL = NUMERIC 92 | DECIMAL_ARRAY = NUMERIC_ARRAY 93 | ROWID = OID 94 | 95 | 96 | def get_datatype_name(oid: int) -> str: 97 | return RedshiftOID(oid).name 98 | -------------------------------------------------------------------------------- /redshift_connector/utils/sql_types.py: -------------------------------------------------------------------------------- 1 | from enum import IntEnum, EnumMeta 2 | 3 | class SQLTypeMeta(EnumMeta): 4 | def __setattr__(self, name, value): 5 | if name in self.__dict__: 6 | raise AttributeError(f"Cannot modify SQL type constant '{name}'") 7 | # throw error if any constant value defined in SQLType was modified 8 | # e.g. "Cannot modify SQL type constant 'SQL_VARCHAR'" 9 | super().__setattr__(name, value) 10 | 11 | class SQLType(IntEnum, metaclass=SQLTypeMeta): 12 | SQL_VARCHAR = 12 13 | SQL_BIT = -7 14 | SQL_TINYINT = -6 15 | SQL_SMALLINT = 5 16 | SQL_INTEGER = 4 17 | SQL_BIGINT = -5 18 | SQL_FLOAT = 6 19 | SQL_REAL = 7 20 | SQL_DOUBLE = 8 21 | SQL_NUMERIC = 2 22 | SQL_DECIMAL = 3 23 | SQL_CHAR = 1 24 | SQL_LONGVARCHAR = -1 25 | SQL_DATE = 91 26 | SQL_TIME = 92 27 | SQL_TIMESTAMP = 93 28 | SQL_BINARY = -2 29 | SQL_VARBINARY = -3 30 | SQL_LONGVARBINARY = -4 31 | SQL_NULL = 0 32 | SQL_OTHER = 1111 33 | SQL_BOOLEAN = 16 34 | SQL_LONGNVARCHAR = -16 35 | SQL_TIME_WITH_TIMEZONE = 2013 36 | SQL_TIMESTAMP_WITH_TIMEZONE = 2014 37 | 38 | 39 | def get_sql_type_name(sql_type: int) -> str: 40 | return SQLType(sql_type).name -------------------------------------------------------------------------------- /redshift_connector/version.py: -------------------------------------------------------------------------------- 1 | # Store the version here so: 2 | # 1) we don't load dependencies by storing it in __init__.py 3 | # 2) we can import it in setup.py for the same reason 4 | # 3) we can import it into your module module 5 | __version__ = "2.1.7" 6 | -------------------------------------------------------------------------------- /requirements-dev.txt: -------------------------------------------------------------------------------- 1 | pytest>=5.4.0,<7.4.0 2 | pytest-xdist[psutil] 3 | mypy>=0.782 4 | pre-commit>=2.6.0 5 | pytest-cov>=2.10.0 6 | pytest-mock>=1.11.1,<=3.2.0 7 | wheel>=0.33 8 | docutils>=0.14 9 | selenium>=4.25 10 | -e . 11 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | scramp>=1.2.0,<1.5.0 2 | pytz>=2020.1 3 | beautifulsoup4>=4.7.0,<5.0.0 4 | boto3>=1.9.201,<2.0.0 5 | requests>=2.23.0,<3.0.0 6 | lxml>=4.6.5 7 | botocore>=1.12.201,<2.0.0 8 | packaging 9 | setuptools 10 | -------------------------------------------------------------------------------- /setup.cfg: -------------------------------------------------------------------------------- 1 | 2 | [coverage:run] 3 | branch = true 4 | parallel = true 5 | 6 | [coverage:paths] 7 | source = 8 | ./ 9 | build/lib/*/site-packages/ 10 | 11 | [coverage:html] 12 | directory = build/coverage 13 | 14 | [coverage:xml] 15 | output = build/coverage/coverage.xml 16 | 17 | [tool:pytest] 18 | addopts = 19 | --verbose 20 | --ignore=build/private 21 | --doctest-modules 22 | --cov redshift_connector 23 | --cov-report term-missing 24 | --cov-report html:build/coverage 25 | --cov-report xml:build/coverage/coverage.xml 26 | test 27 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | import os 3 | import sys 4 | import typing 5 | 6 | from setuptools import find_packages, setup 7 | from setuptools.command.install import install as InstallCommandBase 8 | from setuptools.command.test import test as TestCommand 9 | from setuptools.dist import Distribution 10 | from wheel.bdist_wheel import bdist_wheel as BDistWheelCommandBase 11 | 12 | 13 | class BasePytestCommand(TestCommand): 14 | user_options: typing.List = [] 15 | test_dir: typing.Optional[str] = None 16 | 17 | def initialize_options(self): 18 | TestCommand.initialize_options(self) 19 | 20 | def finalize_options(self): 21 | TestCommand.finalize_options(self) 22 | self.test_args = [] 23 | self.test_suite = True 24 | 25 | def run_tests(self): 26 | import pytest 27 | 28 | src_dir = os.getenv("SRC_DIR", "") 29 | if src_dir: 30 | src_dir += "/" 31 | args = [ 32 | self.test_dir, 33 | "--cov=redshift_connector", 34 | "--cov-report=xml", 35 | "--cov-report=html", 36 | ] 37 | 38 | errno = pytest.main(args) 39 | sys.exit(errno) 40 | 41 | 42 | class UnitTestCommand(BasePytestCommand): 43 | test_dir: str = "test/unit" 44 | 45 | 46 | class IntegrationTestCommand(BasePytestCommand): 47 | test_dir = "test/integration" 48 | 49 | 50 | class BinaryDistribution(Distribution): 51 | def has_ext_modules(self): 52 | return True 53 | 54 | 55 | class InstallCommand(InstallCommandBase): 56 | """Override the installation dir.""" 57 | 58 | def finalize_options(self): 59 | ret = InstallCommandBase.finalize_options(self) 60 | self.install_lib = self.install_platlib 61 | return ret 62 | 63 | 64 | class BDistWheelCommand(BDistWheelCommandBase): 65 | def finalize_options(self): 66 | super().finalize_options() 67 | self.root_is_pure = False 68 | self.universal = True 69 | 70 | def get_tag(self): 71 | python, abi, plat = "py3", "none", "any" 72 | return python, abi, plat 73 | 74 | 75 | custom_cmds = { 76 | "bdist_wheel": BDistWheelCommand, 77 | "unit_test": UnitTestCommand, 78 | "integration_test": IntegrationTestCommand, 79 | } 80 | 81 | if os.getenv("CUSTOMINSTALL", False): 82 | custom_cmds["install"] = InstallCommand 83 | elif "install" in custom_cmds: 84 | del custom_cmds["install"] 85 | 86 | # read the contents of your README file 87 | this_directory = os.path.abspath(os.path.dirname(__file__)) 88 | with open(os.path.join(this_directory, "README.rst"), encoding="utf-8") as f: 89 | long_description = f.read() 90 | exec(open("redshift_connector/version.py").read()) 91 | 92 | optional_deps = { 93 | "full": ["numpy", "pandas"], 94 | } 95 | 96 | setup( 97 | name="redshift_connector", 98 | version=__version__, # type: ignore 99 | description="Redshift interface library", 100 | long_description=long_description, 101 | long_description_content_type="text/x-rst", 102 | author="Amazon Web Services", 103 | author_email="redshift-drivers@amazon.com", 104 | url="https://github.com/aws/amazon-redshift-python-driver", 105 | license="Apache License 2.0", 106 | python_requires=">=3.6", 107 | install_requires=open("requirements.txt").read().strip().split("\n"), 108 | extras_require=optional_deps, 109 | classifiers=[ 110 | "Development Status :: 5 - Production/Stable", 111 | "Intended Audience :: Developers", 112 | "License :: OSI Approved :: BSD License", 113 | "Programming Language :: Python", 114 | "Programming Language :: Python :: 3", 115 | "Programming Language :: Python :: 3.6", 116 | "Programming Language :: Python :: 3.7", 117 | "Programming Language :: Python :: 3.8", 118 | "Programming Language :: Python :: 3.9", 119 | "Programming Language :: Python :: 3.10", 120 | "Programming Language :: Python :: 3.11", 121 | "Programming Language :: Python :: Implementation", 122 | "Programming Language :: Python :: Implementation :: CPython", 123 | "Programming Language :: Python :: Implementation :: Jython", 124 | "Programming Language :: Python :: Implementation :: PyPy", 125 | "Operating System :: OS Independent", 126 | "Topic :: Database :: Front-Ends", 127 | "Topic :: Software Development :: Libraries :: Python Modules", 128 | ], 129 | keywords="redshift dbapi", 130 | include_package_data=True, 131 | package_data={"redshift-connector": ["*.py", "*.crt", "LICENSE", "NOTICE", "py.typed"]}, 132 | packages=find_packages(exclude=["test*"]), 133 | cmdclass=custom_cmds, 134 | ) 135 | -------------------------------------------------------------------------------- /test/__init__.py: -------------------------------------------------------------------------------- 1 | from test.integration import ( 2 | adfs_idp, 3 | azure_browser_idp, 4 | azure_idp, 5 | db_kwargs, 6 | idp_arg, 7 | jumpcloud_browser_idp, 8 | jwt_azure_v2_idp, 9 | jwt_google_idp, 10 | okta_browser_idp, 11 | okta_idp, 12 | ping_browser_idp, 13 | redshift_browser_idc, 14 | redshift_idp_token_auth_plugin, 15 | ) 16 | -------------------------------------------------------------------------------- /test/integration/__init__.py: -------------------------------------------------------------------------------- 1 | from test.conftest import ( 2 | adfs_idp, 3 | azure_browser_idp, 4 | azure_idp, 5 | db_kwargs, 6 | ds_consumer_dsdb_kwargs, 7 | idp_arg, 8 | jumpcloud_browser_idp, 9 | jwt_azure_v2_idp, 10 | jwt_google_idp, 11 | okta_browser_idp, 12 | okta_idp, 13 | ping_browser_idp, 14 | redshift_browser_idc, 15 | redshift_idp_token_auth_plugin, 16 | redshift_native_browser_azure_oauth2_idp, 17 | ) 18 | -------------------------------------------------------------------------------- /test/integration/datatype/test_system_table_queries.py: -------------------------------------------------------------------------------- 1 | import typing 2 | from warnings import warn 3 | 4 | import pytest # type: ignore 5 | 6 | import redshift_connector 7 | from redshift_connector.config import ClientProtocolVersion 8 | 9 | system_tables: typing.List[str] = [ 10 | "pg_aggregate", 11 | "pg_am", 12 | "pg_amop", 13 | "pg_amproc", 14 | "pg_attrdef", 15 | "pg_attribute", 16 | "pg_cast", 17 | "pg_class", 18 | "pg_constraint", 19 | "pg_conversion", 20 | "pg_database", 21 | "pg_default_acl", 22 | "pg_depend", 23 | "pg_description", 24 | # "pg_index", # has unsupported type oids 22, 30 25 | "pg_inherits", 26 | "pg_language", 27 | "pg_largeobject", 28 | "pg_namespace", 29 | "pg_opclass", 30 | "pg_operator", 31 | "pg_proc", 32 | "pg_rewrite", 33 | "pg_shdepend", 34 | "pg_statistic", 35 | "pg_tablespace", 36 | # "pg_trigger", # has unsupported type oid 22 37 | "pg_type", 38 | "pg_group", 39 | "pg_indexes", 40 | "pg_locks", 41 | "pg_rules", 42 | "pg_settings", 43 | "pg_stats", 44 | "pg_tables", 45 | "pg_user", 46 | "pg_views", 47 | # "pg_authid", 48 | # "pg_auth_members", 49 | # "pg_collation", 50 | # "pg_db_role_setting", 51 | # "pg_enum", 52 | # "pg_extension", 53 | # "pg_foreign_data_wrapper", 54 | # "pg_foreign_server", 55 | # "pg_foreign_table", 56 | # "pg_largeobject_metadata", 57 | # "pg_opfamily", 58 | # "pg_pltemplate", 59 | # "pg_seclabel", 60 | # "pg_shdescription", 61 | # "pg_ts_config", 62 | # "pg_ts_config_map", 63 | # "pg_ts_dict", 64 | # "pg_ts_parser", 65 | # "pg_ts_template", 66 | # "pg_user_mapping", 67 | # "pg_available_extensions", 68 | # "pg_available_extension_versions", 69 | # "pg_cursors", 70 | # "pg_prepared_statements", 71 | # "pg_prepared_xacts", 72 | # "pg_roles", 73 | # "pg_seclabels", 74 | # "pg_shadow", 75 | # "pg_timezone_abbrevs", 76 | # "pg_timezone_names", 77 | # "pg_user_mappings", 78 | ] 79 | 80 | # this test ensures system tables can be queried without datatype 81 | # conversion issue. no validation of result set occurs. 82 | 83 | 84 | @pytest.mark.parametrize("client_protocol", ClientProtocolVersion.list()) 85 | @pytest.mark.parametrize("table_name", system_tables) 86 | def test_process_system_table_datatypes(db_kwargs, client_protocol, table_name): 87 | db_kwargs["client_protocol_version"] = client_protocol 88 | 89 | with redshift_connector.connect(**db_kwargs) as conn: 90 | if conn._client_protocol_version != client_protocol: 91 | warn( 92 | "Requested client_protocol_version was not satisfied. Requested {} Got {}".format( 93 | client_protocol, conn._client_protocol_version 94 | ) 95 | ) 96 | with conn.cursor() as cursor: 97 | cursor.execute("select * from {}".format(table_name)) 98 | -------------------------------------------------------------------------------- /test/integration/metadata/test_metadataAPIHelperServer.py: -------------------------------------------------------------------------------- 1 | import configparser 2 | import os 3 | import typing 4 | import pytest # type: ignore 5 | 6 | import redshift_connector 7 | from redshift_connector.metadataAPIHelper import MetadataAPIHelper 8 | 9 | conf: configparser.ConfigParser = configparser.ConfigParser() 10 | root_path: str = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) 11 | conf.read(root_path + "/config.ini") 12 | 13 | 14 | def test_SHOW_DATABASES_col(db_kwargs) -> None: 15 | with redshift_connector.connect(**db_kwargs) as conn: 16 | with conn.cursor() as cursor: 17 | if cursor.supportSHOWDiscovery() >= 2: 18 | cursor.execute("SHOW DATABASES;") 19 | 20 | column = cursor.description 21 | 22 | col_set: typing.Set = set() 23 | 24 | for col in column: 25 | col_set.add(col[0]) 26 | 27 | mock_metadataAPIHelper: MetadataAPIHelper = MetadataAPIHelper() 28 | 29 | assert mock_metadataAPIHelper._SHOW_DATABASES_database_name in col_set 30 | 31 | 32 | def test_SHOW_SCHEMAS_col(db_kwargs) -> None: 33 | with redshift_connector.connect(**db_kwargs) as conn: 34 | with conn.cursor() as cursor: 35 | if cursor.supportSHOWDiscovery() >= 2: 36 | cursor.execute("SHOW SCHEMAS FROM DATABASE test_catalog;") 37 | 38 | column = cursor.description 39 | 40 | col_set: typing.Set = set() 41 | 42 | for col in column: 43 | col_set.add(col[0]) 44 | 45 | mock_metadataAPIHelper: MetadataAPIHelper = MetadataAPIHelper() 46 | 47 | assert mock_metadataAPIHelper._SHOW_SCHEMA_database_name in col_set 48 | assert mock_metadataAPIHelper._SHOW_SCHEMA_schema_name in col_set 49 | 50 | 51 | def test_SHOW_TABLES_col(db_kwargs) -> None: 52 | with redshift_connector.connect(**db_kwargs) as conn: 53 | with conn.cursor() as cursor: 54 | if cursor.supportSHOWDiscovery() >= 2: 55 | cursor.execute("SHOW TABLES FROM SCHEMA test_catalog.test_schema;") 56 | 57 | column = cursor.description 58 | 59 | col_set: typing.Set = set() 60 | 61 | for col in column: 62 | col_set.add(col[0]) 63 | 64 | mock_metadataAPIHelper: MetadataAPIHelper = MetadataAPIHelper() 65 | 66 | assert mock_metadataAPIHelper._SHOW_TABLES_database_name in col_set 67 | assert mock_metadataAPIHelper._SHOW_TABLES_schema_name in col_set 68 | assert mock_metadataAPIHelper._SHOW_TABLES_table_name in col_set 69 | assert mock_metadataAPIHelper._SHOW_TABLES_table_type in col_set 70 | assert mock_metadataAPIHelper._SHOW_TABLES_remarks in col_set 71 | 72 | 73 | def test_SHOW_COLUMNS_col(db_kwargs) -> None: 74 | with redshift_connector.connect(**db_kwargs) as conn: 75 | with conn.cursor() as cursor: 76 | if cursor.supportSHOWDiscovery() >= 2: 77 | cursor.execute("SHOW COLUMNS FROM TABLE test_catalog.test_schema.test_table;") 78 | 79 | column = cursor.description 80 | 81 | col_set: typing.Set = set() 82 | 83 | for col in column: 84 | col_set.add(col[0]) 85 | 86 | mock_metadataAPIHelper: MetadataAPIHelper = MetadataAPIHelper() 87 | 88 | assert mock_metadataAPIHelper._SHOW_COLUMNS_database_name in col_set 89 | assert mock_metadataAPIHelper._SHOW_COLUMNS_schema_name in col_set 90 | assert mock_metadataAPIHelper._SHOW_COLUMNS_table_name in col_set 91 | assert mock_metadataAPIHelper._SHOW_COLUMNS_column_name in col_set 92 | assert mock_metadataAPIHelper._SHOW_COLUMNS_ordinal_position in col_set 93 | assert mock_metadataAPIHelper._SHOW_COLUMNS_column_default in col_set 94 | assert mock_metadataAPIHelper._SHOW_COLUMNS_is_nullable in col_set 95 | assert mock_metadataAPIHelper._SHOW_COLUMNS_data_type in col_set 96 | assert mock_metadataAPIHelper._SHOW_COLUMNS_character_maximum_length in col_set 97 | assert mock_metadataAPIHelper._SHOW_COLUMNS_numeric_precision in col_set 98 | assert mock_metadataAPIHelper._SHOW_COLUMNS_numeric_scale in col_set 99 | assert mock_metadataAPIHelper._SHOW_COLUMNS_remarks in col_set 100 | 101 | 102 | -------------------------------------------------------------------------------- /test/integration/metadata/test_metadataAPI_special_character_handling_standard_identifier.py: -------------------------------------------------------------------------------- 1 | import configparser 2 | import os 3 | import typing 4 | import pytest # type: ignore 5 | 6 | import redshift_connector 7 | from redshift_connector.metadataAPIHelper import MetadataAPIHelper 8 | 9 | conf: configparser.ConfigParser = configparser.ConfigParser() 10 | root_path: str = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) 11 | conf.read(root_path + "/config.ini") 12 | 13 | # Manually set the following flag to False if we want to test cross-database metadata api call 14 | disable_cross_database_testing: bool = True 15 | 16 | current_catalog: str = "təst_cata好log$123standard" 17 | 18 | object_name: str = "təst_n𝒜m好e$123standard" 19 | 20 | startup_stmts: typing.Tuple[str, ...] = ( 21 | "DROP SCHEMA IF EXISTS {} CASCADE;".format(object_name), 22 | "CREATE SCHEMA {};".format(object_name), 23 | "create table {}.{} ({} INT);".format(object_name, object_name, object_name), 24 | ) 25 | 26 | test_cases = [ 27 | # Standard identifier with lower case 28 | ([current_catalog, object_name], 1, 29 | [object_name]), 30 | 31 | # Standard identifier with mixed case 32 | ([current_catalog, "təst_N𝒜m好e$123Standard"], 0, 33 | []), 34 | 35 | # Standard identifier with lower case + illegal special character 36 | ([current_catalog, "təst_n𝒜m好e$123!standard"], 0, 37 | []), 38 | ] 39 | 40 | @pytest.fixture(scope="class", autouse=True) 41 | def test_metadataAPI_config(request, db_kwargs): 42 | global cur_db_kwargs 43 | cur_db_kwargs = dict(db_kwargs) 44 | print(cur_db_kwargs) 45 | 46 | with redshift_connector.connect(**cur_db_kwargs) as con: 47 | con.paramstyle = "format" 48 | con.autocommit = True 49 | with con.cursor() as cursor: 50 | try: 51 | cursor.execute("drop database {};".format(current_catalog)) 52 | except redshift_connector.ProgrammingError: 53 | pass 54 | cursor.execute("create database {};".format(current_catalog)) 55 | cur_db_kwargs["database"] = current_catalog 56 | with redshift_connector.connect(**cur_db_kwargs) as con: 57 | con.paramstyle = "format" 58 | with con.cursor() as cursor: 59 | for stmt in startup_stmts: 60 | cursor.execute(stmt) 61 | 62 | con.commit() 63 | def fin(): 64 | try: 65 | with redshift_connector.connect(**db_kwargs) as con: 66 | con.autocommit = True 67 | with con.cursor() as cursor: 68 | cursor.execute("drop database {};".format(current_catalog)) 69 | cursor.execute("select 1;") 70 | except redshift_connector.ProgrammingError: 71 | pass 72 | 73 | request.addfinalizer(fin) 74 | 75 | @pytest.mark.parametrize("test_case, expected_row_count, expected_result", test_cases) 76 | def test_get_schemas_special_character(db_kwargs, test_case, expected_row_count, expected_result) -> None: 77 | global cur_db_kwargs 78 | if disable_cross_database_testing: 79 | test_db_kwargs = dict(cur_db_kwargs) 80 | else: 81 | test_db_kwargs = dict(db_kwargs) 82 | test_db_kwargs["database_metadata_current_db_only"] = False 83 | 84 | with redshift_connector.connect(**test_db_kwargs) as conn: 85 | with conn.cursor() as cursor: 86 | result: tuple = cursor.get_schemas(test_case[0], test_case[1]) 87 | assert len(result) == expected_row_count 88 | for actual_row, expected_schema in zip(result, expected_result): 89 | assert actual_row[0] == expected_schema 90 | assert actual_row[1] == current_catalog 91 | 92 | @pytest.mark.parametrize("test_case, expected_row_count, expected_result", test_cases) 93 | def test_get_tables_special_character(db_kwargs, test_case, expected_row_count, expected_result) -> None: 94 | global cur_db_kwargs 95 | if disable_cross_database_testing: 96 | test_db_kwargs = dict(cur_db_kwargs) 97 | else: 98 | test_db_kwargs = dict(db_kwargs) 99 | test_db_kwargs["database_metadata_current_db_only"] = False 100 | 101 | with redshift_connector.connect(**test_db_kwargs) as conn: 102 | with conn.cursor() as cursor: 103 | result: tuple = cursor.get_tables(test_case[0], test_case[1], test_case[1], None) 104 | assert len(result) == expected_row_count 105 | for actual_row, expected_name in zip(result, expected_result): 106 | assert actual_row[0] == current_catalog 107 | assert actual_row[1] == expected_name 108 | assert actual_row[2] == expected_name 109 | 110 | @pytest.mark.parametrize("test_case, expected_row_count, expected_result", test_cases) 111 | def test_get_columns_special_character(db_kwargs, test_case, expected_row_count, expected_result) -> None: 112 | global cur_db_kwargs 113 | if disable_cross_database_testing: 114 | test_db_kwargs = dict(cur_db_kwargs) 115 | else: 116 | test_db_kwargs = dict(db_kwargs) 117 | test_db_kwargs["database_metadata_current_db_only"] = False 118 | 119 | with redshift_connector.connect(**test_db_kwargs) as conn: 120 | with conn.cursor() as cursor: 121 | result: tuple = cursor.get_columns(test_case[0], test_case[1], test_case[1], test_case[1]) 122 | assert len(result) == expected_row_count 123 | for actual_row, expected_name in zip(result, expected_result): 124 | assert actual_row[0] == current_catalog 125 | assert actual_row[1] == expected_name 126 | assert actual_row[2] == expected_name 127 | assert actual_row[3] == expected_name 128 | 129 | -------------------------------------------------------------------------------- /test/integration/metadata/test_metadataAPI_sql_injection.py: -------------------------------------------------------------------------------- 1 | import configparser 2 | import os 3 | import typing 4 | import pytest # type: ignore 5 | 6 | import redshift_connector 7 | from redshift_connector.metadataAPIHelper import MetadataAPIHelper 8 | 9 | conf: configparser.ConfigParser = configparser.ConfigParser() 10 | root_path: str = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) 11 | conf.read(root_path + "/config.ini") 12 | 13 | # Manually set the following flag to False if we want to test cross-database metadata api call 14 | disable_cross_database_testing: bool = True 15 | 16 | catalog_name: str = "test_sql_injection -- comment" 17 | 18 | schema_name: str = "test_schema; create table pwn1(i int);--" 19 | table_name: str = "test_table; create table pwn2(i int);--" 20 | col_name: str = "col" 21 | 22 | startup_stmts: typing.Tuple[str, ...] = ( 23 | "DROP SCHEMA IF EXISTS \"{}\" CASCADE;".format(schema_name), 24 | "CREATE SCHEMA \"{}\";".format(schema_name), 25 | "create table \"{}\".\"{}\" (col INT);".format(schema_name, table_name), 26 | ) 27 | 28 | test_cases = [ 29 | "' OR '1'='1", 30 | "' OR 1=1 --", 31 | "' OR '1'='1' --", 32 | "') OR 1=1 --", 33 | "') OR '1'='1' --", 34 | "' OR 1=1;--", 35 | "' OR 1=1 LIMIT 1;--", 36 | "' UNION SELECT null --", 37 | "' UNION SELECT null, null --", 38 | "' UNION SELECT 1, 'username', 'password' FROM users --", 39 | "' UNION SELECT * FROM users --", 40 | "') AND 1=CAST((SELECT current_database()) AS INT) --", 41 | "' AND 1=1 --", 42 | "' AND 1=2 --", 43 | "'; DROP TABLE test_table_100; --", 44 | "\"; DROP TABLE test_table_100; --\"", 45 | "; DROP TABLE test_table_100; --" 46 | ] 47 | 48 | @pytest.fixture(scope="class", autouse=True) 49 | def test_metadataAPI_config(request, db_kwargs): 50 | global cur_db_kwargs 51 | cur_db_kwargs = dict(db_kwargs) 52 | print(cur_db_kwargs) 53 | 54 | with redshift_connector.connect(**cur_db_kwargs) as con: 55 | con.paramstyle = "format" 56 | con.autocommit = True 57 | with con.cursor() as cursor: 58 | try: 59 | cursor.execute("drop database \"{}\";".format(catalog_name)) 60 | except redshift_connector.ProgrammingError: 61 | pass 62 | cursor.execute("create database \"{}\";".format(catalog_name)) 63 | cur_db_kwargs["database"] = catalog_name 64 | with redshift_connector.connect(**cur_db_kwargs) as con: 65 | con.paramstyle = "format" 66 | with con.cursor() as cursor: 67 | for stmt in startup_stmts: 68 | cursor.execute(stmt) 69 | 70 | con.commit() 71 | def fin(): 72 | try: 73 | with redshift_connector.connect(**db_kwargs) as con: 74 | con.autocommit = True 75 | with con.cursor() as cursor: 76 | cursor.execute("drop database \"{}\";".format(catalog_name)) 77 | cursor.execute("select 1;") 78 | except redshift_connector.ProgrammingError: 79 | pass 80 | 81 | request.addfinalizer(fin) 82 | 83 | @pytest.mark.parametrize("test_input", test_cases) 84 | def test_input_parameter_sql_injection(db_kwargs, test_input) -> None: 85 | with redshift_connector.connect(**db_kwargs) as conn: 86 | with conn.cursor() as cursor: 87 | try: 88 | result: tuple = cursor.get_schemas(test_input, None) 89 | except Exception as e: 90 | pytest.fail(f"Unexpected exception raised: {e}") 91 | 92 | def test_get_schemas_sql_injection(db_kwargs) -> None: 93 | global cur_db_kwargs 94 | if disable_cross_database_testing: 95 | test_db_kwargs = dict(cur_db_kwargs) 96 | else: 97 | test_db_kwargs = dict(db_kwargs) 98 | test_db_kwargs["database_metadata_current_db_only"] = False 99 | 100 | with redshift_connector.connect(**test_db_kwargs) as conn: 101 | with conn.cursor() as cursor: 102 | result: tuple = cursor.get_schemas(catalog_name, None) 103 | 104 | assert len(result) > 0 105 | 106 | found_expected_row: bool = False 107 | for actual_row in result: 108 | if actual_row[0] == schema_name and actual_row[1] == catalog_name: 109 | found_expected_row = True 110 | 111 | assert found_expected_row 112 | 113 | def test_get_tables_sql_injection(db_kwargs) -> None: 114 | global cur_db_kwargs 115 | if disable_cross_database_testing: 116 | test_db_kwargs = dict(cur_db_kwargs) 117 | else: 118 | test_db_kwargs = dict(db_kwargs) 119 | test_db_kwargs["database_metadata_current_db_only"] = False 120 | 121 | with redshift_connector.connect(**test_db_kwargs) as conn: 122 | with conn.cursor() as cursor: 123 | result: tuple = cursor.get_tables(catalog_name, None, None) 124 | 125 | assert len(result) > 0 126 | 127 | found_expected_row: bool = False 128 | for actual_row in result: 129 | if actual_row[0] == catalog_name and actual_row[1] == schema_name and actual_row[2] == table_name: 130 | found_expected_row = True 131 | 132 | assert found_expected_row 133 | 134 | def test_get_columns_sql_injection(db_kwargs) -> None: 135 | global cur_db_kwargs 136 | if disable_cross_database_testing: 137 | test_db_kwargs = dict(cur_db_kwargs) 138 | else: 139 | test_db_kwargs = dict(db_kwargs) 140 | test_db_kwargs["database_metadata_current_db_only"] = False 141 | 142 | with redshift_connector.connect(**test_db_kwargs) as conn: 143 | with conn.cursor() as cursor: 144 | result: tuple = cursor.get_columns(catalog_name, None, None, None) 145 | 146 | assert len(result) > 0 147 | 148 | found_expected_row: bool = False 149 | for actual_row in result: 150 | if actual_row[0] == catalog_name and actual_row[1] == schema_name and actual_row[2] == table_name and actual_row[3] == col_name: 151 | found_expected_row = True 152 | 153 | assert found_expected_row 154 | 155 | -------------------------------------------------------------------------------- /test/integration/plugin/conftest.py: -------------------------------------------------------------------------------- 1 | import typing 2 | from test.conftest import _get_default_connection_args, conf 3 | 4 | import pytest # type: ignore 5 | 6 | import redshift_connector 7 | 8 | 9 | @pytest.fixture(scope="session", autouse=True) 10 | def startup_db_stmts() -> None: 11 | """ 12 | Executes a defined set of statements to configure a fresh Amazon Redshift cluster for IdP integration tests. 13 | """ 14 | groups: typing.List[str] = conf.get("cluster-setup", "groups").split(sep=",") 15 | 16 | with redshift_connector.connect(**_get_default_connection_args()) as conn: # type: ignore 17 | conn.autocommit = True 18 | with conn.cursor() as cursor: 19 | for grp in groups: 20 | try: 21 | cursor.execute("DROP GROUP {}".format(grp)) 22 | except: 23 | pass # we can't use IF EXISTS here, so ignore any error 24 | cursor.execute("CREATE GROUP {}".format(grp)) 25 | -------------------------------------------------------------------------------- /test/integration/plugin/test_azure_credentials_provider.py: -------------------------------------------------------------------------------- 1 | import configparser 2 | import os 3 | import typing 4 | 5 | import pytest # type: ignore 6 | 7 | import redshift_connector 8 | 9 | conf: configparser.ConfigParser = configparser.ConfigParser() 10 | root_path: str = os.path.dirname(os.path.dirname(os.path.abspath(os.path.join(__file__, os.pardir)))) 11 | conf.read(root_path + "/config.ini") 12 | 13 | PROVIDER: typing.List[str] = ["azure_idp"] 14 | 15 | 16 | @pytest.mark.parametrize("idp_arg", PROVIDER, indirect=True) 17 | def test_preferred_role_should_use(idp_arg): 18 | idp_arg["preferred_role"] = conf.get("azure-idp", "preferred_role") 19 | with redshift_connector.connect(**idp_arg): 20 | pass 21 | -------------------------------------------------------------------------------- /test/integration/plugin/test_okta_credentials_provider.py: -------------------------------------------------------------------------------- 1 | import configparser 2 | import os 3 | import typing 4 | 5 | import pytest # type: ignore 6 | 7 | import redshift_connector 8 | 9 | conf: configparser.ConfigParser = configparser.ConfigParser() 10 | root_path: str = os.path.dirname(os.path.dirname(os.path.abspath(os.path.join(__file__, os.pardir)))) 11 | conf.read(root_path + "/config.ini") 12 | 13 | PROVIDER: typing.List[str] = ["okta_idp"] 14 | 15 | 16 | @pytest.mark.parametrize("idp_arg", PROVIDER, indirect=True) 17 | def test_idp_host_invalid_should_fail(idp_arg) -> None: 18 | wrong_idp_host: str = "andrew.okta.com" 19 | idp_arg["idp_host"] = wrong_idp_host 20 | 21 | with pytest.raises(redshift_connector.InterfaceError, match="Failed to get SAML assertion"): 22 | redshift_connector.connect(**idp_arg) 23 | 24 | 25 | @pytest.mark.parametrize("idp_arg", PROVIDER, indirect=True) 26 | def test_preferred_role_should_use(idp_arg) -> None: 27 | idp_arg["preferred_role"] = conf.get("okta-idp", "preferred_role") 28 | with redshift_connector.connect(**idp_arg): 29 | pass 30 | -------------------------------------------------------------------------------- /test/integration/test_pandas.py: -------------------------------------------------------------------------------- 1 | from test.utils import numpy_only, pandas_only 2 | from warnings import filterwarnings 3 | 4 | import pytest # type: ignore 5 | 6 | import redshift_connector 7 | from redshift_connector.config import DbApiParamstyle 8 | 9 | # Tests relating to the pandas and numpy operation of the database driver 10 | # redshift_connector custom interface. 11 | 12 | 13 | @pytest.fixture 14 | def db_table(request, con: redshift_connector.Connection) -> redshift_connector.Connection: 15 | filterwarnings("ignore", "DB-API extension cursor.next()") 16 | filterwarnings("ignore", "DB-API extension cursor.__iter__()") 17 | con.paramstyle = "format" # type: ignore 18 | with con.cursor() as cursor: 19 | cursor.execute("drop table if exists book") 20 | cursor.execute("create Temp table book(bookname varchar,author‎ varchar)") 21 | 22 | def fin() -> None: 23 | try: 24 | with con.cursor() as cursor: 25 | cursor.execute("drop table if exists book") 26 | except redshift_connector.ProgrammingError: 27 | pass 28 | 29 | request.addfinalizer(fin) 30 | return con 31 | 32 | 33 | @pandas_only 34 | def test_fetch_dataframe(db_table) -> None: 35 | import numpy as np # type: ignore 36 | import pandas as pd # type: ignore 37 | 38 | df = pd.DataFrame( 39 | np.array( 40 | [ 41 | ["One Hundred Years of Solitude", "Gabriel García Márquez"], 42 | ["A Brief History of Time", "Stephen Hawking"], 43 | ] 44 | ), 45 | columns=["bookname", "author‎"], 46 | ) 47 | with db_table.cursor() as cursor: 48 | cursor.executemany( 49 | "insert into book (bookname, author‎) values (%s, %s)", 50 | [ 51 | ("One Hundred Years of Solitude", "Gabriel García Márquez"), 52 | ("A Brief History of Time", "Stephen Hawking"), 53 | ], 54 | ) 55 | cursor.execute("select * from book; ") 56 | result = cursor.fetch_dataframe() 57 | assert result.columns[0] == "bookname" 58 | assert result.columns[1] == "author\u200e" 59 | 60 | 61 | @pandas_only 62 | @pytest.mark.parametrize("paramstyle", DbApiParamstyle.list()) 63 | def test_write_dataframe(db_table, paramstyle) -> None: 64 | import numpy as np 65 | import pandas as pd 66 | 67 | df = pd.DataFrame( 68 | np.array( 69 | [ 70 | ["One Hundred Years of Solitude", "Gabriel García Márquez"], 71 | ["A Brief History of Time", "Stephen Hawking"], 72 | ] 73 | ), 74 | columns=["bookname", "author‎"], 75 | ) 76 | db_table.paramstyle = paramstyle 77 | 78 | with db_table.cursor() as cursor: 79 | cursor.write_dataframe(df, "book") 80 | cursor.execute("select * from book; ") 81 | result = cursor.fetchall() 82 | assert len(np.array(result)) == 2 83 | 84 | assert db_table.paramstyle == paramstyle 85 | 86 | 87 | @numpy_only 88 | def test_fetch_numpyarray(db_table) -> None: 89 | with db_table.cursor() as cursor: 90 | cursor.executemany( 91 | "insert into book (bookname, author‎) values (%s, %s)", 92 | [ 93 | ("One Hundred Years of Solitude", "Gabriel García Márquez"), 94 | ("A Brief History of Time", "Stephen Hawking"), 95 | ], 96 | ) 97 | cursor.execute("select * from book; ") 98 | result = cursor.fetch_numpy_array() 99 | assert len(result) == 2 100 | -------------------------------------------------------------------------------- /test/integration/test_paramstyle.py: -------------------------------------------------------------------------------- 1 | import typing 2 | 3 | import pytest # type: ignore 4 | 5 | 6 | @pytest.mark.parametrize( 7 | "parameters", 8 | [ 9 | ({"c1": "abc", "c2": "defg", "c3": "hijkl"}, ["abc", "defg", "hijkl"]), 10 | ({"c1": "a", "c2": "b", "c3": "c"}, ["a", "b", "c"]), 11 | ], 12 | ) 13 | def test_pyformat(cursor, parameters) -> None: 14 | cursor.paramstyle = "pyformat" 15 | data, exp_result = parameters 16 | cursor.execute("create temporary table test_pyformat(c1 varchar, c2 varchar, c3 varchar)") 17 | cursor.execute("insert into test_pyformat(c1, c2, c3) values(%(c1)s, %(c2)s, %(c3)s)", data) 18 | cursor.execute("select * from test_pyformat") 19 | res: typing.Tuple[typing.List[str], ...] = cursor.fetchall() 20 | assert len(res) == 1 21 | assert len(res[0]) == len(data) 22 | assert res[0] == exp_result 23 | 24 | 25 | @pytest.mark.parametrize( 26 | "parameters", 27 | [ 28 | ( 29 | ({"c1": "abc", "c2": "defg", "c3": "hijkl"}, {"c1": "a", "c2": "b", "c3": "c"}), 30 | [["a", "b", "c"], ["abc", "defg", "hijkl"]], 31 | ), 32 | ], 33 | ) 34 | def test_pyformat_multiple_insert(cursor, parameters) -> None: 35 | cursor.paramstyle = "pyformat" 36 | data, exp_result = parameters 37 | cursor.execute("create temporary table test_pyformat(c1 varchar, c2 varchar, c3 varchar)") 38 | cursor.executemany("insert into test_pyformat(c1, c2, c3) values(%(c1)s, %(c2)s, %(c3)s)", data) 39 | cursor.execute("select * from test_pyformat order by c1") 40 | res: typing.Tuple[typing.List[str], ...] = cursor.fetchall() 41 | assert len(res) == len(exp_result) 42 | for idx, row in enumerate(res): 43 | assert len(row) == len(exp_result[idx]) 44 | assert row == exp_result[idx] 45 | 46 | 47 | @pytest.mark.parametrize( 48 | "parameters", [(["abc", "defg", "hijkl"], ["abc", "defg", "hijkl"]), (["a", "b", "c"], ["a", "b", "c"])] 49 | ) 50 | def test_qmark(cursor, parameters) -> None: 51 | cursor.paramstyle = "qmark" 52 | data, exp_result = parameters 53 | cursor.execute("create temporary table test_qmark(c1 varchar, c2 varchar, c3 varchar)") 54 | cursor.execute("insert into test_qmark(c1, c2, c3) values(?, ?, ?)", data) 55 | cursor.execute("select * from test_qmark") 56 | res: typing.Tuple[typing.List[str], ...] = cursor.fetchall() 57 | assert len(res) == 1 58 | assert len(res[0]) == len(data) 59 | assert res[0] == exp_result 60 | 61 | 62 | @pytest.mark.parametrize( 63 | "parameters", [(["abc", "defg", "hijkl"], ["abc", "defg", "hijkl"]), (["a", "b", "c"], ["a", "b", "c"])] 64 | ) 65 | def test_numeric(cursor, parameters) -> None: 66 | cursor.paramstyle = "numeric" 67 | data, exp_result = parameters 68 | cursor.execute("create temporary table test_numeric(c1 varchar, c2 varchar, c3 varchar)") 69 | cursor.execute("insert into test_numeric(c1, c2, c3) values(:1, :2, :3)", data) 70 | cursor.execute("select * from test_numeric") 71 | res: typing.Tuple[typing.List[str], ...] = cursor.fetchall() 72 | assert len(res) == 1 73 | assert len(res[0]) == len(data) 74 | assert res[0] == exp_result 75 | 76 | 77 | @pytest.mark.parametrize( 78 | "parameters", 79 | [ 80 | ({"parameter1": "abc", "parameter2": "defg", "parameter3": "hijkl"}, ["abc", "defg", "hijkl"]), 81 | ({"parameter1": "a", "parameter2": "b", "parameter3": "c"}, ["a", "b", "c"]), 82 | ], 83 | ) 84 | def test_named(cursor, parameters) -> None: 85 | cursor.paramstyle = "named" 86 | data, exp_result = parameters 87 | cursor.execute("create temporary table test_named(c1 varchar, c2 varchar, c3 varchar)") 88 | cursor.execute("insert into test_named(c1, c2, c3) values(:parameter1, :parameter2, :parameter3)", data) 89 | cursor.execute("select * from test_named") 90 | res: typing.Tuple[typing.List[str], ...] = cursor.fetchall() 91 | assert len(res) == 1 92 | assert len(res[0]) == len(data) 93 | assert res[0] == exp_result 94 | 95 | 96 | @pytest.mark.parametrize( 97 | "parameters", [(["abc", "defg", "hijkl"], ["abc", "defg", "hijkl"]), (["a", "b", "c"], ["a", "b", "c"])] 98 | ) 99 | def test_format(cursor, parameters) -> None: 100 | cursor.paramstyle = "format" 101 | data, exp_result = parameters 102 | cursor.execute("create temporary table test_format(c1 varchar, c2 varchar, c3 varchar)") 103 | cursor.execute("insert into test_format(c1, c2, c3) values(%s, %s, %s)", data) 104 | cursor.execute("select * from test_format") 105 | res: typing.Tuple[typing.List[str], ...] = cursor.fetchall() 106 | assert len(res) == 1 107 | assert len(res[0]) == len(data) 108 | assert res[0] == exp_result 109 | 110 | 111 | @pytest.mark.parametrize( 112 | "parameters", 113 | [ 114 | ([["abc", "defg", "hijkl"], ["a", "b", "c"]], [["a", "b", "c"], ["abc", "defg", "hijkl"]]), 115 | ], 116 | ) 117 | def test_format_multiple(cursor, parameters) -> None: 118 | cursor.paramstyle = "format" 119 | data, exp_result = parameters 120 | cursor.execute("create temporary table test_format(c1 varchar, c2 varchar, c3 varchar)") 121 | cursor.executemany("insert into test_format(c1, c2, c3) values(%s, %s, %s)", data) 122 | cursor.execute("select * from test_format order by c1") 123 | res: typing.Tuple[typing.List[str], ...] = cursor.fetchall() 124 | assert len(res) == len(exp_result) 125 | for idx, row in enumerate(res): 126 | assert len(row) == len(exp_result[idx]) 127 | assert row == exp_result[idx] 128 | -------------------------------------------------------------------------------- /test/integration/test_redshift_property.py: -------------------------------------------------------------------------------- 1 | import logging 2 | 3 | import pytest 4 | 5 | from redshift_connector import InterfaceError, RedshiftProperty 6 | 7 | LOGGER = logging.getLogger(__name__) 8 | LOGGER.propagate = True 9 | 10 | 11 | def test_set_region_from_endpoint_lookup(db_kwargs) -> None: 12 | rp: RedshiftProperty = RedshiftProperty() 13 | rp.put(key="host", value=db_kwargs["host"]) 14 | rp.put(key="port", value=db_kwargs["port"]) 15 | rp.set_region_from_host() 16 | 17 | expected_region = rp.region 18 | rp.region = None 19 | 20 | rp.set_region_from_endpoint_lookup() 21 | assert rp.region == expected_region 22 | 23 | 24 | @pytest.mark.parametrize("host, port", [("x", 1000), ("amazon.com", -1), ("-o", 5439)]) 25 | def test_set_region_from_endpoint_lookup_raises(host, port, caplog) -> None: 26 | import logging 27 | 28 | rp: RedshiftProperty = RedshiftProperty() 29 | rp.put(key="host", value=host) 30 | rp.put(key="port", value=port) 31 | expected_msg: str = "Unable to automatically determine AWS region from host {} port {}. Please check host and port connection parameters are correct.".format( 32 | host, port 33 | ) 34 | 35 | with caplog.at_level(logging.DEBUG): 36 | rp.set_region_from_endpoint_lookup() 37 | assert expected_msg in caplog.text 38 | -------------------------------------------------------------------------------- /test/integration/test_unsupported_datatypes.py: -------------------------------------------------------------------------------- 1 | import configparser 2 | import os 3 | import typing 4 | from warnings import filterwarnings 5 | 6 | import pytest # type: ignore 7 | 8 | import redshift_connector 9 | 10 | if typing.TYPE_CHECKING: 11 | from redshift_connector import Connection 12 | 13 | conf: configparser.ConfigParser = configparser.ConfigParser() 14 | root_path: str = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) 15 | conf.read(root_path + "/config.ini") 16 | 17 | 18 | @pytest.fixture 19 | def db_table(request, con: redshift_connector.Connection) -> redshift_connector.Connection: 20 | filterwarnings("ignore", "DB-API extension cursor.next()") 21 | filterwarnings("ignore", "DB-API extension cursor.__iter__()") 22 | con.paramstyle = "format" # type: ignore 23 | 24 | def fin() -> None: 25 | try: 26 | with con.cursor() as cursor: 27 | cursor.execute("drop table if exists t1") 28 | except redshift_connector.ProgrammingError: 29 | pass 30 | 31 | request.addfinalizer(fin) 32 | return con 33 | 34 | 35 | # https://docs.aws.amazon.com/redshift/latest/dg/c_unsupported-postgresql-datatypes.html 36 | unsupported_datatypes: typing.List[str] = [ 37 | "bytea", 38 | "interval", 39 | "bit", 40 | "bit varying", 41 | "hstore", 42 | "json", 43 | "serial", 44 | "bigserial", 45 | "smallserial", 46 | "money", 47 | "txid_snapshot", 48 | "uuid", 49 | "xml", 50 | "inet", 51 | "cidr", 52 | "macaddr", 53 | "oid", 54 | "regproc", 55 | "regprocedure", 56 | "regoper", 57 | "regoperator", 58 | "regclass", 59 | "regtype", 60 | "regrole", 61 | "regnamespace", 62 | "regconfig", 63 | "regdictionary", 64 | "any", 65 | "anyelement", 66 | "anyarray", 67 | "anynonarray", 68 | "anyenum", 69 | "anyrange", 70 | "cstring", 71 | "internal", 72 | "language_handler", 73 | "fdw_handler", 74 | "tsm_handler", 75 | "record", 76 | "trigger", 77 | "event_trigger", 78 | "pg_ddl_command", 79 | "void", 80 | "opaque", 81 | "int4range", 82 | "int8range", 83 | "numrange", 84 | "tsrange", 85 | "tstzrange", 86 | "daterange", 87 | "tsvector", 88 | "tsquery", 89 | ] 90 | 91 | 92 | @pytest.mark.unsupported_datatype 93 | class TestUnsupportedDataTypes: 94 | @pytest.mark.parametrize("datatype", unsupported_datatypes) 95 | def test_create_table_with_unsupported_datatype_fails(self, db_table: "Connection", datatype: str) -> None: 96 | with db_table.cursor() as cursor: 97 | with pytest.raises(Exception) as exception: 98 | cursor.execute("CREATE TEMPORARY TABLE t1 (a {})".format(datatype)) 99 | assert exception.type == redshift_connector.ProgrammingError 100 | assert ( 101 | 'Column "t1.a" has unsupported type' in exception.__str__() 102 | or 'type "{}" does not exist'.format(datatype) in exception.__str__() 103 | or 'syntax error at or near "{}"'.format(datatype) in exception.__str__() 104 | ) 105 | 106 | @pytest.mark.parametrize("datatype", ["int[]", "int[][]"]) 107 | def test_create_table_with_array_datatype_fails(self, db_table: "Connection", datatype: str) -> None: 108 | with db_table.cursor() as cursor: 109 | with pytest.raises(Exception) as exception: 110 | cursor.execute("CREATE TEMPORARY TABLE t1 (a {})".format(datatype)) 111 | assert exception.type == redshift_connector.ProgrammingError 112 | assert ( 113 | 'Column "t1.a" has unsupported type' in exception.__str__() 114 | or 'type "{}" does not exist'.format(datatype) in exception.__str__() 115 | ) 116 | -------------------------------------------------------------------------------- /test/manual/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aws/amazon-redshift-python-driver/5bbd061b03cbe555dbc783f44f716bcf8b4864c4/test/manual/__init__.py -------------------------------------------------------------------------------- /test/manual/auth/test_aws_credentials.py: -------------------------------------------------------------------------------- 1 | import pytest # type: ignore 2 | 3 | import redshift_connector 4 | 5 | aws_secret_access_key: str = "" 6 | aws_access_key: str = "" 7 | aws_session_token: str = "" 8 | 9 | """ 10 | How to use: 11 | 0) If necessary, create a Redshift cluster 12 | 1) In the connect method below, specify the connection parameters 13 | 3) Specify the AWS IAM credentials in the variables above 14 | 4) Manually execute this test 15 | """ 16 | 17 | 18 | @pytest.mark.skip(reason="manual") 19 | def test_use_aws_credentials_default_profile() -> None: 20 | with redshift_connector.connect( 21 | iam=True, 22 | database="my_database", 23 | db_user="my_db_user", 24 | password="", 25 | user="", 26 | cluster_identifier="my_cluster_identifier", 27 | region="my_region", 28 | access_key_id=aws_access_key, 29 | secret_access_key=aws_secret_access_key, 30 | session_token=aws_session_token, 31 | ) as con: 32 | with con.cursor() as cursor: 33 | cursor.execute("select 1") 34 | 35 | 36 | """ 37 | How to use: 38 | 0) Generate credentials using instructions: https://docs.aws.amazon.com/sdk-for-javascript/v2/developer-guide/getting-your-credentials.html 39 | 1) In the connect method below, specify the connection parameters 40 | 3) Specify the AWS IAM credentials in the variables above 41 | 4) Update iam_helper.py to include correct min version. 42 | 5) Manually execute this test 43 | """ 44 | 45 | 46 | @pytest.mark.skip(reason="manual") 47 | def test_use_get_cluster_credentials_with_iam(db_kwargs): 48 | role_name = "groupFederationTest" 49 | with redshift_connector.connect(**db_kwargs) as conn: 50 | with conn.cursor() as cursor: 51 | # https://docs.aws.amazon.com/redshift/latest/dg/r_CREATE_USER.html 52 | cursor.execute('create user "IAMR:{}" with password disable;'.format(role_name)) 53 | with redshift_connector.connect( 54 | iam=True, 55 | database="replace_me", 56 | cluster_identifier="replace_me", 57 | region="replace_me", 58 | profile="replace_me", # contains credentials for AssumeRole groupFederationTest 59 | group_federation=True, 60 | ) as con: 61 | with con.cursor() as cursor: 62 | cursor.execute("select 1") 63 | cursor.execute("select current_user") 64 | assert cursor.fetchone()[0] == role_name 65 | -------------------------------------------------------------------------------- /test/manual/auth/test_redshift_auth_profile.py: -------------------------------------------------------------------------------- 1 | import typing 2 | 3 | import pytest # type: ignore 4 | 5 | import redshift_connector 6 | 7 | access_key_id: str = "replace_me" 8 | secret_access_key: str = "replace_me" 9 | session_token: str = "replace_me" 10 | 11 | """ 12 | This is a manually executable test. It requires valid AWS credentials. 13 | 14 | A Redshift authentication profile will be created prior to the test. 15 | Following the test, this profile will be deleted. 16 | 17 | Please modify the fixture, handle_redshift_auth_profile, to include any additional arguments to boto3 client required 18 | for your testing configuration. This may include attributes such as endpoint_url and region. The contents of the auth 19 | profile can be modified as needed. 20 | """ 21 | 22 | 23 | creds: typing.Dict[str, str] = { 24 | "aws_access_key_id": "replace_me", 25 | "aws_session_token": "replace_me", 26 | "aws_secret_access_key": "replace_me", 27 | } 28 | 29 | auth_profile_name: str = "PythonManualTest" 30 | 31 | 32 | @pytest.fixture(autouse=True) 33 | def handle_redshift_auth_profile(request, db_kwargs: typing.Dict[str, typing.Union[str, bool, int]]) -> None: 34 | import json 35 | 36 | import boto3 # type: ignore 37 | from botocore.exceptions import ClientError # type: ignore 38 | 39 | payload: str = json.dumps( 40 | { 41 | "host": db_kwargs["host"], 42 | "db_user": db_kwargs["user"], 43 | "max_prepared_statements": 5, 44 | "region": db_kwargs["region"], 45 | "cluster_identifier": db_kwargs["cluster_identifier"], 46 | "db_name": db_kwargs["database"], 47 | } 48 | ) 49 | 50 | try: 51 | client = boto3.client( 52 | "redshift", 53 | **{**creds, **{"region_name": typing.cast(str, db_kwargs["region"])}}, 54 | verify=False, 55 | ) 56 | client.create_authentication_profile( 57 | AuthenticationProfileName=auth_profile_name, AuthenticationProfileContent=payload 58 | ) 59 | except ClientError: 60 | raise 61 | 62 | def fin() -> None: 63 | import boto3 64 | from botocore.exceptions import ClientError 65 | 66 | try: 67 | client = boto3.client( 68 | "redshift", 69 | **{**creds, **{"region_name": typing.cast(str, db_kwargs["region"])}}, 70 | verify=False, 71 | ) 72 | client.delete_authentication_profile( 73 | AuthenticationProfileName=auth_profile_name, 74 | ) 75 | except ClientError: 76 | raise 77 | 78 | request.addfinalizer(fin) 79 | 80 | 81 | @pytest.mark.skip(reason="manual") 82 | def test_redshift_auth_profile_can_connect(db_kwargs): 83 | with redshift_connector.connect( 84 | region=db_kwargs["region"], 85 | access_key_id=creds["aws_access_key_id"], 86 | secret_access_key=creds["aws_secret_access_key"], 87 | session_token=creds["aws_session_token"], 88 | auth_profile=auth_profile_name, 89 | iam=True, 90 | ) as conn: 91 | assert conn.user == "IAM:{}".format(db_kwargs["user"]).encode() 92 | assert conn.max_prepared_statements == 5 93 | -------------------------------------------------------------------------------- /test/manual/plugin/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aws/amazon-redshift-python-driver/5bbd061b03cbe555dbc783f44f716bcf8b4864c4/test/manual/plugin/__init__.py -------------------------------------------------------------------------------- /test/manual/plugin/test_browser_credentials_provider.py: -------------------------------------------------------------------------------- 1 | import configparser 2 | import os 3 | import typing 4 | 5 | import pytest # type: ignore 6 | 7 | import redshift_connector 8 | 9 | conf: configparser.ConfigParser = configparser.ConfigParser() 10 | root_path: str = os.path.dirname(os.path.dirname(os.path.abspath(os.path.join(__file__, os.pardir)))) 11 | conf.read(root_path + "/config.ini") 12 | 13 | BROWSER_CREDENTIAL_PROVIDERS: typing.List[str] = [ 14 | "jumpcloud_browser_idp", 15 | "okta_browser_idp", 16 | "azure_browser_idp", 17 | # "jwt_azure_v2_idp", 18 | # "jwt_google_idp", 19 | "ping_browser_idp", 20 | "redshift_native_browser_azure_oauth2_idp", 21 | "redshift_browser_idc", 22 | "redshift_idp_token_auth_plugin", 23 | ] 24 | 25 | """ 26 | How to use: 27 | 0) If necessary, create a Redshift cluster and configure it for use with the desired IdP 28 | 1) In config.ini specify the connection parameters required by the desired IdP fixture in test/integration/conftest.py 29 | 2) Ensure browser cookies have been cleared 30 | 3) Manually execute the tests in this file, providing the necessary login information in the web browser 31 | """ 32 | 33 | 34 | @pytest.mark.skip(reason="manual") 35 | @pytest.mark.parametrize("idp_arg", BROWSER_CREDENTIAL_PROVIDERS, indirect=True) 36 | def test_browser_credentials_provider_can_auth(idp_arg): 37 | with redshift_connector.connect(**idp_arg) as conn: 38 | with conn.cursor() as cursor: 39 | cursor.execute("select 1;") 40 | -------------------------------------------------------------------------------- /test/manual/test_redshift_custom_domain.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | 3 | import redshift_connector 4 | from redshift_connector.idp_auth_helper import SupportedSSLMode 5 | 6 | """ 7 | These functional tests ensure connections to Redshift provisioned customer with custom domain name can be established 8 | when using various authentication methods. 9 | 10 | Pre-requisites: 11 | 1) Redshift provisioned configuration 12 | 2) Existing custom domain association with instance created in step 1 13 | """ 14 | 15 | 16 | @pytest.mark.skip(reason="manual") 17 | @pytest.mark.parametrize("sslmode", (SupportedSSLMode.VERIFY_CA, SupportedSSLMode.VERIFY_FULL)) 18 | def test_native_connect(provisioned_cname_db_kwargs, sslmode) -> None: 19 | # this test requires aws default profile contains valid credentials that provide permissions for 20 | # redshift:GetClusterCredentials ( Only called from this test method) 21 | import boto3 22 | 23 | profile = "default" 24 | client = boto3.client( 25 | service_name="redshift", 26 | region_name="eu-north-1", 27 | ) 28 | # fetch cluster credentials and pass them as driver connect parameters 29 | response = client.get_cluster_credentials( 30 | CustomDomainName=provisioned_cname_db_kwargs["host"], DbUser=provisioned_cname_db_kwargs["db_user"] 31 | ) 32 | 33 | provisioned_cname_db_kwargs["password"] = response["DbPassword"] 34 | provisioned_cname_db_kwargs["user"] = response["DbUser"] 35 | provisioned_cname_db_kwargs["profile"] = profile 36 | provisioned_cname_db_kwargs["ssl"] = True 37 | provisioned_cname_db_kwargs["sslmode"] = sslmode.value 38 | 39 | with redshift_connector.connect(**provisioned_cname_db_kwargs): 40 | pass 41 | 42 | 43 | @pytest.mark.skip(reason="manual") 44 | @pytest.mark.parametrize("sslmode", (SupportedSSLMode.VERIFY_CA, SupportedSSLMode.VERIFY_FULL)) 45 | def test_iam_connect(provisioned_cname_db_kwargs, sslmode) -> None: 46 | # this test requires aws default profile contains valid credentials that provide permissions for 47 | # redshift:GetClusterCredentials (called from driver) 48 | # redshift:DescribeClusters (called from driver) 49 | # redshift:DescribeCustomDomainAssociations (called from driver) 50 | provisioned_cname_db_kwargs["iam"] = True 51 | provisioned_cname_db_kwargs["profile"] = "default" 52 | provisioned_cname_db_kwargs["auto_create"] = True 53 | provisioned_cname_db_kwargs["ssl"] = True 54 | provisioned_cname_db_kwargs["sslmode"] = sslmode.value 55 | with redshift_connector.connect(**provisioned_cname_db_kwargs): 56 | pass 57 | 58 | 59 | def test_idp_connect(okta_idp, provisioned_cname_db_kwargs) -> None: 60 | # todo 61 | pass 62 | 63 | 64 | @pytest.mark.skip(reason="manual") 65 | def test_nlb_connect() -> None: 66 | args = { 67 | "iam": True, 68 | # "access_key_id": "xxx", 69 | # "secret_access_key": "xxx", 70 | "cluster_identifier": "replace-me", 71 | "region": "us-east-1", 72 | "host": "replace-me", 73 | "port": 5439, 74 | "database": "dev", 75 | "db_user": "replace-me", 76 | } 77 | with redshift_connector.connect(**args): # type: ignore 78 | pass 79 | 80 | 81 | @pytest.mark.skip(reason="manual") 82 | @pytest.mark.parametrize("sslmode", (SupportedSSLMode.VERIFY_CA, SupportedSSLMode.VERIFY_FULL)) 83 | def test_serverless_iam_cname_connect(sslmode, serverless_cname_db_kwargs): 84 | serverless_cname_db_kwargs["iam"] = True 85 | serverless_cname_db_kwargs["profile"] = "default" 86 | serverless_cname_db_kwargs["auto_create"] = True 87 | serverless_cname_db_kwargs["ssl"] = True 88 | serverless_cname_db_kwargs["sslmode"] = sslmode.value 89 | 90 | with redshift_connector.connect(**serverless_cname_db_kwargs) as conn: 91 | with conn.cursor() as cursor: 92 | cursor.execute("select current_user") 93 | print(cursor.fetchone()) 94 | 95 | 96 | @pytest.mark.skip(reason="manual") 97 | @pytest.mark.parametrize("sslmode", (SupportedSSLMode.VERIFY_CA, SupportedSSLMode.VERIFY_FULL)) 98 | def test_serverless_cname_connect(sslmode, serverless_cname_db_kwargs): 99 | # this test requires aws default profile contains valid credentials that provide permissions for 100 | # redshift-serverless:GetCredentials ( Only called from this test method) 101 | import boto3 102 | 103 | profile = "default" 104 | client = boto3.client( 105 | service_name="redshift-serverless", 106 | region_name="us-east-1", 107 | ) 108 | # fetch cluster credentials and pass them as driver connect parameters 109 | response = client.get_credentials( 110 | customDomainName=serverless_cname_db_kwargs["host"], dbName=serverless_cname_db_kwargs["database"] 111 | ) 112 | 113 | serverless_cname_db_kwargs["sslmode"] = sslmode.value 114 | serverless_cname_db_kwargs["ssl"] = True 115 | serverless_cname_db_kwargs["user"] = response["dbUser"] 116 | serverless_cname_db_kwargs["password"] = response["dbPassword"] 117 | serverless_cname_db_kwargs["profile"] = profile 118 | 119 | with redshift_connector.connect(**serverless_cname_db_kwargs) as conn: 120 | with conn.cursor() as cursor: 121 | cursor.execute("select current_user") 122 | print(cursor.fetchone()) 123 | -------------------------------------------------------------------------------- /test/manual/test_redshift_serverless.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | 3 | import redshift_connector 4 | 5 | """ 6 | These functional tests ensure connections to Redshift serverless can be established when 7 | using various authentication methods. 8 | 9 | Please note the pre-requisites were documented while this feature is under public preview, 10 | and are subject to change. 11 | 12 | Pre-requisites: 13 | 1) Redshift serverless configuration 14 | 2) EC2 instance configured for accessing Redshift serverless (i.e. in compatible VPC, subnet) 15 | 3) Perform a sanity check using psql to ensure Redshift serverless connection can be established 16 | 3) EC2 instance has Python installed 17 | 4) Clone redshift_connector on EC2 instance and install 18 | 19 | How to use: 20 | 1) Populate config.ini with the Redshift serverless endpoint and user authentication information 21 | 2) Run this file with pytest 22 | """ 23 | 24 | 25 | @pytest.mark.skip(reason="manual") 26 | def test_native_auth(serverless_native_db_kwargs) -> None: 27 | with redshift_connector.connect(**serverless_native_db_kwargs): 28 | pass 29 | 30 | 31 | @pytest.mark.skip(reason="manual") 32 | def test_iam_auth(serverless_iam_db_kwargs) -> None: 33 | with redshift_connector.connect(**serverless_iam_db_kwargs): 34 | pass 35 | 36 | 37 | @pytest.mark.skip(reason="manual") 38 | def test_idp_auth(okta_idp) -> None: 39 | okta_idp["host"] = "my_redshift_serverless_endpoint" 40 | 41 | with redshift_connector.connect(**okta_idp): 42 | pass 43 | 44 | 45 | @pytest.mark.skip() 46 | def test_connection_without_host(serverless_iam_db_kwargs) -> None: 47 | serverless_iam_db_kwargs["is_serverless"] = True 48 | serverless_iam_db_kwargs["host"] = None 49 | serverless_iam_db_kwargs["serverless_work_group"] = "default" 50 | with redshift_connector.connect(**serverless_iam_db_kwargs) as conn: 51 | with conn.cursor() as cursor: 52 | cursor.execute("select 1") 53 | 54 | 55 | @pytest.mark.skip() 56 | def test_nlb_connection(serverless_iam_db_kwargs) -> None: 57 | serverless_iam_db_kwargs["is_serverless"] = True 58 | serverless_iam_db_kwargs["host"] = "my_nlb_endpoint" 59 | serverless_iam_db_kwargs["serverless_work_group"] = "default" 60 | with redshift_connector.connect(**serverless_iam_db_kwargs) as conn: 61 | with conn.cursor() as cursor: 62 | cursor.execute("select 1") 63 | 64 | 65 | @pytest.mark.skip() 66 | def test_vpc_endpoint_connection(serverless_iam_db_kwargs) -> None: 67 | serverless_iam_db_kwargs["is_serverless"] = True 68 | serverless_iam_db_kwargs["host"] = "my_vpc_endpoint" 69 | serverless_iam_db_kwargs["serverless_work_group"] = "default" 70 | with redshift_connector.connect(**serverless_iam_db_kwargs) as conn: 71 | with conn.cursor() as cursor: 72 | cursor.execute("select 1") 73 | -------------------------------------------------------------------------------- /test/performance/bulk_insert_performance.py: -------------------------------------------------------------------------------- 1 | import configparser 2 | import csv 3 | import os 4 | import time 5 | import typing 6 | 7 | import redshift_connector 8 | 9 | conf: configparser.ConfigParser = configparser.ConfigParser() 10 | root_path: str = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) 11 | conf.read(root_path + "/config.ini") 12 | 13 | root_path = os.path.dirname(os.path.abspath(__file__)) 14 | 15 | 16 | def perf_conn(): 17 | return redshift_connector.connect( 18 | database=conf.get("ci-cluster", "database"), 19 | host=conf.get("ci-cluster", "host"), 20 | port=conf.getint("default-test", "port"), 21 | user=conf.get("ci-cluster", "test_user"), 22 | password=conf.get("ci-cluster", "test_password"), 23 | ssl=True, 24 | sslmode=conf.get("default-test", "sslmode"), 25 | iam=False, 26 | ) 27 | 28 | 29 | print("Reading data from csv file") 30 | with open(root_path + "/bulk_insert_data.csv", "r", encoding="utf8") as csv_data: 31 | reader = csv.reader(csv_data, delimiter="\t") 32 | next(reader) 33 | data = [] 34 | for row in reader: 35 | data.append(row) 36 | print("Inserting {} rows having {} columns".format(len(data), len(data[0]))) 37 | 38 | print("\nCursor.insert_data_bulk()..") 39 | for batch_size in [1e2, 1e3, 1e4]: 40 | with perf_conn() as conn: 41 | with conn.cursor() as cursor: 42 | cursor.execute("drop table if exists bulk_insert_perf;") 43 | cursor.execute("create table bulk_insert_perf (c1 int, c2 int, c3 int, c4 int, c5 int);") 44 | start_time: float = time.time() 45 | cursor.insert_data_bulk( 46 | filename=root_path + "/bulk_insert_data.csv", 47 | table_name="bulk_insert_perf", 48 | parameter_indices=[0, 1, 2, 3, 4], 49 | column_names=["c1", "c2", "c3", "c4", "c5"], 50 | delimiter="\t", 51 | batch_size=batch_size, 52 | ) 53 | 54 | print("batch_size={0} {1} seconds.".format(batch_size, time.time() - start_time)) 55 | 56 | print("Cursor.executemany()") 57 | with perf_conn() as conn: 58 | with conn.cursor() as cursor: 59 | cursor.execute("drop table if exists bulk_insert_perf;") 60 | cursor.execute("create table bulk_insert_perf (c1 int, c2 int, c3 int, c4 int, c5 int);") 61 | start_time = time.time() 62 | cursor.executemany("insert into bulk_insert_perf(c1, c2, c3, c4, c5) values(%s, %s, %s, %s, %s)", data) 63 | print("{0} seconds.".format(time.time() - start_time)) 64 | -------------------------------------------------------------------------------- /test/performance/performance.py: -------------------------------------------------------------------------------- 1 | import configparser 2 | import os 3 | import time 4 | import typing 5 | 6 | import redshift_connector 7 | 8 | conf: configparser.ConfigParser = configparser.ConfigParser() 9 | root_path: str = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) 10 | conf.read(root_path + "/config.ini") 11 | 12 | root_path = os.path.dirname(os.path.abspath(__file__)) 13 | sql: typing.TextIO = open(root_path + "/test.sql", "r", encoding="utf8") 14 | sqls: typing.List[str] = sql.readlines() 15 | sqls = [_sql.replace("\n", "") for _sql in sqls] 16 | sql.close() 17 | 18 | conn: redshift_connector.Connection = redshift_connector.connect( 19 | database=conf.get("ci-cluster", "database"), 20 | host=conf.get("ci-cluster", "host"), 21 | port=conf.getint("default-test", "port"), 22 | user=conf.get("ci-cluster", "test_user"), 23 | password=conf.get("ci-cluster", "test_password"), 24 | ssl=True, 25 | sslmode=conf.get("default-test", "sslmode"), 26 | iam=False, 27 | ) 28 | 29 | cursor: redshift_connector.Cursor = conn.cursor() 30 | for _sql in sqls: 31 | cursor.execute(_sql) 32 | 33 | result: typing.Tuple[typing.List[int], ...] = cursor.fetchall() 34 | print("fetch {result} rows".format(result=result)) 35 | 36 | print("start calculate fetch time") 37 | for val in [True, False]: 38 | print("merge_socket_read={val}".format(val=val)) 39 | start_time: float = time.time() 40 | cursor.execute("select * from performance", merge_socket_read=val) 41 | results: typing.Tuple[typing.List[int], ...] = cursor.fetchall() 42 | print("Took {0} seconds.".format(time.time() - start_time)) 43 | print("fetch {result} rows".format(result=len(results))) 44 | 45 | cursor.close() 46 | conn.commit() 47 | -------------------------------------------------------------------------------- /test/performance/protocol_perf_test.sql: -------------------------------------------------------------------------------- 1 | /* To use: PGPASSWORD={password} psql --host={host} --port 5439 --user={user} --dbname={db} -f protocol_perf_test.sql */ 2 | /* All used tables have 5 columns, and all columns within the same table have the same datatype */ 3 | 4 | drop table if exists perf_varchar; 5 | create table perf_varchar (val1 varchar, val2 varchar, val3 varchar, val4 varchar, val5 varchar); 6 | 7 | drop table if exists perf_time; 8 | create table perf_time (val1 time, val2 time, val3 time, val4 time, val5 time); 9 | 10 | drop table if exists perf_timetz; 11 | create table perf_timetz (val1 timetz, val2 timetz, val3 timetz, val4 timetz, val5 timetz); 12 | 13 | drop table if exists perf_timestamptz; 14 | create table perf_timestamptz (val1 timestamptz, val2 timestamptz, val3 timestamptz, val4 timestamptz, val5 timestamptz); 15 | 16 | insert into perf_varchar values('abcd¬µ3kt¿abcdÆgda123~Øasd', 'abcd¬µ3kt¿abcdÆgda123~Øasd', 'abcd¬µ3kt¿abcdÆgda123~Øasd', 'abcd¬µ3kt¿abcdÆgda123~Øasd', 'abcd¬µ3kt¿abcdÆgda123~Øasd'); 17 | insert into perf_varchar (select * from perf_varchar); 18 | insert into perf_varchar (select * from perf_varchar); 19 | insert into perf_varchar (select * from perf_varchar); 20 | insert into perf_varchar (select * from perf_varchar); 21 | insert into perf_varchar (select * from perf_varchar); 22 | insert into perf_varchar (select * from perf_varchar); 23 | insert into perf_varchar (select * from perf_varchar); 24 | insert into perf_varchar (select * from perf_varchar); 25 | insert into perf_varchar (select * from perf_varchar); 26 | insert into perf_varchar (select * from perf_varchar); 27 | insert into perf_varchar (select * from perf_varchar); 28 | insert into perf_varchar (select * from perf_varchar); 29 | insert into perf_varchar (select * from perf_varchar); 30 | insert into perf_varchar (select * from perf_varchar); 31 | insert into perf_varchar (select * from perf_varchar); 32 | insert into perf_varchar (select * from perf_varchar); 33 | insert into perf_varchar (select * from perf_varchar); 34 | insert into perf_varchar (select * from perf_varchar); 35 | insert into perf_varchar (select * from perf_varchar); 36 | insert into perf_varchar (select * from perf_varchar); 37 | insert into perf_varchar (select * from perf_varchar); 38 | insert into perf_varchar (select * from perf_varchar); 39 | 40 | insert into perf_time values('12:13:14', '12:13:14', '12:13:14', '12:13:14', '12:13:14'); 41 | insert into perf_time (select * from perf_time); 42 | insert into perf_time (select * from perf_time); 43 | insert into perf_time (select * from perf_time); 44 | insert into perf_time (select * from perf_time); 45 | insert into perf_time (select * from perf_time); 46 | insert into perf_time (select * from perf_time); 47 | insert into perf_time (select * from perf_time); 48 | insert into perf_time (select * from perf_time); 49 | insert into perf_time (select * from perf_time); 50 | insert into perf_time (select * from perf_time); 51 | insert into perf_time (select * from perf_time); 52 | insert into perf_time (select * from perf_time); 53 | insert into perf_time (select * from perf_time); 54 | insert into perf_time (select * from perf_time); 55 | insert into perf_time (select * from perf_time); 56 | insert into perf_time (select * from perf_time); 57 | insert into perf_time (select * from perf_time); 58 | insert into perf_time (select * from perf_time); 59 | insert into perf_time (select * from perf_time); 60 | insert into perf_time (select * from perf_time); 61 | insert into perf_time (select * from perf_time); 62 | insert into perf_time (select * from perf_time); 63 | 64 | insert into perf_timetz values('20:13:14.123456', '20:13:14.123456', '20:13:14.123456', '20:13:14.123456', '20:13:14.123456'); 65 | insert into perf_timetz (select * from perf_timetz); 66 | insert into perf_timetz (select * from perf_timetz); 67 | insert into perf_timetz (select * from perf_timetz); 68 | insert into perf_timetz (select * from perf_timetz); 69 | insert into perf_timetz (select * from perf_timetz); 70 | insert into perf_timetz (select * from perf_timetz); 71 | insert into perf_timetz (select * from perf_timetz); 72 | insert into perf_timetz (select * from perf_timetz); 73 | insert into perf_timetz (select * from perf_timetz); 74 | insert into perf_timetz (select * from perf_timetz); 75 | insert into perf_timetz (select * from perf_timetz); 76 | insert into perf_timetz (select * from perf_timetz); 77 | insert into perf_timetz (select * from perf_timetz); 78 | insert into perf_timetz (select * from perf_timetz); 79 | insert into perf_timetz (select * from perf_timetz); 80 | insert into perf_timetz (select * from perf_timetz); 81 | insert into perf_timetz (select * from perf_timetz); 82 | insert into perf_timetz (select * from perf_timetz); 83 | 84 | insert into perf_timestamptz values('1997-10-11 07:37:16', '1997-10-11 07:37:16', '1997-10-11 07:37:16', '1997-10-11 07:37:16', '1997-10-11 07:37:16'); 85 | insert into perf_timestamptz (select * from perf_timestamptz); 86 | insert into perf_timestamptz (select * from perf_timestamptz); 87 | insert into perf_timestamptz (select * from perf_timestamptz); 88 | insert into perf_timestamptz (select * from perf_timestamptz); 89 | insert into perf_timestamptz (select * from perf_timestamptz); 90 | insert into perf_timestamptz (select * from perf_timestamptz); 91 | insert into perf_timestamptz (select * from perf_timestamptz); 92 | insert into perf_timestamptz (select * from perf_timestamptz); 93 | insert into perf_timestamptz (select * from perf_timestamptz); 94 | insert into perf_timestamptz (select * from perf_timestamptz); 95 | insert into perf_timestamptz (select * from perf_timestamptz); 96 | insert into perf_timestamptz (select * from perf_timestamptz); 97 | insert into perf_timestamptz (select * from perf_timestamptz); 98 | insert into perf_timestamptz (select * from perf_timestamptz); 99 | insert into perf_timestamptz (select * from perf_timestamptz); 100 | insert into perf_timestamptz (select * from perf_timestamptz); 101 | insert into perf_timestamptz (select * from perf_timestamptz); 102 | insert into perf_timestamptz (select * from perf_timestamptz); 103 | -------------------------------------------------------------------------------- /test/performance/test.sql: -------------------------------------------------------------------------------- 1 | drop table if exists performance CASCADE; 2 | create table performance(c1 integer); 3 | insert into performance values(1); 4 | insert into performance values(2); 5 | insert into performance values(3); 6 | insert into performance values(4); 7 | insert into performance values(5); 8 | insert into performance values(6); 9 | insert into performance values(7); 10 | insert into performance values(8); 11 | insert into performance values(9); 12 | insert into performance select * from performance; 13 | insert into performance select * from performance; 14 | insert into performance select * from performance; 15 | insert into performance select * from performance; 16 | insert into performance select * from performance; 17 | insert into performance select * from performance; 18 | insert into performance select * from performance; 19 | insert into performance select * from performance; 20 | insert into performance select * from performance; 21 | insert into performance select * from performance; 22 | insert into performance select * from performance; 23 | insert into performance select * from performance; 24 | insert into performance select * from performance; 25 | insert into performance select * from performance; 26 | insert into performance select * from performance; 27 | insert into performance select * from performance; 28 | insert into performance select * from performance; 29 | insert into performance select * from performance; 30 | insert into performance select * from performance; 31 | insert into performance select * from performance; 32 | insert into performance select * from performance LIMIT 10000000; 33 | insert into performance select * from performance LIMIT 10000000; 34 | insert into performance select * from performance LIMIT 10000000; 35 | insert into performance select * from performance LIMIT 10000000; 36 | insert into performance select * from performance LIMIT 1457280; 37 | select count(*) from performance; 38 | -------------------------------------------------------------------------------- /test/unit/__init__.py: -------------------------------------------------------------------------------- 1 | from .mocks import MockCredentialsProvider 2 | -------------------------------------------------------------------------------- /test/unit/auth/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aws/amazon-redshift-python-driver/5bbd061b03cbe555dbc783f44f716bcf8b4864c4/test/unit/auth/__init__.py -------------------------------------------------------------------------------- /test/unit/auth/test_aws_credentials_provider.py: -------------------------------------------------------------------------------- 1 | import typing 2 | from unittest.mock import MagicMock, patch 3 | 4 | import pytest # type: ignore 5 | 6 | from redshift_connector import InterfaceError 7 | from redshift_connector.auth.aws_credentials_provider import ( 8 | AWSCredentialsProvider, 9 | AWSDirectCredentialsHolder, 10 | ) 11 | from redshift_connector.credentials_holder import AWSProfileCredentialsHolder 12 | from redshift_connector.redshift_property import RedshiftProperty 13 | 14 | 15 | def _make_aws_credentials_obj_with_profile() -> AWSCredentialsProvider: 16 | cred_provider: AWSCredentialsProvider = AWSCredentialsProvider() 17 | rp: RedshiftProperty = RedshiftProperty() 18 | profile_name: str = "myProfile" 19 | 20 | rp.profile = profile_name 21 | 22 | cred_provider.add_parameter(rp) 23 | return cred_provider 24 | 25 | 26 | def _make_aws_credentials_obj_with_credentials() -> AWSCredentialsProvider: 27 | cred_provider: AWSCredentialsProvider = AWSCredentialsProvider() 28 | rp: RedshiftProperty = RedshiftProperty() 29 | access_key_id: str = "my_access" 30 | secret_key: str = "my_secret" 31 | session_token: str = "my_session" 32 | 33 | rp.access_key_id = access_key_id 34 | rp.secret_access_key = secret_key 35 | rp.session_token = session_token 36 | 37 | cred_provider.add_parameter(rp) 38 | return cred_provider 39 | 40 | 41 | def test_create_aws_credentials_provider_with_profile() -> None: 42 | cred_provider: AWSCredentialsProvider = _make_aws_credentials_obj_with_profile() 43 | assert cred_provider.profile == "myProfile" 44 | assert cred_provider.access_key_id is None 45 | assert cred_provider.secret_access_key is None 46 | assert cred_provider.session_token is None 47 | 48 | 49 | def test_create_aws_credentials_provider_with_credentials() -> None: 50 | cred_provider: AWSCredentialsProvider = _make_aws_credentials_obj_with_credentials() 51 | assert cred_provider.profile is None 52 | assert cred_provider.access_key_id == "my_access" 53 | assert cred_provider.secret_access_key == "my_secret" 54 | assert cred_provider.session_token == "my_session" 55 | 56 | 57 | def test_get_cache_key_with_profile() -> None: 58 | cred_provider: AWSCredentialsProvider = _make_aws_credentials_obj_with_profile() 59 | assert cred_provider.get_cache_key() == hash(cred_provider.profile) 60 | 61 | 62 | def test_get_cache_key_with_credentials() -> None: 63 | cred_provider: AWSCredentialsProvider = _make_aws_credentials_obj_with_credentials() 64 | assert cred_provider.get_cache_key() == hash("my_access") 65 | 66 | 67 | def test_get_credentials_checks_cache_first(mocker) -> None: 68 | mocked_credential_holder = MagicMock() 69 | 70 | def mock_set_cache(cp: AWSCredentialsProvider, key: str = "tomato") -> None: 71 | cp.cache[key] = mocked_credential_holder # type: ignore 72 | 73 | cred_provider: AWSCredentialsProvider = AWSCredentialsProvider() 74 | mocker.patch("redshift_connector.auth.AWSCredentialsProvider.get_cache_key", return_value="tomato") 75 | 76 | with patch("redshift_connector.auth.AWSCredentialsProvider.refresh") as mocked_refresh: 77 | mocked_refresh.side_effect = mock_set_cache(cred_provider) # type: ignore 78 | get_cache_key_spy = mocker.spy(cred_provider, "get_cache_key") 79 | 80 | assert cred_provider.get_credentials() == mocked_credential_holder 81 | 82 | assert get_cache_key_spy.called is True 83 | assert get_cache_key_spy.call_count == 1 84 | 85 | 86 | def test_get_credentials_refresh_error_is_raised(mocker) -> None: 87 | cred_provider: AWSCredentialsProvider = AWSCredentialsProvider() 88 | mocker.patch("redshift_connector.auth.AWSCredentialsProvider.get_cache_key", return_value="tomato") 89 | expected_exception = "Refreshing IdP credentials failed" 90 | 91 | with patch("redshift_connector.auth.AWSCredentialsProvider.refresh") as mocked_refresh: 92 | mocked_refresh.side_effect = Exception(expected_exception) 93 | 94 | with pytest.raises(InterfaceError, match=expected_exception): 95 | cred_provider.get_credentials() 96 | 97 | 98 | def test_refresh_uses_profile_if_present(mocker) -> None: 99 | cred_provider: AWSCredentialsProvider = _make_aws_credentials_obj_with_profile() 100 | mocked_boto_session: MagicMock = MagicMock() 101 | 102 | with patch("boto3.Session", return_value=mocked_boto_session): 103 | cred_provider.refresh() 104 | 105 | assert hash("myProfile") in cred_provider.cache 106 | assert isinstance(cred_provider.cache[hash("myProfile")], AWSProfileCredentialsHolder) 107 | assert typing.cast(AWSProfileCredentialsHolder, cred_provider.cache[hash("myProfile")]).profile == "myProfile" 108 | 109 | 110 | def test_refresh_uses_credentials_if_present(mocker) -> None: 111 | cred_provider: AWSCredentialsProvider = _make_aws_credentials_obj_with_credentials() 112 | mocked_boto_session: MagicMock = MagicMock() 113 | 114 | with patch("boto3.Session", return_value=mocked_boto_session): 115 | cred_provider.refresh() 116 | 117 | assert hash("my_access") in cred_provider.cache 118 | assert isinstance(cred_provider.cache[hash("my_access")], AWSDirectCredentialsHolder) 119 | assert ( 120 | typing.cast(AWSDirectCredentialsHolder, cred_provider.cache[hash("my_access")]).access_key_id == "my_access" 121 | ) 122 | assert ( 123 | typing.cast(AWSDirectCredentialsHolder, cred_provider.cache[hash("my_access")]).secret_access_key 124 | == "my_secret" 125 | ) 126 | assert ( 127 | typing.cast(AWSDirectCredentialsHolder, cred_provider.cache[hash("my_access")]).session_token 128 | == "my_session" 129 | ) 130 | -------------------------------------------------------------------------------- /test/unit/datatype/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aws/amazon-redshift-python-driver/5bbd061b03cbe555dbc783f44f716bcf8b4864c4/test/unit/datatype/__init__.py -------------------------------------------------------------------------------- /test/unit/datatype/test_oids.py: -------------------------------------------------------------------------------- 1 | import typing 2 | 3 | import pytest 4 | 5 | from redshift_connector.utils.oids import RedshiftOID, get_datatype_name 6 | 7 | all_oids: typing.List[int] = [d for d in RedshiftOID] 8 | all_datatypes: typing.List[typing.Tuple[int, str]] = [(d, d.name) for d in RedshiftOID] 9 | 10 | 11 | @pytest.mark.parametrize("oid", all_oids) 12 | def test_RedshiftOID_has_type_int(oid): 13 | assert isinstance(oid, int) 14 | 15 | 16 | @pytest.mark.parametrize("oid, datatype", all_datatypes) 17 | def test_get_datatype_name(oid, datatype): 18 | assert get_datatype_name(oid) == datatype 19 | 20 | 21 | def test_get_datatype_name_invalid_oid_raises() -> None: 22 | with pytest.raises(ValueError, match="not a valid RedshiftOID"): 23 | get_datatype_name(-9) 24 | 25 | def test_modify_oid() -> None: 26 | with pytest.raises(AttributeError, match="Cannot modify OID constant 'VARCHAR'"): 27 | RedshiftOID.VARCHAR = 0 -------------------------------------------------------------------------------- /test/unit/datatype/test_sql_types.py: -------------------------------------------------------------------------------- 1 | import typing 2 | 3 | import pytest 4 | 5 | from redshift_connector.utils.sql_types import SQLType, get_sql_type_name 6 | 7 | all_oids: typing.List[int] = [d for d in SQLType] 8 | all_datatypes: typing.List[typing.Tuple[int, str]] = [(d, d.name) for d in SQLType] 9 | 10 | 11 | @pytest.mark.parametrize("oid", all_oids) 12 | def test_sql_type_has_type_int(oid): 13 | assert isinstance(oid, int) 14 | 15 | 16 | @pytest.mark.parametrize("oid, datatype", all_datatypes) 17 | def test_get_sql_type_name(oid, datatype): 18 | assert get_sql_type_name(oid) == datatype 19 | 20 | 21 | def test_get_datatype_name_invalid_oid_raises() -> None: 22 | with pytest.raises(ValueError, match="not a valid SQLType"): 23 | get_sql_type_name(3000) 24 | 25 | def test_modify_sql_type() -> None: 26 | with pytest.raises(AttributeError, match="Cannot modify SQL type constant 'SQL_VARCHAR'"): 27 | SQLType.SQL_VARCHAR = 0 -------------------------------------------------------------------------------- /test/unit/helpers/__init__.py: -------------------------------------------------------------------------------- 1 | from .idp_helpers import make_redshift_property 2 | -------------------------------------------------------------------------------- /test/unit/helpers/idp_helpers.py: -------------------------------------------------------------------------------- 1 | from redshift_connector import RedshiftProperty 2 | from redshift_connector.config import ClientProtocolVersion 3 | 4 | 5 | def make_redshift_property() -> RedshiftProperty: 6 | rp: RedshiftProperty = RedshiftProperty() 7 | rp.user_name = "mario@luigi.com" 8 | rp.password = "bowser" 9 | rp.db_name = "dev" 10 | rp.cluster_identifier = "something" 11 | rp.idp_host = "8000" 12 | rp.duration = 100 13 | rp.preferred_role = "analyst" 14 | rp.ssl_insecure = False 15 | rp.db_user = "primary" 16 | rp.db_groups = ["employees"] 17 | rp.force_lowercase = True 18 | rp.auto_create = False 19 | rp.region = "us-west-1" 20 | rp.principal = "arn:aws:iam::123456789012:user/Development/product_1234/*" 21 | rp.client_protocol_version = ClientProtocolVersion.BASE_SERVER 22 | return rp 23 | -------------------------------------------------------------------------------- /test/unit/mocks/__init__.py: -------------------------------------------------------------------------------- 1 | from .mock_external_credential_provider import MockCredentialsProvider 2 | -------------------------------------------------------------------------------- /test/unit/mocks/mock_external_credential_provider.py: -------------------------------------------------------------------------------- 1 | from redshift_connector.plugin import SamlCredentialsProvider 2 | 3 | 4 | class MockCredentialsProvider(SamlCredentialsProvider): 5 | def get_saml_assertion(self: "SamlCredentialsProvider"): 6 | return "mocked" 7 | -------------------------------------------------------------------------------- /test/unit/mocks/mock_socket.py: -------------------------------------------------------------------------------- 1 | import socket 2 | import typing 3 | 4 | 5 | class MockSocket(socket.socket): 6 | mocked_data: typing.Optional[bytes] = None 7 | 8 | def __init__(self, family=-1, type=-1, proto=-1, fileno=None): 9 | pass 10 | 11 | def __exit__(self, exc_type, exc_val, exc_tb): 12 | pass 13 | 14 | def setsockopt(self, level, optname, value): 15 | pass 16 | 17 | def bind(self, address): 18 | pass 19 | 20 | def listen(self, __backlog=...): 21 | pass 22 | 23 | def settimeout(self, value): 24 | pass 25 | 26 | def accept(self): 27 | return (MockSocket(), "127.0.0.1") 28 | 29 | def recv(self, bufsize, flags=...): 30 | return self.mocked_data 31 | 32 | def close(self) -> None: 33 | pass 34 | 35 | def send(self, *args) -> None: # type: ignore 36 | pass 37 | -------------------------------------------------------------------------------- /test/unit/plugin/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aws/amazon-redshift-python-driver/5bbd061b03cbe555dbc783f44f716bcf8b4864c4/test/unit/plugin/__init__.py -------------------------------------------------------------------------------- /test/unit/plugin/data/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aws/amazon-redshift-python-driver/5bbd061b03cbe555dbc783f44f716bcf8b4864c4/test/unit/plugin/data/__init__.py -------------------------------------------------------------------------------- /test/unit/plugin/data/browser_azure_data.py: -------------------------------------------------------------------------------- 1 | code: str = "helloworld" 2 | state: str = "abcdefghij" 3 | valid_response: bytes = ( 4 | b"POST /redshift/ HTTP/1.1\r\nHost: localhost:51187\r\nConnection: keep-alive\r\nContent-Length: 695\r\n" 5 | b"Cache-Control: max-age=0\r\nUpgrade-Insecure-Requests: 1\r\nOrigin: null\r\n" 6 | b"Content-Type: application/x-www-form-urlencoded\r\nUser-Agent: Mozilla/5.0 (Macintosh; Intel Mac OS X 10_15_7) " 7 | b"AppleWebKit/537.36 (KHTML, like Gecko) Chrome/86.0.4240.75 Safari/537.36\r\nAccept: text/html," 8 | b"application/xhtml+xml,application/xml;q=0.9,image/avif,image/webp,image/apng,*/*;q=0.8," 9 | b"application/signed-exchange;v=b3;q=0.9\r\nSec-Fetch-Site: cross-site\r\nSec-Fetch-Mode: navigate\r\n" 10 | b"Sec-Fetch-Dest: document\r\nAccept-Encoding: gzip, deflate, br\r\nAccept-Language: en-US,en;q=0.9\r\n" 11 | b"\r\ncode=" + code.encode("utf-8") + b"&state=" + state.encode("utf-8") + b"&session_state=hooplah" 12 | ) 13 | 14 | missing_state_response: bytes = ( 15 | b"POST /redshift/ HTTP/1.1\r\nHost: localhost:51187\r\nConnection: keep-alive\r\nContent-Length: 695\r\n" 16 | b"Cache-Control: max-age=0\r\nUpgrade-Insecure-Requests: 1\r\nOrigin: null\r\n" 17 | b"Content-Type: application/x-www-form-urlencoded\r\nUser-Agent: Mozilla/5.0 (Macintosh; Intel Mac OS X 10_15_7) " 18 | b"AppleWebKit/537.36 (KHTML, like Gecko) Chrome/86.0.4240.75 Safari/537.36\r\nAccept: text/html," 19 | b"application/xhtml+xml,application/xml;q=0.9,image/avif,image/webp,image/apng,*/*;q=0.8," 20 | b"application/signed-exchange;v=b3;q=0.9\r\nSec-Fetch-Site: cross-site\r\nSec-Fetch-Mode: navigate\r\n" 21 | b"Sec-Fetch-Dest: document\r\nAccept-Encoding: gzip, deflate, br\r\nAccept-Language: en-US,en;q=0.9\r\n" 22 | b"\r\ncode=" + code.encode("utf-8") 23 | ) 24 | 25 | mismatched_state_response: bytes = ( 26 | b"POST /redshift/ HTTP/1.1\r\nHost: localhost:51187\r\nConnection: keep-alive\r\nContent-Length: 695\r\n" 27 | b"Cache-Control: max-age=0\r\nUpgrade-Insecure-Requests: 1\r\nOrigin: null\r\n" 28 | b"Content-Type: application/x-www-form-urlencoded\r\nUser-Agent: Mozilla/5.0 (Macintosh; Intel Mac OS X 10_15_7) " 29 | b"AppleWebKit/537.36 (KHTML, like Gecko) Chrome/86.0.4240.75 Safari/537.36\r\nAccept: text/html," 30 | b"application/xhtml+xml,application/xml;q=0.9,image/avif,image/webp,image/apng,*/*;q=0.8," 31 | b"application/signed-exchange;v=b3;q=0.9\r\nSec-Fetch-Site: cross-site\r\nSec-Fetch-Mode: navigate\r\n" 32 | b"Sec-Fetch-Dest: document\r\nAccept-Encoding: gzip, deflate, br\r\nAccept-Language: en-US,en;q=0.9\r\n" 33 | b"\r\ncode=" + code.encode("utf-8") + b"&state=" + state[::-1].encode("utf-8") + b"&session_state=hooplah" 34 | ) 35 | 36 | missing_code_response: bytes = ( 37 | b"POST /redshift/ HTTP/1.1\r\nHost: localhost:51187\r\nConnection: keep-alive\r\nContent-Length: 695\r\n" 38 | b"Cache-Control: max-age=0\r\nUpgrade-Insecure-Requests: 1\r\nOrigin: null\r\n" 39 | b"Content-Type: application/x-www-form-urlencoded\r\nUser-Agent: Mozilla/5.0 (Macintosh; Intel Mac OS X 10_15_7) " 40 | b"AppleWebKit/537.36 (KHTML, like Gecko) Chrome/86.0.4240.75 Safari/537.36\r\nAccept: text/html," 41 | b"application/xhtml+xml,application/xml;q=0.9,image/avif,image/webp,image/apng,*/*;q=0.8," 42 | b"application/signed-exchange;v=b3;q=0.9\r\nSec-Fetch-Site: cross-site\r\nSec-Fetch-Mode: navigate\r\n" 43 | b"Sec-Fetch-Dest: document\r\nAccept-Encoding: gzip, deflate, br\r\nAccept-Language: en-US,en;q=0.9\r\n" 44 | b"&state=" + state.encode("utf-8") + b"&session_state=hooplah" 45 | ) 46 | 47 | empty_code_response: bytes = ( 48 | b"POST /redshift/ HTTP/1.1\r\nHost: localhost:51187\r\nConnection: keep-alive\r\nContent-Length: 695\r\n" 49 | b"Cache-Control: max-age=0\r\nUpgrade-Insecure-Requests: 1\r\nOrigin: null\r\n" 50 | b"Content-Type: application/x-www-form-urlencoded\r\nUser-Agent: Mozilla/5.0 (Macintosh; Intel Mac OS X 10_15_7) " 51 | b"AppleWebKit/537.36 (KHTML, like Gecko) Chrome/86.0.4240.75 Safari/537.36\r\nAccept: text/html," 52 | b"application/xhtml+xml,application/xml;q=0.9,image/avif,image/webp,image/apng,*/*;q=0.8," 53 | b"application/signed-exchange;v=b3;q=0.9\r\nSec-Fetch-Site: cross-site\r\nSec-Fetch-Mode: navigate\r\n" 54 | b"Sec-Fetch-Dest: document\r\nAccept-Encoding: gzip, deflate, br\r\nAccept-Language: en-US,en;q=0.9\r\n" 55 | b"\r\ncode=" + b"&state=" + state.encode("utf-8") + b"&session_state=hooplah" 56 | ) 57 | 58 | 59 | saml_response = b"my_access_token" 60 | 61 | valid_json_response: dict = { 62 | "token_type": "Bearer", 63 | "expires_in": "3599", 64 | "ext_expires_in": "3599", 65 | "expires_on": "1602782647", 66 | "resource": "spn:1234567891011121314151617181920", 67 | "access_token": "bXlfYWNjZXNzX3Rva2Vu", # base64.urlsafe_64encode(saml_response) 68 | "issued_token_type": "urn:ietf:params:oauth:token-type:saml2", 69 | "refresh_token": "my_refresh_token", 70 | "id_token": "my_id_token", 71 | } 72 | 73 | json_response_no_access_token: dict = { 74 | "token_type": "Bearer", 75 | "expires_in": "3599", 76 | "ext_expires_in": "3599", 77 | "expires_on": "1602782647", 78 | "resource": "spn:1234567891011121314151617181920", 79 | "issued_token_type": "urn:ietf:params:oauth:token-type:saml2", 80 | "refresh_token": "my_refresh_token", 81 | "id_token": "my_id_token", 82 | } 83 | 84 | json_response_empty_access_token: dict = { 85 | "token_type": "Bearer", 86 | "expires_in": "3599", 87 | "ext_expires_in": "3599", 88 | "expires_on": "1602782647", 89 | "resource": "spn:1234567891011121314151617181920", 90 | "access_token": "", 91 | "issued_token_type": "urn:ietf:params:oauth:token-type:saml2", 92 | "refresh_token": "my_refresh_token", 93 | "id_token": "my_id_token", 94 | } 95 | -------------------------------------------------------------------------------- /test/unit/plugin/data/mock_adfs_saml_response.html: -------------------------------------------------------------------------------- 1 | 2 | Working... 3 | 4 |
7 | 8 |
9 | 10 | 11 | 12 | -------------------------------------------------------------------------------- /test/unit/plugin/data/saml_response.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | http://www.okta.com/exki777jkSgkPNu9L4x6 4 | 5 | 6 | 7 | 8 | 9 | 10 | 11 | 12 | 13 | 14 | 15 | 16 | testDigestValue1 17 | 18 | 19 | testSignature1 20 | 21 | 22 | testCertificate1 23 | 24 | 25 | 26 | 27 | 28 | 29 | 30 | http://www.okta.com/exki777jkSgkPNu9L4x6 31 | 32 | 33 | 34 | 35 | 36 | 37 | 38 | 39 | 40 | 41 | 42 | 43 | testDigestValue2= 44 | 45 | 46 | testSignature2 47 | 48 | 49 | testCertificate2 50 | 51 | 52 | 53 | 54 | example@example.com 55 | 56 | 57 | 58 | 59 | 60 | 61 | urn:amazon:webservices 62 | 63 | 64 | 65 | 66 | urn:oasis:names:tc:SAML:2.0:ac:classes:PasswordProtectedTransport 67 | 68 | 69 | 70 | 71 | arn:aws:iam::123456789012:role/myRole,arn:aws:iam::123456789012:saml-provider/myProvider 72 | 73 | 74 | example@example.com 75 | 76 | 77 | true 78 | 79 | 80 | example@example.com 81 | 82 | 83 | 84 | 85 | -------------------------------------------------------------------------------- /test/unit/plugin/data/saml_response_data.py: -------------------------------------------------------------------------------- 1 | import base64 2 | import zlib 3 | 4 | # uses sample (valid) SAML response from data folder 5 | saml_response: bytes = open("test/unit/plugin/data/saml_response.xml").read().encode("utf-8") 6 | 7 | # the SAML response as received in HTTP response 8 | encoded_saml_response: bytes = base64.b64encode(zlib.compress(saml_response)[2:-4]) 9 | 10 | 11 | # HTTP response containing SAML response can be formatted two ways. 12 | # 1. HTTP response containing things other than the SAML response 13 | valid_http_response_with_header_equal_delim: bytes = ( 14 | b"POST /redshift/ HTTP/1.1\r\nHost: localhost:7890\r\nConnection: keep-alive\r\nContent-Length: " 15 | b"11639\r\nCache-Control: max-age=0\r\nUpgrade-Insecure-Requests: 1\r\nOrigin: null\r\nContent-Type: " 16 | b"application/x-www-form-urlencoded\r\nUser-Agent: Mozilla/5.0 (Macintosh; Intel Mac OS X 10_15_6) " 17 | b"AppleWebKit/537.36 (KHTML, like Gecko) Chrome/85.0.4183.102 Safari/537.36\r\nAccept: text/html," 18 | b"application/xhtml+xml,application/xml;q=0.9,image/avif,image/webp,image/apng,/;q=0.8," 19 | b"application/signed-exchange;v=b3;q=0.9\r\nSec-Fetch-Site: cross-site\r\nSec-Fetch-Mode: " 20 | b"navigate\r\nSec-Fetch-Dest: document\r\nAccept-Encoding: gzip, deflate, br\r\nAccept-Language: en-US," 21 | b"en;q=0.9\r\n\r\nSAMLResponse=" + encoded_saml_response + b"&RelayState=" 22 | ) 23 | 24 | valid_http_response_with_header_colon_delim: bytes = ( 25 | b"POST\r\nRelayState: \r\nSAMLResponse:\r\n" + encoded_saml_response 26 | ) 27 | 28 | # 2. HTTP response containing *only* the SAML response 29 | valid_http_response_no_header: bytes = b"SAMLResponse=" + encoded_saml_response + b"&RelayState=" 30 | -------------------------------------------------------------------------------- /test/unit/plugin/test_azure_oauth2_credentials_provider.py: -------------------------------------------------------------------------------- 1 | import typing 2 | 3 | import pytest 4 | from pytest_mock import mocker # type: ignore 5 | 6 | from redshift_connector.error import InterfaceError 7 | from redshift_connector.plugin.browser_azure_oauth2_credentials_provider import ( 8 | BrowserAzureOAuth2CredentialsProvider, 9 | ) 10 | from redshift_connector.redshift_property import RedshiftProperty 11 | 12 | 13 | def make_valid_azure_oauth2_provider() -> typing.Tuple[BrowserAzureOAuth2CredentialsProvider, RedshiftProperty]: 14 | rp: RedshiftProperty = RedshiftProperty() 15 | rp.idp_tenant = "my_idp_tenant" 16 | rp.client_id = "my_client_id" 17 | rp.scope = "my_scope" 18 | rp.idp_response_timeout = 900 19 | rp.listen_port = 1099 20 | cp: BrowserAzureOAuth2CredentialsProvider = BrowserAzureOAuth2CredentialsProvider() 21 | cp.add_parameter(rp) 22 | return cp, rp 23 | 24 | def test_default_parameters_azure_oauth2_specific() -> None: 25 | acp, _ = make_valid_azure_oauth2_provider() 26 | assert acp.ssl_insecure == False 27 | assert acp.do_verify_ssl_cert() == True 28 | 29 | 30 | def test_add_parameter_sets_azure_oauth2_specific() -> None: 31 | acp, rp = make_valid_azure_oauth2_provider() 32 | assert acp.idp_tenant == rp.idp_tenant 33 | assert acp.client_id == rp.client_id 34 | assert acp.scope == rp.scope 35 | assert acp.idp_response_timeout == rp.idp_response_timeout 36 | assert acp.listen_port == rp.listen_port 37 | 38 | 39 | @pytest.mark.parametrize("value", [None, ""]) 40 | def test_check_required_parameters_raises_if_idp_tenant_missing_or_too_small(value) -> None: 41 | acp, _ = make_valid_azure_oauth2_provider() 42 | acp.idp_tenant = value 43 | 44 | with pytest.raises(InterfaceError, match="Missing required connection property: idp_tenant"): 45 | acp.get_jwt_assertion() 46 | 47 | 48 | @pytest.mark.parametrize("value", [None, ""]) 49 | def test_check_required_parameters_raises_if_client_id_missing(value) -> None: 50 | acp, _ = make_valid_azure_oauth2_provider() 51 | acp.client_id = value 52 | 53 | with pytest.raises(InterfaceError, match="Missing required connection property: client_id"): 54 | acp.get_jwt_assertion() 55 | 56 | 57 | @pytest.mark.parametrize("value", [None, ""]) 58 | def test_check_required_parameters_raises_if_idp_response_timeout_missing(value) -> None: 59 | acp, _ = make_valid_azure_oauth2_provider() 60 | acp.idp_response_timeout = value 61 | 62 | with pytest.raises( 63 | InterfaceError, 64 | match="Invalid value specified for connection property: idp_response_timeout. Must be 10 seconds or greater", 65 | ): 66 | acp.get_jwt_assertion() 67 | 68 | 69 | def test_get_jwt_assertion_fetches_and_extracts(mocker) -> None: 70 | mock_token: str = "mock_token" 71 | mock_content: str = "mock_content" 72 | mock_jwt_assertion: str = "mock_jwt_assertion" 73 | mocker.patch( 74 | "redshift_connector.plugin.browser_azure_oauth2_credentials_provider." 75 | "BrowserAzureOAuth2CredentialsProvider.fetch_authorization_token", 76 | return_value=mock_token, 77 | ) 78 | mocker.patch( 79 | "redshift_connector.plugin.browser_azure_oauth2_credentials_provider." 80 | "BrowserAzureOAuth2CredentialsProvider.fetch_jwt_response", 81 | return_value=mock_content, 82 | ) 83 | mocker.patch( 84 | "redshift_connector.plugin.browser_azure_oauth2_credentials_provider." 85 | "BrowserAzureOAuth2CredentialsProvider.extract_jwt_assertion", 86 | return_value=mock_jwt_assertion, 87 | ) 88 | acp, rp = make_valid_azure_oauth2_provider() 89 | 90 | fetch_token_spy = mocker.spy(acp, "fetch_authorization_token") 91 | fetch_jwt_spy = mocker.spy(acp, "fetch_jwt_response") 92 | extract_jwt_spy = mocker.spy(acp, "extract_jwt_assertion") 93 | 94 | jwt_assertion: str = acp.get_jwt_assertion() 95 | 96 | assert fetch_token_spy.called is True 97 | assert fetch_token_spy.call_count == 1 98 | 99 | assert fetch_jwt_spy.called is True 100 | assert fetch_jwt_spy.call_count == 1 101 | assert fetch_jwt_spy.call_args[0][0] == mock_token 102 | 103 | assert extract_jwt_spy.called is True 104 | assert extract_jwt_spy.call_count == 1 105 | assert extract_jwt_spy.call_args[0][0] == mock_content 106 | 107 | assert jwt_assertion == mock_jwt_assertion 108 | -------------------------------------------------------------------------------- /test/unit/plugin/test_browser_saml_credentials_provider.py: -------------------------------------------------------------------------------- 1 | import typing 2 | from test.unit.mocks.mock_socket import MockSocket 3 | from test.unit.plugin.data import saml_response_data 4 | from unittest.mock import Mock, patch 5 | 6 | import pytest # type: ignore 7 | 8 | from redshift_connector.error import InterfaceError 9 | from redshift_connector.plugin.browser_saml_credentials_provider import ( 10 | BrowserSamlCredentialsProvider, 11 | ) 12 | 13 | http_response_datas: typing.List[bytes] = [ 14 | saml_response_data.valid_http_response_with_header_equal_delim, 15 | saml_response_data.valid_http_response_with_header_colon_delim, 16 | saml_response_data.valid_http_response_no_header, 17 | ] 18 | 19 | 20 | @pytest.fixture(autouse=True) 21 | def cleanup_mock_socket() -> None: 22 | # cleans up class attribute that mocks data the socket receives 23 | MockSocket.mocked_data = None 24 | 25 | 26 | @pytest.mark.parametrize("http_response", http_response_datas) 27 | def test_run_server_parses_saml_response(http_response) -> None: 28 | MockSocket.mocked_data = http_response 29 | with patch("socket.socket", return_value=MockSocket()): 30 | browser_saml_credentials: BrowserSamlCredentialsProvider = BrowserSamlCredentialsProvider() 31 | parsed_saml_response: str = browser_saml_credentials.run_server(listen_port=0, idp_response_timeout=5) 32 | assert parsed_saml_response == saml_response_data.encoded_saml_response.decode("utf-8") 33 | 34 | 35 | invalid_login_url_vals: typing.List[typing.Optional[str]] = ["", None] 36 | 37 | 38 | @pytest.mark.parametrize("login_url", invalid_login_url_vals) 39 | def test_get_saml_assertion_no_login_url_should_fail(login_url) -> None: 40 | browser_saml_credentials: BrowserSamlCredentialsProvider = BrowserSamlCredentialsProvider() 41 | browser_saml_credentials.login_url = login_url 42 | 43 | with pytest.raises(InterfaceError) as ex: 44 | browser_saml_credentials.get_saml_assertion() 45 | assert "Missing required connection property: login_url" in str(ex.value) 46 | 47 | 48 | def test_get_saml_assertion_low_idp_response_timeout_should_fail() -> None: 49 | browser_saml_credentials: BrowserSamlCredentialsProvider = BrowserSamlCredentialsProvider() 50 | browser_saml_credentials.login_url = "https://www.example.com" 51 | browser_saml_credentials.idp_response_timeout = -1 52 | 53 | with pytest.raises(InterfaceError) as ex: 54 | browser_saml_credentials.get_saml_assertion() 55 | assert ( 56 | "Invalid value specified for connection property: idp_response_timeout. Must be 10 seconds or greater" 57 | in str(ex.value) 58 | ) 59 | 60 | 61 | invalid_listen_port_vals: typing.List[int] = [-1, 0, 65536] 62 | 63 | 64 | @pytest.mark.parametrize("listen_port", invalid_listen_port_vals) 65 | def test_get_saml_assertion_invalid_listen_port_should_fail(listen_port) -> None: 66 | browser_saml_credentials: BrowserSamlCredentialsProvider = BrowserSamlCredentialsProvider() 67 | browser_saml_credentials.login_url = "https://www.example.com" 68 | browser_saml_credentials.idp_response_timeout = 11 69 | browser_saml_credentials.listen_port = listen_port 70 | 71 | with pytest.raises(InterfaceError) as ex: 72 | browser_saml_credentials.get_saml_assertion() 73 | assert "Invalid value specified for connection property: listen_port. Must be in range [1,65535]" in str(ex.value) 74 | 75 | 76 | def test_authenticate_returns_authorization_token(mocker) -> None: 77 | browser_saml_credentials: BrowserSamlCredentialsProvider = BrowserSamlCredentialsProvider() 78 | mock_authorization_token: str = "my_authorization_token" 79 | 80 | mocker.patch("redshift_connector.plugin.BrowserSamlCredentialsProvider.open_browser", return_value=None) 81 | mocker.patch( 82 | "redshift_connector.plugin.BrowserSamlCredentialsProvider.run_server", return_value=mock_authorization_token 83 | ) 84 | 85 | assert browser_saml_credentials.authenticate() == mock_authorization_token 86 | 87 | 88 | def test_authenticate_errors_should_fail(mocker) -> None: 89 | browser_saml_credentials: BrowserSamlCredentialsProvider = BrowserSamlCredentialsProvider() 90 | 91 | mocker.patch("redshift_connector.plugin.BrowserSamlCredentialsProvider.open_browser", return_value=None) 92 | with patch("redshift_connector.plugin.BrowserSamlCredentialsProvider.run_server") as mocked_server: 93 | mocked_server.side_effect = Exception("bad mistake") 94 | 95 | with pytest.raises(Exception, match="bad mistake"): 96 | browser_saml_credentials.authenticate() 97 | 98 | 99 | def test_open_browser_no_url_should_fail() -> None: 100 | browser_saml_credentials: BrowserSamlCredentialsProvider = BrowserSamlCredentialsProvider() 101 | 102 | with pytest.raises(InterfaceError) as ex: 103 | browser_saml_credentials.open_browser() 104 | assert "Missing required connection property: login_url" in str(ex.value) 105 | -------------------------------------------------------------------------------- /test/unit/plugin/test_credentials_providers.py: -------------------------------------------------------------------------------- 1 | import configparser 2 | import os 3 | import typing 4 | from test import ( 5 | adfs_idp, 6 | azure_browser_idp, 7 | azure_idp, 8 | idp_arg, 9 | jumpcloud_browser_idp, 10 | jwt_azure_v2_idp, 11 | jwt_google_idp, 12 | okta_browser_idp, 13 | okta_idp, 14 | ping_browser_idp, 15 | redshift_browser_idc, 16 | redshift_idp_token_auth_plugin, 17 | ) 18 | 19 | import pytest # type: ignore 20 | 21 | import redshift_connector 22 | 23 | conf: configparser.ConfigParser = configparser.ConfigParser() 24 | root_path: str = os.path.dirname(os.path.dirname(os.path.abspath(os.path.join(__file__, os.pardir)))) 25 | conf.read(root_path + "/config.ini") 26 | 27 | 28 | NON_BROWSER_IDP: typing.List[str] = ["okta_idp", "azure_idp", "adfs_idp"] 29 | NON_BROWSER_IDC: typing.List[str] = ["redshift_idp_token_auth_plugin"] 30 | 31 | ALL_IDP: typing.List[str] = [ 32 | "okta_browser_idp", 33 | "azure_browser_idp", 34 | "jumpcloud_browser_idp", 35 | "ping_browser_idp", 36 | "jwt_google_idp", 37 | "jwt_azure_v2_idp", 38 | ] + NON_BROWSER_IDP 39 | 40 | 41 | @pytest.mark.parametrize("idp_arg", ALL_IDP, indirect=True) 42 | def test_credential_provider_dne_should_fail(idp_arg) -> None: 43 | idp_arg["credentials_provider"] = "WrongProvider" 44 | with pytest.raises( 45 | redshift_connector.InterfaceError, match="Invalid IdP specified in credential_provider connection parameter" 46 | ): 47 | redshift_connector.connect(**idp_arg) 48 | 49 | 50 | @pytest.mark.parametrize("idp_arg", ALL_IDP, indirect=True) 51 | def test_ssl_and_iam_invalid_should_fail(idp_arg) -> None: 52 | idp_arg["ssl"] = False 53 | idp_arg["iam"] = True 54 | with pytest.raises( 55 | redshift_connector.InterfaceError, 56 | match="Invalid connection property setting. SSL must be enabled when using IAM", 57 | ): 58 | redshift_connector.connect(**idp_arg) 59 | 60 | idp_arg["ssl"] = True 61 | idp_arg["credentials_provider"] = "OktacredentialSProvider" 62 | with pytest.raises( 63 | redshift_connector.InterfaceError, 64 | match="Invalid IdP specified in credential_provider connection parameter", 65 | ): 66 | redshift_connector.connect(**idp_arg) 67 | 68 | 69 | @pytest.mark.parametrize("idp_arg", NON_BROWSER_IDC, indirect=True) 70 | def test_using_iam_should_fail(idp_arg) -> None: 71 | idp_arg["iam"] = True 72 | with pytest.raises( 73 | redshift_connector.InterfaceError, 74 | match="You can not use this authentication plugin with IAM enabled.", 75 | ): 76 | redshift_connector.connect(**idp_arg) 77 | -------------------------------------------------------------------------------- /test/unit/plugin/test_idp_token_auth_plugin.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | 3 | from redshift_connector import IamHelper, InterfaceError, RedshiftProperty 4 | from redshift_connector.plugin import IdpTokenAuthPlugin 5 | from redshift_connector.plugin.native_token_holder import NativeTokenHolder 6 | 7 | 8 | def test_should_fail_without_token(): 9 | itap: IdpTokenAuthPlugin = IdpTokenAuthPlugin() 10 | itap.token_type = "blah" 11 | 12 | with pytest.raises( 13 | InterfaceError, match="IdC authentication failed: The token must be included in the connection parameters." 14 | ): 15 | itap.check_required_parameters() 16 | 17 | 18 | def test_should_fail_without_token_type(): 19 | itap: IdpTokenAuthPlugin = IdpTokenAuthPlugin() 20 | itap.token = "blah" 21 | 22 | with pytest.raises( 23 | InterfaceError, match="IdC authentication failed: The token type must be included in the connection parameters." 24 | ): 25 | itap.check_required_parameters() 26 | 27 | 28 | def test_get_auth_token_calls_check_required_parameters(mocker): 29 | spy = mocker.spy(IdpTokenAuthPlugin, "check_required_parameters") 30 | itap: IdpTokenAuthPlugin = IdpTokenAuthPlugin() 31 | itap.token = "my_token" 32 | itap.token_type = "testing_token" 33 | 34 | itap.get_auth_token() 35 | assert spy.called 36 | assert spy.call_count == 1 37 | 38 | 39 | def test_get_auth_token_returns_token(): 40 | itap: IdpTokenAuthPlugin = IdpTokenAuthPlugin() 41 | itap.token = "my_token" 42 | itap.token_type = "testing_token" 43 | 44 | result = itap.get_auth_token() 45 | assert result == "my_token" 46 | 47 | 48 | def test_add_parameter_sets_token(): 49 | itap: IdpTokenAuthPlugin = IdpTokenAuthPlugin() 50 | rp: RedshiftProperty = RedshiftProperty() 51 | token_value: str = "a token of appreciation" 52 | rp.token = token_value 53 | itap.add_parameter(rp) 54 | assert itap.token == token_value 55 | 56 | 57 | def test_add_parameter_sets_token_type(): 58 | itap: IdpTokenAuthPlugin = IdpTokenAuthPlugin() 59 | rp: RedshiftProperty = RedshiftProperty() 60 | token_type_value: str = "appreciative token" 61 | rp.token_type = token_type_value 62 | itap.add_parameter(rp) 63 | assert itap.token_type == token_type_value 64 | 65 | 66 | def test_get_sub_type_is_idc(): 67 | itap: IdpTokenAuthPlugin = IdpTokenAuthPlugin() 68 | assert itap.get_sub_type() == IamHelper.IDC_PLUGIN 69 | 70 | 71 | def test_cache_disabled_by_default(): 72 | itap: IdpTokenAuthPlugin = IdpTokenAuthPlugin() 73 | rp: RedshiftProperty = RedshiftProperty() 74 | rp.token_type = "happy token" 75 | rp.token = "hello world" 76 | itap.add_parameter(rp) 77 | assert itap.disable_cache == True 78 | 79 | 80 | def test_get_credentials_calls_refresh(mocker): 81 | itap: IdpTokenAuthPlugin = IdpTokenAuthPlugin() 82 | rp: RedshiftProperty = RedshiftProperty() 83 | rp.token_type = "happy token" 84 | rp.token = "hello world" 85 | itap.add_parameter(rp) 86 | mocker.patch("redshift_connector.plugin.IdpTokenAuthPlugin.refresh", return_value=None) 87 | spy = mocker.spy(IdpTokenAuthPlugin, "refresh") 88 | itap.get_credentials() 89 | assert spy.called 90 | assert spy.call_count == 1 91 | 92 | 93 | def test_refresh_sets_credential(): 94 | itap: IdpTokenAuthPlugin = IdpTokenAuthPlugin() 95 | rp: RedshiftProperty = RedshiftProperty() 96 | rp.token_type = "happy token" 97 | rp.token = "hello world" 98 | itap.add_parameter(rp) 99 | 100 | itap.refresh() 101 | result: NativeTokenHolder = itap.last_refreshed_credentials 102 | assert result is not None 103 | assert isinstance(result, NativeTokenHolder) 104 | assert result.access_token == rp.token 105 | 106 | 107 | def test_get_credentials_returns_credential(): 108 | itap: IdpTokenAuthPlugin = IdpTokenAuthPlugin() 109 | rp: RedshiftProperty = RedshiftProperty() 110 | rp.token_type = "happy token" 111 | rp.token = "hello world" 112 | itap.add_parameter(rp) 113 | 114 | result: NativeTokenHolder = itap.get_credentials() 115 | assert result is not None 116 | assert isinstance(result, NativeTokenHolder) 117 | assert result.access_token == rp.token 118 | -------------------------------------------------------------------------------- /test/unit/plugin/test_plugin_inheritance.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | 3 | from redshift_connector import plugin 4 | from redshift_connector.plugin.i_native_plugin import INativePlugin 5 | from redshift_connector.plugin.i_plugin import IPlugin 6 | from redshift_connector.plugin.idp_credentials_provider import IdpCredentialsProvider 7 | 8 | 9 | def test_i_native_plugin_inherits_from_i_plugin() -> None: 10 | assert issubclass(INativePlugin, IPlugin) 11 | 12 | 13 | def test_idp_credentials_provider_inherits_from_i_plugin() -> None: 14 | assert issubclass(IdpCredentialsProvider, IPlugin) 15 | 16 | 17 | def test_saml_provider_plugin_inherit_from_idp_credentials_provider() -> None: 18 | assert issubclass(plugin.SamlCredentialsProvider, IdpCredentialsProvider) 19 | 20 | 21 | def test_jwt_abc_inherit_from_idp_credentials_provider() -> None: 22 | assert issubclass(plugin.JwtCredentialsProvider, IdpCredentialsProvider) 23 | 24 | 25 | saml_provider_plugins = ( 26 | plugin.BrowserSamlCredentialsProvider, 27 | plugin.OktaCredentialsProvider, 28 | plugin.AdfsCredentialsProvider, 29 | plugin.AzureCredentialsProvider, 30 | plugin.BrowserSamlCredentialsProvider, 31 | plugin.BrowserAzureCredentialsProvider, 32 | ) 33 | 34 | 35 | @pytest.mark.parametrize("saml_plugin", saml_provider_plugins) 36 | def test_saml_provider_plugins_inherit_from_saml_credentials_provider(saml_plugin): 37 | assert issubclass(saml_plugin, plugin.SamlCredentialsProvider) 38 | 39 | 40 | jwt_plugins = (plugin.BrowserAzureOAuth2CredentialsProvider, plugin.BasicJwtCredentialsProvider) 41 | 42 | 43 | @pytest.mark.parametrize("jwt_plugin", jwt_plugins) 44 | def test_jwt_plugins_inherit_from_jwt_abc(jwt_plugin): 45 | assert issubclass(jwt_plugin, plugin.JwtCredentialsProvider) 46 | -------------------------------------------------------------------------------- /test/unit/plugin/test_saml_credentials_provider.py: -------------------------------------------------------------------------------- 1 | import base64 2 | import typing 3 | from test.unit.plugin.data import saml_response_data 4 | from unittest.mock import MagicMock, patch 5 | 6 | import pytest # type: ignore 7 | 8 | from redshift_connector import InterfaceError, RedshiftProperty 9 | from redshift_connector.credentials_holder import CredentialsHolder 10 | from redshift_connector.plugin import SamlCredentialsProvider 11 | 12 | 13 | @patch.multiple(SamlCredentialsProvider, __abstractmethods__=set()) 14 | def make_valid_saml_credentials_provider() -> typing.Tuple[SamlCredentialsProvider, RedshiftProperty]: 15 | rp: RedshiftProperty = RedshiftProperty() 16 | rp.user_name = "AzureDiamond" 17 | rp.password = "hunter2" 18 | scp: SamlCredentialsProvider = SamlCredentialsProvider() # type: ignore 19 | scp.add_parameter(rp) 20 | return scp, rp 21 | 22 | 23 | def test_default_parameters_saml_credentials_provider() -> None: 24 | acp, _ = make_valid_saml_credentials_provider() 25 | assert acp.ssl_insecure == False 26 | assert acp.do_verify_ssl_cert() == True 27 | 28 | 29 | def test_get_cache_key_format_as_expected() -> None: 30 | scp, _ = make_valid_saml_credentials_provider() 31 | expected_cache_key: str = "{username}{password}{idp_host}{idp_port}{duration}{preferred_role}".format( 32 | username=scp.user_name, 33 | password=scp.password, 34 | idp_host=scp.idp_host, 35 | idp_port=scp.idpPort, 36 | duration=scp.duration, 37 | preferred_role=scp.preferred_role, 38 | ) 39 | assert scp.get_cache_key() == expected_cache_key 40 | 41 | 42 | def test_get_credentials_uses_cache_when_exists(mocker) -> None: 43 | scp, _ = make_valid_saml_credentials_provider() 44 | mock_credentials = MagicMock() 45 | mock_credentials.is_expired.return_value = False 46 | scp.cache[scp.get_cache_key()] = mock_credentials 47 | 48 | spy = mocker.spy(SamlCredentialsProvider, "refresh") 49 | 50 | assert scp.get_credentials() == mock_credentials 51 | assert spy.called is False 52 | 53 | 54 | def test_get_credentials_calls_refresh_when_cache_expired(mocker) -> None: 55 | scp, _ = make_valid_saml_credentials_provider() 56 | mock_credentials = MagicMock() 57 | mock_credentials.is_expired.return_value = True 58 | scp.cache[scp.get_cache_key()] = mock_credentials 59 | 60 | mocker.patch("redshift_connector.plugin.SamlCredentialsProvider.refresh", return_value=None) 61 | spy = mocker.spy(SamlCredentialsProvider, "refresh") 62 | 63 | scp.get_credentials() 64 | 65 | assert spy.called 66 | assert spy.call_count == 1 67 | 68 | 69 | def test_get_credentials_sets_db_user_when_present(mocker) -> None: 70 | scp, _ = make_valid_saml_credentials_provider() 71 | mocked_db_user: str = "test_db_user" 72 | scp.db_user = mocked_db_user 73 | mock_credentials = MagicMock() 74 | mock_credentials.is_expired.return_value = True 75 | mock_credentials.metadata = MagicMock() 76 | scp.cache[scp.get_cache_key()] = mock_credentials 77 | 78 | mocker.patch("redshift_connector.plugin.SamlCredentialsProvider.refresh", return_value=None) 79 | spy = mocker.spy(mock_credentials.metadata, "set_db_user") 80 | 81 | scp.get_credentials() 82 | 83 | assert spy.called 84 | assert spy.call_count == 1 85 | assert spy.call_args[0][0] == mocked_db_user 86 | 87 | 88 | def test_refresh_get_saml_assertion_fails(mocker) -> None: 89 | scp, _ = make_valid_saml_credentials_provider() 90 | 91 | with patch("redshift_connector.plugin.SamlCredentialsProvider.get_saml_assertion") as mocked_get_saml_assertion: 92 | mocked_get_saml_assertion.side_effect = Exception("bad robot") 93 | 94 | with pytest.raises(InterfaceError, match="Failed to get SAML assertion"): 95 | scp.refresh() 96 | 97 | 98 | def test_refresh_saml_assertion_missing_role_should_fail(mocker) -> None: 99 | scp, _ = make_valid_saml_credentials_provider() 100 | mocked_data: str = "test" 101 | mocker.patch("redshift_connector.plugin.SamlCredentialsProvider.get_saml_assertion", return_value=mocked_data) 102 | 103 | with pytest.raises( 104 | InterfaceError, 105 | match="No roles were found in SAML assertion. Please verify IdP configuration provides ARNs in the SAML https://aws.amazon.com/SAML/Attributes/Role Attribute.", 106 | ): 107 | scp.refresh() 108 | 109 | 110 | def test_refresh_saml_assertion_passed_to_boto(mocker) -> None: 111 | scp, _ = make_valid_saml_credentials_provider() 112 | mocker.patch( 113 | "redshift_connector.plugin.SamlCredentialsProvider.get_saml_assertion", 114 | return_value=base64.b64encode(saml_response_data.saml_response), 115 | ) 116 | 117 | mocked_response: typing.Dict[str, typing.Any] = { 118 | "Credentials": {"Expiration": "test_expiry", "Things": "much data", "Other Things": "more data"} 119 | } 120 | 121 | mocked_boto = MagicMock() 122 | mocked_boto_response = MagicMock() 123 | mocked_boto_response.return_value = mocked_response 124 | mocked_boto.assume_role_with_saml = mocked_boto_response 125 | 126 | credential_holder_spy = mocker.spy(CredentialsHolder, "__init__") 127 | 128 | mocker.patch("boto3.client", return_value=mocked_boto) 129 | 130 | scp.refresh() 131 | assert credential_holder_spy.called 132 | assert credential_holder_spy.call_count == 1 133 | assert credential_holder_spy.call_args[0][1] == mocked_response["Credentials"] 134 | -------------------------------------------------------------------------------- /test/unit/test_array_util.py: -------------------------------------------------------------------------------- 1 | import typing 2 | 3 | import pytest # type: ignore 4 | 5 | from redshift_connector.utils import array_util 6 | 7 | walk_array_data: typing.List = [ 8 | ( 9 | [10, 9, 8, 7, 6], 10 | [ 11 | ([10, 9, 8, 7, 6], 0, 10), 12 | ([10, 9, 8, 7, 6], 1, 9), 13 | ([10, 9, 8, 7, 6], 2, 8), 14 | ([10, 9, 8, 7, 6], 3, 7), 15 | ([10, 9, 8, 7, 6], 4, 6), 16 | ], 17 | ), 18 | ( 19 | [1, 2, [3, 4, 5], 6], 20 | [ 21 | ([1, 2, [3, 4, 5], 6], 0, 1), 22 | ([1, 2, [3, 4, 5], 6], 1, 2), 23 | ([3, 4, 5], 0, 3), 24 | ([3, 4, 5], 1, 4), 25 | ([3, 4, 5], 2, 5), 26 | ([1, 2, [3, 4, 5], 6], 3, 6), 27 | ], 28 | ), 29 | ] 30 | 31 | 32 | @pytest.mark.parametrize("_input", walk_array_data) 33 | def test_walk_array(_input) -> None: 34 | in_val, exp_vals = _input 35 | x: typing.Generator = array_util.walk_array(in_val) 36 | idx: int = 0 37 | for a, b, c in x: 38 | assert a == exp_vals[idx][0] 39 | assert b == exp_vals[idx][1] 40 | assert c == exp_vals[idx][2] 41 | idx += 1 42 | 43 | 44 | array_flatten_data: typing.List = [ 45 | ([1, 2, 3, 4], [1, 2, 3, 4]), 46 | ([1, [2], 3, 4], [1, 2, 3, 4]), 47 | ([1, [2, [3]], 4, [5, [6]]], [1, 2, 3, 4, 5, 6]), 48 | ([[1]], [1]), 49 | ] 50 | 51 | 52 | @pytest.mark.parametrize("_input", array_flatten_data) 53 | def test_array_flatten(_input) -> None: 54 | in_val, exp_val = _input 55 | assert 1 == 1 56 | assert list(array_util.array_flatten(in_val)) == exp_val 57 | 58 | 59 | array_find_first_element_data: typing.List = [ 60 | ([1], 1), 61 | ([None, None, [None, None, 1], 2], 1), 62 | ([[[1]]], 1), 63 | ([None, None, [None]], None), 64 | ] 65 | 66 | 67 | @pytest.mark.parametrize("_input", array_find_first_element_data) 68 | def test_array_find_first_element(_input) -> None: 69 | in_val, exp_val = _input 70 | assert array_util.array_find_first_element(in_val) == exp_val 71 | 72 | 73 | array_has_null_data: typing.List = [ 74 | ([None], True), 75 | ([1, 2, 3, 4, [[[None]]]], True), 76 | ([1, 2, 3, 4], False), 77 | ] 78 | 79 | 80 | @pytest.mark.parametrize("_input", array_has_null_data) 81 | def test_array_has_null(_input) -> None: 82 | in_val, exp_val = _input 83 | assert array_util.array_has_null(in_val) is exp_val 84 | -------------------------------------------------------------------------------- /test/unit/test_credentials_holder.py: -------------------------------------------------------------------------------- 1 | import datetime 2 | import typing 3 | from unittest.mock import MagicMock 4 | 5 | import pytest # type: ignore 6 | 7 | from redshift_connector.credentials_holder import ( 8 | ABCAWSCredentialsHolder, 9 | ABCCredentialsHolder, 10 | AWSDirectCredentialsHolder, 11 | AWSProfileCredentialsHolder, 12 | CredentialsHolder, 13 | ) 14 | 15 | 16 | @pytest.mark.parametrize("cred_holder", (AWSDirectCredentialsHolder, AWSProfileCredentialsHolder)) 17 | def test_aws_credentials_holder_inherit_from_abc(cred_holder) -> None: 18 | assert issubclass(cred_holder, ABCAWSCredentialsHolder) 19 | 20 | 21 | def test_credentials_holder_inherits_from_abc() -> None: 22 | assert issubclass(CredentialsHolder, ABCCredentialsHolder) 23 | 24 | 25 | def test_aws_direct_credentials_holder_should_have_session() -> None: 26 | mocked_session: MagicMock = MagicMock() 27 | obj: AWSDirectCredentialsHolder = AWSDirectCredentialsHolder( 28 | access_key_id="something", 29 | secret_access_key="secret", 30 | session_token="fornow", 31 | session=mocked_session, 32 | ) 33 | 34 | assert isinstance(obj, ABCAWSCredentialsHolder) 35 | assert hasattr(obj, "get_boto_session") 36 | assert obj.has_associated_session == True 37 | assert obj.get_boto_session() == mocked_session 38 | 39 | 40 | valid_aws_direct_credential_params: typing.List[typing.Dict[str, typing.Optional[str]]] = [ 41 | { 42 | "access_key_id": "something", 43 | "secret_access_key": "secret", 44 | "session_token": "fornow", 45 | }, 46 | { 47 | "access_key_id": "something", 48 | "secret_access_key": "secret", 49 | "session_token": None, 50 | }, 51 | ] 52 | 53 | 54 | @pytest.mark.parametrize("input", valid_aws_direct_credential_params) 55 | def test_aws_direct_credentials_holder_get_session_credentials(input) -> None: 56 | input["session"] = MagicMock() 57 | obj: AWSDirectCredentialsHolder = AWSDirectCredentialsHolder(**input) 58 | 59 | ret_value: typing.Dict[str, str] = obj.get_session_credentials() 60 | 61 | assert len(ret_value) == 3 if input["session_token"] is not None else 2 62 | 63 | assert ret_value["aws_access_key_id"] == input["access_key_id"] 64 | assert ret_value["aws_secret_access_key"] == input["secret_access_key"] 65 | 66 | if input["session_token"] is not None: 67 | assert ret_value["aws_session_token"] == input["session_token"] 68 | 69 | 70 | def test_aws_profile_credentials_holder_should_have_session() -> None: 71 | mocked_session: MagicMock = MagicMock() 72 | obj: AWSProfileCredentialsHolder = AWSProfileCredentialsHolder(profile="myprofile", session=mocked_session) 73 | 74 | assert isinstance(obj, ABCAWSCredentialsHolder) 75 | assert hasattr(obj, "get_boto_session") 76 | assert obj.has_associated_session == True 77 | assert obj.get_boto_session() == mocked_session 78 | 79 | 80 | def test_aws_profile_credentials_holder_get_session_credentials() -> None: 81 | profile_val: str = "myprofile" 82 | obj: AWSProfileCredentialsHolder = AWSProfileCredentialsHolder(profile=profile_val, session=MagicMock()) 83 | 84 | ret_value = obj.get_session_credentials() 85 | assert len(ret_value) == 1 86 | 87 | assert ret_value["profile"] == profile_val 88 | 89 | 90 | @pytest.mark.parametrize( 91 | "expiration_delta", 92 | [ 93 | datetime.timedelta(hours=3), # expired 3 hrs ago 94 | datetime.timedelta(days=1), # expired 1 day ago 95 | datetime.timedelta(weeks=1), # expired 1 week ago 96 | ], 97 | ) 98 | def test_is_expired_true(expiration_delta) -> None: 99 | credentials: typing.Dict[str, typing.Any] = { 100 | "AccessKeyId": "something", 101 | "SecretAccessKey": "secret", 102 | "SessionToken": "fornow", 103 | "Expiration": datetime.datetime.now(datetime.timezone.utc) - expiration_delta, 104 | } 105 | 106 | obj: CredentialsHolder = CredentialsHolder(credentials=credentials) 107 | 108 | assert obj.is_expired() == True 109 | 110 | 111 | @pytest.mark.parametrize( 112 | "expiration_delta", 113 | [ 114 | datetime.timedelta(minutes=2), # expired 1 minute ago 115 | datetime.timedelta(hours=3), # expired 3 hrs ago 116 | datetime.timedelta(days=1), # expired 1 day ago 117 | datetime.timedelta(weeks=1), # expired 1 week ago 118 | ], 119 | ) 120 | def test_is_expired_false(expiration_delta) -> None: 121 | credentials: typing.Dict[str, typing.Any] = { 122 | "AccessKeyId": "something", 123 | "SecretAccessKey": "secret", 124 | "SessionToken": "fornow", 125 | "Expiration": datetime.datetime.now(datetime.timezone.utc) + expiration_delta, 126 | } 127 | 128 | obj: CredentialsHolder = CredentialsHolder(credentials=credentials) 129 | 130 | assert obj.is_expired() == False 131 | -------------------------------------------------------------------------------- /test/unit/test_dbapi20.py: -------------------------------------------------------------------------------- 1 | import time 2 | 3 | import redshift_connector 4 | 5 | driver = redshift_connector 6 | 7 | 8 | def test_apilevel() -> None: 9 | # Must exist 10 | apilevel: str = driver.apilevel 11 | 12 | # Must equal 2.0 13 | assert apilevel == "2.0" 14 | 15 | 16 | def test_threadsafety() -> None: 17 | try: 18 | # Must exist 19 | threadsafety: int = driver.threadsafety 20 | # Must be a valid value 21 | assert threadsafety in (0, 1, 2, 3) 22 | except AttributeError: 23 | assert False, "Driver doesn't define threadsafety" 24 | 25 | 26 | def test_paramstyle() -> None: 27 | from redshift_connector.config import DbApiParamstyle 28 | 29 | try: 30 | # Must exist 31 | paramstyle: str = driver.paramstyle 32 | # Must be a valid value 33 | assert paramstyle in DbApiParamstyle.list() 34 | except AttributeError: 35 | assert False, "Driver doesn't define paramstyle" 36 | 37 | 38 | def test_Exceptions() -> None: 39 | # Make sure required exceptions exist, and are in the 40 | # defined heirarchy. 41 | assert issubclass(driver.Warning, Exception) 42 | assert issubclass(driver.Error, Exception) 43 | assert issubclass(driver.InterfaceError, driver.Error) 44 | assert issubclass(driver.DatabaseError, driver.Error) 45 | assert issubclass(driver.OperationalError, driver.Error) 46 | assert issubclass(driver.IntegrityError, driver.Error) 47 | assert issubclass(driver.InternalError, driver.Error) 48 | assert issubclass(driver.ProgrammingError, driver.Error) 49 | assert issubclass(driver.NotSupportedError, driver.Error) 50 | 51 | 52 | def test_Date() -> None: 53 | driver.Date(2002, 12, 25) 54 | driver.DateFromTicks(time.mktime((2002, 12, 25, 0, 0, 0, 0, 0, 0))) 55 | 56 | 57 | def test_Time() -> None: 58 | driver.Time(13, 45, 30) 59 | driver.TimeFromTicks(time.mktime((2001, 1, 1, 13, 45, 30, 0, 0, 0))) 60 | 61 | 62 | def test_Timestamp() -> None: 63 | driver.Timestamp(2002, 12, 25, 13, 45, 30) 64 | driver.TimestampFromTicks(time.mktime((2002, 12, 25, 13, 45, 30, 0, 0, 0))) 65 | 66 | 67 | def test_Binary() -> None: 68 | driver.Binary(b"Something") 69 | driver.Binary(b"") 70 | 71 | 72 | def test_STRING() -> None: 73 | assert hasattr(driver, "STRING"), "module.STRING must be defined" 74 | 75 | 76 | def test_BINARY() -> None: 77 | assert hasattr(driver, "BINARY"), "module.BINARY must be defined." 78 | 79 | 80 | def test_NUMBER() -> None: 81 | assert hasattr(driver, "NUMBER"), "module.NUMBER must be defined." 82 | 83 | 84 | def test_DATETIME() -> None: 85 | assert hasattr(driver, "DATETIME"), "module.DATETIME must be defined." 86 | 87 | 88 | def test_ROWID() -> None: 89 | assert hasattr(driver, "ROWID"), "module.ROWID must be defined." 90 | -------------------------------------------------------------------------------- /test/unit/test_driver_info.py: -------------------------------------------------------------------------------- 1 | import re 2 | 3 | from redshift_connector.utils import DriverInfo 4 | 5 | 6 | def test_version_is_not_none() -> None: 7 | assert DriverInfo.version() is not None 8 | 9 | 10 | def test_version_is_str() -> None: 11 | assert isinstance(DriverInfo.version(), str) 12 | 13 | 14 | def test_version_proper_format() -> None: 15 | version_regex: re.Pattern = re.compile(r"^\d+(\.\d+){2,3}$") 16 | assert version_regex.match(DriverInfo.version()) 17 | 18 | 19 | def test_driver_name_is_not_none() -> None: 20 | assert DriverInfo.driver_name() is not None 21 | 22 | 23 | def test_driver_short_name_is_not_none() -> None: 24 | assert DriverInfo.driver_short_name() is not None 25 | 26 | 27 | def test_driver_full_name_is_not_none() -> None: 28 | assert DriverInfo.driver_full_name() is not None 29 | 30 | 31 | def test_driver_full_name_contains_name() -> None: 32 | assert DriverInfo.driver_name() in DriverInfo.driver_full_name() 33 | 34 | 35 | def test_driver_full_name_contains_version() -> None: 36 | assert DriverInfo.version() in DriverInfo.driver_full_name() 37 | -------------------------------------------------------------------------------- /test/unit/test_idp_auth_helper.py: -------------------------------------------------------------------------------- 1 | from unittest.mock import MagicMock 2 | 3 | from packaging.version import Version 4 | from redshift_connector.idp_auth_helper import IdpAuthHelper 5 | 6 | 7 | def test_get_pkg_version(mocker) -> None: 8 | mocker.patch("importlib.metadata.version", return_value=None) 9 | 10 | module_mock = MagicMock() 11 | module_mock.__version__ = "9.8.7" 12 | mocker.patch("importlib.import_module", return_value=module_mock) 13 | 14 | actual_version: Version = IdpAuthHelper.get_pkg_version("test_module") 15 | 16 | assert actual_version == Version("9.8.7") 17 | -------------------------------------------------------------------------------- /test/unit/test_import.py: -------------------------------------------------------------------------------- 1 | def test_import_redshift_connector() -> None: 2 | import redshift_connector 3 | -------------------------------------------------------------------------------- /test/unit/test_logging_utils.py: -------------------------------------------------------------------------------- 1 | import pytest # type: ignore 2 | 3 | from redshift_connector import RedshiftProperty 4 | from redshift_connector.utils.logging_utils import mask_secure_info_in_props 5 | 6 | secret_rp_values = ( 7 | "password", 8 | "access_key_id", 9 | "session_token", 10 | "secret_access_key", 11 | "client_id", 12 | "client_secret", 13 | "web_identity_token", 14 | ) 15 | 16 | 17 | @pytest.mark.parametrize("rp_arg", secret_rp_values) 18 | def test_mask_secure_info_in_props_obscures_secret_value(rp_arg) -> None: 19 | rp: RedshiftProperty = RedshiftProperty() 20 | secret_value: str = "SECRET_VALUE" 21 | rp.put(rp_arg, secret_value) 22 | result = mask_secure_info_in_props(rp) 23 | assert result.__getattribute__(rp_arg) != secret_value 24 | 25 | 26 | def test_mask_secure_info_in_props_no_info() -> None: 27 | assert mask_secure_info_in_props(None) is None # type: ignore 28 | -------------------------------------------------------------------------------- /test/unit/test_module.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | 3 | import redshift_connector 4 | from redshift_connector.utils.oids import RedshiftOID 5 | 6 | 7 | @pytest.mark.parametrize( 8 | "error_class", 9 | ( 10 | "Warning", 11 | "Error", 12 | "InterfaceError", 13 | "DatabaseError", 14 | "OperationalError", 15 | "IntegrityError", 16 | "InternalError", 17 | "ProgrammingError", 18 | "NotSupportedError", 19 | "ArrayContentNotSupportedError", 20 | "ArrayContentNotHomogenousError", 21 | "ArrayDimensionsNotConsistentError", 22 | ), 23 | ) 24 | def test_errors_available_on_module(error_class) -> None: 25 | import importlib 26 | 27 | getattr(importlib.import_module("redshift_connector"), error_class) 28 | 29 | 30 | def test_cursor_on_module() -> None: 31 | import importlib 32 | 33 | getattr(importlib.import_module("redshift_connector"), "Cursor") 34 | 35 | 36 | def test_connection_on_module() -> None: 37 | import importlib 38 | 39 | getattr(importlib.import_module("redshift_connector"), "Connection") 40 | 41 | 42 | def test_version_on_module() -> None: 43 | import importlib 44 | 45 | getattr(importlib.import_module("redshift_connector"), "__version__") 46 | 47 | 48 | @pytest.mark.parametrize("datatype", [d.name for d in RedshiftOID]) 49 | def test_datatypes_on_module(datatype) -> None: 50 | assert datatype in redshift_connector.__all__ 51 | -------------------------------------------------------------------------------- /test/unit/test_native_plugin_helper.py: -------------------------------------------------------------------------------- 1 | from unittest.mock import MagicMock 2 | 3 | from pytest_mock import mocker # type: ignore 4 | 5 | from redshift_connector.iam_helper import IamHelper, IdpAuthHelper 6 | from redshift_connector.native_plugin_helper import NativeAuthPluginHelper 7 | 8 | 9 | def test_set_native_auth_plugin_properties_gets_idp_token_when_credentials_provider(mocker) -> None: 10 | mocked_idp_token: str = "my_idp_token" 11 | mocker.patch("redshift_connector.iam_helper.IdpAuthHelper.set_auth_properties", return_value=None) 12 | mocker.patch( 13 | "redshift_connector.native_plugin_helper.NativeAuthPluginHelper.get_native_auth_plugin_credentials", 14 | return_value=mocked_idp_token, 15 | ) 16 | spy = mocker.spy(NativeAuthPluginHelper, "get_native_auth_plugin_credentials") 17 | mock_rp: MagicMock = MagicMock() 18 | 19 | NativeAuthPluginHelper.set_native_auth_plugin_properties(mock_rp) 20 | 21 | assert spy.called is True 22 | assert spy.call_count == 1 23 | assert spy.call_args[0][0] == mock_rp 24 | assert mock_rp.method_calls[0][0] == "put" 25 | assert mock_rp.method_calls[0][1] == ("web_identity_token", mocked_idp_token) 26 | -------------------------------------------------------------------------------- /test/unit/test_paramstyle.py: -------------------------------------------------------------------------------- 1 | import typing 2 | 3 | import pytest 4 | 5 | from redshift_connector.config import DbApiParamstyle 6 | from redshift_connector.core import convert_paramstyle as convert 7 | 8 | # Tests of the convert_paramstyle function. 9 | 10 | 11 | @pytest.mark.parametrize( 12 | "in_statement,out_statement,args", 13 | [ 14 | ( 15 | 'SELECT ?, ?, "field_?" FROM t ' "WHERE a='say ''what?''' AND b=? AND c=E'?\\'test\\'?'", 16 | 'SELECT $1, $2, "field_?" FROM t WHERE ' "a='say ''what?''' AND b=$3 AND c=E'?\\'test\\'?'", 17 | (1, 2, 3), 18 | ), 19 | ( 20 | "SELECT ?, ?, * FROM t WHERE a=? AND b='are you ''sure?'", 21 | "SELECT $1, $2, * FROM t WHERE a=$3 AND b='are you ''sure?'", 22 | (1, 2, 3), 23 | ), 24 | ], 25 | ) 26 | def test_qmark(in_statement, out_statement, args) -> None: 27 | new_query, make_args = convert(DbApiParamstyle.QMARK.value, in_statement) 28 | assert new_query == out_statement 29 | assert make_args(args) == args 30 | 31 | 32 | def test_numeric() -> None: 33 | new_query, make_args = convert( 34 | DbApiParamstyle.NUMERIC.value, "SELECT sum(x)::decimal(5, 2) :2, :1, * FROM t WHERE a=:3" 35 | ) 36 | expected: str = "SELECT sum(x)::decimal(5, 2) $2, $1, * FROM t WHERE a=$3" 37 | assert new_query == expected 38 | assert make_args((1, 2, 3)) == (1, 2, 3) 39 | 40 | 41 | def test_numeric_default_parameter() -> None: 42 | new_query, make_args = convert(DbApiParamstyle.NUMERIC.value, "make_interval(days := 10)") 43 | 44 | assert new_query == "make_interval(days := 10)" 45 | assert make_args((1, 2, 3)) == (1, 2, 3) 46 | 47 | 48 | def test_named() -> None: 49 | new_query, make_args = convert( 50 | DbApiParamstyle.NAMED.value, "SELECT sum(x)::decimal(5, 2) :f_2, :f1 FROM t WHERE a=:f_2" 51 | ) 52 | expected: str = "SELECT sum(x)::decimal(5, 2) $1, $2 FROM t WHERE a=$1" 53 | assert new_query == expected 54 | assert make_args({"f_2": 1, "f1": 2}) == (1, 2) 55 | 56 | 57 | def test_format() -> None: 58 | new_query, make_args = convert( 59 | DbApiParamstyle.FORMAT.value, 60 | "SELECT %s, %s, \"f1_%%\", E'txt_%%' FROM t WHERE a=%s AND b='75%%' AND c = '%' -- Comment with %", 61 | ) 62 | expected: str = ( 63 | "SELECT $1, $2, \"f1_%%\", E'txt_%%' FROM t WHERE a=$3 AND " "b='75%%' AND c = '%' -- Comment with %" 64 | ) 65 | assert new_query == expected 66 | assert make_args((1, 2, 3)) == (1, 2, 3) 67 | 68 | 69 | def test_format_multiline() -> None: 70 | new_query, make_args = convert(DbApiParamstyle.FORMAT.value, "SELECT -- Comment\n%s FROM t") 71 | assert new_query == "SELECT -- Comment\n$1 FROM t" 72 | 73 | 74 | @pytest.mark.parametrize("paramstyle", DbApiParamstyle.list()) 75 | @pytest.mark.parametrize( 76 | "statement", 77 | ( 78 | """ 79 | EXPLAIN 80 | /* blabla 81 | something 100% with percent 82 | */ 83 | SELECT {} 84 | """, 85 | """ 86 | EXPLAIN 87 | /* blabla 88 | %% %s :blah %sbooze $1 %%s 89 | */ 90 | SELECT {} 91 | """, 92 | """/* multiple line comment here */""", 93 | """ 94 | /* this is my multi-line sql comment */ 95 | select 96 | pk_id, 97 | {}, 98 | -- shared_id, disabled until 12/12/2020 99 | order_date 100 | from my_table 101 | """, 102 | """/**/select {}""", 103 | """select {} 104 | /*\n 105 | some comments about the logic 106 | */ -- redo later""", 107 | r"""COMMENT ON TABLE test_schema.comment_test """ r"""IS 'the test % '' " \ table comment'""", 108 | ), 109 | ) 110 | def test_multiline_single_parameter(paramstyle, statement) -> None: 111 | in_statement = statement 112 | format_char = None 113 | expected = statement.format("$1") 114 | 115 | if paramstyle == DbApiParamstyle.FORMAT.value: 116 | format_char = "%s" 117 | elif paramstyle == DbApiParamstyle.PYFORMAT.value: 118 | format_char = "%(f1)s" 119 | elif paramstyle == DbApiParamstyle.NAMED.value: 120 | format_char = ":beer" 121 | elif paramstyle == DbApiParamstyle.NUMERIC.value: 122 | format_char = ":1" 123 | elif paramstyle == DbApiParamstyle.QMARK.value: 124 | format_char = "?" 125 | in_statement = in_statement.format(format_char) 126 | 127 | new_query, make_args = convert(paramstyle, in_statement) 128 | assert new_query == expected 129 | 130 | 131 | def test_py_format() -> None: 132 | new_query, make_args = convert( 133 | DbApiParamstyle.PYFORMAT.value, 134 | "SELECT %(f2)s, %(f1)s, \"f1_%%\", E'txt_%%' " "FROM t WHERE a=%(f2)s AND b='75%%'", 135 | ) 136 | expected: str = "SELECT $1, $2, \"f1_%%\", E'txt_%%' FROM t WHERE a=$1 AND " "b='75%%'" 137 | assert new_query == expected 138 | assert make_args({"f2": 1, "f1": 2, "f3": 3}) == (1, 2) 139 | 140 | # pyformat should support %s and an array, too: 141 | new_query, make_args = convert( 142 | DbApiParamstyle.PYFORMAT.value, "SELECT %s, %s, \"f1_%%\", E'txt_%%' " "FROM t WHERE a=%s AND b='75%%'" 143 | ) 144 | expected = "SELECT $1, $2, \"f1_%%\", E'txt_%%' FROM t WHERE a=$3 AND " "b='75%%'" 145 | assert new_query, expected 146 | assert make_args((1, 2, 3)) == (1, 2, 3) 147 | -------------------------------------------------------------------------------- /test/unit/test_type_utils.py: -------------------------------------------------------------------------------- 1 | import typing 2 | from datetime import date, datetime, time 3 | from decimal import Decimal 4 | from enum import Enum 5 | 6 | import pytest # type: ignore 7 | 8 | from redshift_connector.config import ( 9 | EPOCH, 10 | INFINITY_MICROSECONDS, 11 | MINUS_INFINITY_MICROSECONDS, 12 | ) 13 | from redshift_connector.utils import type_utils 14 | 15 | 16 | @pytest.mark.parametrize("_input", [(True, b"\x01"), (False, b"\x00")]) 17 | def test_bool_send(_input) -> None: 18 | in_val, exp_val = _input 19 | assert type_utils.bool_send(in_val) == exp_val 20 | 21 | 22 | @pytest.mark.parametrize("_input", [None, 1]) 23 | def test_null_send(_input) -> None: 24 | assert type_utils.null_send(_input) == type_utils.NULL 25 | 26 | 27 | class Apple(Enum): 28 | macintosh: int = 1 29 | granny_smith: int = 2 30 | ambrosia: int = 3 31 | 32 | 33 | class Orange(Enum): 34 | navel: int = 1 35 | blood: int = 2 36 | cara_cara: int = 3 37 | 38 | 39 | @pytest.mark.parametrize("_input", [(Apple.macintosh, b"1"), (Orange.cara_cara, b"3")]) 40 | def test_enum_out(_input) -> None: 41 | in_val, exp_val = _input 42 | assert type_utils.enum_out(in_val) == exp_val 43 | 44 | 45 | @pytest.mark.parametrize( 46 | "_input", [(time(hour=0, minute=0, second=0), b"00:00:00"), (time(hour=12, minute=34, second=56), b"12:34:56")] 47 | ) 48 | def test_time_out(_input) -> None: 49 | in_val, exp_val = _input 50 | assert type_utils.time_out(in_val) == exp_val 51 | 52 | 53 | @pytest.mark.parametrize( 54 | "_input", [(date(month=1, day=1, year=1), b"0001-01-01"), (date(month=1, day=31, year=2020), b"2020-01-31")] 55 | ) 56 | def test_date_out(_input) -> None: 57 | in_val, exp_val = _input 58 | assert type_utils.date_out(in_val) == exp_val 59 | 60 | 61 | @pytest.mark.parametrize( 62 | "_input", 63 | [ 64 | (Decimal(123.45678), b"123.4567799999999948568074614740908145904541015625"), 65 | (Decimal(123456789.012345), b"123456789.01234500110149383544921875"), 66 | ], 67 | ) 68 | def test_numeric_out(_input) -> None: 69 | in_val, exp_val = _input 70 | assert type_utils.numeric_out(in_val) == exp_val 71 | 72 | 73 | timestamp_send_integer_data: typing.List[typing.Tuple[bytes, datetime]] = [ 74 | (b"00000000", datetime.max), 75 | (b"12345678", datetime.max), 76 | (INFINITY_MICROSECONDS.to_bytes(length=8, byteorder="big"), datetime.max), 77 | (MINUS_INFINITY_MICROSECONDS.to_bytes(signed=True, length=8, byteorder="big"), datetime.min), 78 | ] 79 | 80 | 81 | @pytest.mark.parametrize("_input", timestamp_send_integer_data) 82 | def test_timestamp_recv_integer(_input) -> None: 83 | in_val, exp_val = _input 84 | print(type_utils.timestamp_recv_integer(in_val, 0, 0)) 85 | print(EPOCH.timestamp() * 1000) 86 | assert type_utils.timestamp_recv_integer(in_val, 0, 0) == exp_val 87 | -------------------------------------------------------------------------------- /test/unit/test_typeobjects.py: -------------------------------------------------------------------------------- 1 | import typing 2 | import unittest 3 | 4 | import pytest 5 | 6 | from redshift_connector.interval import Interval 7 | 8 | 9 | @pytest.mark.parametrize( 10 | "kwargs, exp_months, exp_days, exp_microseconds", 11 | ( 12 | ( 13 | {"months": 1}, 14 | 1, 15 | 0, 16 | 0, 17 | ), 18 | ({"days": 1}, 0, 1, 0), 19 | ({"microseconds": 1}, 0, 0, 1), 20 | ({"months": 1, "days": 2, "microseconds": 3}, 1, 2, 3), 21 | ), 22 | ) 23 | def test_interval_constructor(kwargs, exp_months, exp_days, exp_microseconds) -> None: 24 | i = Interval(**kwargs) 25 | assert i.months == exp_months 26 | assert i.days == exp_days 27 | assert i.microseconds == exp_microseconds 28 | 29 | 30 | def test_default_constructor() -> None: 31 | i: Interval = Interval() 32 | assert i.months == 0 33 | assert i.days == 0 34 | assert i.microseconds == 0 35 | 36 | 37 | def interval_range_test(parameter, in_range, out_of_range): 38 | for v in out_of_range: 39 | try: 40 | Interval(**{parameter: v}) 41 | pytest.fail("expected OverflowError") 42 | except OverflowError: 43 | pass 44 | for v in in_range: 45 | Interval(**{parameter: v}) 46 | 47 | 48 | def test_interval_days_range() -> None: 49 | out_of_range_days = ( 50 | -2147483649, 51 | +2147483648, 52 | ) 53 | in_range_days = ( 54 | -2147483648, 55 | +2147483647, 56 | ) 57 | interval_range_test("days", in_range_days, out_of_range_days) 58 | 59 | 60 | def test_interval_months_range() -> None: 61 | out_of_range_months = ( 62 | -2147483649, 63 | +2147483648, 64 | ) 65 | in_range_months = ( 66 | -2147483648, 67 | +2147483647, 68 | ) 69 | interval_range_test("months", in_range_months, out_of_range_months) 70 | 71 | 72 | def test_interval_microseconds_range() -> None: 73 | out_of_range_microseconds = ( 74 | -9223372036854775809, 75 | +9223372036854775808, 76 | ) 77 | in_range_microseconds = ( 78 | -9223372036854775808, 79 | +9223372036854775807, 80 | ) 81 | interval_range_test("microseconds", in_range_microseconds, out_of_range_microseconds) 82 | 83 | 84 | @pytest.mark.parametrize( 85 | "kwargs, exp_total_seconds", 86 | ( 87 | ({"months": 1}, 0), 88 | ({"days": 1}, 86400), 89 | ({"microseconds": 1}, 1e-6), 90 | ({"months": 1, "days": 2, "microseconds": 3}, 172800.000003), 91 | ), 92 | ) 93 | def test_total_seconds(kwargs, exp_total_seconds) -> None: 94 | i: Interval = Interval(**kwargs) 95 | assert i.total_seconds() == exp_total_seconds 96 | 97 | 98 | def test_set_months_raises_type_error() -> None: 99 | with pytest.raises(TypeError): 100 | Interval(months="foobar") # type: ignore 101 | 102 | 103 | def test_set_days_raises_type_error() -> None: 104 | with pytest.raises(TypeError): 105 | Interval(days="foobar") # type: ignore 106 | 107 | 108 | def test_set_microseconds_raises_type_error() -> None: 109 | with pytest.raises(TypeError): 110 | Interval(microseconds="foobar") # type: ignore 111 | 112 | 113 | interval_equality_test_vals: typing.Tuple[ 114 | typing.Tuple[typing.Optional[Interval], typing.Optional[Interval], bool], ... 115 | ] = ( 116 | (Interval(months=1), Interval(months=1), True), 117 | (Interval(months=1), Interval(), False), 118 | (Interval(months=1), Interval(months=2), False), 119 | (Interval(days=1), Interval(days=1), True), 120 | (Interval(days=1), Interval(), False), 121 | (Interval(days=1), Interval(days=2), False), 122 | (Interval(microseconds=1), Interval(microseconds=1), True), 123 | (Interval(microseconds=1), Interval(), False), 124 | (Interval(microseconds=1), Interval(microseconds=2), False), 125 | (Interval(), Interval(), True), 126 | ) 127 | 128 | 129 | @pytest.mark.parametrize("i1, i2, exp_eq", interval_equality_test_vals) 130 | def test__eq__(i1, i2, exp_eq) -> None: 131 | actual_eq = i1.__eq__(i2) 132 | assert actual_eq == exp_eq 133 | 134 | 135 | @pytest.mark.parametrize("i1, i2, exp_eq", interval_equality_test_vals) 136 | def test__neq__(i1, i2, exp_eq) -> None: 137 | exp_neq = not exp_eq 138 | actual_neq = i1.__neq__(i2) 139 | assert actual_neq == exp_neq 140 | -------------------------------------------------------------------------------- /test/utils/__init__.py: -------------------------------------------------------------------------------- 1 | from .decorators import numpy_only, pandas_only 2 | -------------------------------------------------------------------------------- /test/utils/decorators.py: -------------------------------------------------------------------------------- 1 | import pytest # type: ignore 2 | 3 | 4 | def is_numpy_installed() -> bool: 5 | try: 6 | import numpy # type: ignore 7 | 8 | return True 9 | except ModuleNotFoundError: 10 | return False 11 | 12 | 13 | def is_pandas_installed() -> bool: 14 | try: 15 | import pandas # type: ignore 16 | 17 | return True 18 | except ModuleNotFoundError: 19 | return False 20 | 21 | 22 | numpy_only = pytest.mark.skipif(not is_numpy_installed(), reason="requires numpy") 23 | 24 | pandas_only = pytest.mark.skipif(not is_pandas_installed(), reason="requires pandas") 25 | -------------------------------------------------------------------------------- /tutorials/003 - Amazon Redshift Feature Support.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": null, 6 | "metadata": { 7 | "collapsed": true 8 | }, 9 | "outputs": [], 10 | "source": [ 11 | "# Amazon Redshift Feature Support" 12 | ] 13 | }, 14 | { 15 | "cell_type": "markdown", 16 | "metadata": {}, 17 | "source": [ 18 | "# Overview\n", 19 | "`redshift_connector` aims to support the latest and greatest features provided by Amazon Redshift so you can get the most out of your data." 20 | ] 21 | }, 22 | { 23 | "cell_type": "markdown", 24 | "metadata": {}, 25 | "source": [ 26 | "## COPY and UNLOAD Support - Amazon S3\n", 27 | "`redshift_connector` provides the ability to `COPY` and `UNLOAD` data from an Amazon S3 bucket. Shown below is a sample workflow which copies and unloads data from an Amazon S3 bucket" 28 | ] 29 | }, 30 | { 31 | "cell_type": "markdown", 32 | "metadata": { 33 | "pycharm": { 34 | "name": "#%% md\n" 35 | } 36 | }, 37 | "source": [ 38 | "1. Upload the following text file to an Amazon S3 bucket and name it `category_csv.txt`\n", 39 | "\n", 40 | "```text\n", 41 | " 12,Shows,Musicals,Musical theatre\n", 42 | " 13,Shows,Plays,\"All \"\"non-musical\"\" theatre\"\n", 43 | " 14,Shows,Opera,\"All opera, light, and \"\"rock\"\" opera\"\n", 44 | " 15,Concerts,Classical,\"All symphony, concerto, and choir concerts\"\n", 45 | "```" 46 | ] 47 | }, 48 | { 49 | "cell_type": "code", 50 | "execution_count": null, 51 | "metadata": { 52 | "pycharm": { 53 | "name": "#%%\n" 54 | } 55 | }, 56 | "outputs": [], 57 | "source": [ 58 | "import redshift_connector\n", 59 | "\n", 60 | "with redshift_connector.connect(\n", 61 | " host='examplecluster.abc123xyz789.us-west-1.redshift.amazonaws.com',\n", 62 | " database='dev',\n", 63 | " user='awsuser',\n", 64 | " password='my_password'\n", 65 | ") as conn:\n", 66 | " with conn.cursor() as cursor:\n", 67 | " cursor.execute(\"create table category (catid int, cargroup varchar, catname varchar, catdesc varchar)\")\n", 68 | " cursor.execute(\"copy category from 's3://testing/category_csv.txt' iam_role 'arn:aws:iam::123:role/RedshiftCopyUnload' csv;\")\n", 69 | " cursor.execute(\"select * from category\")\n", 70 | " print(cursor.fetchall())\n", 71 | " cursor.execute(\"unload ('select * from category') to 's3://testing/unloaded_category_csv.txt' iam_role 'arn:aws:iam::123:role/RedshiftCopyUnload' csv;\")\n", 72 | " print('done')\n" 73 | ] 74 | }, 75 | { 76 | "cell_type": "markdown", 77 | "metadata": {}, 78 | "source": [ 79 | "After executing the above code block, we can see the requested data was unloaded into the following file, `unloaded_category_csv.text0000_part00`, in the specified Amazon s3 bucket\n" 80 | ] 81 | } 82 | ], 83 | "metadata": { 84 | "kernelspec": { 85 | "display_name": "Python 3 (ipykernel)", 86 | "language": "python", 87 | "name": "python3" 88 | }, 89 | "language_info": { 90 | "codemirror_mode": { 91 | "name": "ipython", 92 | "version": 3 93 | }, 94 | "file_extension": ".py", 95 | "mimetype": "text/x-python", 96 | "name": "python", 97 | "nbconvert_exporter": "python", 98 | "pygments_lexer": "ipython3", 99 | "version": "3.9.7" 100 | } 101 | }, 102 | "nbformat": 4, 103 | "nbformat_minor": 1 104 | } 105 | --------------------------------------------------------------------------------