├── .circleci └── config.yml ├── .flake8 ├── .gitignore ├── CONTRIBUTORS ├── LICENSE ├── README.md ├── mypy.ini ├── poetry.lock ├── pyproject.toml ├── pytest.ini ├── storage ├── __init__.py ├── cloudfiles_storage.py ├── ftp_storage.py ├── google_storage.py ├── local_storage.py ├── py.typed ├── retry.py ├── s3_storage.py ├── storage.py ├── swift_storage.py └── url_parser.py ├── stubs ├── boto3 │ ├── __init__.pyi │ ├── s3 │ │ ├── __init__.pyi │ │ └── transfer.pyi │ └── session.pyi ├── botocore │ ├── __init__.pyi │ ├── config.pyi │ ├── exceptions.pyi │ └── session.pyi ├── google │ ├── __init__.pyi │ ├── cloud │ │ ├── __init__.pyi │ │ ├── exceptions.pyi │ │ └── storage │ │ │ ├── __init__.pyi │ │ │ ├── blob.pyi │ │ │ ├── bucket.pyi │ │ │ └── client.pyi │ └── oauth2 │ │ ├── __init__.pyi │ │ └── service_account.pyi ├── keystoneauth1 │ ├── __init__.pyi │ ├── exceptions │ │ ├── __init__.pyi │ │ └── http.pyi │ ├── identity │ │ ├── __init__.pyi │ │ └── v2.pyi │ └── session.pyi └── swiftclient │ ├── __init__.pyi │ ├── client.pyi │ ├── exceptions.pyi │ └── utils.pyi ├── test_integration.py └── tests ├── __init__.py ├── helpers.py ├── service_test_case.py ├── storage_test_case.py ├── swift_service_test_case.py ├── test_cloudfiles_storage.py ├── test_ftp_storage.py ├── test_google_storage.py ├── test_local_storage.py ├── test_retry.py ├── test_s3_storage.py ├── test_storage.py ├── test_swift_storage.py └── test_url_parser.py /.circleci/config.yml: -------------------------------------------------------------------------------- 1 | version: 2.1 2 | jobs: 3 | test: 4 | parameters: 5 | python_version: 6 | description: "The Python version to use for running the tests" 7 | type: string 8 | docker: 9 | - image: cimg/python:<< parameters.python_version >> 10 | environment: 11 | TEST_STORAGE_FILE_URI: file:///tmp/integration-tests 12 | TEST_STORAGE_FTP_URI: ftp://files:password@localhost/integration-tests 13 | # TEST_STORAGE_GS_URI in CircleCI context 14 | # TEST_STORAGE_S3_URI in CircleCI context 15 | - image: onjin/alpine-vsftpd:latest 16 | environment: 17 | PASSWORD: password 18 | command: > 19 | sh -c 20 | "sed -i -e 's#/home/./files#/home/files/./#' /etc/passwd; 21 | echo 'allow_writeable_chroot=YES' >> /etc/vsftpd/vsftpd.conf; 22 | /docker-entrypoint.sh" 23 | 24 | steps: 25 | - checkout 26 | 27 | - run: 28 | name: Save Python Version 29 | command: | 30 | python --version > pythonversion 31 | 32 | - restore_cache: 33 | keys: 34 | - v1-python-{{ checksum "pythonversion" }}-dependencies-{{ checksum "poetry.lock" }} 35 | 36 | - run: 37 | name: install dependencies 38 | command: | 39 | poetry self update --no-ansi -- 1.8.4 40 | poetry install --no-ansi 41 | 42 | mkdir -p test-reports 43 | 44 | - save_cache: 45 | paths: 46 | - ~/.cache/pypoetry/virtualenvs 47 | key: v1-python-{{ checksum "pythonversion" }}-dependencies-{{ checksum "poetry.lock" }} 48 | 49 | - run: 50 | name: run tests 51 | command: | 52 | poetry run pytest --verbose --junit-xml=test-reports/pytest.xml 53 | 54 | - run: 55 | name: run linter 56 | command: | 57 | poetry run flake8 | tee test-reports/flake8-errors 58 | 59 | - run: 60 | name: run typechecks 61 | command: | 62 | poetry run mypy --junit-xml=test-reports/mypy.xml 63 | 64 | - store_artifacts: 65 | path: test-reports 66 | prefix: python-<< parameters.python_version >> 67 | 68 | - store_test_results: 69 | path: test-reports 70 | prefix: python-<< parameters.python_version >> 71 | 72 | publish: 73 | docker: 74 | - image: cimg/python:3.13 75 | working_directory: ~/repo 76 | steps: 77 | - checkout 78 | 79 | - run: 80 | name: Publish to PyPI 81 | command: | 82 | export POETRY_HTTP_BASIC_PYPI_USERNAME=$PYPI_USERNAME 83 | export POETRY_HTTP_BASIC_PYPI_PASSWORD=$PYPI_PASSWORD 84 | 85 | poetry publish --build 86 | 87 | workflows: 88 | version: 2 89 | test-and-build: 90 | jobs: 91 | - test: 92 | name: test-3.9 93 | python_version: "3.9" 94 | filters: 95 | tags: 96 | only: /.*/ 97 | context: storage-library-tester 98 | - test: 99 | name: test-3.10 100 | python_version: "3.10" 101 | filters: 102 | tags: 103 | only: /.*/ 104 | context: storage-library-tester 105 | - test: 106 | name: test-3.11 107 | python_version: "3.11" 108 | filters: 109 | tags: 110 | only: /.*/ 111 | context: storage-library-tester 112 | - test: 113 | name: test-3.12 114 | python_version: "3.12" 115 | filters: 116 | tags: 117 | only: /.*/ 118 | context: storage-library-tester 119 | - test: 120 | name: test-3.13 121 | python_version: "3.13" 122 | filters: 123 | tags: 124 | only: /.*/ 125 | context: storage-library-tester 126 | - publish: 127 | requires: 128 | - test-3.9 129 | - test-3.10 130 | - test-3.11 131 | - test-3.12 132 | - test-3.13 133 | filters: 134 | tags: 135 | only: /^v[0-9]+(\.[0-9]+)*.*/ 136 | branches: 137 | ignore: /.*/ 138 | context: storage-library-publisher 139 | -------------------------------------------------------------------------------- /.flake8: -------------------------------------------------------------------------------- 1 | [flake8] 2 | max-line-length = 100 3 | ignore= 4 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | 5 | # C extensions 6 | *.so 7 | 8 | # Distribution / packaging 9 | .Python 10 | env/ 11 | bin/ 12 | build/ 13 | develop-eggs/ 14 | dist/ 15 | eggs/ 16 | lib/ 17 | lib64/ 18 | parts/ 19 | sdist/ 20 | var/ 21 | *.egg-info/ 22 | .installed.cfg 23 | *.egg 24 | MANIFEST 25 | 26 | # Installer logs 27 | pip-log.txt 28 | pip-delete-this-directory.txt 29 | 30 | # Unit test / coverage reports 31 | htmlcov/ 32 | .tox/ 33 | .coverage 34 | .cache 35 | nosetests.xml 36 | coverage.xml 37 | *.bak 38 | .mypy_cache/ 39 | 40 | # Translations 41 | *.mo 42 | 43 | # Mr Developer 44 | .mr.developer.cfg 45 | .project 46 | .pydevproject 47 | 48 | # Rope 49 | .ropeproject 50 | 51 | # Django stuff: 52 | *.log 53 | *.pot 54 | 55 | # Sphinx documentation 56 | docs/_build/ 57 | 58 | # Virtualenv Stuff 59 | include/ 60 | man/ 61 | share/ 62 | 63 | # Vim Stuff 64 | .vimrc 65 | *.swp 66 | *.swo 67 | .venv 68 | -------------------------------------------------------------------------------- /CONTRIBUTORS: -------------------------------------------------------------------------------- 1 | 2 | uStudio, Inc 3 | John Turner 4 | Alex Tantona 5 | Ricardo Contreras 6 | Thomas Stephens 7 | Josh Marshall 8 | 9 | Brice Grichy 10 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Licensed under the Apache License, Version 2.0 (the "License"); 2 | you may not use this file except in compliance with the License. 3 | You may obtain a copy of the License at 4 | 5 | http://www.apache.org/licenses/LICENSE-2.0 6 | 7 | Unless required by applicable law or agreed to in writing, software 8 | distributed under the License is distributed on an "AS IS" BASIS, 9 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 10 | See the License for the specific language governing permissions and 11 | limitations under the License. 12 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | storage 2 | ======= 3 | 4 | [![build status](https://dl.circleci.com/status-badge/img/gh/ustudio/storage/tree/master.svg?style=svg)](https://dl.circleci.com/status-badge/redirect/gh/ustudio/storage/tree/master) 5 | 6 | Python library for accessing files over various file transfer protocols. 7 | 8 | ## Installation ## 9 | 10 | Install via pip: 11 | 12 | ```sh 13 | pip install object-storage 14 | ``` 15 | 16 | ## Quick Start ## 17 | 18 | ```python 19 | from storage import get_storage 20 | 21 | # Create a reference to a local or remote file, by URI 22 | source_file = get_storage("file:///path/to/file.txt") 23 | 24 | # Upload contents from a local file 25 | source_file.load_from_filename("/path/to/new-source.txt") 26 | 27 | # Save the contents to a local file 28 | source_file.save_to_filename("/path/to/other-file.txt") 29 | 30 | # Delete the remote file 31 | source_file.delete() 32 | ``` 33 | 34 | ## API ## 35 | 36 | ### `get_storage(uri)` ### 37 | 38 | The main entry point to the storage library is the `get_storage` 39 | function, which takes a URI to a file and returns an object which can 40 | perform operations on that file. 41 | 42 | ### `Storage` ### 43 | 44 | The value returned by `get_storage` is a `Storage` object, which 45 | represents a file accessible by the scheme provided to 46 | `get_storage`. This object has the following methods: 47 | 48 | #### `load_from_filename(filename)` #### 49 | 50 | Uploads the contents of the file at `filename` to the location 51 | specified by the URI to `get_storage`. 52 | 53 | #### `load_from_file(file_object)` #### 54 | 55 | Uploads to the location specified by the URI to `get_storage` by 56 | reading from the specified file-like-object. 57 | 58 | #### `load_from_directory(directory_path)` #### 59 | 60 | Uploads to the location specified by the URI to `get_storage` all 61 | of the contents of the directory at directory\_path. 62 | 63 | #### `save_to_filename(filename)` #### 64 | 65 | Downloads the contents of the file specified by the URI to 66 | `get_storage` into a local file at `filename`. 67 | 68 | #### `save_to_file(file_object)` #### 69 | 70 | Downloads the contents of the file specified by the URI to 71 | `get_storage` by writing into a file-like-object. 72 | 73 | #### `save_to_directory(directory_path)` #### 74 | 75 | Downloads the contents of the directory specified by the URI to 76 | `get_storage` into the directory at directory\_path. 77 | 78 | #### `delete()` #### 79 | 80 | Deletes the file specified by the URI to `get_storage`. 81 | 82 | #### `delete_directory()` #### 83 | 84 | Recursively deletes the directory structure specified by the URI to `get_storage()`. 85 | 86 | 87 | #### `get_download_url(seconds=60, key=None)` #### 88 | 89 | Returns a download URL to the object specified by the URI to `get_storage`. 90 | 91 | For **swift** and **s3** based protocols, this will return a time-limited temporary 92 | URL which can be used to GET the object directly from the container in the 93 | object store. By default the URL will only be valid for 60 seconds, but a 94 | different timeout can be specified by using the `seconds` parameter. 95 | 96 | Note that for **swift** based protocols the container must already have a temp url key 97 | set for the container. If it does not have a temp url key, an exception will be raised. 98 | 99 | For local file storage, the call will return a URL formed by joining the `download_url_base` 100 | (included in the URI that was passed to `get_storage`) with the object name. If no 101 | `download_url_base` query param was included in the storage URI, `get_download_url` 102 | will raise a `DownloadUrlBaseUndefinedError` exception. (*see* [**file**](#file) *below*) 103 | 104 | 105 | #### `get_sanitized_uri()` #### 106 | 107 | Removes the username/password, as well as all query parameters, form the URL. 108 | 109 | ### Supported Protocols ### 110 | 111 | The following protocols are supported, and can be selected by 112 | specifying them in the scheme/protocol section of the URI: 113 | 114 | #### file #### 115 | 116 | A reference to a local file. This is primarily useful for running code 117 | in a development environment. 118 | 119 | Example: 120 | 121 | ``` 122 | 123 | file:///home/user/awesome-file.txt[?download_url_base=] 124 | 125 | ``` 126 | 127 | If the intermediate directories specified in the URI passed to 128 | `get_storage` do not exist, the file-local storage object will attempt 129 | to create them when using `load_from_file` or `load_from_filename`. 130 | 131 | If a `download_url_base` is included in the URI specified to `get_storage`, `get_download_url` will 132 | return a URL that that joins the `download_url_base` with the object name. 133 | 134 | For example, if a `download_url_base` of (`http://hostname/some/path/`) is included in the URI: 135 | 136 | ``` 137 | file:///home/user/awesome-file.txt?download_url_base=http%3A%2F%2Fhostname%2Fsome%2Fpath%2F 138 | ``` 139 | 140 | then a call to `get_download_url` will return: 141 | 142 | ``` 143 | http://hostname/some/path/awesome-file.txt 144 | ``` 145 | 146 | For local storage objects both the `seconds` and `key` parameters to `get_download_url` are ignored. 147 | 148 | 149 | #### swift #### 150 | 151 | A reference to an Object in a Container in an **OpenStack Swift** object store. 152 | With this scheme, the `host` section of the URI is the Container name, and 153 | the `path` is the Object. Credentials are specified in the `username` 154 | and `password` fields. 155 | 156 | In addition, the following parameters are **required** and should be passed as 157 | query parameters in the URI: 158 | 159 | | Query Param | Description | 160 | |:----------------|:------------------------------------------------------------------------| 161 | | `auth_endpoint` | The authentication endpoint that should be used by the storage library. | 162 | | `tenant_id` | The tenant ID to be used during authentication. Typically an account or project Id.| 163 | | `region` | The region which the storage library will use when obtaining the appropriate **object_store** client. | 164 | 165 | Example: 166 | 167 | ``` 168 | 169 | swift://username:password@container/file.txt?region=REG&auth_endpoint=http://identity.svr.com:1234/v2&tenant_id=123456 170 | 171 | ``` 172 | 173 | In addition to the required parameters mentioned above, swift will also 174 | accept the following optional parameters: 175 | 176 | | Query Param | Description | 177 | |:----------------|:------------------------------------------------------------------------| 178 | | `public` | Whether or not to use the internal ServiceNet network. This saves bandwidth if you are accessing CloudFiles from within the same datacenter. (default: true) | 179 | | `api_key` | API key to be used during authentication. | 180 | | `temp_url_key` | Key to be used when retrieving a temp download url to the storage object from the **Swift** object store (see `get_download_url()`)| 181 | 182 | **Note** The connection will have a default 60 second timeout on network 183 | operations, which can be set by changing 184 | `storage.storage.DEFAULT_SWIFT_TIMEOUT`, specified in seconds. The 185 | timeout is per data chunk, not for transfer of the entire object. 186 | 187 | 188 | #### cloudfiles #### 189 | 190 | A reference to an Object in a Container in Rackspace CloudFiles. This scheme is similar to 191 | the [**swift**](#swift) scheme with the following differences: 192 | 193 | - The `auth_endpoint` and `tenant_id` need not be specified. These are automatically determined 194 | by Rackspace. 195 | - The `region` parameter is optional, and will default to `DFW` if not 196 | specified. 197 | 198 | 199 | Example: 200 | 201 | ``` 202 | 203 | cloudfiles://username:apikey@container/awesome-file.txt 204 | 205 | ``` 206 | 207 | Because of the way CloudFiles handles "virtual folders," if the 208 | filename specified in `get_storage` includes subfolders, they will be 209 | created automatically if they do not exist. 210 | 211 | **Note**: Currently, the storage library will always connect to the DFW 212 | region in Rackspace; there is no way to specify a region at this 213 | time. It is possible that the URI scheme will change when this support 214 | is added. 215 | 216 | **Note** The connection will have a default 60 second timeout on network 217 | operations, which can be set by changing 218 | `storage.storage.DEFAULT_SWIFT_TIMEOUT`, specified in seconds. The 219 | timeout is per data chunk, not for transfer of the entire object. 220 | 221 | ### Amazon S3 ### 222 | 223 | A reference to an object in an Amazon S3 bucket. The `s3` scheme can be used when storing 224 | files using the Amazon S3 service. 225 | 226 | A `region` parameter is not required, but can be specified. 227 | 228 | **Note:** Chunked transfer encoding is only used for 229 | `save_to_filename` and `load_from_filename`. If you use `save_to_file` 230 | or `load_from_file`, the entire contents of the file will be loaded 231 | into memory. 232 | 233 | Example: 234 | 235 | ``` 236 | 237 | s3://aws_access_key_id:aws_secret_access_key@bucket/path/to/file[?region=us-west-2] 238 | 239 | 240 | ``` 241 | 242 | Note that the `aws_access_key` and `aws_secret_access_key` should be URL encoded, to quote 243 | unsafe characters, if necessary. This may be necessary as AWS sometimes includes characters 244 | such as a `/`. 245 | 246 | #### JSON Credentials 247 | 248 | Credentials can also be provided as JSON in the URI's username field. Do not specify a 249 | password when providing credentials as JSON. The JSON must be URL encoded to quote 250 | special characters. For example: 251 | 252 | ``` 253 | s3://%7B%22version%22%3A1%2C%22key_id%22%3A%22ACCESS-KEY%22%2C%22access_secret%22%3A%22ACCESS-SECRET%22%7D@bucket/path/to/file 254 | ``` 255 | 256 | The JSON credentials must contain the following required attributes: 257 | 258 | ```json 259 | { 260 | "version": 1, 261 | "key_id": "AKIAIOSFODNN7EXAMPLE", 262 | "access_secret": "wJalrXUtnFEMI/K7MDENG/bPxRfiCYzEXAMPLEKEY" 263 | } 264 | ``` 265 | 266 | The `version` must be `1`. Support for other version numbers may be added in the future. 267 | 268 | If the client needs to assume a role before accessing the bucket (for example, to 269 | access a bucket owned by a third party), include the `role`, `role_session_name`, and 270 | `external_id` attributes in the JSON: 271 | 272 | ```json 273 | { 274 | "version": 1, 275 | "key_id": "AKIAIOSFODNN7EXAMPLE", 276 | "access_secret": "wJalrXUtnFEMI/K7MDENG/bPxRfiCYzEXAMPLEKEY", 277 | "role": "arn:aws:iam::123456789012:role/demo", 278 | "role_session_name": "testAssumeRoleSession", 279 | "external_id": "123ABC" 280 | } 281 | ``` 282 | 283 | 284 | ### ftp #### 285 | 286 | A reference to a file on an FTP server. Username and passwords are supported. 287 | 288 | Example: 289 | 290 | ``` 291 | 292 | ftp://username:password@my-ftp-server/directory/awesome-file.txt[?download_url_base=] 293 | 294 | ``` 295 | 296 | **Note** The FTP connection will have a default 60 second timeout on 297 | network operations, which can be set by changing 298 | `storage.storage.DEFAULT_FTP_TIMEOUT`, specified in seconds. The 299 | timeout is per data chunk, not for transfer of the entire object. 300 | 301 | **Note** The FTP connection's command socket will have TCP_KEEPALIVE 302 | turned on by default, as configurable by 303 | `storage.storage.DEFAULT_FTP_KEEPALIVE_ENABLE`, and will configure 304 | TCP keepalive options when the platform supports them, using 305 | similar configuration globals. 306 | 307 | #### ftps #### 308 | 309 | A reference to a file on an FTP server, served using the FTPS 310 | (a.k.a. FTP\_TLS) encrypted protocol. 311 | 312 | Example: 313 | 314 | ``` 315 | ftps://username:password@my-secure-ftp-server/directory/awesome-file.txt[?download_url_base=] 316 | ``` 317 | 318 | **Note** The FTP_TLS connection will have a default timeout and TCP 319 | keepalive specified in the same manner as the `ftp` protocol (see 320 | above). 321 | 322 | #### Google Cloud Storage #### 323 | 324 | A reference to an object in a Google Cloud Storage bucket. The `gs` scheme can 325 | be used when storing files using the Google Cloud Storage service. 326 | 327 | Example: 328 | 329 | ``` 330 | gs://SERVICE-ACCOUNT-DATA@bucket/path/to/file 331 | ``` 332 | 333 | Note that the `SERVICE-ACCOUNT-DATA` should be a URL-safe base64 encoding of 334 | the JSON key for the service account to be used when accessing the storage. 335 | 336 | ### retry ### 337 | 338 | The `retry` module provides a means for client code to attempt to 339 | transfer a file multiple times, in case of network or other 340 | failures. Exponential backoff is used to wait between retries, and the 341 | operation will be tried a maximum of 5 times before giving up. 342 | 343 | No guarantees are made as to the idempotency of the operations. For 344 | example, if your FTP server handles file-naming conflicts by writing 345 | duplicate files to a different location, and the operation retries 346 | because of a network failure *after* some or all of the file has been 347 | transferred, the second attempt might be stored at a different 348 | location. 349 | 350 | In general, this is not a problem as long as the remote servers are 351 | configured to overwrite files by default. 352 | 353 | #### Quick Start #### 354 | 355 | ```python 356 | from storage import get_storage 357 | from storage.retry import attempt 358 | 359 | # Create a reference to a local or remote file, by URI 360 | source_file = get_storage("file:///path/to/file.txt") 361 | 362 | # Upload contents from a local file 363 | attempt(source_file.load_from_filename, "/path/to/new-source.txt") 364 | 365 | # Save the contents to a local file 366 | attempt(source_file.save_to_filename, "/path/to/other-file.txt") 367 | 368 | # Delete the remote file 369 | attempt(source_file.delete) 370 | ``` 371 | 372 | #### API #### 373 | 374 | ##### `attempt(function, *args, **kwargs)` ##### 375 | 376 | Call `function`, passing in `*args` and `**kwargs`. If the function 377 | raises an exception, sleep and try again, using exponential backoff 378 | after each retry. 379 | 380 | If the exception raised has an attribute, `do_not_retry`, set to 381 | `True`, then do not retry the operation. This can be used by the 382 | function to indicate that a failure is not worth retrying 383 | (i.e. username/password is incorrect) or the operation is not safe to 384 | retry. 385 | 386 | Currently, no methods in the storage library mark exceptions as 387 | `do_not_retry`. 388 | 389 | ### url_parser ### 390 | 391 | The `url_parser` module provides a means for client code to sanitize URIs in 392 | such a way that is most appropriate for the way it encodes secret data. 393 | 394 | #### API #### 395 | 396 | ##### `sanitize_resource_uri(parsed_uri)` ##### 397 | 398 | Implementation is overly restrictive -- only returning the scheme, hostname, 399 | port and path, no query parameters. 400 | 401 | ##### `remove_user_info(parsed_uri)` ##### 402 | 403 | Implementation all credential information before the hostname (if present), and 404 | returns the scheme, hostname, port, path, and query parameters. 405 | 406 | ### Extending ### 407 | 408 | There are two decorators that can be used when extending the storage library. 409 | 410 | #### `register_storage_protocol` #### 411 | 412 | This class decorator will register a scheme and its associated class with the storage library. 413 | For example, if a new storage class were implemented (*subclassing from* `storage.Storage`), 414 | a scheme could be registered with the storage library using the `register_storage_protocol`. 415 | 416 | ```python 417 | 418 | @register_storage_protocol("xstorage") 419 | class XStorage(storage.Storage): 420 | ... ... 421 | 422 | ``` 423 | 424 | This would allow the `XStorage` class to be used by making a call to `get_storage()` using the 425 | specified scheme (`"xstorage"`) 426 | 427 | ```python 428 | 429 | xs = storage.get_storage("xstorage://some/xstorage/path") 430 | 431 | ``` 432 | 433 | #### `register_swift_protocol` #### 434 | 435 | This class decorator is used for registering OpenStack Swift storage classes. It is similar to the 436 | `register_storage_protocol` decorator but is specific to classes that are subclasses from 437 | `storage.SwiftStorage`. It accepts two arguments. The first being the scheme it should be 438 | registered under. The second being the authentication endpoint that should be used when 439 | authenticating. 440 | 441 | ```python 442 | 443 | @register_swift_protocol(scheme="ystorage", 444 | auth_endpoint="http://identity.svr.com:1234/v1.0/") 445 | class YStorage(storage.SwiftStorage): 446 | pass 447 | 448 | ``` 449 | 450 | This will register the swift based storage protocol under the "ystorage" scheme using the specified 451 | authentication endpoint. 452 | 453 | ```python 454 | 455 | ys = storage.get_storage("ystorage://user:pass@container/obj?region=REG&tenant_id=1234") 456 | 457 | ``` 458 | -------------------------------------------------------------------------------- /mypy.ini: -------------------------------------------------------------------------------- 1 | [mypy] 2 | warn_unused_configs = True 3 | disallow_subclassing_any = True 4 | disallow_any_generics = True 5 | disallow_untyped_calls = True 6 | disallow_untyped_defs = True 7 | disallow_incomplete_defs = True 8 | check_untyped_defs = True 9 | disallow_untyped_decorators = True 10 | no_implicit_optional = True 11 | warn_redundant_casts = True 12 | warn_unused_ignores = True 13 | warn_return_any = True 14 | implicit_reexport = False 15 | strict_equality = True 16 | extra_checks = True 17 | mypy_path = stubs 18 | files = storage/,tests/,test_integration.py 19 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [tool.poetry] 2 | name = "object-storage" 3 | version = "0.15.3" 4 | description = "Python library for accessing files over various file transfer protocols." 5 | authors = ["uStudio Developers "] 6 | repository = "https://github.com/ustudio/storage" 7 | license = "Apache-2.0" 8 | readme = "README.md" 9 | packages = [{include = "storage"}] 10 | 11 | [tool.poetry.dependencies] 12 | python = "^3.9" 13 | boto3 = "^1.35.68" 14 | google-cloud-storage = "^2.18.2" 15 | python-swiftclient = "^4.6.0" 16 | keystoneauth1 = "^5.9.0" 17 | 18 | 19 | [tool.poetry.group.dev.dependencies] 20 | flake8 = "^7.1.1" 21 | mypy = "^1.13.0" 22 | pytest = "^8.3.3" 23 | 24 | [build-system] 25 | requires = ["poetry-core"] 26 | build-backend = "poetry.core.masonry.api" 27 | -------------------------------------------------------------------------------- /pytest.ini: -------------------------------------------------------------------------------- 1 | [pytest] 2 | filterwarnings = 3 | error 4 | once:HTTPResponse\.getheader\(\) is deprecated and will be removed.*:DeprecationWarning:swiftclient\.client 5 | # See https://github.com/boto/boto3/issues/3889 6 | once:datetime\.datetime\.utcnow\(\) is deprecated.*:DeprecationWarning:botocore\..* 7 | -------------------------------------------------------------------------------- /storage/__init__.py: -------------------------------------------------------------------------------- 1 | from storage.storage import get_storage, register_storage_protocol, NotFoundError # noqa: F401 2 | from storage import local_storage # noqa: F401 3 | from storage import swift_storage # noqa: F401 4 | from storage import cloudfiles_storage # noqa: F401 5 | from storage import ftp_storage # noqa: F401 6 | from storage import s3_storage # noqa: F401 7 | from storage import google_storage # noqa: F401 8 | from storage.swift_storage import register_swift_protocol # noqa: F401 9 | 10 | __all__ = [ 11 | "get_storage", 12 | "NotFoundError" 13 | ] 14 | -------------------------------------------------------------------------------- /storage/cloudfiles_storage.py: -------------------------------------------------------------------------------- 1 | from urllib.parse import parse_qs 2 | 3 | from keystoneauth1 import session 4 | from keystoneauth1.identity import v2 5 | import swiftclient 6 | 7 | from typing import Any, Dict 8 | 9 | from storage.storage import get_optional_query_parameter, InvalidStorageUri, DEFAULT_SWIFT_TIMEOUT 10 | from storage.swift_storage import register_swift_protocol, SwiftStorage 11 | 12 | 13 | class RackspaceAuth(v2.Password): 14 | 15 | def get_auth_data(self, *args: Any, **kwargs: Any) -> Dict[str, Any]: 16 | auth_data = super().get_auth_data(*args, **kwargs) 17 | return { 18 | "RAX-KSKEY:apiKeyCredentials": { 19 | "username": auth_data["passwordCredentials"]["username"], 20 | "apiKey": auth_data["passwordCredentials"]["password"] 21 | } 22 | } 23 | 24 | 25 | @register_swift_protocol("cloudfiles", "https://identity.api.rackspacecloud.com/v2.0") 26 | class CloudFilesStorage(SwiftStorage): 27 | 28 | def _validate_parsed_uri(self) -> None: 29 | query = parse_qs(self._parsed_storage_uri.query) 30 | public_value = get_optional_query_parameter(query, "public") 31 | public_value = public_value if public_value is not None else "true" 32 | self.public_endpoint = "publicURL" if public_value.lower() == "true" else "internalURL" 33 | region = get_optional_query_parameter(query, "region") 34 | self.region = region if region is not None else "DFW" 35 | self.download_url_key = get_optional_query_parameter(query, "download_url_key") 36 | 37 | if self._parsed_storage_uri.username is None or self._parsed_storage_uri.username == "": 38 | raise InvalidStorageUri("Missing username") 39 | if self._parsed_storage_uri.password is None or self._parsed_storage_uri.password == "": 40 | raise InvalidStorageUri("Missing API key") 41 | if self._parsed_storage_uri.hostname is None: 42 | raise InvalidStorageUri("Missing hostname") 43 | 44 | self._username = self._parsed_storage_uri.username 45 | self._password = self._parsed_storage_uri.password 46 | self._hostname = self._parsed_storage_uri.hostname 47 | 48 | def get_connection(self) -> swiftclient.client.Connection: 49 | if not hasattr(self, "_connection"): 50 | os_options = { 51 | "region_name": self.region, 52 | "endpoint_type": self.public_endpoint 53 | } 54 | 55 | auth = RackspaceAuth( 56 | auth_url=self.auth_endpoint, username=self._username, password=self._password) 57 | 58 | keystone_session = session.Session(auth=auth) 59 | 60 | connection = swiftclient.client.Connection( 61 | session=keystone_session, os_options=os_options, timeout=DEFAULT_SWIFT_TIMEOUT) 62 | 63 | if self.download_url_key is None: 64 | for header_key, header_value in connection.head_account().items(): 65 | if header_key.endswith("temp-url-key"): 66 | self.download_url_key = header_value 67 | break 68 | 69 | self._connection = connection 70 | return self._connection 71 | -------------------------------------------------------------------------------- /storage/ftp_storage.py: -------------------------------------------------------------------------------- 1 | import contextlib 2 | import ftplib 3 | from ftplib import FTP, error_perm 4 | import os 5 | import re 6 | import socket 7 | from urllib.parse import parse_qsl 8 | 9 | from typing import BinaryIO, Generator, List, Optional, Tuple 10 | 11 | from storage.storage import Storage, register_storage_protocol, _generate_download_url_from_base 12 | from storage.storage import InvalidStorageUri 13 | from storage.storage import DEFAULT_FTP_TIMEOUT, DEFAULT_FTP_KEEPALIVE_ENABLE, DEFAULT_FTP_KEEPCNT 14 | from storage.storage import DEFAULT_FTP_KEEPIDLE, DEFAULT_FTP_KEEPINTVL, NotFoundError 15 | from storage.url_parser import remove_user_info 16 | 17 | 18 | class FTPStorageError(Exception): 19 | pass 20 | 21 | 22 | @register_storage_protocol("ftp") 23 | class FTPStorage(Storage): 24 | """FTP storage. 25 | 26 | The URI for working with FTP storage has the following format: 27 | 28 | ftp://username:password@hostname/path/to/file.txt[?download_url_base=] 29 | 30 | ftp://username:password@hostname/path/to/directory[?download_url_base=] 31 | 32 | If the ftp storage has access via HTTP, then a download_url_base can be specified 33 | that will allow get_download_url() to return access to that object via HTTP. 34 | """ 35 | 36 | _download_url_base: Optional[str] 37 | 38 | def __init__(self, storage_uri: str) -> None: 39 | super(FTPStorage, self).__init__(storage_uri) 40 | if self._parsed_storage_uri.username is None: 41 | raise InvalidStorageUri("Missing username") 42 | if self._parsed_storage_uri.password is None: 43 | raise InvalidStorageUri("Missing password") 44 | if self._parsed_storage_uri.hostname is None: 45 | raise InvalidStorageUri("Missing hostname") 46 | 47 | self._username = self._parsed_storage_uri.username 48 | self._password = self._parsed_storage_uri.password 49 | self._hostname = self._parsed_storage_uri.hostname 50 | self._port = \ 51 | self._parsed_storage_uri.port if self._parsed_storage_uri.port is not None else 21 52 | query = dict(parse_qsl(self._parsed_storage_uri.query)) 53 | self._download_url_base = query.get("download_url_base", None) 54 | 55 | def _configure_keepalive(self, ftp_client: FTP) -> None: 56 | sock = ftp_client.sock 57 | if sock is None: 58 | raise FTPStorageError("FTP Client not fully initialized") 59 | 60 | sock.setsockopt( 61 | socket.SOL_SOCKET, socket.SO_KEEPALIVE, DEFAULT_FTP_KEEPALIVE_ENABLE) 62 | 63 | if hasattr(socket, "TCP_KEEPCNT"): 64 | sock.setsockopt(socket.SOL_TCP, socket.TCP_KEEPCNT, DEFAULT_FTP_KEEPCNT) 65 | 66 | if hasattr(socket, "TCP_KEEPIDLE"): 67 | sock.setsockopt(socket.SOL_TCP, socket.TCP_KEEPIDLE, DEFAULT_FTP_KEEPIDLE) 68 | 69 | if hasattr(socket, "TCP_KEEPINTVL"): 70 | sock.setsockopt(socket.SOL_TCP, socket.TCP_KEEPINTVL, DEFAULT_FTP_KEEPINTVL) 71 | 72 | @contextlib.contextmanager 73 | def _connect(self) -> Generator[FTP, None, None]: 74 | ftp_client = ftplib.FTP(timeout=DEFAULT_FTP_TIMEOUT) 75 | try: 76 | ftp_client.connect(self._hostname, port=self._port) 77 | 78 | self._configure_keepalive(ftp_client) 79 | 80 | ftp_client.login(self._username, self._password) 81 | 82 | yield ftp_client 83 | finally: 84 | ftp_client.close() 85 | 86 | def _cd_to_file(self, ftp_client: FTP) -> str: 87 | directory, filename = os.path.split(self._parsed_storage_uri.path.lstrip("/")) 88 | ftp_client.cwd(directory) 89 | return filename 90 | 91 | def _list(self, ftp_client: FTP) -> Tuple[List[str], List[str]]: 92 | directory_listing: List[str] = [] 93 | 94 | ftp_client.retrlines('LIST', directory_listing.append) 95 | 96 | directories = [] 97 | files = [] 98 | 99 | for line in directory_listing: 100 | name = re.split(r"\s+", line, maxsplit=8)[-1] 101 | 102 | if line.lower().startswith("d"): 103 | directories.append(name) 104 | else: 105 | files.append(name) 106 | 107 | return directories, files 108 | 109 | def _walk( 110 | self, ftp_client: FTP, 111 | target_directory: Optional[str] = None) -> \ 112 | Generator[Tuple[str, List[str], List[str]], None, None]: 113 | if target_directory: 114 | ftp_client.cwd(target_directory) 115 | else: 116 | target_directory = ftp_client.pwd() 117 | 118 | dirs, files = self._list(ftp_client) 119 | 120 | yield target_directory, dirs, files 121 | 122 | for name in dirs: 123 | new_target = os.path.join(target_directory, name) 124 | 125 | for result in self._walk(ftp_client, target_directory=new_target): 126 | yield result 127 | 128 | def _create_directory_structure( 129 | self, ftp_client: FTP, target_path: str, restore: Optional[bool] = False) -> None: 130 | directories = target_path.lstrip('/').split('/') 131 | 132 | if restore: 133 | dirpath = ftp_client.pwd() 134 | 135 | for target_directory in directories: 136 | dirs, _ = self._list(ftp_client) 137 | 138 | # TODO (phd): warn the user that a file exists with the name of their target dir 139 | if target_directory not in dirs: 140 | ftp_client.mkd(target_directory) 141 | 142 | ftp_client.cwd(target_directory) 143 | 144 | if restore: 145 | ftp_client.cwd(dirpath) 146 | 147 | def save_to_filename(self, file_path: str) -> None: 148 | with open(file_path, "wb") as output_file: 149 | self.save_to_file(output_file) 150 | 151 | def save_to_file(self, out_file: BinaryIO) -> None: 152 | with self._connect() as ftp_client: 153 | filename = self._cd_to_file(ftp_client) 154 | 155 | try: 156 | ftp_client.retrbinary("RETR {0}".format(filename), callback=out_file.write) 157 | except error_perm as original_exc: 158 | if original_exc.args[0][:3] == "550": 159 | raise NotFoundError("No File Found") from original_exc 160 | raise original_exc 161 | 162 | def save_to_directory(self, destination_directory: str) -> None: 163 | with self._connect() as ftp_client: 164 | base_ftp_path = self._parsed_storage_uri.path 165 | 166 | try: 167 | ftp_client.cwd(base_ftp_path) 168 | 169 | for root, dirs, files in self._walk(ftp_client): 170 | relative_path = "/{}".format(root).replace( 171 | base_ftp_path, destination_directory, 1) 172 | 173 | if not os.path.exists(relative_path): 174 | os.makedirs(relative_path) 175 | 176 | os.chdir(relative_path) 177 | 178 | for filename in files: 179 | with open(os.path.join(relative_path, filename), "wb") as output_file: 180 | ftp_client.retrbinary( 181 | "RETR {0}".format(filename), callback=output_file.write) 182 | except error_perm as original_exc: 183 | if original_exc.args[0][:3] == "550": 184 | raise NotFoundError("No File Found") from original_exc 185 | raise original_exc 186 | 187 | def load_from_filename(self, file_path: str) -> None: 188 | with open(file_path, "rb") as input_file: 189 | self.load_from_file(input_file) 190 | 191 | def load_from_file(self, in_file: BinaryIO) -> None: 192 | with self._connect() as ftp_client: 193 | filename = self._cd_to_file(ftp_client) 194 | 195 | ftp_client.storbinary("STOR {0}".format(filename), in_file) 196 | 197 | def load_from_directory(self, source_directory: str) -> None: 198 | with self._connect() as ftp_client: 199 | base_ftp_path = self._parsed_storage_uri.path 200 | 201 | self._create_directory_structure(ftp_client, base_ftp_path) 202 | 203 | for root, dirs, files in os.walk(source_directory): 204 | relative_ftp_path = root.replace(source_directory, base_ftp_path, 1) 205 | 206 | ftp_client.cwd(relative_ftp_path) 207 | 208 | for directory in dirs: 209 | self._create_directory_structure(ftp_client, directory, restore=True) 210 | 211 | for file in files: 212 | file_path = os.path.join(root, file) 213 | 214 | with open(file_path, "rb") as input_file: 215 | ftp_client.storbinary("STOR {0}".format(file), input_file) 216 | 217 | def delete(self) -> None: 218 | with self._connect() as ftp_client: 219 | filename = self._cd_to_file(ftp_client) 220 | 221 | try: 222 | ftp_client.delete(filename) 223 | except error_perm as original_exc: 224 | if original_exc.args[0][:3] == "550": 225 | raise NotFoundError("No File Found") from original_exc 226 | raise original_exc 227 | 228 | def delete_directory(self) -> None: 229 | with self._connect() as ftp_client: 230 | base_ftp_path = self._parsed_storage_uri.path 231 | 232 | try: 233 | ftp_client.cwd(base_ftp_path) 234 | 235 | directories_to_remove = [] 236 | for root, directories, files in self._walk(ftp_client): 237 | for filename in files: 238 | ftp_client.delete("/{}/{}".format(root, filename)) 239 | 240 | directories_to_remove.append("/{}".format(root)) 241 | except error_perm as original_exc: 242 | if original_exc.args[0][:3] == "550": 243 | raise NotFoundError("No File Found") from original_exc 244 | raise original_exc 245 | 246 | # delete directories _after_ removing files from directories 247 | # directories should be removed in reverse order - leaf directories before 248 | # parent directories - since there is no recursive delete 249 | directories_to_remove.sort(reverse=True) 250 | for directory in directories_to_remove: 251 | ftp_client.rmd("{}".format(directory)) 252 | 253 | def get_download_url(self, seconds: int = 60, key: Optional[str] = None) -> str: 254 | """ 255 | Return a temporary URL allowing access to the storage object. 256 | 257 | If a download_url_base is specified in the storage URI, then a call to get_download_url() 258 | will return the download_url_base joined with the object name. 259 | 260 | For example, if "http://www.someserver.com:1234/path/to/" were passed (urlencoded) as the 261 | download_url_base query parameter of the storage URI: 262 | 263 | ftp://username:password@hostname/some/path/to/a/file.txt?download_url_base=http%3A%2F%2Fwww 264 | .someserver.com%3A1234%2Fpath%2Fto%2 265 | 266 | then a call to get_download_url() would yield: 267 | 268 | http://www.someserver.com:1234/path/to/file.txt 269 | 270 | 271 | :param seconds: ignored for ftp storage 272 | :param key: ignored for ftp storage 273 | :return: the download url that can be used to access the storage object 274 | :raises: DownloadUrlBaseUndefinedError 275 | """ 276 | base = self._download_url_base 277 | object_name = self._parsed_storage_uri.path.split('/')[-1] 278 | return _generate_download_url_from_base(base, object_name) 279 | 280 | def get_sanitized_uri(self) -> str: 281 | return remove_user_info(self._parsed_storage_uri) 282 | 283 | 284 | @register_storage_protocol("ftps") 285 | class FTPSStorage(FTPStorage): 286 | @contextlib.contextmanager 287 | def _connect(self) -> Generator[FTP, None, None]: 288 | ftp_client = ftplib.FTP_TLS(timeout=DEFAULT_FTP_TIMEOUT) 289 | try: 290 | ftp_client.connect(self._hostname, port=self._port) 291 | self._configure_keepalive(ftp_client) 292 | ftp_client.login(self._username, self._password) 293 | ftp_client.prot_p() 294 | 295 | yield ftp_client 296 | finally: 297 | ftp_client.close() 298 | -------------------------------------------------------------------------------- /storage/google_storage.py: -------------------------------------------------------------------------------- 1 | import base64 2 | import datetime 3 | import json 4 | import mimetypes 5 | import os 6 | 7 | from google.cloud.exceptions import NotFound 8 | import google.cloud.storage.client 9 | from google.cloud.storage.bucket import Bucket 10 | from google.cloud.storage.blob import Blob 11 | import google.oauth2.service_account 12 | 13 | from typing import BinaryIO, Optional 14 | 15 | from storage import retry 16 | from storage.storage import Storage, register_storage_protocol, NotFoundError, InvalidStorageUri 17 | 18 | 19 | @register_storage_protocol("gs") 20 | class GoogleStorage(Storage): 21 | 22 | def _validate_parsed_uri(self) -> None: 23 | if self._parsed_storage_uri.username is None: 24 | raise InvalidStorageUri("Missing username") 25 | if self._parsed_storage_uri.hostname is None: 26 | raise InvalidStorageUri("Missing hostname") 27 | 28 | self._username = self._parsed_storage_uri.username 29 | self._hostname = self._parsed_storage_uri.hostname 30 | 31 | def _get_bucket(self) -> Bucket: 32 | credentials_data = json.loads(base64.urlsafe_b64decode(self._username)) 33 | credentials = google.oauth2.service_account.Credentials.from_service_account_info( 34 | credentials_data) 35 | client = google.cloud.storage.client.Client( 36 | project=credentials_data["project_id"], credentials=credentials) 37 | return client.get_bucket(self._hostname) 38 | 39 | def _get_blob(self) -> Blob: 40 | bucket = self._get_bucket() 41 | blob = bucket.blob(self._parsed_storage_uri.path[1:]) 42 | return blob 43 | 44 | def save_to_filename(self, file_path: str) -> None: 45 | blob = self._get_blob() 46 | try: 47 | blob.download_to_filename(file_path) 48 | except NotFound as original_exc: 49 | raise NotFoundError("No File Found") from original_exc 50 | 51 | def save_to_file(self, out_file: BinaryIO) -> None: 52 | blob = self._get_blob() 53 | try: 54 | blob.download_to_file(out_file) 55 | except NotFound as original_exc: 56 | raise NotFoundError("No File Found") from original_exc 57 | 58 | def load_from_filename(self, file_path: str) -> None: 59 | blob = self._get_blob() 60 | blob.upload_from_filename(file_path) 61 | 62 | def load_from_file(self, in_file: BinaryIO) -> None: 63 | blob = self._get_blob() 64 | blob.upload_from_file( 65 | in_file, 66 | content_type=mimetypes.guess_type(self._storage_uri)[0]) 67 | 68 | def delete(self) -> None: 69 | blob = self._get_blob() 70 | try: 71 | blob.delete() 72 | except NotFound as original_exc: 73 | raise NotFoundError("No File Found") from original_exc 74 | 75 | def get_download_url(self, seconds: int = 60, key: Optional[str] = None) -> str: 76 | blob = self._get_blob() 77 | return blob.generate_signed_url( 78 | expiration=datetime.timedelta(seconds=seconds), 79 | response_disposition="attachment") 80 | 81 | def save_to_directory(self, directory_path: str) -> None: 82 | bucket = self._get_bucket() 83 | 84 | prefix = self._parsed_storage_uri.path[1:] + "/" 85 | 86 | count = 0 87 | 88 | for blob in bucket.list_blobs(prefix=prefix): 89 | count += 1 90 | relative_path = blob.name.replace(prefix, "", 1) 91 | local_file_path = os.path.join(directory_path, relative_path) 92 | local_directory = os.path.dirname(local_file_path) 93 | 94 | if not os.path.exists(local_directory): 95 | os.makedirs(local_directory) 96 | 97 | if not relative_path[-1] == "/": 98 | unversioned_blob = bucket.blob(blob.name) 99 | try: 100 | retry.attempt(unversioned_blob.download_to_filename, local_file_path) 101 | except NotFound as original_exc: 102 | raise NotFoundError("No File Found") from original_exc 103 | 104 | if count == 0: 105 | raise NotFoundError("No Files Found") 106 | 107 | def load_from_directory(self, directory_path: str) -> None: 108 | bucket = self._get_bucket() 109 | 110 | prefix = self._parsed_storage_uri.path[1:] 111 | 112 | for root, _, files in os.walk(directory_path): 113 | remote_path = root.replace(directory_path, prefix, 1) 114 | 115 | for filename in files: 116 | blob = bucket.blob("/".join([remote_path, filename])) 117 | retry.attempt(blob.upload_from_filename, os.path.join(root, filename)) 118 | 119 | def delete_directory(self) -> None: 120 | bucket = self._get_bucket() 121 | 122 | count = 0 123 | 124 | for blob in bucket.list_blobs(prefix=self._parsed_storage_uri.path[1:] + "/"): 125 | count += 1 126 | unversioned_blob = bucket.blob(blob.name) 127 | try: 128 | unversioned_blob.delete() 129 | except NotFound as original_exc: 130 | raise NotFoundError("No File Found") from original_exc 131 | 132 | if count == 0: 133 | raise NotFoundError("No Files Found") 134 | -------------------------------------------------------------------------------- /storage/local_storage.py: -------------------------------------------------------------------------------- 1 | import os 2 | import shutil 3 | from urllib.parse import parse_qs 4 | 5 | from typing import BinaryIO, Optional 6 | 7 | from storage.storage import get_optional_query_parameter, Storage, register_storage_protocol 8 | from storage.storage import _generate_download_url_from_base, NotFoundError 9 | from storage.url_parser import remove_user_info 10 | 11 | 12 | @register_storage_protocol("file") 13 | class LocalStorage(Storage): 14 | """LocalStorage is a local file storage object. 15 | 16 | The URI for working with local file storage has the following format: 17 | 18 | file:///some/path/to/a/file.txt?[download_url_base=] 19 | 20 | file:///some/path/to/a/directory?[download_url_base=] 21 | 22 | """ 23 | 24 | def _validate_parsed_uri(self) -> None: 25 | query = parse_qs(self._parsed_storage_uri.query) 26 | self._download_url_base = get_optional_query_parameter(query, "download_url_base") 27 | 28 | def save_to_filename(self, file_path: str) -> None: 29 | try: 30 | shutil.copy(self._parsed_storage_uri.path, file_path) 31 | except FileNotFoundError: 32 | raise NotFoundError("No File Found") 33 | 34 | def save_to_file(self, out_file: BinaryIO) -> None: 35 | try: 36 | with open(self._parsed_storage_uri.path, "rb") as in_file: 37 | for chunk in in_file: 38 | out_file.write(chunk) 39 | except FileNotFoundError: 40 | raise NotFoundError("No File Found") 41 | 42 | def save_to_directory(self, destination_directory: str) -> None: 43 | try: 44 | shutil.copytree( 45 | self._parsed_storage_uri.path, destination_directory, dirs_exist_ok=True) 46 | except FileNotFoundError: 47 | raise NotFoundError("No Files Found") 48 | 49 | def _ensure_exists(self) -> None: 50 | dirname = os.path.dirname(self._parsed_storage_uri.path) 51 | 52 | if not os.path.exists(dirname): 53 | os.makedirs(dirname) 54 | 55 | def load_from_filename(self, file_path: str) -> None: 56 | self._ensure_exists() 57 | 58 | shutil.copy(file_path, self._parsed_storage_uri.path) 59 | 60 | def load_from_file(self, in_file: BinaryIO) -> None: 61 | self._ensure_exists() 62 | 63 | with open(self._parsed_storage_uri.path, "wb") as out_file: 64 | for chunk in in_file: 65 | out_file.write(chunk) 66 | 67 | def load_from_directory(self, source_directory: str) -> None: 68 | self._ensure_exists() 69 | shutil.copytree(source_directory, self._parsed_storage_uri.path, dirs_exist_ok=True) 70 | 71 | def delete(self) -> None: 72 | try: 73 | os.remove(self._parsed_storage_uri.path) 74 | except FileNotFoundError: 75 | raise NotFoundError("No File Found") 76 | 77 | def delete_directory(self) -> None: 78 | try: 79 | shutil.rmtree(self._parsed_storage_uri.path) 80 | except FileNotFoundError: 81 | raise NotFoundError("No Files Found") 82 | 83 | def get_download_url(self, seconds: int = 60, key: Optional[str] = None) -> str: 84 | """ 85 | Return a temporary URL allowing access to the storage object. 86 | 87 | If a download_url_base is specified in the storage URI, then a call to get_download_url() 88 | will return the download_url_base joined with the object name. 89 | 90 | For example, if "http://www.someserver.com:1234/path/to/" were passed (urlencoded) as the 91 | download_url_base query parameter of the storage URI: 92 | 93 | file://some/path/to/a/file.txt?download_url_base=http%3A%2F%2Fwww.someserver.com%3A1234%2Fpath%2Fto%2 94 | 95 | then a call to get_download_url() would yield: 96 | 97 | http://www.someserver.com:1234/path/to/file.txt 98 | 99 | 100 | :param seconds: ignored for local storage 101 | :param key: ignored for local storage 102 | :return: the download url that can be used to access the storage object 103 | :raises: DownloadUrlBaseUndefinedError 104 | """ 105 | return _generate_download_url_from_base( 106 | self._download_url_base, self._parsed_storage_uri.path.split('/')[-1]) 107 | 108 | def get_sanitized_uri(self) -> str: 109 | return remove_user_info(self._parsed_storage_uri) 110 | -------------------------------------------------------------------------------- /storage/py.typed: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ustudio/storage/f72b2a71e3082de1b3dd728b32b939d80531fe4e/storage/py.typed -------------------------------------------------------------------------------- /storage/retry.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import random 3 | import time 4 | 5 | from typing import Any, Callable, TypeVar 6 | 7 | 8 | max_attempts: int = 5 9 | 10 | 11 | T = TypeVar("T") 12 | 13 | 14 | def attempt(f: Callable[..., T], *args: Any, **kwargs: Any) -> T: 15 | attempts = 0 16 | 17 | while True: 18 | try: 19 | return f(*args, **kwargs) 20 | except Exception as e: 21 | if getattr(e, "do_not_retry", False): 22 | raise 23 | 24 | attempts += 1 25 | 26 | if attempts >= max_attempts: 27 | raise 28 | 29 | sleep_time = random.uniform(0, (2 ** attempts) - 1) 30 | time.sleep(sleep_time) 31 | 32 | logging.warning(f"Retry attempt #{attempts}", exc_info=True) 33 | -------------------------------------------------------------------------------- /storage/s3_storage.py: -------------------------------------------------------------------------------- 1 | import json 2 | import mimetypes 3 | import os 4 | from urllib.parse import parse_qs, unquote 5 | 6 | import boto3.session 7 | import boto3.s3.transfer 8 | import botocore.config 9 | from botocore.exceptions import ClientError 10 | from botocore.session import Session 11 | 12 | from typing import BinaryIO, Optional 13 | 14 | from storage import retry 15 | from storage.storage import Storage, NotFoundError, register_storage_protocol, _LARGE_CHUNK 16 | from storage.storage import get_optional_query_parameter, InvalidStorageUri 17 | from storage.url_parser import remove_user_info 18 | 19 | 20 | @register_storage_protocol("s3") 21 | class S3Storage(Storage): 22 | 23 | def __init__(self, storage_uri: str) -> None: 24 | super(S3Storage, self).__init__(storage_uri) 25 | if self._parsed_storage_uri.username is None: 26 | raise InvalidStorageUri("Missing username") 27 | if self._parsed_storage_uri.hostname is None: 28 | raise InvalidStorageUri("Missing hostname") 29 | 30 | self._username = self._parsed_storage_uri.username 31 | self._password = self._parsed_storage_uri.password 32 | self._bucket = self._parsed_storage_uri.hostname 33 | self._keyname = self._parsed_storage_uri.path.replace("/", "", 1) 34 | 35 | def _validate_parsed_uri(self) -> None: 36 | query = parse_qs(self._parsed_storage_uri.query) 37 | self._region = get_optional_query_parameter(query, "region") 38 | 39 | def _connect(self) -> Session: 40 | session_token = None 41 | role = None 42 | role_session_name = None 43 | external_id = None 44 | 45 | if self._password is None: 46 | credentials = json.loads(unquote(self._username)) 47 | 48 | if credentials.get("version") != 1: 49 | raise InvalidStorageUri("Invalid credentials version") 50 | if "key_id" not in credentials: 51 | raise InvalidStorageUri("Missing credentials key_id") 52 | if "access_secret" not in credentials: 53 | raise InvalidStorageUri("Missing credentials access_secret") 54 | 55 | access_key = credentials["key_id"] 56 | access_secret = credentials["access_secret"] 57 | role = credentials.get("role") 58 | role_session_name = credentials.get("role_session_name") 59 | external_id = credentials.get("external_id") 60 | else: 61 | access_key = unquote(self._username) 62 | access_secret = unquote(self._password) 63 | 64 | if role is not None: 65 | if role_session_name is None: 66 | raise InvalidStorageUri("Missing credentials role_session_name") 67 | 68 | sts_client = boto3.client( 69 | "sts", aws_access_key_id=access_key, aws_secret_access_key=access_secret) 70 | 71 | response = sts_client.assume_role( 72 | RoleArn=role, RoleSessionName=role_session_name, ExternalId=external_id) 73 | 74 | access_key = response["Credentials"]["AccessKeyId"] 75 | access_secret = response["Credentials"]["SecretAccessKey"] 76 | session_token = response["Credentials"]["SessionToken"] 77 | 78 | aws_session = boto3.session.Session( 79 | aws_access_key_id=access_key, 80 | aws_secret_access_key=access_secret, 81 | aws_session_token=session_token, 82 | region_name=self._region) 83 | 84 | return aws_session.client("s3", config=botocore.config.Config(signature_version="v4")) 85 | 86 | def save_to_filename(self, file_path: str) -> None: 87 | client = self._connect() 88 | 89 | transfer = boto3.s3.transfer.S3Transfer(client) 90 | try: 91 | transfer.download_file(self._bucket, self._keyname, file_path) 92 | except ClientError as original_exc: 93 | if original_exc.response["Error"]["Code"] == "404": 94 | raise NotFoundError("No File Found") from original_exc 95 | raise original_exc 96 | 97 | def save_to_file(self, out_file: BinaryIO) -> None: 98 | client = self._connect() 99 | 100 | response = client.get_object(Bucket=self._bucket, Key=self._keyname) 101 | 102 | if "Body" not in response: 103 | raise NotFoundError("No File Found") 104 | 105 | while True: 106 | chunk = response["Body"].read(_LARGE_CHUNK) 107 | out_file.write(chunk) 108 | if not chunk: 109 | break 110 | 111 | def save_to_directory(self, directory_path: str) -> None: 112 | client = self._connect() 113 | directory_prefix = "{}/".format(self._keyname) 114 | dir_object = client.list_objects(Bucket=self._bucket, Prefix=directory_prefix) 115 | 116 | if "Contents" not in dir_object: 117 | raise NotFoundError("No Files Found") 118 | 119 | dir_contents = dir_object["Contents"] 120 | 121 | while dir_object["IsTruncated"]: 122 | dir_object = client.list_objects( 123 | Bucket=self._bucket, 124 | Prefix=directory_prefix, 125 | Marker=dir_object["Contents"][-1]["Key"]) 126 | 127 | dir_contents += dir_object["Contents"] 128 | 129 | for obj in dir_contents: 130 | file_key = obj["Key"].replace(self._keyname, "", 1) 131 | 132 | if file_key and not file_key.endswith("/"): 133 | file_path = os.path.dirname(file_key) 134 | 135 | if not os.path.exists(directory_path + file_path): 136 | os.makedirs(directory_path + file_path) 137 | 138 | try: 139 | retry.attempt( 140 | client.download_file, self._bucket, obj["Key"], directory_path + file_key) 141 | except ClientError as original_exc: 142 | if original_exc.response["Error"]["Code"] == "404": 143 | raise NotFoundError("No File Found") from original_exc 144 | raise original_exc 145 | 146 | def load_from_filename(self, file_path: str) -> None: 147 | client = self._connect() 148 | 149 | extra_args = None 150 | content_type = mimetypes.guess_type(file_path)[0] 151 | if content_type is not None: 152 | extra_args = {"ContentType": content_type} 153 | 154 | transfer = boto3.s3.transfer.S3Transfer(client) 155 | transfer.upload_file(file_path, self._bucket, self._keyname, extra_args=extra_args) 156 | 157 | def load_from_file(self, in_file: BinaryIO) -> None: 158 | client = self._connect() 159 | 160 | extra_args: dict[str, str] = {} 161 | 162 | content_type = mimetypes.guess_type(self._storage_uri)[0] 163 | if content_type is not None: 164 | extra_args["ContentType"] = content_type 165 | 166 | client.put_object(Bucket=self._bucket, Key=self._keyname, Body=in_file, **extra_args) 167 | 168 | def load_from_directory(self, source_directory: str) -> None: 169 | client = self._connect() 170 | 171 | for root, _, files in os.walk(source_directory): 172 | relative_path = root.replace(source_directory, self._keyname, 1) 173 | 174 | for filename in files: 175 | upload_path = os.path.join(relative_path, filename) 176 | extra_args = None 177 | content_type = mimetypes.guess_type(filename)[0] 178 | if content_type is not None: 179 | extra_args = {"ContentType": content_type} 180 | retry.attempt( 181 | client.upload_file, os.path.join(root, filename), self._bucket, upload_path, 182 | ExtraArgs=extra_args) 183 | 184 | def delete(self) -> None: 185 | client = self._connect() 186 | response = client.delete_object(Bucket=self._bucket, Key=self._keyname) 187 | 188 | if "DeleteMarker" not in response: 189 | raise NotFoundError("No File Found") 190 | 191 | def delete_directory(self) -> None: 192 | client = self._connect() 193 | directory_prefix = "{}/".format(self._keyname) 194 | 195 | dir_object = client.list_objects(Bucket=self._bucket, Prefix=directory_prefix) 196 | 197 | if "Contents" not in dir_object: 198 | raise NotFoundError("No Files Found") 199 | 200 | object_keys = [[{"Key": o.get("Key", None)} for o in dir_object["Contents"]]] 201 | 202 | while dir_object["IsTruncated"]: 203 | dir_object = client.list_objects( 204 | Bucket=self._bucket, 205 | Prefix=directory_prefix, 206 | Marker=dir_object["Contents"][-1]["Key"]) 207 | 208 | object_keys.append([{"Key": o.get("Key", None)} for o in dir_object["Contents"]]) 209 | 210 | for key_page in object_keys: 211 | try: 212 | client.delete_objects(Bucket=self._bucket, Delete={"Objects": key_page}) 213 | except ClientError as original_exc: 214 | if original_exc.response["Error"]["Code"] == "404": 215 | raise NotFoundError("No File Found") from original_exc 216 | raise original_exc 217 | 218 | def get_download_url(self, seconds: int = 60, key: Optional[str] = None) -> str: 219 | client = self._connect() 220 | 221 | return client.generate_presigned_url( 222 | "get_object", Params={"Bucket": self._bucket, "Key": self._keyname}, ExpiresIn=seconds) 223 | 224 | def get_sanitized_uri(self) -> str: 225 | return remove_user_info(self._parsed_storage_uri) 226 | -------------------------------------------------------------------------------- /storage/storage.py: -------------------------------------------------------------------------------- 1 | import queue 2 | import threading 3 | from urllib.parse import ParseResult, urljoin, urlparse, uses_query 4 | 5 | from typing import BinaryIO, Callable, Dict, List, Optional, Type, TypeVar, Union 6 | 7 | from storage.url_parser import sanitize_resource_uri 8 | 9 | 10 | _STORAGE_TYPES = {} # maintains supported storage protocols 11 | _LARGE_CHUNK = 32 * 1024 * 1024 12 | 13 | DEFAULT_SWIFT_TIMEOUT = 60 14 | 15 | """Socket timeout (float seconds) for FTP transfers.""" 16 | DEFAULT_FTP_TIMEOUT = 60.0 17 | 18 | """Enable (1) or disable (0) KEEPALIVE probes for FTP command socket""" 19 | DEFAULT_FTP_KEEPALIVE_ENABLE = 1 20 | 21 | """Socket KEEPALIVE Probes count for FTP transfers.""" 22 | DEFAULT_FTP_KEEPCNT = 5 23 | 24 | """Socket KEEPALIVE idle timeout for FTP transfers.""" 25 | DEFAULT_FTP_KEEPIDLE = 60 26 | 27 | """Socket KEEPALIVE interval for FTP transfers.""" 28 | DEFAULT_FTP_KEEPINTVL = 60 29 | 30 | 31 | def register_storage_protocol(scheme: str) -> Callable[[Type["Storage"]], Type["Storage"]]: 32 | """Register a storage protocol with the storage library by associating 33 | a scheme with the specified storage class (aClass).""" 34 | 35 | def decorate_storage_protocol(aClass: Type["Storage"]) -> Type["Storage"]: 36 | 37 | _STORAGE_TYPES[scheme] = aClass 38 | uses_query.append(scheme) 39 | return aClass 40 | 41 | return decorate_storage_protocol 42 | 43 | 44 | class NotFoundError(Exception): 45 | pass 46 | 47 | 48 | class DownloadUrlBaseUndefinedError(Exception): 49 | """Exception raised when a download url has been requested and 50 | no download_url_base has been defined for the storage object. 51 | 52 | This exception is used with local file storage and FTP file storage objects. 53 | """ 54 | pass 55 | 56 | 57 | class TimeoutError(IOError): 58 | """Exception raised by timeout when a blocking operation times out.""" 59 | pass 60 | 61 | 62 | T = TypeVar("T") 63 | 64 | 65 | def timeout(seconds: int, worker: Callable[[], T]) -> T: 66 | result_queue: queue.Queue[Union[BaseException, T]] = queue.Queue() 67 | 68 | def wrapper() -> None: 69 | try: 70 | result_queue.put(worker()) 71 | except BaseException as e: 72 | result_queue.put(e) 73 | 74 | thread = threading.Thread(target=wrapper) 75 | thread.daemon = True 76 | thread.start() 77 | 78 | try: 79 | result = result_queue.get(True, seconds) 80 | except queue.Empty: 81 | raise TimeoutError() 82 | 83 | if isinstance(result, BaseException): 84 | raise result 85 | return result 86 | 87 | 88 | class Storage(object): 89 | _storage_uri: str 90 | _parsed_storage_uri: ParseResult 91 | 92 | def __init__(self, storage_uri: str) -> None: 93 | self._storage_uri = storage_uri 94 | self._parsed_storage_uri = urlparse(storage_uri) 95 | self._validate_parsed_uri() 96 | 97 | def _validate_parsed_uri(self) -> None: 98 | pass 99 | 100 | def _class_name(self) -> str: 101 | return self.__class__.__name__ 102 | 103 | def save_to_filename(self, file_path: str) -> None: 104 | raise NotImplementedError( 105 | "{} does not implement 'save_to_filename'".format(self._class_name())) 106 | 107 | def save_to_file(self, out_file: BinaryIO) -> None: 108 | raise NotImplementedError( 109 | "{} does not implement 'save_to_file'".format(self._class_name())) 110 | 111 | def save_to_directory(self, directory_path: str) -> None: 112 | raise NotImplementedError( 113 | "{} does not implement 'save_to_directory'".format(self._class_name())) 114 | 115 | def load_from_filename(self, file_path: str) -> None: 116 | raise NotImplementedError( 117 | "{} does not implement 'load_from_filename'".format(self._class_name())) 118 | 119 | def load_from_file(self, in_file: BinaryIO) -> None: 120 | raise NotImplementedError( 121 | "{} does not implement 'load_from_file'".format(self._class_name())) 122 | 123 | def load_from_directory(self, directory_path: str) -> None: 124 | raise NotImplementedError( 125 | "{} does not implement 'load_from_directory'".format(self._class_name())) 126 | 127 | def delete(self) -> None: 128 | raise NotImplementedError("{} does not implement 'delete'".format(self._class_name())) 129 | 130 | def delete_directory(self) -> None: 131 | raise NotImplementedError( 132 | "{} does not implement 'delete_directory'".format(self._class_name())) 133 | 134 | def get_download_url(self, seconds: int = 60, key: Optional[str] = None) -> str: 135 | raise NotImplementedError( 136 | "{} does not implement 'get_download_url'".format(self._class_name())) 137 | 138 | def get_sanitized_uri(self) -> str: 139 | return sanitize_resource_uri(self._parsed_storage_uri) 140 | 141 | 142 | def _generate_download_url_from_base(base: Union[str, None], object_name: str) -> str: 143 | """Generate a download url by joining the base with the storage object_name. 144 | 145 | If the base is not defined, raise an exception. 146 | """ 147 | if base is None: 148 | raise DownloadUrlBaseUndefinedError("The storage uri has no download_url_base defined.") 149 | 150 | return urljoin(base, object_name) 151 | 152 | 153 | class InvalidStorageUri(RuntimeError): 154 | """Invalid storage URI was specified.""" 155 | pass 156 | 157 | 158 | def get_storage(storage_uri: str) -> Storage: 159 | storage_type = urlparse(storage_uri).scheme 160 | try: 161 | return _STORAGE_TYPES[storage_type](storage_uri) 162 | except KeyError: 163 | raise InvalidStorageUri(f"Invalid storage type '{storage_type}'") 164 | 165 | 166 | ParsedQuery = Dict[str, List[str]] 167 | 168 | 169 | def get_optional_query_parameter(parsed_query: ParsedQuery, parameter: str) -> Optional[str]: 170 | query_arg = parsed_query.get(parameter, []) 171 | if len(query_arg) > 1: 172 | raise InvalidStorageUri(f"Too many `{parameter}` query values.") 173 | return query_arg[0] if len(query_arg) else None 174 | -------------------------------------------------------------------------------- /storage/swift_storage.py: -------------------------------------------------------------------------------- 1 | import mimetypes 2 | import os 3 | from urllib.parse import parse_qs, parse_qsl, ParseResult, urlencode, urljoin, urlparse 4 | 5 | from keystoneauth1 import session 6 | from keystoneauth1.identity import v2 7 | import swiftclient.client 8 | from swiftclient.exceptions import ClientException 9 | import swiftclient.utils 10 | from typing import BinaryIO, Callable, cast, Dict, List, Optional, Tuple, Type, TypeVar 11 | 12 | from storage import retry 13 | from storage.storage import InvalidStorageUri, register_storage_protocol, Storage, NotFoundError 14 | from storage.storage import get_optional_query_parameter, _LARGE_CHUNK, DEFAULT_SWIFT_TIMEOUT 15 | 16 | 17 | def register_swift_protocol( 18 | scheme: str, auth_endpoint: str) -> Callable[[Type["SwiftStorage"]], Type["SwiftStorage"]]: 19 | def wrapper(cls: Type["SwiftStorage"]) -> Type["SwiftStorage"]: 20 | cls.auth_endpoint = auth_endpoint 21 | return cast(Type[SwiftStorage], register_storage_protocol(scheme)(cls)) 22 | return wrapper 23 | 24 | 25 | class SwiftStorageError(Exception): 26 | 27 | def __init__(self, message: str, do_not_retry: bool = False) -> None: 28 | super().__init__(message) 29 | self.do_not_retry = do_not_retry 30 | 31 | 32 | T = TypeVar("T") 33 | 34 | 35 | @register_storage_protocol("swift") 36 | class SwiftStorage(Storage): 37 | 38 | download_url_key: Optional[str] 39 | 40 | def _validate_parsed_uri(self) -> None: 41 | query = parse_qs(self._parsed_storage_uri.query) 42 | 43 | auth_endpoint = get_optional_query_parameter(query, "auth_endpoint") 44 | if auth_endpoint is None: 45 | raise SwiftStorageError("Required field is missing: auth_endpoint") 46 | self.auth_endpoint = auth_endpoint 47 | 48 | region_name = get_optional_query_parameter(query, "region") 49 | if region_name is None: 50 | raise SwiftStorageError("Required field is missing: region_name") 51 | self.region_name = region_name 52 | 53 | tenant_id = get_optional_query_parameter(query, "tenant_id") 54 | if tenant_id is None: 55 | raise SwiftStorageError("Required field is missing: tenant_id") 56 | self.tenant_id = tenant_id 57 | 58 | self.download_url_key = get_optional_query_parameter(query, "download_url_key") 59 | 60 | if self._parsed_storage_uri.username is None or self._parsed_storage_uri.username == "": 61 | raise InvalidStorageUri("Missing username") 62 | if self._parsed_storage_uri.password is None or self._parsed_storage_uri.password == "": 63 | raise InvalidStorageUri("Missing API key") 64 | if self._parsed_storage_uri.hostname is None: 65 | raise InvalidStorageUri("Missing hostname") 66 | 67 | self._username = self._parsed_storage_uri.username 68 | self._password = self._parsed_storage_uri.password 69 | self._hostname = self._parsed_storage_uri.hostname 70 | 71 | # cache get connections 72 | def get_connection(self) -> swiftclient.client.Connection: 73 | if not hasattr(self, "_connection"): 74 | os_options = { 75 | "tenant_id": self.tenant_id, 76 | "region_name": self.region_name 77 | } 78 | 79 | auth = v2.Password( 80 | auth_url=self.auth_endpoint, username=self._username, password=self._password, 81 | tenant_name=self.tenant_id) 82 | 83 | keystone_session = session.Session(auth=auth) 84 | 85 | connection = swiftclient.client.Connection( 86 | session=keystone_session, os_options=os_options, timeout=DEFAULT_SWIFT_TIMEOUT, 87 | retries=0) 88 | connection.get_auth() 89 | 90 | self._connection = connection 91 | return self._connection 92 | 93 | def get_container_and_object_names(self) -> Tuple[str, str]: 94 | _, container = self._parsed_storage_uri.netloc.split("@") 95 | object_name = self._parsed_storage_uri.path[1:] 96 | return container, object_name 97 | 98 | def _download_object_to_file( 99 | self, container: str, object_name: str, out_file: BinaryIO) -> None: 100 | connection = self.get_connection() 101 | 102 | # may need to truncate file if retrying... 103 | out_file.seek(0) 104 | 105 | try: 106 | resp_headers, object_contents = connection.get_object( 107 | container, object_name, resp_chunk_size=_LARGE_CHUNK) 108 | except ClientException as original_exc: 109 | if original_exc.http_status == 404: 110 | raise NotFoundError("No File Found") from original_exc 111 | raise original_exc 112 | 113 | for object_content in object_contents: 114 | out_file.write(object_content) 115 | 116 | def _download_object_to_filename(self, container: str, object_name: str, filename: str) -> None: 117 | with open(filename, "wb") as out_file: 118 | self._download_object_to_file(container, object_name, out_file) 119 | 120 | def save_to_file(self, out_file: BinaryIO) -> None: 121 | container, object_name = self.get_container_and_object_names() 122 | self._download_object_to_file(container, object_name, out_file) 123 | 124 | def save_to_filename(self, file_path: str) -> None: 125 | container, object_name = self.get_container_and_object_names() 126 | self._download_object_to_filename(container, object_name, file_path) 127 | 128 | def load_from_file(self, in_file: BinaryIO) -> None: 129 | connection = self.get_connection() 130 | container, object_name = self.get_container_and_object_names() 131 | 132 | mimetype = mimetypes.guess_type(object_name)[0] or "application/octet-stream" 133 | 134 | connection.put_object(container, object_name, in_file, content_type=mimetype) 135 | 136 | def load_from_filename(self, in_path: str) -> None: 137 | with open(in_path, "rb") as fp: 138 | self.load_from_file(fp) 139 | 140 | def delete(self) -> None: 141 | connection = self.get_connection() 142 | container, object_name = self.get_container_and_object_names() 143 | 144 | try: 145 | connection.delete_object(container, object_name) 146 | except ClientException as original_exc: 147 | if original_exc.http_status == 404: 148 | raise NotFoundError("No File Found") from original_exc 149 | raise original_exc 150 | 151 | def get_download_url(self, seconds: int = 60, key: Optional[str] = None) -> str: 152 | connection = self.get_connection() 153 | 154 | download_url_key = key or self.download_url_key 155 | 156 | if download_url_key is None: 157 | raise SwiftStorageError( 158 | "Missing required `download_url_key` for `get_download_url`.") 159 | 160 | host, _ = connection.get_service_auth() 161 | container, object_name = self.get_container_and_object_names() 162 | 163 | storage_url, _ = connection.get_auth() 164 | storage_path = urlparse(storage_url).path 165 | 166 | path = swiftclient.utils.generate_temp_url( 167 | f"{storage_path}/{container}/{object_name}", 168 | seconds=seconds, key=download_url_key, method="GET") 169 | 170 | return urljoin(host, path) 171 | 172 | def get_sanitized_uri(self) -> str: 173 | parsed_uri = self._parsed_storage_uri 174 | new_query = dict(parse_qsl(parsed_uri.query)) 175 | 176 | if "download_url_key" in new_query: 177 | del new_query["download_url_key"] 178 | 179 | new_uri = ParseResult( 180 | parsed_uri.scheme, self._hostname, parsed_uri.path, parsed_uri.params, 181 | urlencode(new_query), parsed_uri.fragment) 182 | 183 | return new_uri.geturl() 184 | 185 | def _find_storage_objects_with_prefix( 186 | self, container: str, prefix: str) -> List[Dict[str, str]]: 187 | connection = self.get_connection() 188 | try: 189 | _, container_objects = connection.get_container(container, prefix=prefix) 190 | if len(container_objects) == 0: 191 | raise NotFoundError("No Files Found") 192 | return container_objects 193 | except ClientException as original_exc: 194 | if original_exc.http_status == 404: 195 | raise NotFoundError("No File Found") from original_exc 196 | raise original_exc 197 | 198 | def save_to_directory(self, directory_path: str) -> None: 199 | container, object_name = self.get_container_and_object_names() 200 | 201 | prefix = self._parsed_storage_uri.path[1:] + "/" 202 | 203 | for container_object in self._find_storage_objects_with_prefix(container, prefix): 204 | if container_object["name"].endswith("/"): 205 | continue 206 | 207 | base_path = container_object["name"].split(prefix)[1] 208 | relative_path = os.path.sep.join(base_path.split("/")) 209 | file_path = os.path.join(directory_path, relative_path) 210 | object_path = container_object["name"] 211 | 212 | while object_path.startswith("/"): 213 | object_path = object_path[1:] 214 | 215 | dir_name = os.path.dirname(file_path) 216 | os.makedirs(dir_name, exist_ok=True) 217 | 218 | retry.attempt(self._download_object_to_filename, container, object_path, file_path) 219 | 220 | def load_from_directory(self, directory_path: str) -> None: 221 | connection = self.get_connection() 222 | container, object_name = self.get_container_and_object_names() 223 | 224 | prefix = self._parsed_storage_uri.path[1:] 225 | 226 | for root, _, files in os.walk(directory_path): 227 | base = root.split(directory_path, 1)[1] 228 | while base.startswith("/"): 229 | base = base[1:] 230 | while base.endswith("/"): 231 | base = base[:-1] 232 | for filename in files: 233 | local_path = os.path.join(root, filename) 234 | remote_path = "/".join(filter(lambda x: x != "", [prefix, base, filename])) 235 | 236 | mimetype = mimetypes.guess_type(remote_path)[0] or "application/octet-stream" 237 | 238 | with open(local_path, "rb") as fp: 239 | retry.attempt( 240 | connection.put_object, container, remote_path, fp, content_type=mimetype) 241 | 242 | def delete_directory(self) -> None: 243 | connection = self.get_connection() 244 | container, object_name = self.get_container_and_object_names() 245 | 246 | prefix = self._parsed_storage_uri.path[1:] + "/" 247 | 248 | for container_object in self._find_storage_objects_with_prefix(container, prefix): 249 | object_path = container_object["name"] 250 | 251 | while object_path.startswith("/"): 252 | object_path = object_path[1:] 253 | 254 | connection.delete_object(container, object_path) 255 | -------------------------------------------------------------------------------- /storage/url_parser.py: -------------------------------------------------------------------------------- 1 | from urllib.parse import parse_qsl, ParseResult, urlencode 2 | 3 | from typing import Dict 4 | 5 | 6 | def _new_uri(parsed_uri: ParseResult, new_netloc: str, new_query: Dict[str, str]) -> ParseResult: 7 | return ParseResult( 8 | parsed_uri.scheme, new_netloc, parsed_uri.path, parsed_uri.params, urlencode(new_query), 9 | parsed_uri.fragment) 10 | 11 | 12 | def remove_user_info(parsed_uri: ParseResult) -> str: 13 | new_netloc = "" if parsed_uri.hostname is None else parsed_uri.hostname 14 | 15 | if parsed_uri.port is not None: 16 | new_netloc = ":".join((new_netloc, str(parsed_uri.port))) 17 | 18 | new_uri = _new_uri(parsed_uri, new_netloc, dict(parse_qsl(parsed_uri.query))) 19 | 20 | return new_uri.geturl() 21 | 22 | 23 | def sanitize_resource_uri(parsed_uri: ParseResult) -> str: 24 | new_netloc = "" if parsed_uri.hostname is None else parsed_uri.hostname 25 | 26 | if parsed_uri.port is not None: 27 | new_netloc = ":".join((new_netloc, str(parsed_uri.port))) 28 | 29 | new_uri = _new_uri(parsed_uri, new_netloc, {}) 30 | 31 | return new_uri.geturl() 32 | -------------------------------------------------------------------------------- /stubs/boto3/__init__.pyi: -------------------------------------------------------------------------------- 1 | import botocore.session 2 | from typing import Literal, Optional, TypedDict 3 | 4 | 5 | class _AssumeRoleResponseCredentials(TypedDict): 6 | AccessKeyId: str 7 | SecretAccessKey: str 8 | SessionToken: str 9 | # Other fields omitted 10 | 11 | 12 | class _AssumeRoleResponse(TypedDict): 13 | Credentials: _AssumeRoleResponseCredentials 14 | # Other fields omitted 15 | 16 | 17 | class _STSClient: 18 | def assume_role( 19 | self, 20 | *, 21 | RoleArn: str, 22 | RoleSessionName: str, 23 | ExternalId: Optional[str] = None 24 | # Other arguments omitted 25 | ) -> _AssumeRoleResponse: 26 | ... 27 | 28 | 29 | def client( 30 | service_name: Literal["sts"], 31 | aws_access_key_id: Optional[str] = None, 32 | aws_secret_access_key: Optional[str] = None, 33 | aws_session_token: Optional[str] = None, 34 | region_name: Optional[str] = None, 35 | botocore_session: Optional[botocore.session.Session] = None, 36 | profile_name: Optional[str] = None 37 | ) -> _STSClient: 38 | ... 39 | -------------------------------------------------------------------------------- /stubs/boto3/s3/__init__.pyi: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ustudio/storage/f72b2a71e3082de1b3dd728b32b939d80531fe4e/stubs/boto3/s3/__init__.pyi -------------------------------------------------------------------------------- /stubs/boto3/s3/transfer.pyi: -------------------------------------------------------------------------------- 1 | from typing import Any, Optional 2 | from botocore.session import Session 3 | 4 | 5 | # these three are actually in s3transfer module, but they are just pure 6 | # type stubs for now since we aren't using them... 7 | 8 | class TransferConfig(object): ... 9 | 10 | class OSUtils(object): ... 11 | 12 | class TransferManager(object): ... 13 | 14 | 15 | class S3Transfer(object): 16 | def __init__( 17 | self, 18 | client: Optional[Session] = None, 19 | config: Optional[TransferConfig] = None, 20 | osutil: Optional[OSUtils] = None, 21 | manager: Optional[TransferManager] = None) -> None: ... 22 | 23 | def download_file( 24 | self, 25 | bucket: str, 26 | key: str, 27 | filename: str, 28 | *args: Any, 29 | **kwargs: Any) -> None: ... 30 | 31 | def upload_file( 32 | self, 33 | filename: str, 34 | bucket: str, 35 | key: str, 36 | *args: Any, 37 | **kwargs: Any) -> None: ... 38 | -------------------------------------------------------------------------------- /stubs/boto3/session.pyi: -------------------------------------------------------------------------------- 1 | import botocore.config 2 | import botocore.session 3 | 4 | from typing import Optional, Union 5 | 6 | 7 | class Session(object): 8 | 9 | def __init__( 10 | self, 11 | aws_access_key_id: Optional[str] = None, 12 | aws_secret_access_key: Optional[str] = None, 13 | aws_session_token: Optional[str] = None, 14 | region_name: Optional[str] = None, 15 | botocore_session: Optional[botocore.session.Session] = None, 16 | profile_name: Optional[str] = None) -> None: ... 17 | 18 | def client( 19 | self, 20 | service_name: str, 21 | region_name: Optional[str] = None, 22 | api_version: Optional[str] = None, 23 | use_ssl: bool = True, 24 | verify: Union[None, bool, str] = None, 25 | endpoint_url: Optional[str] = None, 26 | aws_access_key_id: Optional[str] = None, 27 | aws_secret_access_key: Optional[str] = None, 28 | aws_session_token: Optional[str] = None, 29 | config: Optional[botocore.config.Config] = None 30 | ) -> botocore.session.Session: ... 31 | -------------------------------------------------------------------------------- /stubs/botocore/__init__.pyi: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ustudio/storage/f72b2a71e3082de1b3dd728b32b939d80531fe4e/stubs/botocore/__init__.pyi -------------------------------------------------------------------------------- /stubs/botocore/config.pyi: -------------------------------------------------------------------------------- 1 | from typing import Optional, Union 2 | 3 | 4 | class Config(): 5 | def __init__( 6 | self, 7 | region_name: Optional[str] = None, 8 | signature_version: Optional[str] = None, 9 | user_agent: Optional[str] = None, 10 | user_agent_extra: Optional[str] = None, 11 | user_agent_appid: Optional[str] = None, 12 | connect_timeout: Optional[Union[float, int]] = None, 13 | read_timeout: Optional[Union[float, int]] = None, 14 | parameter_validation: Optional[bool] = None, 15 | max_pool_connections: Optional[int] = None, 16 | proxies: Optional[dict[str, object]] = None, 17 | proxies_config: Optional[dict[str, object]] = None, 18 | s3: Optional[dict[str, object]] = None, 19 | retries: Optional[dict[str, object]] = None, 20 | client_cert: Optional[Union[str, tuple[str, str]]] = None, 21 | inject_host_prefix: Optional[bool] = None, 22 | use_dualstack_endpoint: Optional[bool] = None, 23 | use_fips_endpoint: Optional[bool] = None, 24 | ignore_configured_endpoint_urls: Optional[bool] = None, 25 | tcp_keepalive: Optional[bool] = None, 26 | request_min_compression_size_bytes: Optional[int] = None, 27 | disable_request_compression: Optional[bool] = None, 28 | sigv4a_signing_region_set: Optional[str] = None, 29 | client_context_params: Optional[dict[str, object]] = None 30 | ) -> None: 31 | ... 32 | -------------------------------------------------------------------------------- /stubs/botocore/exceptions.pyi: -------------------------------------------------------------------------------- 1 | from typing import Any, Dict 2 | 3 | class ClientError(BaseException): 4 | def __init__(self, error_response: Dict[Any, Any], operation_name: object) -> None: 5 | self.response = error_response 6 | self.operation_name = operation_name 7 | -------------------------------------------------------------------------------- /stubs/botocore/session.pyi: -------------------------------------------------------------------------------- 1 | # this is actually a very dynamic class instance with dynamic attributes 2 | # performing underlying calls against the configured services, but we're 3 | # stubbing it out here for S3 for simplicity's sake. 4 | 5 | 6 | from io import BytesIO 7 | from typing import BinaryIO, Dict, List, Optional 8 | from mypy_extensions import TypedDict 9 | 10 | 11 | ObjectResponse = TypedDict("ObjectResponse", {"Body": BytesIO}) 12 | 13 | ListObjectResponse = TypedDict( 14 | "ListObjectResponse", { 15 | "Key": str, 16 | "LastModified": str, 17 | "ETag": str, 18 | "Size": int, 19 | "StorageClass": str 20 | }) 21 | 22 | ListResponse = TypedDict( 23 | "ListResponse", { 24 | "Contents": List[ListObjectResponse], 25 | "IsTruncated": bool, 26 | "NextMarker": Optional[str] 27 | }) 28 | 29 | DeleteEntries = TypedDict( 30 | "DeleteEntries", {"Objects": List[Dict[str, Optional[str]]]}) 31 | 32 | ParamEntries = TypedDict("ParamEntries", {"Bucket": str, "Key": str}) 33 | 34 | 35 | class Session(object): 36 | 37 | def get_object(self, Bucket: str, Key: str) -> ObjectResponse: ... 38 | 39 | def list_objects( 40 | self, 41 | Bucket: str, 42 | Prefix: str, 43 | Marker: Optional[str] = None 44 | ) -> ListResponse: 45 | ... 46 | 47 | def download_file(self, Bucket: str, Key: str, filepath: str) -> None: ... 48 | 49 | def put_object( 50 | self, Bucket: str, Key: str, Body: BinaryIO, ContentType: Optional[str]) -> None: ... 51 | 52 | def upload_file(self, Filename: str, Bucket: str, Key: str) -> None: ... 53 | 54 | def delete_object(self, Bucket: str, Key: str) -> ListResponse: ... 55 | 56 | def delete_objects(self, Bucket: str, Delete: DeleteEntries) -> None: ... 57 | 58 | def generate_presigned_url( 59 | self, 60 | Permission: str, 61 | Params: ParamEntries, 62 | ExpiresIn: int) -> str: ... 63 | -------------------------------------------------------------------------------- /stubs/google/__init__.pyi: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ustudio/storage/f72b2a71e3082de1b3dd728b32b939d80531fe4e/stubs/google/__init__.pyi -------------------------------------------------------------------------------- /stubs/google/cloud/__init__.pyi: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ustudio/storage/f72b2a71e3082de1b3dd728b32b939d80531fe4e/stubs/google/cloud/__init__.pyi -------------------------------------------------------------------------------- /stubs/google/cloud/exceptions.pyi: -------------------------------------------------------------------------------- 1 | class NotFound(BaseException): 2 | 3 | def __init__(self, message: str) -> None: ... 4 | -------------------------------------------------------------------------------- /stubs/google/cloud/storage/__init__.pyi: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ustudio/storage/f72b2a71e3082de1b3dd728b32b939d80531fe4e/stubs/google/cloud/storage/__init__.pyi -------------------------------------------------------------------------------- /stubs/google/cloud/storage/blob.pyi: -------------------------------------------------------------------------------- 1 | from datetime import timedelta 2 | 3 | from typing import BinaryIO, Optional 4 | 5 | 6 | class Blob(object): 7 | 8 | name: str 9 | 10 | def download_to_filename(self, path: str) -> None: ... 11 | 12 | def download_to_file(self, fp: BinaryIO) -> None: ... 13 | 14 | def upload_from_filename(self, path: str) -> None: ... 15 | 16 | def upload_from_file(self, fp: BinaryIO, content_type: Optional[str]) -> None: ... 17 | 18 | def delete(self) -> None: ... 19 | 20 | def generate_signed_url( 21 | self, expiration: Optional[timedelta] = None, 22 | response_disposition: Optional[str] = None) -> str: ... 23 | -------------------------------------------------------------------------------- /stubs/google/cloud/storage/bucket.pyi: -------------------------------------------------------------------------------- 1 | from google.cloud.storage.blob import Blob 2 | 3 | from typing import Iterator 4 | 5 | 6 | class Bucket(object): 7 | 8 | def blob(self, blob_name: str) -> Blob: ... 9 | 10 | def list_blobs(self, prefix: str) -> Iterator[Blob]: ... 11 | -------------------------------------------------------------------------------- /stubs/google/cloud/storage/client.pyi: -------------------------------------------------------------------------------- 1 | from google.oauth2.service_account import Credentials 2 | from google.cloud.storage.bucket import Bucket 3 | 4 | 5 | class Client(object): 6 | 7 | def __init__(self, project: str, credentials: Credentials) -> None: ... 8 | 9 | def get_bucket(self, bucket_name: str) -> Bucket: ... 10 | -------------------------------------------------------------------------------- /stubs/google/oauth2/__init__.pyi: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ustudio/storage/f72b2a71e3082de1b3dd728b32b939d80531fe4e/stubs/google/oauth2/__init__.pyi -------------------------------------------------------------------------------- /stubs/google/oauth2/service_account.pyi: -------------------------------------------------------------------------------- 1 | from typing import Dict 2 | 3 | 4 | class Credentials(object): 5 | 6 | @classmethod 7 | def from_service_account_info( 8 | self, account_info: Dict[str, str]) -> "Credentials": ... 9 | -------------------------------------------------------------------------------- /stubs/keystoneauth1/__init__.pyi: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ustudio/storage/f72b2a71e3082de1b3dd728b32b939d80531fe4e/stubs/keystoneauth1/__init__.pyi -------------------------------------------------------------------------------- /stubs/keystoneauth1/exceptions/__init__.pyi: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ustudio/storage/f72b2a71e3082de1b3dd728b32b939d80531fe4e/stubs/keystoneauth1/exceptions/__init__.pyi -------------------------------------------------------------------------------- /stubs/keystoneauth1/exceptions/http.pyi: -------------------------------------------------------------------------------- 1 | class BadGateway(Exception): 2 | pass 3 | 4 | 5 | class Forbidden(Exception): 6 | pass 7 | 8 | 9 | class InternalServerError(Exception): 10 | pass 11 | 12 | 13 | class Unauthorized(Exception): 14 | pass 15 | -------------------------------------------------------------------------------- /stubs/keystoneauth1/identity/__init__.pyi: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ustudio/storage/f72b2a71e3082de1b3dd728b32b939d80531fe4e/stubs/keystoneauth1/identity/__init__.pyi -------------------------------------------------------------------------------- /stubs/keystoneauth1/identity/v2.pyi: -------------------------------------------------------------------------------- 1 | from typing import Any, Dict, Optional, Union 2 | 3 | 4 | class _NotPassed(object): 5 | pass 6 | 7 | 8 | _NOT_PASSED = _NotPassed() 9 | 10 | 11 | class Password(object): 12 | 13 | def __init__( 14 | self, 15 | auth_url: Optional[str], 16 | username: Union[_NotPassed, str] = _NOT_PASSED, 17 | password: Optional[str] = None, 18 | user_id: Union[_NotPassed, str] = _NOT_PASSED, 19 | **kwargs: Any) -> None: ... 20 | 21 | def get_auth_data( 22 | self, 23 | *args: Any, 24 | **kwargs: Any) -> Dict[str, Any]: ... 25 | -------------------------------------------------------------------------------- /stubs/keystoneauth1/session.pyi: -------------------------------------------------------------------------------- 1 | # This is purposefully typed only for the subset that we use, may need to be 2 | # expanded as new arguments, etc. are used. 3 | 4 | 5 | from keystoneauth1.identity.v2 import Password 6 | from typing import Optional 7 | 8 | 9 | class Session(object): 10 | def __init__(self, auth: Optional[Password] = None) -> None: ... 11 | -------------------------------------------------------------------------------- /stubs/swiftclient/__init__.pyi: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ustudio/storage/f72b2a71e3082de1b3dd728b32b939d80531fe4e/stubs/swiftclient/__init__.pyi -------------------------------------------------------------------------------- /stubs/swiftclient/client.pyi: -------------------------------------------------------------------------------- 1 | from typing import Any, BinaryIO, Dict, Generator, Iterable, List, Optional 2 | from typing import overload, Tuple, Union 3 | 4 | from keystoneauth1.session import Session 5 | 6 | 7 | ObjectContents = Union[str, BinaryIO, Iterable[str]] 8 | 9 | 10 | class Connection(object): 11 | 12 | def __init__( 13 | self, 14 | authurl: Optional[str] = None, 15 | user: Optional[str] = None, 16 | key: Optional[str] = None, 17 | retries: int = 5, 18 | preauthurl: Optional[str] = None, 19 | preauthtoken: Optional[str] = None, 20 | snet: bool = False, 21 | starting_backoff: int = 1, 22 | max_backoff: int = 64, 23 | tenant_name: Optional[str] = None, 24 | os_options: Optional[Dict[str, Any]] = None, 25 | auth_version: str = '1', 26 | cacert: Optional[str] = None, 27 | insecure: bool = False, 28 | cert: Optional[str] = None, 29 | cert_key: Optional[str] = None, 30 | ssl_compression: bool = True, 31 | retry_on_ratelimit: bool = False, 32 | timeout: Optional[int] = None, 33 | session: Optional[Session] = None, 34 | force_auth_retry: bool = False) -> None: ... 35 | 36 | @overload 37 | def get_object( 38 | self, 39 | container: str, 40 | obj: str, 41 | resp_chunk_size: None = None, 42 | query_string: Optional[str] = None, 43 | response_dict: Optional[Dict[str, Any]] = None, 44 | headers: Optional[Dict[str, str]] = None 45 | ) -> Tuple[Dict[str, str], bytes]: ... 46 | 47 | @overload 48 | def get_object( 49 | self, 50 | container: str, 51 | obj: str, 52 | resp_chunk_size: int, 53 | query_string: Optional[str] = None, 54 | response_dict: Optional[Dict[str, Any]] = None, 55 | headers: Optional[Dict[str, str]] = None 56 | ) -> Tuple[Dict[str, str], Generator[bytes, None, None]]: ... 57 | 58 | def put_object( 59 | self, 60 | container: str, 61 | obj: str, 62 | contents: Optional[ObjectContents], 63 | content_length: Optional[int] = None, 64 | etag: Optional[str] = None, 65 | chunk_size: Optional[int] = None, 66 | content_type: Optional[str] = None, 67 | headers: Optional[Dict[str, str]] = None, 68 | query_string: Optional[str] = None, 69 | response_dict: Optional[Dict[str, Any]] = None) -> str: ... 70 | 71 | def delete_object( 72 | self, 73 | container: str, 74 | obj: str, 75 | query_string: Optional[str] = None, 76 | response_dict: Optional[Dict[str, Any]] = None, 77 | headers: Optional[Dict[str, str]] = None) -> None: ... 78 | 79 | def get_service_auth(self) -> Tuple[str, Dict[str, str]]: ... 80 | 81 | def get_auth(self) -> Tuple[str, str]: ... 82 | 83 | def get_container( 84 | self, 85 | container: str, 86 | marker: Optional[str] = None, 87 | limit: Optional[str] = None, 88 | prefix: Optional[str] = None, 89 | delimiter: Optional[str] = None, 90 | end_marker: Optional[str] = None, 91 | path: Optional[str] = None, 92 | full_listing: bool = False, 93 | headers: Optional[Dict[str, str]] = None, 94 | query_string: Optional[str] = None 95 | ) -> Tuple[Dict[str, str], List[Dict[str, str]]]: ... 96 | 97 | def head_account(self, headers: Optional[Dict[str, str]] = None) -> Dict[str, str]: ... 98 | -------------------------------------------------------------------------------- /stubs/swiftclient/exceptions.pyi: -------------------------------------------------------------------------------- 1 | from typing import Optional 2 | 3 | 4 | class ClientException(Exception): 5 | # Since underlying library doesn't provide this attribute, 6 | # we're "mixing it in" dynamically later. 7 | do_not_retry: Optional[bool] 8 | http_status: int 9 | 10 | pass 11 | -------------------------------------------------------------------------------- /stubs/swiftclient/utils.pyi: -------------------------------------------------------------------------------- 1 | from typing import Optional 2 | 3 | 4 | def generate_temp_url( 5 | path: str, 6 | seconds: int, 7 | key: str, 8 | method: str, 9 | absolute: bool = False, 10 | prefix: bool = False, 11 | iso8601: bool = False, 12 | ip_range: Optional[str] = None) -> str: ... 13 | -------------------------------------------------------------------------------- /test_integration.py: -------------------------------------------------------------------------------- 1 | import io 2 | import os 3 | import random 4 | import subprocess 5 | import shutil 6 | import string 7 | import tempfile 8 | import time 9 | import unittest 10 | from urllib.request import urlopen 11 | 12 | from storage import get_storage 13 | 14 | 15 | class IntegrationTests(unittest.TestCase): 16 | dest_prefix: str 17 | directory: str 18 | 19 | @classmethod 20 | def setUpClass(cls) -> None: 21 | cls.dest_prefix = "".join([random.choice(string.ascii_letters) for i in range(16)]) 22 | cls.directory = tempfile.mkdtemp(prefix="source-root") 23 | contents = "storage-test-{}".format(time.time()).encode("utf8") 24 | 25 | for i in range(3): 26 | tempfile.mkstemp( 27 | prefix="source-empty", 28 | dir=tempfile.mkdtemp(prefix="source-emptyfiledir", dir=cls.directory)) 29 | 30 | handle, _ = tempfile.mkstemp( 31 | prefix="spaces rock", 32 | dir=tempfile.mkdtemp(prefix="source-spacedir", dir=cls.directory)) 33 | os.write(handle, contents) 34 | os.close(handle) 35 | 36 | handle, _ = tempfile.mkstemp( 37 | prefix="source-contentfile", 38 | dir=tempfile.mkdtemp(prefix="source-contentdir", dir=cls.directory)) 39 | os.write(handle, contents) 40 | os.close(handle) 41 | 42 | handle, _ = tempfile.mkstemp( 43 | prefix="source-mimetyped-file", suffix=".png", 44 | dir=tempfile.mkdtemp(prefix="source-contentdir", dir=cls.directory)) 45 | os.write(handle, contents) 46 | os.close(handle) 47 | 48 | @classmethod 49 | def tearDownClass(cls) -> None: 50 | shutil.rmtree(cls.directory) 51 | 52 | def assert_transport_handles_directories(self, transport: str) -> None: 53 | variable = "TEST_STORAGE_{}_URI".format(transport) 54 | uri = os.getenv(variable, None) 55 | 56 | if not uri: 57 | raise unittest.SkipTest("Skipping {} - define {} to test".format(transport, variable)) 58 | 59 | uri += "/{}".format(self.dest_prefix) 60 | storage = get_storage(uri) 61 | print(f"Testing using: {storage.get_sanitized_uri()}") 62 | 63 | print("Transport:", transport) 64 | print("\t* Uploading") 65 | storage.load_from_directory(self.directory) 66 | 67 | target_directory = tempfile.mkdtemp(prefix="dest-root") 68 | 69 | print("\t* Downloading") 70 | storage.save_to_directory(target_directory) 71 | 72 | print("\t* Checking") 73 | try: 74 | subprocess.check_output( 75 | ["diff", "-r", self.directory, target_directory], stderr=subprocess.STDOUT) 76 | except subprocess.CalledProcessError as error: 77 | print("Diff output:\n{}".format(error.output)) 78 | raise 79 | 80 | def assert_download_url_generated_correctly(self, transport: str) -> None: 81 | variable = "TEST_STORAGE_{}_URI".format(transport) 82 | uri = os.getenv(variable, None) 83 | 84 | if not uri: 85 | raise unittest.SkipTest("Skipping {} - define {} to test".format(transport, variable)) 86 | 87 | uri += "/download-test" 88 | upload_storage = get_storage(uri) 89 | print(f"Testing using: {upload_storage.get_sanitized_uri()}") 90 | 91 | print("Transport:", transport) 92 | 93 | upload_storage.load_from_file(io.BytesIO(b"Test data")) 94 | 95 | download_storage = get_storage(uri) 96 | 97 | download_url = download_storage.get_download_url(seconds=3600) 98 | 99 | print("Downloading from download URL") 100 | 101 | with urlopen(download_url) as download: 102 | assert download.read() == b"Test data" 103 | 104 | def test_file_transport_can_upload_and_download_directories(self) -> None: 105 | self.assert_transport_handles_directories("FILE") 106 | 107 | def test_ftp_transport_can_upload_and_download_directories(self) -> None: 108 | self.assert_transport_handles_directories("FTP") 109 | 110 | def test_s3_transport_can_upload_and_download_directories(self) -> None: 111 | self.assert_transport_handles_directories("S3") 112 | 113 | def test_s3_transport_with_json_credentials_can_upload_and_download_directories(self) -> None: 114 | self.assert_transport_handles_directories("S3_JSON") 115 | 116 | def test_swift_transport_can_upload_and_download_directories(self) -> None: 117 | self.assert_transport_handles_directories("SWIFT") 118 | 119 | def test_gs_transport_can_upload_and_download_directories(self) -> None: 120 | self.assert_transport_handles_directories("GS") 121 | 122 | def test_s3_transport_can_generate_valid_download_urls(self) -> None: 123 | self.assert_download_url_generated_correctly("S3") 124 | 125 | def test_s3_transport_with_json_credentials_can_generate_valid_download_urls(self) -> None: 126 | self.assert_download_url_generated_correctly("S3_JSON") 127 | 128 | def test_gs_transport_can_generate_valid_download_urls(self) -> None: 129 | self.assert_download_url_generated_correctly("GS") 130 | -------------------------------------------------------------------------------- /tests/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ustudio/storage/f72b2a71e3082de1b3dd728b32b939d80531fe4e/tests/__init__.py -------------------------------------------------------------------------------- /tests/helpers.py: -------------------------------------------------------------------------------- 1 | import io 2 | import os 3 | import tempfile 4 | 5 | from typing import Any, cast, List, Optional, Union 6 | from typing_extensions import Buffer 7 | from mypy_extensions import TypedDict 8 | 9 | 10 | class NamedIO(io.BufferedReader): 11 | 12 | name: str 13 | 14 | 15 | class TempDirectory(object): 16 | 17 | def __init__(self, parent: Optional[str] = None) -> None: 18 | self.directory = tempfile.TemporaryDirectory(dir=parent) 19 | self.subdirectories: List["TempDirectory"] = [] 20 | self.files: List[NamedIO] = [] 21 | 22 | @property 23 | def name(self) -> str: 24 | return self.directory.name 25 | 26 | def add_file(self, contents: bytes, suffix: str = "") -> NamedIO: 27 | temp = cast(NamedIO, tempfile.NamedTemporaryFile(dir=self.directory.name, suffix=suffix)) 28 | temp.write(contents) 29 | temp.flush() 30 | temp.seek(0) 31 | self.files.append(temp) 32 | return temp 33 | 34 | def add_dir(self) -> "TempDirectory": 35 | temp = TempDirectory(parent=self.directory.name) 36 | self.subdirectories.append(temp) 37 | return temp 38 | 39 | def cleanup(self) -> None: 40 | for subdir in self.subdirectories: 41 | subdir.cleanup() 42 | for temp in self.files: 43 | temp.close() 44 | self.directory.cleanup() 45 | 46 | def __enter__(self) -> "TempDirectory": 47 | return self 48 | 49 | def __exit__(self, *args: Any, **kwargs: Any) -> None: 50 | self.cleanup() 51 | 52 | 53 | NestedFileInfo = TypedDict("NestedFileInfo", { 54 | "file": NamedIO, 55 | "path": str, 56 | "name": str 57 | }) 58 | 59 | 60 | NestedDirectoryInfo = TypedDict("NestedDirectoryInfo", { 61 | "path": str, 62 | "name": str, 63 | "object": TempDirectory 64 | }) 65 | 66 | 67 | NestedDirectoryTempInfo = TypedDict("NestedDirectoryTempInfo", { 68 | "path": str, 69 | "object": TempDirectory 70 | }) 71 | 72 | 73 | NestedDirectoryDict = TypedDict("NestedDirectoryDict", { 74 | "temp_directory": NestedDirectoryTempInfo, 75 | "nested_temp_directory": NestedDirectoryInfo, 76 | "temp_input_one": NestedFileInfo, 77 | "temp_input_two": NestedFileInfo, 78 | "nested_temp_input": NestedFileInfo 79 | }) 80 | 81 | 82 | def cleanup(value: Union[TempDirectory, NamedIO, str]) -> None: 83 | if isinstance(value, TempDirectory): 84 | value.cleanup() 85 | else: 86 | raise ValueError(f"Cannot call cleanup on {type(value)}") 87 | 88 | 89 | def cleanup_nested_directory(value: NestedDirectoryDict) -> None: 90 | value["temp_directory"]["object"].cleanup() 91 | 92 | 93 | def create_temp_nested_directory_with_files( 94 | suffixes: List[str] = ["", "", ""] 95 | ) -> NestedDirectoryDict: 96 | # temp_directory/ 97 | # temp_input_one 98 | # temp_input_two 99 | # nested_temp_directory/ 100 | # nested_temp_input 101 | 102 | directory = TempDirectory() 103 | new_file_1 = directory.add_file(b"FOO", suffixes[0]) 104 | new_file_2 = directory.add_file(b"BAR", suffixes[1]) 105 | 106 | nested_directory = directory.add_dir() 107 | nested_file = nested_directory.add_file(b"FOOBAR", suffixes[2]) 108 | 109 | return { 110 | "temp_directory": { 111 | "path": directory.name, 112 | "object": directory 113 | }, 114 | "nested_temp_directory": { 115 | "path": nested_directory.name, 116 | "name": os.path.basename(nested_directory.name), 117 | "object": nested_directory 118 | }, 119 | "temp_input_one": { 120 | "file": new_file_1, 121 | "path": new_file_1.name, 122 | "name": os.path.basename(new_file_1.name) 123 | }, 124 | "temp_input_two": { 125 | "file": new_file_2, 126 | "path": new_file_2.name, 127 | "name": os.path.basename(new_file_2.name) 128 | }, 129 | "nested_temp_input": { 130 | "file": nested_file, 131 | "path": nested_file.name, 132 | "name": os.path.basename(nested_file.name) 133 | } 134 | } 135 | 136 | 137 | class FileSpy(io.BytesIO): 138 | 139 | def __init__(self) -> None: 140 | self.chunks: List[bytes] = [] 141 | self.index = 0 142 | self.name = "" 143 | 144 | def write(self, chunk: Buffer) -> int: 145 | raw = bytes(chunk) 146 | rawlen = len(raw) 147 | self.chunks.append(raw) 148 | self.index += rawlen 149 | return rawlen 150 | 151 | def seek(self, index: int, whence: int = 0) -> int: 152 | if whence != 0: 153 | raise ValueError("FileSpy can only seek absolutely.") 154 | self.index = index 155 | return self.index 156 | 157 | def assert_written(self, assertion: bytes) -> None: 158 | assert b"".join(self.chunks) == assertion 159 | 160 | def assert_number_of_chunks(self, n: int) -> None: 161 | assert n == len(self.chunks) 162 | -------------------------------------------------------------------------------- /tests/service_test_case.py: -------------------------------------------------------------------------------- 1 | import contextlib 2 | import io 3 | import logging 4 | import unittest 5 | import socket 6 | from threading import Thread, Event 7 | from wsgiref.headers import Headers 8 | from wsgiref.simple_server import make_server 9 | from wsgiref.util import request_uri 10 | from urllib.parse import urlparse 11 | 12 | from typing import Any, Callable, cast, Dict, Generator, Iterable, List 13 | from typing import Optional, Tuple, TYPE_CHECKING 14 | 15 | if TYPE_CHECKING: 16 | # The "type: ignore" on the next line is needed for Python 3.9 and 3.10 support 17 | from wsgiref.types import StartResponse # type: ignore[import-not-found, unused-ignore] 18 | 19 | Environ = Dict[str, Any] 20 | 21 | Handler = Callable[ 22 | [ 23 | Environ, 24 | StartResponse 25 | ], 26 | Iterable[bytes] 27 | ] 28 | 29 | HandlerIdentifier = Tuple[str, str] 30 | 31 | 32 | def get_port() -> int: 33 | s = socket.socket(socket.AF_INET, socket.SOCK_STREAM) 34 | s.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) 35 | s.bind(("localhost", 0)) 36 | 37 | port = cast(int, s.getsockname()[1]) 38 | 39 | s.close() 40 | return port 41 | 42 | 43 | class ServiceRequest(object): 44 | # eventually this can contain headers, body, etc. as necessary for comparison 45 | 46 | def __init__( 47 | self, headers: Dict[str, str], method: str, path: str, body: Optional[bytes]) -> None: 48 | self.headers = Headers([(key.replace("_", "-"), value) for key, value in headers.items()]) 49 | self.method = method 50 | self.path = path 51 | self.body = body 52 | 53 | def assert_header_equals(self, header_key: str, header_value: str) -> None: 54 | assert header_key in self.headers, \ 55 | f"Expected header {header_key} not in request headers." 56 | actual_value = self.headers[header_key] 57 | assert actual_value == header_value, \ 58 | f"Request header {header_key} unexpectedly set to " \ 59 | f"`{actual_value}` instead of `{header_value}`." 60 | 61 | def assert_body_equals(self, body: bytes) -> None: 62 | assert body == self.body, f"Body unexpectedly equals {self.body!r} instead of {body!r}" 63 | 64 | 65 | class Service(object): 66 | 67 | fetches: Dict[Tuple[str, str], List[ServiceRequest]] 68 | server_started: Optional[Event] 69 | stop_server: Optional[Event] 70 | thread: Optional[Thread] 71 | 72 | def __init__(self) -> None: 73 | self.port: int = get_port() 74 | self.handlers: "Dict[HandlerIdentifier, Handler]" = {} 75 | self.thread = None 76 | self.fetches = {} 77 | self.server_started = None 78 | self.stop_server = None 79 | 80 | def url(self, path: str) -> str: 81 | return f"http://localhost:{self.port}{path}" 82 | 83 | def add_handler(self, method: str, path: str, callback: "Handler") -> None: 84 | identifier: "HandlerIdentifier" = (method, path) 85 | self.handlers[identifier] = callback 86 | 87 | def handler(self, environ: "Environ", start_response: "StartResponse") -> Iterable[bytes]: 88 | uri = request_uri(environ, include_query=True) 89 | path = urlparse(uri).path 90 | method = environ["REQUEST_METHOD"] 91 | body: Optional[bytes] = None 92 | if method in ("POST", "PUT"): 93 | content_length = int(environ.get("CONTENT_LENGTH", 0) or "0") 94 | if content_length > 0: 95 | body = environ["wsgi.input"].read(content_length) 96 | if body is None: 97 | body = b"" 98 | environ["wsgi.input"] = io.BytesIO(body) 99 | else: 100 | logging.warning(f"Unable to determine content length for request {method} {path}") 101 | 102 | headers = { 103 | key: value for key, value in environ.copy().items() 104 | if key.startswith("HTTP_") or key == "CONTENT_TYPE" 105 | } 106 | request = ServiceRequest(headers=headers, method=method, path=path, body=body) 107 | 108 | logging.info(f"Received {method} request for localhost:{self.port}{path}.") 109 | 110 | identifier = (method, path) 111 | if identifier not in self.handlers: 112 | logging.warning( 113 | f"No handler registered for {method} " 114 | f"localhost:{self.port}{path}") 115 | start_response("404 Not Found", [("Content-type", "text/plain")]) 116 | return [f"No handler registered for {identifier}".encode("utf8")] 117 | 118 | environ["REQUEST_PATH"] = path 119 | self.fetches.setdefault(identifier, []) 120 | self.fetches[identifier].append(request) 121 | return self.handlers[identifier](environ, start_response) 122 | 123 | def start(self) -> None: 124 | if self.server_started is not None or self.stop_server is not None: 125 | raise Exception(f"Service already started on port {self.port}") 126 | 127 | self.server_started = Event() 128 | self.stop_server = Event() 129 | 130 | # work around mypy failing to infer that these variables can't be None 131 | server_started = self.server_started 132 | stop_server = self.stop_server 133 | 134 | self.thread = Thread(target=lambda: self.loop(server_started, stop_server)) 135 | self.thread.start() 136 | 137 | logging.info(f"Starting server on port {self.port}...") 138 | 139 | server_started.wait() 140 | 141 | logging.info(f"Server on port {self.port} ready for requests.") 142 | 143 | def stop(self) -> None: 144 | if self.server_started is not None and self.stop_server is not None \ 145 | and self.thread is not None: 146 | self.stop_server.set() 147 | self.thread.join() 148 | self.server_started = None 149 | self.stop_server = None 150 | self.thread = None 151 | 152 | def loop(self, server_started: Event, stop_server: Event) -> None: 153 | with make_server("localhost", self.port, self.handler) as httpd: 154 | httpd.timeout = 0.01 155 | 156 | server_started.set() 157 | while not stop_server.is_set(): 158 | httpd.handle_request() 159 | 160 | def assert_requested( 161 | self, method: str, path: str, 162 | headers: Optional[Dict[str, str]] = None) -> ServiceRequest: 163 | identifier = (method, path) 164 | assert identifier in self.fetches, f"Could not find request matching {method} {path}" 165 | request = self.fetches[identifier][0] 166 | if headers is not None: 167 | for expected_header, expected_value in headers.items(): 168 | request.assert_header_equals(expected_header, expected_value) 169 | return request 170 | 171 | def get_all_requests(self, method: str, path: str) -> List[ServiceRequest]: 172 | identifier = (method, path) 173 | return self.fetches.get(identifier, []) 174 | 175 | def assert_not_requested(self, method: str, path: str) -> None: 176 | identifier = (method, path) 177 | assert identifier not in self.fetches, f"Unexpected request found for {method} {path}" 178 | 179 | def assert_requested_n_times( 180 | self, method: str, path: str, n: int) -> List[ServiceRequest]: 181 | requests = self.get_all_requests(method, path) 182 | assert len(requests) == n, \ 183 | f"Expected request count for {method} {path} ({n}) did not match " \ 184 | f"actual count: {len(requests)}" 185 | return requests 186 | 187 | 188 | class ServiceTestCase(unittest.TestCase): 189 | 190 | def setUp(self) -> None: 191 | super().setUp() 192 | self.services: List[Service] = [] 193 | 194 | def tearDown(self) -> None: 195 | super().tearDown() 196 | self.stop_services() 197 | 198 | def add_service(self) -> Service: 199 | service = Service() 200 | self.services.append(service) 201 | return service 202 | 203 | def start_services(self) -> None: 204 | for service in self.services: 205 | service.start() 206 | 207 | def stop_services(self) -> None: 208 | for service in self.services: 209 | service.stop() 210 | 211 | @contextlib.contextmanager 212 | def run_services(self) -> Generator[None, None, None]: 213 | self.start_services() 214 | try: 215 | yield 216 | finally: 217 | self.stop_services() 218 | -------------------------------------------------------------------------------- /tests/storage_test_case.py: -------------------------------------------------------------------------------- 1 | from unittest import TestCase 2 | from storage.storage import get_storage, InvalidStorageUri 3 | 4 | from typing import Dict, Optional, Sequence 5 | 6 | 7 | class StorageTestCase(TestCase): 8 | 9 | def _generate_storage_uri( 10 | self, object_path: str, parameters: Optional[Dict[str, str]] = None) -> str: 11 | raise NotImplementedError(f"_generate_storage_uri is not implemented on {self.__class__}") 12 | 13 | def assert_rejects_multiple_query_values( 14 | self, object_path: str, query_arg: str, 15 | values: Sequence[str] = ["a", "b"]) -> None: 16 | base_uri = self._generate_storage_uri(object_path) 17 | query_args = [] 18 | list_values = list(values) 19 | for value in list_values: 20 | query_args.append(f"{query_arg}={value}") 21 | 22 | separator = "&" if "?" in base_uri else "?" 23 | uri = f"{base_uri}{separator}{'&'.join(query_args)}" 24 | 25 | with self.assertRaises(InvalidStorageUri): 26 | get_storage(uri) 27 | -------------------------------------------------------------------------------- /tests/swift_service_test_case.py: -------------------------------------------------------------------------------- 1 | import contextlib 2 | import json 3 | import os 4 | import tempfile 5 | from unittest import mock 6 | from urllib.parse import parse_qsl 7 | 8 | from tests.helpers import NamedIO 9 | from tests.service_test_case import ServiceTestCase 10 | 11 | from typing import Any, cast, Dict, Generator, List, Optional, TYPE_CHECKING, Union 12 | 13 | if TYPE_CHECKING: 14 | from tests.service_test_case import Environ 15 | # The "type: ignore" on the next line is needed for Python 3.9 and 3.10 support 16 | from wsgiref.types import StartResponse # type: ignore[import-not-found, unused-ignore] 17 | 18 | 19 | def strip_slashes(path: str) -> str: 20 | while path.endswith(os.path.sep): 21 | path = path[:-1] 22 | while path.startswith(os.path.sep): 23 | path = path[1:] 24 | return path 25 | 26 | 27 | class SwiftServiceTestCase(ServiceTestCase): 28 | 29 | def setUp(self) -> None: 30 | super().setUp() 31 | self.tmp_dir = tempfile.TemporaryDirectory() 32 | self.tmp_files: List[NamedIO] = [] 33 | 34 | self.remaining_container_failures: List[str] = [] 35 | self.remaining_file_failures: List[str] = [] 36 | self.remaining_object_put_failures: List[str] = [] 37 | self.remaining_file_delete_failures: List[str] = [] 38 | 39 | self.container_contents: Union[Dict[str, bytes], Any] = {} 40 | self.directory_contents: Dict[str, bytes] = {} 41 | self.object_contents: Dict[str, bytes] = {} 42 | 43 | # this is fragile, but they import and use the function directly, so we mock in-module 44 | self.mock_swiftclient_sleep_patch = mock.patch("swiftclient.client.sleep") 45 | self.mock_swiftclient_sleep = self.mock_swiftclient_sleep_patch.start() 46 | self.mock_swiftclient_sleep.side_effect = lambda x: None 47 | 48 | self.swift_service = self.add_service() 49 | 50 | def tearDown(self) -> None: 51 | super().tearDown() 52 | for fp in self.tmp_files: 53 | fp.close() 54 | self.tmp_dir.cleanup() 55 | self.mock_swiftclient_sleep_patch.stop() 56 | 57 | def _add_file_to_directory(self, filepath: str, file_content: bytes) -> None: 58 | if type(file_content) is not bytes: 59 | raise Exception("Object file contents must be bytes") 60 | 61 | self.container_contents[filepath] = file_content 62 | container_path = "/v2.0/1234/CONTAINER" 63 | self.swift_service.add_handler("GET", container_path, self.swift_container_handler) 64 | self.add_container_object(container_path, filepath, file_content) 65 | 66 | def _add_tmp_file_to_dir( 67 | self, directory: str, 68 | file_content: bytes, 69 | suffix: Optional[str] = None) -> NamedIO: 70 | if type(file_content) is not bytes: 71 | raise Exception("Object file contents must be bytes") 72 | 73 | os.makedirs(directory, exist_ok=True) 74 | 75 | tmp_file = cast( 76 | NamedIO, tempfile.NamedTemporaryFile(dir=directory, suffix=suffix)) 77 | tmp_file.write(file_content) 78 | tmp_file.flush() 79 | 80 | self.tmp_files.append(tmp_file) 81 | return tmp_file 82 | 83 | def add_file_error(self, error: str) -> None: 84 | self.remaining_file_failures.append(error) 85 | 86 | def add_container_object( 87 | self, container_path: str, object_path: str, content: bytes) -> None: 88 | if type(content) is not bytes: 89 | raise Exception("Object file contents numst be bytes") 90 | 91 | self.object_contents[object_path] = content 92 | 93 | get_path = f"{container_path}{object_path}" 94 | self.swift_service.add_handler("GET", get_path, self.object_handler) 95 | 96 | @contextlib.contextmanager 97 | def expect_put_object( 98 | self, 99 | container_path: str, 100 | object_path: str, 101 | content: bytes) -> Generator[None, None, None]: 102 | put_path = f"{container_path}{object_path}" 103 | self.swift_service.add_handler("PUT", put_path, self.object_put_handler) 104 | yield 105 | self.assertEqual(content, self.container_contents[object_path]) 106 | self.swift_service.assert_requested("PUT", put_path) 107 | 108 | @contextlib.contextmanager 109 | def expect_delete_object( 110 | self, container_path: str, object_path: str) -> Generator[None, None, None]: 111 | self.container_contents[object_path] = b"UNDELETED!" 112 | delete_path = f"{container_path}{object_path}" 113 | self.swift_service.add_handler("DELETE", delete_path, self.object_delete_handler) 114 | yield 115 | self.assertNotIn( 116 | object_path, self.container_contents, 117 | f"File {object_path} was not deleted as expected.") 118 | 119 | @contextlib.contextmanager 120 | def expect_directory(self, filepath: str) -> Generator[None, None, None]: 121 | for root, _, files in os.walk(self.tmp_dir.name): 122 | dirpath = strip_slashes(root.split(self.tmp_dir.name)[1]) 123 | for basepath in files: 124 | relative_path = os.path.join(dirpath, basepath) 125 | remote_path = "/".join([filepath, relative_path]) 126 | 127 | put_path = f"/v2.0/1234/CONTAINER{remote_path}" 128 | self.swift_service.add_handler("PUT", put_path, self.object_put_handler) 129 | yield 130 | 131 | self.assert_container_contents_equal(filepath) 132 | 133 | @contextlib.contextmanager 134 | def expect_delete_directory(self, filepath: str) -> Generator[None, None, None]: 135 | expected_delete_paths = [] 136 | for name in self.container_contents: 137 | delete_path = f"/v2.0/1234/CONTAINER/{strip_slashes(name)}" 138 | expected_delete_paths.append(delete_path) 139 | self.swift_service.add_handler("DELETE", delete_path, self.object_delete_handler) 140 | 141 | yield 142 | 143 | for delete_path in expected_delete_paths: 144 | self.swift_service.assert_requested("DELETE", delete_path) 145 | 146 | def object_handler(self, environ: "Environ", start_response: "StartResponse") -> List[bytes]: 147 | path = environ["REQUEST_PATH"].split("CONTAINER")[1] 148 | 149 | if len(self.remaining_file_failures) > 0: 150 | failure = self.remaining_file_failures.pop(0) 151 | 152 | start_response(failure, [("Content-type", "text/plain")]) 153 | return [b"Internal Server Error"] 154 | 155 | if path not in self.object_contents: 156 | start_response("404 NOT FOUND", [("Content-Type", "text/plain")]) 157 | return [f"Object file {path} not in file contents dictionary".encode("utf8")] 158 | 159 | start_response("200 OK", [("Content-type", "video/mp4")]) 160 | return [self.object_contents[path]] 161 | 162 | def object_put_handler( 163 | self, environ: "Environ", start_response: "StartResponse") -> List[bytes]: 164 | path = environ["REQUEST_PATH"].split("CONTAINER")[1] 165 | 166 | contents = b"" 167 | while True: 168 | header = b"" 169 | while not header.endswith(b"\r\n"): 170 | header += environ["wsgi.input"].read(1) 171 | 172 | body_size = int(header.strip()) 173 | contents += environ["wsgi.input"].read(body_size) 174 | environ["wsgi.input"].read(2) # read trailing "\r\n" 175 | 176 | if body_size == 0: 177 | break 178 | 179 | self.container_contents[path] = contents 180 | 181 | if len(self.remaining_object_put_failures) > 0: 182 | failure = self.remaining_object_put_failures.pop(0) 183 | start_response(failure, [("Content-type", "text/plain")]) 184 | return [b"Internal server error"] 185 | 186 | start_response("201 OK", [("Content-type", "text/plain")]) 187 | return [b""] 188 | 189 | def object_delete_handler( 190 | self, environ: "Environ", start_response: "StartResponse") -> List[bytes]: 191 | path = environ["REQUEST_PATH"].split("CONTAINER")[1] 192 | 193 | if len(self.remaining_file_delete_failures) > 0: 194 | failure = self.remaining_file_delete_failures.pop(0) 195 | start_response(failure, [("Content-type", "text/plain")]) 196 | return [b"Internal server error."] 197 | 198 | del self.container_contents[path] 199 | start_response("204 OK", [("Content-type", "text-plain")]) 200 | return [b""] 201 | 202 | def swift_container_handler( 203 | self, environ: "Environ", start_response: "StartResponse") -> List[bytes]: 204 | if len(self.remaining_container_failures) > 0: 205 | failure = self.remaining_container_failures.pop(0) 206 | 207 | start_response(failure, [("Content-type", "text/plain")]) 208 | return [b"Internal server error"] 209 | 210 | parsed_args = dict(parse_qsl(environ["QUERY_STRING"])) 211 | 212 | if "json" == parsed_args.get("format"): 213 | start_response("200 OK", [("Content-Type", "application/json")]) 214 | if len(self.container_contents) == 0: 215 | return [json.dumps([]).encode("utf8")] 216 | 217 | return [json.dumps([ 218 | {"name": v} for v in self.container_contents.keys() 219 | ]).encode("utf8")] 220 | 221 | start_response("200 OK", [("Content-Type", "text/plain")]) 222 | return ["\n".join(self.container_contents).encode("utf8")] 223 | 224 | def assert_container_contents_equal(self, object_path: str) -> None: 225 | written_files = {} 226 | expected_files = { 227 | strip_slashes(f.split(object_path)[1]): v 228 | for f, v in self.container_contents.items() if not f.endswith("/") 229 | } 230 | 231 | for root, dirs, files in os.walk(self.tmp_dir.name): 232 | dirpath = strip_slashes(root.split(self.tmp_dir.name)[1]) 233 | for basepath in files: 234 | fullpath = os.path.join(root, basepath) 235 | relpath = os.path.join(dirpath, basepath) 236 | with open(fullpath, "rb") as fp: 237 | written_files[relpath] = fp.read() 238 | 239 | self.assertCountEqual(written_files, expected_files) 240 | -------------------------------------------------------------------------------- /tests/test_cloudfiles_storage.py: -------------------------------------------------------------------------------- 1 | import contextlib 2 | import io 3 | import json 4 | from unittest import mock 5 | from urllib.parse import parse_qsl, urlencode, urlparse 6 | 7 | from keystoneauth1.exceptions.http import Forbidden, Unauthorized 8 | from typing import Dict, Generator, List, Optional, TYPE_CHECKING 9 | 10 | from storage.storage import get_storage, InvalidStorageUri 11 | from tests.storage_test_case import StorageTestCase 12 | from tests.swift_service_test_case import SwiftServiceTestCase 13 | 14 | if TYPE_CHECKING: 15 | from tests.service_test_case import Environ 16 | # The "type: ignore" on the next line is needed for Python 3.9 and 3.10 support 17 | from wsgiref.types import StartResponse # type: ignore[import-not-found, unused-ignore] 18 | 19 | 20 | class TestCloudFilesStorageProvider(StorageTestCase, SwiftServiceTestCase): 21 | def setUp(self) -> None: 22 | super().setUp() 23 | 24 | self.download_url_key = {"download_url_key": "KEY"} 25 | self.keystone_credentials = { 26 | "username": "USER", 27 | "key": "TOKEN" 28 | } 29 | 30 | self.identity_service = self.add_service() 31 | self.identity_service.add_handler("GET", "/v2.0", self.identity_handler) 32 | self.identity_service.add_handler("POST", "/v2.0/tokens", self.authentication_handler) 33 | 34 | self.alt_cloudfiles_service = self.add_service() 35 | self.internal_cloudfiles_service = self.add_service() 36 | 37 | self.mock_sleep_patch = mock.patch("time.sleep") 38 | self.mock_sleep = self.mock_sleep_patch.start() 39 | self.mock_sleep.side_effect = lambda x: None 40 | 41 | def tearDown(self) -> None: 42 | super().tearDown() 43 | self.mock_sleep_patch.stop() 44 | 45 | def _generate_storage_uri( 46 | self, object_path: str, parameters: Optional[Dict[str, str]] = None) -> str: 47 | base_uri = f"cloudfiles://USER:TOKEN@CONTAINER{object_path}" 48 | if parameters is not None: 49 | return f"{base_uri}?{urlencode(parameters)}" 50 | return base_uri 51 | 52 | def _has_valid_credentials(self, auth_data: Dict[str, str]) -> bool: 53 | if auth_data["username"] == self.keystone_credentials["username"] and \ 54 | auth_data["apiKey"] == self.keystone_credentials["key"]: 55 | return True 56 | else: 57 | return False 58 | 59 | def assert_requires_all_parameters(self, object_path: str) -> None: 60 | for auth_string in ["USER:@", ":TOKEN@", ""]: 61 | cloudfiles_uri = f"cloudfiles://{auth_string}CONTAINER{object_path}" 62 | 63 | with self.assertRaises(InvalidStorageUri): 64 | get_storage(cloudfiles_uri) 65 | 66 | with self.assertRaises(InvalidStorageUri): 67 | get_storage(f"cloudfiles://{object_path}") 68 | 69 | @contextlib.contextmanager 70 | def assert_raises_on_forbidden_access(self) -> Generator[None, None, None]: 71 | self.keystone_credentials["username"] = "nobody" 72 | with self.run_services(): 73 | with self.assertRaises(Forbidden): 74 | yield 75 | 76 | @contextlib.contextmanager 77 | def assert_raises_on_unauthorized_access(self) -> Generator[None, None, None]: 78 | self.keystone_credentials = {} 79 | with self.run_services(): 80 | with self.assertRaises(Unauthorized): 81 | yield 82 | 83 | @contextlib.contextmanager 84 | def use_local_identity_service(self) -> Generator[None, None, None]: 85 | with mock.patch( 86 | "storage.cloudfiles_storage.CloudFilesStorage.auth_endpoint", 87 | new_callable=mock.PropertyMock) as mock_endpoint: 88 | mock_endpoint.return_value = self.identity_service.url("/v2.0") 89 | yield 90 | 91 | @contextlib.contextmanager 92 | def expect_head_account_object(self, path: str) -> Generator[None, None, None]: 93 | self.swift_service.add_handler("HEAD", path, self.object_head_account_handler) 94 | yield 95 | self.swift_service.assert_requested("HEAD", path) 96 | 97 | def identity_handler(self, environ: "Environ", start_response: "StartResponse") -> List[bytes]: 98 | start_response("200 OK", [("Content-Type", "application/json")]) 99 | return [json.dumps({ 100 | "version": { 101 | "media-types": { 102 | "values": [ 103 | { 104 | "type": "application/vnd.openstack.identity+json;version=2.0", 105 | "base": "application/json" 106 | } 107 | ] 108 | }, 109 | "links": [ 110 | { 111 | "rel": "self", 112 | "href": self.identity_service.url("/v2.0") 113 | } 114 | ], 115 | "id": "v2.0", 116 | "status": "CURRENT" 117 | } 118 | }).encode("utf8")] 119 | 120 | def authentication_handler( 121 | self, environ: "Environ", start_response: "StartResponse") -> List[bytes]: 122 | body_size = int(environ.get("CONTENT_LENGTH", 0)) 123 | body = json.loads(environ["wsgi.input"].read(body_size)) 124 | 125 | # Forcing a 401 since swift service won't let us provide it 126 | if self.keystone_credentials == {}: 127 | start_response("401 Unauthorized", [("Content-type", "text/plain")]) 128 | return [b"Unauthorized keystone credentials."] 129 | if not self._has_valid_credentials(body["auth"]["RAX-KSKEY:apiKeyCredentials"]): 130 | start_response("403 Forbidden", [("Content-type", "text/plain")]) 131 | return [b"Invalid keystone credentials."] 132 | 133 | start_response("200 OK", [("Content-type", "application/json")]) 134 | return [json.dumps({ 135 | "access": { 136 | "serviceCatalog": [{ 137 | "endpoints": [ 138 | { 139 | "tenantId": "MOSSO-TENANT", 140 | "publicURL": self.swift_service.url("/v2.0/MOSSO-TENANT"), 141 | "internalURL": self.internal_cloudfiles_service.url( 142 | "/v2.0/MOSSO-TENANT"), 143 | "region": "DFW" 144 | }, 145 | { 146 | "tenantId": "MOSSO-TENANT", 147 | "publicURL": self.alt_cloudfiles_service.url("/v2.0/MOSSO-TENANT"), 148 | "internalURL": self.alt_cloudfiles_service.url("/v2.0/MOSSO-TENANT"), 149 | "region": "ORD" 150 | } 151 | ], 152 | "name": "cloudfiles", 153 | "type": "object-store" 154 | }], 155 | "user": { 156 | "RAX-AUTH:defaultRegion": "DFW", 157 | "roles": [{ 158 | "name": "object-store:default", 159 | "tenantId": "MOSSO-TENANT", 160 | "id": "ID" 161 | }], 162 | "name": "USER", 163 | "id": "IDENTIFIER" 164 | }, 165 | "token": { 166 | "expires": "2019-07-18T05:47:13.090Z", 167 | "RAX-AUTH:authenticatedBy": ["APIKEY"], 168 | "id": "KEY", 169 | "tenant": { 170 | "name": "MOSSO-TENANT", 171 | "id": "MOSSO-TENANT" 172 | } 173 | } 174 | } 175 | }).encode("utf8")] 176 | 177 | def object_head_account_handler( 178 | self, environ: "Environ", start_response: "StartResponse") -> List[bytes]: 179 | start_response("204 OK", [ 180 | ("Content-type", "text-plain"), 181 | ("X-Account-Meta-Temp-Url-Key", "TEMPKEY") 182 | ]) 183 | return [b""] 184 | 185 | def test_cloudfiles_default_auth_endpoint_points_to_correct_host(self) -> None: 186 | cloudfiles_uri = self._generate_storage_uri("/path/to/file.mp4", self.download_url_key) 187 | storage_object = get_storage(cloudfiles_uri) 188 | 189 | self.assertEqual( 190 | "https://identity.api.rackspacecloud.com/v2.0", 191 | storage_object.auth_endpoint) # type: ignore 192 | 193 | def test_save_to_file_raises_exception_when_missing_required_parameters(self) -> None: 194 | self.assert_requires_all_parameters("/path/to/file.mp4") 195 | 196 | def test_save_to_file_raises_on_forbidden_credentials(self) -> None: 197 | self.add_container_object("/v2.0/MOSSO-TENANT/CONTAINER", "/path/to/file.mp4", b"FOOBAR") 198 | 199 | temp = io.BytesIO() 200 | 201 | cloudfiles_uri = self._generate_storage_uri("/path/to/file.mp4", self.download_url_key) 202 | storage_object = get_storage(cloudfiles_uri) 203 | 204 | with self.use_local_identity_service(): 205 | with self.assert_raises_on_forbidden_access(): 206 | storage_object.save_to_file(temp) 207 | 208 | with self.use_local_identity_service(): 209 | with self.assert_raises_on_unauthorized_access(): 210 | storage_object.save_to_file(temp) 211 | 212 | def test_save_to_file_writes_file_contents_to_file_object(self) -> None: 213 | self.add_container_object("/v2.0/MOSSO-TENANT/CONTAINER", "/path/to/file.mp4", b"FOOBAR") 214 | 215 | temp = io.BytesIO() 216 | 217 | cloudfiles_uri = self._generate_storage_uri("/path/to/file.mp4", self.download_url_key) 218 | storage_object = get_storage(cloudfiles_uri) 219 | 220 | with self.use_local_identity_service(): 221 | with self.run_services(): 222 | storage_object.save_to_file(temp) 223 | 224 | temp.seek(0) 225 | self.assertEqual(b"FOOBAR", temp.read()) 226 | 227 | def test_save_to_file_uses_default_region_when_one_is_not_provided(self) -> None: 228 | self.add_container_object("/v2.0/MOSSO-TENANT/CONTAINER", "/path/to/file.mp4", b"FOOBAR") 229 | 230 | temp = io.BytesIO() 231 | 232 | cloudfiles_uri = self._generate_storage_uri("/path/to/file.mp4", self.download_url_key) 233 | storage_object = get_storage(cloudfiles_uri) 234 | 235 | with self.use_local_identity_service(): 236 | with self.run_services(): 237 | storage_object.save_to_file(temp) 238 | 239 | self.swift_service.assert_requested_n_times( 240 | "GET", "/v2.0/MOSSO-TENANT/CONTAINER/path/to/file.mp4", 1) 241 | self.alt_cloudfiles_service.assert_requested_n_times( 242 | "GET", "/v2.0/MOSSO-TENANT/CONTAINER/path/to/file.mp4", 0) 243 | 244 | def test_save_to_file_uses_provided_region_parameter(self) -> None: 245 | self.object_contents["/path/to/file.mp4"] = b"FOOBAR" 246 | 247 | get_path = "/v2.0/MOSSO-TENANT/CONTAINER/path/to/file.mp4" 248 | self.alt_cloudfiles_service.add_handler("GET", get_path, self.object_handler) 249 | 250 | temp = io.BytesIO() 251 | 252 | cloudfiles_uri = self._generate_storage_uri("/path/to/file.mp4", { 253 | "region": "ORD", 254 | "download_url_key": "KEY" 255 | }) 256 | storage_object = get_storage(cloudfiles_uri) 257 | 258 | with self.use_local_identity_service(): 259 | with self.run_services(): 260 | storage_object.save_to_file(temp) 261 | 262 | self.alt_cloudfiles_service.assert_requested_n_times( 263 | "GET", "/v2.0/MOSSO-TENANT/CONTAINER/path/to/file.mp4", 1) 264 | self.swift_service.assert_requested_n_times( 265 | "GET", "/v2.0/MOSSO-TENANT/CONTAINER/path/to/file.mp4", 0) 266 | 267 | def test_save_to_file_uses_default_endpoint_type_when_one_is_not_provided(self) -> None: 268 | self.add_container_object("/v2.0/MOSSO-TENANT/CONTAINER", "/path/to/file.mp4", b"FOOBAR") 269 | 270 | temp = io.BytesIO() 271 | 272 | cloudfiles_uri = self._generate_storage_uri("/path/to/file.mp4", self.download_url_key) 273 | storage_object = get_storage(cloudfiles_uri) 274 | 275 | with self.use_local_identity_service(): 276 | with self.run_services(): 277 | storage_object.save_to_file(temp) 278 | 279 | self.swift_service.assert_requested_n_times( 280 | "GET", "/v2.0/MOSSO-TENANT/CONTAINER/path/to/file.mp4", 1) 281 | self.internal_cloudfiles_service.assert_requested_n_times( 282 | "GET", "/v2.0/MOSSO-TENANT/CONTAINER/path/to/file.mp4", 0) 283 | 284 | def test_save_to_file_uses_provided_public_parameter(self) -> None: 285 | self.object_contents["/path/to/file.mp4"] = b"FOOBAR" 286 | 287 | get_path = "/v2.0/MOSSO-TENANT/CONTAINER/path/to/file.mp4" 288 | self.internal_cloudfiles_service.add_handler("GET", get_path, self.object_handler) 289 | 290 | temp = io.BytesIO() 291 | 292 | cloudfiles_uri = self._generate_storage_uri("/path/to/file.mp4", { 293 | "public": "false", 294 | "download_url_key": "KEY" 295 | }) 296 | storage_object = get_storage(cloudfiles_uri) 297 | 298 | with self.use_local_identity_service(): 299 | with self.run_services(): 300 | storage_object.save_to_file(temp) 301 | 302 | self.internal_cloudfiles_service.assert_requested_n_times( 303 | "GET", "/v2.0/MOSSO-TENANT/CONTAINER/path/to/file.mp4", 1) 304 | self.swift_service.assert_requested_n_times( 305 | "GET", "/v2.0/MOSSO-TENANT/CONTAINER/path/to/file.mp4", 0) 306 | 307 | def test_save_to_file_uses_provided_public_parameter_case_insensitive(self) -> None: 308 | self.object_contents["/path/to/file.mp4"] = b"FOOBAR" 309 | 310 | get_path = "/v2.0/MOSSO-TENANT/CONTAINER/path/to/file.mp4" 311 | self.internal_cloudfiles_service.add_handler("GET", get_path, self.object_handler) 312 | 313 | temp = io.BytesIO() 314 | 315 | cloudfiles_uri = self._generate_storage_uri("/path/to/file.mp4", { 316 | "public": "False", 317 | "download_url_key": "KEY" 318 | }) 319 | storage_object = get_storage(cloudfiles_uri) 320 | 321 | with self.use_local_identity_service(): 322 | with self.run_services(): 323 | storage_object.save_to_file(temp) 324 | 325 | self.internal_cloudfiles_service.assert_requested_n_times( 326 | "GET", "/v2.0/MOSSO-TENANT/CONTAINER/path/to/file.mp4", 1) 327 | self.swift_service.assert_requested_n_times( 328 | "GET", "/v2.0/MOSSO-TENANT/CONTAINER/path/to/file.mp4", 0) 329 | 330 | def test_load_from_file_puts_file_contents_at_object_endpoint(self) -> None: 331 | temp = io.BytesIO(b"FOOBAR") 332 | 333 | cloudfiles_uri = self._generate_storage_uri("/path/to/file.mp4", self.download_url_key) 334 | storage_object = get_storage(cloudfiles_uri) 335 | 336 | with self.use_local_identity_service(): 337 | with self.expect_put_object( 338 | "/v2.0/MOSSO-TENANT/CONTAINER", "/path/to/file.mp4", b"FOOBAR"): 339 | with self.run_services(): 340 | storage_object.load_from_file(temp) 341 | 342 | def test_delete_makes_delete_request_against_swift_service(self) -> None: 343 | cloudfiles_uri = self._generate_storage_uri("/path/to/file.mp4", self.download_url_key) 344 | storage_object = get_storage(cloudfiles_uri) 345 | 346 | with self.use_local_identity_service(): 347 | with self.expect_delete_object("/v2.0/MOSSO-TENANT/CONTAINER", "/path/to/file.mp4"): 348 | with self.run_services(): 349 | storage_object.delete() 350 | 351 | @mock.patch("time.time") 352 | def test_get_download_url_returns_signed_url(self, mock_time: mock.Mock) -> None: 353 | mock_time.return_value = 9000 354 | 355 | cloudfiles_uri = self._generate_storage_uri("/path/to/file.mp4", self.download_url_key) 356 | storage_object = get_storage(cloudfiles_uri) 357 | 358 | with self.use_local_identity_service(): 359 | with self.run_services(): 360 | url = storage_object.get_download_url() 361 | 362 | parsed = urlparse(url) 363 | expected = urlparse(self.swift_service.url("/v2.0/MOSSO-TENANT/CONTAINER/path/to/file.mp4")) 364 | 365 | self.assertEqual(parsed.path, expected.path) 366 | self.assertEqual(parsed.netloc, expected.netloc) 367 | 368 | query = dict(parse_qsl(parsed.query)) 369 | 370 | self.assertEqual("9060", query["temp_url_expires"]) 371 | self.assertTrue("temp_url_sig" in query) 372 | 373 | @mock.patch("time.time") 374 | def test_get_download_url_uses_temp_url_key_when_download_url_key_not_present( 375 | self, mock_time: mock.Mock) -> None: 376 | mock_time.return_value = 9000 377 | 378 | cloudfiles_uri = self._generate_storage_uri("/path/to/file.mp4") 379 | storage_object = get_storage(cloudfiles_uri) 380 | 381 | with self.use_local_identity_service(): 382 | with self.run_services(): 383 | with self.expect_head_account_object("/v2.0/MOSSO-TENANT"): 384 | url = storage_object.get_download_url() 385 | 386 | parsed = urlparse(url) 387 | expected = urlparse(self.swift_service.url("/v2.0/MOSSO-TENANT/CONTAINER/path/to/file.mp4")) 388 | 389 | self.assertEqual(parsed.path, expected.path) 390 | self.assertEqual(parsed.netloc, expected.netloc) 391 | 392 | def test_cloudfiles_rejects_multiple_query_values_for_public_setting(self) -> None: 393 | self.assert_rejects_multiple_query_values( 394 | "object.mp4", "public", values=["public", "private"]) 395 | 396 | def test_cloudfiles_rejects_multiple_query_values_for_region_setting(self) -> None: 397 | self.assert_rejects_multiple_query_values("object.mp4", "region", values=["DFW", "ORD"]) 398 | 399 | def test_cloudfiles_rejects_multiple_query_values_for_download_url_key_setting(self) -> None: 400 | self.assert_rejects_multiple_query_values("object.mp4", "download_url_key") 401 | 402 | def test_get_sanitized_uri_returns_storage_uri_without_username_and_password(self) -> None: 403 | cloudfiles_uri = self._generate_storage_uri("/path/to/file.mp4") 404 | storage_object = get_storage(cloudfiles_uri) 405 | 406 | with self.use_local_identity_service(): 407 | with self.run_services(): 408 | sanitized_uri = storage_object.get_sanitized_uri() 409 | 410 | self.assertEqual("cloudfiles://container/path/to/file.mp4", sanitized_uri) 411 | 412 | def test_get_sanitized_uri_returns_storage_uri_without_download_url_key(self) -> None: 413 | cloudfiles_uri = self._generate_storage_uri("/path/to/file.mp4", self.download_url_key) 414 | storage_object = get_storage(cloudfiles_uri) 415 | 416 | with self.use_local_identity_service(): 417 | with self.run_services(): 418 | sanitized_uri = storage_object.get_sanitized_uri() 419 | 420 | self.assertEqual("cloudfiles://container/path/to/file.mp4", sanitized_uri) 421 | -------------------------------------------------------------------------------- /tests/test_google_storage.py: -------------------------------------------------------------------------------- 1 | import base64 2 | import datetime 3 | import json 4 | from unittest import TestCase, mock 5 | 6 | from google.cloud.exceptions import NotFound 7 | from storage.storage import get_storage, NotFoundError, InvalidStorageUri 8 | 9 | 10 | class TestGoogleStorage(TestCase): 11 | def setUp(self) -> None: 12 | super(TestGoogleStorage, self).setUp() 13 | 14 | self.credentials = base64.urlsafe_b64encode(json.dumps({ 15 | "SOME": "CREDENTIALS", 16 | "project_id": "PROJECT-ID" 17 | }).encode("utf8")).decode("utf8") 18 | 19 | service_account_patcher = mock.patch( 20 | "google.oauth2.service_account.Credentials.from_service_account_info") 21 | self.mock_from_service_account_info = service_account_patcher.start() 22 | self.addCleanup(service_account_patcher.stop) 23 | 24 | client_patcher = mock.patch("google.cloud.storage.client.Client") 25 | self.mock_client_class = client_patcher.start() 26 | self.addCleanup(client_patcher.stop) 27 | 28 | self.mock_credentials = self.mock_from_service_account_info.return_value 29 | self.mock_client = self.mock_client_class.return_value 30 | self.mock_bucket = self.mock_client.get_bucket.return_value 31 | self.mock_blob = self.mock_bucket.blob.return_value 32 | 33 | def assert_gets_bucket_with_credentials(self) -> None: 34 | self.mock_from_service_account_info.assert_called_once_with( 35 | {"SOME": "CREDENTIALS", "project_id": "PROJECT-ID"}) 36 | self.mock_client_class.assert_called_once_with( 37 | project="PROJECT-ID", credentials=self.mock_credentials) 38 | self.mock_client.get_bucket.assert_called_once_with("bucketname") 39 | 40 | def test_requires_username_in_uri(self) -> None: 41 | with self.assertRaises(InvalidStorageUri): 42 | get_storage("gs://bucket/path") 43 | 44 | def test_requires_hostname_in_uri(self) -> None: 45 | with self.assertRaises(InvalidStorageUri): 46 | get_storage("gs://username@/path") 47 | 48 | def test_save_to_filename_downloads_blob_to_file_location(self) -> None: 49 | storage = get_storage("gs://{}@bucketname/path/filename".format(self.credentials)) 50 | 51 | storage.save_to_filename("SOME-FILE") 52 | 53 | self.assert_gets_bucket_with_credentials() 54 | 55 | self.mock_bucket.blob.assert_called_once_with("path/filename") 56 | self.mock_blob.download_to_filename.assert_called_once_with("SOME-FILE") 57 | 58 | def test_save_to_filename_raises_when_file_does_not_exist(self) -> None: 59 | storage = get_storage("gs://{}@bucketname/path/filename".format(self.credentials)) 60 | self.mock_blob.download_to_filename.side_effect = NotFound("File Not Found") 61 | 62 | with self.assertRaises(NotFoundError): 63 | storage.save_to_filename("SOME-FILE") 64 | 65 | def test_save_to_file_downloads_blob_to_file_object(self) -> None: 66 | mock_file = mock.Mock() 67 | 68 | storage = get_storage("gs://{}@bucketname/path/filename".format(self.credentials)) 69 | 70 | storage.save_to_file(mock_file) 71 | 72 | self.assert_gets_bucket_with_credentials() 73 | 74 | self.mock_bucket.blob.assert_called_once_with("path/filename") 75 | self.mock_blob.download_to_file.assert_called_once_with(mock_file) 76 | 77 | def test_save_to_file_raises_when_filename_does_not_exist(self) -> None: 78 | mock_file = mock.Mock() 79 | 80 | storage = get_storage("gs://{}@bucketname/path/filename".format(self.credentials)) 81 | self.mock_blob.download_to_file.side_effect = NotFound("File Not Found") 82 | 83 | with self.assertRaises(NotFoundError): 84 | storage.save_to_file(mock_file) 85 | 86 | def test_load_from_filename_uploads_blob_from_file_location(self) -> None: 87 | storage = get_storage("gs://{}@bucketname/path/filename".format(self.credentials)) 88 | 89 | storage.load_from_filename("SOME-FILE") 90 | 91 | self.assert_gets_bucket_with_credentials() 92 | 93 | self.mock_bucket.blob.assert_called_once_with("path/filename") 94 | self.mock_blob.upload_from_filename.assert_called_once_with("SOME-FILE") 95 | 96 | def test_load_from_file_uploads_blob_from_file_object(self) -> None: 97 | mock_file = mock.Mock() 98 | 99 | storage = get_storage("gs://{}@bucketname/path/filename".format(self.credentials)) 100 | 101 | storage.load_from_file(mock_file) 102 | 103 | self.assert_gets_bucket_with_credentials() 104 | 105 | self.mock_bucket.blob.assert_called_once_with("path/filename") 106 | self.mock_blob.upload_from_file.assert_called_once_with(mock_file, content_type=None) 107 | 108 | def test_load_from_file_guesses_content_type_based_on_filename(self) -> None: 109 | mock_file = mock.Mock() 110 | 111 | storage = get_storage("gs://{}@bucketname/path/whatever.html".format(self.credentials)) 112 | 113 | storage.load_from_file(mock_file) 114 | 115 | self.mock_blob.upload_from_file.assert_called_once_with(mock_file, content_type="text/html") 116 | 117 | def test_delete_deletes_blob(self) -> None: 118 | storage = get_storage("gs://{}@bucketname/path/filename".format(self.credentials)) 119 | 120 | storage.delete() 121 | 122 | self.assert_gets_bucket_with_credentials() 123 | 124 | self.mock_bucket.blob.assert_called_once_with("path/filename") 125 | self.mock_blob.delete.assert_called_once_with() 126 | 127 | def test_delete_raises_when_file_does_not_exist(self) -> None: 128 | storage = get_storage("gs://{}@bucketname/path/filename".format(self.credentials)) 129 | self.mock_blob.delete.side_effect = NotFound("File Not Found") 130 | 131 | with self.assertRaises(NotFoundError): 132 | storage.delete() 133 | 134 | def test_get_download_url_returns_signed_url_with_default_expiration(self) -> None: 135 | mock_signed_url = self.mock_blob.generate_signed_url.return_value 136 | 137 | storage = get_storage("gs://{}@bucketname/path/filename".format(self.credentials)) 138 | 139 | result = storage.get_download_url() 140 | 141 | self.assertEqual(mock_signed_url, result) 142 | 143 | self.assert_gets_bucket_with_credentials() 144 | 145 | self.mock_bucket.blob.assert_called_once_with("path/filename") 146 | self.mock_blob.generate_signed_url.assert_called_once_with( 147 | expiration=datetime.timedelta(seconds=60), 148 | response_disposition="attachment") 149 | 150 | def test_get_download_url_returns_signed_url_with_provided_expiration(self) -> None: 151 | storage = get_storage("gs://{}@bucketname/path/filename".format(self.credentials)) 152 | 153 | storage.get_download_url(1000) 154 | 155 | self.mock_blob.generate_signed_url.assert_called_once_with( 156 | expiration=datetime.timedelta(seconds=1000), 157 | response_disposition="attachment") 158 | 159 | def test_get_download_url_does_not_use_key_when_provided(self) -> None: 160 | storage = get_storage("gs://{}@bucketname/path/filename".format(self.credentials)) 161 | 162 | storage.get_download_url(key="KEY") 163 | 164 | self.mock_blob.generate_signed_url.assert_called_once_with( 165 | expiration=datetime.timedelta(seconds=60), 166 | response_disposition="attachment") 167 | 168 | def test_get_sanitized_uri_returns_storage_uri_without_username_and_password(self) -> None: 169 | storage = get_storage("gs://{}@bucketname/path/filename".format(self.credentials)) 170 | 171 | sanitized_uri = storage.get_sanitized_uri() 172 | 173 | self.assertEqual("gs://bucketname/path/filename", sanitized_uri) 174 | 175 | def _mock_blob(self, name: str) -> mock.Mock: 176 | blob = mock.Mock() 177 | blob.name = name 178 | return blob 179 | 180 | @mock.patch("os.path.exists") 181 | @mock.patch("os.makedirs") 182 | def test_save_to_directory_downloads_blobs_matching_prefix_to_directory_location( 183 | self, mock_makedirs: mock.Mock, mock_exists: mock.Mock) -> None: 184 | mock_exists.side_effect = [True, False, False] 185 | 186 | mock_listed_blobs = [ 187 | self._mock_blob("path/filename/file1"), 188 | self._mock_blob("path/filename/subdir1/subdir2/file2"), 189 | self._mock_blob("path/filename/subdir3/path/filename/file3") 190 | ] 191 | self.mock_bucket.list_blobs.return_value = iter(mock_listed_blobs) 192 | 193 | mock_unversioned_blobs = [ 194 | self._mock_blob("path/filename/file1"), 195 | self._mock_blob("path/filename/subdir1/subdir2/file2"), 196 | self._mock_blob("path/filename/subdir3/path/filename/file3") 197 | ] 198 | self.mock_bucket.blob.side_effect = mock_unversioned_blobs 199 | 200 | storage = get_storage("gs://{}@bucketname/path/filename".format(self.credentials)) 201 | 202 | storage.save_to_directory("directory-name") 203 | 204 | self.assert_gets_bucket_with_credentials() 205 | 206 | self.mock_bucket.list_blobs.assert_called_once_with(prefix="path/filename/") 207 | 208 | self.assertEqual(3, self.mock_bucket.blob.call_count) 209 | self.mock_bucket.blob.assert_has_calls([ 210 | mock.call("path/filename/file1"), 211 | mock.call("path/filename/subdir1/subdir2/file2"), 212 | mock.call("path/filename/subdir3/path/filename/file3") 213 | ]) 214 | 215 | mock_unversioned_blobs[0].download_to_filename.assert_called_once_with( 216 | "directory-name/file1") 217 | mock_unversioned_blobs[1].download_to_filename.assert_called_once_with( 218 | "directory-name/subdir1/subdir2/file2") 219 | mock_unversioned_blobs[2].download_to_filename.assert_called_once_with( 220 | "directory-name/subdir3/path/filename/file3") 221 | 222 | self.assertEqual( 223 | [ 224 | mock.call("directory-name"), 225 | mock.call("directory-name/subdir1/subdir2"), 226 | mock.call("directory-name/subdir3/path/filename") 227 | ], 228 | mock_exists.call_args_list) 229 | 230 | self.assertEqual( 231 | [ 232 | mock.call("directory-name/subdir1/subdir2"), 233 | mock.call("directory-name/subdir3/path/filename") 234 | ], 235 | mock_makedirs.call_args_list) 236 | 237 | @mock.patch("os.path.exists") 238 | @mock.patch("os.makedirs") 239 | def test_save_to_directory_ignores_placeholder_directory_entries_when_present( 240 | self, mock_makedirs: mock.Mock, mock_exists: mock.Mock) -> None: 241 | mock_exists.side_effect = [False, True, False, False] 242 | 243 | mock_listed_blobs = [ 244 | self._mock_blob("path/filename/dir/"), 245 | self._mock_blob("path/filename/dir/file.txt"), 246 | self._mock_blob("path/filename/dir/emptysubdir/"), 247 | self._mock_blob("path/filename/emptydir/") 248 | ] 249 | self.mock_bucket.list_blobs.return_value = iter(mock_listed_blobs) 250 | 251 | mock_unversioned_blobs = [ 252 | self._mock_blob("path/filename/dir/file.txt") 253 | ] 254 | self.mock_bucket.blob.side_effect = mock_unversioned_blobs 255 | 256 | storage = get_storage("gs://{}@bucketname/path/filename".format(self.credentials)) 257 | 258 | storage.save_to_directory("directory-name") 259 | 260 | mock_unversioned_blobs[0].download_to_filename.assert_called_once_with( 261 | "directory-name/dir/file.txt") 262 | 263 | self.assertEqual( 264 | [ 265 | mock.call("directory-name/dir"), 266 | mock.call("directory-name/dir"), 267 | mock.call("directory-name/dir/emptysubdir"), 268 | mock.call("directory-name/emptydir") 269 | ], 270 | mock_exists.call_args_list) 271 | 272 | self.assertEqual( 273 | [ 274 | mock.call("directory-name/dir"), 275 | mock.call("directory-name/dir/emptysubdir"), 276 | mock.call("directory-name/emptydir") 277 | ], 278 | mock_makedirs.call_args_list) 279 | 280 | @mock.patch("os.path.exists") 281 | @mock.patch("random.uniform") 282 | @mock.patch("time.sleep") 283 | def test_save_to_directory_retries_file_download_on_error( 284 | self, mock_sleep: mock.Mock, mock_uniform: mock.Mock, mock_exists: mock.Mock) -> None: 285 | mock_exists.return_value = True 286 | 287 | mock_listed_blobs = [ 288 | self._mock_blob("path/filename/file1"), 289 | self._mock_blob("path/filename/subdir1/subdir2/file2"), 290 | self._mock_blob("path/filename/subdir3/path/filename/file3") 291 | ] 292 | self.mock_bucket.list_blobs.return_value = iter(mock_listed_blobs) 293 | 294 | mock_unversioned_blobs = [ 295 | self._mock_blob("path/filename/file1"), 296 | self._mock_blob("path/filename/subdir1/subdir2/file2"), 297 | self._mock_blob("path/filename/subdir3/path/filename/file3") 298 | ] 299 | self.mock_bucket.blob.side_effect = mock_unversioned_blobs 300 | mock_unversioned_blobs[1].download_to_filename.side_effect = [Exception, None] 301 | 302 | storage = get_storage("gs://{}@bucketname/path/filename".format(self.credentials)) 303 | 304 | storage.save_to_directory("directory-name") 305 | 306 | mock_unversioned_blobs[0].download_to_filename.assert_called_once_with( 307 | "directory-name/file1") 308 | 309 | self.assertEqual(2, mock_unversioned_blobs[1].download_to_filename.call_count) 310 | mock_unversioned_blobs[1].download_to_filename.assert_called_with( 311 | "directory-name/subdir1/subdir2/file2") 312 | 313 | mock_unversioned_blobs[2].download_to_filename.assert_called_once_with( 314 | "directory-name/subdir3/path/filename/file3") 315 | 316 | mock_uniform.assert_called_once_with(0, 1) 317 | mock_sleep.assert_called_once_with(mock_uniform.return_value) 318 | 319 | @mock.patch("os.path.exists") 320 | @mock.patch("random.uniform") 321 | @mock.patch("time.sleep") 322 | def test_save_to_directory_fails_after_five_unsuccessful_download_attempts( 323 | self, mock_sleep: mock.Mock, mock_uniform: mock.Mock, mock_exists: mock.Mock) -> None: 324 | mock_uniform_results = [mock.Mock() for i in range(4)] 325 | mock_uniform.side_effect = mock_uniform_results 326 | mock_exists.return_value = True 327 | 328 | mock_listed_blobs = [ 329 | self._mock_blob("path/filename/file1"), 330 | self._mock_blob("path/filename/subdir1/subdir2/file2"), 331 | self._mock_blob("path/filename/subdir3/path/filename/file3") 332 | ] 333 | self.mock_bucket.list_blobs.return_value = iter(mock_listed_blobs) 334 | 335 | mock_unversioned_blobs = [ 336 | self._mock_blob("path/filename/file1"), 337 | self._mock_blob("path/filename/subdir1/subdir2/file2") 338 | ] 339 | self.mock_bucket.blob.side_effect = mock_unversioned_blobs 340 | mock_unversioned_blobs[1].download_to_filename.side_effect = Exception 341 | 342 | storage = get_storage("gs://{}@bucketname/path/filename".format(self.credentials)) 343 | 344 | with self.assertRaises(Exception): 345 | storage.save_to_directory("directory-name") 346 | 347 | self.assertEqual(2, self.mock_bucket.blob.call_count) 348 | 349 | mock_unversioned_blobs[0].download_to_filename.assert_called_once_with( 350 | "directory-name/file1") 351 | 352 | self.assertEqual(5, mock_unversioned_blobs[1].download_to_filename.call_count) 353 | mock_unversioned_blobs[1].download_to_filename.assert_called_with( 354 | "directory-name/subdir1/subdir2/file2") 355 | 356 | mock_uniform.assert_has_calls([ 357 | mock.call(0, 1), 358 | mock.call(0, 3), 359 | mock.call(0, 7), 360 | mock.call(0, 15) 361 | ]) 362 | mock_sleep.assert_has_calls([ 363 | mock.call(mock_uniform_results[0]), 364 | mock.call(mock_uniform_results[1]), 365 | mock.call(mock_uniform_results[2]), 366 | mock.call(mock_uniform_results[3]) 367 | ]) 368 | 369 | @mock.patch("os.path.exists") 370 | @mock.patch("random.uniform") 371 | @mock.patch("time.sleep") 372 | def test_save_to_directory_raises_when_listed_blobs_is_empty( 373 | self, mock_sleep: mock.Mock, mock_uniform: mock.Mock, mock_exists: mock.Mock) -> None: 374 | mock_uniform_results = [mock.Mock() for i in range(4)] 375 | mock_uniform.side_effect = mock_uniform_results 376 | mock_exists.return_value = True 377 | 378 | self.mock_bucket.list_blobs.return_value = iter([]) 379 | 380 | storage = get_storage("gs://{}@bucketname/path/filename".format(self.credentials)) 381 | 382 | with self.assertRaises(NotFoundError): 383 | storage.save_to_directory("directory-name") 384 | 385 | self.assertEqual(0, self.mock_bucket.blob.call_count) 386 | 387 | @mock.patch("os.path.exists") 388 | @mock.patch("random.uniform") 389 | @mock.patch("time.sleep") 390 | def test_save_to_directory_raises_when_file_not_found( 391 | self, mock_sleep: mock.Mock, mock_uniform: mock.Mock, mock_exists: mock.Mock) -> None: 392 | mock_uniform_results = [mock.Mock() for i in range(4)] 393 | mock_uniform.side_effect = mock_uniform_results 394 | mock_exists.return_value = True 395 | 396 | mock_listed_blobs = [ 397 | self._mock_blob("path/filename/file1"), 398 | self._mock_blob("path/filename/subdir1/subdir2/file2"), 399 | self._mock_blob("path/filename/subdir3/path/filename/file3") 400 | ] 401 | self.mock_bucket.list_blobs.return_value = iter(mock_listed_blobs) 402 | 403 | mock_unversioned_blobs = [ 404 | self._mock_blob("path/filename/file1"), 405 | self._mock_blob("path/filename/subdir1/subdir2/file2") 406 | ] 407 | self.mock_bucket.blob.side_effect = mock_unversioned_blobs 408 | mock_unversioned_blobs[1].download_to_filename.side_effect = NotFound("File Not Found") 409 | 410 | storage = get_storage("gs://{}@bucketname/path/filename".format(self.credentials)) 411 | 412 | with self.assertRaises(NotFoundError): 413 | storage.save_to_directory("directory-name") 414 | 415 | self.assertEqual(2, self.mock_bucket.blob.call_count) 416 | 417 | mock_unversioned_blobs[0].download_to_filename.assert_called_once_with( 418 | "directory-name/file1") 419 | 420 | self.assertEqual(5, mock_unversioned_blobs[1].download_to_filename.call_count) 421 | mock_unversioned_blobs[1].download_to_filename.assert_called_with( 422 | "directory-name/subdir1/subdir2/file2") 423 | 424 | mock_uniform.assert_has_calls([ 425 | mock.call(0, 1), 426 | mock.call(0, 3), 427 | mock.call(0, 7), 428 | mock.call(0, 15) 429 | ]) 430 | mock_sleep.assert_has_calls([ 431 | mock.call(mock_uniform_results[0]), 432 | mock.call(mock_uniform_results[1]), 433 | mock.call(mock_uniform_results[2]), 434 | mock.call(mock_uniform_results[3]) 435 | ]) 436 | 437 | @mock.patch("os.walk") 438 | def test_load_from_directory_uploads_files_to_bucket_with_prefix( 439 | self, mock_walk: mock.Mock) -> None: 440 | mock_blobs = [mock.Mock(), mock.Mock(), mock.Mock(), mock.Mock()] 441 | self.mock_bucket.blob.side_effect = mock_blobs 442 | 443 | mock_walk.return_value = [ 444 | ("/path/to/directory-name", ["subdir", "emptysubdir"], ["root1"]), 445 | ("/path/to/directory-name/subdir", ["nesteddir"], ["sub1", "sub2"]), 446 | ("/path/to/directory-name/subdir/nesteddir", [], ["nested1"]), 447 | ("/path/to/directory-name/emptysubdir", [], []) 448 | ] 449 | 450 | storage = get_storage("gs://{}@bucketname/path/filename".format(self.credentials)) 451 | 452 | storage.load_from_directory("/path/to/directory-name") 453 | 454 | self.assert_gets_bucket_with_credentials() 455 | 456 | mock_walk.assert_called_once_with("/path/to/directory-name") 457 | 458 | self.mock_bucket.blob.assert_has_calls([ 459 | mock.call("path/filename/root1"), 460 | mock.call("path/filename/subdir/sub1"), 461 | mock.call("path/filename/subdir/sub2"), 462 | mock.call("path/filename/subdir/nesteddir/nested1") 463 | ]) 464 | 465 | mock_blobs[0].upload_from_filename.assert_called_once_with("/path/to/directory-name/root1") 466 | mock_blobs[1].upload_from_filename.assert_called_once_with( 467 | "/path/to/directory-name/subdir/sub1") 468 | mock_blobs[2].upload_from_filename.assert_called_once_with( 469 | "/path/to/directory-name/subdir/sub2") 470 | mock_blobs[3].upload_from_filename.assert_called_once_with( 471 | "/path/to/directory-name/subdir/nesteddir/nested1") 472 | 473 | @mock.patch("os.walk") 474 | def test_load_from_directory_handles_repeated_directory_structure( 475 | self, mock_walk: mock.Mock) -> None: 476 | mock_blobs = [mock.Mock(), mock.Mock()] 477 | self.mock_bucket.blob.side_effect = mock_blobs 478 | 479 | mock_walk.return_value = [ 480 | ("dir/name", ["dir"], []), 481 | ("dir/name/dir", ["name"], []), 482 | ("dir/name/dir/name", ["foo"], ["file1"]), 483 | ("dir/name/dir/name/foo", [], ["file2"]) 484 | ] 485 | 486 | storage = get_storage("gs://{}@bucketname/path/filename".format(self.credentials)) 487 | 488 | storage.load_from_directory("dir/name") 489 | 490 | self.assert_gets_bucket_with_credentials() 491 | 492 | mock_walk.assert_called_once_with("dir/name") 493 | 494 | self.mock_bucket.blob.assert_has_calls([ 495 | mock.call("path/filename/dir/name/file1"), 496 | mock.call("path/filename/dir/name/foo/file2") 497 | ]) 498 | 499 | mock_blobs[0].upload_from_filename.assert_called_once_with("dir/name/dir/name/file1") 500 | mock_blobs[1].upload_from_filename.assert_called_once_with("dir/name/dir/name/foo/file2") 501 | 502 | @mock.patch("os.walk") 503 | @mock.patch("time.sleep") 504 | def test_load_from_directory_retries_file_upload_on_error( 505 | self, mock_sleep: mock.Mock, mock_walk: mock.Mock) -> None: 506 | mock_blobs = [mock.Mock(), mock.Mock(), mock.Mock()] 507 | mock_blobs[1].upload_from_filename.side_effect = [Exception, None] 508 | self.mock_bucket.blob.side_effect = mock_blobs 509 | 510 | mock_walk.return_value = [ 511 | ("/dir", [], ["file1", "file2", "file3"]) 512 | ] 513 | 514 | storage = get_storage("gs://{}@bucketname/path/filename".format(self.credentials)) 515 | 516 | storage.load_from_directory("/dir") 517 | 518 | self.mock_bucket.blob.assert_has_calls([ 519 | mock.call("path/filename/file1"), 520 | mock.call("path/filename/file2"), 521 | mock.call("path/filename/file3") 522 | ]) 523 | 524 | mock_blobs[0].upload_from_filename.assert_called_once_with("/dir/file1") 525 | 526 | self.assertEqual(2, mock_blobs[1].upload_from_filename.call_count) 527 | mock_blobs[1].upload_from_filename.assert_called_with("/dir/file2") 528 | 529 | mock_blobs[2].upload_from_filename.assert_called_once_with("/dir/file3") 530 | 531 | @mock.patch("os.walk") 532 | @mock.patch("time.sleep") 533 | def test_load_from_directory_fails_after_five_unsuccessful_upload_attempts( 534 | self, mock_sleep: mock.Mock, mock_walk: mock.Mock) -> None: 535 | mock_blobs = [mock.Mock(), mock.Mock(), mock.Mock()] 536 | mock_blobs[1].upload_from_filename.side_effect = Exception 537 | self.mock_bucket.blob.side_effect = mock_blobs 538 | 539 | mock_walk.return_value = [ 540 | ("/dir", [], ["file1", "file2", "file3"]) 541 | ] 542 | 543 | storage = get_storage("gs://{}@bucketname/path/filename".format(self.credentials)) 544 | 545 | with self.assertRaises(Exception): 546 | storage.load_from_directory("/dir") 547 | 548 | mock_blobs[0].upload_from_filename.assert_called_once_with("/dir/file1") 549 | 550 | self.assertEqual(5, mock_blobs[1].upload_from_filename.call_count) 551 | mock_blobs[1].upload_from_filename.assert_called_with("/dir/file2") 552 | 553 | self.assertEqual(0, mock_blobs[2].upload_from_filename.call_count) 554 | 555 | def test_delete_directory_deletes_blobs_with_prefix(self) -> None: 556 | mock_listed_blobs = [ 557 | self._mock_blob("path/filename/file1"), 558 | self._mock_blob("path/filename/file2"), 559 | self._mock_blob("path/filename/file3") 560 | ] 561 | self.mock_bucket.list_blobs.return_value = iter(mock_listed_blobs) 562 | 563 | mock_unversioned_blobs = [ 564 | self._mock_blob("path/filename/file1"), 565 | self._mock_blob("path/filename/file2"), 566 | self._mock_blob("path/filename/file3") 567 | ] 568 | self.mock_bucket.blob.side_effect = mock_unversioned_blobs 569 | 570 | self.mock_bucket.list_blobs.return_value = iter(mock_listed_blobs) 571 | 572 | storage = get_storage("gs://{}@bucketname/path/filename".format(self.credentials)) 573 | 574 | storage.delete_directory() 575 | 576 | self.assert_gets_bucket_with_credentials() 577 | 578 | self.mock_bucket.list_blobs.assert_called_once_with(prefix="path/filename/") 579 | 580 | self.mock_bucket.blob.assert_has_calls([ 581 | mock.call("path/filename/file1"), 582 | mock.call("path/filename/file2"), 583 | mock.call("path/filename/file3") 584 | ]) 585 | 586 | mock_unversioned_blobs[0].delete.assert_called_once_with() 587 | mock_unversioned_blobs[1].delete.assert_called_once_with() 588 | mock_unversioned_blobs[2].delete.assert_called_once_with() 589 | 590 | def test_delete_directory_raises_when_file_does_not_exist(self) -> None: 591 | mock_listed_blobs = [ 592 | self._mock_blob("path/filename/file1"), 593 | self._mock_blob("path/filename/file2"), 594 | self._mock_blob("path/filename/file3") 595 | ] 596 | self.mock_bucket.list_blobs.return_value = iter(mock_listed_blobs) 597 | 598 | mock_unversioned_blobs = [ 599 | self._mock_blob("path/filename/file1"), 600 | self._mock_blob("path/filename/file2"), 601 | self._mock_blob("path/filename/file3") 602 | ] 603 | self.mock_bucket.blob.side_effect = mock_unversioned_blobs 604 | mock_unversioned_blobs[1].delete.side_effect = NotFound("File Not Found") 605 | 606 | self.mock_bucket.list_blobs.return_value = iter(mock_listed_blobs) 607 | 608 | storage = get_storage("gs://{}@bucketname/path/filename".format(self.credentials)) 609 | 610 | with self.assertRaises(NotFoundError): 611 | storage.delete_directory() 612 | 613 | self.mock_bucket.list_blobs.assert_called_once_with(prefix="path/filename/") 614 | 615 | self.mock_bucket.blob.assert_has_calls([ 616 | mock.call("path/filename/file1"), 617 | mock.call("path/filename/file2") 618 | ]) 619 | 620 | mock_unversioned_blobs[0].delete.assert_called_once_with() 621 | mock_unversioned_blobs[1].delete.assert_called_once_with() 622 | mock_unversioned_blobs[2].delete.assert_not_called() 623 | 624 | def test_delete_directory_raises_when_list_blobs_is_empty(self) -> None: 625 | self.mock_bucket.list_blobs.return_value = iter([]) 626 | 627 | storage = get_storage("gs://{}@bucketname/path/filename".format(self.credentials)) 628 | 629 | with self.assertRaises(NotFoundError): 630 | storage.delete_directory() 631 | 632 | self.mock_bucket.list_blobs.assert_called_once_with(prefix="path/filename/") 633 | -------------------------------------------------------------------------------- /tests/test_local_storage.py: -------------------------------------------------------------------------------- 1 | import os 2 | from io import BytesIO 3 | import tempfile 4 | from unittest import mock 5 | from urllib.parse import quote_plus 6 | 7 | from typing import Dict, Optional 8 | 9 | from storage.storage import get_storage, DownloadUrlBaseUndefinedError, NotFoundError 10 | from tests.helpers import create_temp_nested_directory_with_files, NestedDirectoryDict 11 | from tests.storage_test_case import StorageTestCase 12 | from tests.helpers import TempDirectory 13 | 14 | 15 | class TestLocalStorage(StorageTestCase): 16 | 17 | temp_directory: Optional[NestedDirectoryDict] 18 | 19 | def setUp(self) -> None: 20 | super().setUp() 21 | self.temp_directory = None 22 | 23 | def tearDown(self) -> None: 24 | super().tearDown() 25 | if self.temp_directory is not None: 26 | self.temp_directory["temp_directory"]["object"].cleanup() 27 | 28 | def _generate_storage_uri( 29 | self, object_path: str, parameters: Optional[Dict[str, str]] = None) -> str: 30 | return "file:///path/to/file.mp4" 31 | 32 | def test_local_storage_save_to_filename(self) -> None: 33 | temp_input = tempfile.NamedTemporaryFile() 34 | temp_input.write(b"FOOBAR") 35 | temp_input.flush() 36 | 37 | temp_output = tempfile.NamedTemporaryFile() 38 | 39 | storage = get_storage("file://%s" % (temp_input.name)) 40 | storage.save_to_filename(temp_output.name) 41 | 42 | with open(temp_output.name, "rb") as temp_output_fp: 43 | self.assertEqual(b"FOOBAR", temp_output_fp.read()) 44 | 45 | def test_local_storage_raises_when_filename_does_not_exist(self) -> None: 46 | with tempfile.NamedTemporaryFile() as fp: 47 | removed_path = fp.name 48 | 49 | self.assertFalse(os.path.exists(removed_path)) 50 | temp_output = tempfile.NamedTemporaryFile() 51 | 52 | storage = get_storage(f"file://{removed_path}") 53 | with self.assertRaises(NotFoundError): 54 | storage.save_to_filename(temp_output.name) 55 | 56 | @mock.patch("os.makedirs", autospec=True) 57 | def test_local_storage_load_from_filename(self, mock_makedirs: mock.Mock) -> None: 58 | temp_input = tempfile.NamedTemporaryFile() 59 | temp_input.write(b"FOOBAR") 60 | temp_input.flush() 61 | 62 | temp_output = tempfile.NamedTemporaryFile() 63 | storage = get_storage("file://%s" % (temp_output.name)) 64 | storage.load_from_filename(temp_input.name) 65 | 66 | self.assertEqual(0, mock_makedirs.call_count) 67 | 68 | with open(temp_output.name, "rb") as temp_output_fp: 69 | self.assertEqual(b"FOOBAR", temp_output_fp.read()) 70 | 71 | def test_local_storage_save_to_directory(self) -> None: 72 | self.temp_directory = create_temp_nested_directory_with_files() 73 | 74 | storage = get_storage("file://{0}".format(self.temp_directory["temp_directory"]["path"])) 75 | 76 | with TempDirectory() as temp_output: 77 | temp_output_dir = temp_output.name 78 | destination_directory_path = os.path.join(temp_output_dir, "tmp") 79 | storage.save_to_directory(destination_directory_path) 80 | 81 | destination_input_one_path = os.path.join( 82 | destination_directory_path, 83 | self.temp_directory["temp_input_one"]["name"]) 84 | destination_input_two_path = os.path.join( 85 | destination_directory_path, 86 | self.temp_directory["temp_input_two"]["name"]) 87 | nested_temp_input_path = os.path.join( 88 | destination_directory_path, 89 | self.temp_directory["nested_temp_directory"]["name"], 90 | self.temp_directory["nested_temp_input"]["name"]) 91 | 92 | self.assertTrue(os.path.exists(destination_input_one_path)) 93 | self.assertTrue(os.path.exists(destination_input_two_path)) 94 | self.assertTrue(os.path.exists(nested_temp_input_path)) 95 | 96 | with open(destination_input_one_path, "rb") as temp_output_fp: 97 | self.assertEqual(b"FOO", temp_output_fp.read()) 98 | 99 | with open(destination_input_two_path, "rb") as temp_output_fp: 100 | self.assertEqual(b"BAR", temp_output_fp.read()) 101 | 102 | with open(nested_temp_input_path, "rb") as temp_output_fp: 103 | self.assertEqual(b"FOOBAR", temp_output_fp.read()) 104 | 105 | def test_local_storage_save_to_directory_overwrites_existing_files(self) -> None: 106 | self.temp_directory = create_temp_nested_directory_with_files() 107 | 108 | storage = get_storage("file://{0}".format(self.temp_directory["temp_directory"]["path"])) 109 | 110 | with TempDirectory() as temp_output: 111 | temp_output_dir = temp_output.name 112 | destination_directory_path = os.path.join(temp_output_dir, "tmp") 113 | destination_input_one_path = os.path.join( 114 | destination_directory_path, 115 | self.temp_directory["temp_input_one"]["name"]) 116 | 117 | os.makedirs(destination_directory_path) 118 | with open(destination_input_one_path, "wb") as to_be_overwritten: 119 | to_be_overwritten.write(b"TO-BE-OVERWRITTEN") 120 | 121 | storage.save_to_directory(destination_directory_path) 122 | 123 | with open(destination_input_one_path, "rb") as temp_output_fp: 124 | self.assertEqual(b"FOO", temp_output_fp.read()) 125 | 126 | def test_local_storage_save_to_directory_raises_when_source_directory_does_not_exist( 127 | self) -> None: 128 | self.temp_directory = create_temp_nested_directory_with_files() 129 | fake_path = os.path.join(self.temp_directory["temp_directory"]["path"], "invalid") 130 | self.assertFalse(os.path.exists(fake_path)) 131 | 132 | storage = get_storage("file://{0}".format(fake_path)) 133 | 134 | with TempDirectory() as temp_output: 135 | temp_output_dir = temp_output.name 136 | destination_directory_path = os.path.join(temp_output_dir, "tmp") 137 | 138 | with self.assertRaises(NotFoundError): 139 | storage.save_to_directory(destination_directory_path) 140 | 141 | def test_local_storage_load_from_directory(self) -> None: 142 | self.temp_directory = create_temp_nested_directory_with_files() 143 | 144 | with TempDirectory() as temp_output: 145 | temp_output_dir = temp_output.name 146 | storage = get_storage("file://{0}/{1}".format(temp_output_dir, "tmp")) 147 | 148 | storage.load_from_directory(self.temp_directory["temp_directory"]["path"]) 149 | 150 | destination_directory_path = os.path.join( 151 | temp_output_dir, "tmp") 152 | destination_input_one_path = os.path.join( 153 | temp_output_dir, destination_directory_path, 154 | self.temp_directory["temp_input_one"]["name"]) 155 | destination_input_two_path = os.path.join( 156 | temp_output_dir, destination_directory_path, 157 | self.temp_directory["temp_input_two"]["name"]) 158 | nested_temp_input_path = os.path.join( 159 | temp_output_dir, destination_directory_path, 160 | self.temp_directory["nested_temp_directory"]["path"], 161 | self.temp_directory["nested_temp_input"]["name"]) 162 | 163 | self.assertTrue(os.path.exists(destination_input_one_path)) 164 | self.assertTrue(os.path.exists(destination_input_two_path)) 165 | self.assertTrue(os.path.exists(nested_temp_input_path)) 166 | 167 | with open(destination_input_one_path, "rb") as temp_output_fp: 168 | self.assertEqual(b"FOO", temp_output_fp.read()) 169 | 170 | with open(destination_input_two_path, "rb") as temp_output_fp: 171 | self.assertEqual(b"BAR", temp_output_fp.read()) 172 | 173 | with open(nested_temp_input_path, "rb") as temp_output_fp: 174 | self.assertEqual(b"FOOBAR", temp_output_fp.read()) 175 | 176 | def test_local_storage_load_from_directory_overwrites_existing_files(self) -> None: 177 | self.temp_directory = create_temp_nested_directory_with_files() 178 | 179 | with TempDirectory() as temp_output: 180 | temp_output_dir = temp_output.name 181 | destination_directory_path = os.path.join( 182 | temp_output_dir, "tmp") 183 | destination_input_one_path = os.path.join( 184 | temp_output_dir, destination_directory_path, 185 | self.temp_directory["temp_input_one"]["name"]) 186 | 187 | os.makedirs(destination_directory_path) 188 | with open(destination_input_one_path, "wb") as to_be_overwritten: 189 | to_be_overwritten.write(b"TO-BE-OVERWRITTEN") 190 | 191 | storage = get_storage("file://{0}/{1}".format(temp_output_dir, "tmp")) 192 | 193 | storage.load_from_directory(self.temp_directory["temp_directory"]["path"]) 194 | 195 | with open(destination_input_one_path, "rb") as temp_output_fp: 196 | self.assertEqual(b"FOO", temp_output_fp.read()) 197 | 198 | @mock.patch("shutil.copy", autospec=True) 199 | @mock.patch("os.makedirs", autospec=True) 200 | @mock.patch("os.path.exists", autospec=True) 201 | def test_load_from_file_creates_intermediate_dirs( 202 | self, mock_exists: mock.Mock, mock_makedirs: mock.Mock, mock_copy: mock.Mock) -> None: 203 | mock_exists.return_value = False 204 | 205 | storage = get_storage("file:///foo/bar/file") 206 | storage.load_from_filename("input_file") 207 | 208 | mock_exists.assert_called_with("/foo/bar") 209 | mock_makedirs.assert_called_with("/foo/bar") 210 | mock_copy.assert_called_with("input_file", "/foo/bar/file") 211 | 212 | @mock.patch("os.remove", autospec=True) 213 | def test_local_storage_delete(self, mock_remove: mock.Mock) -> None: 214 | storage = get_storage("file:///folder/file") 215 | storage.delete() 216 | 217 | mock_remove.assert_called_with("/folder/file") 218 | 219 | def test_local_storage_delete_raises_when_file_does_not_exist(self) -> None: 220 | with tempfile.NamedTemporaryFile() as fp: 221 | removed_path = fp.name 222 | 223 | self.assertFalse(os.path.exists(removed_path)) 224 | 225 | storage = get_storage(f"file:///{removed_path}") 226 | with self.assertRaises(NotFoundError): 227 | storage.delete() 228 | 229 | @mock.patch("shutil.rmtree", autospec=True) 230 | @mock.patch("os.remove", autospec=True) 231 | def test_local_storage_delete_directory( 232 | self, mock_remove: mock.Mock, mock_rmtree: mock.Mock) -> None: 233 | self.temp_directory = create_temp_nested_directory_with_files() 234 | 235 | storage = get_storage("file://{0}".format(self.temp_directory["temp_directory"]["path"])) 236 | storage.delete_directory() 237 | 238 | self.assertFalse(mock_remove.called) 239 | mock_rmtree.assert_called_once_with(self.temp_directory["temp_directory"]["path"]) 240 | 241 | def test_local_storage_delete_directory_raises_when_directory_does_not_exist(self) -> None: 242 | self.temp_directory = create_temp_nested_directory_with_files() 243 | fake_path = os.path.join(self.temp_directory["temp_directory"]["path"], "invalid") 244 | 245 | self.assertFalse(os.path.exists(fake_path)) 246 | 247 | storage = get_storage("file://{0}".format(fake_path)) 248 | with self.assertRaises(NotFoundError): 249 | storage.delete_directory() 250 | 251 | def test_local_storage_save_to_file(self) -> None: 252 | temp_input = tempfile.NamedTemporaryFile() 253 | temp_input.write(b"FOOBAR") 254 | temp_input.flush() 255 | 256 | out_file = BytesIO() 257 | 258 | storage = get_storage("file://%s" % (temp_input.name)) 259 | storage.save_to_file(out_file) 260 | 261 | self.assertEqual(b"FOOBAR", out_file.getvalue()) 262 | 263 | def test_local_storage_raises_when_file_does_not_exist(self) -> None: 264 | with tempfile.NamedTemporaryFile() as fp: 265 | removed_path = fp.name 266 | 267 | self.assertFalse(os.path.exists(removed_path)) 268 | 269 | out_file = BytesIO() 270 | 271 | storage = get_storage(f"file://{removed_path}") 272 | with self.assertRaises(NotFoundError): 273 | storage.save_to_file(out_file) 274 | 275 | def test_local_storage_load_from_file(self) -> None: 276 | in_file = BytesIO(b"foobar") 277 | temp_output = tempfile.NamedTemporaryFile() 278 | 279 | storage = get_storage("file://{0}".format(temp_output.name)) 280 | storage.load_from_file(in_file) 281 | 282 | with open(temp_output.name, "rb") as temp_output_fp: 283 | self.assertEqual(b"foobar", temp_output_fp.read()) 284 | 285 | @mock.patch("os.makedirs") 286 | @mock.patch("os.path.exists") 287 | @mock.patch("builtins.open") 288 | def test_load_from_file_creates_dirs_if_not_present( 289 | self, mock_open: mock.Mock, mock_exists: mock.Mock, mock_makedirs: mock.Mock) -> None: 290 | mock_exists.return_value = False 291 | in_file = BytesIO(b"foobar") 292 | 293 | mock_file = mock_open.return_value.__enter__.return_value 294 | mock_file.read.side_effect = ["FOOBAR", None] 295 | 296 | out_storage = get_storage("file:///foobar/is/out") 297 | out_storage.load_from_file(in_file) 298 | 299 | mock_open.assert_has_calls([ 300 | mock.call("/foobar/is/out", "wb") 301 | ]) 302 | 303 | mock_exists.assert_called_with("/foobar/is") 304 | mock_makedirs.assert_called_with("/foobar/is") 305 | mock_file.write.assert_called_with(b"foobar") 306 | self.assertEqual(1, mock_open.return_value.__exit__.call_count) 307 | 308 | @mock.patch("os.makedirs") 309 | @mock.patch("os.path.exists") 310 | @mock.patch("builtins.open") 311 | def test_load_from_file_does_not_create_dirs_if_present( 312 | self, mock_open: mock.Mock, mock_exists: mock.Mock, mock_makedirs: mock.Mock) -> None: 313 | mock_exists.return_value = True 314 | in_file = BytesIO(b"foobar") 315 | 316 | out_storage = get_storage("file:///foobar/is/out") 317 | out_storage.load_from_file(in_file) 318 | 319 | mock_exists.assert_called_with("/foobar/is") 320 | self.assertEqual(0, mock_makedirs.call_count) 321 | 322 | def test_local_storage_get_download_url(self) -> None: 323 | temp_input = tempfile.NamedTemporaryFile() 324 | temp_input.write(b"FOOBAR") 325 | temp_input.flush() 326 | 327 | download_url_base = "http://host:123/path/to/" 328 | download_url_base_encoded = quote_plus(download_url_base) 329 | 330 | storage_uri = f"file://{temp_input.name}?download_url_base={download_url_base_encoded}" 331 | out_storage = get_storage(storage_uri) 332 | temp_url = out_storage.get_download_url() 333 | 334 | self.assertEqual( 335 | f"http://host:123/path/to/{os.path.basename(temp_input.name)}", temp_url) 336 | 337 | def test_local_storage_get_download_url_ignores_args(self) -> None: 338 | temp_input = tempfile.NamedTemporaryFile() 339 | temp_input.write(b"FOOBAR") 340 | temp_input.flush() 341 | 342 | download_url_base = "http://host:123/path/to/" 343 | download_url_base_encoded = quote_plus(download_url_base) 344 | 345 | storage_uri = f"file://{temp_input.name}?download_url_base={download_url_base_encoded}" 346 | 347 | out_storage = get_storage(storage_uri) 348 | temp_url = out_storage.get_download_url(seconds=900) 349 | 350 | self.assertEqual( 351 | f"http://host:123/path/to/{os.path.basename(temp_input.name)}", temp_url) 352 | 353 | temp_url = out_storage.get_download_url(key="secret") 354 | 355 | self.assertEqual( 356 | f"http://host:123/path/to/{os.path.basename(temp_input.name)}", temp_url) 357 | 358 | def test_local_storage_get_download_url_returns_none_on_empty_base(self) -> None: 359 | temp_input = tempfile.NamedTemporaryFile() 360 | temp_input.write(b"FOOBAR") 361 | temp_input.flush() 362 | 363 | storage_uri = "file://{fpath}".format(fpath=temp_input.name) 364 | out_storage = get_storage(storage_uri) 365 | 366 | with self.assertRaises(DownloadUrlBaseUndefinedError): 367 | out_storage.get_download_url() 368 | 369 | def test_local_storage_rejects_multiple_query_values_for_download_url_key_setting(self) -> None: 370 | self.assert_rejects_multiple_query_values("/foo/bar/object.mp4", "download_url_base") 371 | 372 | def test_local_storage_get_sanitized_uri_returns_filepath(self) -> None: 373 | temp_input = tempfile.NamedTemporaryFile() 374 | temp_input.write(b"FOOBAR") 375 | temp_input.flush() 376 | 377 | download_url_base = "http://host:123/path/to/" 378 | download_url_base_encoded = quote_plus(download_url_base) 379 | 380 | storage_uri = f"file://{temp_input.name}?download_url_base={download_url_base_encoded}" 381 | out_storage = get_storage(storage_uri) 382 | 383 | sanitized_uri = out_storage.get_sanitized_uri() 384 | 385 | self.assertEqual( 386 | f"file://{temp_input.name}?download_url_base={download_url_base_encoded}", 387 | sanitized_uri) 388 | -------------------------------------------------------------------------------- /tests/test_retry.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | from unittest import mock 3 | 4 | from storage import retry 5 | 6 | 7 | class UnRetriableError(Exception): 8 | do_not_retry = True 9 | 10 | 11 | class TestRetry(unittest.TestCase): 12 | @mock.patch("random.uniform") 13 | @mock.patch("time.sleep") 14 | def test_does_not_retry_on_success( 15 | self, mock_sleep: mock.Mock, mock_uniform: mock.Mock) -> None: 16 | successful_function = mock.Mock(return_value="result") 17 | 18 | result = retry.attempt(successful_function, 1, 2, foo="bar", biz="baz") 19 | 20 | self.assertEqual(0, mock_uniform.call_count) 21 | self.assertEqual(0, mock_sleep.call_count) 22 | 23 | successful_function.assert_called_with(1, 2, foo="bar", biz="baz") 24 | self.assertEqual("result", result) 25 | 26 | @mock.patch("random.uniform") 27 | @mock.patch("time.sleep") 28 | def test_retries_on_failure( 29 | self, mock_sleep: mock.Mock, mock_uniform: mock.Mock) -> None: 30 | mock_uniform.side_effect = [0.5, 2.4] 31 | 32 | failing_function = mock.Mock(side_effect=[RuntimeError, RuntimeError, "result"]) 33 | 34 | result = retry.attempt(failing_function, 1, 2, foo="bar", biz="baz") 35 | 36 | mock_uniform.assert_has_calls([ 37 | mock.call(0, 1), 38 | mock.call(0, 3) 39 | ]) 40 | 41 | mock_sleep.assert_has_calls([ 42 | mock.call(0.5), 43 | mock.call(2.4) 44 | ]) 45 | 46 | self.assertEqual(3, failing_function.call_count) 47 | failing_function.assert_called_with(1, 2, foo="bar", biz="baz") 48 | self.assertEqual("result", result) 49 | 50 | @mock.patch("random.uniform") 51 | @mock.patch("time.sleep") 52 | def test_reraises_last_exception_on_attempt_exhaustion( 53 | self, mock_sleep: mock.Mock, mock_uniform: mock.Mock) -> None: 54 | mock_uniform.side_effect = [0.5, 2.4, 3.6, 5.6] 55 | 56 | failing_function = mock.Mock( 57 | side_effect=[RuntimeError, RuntimeError, RuntimeError, RuntimeError, RuntimeError]) 58 | 59 | with self.assertRaises(RuntimeError): 60 | retry.attempt(failing_function, 1, 2, foo="bar", biz="baz") 61 | 62 | self.assertEqual(4, mock_uniform.call_count) 63 | self.assertEqual(4, mock_sleep.call_count) 64 | 65 | self.assertEqual(5, failing_function.call_count) 66 | failing_function.assert_called_with(1, 2, foo="bar", biz="baz") 67 | 68 | @mock.patch("random.uniform") 69 | @mock.patch("time.sleep") 70 | def test_does_not_retry_unretriable_errors( 71 | self, mock_sleep: mock.Mock, mock_uniform: mock.Mock) -> None: 72 | failing_function = mock.Mock(side_effect=UnRetriableError) 73 | 74 | with self.assertRaises(UnRetriableError): 75 | retry.attempt(failing_function, 1, 2, foo="bar", biz="baz") 76 | 77 | self.assertEqual(0, mock_uniform.call_count) 78 | self.assertEqual(0, mock_sleep.call_count) 79 | 80 | self.assertEqual(1, failing_function.call_count) 81 | failing_function.assert_called_with(1, 2, foo="bar", biz="baz") 82 | -------------------------------------------------------------------------------- /tests/test_storage.py: -------------------------------------------------------------------------------- 1 | import threading 2 | 3 | from storage.storage import get_storage, InvalidStorageUri, register_storage_protocol, Storage 4 | from storage.storage import timeout, TimeoutError, _STORAGE_TYPES 5 | from unittest import mock, TestCase 6 | 7 | 8 | class TestTimeout(TestCase): 9 | @mock.patch("threading.Thread", wraps=threading.Thread) 10 | def test_calls_function_in_thread(self, mock_thread_class: mock.Mock) -> None: 11 | def worker() -> str: 12 | self.assertTrue(threading.current_thread().daemon) 13 | return "success" 14 | 15 | self.assertEqual("success", timeout(5, worker)) 16 | 17 | mock_thread_class.assert_called_once_with(target=mock.ANY) 18 | 19 | def test_reraises_exception_raised_by_worker(self) -> None: 20 | def worker() -> None: 21 | raise Exception("some error") 22 | 23 | with self.assertRaisesRegex(Exception, "^some error$"): 24 | timeout(5, worker) 25 | 26 | def test_raises_timeout_error_when_worker_does_not_complete_within_timeout(self) -> None: 27 | event = threading.Event() 28 | 29 | def worker() -> None: 30 | event.wait() 31 | 32 | try: 33 | with self.assertRaises(TimeoutError): 34 | timeout(0, worker) 35 | finally: 36 | event.set() 37 | 38 | 39 | class TestRegisterStorageProtocol(TestCase): 40 | 41 | def setUp(self) -> None: 42 | self.scheme = "myscheme" 43 | 44 | def test_register_storage_protocol_updates_storage_types(self) -> None: 45 | 46 | @register_storage_protocol(scheme=self.scheme) 47 | class MyStorageClass(Storage): 48 | pass 49 | 50 | self.assertIn(self.scheme, _STORAGE_TYPES) 51 | 52 | uri = "{0}://some/uri/path".format(self.scheme) 53 | store_obj = get_storage(uri) 54 | self.assertIsInstance(store_obj, MyStorageClass) 55 | 56 | def test_storage_provider_calls_validation_on_implementation(self) -> None: 57 | 58 | @register_storage_protocol(scheme=self.scheme) 59 | class ValidatingStorageClass(Storage): 60 | def _validate_parsed_uri(self) -> None: 61 | raise InvalidStorageUri("Nope I don't like it.") 62 | 63 | with self.assertRaises(InvalidStorageUri): 64 | get_storage(f"{self.scheme}://some/uri/path") 65 | 66 | 67 | class TestGetStorage(TestCase): 68 | def test_raises_for_unsupported_scheme(self) -> None: 69 | with self.assertRaises(InvalidStorageUri) as error: 70 | get_storage("unsupported://creds:secret@bucket/path") 71 | 72 | self.assertEqual("Invalid storage type 'unsupported'", str(error.exception)) 73 | 74 | def test_raises_for_missing_scheme(self) -> None: 75 | with self.assertRaises(InvalidStorageUri) as error: 76 | get_storage("//creds:secret@invalid/storage/uri") 77 | 78 | self.assertEqual("Invalid storage type ''", str(error.exception)) 79 | 80 | def test_raises_for_missing_scheme_and_netloc(self) -> None: 81 | with self.assertRaises(InvalidStorageUri) as error: 82 | get_storage("invalid/storage/uri") 83 | 84 | self.assertEqual("Invalid storage type ''", str(error.exception)) 85 | 86 | 87 | class TestStorage(TestCase): 88 | def test_get_sanitized_uri_removes_username_and_password(self) -> None: 89 | storage = Storage(storage_uri="https://username:password@bucket/path/filename") 90 | sanitized_uri = storage.get_sanitized_uri() 91 | 92 | self.assertEqual("https://bucket/path/filename", sanitized_uri) 93 | 94 | def test_get_sanitized_uri_does_not_preserves_parameters(self) -> None: 95 | storage = Storage(storage_uri="https://username:password@bucket/path/filename?other=param") 96 | sanitized_uri = storage.get_sanitized_uri() 97 | 98 | self.assertEqual("https://bucket/path/filename", sanitized_uri) 99 | 100 | def test_get_sanitized_uri_preserves_port_number(self) -> None: 101 | storage = Storage(storage_uri="ftp://username:password@ftp.foo.com:8080/path/filename") 102 | sanitized_uri = storage.get_sanitized_uri() 103 | 104 | self.assertEqual("ftp://ftp.foo.com:8080/path/filename", sanitized_uri) 105 | -------------------------------------------------------------------------------- /tests/test_url_parser.py: -------------------------------------------------------------------------------- 1 | from urllib.parse import urlparse 2 | 3 | from unittest import TestCase 4 | 5 | from storage import url_parser 6 | 7 | 8 | class TestUrlParser(TestCase): 9 | 10 | def test_remove_user_info_removes_username_and_password(self) -> None: 11 | storage_uri = urlparse("https://username:password@bucket/path/filename") 12 | sanitized_uri = url_parser.remove_user_info(storage_uri) 13 | 14 | self.assertEqual("https://bucket/path/filename", sanitized_uri) 15 | 16 | def test_remove_user_info_preserves_parameters(self) -> None: 17 | storage_uri = urlparse("https://username:password@bucket/path/filename?other=parameter") 18 | sanitized_uri = url_parser.remove_user_info(storage_uri) 19 | 20 | self.assertEqual("https://bucket/path/filename?other=parameter", sanitized_uri) 21 | 22 | def test_remove_user_info_preserves_port_number(self) -> None: 23 | storage_uri = urlparse("ftp://username:password@ftp.foo.com:8080/path/filename") 24 | sanitized_uri = url_parser.remove_user_info(storage_uri) 25 | 26 | self.assertEqual("ftp://ftp.foo.com:8080/path/filename", sanitized_uri) 27 | 28 | def test_remove_user_info_accepts_urls_without_hostnames(self) -> None: 29 | storage_uri = urlparse("file:///path/filename") 30 | sanitized_uri = url_parser.remove_user_info(storage_uri) 31 | 32 | self.assertEqual("file:///path/filename", sanitized_uri) 33 | 34 | def test_sanitize_resource_uri_removes_username_and_password(self) -> None: 35 | storage_uri = urlparse("https://username:password@bucket/path/filename") 36 | sanitized_uri = url_parser.sanitize_resource_uri(storage_uri) 37 | 38 | self.assertEqual("https://bucket/path/filename", sanitized_uri) 39 | 40 | def test_sanitize_resource_uri_does_not_preserves_parameters(self) -> None: 41 | storage_uri = urlparse("https://username:password@bucket/path/filename?other=parameter") 42 | sanitized_uri = url_parser.sanitize_resource_uri(storage_uri) 43 | 44 | self.assertEqual("https://bucket/path/filename", sanitized_uri) 45 | 46 | def test_sanitize_resource_uri_preserves_port_number(self) -> None: 47 | storage_uri = urlparse("ftp://username:password@ftp.foo.com:8080/path/filename") 48 | sanitized_uri = url_parser.sanitize_resource_uri(storage_uri) 49 | 50 | self.assertEqual("ftp://ftp.foo.com:8080/path/filename", sanitized_uri) 51 | 52 | def test_sanitize_resource_uri_accepts_urls_without_hostnames(self) -> None: 53 | storage_uri = urlparse("file:///path/filename") 54 | sanitized_uri = url_parser.sanitize_resource_uri(storage_uri) 55 | 56 | self.assertEqual("file:///path/filename", sanitized_uri) 57 | --------------------------------------------------------------------------------