├── CONTRIBUTING.md ├── LICENSE ├── README.md ├── dmvr ├── __init__.py ├── builders.py ├── builders_test.py ├── modalities.py ├── modalities_test.py ├── processors.py ├── processors_test.py ├── sources.py ├── sources_test.py ├── testdata │ ├── sample.jpeg │ └── tokenizers │ │ ├── bert_word_vocab.txt │ │ ├── spiece.model.1000.model │ │ └── word_vocab.txt ├── tokenizers.py ├── tokenizers_test.py ├── utils.py ├── utils_test.py ├── video_dataset.py └── video_dataset_test.py ├── examples ├── README.md ├── generate_from_file.py ├── generate_hmdb_csv.py ├── hmdb.py └── linear_mmv_hmdb.py ├── requirements-test.txt ├── requirements.txt ├── setup.py └── test.sh /CONTRIBUTING.md: -------------------------------------------------------------------------------- 1 | # How to Contribute 2 | 3 | We'd love to accept your patches and contributions to this project. There are 4 | just a few small guidelines you need to follow. 5 | 6 | ## Contributor License Agreement 7 | 8 | Contributions to this project must be accompanied by a Contributor License 9 | Agreement (CLA). You (or your employer) retain the copyright to your 10 | contribution; this simply gives us permission to use and redistribute your 11 | contributions as part of the project. Head over to 12 | to see your current agreements on file or 13 | to sign a new one. 14 | 15 | You generally only need to submit a CLA once, so if you've already submitted one 16 | (even if it was for a different project), you probably don't need to do it 17 | again. 18 | 19 | ## Code Reviews 20 | 21 | All submissions, including submissions by project members, require review. We 22 | use GitHub pull requests for this purpose. Consult 23 | [GitHub Help](https://help.github.com/articles/about-pull-requests/) for more 24 | information on using pull requests. 25 | 26 | ## Community Guidelines 27 | 28 | This project follows 29 | [Google's Open Source Community Guidelines](https://opensource.google/conduct/). 30 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | 2 | Apache License 3 | Version 2.0, January 2004 4 | http://www.apache.org/licenses/ 5 | 6 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 7 | 8 | 1. Definitions. 9 | 10 | "License" shall mean the terms and conditions for use, reproduction, 11 | and distribution as defined by Sections 1 through 9 of this document. 12 | 13 | "Licensor" shall mean the copyright owner or entity authorized by 14 | the copyright owner that is granting the License. 15 | 16 | "Legal Entity" shall mean the union of the acting entity and all 17 | other entities that control, are controlled by, or are under common 18 | control with that entity. For the purposes of this definition, 19 | "control" means (i) the power, direct or indirect, to cause the 20 | direction or management of such entity, whether by contract or 21 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 22 | outstanding shares, or (iii) beneficial ownership of such entity. 23 | 24 | "You" (or "Your") shall mean an individual or Legal Entity 25 | exercising permissions granted by this License. 26 | 27 | "Source" form shall mean the preferred form for making modifications, 28 | including but not limited to software source code, documentation 29 | source, and configuration files. 30 | 31 | "Object" form shall mean any form resulting from mechanical 32 | transformation or translation of a Source form, including but 33 | not limited to compiled object code, generated documentation, 34 | and conversions to other media types. 35 | 36 | "Work" shall mean the work of authorship, whether in Source or 37 | Object form, made available under the License, as indicated by a 38 | copyright notice that is included in or attached to the work 39 | (an example is provided in the Appendix below). 40 | 41 | "Derivative Works" shall mean any work, whether in Source or Object 42 | form, that is based on (or derived from) the Work and for which the 43 | editorial revisions, annotations, elaborations, or other modifications 44 | represent, as a whole, an original work of authorship. For the purposes 45 | of this License, Derivative Works shall not include works that remain 46 | separable from, or merely link (or bind by name) to the interfaces of, 47 | the Work and Derivative Works thereof. 48 | 49 | "Contribution" shall mean any work of authorship, including 50 | the original version of the Work and any modifications or additions 51 | to that Work or Derivative Works thereof, that is intentionally 52 | submitted to Licensor for inclusion in the Work by the copyright owner 53 | or by an individual or Legal Entity authorized to submit on behalf of 54 | the copyright owner. For the purposes of this definition, "submitted" 55 | means any form of electronic, verbal, or written communication sent 56 | to the Licensor or its representatives, including but not limited to 57 | communication on electronic mailing lists, source code control systems, 58 | and issue tracking systems that are managed by, or on behalf of, the 59 | Licensor for the purpose of discussing and improving the Work, but 60 | excluding communication that is conspicuously marked or otherwise 61 | designated in writing by the copyright owner as "Not a Contribution." 62 | 63 | "Contributor" shall mean Licensor and any individual or Legal Entity 64 | on behalf of whom a Contribution has been received by Licensor and 65 | subsequently incorporated within the Work. 66 | 67 | 2. Grant of Copyright License. Subject to the terms and conditions of 68 | this License, each Contributor hereby grants to You a perpetual, 69 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 70 | copyright license to reproduce, prepare Derivative Works of, 71 | publicly display, publicly perform, sublicense, and distribute the 72 | Work and such Derivative Works in Source or Object form. 73 | 74 | 3. Grant of Patent License. Subject to the terms and conditions of 75 | this License, each Contributor hereby grants to You a perpetual, 76 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 77 | (except as stated in this section) patent license to make, have made, 78 | use, offer to sell, sell, import, and otherwise transfer the Work, 79 | where such license applies only to those patent claims licensable 80 | by such Contributor that are necessarily infringed by their 81 | Contribution(s) alone or by combination of their Contribution(s) 82 | with the Work to which such Contribution(s) was submitted. If You 83 | institute patent litigation against any entity (including a 84 | cross-claim or counterclaim in a lawsuit) alleging that the Work 85 | or a Contribution incorporated within the Work constitutes direct 86 | or contributory patent infringement, then any patent licenses 87 | granted to You under this License for that Work shall terminate 88 | as of the date such litigation is filed. 89 | 90 | 4. Redistribution. You may reproduce and distribute copies of the 91 | Work or Derivative Works thereof in any medium, with or without 92 | modifications, and in Source or Object form, provided that You 93 | meet the following conditions: 94 | 95 | (a) You must give any other recipients of the Work or 96 | Derivative Works a copy of this License; and 97 | 98 | (b) You must cause any modified files to carry prominent notices 99 | stating that You changed the files; and 100 | 101 | (c) You must retain, in the Source form of any Derivative Works 102 | that You distribute, all copyright, patent, trademark, and 103 | attribution notices from the Source form of the Work, 104 | excluding those notices that do not pertain to any part of 105 | the Derivative Works; and 106 | 107 | (d) If the Work includes a "NOTICE" text file as part of its 108 | distribution, then any Derivative Works that You distribute must 109 | include a readable copy of the attribution notices contained 110 | within such NOTICE file, excluding those notices that do not 111 | pertain to any part of the Derivative Works, in at least one 112 | of the following places: within a NOTICE text file distributed 113 | as part of the Derivative Works; within the Source form or 114 | documentation, if provided along with the Derivative Works; or, 115 | within a display generated by the Derivative Works, if and 116 | wherever such third-party notices normally appear. The contents 117 | of the NOTICE file are for informational purposes only and 118 | do not modify the License. You may add Your own attribution 119 | notices within Derivative Works that You distribute, alongside 120 | or as an addendum to the NOTICE text from the Work, provided 121 | that such additional attribution notices cannot be construed 122 | as modifying the License. 123 | 124 | You may add Your own copyright statement to Your modifications and 125 | may provide additional or different license terms and conditions 126 | for use, reproduction, or distribution of Your modifications, or 127 | for any such Derivative Works as a whole, provided Your use, 128 | reproduction, and distribution of the Work otherwise complies with 129 | the conditions stated in this License. 130 | 131 | 5. Submission of Contributions. Unless You explicitly state otherwise, 132 | any Contribution intentionally submitted for inclusion in the Work 133 | by You to the Licensor shall be under the terms and conditions of 134 | this License, without any additional terms or conditions. 135 | Notwithstanding the above, nothing herein shall supersede or modify 136 | the terms of any separate license agreement you may have executed 137 | with Licensor regarding such Contributions. 138 | 139 | 6. Trademarks. This License does not grant permission to use the trade 140 | names, trademarks, service marks, or product names of the Licensor, 141 | except as required for reasonable and customary use in describing the 142 | origin of the Work and reproducing the content of the NOTICE file. 143 | 144 | 7. Disclaimer of Warranty. Unless required by applicable law or 145 | agreed to in writing, Licensor provides the Work (and each 146 | Contributor provides its Contributions) on an "AS IS" BASIS, 147 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 148 | implied, including, without limitation, any warranties or conditions 149 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 150 | PARTICULAR PURPOSE. You are solely responsible for determining the 151 | appropriateness of using or redistributing the Work and assume any 152 | risks associated with Your exercise of permissions under this License. 153 | 154 | 8. Limitation of Liability. In no event and under no legal theory, 155 | whether in tort (including negligence), contract, or otherwise, 156 | unless required by applicable law (such as deliberate and grossly 157 | negligent acts) or agreed to in writing, shall any Contributor be 158 | liable to You for damages, including any direct, indirect, special, 159 | incidental, or consequential damages of any character arising as a 160 | result of this License or out of the use or inability to use the 161 | Work (including but not limited to damages for loss of goodwill, 162 | work stoppage, computer failure or malfunction, or any and all 163 | other commercial damages or losses), even if such Contributor 164 | has been advised of the possibility of such damages. 165 | 166 | 9. Accepting Warranty or Additional Liability. While redistributing 167 | the Work or Derivative Works thereof, You may choose to offer, 168 | and charge a fee for, acceptance of support, warranty, indemnity, 169 | or other liability obligations and/or rights consistent with this 170 | License. However, in accepting such obligations, You may act only 171 | on Your own behalf and on Your sole responsibility, not on behalf 172 | of any other Contributor, and only if You agree to indemnify, 173 | defend, and hold each Contributor harmless for any liability 174 | incurred by, or claims asserted against, such Contributor by reason 175 | of your accepting any such warranty or additional liability. 176 | 177 | END OF TERMS AND CONDITIONS 178 | 179 | APPENDIX: How to apply the Apache License to your work. 180 | 181 | To apply the Apache License to your work, attach the following 182 | boilerplate notice, with the fields enclosed by brackets "[]" 183 | replaced with your own identifying information. (Don't include 184 | the brackets!) The text should be enclosed in the appropriate 185 | comment syntax for the file format. We also recommend that a 186 | file or class name and description of purpose be included on the 187 | same "printed page" as the copyright notice for easier 188 | identification within third-party archives. 189 | 190 | Copyright [yyyy] [name of copyright owner] 191 | 192 | Licensed under the Apache License, Version 2.0 (the "License"); 193 | you may not use this file except in compliance with the License. 194 | You may obtain a copy of the License at 195 | 196 | http://www.apache.org/licenses/LICENSE-2.0 197 | 198 | Unless required by applicable law or agreed to in writing, software 199 | distributed under the License is distributed on an "AS IS" BASIS, 200 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 201 | See the License for the specific language governing permissions and 202 | limitations under the License. 203 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # DMVR: DeepMind Video Readers 2 | 3 | DMVR is a library providing a framework for easily reading raw data and 4 | producing `tf.data.Dataset` objects ready to be consumed by models. 5 | 6 | ## Design principles 7 | 8 | ### Data processing graph 9 | 10 | The main idea of the framework is to build a customizable and reusable data 11 | processing graph that when applied to raw data files, will produce final dataset 12 | objects. Building blocks called Builders are used to interact with the graph by 13 | adding, removing or replacing data processing blocks. 14 | 15 | Dataset providers can write a Factory with a default data processing graph for 16 | each dataset. Dataset consumers can customize the graph to their needs either by 17 | creating a child Factory or just appending a given instance. Factory objects 18 | expose instances of Builders allowing control of the multiple phases of the data 19 | processing graph. The Factory is then able to generate `tf.data.Dataset` 20 | objects. 21 | 22 | ### Phases 23 | 24 | The data processing graph is split in multipple phases. This abstraction is 25 | purely semantic, which makes code easier to reuse. The phases are: 26 | 27 | - Parse 28 | - Sample 29 | - Decode 30 | - Preprocess 31 | - Postprocess 32 | 33 | ### Modalities 34 | 35 | In order to easily add different modalities to the dataset from the raw data, 36 | sub graphs for some modalities with default processing (e.g. sample, decode and 37 | crop for images) is provided. These sub graphs can be added by simply calling 38 | the corresponding methods for the Builders. 39 | 40 | ## Usage 41 | 42 | ### Dataset providers 43 | 44 | Dataset providers should implement a factory populating the default graph. 45 | 46 | Example: 47 | 48 | - Data is stored in TFRecords as `tf.train.SequenceExample` objects. 49 | 50 | ```python 51 | from typing import List 52 | 53 | from dmvr import modalities 54 | from dmvr import video_dataset 55 | 56 | class Kinetics700Factory(video_dataset.BaseVideoDatasetFactory): 57 | 58 | _NUM_CLASSES = 700 59 | 60 | def __init__(self, subset: str): 61 | self._is_training = subset == 'train' 62 | shards: List[str] = path_to_the_data(subset) 63 | super().__init__(shards) 64 | 65 | def _build(self, 66 | # Video related parameters. 67 | num_frames: int = 32, 68 | stride: int = 1, 69 | num_test_clips: int = 1, 70 | min_resize: int = 224, 71 | crop_size: int = 200, 72 | zero_centering_image: bool = False, 73 | # Label related parameters. 74 | one_hot_label: bool = True, 75 | add_label_name: bool = False): 76 | """Build default data processing graph.""" 77 | modalities.add_image(parser_builder=self.parser_builder, 78 | sampler_builder=self.sampler_builder, 79 | decoder_builder=self.decoder_builder, 80 | preprocessor_builder=self.preprocessor_builder, 81 | postprocessor_builder=self.postprocessor_builder, 82 | is_training=self._is_training, 83 | num_frames=num_frames, 84 | stride=stride, 85 | min_resize=min_resize, 86 | crop_size=crop_size, 87 | zero_centering_image=zero_centering_image) 88 | 89 | modalities.add_label(parser_builder=self.parser_builder, 90 | decoder_builder=self.decoder_builder, 91 | preprocessor_builder=self.preprocessor_builder, 92 | one_hot_label=one_hot_label, 93 | num_classes=self._NUM_CLASSES, 94 | add_label_name=add_label_name) 95 | ``` 96 | 97 | ### Dataset consumers 98 | 99 | Dataset consumers can create `tf.data.Dataset` objects from a factory instance. 100 | 101 | Example: 102 | 103 | ```python 104 | factory = Kinetics700Factory('train') 105 | factory.configure(num_frames=16) 106 | ds = factory.make_dataset(batch_size=8) 107 | ``` 108 | 109 | The user can also customize the data processing graph by adding more functions: 110 | 111 | ```python 112 | from dmvr import builders 113 | from dmvr import processors 114 | 115 | factory = Kinetics700Factory('train') 116 | factory.configure(num_frames=16) 117 | 118 | factory.preprocess_builder.add_fn(processors.scale_jitter_augm, 119 | feature_name=builders.IMAGE_FEATURE_NAME) 120 | factory.preprocess_builder.add_fn(processors.color_default_augm, 121 | feature_name=builders.IMAGE_FEATURE_NAME) 122 | 123 | ds = factory.make_dataset(batch_size=8) 124 | ``` 125 | 126 | ## Installation 127 | 128 | DMVR can be installed with pip directly from github, with the following command: 129 | 130 | pip install git+git://github.com/deepmind/dmvr.git 131 | 132 | Python 3.9+ is required in order for all features to be available. 133 | -------------------------------------------------------------------------------- /dmvr/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 DeepMind Technologies Limited. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | -------------------------------------------------------------------------------- /dmvr/builders_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 DeepMind Technologies Limited. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Tests for builders.""" 16 | 17 | from dmvr import builders 18 | from parameterized import parameterized 19 | import tensorflow as tf 20 | 21 | 22 | class SequenceExampleParserBuilderTest(tf.test.TestCase): 23 | 24 | def setUp(self): 25 | super().setUp() 26 | 27 | # Prepare SequenceExample. 28 | seq_example = tf.train.SequenceExample() 29 | seq_example.context.feature.get_or_create( 30 | 'my_context_feature').int64_list.value[:] = [0, 1] 31 | seq_example.feature_lists.feature_list.get_or_create( 32 | 'my_seq_feature').feature.add().int64_list.value[:] = [2, 3] 33 | seq_example.feature_lists.feature_list.get( 34 | 'my_seq_feature').feature.add().int64_list.value[:] = [4, 5] 35 | seq_example.feature_lists.feature_list.get_or_create( 36 | 'my_var_len_seq_feature').feature.add().int64_list.value[:] = [6] 37 | seq_example.feature_lists.feature_list.get( 38 | 'my_var_len_seq_feature').feature.add().int64_list.value[:] = [7, 8] 39 | 40 | # Put SequenceExample in expected format. 41 | self._raw_seq_example = tf.constant(seq_example.SerializeToString()) 42 | 43 | def test_parse(self): 44 | parse_fn = ( 45 | builders.SequenceExampleParserBuilder() 46 | .parse_feature('my_context_feature', 47 | tf.io.FixedLenFeature((2,), dtype=tf.int64), 48 | 'context_name', True) 49 | .parse_feature('my_seq_feature', 50 | tf.io.FixedLenSequenceFeature((2,), dtype=tf.int64), 51 | 'seq_name') 52 | .parse_feature('my_var_len_seq_feature', 53 | tf.io.VarLenFeature(dtype=tf.int64), 54 | 'var_len_seq_name') 55 | .build()) 56 | features_dict = parse_fn(self._raw_seq_example) 57 | 58 | self.assertSetEqual(set(['context_name', 'seq_name', 'var_len_seq_name']), 59 | set(features_dict.keys())) 60 | self.assertAllEqual(features_dict['context_name'], [0, 1]) 61 | self.assertAllEqual(features_dict['seq_name'], [[2, 3], [4, 5]]) 62 | self.assertAllEqual(features_dict['var_len_seq_name'].values, [6, 7, 8]) 63 | self.assertAllEqual(features_dict['var_len_seq_name'].indices, 64 | [[0, 0], [1, 0], [1, 1]]) 65 | self.assertAllEqual(features_dict['var_len_seq_name'].dense_shape, [2, 2]) 66 | 67 | def test_fake_data(self): 68 | parser = builders.SequenceExampleParserBuilder() 69 | parser.parse_feature('my_context_feature', 70 | tf.io.FixedLenFeature((2,), dtype=tf.int64), 71 | 'context_name', True) 72 | parser.parse_feature('my_seq_feature', 73 | tf.io.FixedLenSequenceFeature((2,), dtype=tf.int64), 74 | 'seq_name') 75 | parser.parse_feature('my_var_len_seq_feature', 76 | tf.io.VarLenFeature(dtype=tf.int64), 77 | 'var_len_seq_name') 78 | fake_data = parser.get_fake_data(default_values={ 79 | 'context_name': (0, 1), 'var_len_seq_name': ((1, 2), (3, 4, 5))}) 80 | self.assertSetEqual(set(['context_name', 'seq_name', 'var_len_seq_name']), 81 | set(fake_data.keys())) 82 | self.assertAllEqual(fake_data['context_name'], [0, 1]) 83 | self.assertAllEqual(fake_data['seq_name'], [[0, 0]]) 84 | self.assertAllEqual(fake_data['var_len_seq_name'].values, [1, 2, 3, 4, 5]) 85 | self.assertAllEqual(fake_data['var_len_seq_name'].dense_shape, [2, 3]) 86 | 87 | def test_no_output_name(self): 88 | parse_fn = ( 89 | builders.SequenceExampleParserBuilder() 90 | .parse_feature('my_context_feature', 91 | tf.io.FixedLenFeature((2,), dtype=tf.int64), 92 | is_context=True) 93 | .build()) 94 | features_dict = parse_fn(self._raw_seq_example) 95 | 96 | self.assertSetEqual(set(['my_context_feature']), set(features_dict.keys())) 97 | 98 | def test_same_output_name(self): 99 | parser_builder = builders.SequenceExampleParserBuilder() 100 | parser_builder.parse_feature('my_context_feature', 101 | tf.io.FixedLenFeature((2,), dtype=tf.int64), 102 | 'same_name', True) 103 | parser_builder.parse_feature('my_context_feature', 104 | tf.io.FixedLenFeature((2,), dtype=tf.int64), 105 | 'other_name', True) 106 | 107 | with self.assertRaises(ValueError) as _: 108 | parser_builder.parse_feature( 109 | 'my_seq_feature', tf.io.FixedLenSequenceFeature((2,), dtype=tf.int64), 110 | 'same_name') 111 | 112 | def test_different_types_for_same_feature(self): 113 | parser_builder = builders.SequenceExampleParserBuilder() 114 | parser_builder.parse_feature('my_context_feature', 115 | tf.io.FixedLenFeature((2,), dtype=tf.int64), 116 | 'context_name', True) 117 | 118 | with self.assertRaises(ValueError) as _: 119 | parser_builder.parse_feature('my_context_feature', 120 | tf.io.FixedLenFeature((3,), dtype=tf.int64), 121 | 'context_name_2', True) 122 | 123 | with self.assertRaises(ValueError) as _: 124 | parser_builder.parse_feature('my_context_feature', 125 | tf.io.FixedLenFeature((2,), dtype=tf.string), 126 | 'context_name_3', True) 127 | 128 | 129 | class ExampleParserBuilderTest(tf.test.TestCase): 130 | 131 | def setUp(self): 132 | super().setUp() 133 | 134 | # Prepare Example. 135 | tf_example = tf.train.Example() 136 | tf_example.features.feature.get_or_create( 137 | 'my_fixed_len_feature').int64_list.value[:] = [0, 1] 138 | tf_example.features.feature.get_or_create( 139 | 'my_var_len_feature').int64_list.value[:] = [2, 3, 4] 140 | 141 | # Put Example in expected format. 142 | self._raw_tf_example = tf.constant(tf_example.SerializeToString()) 143 | 144 | def test_parse(self): 145 | parse_fn = ( 146 | builders.ExampleParserBuilder() 147 | .parse_feature('my_fixed_len_feature', 148 | tf.io.FixedLenFeature((2,), dtype=tf.int64), 149 | 'fixed_name') 150 | .parse_feature('my_var_len_feature', 151 | tf.io.VarLenFeature(dtype=tf.int64), 'var_name') 152 | .build()) 153 | features_dict = parse_fn(self._raw_tf_example) 154 | 155 | self.assertSetEqual(set(['fixed_name', 'var_name']), 156 | set(features_dict.keys())) 157 | self.assertAllEqual(features_dict['fixed_name'], [0, 1]) 158 | self.assertAllEqual(features_dict['var_name'].values, [2, 3, 4]) 159 | self.assertAllEqual(features_dict['var_name'].indices, [[0], [1], [2]]) 160 | self.assertAllEqual(features_dict['var_name'].dense_shape, [3]) 161 | 162 | def test_fake_data(self): 163 | fake_data = ( 164 | builders.ExampleParserBuilder() 165 | .parse_feature('my_fixed_len_feature', 166 | tf.io.FixedLenFeature((2,), dtype=tf.string), 167 | 'fixed_name') 168 | .parse_feature('my_var_len_feature', 169 | tf.io.VarLenFeature(dtype=tf.int64), 'var_name') 170 | .get_fake_data(default_values={'fixed_name': (b'42', b'25')})) 171 | self.assertSetEqual(set(['fixed_name', 'var_name']), 172 | set(fake_data.keys())) 173 | self.assertAllEqual(fake_data['fixed_name'], [b'42', b'25']) 174 | self.assertAllEqual(fake_data['var_name'].values, [0]) 175 | self.assertAllEqual(fake_data['var_name'].indices, [[0]]) 176 | self.assertAllEqual(fake_data['var_name'].dense_shape, [1]) 177 | 178 | def test_no_output_name(self): 179 | parse_fn = ( 180 | builders.ExampleParserBuilder() 181 | .parse_feature('my_fixed_len_feature', 182 | tf.io.FixedLenFeature((2,), dtype=tf.int64)) 183 | .build()) 184 | features_dict = parse_fn(self._raw_tf_example) 185 | 186 | self.assertSetEqual(set(['my_fixed_len_feature']), 187 | set(features_dict.keys())) 188 | 189 | def test_same_output_name(self): 190 | parser_builder = builders.ExampleParserBuilder() 191 | parser_builder.parse_feature('my_fixed_len_feature', 192 | tf.io.FixedLenFeature((2,), dtype=tf.int64), 193 | 'same_name') 194 | parser_builder.parse_feature('my_fixed_len_feature', 195 | tf.io.FixedLenFeature((2,), dtype=tf.int64), 196 | 'other_name') 197 | 198 | with self.assertRaises(ValueError) as _: 199 | parser_builder.parse_feature( 200 | 'my_var_len_feature', 201 | tf.io.FixedLenSequenceFeature((2,), dtype=tf.int64), 'same_name') 202 | 203 | def test_different_types_for_same_feature(self): 204 | parser_builder = builders.SequenceExampleParserBuilder() 205 | parser_builder.parse_feature('my_fixed_len_feature', 206 | tf.io.FixedLenFeature((2,), dtype=tf.int64), 207 | 'fixed_name') 208 | 209 | with self.assertRaises(ValueError) as _: 210 | parser_builder.parse_feature('my_fixed_len_feature', 211 | tf.io.FixedLenFeature((3,), dtype=tf.int64), 212 | 'fixed_name_2') 213 | 214 | with self.assertRaises(ValueError) as _: 215 | parser_builder.parse_feature('my_fixed_len_feature', 216 | tf.io.FixedLenFeature((2,), dtype=tf.string), 217 | 'fixed_name_3') 218 | 219 | 220 | def _add_one(x): 221 | return tf.math.add(x, 1) 222 | 223 | 224 | def _subtract_one(x): 225 | return tf.math.subtract(x, 1) 226 | 227 | 228 | def _upper_text(x): 229 | return tf.strings.upper(x) 230 | 231 | 232 | def _add_text_len(features_dict): 233 | features_dict['feature_3'] = tf.strings.length( 234 | input=features_dict['feature_2']) 235 | return features_dict 236 | 237 | 238 | def _set_state(x, state): 239 | state['value'] = x 240 | return x 241 | 242 | 243 | def _use_state(features_dict, state): 244 | features_dict['feature_4'] = state['value'] 245 | return features_dict 246 | 247 | 248 | class BuilderTest(tf.test.TestCase): 249 | 250 | def setUp(self): 251 | super().setUp() 252 | 253 | # Prepare features dictionary. 254 | self._input_features_dict = { 255 | 'feature_1': tf.constant(0), 256 | 'feature_2': tf.constant('text') 257 | } 258 | 259 | def test_basic(self): 260 | process_fn = ( 261 | builders._Builder() 262 | .add_fn(_add_one, 'feature_1') 263 | .add_fn(_upper_text, 'feature_2') 264 | .add_fn(_add_text_len) 265 | .add_fn(_add_one, 'feature_1') 266 | .build()) 267 | output_features_dict = process_fn(self._input_features_dict) 268 | 269 | self.assertSetEqual( 270 | set(['feature_1', 'feature_2', 'feature_3']), 271 | set(output_features_dict.keys())) 272 | self.assertEqual(output_features_dict['feature_1'], 2) 273 | self.assertEqual(output_features_dict['feature_2'], b'TEXT') 274 | self.assertEqual(output_features_dict['feature_3'], 4) 275 | 276 | def test_replace(self): 277 | process_fn = ( 278 | builders._Builder() 279 | .add_fn(_add_one, 'feature_1', 'add_one') 280 | .add_fn(_upper_text, 'feature_2') 281 | .replace_fn('add_one', _subtract_one) 282 | .build()) 283 | output_features_dict = process_fn(self._input_features_dict) 284 | 285 | self.assertSetEqual(set(['feature_1', 'feature_2']), 286 | set(output_features_dict.keys())) 287 | self.assertEqual(output_features_dict['feature_1'], -1) 288 | self.assertEqual(output_features_dict['feature_2'], b'TEXT') 289 | 290 | def test_remove(self): 291 | process_fn = ( 292 | builders._Builder() 293 | .add_fn(_add_one, 'feature_1', 'add_one') 294 | .add_fn(_upper_text, 'feature_2') 295 | .remove_fn('add_one') 296 | .build()) 297 | output_features_dict = process_fn(self._input_features_dict) 298 | 299 | self.assertSetEqual(set(['feature_1', 'feature_2']), 300 | set(output_features_dict.keys())) 301 | self.assertEqual(output_features_dict['feature_1'], 0) 302 | self.assertEqual(output_features_dict['feature_2'], b'TEXT') 303 | 304 | def test_reset(self): 305 | process_fn = ( 306 | builders._Builder() 307 | .add_fn(_add_one, 'feature_1') 308 | .add_fn(_upper_text, 'feature_2') 309 | .reset() 310 | .build()) 311 | output_features_dict = process_fn(self._input_features_dict) 312 | 313 | self.assertSetEqual(set(['feature_1', 'feature_2']), 314 | set(output_features_dict.keys())) 315 | self.assertEqual(output_features_dict['feature_1'], 0) 316 | self.assertEqual(output_features_dict['feature_2'], b'text') 317 | 318 | def test_stateful(self): 319 | process_fn = ( 320 | builders._Builder() 321 | .add_fn(_set_state, 'feature_1', stateful=True) 322 | .add_fn(_use_state, stateful=True) 323 | .build()) 324 | output_features_dict = process_fn(self._input_features_dict) 325 | 326 | self.assertSetEqual(set(['feature_1', 'feature_2', 'feature_4']), 327 | set(output_features_dict.keys())) 328 | self.assertEqual(output_features_dict['feature_1'], 0) 329 | self.assertEqual(output_features_dict['feature_4'], 0) 330 | 331 | def test_same_fn_name(self): 332 | builder = builders._Builder().add_fn(_add_one, 'feature_1', 'add_one') 333 | 334 | with self.assertRaises(ValueError) as _: 335 | builder.add_fn(_add_one, 'feature_1', 'add_one') 336 | 337 | def test_replace_wrong_fn_name(self): 338 | builder = builders._Builder().add_fn(_add_one, 'feature_1', 'add_one') 339 | 340 | with self.assertRaises(ValueError) as _: 341 | builder.replace_fn('add_one_wrong', _add_one) 342 | 343 | def test_insert(self): 344 | def replace_string(_): 345 | return tf.constant('replaced_text') 346 | 347 | builder = builders._Builder() .add_fn(_add_text_len, fn_name='text_len') 348 | output_features_dict = builder.build()(self._input_features_dict) 349 | 350 | builder.add_fn(replace_string, 'feature_2', add_before_fn_name='text_len') 351 | output_features_dict_2 = builder.build()(self._input_features_dict) 352 | 353 | self.assertSetEqual(set(['feature_1', 'feature_2', 'feature_3']), 354 | set(output_features_dict.keys())) 355 | self.assertEqual(output_features_dict['feature_2'], b'text') 356 | self.assertEqual(output_features_dict['feature_3'], 4) 357 | 358 | self.assertSetEqual(set(['feature_1', 'feature_2', 'feature_3']), 359 | set(output_features_dict_2.keys())) 360 | self.assertEqual(output_features_dict_2['feature_2'], b'replaced_text') 361 | self.assertEqual(output_features_dict_2['feature_3'], 13) 362 | 363 | def test_wrong_add_before_fn_name(self): 364 | builder = builders._Builder().add_fn(_add_one, 'feature_1', 'add_one') 365 | 366 | with self.assertRaises(ValueError) as _: 367 | builder.add_fn(_add_one, 'feature_1', add_before_fn_name='add_one_wrong') 368 | 369 | 370 | class FilterBuilderTest(tf.test.TestCase): 371 | 372 | def setUp(self): 373 | super().setUp() 374 | 375 | # Prepare features dictionary. 376 | self._input_features_dict = { 377 | 'feature_1': tf.constant(0), 378 | 'feature_2': tf.constant('text'), 379 | 'feature_3': tf.zeros((16, 200, 200, 3)) 380 | } 381 | 382 | @parameterized.expand(((builders.Phase.READ,), (builders.Phase.PARSE,), 383 | (builders.Phase.SAMPLE,), (builders.Phase.DECODE,), 384 | (builders.Phase.PREPROCESS,), 385 | (builders.Phase.POSTPROCESS,))) 386 | def test_drop(self, phase): 387 | filter_fn = ( 388 | builders.FilterBuilder() 389 | .add_filter_fn(lambda fd: tf.equal(fd['feature_1'], 0), phase) 390 | .add_filter_fn(lambda fd: tf.equal(fd['feature_2'], 'no_text'), phase) 391 | .add_filter_fn( 392 | lambda fd: tf.equal(tf.shape(input=fd['feature_3'])[3], 3), phase) 393 | .build(phase)) 394 | keep = filter_fn(self._input_features_dict) 395 | 396 | self.assertEqual(keep, False) 397 | 398 | @parameterized.expand(((builders.Phase.READ,), (builders.Phase.PARSE,), 399 | (builders.Phase.SAMPLE,), (builders.Phase.DECODE,), 400 | (builders.Phase.PREPROCESS,), 401 | (builders.Phase.POSTPROCESS,))) 402 | def test_keep(self, phase): 403 | filter_fn = ( 404 | builders.FilterBuilder() 405 | .add_filter_fn(lambda fd: tf.equal(fd['feature_1'], 0), phase) 406 | .add_filter_fn(lambda fd: tf.equal(fd['feature_2'], 'text'), phase) 407 | .add_filter_fn( 408 | lambda fd: tf.equal(tf.shape(input=fd['feature_3'])[3], 3), phase) 409 | .build(phase)) 410 | keep = filter_fn(self._input_features_dict) 411 | 412 | self.assertEqual(keep, True) 413 | 414 | @parameterized.expand(((builders.Phase.READ,), (builders.Phase.PARSE,), 415 | (builders.Phase.SAMPLE,), (builders.Phase.DECODE,), 416 | (builders.Phase.PREPROCESS,), 417 | (builders.Phase.POSTPROCESS,))) 418 | def test_empty(self, phase): 419 | filter_fn = builders.FilterBuilder().build(phase) 420 | keep = filter_fn(self._input_features_dict) 421 | 422 | self.assertEqual(keep, True) 423 | 424 | 425 | if __name__ == '__main__': 426 | tf.test.main() 427 | -------------------------------------------------------------------------------- /dmvr/modalities.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 DeepMind Technologies Limited. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Utils for adding modalities.""" 16 | 17 | import functools 18 | from typing import Optional 19 | from typing import Union 20 | 21 | from absl import logging 22 | from dmvr import builders 23 | from dmvr import processors 24 | from dmvr import tokenizers 25 | import tensorflow as tf 26 | 27 | 28 | # ---------------------------------------------------------------------- 29 | # -------- Methods aggregating functions for a given modality. --------- 30 | # ---------------------------------------------------------------------- 31 | 32 | 33 | def add_image( 34 | parser_builder: builders.BaseParserBuilder, 35 | sampler_builder: builders.SamplerBuilder, 36 | decoder_builder: builders.DecoderBuilder, 37 | preprocessor_builder: builders.PreprocessorBuilder, 38 | postprocessor_builder: builders.PostprocessorBuilder, 39 | input_feature_name: str = 'image/encoded', 40 | output_feature_name: str = builders.IMAGE_FEATURE_NAME, 41 | is_training: bool = True, 42 | # Video related parameters. 43 | num_frames: int = 32, 44 | stride: int = 1, 45 | num_test_clips: int = 1, 46 | min_resize: int = 224, 47 | resize_method: str = tf.image.ResizeMethod.BILINEAR, 48 | crop_size: int = 200, 49 | zero_centering_image: bool = False, 50 | sync_random_state: bool = True, 51 | is_rgb: Optional[bool] = True, 52 | is_flow: bool = False, 53 | random_flip: bool = True, 54 | normalization_mean: Union[tf.Tensor, float] = 0, 55 | normalization_std: Union[tf.Tensor, float] = 1, 56 | ) -> None: 57 | """Adds functions to process image feature to builders. 58 | 59 | This function expects the input to be either a `tf.train.SequenceExample` (for 60 | videos) and have the following structure: 61 | ``` 62 | feature_lists { 63 | feature_list { 64 | key: input_feature_name 65 | value { 66 | feature { 67 | bytes_list { 68 | value: jpeg_bytes 69 | } 70 | } 71 | } 72 | } 73 | } 74 | ``` 75 | 76 | Or a `tf.train.Example` (for image only) and have the following structure: 77 | ``` 78 | features { 79 | feature { 80 | key: input_feature_name 81 | value { 82 | bytes_list { 83 | value: "JPEG" 84 | } 85 | } 86 | } 87 | } 88 | ``` 89 | 90 | The corresponding `builders.ExampleParserBuilder` or 91 | `builders.SequenceExampleParserBuilder` has to be given as parameter. 92 | 93 | Args: 94 | parser_builder: An instance of a `builders.BaseParserBuilder`. 95 | sampler_builder: An instance of a `builders.SamplerBuilder`. 96 | decoder_builder: An instance of a `builders.DecoderBuilder`. 97 | preprocessor_builder: An instance of a `builders.PreprocessorBuilder`. 98 | postprocessor_builder: An instance of a `builders.PostprocessorBuilder`. 99 | input_feature_name: Name of the feature in the input `tf.train.Example` or 100 | `tf.train.SequenceExample`. Exposing this as an argument allows using this 101 | function for different image features within a single dataset. 102 | output_feature_name: Name of the feature in the output features dictionary. 103 | Exposing this as an argument allows using this function for different 104 | image features within a single dataset. 105 | is_training: Whether in training mode. If `True`, random sample, crop and 106 | left right flip is used. 107 | num_frames: Number of frames per subclip. For single images, use 1. 108 | stride: Temporal stride to sample frames. 109 | num_test_clips: Number of test clips (1 by default). If more than 1, this 110 | will sample multiple linearly spaced clips within each video at test time. 111 | If 1, then a single clip in the middle of the video is sampled. The clips 112 | are aggregated in the batch dimension. 113 | min_resize: Frames are resized so that `min(height, width)` is `min_resize`. 114 | resize_method: A resizing method. 115 | crop_size: Final size of the frame after cropping the resized frames. Both 116 | height and width are the same. 117 | zero_centering_image: If `True`, frames are normalized to values in [-1, 1]. 118 | If `False`, values in [0, 1]. 119 | sync_random_state: Whether to use stateful option to keep random operations 120 | in sync between different modalities. All modalities having this option 121 | `True` will use the same outcome in random operations such as sampling and 122 | cropping. 123 | is_rgb: If `True`, the number of channels in the JPEG is 3, if False, 1. If 124 | is_flow is `True`, `is_rgb` should be set to `None` (see below). 125 | is_flow: If `True`, the image is assumed to contain flow and will be 126 | processed as such. Note that the number of channels in the JPEG for flow 127 | is 3, but only two channels will be output corresponding to the valid 128 | horizontal and vertical displacement. 129 | random_flip: If `True`, a random horizontal flip is applied to the input 130 | image. This augmentation may not be used if the label set contains 131 | direction related classes, such as `pointing left`, `pointing right`, etc. 132 | normalization_mean: value to subtract from the input image to normalize it. 133 | normalization_std: value to divide by from the input image to normalize it. 134 | """ 135 | 136 | # Validate parameters. 137 | if is_flow and is_rgb is not None: 138 | raise ValueError('`is_rgb` should be `None` when requesting flow.') 139 | 140 | if is_flow and not zero_centering_image: 141 | raise ValueError('Flow contains displacement values that can be negative, ' 142 | 'but `zero_centering_image` was set to `False`.') 143 | 144 | if is_training and num_test_clips != 1: 145 | logging.info('`num_test_clips` %d is ignored since `is_training` is true.', 146 | num_test_clips) 147 | 148 | # Parse frames or single image. 149 | if isinstance(parser_builder, builders.SequenceExampleParserBuilder): 150 | parser_builder.parse_feature( 151 | feature_name=input_feature_name, 152 | feature_type=tf.io.FixedLenSequenceFeature((), dtype=tf.string), 153 | output_name=output_feature_name) 154 | elif isinstance(parser_builder, builders.ExampleParserBuilder): 155 | parser_builder.parse_feature( 156 | feature_name=input_feature_name, 157 | feature_type=tf.io.FixedLenFeature((), dtype=tf.string), 158 | output_name=output_feature_name) 159 | # Expand dimensions so single images have the same structure as videos. 160 | sampler_builder.add_fn( 161 | fn=lambda x: tf.expand_dims(x, axis=0), 162 | feature_name=output_feature_name, 163 | fn_name=f'{output_feature_name}_expand_dims') 164 | else: 165 | raise ValueError('`parser_builder` has an unexpected type.') 166 | 167 | # Temporal sampler. 168 | if is_training: 169 | # Sample random clip. 170 | sampler_builder.add_fn( 171 | # pylint: disable=g-long-lambda 172 | fn=lambda x, s=None: processors.sample_sequence( 173 | x, num_frames, True, stride, state=s), 174 | # pylint: enable=g-long-lambda 175 | feature_name=output_feature_name, 176 | fn_name=f'{output_feature_name}_random_sample', 177 | # Use state to keep coherence between modalities if requested. 178 | stateful=sync_random_state) 179 | else: 180 | if num_test_clips > 1: 181 | # Sample linspace clips. 182 | sampler_builder.add_fn( 183 | # pylint: disable=g-long-lambda 184 | fn=lambda x: processors.sample_linspace_sequence( 185 | x, num_test_clips, num_frames, stride), 186 | # pylint: enable=g-long-lambda 187 | feature_name=output_feature_name, 188 | fn_name=f'{output_feature_name}_linspace_sample') 189 | else: 190 | # Sample middle clip. 191 | sampler_builder.add_fn( 192 | fn=lambda x: processors.sample_sequence(x, num_frames, False, stride), 193 | feature_name=output_feature_name, 194 | fn_name=f'{output_feature_name}_middle_sample') 195 | 196 | # Decode JPEG string to `tf.uint8`. 197 | # Note that for flow, 3 channels are stored in the JPEG: the first two 198 | # corresponds to horizontal and vertical displacement, respectively. 199 | # The last channel contains zeros and is dropped later in the preprocessing. 200 | # Hence, the output number of channels for flow is 2. 201 | num_raw_channels = 3 if (is_rgb or is_flow) else 1 202 | decoder_builder.add_fn( 203 | fn=lambda x: processors.decode_jpeg(x, channels=num_raw_channels), 204 | feature_name=output_feature_name, 205 | fn_name=f'{output_feature_name}_decode_jpeg') 206 | 207 | if is_flow: 208 | # Cast the flow to `tf.float32`, normalizing between [-1.0, 1.0]. 209 | preprocessor_builder.add_fn( 210 | fn=lambda x: processors.normalize_image(x, zero_centering_image=True), 211 | feature_name=output_feature_name, 212 | fn_name=f'{output_feature_name}_normalize') 213 | 214 | # Resize images (resize happens only if necessary to save compute). 215 | preprocessor_builder.add_fn( 216 | # pylint: disable=g-long-lambda 217 | fn=lambda x: processors.resize_smallest( 218 | x, min_resize, is_flow=is_flow, method=resize_method), 219 | # pylint: enable=g-long-lambda 220 | feature_name=output_feature_name, 221 | fn_name=f'{output_feature_name}_resize_smallest') 222 | 223 | if is_training: 224 | # Standard image data augmentation: random crop and random flip. 225 | preprocessor_builder.add_fn( 226 | # pylint: disable=g-long-lambda 227 | fn=lambda x, s=None: processors.crop_image( 228 | x, crop_size, crop_size, True, state=s), 229 | # pylint: enable=g-long-lambda 230 | feature_name=output_feature_name, 231 | fn_name=f'{output_feature_name}_random_crop', 232 | # Use state to keep coherence between modalities if requested. 233 | stateful=sync_random_state) 234 | if random_flip: 235 | preprocessor_builder.add_fn( 236 | # pylint: disable=g-long-lambda 237 | fn=lambda x, s=None: processors.random_flip_left_right( 238 | x, state=s, is_flow=is_flow), 239 | # pylint: enable=g-long-lambda 240 | feature_name=output_feature_name, 241 | fn_name=f'{output_feature_name}_random_flip', 242 | # Use state to keep coherence between modalities if requested. 243 | stateful=sync_random_state) 244 | else: 245 | # Central crop of the frames. 246 | preprocessor_builder.add_fn( 247 | fn=lambda x: processors.crop_image(x, crop_size, crop_size, False), 248 | feature_name=output_feature_name, 249 | fn_name=f'{output_feature_name}_central_crop') 250 | 251 | if is_flow: 252 | # Keep only two channels for the flow: horizontal and vertical displacement. 253 | preprocessor_builder.add_fn( 254 | fn=lambda x: x[:, :, :, :2], 255 | feature_name=output_feature_name, 256 | fn_name=f'{output_feature_name}_extract_flow_channels') 257 | 258 | # Clip the flow to stay between [-1.0 and 1.0] 259 | preprocessor_builder.add_fn( 260 | fn=lambda x: tf.clip_by_value(x, -1.0, 1.0), 261 | feature_name=output_feature_name, 262 | fn_name=f'{output_feature_name}_clip_flow') 263 | else: 264 | # Cast the frames to `tf.float32`, normalizing according to 265 | # `zero_centering_image`. 266 | preprocessor_builder.add_fn( 267 | fn=lambda x: processors.normalize_image(x, zero_centering_image), 268 | feature_name=output_feature_name, 269 | fn_name=f'{output_feature_name}_normalize') 270 | 271 | preprocessor_builder.add_fn( 272 | fn=lambda x: x - normalization_mean, 273 | feature_name=output_feature_name, 274 | fn_name=f'{output_feature_name}_subtract_given_mean') 275 | 276 | preprocessor_builder.add_fn( 277 | fn=lambda x: x / normalization_std, 278 | feature_name=output_feature_name, 279 | fn_name=f'{output_feature_name}_divide_by_given_std') 280 | 281 | if num_test_clips > 1 and not is_training: 282 | # In this case, multiple clips are merged together in batch dimension which 283 | # will be `B * num_test_clips`. 284 | postprocessor_builder.add_fn( 285 | fn=lambda x: tf.reshape( # pylint: disable=g-long-lambda 286 | x, (-1, num_frames, x.shape[2], x.shape[3], x.shape[4])), 287 | feature_name=output_feature_name, 288 | fn_name=f'{output_feature_name}_reshape') 289 | 290 | 291 | def add_label( 292 | parser_builder: builders.BaseParserBuilder, 293 | decoder_builder: builders.DecoderBuilder, 294 | preprocessor_builder: builders.PreprocessorBuilder, 295 | input_label_index_feature_name: str = 'clip/label/index', 296 | output_label_index_feature_name: str = builders.LABEL_INDEX_FEATURE_NAME, 297 | input_label_name_feature_name: Optional[str] = 'clip/label/text', 298 | output_label_name_feature_name: Optional[str] = builders 299 | .LABEL_NAME_FEATURE_NAME, 300 | # Label related parameters. 301 | is_multi_label: bool = False, 302 | one_hot_label: bool = True, 303 | num_classes: Optional[int] = None, 304 | add_label_name: bool = False): 305 | """Adds functions to process label feature to builders. 306 | 307 | This function expects the input to be either a `tf.train.SequenceExample` 308 | (with the features in the context) or a `tf.train.Example`. The expected 309 | structure is (or equivalent for `tf.train.Example`): 310 | ``` 311 | context { 312 | feature { 313 | key: input_label_index_feature_name 314 | value { 315 | int64_list { 316 | value: 42 317 | ... 318 | } 319 | } 320 | } 321 | feature { 322 | key: input_label_name_feature_name 323 | value { 324 | bytes_list { 325 | value: "label_42" 326 | ... 327 | } 328 | } 329 | } 330 | } 331 | ``` 332 | 333 | The corresponding `builders.ExampleParserBuilder` or 334 | `builders.SequenceExampleParserBuilder` has to be given as parameter. 335 | 336 | Args: 337 | parser_builder: An instance of a `builders.BaseParserBuilder`. 338 | decoder_builder: An instance of a `builders.DecoderBuilder`. 339 | preprocessor_builder: An instance of a `builders.PreprocessorBuilder`. 340 | input_label_index_feature_name: Name of the label index feature in the input 341 | `tf.train.Example` or `tf.train.SequenceExample`. Exposing this as an 342 | argument allows using this function for different label features within a 343 | single dataset. 344 | output_label_index_feature_name: Name of the label index feature in the 345 | output features dictionary. Exposing this as an argument allows using this 346 | function for different label features within a single dataset. 347 | input_label_name_feature_name: Name of the label name feature in the input 348 | `tf.train.Example` or `tf.train.SequenceExample`. If `add_label_name` is 349 | false, this option is ignored. Exposing this as an argument allows using 350 | this function for different label features within a single dataset. 351 | output_label_name_feature_name: Name of the label name feature in the output 352 | features dictionary. If `add_label_name` is false, this option is ignored. 353 | Exposing this as an argument allows using this function for different 354 | label features within a single dataset. 355 | is_multi_label: Whether raw data contains multiple labels per example. 356 | one_hot_label: Return labels as one hot tensors. If `is_multi_label` is 357 | `True`, one hot tensor might have multiple ones. 358 | num_classes: Total number of classes in the dataset. It has to be provided 359 | if `one_hot_label` is `True`. 360 | add_label_name: Also return the name of the label. Not yet supported for 361 | multi label. 362 | """ 363 | # Validate parameters. 364 | if one_hot_label and not num_classes: 365 | raise ValueError( 366 | '`num_classes` must be given when requesting one hot label.') 367 | if is_multi_label and not one_hot_label: 368 | logging.warning( 369 | 'Multi label indices will be returned in a non fixed size dimension.') 370 | if add_label_name and (input_label_name_feature_name is None or 371 | output_label_name_feature_name is None): 372 | raise ValueError( 373 | '`input_label_name_feature_name` and `output_label_name_feature_name` ' 374 | 'must be given when `add_label_name` is true.') 375 | 376 | # Parse label. 377 | if isinstance(parser_builder, builders.SequenceExampleParserBuilder): 378 | parser_builder.parse_feature( 379 | feature_name=input_label_index_feature_name, 380 | feature_type=tf.io.VarLenFeature(dtype=tf.int64), 381 | output_name=output_label_index_feature_name, 382 | is_context=True) 383 | if add_label_name: 384 | parser_builder.parse_feature( 385 | feature_name=input_label_name_feature_name, 386 | feature_type=tf.io.VarLenFeature(dtype=tf.string), 387 | output_name=output_label_name_feature_name, 388 | is_context=True) 389 | elif isinstance(parser_builder, builders.ExampleParserBuilder): 390 | parser_builder.parse_feature( 391 | feature_name=input_label_index_feature_name, 392 | feature_type=tf.io.VarLenFeature(dtype=tf.int64), 393 | output_name=output_label_index_feature_name) 394 | if add_label_name: 395 | parser_builder.parse_feature( 396 | feature_name=input_label_name_feature_name, 397 | feature_type=tf.io.VarLenFeature(dtype=tf.string), 398 | output_name=output_label_name_feature_name) 399 | else: 400 | raise ValueError('`parser_builder` has an unexpected type.') 401 | 402 | # Densify labels tensor in order to support multi label case. 403 | decoder_builder.add_fn( 404 | fn=tf.sparse.to_dense, 405 | feature_name=output_label_index_feature_name, 406 | fn_name=f'{output_label_index_feature_name}_sparse_to_dense') 407 | if add_label_name: 408 | decoder_builder.add_fn( 409 | fn=tf.sparse.to_dense, 410 | feature_name=output_label_name_feature_name, 411 | fn_name=f'{output_label_name_feature_name}_sparse_to_dense') 412 | 413 | if one_hot_label: 414 | # Replace label index by one hot representation. 415 | preprocessor_builder.add_fn( 416 | fn=lambda x: tf.reduce_sum( # pylint: disable=g-long-lambda 417 | input_tensor=tf.one_hot(x, num_classes), 418 | axis=0), 419 | feature_name=output_label_index_feature_name, 420 | fn_name=f'{output_label_index_feature_name}_one_hot') 421 | elif not is_multi_label: 422 | preprocessor_builder.add_fn( 423 | fn=lambda x: processors.set_shape(x, (1,)), 424 | feature_name=output_label_index_feature_name, 425 | fn_name=f'{output_label_index_feature_name}_set_shape') 426 | 427 | if add_label_name and not is_multi_label: 428 | preprocessor_builder.add_fn( 429 | fn=lambda x: processors.set_shape(x, (1,)), 430 | feature_name=output_label_name_feature_name, 431 | fn_name=f'{output_label_name_feature_name}_set_shape') 432 | 433 | 434 | def add_text( 435 | parser_builder: builders.BaseParserBuilder, 436 | decoder_builder: builders.DecoderBuilder, 437 | preprocessor_builder: builders.PreprocessorBuilder, 438 | tokenizer: tokenizers.TextTokenizer, 439 | is_training: bool = True, 440 | input_feature_name: str = 'caption/string', 441 | output_raw_string_name: str = builders.TEXT_FEATURE_NAME, 442 | output_feature_name: str = builders.TEXT_INDICES_FEATURE_NAME, 443 | # Text related parameters. 444 | prepend_bos: bool = False, 445 | append_eos: bool = False, 446 | keep_raw_string: bool = False, 447 | max_num_captions: int = 1, 448 | max_num_tokens: Optional[int] = 16, 449 | sync_random_state: bool = False): 450 | """Adds functions to process text feature to builders. 451 | 452 | This function expects the input to be either a `tf.train.SequenceExample` 453 | (with the features in the context) or a `tf.train.Example`. The expected 454 | structure is (or equivalent for `tf.train.Example`): 455 | ``` 456 | context { 457 | feature { 458 | key: input_feature_name 459 | value { 460 | bytes_list { 461 | value: "Hello world!" 462 | value: "This is a caption." 463 | ... 464 | } 465 | } 466 | } 467 | } 468 | ``` 469 | 470 | The corresponding `builders.ExampleParserBuilder` or 471 | `builders.SequenceExampleParserBuilder` has to be given as parameter. 472 | 473 | Args: 474 | parser_builder: An instance of a `builders.BaseParserBuilder`. 475 | decoder_builder: An instance of a `builders.DecoderBuilder`. 476 | preprocessor_builder: An instance of a `builders.PreprocessorBuilder`. 477 | tokenizer: An instance of a tokenizer. 478 | is_training: Whether in training mode. This will be used to randomly sample 479 | the captions. 480 | input_feature_name: Name of the feature in the input `tf.train.Example` or 481 | `tf.train.SequenceExample`. Exposing this as an argument allows using this 482 | function for different text features within a single dataset. 483 | output_raw_string_name: Name of the raw string in the output features 484 | dictionary. Exposing this as an argument allows using this function for 485 | different text features within a single dataset. 486 | output_feature_name: Name of the feature in the output features dictionary. 487 | Exposing this as an argument allows using this function for different text 488 | features. 489 | prepend_bos: Whether to prepend BOS token. 490 | append_eos: Whether to append EOS token. 491 | keep_raw_string: Whether to keep raw string. 492 | max_num_captions: Maximum number of captions to keep. If there are more 493 | captions in the proto, only the first `max_num_captions` will be returned 494 | is `is_training` is set to `False`. If `is_training` is `True`, then 495 | `max_num_captions` will be randomly sampled. Finally, if the proto 496 | contains less than `max_num_captions`, we pad with empty strings to make 497 | sure there are `max_num_captions` in total. 498 | max_num_tokens: Maximum number of tokens to keep from the text for each 499 | caption. If there are more tokens, sequence is cropped, if less, the 500 | caption is padded using the tokenizer pad id. The sequence is unmodified 501 | if max_num_tokens is None. 502 | sync_random_state: Whether to use stateful option to keep random operations 503 | in sync between different modalities. All modalities having this option 504 | `True` will use the same outcome in random operations used for sampling 505 | the captions. 506 | """ 507 | # Parse text indices. 508 | if isinstance(parser_builder, builders.SequenceExampleParserBuilder): 509 | parser_builder.parse_feature( 510 | feature_name=input_feature_name, 511 | feature_type=tf.io.VarLenFeature(dtype=tf.string), 512 | output_name=output_raw_string_name, 513 | is_context=True) 514 | elif isinstance(parser_builder, builders.ExampleParserBuilder): 515 | parser_builder.parse_feature( 516 | feature_name=input_feature_name, 517 | feature_type=tf.io.VarLenFeature(dtype=tf.string), 518 | output_name=output_raw_string_name) 519 | 520 | # Densify text tensor. 521 | decoder_builder.add_fn( 522 | fn=tf.sparse.to_dense, 523 | feature_name=output_raw_string_name, 524 | fn_name=f'{output_feature_name}_sparse_to_dense') 525 | 526 | preprocessor_builder.add_fn( 527 | # pylint: disable=g-long-lambda 528 | lambda x, s=None: processors.sample_or_pad_non_sorted_sequence( 529 | x, max_num_captions, b'', random=is_training, state=s), 530 | # pylint: enable=g-long-lambda 531 | feature_name=output_raw_string_name, 532 | fn_name=f'{output_feature_name}_sample_captions', 533 | # Use state to keep coherence between modalities if requested. 534 | stateful=sync_random_state) 535 | 536 | # Tokenize the sentence. 537 | preprocessor_builder.add_fn( 538 | fn=lambda x: processors.tokenize( # pylint: disable=g-long-lambda 539 | x, tokenizer, output_raw_string_name, output_feature_name, 540 | prepend_bos, append_eos, max_num_tokens, keep_raw_string), 541 | fn_name=f'{output_feature_name}_tokenization') 542 | 543 | if max_num_tokens is not None: 544 | # Set text shape. 545 | shape = (max_num_captions, max_num_tokens) 546 | preprocessor_builder.add_fn( 547 | fn=lambda x: processors.set_shape(x, shape), 548 | feature_name=output_feature_name, 549 | fn_name=f'{output_feature_name}_set_shape') 550 | 551 | 552 | def add_audio( 553 | parser_builder: builders.BaseParserBuilder, 554 | sampler_builder: builders.SamplerBuilder, 555 | postprocessor_builder: builders.PostprocessorBuilder, 556 | preprocessor_builder: Optional[builders.PreprocessorBuilder] = None, 557 | input_feature_name: str = 'WAVEFORM/feature/floats', 558 | output_feature_name: str = builders.AUDIO_FEATURE_NAME, 559 | is_training: bool = True, 560 | # Audio related parameters. 561 | num_samples: int = 30720, 562 | stride: int = 1, 563 | sample_rate: Optional[int] = 48000, 564 | target_sample_rate: Optional[int] = None, 565 | num_test_clips: int = 1, 566 | sync_random_state: bool = True): 567 | """Adds functions to process audio feature to builders. 568 | 569 | This function expects the input to be either a `tf.train.SequenceExample` (for 570 | videos) and have the following structure: 571 | ``` 572 | feature_lists { 573 | feature_list { 574 | key: input_feature_name 575 | value { 576 | feature { 577 | float_list { 578 | value: 0.0 579 | value: 0.1 580 | value: 0.2 581 | ... 582 | } 583 | } 584 | } 585 | } 586 | } 587 | ``` 588 | 589 | Or a `tf.train.Example` (for image only) and have the following structure: 590 | ``` 591 | features { 592 | feature { 593 | key: input_feature_name 594 | value { 595 | float_list { 596 | value: 0.0 597 | value: 0.1 598 | value: 0.2 599 | ... 600 | } 601 | } 602 | } 603 | } 604 | ``` 605 | 606 | The corresponding `builders.ExampleParserBuilder` or 607 | `builders.SequenceExampleParserBuilder` has to be given as parameter. 608 | 609 | Args: 610 | parser_builder: An instance of a `builders.BaseParserBuilder`. 611 | sampler_builder: An instance of a `builders.SamplerBuilder`. 612 | postprocessor_builder: An instance of a `builders.PostprocessorBuilder`. 613 | preprocessor_builder: An instance of a `builders.PreprocessorBuilder`. 614 | input_feature_name: Name of the feature in the input `tf.train.Example` or 615 | `tf.train.SequenceExample`. Exposing this as an argument allows using this 616 | function for different audio features within a single dataset. 617 | output_feature_name: Name of the feature in the output features dictionary. 618 | Exposing this as an argument allows using this function for different 619 | audio features within a single dataset 620 | is_training: Whether in training mode. If `True`, random sample is used. 621 | num_samples: Number of samples per subclip. 622 | stride: Temporal stride to sample audio signal. 623 | sample_rate: The original sample rate of the input audio stored in sstables. 624 | target_sample_rate: If this is not None, the target new sample rate of the 625 | waveforms. Fast Fourier Transforms will be triggered if true. 626 | num_test_clips: Number of test clips (1 by default). If more than 1, this 627 | will sample multiple linearly spaced clips within each audio at test time. 628 | If 1, then a single clip in the middle of the audio is sampled. The clips 629 | are aggregated in the batch dimension. 630 | sync_random_state: Whether to use stateful option to keep random operations 631 | in sync between different modalities. All modalities having this option 632 | `True` will use the same outcome in random operations such as sampling and 633 | cropping. 634 | """ 635 | # Validate parameters. 636 | if is_training and num_test_clips != 1: 637 | logging.info('`num_test_clips` %d is ignored since `is_training` is true.', 638 | num_test_clips) 639 | 640 | # Keep audio signal. 641 | parser_builder.parse_feature( 642 | feature_name=input_feature_name, 643 | # Entire signal stored in one Feature. 644 | feature_type=tf.io.VarLenFeature(dtype=tf.float32), 645 | output_name=output_feature_name) 646 | 647 | # Densify. 648 | sampler_builder.add_fn( 649 | fn=lambda x: tf.sparse.to_dense(x)[0], 650 | feature_name=output_feature_name, 651 | fn_name=f'{output_feature_name}_sparse_to_dense') 652 | 653 | # Temporal sampler. 654 | if is_training: 655 | # Sample random clip. 656 | sampler_builder.add_fn( 657 | # pylint: disable=g-long-lambda 658 | fn=lambda x, s=None: processors.sample_sequence( 659 | x, num_samples, True, stride, state=s), 660 | # pylint: enable=g-long-lambda 661 | feature_name=output_feature_name, 662 | fn_name=f'{output_feature_name}_random_sample', 663 | # Use state to keep coherence between modalities if requested. 664 | stateful=sync_random_state) 665 | else: 666 | if num_test_clips > 1: 667 | # Sample linspace clips. 668 | sampler_builder.add_fn( 669 | # pylint: disable=g-long-lambda 670 | fn=lambda x: processors.sample_linspace_sequence( 671 | x, num_test_clips, num_samples, stride), 672 | # pylint: enable=g-long-lambda 673 | feature_name=output_feature_name, 674 | fn_name=f'{output_feature_name}_linspace_sample') 675 | else: 676 | # Sample middle clip. 677 | sampler_builder.add_fn( 678 | # pylint: disable=g-long-lambda 679 | fn=lambda x: processors.sample_sequence( 680 | x, num_samples, False, stride), 681 | # pylint: enable=g-long-lambda 682 | feature_name=output_feature_name, 683 | fn_name=f'{output_feature_name}_middle_sample') 684 | 685 | # Apply FFTs to change the sample rate of the waveforms. 686 | if preprocessor_builder is not None and target_sample_rate is not None: 687 | preprocessor_builder.add_fn( 688 | functools.partial( 689 | processors.resample_audio, 690 | num_subclips=num_test_clips, 691 | in_sample_rate=sample_rate, 692 | out_sample_rate=target_sample_rate, 693 | is_training=is_training), 694 | feature_name=builders.AUDIO_FEATURE_NAME) 695 | 696 | if num_test_clips > 1 and not is_training: 697 | # In this case, multiple clips are merged together in batch dimension which 698 | # will be `B * num_test_clips`. 699 | postprocessor_builder.add_fn( 700 | fn=lambda x: tf.reshape(x, (-1, x.shape[-1])), 701 | feature_name=output_feature_name, 702 | fn_name=f'{output_feature_name}_reshape') 703 | 704 | 705 | def add_spectrogram( 706 | preprocessor_builder: builders.PreprocessorBuilder, 707 | postprocessor_builder: builders.PostprocessorBuilder, 708 | input_feature_name: str = builders.AUDIO_FEATURE_NAME, 709 | output_feature_name: str = builders.AUDIO_MEL_FEATURE_NAME, 710 | is_training: bool = True, 711 | sample_rate: int = 48000, 712 | spectrogram_type: str = 'logmf', 713 | frame_length: int = 2048, 714 | frame_step: int = 1024, 715 | num_features: int = 80, 716 | lower_edge_hertz: float = 80.0, 717 | upper_edge_hertz: float = 7600.0, 718 | preemphasis: Optional[float] = None, 719 | normalize_audio: bool = False, 720 | num_test_clips: int = 1): 721 | """Adds functions to process audio spectrogram feature to builders. 722 | 723 | Note that this function does not extract and parse audio feature. Instead, it 724 | should be used after a `add_audio` function. The output spectrogram is of the 725 | shape [batch_size, num_frames, num_features]. 726 | 727 | Args: 728 | preprocessor_builder: An instance of a `builders.PreprocessorBuilder`. 729 | postprocessor_builder: An instance of a `builders.PostprocessorBuilder`. 730 | input_feature_name: Name of the feature in the input features dictionary. 731 | Exposing this as an argument allows using this function for different 732 | audio features. 733 | output_feature_name: Name of the feature in the output features dictionary. 734 | Exposing this as an argument allows using this function for different 735 | audio features. 736 | is_training: If the current mode is training or not. 737 | sample_rate: The sample rate of the input audio. 738 | spectrogram_type: The type of the spectrogram to be extracted from the 739 | waveform. Can be either `spectrogram`, `logmf`, and `mfcc`. 740 | frame_length: The length of each spectrogram frame. 741 | frame_step: The stride of spectrogram frames. 742 | num_features: The number of spectrogram features. 743 | lower_edge_hertz: Lowest frequency to consider. 744 | upper_edge_hertz: Highest frequency to consider. 745 | preemphasis: The strength of pre-emphasis on the waveform. If None, no 746 | pre-emphasis will be applied. 747 | normalize_audio: Whether to normalize the waveform or not. 748 | num_test_clips: Number of test clips (1 by default). If more than 1, this 749 | will sample multiple linearly spaced clips within each audio at test time. 750 | If 1, then a single clip in the middle of the audio is sampled. The clips 751 | are aggregated in the batch dimension. 752 | """ 753 | # Validate parameters. 754 | if is_training and num_test_clips != 1: 755 | logging.info('`num_test_clips` %d is ignored since `is_training` is true.', 756 | num_test_clips) 757 | 758 | # Extract audio spectrograms. 759 | preprocessor_builder.add_fn( 760 | functools.partial( 761 | processors.compute_audio_spectrogram, 762 | num_subclips=num_test_clips, 763 | sample_rate=sample_rate, 764 | spectrogram_type=spectrogram_type, 765 | frame_length=frame_length, 766 | frame_step=frame_step, 767 | num_features=num_features, 768 | lower_edge_hertz=lower_edge_hertz, 769 | upper_edge_hertz=upper_edge_hertz, 770 | normalize=normalize_audio, 771 | preemphasis=preemphasis, 772 | audio_feature_name=input_feature_name, 773 | spectrogram_feature_name=output_feature_name)) 774 | 775 | if num_test_clips > 1 and not is_training: 776 | # In this case, multiple clips are merged together in batch dimension which 777 | # will be `B * num_test_clips`. 778 | postprocessor_builder.add_fn( 779 | fn=lambda x: tf.reshape(x, (-1, x.shape[-2], x.shape[-1])), 780 | feature_name=output_feature_name, 781 | fn_name=f'{output_feature_name}_reshape') 782 | -------------------------------------------------------------------------------- /dmvr/modalities_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 DeepMind Technologies Limited. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Tests for modalities.""" 16 | 17 | import os 18 | 19 | from dmvr import builders 20 | from dmvr import modalities 21 | from dmvr import tokenizers 22 | import numpy as np 23 | from parameterized import parameterized 24 | import tensorflow as tf 25 | 26 | # Removed: Internal pyglib dependencies 27 | 28 | _TESTDATA_DIR = os.path.join(os.path.dirname(__file__), 'testdata') 29 | _SAMPLE_IMAGE_PATH = os.path.join(_TESTDATA_DIR, 'sample.jpeg') 30 | _VOCAB_PATH = os.path.join(_TESTDATA_DIR, 'tokenizers', 'word_vocab.txt') 31 | 32 | 33 | class ModalitiesTest(tf.test.TestCase): 34 | 35 | def setUp(self): 36 | super().setUp() 37 | seq_example = tf.train.SequenceExample() 38 | 39 | # Create stub frames and inject them in the SequenceExample. 40 | with open(_SAMPLE_IMAGE_PATH, 'rb') as f: raw_image_bytes = f.read() 41 | for _ in range(10 * 5): 42 | seq_example.feature_lists.feature_list.get_or_create( 43 | 'image/encoded').feature.add().bytes_list.value[:] = [raw_image_bytes] 44 | 45 | # Create stub flow and inject it in the SequenceExample. 46 | for _ in range(10 * 5): 47 | seq_example.feature_lists.feature_list.get_or_create( 48 | 'flow/encoded').feature.add().bytes_list.value[:] = [raw_image_bytes] 49 | 50 | # Create stub label and inject it in the SequenceExample. 51 | raw_label_index = 42 52 | raw_label_name = b'label' 53 | seq_example.context.feature.get_or_create( 54 | 'clip/label/index').int64_list.value[:] = [raw_label_index] 55 | seq_example.context.feature.get_or_create( 56 | 'clip/label/string').bytes_list.value[:] = [raw_label_name] 57 | 58 | # Create stub raw text and inject it in SequenceExample. 59 | raw_text = b'hello world' 60 | seq_example.context.feature.get_or_create( 61 | 'caption/string').bytes_list.value[:] = [raw_text, raw_text] 62 | 63 | # Create stub audio and inject it in SequenceExample. 64 | raw_audio = np.linspace(-1, 1, 48000 * 5) 65 | seq_example.feature_lists.feature_list.get_or_create( 66 | 'WAVEFORM/feature/floats').feature.add().float_list.value[:] = raw_audio 67 | 68 | serialized_seq_example = seq_example.SerializeToString() 69 | self._seq_examples = [serialized_seq_example] * 8 # Batch size is 8. 70 | 71 | # Create builders. 72 | self._parser_builder = builders.SequenceExampleParserBuilder() 73 | self._sampler_builder = builders.SamplerBuilder() 74 | self._decoder_builder = builders.DecoderBuilder() 75 | self._preprocessor_builder = builders.PreprocessorBuilder() 76 | self._postprocessor_builder = builders.PostprocessorBuilder() 77 | 78 | def _process_examples(self): 79 | """Process input examples simulating dataset object creation.""" 80 | def pre_batch_process(raw_seq_example): 81 | output = self._parser_builder.build()(raw_seq_example) 82 | output = self._sampler_builder.build()(output) 83 | output = self._decoder_builder.build()(output) 84 | output = self._preprocessor_builder.build()(output) 85 | return output 86 | 87 | # Batch and postprocess. 88 | output = [pre_batch_process(rse) for rse in self._seq_examples] 89 | batched_output = {} 90 | for k in output[0].keys(): 91 | batched_output[k] = tf.stack([out[k] for out in output]) 92 | output = batched_output 93 | output = self._postprocessor_builder.build()(output) 94 | 95 | return output 96 | 97 | @parameterized.expand(( 98 | (True, 1, False, True, ['image_random_sample'], [ 99 | 'image_resize_smallest', 'image_random_crop', 'image_random_flip', 100 | 'image_normalize' 101 | ], []), 102 | (True, 1, False, False, ['image_random_sample'], 103 | ['image_resize_smallest', 'image_random_crop', 'image_normalize'], []), 104 | (False, 1, False, True, ['image_middle_sample'], 105 | ['image_resize_smallest', 'image_central_crop', 'image_normalize'], []), 106 | (False, 2, False, True, ['image_linspace_sample'], 107 | ['image_resize_smallest', 'image_central_crop', 108 | 'image_normalize'], ['image_reshape']), 109 | (True, 1, True, True, ['image_random_sample'], [ 110 | 'image_normalize', 'image_resize_smallest', 'image_random_crop', 111 | 'image_random_flip', 'image_extract_flow_channels', 'image_clip_flow' 112 | ], []), 113 | )) 114 | def test_add_image(self, is_training, num_test_clips, is_flow, random_flip, 115 | sample_ops, preprocess_ops, postprocess_ops): 116 | is_rgb = None if is_flow else True 117 | zero_centering_image = is_flow 118 | modalities.add_image( 119 | self._parser_builder, # `parser_builder` 120 | self._sampler_builder, # `sampler_builder` 121 | self._decoder_builder, # `decoder_builder` 122 | self._preprocessor_builder, # `preprocessor_builder` 123 | self._postprocessor_builder, # `postprocessor_builder` 124 | 'image/encoded', # `input_feature_name` 125 | 'image', # `output_feature_name` 126 | is_training, # `is_training` 127 | 32, # `num_frames` 128 | 1, # `stride` 129 | num_test_clips, # `num_test_clips` 130 | 224, # `min_resize` 131 | 200, # `crop_size` 132 | zero_centering_image, # `zero_centering_image` 133 | True, # `sync_random_state` 134 | is_rgb, # `is_rgb` 135 | is_flow, # `is_flow` 136 | random_flip) # `random_flip` 137 | output = self._process_examples() 138 | 139 | self.assertAllEqual( 140 | [fd.fn_name for fd in self._sampler_builder.get_summary()], sample_ops) 141 | self.assertAllEqual( 142 | [fd.fn_name for fd in self._decoder_builder.get_summary()], 143 | ['image_decode_jpeg']) 144 | self.assertAllEqual( 145 | [fd.fn_name for fd in self._preprocessor_builder.get_summary()], 146 | preprocess_ops) 147 | self.assertAllEqual( 148 | [fd.fn_name for fd in self._postprocessor_builder.get_summary()], 149 | postprocess_ops) 150 | 151 | # Assert static shape. 152 | self.assertNotIn(None, output['image'].shape.as_list()) 153 | self.assertSetEqual(set(output.keys()), set(['image'])) 154 | num_output_channels = 2 if is_flow else 3 155 | self.assertAllEqual(output['image'].shape, 156 | (8 * num_test_clips, 32, 200, 200, num_output_channels)) 157 | 158 | @parameterized.expand(((False, False), (False, True), (True, True))) 159 | def test_add_label(self, one_hot_label, add_label_name): 160 | modalities.add_label( 161 | self._parser_builder, # `parser_builder` 162 | self._decoder_builder, # `decoder_builder` 163 | self._preprocessor_builder, # `preprocessor_builder` 164 | 'clip/label/index', # `input_label_index_feature_name` 165 | 'label', # `output_label_index_feature_name` 166 | 'clip/label/string', # `input_label_name_feature_name` 167 | 'label_name', # `output_label_name_feature_name` 168 | False, # `is_multi_label` 169 | one_hot_label, # `one_hot_label` 170 | 50, # `num_classes` 171 | add_label_name) # `add_label_name` 172 | output = self._process_examples() 173 | 174 | decoder_ops = ['label_sparse_to_dense'] 175 | if add_label_name: 176 | decoder_ops.append('label_name_sparse_to_dense') 177 | self.assertAllEqual( 178 | [fd.fn_name for fd in self._decoder_builder.get_summary()], 179 | decoder_ops) 180 | if one_hot_label: 181 | preprocess_ops = ['label_one_hot'] 182 | else: 183 | preprocess_ops = ['label_set_shape'] 184 | if add_label_name: 185 | preprocess_ops.append('label_name_set_shape') 186 | self.assertAllEqual( 187 | [fd.fn_name for fd in self._preprocessor_builder.get_summary()], 188 | preprocess_ops) 189 | 190 | # Assert static shape. 191 | self.assertNotIn(None, output['label'].shape.as_list()) 192 | 193 | keys = set(['label']) 194 | if add_label_name: 195 | keys.add('label_name') 196 | self.assertSetEqual(set(output.keys()), keys) 197 | if one_hot_label: 198 | self.assertAllEqual(output['label'], [[0] * 42 + [1] + [0] * 7] * 8) 199 | else: 200 | self.assertAllEqual(output['label'], [[42]] * 8) 201 | if add_label_name: 202 | self.assertAllEqual(output['label_name'], [[b'label']] * 8) 203 | 204 | @parameterized.expand(((16,), (1,))) 205 | def test_add_text(self, max_num_words): 206 | tokenizer_model = tokenizers.WordTokenizer( 207 | _VOCAB_PATH) # OSS: removed internal filename loading. 208 | tokenizer_model.initialize() 209 | 210 | modalities.add_text( 211 | self._parser_builder, # `parser_builder` 212 | self._decoder_builder, # `decoder_builder` 213 | self._preprocessor_builder, # `preprocessor_builder` 214 | tokenizer_model, # `tokenizer` 215 | True, # `is_training` 216 | 'caption/string', # `input_feature_name` 217 | builders.TEXT_FEATURE_NAME, # `output_raw_name` 218 | builders.TEXT_INDICES_FEATURE_NAME, # `output_feature_name` 219 | False, # `prepend_bos` 220 | False, # `append_eos` 221 | True, # `keep_raw_string` 222 | 2, # `max_num_captions` 223 | max_num_words, # `max_num_words` 224 | True) # `sync_random_state` 225 | 226 | output = self._process_examples() 227 | self.assertAllEqual( 228 | [fd.fn_name for fd in self._decoder_builder.get_summary()], 229 | ['text_indices_sparse_to_dense']) 230 | self.assertAllEqual( 231 | [fd.fn_name for fd in self._preprocessor_builder.get_summary()], 232 | ['text_indices_sample_captions', 'text_indices_tokenization', 233 | 'text_indices_set_shape']) 234 | 235 | # Assert static shape. 236 | self.assertNotIn( 237 | None, output[builders.TEXT_INDICES_FEATURE_NAME].shape.as_list()) 238 | self.assertSetEqual(set(output.keys()), 239 | set([builders.TEXT_INDICES_FEATURE_NAME, 240 | builders.TEXT_FEATURE_NAME])) 241 | words = [4, 5][:min(2, max_num_words)] 242 | padding = [0] * max(0, max_num_words - 2) 243 | self.assertAllEqual( 244 | output[builders.TEXT_INDICES_FEATURE_NAME], 245 | [[words + padding, words + padding]] * 8) 246 | 247 | @parameterized.expand(( 248 | (True, 1, ['audio_sparse_to_dense', 'audio_random_sample'], []), 249 | (False, 1, ['audio_sparse_to_dense', 'audio_middle_sample'], []), 250 | (False, 2, ['audio_sparse_to_dense', 'audio_linspace_sample'], 251 | ['audio_reshape']))) 252 | def test_add_audio(self, is_training, num_test_clips, sample_ops, 253 | postprocess_ops): 254 | modalities.add_audio( 255 | self._parser_builder, # `parser_builder` 256 | self._sampler_builder, # `sampler_builder` 257 | self._postprocessor_builder, # `postprocessor_builder` 258 | 'WAVEFORM/feature/floats', # `input_feature_name` 259 | builders.AUDIO_FEATURE_NAME, # `output_feature_name` 260 | is_training, # `is_training` 261 | 30720, # `num_samples` 262 | 1, # `stride` 263 | num_test_clips) # `num_test_clips` 264 | output = self._process_examples() 265 | 266 | self.assertAllEqual( 267 | [fd.fn_name for fd in self._sampler_builder.get_summary()], 268 | sample_ops) 269 | self.assertAllEqual( 270 | [fd.fn_name for fd in self._postprocessor_builder.get_summary()], 271 | postprocess_ops) 272 | 273 | # Assert static shape. 274 | self.assertNotIn( 275 | None, output[builders.AUDIO_FEATURE_NAME].shape.as_list()) 276 | self.assertSetEqual(set(output.keys()), 277 | set([builders.AUDIO_FEATURE_NAME])) 278 | self.assertAllEqual(output[builders.AUDIO_FEATURE_NAME].shape, 279 | (8 * num_test_clips, 30720)) 280 | 281 | def test_all_modalities(self): 282 | # Add RGB image. 283 | modalities.add_image(self._parser_builder, self._sampler_builder, 284 | self._decoder_builder, self._preprocessor_builder, 285 | self._postprocessor_builder) 286 | # Add flow image. Note that in this test this will read from a RGB 287 | # flow/encoded since we store flow on disk as RGB images where only the two 288 | # first channels (RG) corresponds to the relevant horizontal and vertical 289 | # displacement vector. 290 | modalities.add_image( 291 | self._parser_builder, 292 | self._sampler_builder, 293 | self._decoder_builder, 294 | self._preprocessor_builder, 295 | self._postprocessor_builder, 296 | input_feature_name='flow/encoded', 297 | output_feature_name=builders.FLOW_FEATURE_NAME, 298 | is_rgb=None, 299 | zero_centering_image=True, 300 | is_flow=True) 301 | modalities.add_label( 302 | self._parser_builder, 303 | self._decoder_builder, 304 | self._preprocessor_builder, 305 | num_classes=50) 306 | tokenizer = tokenizers.WordTokenizer( 307 | _VOCAB_PATH) # OSS: removed internal filename loading. 308 | tokenizer.initialize() 309 | modalities.add_text( 310 | self._parser_builder, 311 | self._decoder_builder, 312 | self._preprocessor_builder, 313 | tokenizer=tokenizer) 314 | modalities.add_audio(self._parser_builder, self._sampler_builder, 315 | self._postprocessor_builder) 316 | output = self._process_examples() 317 | 318 | self.assertSetEqual( 319 | set(output.keys()), 320 | set([ 321 | builders.IMAGE_FEATURE_NAME, builders.FLOW_FEATURE_NAME, 322 | builders.LABEL_INDEX_FEATURE_NAME, 323 | builders.TEXT_INDICES_FEATURE_NAME, builders.AUDIO_FEATURE_NAME 324 | ])) 325 | self.assertAllEqual(output[builders.IMAGE_FEATURE_NAME].shape, 326 | (8, 32, 200, 200, 3)) 327 | self.assertAllEqual(output[builders.FLOW_FEATURE_NAME].shape, 328 | (8, 32, 200, 200, 2)) 329 | self.assertAllEqual(output[builders.LABEL_INDEX_FEATURE_NAME], 330 | [[0] * 42 + [1] + [0] * 7] * 8) 331 | self.assertAllEqual(output[builders.TEXT_INDICES_FEATURE_NAME], 332 | [[[4, 5] + [0] * 14]] * 8) 333 | self.assertAllEqual(output[builders.AUDIO_FEATURE_NAME].shape, (8, 30720)) 334 | 335 | 336 | if __name__ == '__main__': 337 | tf.test.main() 338 | -------------------------------------------------------------------------------- /dmvr/processors_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 DeepMind Technologies Limited. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Tests for processors.""" 16 | 17 | import itertools 18 | import os 19 | 20 | from absl.testing import parameterized 21 | from dmvr import processors 22 | from dmvr import tokenizers 23 | import numpy as np 24 | import tensorflow as tf 25 | 26 | # Removed: Internal pyglib dependencies 27 | 28 | _TESTDATA_DIR = os.path.join(os.path.dirname(__file__), 'testdata') 29 | _SAMPLE_IMAGE_PATH = os.path.join(_TESTDATA_DIR, 'sample.jpeg') 30 | _VOCAB_PATH = os.path.join(_TESTDATA_DIR, 'tokenizers', 'word_vocab.txt') 31 | 32 | 33 | class SampleTest(tf.test.TestCase, parameterized.TestCase): 34 | 35 | def setUp(self): 36 | super().setUp() 37 | self._sequence = tf.range(100) 38 | 39 | def test_sample_linspace_sequence(self): 40 | sampled_seq_1 = processors.sample_linspace_sequence(self._sequence, 10, 10) 41 | sampled_seq_2 = processors.sample_linspace_sequence(self._sequence, 7, 10) 42 | sampled_seq_3 = processors.sample_linspace_sequence(self._sequence, 7, 5, 2) 43 | sampled_seq_4 = processors.sample_linspace_sequence(self._sequence, 101, 1) 44 | self.assertAllEqual(sampled_seq_1, range(100)) 45 | # [0, 1, 2, 3, 4, ..., 8, 9, 15, 16, ..., 97, 98, 99] 46 | self.assertAllEqual( 47 | sampled_seq_2, 48 | [15 * i + j for i, j in itertools.product(range(7), range(10))]) 49 | # [0, 2, 4, 6, 8, 15, 17, 19, ..., 96, 98] 50 | self.assertAllEqual( 51 | sampled_seq_3, 52 | [15 * i + 2 * j for i, j in itertools.product(range(7), range(5))]) 53 | self.assertAllEqual(sampled_seq_4, [0] + list(range(100))) 54 | 55 | def test_sample_sequence(self): 56 | sampled_seq_1 = processors.sample_sequence(self._sequence, 10, False) 57 | sampled_seq_2 = processors.sample_sequence(self._sequence, 10, False, 2) 58 | sampled_seq_3 = processors.sample_sequence(self._sequence, 10, True) 59 | 60 | self.assertAllEqual(sampled_seq_1, range(45, 55)) 61 | self.assertAllEqual(sampled_seq_2, range(40, 60, 2)) 62 | 63 | offset_3 = sampled_seq_3[0] 64 | self.assertBetween(offset_3, 0, 99) 65 | self.assertAllEqual(sampled_seq_3, range(offset_3, offset_3 + 10)) 66 | 67 | def test_sample_sequence_with_state(self): 68 | state = {} 69 | sampled_seq_1 = processors.sample_sequence( 70 | self._sequence, 10, True, state=state) 71 | sampled_seq_2 = processors.sample_sequence( 72 | self._sequence, 10, True, state=state) 73 | 74 | self.assertAllEqual(sampled_seq_1, sampled_seq_2) 75 | 76 | def test_sample_or_pad_non_sorted_sequence(self): 77 | sampled_seq_1 = processors.sample_or_pad_non_sorted_sequence( 78 | self._sequence, 10, 0, False) 79 | sampled_seq_2 = processors.sample_or_pad_non_sorted_sequence( 80 | self._sequence, 110, 0, False) 81 | 82 | self.assertAllEqual(sampled_seq_1, range(10)) 83 | self.assertAllEqual(sampled_seq_2, list(range(100)) + [0] * 10) 84 | 85 | def test_sample_or_pad_non_sorted_sequence_with_state(self): 86 | state = {} 87 | sampled_seq_1 = processors.sample_or_pad_non_sorted_sequence( 88 | self._sequence, 10, 0, True, state=state) 89 | sampled_seq_2 = processors.sample_or_pad_non_sorted_sequence( 90 | self._sequence, 10, 0, True, state=state) 91 | 92 | self.assertAllEqual(sampled_seq_1, sampled_seq_2) 93 | self.assertRaises( 94 | tf.errors.InvalidArgumentError, 95 | processors.sample_or_pad_non_sorted_sequence, 96 | self._sequence[:10], 10, 0, True, state=state) 97 | 98 | def test_sample_or_pad_non_sorted_sequence_multidim_with_state(self): 99 | state = {} 100 | sampled_seq_1 = processors.sample_or_pad_non_sorted_sequence( 101 | self._sequence, 10, 0, True, state=state) 102 | multi_dim_sequence = tf.tile(self._sequence[:, None], (1, 10)) 103 | sampled_seq_2 = processors.sample_or_pad_non_sorted_sequence( 104 | multi_dim_sequence, 10, 0, True, state=state) 105 | self.assertAllEqual(sampled_seq_1, sampled_seq_2[:, 0]) 106 | 107 | @parameterized.named_parameters( 108 | { 109 | 'testcase_name': 'len(seq) < num_steps', 110 | 'sequence': np.array([1, 2, 3]), 111 | 'num_steps': 5, 112 | 'expected_sequence': np.array([1, 2, 3, 1, 2]) 113 | }, 114 | { 115 | 'testcase_name': 'len(seq) == num_steps', 116 | 'sequence': np.array([1, 2, 3]), 117 | 'num_steps': 3, 118 | 'expected_sequence': np.array([1, 2, 3]) 119 | }, 120 | { 121 | 'testcase_name': 'len(seq) < num_steps with stride', 122 | 'sequence': np.array([1, 2, 3]), 123 | 'num_steps': 5, 124 | 'expected_sequence': np.array([1, 3, 2, 1, 3]), 125 | 'stride': 2 126 | }, 127 | { 128 | 'testcase_name': 'len(seq) == num_steps with stride', 129 | 'sequence': np.array([1, 2, 3]), 130 | 'num_steps': 3, 131 | 'expected_sequence': np.array([1, 1, 1]), 132 | 'stride': 3 133 | }, 134 | ) 135 | def test_sample_sequence_fixed_offset(self, 136 | sequence: np.ndarray, 137 | num_steps: int, 138 | expected_sequence: np.ndarray, 139 | stride: int = 1): 140 | """Tests that offset is always 0.""" 141 | for seed in range(5): 142 | actual_sequence = processors.sample_sequence( 143 | sequence, num_steps=num_steps, random=True, stride=stride, seed=seed) 144 | np.testing.assert_array_equal(actual_sequence, expected_sequence) 145 | 146 | 147 | class DecodeTest(tf.test.TestCase): 148 | 149 | def test_decode_jpeg(self): 150 | with open(_SAMPLE_IMAGE_PATH, 'rb') as f: raw_image_bytes = f.read() 151 | raw_image = tf.constant([raw_image_bytes, raw_image_bytes]) 152 | decoded_image = processors.decode_jpeg(raw_image) 153 | decoded_image_with_static_channels = processors.decode_jpeg(raw_image, 3) 154 | self.assertEqual(decoded_image_with_static_channels.shape.as_list()[3], 3) 155 | self.assertAllEqual(decoded_image.shape, (2, 263, 320, 3)) 156 | self.assertAllEqual(decoded_image_with_static_channels.shape, 157 | (2, 263, 320, 3)) 158 | 159 | 160 | class PreprocessTest(tf.test.TestCase): 161 | 162 | def setUp(self): 163 | super().setUp() 164 | # [[0, 1, ..., 119], [1, 2, ..., 120], ..., [119, 120, ..., 218]]. 165 | self._frames = tf.stack([tf.range(i, i + 120) for i in range(90)]) 166 | self._frames = tf.cast(self._frames, tf.uint8) 167 | self._frames = self._frames[tf.newaxis, :, :, tf.newaxis] 168 | self._frames = tf.broadcast_to(self._frames, (6, 90, 120, 3)) 169 | 170 | # Create an equivalent numpy array for assertions. 171 | self._np_frames = np.array([range(i, i + 120) for i in range(90)]) 172 | self._np_frames = self._np_frames[np.newaxis, :, :, np.newaxis] 173 | self._np_frames = np.broadcast_to(self._np_frames, (6, 90, 120, 3)) 174 | 175 | def test_set_shape(self): 176 | with open(_SAMPLE_IMAGE_PATH, 'rb') as f: raw_image = f.read() 177 | raw_image = tf.constant([raw_image]) 178 | decoded_image = processors.decode_jpeg(raw_image) 179 | decoded_image = processors.set_shape(decoded_image, (1, 263, 320, 3)) 180 | self.assertAllEqual(decoded_image.shape.as_list(), (1, 263, 320, 3)) 181 | 182 | def test_crop_image(self): 183 | cropped_image_1 = processors.crop_image(self._frames, 50, 70) 184 | cropped_image_2 = processors.crop_image(self._frames, 200, 200) 185 | cropped_image_3 = processors.crop_image(self._frames, 50, 70, True) 186 | 187 | self.assertAllEqual(cropped_image_1.shape, (6, 50, 70, 3)) 188 | self.assertAllEqual(cropped_image_1, self._np_frames[:, 20:70, 25:95, :]) 189 | 190 | self.assertAllEqual(cropped_image_2.shape, (6, 200, 200, 3)) 191 | expected = np.pad( 192 | self._np_frames, ((0, 0), (55, 55), (40, 40), (0, 0)), 'constant') 193 | self.assertAllEqual(cropped_image_2, expected) 194 | 195 | self.assertAllEqual(cropped_image_3.shape, (6, 50, 70, 3)) 196 | offset = cropped_image_3[0, 0, 0, 0] 197 | expected = np.array([range(i, i + 70) for i in range(offset, offset + 50)]) 198 | expected = expected[np.newaxis, :, :, np.newaxis] 199 | expected = np.broadcast_to(expected, (6, 50, 70, 3)) 200 | self.assertAllEqual(cropped_image_3, expected) 201 | 202 | def test_crop_image_with_state(self): 203 | state = {} 204 | cropped_image_1 = processors.crop_image(self._frames, 50, 70, state=state) 205 | cropped_image_2 = processors.crop_image(self._frames, 50, 70, state=state) 206 | 207 | self.assertAllEqual(cropped_image_1, cropped_image_2) 208 | 209 | def test_resize_smallest(self): 210 | resized_frames_1 = processors.resize_smallest(self._frames, 180) 211 | resized_frames_2 = processors.resize_smallest(self._frames, 45) 212 | resized_frames_3 = processors.resize_smallest(self._frames, 90) 213 | resized_frames_4 = processors.resize_smallest( 214 | tf.transpose(a=self._frames, perm=(0, 2, 1, 3)), 45) 215 | self.assertAllEqual(resized_frames_1.shape, (6, 180, 240, 3)) 216 | self.assertAllEqual(resized_frames_2.shape, (6, 45, 60, 3)) 217 | self.assertAllEqual(resized_frames_3.shape, (6, 90, 120, 3)) 218 | self.assertAllEqual(resized_frames_4.shape, (6, 60, 45, 3)) 219 | 220 | def test_resize_smallest_with_flow(self): 221 | flows = tf.cast(self._frames, tf.float32) 222 | resized_flows = processors.resize_smallest(flows, 180, True) 223 | resized_flows_expected = 2.0 * processors.resize_smallest(flows, 180, False) 224 | 225 | self.assertAllEqual(resized_flows, resized_flows_expected) 226 | 227 | def test_random_flip_left_right(self): 228 | flipped_frames = processors.random_flip_left_right(self._frames) 229 | flipped = np.fliplr(self._np_frames[0, :, :, 0]) 230 | flipped = flipped[np.newaxis, :, :, np.newaxis] 231 | flipped = np.broadcast_to(flipped, (6, 90, 120, 3)) 232 | self.assertTrue((flipped_frames == self._np_frames).numpy().all() or ( 233 | flipped_frames == flipped).numpy().all()) 234 | 235 | def test_random_flip_left_right_with_flow(self): 236 | flows = tf.cast(self._frames, tf.float32) 237 | flipped_flows = processors.random_flip_left_right(flows, is_flow=True) 238 | flipped = np.fliplr(self._np_frames[0, :, :, 0]) 239 | flipped = flipped[np.newaxis, :, :, np.newaxis] 240 | flipped = np.broadcast_to(flipped, (6, 90, 120, 3)) 241 | flipped_flow = flipped.astype(np.float32) 242 | flipped_flow[:, :, :, 0] *= -1.0 243 | self.assertTrue( 244 | (flipped_flows == self._np_frames.astype(np.float32)).numpy().all() or ( 245 | flipped_flows == flipped_flow).numpy().all()) 246 | 247 | def test_random_flip_left_right_with_state(self): 248 | state = {} 249 | flipped_frames_1 = processors.random_flip_left_right( 250 | self._frames, state=state) 251 | flipped_frames_2 = processors.random_flip_left_right( 252 | self._frames, state=state) 253 | self.assertAllEqual(flipped_frames_1, flipped_frames_2) 254 | 255 | def test_normalize_image(self): 256 | normalized_images_1 = processors.normalize_image( 257 | self._frames, False, tf.float32) 258 | normalized_images_2 = processors.normalize_image( 259 | self._frames, True, tf.float32) 260 | self.assertAllClose(normalized_images_1, self._np_frames / 255) 261 | self.assertAllClose(normalized_images_2, self._np_frames * 2 / 255 - 1.0) 262 | 263 | def test_scale_jitter_augm(self): 264 | no_jitter_images = processors.scale_jitter_augm(self._frames, 0.8, 1.0, 0.0) 265 | jitter_images = processors.scale_jitter_augm( 266 | self._frames, 2.0, 2.00001, 1.0) 267 | self.assertAllEqual(no_jitter_images.shape, (6, 90, 120, 3)) 268 | self.assertAllEqual(jitter_images.shape, (6, 180, 240, 3)) 269 | 270 | def test_scale_jitter_augm_with_state(self): 271 | state = {} 272 | jitter_image_1 = processors.scale_jitter_augm( 273 | self._frames, 0.8, 1.2, 1.0, state=state) 274 | jitter_image_2 = processors.scale_jitter_augm( 275 | self._frames, 0.8, 1.2, 1.0, state=state) 276 | self.assertAllEqual(jitter_image_1, jitter_image_2) 277 | 278 | def test_scale_jitter_augm_with_flow(self): 279 | state = {} 280 | flows = tf.cast(self._frames, tf.float32) 281 | jitter_flows = processors.scale_jitter_augm( 282 | flows, 0.8, 1.2, 1.0, state=state, is_flow=True) 283 | jitter_flows_expected = processors.scale_jitter_augm( 284 | flows, 0.8, 1.2, 1.0, state=state) 285 | h_s, w_s, _ = state['scale_jitter_augm_info'] 286 | jitter_flows_expected *= tf.stack([h_s, w_s, 1.0])[None, None, None, :] 287 | self.assertAllClose(jitter_flows, jitter_flows_expected) 288 | 289 | def test_color_default_augment(self): 290 | normalized_images = processors.normalize_image( 291 | self._frames, False, tf.float32) 292 | no_augmented_images = processors.color_default_augm( 293 | normalized_images, False, 0.0, 0.0) 294 | color_augmented_images = processors.color_default_augm( 295 | normalized_images, False, 1.0, 0.0) 296 | color_dropped_images = processors.color_default_augm( 297 | normalized_images, False, 0.0, 1.0) 298 | self.assertAllEqual(no_augmented_images.shape, normalized_images.shape) 299 | self.assertAllEqual(color_augmented_images.shape, normalized_images.shape) 300 | self.assertAllEqual(color_dropped_images.shape, normalized_images.shape) 301 | 302 | self.assertAllEqual(normalized_images, no_augmented_images) 303 | self.assertNotAllEqual(normalized_images, color_augmented_images) 304 | self.assertNotAllEqual(normalized_images, color_dropped_images) 305 | 306 | self.assertAllEqual(color_dropped_images[:, :, :, 0], 307 | color_dropped_images[:, :, :, 1]) 308 | self.assertAllEqual(color_dropped_images[:, :, :, 0], 309 | color_dropped_images[:, :, :, 2]) 310 | 311 | def test_space_to_depth(self): 312 | output_frames_1 = processors.space_to_depth(self._frames, 2, 3) 313 | output_frames_2 = processors.space_to_depth(self._frames, 3, 2) 314 | output_frames_3 = processors.space_to_depth( 315 | self._frames, spatial_block_size=2) 316 | self.assertAllEqual(output_frames_1.shape, (3, 30, 40, 54)) 317 | self.assertAllEqual(output_frames_2.shape, (2, 45, 60, 36)) 318 | self.assertAllEqual(output_frames_3.shape, (6, 45, 60, 12)) 319 | 320 | def test_crop_or_pad_words(self): 321 | input_words_indices = tf.expand_dims(tf.range(10, dtype=tf.int32), axis=0) 322 | 323 | output_words_indices_1 = processors.crop_or_pad_words( 324 | input_words_indices, 5) 325 | output_words_indices_2 = processors.crop_or_pad_words( 326 | input_words_indices, 15) 327 | self.assertAllEqual(output_words_indices_1, [list(range(5))]) 328 | self.assertAllEqual(output_words_indices_2, 329 | [[i for i in range(10)] + [0] * 5]) 330 | 331 | def test_tokenize(self): 332 | tokenizer = tokenizers.WordTokenizer( 333 | _VOCAB_PATH) # OSS: removed internal filename loading. 334 | tokenizer.initialize() 335 | input_features = {'text': tf.constant(['hello world', 'hello', 'world'])} 336 | 337 | output_features = processors.tokenize(input_features, tokenizer, 'text', 338 | 'indices', False, False, 4, True) 339 | self.assertAllEqual(output_features['text'], 340 | ['hello world', 'hello', 'world']) 341 | self.assertAllEqual(output_features['indices'], 342 | [[4, 5, 0, 0], [4, 0, 0, 0], [5, 0, 0, 0]]) 343 | 344 | 345 | class PostprocessTest(tf.test.TestCase): 346 | 347 | def test_batched_video_transpose(self): 348 | input_tensor = tf.constant([[[1, 2], [3, 4], [5, 6]]]) 349 | output_tensor = processors.batched_video_transpose(input_tensor, (0, 2, 1)) 350 | 351 | self.assertAllEqual(output_tensor, [[[1, 3, 5], [2, 4, 6]]]) 352 | 353 | def test_batched_space_to_depth(self): 354 | input_frames = tf.zeros((8, 30, 150, 210, 3)) 355 | 356 | output_frames_1 = processors.batched_space_to_depth(input_frames, 2, 3) 357 | output_frames_2 = processors.batched_space_to_depth(input_frames, 3, 2) 358 | output_frames_3 = processors.batched_space_to_depth( 359 | input_frames, spatial_block_size=2) 360 | 361 | self.assertAllEqual(output_frames_1.shape, (8, 15, 50, 70, 54)) 362 | self.assertAllEqual(output_frames_2.shape, (8, 10, 75, 105, 36)) 363 | self.assertAllEqual(output_frames_3.shape, (8, 30, 75, 105, 12)) 364 | 365 | 366 | if __name__ == '__main__': 367 | tf.test.main() 368 | -------------------------------------------------------------------------------- /dmvr/sources.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 DeepMind Technologies Limited. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Sources for reading and decoding raw binary data files.""" 16 | 17 | import abc 18 | from typing import Optional, Union 19 | 20 | import tensorflow as tf 21 | 22 | 23 | class Source(abc.ABC): 24 | """Base class for sources. 25 | 26 | Sources are objects reading from binary files and generating an initial 27 | `tf.data.Dataset` with the serialized examples. Deserializing the examples is 28 | not responsibility of the `Source` (it should be done by the parser). 29 | 30 | For each different type of storage (e.g. TFRecords, image files, text files), 31 | a subclass can be implemented. 32 | """ 33 | 34 | @abc.abstractmethod 35 | def load_and_decode_shard( 36 | self, 37 | shard: Union[str, tf.Tensor] # Shape () and type `tf.string`. 38 | ) -> tf.data.Dataset: 39 | """Decodes a single raw input file into a `tf.data.Dataset`. 40 | 41 | Args: 42 | shard: Path to a single file with encoded data. 43 | 44 | Returns: 45 | A `tf.data.Dataset` object containing a key (this can be a file name, 46 | index, empty or any other useful bits) and a raw example (both encoded as 47 | bytes). Current supported types of examples are `tf.train.Example` and 48 | `tf.train.SequenceExample` (see `builders.BaseParserBuilder`). 49 | """ 50 | 51 | 52 | class TFRecordsSource(Source): 53 | """Source for TFRecords data format.""" 54 | 55 | def __init__(self, compression_type: Optional[str] = None): 56 | self._compression_type = compression_type 57 | 58 | def load_and_decode_shard( 59 | self, 60 | shard: Union[str, tf.Tensor] # Shape () and type `tf.string`. 61 | ) -> tf.data.Dataset: 62 | ds = tf.data.TFRecordDataset(shard, compression_type=self._compression_type) 63 | # TFRecords do not provide an index or key per example. Use shard path as 64 | # key, since it can be useful later for retrieval. 65 | key = shard.encode('utf-8') if isinstance(shard, str) else shard 66 | ds = ds.map(lambda example: (key, example)) 67 | return ds 68 | -------------------------------------------------------------------------------- /dmvr/sources_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 DeepMind Technologies Limited. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Tests for sources.""" 16 | 17 | import os 18 | 19 | from dmvr import sources 20 | import tensorflow as tf 21 | 22 | 23 | class TFRecordsSourceTest(tf.test.TestCase): 24 | 25 | def setUp(self): 26 | super().setUp() 27 | 28 | self._shard = os.path.join(self.get_temp_dir(), 'shard') 29 | # Generate a TFRecord single shard with one serialized SequenceExample 30 | # in the format ('sequence', [[0], [1], ..., [99]]). 31 | with tf.io.TFRecordWriter(self._shard) as builder: 32 | self._seq_example = tf.train.SequenceExample() 33 | for i in range(100): 34 | self._seq_example.feature_lists.feature_list.get_or_create( 35 | 'sequence').feature.add().int64_list.value[:] = [i] 36 | builder.write(self._seq_example.SerializeToString()) 37 | 38 | def test_load_and_decode(self): 39 | source = sources.TFRecordsSource() 40 | ds = source.load_and_decode_shard(self._shard) 41 | it = iter(ds) 42 | 43 | data = next(it) 44 | self.assertEqual(data[0], self._shard.encode('utf-8')) 45 | self.assertEqual(data[1], self._seq_example.SerializeToString()) 46 | 47 | with self.assertRaises(StopIteration) as _: 48 | data = next(it) 49 | 50 | def test_input_as_tensor(self): 51 | source = sources.TFRecordsSource() 52 | ds = source.load_and_decode_shard(tf.constant(self._shard)) 53 | it = iter(ds) 54 | 55 | data = next(it) 56 | self.assertEqual(data[0], self._shard.encode('utf-8')) 57 | self.assertEqual(data[1], self._seq_example.SerializeToString()) 58 | 59 | 60 | if __name__ == '__main__': 61 | tf.test.main() 62 | -------------------------------------------------------------------------------- /dmvr/testdata/sample.jpeg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/google-deepmind/dmvr/77ccedaa084d29239eaeafddb0b2e83843b613a1/dmvr/testdata/sample.jpeg -------------------------------------------------------------------------------- /dmvr/testdata/tokenizers/bert_word_vocab.txt: -------------------------------------------------------------------------------- 1 | [CLS] 2 | [PAD] 3 | [UNK] 4 | [SEP] 5 | hello 6 | world 7 | -------------------------------------------------------------------------------- /dmvr/testdata/tokenizers/spiece.model.1000.model: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/google-deepmind/dmvr/77ccedaa084d29239eaeafddb0b2e83843b613a1/dmvr/testdata/tokenizers/spiece.model.1000.model -------------------------------------------------------------------------------- /dmvr/testdata/tokenizers/word_vocab.txt: -------------------------------------------------------------------------------- 1 | 0 2 | 1 3 | 2 4 | 3 5 | 4 hello 6 | 5 world 7 | -------------------------------------------------------------------------------- /dmvr/tokenizers.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 DeepMind Technologies Limited. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | """A simple tokenizer interface with basic implementations.""" 15 | 16 | import abc 17 | from typing import Optional, Sequence, Union 18 | 19 | import clip.simple_tokenizer 20 | import tensorflow as tf 21 | import tensorflow_text 22 | 23 | import sentencepiece as spm 24 | 25 | 26 | class TextTokenizer(abc.ABC): 27 | """Base class for text tokenizers.""" 28 | 29 | def initialize(self): 30 | """Initializes tensorflow tables and models.""" 31 | return 32 | 33 | @abc.abstractmethod 34 | def string_tensor_to_indices(self, 35 | string_tensor: Union[tf.Tensor, Sequence[str]], 36 | prepend_bos: bool = False, 37 | append_eos: bool = False, 38 | max_num_tokens: Optional[int] = 32) -> tf.Tensor: 39 | """Tokenizes input text, mapping a tensor of strings to a tensor of ints. 40 | 41 | Args: 42 | string_tensor: Input string tensor of shape [num_texts]. 43 | prepend_bos: Whether to prepend the BOS (beginning of sentence) token to 44 | the output tokens. 45 | append_eos: Whether to append the EOS (end of sentence) token to the 46 | output tokens. 47 | max_num_tokens: Maximum number of tokens to return per caption. If 48 | provided, the tokens will be padded / cut at the given size. If not, a 49 | tensor of unknown size will be returned. 50 | 51 | Returns: 52 | A `tf.int32` tensor of shape [num_texts, `max_num_tokens`] if 53 | `max_num_tokens` is provided or [num_texts, max_num_tokens_in_batch] 54 | otherwise. 55 | """ 56 | 57 | @abc.abstractmethod 58 | def indices_to_string(self, indices: Sequence[int]) -> str: 59 | """Detokenizes, mapping a python sequence of indices to a string.""" 60 | 61 | @property 62 | @abc.abstractmethod 63 | def vocab_size(self) -> int: 64 | """Returns the vocabulary size.""" 65 | 66 | @property 67 | @abc.abstractmethod 68 | def pad_token(self) -> int: 69 | """Returns index of the PAD token.""" 70 | 71 | @property 72 | @abc.abstractmethod 73 | def bos_token(self) -> int: 74 | """Returns index of the BOS token.""" 75 | 76 | @property 77 | @abc.abstractmethod 78 | def eos_token(self) -> int: 79 | """Returns index of the EOS token.""" 80 | 81 | @property 82 | @abc.abstractmethod 83 | def unk_token(self) -> int: 84 | """Returns index of the UNK token.""" 85 | 86 | 87 | class SentencePieceTokenizer(TextTokenizer): 88 | """SentencePiece tokenizer from a pre-trained SentencePiece model. 89 | 90 | Pre-trained models are provided in multiple repositories around the web. See 91 | https://github.com/google/sentencepiece for info on how to train new models on 92 | specific corpus. 93 | """ 94 | 95 | def __init__(self, model_path: str): 96 | """Initializes the `SentencePieceTokenizer`. 97 | 98 | Args: 99 | model_path: Path to the '.model' file. 100 | """ 101 | self._model_path = model_path 102 | self._sp_model = spm.SentencePieceProcessor() 103 | self._sp_model.Load(model_path) 104 | 105 | self._vocab_size = self._sp_model.GetPieceSize() 106 | self._bos_token = self._sp_model.bos_id() 107 | self._eos_token = self._sp_model.eos_id() 108 | self._pad_token = self._sp_model.pad_id() 109 | self._unk_token = self._sp_model.unk_id() 110 | 111 | self._tf_sp_model = None 112 | 113 | def initialize(self): 114 | with tf.io.gfile.GFile(self._model_path, 'rb') as f: 115 | self._tf_sp_model = tensorflow_text.SentencepieceTokenizer( 116 | model=f.read(), out_type=tf.int32, add_bos=True, add_eos=True) 117 | 118 | def string_tensor_to_indices(self, 119 | string_tensor: Union[tf.Tensor, Sequence[str]], 120 | prepend_bos: bool = False, 121 | append_eos: bool = False, 122 | max_num_tokens: Optional[int] = 32) -> tf.Tensor: 123 | if self._tf_sp_model is None: 124 | raise RuntimeError('Model was not initialized. Call `initialize` method.') 125 | 126 | tokenized = self._tf_sp_model.tokenize(string_tensor) 127 | tokenized = tokenized if prepend_bos else tokenized[..., 1:] 128 | tokenized = tokenized if append_eos else tokenized[..., :-1] 129 | 130 | # Pad to `max_num_tokens`. 131 | shape = None if max_num_tokens is None else [None, max_num_tokens] 132 | tokenized = tokenized.to_tensor(default_value=self._pad_token, shape=shape) 133 | return tokenized 134 | 135 | def indices_to_string(self, indices: Sequence[int]) -> str: 136 | return self._sp_model.DecodeIds(indices) 137 | 138 | def string_to_indices(self, 139 | string: str, 140 | prepend_bos: bool = False, 141 | append_eos: bool = False, 142 | max_num_tokens: Optional[int] = 32) -> Sequence[int]: 143 | """Tokenizes, mapping a python string to a sequence of indices.""" 144 | tokenized = self._sp_model.EncodeAsIds(string) 145 | tokenized = [self._bos_token] * prepend_bos + tokenized 146 | tokenized += [self._eos_token] * append_eos 147 | if max_num_tokens: 148 | tokenized = tokenized[:max_num_tokens] 149 | num_tokens = len(tokenized) 150 | tokenized = tokenized + [self._pad_token] * (max_num_tokens - num_tokens) 151 | return tokenized 152 | 153 | @property 154 | def vocab_size(self): 155 | return self._vocab_size 156 | 157 | @property 158 | def pad_token(self): 159 | return self._pad_token 160 | 161 | @property 162 | def bos_token(self): 163 | return self._bos_token 164 | 165 | @property 166 | def eos_token(self): 167 | return self._eos_token 168 | 169 | @property 170 | def unk_token(self): 171 | return self._unk_token 172 | 173 | 174 | class WordTokenizer(TextTokenizer): 175 | """Vocabulary based word tokenizer.""" 176 | 177 | PAD = '' 178 | BOS = '' 179 | EOS = '' 180 | UNK = '' 181 | 182 | def __init__(self, vocabulary_path: str): 183 | """Initializes the `WordTokenizer`. 184 | 185 | Args: 186 | vocabulary_path: A path to a vocabulary file. The vocabulary is a simple 187 | text file where each line is of the form: 'idx_word word' or simply 188 | 'word' (the line index will be used). The vocabulary should at least 189 | contain the words: '', '', '' and ''. 190 | """ 191 | # Parse the vocabulary. The expected format is either one word per line (and 192 | # the index for that word will be the line index) or an index and a word, 193 | # split by space. 194 | idx2word = {} 195 | with tf.io.gfile.GFile(vocabulary_path) as f: 196 | for line_idx, line in enumerate(f): 197 | line = line.strip().split(' ') 198 | 199 | if len(line) not in [1, 2]: 200 | raise ValueError(f'Line {line_idx} of vocabulary file, with contents ' 201 | f'\'{line}\' is malformed') 202 | 203 | idx, word = line if len(line) == 2 else (line_idx, line[0]) 204 | idx = int(idx) 205 | 206 | if idx in idx2word: 207 | raise ValueError( 208 | f'Vocabulary contains two words with same index {idx}.') 209 | if word != word.lower(): 210 | raise ValueError(f'Word {word} with index {idx} is not lower case.') 211 | 212 | idx2word[idx] = word 213 | 214 | # Validate. 215 | if len(idx2word) != len(set(idx2word.values())): 216 | raise ValueError('Words in vocabulary are not unique.') 217 | basic_tokens = {self.PAD, self.BOS, self.EOS, self.UNK} 218 | if basic_tokens & set(idx2word.values()) != basic_tokens: 219 | raise ValueError( 220 | f'Vocabulary does not contain all basic tokens {basic_tokens}.') 221 | 222 | self._idx2word = idx2word 223 | self._word2idx = {v: k for k, v in idx2word.items()} 224 | 225 | self._vocab_size = len(idx2word) 226 | self._pad_token = self._word2idx[self.PAD] 227 | self._bos_token = self._word2idx[self.BOS] 228 | self._eos_token = self._word2idx[self.EOS] 229 | self._unk_token = self._word2idx[self.UNK] 230 | 231 | self._tf_word2idx = None 232 | self._tf_whitespace_tokenizer = None 233 | 234 | def initialize(self): 235 | ids_tensor = tf.constant([i for w, i in self._word2idx.items()], 236 | dtype=tf.int32) 237 | words_tensor = tf.constant([w for w, i in self._word2idx.items()], 238 | dtype=tf.string) 239 | self._tf_whitespace_tokenizer = tensorflow_text.WhitespaceTokenizer() 240 | self._tf_word2idx = tf.lookup.StaticHashTable( 241 | tf.lookup.KeyValueTensorInitializer(words_tensor, ids_tensor), 242 | self._unk_token) 243 | 244 | def string_tensor_to_indices(self, 245 | string_tensor: Union[tf.Tensor, Sequence[str]], 246 | prepend_bos: bool = False, 247 | append_eos: bool = False, 248 | max_num_tokens: Optional[int] = 32) -> tf.Tensor: 249 | if self._tf_word2idx is None or self._tf_whitespace_tokenizer is None: 250 | raise RuntimeError('Model was not initialized. Call `initialize` method.') 251 | 252 | # Remove punctuation. 253 | string_tensor = tf.strings.regex_replace(string_tensor, '[[:punct:]]', '') 254 | # Lower case. 255 | string_tensor = tf.strings.lower(string_tensor) 256 | if prepend_bos: 257 | string_tensor = self.BOS.encode('utf-8') + b' ' + string_tensor 258 | if append_eos: 259 | string_tensor += b' ' + self.EOS.encode('utf-8') 260 | 261 | # Separate words by whitespace. 262 | tokenized = self._tf_whitespace_tokenizer.tokenize(string_tensor) 263 | # Map word to indices. 264 | tokenized = self._tf_word2idx.lookup(tokenized) 265 | # Pad to `max_num_tokens`. 266 | shape = None if max_num_tokens is None else [None, max_num_tokens] 267 | tokenized = tokenized.to_tensor(default_value=self._pad_token, shape=shape) 268 | return tokenized 269 | 270 | def indices_to_string(self, indices: Sequence[int]) -> str: 271 | # Cut at `EOS` or `PAD`. 272 | idx_list_cut = [] 273 | for token_id in indices: 274 | if token_id in [self._pad_token, self._eos_token]: 275 | break 276 | idx_list_cut.append(token_id) 277 | 278 | # Decode back to string. 279 | words_list = [self._idx2word[idx] for idx in idx_list_cut] 280 | return ' '.join(words_list) 281 | 282 | def string_to_indices(self, 283 | string: str, 284 | prepend_bos: bool = False, 285 | append_eos: bool = False, 286 | max_num_tokens: Optional[int] = 32) -> Sequence[int]: 287 | """Tokenizes, mapping a python string to a sequence of indices.""" 288 | string = string.translate( 289 | str.maketrans('', '', '!"#$%&\'()*+,-./:;<=>?@[\\]^_`{|}~')) 290 | string = string.lower() 291 | words = string.split(' ') 292 | tokenized = [self._word2idx.get(w, self._unk_token) for w in words] 293 | tokenized = [self._bos_token] * prepend_bos + tokenized 294 | tokenized += [self._eos_token] * append_eos 295 | if max_num_tokens: 296 | tokenized = tokenized[:max_num_tokens] 297 | num_tokens = len(tokenized) 298 | tokenized = tokenized + [self._pad_token] * (max_num_tokens - num_tokens) 299 | return tokenized 300 | 301 | @property 302 | def vocab_size(self): 303 | return self._vocab_size 304 | 305 | @property 306 | def pad_token(self): 307 | return self._pad_token 308 | 309 | @property 310 | def bos_token(self): 311 | return self._bos_token 312 | 313 | @property 314 | def eos_token(self): 315 | return self._eos_token 316 | 317 | @property 318 | def unk_token(self): 319 | return self._unk_token 320 | 321 | 322 | class BertTokenizer(TextTokenizer): 323 | """BERT tokenizer. 324 | 325 | Standard BERT vocabularies can be found in tf hub. 326 | """ 327 | 328 | PAD = '[PAD]' 329 | CLS = '[CLS]' 330 | SEP = '[SEP]' 331 | BOS = CLS 332 | EOS = SEP 333 | UNK = '[UNK]' 334 | 335 | def __init__(self, vocabulary_path: str): 336 | """Initializes the `BertTokenizer`. 337 | 338 | Args: 339 | vocabulary_path: A path to a vocabulary file. The vocabulary is a simple 340 | text file where each line is of the form: 'token'. The vocabulary should 341 | at least contain the words: '[PAD]', '[CLS]', '[SEP]' and '[UNK]'. 342 | """ 343 | # Parse the vocabulary. 344 | idx2word = {} 345 | self._vocabulary_path = vocabulary_path 346 | with tf.io.gfile.GFile(vocabulary_path) as f: 347 | for idx, line in enumerate(f): 348 | word = line.strip() 349 | idx2word[idx] = word 350 | 351 | # Validate. 352 | if len(idx2word) != len(set(idx2word.values())): 353 | raise ValueError('Words in vocabulary are not unique.') 354 | basic_tokens = {self.PAD, self.BOS, self.EOS, self.UNK} 355 | if basic_tokens & set(idx2word.values()) != basic_tokens: 356 | raise ValueError( 357 | f'Vocabulary does not contain all basic tokens {basic_tokens}.') 358 | 359 | self._idx2word = idx2word 360 | self._word2idx = {v: k for k, v in idx2word.items()} 361 | 362 | self._vocab_size = len(idx2word) 363 | self._pad_token = self._word2idx[self.PAD] 364 | self._bos_token = self._word2idx[self.BOS] 365 | self._eos_token = self._word2idx[self.EOS] 366 | self._unk_token = self._word2idx[self.UNK] 367 | 368 | self._tf_tokenizer = None 369 | 370 | def initialize(self): 371 | self._tf_tokenizer = tensorflow_text.BertTokenizer( 372 | self._vocabulary_path, 373 | token_out_type=tf.int32, 374 | unknown_token=self.UNK, 375 | lower_case=True) 376 | 377 | def string_tensor_to_indices(self, 378 | string_tensor: Union[tf.Tensor, Sequence[str]], 379 | prepend_bos: bool = False, 380 | append_eos: bool = False, 381 | max_num_tokens: Optional[int] = 32) -> tf.Tensor: 382 | if self._tf_tokenizer is None: 383 | raise RuntimeError('Model was not initialized. Call `initialize` method.') 384 | 385 | batch_size = tf.shape(input=string_tensor)[0] 386 | tokenized = self._tf_tokenizer.tokenize(string_tensor) 387 | tokenized = tokenized.merge_dims(-2, -1) 388 | 389 | if append_eos: 390 | eos_tensor = tf.ragged.constant([self._eos_token]) 391 | eos_tensor = tf.tile(eos_tensor, [batch_size]) 392 | eos_tensor = tf.expand_dims(eos_tensor, axis=1) 393 | tokenized = tf.concat([tokenized, eos_tensor], axis=1) 394 | if prepend_bos: 395 | bos_tensor = tf.ragged.constant([self._bos_token]) 396 | bos_tensor = tf.tile(bos_tensor, [batch_size]) 397 | bos_tensor = tf.expand_dims(bos_tensor, axis=1) 398 | tokenized = tf.concat([bos_tensor, tokenized], axis=1) 399 | 400 | # Pad to `max_num_tokens`. 401 | shape = None if max_num_tokens is None else [None, max_num_tokens] 402 | tokenized = tokenized.to_tensor(default_value=self._pad_token, shape=shape) 403 | return tokenized 404 | 405 | def indices_to_string(self, indices: Sequence[int]) -> str: 406 | # Cut at `EOS` or `PAD`. 407 | idx_list_cut = [] 408 | for token_id in indices: 409 | if token_id in [self._pad_token, self._eos_token]: 410 | break 411 | idx_list_cut.append(token_id) 412 | 413 | # Decode back to string. 414 | word_iter = (self._idx2word[idx] for idx in idx_list_cut) 415 | return ' '.join(word_iter).replace(' ##', '') 416 | 417 | @property 418 | def vocab_size(self): 419 | return self._vocab_size 420 | 421 | @property 422 | def pad_token(self): 423 | return self._pad_token 424 | 425 | @property 426 | def bos_token(self): 427 | return self._bos_token 428 | 429 | @property 430 | def eos_token(self): 431 | return self._eos_token 432 | 433 | @property 434 | def unk_token(self): 435 | return self._unk_token 436 | 437 | @property 438 | def cls_token(self): 439 | return self._bos_token 440 | 441 | @property 442 | def sep_token(self): 443 | return self._eos_token 444 | 445 | 446 | class ClipTokenizer(TextTokenizer): 447 | """CLIP tokenizer.""" 448 | 449 | BOS = '<|startoftext|>' 450 | EOS = '<|endoftext|>' 451 | UNK = EOS 452 | 453 | def __init__( 454 | self, 455 | vocabulary_path: Optional[str] = None, 456 | ) -> None: 457 | """Initializes the `ClipTokenizer`. 458 | 459 | Args: 460 | vocabulary_path: A path to a CLIP-style vocabulary file. 461 | """ 462 | self._tokenizer = clip.simple_tokenizer.SimpleTokenizer(vocabulary_path) 463 | 464 | self._vocab_size = len(self._tokenizer.encoder) 465 | self._pad_token = 0 466 | self._bos_token = self._tokenizer.encoder[self.BOS] 467 | self._eos_token = self._tokenizer.encoder[self.EOS] 468 | self._unk_token = self._tokenizer.encoder[self.UNK] 469 | 470 | self._initialized = False 471 | 472 | def initialize(self) -> None: 473 | self._initialized = True 474 | 475 | def _clip_tokenize(self, texts: Union[tf.Tensor, 476 | Sequence[str]]) -> tf.RaggedTensor: 477 | if isinstance(texts, tf.Tensor): 478 | texts = [text.decode('utf-8') for text in texts._numpy().tolist()] # pylint: disable=protected-access 479 | return tf.ragged.constant([self._tokenizer.encode(text) for text in texts], 480 | dtype=tf.int32) 481 | 482 | def string_tensor_to_indices(self, 483 | string_tensor: Union[tf.Tensor, Sequence[str]], 484 | prepend_bos: bool = False, 485 | append_eos: bool = False, 486 | max_num_tokens: Optional[int] = 77) -> tf.Tensor: 487 | if not self._initialized: # To satisfy the tests. 488 | raise RuntimeError('Model was not initialized. Call `initialize` method.') 489 | 490 | batch_size = tf.shape(input=string_tensor)[0] 491 | 492 | tokenized = tf.py_function( 493 | func=self._clip_tokenize, 494 | inp=[string_tensor], 495 | Tout=tf.RaggedTensorSpec([None, None], dtype=tf.int32)) 496 | 497 | if append_eos: 498 | eos_tensor = tf.ragged.constant([self._eos_token]) 499 | eos_tensor = tf.tile(eos_tensor, [batch_size]) 500 | eos_tensor = tf.expand_dims(eos_tensor, axis=1) 501 | tokenized = tf.concat([tokenized, eos_tensor], axis=1) 502 | if prepend_bos: 503 | bos_tensor = tf.ragged.constant([self._bos_token]) 504 | bos_tensor = tf.tile(bos_tensor, [batch_size]) 505 | bos_tensor = tf.expand_dims(bos_tensor, axis=1) 506 | tokenized = tf.concat([bos_tensor, tokenized], axis=1) 507 | 508 | # Pad to `max_num_tokens`. 509 | shape = None if max_num_tokens is None else [None, max_num_tokens] 510 | return tokenized.to_tensor(default_value=self._pad_token, shape=shape) 511 | 512 | def indices_to_string(self, indices: Sequence[int]) -> str: 513 | text = self._tokenizer.decode(i for i in indices if i != self._pad_token) 514 | start_pos = len(self.BOS) if text.startswith(self.BOS) else 0 515 | end_pos = text.index(self.EOS) if self.EOS in text else None 516 | return text[start_pos:end_pos].strip() 517 | 518 | @property 519 | def vocab_size(self) -> int: 520 | return self._vocab_size 521 | 522 | @property 523 | def pad_token(self) -> int: 524 | return self._pad_token 525 | 526 | @property 527 | def bos_token(self) -> int: 528 | return self._bos_token 529 | 530 | @property 531 | def eos_token(self) -> int: 532 | return self._eos_token 533 | 534 | @property 535 | def unk_token(self) -> int: 536 | return self._unk_token 537 | -------------------------------------------------------------------------------- /dmvr/tokenizers_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 DeepMind Technologies Limited. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Tests for tokenizers.""" 16 | 17 | from __future__ import annotations 18 | 19 | from collections.abc import Sequence 20 | import os 21 | from typing import Type, TypeVar 22 | 23 | import clip.simple_tokenizer 24 | from dmvr import tokenizers 25 | from parameterized import parameterized 26 | import tensorflow as tf 27 | 28 | # Removed: Internal pyglib dependencies 29 | 30 | _TESTDATA_DIR = os.path.join(os.path.dirname(__file__), 'testdata') 31 | _MOCK_DATA = os.path.join(_TESTDATA_DIR, 'tokenizers') 32 | 33 | _FILENAMES = { 34 | tokenizers.SentencePieceTokenizer: 'spiece.model.1000.model', 35 | tokenizers.WordTokenizer: 'word_vocab.txt', 36 | tokenizers.BertTokenizer: 'bert_word_vocab.txt', 37 | tokenizers.ClipTokenizer: clip.simple_tokenizer.default_bpe(), 38 | } 39 | 40 | T = TypeVar('T', bound=tokenizers.TextTokenizer) 41 | 42 | 43 | def _get_tokenizer(cls: Type[T]) -> T: 44 | filename = _FILENAMES[cls] 45 | path = os.path.join(_MOCK_DATA, filename) # OSS: removed internal filename loading. 46 | return cls(path) 47 | 48 | 49 | def _tokenize_with_original_clip( 50 | texts: str | Sequence[str], 51 | context_length: int = 77) -> Sequence[Sequence[int]]: 52 | # Code adapted from `clip.tokenize` because it's not importable (only 53 | # `clip.simple_tokenizer` is). 54 | 55 | if isinstance(texts, str): 56 | texts = [texts] 57 | 58 | tokenizer = clip.simple_tokenizer.SimpleTokenizer() 59 | sot_token = tokenizer.encoder['<|startoftext|>'] 60 | eot_token = tokenizer.encoder['<|endoftext|>'] 61 | all_tokens = [[sot_token] + tokenizer.encode(text) + [eot_token] 62 | for text in texts] 63 | result = [] 64 | 65 | for i, tokens in enumerate(all_tokens): 66 | if len(tokens) > context_length: 67 | raise RuntimeError(f'Input {texts[i]} is too long for context length' 68 | f' {context_length}') 69 | result.append(tokens + [0] * (context_length - len(tokens))) 70 | 71 | return result 72 | 73 | 74 | def _decode_with_original_clip(tokens_ids: Sequence[int]) -> str: 75 | tokenizer = clip.simple_tokenizer.SimpleTokenizer() 76 | text = tokenizer.decode(tokens_ids) 77 | 78 | eos = '<|endoftext|>' 79 | return text[:text.index(eos) + len(eos)] 80 | 81 | 82 | class TokenizerTest(tf.test.TestCase): 83 | 84 | @parameterized.expand( 85 | ((tokenizers.WordTokenizer,), (tokenizers.SentencePieceTokenizer,), 86 | (tokenizers.BertTokenizer,), (tokenizers.ClipTokenizer,))) 87 | def test_tokenizer(self, cls): 88 | tokenizer = _get_tokenizer(cls) 89 | tokenizer.initialize() 90 | input_string = ['hello world'] 91 | 92 | tokenized = tokenizer.string_tensor_to_indices( 93 | input_string, max_num_tokens=42) 94 | self.assertEqual(tokenized.dtype, tf.int32) 95 | 96 | tokenized = tokenized.numpy().tolist()[0] 97 | self.assertLen(tokenized, 42) 98 | self.assertEqual(tokenized[-1], tokenizer.pad_token) 99 | 100 | detokenized = tokenizer.indices_to_string(tokenized) 101 | self.assertEqual(detokenized, 'hello world') 102 | 103 | @parameterized.expand( 104 | ((tokenizers.WordTokenizer,), (tokenizers.SentencePieceTokenizer,), 105 | (tokenizers.BertTokenizer,), (tokenizers.ClipTokenizer,))) 106 | def test_bos_eos(self, cls): 107 | tokenizer = _get_tokenizer(cls) 108 | tokenizer.initialize() 109 | input_string = ['hello world'] 110 | 111 | tokenized = tokenizer.string_tensor_to_indices( 112 | input_string, prepend_bos=True, append_eos=True) 113 | tokenized = tokenized.numpy().tolist()[0] 114 | self.assertEqual(tokenized[0], tokenizer.bos_token) 115 | 116 | if tokenizer.pad_token != tokenizer.eos_token: 117 | tokenized = [t for t in tokenized if t != tokenizer.pad_token] 118 | self.assertEqual(tokenized[-1], tokenizer.eos_token) 119 | 120 | @parameterized.expand( 121 | ((tokenizers.WordTokenizer,), (tokenizers.SentencePieceTokenizer,), 122 | (tokenizers.BertTokenizer,), (tokenizers.ClipTokenizer,))) 123 | def test_not_initialized(self, cls): 124 | tokenizer = _get_tokenizer(cls) 125 | input_string = ['hello world'] 126 | 127 | with self.assertRaises(RuntimeError): 128 | tokenizer.string_tensor_to_indices(input_string) 129 | 130 | @parameterized.expand(( 131 | (tokenizers.WordTokenizer,), 132 | (tokenizers.SentencePieceTokenizer,), 133 | )) 134 | def test_string_to_indices(self, cls): 135 | tokenizer = _get_tokenizer(cls) 136 | tokenizer.initialize() 137 | input_string = 'hello world' 138 | tokenized = tokenizer.string_to_indices( 139 | input_string, prepend_bos=True, append_eos=True, max_num_tokens=42) 140 | self.assertEqual(type(tokenized), list) 141 | self.assertEqual(tokenized[0], tokenizer.bos_token) 142 | tokenized = [t for t in tokenized if t != tokenizer.pad_token] 143 | self.assertEqual(tokenized[-1], tokenizer.eos_token) 144 | 145 | detokenized = tokenizer.indices_to_string(tokenized[1:-1]) 146 | self.assertEqual(detokenized, 'hello world') 147 | 148 | def test_clip_tokenizer(self): 149 | tokenizer = _get_tokenizer(tokenizers.ClipTokenizer) 150 | tokenizer.initialize() 151 | input_string = ['This is a test.', 'pushups'] 152 | actual_tokenized_tf = tokenizer.string_tensor_to_indices( 153 | input_string, prepend_bos=True, append_eos=True, max_num_tokens=77) 154 | 155 | expected_tokenized = _tokenize_with_original_clip(input_string) 156 | 157 | actual_tokenized1 = actual_tokenized_tf.numpy().tolist()[0] 158 | expected_tokenized1 = expected_tokenized[0] 159 | self.assertEqual(actual_tokenized1, expected_tokenized1) 160 | 161 | actual_decoded = tokenizer.indices_to_string(actual_tokenized1) 162 | self.assertEqual(actual_decoded, 'this is a test .') 163 | 164 | actual_tokenized2 = actual_tokenized_tf.numpy().tolist()[1] 165 | expected_tokenized2 = expected_tokenized[1] 166 | self.assertEqual(actual_tokenized2, expected_tokenized2) 167 | 168 | actual_decoded = tokenizer.indices_to_string(actual_tokenized2) 169 | self.assertEqual(actual_decoded, input_string[1]) 170 | 171 | 172 | if __name__ == '__main__': 173 | tf.test.main() 174 | -------------------------------------------------------------------------------- /dmvr/utils.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 DeepMind Technologies Limited. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Utils.""" 16 | 17 | from typing import Optional, Sequence 18 | 19 | import tensorflow as tf 20 | 21 | 22 | # ---------------------------------------------------------------------- 23 | # ------------------------ Experimental utils. ------------------------- 24 | # ---------------------------------------------------------------------- 25 | 26 | 27 | def combine_datasets(datasets: Sequence[tf.data.Dataset], 28 | batch_size: int = 1, 29 | weights: Optional[Sequence[float]] = None, 30 | seed: Optional[int] = None) -> tf.data.Dataset: 31 | """Combines multiple datasets into a single one. 32 | 33 | THIS IS AN EXPERIMENTAL FEATURE AND MIGHT BE REMOVED AT ANY TIME. 34 | 35 | This function combines multiple datasets into a single one by sampling 36 | elements from each one with the given probabilities. All input datasets must 37 | have the same structure and Tensor shapes. 38 | 39 | Args: 40 | datasets: A list of batched datasets. All datasets should have the same 41 | structure and Tensor shapes. 42 | batch_size: Batch size of the resulting dataset. 43 | weights: A list of the same length as datasets of floats where `weights[i]` 44 | represents the probability with which an element should be sampled from 45 | `datasets[i]`. If `None`, defaults to a uniform distribution across 46 | datasets. 47 | seed: A deterministic seed to use when sampling. 48 | 49 | Returns: 50 | A dataset that interleaves elements from datasets at random, according to 51 | weights if provided, otherwise with uniform probability. The resulting 52 | dataset is batched. 53 | """ 54 | datasets = [ds.unbatch() for ds in datasets] 55 | combined_ds = tf.data.experimental.sample_from_datasets( 56 | datasets, weights, seed) 57 | return combined_ds.batch(batch_size, drop_remainder=True) 58 | -------------------------------------------------------------------------------- /dmvr/utils_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 DeepMind Technologies Limited. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Tests for utils.""" 16 | 17 | from dmvr import utils 18 | import tensorflow as tf 19 | 20 | 21 | class UtilsTest(tf.test.TestCase): 22 | 23 | def test_combine_datasets(self): 24 | ds_0 = tf.data.Dataset.from_tensor_slices({ 25 | 'feature_0': [[[0] * 10] * 10] * 5, 26 | 'feature_1': [[0] * 10] * 5, 27 | }) 28 | ds_1 = tf.data.Dataset.from_tensor_slices({ 29 | 'feature_0': [[[1] * 10] * 10] * 5, 30 | 'feature_1': [[1] * 10] * 5, 31 | }) 32 | ds_2 = tf.data.Dataset.from_tensor_slices({ 33 | 'feature_0': [[[2] * 10] * 10] * 5, 34 | 'feature_1': [[2] * 10] * 5, 35 | }) 36 | 37 | # Dataset uniformly sampling from all 3 datasets. 38 | ds_uniform = utils.combine_datasets([ds_0, ds_1, ds_2], 7) 39 | data_uniform = next(iter(ds_uniform)) 40 | 41 | # Dataset sampling from ds_1 and ds_2. 42 | ds_no_1 = utils.combine_datasets([ds_0, ds_1, ds_2], 7, [0.5, 0, 0.5]) 43 | data_no_1 = next(iter(ds_no_1)) 44 | 45 | self.assertSetEqual(set(data_uniform.keys()), 46 | set(['feature_0', 'feature_1'])) 47 | self.assertAllEqual(data_uniform['feature_0'].shape, (7, 10)) 48 | self.assertAllEqual(data_uniform['feature_1'].shape, (7,)) 49 | 50 | self.assertSetEqual(set(data_no_1.keys()), 51 | set(['feature_0', 'feature_1'])) 52 | self.assertAllEqual(data_no_1['feature_0'].shape, (7, 10)) 53 | self.assertAllEqual(data_no_1['feature_1'].shape, (7,)) 54 | 55 | self.assertAllInSet(data_uniform['feature_0'], (0, 1, 2)) 56 | self.assertAllInSet(data_uniform['feature_1'], (0, 1, 2)) 57 | self.assertAllInSet(data_no_1['feature_0'], (0, 2)) 58 | self.assertAllInSet(data_no_1['feature_1'], (0, 2)) 59 | 60 | 61 | if __name__ == '__main__': 62 | tf.test.main() 63 | -------------------------------------------------------------------------------- /dmvr/video_dataset.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 DeepMind Technologies Limited. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | """Basic constructors for video datasets.""" 15 | 16 | import abc 17 | from typing import Any, List, Optional, Type, TypeVar 18 | 19 | from absl import logging 20 | from dmvr import builders 21 | from dmvr import sources 22 | import tensorflow as tf 23 | 24 | # Types. 25 | T = TypeVar('T', bound=builders.BaseParserBuilder) 26 | NestedStructure = Any 27 | 28 | 29 | class BaseVideoDatasetFactory(abc.ABC): 30 | """Base class to build final `tf.data.Dataset` objects from files. 31 | 32 | Glossary: 33 | 34 | - A source is an object reading binary files in disk (e.g. TFRecords, image 35 | files) and outputting serialized examples (e.g. `tf.train.SequenceExample`). 36 | - A parser is an object reading serialized examples (e.g. 37 | `tf.train.SequenceExample`) and outputting a `builders.FeaturesDict`. 38 | - A processor is an object transforming features dictionary. 39 | - The data processing pipeline is organised in phases. A phase is an unit of 40 | the data processing graph and will have one parser or processor. 41 | - Builders are helpers designed to allow the user to easily customize the data 42 | processing graph by adding functions to each phase. 43 | 44 | Principle: 45 | 46 | All datasets created with this factory follow the same abstraction: 47 | a `parse_fn`, a `sample_fn`, a `decode_fn`, a `preprocess_fn` and a 48 | `postprocess_fn` are used to control the flow of dataset creation besides 49 | normal dataset operations. These functions are created from builders, allowing 50 | the user to build a graph of data processing operations. In details, the 51 | following steps are followed when creating a dataset: 52 | - Read shards from file system using the given `source.Source`. 53 | - Apply `parse_fn` to output values of the `source` (as bytes) to build a 54 | dictionary of raw features. The parse function should only parse the useful 55 | bytes of the serialized input example (e.g. `tf.train.SequenceExample`) and 56 | put the features in a `builders.FeaturesDict` format. `parser_builder` can 57 | be used to easily add more features / modalities. 58 | - Apply `sample_fn` to sequence features contained in the dictionary in 59 | order to select the desired elements of the sequence, e.g. sample a subset 60 | of frames from the entire stored video. `sampler_builder` can be used to 61 | modify or add sampling options. 62 | - Apply `decode_fn` to convert raw formats to the final format. E.g. decode 63 | JPEG string `tf.Tensor` to a `tf.Tensor` of `uint8`. `decoder_builder` can 64 | be used. 65 | - Apply `preprocess_fn`. E.g. crop images, process audio and text. 66 | `preprocessor_builder` can be used. 67 | - Batch, shuffle, prefetch and do other basic operations with the dataset. 68 | - Apply `postprocess_fn` to batched examples. E.g. transpose batches. 69 | `postprocessor_builder` can be used. 70 | 71 | After each one of the data processing functions, a filter is applied in order 72 | to keep only desirable elements in the dataset. These filters can be 73 | customized by using the `filter_builder`. 74 | 75 | A conventional use of this factory consists of implementing a subclass for a 76 | specific dataset, overriding the `_build` method where all common processing 77 | of the specific dataset can be added using the builders. 78 | 79 | The client of the dataset is able to create a factory, configure it, possibly 80 | add custom extra processing steps and use it to make a dataset. 81 | 82 | Usage: 83 | 84 | ```python 85 | class KineticsFactory(BaseVideoDatasetFactory): 86 | 87 | def __init__(self, subset: str): 88 | shards = ['path/to/kinetics/tfrecords/records-00001-of-00500.tfrecord', 89 | ...] 90 | shards = filter_by_subset(shards, subset) 91 | super().__init__(shards) 92 | 93 | def _build(self, frame_height: int, frame_width: int, frame_count: int): 94 | self.parser_builder.parse_feature( 95 | image_seq_example_feature_name, 96 | tf.io.FixedLenSequenceFeature((), dtype=tf.string), 97 | builders.IMAGE_FEATURE_NAME) 98 | self.sampler_builder.add_fn( 99 | lambda x: sample_sequence_fn(x, frame_count), 100 | builders.IMAGE_FEATURE_NAME) 101 | self.decoder_builder.add_fn(decode_frames_fn, builders.IMAGE_FEATURE_NAME) 102 | self.preprocessor_builder.add_fn( 103 | lambda x: resize_frames(x, frame_height, frame_width), 104 | builders.IMAGE_FEATURE_NAME) 105 | # Other processing functions adding text and label. 106 | 107 | # Dataset client code: 108 | factory = KineticsFactory(subset='test').configure( 109 | frame_height=224, frame_width=224, frame_count=8) 110 | 111 | # Add extra custom preprocess functions: 112 | def my_custom_text_tokenizer(text: tf.Tensor) -> tf.Tensor: 113 | # Tokenize text string. 114 | return tokenized_tensor 115 | 116 | def my_custom_add_word_indices( 117 | features_dict: builders.FeaturesDict) -> builders.FeaturesDict: 118 | tokenized_text = features_dict[builders.TEXT_FEATURE_NAME] 119 | features_dict[builders.TEXT_INDICES_FEATURE_NAME] = text_to_indices( 120 | tokenized_text) 121 | return features_dict 122 | 123 | (factory.preprocess_builder.add_fn(my_custom_tokenizer, 124 | builders.TEXT_FEATURE_NAME) 125 | .add_fn(my_custom_add_word_indices)) 126 | 127 | # Add filter: 128 | def keep_only_label_zero(fetures_dict: builders.FeaturesDict) -> tf.Tensor: 129 | return tf.equal(features_dict[builders.LABEL_INDEX_FEATURE_NAME], 0) 130 | factory.filter_builder.add_filter_fn( 131 | keep_only_label_zero, builders.Phase.PARSE) 132 | 133 | # Create dataset: 134 | ds = factory.make_dataset(batch_size=16) 135 | ``` 136 | 137 | The factory exposes the process functions builders to the client, allowing 138 | simple modifications to the functions. Common process functions, as crop, 139 | resize, etc. should be implemented in common modules. 140 | 141 | See builders documentation for more details. 142 | """ 143 | 144 | def __init__(self, 145 | shards: List[str], 146 | parser_builder_class: Type[T] = builders 147 | .SequenceExampleParserBuilder, 148 | source: sources.Source = sources.TFRecordsSource()): 149 | """Initializes the `BaseVideoDatasetFactory`. 150 | 151 | Args: 152 | shards: List of paths to shards containing the data files. Each one of the 153 | paths will be passed to the `source`, that will read the data and output 154 | examples (that will be fed into the parse function generated by the 155 | `parser_builder_class`). Therefore, `shards`, `parser_builder_class` and 156 | `source` have to be consistent. 157 | parser_builder_class: A parser builder class able to parse examples of the 158 | types contained in `shards` files. 159 | source: Source to be used to load raw binary files and decoding it into 160 | examples (encoded as bytes). 161 | """ 162 | 163 | self._shards = shards 164 | self._source = source 165 | 166 | # Initialize all function builders. 167 | self.parser_builder = parser_builder_class() 168 | self.sampler_builder = builders.SamplerBuilder() 169 | self.decoder_builder = builders.DecoderBuilder() 170 | self.preprocessor_builder = builders.PreprocessorBuilder() 171 | self.postprocessor_builder = builders.PostprocessorBuilder() 172 | 173 | # Initialize filters. 174 | self.filter_builder = builders.FilterBuilder() 175 | 176 | # Default tune parameters. 177 | self._shuffle_buffer = 256 178 | self._num_parser_threads = 16 179 | self._num_process_threads = tf.data.experimental.AUTOTUNE 180 | self._num_postprocess_threads = 4 181 | self._parser_buffer_size = 64 182 | self._postprocess_buffer_size = 1 183 | self._prefetch_buffer_size = 8 184 | self._cycle_length = None 185 | self._num_parallel_calls_interleave = tf.data.experimental.AUTOTUNE 186 | self._block_length = None 187 | self._seed = None 188 | self._duplicate_proto = None 189 | 190 | self._is_configured = False 191 | 192 | def configure(self, *args, **kwargs) -> 'BaseVideoDatasetFactory': 193 | """Configures all parse and process functions of this factory. 194 | 195 | This function should be called exactly once per factory instance and will 196 | delegate builders configuration to `_build` method. 197 | 198 | Args: 199 | *args: Positional arguments passed to `_build` function. 200 | **kwargs: Non positional arguments passed to `_build` function. 201 | 202 | Returns: 203 | This instance of the factory. 204 | 205 | Raises: 206 | ValueError: Method has already been called. 207 | """ 208 | if self._is_configured: 209 | raise ValueError( 210 | '`configure` has already been called. The method should be called ' 211 | 'only once to avoid duplicated process functions.') 212 | self._is_configured = True 213 | self._build(*args, **kwargs) 214 | return self 215 | 216 | def tune(self, 217 | shuffle_buffer: Optional[int] = None, 218 | num_parser_threads: Optional[int] = None, 219 | num_process_threads: Optional[int] = None, 220 | num_postprocess_threads: Optional[int] = None, 221 | parser_buffer_size: Optional[int] = None, 222 | postprocess_buffer_size: Optional[int] = None, 223 | prefetch_buffer_size: Optional[int] = None, 224 | cycle_length: Optional[int] = None, 225 | num_parallel_calls_interleave: Optional[int] = None, 226 | block_length: Optional[int] = None, 227 | seed: Optional[int] = None, 228 | duplicate_proto: Optional[int] = None): 229 | """Changes the dataset creation parameters. 230 | 231 | This method should be used to change the default parameters used to create 232 | the dataset in order to improve speed, memory or other. Only given 233 | parameters will be changed, the others will remain the same. 234 | 235 | Args: 236 | shuffle_buffer: The buffer size for shuffle operation. This affects the 237 | randomness of the output. It must be specified if `shuffle` is `True`. 238 | num_parser_threads: Number of threads to use for the parsing operation. 239 | `tf.data.experimental.AUTOTUNE` can be used to auto-tune. 240 | num_process_threads: Number of threads to use for map operations in 241 | sample, decode and preprocess. `tf.data.experimental.AUTOTUNE` can be 242 | used to auto-tune. 243 | num_postprocess_threads: Number of threads to use for map operations in 244 | postprocess. `tf.data.experimental.AUTOTUNE` can be used to auto-tune. 245 | parser_buffer_size: Buffer size of the sample, decode and preprocess 246 | operation. 247 | postprocess_buffer_size: Buffer size of the postprocess operation. 248 | prefetch_buffer_size: Size of the final prefetch buffer. 249 | cycle_length: The number of shards that will be processed concurrently. 250 | `tf.data.experimental.AUTOTUNE` can be used to auto-tune. 251 | num_parallel_calls_interleave: The number of parallel calls to the 252 | interleave method. `tf.data.experimental.AUTOTUNE` can be used to 253 | auto-tune. 254 | block_length: The number of consecutive elements to produce from each 255 | shard. 256 | seed: Random seed of the shuffle operations. 257 | duplicate_proto: Number of duplicates to make for each loaded proto. 258 | Typically different augmentations will be applied for each copy, so 259 | this can reduce disk reads without harming training performance. 260 | This is applied after the post read function, but before the shuffle 261 | buffer. 262 | 263 | Returns: 264 | This instance of the factory. 265 | """ 266 | self._shuffle_buffer = shuffle_buffer or self._shuffle_buffer 267 | self._num_parser_threads = num_parser_threads or self._num_parser_threads 268 | self._num_process_threads = num_process_threads or self._num_process_threads 269 | self._num_postprocess_threads = ( 270 | num_postprocess_threads or self._num_postprocess_threads) 271 | self._parser_buffer_size = parser_buffer_size or self._parser_buffer_size 272 | self._postprocess_buffer_size = ( 273 | postprocess_buffer_size or self._postprocess_buffer_size) 274 | self._prefetch_buffer_size = ( 275 | prefetch_buffer_size or self._prefetch_buffer_size) 276 | self._cycle_length = cycle_length or self._cycle_length 277 | self._num_parallel_calls_interleave = ( 278 | num_parallel_calls_interleave or self._num_parallel_calls_interleave) 279 | self._block_length = block_length or self._block_length 280 | self._seed = seed or self._seed 281 | self._duplicate_proto = duplicate_proto or self._duplicate_proto 282 | 283 | return self 284 | 285 | # ---------------------------------------------------------------------- 286 | # ---------- Methods that must be implemented by child class. ---------- 287 | # ---------------------------------------------------------------------- 288 | 289 | @abc.abstractmethod 290 | def _build(self, *args, **kwargs) -> None: 291 | """Builds the data processing graph.""" 292 | 293 | # ---------------------------------------------------------------------- 294 | # -------- Methods that should only be overridden if necessary. -------- 295 | # ---------------------------------------------------------------------- 296 | 297 | def make_dataset( 298 | self, 299 | shuffle: bool = True, 300 | num_epochs: Optional[int] = None, 301 | batch_size: Optional[int] = 16, 302 | padded_batch: bool = False, 303 | padded_batch_shapes: NestedStructure = None, 304 | drop_remainder: bool = True, 305 | keep_key: bool = False, 306 | cache: bool = False, 307 | override_preprocess_fn: Optional[builders.Processor] = None, 308 | **experimental_kwargs 309 | ) -> tf.data.Dataset: 310 | """Creates a `tf.data.Dataset` instance of the given dataset. 311 | 312 | Args: 313 | shuffle: Whether output data is shuffled. 314 | num_epochs: Number of epochs to cycle through before stopping. If `None`, 315 | this will read samples indefinitely. 316 | batch_size: If an int, an extra leading batch dimension will be present 317 | for all features. If `None`, then no batching is done and no extra batch 318 | dimension is added. 319 | padded_batch: Whether to use `padded_batch` instead of `batch` method. 320 | Padded batch pads a batch of examples to a given output shape. It pads 321 | all examples to the longest one in that batch. This could be used for 322 | sequence data. 323 | padded_batch_shapes: `padded_shapes` to be passed to `padded_batch`. 324 | drop_remainder: Whether to drop any remainder after the last full-size 325 | batch. If `True`, the batch dimension of the resulting op is known; 326 | otherwise, the batch dimension may be `None` in cases where `num_epochs` 327 | is finite and `batch_size` > 1, since the final remainder batch may be 328 | smaller than the usual batch size. 329 | keep_key: Whether to keep the `builders.Source` key as a feature in the 330 | final dictionary. The key for the key in the dictionary is 331 | `builders.KEY_FEATURE_NAME`. 332 | cache: Whether to cache the dataset in RAM. Note that this should only 333 | be used if the dataset can fit in RAM as otherwise it will lead to 334 | out of memory error. 335 | override_preprocess_fn: Function to use instead of built preprocess_fn. 336 | **experimental_kwargs: Other arguments used for experimental features. 337 | These can be removed at any time without prior notice. 338 | 339 | Returns: 340 | An instance of the dataset. 341 | 342 | Raises: 343 | ValueError: Factory has not been configured. 344 | ValueError: `shuffle_buffer` is `None` when dataset is shuffled. 345 | ValueError: `batch_size` is not `None`, `padded_batch` is `False` and 346 | `padded_batch_shapes` is not `None`. 347 | """ 348 | 349 | if not self._is_configured: 350 | raise ValueError('Factory has not been configured. Call `configure` ' 351 | 'method before `make_dataset`.') 352 | 353 | # Build functions or use its overrides. 354 | parse_fn = self.parser_builder.build() 355 | sample_fn = self.sampler_builder.build() 356 | decode_fn = self.decoder_builder.build() 357 | preprocess_fn = override_preprocess_fn or self.preprocessor_builder.build() 358 | postprocess_fn = self.postprocessor_builder.build() 359 | 360 | # Filter functions. 361 | filter_fn_post_read = self.filter_builder.build(builders.Phase.READ) 362 | filter_fn_post_parse = self.filter_builder.build(builders.Phase.PARSE) 363 | filter_fn_post_sample = self.filter_builder.build(builders.Phase.SAMPLE) 364 | filter_fn_post_decode = self.filter_builder.build(builders.Phase.DECODE) 365 | filter_fn_post_preprocess = self.filter_builder.build( 366 | builders.Phase.PREPROCESS) 367 | filter_fn_post_postprocess = self.filter_builder.build( 368 | builders.Phase.POSTPROCESS) 369 | 370 | if shuffle and self._shuffle_buffer is None: 371 | raise ValueError( 372 | '`shuffle_buffer` cannot be `None` if dataset is shuffled.') 373 | 374 | def parse_example(key: tf.Tensor, 375 | raw_example: tf.Tensor) -> builders.FeaturesDict: 376 | """Decodes bytes of example and parse it into a features dictionary.""" 377 | output = parse_fn(raw_example) 378 | # Potentially parse the key. 379 | if keep_key: 380 | output[builders.KEY_FEATURE_NAME] = key 381 | return output 382 | 383 | ds = tf.data.Dataset.from_tensor_slices(self._shards) 384 | if shuffle: 385 | # Shuffling the shards and not only the examples later is important. 386 | ds = ds.shuffle(len(self._shards), seed=self._seed) 387 | 388 | ds = ds.interleave( 389 | self._source.load_and_decode_shard, 390 | cycle_length=self._cycle_length, 391 | block_length=self._block_length, 392 | num_parallel_calls=self._num_parallel_calls_interleave, 393 | deterministic=not shuffle) 394 | 395 | # At this point, the features dictionary is not yet created. We artificially 396 | # create one with the key only to make the interface uniform. 397 | ds = ds.filter( 398 | lambda key, _: filter_fn_post_read({builders.KEY_FEATURE_NAME: key})) 399 | 400 | if self._duplicate_proto is not None: 401 | 402 | def duplicate_fn(x, y): 403 | return (tf.stack([x] * self._duplicate_proto), 404 | tf.stack([y] * self._duplicate_proto)) 405 | 406 | ds = ds.map(duplicate_fn) 407 | ds = ds.unbatch() 408 | 409 | if not cache: 410 | ds = ds.repeat(num_epochs) 411 | if shuffle: 412 | ds = ds.shuffle(self._shuffle_buffer, seed=self._seed) 413 | 414 | # Parse. 415 | ds = ds.map( 416 | parse_example, 417 | num_parallel_calls=self._num_parser_threads, 418 | deterministic=not shuffle) 419 | ds = ds.filter(filter_fn_post_parse) 420 | 421 | if cache: 422 | # We cache the dataset after the parsing operation. This means that we 423 | # cache the raw protos before any random operations happen. This can avoid 424 | # IO issues when the dataset fits in RAM. Note that this is the optimal 425 | # place to cache the data (caching before would have no effect as that 426 | # would only be caching a list of files, caching after would be not 427 | # possible due to the random operations that needs to happen after the 428 | # `ds.repeat` operation, making it impossible to cache as the dataset 429 | # would be unbounded). 430 | ds = ds.cache() 431 | ds = ds.repeat(num_epochs) 432 | if shuffle: 433 | ds = ds.shuffle(self._shuffle_buffer, seed=self._seed) 434 | else: 435 | ds = ds.prefetch(self._parser_buffer_size) 436 | 437 | # Sample. 438 | ds = ds.map( 439 | sample_fn, 440 | num_parallel_calls=self._num_process_threads, 441 | deterministic=not shuffle) 442 | ds = ds.filter(filter_fn_post_sample) 443 | 444 | # Decode. 445 | ds = ds.map( 446 | decode_fn, 447 | num_parallel_calls=self._num_process_threads, 448 | deterministic=not shuffle) 449 | ds = ds.filter(filter_fn_post_decode) 450 | 451 | # Preprocess. 452 | ds = ds.map( 453 | preprocess_fn, 454 | num_parallel_calls=self._num_process_threads, 455 | deterministic=not shuffle) 456 | ds = ds.filter(filter_fn_post_preprocess) 457 | 458 | if experimental_kwargs.get('unbatch_after_preprocessing', False): 459 | ds = ds.unbatch() 460 | 461 | if experimental_kwargs.get('ignore_processing_errors', False): 462 | ds = ds.apply(tf.data.experimental.ignore_errors()) 463 | 464 | if batch_size is not None: 465 | if padded_batch: 466 | ds = ds.padded_batch( 467 | batch_size=batch_size, 468 | padded_shapes=padded_batch_shapes, 469 | drop_remainder=drop_remainder) 470 | else: 471 | if padded_batch_shapes is not None: 472 | raise ValueError( 473 | '`padded_batch` is `False`, `padded_batch_shapes` must be `None`,' 474 | f'but is {padded_batch_shapes}.') 475 | ds = ds.batch(batch_size, drop_remainder=drop_remainder) 476 | 477 | # Postprocess. 478 | ds = ds.prefetch(self._postprocess_buffer_size) 479 | ds = ds.map( 480 | postprocess_fn, 481 | num_parallel_calls=self._num_postprocess_threads, 482 | deterministic=not shuffle) 483 | ds = ds.filter(filter_fn_post_postprocess) 484 | 485 | ds = ds.prefetch(self._prefetch_buffer_size) 486 | 487 | logging.info('Dataset created successfully') 488 | 489 | return ds 490 | -------------------------------------------------------------------------------- /dmvr/video_dataset_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 DeepMind Technologies Limited. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Tests for video_dataset.""" 16 | 17 | import os 18 | from typing import List, Union 19 | 20 | from dmvr import builders 21 | from dmvr import sources 22 | from dmvr import video_dataset 23 | from parameterized import parameterized 24 | import tensorflow as tf 25 | 26 | 27 | class _TestTFRecordsSource(sources.Source): 28 | 29 | def load_and_decode_shard(self, 30 | shard: Union[str, tf.Tensor]) -> tf.data.Dataset: 31 | ds = tf.data.TFRecordDataset(shard) 32 | ds = ds.map(lambda example: (b'test_key', example)) 33 | return ds 34 | 35 | 36 | class _TestVideoDatasetFactory( 37 | video_dataset.BaseVideoDatasetFactory): 38 | 39 | def __init__(self, shards: List[str]): 40 | super().__init__(shards, builders.SequenceExampleParserBuilder, 41 | _TestTFRecordsSource()) 42 | 43 | def _build(self, 44 | sample_offset: int = 0, 45 | multiply_by_2: bool = False, 46 | reduce_max: bool = False, 47 | keep_idx: bool = False): 48 | self.parser_builder.parse_feature( 49 | 'sequence', tf.io.FixedLenSequenceFeature((), dtype=tf.int64)) 50 | if keep_idx: 51 | self.parser_builder.parse_feature( 52 | 'idx', tf.io.FixedLenFeature((), dtype=tf.int64), is_context=True) 53 | 54 | self.sampler_builder.add_fn( 55 | lambda x: x[sample_offset:(sample_offset + 50)], 'sequence') 56 | 57 | self.decoder_builder.add_fn( 58 | lambda x: tf.cast(x, tf.uint8), 'sequence') 59 | 60 | if multiply_by_2: 61 | self.preprocessor_builder.add_fn(lambda x: 2 * x, 'sequence') 62 | 63 | if reduce_max: 64 | self.postprocessor_builder.add_fn( 65 | lambda x: tf.reduce_max(input_tensor=x, axis=1), 'sequence') 66 | 67 | 68 | class BaseVideoDatasetFactoryTest(tf.test.TestCase): 69 | 70 | def setUp(self): 71 | super().setUp() 72 | 73 | shards = [] 74 | tmp_dir = self.get_temp_dir() 75 | # Generate TFRecords of 5 shards with serialized SequenceExamples in the 76 | # format ('sequence', [[0], [1], ..., [99]]) plus the shard and element 77 | # indices. 78 | for shard_idx in range(5): 79 | shard = os.path.join(tmp_dir, 80 | 'example-{:05}-of-00005.tfrecord'.format(shard_idx)) 81 | shards.append(shard) 82 | 83 | # Create fake `tf.train.SequenceExample`. 84 | seq_example = tf.train.SequenceExample() 85 | for i in range(100): 86 | seq_example.feature_lists.feature_list.get_or_create( 87 | 'sequence').feature.add().int64_list.value[:] = [i] 88 | 89 | with tf.io.TFRecordWriter(shard) as builder: 90 | for idx in range(10): 91 | seq_example.context.feature.get_or_create( 92 | 'idx').int64_list.value[:] = [shard_idx * 10 + idx] 93 | builder.write(seq_example.SerializeToString()) 94 | 95 | self._factory = _TestVideoDatasetFactory(shards) 96 | 97 | def test_basic(self): 98 | ds = self._factory.configure().make_dataset(batch_size=2) 99 | 100 | data = next(iter(ds)) 101 | self.assertSetEqual(set(data.keys()), set(['sequence'])) 102 | self.assertAllEqual(data['sequence'], [list(range(50))] * 2) 103 | 104 | def test_configure(self): 105 | ds = self._factory.configure(10, True, True).make_dataset(batch_size=2) 106 | 107 | data = next(iter(ds)) 108 | self.assertSetEqual(set(data.keys()), set(['sequence'])) 109 | self.assertAllEqual(data['sequence'], [59 * 2] * 2) 110 | 111 | def test_configure_exception(self): 112 | with self.assertRaises(ValueError) as _: 113 | self._factory.make_dataset(batch_size=2) 114 | 115 | with self.assertRaises(ValueError) as _: 116 | self._factory.configure().configure() 117 | 118 | def test_keep_key(self): 119 | ds = self._factory.configure().make_dataset(batch_size=2, keep_key=True) 120 | 121 | data = next(iter(ds)) 122 | self.assertSetEqual(set(data.keys()), 123 | set(['sequence', builders.KEY_FEATURE_NAME])) 124 | self.assertAllEqual(data[builders.KEY_FEATURE_NAME].shape, (2,)) 125 | self.assertEqual(data[builders.KEY_FEATURE_NAME][0].numpy(), b'test_key') 126 | self.assertEqual(data[builders.KEY_FEATURE_NAME][1].numpy(), b'test_key') 127 | 128 | def test_override_preprocess_fn(self): 129 | # Data shouldn't be multiplied by 2. 130 | ds = self._factory.configure(multiply_by_2=True).make_dataset( 131 | batch_size=2, override_preprocess_fn=lambda x: x) 132 | 133 | data = next(iter(ds)) 134 | self.assertSetEqual(set(data.keys()), set(['sequence'])) 135 | self.assertAllEqual(data['sequence'], [list(range(50))] * 2) 136 | 137 | def test_no_shuffle(self): 138 | # Set block_length to guarantee reading all examples from the first shard. 139 | ds = self._factory.configure(keep_idx=True).tune( 140 | block_length=5).make_dataset(shuffle=False, batch_size=5) 141 | 142 | data = next(iter(ds)) 143 | self.assertSetEqual(set(data.keys()), set(['sequence', 'idx'])) 144 | self.assertAllEqual(data['idx'], [0, 1, 2, 3, 4]) 145 | 146 | def test_filter_read(self): 147 | self._factory.filter_builder.add_filter_fn( 148 | lambda fd: tf.not_equal(fd[builders.KEY_FEATURE_NAME], 'test_key'), 149 | builders.Phase.READ) 150 | ds = self._factory.configure().make_dataset(batch_size=10, keep_key=True) 151 | 152 | with self.assertRaises(StopIteration) as _: 153 | next(iter(ds)) 154 | 155 | @parameterized.expand( 156 | ((builders.Phase.PARSE,), (builders.Phase.SAMPLE,), 157 | (builders.Phase.DECODE,), (builders.Phase.PREPROCESS,))) 158 | def test_filter(self, phase): 159 | 160 | def keep_even_idx(features_dict): 161 | idx = features_dict['idx'] 162 | return tf.equal(idx % 2, 0) 163 | 164 | self._factory.filter_builder.add_filter_fn(keep_even_idx, phase) 165 | # Set block_length to guarantee reading examples in key order. 166 | ds = self._factory.configure(keep_idx=True).tune( 167 | block_length=10).make_dataset(shuffle=False, batch_size=10) 168 | 169 | data = next(iter(ds)) 170 | self.assertSetEqual(set(data.keys()), set(['sequence', 'idx'])) 171 | self.assertAllEqual(data['idx'], range(0, 20, 2)) 172 | 173 | def test_filter_postprocess(self): 174 | self._factory.filter_builder.add_filter_fn( 175 | lambda fd: tf.not_equal(fd['idx'][0], 0), # Filter first batch. 176 | builders.Phase.POSTPROCESS) 177 | # Set block_length to guarantee reading examples in key order. 178 | ds = self._factory.configure(keep_idx=True).tune( 179 | block_length=10).make_dataset(shuffle=False, batch_size=10) 180 | 181 | data = next(iter(ds)) 182 | self.assertSetEqual(set(data.keys()), set(['sequence', 'idx'])) 183 | self.assertAllEqual(data['idx'], range(10, 20)) 184 | 185 | def test_ignore_processing_errors(self): 186 | 187 | def fail_decode(idx): 188 | # Fail for all odd indices. 189 | error = tf.assert_equal(idx % 2, tf.zeros((), dtype=tf.int64)) 190 | with tf.control_dependencies([error]): 191 | return idx 192 | 193 | self._factory.decoder_builder.add_fn(fail_decode, 'idx') 194 | # Set block_length to guarantee reading examples in key order. 195 | ds = self._factory.configure(keep_idx=True).tune( 196 | block_length=10).make_dataset( 197 | shuffle=False, batch_size=10, ignore_processing_errors=True) 198 | 199 | data = next(iter(ds)) 200 | self.assertSetEqual(set(data.keys()), set(['sequence', 'idx'])) 201 | self.assertAllEqual(data['idx'], range(0, 20, 2)) 202 | 203 | 204 | if __name__ == '__main__': 205 | tf.test.main() 206 | -------------------------------------------------------------------------------- /examples/README.md: -------------------------------------------------------------------------------- 1 | # End-to-end walktrough: from raw data to using the readers for learning 2 | 3 | This document walks you through all the steps that go from raw data (a list of 4 | mp4 files), to a format that is compatible with DMVR, to writing a reader to 5 | finally use it in an ML application. 6 | 7 | 8 | ## Requirements 9 | 10 | To run the code, you will need to install the following dependencies: 11 | 12 | - python3 13 | - numpy 14 | - absl-py 15 | - pandas 16 | - Tensorflow 17 | - [ffmpeg](https://johnvansickle.com/ffmpeg/) 18 | - [ffmpeg-python](https://github.com/kkroening/ffmpeg-python) [Make sure you pip install ffmpeg-python and not python-ffmpeg] 19 | - unrar [Only for the HMDB-51 dataset generation example] 20 | - [scikit-learn](https://scikit-learn.org/) [Only for training linear model on HMDB] 21 | 22 | Please make sure the ffmpeg binaries (downloadable 23 | [here](https://johnvansickle.com/ffmpeg/)) are visible from the *PATH* 24 | environment variable and to install its python-ffmpeg python wrapper (and not 25 | ffmpeg-python which is different). Installing python-ffmpeg with pip can be done 26 | in one line with: 27 | 28 | ```sh 29 | pip install ffmpeg-python 30 | ``` 31 | 32 | ## Creating and reading your own DMVR dataset using open-source tools 33 | 34 | First, we will describe how to generate your own DMVR dataset as tfrecord files 35 | from your own videos using open-source tools. 36 | 37 | Finally, we provide a step-by-step example of how to generate the popular 38 | [HMDB-51](https://serre-lab.clps.brown.edu/resource/hmdb-a-large-human-motion-database/) 39 | action recognition video dataset into the DMVR format. 40 | 41 | ### Generating your own tfrecord files 42 | 43 | #### Creating the input CSV for the generation 44 | 45 | To generate a DMVR compatible video dataset using this tool, all you need is to 46 | provide a csv file with the paths of the videos you want to process, together 47 | with additional metadata such as the start/end timestamps, a label or a text 48 | caption. As an example we are going to download two videos (creative common 49 | license): 50 | 51 | ```sh 52 | wget https://cdn.spacetelescope.org/archives/videos/medium_podcast/heic1608c.mp4 \ 53 | -O /tmp/heic1608c.mp4 54 | wget https://upload.wikimedia.org/wikipedia/commons/1/18/BRKM_Javeline_Throw.webm \ 55 | -O /tmp/BRKM_Javeline_Throw.webm 56 | ``` 57 | 58 | We can create the following csv with the downloaded videos to process: 59 | 60 | ```sh 61 | video_path,start,end,label,caption 62 | /tmp/heic1608c.mp4,1.5,6.0,space,the view of the space from a telescope 63 | /tmp/BRKM_Javeline_Throw.webm,0.0,3.0,javeline_throw,someone is throwing a javeline 64 | ``` 65 | 66 | where a more precise description of each column is given below: 67 | 68 | | Column name | Description | Optional | 69 | | ----------- | -------------------------------------------------- | -------- | 70 | | video_path | the path of video to process | No | 71 | | start | the clip start time (in second) | No | 72 | | end | the clip end timee (in second) | No | 73 | | label | A label annotated with the clip (i.e. for | Yes | 74 | : : classification) : : 75 | | caption | A free-form text annotated with the clip (i.e. for | Yes | 76 | : : captioning or retrieval) : : 77 | 78 | Run this following line to create the csv: 79 | 80 | ```sh 81 | echo -e "video_path,start,end,label,caption\n/tmp/heic1608c.mp4,1.5,3.0,space,hubble\n/tmp/BRKM_Javeline_Throw.webm,0.0,3.0,javeline_throw,someone is throwing a javeline" > /tmp/input.csv 82 | ``` 83 | 84 | #### Generating the tfrecords data using the CSV 85 | 86 | Now that we have created a CSV file with the videos we wish to process, we can 87 | generated the tfrecords using the provided code. This can be done by running the 88 | following commands: 89 | 90 | ```sh 91 | mkdir /tmp/generated_dataset 92 | python generate_from_file.py \ 93 | --csv_path=/tmp/input.csv \ 94 | --output_path=/tmp/generated_dataset 95 | ``` 96 | 97 | where a description of the arguments is given below: 98 | 99 | Arguments | Description 100 | ------------ | ------------------------------------------------------- 101 | csv_path | The path of the input CSV with all the video path. 102 | output_path | The generated tfrecords output path 103 | num_shards | The number of tfrecord shards to create (default=1) 104 | decode_audio | Decode and store audio in the tfrecords (default=False) 105 | shuffle_csv | Whether or not to shuffle the input csv (default=False) 106 | 107 | Congratulations! You have created a DMVR compatible dataset from your own file! 108 | 109 | ### Example: Step-by-step generation of HMDB-51 in the DMVR format 110 | 111 | As another example, we provide step-by-step instructions for generating the 112 | [HMDB-51](https://serre-lab.clps.brown.edu/resource/hmdb-a-large-human-motion-database/) 113 | video dataset in the DMVR format. 114 | 115 | #### Creating the HMDB-51 input CSV for the generation pipeline 116 | 117 | First, you need to download the original splits from the official 118 | [link](http://serre-lab.clps.brown.edu/wp-content/uploads/2013/10/test_train_splits.rar). 119 | 120 | ```sh 121 | wget http://serre-lab.clps.brown.edu/wp-content/uploads/2013/10/test_train_splits.rar 122 | ``` 123 | 124 | If you have not 125 | [installed unrar](https://www.tecmint.com/how-to-open-extract-and-create-rar-files-in-linux/) 126 | yet, please install it to extract the rar, and then run: 127 | 128 | ```sh 129 | unrar x test_train_splits.rar 130 | rm test_train_splits.rar 131 | ``` 132 | 133 | Create the HMDB CSV file using our provided script: 134 | 135 | ```sh 136 | mkdir hmdb_csv 137 | python generate_hmdb_csv.py \ 138 | --input_path=testTrainMulti_7030_splits \ 139 | --output_path=hmdb_csv 140 | ``` 141 | 142 | This will generate in the *hmdb_csv* folder, 6 csv files: train_1.csv, 143 | test_1.csv, train_2.csv, test_2.csv, train_3.csv and test_3.csv which are the 144 | three train/test splits. 145 | 146 | #### Generating the tfrecords data from the generated HMDB-51 CSV 147 | 148 | Now that you have generated the HMDB-51 csv, you will need to download and 149 | extract the videos from the official website and store them in a newly created 150 | *hmdb_videos* directory: 151 | 152 | ```sh 153 | mkdir hmdb_videos 154 | wget https://serre-lab.clps.brown.edu/wp-content/uploads/2013/10/hmdb51_org.rar \ 155 | -P hmdb_videos 156 | cd hmdb_videos 157 | unrar x hmdb51_org.rar 158 | rm hmdb51_org.rar 159 | for video_dir in *rar; do unrar x $video_dir; done 160 | rm *.rar 161 | cd .. 162 | ``` 163 | 164 | You can now run the generation pipeline given any csv split, for example you can 165 | run the generation pipeline on the train set of the first split with the 166 | following command: 167 | 168 | ```sh 169 | python generate_from_file.py \ 170 | --csv_path=hmdb_csv/train_1.csv \ 171 | --video_root_path=hmdb_videos \ 172 | --output_path=/path/to/hmdb_shards 173 | ``` 174 | 175 | and this will generate the tfrecords in the DMVR format for the HMDB-51 split 1 176 | train set split into *sqrt(num_clips)* shards, where *num_clips* is the number 177 | of video clips from the HMDB-51 split 1 train set. 178 | 179 | ## Writing a DMVR reader 180 | 181 | See `hmdb.py` for an example reader for the data created above. 182 | 183 | ## Training a linear classifier on top of existing features 184 | 185 | The script `linear_mmv_hmdb.py` provides a script evaluating the linear 186 | performance of the recently introduced 187 | [MMV networks](https://arxiv.org/abs/2006.16228) on HMDB51. 188 | 189 | To run the script simply do: 190 | 191 | ```shell 192 | python linear_mmv_hmdb.py \ 193 | --data_path=/path/to/hmdb_shards \ 194 | --model_name=s3d \ 195 | --hmdb51_split=1 196 | ``` 197 | 198 | It supports three different models and the script should reproduce 199 | the following results (as reported in the paper): 200 | 201 | Visual Backbone | Results on Linear HMDB51 (avg over 3 splits) 202 | ------- | -------- 203 | [S3D-G](https://tfhub.dev/deepmind/mmv/s3d/1) (`s3d`) | 62.6 204 | [Resnet-50 TSM](https://tfhub.dev/deepmind/mmv/tsm-resnet50/1): (`tsm-resnet50`) | 66.7 205 | [Resnet-50 TSMx2](https://tfhub.dev/deepmind/mmv/tsm-resnet50/1): (`tsm-resnet50x2`) | 67.1 206 | 207 | 208 | ### References 209 | 210 | ```bibtex 211 | @inproceedings{alayrac2020self, 212 | title={{S}elf-{S}upervised {M}ulti{M}odal {V}ersatile {N}etworks}, 213 | author={Alayrac, Jean-Baptiste and Recasens, Adri{\`a} and Schneider, Rosalia and Arandjelovi{\'c}, Relja and Ramapuram, Jason and De Fauw, Jeffrey and Smaira, Lucas and Dieleman, Sander and Zisserman, Andrew}, 214 | booktitle={NeurIPS}, 215 | year={2020} 216 | } 217 | ``` 218 | -------------------------------------------------------------------------------- /examples/generate_from_file.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 DeepMind Technologies Limited. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | """Python script to generate TFRecords of SequenceExample from raw videos.""" 15 | 16 | import contextlib 17 | import math 18 | import os 19 | from typing import Dict, Optional, Sequence 20 | 21 | from absl import app 22 | from absl import flags 23 | import ffmpeg 24 | import numpy as np 25 | import pandas as pd 26 | import tensorflow as tf 27 | 28 | flags.DEFINE_string("csv_path", None, "Input csv") 29 | flags.DEFINE_string("output_path", None, "Tfrecords output path.") 30 | flags.DEFINE_string("video_root_path", None, 31 | "Root directory containing the raw videos.") 32 | flags.DEFINE_integer( 33 | "num_shards", -1, "Number of shards to output, -1 means" 34 | "it will automatically adapt to the sqrt(num_examples).") 35 | flags.DEFINE_bool("decode_audio", False, "Whether or not to decode the audio") 36 | flags.DEFINE_bool("shuffle_csv", False, "Whether or not to shuffle the csv.") 37 | FLAGS = flags.FLAGS 38 | 39 | 40 | _JPEG_HEADER = b"\xff\xd8" 41 | 42 | 43 | @contextlib.contextmanager 44 | def _close_on_exit(writers): 45 | """Call close on all writers on exit.""" 46 | try: 47 | yield writers 48 | finally: 49 | for writer in writers: 50 | writer.close() 51 | 52 | 53 | def add_float_list(key: str, values: Sequence[float], 54 | sequence: tf.train.SequenceExample): 55 | sequence.feature_lists.feature_list[key].feature.add( 56 | ).float_list.value[:] = values 57 | 58 | 59 | def add_bytes_list(key: str, values: Sequence[bytes], 60 | sequence: tf.train.SequenceExample): 61 | sequence.feature_lists.feature_list[key].feature.add( 62 | ).bytes_list.value[:] = values 63 | 64 | 65 | def add_int_list(key: str, values: Sequence[int], 66 | sequence: tf.train.SequenceExample): 67 | sequence.feature_lists.feature_list[key].feature.add( 68 | ).int64_list.value[:] = values 69 | 70 | 71 | def set_context_int_list(key: str, value: Sequence[int], 72 | sequence: tf.train.SequenceExample): 73 | sequence.context.feature[key].int64_list.value[:] = value 74 | 75 | 76 | def set_context_bytes(key: str, value: bytes, 77 | sequence: tf.train.SequenceExample): 78 | sequence.context.feature[key].bytes_list.value[:] = (value,) 79 | 80 | 81 | def set_context_float(key: str, value: float, 82 | sequence: tf.train.SequenceExample): 83 | sequence.context.feature[key].float_list.value[:] = (value,) 84 | 85 | 86 | def set_context_int(key: str, value: int, sequence: tf.train.SequenceExample): 87 | sequence.context.feature[key].int64_list.value[:] = (value,) 88 | 89 | 90 | def extract_frames(video_path: str, 91 | start: float, 92 | end: float, 93 | fps: int = 10, 94 | min_resize: int = 256): 95 | """Extract list of jpeg bytes from video_path using ffmpeg.""" 96 | new_width = "(iw/min(iw,ih))*{}".format(min_resize) 97 | cmd = ( 98 | ffmpeg 99 | .input(video_path) 100 | .trim(start=start, end=end) 101 | .filter("fps", fps=fps) 102 | .filter("scale", new_width, -1) 103 | .output("pipe:", format="image2pipe") 104 | ) 105 | jpeg_bytes, _ = cmd.run(capture_stdout=True, quiet=True) 106 | jpeg_bytes = jpeg_bytes.split(_JPEG_HEADER)[1:] 107 | jpeg_bytes = map(lambda x: _JPEG_HEADER + x, jpeg_bytes) 108 | return list(jpeg_bytes) 109 | 110 | 111 | def extract_audio(video_path: str, 112 | start: float, 113 | end: float, 114 | sampling_rate: int = 48000): 115 | """Extract raw mono audio float list from video_path with ffmpeg.""" 116 | cmd = ( 117 | ffmpeg 118 | .input(video_path, ss=start, t=end-start) 119 | .output("pipe:", ac=1, ar=sampling_rate, format="s32le") 120 | ) 121 | audio, _ = cmd.run(capture_stdout=True, quiet=True) 122 | audio = np.frombuffer(audio, np.float32) 123 | return list(audio) 124 | 125 | 126 | def generate_sequence_example(video_path: str, 127 | start: float, 128 | end: float, 129 | label_name: Optional[str] = None, 130 | caption: Optional[str] = None, 131 | label_map: Optional[Dict[str, int]] = None): 132 | """Generate a sequence example.""" 133 | if FLAGS.video_root_path: 134 | video_path = os.path.join(FLAGS.video_root_path, video_path) 135 | imgs_encoded = extract_frames(video_path, start, end) 136 | 137 | # Initiate the sequence example. 138 | seq_example = tf.train.SequenceExample() 139 | 140 | # Add the label list as text and indices. 141 | if label_name: 142 | set_context_int("clip/label/index", label_map[label_name], seq_example) 143 | set_context_bytes("clip/label/text", label_name.encode(), seq_example) 144 | if caption: 145 | set_context_bytes("caption/string", caption.encode(), seq_example) 146 | # Add the frames as one feature per frame. 147 | for img_encoded in imgs_encoded: 148 | add_bytes_list("image/encoded", [img_encoded], seq_example) 149 | 150 | # Add audio. 151 | if FLAGS.decode_audio: 152 | audio = extract_audio(video_path, start, end) 153 | add_float_list("WAVEFORM/feature/floats", audio, seq_example) 154 | 155 | # Add other metadata. 156 | set_context_bytes("video/filename", video_path.encode(), seq_example) 157 | # Add start and time in micro seconds. 158 | set_context_int("clip/start/timestamp", int(1000000 * start), seq_example) 159 | set_context_int("clip/end/timestamp", int(1000000 * end), seq_example) 160 | return seq_example 161 | 162 | 163 | def main(argv): 164 | del argv 165 | # reads the input csv. 166 | input_csv = pd.read_csv(FLAGS.csv_path) 167 | if FLAGS.num_shards == -1: 168 | num_shards = int(math.sqrt(len(input_csv))) 169 | else: 170 | num_shards = FLAGS.num_shards 171 | # Set up the TFRecordWriters. 172 | basename = os.path.splitext(os.path.basename(FLAGS.csv_path))[0] 173 | shard_names = [ 174 | os.path.join(FLAGS.output_path, f"{basename}-{i:05d}-of-{num_shards:05d}") 175 | for i in range(num_shards) 176 | ] 177 | writers = [tf.io.TFRecordWriter(shard_name) for shard_name in shard_names] 178 | 179 | if "label" in input_csv: 180 | unique_labels = list(set(input_csv["label"].values)) 181 | l_map = {unique_labels[i]: i for i in range(len(unique_labels))} 182 | else: 183 | l_map = None 184 | 185 | if FLAGS.shuffle_csv: 186 | input_csv = input_csv.sample(frac=1) 187 | with _close_on_exit(writers) as writers: 188 | for i in range(len(input_csv)): 189 | print( 190 | "Processing example %d of %d (%d%%) \r" % 191 | (i, len(input_csv), i * 100 / len(input_csv)), 192 | end="") 193 | v = input_csv["video_path"].values[i] 194 | s = input_csv["start"].values[i] 195 | e = input_csv["end"].values[i] 196 | l = input_csv["label"].values[i] if "label" in input_csv else None 197 | c = input_csv["caption"].values[i] if "caption" in input_csv else None 198 | seq_ex = generate_sequence_example( 199 | v, s, e, label_name=l, caption=c, label_map=l_map) 200 | writers[i % len(writers)].write(seq_ex.SerializeToString()) 201 | 202 | 203 | if __name__ == "__main__": 204 | app.run(main) 205 | -------------------------------------------------------------------------------- /examples/generate_hmdb_csv.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 DeepMind Technologies Limited. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | r"""Script to generate csvs for HMDB. 15 | 16 | You would need to download the official splits at: 17 | 18 | http://serre-lab.clps.brown.edu/wp-content/uploads/2013/10/test_train_splits.rar 19 | 20 | and unrar the archive on your machine, e.g. /path/to/hmdb/ 21 | 22 | Usage: 23 | ``` 24 | python generate_hmdb_csv.py \ 25 | --input_path=/path/to/hmdb/testTrainMulti_7030_splits \ 26 | --output_path=/path/to/hmdb 27 | ``` 28 | """ 29 | 30 | import collections 31 | import csv 32 | import glob 33 | import os 34 | 35 | from absl import app 36 | from absl import flags 37 | 38 | 39 | flags.DEFINE_string( 40 | 'input_path', None, 'Path containing the metadata from HMDB51.') 41 | flags.DEFINE_string( 42 | 'output_path', None, 'Path containing the metadata from HMDB51.') 43 | 44 | FLAGS = flags.FLAGS 45 | 46 | InputVideo = collections.namedtuple( 47 | 'InputRow', 48 | ('video_id', 'split', 'subset', 'label_name')) 49 | 50 | OutputRow = collections.namedtuple( 51 | 'OutputRow', 52 | ('video_id', 'start_sec', 'end_sec', 'label_name', 'label_id')) 53 | 54 | 55 | def main(argv): 56 | del argv 57 | all_files = glob.glob(os.path.join(FLAGS.input_path, '*txt')) 58 | 59 | all_rows = [] 60 | label_names = set() 61 | 62 | # Read the files. 63 | for file_name in all_files: 64 | base_name = os.path.basename(file_name) 65 | base_name_split = base_name.split('_') 66 | label_name = ' '.join(base_name_split[:-2]) 67 | label_name = label_name.replace(' ', '_') 68 | label_names.add(label_name) 69 | split = int(base_name[-5]) 70 | with open(file_name, 'r') as f: 71 | lines = [x.strip().split(' ') for x in f.readlines()] 72 | 73 | for (video_id, ind) in lines: 74 | if ind == '1': 75 | all_rows.append( 76 | InputVideo(video_id, split, 'train', label_name)) 77 | elif ind == '2': 78 | all_rows.append( 79 | InputVideo(video_id, split, 'test', label_name)) 80 | 81 | # Sort the label names. 82 | label_names = list(label_names) 83 | label_names.sort() 84 | 85 | all_csvs = { 86 | 'train_1': [], 87 | 'train_2': [], 88 | 'train_3': [], 89 | 'test_1': [], 90 | 'test_2': [], 91 | 'test_3': [], 92 | } 93 | 94 | # Generate the csvs rows. 95 | for row in all_rows: 96 | csv_name = f'{row.subset}_{row.split}' 97 | all_csvs[csv_name].append(OutputRow( 98 | video_id=f'{row.label_name}/{row.video_id}', 99 | start_sec=0, 100 | end_sec=20, 101 | label_name=row.label_name, 102 | label_id=label_names.index(row.label_name) 103 | )) 104 | 105 | # Write the csvs. 106 | for csv_name in all_csvs: 107 | output_path = os.path.join(FLAGS.output_path, f'{csv_name}.csv') 108 | print(f'Writing outputs to CSV file {output_path}') 109 | with open(output_path, 'w') as f: 110 | writer = csv.writer(f, delimiter=',') 111 | writer.writerow( 112 | ['video_path', 'start', 'end', 'label']) 113 | 114 | for row in all_csvs[csv_name]: 115 | writer.writerow([ 116 | row.video_id, row.start_sec, row.end_sec, row.label_name]) 117 | 118 | if __name__ == '__main__': 119 | app.run(main) 120 | -------------------------------------------------------------------------------- /examples/hmdb.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 DeepMind Technologies Limited. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | """HMDB51 video dataset.""" 15 | 16 | import os 17 | from typing import Optional 18 | 19 | from dmvr import modalities 20 | from dmvr import video_dataset 21 | 22 | 23 | class HMDB51Factory(video_dataset.BaseVideoDatasetFactory): 24 | """HMDB51 reader.""" 25 | 26 | _SUBSETS = ('train', 'test') 27 | _SPLITS = (1, 2, 3) 28 | _NUM_CLASSES = 51 29 | 30 | _NUM_SHARDS = {'train': 59, 'test': 39} 31 | 32 | def __init__( 33 | self, 34 | base_dir: str, 35 | subset: str = 'train', 36 | split: int = 1): 37 | """Constructor of HMDB51Factory.""" 38 | 39 | if subset not in HMDB51Factory._SUBSETS: 40 | raise ValueError('Invalid subset "{}". The available subsets are: {}' 41 | .format(subset, HMDB51Factory._SUBSETS)) 42 | 43 | if split not in HMDB51Factory._SPLITS: 44 | raise ValueError('Invalid split "{}". The available splits are: {}' 45 | .format(split, HMDB51Factory._SPLITS)) 46 | 47 | num_shards = self._NUM_SHARDS[subset] 48 | shards = [f'{subset}_{split}-{i:05d}-of-{num_shards:05d}' 49 | for i in range(num_shards)] 50 | super().__init__(shards=[os.path.join(base_dir, s) for s in shards]) 51 | 52 | def _build(self, 53 | is_training: Optional[bool] = True, 54 | # Video related parameters. 55 | num_frames: int = 32, 56 | stride: int = 1, 57 | num_test_clips: int = 1, 58 | min_resize: int = 256, 59 | crop_size: int = 224, 60 | zero_centering_image: bool = False, 61 | # Label related parameters. 62 | one_hot_label: bool = True, 63 | add_label_name: bool = False): 64 | """Default build for this dataset. 65 | 66 | Args: 67 | is_training: Whether or not in training mode. 68 | num_frames: Number of frames per subclip. For single images, use 1. 69 | stride: Temporal stride to sample frames. 70 | num_test_clips: Number of test clips (1 by default). If more than 1, this 71 | will sample multiple linearly spaced clips within each video at test 72 | time. If 1, then a single clip in the middle of the video is sampled. 73 | The clips are aggreagated in the batch dimension. 74 | min_resize: Frames are resized so that `min(height, width)` is 75 | `min_resize`. 76 | crop_size: Final size of the frame after cropping the resized frames. Both 77 | height and width are the same. 78 | zero_centering_image: If `True`, frames are normalized to values in 79 | [-1, 1]. If `False`, values in [0, 1]. 80 | one_hot_label: Return labels as one hot tensors. 81 | add_label_name: Also return the name of the label. 82 | """ 83 | modalities.add_image( 84 | parser_builder=self.parser_builder, 85 | sampler_builder=self.sampler_builder, 86 | decoder_builder=self.decoder_builder, 87 | preprocessor_builder=self.preprocessor_builder, 88 | postprocessor_builder=self.postprocessor_builder, 89 | is_training=is_training, 90 | num_frames=num_frames, stride=stride, 91 | num_test_clips=num_test_clips, 92 | min_resize=min_resize, crop_size=crop_size, 93 | zero_centering_image=zero_centering_image) 94 | 95 | modalities.add_label( 96 | parser_builder=self.parser_builder, 97 | decoder_builder=self.decoder_builder, 98 | preprocessor_builder=self.preprocessor_builder, 99 | one_hot_label=one_hot_label, 100 | num_classes=HMDB51Factory._NUM_CLASSES, 101 | add_label_name=add_label_name) 102 | -------------------------------------------------------------------------------- /examples/linear_mmv_hmdb.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 DeepMind Technologies Limited. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """HMDB51 linear evaluation of MMV models.""" 16 | 17 | from absl import app 18 | from absl import flags 19 | from dmvr import builders 20 | import hmdb 21 | import numpy as np 22 | from sklearn import preprocessing 23 | from sklearn import svm 24 | import tensorflow.compat.v2 as tf 25 | import tensorflow_datasets as tfds 26 | import tensorflow_hub as hub 27 | 28 | 29 | flags.DEFINE_enum('model_name', 's3d', 30 | ['s3d', 'tsm-resnet50', 'tsm-resnet50x2'], 31 | 'Which MMV backbone to load.') 32 | flags.DEFINE_string('data_path', '/path/to/hmdb/', 'Path where shards live.') 33 | flags.DEFINE_integer('eval_batch_size', 1, 34 | 'The batch size for evaluation.') 35 | flags.DEFINE_integer('train_batch_size', 16, 36 | 'The batch size for training.') 37 | flags.DEFINE_integer('num_train_epochs', 10, 38 | 'How many epochs to collect features during training.') 39 | flags.DEFINE_integer('num_test_clips', 10, 40 | 'How many clips to average on during test.') 41 | flags.DEFINE_integer('min_resize', 224, 42 | 'Min value to resize images to during preprocessing.') 43 | flags.DEFINE_integer('crop_size', 200, 44 | 'Value to resize images to during preprocessing.') 45 | flags.DEFINE_integer('num_frames', 32, 'Number of video frames.') 46 | flags.DEFINE_integer('stride', 1, 'Stride for video frames.') 47 | flags.DEFINE_integer('hmdb51_split', 1, 'Which split of hmdb51 to use.') 48 | 49 | 50 | FLAGS = flags.FLAGS 51 | 52 | _MODELS2REG = {'s3d': 0.0003, 53 | 'tsm-resnet50': 0.0001, 54 | 'tsm-resnet50x2': 0.0003} 55 | 56 | 57 | def compute_accuracy_metrics(pred: np.ndarray, gt: np.ndarray, 58 | prefix: str = ''): 59 | order_pred = np.argsort(pred, axis=1) 60 | assert len(gt.shape) == len(order_pred.shape) == 2 61 | top1_pred = order_pred[:, -1:] 62 | top5_pred = order_pred[:, -5:] 63 | top1_acc = np.mean(top1_pred == gt) 64 | top5_acc = np.mean(np.max(top5_pred == gt, 1)) 65 | return {prefix + 'top1': top1_acc, 66 | prefix + 'top5': top5_acc} 67 | 68 | 69 | def main(argv): 70 | del argv 71 | 72 | # Load the model. 73 | sklearn_reg = _MODELS2REG[FLAGS.model_name] 74 | module = hub.load(f'https://tfhub.dev/deepmind/mmv/{FLAGS.model_name}/1') 75 | 76 | def get_features(input_frames: np.ndarray): 77 | vision_output = module.signatures['video']( 78 | tf.constant(tf.cast(input_frames, dtype=tf.float32))) 79 | return vision_output['before_head'].numpy() 80 | 81 | def collect_features_and_labels(ds: tf.data.Dataset, subset: str): 82 | """Collect features and labels.""" 83 | features = [] 84 | labels = [] 85 | print(f'Computing features on {subset}') 86 | examples = iter(tfds.as_numpy(ds)) 87 | num_examples = 0 88 | for ex in examples: 89 | vid_representation = get_features(ex[builders.IMAGE_FEATURE_NAME]) 90 | labels.append(ex[builders.LABEL_INDEX_FEATURE_NAME]) 91 | features.append(vid_representation) 92 | num_examples += ex[builders.LABEL_INDEX_FEATURE_NAME].shape[0] 93 | if num_examples % 100 == 0: 94 | print(f'Processed {num_examples} examples.') 95 | labels = np.concatenate(labels, axis=0) 96 | features = np.concatenate(features, axis=0) 97 | print(f'Finish collecting {subset} features of shape {features.shape}') 98 | return features, labels 99 | 100 | # Generate the training and testing datasets. 101 | conf_kwargs = dict( 102 | num_frames=FLAGS.num_frames, 103 | stride=FLAGS.stride, 104 | min_resize=FLAGS.min_resize, 105 | crop_size=FLAGS.crop_size, 106 | one_hot_label=False) 107 | 108 | train_ds = hmdb.HMDB51Factory( 109 | FLAGS.data_path, subset='train', split=FLAGS.hmdb51_split).configure( 110 | is_training=True, **conf_kwargs).make_dataset( 111 | shuffle=True, 112 | num_epochs=FLAGS.num_train_epochs, 113 | batch_size=FLAGS.train_batch_size) 114 | 115 | test_ds = hmdb.HMDB51Factory( 116 | FLAGS.data_path, subset='test', split=FLAGS.hmdb51_split).configure( 117 | is_training=False, num_test_clips=FLAGS.num_test_clips, 118 | **conf_kwargs).make_dataset(shuffle=False, 119 | num_epochs=1, 120 | batch_size=FLAGS.eval_batch_size) 121 | 122 | # Collect features and labels. 123 | train_features, train_labels = collect_features_and_labels(train_ds, 'train') 124 | test_features, test_labels = collect_features_and_labels(test_ds, 'test') 125 | 126 | # Train classifier 127 | print('Training linear classifier!') 128 | classifier = svm.LinearSVC(C=sklearn_reg) 129 | scaler = preprocessing.StandardScaler().fit(train_features) 130 | train_features = scaler.transform(train_features) 131 | classifier.fit(train_features, train_labels.ravel()) 132 | print('Training done !') 133 | 134 | # Evaluation. 135 | test_features = scaler.transform(test_features) 136 | print('Running inference on train') 137 | pred_train = classifier.decision_function(train_features) 138 | print('Running inference on test') 139 | pred_test = classifier.decision_function(test_features) 140 | if FLAGS.num_test_clips > 1: 141 | pred_test = np.reshape( 142 | pred_test, (test_labels.shape[0], -1, pred_test.shape[1])) 143 | pred_test = pred_test.mean(axis=1) 144 | 145 | # Compute accuracies. 146 | metrics = compute_accuracy_metrics(pred_train, train_labels, prefix='train_') 147 | metrics.update( 148 | compute_accuracy_metrics(pred_test, test_labels, prefix='test_')) 149 | print(metrics) 150 | 151 | if __name__ == '__main__': 152 | app.run(main) 153 | -------------------------------------------------------------------------------- /requirements-test.txt: -------------------------------------------------------------------------------- 1 | numpy 2 | parameterized 3 | tensorflow>=2.0.0 4 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | absl-py 2 | sentencepiece 3 | tensorflow>=2.0.0 4 | tensorflow_text>=2.0.0 5 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 DeepMind Technologies Limited. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # https://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Setup for pip package.""" 16 | 17 | import setuptools 18 | 19 | _VERSION = '0.0.1' 20 | 21 | 22 | def _parse_requirements(requirements_txt_path): 23 | parse_line = lambda l: l.split('#')[0].strip() 24 | with open(requirements_txt_path) as f: 25 | return [parse_line(l) for l in f] 26 | 27 | 28 | setuptools.setup( 29 | name='dmvr', 30 | version=_VERSION, 31 | url='https://github.com/deepmind/dmvr', 32 | license='Apache 2.0', 33 | author='DeepMind', 34 | description=( 35 | 'DMVR is a library for reading and processing multimodal datasets.'), 36 | long_description=open('README.md').read(), 37 | long_description_content_type='text/markdown', 38 | author_email='dmvr-dev-os@google.com', 39 | # Contained modules and scripts. 40 | packages=setuptools.find_namespace_packages(exclude=['*_test.py']), 41 | install_requires=_parse_requirements('requirements.txt'), 42 | tests_require=_parse_requirements('requirements-test.txt'), 43 | requires_python='>=3.6', 44 | include_package_data=True, 45 | zip_safe=False, 46 | # PyPI package information. 47 | classifiers=[ 48 | 'Intended Audience :: Developers', 49 | 'Intended Audience :: Education', 50 | 'Intended Audience :: Science/Research', 51 | 'License :: OSI Approved :: Apache Software License', 52 | 'Programming Language :: Python :: 3', 53 | 'Programming Language :: Python :: 3.6', 54 | 'Programming Language :: Python :: 3.7', 55 | 'Topic :: Scientific/Engineering :: Mathematics', 56 | 'Topic :: Software Development :: Libraries :: Python Modules', 57 | 'Topic :: Software Development :: Libraries', 58 | ], 59 | ) 60 | -------------------------------------------------------------------------------- /test.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # Copyright 2021 DeepMind Technologies Limited. 4 | # 5 | # Licensed under the Apache License, Version 2.0 (the "License"); 6 | # you may not use this file except in compliance with the License. 7 | # You may obtain a copy of the License at 8 | # 9 | # https://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | 17 | # Pip installs the relevant dependencies and runs DMVR tests. 18 | 19 | set -e 20 | set -x 21 | 22 | python3 -m pip install --upgrade pip 23 | python3 -m pip install virtualenv 24 | virtualenv -p python3 . 25 | source bin/activate 26 | python3 --version 27 | 28 | # Run setup.py, install dependencies first to use pip install. 29 | python3 -m pip install -r requirements.txt 30 | python3 setup.py install 31 | 32 | # Python test dependencies. 33 | python3 -m pip install -r requirements-test.txt 34 | 35 | # Run all tests. 36 | python3 -m unittest discover -s 'dmvr' -p '*_test.py' 37 | 38 | deactivate 39 | --------------------------------------------------------------------------------