├── .circleci └── config.yml ├── .flake8 ├── .gitignore ├── LICENSE ├── MANIFEST.in ├── Makefile ├── NOTICE ├── README.md ├── docs ├── Makefile ├── api.txt ├── conf.py ├── examples.txt ├── index.txt ├── storage.txt └── yogadl.txt ├── examples ├── custom_data_ref.py ├── mnist.py └── walkthrough.py ├── mypy.ini ├── pyproject.toml ├── readthedocs.yaml ├── requirements.txt ├── setup.py ├── tests ├── __init__.py ├── integration │ ├── __init__.py │ ├── aws │ │ ├── __init__.py │ │ └── test_s3_system.py │ ├── gcp │ │ ├── __init__.py │ │ └── test_gcs_system.py │ ├── local │ │ ├── __init__.py │ │ ├── test_examples.py │ │ └── test_lfs_system.py │ └── util.py ├── performance │ ├── __init__.py │ └── imagenet │ │ ├── __init__.py │ │ ├── resnet_preprocessing.py │ │ └── test_imagenet.py └── unit │ ├── __init__.py │ ├── aws │ ├── __init__.py │ └── test_s3_storage.py │ ├── gcp │ ├── __init__.py │ └── test_gcs_storage.py │ ├── local │ ├── __init__.py │ ├── test_lfs_storage.py │ ├── test_lmdb_access.py │ ├── test_local_lmdb_dataref.py │ ├── test_rw_coordinator.py │ └── test_tensorflow_util.py │ └── util.py └── yogadl ├── __init__.py ├── _core.py ├── _keys_operator.py ├── _lmdb_handler.py ├── constants.py ├── dataref ├── __init__.py └── _local_lmdb_dataref.py ├── rw_coordinator ├── __init__.py ├── _client.py ├── _server.py └── communication_protocol.py ├── storage ├── __init__.py ├── _cloud_storage.py ├── _gcs_storage.py ├── _lfs_storage.py └── _s3_storage.py └── tensorflow.py /.circleci/config.yml: -------------------------------------------------------------------------------- 1 | # Use the latest 2.1 version of CircleCI pipeline process engine. See: https://circleci.com/docs/2.0/configuration-reference 2 | version: 2.1 3 | 4 | orbs: 5 | gcp-cli: circleci/gcp-cli@1.8.4 6 | 7 | commands: 8 | activate-service-account: 9 | steps: 10 | - run: 11 | name: Activate GCP service account 12 | command: | 13 | GOOGLE_APPLICATION_CREDENTIALS=${HOME}/gcloud-service-key.json 14 | echo ${GCLOUD_SERVICE_KEY} > ${GOOGLE_APPLICATION_CREDENTIALS} 15 | echo "export GOOGLE_APPLICATION_CREDENTIALS=\"${GOOGLE_APPLICATION_CREDENTIALS}\"" >> $BASH_ENV 16 | gcloud auth activate-service-account --key-file=${GOOGLE_APPLICATION_CREDENTIALS} 17 | 18 | gcloud --quiet config set project ${GOOGLE_PROJECT_ID} 19 | gcloud --quiet config set compute/zone ${GOOGLE_COMPUTE_ZONE} 20 | 21 | setup-python-venv: 22 | description: Set up and create Python venv. 23 | parameters: 24 | yogadl: 25 | type: boolean 26 | default: false 27 | extras-requires: 28 | type: string 29 | default: "" 30 | extra-requirements-file: 31 | type: string 32 | default: "" 33 | use-pyenv: 34 | type: boolean 35 | default: false 36 | steps: 37 | - when: 38 | condition: <> 39 | steps: 40 | - run: pyenv install 3.7.10 41 | - run: pyenv global 3.7.10 42 | - run: pip3 install --upgrade pip wheel setuptools 43 | 44 | # Put all the pip requirements into a single /tmp/requirements.txt file. 45 | - run: echo <> > /tmp/requirements.txt 46 | - run: cat <> >> /tmp/requirements.txt 47 | - run: cat /tmp/requirements.txt >> /tmp/cachefile 48 | - when: 49 | condition: <> 50 | steps: 51 | - run: python3 setup.py bdist_wheel -d /tmp 52 | - run: pip3 install /tmp/yogadl*.whl 53 | - run: pip3 install --no-deps --force-reinstall /tmp/yogadl*.whl 54 | - run: pip3 install -r /tmp/requirements.txt 55 | # Useful diagnostics for test failures. 56 | - run: pip3 freeze 57 | 58 | jobs: 59 | lint: 60 | docker: 61 | - image: cimg/python:3.7 62 | steps: 63 | - checkout 64 | - setup-python-venv: 65 | extra-requirements-file: "requirements.txt" 66 | - run: make check 67 | 68 | test: 69 | parameters: 70 | test-target: 71 | type: string 72 | tensorflow-version: 73 | type: string 74 | default: "2.4.1" 75 | gcp: 76 | type: boolean 77 | default: false 78 | machine: 79 | image: ubuntu-2004:202104-01 80 | steps: 81 | - checkout 82 | - setup-python-venv: 83 | yogadl: true 84 | use-pyenv: true 85 | extras-requires: "tensorflow==<>" 86 | extra-requirements-file: "requirements.txt" 87 | - when: 88 | condition: <> 89 | steps: 90 | - gcp-cli/install 91 | - activate-service-account 92 | - run: make <> 93 | 94 | workflows: 95 | lint: 96 | jobs: 97 | - lint 98 | 99 | test: 100 | jobs: 101 | - test: 102 | context: aws 103 | matrix: 104 | parameters: 105 | test-target: ["test-integration-aws"] 106 | tensorflow-version: ["1.15.5", "2.4.1"] 107 | 108 | - test: 109 | context: aws 110 | matrix: 111 | parameters: 112 | test-target: ["test-unit-aws"] 113 | 114 | - test: 115 | context: gcp 116 | matrix: 117 | parameters: 118 | test-target: ["test-integration-gcp"] 119 | tensorflow-version: ["1.15.5", "2.4.1"] 120 | gcp: [true] 121 | 122 | - test: 123 | context: gcp 124 | matrix: 125 | parameters: 126 | test-target: ["test-unit-gcp"] 127 | gcp: [true] 128 | 129 | - test: 130 | matrix: 131 | parameters: 132 | test-target: ["test-integration-local"] 133 | tensorflow-version: ["2.4.1"] # Note: local examples don't support 1.15.5 134 | 135 | - test: 136 | matrix: 137 | parameters: 138 | test-target: ["test-unit-local"] 139 | -------------------------------------------------------------------------------- /.flake8: -------------------------------------------------------------------------------- 1 | [flake8] 2 | max-line-length = 100 3 | 4 | # We ignore F401 in __init__.py because it is expected for there to be 5 | # "unused imports" when defining a "regular" package. (This file is 6 | # implicitly executed when the package is imported, and the imports would 7 | # be used by the importer.) We ignore patch_saver_restore.py because it includes 8 | # a near-verbatim TensorFlow function with a small patch. 9 | per-file-ignores = __init__.py:F401 patch_saver_restore.py:E111,E114, 10 | 11 | # Explanations for ignored error codes: 12 | # - D1* (no missing docstrings): too much effort to start enforcing 13 | # - D200 (short docstring must fit in one line with quotes): stylistic choice 14 | # - D202 (no blank lines after function docstrings): stylistic choice 15 | # - D203 (blank line before class docstring): stylistic choice 16 | # - D205 (blank line between summary and description): not enforcing single-line summaries 17 | # - D212 (docstring should start on first line): stylistic choice (prefer D213, docstrings start on second line) 18 | # - D4* (docstring content warnings): too much effort to start enforcing 19 | # - E203 (no space before colon): not PEP8-compliant; triggered by Black-formatted code 20 | # - W503 (no line breaks before binary operator): not PEP8-compliant; triggered by Black-formatted code 21 | # - C812-C816 (missing trailing comma): stylistic choice 22 | ignore = D1,D200,D202,D203,D205,D212,D4,E203,W503,C812,C813,C814,C815,C816 23 | 24 | show_source = true 25 | 26 | # flake8-colors 27 | format = ${cyan}%(path)s${reset}:${yellow_bold}%(row)d${reset}:${green_bold}%(col)d${reset}: ${red_bold}%(code)s${reset} %(text)s 28 | 29 | # flake8-docstrings 30 | docstring-convention = google 31 | 32 | # flake8-import-order 33 | application-import-names = yogadl 34 | import-order-style = edited 35 | 36 | # flake8-quotes 37 | inline-quotes = " 38 | multiline-quotes = """ 39 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | yogadl.egg-info 2 | __pycache__ 3 | .mypy_cache 4 | docs/site 5 | *.sw[op] 6 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright [yyyy] [name of copyright owner] 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /MANIFEST.in: -------------------------------------------------------------------------------- 1 | include LICENSE 2 | include NOTICE 3 | -------------------------------------------------------------------------------- /Makefile: -------------------------------------------------------------------------------- 1 | all: check test 2 | 3 | check: black flake8 mypy 4 | 5 | black: 6 | black --check yogadl tests 7 | 8 | flake8: 9 | flake8 yogadl tests 10 | 11 | mypy: 12 | mypy yogadl tests 13 | 14 | fmt: 15 | black yogadl tests 16 | 17 | TEST_EXPR ?= "" 18 | 19 | test-unit-aws: 20 | pytest -v -k $(TEST_EXPR) tests/unit/aws 21 | 22 | test-unit-gcp: 23 | pytest -v -k $(TEST_EXPR) tests/unit/gcp 24 | 25 | test-unit-local: 26 | pytest -v -k $(TEST_EXPR) tests/unit/local 27 | 28 | test-integration-aws: 29 | pytest -v -k $(TEST_EXPR) tests/integration/aws 30 | 31 | test-integration-gcp: 32 | pytest -v -k $(TEST_EXPR) tests/integration/gcp 33 | 34 | test-integration-local: 35 | pytest -v -k $(TEST_EXPR) tests/integration/local 36 | 37 | test-local: test-unit-local test-integration-local 38 | 39 | test-aws: test-unit-aws test-integration-aws 40 | 41 | test-gcp: test-unit-gcp test-integration-gcp 42 | 43 | test: test-local 44 | 45 | test-all: test-local test-gcp 46 | -------------------------------------------------------------------------------- /NOTICE: -------------------------------------------------------------------------------- 1 | YogaDL 2 | Copyright 2020, Determined AI. 3 | 4 | 5 | YogaDL includes derived work from the following: 6 | 7 | Tensorpack 8 | Copyright (c) 2016 Yuxin Wu 9 | 10 | Licensed under the Apache License, Version 2.0 (the "License"); 11 | you may not use this file except in compliance with the License. 12 | You may obtain a copy of the License at 13 | 14 | http://www.apache.org/licenses/LICENSE-2.0 15 | 16 | Unless required by applicable law or agreed to in writing, software 17 | distributed under the License is distributed on an "AS IS" BASIS, 18 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 19 | See the License for the specific language governing permissions and 20 | limitations under the License. 21 | 22 | The derived work can be found in this file: 23 | 24 | - yogadl/_lmdb_handler.py 25 | 26 | 27 | Tensorflow/tpu 28 | Copyright (c) 2017 Tensorflow Authors 29 | 30 | Licensed under the Apache License, Version 2.0 (the "License"); 31 | you may not use this file except in compliance with the License. 32 | You may obtain a copy of the License at 33 | 34 | http://www.apache.org/licenses/LICENSE-2.0 35 | 36 | Unless required by applicable law or agreed to in writing, software 37 | distributed under the License is distributed on an "AS IS" BASIS, 38 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 39 | See the License for the specific language governing permissions and 40 | limitations under the License. 41 | 42 | The derived work can be found in this file: 43 | 44 | - tests/performance/imagenet/test_imagenet.py 45 | - tests/performance/imagenet/resnet_preprocessing.py 46 | 47 | 48 | Tensorflow/docs 49 | Copyright 2018 The TensorFlow Authors. All rights reserved. 50 | 51 | Licensed under the Apache License, Version 2.0 (the "License"); 52 | you may not use this file except in compliance with the License. 53 | You may obtain a copy of the License at 54 | 55 | http://www.apache.org/licenses/LICENSE-2.0 56 | 57 | Unless required by applicable law or agreed to in writing, software 58 | distributed under the License is distributed on an "AS IS" BASIS, 59 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 60 | See the License for the specific language governing permissions and 61 | limitations under the License. 62 | 63 | The derived work can be found in this file: 64 | 65 | - examples/minst.py 66 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Yoga Data Layer: The _Flexible_ Data Layer 2 | 3 | A better approach to data loading for Deep Learning. API-transparent caching to disk, GCS, or S3. 4 | 5 | ## Why `yogadl`? 6 | 7 | At [Determined AI](https://determined.ai), we help many of our customers perform high-performance data 8 | loading for deep learning models. We believe every data loader should have two layers: the 9 | **random-access layer** and the **sequential layer**. 10 | 11 | The **random-access layer** is critical for good training infrastructure. Direct random access to 12 | any record enables: 13 | 14 | - Shuffling (potentially every epoch) 15 | - Pausing/continuing training mid-epoch 16 | - Sharding the dataset efficiently for distributed training 17 | 18 | The **sequential layer** starts as soon as you decide the order in which you will access the records in 19 | the dataset. Often the transition is implicit, in which case it starts as soon as you are done 20 | modifying the order of access (i.e. via shuffling, sharding, or splitting). This layer is vital to 21 | performance optimizations because it enables: 22 | 23 | - Prefetching data loading to hide latency costs 24 | - Parallelizing data loading to hide compute costs 25 | 26 | Here is a simple code snippet to illustrate what the transition from random-access layer to 27 | sequential layer looks like: 28 | 29 | ```python 30 | # Start of random-access layer. 31 | indices = list(range(100)) 32 | indices = indices[skip:] 33 | indices=np.random.shuffle(indices) 34 | 35 | # Start of sequential layer. 36 | 37 | def record_gen(): 38 | for i in indices: 39 | yield read_file_at_index(i) 40 | 41 | record_ds = tf.data.Dataset.from_generator(record_gen, ...) 42 | final_ds = record_ds.prefetch(...) 43 | 44 | ``` 45 | 46 | Notice that in the above example, the `tf.data` API is used, but only in the sequential layer. 47 | This is because `tf.data` has no concept of the random access layer. As a result: 48 | 49 | - `tf.data.Dataset.shuffle()` can only approximate a shuffle. Calling `.shuffle(N)` will read 50 | `N` records into a buffer and choose samples randomly from **those `N` records**, while a true 51 | shuffle chooses samples randomly from the **entire dataset**. This shortcoming forces you 52 | to choose between memory footprint and the quality of your shuffle. The only true 53 | shuffle with tf.data.Dataset.shuffle() is to read the entire dataset into memory. 54 | - `tf.data.Dataset.skip(N)` is as inefficient as possible. Each of the `N` skipped records will 55 | still be read from disk and processed normally, according to all of the operations preceeding 56 | the `.skip()` call, making `.skip()` prohibitively expensive for most use cases. 57 | - Pausing and continuing training is only possible by saving the state of a `tf.data.Iterator`. 58 | However, saving a `tf.data.Iterator` does not work with all datasets. In particular, it does 59 | not work with datasets created using `from_generator()`, which is the easiest way to create a 60 | `tf.data.Dataset`. 61 | 62 | We have seen countless instances where `tf.data.Dataset` shortcomings have made life harder for 63 | deep learning practitioners, so we set out to build something better. We set out to build a new 64 | data layer which could augment an existing `tf.data.Dataset` data loader with the properties should 65 | come standard with every data loader. 66 | 67 | At the same time, we wanted this new data layer to relieve another key pain point: high-performance 68 | dataset caching and dataset versioning. 69 | 70 | ## What is `yogadl`? 71 | 72 | We designed `yogadl` to be two things: a standalone caching layer to imbue existing data loaders 73 | with the properties that come from a random-access layer, and a better interface for defining data 74 | loaders in general. 75 | 76 | ### A standalone caching tool 77 | 78 | Since `tf.data.Dataset`-based datasets have no random-access layer, `yogadl` caches them to disk in 79 | a random-access-friendly way. The storage mechanism is, in fact, nearly identical to how 80 | [TensorPack caches datasets to disk](https://tensorpack.readthedocs.io/modules/dataflow.html#tensorpack.dataflow.LMDBSerializer), 81 | only with some additional abstractions to allow dataset versioning, cloud storage, and all of the 82 | wonderful features that a data loader with a random-access layer ought to have. 83 | 84 | What does all this do for you? A few things: 85 | 86 | - **Better training**: A `yogadl`-cached `tf.data.Dataset` will have better shuffling than a 87 | native `tf.data.Dataset`. Additionally, pausing and continuing training mid-epoch will be 88 | simple and robust, and efficient sharding for distributed training comes standard. 89 | - **Faster data loading**: Slow data loader? Don't waste your time optimizing it. `yogadl` will 90 | save it in a high-performance cache the first time it is used, and all future uses will be 91 | fast and efficient. 92 | - **API-transparent**: Not all operations in the data loader are cacheable. Data augmentation 93 | must be done at run time. `yogadl` allows you to keep your existing data augmentation code. 94 | 95 | ### A better interface 96 | 97 | At the core of `yogadl` is the `DataRef` interface, which creates an explicit boundary between the 98 | random-access layer and the sequential layer. 99 | 100 | We are not the first people to think of this: PyTorch separates the `DataSet` (the random-access 101 | layer) from the `Sampler` (which defines the sequential layer). Keras has a `Sequence` object 102 | which defines the random-access layer, leaving the order of access (the sequential layer) to be 103 | decided by the arguments to `model.fit()`. Both `DataSet` and `Sequence` are already 100% 104 | compatible with `yogadl`'s `DataRef` interface (although `yogadl` does not yet include those 105 | adapters). 106 | 107 | And yet, the world is still full of data loaders which are lacking. At Determined AI, we are 108 | dedicated to advancing the state of the art for training Deep Learning models, and we believe that 109 | a better interface for data loading is a critical piece of that goal. Any data loader which 110 | implements the `DataRef` interface is capable of proper shuffling, pausing and continuing training 111 | mid-epoch, and efficient multi-machine distributed training. 112 | 113 | ## What is `yogadl` _not_? 114 | 115 | `yogadl` is not a data manipulation API. 116 | [This](https://www.tensorflow.org/api_docs/python/tf/data/Dataset) 117 | [world](https://tensorpack.readthedocs.io/tutorial/dataflow.html) 118 | [has](https://keras.io/preprocessing/image/) 119 | [more](https://pytorch.org/docs/stable/torchvision/ops.html) 120 | [than](https://numpy.org/) 121 | [enough](https://pandas.pydata.org/) 122 | [of](https://docs.nvidia.com/deeplearning/sdk/dali-developer-guide/docs/index.html) 123 | [those](https://opencv-python-tutroals.readthedocs.io/en/latest/). 124 | Instead, `yogadl` seeks to be API-transparent so that you can continue to use your existing data 125 | loading code, but with all the benefits of a high-performance, random-access cache. If you have 126 | data augmentation steps which cannot be cached, that code should continue to work without any 127 | modifications. 128 | 129 | `yogadl` does not (at this time) work with any data frameworks other than `tf.data.Dataset.` 130 | First-class support for (tf.)Keras `Sequence` objects, PyTorch `DataSet` objects, and TensorPack 131 | `DataFlow` objects is on the near-term roadmap. 132 | 133 | `yogadl` offers basic dataset versioning, but it is not (at this time) a full-blown version control 134 | for datasets. Offering something like version control for datasets is on the roadmap as well. 135 | 136 | ## Installing `yogadl` 137 | 138 | `yogadl` can be installed via `pip install yogadl`. 139 | 140 | ## Further Information 141 | 142 | Please refer to the following links for more information: 143 | - [YogaDL official documentation](https://yogadl.readthedocs.io/) 144 | - [YogaDL examples](https://yogadl.readthedocs.io/en/latest/examples.html) 145 | - [YogaDL API Reference](https://yogadl.readthedocs.io/en/latest/api.html) 146 | -------------------------------------------------------------------------------- /docs/Makefile: -------------------------------------------------------------------------------- 1 | SPHINXOPTS = -W 2 | SPHINXBUILD = sphinx-build 3 | 4 | .PHONY: build 5 | build: sp-html 6 | 7 | .PHONY: clean 8 | clean: 9 | rm -rf site 10 | 11 | live: 12 | npx nodemon --ext txt --exec "$(MAKE) build" --ignore site 13 | 14 | # Catch-all target: route all unknown targets to Sphinx using the new 15 | # "make mode" option. $(O) is meant as a shortcut for $(SPHINXOPTS). 16 | .PHONY: ALWAYS 17 | sp-%: ALWAYS 18 | @$(SPHINXBUILD) -M $* . site $(SPHINXOPTS) $(O) 19 | -------------------------------------------------------------------------------- /docs/api.txt: -------------------------------------------------------------------------------- 1 | API Reference 2 | ============= 3 | 4 | .. toctree:: 5 | :maxdepth: 1 6 | :name: api_reference 7 | 8 | yogadl 9 | storage 10 | -------------------------------------------------------------------------------- /docs/conf.py: -------------------------------------------------------------------------------- 1 | # Configuration file for the Sphinx documentation builder. 2 | # 3 | # This file only contains a selection of the most common options. For a full 4 | # list see the documentation: 5 | # http://www.sphinx-doc.org/en/master/config 6 | 7 | # -- Path setup -------------------------------------------------------------- 8 | 9 | # If extensions (or modules to document with autodoc) are in another directory, 10 | # add these directories to sys.path here. If the directory is relative to the 11 | # documentation root, use os.path.abspath to make it absolute, like shown here. 12 | 13 | import os 14 | import pathlib 15 | import sys 16 | 17 | import determined_ai_sphinx_theme 18 | 19 | sys.path.append(os.path.abspath(os.path.dirname(__file__))) 20 | 21 | # -- Project information ----------------------------------------------------- 22 | 23 | project = "YogaDL" 24 | html_title = "YogaDL Documentation" 25 | copyright = "2020, Determined AI" 26 | author = "hello@determined.ai" 27 | 28 | # The version info for the project you"re documenting, acts as replacement for 29 | # |version| and |release|, also used in various other places throughout the 30 | # built documents. 31 | # 32 | # The short X.Y version. 33 | version = "0.1" 34 | 35 | # The full version, including alpha/beta/rc tags. 36 | release = version 37 | 38 | # -- General configuration --------------------------------------------------- 39 | 40 | # Add any Sphinx extension module names here, as strings. They can be 41 | # extensions coming with Sphinx (named 'sphinx.ext.*') or your custom 42 | # ones. 43 | extensions = [ 44 | "sphinx.ext.autodoc", 45 | "sphinx.ext.extlinks", 46 | "sphinx.ext.intersphinx", 47 | "sphinx.ext.mathjax", 48 | "sphinx.ext.napoleon", 49 | "sphinxarg.ext", 50 | # "sphinx_gallery.gen_gallery", 51 | "sphinx_copybutton", 52 | "m2r", 53 | ] 54 | 55 | autosummary_generate = True 56 | autoclass_content = "class" 57 | 58 | # List of patterns, relative to source directory, that match files and 59 | # directories to ignore when looking for source files. 60 | # This pattern also effect to html_static_path and html_extra_path 61 | exclude_patterns = ["_build", "Thumbs.db", ".DS_Store", "examples", "requirements.txt"] 62 | 63 | # The suffix of source filenames. 64 | source_suffix = {".rst": "restructuredtext", ".txt": "restructuredtext"} 65 | 66 | highlight_language = "none" 67 | 68 | # -- Options for HTML output ------------------------------------------------- 69 | 70 | # The theme to use for HTML and HTML Help pages. See the documentation for 71 | # a list of builtin themes. 72 | # 73 | 74 | # Add any paths that contain custom static files (such as style sheets) here, 75 | # relative to this directory. They are copied after the builtin static files, 76 | # so a file named 'default.css' will overwrite the builtin 'default.css'. 77 | # html_static_path = ["_static"] 78 | 79 | # -- HTML theme settings ------------------------------------------------ 80 | 81 | html_show_sourcelink = False 82 | html_show_sphinx = False 83 | html_last_updated_fmt = None 84 | # html_sidebars = {"**": ["logo-text.html", "globaltoc.html", "localtoc.html", "searchbox.html"]} 85 | 86 | html_theme_path = [determined_ai_sphinx_theme.get_html_theme_path()] 87 | html_theme = "determined_ai_sphinx_theme" 88 | # html_logo = "assets/images/logo.png" 89 | # html_favicon = "assets/images/favicon.ico" 90 | 91 | html_theme_options = { 92 | "analytics_id": "UA-110089850-1", 93 | "collapse_navigation": False, 94 | "display_version": True, 95 | "logo_only": False, 96 | } 97 | 98 | language = "en" 99 | 100 | todo_include_todos = True 101 | 102 | html_use_index = True 103 | html_domain_indices = True 104 | 105 | # -- Sphinx Gallery settings ------------------------------------------- 106 | 107 | sphinx_gallery_conf = { 108 | # Subsections are sorted by number of code lines per example. Override this 109 | # to sort via the explicit ordering. 110 | # "within_subsection_order": CustomOrdering, 111 | # "download_all_examples": True, 112 | "plot_gallery": False, 113 | "min_reported_time": float("inf"), 114 | } 115 | -------------------------------------------------------------------------------- /docs/examples.txt: -------------------------------------------------------------------------------- 1 | Examples 2 | ======== 3 | 4 | The following examples will walk you through the core concepts of YogaDL: 5 | storing, fetching, and streaming datasets. 6 | 7 | 8 | Creating a yogadl.Storage 9 | ^^^^^^^^^^^^^^^^^^^^^^^^^ 10 | 11 | Most users will interact with :class:`yogadl.Storage` object as a mechanism for 12 | storing and fetching datasets. The simplest ``Storage`` is the 13 | :class:`yogadl.storage.LFSStorage`, or "local filesystem storage". Let's create 14 | one: 15 | 16 | .. literalinclude:: ../examples/walkthrough.py 17 | :language: python 18 | :start-after: START creating a yogadl.Storage 19 | :end-before: END creating a yogadl.Storage 20 | 21 | YogaDL also comes with built-in support for GCS via 22 | :class:`yogadl.storage.GCSStorage` and for S3 via 23 | :class:`yogadl.storage.S3Storage`. 24 | 25 | Storing a dataset 26 | ^^^^^^^^^^^^^^^^^ 27 | 28 | Let's create a silly 10-record dataset and store it in the ``yogadl.Storage``. 29 | This is done via ``storage.submit()``. During ``storage.submit()``, the entire 30 | dataset will be read and written to the storage backend (in this case, to a 31 | file). 32 | 33 | .. literalinclude:: ../examples/walkthrough.py 34 | :language: python 35 | :start-after: START storing a dataset 36 | :end-before: END storing a dataset 37 | 38 | 39 | Fetching a dataset 40 | ^^^^^^^^^^^^^^^^^^ 41 | 42 | Later (possibly in a different process), you can fetch a 43 | :class:`yogadl.DataRef` representing the dataset via ``storage.fetch()``. 44 | 45 | A ``DataRef`` is just a reference to a dataset. In this case, the dataset will 46 | be stored in a file on your computer, but a ``DataRef`` could just as easily 47 | refer to a dataset on some remote machine; the interface would be the same. 48 | 49 | To actually access the dataset, you need to first call ``dataref.stream()``, 50 | which will return a :class:`yogadl.Stream` object. Then you can convert the 51 | ``Stream`` object to a framework-native data loader format (currently only 52 | ``tf.data.Dataset`` is supported). 53 | 54 | .. literalinclude:: ../examples/walkthrough.py 55 | :language: python 56 | :start-after: START fetching a dataset 57 | :end-before: END fetching a dataset 58 | 59 | This should print: 60 | 61 | .. code:: 62 | 63 | tf.Tensor([5 1 9 6 7], shape=(5,), dtype=int64) 64 | tf.Tensor([1 7 3 9 8], shape=(5,), dtype=int64) 65 | tf.Tensor([2 6 0 4 5], shape=(5,), dtype=int64) 66 | tf.Tensor([9 5 3 0 8], shape=(5,), dtype=int64) 67 | tf.Tensor([6 7 4 1 2], shape=(5,), dtype=int64) 68 | 69 | Notice that: 70 | 71 | - The start_offset is only applied to the first epoch, so in this example 72 | .repeat(3) gave us 2.5 epochs of data since we skipped the first epoch. 73 | - The shuffle is a true shuffle. The shuffled stream samples from the whole 74 | dataset without any concept of a "buffer", as with 75 | ``tf.data.Dataset.shuffle()`` 76 | - The shuffle is reproducible because we chose a shuffle seed. 77 | - Each epoch is reshuffled. 78 | 79 | 80 | Can I get the same features in fewer steps? 81 | ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ 82 | 83 | As a matter of fact, you can! In order to support the common use-case of 84 | running the same dataset through many different models during model development 85 | or hyperparameter search, you can use the ``storage.cacheable()`` decorator to 86 | decorate a function that returns a datastet. 87 | 88 | When the decorated function is called the first time, it will run one time and 89 | save its output to ``storage``. On subsequent calls, the original function 90 | will not run, but its cached output will be returned instead. 91 | 92 | In this way, you can get the benefit of caching without a single script and 93 | only a single call against the ``storage`` object: 94 | 95 | .. literalinclude:: ../examples/walkthrough.py 96 | :language: python 97 | :start-after: START can I get the same features in fewer steps? 98 | :end-before: END can I get the same features in fewer steps? 99 | 100 | The ``storage.cacheble()`` decorator is multi-processing safe, so if two 101 | identical processes are configured to use the same storage, only one of them 102 | will create and save the dataset. The other one will wait for the dataset to 103 | be saved and will then read the dataset from the cache. 104 | 105 | 106 | End-to-end training example: 107 | ^^^^^^^^^^^^^^^^^^^^^^^^^^^^ 108 | 109 | Here is an example of how you might use YogaDL to train on the second half of 110 | an MNIST dataset. This illustrates the ability to continue training mid-dataset 111 | that is simply not natively possible with tf.keras. Without YogaDL, you could 112 | imitate this behavior using ``tf.data.Dataset.skip(N)``, but that is 113 | prohibitively expensive for large values of ``N``. 114 | 115 | .. note:: 116 | 117 | MNIST is such a small dataset that YogaDL is not going to outperform 118 | any example that treats MNIST as an in-memory dataset. 119 | 120 | .. literalinclude:: ../examples/mnist.py 121 | :language: python 122 | :start-after: INCLUDE IN DOCS 123 | 124 | 125 | Advanced Use Case: Distributed Training 126 | ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ 127 | 128 | Sharding a dataset for use with distributed training is easy. If you are using 129 | Horovod for distributed training, you only need to alter the arguments of your 130 | call to ``DataRef.stream()``. 131 | 132 | .. code:: python 133 | 134 | import horovod.tensorflow as hvd 135 | 136 | ... 137 | 138 | stream = dataref.stream( 139 | shard_rank=hvd.rank(), num_shards=hvd.size() 140 | ) 141 | 142 | 143 | Advanced Use Case: Custom DataRef Objects 144 | ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ 145 | 146 | If you have an advanced use case, like generating data on an external machine 147 | and streaming it to another machine for training or something, and you would 148 | like to integrate with a platform that allows you to submit your dataset as a 149 | ``yogadl.DataRef``, you can implement a custom :class:`yogadl.DataRef`. By 150 | implementing the ``yogadl.DataRef`` interface, you can fully customize the 151 | behavior of how the platform interacts with your dataset. Here is a toy example 152 | of what that might look like: 153 | 154 | .. literalinclude:: ../examples/custom_data_ref.py 155 | :language: python 156 | :start-after: INCLUDE IN DOCS 157 | -------------------------------------------------------------------------------- /docs/index.txt: -------------------------------------------------------------------------------- 1 | .. toctree:: 2 | :hidden: 3 | :maxdepth: 2 4 | 5 | api 6 | examples 7 | 8 | .. mdinclude:: ../README.md 9 | -------------------------------------------------------------------------------- /docs/storage.txt: -------------------------------------------------------------------------------- 1 | yogadl.storage 2 | ============== 3 | 4 | .. autoclass:: yogadl.storage.LFSConfigurations 5 | :noindex: 6 | 7 | .. autoclass:: yogadl.storage.LFSStorage 8 | :noindex: 9 | :members: 10 | 11 | .. autoclass:: yogadl.storage.S3Configurations 12 | :noindex: 13 | 14 | .. autoclass:: yogadl.storage.S3Storage 15 | :noindex: 16 | :inherited-members: 17 | 18 | .. autoclass:: yogadl.storage.GCSConfigurations 19 | :noindex: 20 | 21 | .. autoclass:: yogadl.storage.GCSStorage 22 | :noindex: 23 | :inherited-members: 24 | -------------------------------------------------------------------------------- /docs/yogadl.txt: -------------------------------------------------------------------------------- 1 | yogadl 2 | ====== 3 | 4 | .. autoclass:: yogadl.Stream 5 | :noindex: 6 | :members: __iter__, __len__ 7 | 8 | .. autoclass:: yogadl.DataRef 9 | :noindex: 10 | :members: stream, __len__ 11 | 12 | .. autoclass:: yogadl.Storage 13 | :noindex: 14 | :members: 15 | -------------------------------------------------------------------------------- /examples/custom_data_ref.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020 Determined AI. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """ 16 | If you have an advanced use case, like generating data on an external machine 17 | and streaming it to another machine for training or something, and you would 18 | like to integrate with a platform that allows you to submit your dataset as a 19 | ``yogadl.DataRef``, you can implement a custom ``yogadl.DataRef``. By 20 | implementing the ``yogadl.DataRef`` interface, you can fully customize the 21 | behavior of how the platform interacts with your dataset. Here is a toy example 22 | of what that might look like: 23 | """ 24 | 25 | # INCLUDE IN DOCS 26 | import os 27 | import yogadl 28 | import yogadl.tensorflow 29 | import tensorflow as tf 30 | 31 | class RandomDataRef(yogadl.DataRef): 32 | """ 33 | A DataRef to a a non-reproducible dataset that just produces random 34 | int32 values. 35 | """ 36 | 37 | def __len__(self): 38 | return 10 39 | 40 | def stream( 41 | self, 42 | start_offset = 0, 43 | shuffle = False, 44 | skip_shuffle_at_epoch_end = False, 45 | shuffle_seed = None, 46 | shard_rank = 0, 47 | num_shards = 1, 48 | drop_shard_remainder = False, 49 | ) -> yogadl.Stream: 50 | """ 51 | For custom DataRefs, .stream() will often be a pretty beefy 52 | function. This example simplifies it by assuming that the dataset 53 | is non-reproducible, meaning that shuffle and shuffle_seed 54 | arguments are meaningless, and the shard_rank is only used to 55 | determine how many records will be yielded during each epoch. 56 | """ 57 | 58 | first_epoch = True 59 | 60 | def iterator_fn(): 61 | nonlocal first_epoch 62 | if first_epoch: 63 | first_epoch = False 64 | start = start_offset + shard_rank 65 | else: 66 | start = shard_rank 67 | 68 | if drop_shard_remainder: 69 | end = len(self) - (len(self) % num_shards) 70 | else: 71 | end = len(self) 72 | 73 | for _ in range(start, end, num_shards): 74 | # Make a uint32 out of 4 random bytes 75 | r = os.urandom(4) 76 | yield r[0] + (r[1] << 8) + (r[2] << 16) + (r[3] << 24) 77 | 78 | # Since we will later convert to tf.data.Dataset, 79 | # we will supply output_types and shapes. 80 | return yogadl.Stream( 81 | iterator_fn, 82 | len(self), 83 | output_types=tf.uint32, 84 | output_shapes=tf.TensorShape([]) 85 | ) 86 | 87 | dataref = RandomDataRef() 88 | stream = dataref.stream() 89 | records = yogadl.tensorflow.make_tf_dataset(stream) 90 | batches = records.batch(5) 91 | for batch in batches: 92 | print(batch) 93 | -------------------------------------------------------------------------------- /examples/mnist.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020 Determined AI. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | # 16 | # `normalize_img()` and the model definition are derived from the TensorFlow 17 | # documentation: https://www.tensorflow.org/datasets/keras_example 18 | # 19 | # Copyright 2018 The TensorFlow Authors. All Rights Reserved. 20 | # 21 | # Licensed under the Apache License, Version 2.0 (the "License"); 22 | # you may not use this file except in compliance with the License. 23 | # You may obtain a copy of the License at 24 | # 25 | # http://www.apache.org/licenses/LICENSE-2.0 26 | # 27 | # Unless required by applicable law or agreed to in writing, software 28 | # distributed under the License is distributed on an "AS IS" BASIS, 29 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 30 | # See the License for the specific language governing permissions and 31 | # limitations under the License. 32 | # ============================================================================== 33 | """ 34 | End-to-end training example: 35 | 36 | Here is an example of how you might use YogaDL to train on the second half of 37 | an MNIST dataset. This illustrates the ability to continue training mid-dataset 38 | that is simply not natively possible with tf.keras. Without YogaDL, you could 39 | imitate this behavior using tf.data.Dataset.skip(N), but that is 40 | prohibitively expensive for large values of N. 41 | """ 42 | 43 | # INCLUDE IN DOCS 44 | import math 45 | import os 46 | import tensorflow as tf 47 | import tensorflow_datasets as tfds 48 | import yogadl 49 | import yogadl.tensorflow 50 | import yogadl.storage 51 | 52 | BATCH_SIZE = 32 53 | 54 | # Configure the yogadl storage. 55 | storage_path = "/tmp/yogadl_cache" 56 | os.makedirs(storage_path, exist_ok=True) 57 | lfs_config = yogadl.storage.LFSConfigurations(storage_path) 58 | storage = yogadl.storage.LFSStorage(lfs_config) 59 | 60 | @storage.cacheable("mnist", "1.0") 61 | def make_data(): 62 | mnist = tfds.image.MNIST() 63 | mnist.download_and_prepare() 64 | dataset = mnist.as_dataset(as_supervised=True)["train"] 65 | 66 | # Apply dataset transformations from the TensorFlow docs: 67 | # (https://www.tensorflow.org/datasets/keras_example) 68 | 69 | def normalize_img(image, label): 70 | """Normalizes images: `uint8` -> `float32`.""" 71 | return tf.cast(image, tf.float32) / 255., label 72 | 73 | return dataset.map(normalize_img) 74 | 75 | # Get the DataRef from the storage via the decorated function. 76 | dataref = make_data() 77 | 78 | # Stream the dataset starting halfway through it. 79 | num_batches = math.ceil(len(dataref) / BATCH_SIZE) 80 | batches_to_skip = num_batches // 2 81 | records_to_skip = batches_to_skip * BATCH_SIZE 82 | stream = dataref.stream( 83 | start_offset=records_to_skip, shuffle=True, shuffle_seed=777 84 | ) 85 | 86 | # Convert the stream to a tf.data.Dataset object. 87 | dataset = yogadl.tensorflow.make_tf_dataset(stream) 88 | 89 | # Apply normal data augmentation and prefetch steps. 90 | dataset = dataset.batch(BATCH_SIZE) 91 | dataset = dataset.prefetch(tf.data.experimental.AUTOTUNE) 92 | 93 | # Model is straight from the TensorFlow docs: 94 | # https://www.tensorflow.org/datasets/keras_example 95 | model = tf.keras.models.Sequential([ 96 | tf.keras.layers.Flatten(input_shape=(28, 28, 1)), 97 | tf.keras.layers.Dense(128,activation='relu'), 98 | tf.keras.layers.Dense(10, activation='softmax') 99 | ]) 100 | model.compile( 101 | loss='sparse_categorical_crossentropy', 102 | optimizer=tf.keras.optimizers.Adam(0.001), 103 | metrics=['accuracy'], 104 | ) 105 | model.fit(dataset) 106 | -------------------------------------------------------------------------------- /examples/walkthrough.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020 Determined AI. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """ 16 | This file contains the code snippets which are used in the Examples section of 17 | the documentation. It is compiled in one place here for testing purposes. 18 | 19 | The START and END comments denote sections of this file which will become code 20 | snippets in the documentation. 21 | """ 22 | 23 | # START creating a yogadl.Storage 24 | import os 25 | import yogadl 26 | import yogadl.storage 27 | 28 | # Create a yogadl.Storage object backed by the local filesystem. 29 | storage_path = "/tmp/yogadl_cache" 30 | os.makedirs(storage_path, exist_ok=True) 31 | lfs_config = yogadl.storage.LFSConfigurations(storage_path) 32 | storage = yogadl.storage.LFSStorage(lfs_config) 33 | # END creating a yogadl.Storage 34 | 35 | 36 | # START storing a dataset 37 | import tensorflow as tf 38 | 39 | # Create a dataset we can store. 40 | records = tf.data.Dataset.range(10) 41 | 42 | # Store this dataset as "range" version "1.0". 43 | storage.submit(records, "range", "1.0") 44 | # END storing a dataset 45 | 46 | 47 | # START fetching a dataset 48 | import yogadl.tensorflow 49 | 50 | # Get the DataRef. 51 | dataref = storage.fetch("range", "1.0") 52 | 53 | # Tell the DataRef how to stream the dataset. 54 | stream = dataref.stream(start_offset=5, shuffle=True, shuffle_seed=777) 55 | 56 | # Interpret the stream as a tensorflow dataset 57 | records = yogadl.tensorflow.make_tf_dataset(stream) 58 | 59 | # It's a real tf.data.Dataset; you can use normal tf.data operations on it. 60 | batches = records.repeat(3).batch(5) 61 | 62 | # (this part requires TensorFlow >= 2.0) 63 | for batch in batches: 64 | print(batch) 65 | # END fetching a dataset 66 | 67 | 68 | # START can I get the same features in fewer steps? 69 | @storage.cacheable("range", "2.0") 70 | def make_records(): 71 | print("Cache not found, making range v2 dataset...") 72 | records = tf.data.Dataset.range(10).map(lambda x: 2*x) 73 | return records 74 | 75 | # Follow the same steps as before. 76 | dataref = make_records() 77 | stream = dataref.stream() 78 | records = yogadl.tensorflow.make_tf_dataset(stream) 79 | batches = records.repeat(3).batch(5) 80 | 81 | for batch in batches: 82 | print(batch) 83 | # END can I get the same features in fewer steps? 84 | -------------------------------------------------------------------------------- /mypy.ini: -------------------------------------------------------------------------------- 1 | [mypy] 2 | python_version = 3.6 3 | follow_imports = silent 4 | ignore_missing_imports = True 5 | 6 | # All strict checks. 7 | check_untyped_defs = True 8 | disallow_incomplete_defs = True 9 | disallow_subclassing_any = True 10 | disallow_untyped_calls = True 11 | disallow_untyped_decorators = True 12 | disallow_untyped_defs = True 13 | no_implicit_optional = True 14 | strict_equality = True 15 | warn_redundant_casts = True 16 | warn_return_any = True 17 | warn_unused_configs = True 18 | warn_unused_ignores = True 19 | 20 | [mypy-tests.performance.imagenet.resnet_preprocessing] 21 | ignore_errors = True 22 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [tool.black] 2 | line-length = 100 3 | -------------------------------------------------------------------------------- /readthedocs.yaml: -------------------------------------------------------------------------------- 1 | # Read the Docs configuration file 2 | # See https://docs.readthedocs.io/en/stable/config-file/v2.html for details 3 | 4 | version: 2 5 | 6 | # Build documentation in the docs/ directory with Sphinx 7 | sphinx: 8 | configuration: docs/conf.py 9 | 10 | python: 11 | version: 3.7 12 | install: 13 | - requirements: requirements.txt 14 | - method: pip 15 | path: . 16 | extra_requirements: 17 | - tf 18 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | black 2 | flake8 3 | mypy 4 | pytest 5 | tensorflow_datasets<2 6 | tl.testing 7 | pytest 8 | 9 | docutils==0.15.2 10 | sphinx==2.4.4 11 | sphinx-argparse>=0.2.5 12 | git+https://github.com/determined-ai/determined_sphinx_theme.git@accf3fbb639c88deb2e60293bdd15b3cae3afc01 13 | sphinx-gallery>=0.6.1 14 | pillow 15 | sphinx-copybutton 16 | m2r 17 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020 Determined AI. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | #!/usr/bin/env python3 16 | import pathlib 17 | from setuptools import find_packages, setup 18 | 19 | HERE = pathlib.Path(__file__).parent 20 | README = (HERE / "README.md").read_text() 21 | 22 | setup( 23 | name="yogadl", 24 | version="0.1.4", 25 | author="Determined AI", 26 | author_email="hello@determined.ai", 27 | url="https://www.github.com/determined-ai/yogadl/", 28 | description="Yoga Data Layer, a flexible data layer for machine learning", 29 | license="Apache License 2.0", 30 | classifiers=["License :: OSI Approved :: Apache Software License"], 31 | long_description=README, 32 | long_description_content_type="text/markdown", 33 | packages=find_packages(exclude=["*.tests", "*.tests.*", "tests.*", "tests"]), 34 | python_requires=">=3.6.0", 35 | install_requires=[ 36 | "async_generator", 37 | "boto3", 38 | "filelock", 39 | "google-cloud-storage", 40 | "lmdb", 41 | "lomond", 42 | # We use the ConnectionClosedError not present in version 7. 43 | "websockets>=8.0", 44 | ], 45 | extras_require={"tf": ["tensorflow"]}, 46 | zip_safe=False, 47 | include_package_data=True, 48 | ) 49 | -------------------------------------------------------------------------------- /tests/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/determined-ai/yogadl/7f4233dd76d53664b913558f742728203ee9406a/tests/__init__.py -------------------------------------------------------------------------------- /tests/integration/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/determined-ai/yogadl/7f4233dd76d53664b913558f742728203ee9406a/tests/integration/__init__.py -------------------------------------------------------------------------------- /tests/integration/aws/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/determined-ai/yogadl/7f4233dd76d53664b913558f742728203ee9406a/tests/integration/aws/__init__.py -------------------------------------------------------------------------------- /tests/integration/aws/test_s3_system.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020 Determined AI. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | import pytest 16 | from tl.testing import thread 17 | 18 | import tests.integration.util as util # noqa: I202, I100 19 | 20 | from yogadl import dataref, storage, tensorflow 21 | 22 | 23 | def create_s3_configuration(access_server_port: int) -> storage.S3Configurations: 24 | return storage.S3Configurations( 25 | bucket="yogadl-test", 26 | bucket_directory_path="integration-tests", 27 | url=f"ws://localhost:{access_server_port}", 28 | local_cache_dir="/tmp/", 29 | ) 30 | 31 | 32 | def worker_using_cacheable( 33 | config: storage.S3Configurations, dataset_id: str, dataset_version: str 34 | ) -> None: 35 | s3_storage = storage.S3Storage(configurations=config) 36 | 37 | @s3_storage.cacheable(dataset_id=dataset_id, dataset_version=dataset_version) 38 | def make_dataset() -> dataref.LMDBDataRef: 39 | return util.make_mnist_test_dataset() # type: ignore 40 | 41 | stream_from_cache = make_dataset().stream() 42 | dataset_from_stream = tensorflow.make_tf_dataset(stream_from_cache) 43 | original_dataset = util.make_mnist_test_dataset() 44 | 45 | data_samples = util.compare_datasets(original_dataset, dataset_from_stream) 46 | assert data_samples == 10000 47 | assert stream_from_cache.length == data_samples 48 | 49 | 50 | @pytest.mark.gcp 51 | def test_mnist_single_threaded() -> None: 52 | dataset_id = "mnist" 53 | dataset_version = "1" 54 | config = create_s3_configuration(access_server_port=29243) 55 | 56 | util.cleanup_s3_storage( 57 | configurations=config, dataset_id=dataset_id, dataset_version=dataset_version 58 | ) 59 | 60 | access_server_handler = util.AccessServerHandler(hostname="localhost", port=29243) 61 | access_server_handler.run_server_in_thread() 62 | 63 | try: 64 | worker_using_cacheable( 65 | config=config, dataset_id=dataset_id, dataset_version=dataset_version 66 | ) 67 | finally: 68 | access_server_handler.stop_server() 69 | util.cleanup_s3_storage( 70 | configurations=config, dataset_id=dataset_id, dataset_version=dataset_version 71 | ) 72 | 73 | 74 | class MultiThreadedTests(thread.ThreadAwareTestCase): # type: ignore 75 | @pytest.mark.gcp 76 | def test_mnist_multi_threaded(self) -> None: 77 | dataset_id = "mnist" 78 | dataset_version = "1" 79 | num_threads = 4 80 | 81 | config = create_s3_configuration(access_server_port=29243) 82 | 83 | util.cleanup_s3_storage( 84 | configurations=config, dataset_id=dataset_id, dataset_version=dataset_version 85 | ) 86 | 87 | access_server_handler = util.AccessServerHandler(hostname="localhost", port=29243) 88 | access_server_handler.run_server_in_thread() 89 | 90 | try: 91 | with thread.ThreadJoiner(60): 92 | for _ in range(num_threads): 93 | self.run_in_thread( 94 | lambda: worker_using_cacheable( 95 | config=config, dataset_id=dataset_id, dataset_version=dataset_version 96 | ) 97 | ) 98 | finally: 99 | access_server_handler.stop_server() 100 | util.cleanup_s3_storage( 101 | configurations=config, dataset_id=dataset_id, dataset_version=dataset_version 102 | ) 103 | -------------------------------------------------------------------------------- /tests/integration/gcp/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/determined-ai/yogadl/7f4233dd76d53664b913558f742728203ee9406a/tests/integration/gcp/__init__.py -------------------------------------------------------------------------------- /tests/integration/gcp/test_gcs_system.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020 Determined AI. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | import pytest 16 | from tl.testing import thread 17 | 18 | import tests.integration.util as util # noqa: I202, I100 19 | 20 | from yogadl import dataref, storage, tensorflow 21 | 22 | 23 | def create_gcs_configuration(access_server_port: int) -> storage.GCSConfigurations: 24 | return storage.GCSConfigurations( 25 | bucket="yogadl-test", 26 | bucket_directory_path="integration-tests", 27 | url=f"ws://localhost:{access_server_port}", 28 | local_cache_dir="/tmp/", 29 | ) 30 | 31 | 32 | def worker_using_cacheable( 33 | config: storage.GCSConfigurations, dataset_id: str, dataset_version: str 34 | ) -> None: 35 | gcs_storage = storage.GCSStorage(configurations=config) 36 | 37 | @gcs_storage.cacheable(dataset_id=dataset_id, dataset_version=dataset_version) 38 | def make_dataset() -> dataref.LMDBDataRef: 39 | return util.make_mnist_test_dataset() # type: ignore 40 | 41 | stream_from_cache = make_dataset().stream() 42 | dataset_from_stream = tensorflow.make_tf_dataset(stream_from_cache) 43 | original_dataset = util.make_mnist_test_dataset() 44 | 45 | data_samples = util.compare_datasets(original_dataset, dataset_from_stream) 46 | assert data_samples == 10000 47 | assert stream_from_cache.length == data_samples 48 | 49 | 50 | @pytest.mark.gcp 51 | def test_mnist_single_threaded() -> None: 52 | dataset_id = "mnist" 53 | dataset_version = "1" 54 | config = create_gcs_configuration(access_server_port=29243) 55 | 56 | util.cleanup_gcs_storage( 57 | configurations=config, dataset_id=dataset_id, dataset_version=dataset_version 58 | ) 59 | 60 | access_server_handler = util.AccessServerHandler(hostname="localhost", port=29243) 61 | access_server_handler.run_server_in_thread() 62 | 63 | try: 64 | worker_using_cacheable( 65 | config=config, dataset_id=dataset_id, dataset_version=dataset_version 66 | ) 67 | finally: 68 | access_server_handler.stop_server() 69 | util.cleanup_gcs_storage( 70 | configurations=config, dataset_id=dataset_id, dataset_version=dataset_version 71 | ) 72 | 73 | 74 | class MultiThreadedTests(thread.ThreadAwareTestCase): # type: ignore 75 | @pytest.mark.gcp 76 | def test_mnist_multi_threaded(self) -> None: 77 | dataset_id = "mnist" 78 | dataset_version = "1" 79 | num_threads = 4 80 | 81 | config = create_gcs_configuration(access_server_port=29243) 82 | 83 | util.cleanup_gcs_storage( 84 | configurations=config, dataset_id=dataset_id, dataset_version=dataset_version 85 | ) 86 | 87 | access_server_handler = util.AccessServerHandler(hostname="localhost", port=29243) 88 | access_server_handler.run_server_in_thread() 89 | 90 | try: 91 | with thread.ThreadJoiner(60): 92 | for _ in range(num_threads): 93 | self.run_in_thread( 94 | lambda: worker_using_cacheable( 95 | config=config, dataset_id=dataset_id, dataset_version=dataset_version 96 | ) 97 | ) 98 | finally: 99 | access_server_handler.stop_server() 100 | util.cleanup_gcs_storage( 101 | configurations=config, dataset_id=dataset_id, dataset_version=dataset_version 102 | ) 103 | -------------------------------------------------------------------------------- /tests/integration/local/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/determined-ai/yogadl/7f4233dd76d53664b913558f742728203ee9406a/tests/integration/local/__init__.py -------------------------------------------------------------------------------- /tests/integration/local/test_examples.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020 Determined AI. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | import pathlib 16 | import runpy 17 | 18 | 19 | def examples_dir() -> pathlib.Path: 20 | here = pathlib.Path(__file__).parent 21 | return here.parent.parent.parent.joinpath("examples") 22 | 23 | 24 | def test_walkthrough() -> None: 25 | runpy.run_path(str(examples_dir().joinpath("walkthrough.py"))) 26 | 27 | 28 | def test_mnist() -> None: 29 | runpy.run_path(str(examples_dir().joinpath("mnist.py"))) 30 | 31 | 32 | def test_custom_data_ref() -> None: 33 | runpy.run_path(str(examples_dir().joinpath("custom_data_ref.py"))) 34 | -------------------------------------------------------------------------------- /tests/integration/local/test_lfs_system.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020 Determined AI. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | import tests.integration.util as util # noqa: I202, I100 16 | 17 | from yogadl import dataref, storage, tensorflow 18 | 19 | 20 | def test_mnist_single_threaded() -> None: 21 | config = storage.LFSConfigurations(storage_dir_path="/tmp/") 22 | lfs_storage = storage.LFSStorage(configurations=config) 23 | 24 | dataset_id = "mnist" 25 | dataset_version = "1" 26 | 27 | util.cleanup_lfs_storage( 28 | configurations=config, dataset_id=dataset_id, dataset_version=dataset_version 29 | ) 30 | 31 | @lfs_storage.cacheable(dataset_id=dataset_id, dataset_version=dataset_version) 32 | def make_dataset() -> dataref.LMDBDataRef: 33 | return util.make_mnist_test_dataset() # type: ignore 34 | 35 | stream_from_cache = make_dataset().stream() 36 | dataset_from_stream = tensorflow.make_tf_dataset(stream_from_cache) 37 | original_dataset = util.make_mnist_test_dataset() 38 | 39 | data_samples = util.compare_datasets(original_dataset, dataset_from_stream) 40 | assert data_samples == 10000 41 | assert stream_from_cache.length == data_samples 42 | util.cleanup_lfs_storage( 43 | configurations=config, dataset_id=dataset_id, dataset_version=dataset_version 44 | ) 45 | -------------------------------------------------------------------------------- /tests/integration/util.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020 Determined AI. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | import asyncio 16 | import threading 17 | from typing import Optional 18 | 19 | import boto3 20 | import google.cloud.storage as google_storage 21 | import numpy as np 22 | import tensorflow as tf 23 | import tensorflow_datasets as tfds 24 | 25 | from yogadl import rw_coordinator, storage 26 | 27 | 28 | def make_mnist_test_dataset() -> tf.data.Dataset: 29 | mnist_builder = tfds.builder("mnist") 30 | mnist_builder.download_and_prepare() 31 | # We use test because for tfds version < 1.3 the 32 | # train split is automatically shuffled, breaking 33 | # the test. 34 | mnist_test = mnist_builder.as_dataset(split="test") 35 | return mnist_test 36 | 37 | 38 | def cleanup_lfs_storage( 39 | configurations: storage.LFSConfigurations, dataset_id: str, dataset_version: str 40 | ) -> None: 41 | cache_filepath = ( 42 | configurations.storage_dir_path.joinpath(dataset_id) 43 | .joinpath(dataset_version) 44 | .joinpath("cache.mdb") 45 | ) 46 | if cache_filepath.exists(): 47 | cache_filepath.unlink() 48 | 49 | 50 | def cleanup_gcs_storage( 51 | configurations: storage.GCSConfigurations, dataset_id: str, dataset_version: str 52 | ) -> None: 53 | gcs_cache_filepath = ( 54 | configurations.bucket_directory_path.joinpath(dataset_id) 55 | .joinpath(dataset_version) 56 | .joinpath("cache.mdb") 57 | ) 58 | 59 | client = google_storage.Client() 60 | bucket = client.bucket(configurations.bucket) 61 | blob = bucket.blob(str(gcs_cache_filepath)) 62 | if blob.exists(): 63 | blob.delete() 64 | 65 | 66 | def cleanup_s3_storage( 67 | configurations: storage.S3Configurations, dataset_id: str, dataset_version: str 68 | ) -> None: 69 | s3_cache_filepath = ( 70 | configurations.bucket_directory_path.joinpath(dataset_id) 71 | .joinpath(dataset_version) 72 | .joinpath("cache.mdb") 73 | ) 74 | 75 | client = boto3.client("s3") 76 | client.delete_object(Bucket=configurations.bucket, Key=str(s3_cache_filepath)) 77 | 78 | 79 | class AccessServerHandler: 80 | def __init__(self, hostname: str, port: int) -> None: 81 | self._access_server = rw_coordinator.RwCoordinatorServer(hostname=hostname, port=port) 82 | 83 | self._thread_running_server = None # type: Optional[threading.Thread] 84 | 85 | def run_server_in_thread(self) -> None: 86 | asyncio.get_event_loop().run_until_complete(self._access_server.run_server()) 87 | self._thread_running_server = threading.Thread(target=asyncio.get_event_loop().run_forever) 88 | self._thread_running_server.start() 89 | 90 | def stop_server(self) -> None: 91 | self._access_server.stop_server() 92 | 93 | assert self._thread_running_server 94 | self._thread_running_server.join() 95 | 96 | 97 | def compare_datasets_graph_mode( 98 | original_dataset: tf.data.Dataset, dataset_from_stream: tf.data.Dataset 99 | ) -> int: 100 | next_element_from_stream = dataset_from_stream.make_one_shot_iterator().get_next() 101 | next_element_from_orig = original_dataset.make_one_shot_iterator().get_next() 102 | data_samples = 0 103 | 104 | with tf.Session() as sess: 105 | while True: 106 | try: 107 | element_from_stream = sess.run(next_element_from_stream) 108 | element_from_dataset = sess.run(next_element_from_orig) 109 | assert element_from_stream["label"] == element_from_dataset["label"] 110 | assert np.array_equal(element_from_stream["image"], element_from_dataset["image"]) 111 | data_samples += 1 112 | except tf.errors.OutOfRangeError: 113 | break 114 | 115 | return data_samples 116 | 117 | 118 | def compare_datasets_eager_mode( 119 | original_dataset: tf.data.Dataset, dataset_from_stream: tf.data.Dataset 120 | ) -> int: 121 | next_element_from_stream = dataset_from_stream.as_numpy_iterator() 122 | next_element_from_orig = original_dataset.as_numpy_iterator() 123 | data_samples = 0 124 | 125 | for orig_dict, from_stream_dict in zip(next_element_from_orig, next_element_from_stream): 126 | for orig_data, from_stream_data in zip(orig_dict, from_stream_dict): 127 | assert np.array_equal(orig_data, from_stream_data) 128 | data_samples += 1 129 | 130 | return data_samples 131 | 132 | 133 | def compare_datasets( 134 | original_dataset: tf.data.Dataset, dataset_from_stream: tf.data.Dataset 135 | ) -> int: 136 | if tf.executing_eagerly(): 137 | return compare_datasets_eager_mode(original_dataset, dataset_from_stream) 138 | else: 139 | return compare_datasets_graph_mode(original_dataset, dataset_from_stream) 140 | -------------------------------------------------------------------------------- /tests/performance/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/determined-ai/yogadl/7f4233dd76d53664b913558f742728203ee9406a/tests/performance/__init__.py -------------------------------------------------------------------------------- /tests/performance/imagenet/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/determined-ai/yogadl/7f4233dd76d53664b913558f742728203ee9406a/tests/performance/imagenet/__init__.py -------------------------------------------------------------------------------- /tests/performance/imagenet/resnet_preprocessing.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020 Determined AI. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | # 16 | # Copyright 2018 The TensorFlow Authors. All Rights Reserved. 17 | # 18 | # Licensed under the Apache License, Version 2.0 (the "License"); 19 | # you may not use this file except in compliance with the License. 20 | # You may obtain a copy of the License at 21 | # 22 | # http://www.apache.org/licenses/LICENSE-2.0 23 | # 24 | # Unless required by applicable law or agreed to in writing, software 25 | # distributed under the License is distributed on an "AS IS" BASIS, 26 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 27 | # See the License for the specific language governing permissions and 28 | # limitations under the License. 29 | # ============================================================================== 30 | from __future__ import absolute_import 31 | from __future__ import division 32 | from __future__ import print_function 33 | 34 | import tensorflow.compat.v1 as tf 35 | 36 | IMAGE_SIZE = 224 37 | CROP_PADDING = 32 38 | 39 | 40 | def distorted_bounding_box_crop( 41 | image_bytes, 42 | bbox, 43 | min_object_covered=0.1, 44 | aspect_ratio_range=(0.75, 1.33), 45 | area_range=(0.05, 1.0), 46 | max_attempts=100, 47 | scope=None, 48 | ): 49 | with tf.name_scope(scope, "distorted_bounding_box_crop", [image_bytes, bbox]): 50 | shape = tf.image.extract_jpeg_shape(image_bytes) 51 | sample_distorted_bounding_box = tf.image.sample_distorted_bounding_box( 52 | shape, 53 | bounding_boxes=bbox, 54 | min_object_covered=min_object_covered, 55 | aspect_ratio_range=aspect_ratio_range, 56 | area_range=area_range, 57 | max_attempts=max_attempts, 58 | use_image_if_no_bounding_boxes=True, 59 | ) 60 | bbox_begin, bbox_size, _ = sample_distorted_bounding_box 61 | 62 | # Crop the image to the specified bounding box. 63 | offset_y, offset_x, _ = tf.unstack(bbox_begin) 64 | target_height, target_width, _ = tf.unstack(bbox_size) 65 | crop_window = tf.stack([offset_y, offset_x, target_height, target_width]) 66 | image = tf.image.decode_and_crop_jpeg(image_bytes, crop_window, channels=3) 67 | 68 | return image 69 | 70 | 71 | def _at_least_x_are_equal(a, b, x): 72 | match = tf.equal(a, b) 73 | match = tf.cast(match, tf.int32) 74 | return tf.greater_equal(tf.reduce_sum(match), x) 75 | 76 | 77 | def _decode_and_random_crop(image_bytes): 78 | bbox = tf.constant([0.0, 0.0, 1.0, 1.0], dtype=tf.float32, shape=[1, 1, 4]) 79 | image = distorted_bounding_box_crop( 80 | image_bytes, 81 | bbox, 82 | min_object_covered=0.1, 83 | aspect_ratio_range=(3.0 / 4, 4.0 / 3.0), 84 | area_range=(0.08, 1.0), 85 | max_attempts=10, 86 | scope=None, 87 | ) 88 | original_shape = tf.image.extract_jpeg_shape(image_bytes) 89 | bad = _at_least_x_are_equal(original_shape, tf.shape(image), 3) 90 | 91 | image = tf.cond( 92 | bad, 93 | lambda: _decode_and_center_crop(image_bytes), 94 | lambda: tf.image.resize_bicubic( 95 | [image], [IMAGE_SIZE, IMAGE_SIZE] # pylint: disable=g-long-lambda 96 | )[0], 97 | ) 98 | 99 | return image 100 | 101 | 102 | def _decode_and_center_crop(image_bytes): 103 | shape = tf.image.extract_jpeg_shape(image_bytes) 104 | image_height = shape[0] 105 | image_width = shape[1] 106 | 107 | padded_center_crop_size = tf.cast( 108 | ( 109 | (IMAGE_SIZE / (IMAGE_SIZE + CROP_PADDING)) 110 | * tf.cast(tf.minimum(image_height, image_width), tf.float32) 111 | ), 112 | tf.int32, 113 | ) 114 | 115 | offset_height = ((image_height - padded_center_crop_size) + 1) // 2 116 | offset_width = ((image_width - padded_center_crop_size) + 1) // 2 117 | crop_window = tf.stack( 118 | [offset_height, offset_width, padded_center_crop_size, padded_center_crop_size] 119 | ) 120 | image = tf.image.decode_and_crop_jpeg(image_bytes, crop_window, channels=3) 121 | image = tf.image.resize_bicubic([image], [IMAGE_SIZE, IMAGE_SIZE])[0] 122 | 123 | return image 124 | 125 | 126 | def _flip(image): 127 | image = tf.image.random_flip_left_right(image) 128 | return image 129 | 130 | 131 | def preprocess_for_train(image_bytes, use_bfloat16): 132 | image = _decode_and_random_crop(image_bytes) 133 | image = _flip(image) 134 | image = tf.reshape(image, [IMAGE_SIZE, IMAGE_SIZE, 3]) 135 | image = tf.image.convert_image_dtype(image, dtype=tf.bfloat16 if use_bfloat16 else tf.float32) 136 | return image 137 | 138 | 139 | def preprocess_for_eval(image_bytes, use_bfloat16): 140 | image = _decode_and_center_crop(image_bytes) 141 | image = tf.reshape(image, [IMAGE_SIZE, IMAGE_SIZE, 3]) 142 | image = tf.image.convert_image_dtype(image, dtype=tf.bfloat16 if use_bfloat16 else tf.float32) 143 | return image 144 | 145 | 146 | def preprocess_image(image_bytes, is_training=False, use_bfloat16=False): 147 | if is_training: 148 | return preprocess_for_train(image_bytes, use_bfloat16) 149 | else: 150 | return preprocess_for_eval(image_bytes, use_bfloat16) 151 | -------------------------------------------------------------------------------- /tests/performance/imagenet/test_imagenet.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020 Determined AI. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | # 16 | # `dataset_parser()` and `make_dataset_from_tf_records()` are derived from: 17 | # 18 | # Copyright 2018 The TensorFlow Authors. All Rights Reserved. 19 | # 20 | # Licensed under the Apache License, Version 2.0 (the "License"); 21 | # you may not use this file except in compliance with the License. 22 | # You may obtain a copy of the License at 23 | # 24 | # http://www.apache.org/licenses/LICENSE-2.0 25 | # 26 | # Unless required by applicable law or agreed to in writing, software 27 | # distributed under the License is distributed on an "AS IS" BASIS, 28 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 29 | # See the License for the specific language governing permissions and 30 | # limitations under the License. 31 | # ============================================================================== 32 | import os 33 | import pathlib 34 | import time 35 | from typing import Any, Tuple 36 | 37 | import tensorflow as tf 38 | 39 | from tests.performance.imagenet import resnet_preprocessing 40 | 41 | from yogadl import dataref, storage, tensorflow 42 | 43 | 44 | def cleanup_lfs_storage( 45 | configurations: storage.LFSConfigurations, dataset_id: str, dataset_version: str 46 | ) -> None: 47 | cache_filepath = ( 48 | configurations.storage_dir_path.joinpath(dataset_id) 49 | .joinpath(dataset_version) 50 | .joinpath("cache.mdb") 51 | ) 52 | if cache_filepath.exists(): 53 | cache_filepath.unlink() 54 | 55 | 56 | def dataset_parser(value: Any) -> Any: 57 | """ 58 | Based on [1]. 59 | 60 | [1] - https://github.com/tensorflow/tpu/blob/master/models/ 61 | experimental/resnet50_keras/imagenet_input.py 62 | """ 63 | keys_to_features = { 64 | "image/encoded": tf.FixedLenFeature((), tf.string, ""), 65 | "image/format": tf.FixedLenFeature((), tf.string, "jpeg"), 66 | "image/class/label": tf.FixedLenFeature([], tf.int64, -1), 67 | "image/class/text": tf.FixedLenFeature([], tf.string, ""), 68 | "image/object/bbox/xmin": tf.VarLenFeature(dtype=tf.float32), 69 | "image/object/bbox/ymin": tf.VarLenFeature(dtype=tf.float32), 70 | "image/object/bbox/xmax": tf.VarLenFeature(dtype=tf.float32), 71 | "image/object/bbox/ymax": tf.VarLenFeature(dtype=tf.float32), 72 | "image/object/class/label": tf.VarLenFeature(dtype=tf.int64), 73 | } 74 | 75 | parsed = tf.parse_single_example(value, keys_to_features) 76 | image_bytes = tf.reshape(parsed["image/encoded"], shape=[]) 77 | 78 | image = resnet_preprocessing.preprocess_image( # type: ignore 79 | image_bytes=image_bytes, is_training=True, use_bfloat16=False 80 | ) 81 | 82 | # Subtract one so that labels are in [0, 1000), and cast to float32 for 83 | # Keras model. 84 | label = tf.cast( 85 | tf.cast(tf.reshape(parsed["image/class/label"], shape=[1]), dtype=tf.int32) - 1, 86 | dtype=tf.float32, 87 | ) 88 | 89 | return image, label 90 | 91 | 92 | def make_dataset_from_tf_records(data_dir: pathlib.Path, training: bool) -> tf.data.Dataset: 93 | """ 94 | Based on [1]. 95 | 96 | [1] - https://github.com/tensorflow/tpu/blob/master/models/ 97 | experimental/resnet50_keras/imagenet_input.py 98 | """ 99 | # Process 100 out of 1024 record files. 100 | file_pattern = os.path.join( 101 | str(data_dir), "train/train-003*" if training else "validation/validation-*" 102 | ) 103 | dataset = tf.data.Dataset.list_files(file_pattern, shuffle=False) 104 | 105 | def fetch_tf_record_file(filename: str) -> tf.data.TFRecordDataset: 106 | buffer_size = 8 * 1024 * 1024 # 8 MiB per file 107 | tf_record_dataset = tf.data.TFRecordDataset(filename, buffer_size=buffer_size) 108 | return tf_record_dataset 109 | 110 | dataset = dataset.interleave( 111 | fetch_tf_record_file, cycle_length=16, num_parallel_calls=tf.data.experimental.AUTOTUNE 112 | ) 113 | 114 | return dataset 115 | 116 | 117 | def read_dataset(dataset: tf.data.Dataset) -> Tuple[float, int]: 118 | dataset = dataset.apply( 119 | tf.data.experimental.map_and_batch( 120 | dataset_parser, batch_size=1, num_parallel_batches=2, drop_remainder=True 121 | ) 122 | ) 123 | dataset = dataset.prefetch(tf.data.experimental.AUTOTUNE) 124 | next_element_from_dataset = dataset.make_one_shot_iterator().get_next() 125 | 126 | with tf.Session() as sess: 127 | data_samples = 0 128 | dataset_read_start_time = time.time() 129 | 130 | while True: 131 | try: 132 | sess.run(next_element_from_dataset) 133 | data_samples += 1 134 | except tf.errors.OutOfRangeError: 135 | break 136 | 137 | dataset_read_time = time.time() - dataset_read_start_time 138 | 139 | return dataset_read_time, data_samples 140 | 141 | 142 | def compare_performance_tf_record_dataset(data_dir: pathlib.Path) -> None: 143 | config = storage.LFSConfigurations(storage_dir_path="/tmp/") 144 | lfs_storage = storage.LFSStorage(configurations=config) 145 | 146 | dataset_id = "imagenet-train" 147 | dataset_version = "0" 148 | training = True 149 | 150 | cleanup_lfs_storage( 151 | configurations=config, dataset_id=dataset_id, dataset_version=dataset_version 152 | ) 153 | 154 | @lfs_storage.cacheable(dataset_id=dataset_id, dataset_version=dataset_version) 155 | def make_dataset() -> dataref.LMDBDataRef: 156 | return make_dataset_from_tf_records(data_dir=data_dir, training=training) # type: ignore 157 | 158 | cache_creation_start_time = time.time() 159 | stream_from_cache = make_dataset().stream() 160 | cache_creation_time = time.time() - cache_creation_start_time 161 | print(f"Cache creation took: {cache_creation_time} seconds.") 162 | 163 | dataset_from_stream = tensorflow.make_tf_dataset(stream_from_cache) 164 | cache_read_time, cache_data_items = read_dataset(dataset=dataset_from_stream) 165 | print(f"Cache read took: {cache_read_time} seconds.") 166 | 167 | original_dataset_read_time, original_data_items = read_dataset( 168 | dataset=make_dataset_from_tf_records(data_dir=data_dir, training=training) 169 | ) 170 | print(f"Original read took: {original_dataset_read_time} seconds.") 171 | 172 | assert cache_data_items == original_data_items 173 | 174 | 175 | def test_lfs_imagenet() -> None: 176 | # This test requires that the imagenet dataset be present in TFRecords format. 177 | error_message = ( 178 | "Please set `IMAGENET_DIRECTORY` environment variable to " 179 | "be the directory containing the TFRecords." 180 | ) 181 | 182 | imagenet_directory_path = os.environ.get("IMAGENET_DIRECTORY") 183 | assert imagenet_directory_path, error_message 184 | imagenet_directory = pathlib.Path(imagenet_directory_path) 185 | assert imagenet_directory.is_dir(), error_message 186 | 187 | compare_performance_tf_record_dataset(data_dir=imagenet_directory) 188 | -------------------------------------------------------------------------------- /tests/unit/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/determined-ai/yogadl/7f4233dd76d53664b913558f742728203ee9406a/tests/unit/__init__.py -------------------------------------------------------------------------------- /tests/unit/aws/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/determined-ai/yogadl/7f4233dd76d53664b913558f742728203ee9406a/tests/unit/aws/__init__.py -------------------------------------------------------------------------------- /tests/unit/aws/test_s3_storage.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020 Determined AI. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | import json 16 | import pathlib 17 | 18 | import boto3 19 | import botocore.client as boto_client 20 | import tensorflow as tf 21 | from tl.testing import thread 22 | 23 | import tests.unit.util as test_util 24 | 25 | from yogadl import dataref, storage 26 | 27 | 28 | def create_s3_configuration(access_server_port: int) -> storage.S3Configurations: 29 | return storage.S3Configurations( 30 | bucket="yogadl-test", 31 | bucket_directory_path="unit-tests", 32 | url=f"ws://localhost:{access_server_port}", 33 | local_cache_dir="/tmp/", 34 | ) 35 | 36 | 37 | def get_local_cache_filepath( 38 | configurations: storage.S3Configurations, dataset_id: str, dataset_version: str 39 | ) -> pathlib.Path: 40 | return ( 41 | configurations.local_cache_dir.joinpath("yogadl_local_cache") 42 | .joinpath(dataset_id) 43 | .joinpath(dataset_version) 44 | .joinpath("cache.mdb") 45 | ) 46 | 47 | 48 | def get_local_metadata_filepath( 49 | configurations: storage.S3Configurations, dataset_id: str, dataset_version: str 50 | ) -> pathlib.Path: 51 | return ( 52 | configurations.local_cache_dir.joinpath("yogadl_local_cache") 53 | .joinpath(dataset_id) 54 | .joinpath(dataset_version) 55 | .joinpath("local_metadata.json") 56 | ) 57 | 58 | 59 | def get_s3_filepath( 60 | configurations: storage.S3Configurations, dataset_id: str, dataset_version: str 61 | ) -> pathlib.Path: 62 | return ( 63 | configurations.bucket_directory_path.joinpath(dataset_id) 64 | .joinpath(dataset_version) 65 | .joinpath("cache.mdb") 66 | ) 67 | 68 | 69 | def test_s3_storage_submit() -> None: 70 | range_size = 10 71 | dataset_id = "range-dataset" 72 | dataset_version = "0" 73 | dataset = tf.data.Dataset.range(range_size) 74 | configurations = create_s3_configuration(access_server_port=15032) 75 | 76 | client = boto3.client("s3") 77 | aws_cache_filepath = get_s3_filepath( 78 | configurations=configurations, 79 | dataset_id=dataset_id, 80 | dataset_version=dataset_version, 81 | ) 82 | 83 | try: 84 | blob_info = client.head_object(Bucket=configurations.bucket, Key=str(aws_cache_filepath)) 85 | previous_creation_time = blob_info.get("LastModified") 86 | except boto_client.ClientError: 87 | previous_creation_time = None 88 | 89 | s3_storage = storage.S3Storage(configurations=configurations) 90 | s3_storage.submit( 91 | data=dataset, 92 | dataset_id=dataset_id, 93 | dataset_version=dataset_version, 94 | ) 95 | 96 | blob_info = client.head_object(Bucket=configurations.bucket, Key=str(aws_cache_filepath)) 97 | assert blob_info.get("LastModified") is not None 98 | assert previous_creation_time != blob_info.get("LastModified") 99 | 100 | if previous_creation_time is not None: 101 | assert previous_creation_time < blob_info.get("LastModified") 102 | 103 | 104 | def test_s3_storage_local_metadata() -> None: 105 | range_size = 10 106 | dataset_id = "range-dataset" 107 | dataset_version = "0" 108 | dataset = tf.data.Dataset.range(range_size) 109 | configurations = create_s3_configuration(access_server_port=15032) 110 | 111 | client = boto3.client("s3") 112 | aws_cache_filepath = get_s3_filepath( 113 | configurations=configurations, 114 | dataset_id=dataset_id, 115 | dataset_version=dataset_version, 116 | ) 117 | 118 | s3_storage = storage.S3Storage(configurations=configurations) 119 | s3_storage.submit( 120 | data=dataset, 121 | dataset_id=dataset_id, 122 | dataset_version=dataset_version, 123 | ) 124 | 125 | local_metadata_filepath = get_local_metadata_filepath( 126 | configurations=configurations, dataset_id=dataset_id, dataset_version=dataset_version 127 | ) 128 | with open(str(local_metadata_filepath), "r") as metadata_file: 129 | metadata = json.load(metadata_file) 130 | 131 | blob_info = client.head_object(Bucket=configurations.bucket, Key=str(aws_cache_filepath)) 132 | creation_time = blob_info.get("LastModified") 133 | 134 | assert metadata.get("time_created") 135 | assert creation_time.timestamp() == metadata["time_created"] 136 | 137 | local_metadata_filepath.unlink() 138 | _ = s3_storage.fetch(dataset_id=dataset_id, dataset_version=dataset_version) 139 | with open(str(local_metadata_filepath), "r") as metadata_file: 140 | metadata = json.load(metadata_file) 141 | 142 | assert metadata.get("time_created") 143 | assert creation_time.timestamp() == metadata["time_created"] 144 | 145 | 146 | def test_s3_storage_submit_and_fetch() -> None: 147 | range_size = 20 148 | dataset_id = "range-dataset" 149 | dataset_version = "0" 150 | dataset = tf.data.Dataset.range(range_size) 151 | configurations = create_s3_configuration(access_server_port=15032) 152 | 153 | s3_storage = storage.S3Storage(configurations=configurations) 154 | s3_storage.submit( 155 | data=dataset, 156 | dataset_id=dataset_id, 157 | dataset_version=dataset_version, 158 | ) 159 | dataref = s3_storage.fetch(dataset_id=dataset_id, dataset_version=dataset_version) 160 | stream = dataref.stream() 161 | 162 | assert stream.length == range_size 163 | data_generator = stream.iterator_fn() 164 | generator_length = 0 165 | for idx, data in enumerate(data_generator): 166 | assert idx == data 167 | generator_length += 1 168 | assert generator_length == range_size 169 | 170 | 171 | def test_s3_storage_cacheable_single_threaded() -> None: 172 | original_range_size = 120 173 | updated_range_size = 55 174 | dataset_id = "range-dataset" 175 | dataset_version = "0" 176 | configurations = create_s3_configuration(access_server_port=15032) 177 | 178 | access_server_handler = test_util.AccessServerHandler(hostname="localhost", port=15032) 179 | access_server_handler.run_server_in_thread() 180 | 181 | s3_cache_filepath = get_s3_filepath( 182 | configurations=configurations, 183 | dataset_id=dataset_id, 184 | dataset_version=dataset_version, 185 | ) 186 | client = boto3.client("s3") 187 | client.delete_object(Bucket=configurations.bucket, Key=str(s3_cache_filepath)) 188 | 189 | s3_storage = storage.S3Storage(configurations=configurations) 190 | 191 | @s3_storage.cacheable(dataset_id, dataset_version) 192 | def make_dataref(range_size: int) -> dataref.LMDBDataRef: 193 | return tf.data.Dataset.range(range_size) # type: ignore 194 | 195 | original_data_stream = make_dataref(range_size=original_range_size).stream() 196 | assert original_data_stream.length == original_range_size 197 | data_generator = original_data_stream.iterator_fn() 198 | generator_length = 0 199 | for idx, data in enumerate(data_generator): 200 | assert idx == data 201 | generator_length += 1 202 | assert generator_length == original_range_size 203 | 204 | updated_data_stream = make_dataref(range_size=updated_range_size).stream() 205 | assert updated_data_stream.length == original_range_size 206 | 207 | access_server_handler.stop_server() 208 | 209 | 210 | def worker(configurations: storage.S3Configurations, dataset_id: str, dataset_version: str) -> None: 211 | range_size = 120 212 | s3_storage = storage.S3Storage(configurations=configurations) 213 | 214 | @s3_storage.cacheable(dataset_id, dataset_version) 215 | def make_dataref(input_range_size: int) -> dataref.LMDBDataRef: 216 | return tf.data.Dataset.range(input_range_size) # type: ignore 217 | 218 | stream = make_dataref(input_range_size=range_size).stream() 219 | assert stream.length == range_size 220 | 221 | data_generator = stream.iterator_fn() 222 | generator_length = 0 223 | for idx, data in enumerate(data_generator): 224 | assert idx == data 225 | generator_length += 1 226 | assert generator_length == range_size 227 | 228 | 229 | class MultiThreadedTests(thread.ThreadAwareTestCase): # type: ignore 230 | def test_gcs_storage_cacheable_multi_threaded(self) -> None: 231 | dataset_id = "range-dataset" 232 | dataset_version = "0" 233 | num_threads = 20 234 | configurations = create_s3_configuration(access_server_port=15032) 235 | 236 | access_server_handler = test_util.AccessServerHandler(hostname="localhost", port=15032) 237 | access_server_handler.run_server_in_thread() 238 | 239 | s3_cache_filepath = get_s3_filepath( 240 | configurations=configurations, 241 | dataset_id=dataset_id, 242 | dataset_version=dataset_version, 243 | ) 244 | client = boto3.client("s3") 245 | client.delete_object(Bucket=configurations.bucket, Key=str(s3_cache_filepath)) 246 | 247 | try: 248 | with thread.ThreadJoiner(10): 249 | for _ in range(num_threads): 250 | self.run_in_thread(lambda: worker(configurations, dataset_id, dataset_version)) 251 | finally: 252 | access_server_handler.stop_server() 253 | -------------------------------------------------------------------------------- /tests/unit/gcp/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/determined-ai/yogadl/7f4233dd76d53664b913558f742728203ee9406a/tests/unit/gcp/__init__.py -------------------------------------------------------------------------------- /tests/unit/gcp/test_gcs_storage.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020 Determined AI. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | import json 16 | import pathlib 17 | 18 | import google.cloud.storage as google_storage 19 | import tensorflow as tf 20 | from tl.testing import thread 21 | 22 | import tests.unit.util as test_util 23 | 24 | from yogadl import dataref, storage 25 | 26 | 27 | def create_gcs_configuration(access_server_port: int) -> storage.GCSConfigurations: 28 | return storage.GCSConfigurations( 29 | bucket="yogadl-test", 30 | bucket_directory_path="unit-tests", 31 | url=f"ws://localhost:{access_server_port}", 32 | local_cache_dir="/tmp/", 33 | ) 34 | 35 | 36 | def get_local_cache_filepath( 37 | configurations: storage.GCSConfigurations, dataset_id: str, dataset_version: str 38 | ) -> pathlib.Path: 39 | return ( 40 | configurations.local_cache_dir.joinpath("yogadl_local_cache") 41 | .joinpath(dataset_id) 42 | .joinpath(dataset_version) 43 | .joinpath("cache.mdb") 44 | ) 45 | 46 | 47 | def get_local_metadata_filepath( 48 | configurations: storage.GCSConfigurations, dataset_id: str, dataset_version: str 49 | ) -> pathlib.Path: 50 | return ( 51 | configurations.local_cache_dir.joinpath("yogadl_local_cache") 52 | .joinpath(dataset_id) 53 | .joinpath(dataset_version) 54 | .joinpath("local_metadata.json") 55 | ) 56 | 57 | 58 | def get_gcs_filepath( 59 | configurations: storage.GCSConfigurations, dataset_id: str, dataset_version: str 60 | ) -> pathlib.Path: 61 | return ( 62 | configurations.bucket_directory_path.joinpath(dataset_id) 63 | .joinpath(dataset_version) 64 | .joinpath("cache.mdb") 65 | ) 66 | 67 | 68 | def test_gcs_storage_submit() -> None: 69 | range_size = 10 70 | dataset_id = "range-dataset" 71 | dataset_version = "0" 72 | dataset = tf.data.Dataset.range(range_size) 73 | configurations = create_gcs_configuration(access_server_port=15032) 74 | 75 | client = google_storage.Client() 76 | bucket = client.bucket(configurations.bucket) 77 | gcs_cache_filepath = get_gcs_filepath( 78 | configurations=configurations, 79 | dataset_id=dataset_id, 80 | dataset_version=dataset_version, 81 | ) 82 | blob = bucket.blob(str(gcs_cache_filepath)) 83 | 84 | previous_creation_time = None 85 | if blob.exists(): 86 | blob.reload() 87 | previous_creation_time = blob.time_created 88 | 89 | gcs_storage = storage.GCSStorage(configurations=configurations) 90 | gcs_storage.submit( 91 | data=dataset, 92 | dataset_id=dataset_id, 93 | dataset_version=dataset_version, 94 | ) 95 | 96 | blob = bucket.blob(str(gcs_cache_filepath)) 97 | blob.reload() 98 | assert blob.exists() 99 | assert blob.time_created is not None 100 | assert previous_creation_time != blob.time_created 101 | 102 | if previous_creation_time is not None: 103 | assert previous_creation_time < blob.time_created 104 | 105 | 106 | def test_gcs_storage_local_metadata() -> None: 107 | range_size = 10 108 | dataset_id = "range-dataset" 109 | dataset_version = "0" 110 | dataset = tf.data.Dataset.range(range_size) 111 | configurations = create_gcs_configuration(access_server_port=15032) 112 | 113 | client = google_storage.Client() 114 | bucket = client.bucket(configurations.bucket) 115 | gcs_cache_filepath = get_gcs_filepath( 116 | configurations=configurations, 117 | dataset_id=dataset_id, 118 | dataset_version=dataset_version, 119 | ) 120 | 121 | gcs_storage = storage.GCSStorage(configurations=configurations) 122 | gcs_storage.submit( 123 | data=dataset, 124 | dataset_id=dataset_id, 125 | dataset_version=dataset_version, 126 | ) 127 | 128 | local_metadata_filepath = get_local_metadata_filepath( 129 | configurations=configurations, dataset_id=dataset_id, dataset_version=dataset_version 130 | ) 131 | with open(str(local_metadata_filepath), "r") as metadata_file: 132 | metadata = json.load(metadata_file) 133 | 134 | blob = bucket.blob(str(gcs_cache_filepath)) 135 | blob.reload() 136 | 137 | assert metadata.get("time_created") 138 | assert blob.time_created.timestamp() == metadata["time_created"] 139 | 140 | local_metadata_filepath.unlink() 141 | _ = gcs_storage.fetch(dataset_id=dataset_id, dataset_version=dataset_version) 142 | with open(str(local_metadata_filepath), "r") as metadata_file: 143 | metadata = json.load(metadata_file) 144 | 145 | assert metadata.get("time_created") 146 | assert blob.time_created.timestamp() == metadata["time_created"] 147 | 148 | 149 | def test_gcs_storage_submit_and_fetch() -> None: 150 | range_size = 20 151 | dataset_id = "range-dataset" 152 | dataset_version = "0" 153 | dataset = tf.data.Dataset.range(range_size) 154 | configurations = create_gcs_configuration(access_server_port=15032) 155 | 156 | gcs_storage = storage.GCSStorage(configurations=configurations) 157 | gcs_storage.submit( 158 | data=dataset, 159 | dataset_id=dataset_id, 160 | dataset_version=dataset_version, 161 | ) 162 | dataref = gcs_storage.fetch(dataset_id=dataset_id, dataset_version=dataset_version) 163 | stream = dataref.stream() 164 | 165 | assert stream.length == range_size 166 | data_generator = stream.iterator_fn() 167 | generator_length = 0 168 | for idx, data in enumerate(data_generator): 169 | assert idx == data 170 | generator_length += 1 171 | assert generator_length == range_size 172 | 173 | 174 | def test_gcs_storage_cacheable_single_threaded() -> None: 175 | original_range_size = 120 176 | updated_range_size = 55 177 | dataset_id = "range-dataset" 178 | dataset_version = "0" 179 | configurations = create_gcs_configuration(access_server_port=15032) 180 | 181 | access_server_handler = test_util.AccessServerHandler(hostname="localhost", port=15032) 182 | access_server_handler.run_server_in_thread() 183 | 184 | gcs_cache_filepath = get_gcs_filepath( 185 | configurations=configurations, 186 | dataset_id=dataset_id, 187 | dataset_version=dataset_version, 188 | ) 189 | client = google_storage.Client() 190 | bucket = client.bucket(configurations.bucket) 191 | blob = bucket.blob(str(gcs_cache_filepath)) 192 | if blob.exists(): 193 | blob.delete() 194 | 195 | gcs_storage = storage.GCSStorage(configurations=configurations) 196 | 197 | @gcs_storage.cacheable(dataset_id, dataset_version) 198 | def make_dataref(range_size: int) -> dataref.LMDBDataRef: 199 | return tf.data.Dataset.range(range_size) # type: ignore 200 | 201 | original_data_stream = make_dataref(range_size=original_range_size).stream() 202 | assert original_data_stream.length == original_range_size 203 | data_generator = original_data_stream.iterator_fn() 204 | generator_length = 0 205 | for idx, data in enumerate(data_generator): 206 | assert idx == data 207 | generator_length += 1 208 | assert generator_length == original_range_size 209 | 210 | updated_data_stream = make_dataref(range_size=updated_range_size).stream() 211 | assert updated_data_stream.length == original_range_size 212 | 213 | access_server_handler.stop_server() 214 | 215 | 216 | def worker( 217 | configurations: storage.GCSConfigurations, dataset_id: str, dataset_version: str 218 | ) -> None: 219 | range_size = 120 220 | gcs_storage = storage.GCSStorage(configurations=configurations) 221 | 222 | @gcs_storage.cacheable(dataset_id, dataset_version) 223 | def make_dataref(input_range_size: int) -> dataref.LMDBDataRef: 224 | return tf.data.Dataset.range(input_range_size) # type: ignore 225 | 226 | stream = make_dataref(input_range_size=range_size).stream() 227 | assert stream.length == range_size 228 | 229 | data_generator = stream.iterator_fn() 230 | generator_length = 0 231 | for idx, data in enumerate(data_generator): 232 | assert idx == data 233 | generator_length += 1 234 | assert generator_length == range_size 235 | 236 | 237 | class MultiThreadedTests(thread.ThreadAwareTestCase): # type: ignore 238 | def test_gcs_storage_cacheable_multi_threaded(self) -> None: 239 | dataset_id = "range-dataset" 240 | dataset_version = "0" 241 | num_threads = 20 242 | configurations = create_gcs_configuration(access_server_port=15032) 243 | 244 | access_server_handler = test_util.AccessServerHandler(hostname="localhost", port=15032) 245 | access_server_handler.run_server_in_thread() 246 | 247 | gcs_cache_filepath = get_gcs_filepath( 248 | configurations=configurations, 249 | dataset_id=dataset_id, 250 | dataset_version=dataset_version, 251 | ) 252 | client = google_storage.Client() 253 | bucket = client.bucket(configurations.bucket) 254 | blob = bucket.blob(str(gcs_cache_filepath)) 255 | if blob.exists(): 256 | blob.delete() 257 | 258 | try: 259 | with thread.ThreadJoiner(10): 260 | for _ in range(num_threads): 261 | self.run_in_thread(lambda: worker(configurations, dataset_id, dataset_version)) 262 | finally: 263 | access_server_handler.stop_server() 264 | -------------------------------------------------------------------------------- /tests/unit/local/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/determined-ai/yogadl/7f4233dd76d53664b913558f742728203ee9406a/tests/unit/local/__init__.py -------------------------------------------------------------------------------- /tests/unit/local/test_lfs_storage.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020 Determined AI. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | import pathlib 16 | 17 | import tensorflow as tf 18 | 19 | from yogadl import dataref, storage 20 | 21 | 22 | def create_configurations() -> storage.LFSConfigurations: 23 | return storage.LFSConfigurations(storage_dir_path="/tmp/") 24 | 25 | 26 | def get_cache_filepath( 27 | configurations: storage.LFSConfigurations, dataset_id: str, dataset_version: str 28 | ) -> pathlib.Path: 29 | return ( 30 | configurations.storage_dir_path.joinpath(dataset_id) 31 | .joinpath(dataset_version) 32 | .joinpath("cache.mdb") 33 | ) 34 | 35 | 36 | def test_storage_submit() -> None: 37 | range_size = 10 38 | dataset_id = "range-dataset" 39 | dataset_version = "0" 40 | dataset = tf.data.Dataset.range(range_size) 41 | configurations = create_configurations() 42 | if get_cache_filepath(configurations, dataset_id, dataset_version).exists(): 43 | get_cache_filepath(configurations, dataset_id, dataset_version).unlink() 44 | 45 | lfs_storage = storage.LFSStorage(configurations=configurations) 46 | lfs_storage.submit(data=dataset, dataset_id=dataset_id, dataset_version=dataset_version) 47 | 48 | assert get_cache_filepath(configurations, dataset_id, dataset_version).is_file() 49 | 50 | 51 | def test_storage_cacheable_single_threaded() -> None: 52 | original_range_size = 120 53 | updated_range_size = 126 54 | dataset_id = "range-dataset" 55 | dataset_version = "1" 56 | configurations = create_configurations() 57 | if get_cache_filepath(configurations, dataset_id, dataset_version).exists(): 58 | get_cache_filepath(configurations, dataset_id, dataset_version).unlink() 59 | 60 | lfs_storage = storage.LFSStorage(configurations=configurations) 61 | 62 | @lfs_storage.cacheable(dataset_id, dataset_version) 63 | def make_dataref(range_size: int) -> dataref.LMDBDataRef: 64 | return tf.data.Dataset.range(range_size) # type: ignore 65 | 66 | original_data_stream = make_dataref(range_size=original_range_size).stream() 67 | assert original_data_stream.length == original_range_size 68 | data_generator = original_data_stream.iterator_fn() 69 | for idx in range(original_range_size): 70 | assert idx == next(data_generator) 71 | 72 | updated_data_stream = make_dataref(range_size=updated_range_size).stream() 73 | assert updated_data_stream.length == original_range_size 74 | -------------------------------------------------------------------------------- /tests/unit/local/test_lmdb_access.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020 Determined AI. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | import pytest 16 | from typing import List 17 | 18 | import tensorflow as tf 19 | 20 | import tests.unit.util as util # noqa: I202, I100 21 | 22 | import yogadl 23 | 24 | 25 | def shard_and_get_keys( 26 | lmdb_reader: yogadl.LmdbAccess, 27 | shard_index: int, 28 | num_shards: int, 29 | sequential: bool, 30 | drop_shard_remainder: bool, 31 | ) -> List[bytes]: 32 | keys = lmdb_reader.get_keys() 33 | keys = yogadl.shard_keys( 34 | keys=keys, 35 | shard_index=shard_index, 36 | num_shards=num_shards, 37 | sequential=sequential, 38 | drop_shard_remainder=drop_shard_remainder, 39 | ) 40 | return keys 41 | 42 | 43 | def convert_int_to_byte_string(input_int: int) -> bytes: 44 | return u"{:08}".format(input_int).encode("ascii") 45 | 46 | 47 | def test_lmdb_access_keys() -> None: 48 | range_size = 10 49 | lmdb_reader = yogadl.LmdbAccess( 50 | lmdb_path=util.create_lmdb_checkpoint_using_range(range_size=range_size) 51 | ) 52 | keys = lmdb_reader.get_keys() 53 | assert len(keys) == range_size 54 | for idx, key in enumerate(keys): 55 | assert convert_int_to_byte_string(idx) == key 56 | 57 | 58 | @pytest.mark.parametrize("drop_remainder", [True, False]) 59 | def test_lmdb_access_keys_sequential_shard(drop_remainder: bool) -> None: 60 | range_size = 10 61 | num_shards = 3 62 | lmdb_checkpoint_path = util.create_lmdb_checkpoint_using_range(range_size=range_size) 63 | key_shards = [] 64 | for shard_id in range(num_shards): 65 | lmdb_reader = yogadl.LmdbAccess(lmdb_path=lmdb_checkpoint_path) 66 | key_shards.append( 67 | shard_and_get_keys( 68 | lmdb_reader=lmdb_reader, 69 | shard_index=shard_id, 70 | num_shards=num_shards, 71 | sequential=True, 72 | drop_shard_remainder=drop_remainder, 73 | ) 74 | ) 75 | 76 | merged_keys = [] 77 | for key_shard in key_shards: 78 | merged_keys.extend(key_shard) 79 | 80 | expected_range_size = ( 81 | range_size if not drop_remainder else range_size - (range_size % num_shards) 82 | ) 83 | assert len(merged_keys) == expected_range_size 84 | for idx, key in enumerate(merged_keys): 85 | assert convert_int_to_byte_string(idx) == key 86 | 87 | 88 | @pytest.mark.parametrize("drop_remainder", [True, False]) 89 | def test_lmdb_access_keys_non_sequential_shard(drop_remainder: bool) -> None: 90 | range_size = 10 91 | num_shards = 3 92 | lmdb_checkpoint_path = util.create_lmdb_checkpoint_using_range(range_size=range_size) 93 | key_shards = [] 94 | for shard_id in range(num_shards): 95 | lmdb_reader = yogadl.LmdbAccess(lmdb_path=lmdb_checkpoint_path) 96 | key_shards.append( 97 | shard_and_get_keys( 98 | lmdb_reader=lmdb_reader, 99 | shard_index=shard_id, 100 | num_shards=num_shards, 101 | sequential=False, 102 | drop_shard_remainder=drop_remainder, 103 | ) 104 | ) 105 | 106 | merged_keys = [] 107 | for idx in range(len(key_shards[0])): 108 | for key_shard in key_shards: 109 | if idx < len(key_shard): 110 | merged_keys.append(key_shard[idx]) 111 | 112 | expected_range_size = ( 113 | range_size if not drop_remainder else range_size - (range_size % num_shards) 114 | ) 115 | assert len(merged_keys) == expected_range_size 116 | for idx, key in enumerate(merged_keys): 117 | assert convert_int_to_byte_string(idx) == key 118 | 119 | 120 | def test_lmdb_access_shuffle() -> None: 121 | range_size = 10 122 | seed_one = 41 123 | seed_two = 421 124 | lmdb_checkpoint_path = util.create_lmdb_checkpoint_using_range(range_size=range_size) 125 | 126 | lmdb_reader_one = yogadl.LmdbAccess(lmdb_path=lmdb_checkpoint_path) 127 | keys_one = lmdb_reader_one.get_keys() 128 | keys_one = yogadl.shuffle_keys(keys=keys_one, seed=seed_one) 129 | 130 | lmdb_reader_two = yogadl.LmdbAccess(lmdb_path=lmdb_checkpoint_path) 131 | keys_two = lmdb_reader_two.get_keys() 132 | keys_two = yogadl.shuffle_keys(keys=keys_two, seed=seed_one) 133 | 134 | lmdb_reader_three = yogadl.LmdbAccess(lmdb_path=lmdb_checkpoint_path) 135 | keys_three = lmdb_reader_three.get_keys() 136 | keys_three = yogadl.shuffle_keys(keys=keys_three, seed=seed_two) 137 | 138 | assert keys_one == keys_two 139 | assert keys_one != keys_three 140 | 141 | 142 | def test_lmdb_access_read_values() -> None: 143 | range_size = 10 144 | lmdb_checkpoint_path = util.create_lmdb_checkpoint_using_range(range_size=range_size) 145 | lmdb_reader = yogadl.LmdbAccess(lmdb_path=lmdb_checkpoint_path) 146 | keys = lmdb_reader.get_keys() 147 | 148 | for idx, key in enumerate(keys): 149 | assert lmdb_reader.read_value_by_key(key=key) == idx 150 | 151 | 152 | def test_lmdb_access_shapes_and_types() -> None: 153 | range_size = 10 154 | lmdb_reader = yogadl.LmdbAccess( 155 | lmdb_path=util.create_lmdb_checkpoint_using_range(range_size=range_size) 156 | ) 157 | matching_dataset = tf.data.Dataset.range(range_size) 158 | assert lmdb_reader.get_shapes() == tf.compat.v1.data.get_output_shapes(matching_dataset) 159 | assert lmdb_reader.get_types() == tf.compat.v1.data.get_output_types(matching_dataset) 160 | -------------------------------------------------------------------------------- /tests/unit/local/test_local_lmdb_dataref.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020 Determined AI. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | import copy 16 | 17 | import numpy as np 18 | 19 | import tests.unit.util as util 20 | 21 | from yogadl import dataref 22 | 23 | 24 | def test_lfs_dataref_from_checkpoint() -> None: 25 | range_size = 10 26 | checkpoint_path = util.create_lmdb_checkpoint_using_range(range_size=range_size) 27 | lfs_dataref = dataref.LMDBDataRef(cache_filepath=checkpoint_path) 28 | stream = lfs_dataref.stream() 29 | 30 | for _ in range(3): 31 | idx = 0 32 | data_generator = stream.iterator_fn() 33 | for data in data_generator: 34 | assert data == idx 35 | idx += 1 36 | assert idx == range_size 37 | 38 | 39 | def test_lfs_dataref_with_offset() -> None: 40 | range_size = 10 41 | offset = 5 42 | checkpoint_path = util.create_lmdb_checkpoint_using_range(range_size=range_size) 43 | lfs_dataref = dataref.LMDBDataRef(cache_filepath=checkpoint_path) 44 | stream = lfs_dataref.stream(start_offset=offset) 45 | 46 | for epoch in range(3): 47 | idx = 5 if epoch == 0 else 0 48 | data_generator = stream.iterator_fn() 49 | for data in data_generator: 50 | assert data == idx 51 | idx += 1 52 | assert idx == range_size 53 | 54 | 55 | def test_lfs_dataref_with_shuffle() -> None: 56 | range_size = 10 57 | seed = 325 58 | checkpoint_path = util.create_lmdb_checkpoint_using_range(range_size=range_size) 59 | lfs_dataref = dataref.LMDBDataRef(cache_filepath=checkpoint_path) 60 | stream = lfs_dataref.stream(shuffle=True, skip_shuffle_at_epoch_end=True, shuffle_seed=seed) 61 | shuffled_keys = list(range(range_size)) 62 | shuffler = np.random.RandomState(seed) 63 | shuffler.shuffle(shuffled_keys) 64 | 65 | for _ in range(3): 66 | data_generator = stream.iterator_fn() 67 | idx = 0 68 | for data, shuffled_key in zip(data_generator, shuffled_keys): 69 | assert data == shuffled_key 70 | idx += 1 71 | assert idx == range_size 72 | 73 | 74 | def test_lfs_dataref_with_shuffle_after_epoch() -> None: 75 | range_size = 10 76 | seed = 325 77 | checkpoint_path = util.create_lmdb_checkpoint_using_range(range_size=range_size) 78 | lfs_dataref = dataref.LMDBDataRef(cache_filepath=checkpoint_path) 79 | stream = lfs_dataref.stream(shuffle=True, skip_shuffle_at_epoch_end=False, shuffle_seed=seed) 80 | un_shuffled_keys = list(range(range_size)) 81 | 82 | for epoch in range(3): 83 | shuffled_keys_for_epoch = copy.deepcopy(un_shuffled_keys) 84 | shuffler = np.random.RandomState(seed + epoch) 85 | shuffler.shuffle(shuffled_keys_for_epoch) 86 | 87 | data_generator = stream.iterator_fn() 88 | idx = 0 89 | for data, shuffled_key in zip(data_generator, shuffled_keys_for_epoch): 90 | assert data == shuffled_key 91 | idx += 1 92 | assert idx == range_size 93 | 94 | 95 | def test_lfs_dataref_with_offset_and_shuffle_after_epoch() -> None: 96 | range_size = 10 97 | seed = 325 98 | offset = 15 99 | checkpoint_path = util.create_lmdb_checkpoint_using_range(range_size=range_size) 100 | lfs_dataref = dataref.LMDBDataRef(cache_filepath=checkpoint_path) 101 | stream = lfs_dataref.stream( 102 | shuffle=True, skip_shuffle_at_epoch_end=False, shuffle_seed=seed, start_offset=offset 103 | ) 104 | un_shuffled_keys = list(range(range_size)) 105 | 106 | for epoch in range(offset // range_size, 5): 107 | shuffled_keys_for_epoch = copy.deepcopy(un_shuffled_keys) 108 | shuffler = np.random.RandomState(seed + epoch) 109 | shuffler.shuffle(shuffled_keys_for_epoch) 110 | 111 | if offset // range_size == epoch: 112 | shuffled_keys_for_epoch = shuffled_keys_for_epoch[offset % range_size :] 113 | 114 | data_generator = stream.iterator_fn() 115 | idx = 0 116 | for data, shuffled_key in zip(data_generator, shuffled_keys_for_epoch): 117 | assert data == shuffled_key 118 | idx += 1 119 | assert idx == len(shuffled_keys_for_epoch) 120 | 121 | 122 | def test_lfs_dataref_with_shuffle_zero_seed() -> None: 123 | range_size = 10 124 | seed = 0 125 | checkpoint_path = util.create_lmdb_checkpoint_using_range(range_size=range_size) 126 | lfs_dataref = dataref.LMDBDataRef(cache_filepath=checkpoint_path) 127 | stream = lfs_dataref.stream(shuffle=True, skip_shuffle_at_epoch_end=False, shuffle_seed=seed) 128 | un_shuffled_keys = list(range(range_size)) 129 | 130 | for epoch in range(3): 131 | shuffled_keys_for_epoch = copy.deepcopy(un_shuffled_keys) 132 | shuffler = np.random.RandomState(seed + epoch) 133 | shuffler.shuffle(shuffled_keys_for_epoch) 134 | 135 | data_generator = stream.iterator_fn() 136 | idx = 0 137 | for data, shuffled_key in zip(data_generator, shuffled_keys_for_epoch): 138 | assert data == shuffled_key 139 | idx += 1 140 | assert idx == range_size 141 | -------------------------------------------------------------------------------- /tests/unit/local/test_rw_coordinator.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020 Determined AI. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | import pathlib 16 | import time 17 | import urllib.parse 18 | from typing import List 19 | 20 | import lomond 21 | from tl.testing import thread 22 | 23 | import tests.unit.util as test_util 24 | 25 | import yogadl.constants as constants 26 | from yogadl import rw_coordinator 27 | 28 | 29 | def read_and_sleep( 30 | access_client: rw_coordinator.RwCoordinatorClient, 31 | sleep_time: int, 32 | bucket: str, 33 | cache_path: pathlib.Path, 34 | ) -> None: 35 | with access_client.read_lock( 36 | storage_type=constants.GCS_STORAGE, bucket=bucket, cache_path=cache_path 37 | ): 38 | time.sleep(sleep_time) 39 | 40 | 41 | def write_and_sleep( 42 | shared_data: List[int], 43 | access_client: rw_coordinator.RwCoordinatorClient, 44 | sleep_time: int, 45 | bucket: str, 46 | cache_path: pathlib.Path, 47 | ) -> None: 48 | with access_client.write_lock( 49 | storage_type=constants.GCS_STORAGE, bucket=bucket, cache_path=cache_path 50 | ): 51 | shared_data[0] += 1 52 | time.sleep(sleep_time) 53 | 54 | 55 | def send_and_die(lock_request_url: str) -> None: 56 | with lomond.WebSocket(lock_request_url) as socket: 57 | for event in socket.connect(): 58 | if isinstance(event, lomond.events.Text): 59 | return 60 | 61 | 62 | def read_and_die(bucket: str, cache_path: pathlib.Path, ip_address: str, port: int) -> None: 63 | lock_request_url = ( 64 | f"ws://{ip_address}:{port}/{constants.GCS_STORAGE}/{bucket}/" 65 | f"{str(cache_path)}?{urllib.parse.urlencode({'read_lock': True})}" 66 | ) 67 | 68 | send_and_die(lock_request_url=lock_request_url) 69 | 70 | 71 | def write_and_die(bucket: str, cache_path: pathlib.Path, ip_address: str, port: int) -> None: 72 | lock_request_url = ( 73 | f"ws://{ip_address}:{port}/{constants.GCS_STORAGE}/{bucket}/" 74 | f"{str(cache_path)}?{urllib.parse.urlencode({'read_lock': False})}" 75 | ) 76 | 77 | send_and_die(lock_request_url=lock_request_url) 78 | 79 | 80 | class MultiThreadedTests(thread.ThreadAwareTestCase): # type: ignore 81 | def test_rw_coordinator(self) -> None: 82 | ip_address = "localhost" 83 | port = 10245 84 | bucket = "my_bucket" 85 | cache_path = pathlib.Path("/tmp.mdb") 86 | num_threads = 5 87 | shared_data = [0] 88 | 89 | access_server_handler = test_util.AccessServerHandler(hostname=ip_address, port=port) 90 | access_server_handler.run_server_in_thread() 91 | access_client = rw_coordinator.RwCoordinatorClient(url=f"ws://{ip_address}:{port}") 92 | 93 | try: 94 | with thread.ThreadJoiner(45): 95 | for i in range(num_threads): 96 | self.run_in_thread( 97 | lambda: read_and_sleep( 98 | access_client=access_client, 99 | sleep_time=i + 1, 100 | bucket=bucket, 101 | cache_path=cache_path, 102 | ) 103 | ) 104 | self.run_in_thread( 105 | lambda: write_and_sleep( 106 | shared_data=shared_data, 107 | access_client=access_client, 108 | sleep_time=i, 109 | bucket=bucket, 110 | cache_path=cache_path, 111 | ) 112 | ) 113 | finally: 114 | access_server_handler.stop_server() 115 | 116 | assert shared_data[0] == num_threads 117 | 118 | def test_rw_coordinator_connections_die(self) -> None: 119 | ip_address = "localhost" 120 | port = 10245 121 | bucket = "my_bucket" 122 | cache_path = pathlib.Path("/tmp.mdb") 123 | num_threads = 5 124 | shared_data = [0] 125 | threads_to_die = [2, 3] 126 | 127 | access_server_handler = test_util.AccessServerHandler(hostname=ip_address, port=port) 128 | access_server_handler.run_server_in_thread() 129 | access_client = rw_coordinator.RwCoordinatorClient(url=f"ws://{ip_address}:{port}") 130 | 131 | try: 132 | with thread.ThreadJoiner(45): 133 | for i in range(num_threads): 134 | if i in threads_to_die: 135 | self.run_in_thread( 136 | lambda: read_and_die( 137 | bucket=bucket, 138 | cache_path=cache_path, 139 | ip_address=ip_address, 140 | port=port, 141 | ) 142 | ) 143 | self.run_in_thread( 144 | lambda: write_and_die( 145 | bucket=bucket, 146 | cache_path=cache_path, 147 | ip_address=ip_address, 148 | port=port, 149 | ) 150 | ) 151 | else: 152 | self.run_in_thread( 153 | lambda: read_and_sleep( 154 | access_client=access_client, 155 | sleep_time=i + 1, 156 | bucket=bucket, 157 | cache_path=cache_path, 158 | ) 159 | ) 160 | self.run_in_thread( 161 | lambda: write_and_sleep( 162 | shared_data=shared_data, 163 | access_client=access_client, 164 | sleep_time=i, 165 | bucket=bucket, 166 | cache_path=cache_path, 167 | ) 168 | ) 169 | finally: 170 | access_server_handler.stop_server() 171 | 172 | assert shared_data[0] == num_threads - len(threads_to_die) 173 | -------------------------------------------------------------------------------- /tests/unit/local/test_tensorflow_util.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020 Determined AI. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | import os 16 | import pathlib 17 | 18 | import tensorflow as tf 19 | 20 | from yogadl import tensorflow 21 | 22 | 23 | def test_read_tf_dataset() -> None: 24 | range_size = 10 25 | dataset = tf.data.Dataset.range(range_size) 26 | yield_output = list(tensorflow.read_tf_dataset(dataset=dataset, tf_config=None)) 27 | original_dataset = range(range_size) 28 | assert len(original_dataset) == len(yield_output) 29 | for original_data, yielded_data in zip(original_dataset, yield_output): 30 | assert original_data == yielded_data # type: ignore 31 | 32 | 33 | def test_serialize_tf_dataset_to_lmdb_metadata() -> None: 34 | range_size = 10 35 | dataset = tf.data.Dataset.range(range_size) 36 | checkpoint_path = pathlib.Path("/tmp/test_lmdb_checkpoint.mdb") 37 | if checkpoint_path.exists(): 38 | os.unlink(str(checkpoint_path)) 39 | assert not checkpoint_path.exists() 40 | 41 | dataset_entries = tensorflow.serialize_tf_dataset_to_lmdb( 42 | dataset=dataset, checkpoint_path=checkpoint_path, tf_config=None 43 | ) 44 | assert dataset_entries == range_size 45 | assert checkpoint_path.exists() 46 | -------------------------------------------------------------------------------- /tests/unit/util.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020 Determined AI. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | import asyncio 16 | import pathlib 17 | import threading 18 | from typing import Optional 19 | 20 | import tensorflow as tf 21 | 22 | from yogadl import rw_coordinator, tensorflow 23 | 24 | 25 | def create_lmdb_checkpoint_using_range(range_size: int) -> pathlib.Path: 26 | dataset = tf.data.Dataset.range(range_size) 27 | checkpoint_path = pathlib.Path("/tmp/test_lmdb_checkpoint.mdb") 28 | if checkpoint_path.exists(): 29 | checkpoint_path.unlink() 30 | 31 | tensorflow.serialize_tf_dataset_to_lmdb( 32 | dataset=dataset, checkpoint_path=checkpoint_path, tf_config=None 33 | ) 34 | 35 | return checkpoint_path 36 | 37 | 38 | class AccessServerHandler: 39 | def __init__(self, hostname: str, port: int) -> None: 40 | self._access_server = rw_coordinator.RwCoordinatorServer(hostname=hostname, port=port) 41 | 42 | self._thread_running_server = None # type: Optional[threading.Thread] 43 | 44 | def run_server_in_thread(self) -> None: 45 | asyncio.get_event_loop().run_until_complete(self._access_server.run_server()) 46 | self._thread_running_server = threading.Thread(target=asyncio.get_event_loop().run_forever) 47 | self._thread_running_server.start() 48 | 49 | def stop_server(self) -> None: 50 | self._access_server.stop_server() 51 | 52 | assert self._thread_running_server 53 | self._thread_running_server.join() 54 | -------------------------------------------------------------------------------- /yogadl/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020 Determined AI. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | from ._core import DataRef, Storage, Stream, Submittable 16 | from ._keys_operator import ( 17 | GeneratorFromKeys, 18 | non_sequential_shard, 19 | sequential_shard, 20 | shard_keys, 21 | shuffle_keys, 22 | ) 23 | from ._lmdb_handler import LmdbAccess, serialize_generator_to_lmdb 24 | -------------------------------------------------------------------------------- /yogadl/_core.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020 Determined AI. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """ 16 | The core interfaces of the yoga data layer. 17 | """ 18 | 19 | import abc 20 | from typing import Any, Callable, Optional, Union 21 | 22 | import tensorflow as tf 23 | 24 | 25 | # TODO: Make sure users are not required to have TF, PyTorch, 26 | # and TP dataflows all installed to use this. 27 | Submittable = Union[ 28 | tf.data.Dataset, 29 | ] 30 | 31 | 32 | class Stream: 33 | """ 34 | Stream contains a generator of data and other required information 35 | to feed into framework specific data APIs. 36 | """ 37 | 38 | def __init__( 39 | self, 40 | iterator_fn: Callable, 41 | length: int, 42 | output_types: Any = None, 43 | output_shapes: Any = None, 44 | ): 45 | self.iterator_fn = iterator_fn 46 | self.length = length 47 | self.output_types = output_types 48 | self.output_shapes = output_shapes 49 | 50 | def __iter__(self) -> Any: 51 | """ 52 | Iterate through the records in the stream. 53 | """ 54 | return self.iterator_fn() 55 | 56 | def __len__(self) -> int: 57 | """ 58 | Return the length of the stream, which may differ from the length of the dataset. 59 | """ 60 | return self.length 61 | 62 | 63 | class DataRef(metaclass=abc.ABCMeta): 64 | """ 65 | The base interface for a reference to a dataset in the yogadl framework. 66 | 67 | The DataRef may refer to a dataset in a remote storage location; it need not refer to locally- 68 | available data. The only mechanism for accessing the records inside the dataset is to create a 69 | Stream and to iterate through them. 70 | 71 | By specifying all of the random-access options up front, the backend which provides the DataRef 72 | can provide performance-optimized streaming, since it is guaranteed with yogadl that lower 73 | layers will operate without random access. 74 | """ 75 | 76 | @abc.abstractmethod 77 | def stream( 78 | self, 79 | start_offset: int = 0, 80 | shuffle: bool = False, 81 | skip_shuffle_at_epoch_end: bool = False, 82 | shuffle_seed: Optional[int] = None, 83 | shard_rank: int = 0, 84 | num_shards: int = 1, 85 | drop_shard_remainder: bool = False, 86 | ) -> Stream: 87 | """ 88 | Create a sequentially accessible set of records from the dataset, according to the 89 | random-access arguments given as parameters. 90 | """ 91 | pass 92 | 93 | @abc.abstractmethod 94 | def __len__(self) -> int: 95 | """ 96 | Return the length of the dataset that the DataRef refers to. 97 | """ 98 | pass 99 | 100 | 101 | class Storage(metaclass=abc.ABCMeta): 102 | """ 103 | Storage is a cache for datasets. 104 | 105 | Storage accepts datasets in various forms via submit(), and returns DataRef objects via 106 | fetch(). 107 | 108 | Conceptually, Storage is sort of like a DataRef factory. It stores datasets 109 | in an unspecified format, and returns objects which implement the DataRef 110 | interface. 111 | 112 | Note that submit() and fetch() are not multiprocessing-safe by default. 113 | The @cacheable decorator should be safe to call simultaneously from 114 | many threads, processes, or machines. 115 | """ 116 | 117 | @abc.abstractmethod 118 | def submit(self, data: Submittable, dataset_id: str, dataset_version: str) -> None: 119 | """ 120 | Stores dataset to a cache. 121 | """ 122 | pass 123 | 124 | @abc.abstractmethod 125 | def fetch(self, dataset_id: str, dataset_version: str) -> DataRef: 126 | """ 127 | Fetch a dataset from storage and provide a DataRef for streaming it. 128 | """ 129 | pass 130 | 131 | @abc.abstractmethod 132 | def cacheable(self, dataset_id: str, dataset_version: str) -> Callable: 133 | """ 134 | A decorator that calls submit and fetch and is responsible for coordinating 135 | amongst instances of Storage in different processes. 136 | """ 137 | pass 138 | -------------------------------------------------------------------------------- /yogadl/_keys_operator.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020 Determined AI. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | import copy 16 | from typing import Any, Callable, Generator, List, Optional 17 | 18 | import numpy as np 19 | 20 | 21 | def sequential_shard(keys: List[bytes], shard_index: int, num_shards: int) -> List[bytes]: 22 | num_keys = len(keys) // num_shards 23 | if shard_index < len(keys) % num_shards: 24 | num_keys += 1 25 | start_index = num_keys * shard_index + min(len(keys) % num_shards, shard_index) 26 | return keys[start_index : start_index + num_keys] 27 | 28 | 29 | def non_sequential_shard(keys: List[bytes], shard_index: int, num_shards: int) -> List[bytes]: 30 | key_indexes = list(range(shard_index, len(keys), num_shards)) 31 | return [keys[idx] for idx in key_indexes] 32 | 33 | 34 | def shard_keys( 35 | keys: List[bytes], 36 | shard_index: int, 37 | num_shards: int, 38 | sequential: bool = False, 39 | drop_shard_remainder: bool = False, 40 | ) -> List[bytes]: 41 | assert shard_index >= 0, "Shard index must be greater or equal to zero." 42 | assert shard_index < num_shards, "Shard index must be less than num_shards." 43 | 44 | if drop_shard_remainder: 45 | assert len(keys) >= num_shards, f"Too few keys to shard across {num_shards} ranks." 46 | keys = keys[: len(keys) - (len(keys) % num_shards)] 47 | 48 | if sequential: 49 | return sequential_shard(keys=keys, shard_index=shard_index, num_shards=num_shards) 50 | else: 51 | return non_sequential_shard(keys=keys, shard_index=shard_index, num_shards=num_shards) 52 | 53 | 54 | def shuffle_keys(keys: List[bytes], seed: Optional[int] = None) -> List[bytes]: 55 | shuffler = np.random.RandomState(seed) 56 | shuffler.shuffle(keys) 57 | return keys 58 | 59 | 60 | class GeneratorFromKeys: 61 | def __init__( 62 | self, 63 | keys: List[bytes], 64 | initial_offset: int, 65 | read_val_from_key_fn: Callable, 66 | shuffle_at_start: bool, 67 | shuffle_after_epoch: bool, 68 | shuffle_seed: Optional[int], 69 | ) -> None: 70 | assert initial_offset >= 0 71 | self._keys = keys 72 | self._initial_offset = initial_offset % len(self._keys) 73 | self._current_epoch = initial_offset // len(self._keys) 74 | self._read_val_from_key_fn = read_val_from_key_fn 75 | self._shuffle_enabled = shuffle_at_start 76 | self._shuffle_after_epoch = shuffle_after_epoch 77 | self._shuffle_seed = shuffle_seed 78 | self._initial_epoch = True 79 | 80 | self._validate_args() 81 | 82 | def _validate_args(self) -> None: 83 | if self._shuffle_after_epoch: 84 | assert self._shuffle_enabled, "`shuffle` must be enabled to use `shuffle_after_epoch`." 85 | assert ( 86 | self._shuffle_seed is not None 87 | ), "`shuffle_seed` must be set to use `shuffle_after_epoch`." 88 | 89 | def instantiate_generator(self) -> Generator[Any, None, None]: 90 | keys = self._shuffle_keys() if self._shuffle_enabled else self._keys 91 | self._current_epoch += 1 92 | 93 | key_index = self._initial_offset if self._initial_epoch else 0 94 | self._initial_epoch = False 95 | 96 | while key_index < len(keys): 97 | yield self._read_val_from_key_fn(keys[key_index]) 98 | key_index += 1 99 | 100 | def _shuffle_keys(self) -> List[bytes]: 101 | shuffle_seed = self._shuffle_seed 102 | if self._current_epoch > 0 and self._shuffle_after_epoch: 103 | assert shuffle_seed is not None 104 | shuffle_seed += self._current_epoch 105 | 106 | return shuffle_keys(keys=copy.deepcopy(self._keys), seed=shuffle_seed) 107 | -------------------------------------------------------------------------------- /yogadl/_lmdb_handler.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020 Determined AI. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | import logging 16 | import pathlib 17 | import pickle 18 | import platform 19 | from typing import Any, cast, Generator, List 20 | 21 | import lmdb 22 | 23 | 24 | # serialize_generator_to_lmdb is derived from: 25 | # 26 | # Copyright 2016 Yuxin Wu. All Rights Reserved. 27 | # 28 | # Licensed under the Apache License, Version 2.0 (the "License"); 29 | # you may not use this file except in compliance with the License. 30 | # You may obtain a copy of the License at 31 | # 32 | # http://www.apache.org/licenses/LICENSE-2.0 33 | # 34 | # Unless required by applicable law or agreed to in writing, software 35 | # distributed under the License is distributed on an "AS IS" BASIS, 36 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 37 | # See the License for the specific language governing permissions and 38 | # limitations under the License. 39 | # ============================================================================== 40 | def serialize_generator_to_lmdb( 41 | dataset_generator: Generator, 42 | data_shapes: Any, 43 | data_types: Any, 44 | lmdb_path: pathlib.Path, 45 | write_frequency: int = 5000, 46 | ) -> int: 47 | """ 48 | Serialize a generator to a single LMDB file. Adapted from [1]. 49 | 50 | [1] https://tensorpack.readthedocs.io/_modules/tensorpack/dataflow/serialize.html 51 | """ 52 | assert lmdb_path.parent.is_dir(), "Checkpoint directory does not exist." 53 | assert not lmdb_path.exists(), "Checkpoint path already exists." 54 | # It's OK to use super large map_size on Linux, but not on other platforms 55 | # See: https://github.com/NVIDIA/DIGITS/issues/206 56 | map_size = 1099511627776 * 2 if platform.system() == "Linux" else 128 * 10 ** 6 57 | db = lmdb.open( 58 | str(lmdb_path), 59 | subdir=False, 60 | map_size=map_size, 61 | readonly=False, 62 | meminit=False, 63 | map_async=True, 64 | ) # need sync() at the end 65 | 66 | # put data into lmdb, and doubling the size if full. 67 | # Ref: https://github.com/NVIDIA/DIGITS/pull/209/files 68 | def put_or_grow(txn: lmdb.Transaction, key: Any, value: Any) -> lmdb.Transaction: 69 | try: 70 | txn.put(key, value) 71 | return txn 72 | except lmdb.MapFullError: 73 | pass 74 | txn.abort() 75 | curr_size = db.info()["map_size"] 76 | new_size = curr_size * 2 77 | logging.info(f"Doubling LMDB map_size to {new_size / 10 ** 9} GB") 78 | db.set_mapsize(new_size) 79 | txn = db.begin(write=True) 80 | txn = put_or_grow(txn, key, value) 81 | return txn 82 | 83 | # LMDB transaction is not exception-safe! 84 | # although it has a context manager interface 85 | txn = db.begin(write=True) 86 | dataset_entries = 0 87 | for data in dataset_generator: 88 | txn = put_or_grow( 89 | txn=txn, 90 | key="{:08}".format(dataset_entries).encode("ascii"), 91 | value=pickle.dumps(data, protocol=-1), 92 | ) 93 | if dataset_entries % write_frequency == 0: 94 | txn.commit() 95 | txn = db.begin(write=True) 96 | dataset_entries += 1 97 | txn.commit() 98 | 99 | keys = ["{:08}".format(k).encode("ascii") for k in range(dataset_entries)] 100 | with db.begin(write=True) as txn: 101 | put_or_grow(txn=txn, key=b"__keys__", value=pickle.dumps(keys, protocol=-1)) 102 | put_or_grow(txn=txn, key=b"__shapes__", value=pickle.dumps(data_shapes, protocol=-1)) 103 | put_or_grow(txn=txn, key=b"__types__", value=pickle.dumps(data_types, protocol=-1)) 104 | 105 | logging.debug("Flushing database ...") 106 | db.sync() 107 | db.close() 108 | 109 | return dataset_entries 110 | 111 | 112 | class LmdbAccess: 113 | """ 114 | Provides random access to an LMDB store file. Adopted from [1]. 115 | 116 | [1] https://github.com/tensorpack/tensorpack/blob/master/tensorpack/dataflow/format.py 117 | """ 118 | 119 | def __init__(self, lmdb_path: pathlib.Path) -> None: 120 | assert lmdb_path.exists(), f"Unable to load LMDB database from {lmdb_path}." 121 | self._lmdb_path = lmdb_path 122 | self._db_connection_open = False 123 | 124 | self._open_lmdb() 125 | self._size = cast(int, self._txn.stat()["entries"]) 126 | self._read_keys_from_db() 127 | self._read_shapes_from_db() 128 | self._read_types_from_db() 129 | logging.debug(f"Found {self._size} entries in {self._lmdb_path}.") 130 | self._close_lmdb() 131 | 132 | def __exit__(self, *_: Any) -> None: 133 | self._close_lmdb() 134 | 135 | def _open_lmdb(self) -> None: 136 | self._lmdb = lmdb.open( 137 | str(self._lmdb_path), 138 | subdir=False, 139 | readonly=True, 140 | lock=False, 141 | readahead=True, 142 | map_size=1099511627776 * 2, 143 | max_readers=100, 144 | ) 145 | self._txn = self._lmdb.begin() 146 | self._db_connection_open = True 147 | 148 | def _read_keys_from_db(self) -> None: 149 | self._keys = self._txn.get(b"__keys__") 150 | assert self._keys is not None 151 | self._keys = cast(List[bytes], pickle.loads(self._keys)) 152 | self._size -= 1 # delete this item 153 | 154 | def _read_shapes_from_db(self) -> None: 155 | self._shapes = self._txn.get(b"__shapes__") 156 | assert self._shapes is not None 157 | self._shapes = pickle.loads(self._shapes) 158 | 159 | def _read_types_from_db(self) -> None: 160 | self._types = self._txn.get(b"__types__") 161 | assert self._types is not None 162 | self._types = pickle.loads(self._types) 163 | 164 | def _close_lmdb(self) -> None: 165 | self._lmdb.close() 166 | del self._lmdb 167 | del self._txn 168 | self._db_connection_open = False 169 | 170 | def get_keys(self) -> List[bytes]: 171 | return cast(List[bytes], self._keys) 172 | 173 | def get_shapes(self) -> Any: 174 | return self._shapes 175 | 176 | def get_types(self) -> Any: 177 | return self._types 178 | 179 | def read_value_by_key(self, key: bytes) -> Any: 180 | if not self._db_connection_open: 181 | self._open_lmdb() 182 | 183 | return pickle.loads(self._txn.get(key)) 184 | -------------------------------------------------------------------------------- /yogadl/constants.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020 Determined AI. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | AWS_STORAGE = "aws_storage" 16 | 17 | GCS_STORAGE = "gcs_storage" 18 | -------------------------------------------------------------------------------- /yogadl/dataref/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020 Determined AI. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | from yogadl.dataref._local_lmdb_dataref import LMDBDataRef 16 | -------------------------------------------------------------------------------- /yogadl/dataref/_local_lmdb_dataref.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020 Determined AI. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | import pathlib 16 | from typing import List, Optional 17 | 18 | import yogadl 19 | 20 | 21 | class LMDBDataRef(yogadl.DataRef): 22 | def __init__(self, cache_filepath: pathlib.Path): 23 | self._lmdb_access = yogadl.LmdbAccess(lmdb_path=cache_filepath) 24 | self._keys = self._lmdb_access.get_keys() 25 | 26 | def stream( 27 | self, 28 | start_offset: int = 0, 29 | shuffle: bool = False, 30 | skip_shuffle_at_epoch_end: bool = False, 31 | shuffle_seed: Optional[int] = None, 32 | shard_rank: int = 0, 33 | num_shards: int = 1, 34 | drop_shard_remainder: bool = False, 35 | ) -> yogadl.Stream: 36 | """ 37 | Create a stream from a cache. 38 | """ 39 | if shuffle and not skip_shuffle_at_epoch_end: 40 | assert shuffle_seed is not None, ( 41 | "Please set `shuffle_seed` if enabling `shuffle` and not enabling " 42 | "`skip_shuffle_at_epoch_end`." 43 | ) 44 | 45 | generated_keys = self._shard_keys( 46 | shard_rank=shard_rank, 47 | num_shards=num_shards, 48 | drop_shard_remainder=drop_shard_remainder, 49 | ) 50 | 51 | generator_from_keys = yogadl.GeneratorFromKeys( 52 | keys=generated_keys, 53 | initial_offset=start_offset, 54 | read_val_from_key_fn=self._lmdb_access.read_value_by_key, 55 | shuffle_at_start=shuffle, 56 | shuffle_after_epoch=shuffle and not skip_shuffle_at_epoch_end, 57 | shuffle_seed=shuffle_seed, 58 | ) 59 | 60 | return yogadl.Stream( 61 | iterator_fn=generator_from_keys.instantiate_generator, 62 | length=len(generated_keys), 63 | output_types=self._lmdb_access.get_types(), 64 | output_shapes=self._lmdb_access.get_shapes(), 65 | ) 66 | 67 | def __len__(self) -> int: 68 | return len(self._keys) 69 | 70 | def _shard_keys( 71 | self, shard_rank: int, num_shards: int, drop_shard_remainder: bool 72 | ) -> List[bytes]: 73 | generated_keys = yogadl.shard_keys( 74 | keys=self._keys, 75 | shard_index=shard_rank, 76 | num_shards=num_shards, 77 | sequential=False, 78 | drop_shard_remainder=drop_shard_remainder, 79 | ) 80 | 81 | return generated_keys 82 | -------------------------------------------------------------------------------- /yogadl/rw_coordinator/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020 Determined AI. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | from . import communication_protocol as communication_protocol 16 | from yogadl.rw_coordinator._client import RwCoordinatorClient as RwCoordinatorClient 17 | from yogadl.rw_coordinator._server import RwCoordinatorServer as RwCoordinatorServer 18 | -------------------------------------------------------------------------------- /yogadl/rw_coordinator/_client.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020 Determined AI. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | import contextlib 16 | import logging 17 | import pathlib 18 | import socket 19 | import ssl 20 | import urllib.parse 21 | from typing import Generator, Optional 22 | 23 | import lomond 24 | 25 | from yogadl.rw_coordinator import communication_protocol 26 | 27 | 28 | class CustomSSLWebsocketSession(lomond.session.WebsocketSession): # type: ignore 29 | """ 30 | A session class that allows for the TLS verification mode of a WebSocket connection to be 31 | configured. 32 | """ 33 | 34 | def __init__( 35 | self, 36 | socket: lomond.WebSocket, 37 | skip_verify: bool, 38 | coordinator_cert_file: Optional[str], 39 | coordinator_cert_name: Optional[str], 40 | ) -> None: 41 | super().__init__(socket) 42 | self._coordinator_cert_name = coordinator_cert_name 43 | 44 | self.ctx = ssl.SSLContext(ssl.PROTOCOL_TLSv1_2) 45 | 46 | if skip_verify: 47 | return 48 | self.ctx.verify_mode = ssl.CERT_REQUIRED 49 | self.ctx.check_hostname = True 50 | self.ctx.load_default_certs() 51 | if coordinator_cert_file is not None: 52 | self.ctx.load_verify_locations(cafile=coordinator_cert_file) 53 | 54 | def _wrap_socket(self, sock: socket.SocketType, host: str) -> socket.SocketType: 55 | return self.ctx.wrap_socket(sock, server_hostname=self._coordinator_cert_name or host) 56 | 57 | 58 | class RwCoordinatorClient: 59 | """ 60 | RwCoordinatorClient acquires locks from RwCoordinatorServer. 61 | 62 | RwCoordinatorClient provides read and write locks. An instance of 63 | AccessServer must be running during lock request. 64 | """ 65 | 66 | def __init__( 67 | self, 68 | url: str, 69 | skip_verify: bool = False, 70 | coordinator_cert_file: Optional[str] = None, 71 | coordinator_cert_name: Optional[str] = None, 72 | ): 73 | self._url = url 74 | self._skip_verify = skip_verify 75 | self._coordinator_cert_file = coordinator_cert_file 76 | self._coordinator_cert_name = coordinator_cert_name 77 | 78 | def _construct_url_request( 79 | self, storage_type: str, bucket: str, cache_path: pathlib.Path, read_lock: bool 80 | ) -> str: 81 | lock_request_url = ( 82 | f"{self._url}/{storage_type}/{bucket}/{str(cache_path)}" 83 | f"?{urllib.parse.urlencode({'read_lock': read_lock})}" 84 | ) 85 | 86 | logging.debug(f"Generated lock request url: {lock_request_url}.") 87 | 88 | return lock_request_url 89 | 90 | @contextlib.contextmanager 91 | def _request_lock( 92 | self, lock_request_url: str, expected_response: str 93 | ) -> Generator[None, None, None]: 94 | with lomond.WebSocket(lock_request_url) as socket: 95 | for event in socket.connect( 96 | session_class=lambda socket: CustomSSLWebsocketSession( 97 | socket, 98 | self._skip_verify, 99 | self._coordinator_cert_file, 100 | self._coordinator_cert_name, 101 | ) 102 | ): 103 | if isinstance(event, lomond.events.ConnectFail): 104 | raise ConnectionError(f"connect({self._url}): {event}") 105 | elif isinstance(event, lomond.events.Text): 106 | assert event.text == expected_response 107 | yield 108 | socket.close() 109 | 110 | @contextlib.contextmanager 111 | def read_lock( 112 | self, storage_type: str, bucket: str, cache_path: pathlib.Path 113 | ) -> Generator[None, None, None]: 114 | lock_request_url = self._construct_url_request( 115 | storage_type=storage_type, 116 | bucket=bucket, 117 | cache_path=cache_path, 118 | read_lock=True, 119 | ) 120 | 121 | with self._request_lock( 122 | lock_request_url=lock_request_url, 123 | expected_response=communication_protocol.READ_LOCK_GRANTED, 124 | ): 125 | yield 126 | 127 | @contextlib.contextmanager 128 | def write_lock( 129 | self, storage_type: str, bucket: str, cache_path: pathlib.Path 130 | ) -> Generator[None, None, None]: 131 | lock_request_url = self._construct_url_request( 132 | storage_type=storage_type, 133 | bucket=bucket, 134 | cache_path=cache_path, 135 | read_lock=False, 136 | ) 137 | 138 | with self._request_lock( 139 | lock_request_url=lock_request_url, 140 | expected_response=communication_protocol.WRITE_LOCK_GRANTED, 141 | ): 142 | yield 143 | -------------------------------------------------------------------------------- /yogadl/rw_coordinator/_server.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020 Determined AI. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | import asyncio 16 | import async_generator 17 | import logging 18 | import time 19 | import urllib.parse 20 | from typing import Any, Dict, AsyncIterator, Optional 21 | 22 | from websockets import WebSocketServerProtocol, ConnectionClosedError, serve 23 | 24 | from yogadl.rw_coordinator import communication_protocol 25 | 26 | 27 | class RWLock: 28 | def __init__(self) -> None: 29 | self.rw_cond = asyncio.Condition() 30 | self.writers_waiting = 0 31 | self.active_readers = 0 32 | self.active_writer = False 33 | 34 | @async_generator.asynccontextmanager # type: ignore 35 | async def read_lock(self) -> AsyncIterator[str]: 36 | async with self.rw_cond: 37 | while self.writers_waiting > 0 or self.active_writer: 38 | await self.rw_cond.wait() 39 | self.active_readers += 1 40 | 41 | try: 42 | yield communication_protocol.READ_LOCK_GRANTED 43 | finally: 44 | async with self.rw_cond: 45 | self.active_readers -= 1 46 | self.rw_cond.notify_all() 47 | 48 | @async_generator.asynccontextmanager # type: ignore 49 | async def write_lock(self) -> AsyncIterator[str]: 50 | async with self.rw_cond: 51 | self.writers_waiting += 1 52 | while self.active_readers > 0 or self.active_writer: 53 | await self.rw_cond.wait() 54 | self.active_writer = True 55 | self.writers_waiting -= 1 56 | 57 | try: 58 | yield communication_protocol.WRITE_LOCK_GRANTED 59 | finally: 60 | async with self.rw_cond: 61 | self.active_writer = False 62 | self.rw_cond.notify_all() 63 | 64 | 65 | class RwCoordinatorServer: 66 | """ 67 | RwCoordinatorServer provides RWlocks for clients that connect to it. 68 | 69 | RwCoordinatorServer provides unique RWLock for each instance of dataset id 70 | and dataset version. In cases of connection loss, the lock issued 71 | to the client is revoked. 72 | 73 | RwCoordinatorServer does not support synchronization across multiple instances 74 | of RwCoordinatorServer. Users should avoid running more than one instance of 75 | RwCoordinatorServer concurrently. 76 | """ 77 | 78 | def __init__( 79 | self, 80 | hostname: Optional[str] = None, 81 | port: Optional[int] = None, 82 | ssl_context: Optional[Any] = None, 83 | ) -> None: 84 | self._hostname = hostname 85 | self._port = port 86 | self._ssl_context = ssl_context 87 | 88 | # Used to access rw_locks dictionary. 89 | self._global_lock = asyncio.Lock() 90 | 91 | # Unique RWLock per cache. 92 | self._rw_locks = {} # type: Dict[str, RWLock] 93 | 94 | async def run_server(self) -> None: 95 | self._server = await serve( 96 | self._process_lock_request, 97 | self._hostname, 98 | self._port, 99 | ssl=self._ssl_context, 100 | ) 101 | 102 | def stop_server(self) -> None: 103 | asyncio.get_event_loop().call_soon_threadsafe(asyncio.get_event_loop().stop) 104 | 105 | while asyncio.get_event_loop().is_running(): 106 | time.sleep(1) 107 | 108 | self._server.close() 109 | 110 | @async_generator.asynccontextmanager # type: ignore 111 | async def _get_lock(self, rw_lock: RWLock, read_lock: bool) -> AsyncIterator[str]: 112 | if read_lock: 113 | async with rw_lock.read_lock() as response: 114 | yield response 115 | else: 116 | async with rw_lock.write_lock() as response: 117 | yield response 118 | 119 | async def _process_lock_request(self, websocket: WebSocketServerProtocol, path: str) -> None: 120 | parsed_url = urllib.parse.urlparse(path) 121 | resource = parsed_url.path 122 | parsed_query = urllib.parse.parse_qs(parsed_url.query) 123 | assert "read_lock" in parsed_query.keys() 124 | read_lock = parsed_query["read_lock"][0] == "True" 125 | 126 | async with self._global_lock: 127 | rw_lock = self._rw_locks.setdefault(resource, RWLock()) 128 | 129 | try: 130 | async with self._get_lock(rw_lock=rw_lock, read_lock=read_lock) as response: 131 | await websocket.send(response) 132 | 133 | async for _ in websocket: 134 | pass 135 | 136 | except ConnectionClosedError: 137 | logging.warning("Client connection closed unexpectedly.") 138 | pass 139 | -------------------------------------------------------------------------------- /yogadl/rw_coordinator/communication_protocol.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020 Determined AI. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | READ_LOCK_GRANTED = "read_lock_granted" 16 | WRITE_LOCK_GRANTED = "write_lock_granted" 17 | -------------------------------------------------------------------------------- /yogadl/storage/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020 Determined AI. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | from ._cloud_storage import BaseCloudConfigurations, BaseCloudStorage 16 | from ._gcs_storage import GCSConfigurations, GCSStorage 17 | from ._lfs_storage import LFSConfigurations, LFSStorage 18 | from ._s3_storage import S3Configurations, S3Storage 19 | -------------------------------------------------------------------------------- /yogadl/storage/_cloud_storage.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020 Determined AI. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | import abc 16 | import contextlib 17 | import datetime 18 | import json 19 | import logging 20 | import pathlib 21 | from typing import Any, Callable, Dict, Generator, Optional, cast 22 | 23 | import filelock 24 | import tensorflow as tf 25 | 26 | import yogadl 27 | from yogadl import dataref, rw_coordinator, tensorflow 28 | 29 | 30 | class BaseCloudConfigurations(metaclass=abc.ABCMeta): 31 | """ 32 | Configurations for BaseCloudStorage. 33 | """ 34 | 35 | def __init__( 36 | self, 37 | bucket: str, 38 | bucket_directory_path: str, 39 | url: str, 40 | local_cache_dir: str, 41 | skip_verify: bool, 42 | coordinator_cert_file: Optional[str], 43 | coordinator_cert_name: Optional[str], 44 | ) -> None: 45 | self.bucket = bucket 46 | self.bucket_directory_path = pathlib.Path(bucket_directory_path) 47 | self.url = url 48 | self.local_cache_dir = pathlib.Path(local_cache_dir) 49 | self.cache_format = "LMDB" 50 | self.skip_verify = skip_verify 51 | self.coordinator_cert_file = coordinator_cert_file 52 | self.coordinator_cert_name = coordinator_cert_name 53 | 54 | 55 | class BaseCloudStorage(yogadl.Storage): 56 | """ 57 | Base class for using cloud storage. 58 | 59 | This class should never be used directly. Instead users should use 60 | S3Storage or GCSStorage. 61 | """ 62 | 63 | def __init__( 64 | self, 65 | configurations: BaseCloudConfigurations, 66 | tensorflow_config: Optional[tf.compat.v1.ConfigProto], 67 | ) -> None: 68 | self._configurations = configurations 69 | self._rw_client = rw_coordinator.RwCoordinatorClient( 70 | url=self._configurations.url, 71 | skip_verify=self._configurations.skip_verify, 72 | coordinator_cert_file=self._configurations.coordinator_cert_file, 73 | coordinator_cert_name=self._configurations.coordinator_cert_name, 74 | ) 75 | self._supported_cache_formats = ["LMDB"] 76 | self._tensorflow_config = tensorflow_config 77 | 78 | @property 79 | @abc.abstractmethod 80 | def _storage_type(self) -> str: 81 | pass 82 | 83 | @abc.abstractmethod 84 | def _is_cloud_cache_present(self, dataset_id: str, dataset_version: str) -> bool: 85 | pass 86 | 87 | @abc.abstractmethod 88 | def _download_from_cloud_storage( 89 | self, dataset_id: str, dataset_version: str, local_cache_filepath: pathlib.Path 90 | ) -> datetime.datetime: 91 | pass 92 | 93 | @abc.abstractmethod 94 | def _upload_to_cloud_storage( 95 | self, dataset_id: str, dataset_version: str, local_cache_filepath: pathlib.Path 96 | ) -> datetime.datetime: 97 | pass 98 | 99 | @abc.abstractmethod 100 | def _get_remote_cache_timestamp( 101 | self, dataset_id: str, dataset_version: str 102 | ) -> datetime.datetime: 103 | pass 104 | 105 | def submit(self, data: yogadl.Submittable, dataset_id: str, dataset_version: str) -> None: 106 | """ 107 | Stores dataset by creating a local cache and uploading it to cloud storage. 108 | 109 | If a cache with a matching filepath already exists in cloud storage, it will be overwritten. 110 | 111 | `submit()` is not safe for concurrent accesses. For concurrent accesses use 112 | `cacheable()`. 113 | """ 114 | local_cache_filepath = self._get_local_cache_filepath( 115 | dataset_id=dataset_id, 116 | dataset_version=dataset_version, 117 | ) 118 | local_cache_filepath.parent.mkdir(parents=True, exist_ok=True) 119 | 120 | if local_cache_filepath.exists(): 121 | logging.debug(f"Removing old local cache: {local_cache_filepath}.") 122 | local_cache_filepath.unlink() 123 | 124 | # TODO: remove TF hardcoding. 125 | tensorflow.serialize_tf_dataset_to_lmdb( 126 | dataset=data, checkpoint_path=local_cache_filepath, tf_config=self._tensorflow_config 127 | ) 128 | logging.info( 129 | f"Serialized dataset {dataset_id}:{dataset_version} to local cache: " 130 | f"{local_cache_filepath} and uploading to remote storage." 131 | ) 132 | 133 | timestamp = self._upload_to_cloud_storage( 134 | dataset_id=dataset_id, 135 | dataset_version=dataset_version, 136 | local_cache_filepath=local_cache_filepath, 137 | ).timestamp() 138 | logging.info("Cache upload to remote storage finished.") 139 | 140 | # Update metadata with new upload time. 141 | local_metadata = self._get_local_metadata( 142 | dataset_id=dataset_id, dataset_version=dataset_version 143 | ) 144 | 145 | local_metadata["time_created"] = timestamp 146 | self._save_local_metadata( 147 | dataset_id=dataset_id, 148 | dataset_version=dataset_version, 149 | metadata=local_metadata, 150 | ) 151 | 152 | def fetch(self, dataset_id: str, dataset_version: str) -> dataref.LMDBDataRef: 153 | """ 154 | Fetch a dataset from cloud storage and provide a DataRef for streaming it. 155 | 156 | The timestamp of the cache in cloud storage is compared to the creation 157 | time of the local cache, if they are not identical, the local cache 158 | is overwritten. 159 | 160 | `fetch()` is not safe for concurrent accesses. For concurrent accesses use 161 | `cacheable()`. 162 | """ 163 | 164 | local_metadata = self._get_local_metadata( 165 | dataset_id=dataset_id, dataset_version=dataset_version 166 | ) 167 | local_cache_filepath = self._get_local_cache_filepath( 168 | dataset_id=dataset_id, 169 | dataset_version=dataset_version, 170 | ) 171 | 172 | remote_cache_timestamp = self._get_remote_cache_timestamp( 173 | dataset_id=dataset_id, dataset_version=dataset_version 174 | ).timestamp() 175 | 176 | if local_metadata.get("time_created") == remote_cache_timestamp: 177 | logging.info("Local cache matches remote cache.") 178 | else: 179 | logging.info(f"Downloading remote cache to {local_cache_filepath}.") 180 | local_metadata["time_created"] = self._download_from_cloud_storage( 181 | dataset_id=dataset_id, 182 | dataset_version=dataset_version, 183 | local_cache_filepath=local_cache_filepath, 184 | ).timestamp() 185 | logging.info("Cache download finished.") 186 | 187 | self._save_local_metadata( 188 | dataset_id=dataset_id, 189 | dataset_version=dataset_version, 190 | metadata=local_metadata, 191 | ) 192 | 193 | assert local_cache_filepath.exists() 194 | 195 | return dataref.LMDBDataRef(cache_filepath=local_cache_filepath) 196 | 197 | def cacheable(self, dataset_id: str, dataset_version: str) -> Callable: 198 | """ 199 | A decorator that calls submit and fetch and is responsible for coordinating 200 | amongst instantiations of Storage in different processes. 201 | 202 | Initially requests a read lock, if cache is not present in cloud storage, will request 203 | a write lock and submit to cloud storage. Once file is present in cloud storage, will 204 | request a read lock and fetch. 205 | """ 206 | 207 | def wrap(f: Callable) -> Callable: 208 | def create_dataref(*args: Any, **kwargs: Any) -> dataref.LMDBDataRef: 209 | local_lmdb_dataref = self._try_reading_from_cloud_storage( 210 | dataset_id=dataset_id, dataset_version=dataset_version 211 | ) 212 | 213 | if not local_lmdb_dataref: 214 | self._try_writing_to_cloud_storage( 215 | dataset_id=dataset_id, 216 | dataset_version=dataset_version, 217 | f=f, 218 | args=args, 219 | kwargs=kwargs, 220 | ) 221 | 222 | local_lmdb_dataref = self._try_reading_from_cloud_storage( 223 | dataset_id=dataset_id, dataset_version=dataset_version 224 | ) 225 | 226 | assert local_lmdb_dataref, "Unable to create dataref from cloud cache." 227 | 228 | return local_lmdb_dataref 229 | 230 | return create_dataref 231 | 232 | return wrap 233 | 234 | def _try_reading_from_cloud_storage( 235 | self, dataset_id: str, dataset_version: str 236 | ) -> Optional[dataref.LMDBDataRef]: 237 | remote_cache_path = self._get_remote_cache_filepath( 238 | dataset_id=dataset_id, dataset_version=dataset_version 239 | ) 240 | local_lmdb_dataref = None # type: Optional[dataref.LMDBDataRef] 241 | with self._rw_client.read_lock( 242 | storage_type=self._storage_type, 243 | bucket=self._configurations.bucket, 244 | cache_path=remote_cache_path, 245 | ): 246 | if self._is_cloud_cache_present(dataset_id=dataset_id, dataset_version=dataset_version): 247 | with self._lock_local_cache( 248 | dataset_id=dataset_id, 249 | dataset_version=dataset_version, 250 | ): 251 | local_lmdb_dataref = self.fetch( 252 | dataset_id=dataset_id, dataset_version=dataset_version 253 | ) 254 | 255 | return local_lmdb_dataref 256 | 257 | def _try_writing_to_cloud_storage( 258 | self, 259 | dataset_id: str, 260 | dataset_version: str, 261 | f: Callable, 262 | args: Any, 263 | kwargs: Any, 264 | ) -> None: 265 | remote_cache_path = self._get_remote_cache_filepath( 266 | dataset_id=dataset_id, dataset_version=dataset_version 267 | ) 268 | with self._rw_client.write_lock( 269 | storage_type=self._storage_type, 270 | bucket=self._configurations.bucket, 271 | cache_path=remote_cache_path, 272 | ): 273 | # It is possible that the cache was created while 274 | # the write lock was being acquired. 275 | if not self._is_cloud_cache_present( 276 | dataset_id=dataset_id, dataset_version=dataset_version 277 | ): 278 | with self._lock_local_cache( 279 | dataset_id=dataset_id, 280 | dataset_version=dataset_version, 281 | ): 282 | self.submit( 283 | data=f(*args, **kwargs), 284 | dataset_id=dataset_id, 285 | dataset_version=dataset_version, 286 | ) 287 | 288 | @contextlib.contextmanager 289 | def _lock_local_cache( 290 | self, dataset_id: str, dataset_version: str 291 | ) -> Generator[None, None, None]: 292 | lock_filepath = ( 293 | self._configurations.local_cache_dir.joinpath("yogadl_local_cache") 294 | .joinpath(dataset_id) 295 | .joinpath(dataset_version) 296 | .joinpath("yogadl.lock") 297 | ) 298 | lock_filepath.parent.mkdir(parents=True, exist_ok=True) 299 | 300 | # Blocks until access is granted. 301 | access_lock = filelock.FileLock(str(lock_filepath)) 302 | with access_lock.acquire(): 303 | yield 304 | 305 | def _get_remote_cache_filepath(self, dataset_id: str, dataset_version: str) -> pathlib.Path: 306 | assert dataset_id, "`dataset_id` must be a non-empty string." 307 | assert dataset_version, "`dataset_version` must be a non-empty string." 308 | return ( 309 | self._configurations.bucket_directory_path.joinpath(dataset_id) 310 | .joinpath(dataset_version) 311 | .joinpath("cache.mdb") 312 | ) 313 | 314 | def _get_local_cache_filepath(self, dataset_id: str, dataset_version: str) -> pathlib.Path: 315 | assert dataset_id, "`dataset_id` must be a non-empty string." 316 | assert dataset_version, "`dataset_version` must be a non-empty string." 317 | return ( 318 | self._configurations.local_cache_dir.joinpath("yogadl_local_cache") 319 | .joinpath(dataset_id) 320 | .joinpath(dataset_version) 321 | .joinpath("cache.mdb") 322 | ) 323 | 324 | def _get_local_metadata_filepath(self, dataset_id: str, dataset_version: str) -> pathlib.Path: 325 | return ( 326 | self._configurations.local_cache_dir.joinpath("yogadl_local_cache") 327 | .joinpath(dataset_id) 328 | .joinpath(dataset_version) 329 | .joinpath("local_metadata.json") 330 | ) 331 | 332 | def _get_local_metadata(self, dataset_id: str, dataset_version: str) -> Dict[str, Any]: 333 | metadata_filepath = self._get_local_metadata_filepath( 334 | dataset_id=dataset_id, 335 | dataset_version=dataset_version, 336 | ) 337 | 338 | if not metadata_filepath.exists(): 339 | return {} 340 | 341 | with open(str(metadata_filepath), "r") as metadata_file: 342 | return cast(Dict[str, Any], json.load(metadata_file)) 343 | 344 | def _save_local_metadata( 345 | self, dataset_id: str, dataset_version: str, metadata: Dict[str, Any] 346 | ) -> None: 347 | metadata_filepath = self._get_local_metadata_filepath( 348 | dataset_id=dataset_id, 349 | dataset_version=dataset_version, 350 | ) 351 | 352 | with open(str(metadata_filepath), "w") as metadata_file: 353 | json.dump(metadata, metadata_file) 354 | -------------------------------------------------------------------------------- /yogadl/storage/_gcs_storage.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020 Determined AI. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | import datetime 16 | import pathlib 17 | from typing import Optional 18 | 19 | import google.api_core.exceptions as gcp_exceptions 20 | import google.cloud.storage as google_storage 21 | import tensorflow as tf 22 | 23 | import yogadl.constants as constants 24 | from yogadl import storage 25 | 26 | 27 | class GCSConfigurations(storage.BaseCloudConfigurations): 28 | def __init__( 29 | self, 30 | bucket: str, 31 | bucket_directory_path: str, 32 | url: str, 33 | local_cache_dir: str = "/tmp/", 34 | skip_verify: bool = False, 35 | coordinator_cert_file: Optional[str] = None, 36 | coordinator_cert_name: Optional[str] = None, 37 | ) -> None: 38 | super().__init__( 39 | bucket=bucket, 40 | bucket_directory_path=bucket_directory_path, 41 | url=url, 42 | local_cache_dir=local_cache_dir, 43 | skip_verify=skip_verify, 44 | coordinator_cert_file=coordinator_cert_file, 45 | coordinator_cert_name=coordinator_cert_name, 46 | ) 47 | 48 | 49 | class GCSStorage(storage.BaseCloudStorage): 50 | """ 51 | Stores dataset cache in Google Cloud Storage (GCS). 52 | 53 | GCSStorage creates a local cache from a dataset and then uploads 54 | it to the specified GCS bucket. When fetching from GCS, the creation 55 | time of the local cache (recorded in metadata), is compared to the 56 | creation time of the GCS cache, if they are not equivalent, the 57 | local cache is overwritten. 58 | 59 | The GCS cache, and the local cache are potentially shared across a 60 | number of concurrent processes. `cacheable()` provides synchronization 61 | guarantees. Users should not call `submit()` and `fetch()` if they 62 | anticipate concurrent data accesses. 63 | 64 | Authentication is currently only supported via the "Application 65 | Default Credentials" method in GCP. Typical configuration: 66 | ensure your VM runs in a service account that has sufficient 67 | permissions to read/write/delete from the GCS bucket where 68 | checkpoints will be stored (this only works when running in GCE). 69 | """ 70 | 71 | def __init__( 72 | self, 73 | configurations: GCSConfigurations, 74 | tensorflow_config: Optional[tf.compat.v1.ConfigProto] = None, 75 | ): 76 | super().__init__(configurations=configurations, tensorflow_config=tensorflow_config) 77 | 78 | self._gcs_client = google_storage.Client() 79 | self._bucket = self._gcs_client.bucket(self._configurations.bucket) 80 | 81 | self._check_configurations() 82 | 83 | def _check_configurations(self) -> None: 84 | assert self._configurations.local_cache_dir.is_dir() 85 | assert self._configurations.cache_format in self._supported_cache_formats 86 | assert self._bucket.exists() 87 | 88 | @property 89 | def _storage_type(self) -> str: 90 | return constants.GCS_STORAGE 91 | 92 | def _is_cloud_cache_present(self, dataset_id: str, dataset_version: str) -> bool: 93 | 94 | gcs_cache_filepath = self._get_remote_cache_filepath( 95 | dataset_id=dataset_id, 96 | dataset_version=dataset_version, 97 | ) 98 | blob = self._bucket.blob(str(gcs_cache_filepath)) 99 | 100 | blob_exists = blob.exists() 101 | assert isinstance(blob_exists, bool) 102 | 103 | return blob_exists 104 | 105 | def _download_from_cloud_storage( 106 | self, dataset_id: str, dataset_version: str, local_cache_filepath: pathlib.Path 107 | ) -> datetime.datetime: 108 | 109 | gcs_cache_filepath = self._get_remote_cache_filepath( 110 | dataset_id=dataset_id, 111 | dataset_version=dataset_version, 112 | ) 113 | blob = self._bucket.blob(str(gcs_cache_filepath)) 114 | 115 | assert ( 116 | blob.exists() 117 | ), f"Downloading non-existent blob {self._configurations.bucket}/{gcs_cache_filepath}." 118 | 119 | try: 120 | blob.download_to_filename(str(local_cache_filepath)) 121 | except gcp_exceptions.GoogleAPICallError as e: 122 | raise AssertionError( 123 | f"Downloading blob {self._configurations.bucket}" 124 | f"/{gcs_cache_filepath} failed with exception {e}." 125 | ) 126 | 127 | return self._get_remote_cache_timestamp( 128 | dataset_id=dataset_id, dataset_version=dataset_version 129 | ) 130 | 131 | def _upload_to_cloud_storage( 132 | self, dataset_id: str, dataset_version: str, local_cache_filepath: pathlib.Path 133 | ) -> datetime.datetime: 134 | 135 | gcs_cache_filepath = self._get_remote_cache_filepath( 136 | dataset_id=dataset_id, 137 | dataset_version=dataset_version, 138 | ) 139 | blob = self._bucket.blob(str(gcs_cache_filepath)) 140 | 141 | try: 142 | blob.upload_from_filename(str(local_cache_filepath)) 143 | except gcp_exceptions.GoogleAPICallError as e: 144 | raise AssertionError( 145 | f"Upload from {local_cache_filepath} to {self._configurations.bucket}" 146 | f"/{gcs_cache_filepath} failed with exception {e}." 147 | ) 148 | 149 | # Do not need to `reload()` to get latest blob metadata after upload. 150 | assert blob.time_created 151 | assert isinstance(blob.time_created, datetime.datetime) 152 | 153 | return blob.time_created 154 | 155 | def _get_remote_cache_timestamp( 156 | self, dataset_id: str, dataset_version: str 157 | ) -> datetime.datetime: 158 | 159 | gcs_cache_filepath = self._get_remote_cache_filepath( 160 | dataset_id=dataset_id, 161 | dataset_version=dataset_version, 162 | ) 163 | blob = self._bucket.blob(str(gcs_cache_filepath)) 164 | 165 | try: 166 | blob.reload() 167 | except gcp_exceptions.GoogleAPICallError as e: 168 | raise AssertionError( 169 | f"Getting metadata of {self._configurations.bucket}" 170 | f"/{gcs_cache_filepath} failed with exception {e}." 171 | ) 172 | 173 | assert isinstance(blob.time_created, datetime.datetime) 174 | 175 | return blob.time_created 176 | -------------------------------------------------------------------------------- /yogadl/storage/_lfs_storage.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020 Determined AI. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | import contextlib 16 | import logging 17 | import pathlib 18 | from typing import Any, Callable, Generator, Optional 19 | 20 | import filelock 21 | import tensorflow as tf 22 | 23 | import yogadl 24 | from yogadl import dataref, tensorflow 25 | 26 | 27 | class LFSConfigurations: 28 | """ 29 | Configurations for LFSStorage. 30 | """ 31 | 32 | def __init__(self, storage_dir_path: str): 33 | self.storage_dir_path = pathlib.Path(storage_dir_path) 34 | self.cache_format = "LMDB" 35 | 36 | 37 | class LFSStorage(yogadl.Storage): 38 | """ 39 | Storage for local file system (not NFS). 40 | """ 41 | 42 | def __init__( 43 | self, 44 | configurations: LFSConfigurations, 45 | tensorflow_config: Optional[tf.compat.v1.ConfigProto] = None, 46 | ): 47 | self._configurations = configurations 48 | self._supported_cache_formats = ["LMDB"] 49 | self._tensorflow_config = tensorflow_config 50 | 51 | self._check_configurations() 52 | 53 | def _check_configurations(self) -> None: 54 | assert self._configurations.storage_dir_path.is_dir() 55 | assert self._configurations.cache_format in self._supported_cache_formats 56 | 57 | def submit(self, data: yogadl.Submittable, dataset_id: str, dataset_version: str) -> None: 58 | """ 59 | Stores dataset to a cache and updates metadata file with information. 60 | 61 | If a cache with a matching filepath already exists, it will be overwritten. 62 | """ 63 | cache_filepath = self._get_cache_filepath( 64 | dataset_id=dataset_id, 65 | dataset_version=dataset_version, 66 | ) 67 | cache_filepath.parent.mkdir(parents=True, exist_ok=True) 68 | 69 | if cache_filepath.exists(): 70 | logging.info(f"Removing old cache: {cache_filepath}.") 71 | cache_filepath.unlink() 72 | 73 | # TODO: remove TF hardcoding. 74 | tensorflow.serialize_tf_dataset_to_lmdb( 75 | dataset=data, checkpoint_path=cache_filepath, tf_config=self._tensorflow_config 76 | ) 77 | logging.info(f"Serialized dataset {dataset_id}:{dataset_version} to: {cache_filepath}.") 78 | 79 | def fetch(self, dataset_id: str, dataset_version: str) -> dataref.LMDBDataRef: 80 | """ 81 | Fetch a dataset from storage and provide a DataRef 82 | for streaming it. 83 | """ 84 | cache_filepath = self._get_cache_filepath( 85 | dataset_id=dataset_id, dataset_version=dataset_version 86 | ) 87 | assert cache_filepath.exists() 88 | 89 | return dataref.LMDBDataRef(cache_filepath=cache_filepath) 90 | 91 | def cacheable(self, dataset_id: str, dataset_version: str) -> Callable: 92 | """ 93 | A decorator that calls submit and fetch and is responsible 94 | for coordinating amongst instantiations of Storage in different 95 | processes. 96 | """ 97 | 98 | def wrap(f: Callable) -> Callable: 99 | def create_dataref(*args: Any, **kwargs: Any) -> dataref.LMDBDataRef: 100 | with self._lock_this_dataset_version(dataset_id, dataset_version): 101 | cache_filepath = self._get_cache_filepath( 102 | dataset_id=dataset_id, dataset_version=dataset_version 103 | ) 104 | 105 | if not cache_filepath.exists(): 106 | self.submit( 107 | data=f(*args, **kwargs), 108 | dataset_id=dataset_id, 109 | dataset_version=dataset_version, 110 | ) 111 | 112 | return self.fetch(dataset_id=dataset_id, dataset_version=dataset_version) 113 | 114 | return create_dataref 115 | 116 | return wrap 117 | 118 | @contextlib.contextmanager 119 | def _lock_this_dataset_version( 120 | self, dataset_id: str, dataset_version: str 121 | ) -> Generator[None, None, None]: 122 | lock_filepath = ( 123 | self._configurations.storage_dir_path.joinpath(dataset_id) 124 | .joinpath(dataset_version) 125 | .joinpath("yogadl.lock") 126 | ) 127 | lock_filepath.parent.mkdir(parents=True, exist_ok=True) 128 | 129 | # Blocks until access is granted. 130 | access_lock = filelock.FileLock(str(lock_filepath)) 131 | with access_lock.acquire(): 132 | yield 133 | 134 | def _get_cache_filepath(self, dataset_id: str, dataset_version: str) -> pathlib.Path: 135 | assert dataset_id, "`dataset_id` must be a non-empty string." 136 | assert dataset_version, "`dataset_version` must be a non-empty string." 137 | return ( 138 | self._configurations.storage_dir_path.joinpath(dataset_id) 139 | .joinpath(dataset_version) 140 | .joinpath("cache.mdb") 141 | ) 142 | -------------------------------------------------------------------------------- /yogadl/storage/_s3_storage.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020 Determined AI. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | import datetime 16 | import pathlib 17 | from typing import Optional 18 | 19 | import boto3 20 | import botocore.client as boto_client 21 | import tensorflow as tf 22 | 23 | import yogadl.constants as constants 24 | from yogadl import storage 25 | 26 | 27 | class S3Configurations(storage.BaseCloudConfigurations): 28 | def __init__( 29 | self, 30 | bucket: str, 31 | bucket_directory_path: str, 32 | url: str, 33 | access_key: Optional[str] = None, 34 | secret_key: Optional[str] = None, 35 | endpoint_url: Optional[str] = None, 36 | local_cache_dir: str = "/tmp/", 37 | skip_verify: bool = False, 38 | coordinator_cert_file: Optional[str] = None, 39 | coordinator_cert_name: Optional[str] = None, 40 | ) -> None: 41 | super().__init__( 42 | bucket=bucket, 43 | bucket_directory_path=bucket_directory_path, 44 | url=url, 45 | local_cache_dir=local_cache_dir, 46 | skip_verify=skip_verify, 47 | coordinator_cert_file=coordinator_cert_file, 48 | coordinator_cert_name=coordinator_cert_name, 49 | ) 50 | self.access_key = access_key 51 | self.secret_key = secret_key 52 | self.endpoint_url = endpoint_url 53 | 54 | 55 | class S3Storage(storage.BaseCloudStorage): 56 | """ 57 | Stores dataset cache in AWS S3. 58 | 59 | S3Storage creates a local cache from a dataset and then uploads 60 | it to the specified S3 bucket. When fetching from S3, the creation 61 | time of the local cache (recorded in metadata), is compared to the 62 | creation time of the S3 cache, if they are not equivalent, the 63 | local cache is overwritten. 64 | 65 | The S3 cache, and the local cache are potentially shared across a 66 | number of concurrent processes. `cacheable()` provides synchronization 67 | guarantees. Users should not call `submit()` and `fetch()` if they 68 | anticipate concurrent data accesses. 69 | """ 70 | 71 | def __init__( 72 | self, 73 | configurations: S3Configurations, 74 | tensorflow_config: Optional[tf.compat.v1.ConfigProto] = None, 75 | ) -> None: 76 | super().__init__(configurations=configurations, tensorflow_config=tensorflow_config) 77 | 78 | assert isinstance(self._configurations, S3Configurations) 79 | self._client = boto3.client( 80 | "s3", 81 | aws_access_key_id=self._configurations.access_key, 82 | aws_secret_access_key=self._configurations.secret_key, 83 | endpoint_url=self._configurations.endpoint_url, 84 | ) 85 | 86 | self._check_configurations() 87 | 88 | def _check_configurations(self) -> None: 89 | assert self._configurations.local_cache_dir.is_dir() 90 | assert self._configurations.cache_format in self._supported_cache_formats 91 | 92 | try: 93 | self._client.head_bucket(Bucket=self._configurations.bucket) 94 | except boto_client.ClientError as error: 95 | raise AssertionError( 96 | f"Unable to access bucket {self._configurations.bucket}. " 97 | f"Failed with exception: {error}." 98 | ) 99 | 100 | @property 101 | def _storage_type(self) -> str: 102 | return constants.AWS_STORAGE 103 | 104 | def _is_cloud_cache_present(self, dataset_id: str, dataset_version: str) -> bool: 105 | 106 | s3_cache_filepath = self._get_remote_cache_filepath( 107 | dataset_id=dataset_id, 108 | dataset_version=dataset_version, 109 | ) 110 | 111 | try: 112 | self._client.head_object(Bucket=self._configurations.bucket, Key=str(s3_cache_filepath)) 113 | cloud_cache_exists = True 114 | except boto_client.ClientError: 115 | cloud_cache_exists = False 116 | 117 | return cloud_cache_exists 118 | 119 | def _download_from_cloud_storage( 120 | self, dataset_id: str, dataset_version: str, local_cache_filepath: pathlib.Path 121 | ) -> datetime.datetime: 122 | 123 | s3_cache_filepath = self._get_remote_cache_filepath( 124 | dataset_id=dataset_id, 125 | dataset_version=dataset_version, 126 | ) 127 | 128 | try: 129 | self._client.download_file( 130 | Bucket=self._configurations.bucket, 131 | Key=str(s3_cache_filepath), 132 | Filename=str(local_cache_filepath), 133 | ) 134 | except boto_client.ClientError as error: 135 | raise AssertionError( 136 | f"Downloading blob {self._configurations.bucket}" 137 | f"/{s3_cache_filepath}. Failed with exception {error}." 138 | ) 139 | 140 | return self._get_remote_cache_timestamp( 141 | dataset_id=dataset_id, dataset_version=dataset_version 142 | ) 143 | 144 | def _upload_to_cloud_storage( 145 | self, dataset_id: str, dataset_version: str, local_cache_filepath: pathlib.Path 146 | ) -> datetime.datetime: 147 | 148 | s3_cache_filepath = self._get_remote_cache_filepath( 149 | dataset_id=dataset_id, 150 | dataset_version=dataset_version, 151 | ) 152 | 153 | try: 154 | self._client.upload_file( 155 | Filename=str(local_cache_filepath), 156 | Bucket=self._configurations.bucket, 157 | Key=str(s3_cache_filepath), 158 | ) 159 | except boto3.exceptions.S3UploadFailedError as error: 160 | raise AssertionError(f"Failed to upload file to S3 with exception: {error}.") 161 | 162 | return self._get_remote_cache_timestamp( 163 | dataset_id=dataset_id, 164 | dataset_version=dataset_version, 165 | ) 166 | 167 | def _get_remote_cache_timestamp( 168 | self, dataset_id: str, dataset_version: str 169 | ) -> datetime.datetime: 170 | 171 | s3_cache_filepath = self._get_remote_cache_filepath( 172 | dataset_id=dataset_id, 173 | dataset_version=dataset_version, 174 | ) 175 | 176 | try: 177 | s3_object_info = self._client.head_object( 178 | Bucket=self._configurations.bucket, Key=str(s3_cache_filepath) 179 | ) 180 | except boto_client.ClientError as error: 181 | raise AssertionError( 182 | f"Unable to look up metadata for {self._configurations.bucket}/" 183 | f"{s3_cache_filepath}. Failed with exception: {error}." 184 | ) 185 | 186 | timestamp = s3_object_info.get("LastModified") 187 | assert isinstance(timestamp, datetime.datetime) 188 | 189 | return timestamp 190 | -------------------------------------------------------------------------------- /yogadl/tensorflow.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020 Determined AI. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | import pathlib 16 | from typing import Any, Generator, Optional, Tuple 17 | 18 | import tensorflow as tf 19 | 20 | import yogadl 21 | 22 | 23 | def read_tf_dataset_eager_mode(dataset: tf.data.Dataset) -> Generator[Tuple[Any, bool], None, None]: 24 | # TODO: If repeat() has been applied we will hit an infinite 25 | # loop here. Probably best approach is to include log message 26 | # specifying how many data items we have read and this should 27 | # alert the user if we are stuck in an infinite loop. 28 | for next_element in dataset.as_numpy_iterator(): 29 | yield next_element 30 | 31 | 32 | def read_tf_dataset_graph_mode( 33 | dataset: tf.data.Dataset, tf_config: Optional[tf.compat.v1.ConfigProto] 34 | ) -> Generator[Tuple[Any, bool], None, None]: 35 | get_next_element = tf.compat.v1.data.make_one_shot_iterator(dataset).get_next() 36 | with tf.compat.v1.Session(config=tf_config) as sess: 37 | while True: 38 | try: 39 | # TODO: If repeat() has been applied we will hit an infinite 40 | # loop here. Probably best approach is to include log message 41 | # specifying how many data items we have read and this should 42 | # alert the user if we are stuck in an infinite loop. 43 | yield sess.run(get_next_element) 44 | except tf.errors.OutOfRangeError: 45 | break 46 | 47 | 48 | def read_tf_dataset( 49 | dataset: tf.data.Dataset, tf_config: Optional[tf.compat.v1.ConfigProto] 50 | ) -> Generator[Tuple[Any, bool], None, None]: 51 | if tf.executing_eagerly(): 52 | return read_tf_dataset_eager_mode(dataset) 53 | else: 54 | return read_tf_dataset_graph_mode(dataset, tf_config=tf_config) 55 | 56 | 57 | def serialize_tf_dataset_to_lmdb( 58 | dataset: tf.data.Dataset, 59 | checkpoint_path: pathlib.Path, 60 | tf_config: Optional[tf.compat.v1.ConfigProto], 61 | write_frequency: int = 5000, 62 | ) -> int: 63 | assert isinstance(dataset, tf.data.Dataset) 64 | return yogadl.serialize_generator_to_lmdb( 65 | dataset_generator=read_tf_dataset(dataset=dataset, tf_config=tf_config), 66 | data_shapes=tf.compat.v1.data.get_output_shapes(dataset), 67 | data_types=tf.compat.v1.data.get_output_types(dataset), 68 | lmdb_path=checkpoint_path, 69 | write_frequency=write_frequency, 70 | ) 71 | 72 | 73 | def make_tf_dataset(stream: yogadl.Stream) -> tf.data.Dataset: 74 | """ 75 | Produce a tf.data.Dataset from a yogadl.Stream. 76 | """ 77 | return tf.data.Dataset.from_generator( 78 | stream.iterator_fn, output_types=stream.output_types, output_shapes=stream.output_shapes 79 | ) 80 | --------------------------------------------------------------------------------