├── CONTRIBUTING.md
├── LICENSE
├── README.md
├── dmvr
├── __init__.py
├── builders.py
├── builders_test.py
├── modalities.py
├── modalities_test.py
├── processors.py
├── processors_test.py
├── sources.py
├── sources_test.py
├── testdata
│ ├── sample.jpeg
│ └── tokenizers
│ │ ├── bert_word_vocab.txt
│ │ ├── spiece.model.1000.model
│ │ └── word_vocab.txt
├── tokenizers.py
├── tokenizers_test.py
├── utils.py
├── utils_test.py
├── video_dataset.py
└── video_dataset_test.py
├── examples
├── README.md
├── generate_from_file.py
├── generate_hmdb_csv.py
├── hmdb.py
└── linear_mmv_hmdb.py
├── requirements-test.txt
├── requirements.txt
├── setup.py
└── test.sh
/CONTRIBUTING.md:
--------------------------------------------------------------------------------
1 | # How to Contribute
2 |
3 | We'd love to accept your patches and contributions to this project. There are
4 | just a few small guidelines you need to follow.
5 |
6 | ## Contributor License Agreement
7 |
8 | Contributions to this project must be accompanied by a Contributor License
9 | Agreement (CLA). You (or your employer) retain the copyright to your
10 | contribution; this simply gives us permission to use and redistribute your
11 | contributions as part of the project. Head over to
12 | to see your current agreements on file or
13 | to sign a new one.
14 |
15 | You generally only need to submit a CLA once, so if you've already submitted one
16 | (even if it was for a different project), you probably don't need to do it
17 | again.
18 |
19 | ## Code Reviews
20 |
21 | All submissions, including submissions by project members, require review. We
22 | use GitHub pull requests for this purpose. Consult
23 | [GitHub Help](https://help.github.com/articles/about-pull-requests/) for more
24 | information on using pull requests.
25 |
26 | ## Community Guidelines
27 |
28 | This project follows
29 | [Google's Open Source Community Guidelines](https://opensource.google/conduct/).
30 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 |
2 | Apache License
3 | Version 2.0, January 2004
4 | http://www.apache.org/licenses/
5 |
6 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
7 |
8 | 1. Definitions.
9 |
10 | "License" shall mean the terms and conditions for use, reproduction,
11 | and distribution as defined by Sections 1 through 9 of this document.
12 |
13 | "Licensor" shall mean the copyright owner or entity authorized by
14 | the copyright owner that is granting the License.
15 |
16 | "Legal Entity" shall mean the union of the acting entity and all
17 | other entities that control, are controlled by, or are under common
18 | control with that entity. For the purposes of this definition,
19 | "control" means (i) the power, direct or indirect, to cause the
20 | direction or management of such entity, whether by contract or
21 | otherwise, or (ii) ownership of fifty percent (50%) or more of the
22 | outstanding shares, or (iii) beneficial ownership of such entity.
23 |
24 | "You" (or "Your") shall mean an individual or Legal Entity
25 | exercising permissions granted by this License.
26 |
27 | "Source" form shall mean the preferred form for making modifications,
28 | including but not limited to software source code, documentation
29 | source, and configuration files.
30 |
31 | "Object" form shall mean any form resulting from mechanical
32 | transformation or translation of a Source form, including but
33 | not limited to compiled object code, generated documentation,
34 | and conversions to other media types.
35 |
36 | "Work" shall mean the work of authorship, whether in Source or
37 | Object form, made available under the License, as indicated by a
38 | copyright notice that is included in or attached to the work
39 | (an example is provided in the Appendix below).
40 |
41 | "Derivative Works" shall mean any work, whether in Source or Object
42 | form, that is based on (or derived from) the Work and for which the
43 | editorial revisions, annotations, elaborations, or other modifications
44 | represent, as a whole, an original work of authorship. For the purposes
45 | of this License, Derivative Works shall not include works that remain
46 | separable from, or merely link (or bind by name) to the interfaces of,
47 | the Work and Derivative Works thereof.
48 |
49 | "Contribution" shall mean any work of authorship, including
50 | the original version of the Work and any modifications or additions
51 | to that Work or Derivative Works thereof, that is intentionally
52 | submitted to Licensor for inclusion in the Work by the copyright owner
53 | or by an individual or Legal Entity authorized to submit on behalf of
54 | the copyright owner. For the purposes of this definition, "submitted"
55 | means any form of electronic, verbal, or written communication sent
56 | to the Licensor or its representatives, including but not limited to
57 | communication on electronic mailing lists, source code control systems,
58 | and issue tracking systems that are managed by, or on behalf of, the
59 | Licensor for the purpose of discussing and improving the Work, but
60 | excluding communication that is conspicuously marked or otherwise
61 | designated in writing by the copyright owner as "Not a Contribution."
62 |
63 | "Contributor" shall mean Licensor and any individual or Legal Entity
64 | on behalf of whom a Contribution has been received by Licensor and
65 | subsequently incorporated within the Work.
66 |
67 | 2. Grant of Copyright License. Subject to the terms and conditions of
68 | this License, each Contributor hereby grants to You a perpetual,
69 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
70 | copyright license to reproduce, prepare Derivative Works of,
71 | publicly display, publicly perform, sublicense, and distribute the
72 | Work and such Derivative Works in Source or Object form.
73 |
74 | 3. Grant of Patent License. Subject to the terms and conditions of
75 | this License, each Contributor hereby grants to You a perpetual,
76 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
77 | (except as stated in this section) patent license to make, have made,
78 | use, offer to sell, sell, import, and otherwise transfer the Work,
79 | where such license applies only to those patent claims licensable
80 | by such Contributor that are necessarily infringed by their
81 | Contribution(s) alone or by combination of their Contribution(s)
82 | with the Work to which such Contribution(s) was submitted. If You
83 | institute patent litigation against any entity (including a
84 | cross-claim or counterclaim in a lawsuit) alleging that the Work
85 | or a Contribution incorporated within the Work constitutes direct
86 | or contributory patent infringement, then any patent licenses
87 | granted to You under this License for that Work shall terminate
88 | as of the date such litigation is filed.
89 |
90 | 4. Redistribution. You may reproduce and distribute copies of the
91 | Work or Derivative Works thereof in any medium, with or without
92 | modifications, and in Source or Object form, provided that You
93 | meet the following conditions:
94 |
95 | (a) You must give any other recipients of the Work or
96 | Derivative Works a copy of this License; and
97 |
98 | (b) You must cause any modified files to carry prominent notices
99 | stating that You changed the files; and
100 |
101 | (c) You must retain, in the Source form of any Derivative Works
102 | that You distribute, all copyright, patent, trademark, and
103 | attribution notices from the Source form of the Work,
104 | excluding those notices that do not pertain to any part of
105 | the Derivative Works; and
106 |
107 | (d) If the Work includes a "NOTICE" text file as part of its
108 | distribution, then any Derivative Works that You distribute must
109 | include a readable copy of the attribution notices contained
110 | within such NOTICE file, excluding those notices that do not
111 | pertain to any part of the Derivative Works, in at least one
112 | of the following places: within a NOTICE text file distributed
113 | as part of the Derivative Works; within the Source form or
114 | documentation, if provided along with the Derivative Works; or,
115 | within a display generated by the Derivative Works, if and
116 | wherever such third-party notices normally appear. The contents
117 | of the NOTICE file are for informational purposes only and
118 | do not modify the License. You may add Your own attribution
119 | notices within Derivative Works that You distribute, alongside
120 | or as an addendum to the NOTICE text from the Work, provided
121 | that such additional attribution notices cannot be construed
122 | as modifying the License.
123 |
124 | You may add Your own copyright statement to Your modifications and
125 | may provide additional or different license terms and conditions
126 | for use, reproduction, or distribution of Your modifications, or
127 | for any such Derivative Works as a whole, provided Your use,
128 | reproduction, and distribution of the Work otherwise complies with
129 | the conditions stated in this License.
130 |
131 | 5. Submission of Contributions. Unless You explicitly state otherwise,
132 | any Contribution intentionally submitted for inclusion in the Work
133 | by You to the Licensor shall be under the terms and conditions of
134 | this License, without any additional terms or conditions.
135 | Notwithstanding the above, nothing herein shall supersede or modify
136 | the terms of any separate license agreement you may have executed
137 | with Licensor regarding such Contributions.
138 |
139 | 6. Trademarks. This License does not grant permission to use the trade
140 | names, trademarks, service marks, or product names of the Licensor,
141 | except as required for reasonable and customary use in describing the
142 | origin of the Work and reproducing the content of the NOTICE file.
143 |
144 | 7. Disclaimer of Warranty. Unless required by applicable law or
145 | agreed to in writing, Licensor provides the Work (and each
146 | Contributor provides its Contributions) on an "AS IS" BASIS,
147 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
148 | implied, including, without limitation, any warranties or conditions
149 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
150 | PARTICULAR PURPOSE. You are solely responsible for determining the
151 | appropriateness of using or redistributing the Work and assume any
152 | risks associated with Your exercise of permissions under this License.
153 |
154 | 8. Limitation of Liability. In no event and under no legal theory,
155 | whether in tort (including negligence), contract, or otherwise,
156 | unless required by applicable law (such as deliberate and grossly
157 | negligent acts) or agreed to in writing, shall any Contributor be
158 | liable to You for damages, including any direct, indirect, special,
159 | incidental, or consequential damages of any character arising as a
160 | result of this License or out of the use or inability to use the
161 | Work (including but not limited to damages for loss of goodwill,
162 | work stoppage, computer failure or malfunction, or any and all
163 | other commercial damages or losses), even if such Contributor
164 | has been advised of the possibility of such damages.
165 |
166 | 9. Accepting Warranty or Additional Liability. While redistributing
167 | the Work or Derivative Works thereof, You may choose to offer,
168 | and charge a fee for, acceptance of support, warranty, indemnity,
169 | or other liability obligations and/or rights consistent with this
170 | License. However, in accepting such obligations, You may act only
171 | on Your own behalf and on Your sole responsibility, not on behalf
172 | of any other Contributor, and only if You agree to indemnify,
173 | defend, and hold each Contributor harmless for any liability
174 | incurred by, or claims asserted against, such Contributor by reason
175 | of your accepting any such warranty or additional liability.
176 |
177 | END OF TERMS AND CONDITIONS
178 |
179 | APPENDIX: How to apply the Apache License to your work.
180 |
181 | To apply the Apache License to your work, attach the following
182 | boilerplate notice, with the fields enclosed by brackets "[]"
183 | replaced with your own identifying information. (Don't include
184 | the brackets!) The text should be enclosed in the appropriate
185 | comment syntax for the file format. We also recommend that a
186 | file or class name and description of purpose be included on the
187 | same "printed page" as the copyright notice for easier
188 | identification within third-party archives.
189 |
190 | Copyright [yyyy] [name of copyright owner]
191 |
192 | Licensed under the Apache License, Version 2.0 (the "License");
193 | you may not use this file except in compliance with the License.
194 | You may obtain a copy of the License at
195 |
196 | http://www.apache.org/licenses/LICENSE-2.0
197 |
198 | Unless required by applicable law or agreed to in writing, software
199 | distributed under the License is distributed on an "AS IS" BASIS,
200 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
201 | See the License for the specific language governing permissions and
202 | limitations under the License.
203 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # DMVR: DeepMind Video Readers
2 |
3 | DMVR is a library providing a framework for easily reading raw data and
4 | producing `tf.data.Dataset` objects ready to be consumed by models.
5 |
6 | ## Design principles
7 |
8 | ### Data processing graph
9 |
10 | The main idea of the framework is to build a customizable and reusable data
11 | processing graph that when applied to raw data files, will produce final dataset
12 | objects. Building blocks called Builders are used to interact with the graph by
13 | adding, removing or replacing data processing blocks.
14 |
15 | Dataset providers can write a Factory with a default data processing graph for
16 | each dataset. Dataset consumers can customize the graph to their needs either by
17 | creating a child Factory or just appending a given instance. Factory objects
18 | expose instances of Builders allowing control of the multiple phases of the data
19 | processing graph. The Factory is then able to generate `tf.data.Dataset`
20 | objects.
21 |
22 | ### Phases
23 |
24 | The data processing graph is split in multipple phases. This abstraction is
25 | purely semantic, which makes code easier to reuse. The phases are:
26 |
27 | - Parse
28 | - Sample
29 | - Decode
30 | - Preprocess
31 | - Postprocess
32 |
33 | ### Modalities
34 |
35 | In order to easily add different modalities to the dataset from the raw data,
36 | sub graphs for some modalities with default processing (e.g. sample, decode and
37 | crop for images) is provided. These sub graphs can be added by simply calling
38 | the corresponding methods for the Builders.
39 |
40 | ## Usage
41 |
42 | ### Dataset providers
43 |
44 | Dataset providers should implement a factory populating the default graph.
45 |
46 | Example:
47 |
48 | - Data is stored in TFRecords as `tf.train.SequenceExample` objects.
49 |
50 | ```python
51 | from typing import List
52 |
53 | from dmvr import modalities
54 | from dmvr import video_dataset
55 |
56 | class Kinetics700Factory(video_dataset.BaseVideoDatasetFactory):
57 |
58 | _NUM_CLASSES = 700
59 |
60 | def __init__(self, subset: str):
61 | self._is_training = subset == 'train'
62 | shards: List[str] = path_to_the_data(subset)
63 | super().__init__(shards)
64 |
65 | def _build(self,
66 | # Video related parameters.
67 | num_frames: int = 32,
68 | stride: int = 1,
69 | num_test_clips: int = 1,
70 | min_resize: int = 224,
71 | crop_size: int = 200,
72 | zero_centering_image: bool = False,
73 | # Label related parameters.
74 | one_hot_label: bool = True,
75 | add_label_name: bool = False):
76 | """Build default data processing graph."""
77 | modalities.add_image(parser_builder=self.parser_builder,
78 | sampler_builder=self.sampler_builder,
79 | decoder_builder=self.decoder_builder,
80 | preprocessor_builder=self.preprocessor_builder,
81 | postprocessor_builder=self.postprocessor_builder,
82 | is_training=self._is_training,
83 | num_frames=num_frames,
84 | stride=stride,
85 | min_resize=min_resize,
86 | crop_size=crop_size,
87 | zero_centering_image=zero_centering_image)
88 |
89 | modalities.add_label(parser_builder=self.parser_builder,
90 | decoder_builder=self.decoder_builder,
91 | preprocessor_builder=self.preprocessor_builder,
92 | one_hot_label=one_hot_label,
93 | num_classes=self._NUM_CLASSES,
94 | add_label_name=add_label_name)
95 | ```
96 |
97 | ### Dataset consumers
98 |
99 | Dataset consumers can create `tf.data.Dataset` objects from a factory instance.
100 |
101 | Example:
102 |
103 | ```python
104 | factory = Kinetics700Factory('train')
105 | factory.configure(num_frames=16)
106 | ds = factory.make_dataset(batch_size=8)
107 | ```
108 |
109 | The user can also customize the data processing graph by adding more functions:
110 |
111 | ```python
112 | from dmvr import builders
113 | from dmvr import processors
114 |
115 | factory = Kinetics700Factory('train')
116 | factory.configure(num_frames=16)
117 |
118 | factory.preprocess_builder.add_fn(processors.scale_jitter_augm,
119 | feature_name=builders.IMAGE_FEATURE_NAME)
120 | factory.preprocess_builder.add_fn(processors.color_default_augm,
121 | feature_name=builders.IMAGE_FEATURE_NAME)
122 |
123 | ds = factory.make_dataset(batch_size=8)
124 | ```
125 |
126 | ## Installation
127 |
128 | DMVR can be installed with pip directly from github, with the following command:
129 |
130 | pip install git+git://github.com/deepmind/dmvr.git
131 |
132 | Python 3.9+ is required in order for all features to be available.
133 |
--------------------------------------------------------------------------------
/dmvr/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright 2021 DeepMind Technologies Limited.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # https://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
--------------------------------------------------------------------------------
/dmvr/builders_test.py:
--------------------------------------------------------------------------------
1 | # Copyright 2021 DeepMind Technologies Limited.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # https://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | """Tests for builders."""
16 |
17 | from dmvr import builders
18 | from parameterized import parameterized
19 | import tensorflow as tf
20 |
21 |
22 | class SequenceExampleParserBuilderTest(tf.test.TestCase):
23 |
24 | def setUp(self):
25 | super().setUp()
26 |
27 | # Prepare SequenceExample.
28 | seq_example = tf.train.SequenceExample()
29 | seq_example.context.feature.get_or_create(
30 | 'my_context_feature').int64_list.value[:] = [0, 1]
31 | seq_example.feature_lists.feature_list.get_or_create(
32 | 'my_seq_feature').feature.add().int64_list.value[:] = [2, 3]
33 | seq_example.feature_lists.feature_list.get(
34 | 'my_seq_feature').feature.add().int64_list.value[:] = [4, 5]
35 | seq_example.feature_lists.feature_list.get_or_create(
36 | 'my_var_len_seq_feature').feature.add().int64_list.value[:] = [6]
37 | seq_example.feature_lists.feature_list.get(
38 | 'my_var_len_seq_feature').feature.add().int64_list.value[:] = [7, 8]
39 |
40 | # Put SequenceExample in expected format.
41 | self._raw_seq_example = tf.constant(seq_example.SerializeToString())
42 |
43 | def test_parse(self):
44 | parse_fn = (
45 | builders.SequenceExampleParserBuilder()
46 | .parse_feature('my_context_feature',
47 | tf.io.FixedLenFeature((2,), dtype=tf.int64),
48 | 'context_name', True)
49 | .parse_feature('my_seq_feature',
50 | tf.io.FixedLenSequenceFeature((2,), dtype=tf.int64),
51 | 'seq_name')
52 | .parse_feature('my_var_len_seq_feature',
53 | tf.io.VarLenFeature(dtype=tf.int64),
54 | 'var_len_seq_name')
55 | .build())
56 | features_dict = parse_fn(self._raw_seq_example)
57 |
58 | self.assertSetEqual(set(['context_name', 'seq_name', 'var_len_seq_name']),
59 | set(features_dict.keys()))
60 | self.assertAllEqual(features_dict['context_name'], [0, 1])
61 | self.assertAllEqual(features_dict['seq_name'], [[2, 3], [4, 5]])
62 | self.assertAllEqual(features_dict['var_len_seq_name'].values, [6, 7, 8])
63 | self.assertAllEqual(features_dict['var_len_seq_name'].indices,
64 | [[0, 0], [1, 0], [1, 1]])
65 | self.assertAllEqual(features_dict['var_len_seq_name'].dense_shape, [2, 2])
66 |
67 | def test_fake_data(self):
68 | parser = builders.SequenceExampleParserBuilder()
69 | parser.parse_feature('my_context_feature',
70 | tf.io.FixedLenFeature((2,), dtype=tf.int64),
71 | 'context_name', True)
72 | parser.parse_feature('my_seq_feature',
73 | tf.io.FixedLenSequenceFeature((2,), dtype=tf.int64),
74 | 'seq_name')
75 | parser.parse_feature('my_var_len_seq_feature',
76 | tf.io.VarLenFeature(dtype=tf.int64),
77 | 'var_len_seq_name')
78 | fake_data = parser.get_fake_data(default_values={
79 | 'context_name': (0, 1), 'var_len_seq_name': ((1, 2), (3, 4, 5))})
80 | self.assertSetEqual(set(['context_name', 'seq_name', 'var_len_seq_name']),
81 | set(fake_data.keys()))
82 | self.assertAllEqual(fake_data['context_name'], [0, 1])
83 | self.assertAllEqual(fake_data['seq_name'], [[0, 0]])
84 | self.assertAllEqual(fake_data['var_len_seq_name'].values, [1, 2, 3, 4, 5])
85 | self.assertAllEqual(fake_data['var_len_seq_name'].dense_shape, [2, 3])
86 |
87 | def test_no_output_name(self):
88 | parse_fn = (
89 | builders.SequenceExampleParserBuilder()
90 | .parse_feature('my_context_feature',
91 | tf.io.FixedLenFeature((2,), dtype=tf.int64),
92 | is_context=True)
93 | .build())
94 | features_dict = parse_fn(self._raw_seq_example)
95 |
96 | self.assertSetEqual(set(['my_context_feature']), set(features_dict.keys()))
97 |
98 | def test_same_output_name(self):
99 | parser_builder = builders.SequenceExampleParserBuilder()
100 | parser_builder.parse_feature('my_context_feature',
101 | tf.io.FixedLenFeature((2,), dtype=tf.int64),
102 | 'same_name', True)
103 | parser_builder.parse_feature('my_context_feature',
104 | tf.io.FixedLenFeature((2,), dtype=tf.int64),
105 | 'other_name', True)
106 |
107 | with self.assertRaises(ValueError) as _:
108 | parser_builder.parse_feature(
109 | 'my_seq_feature', tf.io.FixedLenSequenceFeature((2,), dtype=tf.int64),
110 | 'same_name')
111 |
112 | def test_different_types_for_same_feature(self):
113 | parser_builder = builders.SequenceExampleParserBuilder()
114 | parser_builder.parse_feature('my_context_feature',
115 | tf.io.FixedLenFeature((2,), dtype=tf.int64),
116 | 'context_name', True)
117 |
118 | with self.assertRaises(ValueError) as _:
119 | parser_builder.parse_feature('my_context_feature',
120 | tf.io.FixedLenFeature((3,), dtype=tf.int64),
121 | 'context_name_2', True)
122 |
123 | with self.assertRaises(ValueError) as _:
124 | parser_builder.parse_feature('my_context_feature',
125 | tf.io.FixedLenFeature((2,), dtype=tf.string),
126 | 'context_name_3', True)
127 |
128 |
129 | class ExampleParserBuilderTest(tf.test.TestCase):
130 |
131 | def setUp(self):
132 | super().setUp()
133 |
134 | # Prepare Example.
135 | tf_example = tf.train.Example()
136 | tf_example.features.feature.get_or_create(
137 | 'my_fixed_len_feature').int64_list.value[:] = [0, 1]
138 | tf_example.features.feature.get_or_create(
139 | 'my_var_len_feature').int64_list.value[:] = [2, 3, 4]
140 |
141 | # Put Example in expected format.
142 | self._raw_tf_example = tf.constant(tf_example.SerializeToString())
143 |
144 | def test_parse(self):
145 | parse_fn = (
146 | builders.ExampleParserBuilder()
147 | .parse_feature('my_fixed_len_feature',
148 | tf.io.FixedLenFeature((2,), dtype=tf.int64),
149 | 'fixed_name')
150 | .parse_feature('my_var_len_feature',
151 | tf.io.VarLenFeature(dtype=tf.int64), 'var_name')
152 | .build())
153 | features_dict = parse_fn(self._raw_tf_example)
154 |
155 | self.assertSetEqual(set(['fixed_name', 'var_name']),
156 | set(features_dict.keys()))
157 | self.assertAllEqual(features_dict['fixed_name'], [0, 1])
158 | self.assertAllEqual(features_dict['var_name'].values, [2, 3, 4])
159 | self.assertAllEqual(features_dict['var_name'].indices, [[0], [1], [2]])
160 | self.assertAllEqual(features_dict['var_name'].dense_shape, [3])
161 |
162 | def test_fake_data(self):
163 | fake_data = (
164 | builders.ExampleParserBuilder()
165 | .parse_feature('my_fixed_len_feature',
166 | tf.io.FixedLenFeature((2,), dtype=tf.string),
167 | 'fixed_name')
168 | .parse_feature('my_var_len_feature',
169 | tf.io.VarLenFeature(dtype=tf.int64), 'var_name')
170 | .get_fake_data(default_values={'fixed_name': (b'42', b'25')}))
171 | self.assertSetEqual(set(['fixed_name', 'var_name']),
172 | set(fake_data.keys()))
173 | self.assertAllEqual(fake_data['fixed_name'], [b'42', b'25'])
174 | self.assertAllEqual(fake_data['var_name'].values, [0])
175 | self.assertAllEqual(fake_data['var_name'].indices, [[0]])
176 | self.assertAllEqual(fake_data['var_name'].dense_shape, [1])
177 |
178 | def test_no_output_name(self):
179 | parse_fn = (
180 | builders.ExampleParserBuilder()
181 | .parse_feature('my_fixed_len_feature',
182 | tf.io.FixedLenFeature((2,), dtype=tf.int64))
183 | .build())
184 | features_dict = parse_fn(self._raw_tf_example)
185 |
186 | self.assertSetEqual(set(['my_fixed_len_feature']),
187 | set(features_dict.keys()))
188 |
189 | def test_same_output_name(self):
190 | parser_builder = builders.ExampleParserBuilder()
191 | parser_builder.parse_feature('my_fixed_len_feature',
192 | tf.io.FixedLenFeature((2,), dtype=tf.int64),
193 | 'same_name')
194 | parser_builder.parse_feature('my_fixed_len_feature',
195 | tf.io.FixedLenFeature((2,), dtype=tf.int64),
196 | 'other_name')
197 |
198 | with self.assertRaises(ValueError) as _:
199 | parser_builder.parse_feature(
200 | 'my_var_len_feature',
201 | tf.io.FixedLenSequenceFeature((2,), dtype=tf.int64), 'same_name')
202 |
203 | def test_different_types_for_same_feature(self):
204 | parser_builder = builders.SequenceExampleParserBuilder()
205 | parser_builder.parse_feature('my_fixed_len_feature',
206 | tf.io.FixedLenFeature((2,), dtype=tf.int64),
207 | 'fixed_name')
208 |
209 | with self.assertRaises(ValueError) as _:
210 | parser_builder.parse_feature('my_fixed_len_feature',
211 | tf.io.FixedLenFeature((3,), dtype=tf.int64),
212 | 'fixed_name_2')
213 |
214 | with self.assertRaises(ValueError) as _:
215 | parser_builder.parse_feature('my_fixed_len_feature',
216 | tf.io.FixedLenFeature((2,), dtype=tf.string),
217 | 'fixed_name_3')
218 |
219 |
220 | def _add_one(x):
221 | return tf.math.add(x, 1)
222 |
223 |
224 | def _subtract_one(x):
225 | return tf.math.subtract(x, 1)
226 |
227 |
228 | def _upper_text(x):
229 | return tf.strings.upper(x)
230 |
231 |
232 | def _add_text_len(features_dict):
233 | features_dict['feature_3'] = tf.strings.length(
234 | input=features_dict['feature_2'])
235 | return features_dict
236 |
237 |
238 | def _set_state(x, state):
239 | state['value'] = x
240 | return x
241 |
242 |
243 | def _use_state(features_dict, state):
244 | features_dict['feature_4'] = state['value']
245 | return features_dict
246 |
247 |
248 | class BuilderTest(tf.test.TestCase):
249 |
250 | def setUp(self):
251 | super().setUp()
252 |
253 | # Prepare features dictionary.
254 | self._input_features_dict = {
255 | 'feature_1': tf.constant(0),
256 | 'feature_2': tf.constant('text')
257 | }
258 |
259 | def test_basic(self):
260 | process_fn = (
261 | builders._Builder()
262 | .add_fn(_add_one, 'feature_1')
263 | .add_fn(_upper_text, 'feature_2')
264 | .add_fn(_add_text_len)
265 | .add_fn(_add_one, 'feature_1')
266 | .build())
267 | output_features_dict = process_fn(self._input_features_dict)
268 |
269 | self.assertSetEqual(
270 | set(['feature_1', 'feature_2', 'feature_3']),
271 | set(output_features_dict.keys()))
272 | self.assertEqual(output_features_dict['feature_1'], 2)
273 | self.assertEqual(output_features_dict['feature_2'], b'TEXT')
274 | self.assertEqual(output_features_dict['feature_3'], 4)
275 |
276 | def test_replace(self):
277 | process_fn = (
278 | builders._Builder()
279 | .add_fn(_add_one, 'feature_1', 'add_one')
280 | .add_fn(_upper_text, 'feature_2')
281 | .replace_fn('add_one', _subtract_one)
282 | .build())
283 | output_features_dict = process_fn(self._input_features_dict)
284 |
285 | self.assertSetEqual(set(['feature_1', 'feature_2']),
286 | set(output_features_dict.keys()))
287 | self.assertEqual(output_features_dict['feature_1'], -1)
288 | self.assertEqual(output_features_dict['feature_2'], b'TEXT')
289 |
290 | def test_remove(self):
291 | process_fn = (
292 | builders._Builder()
293 | .add_fn(_add_one, 'feature_1', 'add_one')
294 | .add_fn(_upper_text, 'feature_2')
295 | .remove_fn('add_one')
296 | .build())
297 | output_features_dict = process_fn(self._input_features_dict)
298 |
299 | self.assertSetEqual(set(['feature_1', 'feature_2']),
300 | set(output_features_dict.keys()))
301 | self.assertEqual(output_features_dict['feature_1'], 0)
302 | self.assertEqual(output_features_dict['feature_2'], b'TEXT')
303 |
304 | def test_reset(self):
305 | process_fn = (
306 | builders._Builder()
307 | .add_fn(_add_one, 'feature_1')
308 | .add_fn(_upper_text, 'feature_2')
309 | .reset()
310 | .build())
311 | output_features_dict = process_fn(self._input_features_dict)
312 |
313 | self.assertSetEqual(set(['feature_1', 'feature_2']),
314 | set(output_features_dict.keys()))
315 | self.assertEqual(output_features_dict['feature_1'], 0)
316 | self.assertEqual(output_features_dict['feature_2'], b'text')
317 |
318 | def test_stateful(self):
319 | process_fn = (
320 | builders._Builder()
321 | .add_fn(_set_state, 'feature_1', stateful=True)
322 | .add_fn(_use_state, stateful=True)
323 | .build())
324 | output_features_dict = process_fn(self._input_features_dict)
325 |
326 | self.assertSetEqual(set(['feature_1', 'feature_2', 'feature_4']),
327 | set(output_features_dict.keys()))
328 | self.assertEqual(output_features_dict['feature_1'], 0)
329 | self.assertEqual(output_features_dict['feature_4'], 0)
330 |
331 | def test_same_fn_name(self):
332 | builder = builders._Builder().add_fn(_add_one, 'feature_1', 'add_one')
333 |
334 | with self.assertRaises(ValueError) as _:
335 | builder.add_fn(_add_one, 'feature_1', 'add_one')
336 |
337 | def test_replace_wrong_fn_name(self):
338 | builder = builders._Builder().add_fn(_add_one, 'feature_1', 'add_one')
339 |
340 | with self.assertRaises(ValueError) as _:
341 | builder.replace_fn('add_one_wrong', _add_one)
342 |
343 | def test_insert(self):
344 | def replace_string(_):
345 | return tf.constant('replaced_text')
346 |
347 | builder = builders._Builder() .add_fn(_add_text_len, fn_name='text_len')
348 | output_features_dict = builder.build()(self._input_features_dict)
349 |
350 | builder.add_fn(replace_string, 'feature_2', add_before_fn_name='text_len')
351 | output_features_dict_2 = builder.build()(self._input_features_dict)
352 |
353 | self.assertSetEqual(set(['feature_1', 'feature_2', 'feature_3']),
354 | set(output_features_dict.keys()))
355 | self.assertEqual(output_features_dict['feature_2'], b'text')
356 | self.assertEqual(output_features_dict['feature_3'], 4)
357 |
358 | self.assertSetEqual(set(['feature_1', 'feature_2', 'feature_3']),
359 | set(output_features_dict_2.keys()))
360 | self.assertEqual(output_features_dict_2['feature_2'], b'replaced_text')
361 | self.assertEqual(output_features_dict_2['feature_3'], 13)
362 |
363 | def test_wrong_add_before_fn_name(self):
364 | builder = builders._Builder().add_fn(_add_one, 'feature_1', 'add_one')
365 |
366 | with self.assertRaises(ValueError) as _:
367 | builder.add_fn(_add_one, 'feature_1', add_before_fn_name='add_one_wrong')
368 |
369 |
370 | class FilterBuilderTest(tf.test.TestCase):
371 |
372 | def setUp(self):
373 | super().setUp()
374 |
375 | # Prepare features dictionary.
376 | self._input_features_dict = {
377 | 'feature_1': tf.constant(0),
378 | 'feature_2': tf.constant('text'),
379 | 'feature_3': tf.zeros((16, 200, 200, 3))
380 | }
381 |
382 | @parameterized.expand(((builders.Phase.READ,), (builders.Phase.PARSE,),
383 | (builders.Phase.SAMPLE,), (builders.Phase.DECODE,),
384 | (builders.Phase.PREPROCESS,),
385 | (builders.Phase.POSTPROCESS,)))
386 | def test_drop(self, phase):
387 | filter_fn = (
388 | builders.FilterBuilder()
389 | .add_filter_fn(lambda fd: tf.equal(fd['feature_1'], 0), phase)
390 | .add_filter_fn(lambda fd: tf.equal(fd['feature_2'], 'no_text'), phase)
391 | .add_filter_fn(
392 | lambda fd: tf.equal(tf.shape(input=fd['feature_3'])[3], 3), phase)
393 | .build(phase))
394 | keep = filter_fn(self._input_features_dict)
395 |
396 | self.assertEqual(keep, False)
397 |
398 | @parameterized.expand(((builders.Phase.READ,), (builders.Phase.PARSE,),
399 | (builders.Phase.SAMPLE,), (builders.Phase.DECODE,),
400 | (builders.Phase.PREPROCESS,),
401 | (builders.Phase.POSTPROCESS,)))
402 | def test_keep(self, phase):
403 | filter_fn = (
404 | builders.FilterBuilder()
405 | .add_filter_fn(lambda fd: tf.equal(fd['feature_1'], 0), phase)
406 | .add_filter_fn(lambda fd: tf.equal(fd['feature_2'], 'text'), phase)
407 | .add_filter_fn(
408 | lambda fd: tf.equal(tf.shape(input=fd['feature_3'])[3], 3), phase)
409 | .build(phase))
410 | keep = filter_fn(self._input_features_dict)
411 |
412 | self.assertEqual(keep, True)
413 |
414 | @parameterized.expand(((builders.Phase.READ,), (builders.Phase.PARSE,),
415 | (builders.Phase.SAMPLE,), (builders.Phase.DECODE,),
416 | (builders.Phase.PREPROCESS,),
417 | (builders.Phase.POSTPROCESS,)))
418 | def test_empty(self, phase):
419 | filter_fn = builders.FilterBuilder().build(phase)
420 | keep = filter_fn(self._input_features_dict)
421 |
422 | self.assertEqual(keep, True)
423 |
424 |
425 | if __name__ == '__main__':
426 | tf.test.main()
427 |
--------------------------------------------------------------------------------
/dmvr/modalities.py:
--------------------------------------------------------------------------------
1 | # Copyright 2021 DeepMind Technologies Limited.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # https://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | """Utils for adding modalities."""
16 |
17 | import functools
18 | from typing import Optional
19 | from typing import Union
20 |
21 | from absl import logging
22 | from dmvr import builders
23 | from dmvr import processors
24 | from dmvr import tokenizers
25 | import tensorflow as tf
26 |
27 |
28 | # ----------------------------------------------------------------------
29 | # -------- Methods aggregating functions for a given modality. ---------
30 | # ----------------------------------------------------------------------
31 |
32 |
33 | def add_image(
34 | parser_builder: builders.BaseParserBuilder,
35 | sampler_builder: builders.SamplerBuilder,
36 | decoder_builder: builders.DecoderBuilder,
37 | preprocessor_builder: builders.PreprocessorBuilder,
38 | postprocessor_builder: builders.PostprocessorBuilder,
39 | input_feature_name: str = 'image/encoded',
40 | output_feature_name: str = builders.IMAGE_FEATURE_NAME,
41 | is_training: bool = True,
42 | # Video related parameters.
43 | num_frames: int = 32,
44 | stride: int = 1,
45 | num_test_clips: int = 1,
46 | min_resize: int = 224,
47 | resize_method: str = tf.image.ResizeMethod.BILINEAR,
48 | crop_size: int = 200,
49 | zero_centering_image: bool = False,
50 | sync_random_state: bool = True,
51 | is_rgb: Optional[bool] = True,
52 | is_flow: bool = False,
53 | random_flip: bool = True,
54 | normalization_mean: Union[tf.Tensor, float] = 0,
55 | normalization_std: Union[tf.Tensor, float] = 1,
56 | ) -> None:
57 | """Adds functions to process image feature to builders.
58 |
59 | This function expects the input to be either a `tf.train.SequenceExample` (for
60 | videos) and have the following structure:
61 | ```
62 | feature_lists {
63 | feature_list {
64 | key: input_feature_name
65 | value {
66 | feature {
67 | bytes_list {
68 | value: jpeg_bytes
69 | }
70 | }
71 | }
72 | }
73 | }
74 | ```
75 |
76 | Or a `tf.train.Example` (for image only) and have the following structure:
77 | ```
78 | features {
79 | feature {
80 | key: input_feature_name
81 | value {
82 | bytes_list {
83 | value: "JPEG"
84 | }
85 | }
86 | }
87 | }
88 | ```
89 |
90 | The corresponding `builders.ExampleParserBuilder` or
91 | `builders.SequenceExampleParserBuilder` has to be given as parameter.
92 |
93 | Args:
94 | parser_builder: An instance of a `builders.BaseParserBuilder`.
95 | sampler_builder: An instance of a `builders.SamplerBuilder`.
96 | decoder_builder: An instance of a `builders.DecoderBuilder`.
97 | preprocessor_builder: An instance of a `builders.PreprocessorBuilder`.
98 | postprocessor_builder: An instance of a `builders.PostprocessorBuilder`.
99 | input_feature_name: Name of the feature in the input `tf.train.Example` or
100 | `tf.train.SequenceExample`. Exposing this as an argument allows using this
101 | function for different image features within a single dataset.
102 | output_feature_name: Name of the feature in the output features dictionary.
103 | Exposing this as an argument allows using this function for different
104 | image features within a single dataset.
105 | is_training: Whether in training mode. If `True`, random sample, crop and
106 | left right flip is used.
107 | num_frames: Number of frames per subclip. For single images, use 1.
108 | stride: Temporal stride to sample frames.
109 | num_test_clips: Number of test clips (1 by default). If more than 1, this
110 | will sample multiple linearly spaced clips within each video at test time.
111 | If 1, then a single clip in the middle of the video is sampled. The clips
112 | are aggregated in the batch dimension.
113 | min_resize: Frames are resized so that `min(height, width)` is `min_resize`.
114 | resize_method: A resizing method.
115 | crop_size: Final size of the frame after cropping the resized frames. Both
116 | height and width are the same.
117 | zero_centering_image: If `True`, frames are normalized to values in [-1, 1].
118 | If `False`, values in [0, 1].
119 | sync_random_state: Whether to use stateful option to keep random operations
120 | in sync between different modalities. All modalities having this option
121 | `True` will use the same outcome in random operations such as sampling and
122 | cropping.
123 | is_rgb: If `True`, the number of channels in the JPEG is 3, if False, 1. If
124 | is_flow is `True`, `is_rgb` should be set to `None` (see below).
125 | is_flow: If `True`, the image is assumed to contain flow and will be
126 | processed as such. Note that the number of channels in the JPEG for flow
127 | is 3, but only two channels will be output corresponding to the valid
128 | horizontal and vertical displacement.
129 | random_flip: If `True`, a random horizontal flip is applied to the input
130 | image. This augmentation may not be used if the label set contains
131 | direction related classes, such as `pointing left`, `pointing right`, etc.
132 | normalization_mean: value to subtract from the input image to normalize it.
133 | normalization_std: value to divide by from the input image to normalize it.
134 | """
135 |
136 | # Validate parameters.
137 | if is_flow and is_rgb is not None:
138 | raise ValueError('`is_rgb` should be `None` when requesting flow.')
139 |
140 | if is_flow and not zero_centering_image:
141 | raise ValueError('Flow contains displacement values that can be negative, '
142 | 'but `zero_centering_image` was set to `False`.')
143 |
144 | if is_training and num_test_clips != 1:
145 | logging.info('`num_test_clips` %d is ignored since `is_training` is true.',
146 | num_test_clips)
147 |
148 | # Parse frames or single image.
149 | if isinstance(parser_builder, builders.SequenceExampleParserBuilder):
150 | parser_builder.parse_feature(
151 | feature_name=input_feature_name,
152 | feature_type=tf.io.FixedLenSequenceFeature((), dtype=tf.string),
153 | output_name=output_feature_name)
154 | elif isinstance(parser_builder, builders.ExampleParserBuilder):
155 | parser_builder.parse_feature(
156 | feature_name=input_feature_name,
157 | feature_type=tf.io.FixedLenFeature((), dtype=tf.string),
158 | output_name=output_feature_name)
159 | # Expand dimensions so single images have the same structure as videos.
160 | sampler_builder.add_fn(
161 | fn=lambda x: tf.expand_dims(x, axis=0),
162 | feature_name=output_feature_name,
163 | fn_name=f'{output_feature_name}_expand_dims')
164 | else:
165 | raise ValueError('`parser_builder` has an unexpected type.')
166 |
167 | # Temporal sampler.
168 | if is_training:
169 | # Sample random clip.
170 | sampler_builder.add_fn(
171 | # pylint: disable=g-long-lambda
172 | fn=lambda x, s=None: processors.sample_sequence(
173 | x, num_frames, True, stride, state=s),
174 | # pylint: enable=g-long-lambda
175 | feature_name=output_feature_name,
176 | fn_name=f'{output_feature_name}_random_sample',
177 | # Use state to keep coherence between modalities if requested.
178 | stateful=sync_random_state)
179 | else:
180 | if num_test_clips > 1:
181 | # Sample linspace clips.
182 | sampler_builder.add_fn(
183 | # pylint: disable=g-long-lambda
184 | fn=lambda x: processors.sample_linspace_sequence(
185 | x, num_test_clips, num_frames, stride),
186 | # pylint: enable=g-long-lambda
187 | feature_name=output_feature_name,
188 | fn_name=f'{output_feature_name}_linspace_sample')
189 | else:
190 | # Sample middle clip.
191 | sampler_builder.add_fn(
192 | fn=lambda x: processors.sample_sequence(x, num_frames, False, stride),
193 | feature_name=output_feature_name,
194 | fn_name=f'{output_feature_name}_middle_sample')
195 |
196 | # Decode JPEG string to `tf.uint8`.
197 | # Note that for flow, 3 channels are stored in the JPEG: the first two
198 | # corresponds to horizontal and vertical displacement, respectively.
199 | # The last channel contains zeros and is dropped later in the preprocessing.
200 | # Hence, the output number of channels for flow is 2.
201 | num_raw_channels = 3 if (is_rgb or is_flow) else 1
202 | decoder_builder.add_fn(
203 | fn=lambda x: processors.decode_jpeg(x, channels=num_raw_channels),
204 | feature_name=output_feature_name,
205 | fn_name=f'{output_feature_name}_decode_jpeg')
206 |
207 | if is_flow:
208 | # Cast the flow to `tf.float32`, normalizing between [-1.0, 1.0].
209 | preprocessor_builder.add_fn(
210 | fn=lambda x: processors.normalize_image(x, zero_centering_image=True),
211 | feature_name=output_feature_name,
212 | fn_name=f'{output_feature_name}_normalize')
213 |
214 | # Resize images (resize happens only if necessary to save compute).
215 | preprocessor_builder.add_fn(
216 | # pylint: disable=g-long-lambda
217 | fn=lambda x: processors.resize_smallest(
218 | x, min_resize, is_flow=is_flow, method=resize_method),
219 | # pylint: enable=g-long-lambda
220 | feature_name=output_feature_name,
221 | fn_name=f'{output_feature_name}_resize_smallest')
222 |
223 | if is_training:
224 | # Standard image data augmentation: random crop and random flip.
225 | preprocessor_builder.add_fn(
226 | # pylint: disable=g-long-lambda
227 | fn=lambda x, s=None: processors.crop_image(
228 | x, crop_size, crop_size, True, state=s),
229 | # pylint: enable=g-long-lambda
230 | feature_name=output_feature_name,
231 | fn_name=f'{output_feature_name}_random_crop',
232 | # Use state to keep coherence between modalities if requested.
233 | stateful=sync_random_state)
234 | if random_flip:
235 | preprocessor_builder.add_fn(
236 | # pylint: disable=g-long-lambda
237 | fn=lambda x, s=None: processors.random_flip_left_right(
238 | x, state=s, is_flow=is_flow),
239 | # pylint: enable=g-long-lambda
240 | feature_name=output_feature_name,
241 | fn_name=f'{output_feature_name}_random_flip',
242 | # Use state to keep coherence between modalities if requested.
243 | stateful=sync_random_state)
244 | else:
245 | # Central crop of the frames.
246 | preprocessor_builder.add_fn(
247 | fn=lambda x: processors.crop_image(x, crop_size, crop_size, False),
248 | feature_name=output_feature_name,
249 | fn_name=f'{output_feature_name}_central_crop')
250 |
251 | if is_flow:
252 | # Keep only two channels for the flow: horizontal and vertical displacement.
253 | preprocessor_builder.add_fn(
254 | fn=lambda x: x[:, :, :, :2],
255 | feature_name=output_feature_name,
256 | fn_name=f'{output_feature_name}_extract_flow_channels')
257 |
258 | # Clip the flow to stay between [-1.0 and 1.0]
259 | preprocessor_builder.add_fn(
260 | fn=lambda x: tf.clip_by_value(x, -1.0, 1.0),
261 | feature_name=output_feature_name,
262 | fn_name=f'{output_feature_name}_clip_flow')
263 | else:
264 | # Cast the frames to `tf.float32`, normalizing according to
265 | # `zero_centering_image`.
266 | preprocessor_builder.add_fn(
267 | fn=lambda x: processors.normalize_image(x, zero_centering_image),
268 | feature_name=output_feature_name,
269 | fn_name=f'{output_feature_name}_normalize')
270 |
271 | preprocessor_builder.add_fn(
272 | fn=lambda x: x - normalization_mean,
273 | feature_name=output_feature_name,
274 | fn_name=f'{output_feature_name}_subtract_given_mean')
275 |
276 | preprocessor_builder.add_fn(
277 | fn=lambda x: x / normalization_std,
278 | feature_name=output_feature_name,
279 | fn_name=f'{output_feature_name}_divide_by_given_std')
280 |
281 | if num_test_clips > 1 and not is_training:
282 | # In this case, multiple clips are merged together in batch dimension which
283 | # will be `B * num_test_clips`.
284 | postprocessor_builder.add_fn(
285 | fn=lambda x: tf.reshape( # pylint: disable=g-long-lambda
286 | x, (-1, num_frames, x.shape[2], x.shape[3], x.shape[4])),
287 | feature_name=output_feature_name,
288 | fn_name=f'{output_feature_name}_reshape')
289 |
290 |
291 | def add_label(
292 | parser_builder: builders.BaseParserBuilder,
293 | decoder_builder: builders.DecoderBuilder,
294 | preprocessor_builder: builders.PreprocessorBuilder,
295 | input_label_index_feature_name: str = 'clip/label/index',
296 | output_label_index_feature_name: str = builders.LABEL_INDEX_FEATURE_NAME,
297 | input_label_name_feature_name: Optional[str] = 'clip/label/text',
298 | output_label_name_feature_name: Optional[str] = builders
299 | .LABEL_NAME_FEATURE_NAME,
300 | # Label related parameters.
301 | is_multi_label: bool = False,
302 | one_hot_label: bool = True,
303 | num_classes: Optional[int] = None,
304 | add_label_name: bool = False):
305 | """Adds functions to process label feature to builders.
306 |
307 | This function expects the input to be either a `tf.train.SequenceExample`
308 | (with the features in the context) or a `tf.train.Example`. The expected
309 | structure is (or equivalent for `tf.train.Example`):
310 | ```
311 | context {
312 | feature {
313 | key: input_label_index_feature_name
314 | value {
315 | int64_list {
316 | value: 42
317 | ...
318 | }
319 | }
320 | }
321 | feature {
322 | key: input_label_name_feature_name
323 | value {
324 | bytes_list {
325 | value: "label_42"
326 | ...
327 | }
328 | }
329 | }
330 | }
331 | ```
332 |
333 | The corresponding `builders.ExampleParserBuilder` or
334 | `builders.SequenceExampleParserBuilder` has to be given as parameter.
335 |
336 | Args:
337 | parser_builder: An instance of a `builders.BaseParserBuilder`.
338 | decoder_builder: An instance of a `builders.DecoderBuilder`.
339 | preprocessor_builder: An instance of a `builders.PreprocessorBuilder`.
340 | input_label_index_feature_name: Name of the label index feature in the input
341 | `tf.train.Example` or `tf.train.SequenceExample`. Exposing this as an
342 | argument allows using this function for different label features within a
343 | single dataset.
344 | output_label_index_feature_name: Name of the label index feature in the
345 | output features dictionary. Exposing this as an argument allows using this
346 | function for different label features within a single dataset.
347 | input_label_name_feature_name: Name of the label name feature in the input
348 | `tf.train.Example` or `tf.train.SequenceExample`. If `add_label_name` is
349 | false, this option is ignored. Exposing this as an argument allows using
350 | this function for different label features within a single dataset.
351 | output_label_name_feature_name: Name of the label name feature in the output
352 | features dictionary. If `add_label_name` is false, this option is ignored.
353 | Exposing this as an argument allows using this function for different
354 | label features within a single dataset.
355 | is_multi_label: Whether raw data contains multiple labels per example.
356 | one_hot_label: Return labels as one hot tensors. If `is_multi_label` is
357 | `True`, one hot tensor might have multiple ones.
358 | num_classes: Total number of classes in the dataset. It has to be provided
359 | if `one_hot_label` is `True`.
360 | add_label_name: Also return the name of the label. Not yet supported for
361 | multi label.
362 | """
363 | # Validate parameters.
364 | if one_hot_label and not num_classes:
365 | raise ValueError(
366 | '`num_classes` must be given when requesting one hot label.')
367 | if is_multi_label and not one_hot_label:
368 | logging.warning(
369 | 'Multi label indices will be returned in a non fixed size dimension.')
370 | if add_label_name and (input_label_name_feature_name is None or
371 | output_label_name_feature_name is None):
372 | raise ValueError(
373 | '`input_label_name_feature_name` and `output_label_name_feature_name` '
374 | 'must be given when `add_label_name` is true.')
375 |
376 | # Parse label.
377 | if isinstance(parser_builder, builders.SequenceExampleParserBuilder):
378 | parser_builder.parse_feature(
379 | feature_name=input_label_index_feature_name,
380 | feature_type=tf.io.VarLenFeature(dtype=tf.int64),
381 | output_name=output_label_index_feature_name,
382 | is_context=True)
383 | if add_label_name:
384 | parser_builder.parse_feature(
385 | feature_name=input_label_name_feature_name,
386 | feature_type=tf.io.VarLenFeature(dtype=tf.string),
387 | output_name=output_label_name_feature_name,
388 | is_context=True)
389 | elif isinstance(parser_builder, builders.ExampleParserBuilder):
390 | parser_builder.parse_feature(
391 | feature_name=input_label_index_feature_name,
392 | feature_type=tf.io.VarLenFeature(dtype=tf.int64),
393 | output_name=output_label_index_feature_name)
394 | if add_label_name:
395 | parser_builder.parse_feature(
396 | feature_name=input_label_name_feature_name,
397 | feature_type=tf.io.VarLenFeature(dtype=tf.string),
398 | output_name=output_label_name_feature_name)
399 | else:
400 | raise ValueError('`parser_builder` has an unexpected type.')
401 |
402 | # Densify labels tensor in order to support multi label case.
403 | decoder_builder.add_fn(
404 | fn=tf.sparse.to_dense,
405 | feature_name=output_label_index_feature_name,
406 | fn_name=f'{output_label_index_feature_name}_sparse_to_dense')
407 | if add_label_name:
408 | decoder_builder.add_fn(
409 | fn=tf.sparse.to_dense,
410 | feature_name=output_label_name_feature_name,
411 | fn_name=f'{output_label_name_feature_name}_sparse_to_dense')
412 |
413 | if one_hot_label:
414 | # Replace label index by one hot representation.
415 | preprocessor_builder.add_fn(
416 | fn=lambda x: tf.reduce_sum( # pylint: disable=g-long-lambda
417 | input_tensor=tf.one_hot(x, num_classes),
418 | axis=0),
419 | feature_name=output_label_index_feature_name,
420 | fn_name=f'{output_label_index_feature_name}_one_hot')
421 | elif not is_multi_label:
422 | preprocessor_builder.add_fn(
423 | fn=lambda x: processors.set_shape(x, (1,)),
424 | feature_name=output_label_index_feature_name,
425 | fn_name=f'{output_label_index_feature_name}_set_shape')
426 |
427 | if add_label_name and not is_multi_label:
428 | preprocessor_builder.add_fn(
429 | fn=lambda x: processors.set_shape(x, (1,)),
430 | feature_name=output_label_name_feature_name,
431 | fn_name=f'{output_label_name_feature_name}_set_shape')
432 |
433 |
434 | def add_text(
435 | parser_builder: builders.BaseParserBuilder,
436 | decoder_builder: builders.DecoderBuilder,
437 | preprocessor_builder: builders.PreprocessorBuilder,
438 | tokenizer: tokenizers.TextTokenizer,
439 | is_training: bool = True,
440 | input_feature_name: str = 'caption/string',
441 | output_raw_string_name: str = builders.TEXT_FEATURE_NAME,
442 | output_feature_name: str = builders.TEXT_INDICES_FEATURE_NAME,
443 | # Text related parameters.
444 | prepend_bos: bool = False,
445 | append_eos: bool = False,
446 | keep_raw_string: bool = False,
447 | max_num_captions: int = 1,
448 | max_num_tokens: Optional[int] = 16,
449 | sync_random_state: bool = False):
450 | """Adds functions to process text feature to builders.
451 |
452 | This function expects the input to be either a `tf.train.SequenceExample`
453 | (with the features in the context) or a `tf.train.Example`. The expected
454 | structure is (or equivalent for `tf.train.Example`):
455 | ```
456 | context {
457 | feature {
458 | key: input_feature_name
459 | value {
460 | bytes_list {
461 | value: "Hello world!"
462 | value: "This is a caption."
463 | ...
464 | }
465 | }
466 | }
467 | }
468 | ```
469 |
470 | The corresponding `builders.ExampleParserBuilder` or
471 | `builders.SequenceExampleParserBuilder` has to be given as parameter.
472 |
473 | Args:
474 | parser_builder: An instance of a `builders.BaseParserBuilder`.
475 | decoder_builder: An instance of a `builders.DecoderBuilder`.
476 | preprocessor_builder: An instance of a `builders.PreprocessorBuilder`.
477 | tokenizer: An instance of a tokenizer.
478 | is_training: Whether in training mode. This will be used to randomly sample
479 | the captions.
480 | input_feature_name: Name of the feature in the input `tf.train.Example` or
481 | `tf.train.SequenceExample`. Exposing this as an argument allows using this
482 | function for different text features within a single dataset.
483 | output_raw_string_name: Name of the raw string in the output features
484 | dictionary. Exposing this as an argument allows using this function for
485 | different text features within a single dataset.
486 | output_feature_name: Name of the feature in the output features dictionary.
487 | Exposing this as an argument allows using this function for different text
488 | features.
489 | prepend_bos: Whether to prepend BOS token.
490 | append_eos: Whether to append EOS token.
491 | keep_raw_string: Whether to keep raw string.
492 | max_num_captions: Maximum number of captions to keep. If there are more
493 | captions in the proto, only the first `max_num_captions` will be returned
494 | is `is_training` is set to `False`. If `is_training` is `True`, then
495 | `max_num_captions` will be randomly sampled. Finally, if the proto
496 | contains less than `max_num_captions`, we pad with empty strings to make
497 | sure there are `max_num_captions` in total.
498 | max_num_tokens: Maximum number of tokens to keep from the text for each
499 | caption. If there are more tokens, sequence is cropped, if less, the
500 | caption is padded using the tokenizer pad id. The sequence is unmodified
501 | if max_num_tokens is None.
502 | sync_random_state: Whether to use stateful option to keep random operations
503 | in sync between different modalities. All modalities having this option
504 | `True` will use the same outcome in random operations used for sampling
505 | the captions.
506 | """
507 | # Parse text indices.
508 | if isinstance(parser_builder, builders.SequenceExampleParserBuilder):
509 | parser_builder.parse_feature(
510 | feature_name=input_feature_name,
511 | feature_type=tf.io.VarLenFeature(dtype=tf.string),
512 | output_name=output_raw_string_name,
513 | is_context=True)
514 | elif isinstance(parser_builder, builders.ExampleParserBuilder):
515 | parser_builder.parse_feature(
516 | feature_name=input_feature_name,
517 | feature_type=tf.io.VarLenFeature(dtype=tf.string),
518 | output_name=output_raw_string_name)
519 |
520 | # Densify text tensor.
521 | decoder_builder.add_fn(
522 | fn=tf.sparse.to_dense,
523 | feature_name=output_raw_string_name,
524 | fn_name=f'{output_feature_name}_sparse_to_dense')
525 |
526 | preprocessor_builder.add_fn(
527 | # pylint: disable=g-long-lambda
528 | lambda x, s=None: processors.sample_or_pad_non_sorted_sequence(
529 | x, max_num_captions, b'', random=is_training, state=s),
530 | # pylint: enable=g-long-lambda
531 | feature_name=output_raw_string_name,
532 | fn_name=f'{output_feature_name}_sample_captions',
533 | # Use state to keep coherence between modalities if requested.
534 | stateful=sync_random_state)
535 |
536 | # Tokenize the sentence.
537 | preprocessor_builder.add_fn(
538 | fn=lambda x: processors.tokenize( # pylint: disable=g-long-lambda
539 | x, tokenizer, output_raw_string_name, output_feature_name,
540 | prepend_bos, append_eos, max_num_tokens, keep_raw_string),
541 | fn_name=f'{output_feature_name}_tokenization')
542 |
543 | if max_num_tokens is not None:
544 | # Set text shape.
545 | shape = (max_num_captions, max_num_tokens)
546 | preprocessor_builder.add_fn(
547 | fn=lambda x: processors.set_shape(x, shape),
548 | feature_name=output_feature_name,
549 | fn_name=f'{output_feature_name}_set_shape')
550 |
551 |
552 | def add_audio(
553 | parser_builder: builders.BaseParserBuilder,
554 | sampler_builder: builders.SamplerBuilder,
555 | postprocessor_builder: builders.PostprocessorBuilder,
556 | preprocessor_builder: Optional[builders.PreprocessorBuilder] = None,
557 | input_feature_name: str = 'WAVEFORM/feature/floats',
558 | output_feature_name: str = builders.AUDIO_FEATURE_NAME,
559 | is_training: bool = True,
560 | # Audio related parameters.
561 | num_samples: int = 30720,
562 | stride: int = 1,
563 | sample_rate: Optional[int] = 48000,
564 | target_sample_rate: Optional[int] = None,
565 | num_test_clips: int = 1,
566 | sync_random_state: bool = True):
567 | """Adds functions to process audio feature to builders.
568 |
569 | This function expects the input to be either a `tf.train.SequenceExample` (for
570 | videos) and have the following structure:
571 | ```
572 | feature_lists {
573 | feature_list {
574 | key: input_feature_name
575 | value {
576 | feature {
577 | float_list {
578 | value: 0.0
579 | value: 0.1
580 | value: 0.2
581 | ...
582 | }
583 | }
584 | }
585 | }
586 | }
587 | ```
588 |
589 | Or a `tf.train.Example` (for image only) and have the following structure:
590 | ```
591 | features {
592 | feature {
593 | key: input_feature_name
594 | value {
595 | float_list {
596 | value: 0.0
597 | value: 0.1
598 | value: 0.2
599 | ...
600 | }
601 | }
602 | }
603 | }
604 | ```
605 |
606 | The corresponding `builders.ExampleParserBuilder` or
607 | `builders.SequenceExampleParserBuilder` has to be given as parameter.
608 |
609 | Args:
610 | parser_builder: An instance of a `builders.BaseParserBuilder`.
611 | sampler_builder: An instance of a `builders.SamplerBuilder`.
612 | postprocessor_builder: An instance of a `builders.PostprocessorBuilder`.
613 | preprocessor_builder: An instance of a `builders.PreprocessorBuilder`.
614 | input_feature_name: Name of the feature in the input `tf.train.Example` or
615 | `tf.train.SequenceExample`. Exposing this as an argument allows using this
616 | function for different audio features within a single dataset.
617 | output_feature_name: Name of the feature in the output features dictionary.
618 | Exposing this as an argument allows using this function for different
619 | audio features within a single dataset
620 | is_training: Whether in training mode. If `True`, random sample is used.
621 | num_samples: Number of samples per subclip.
622 | stride: Temporal stride to sample audio signal.
623 | sample_rate: The original sample rate of the input audio stored in sstables.
624 | target_sample_rate: If this is not None, the target new sample rate of the
625 | waveforms. Fast Fourier Transforms will be triggered if true.
626 | num_test_clips: Number of test clips (1 by default). If more than 1, this
627 | will sample multiple linearly spaced clips within each audio at test time.
628 | If 1, then a single clip in the middle of the audio is sampled. The clips
629 | are aggregated in the batch dimension.
630 | sync_random_state: Whether to use stateful option to keep random operations
631 | in sync between different modalities. All modalities having this option
632 | `True` will use the same outcome in random operations such as sampling and
633 | cropping.
634 | """
635 | # Validate parameters.
636 | if is_training and num_test_clips != 1:
637 | logging.info('`num_test_clips` %d is ignored since `is_training` is true.',
638 | num_test_clips)
639 |
640 | # Keep audio signal.
641 | parser_builder.parse_feature(
642 | feature_name=input_feature_name,
643 | # Entire signal stored in one Feature.
644 | feature_type=tf.io.VarLenFeature(dtype=tf.float32),
645 | output_name=output_feature_name)
646 |
647 | # Densify.
648 | sampler_builder.add_fn(
649 | fn=lambda x: tf.sparse.to_dense(x)[0],
650 | feature_name=output_feature_name,
651 | fn_name=f'{output_feature_name}_sparse_to_dense')
652 |
653 | # Temporal sampler.
654 | if is_training:
655 | # Sample random clip.
656 | sampler_builder.add_fn(
657 | # pylint: disable=g-long-lambda
658 | fn=lambda x, s=None: processors.sample_sequence(
659 | x, num_samples, True, stride, state=s),
660 | # pylint: enable=g-long-lambda
661 | feature_name=output_feature_name,
662 | fn_name=f'{output_feature_name}_random_sample',
663 | # Use state to keep coherence between modalities if requested.
664 | stateful=sync_random_state)
665 | else:
666 | if num_test_clips > 1:
667 | # Sample linspace clips.
668 | sampler_builder.add_fn(
669 | # pylint: disable=g-long-lambda
670 | fn=lambda x: processors.sample_linspace_sequence(
671 | x, num_test_clips, num_samples, stride),
672 | # pylint: enable=g-long-lambda
673 | feature_name=output_feature_name,
674 | fn_name=f'{output_feature_name}_linspace_sample')
675 | else:
676 | # Sample middle clip.
677 | sampler_builder.add_fn(
678 | # pylint: disable=g-long-lambda
679 | fn=lambda x: processors.sample_sequence(
680 | x, num_samples, False, stride),
681 | # pylint: enable=g-long-lambda
682 | feature_name=output_feature_name,
683 | fn_name=f'{output_feature_name}_middle_sample')
684 |
685 | # Apply FFTs to change the sample rate of the waveforms.
686 | if preprocessor_builder is not None and target_sample_rate is not None:
687 | preprocessor_builder.add_fn(
688 | functools.partial(
689 | processors.resample_audio,
690 | num_subclips=num_test_clips,
691 | in_sample_rate=sample_rate,
692 | out_sample_rate=target_sample_rate,
693 | is_training=is_training),
694 | feature_name=builders.AUDIO_FEATURE_NAME)
695 |
696 | if num_test_clips > 1 and not is_training:
697 | # In this case, multiple clips are merged together in batch dimension which
698 | # will be `B * num_test_clips`.
699 | postprocessor_builder.add_fn(
700 | fn=lambda x: tf.reshape(x, (-1, x.shape[-1])),
701 | feature_name=output_feature_name,
702 | fn_name=f'{output_feature_name}_reshape')
703 |
704 |
705 | def add_spectrogram(
706 | preprocessor_builder: builders.PreprocessorBuilder,
707 | postprocessor_builder: builders.PostprocessorBuilder,
708 | input_feature_name: str = builders.AUDIO_FEATURE_NAME,
709 | output_feature_name: str = builders.AUDIO_MEL_FEATURE_NAME,
710 | is_training: bool = True,
711 | sample_rate: int = 48000,
712 | spectrogram_type: str = 'logmf',
713 | frame_length: int = 2048,
714 | frame_step: int = 1024,
715 | num_features: int = 80,
716 | lower_edge_hertz: float = 80.0,
717 | upper_edge_hertz: float = 7600.0,
718 | preemphasis: Optional[float] = None,
719 | normalize_audio: bool = False,
720 | num_test_clips: int = 1):
721 | """Adds functions to process audio spectrogram feature to builders.
722 |
723 | Note that this function does not extract and parse audio feature. Instead, it
724 | should be used after a `add_audio` function. The output spectrogram is of the
725 | shape [batch_size, num_frames, num_features].
726 |
727 | Args:
728 | preprocessor_builder: An instance of a `builders.PreprocessorBuilder`.
729 | postprocessor_builder: An instance of a `builders.PostprocessorBuilder`.
730 | input_feature_name: Name of the feature in the input features dictionary.
731 | Exposing this as an argument allows using this function for different
732 | audio features.
733 | output_feature_name: Name of the feature in the output features dictionary.
734 | Exposing this as an argument allows using this function for different
735 | audio features.
736 | is_training: If the current mode is training or not.
737 | sample_rate: The sample rate of the input audio.
738 | spectrogram_type: The type of the spectrogram to be extracted from the
739 | waveform. Can be either `spectrogram`, `logmf`, and `mfcc`.
740 | frame_length: The length of each spectrogram frame.
741 | frame_step: The stride of spectrogram frames.
742 | num_features: The number of spectrogram features.
743 | lower_edge_hertz: Lowest frequency to consider.
744 | upper_edge_hertz: Highest frequency to consider.
745 | preemphasis: The strength of pre-emphasis on the waveform. If None, no
746 | pre-emphasis will be applied.
747 | normalize_audio: Whether to normalize the waveform or not.
748 | num_test_clips: Number of test clips (1 by default). If more than 1, this
749 | will sample multiple linearly spaced clips within each audio at test time.
750 | If 1, then a single clip in the middle of the audio is sampled. The clips
751 | are aggregated in the batch dimension.
752 | """
753 | # Validate parameters.
754 | if is_training and num_test_clips != 1:
755 | logging.info('`num_test_clips` %d is ignored since `is_training` is true.',
756 | num_test_clips)
757 |
758 | # Extract audio spectrograms.
759 | preprocessor_builder.add_fn(
760 | functools.partial(
761 | processors.compute_audio_spectrogram,
762 | num_subclips=num_test_clips,
763 | sample_rate=sample_rate,
764 | spectrogram_type=spectrogram_type,
765 | frame_length=frame_length,
766 | frame_step=frame_step,
767 | num_features=num_features,
768 | lower_edge_hertz=lower_edge_hertz,
769 | upper_edge_hertz=upper_edge_hertz,
770 | normalize=normalize_audio,
771 | preemphasis=preemphasis,
772 | audio_feature_name=input_feature_name,
773 | spectrogram_feature_name=output_feature_name))
774 |
775 | if num_test_clips > 1 and not is_training:
776 | # In this case, multiple clips are merged together in batch dimension which
777 | # will be `B * num_test_clips`.
778 | postprocessor_builder.add_fn(
779 | fn=lambda x: tf.reshape(x, (-1, x.shape[-2], x.shape[-1])),
780 | feature_name=output_feature_name,
781 | fn_name=f'{output_feature_name}_reshape')
782 |
--------------------------------------------------------------------------------
/dmvr/modalities_test.py:
--------------------------------------------------------------------------------
1 | # Copyright 2021 DeepMind Technologies Limited.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # https://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | """Tests for modalities."""
16 |
17 | import os
18 |
19 | from dmvr import builders
20 | from dmvr import modalities
21 | from dmvr import tokenizers
22 | import numpy as np
23 | from parameterized import parameterized
24 | import tensorflow as tf
25 |
26 | # Removed: Internal pyglib dependencies
27 |
28 | _TESTDATA_DIR = os.path.join(os.path.dirname(__file__), 'testdata')
29 | _SAMPLE_IMAGE_PATH = os.path.join(_TESTDATA_DIR, 'sample.jpeg')
30 | _VOCAB_PATH = os.path.join(_TESTDATA_DIR, 'tokenizers', 'word_vocab.txt')
31 |
32 |
33 | class ModalitiesTest(tf.test.TestCase):
34 |
35 | def setUp(self):
36 | super().setUp()
37 | seq_example = tf.train.SequenceExample()
38 |
39 | # Create stub frames and inject them in the SequenceExample.
40 | with open(_SAMPLE_IMAGE_PATH, 'rb') as f: raw_image_bytes = f.read()
41 | for _ in range(10 * 5):
42 | seq_example.feature_lists.feature_list.get_or_create(
43 | 'image/encoded').feature.add().bytes_list.value[:] = [raw_image_bytes]
44 |
45 | # Create stub flow and inject it in the SequenceExample.
46 | for _ in range(10 * 5):
47 | seq_example.feature_lists.feature_list.get_or_create(
48 | 'flow/encoded').feature.add().bytes_list.value[:] = [raw_image_bytes]
49 |
50 | # Create stub label and inject it in the SequenceExample.
51 | raw_label_index = 42
52 | raw_label_name = b'label'
53 | seq_example.context.feature.get_or_create(
54 | 'clip/label/index').int64_list.value[:] = [raw_label_index]
55 | seq_example.context.feature.get_or_create(
56 | 'clip/label/string').bytes_list.value[:] = [raw_label_name]
57 |
58 | # Create stub raw text and inject it in SequenceExample.
59 | raw_text = b'hello world'
60 | seq_example.context.feature.get_or_create(
61 | 'caption/string').bytes_list.value[:] = [raw_text, raw_text]
62 |
63 | # Create stub audio and inject it in SequenceExample.
64 | raw_audio = np.linspace(-1, 1, 48000 * 5)
65 | seq_example.feature_lists.feature_list.get_or_create(
66 | 'WAVEFORM/feature/floats').feature.add().float_list.value[:] = raw_audio
67 |
68 | serialized_seq_example = seq_example.SerializeToString()
69 | self._seq_examples = [serialized_seq_example] * 8 # Batch size is 8.
70 |
71 | # Create builders.
72 | self._parser_builder = builders.SequenceExampleParserBuilder()
73 | self._sampler_builder = builders.SamplerBuilder()
74 | self._decoder_builder = builders.DecoderBuilder()
75 | self._preprocessor_builder = builders.PreprocessorBuilder()
76 | self._postprocessor_builder = builders.PostprocessorBuilder()
77 |
78 | def _process_examples(self):
79 | """Process input examples simulating dataset object creation."""
80 | def pre_batch_process(raw_seq_example):
81 | output = self._parser_builder.build()(raw_seq_example)
82 | output = self._sampler_builder.build()(output)
83 | output = self._decoder_builder.build()(output)
84 | output = self._preprocessor_builder.build()(output)
85 | return output
86 |
87 | # Batch and postprocess.
88 | output = [pre_batch_process(rse) for rse in self._seq_examples]
89 | batched_output = {}
90 | for k in output[0].keys():
91 | batched_output[k] = tf.stack([out[k] for out in output])
92 | output = batched_output
93 | output = self._postprocessor_builder.build()(output)
94 |
95 | return output
96 |
97 | @parameterized.expand((
98 | (True, 1, False, True, ['image_random_sample'], [
99 | 'image_resize_smallest', 'image_random_crop', 'image_random_flip',
100 | 'image_normalize'
101 | ], []),
102 | (True, 1, False, False, ['image_random_sample'],
103 | ['image_resize_smallest', 'image_random_crop', 'image_normalize'], []),
104 | (False, 1, False, True, ['image_middle_sample'],
105 | ['image_resize_smallest', 'image_central_crop', 'image_normalize'], []),
106 | (False, 2, False, True, ['image_linspace_sample'],
107 | ['image_resize_smallest', 'image_central_crop',
108 | 'image_normalize'], ['image_reshape']),
109 | (True, 1, True, True, ['image_random_sample'], [
110 | 'image_normalize', 'image_resize_smallest', 'image_random_crop',
111 | 'image_random_flip', 'image_extract_flow_channels', 'image_clip_flow'
112 | ], []),
113 | ))
114 | def test_add_image(self, is_training, num_test_clips, is_flow, random_flip,
115 | sample_ops, preprocess_ops, postprocess_ops):
116 | is_rgb = None if is_flow else True
117 | zero_centering_image = is_flow
118 | modalities.add_image(
119 | self._parser_builder, # `parser_builder`
120 | self._sampler_builder, # `sampler_builder`
121 | self._decoder_builder, # `decoder_builder`
122 | self._preprocessor_builder, # `preprocessor_builder`
123 | self._postprocessor_builder, # `postprocessor_builder`
124 | 'image/encoded', # `input_feature_name`
125 | 'image', # `output_feature_name`
126 | is_training, # `is_training`
127 | 32, # `num_frames`
128 | 1, # `stride`
129 | num_test_clips, # `num_test_clips`
130 | 224, # `min_resize`
131 | 200, # `crop_size`
132 | zero_centering_image, # `zero_centering_image`
133 | True, # `sync_random_state`
134 | is_rgb, # `is_rgb`
135 | is_flow, # `is_flow`
136 | random_flip) # `random_flip`
137 | output = self._process_examples()
138 |
139 | self.assertAllEqual(
140 | [fd.fn_name for fd in self._sampler_builder.get_summary()], sample_ops)
141 | self.assertAllEqual(
142 | [fd.fn_name for fd in self._decoder_builder.get_summary()],
143 | ['image_decode_jpeg'])
144 | self.assertAllEqual(
145 | [fd.fn_name for fd in self._preprocessor_builder.get_summary()],
146 | preprocess_ops)
147 | self.assertAllEqual(
148 | [fd.fn_name for fd in self._postprocessor_builder.get_summary()],
149 | postprocess_ops)
150 |
151 | # Assert static shape.
152 | self.assertNotIn(None, output['image'].shape.as_list())
153 | self.assertSetEqual(set(output.keys()), set(['image']))
154 | num_output_channels = 2 if is_flow else 3
155 | self.assertAllEqual(output['image'].shape,
156 | (8 * num_test_clips, 32, 200, 200, num_output_channels))
157 |
158 | @parameterized.expand(((False, False), (False, True), (True, True)))
159 | def test_add_label(self, one_hot_label, add_label_name):
160 | modalities.add_label(
161 | self._parser_builder, # `parser_builder`
162 | self._decoder_builder, # `decoder_builder`
163 | self._preprocessor_builder, # `preprocessor_builder`
164 | 'clip/label/index', # `input_label_index_feature_name`
165 | 'label', # `output_label_index_feature_name`
166 | 'clip/label/string', # `input_label_name_feature_name`
167 | 'label_name', # `output_label_name_feature_name`
168 | False, # `is_multi_label`
169 | one_hot_label, # `one_hot_label`
170 | 50, # `num_classes`
171 | add_label_name) # `add_label_name`
172 | output = self._process_examples()
173 |
174 | decoder_ops = ['label_sparse_to_dense']
175 | if add_label_name:
176 | decoder_ops.append('label_name_sparse_to_dense')
177 | self.assertAllEqual(
178 | [fd.fn_name for fd in self._decoder_builder.get_summary()],
179 | decoder_ops)
180 | if one_hot_label:
181 | preprocess_ops = ['label_one_hot']
182 | else:
183 | preprocess_ops = ['label_set_shape']
184 | if add_label_name:
185 | preprocess_ops.append('label_name_set_shape')
186 | self.assertAllEqual(
187 | [fd.fn_name for fd in self._preprocessor_builder.get_summary()],
188 | preprocess_ops)
189 |
190 | # Assert static shape.
191 | self.assertNotIn(None, output['label'].shape.as_list())
192 |
193 | keys = set(['label'])
194 | if add_label_name:
195 | keys.add('label_name')
196 | self.assertSetEqual(set(output.keys()), keys)
197 | if one_hot_label:
198 | self.assertAllEqual(output['label'], [[0] * 42 + [1] + [0] * 7] * 8)
199 | else:
200 | self.assertAllEqual(output['label'], [[42]] * 8)
201 | if add_label_name:
202 | self.assertAllEqual(output['label_name'], [[b'label']] * 8)
203 |
204 | @parameterized.expand(((16,), (1,)))
205 | def test_add_text(self, max_num_words):
206 | tokenizer_model = tokenizers.WordTokenizer(
207 | _VOCAB_PATH) # OSS: removed internal filename loading.
208 | tokenizer_model.initialize()
209 |
210 | modalities.add_text(
211 | self._parser_builder, # `parser_builder`
212 | self._decoder_builder, # `decoder_builder`
213 | self._preprocessor_builder, # `preprocessor_builder`
214 | tokenizer_model, # `tokenizer`
215 | True, # `is_training`
216 | 'caption/string', # `input_feature_name`
217 | builders.TEXT_FEATURE_NAME, # `output_raw_name`
218 | builders.TEXT_INDICES_FEATURE_NAME, # `output_feature_name`
219 | False, # `prepend_bos`
220 | False, # `append_eos`
221 | True, # `keep_raw_string`
222 | 2, # `max_num_captions`
223 | max_num_words, # `max_num_words`
224 | True) # `sync_random_state`
225 |
226 | output = self._process_examples()
227 | self.assertAllEqual(
228 | [fd.fn_name for fd in self._decoder_builder.get_summary()],
229 | ['text_indices_sparse_to_dense'])
230 | self.assertAllEqual(
231 | [fd.fn_name for fd in self._preprocessor_builder.get_summary()],
232 | ['text_indices_sample_captions', 'text_indices_tokenization',
233 | 'text_indices_set_shape'])
234 |
235 | # Assert static shape.
236 | self.assertNotIn(
237 | None, output[builders.TEXT_INDICES_FEATURE_NAME].shape.as_list())
238 | self.assertSetEqual(set(output.keys()),
239 | set([builders.TEXT_INDICES_FEATURE_NAME,
240 | builders.TEXT_FEATURE_NAME]))
241 | words = [4, 5][:min(2, max_num_words)]
242 | padding = [0] * max(0, max_num_words - 2)
243 | self.assertAllEqual(
244 | output[builders.TEXT_INDICES_FEATURE_NAME],
245 | [[words + padding, words + padding]] * 8)
246 |
247 | @parameterized.expand((
248 | (True, 1, ['audio_sparse_to_dense', 'audio_random_sample'], []),
249 | (False, 1, ['audio_sparse_to_dense', 'audio_middle_sample'], []),
250 | (False, 2, ['audio_sparse_to_dense', 'audio_linspace_sample'],
251 | ['audio_reshape'])))
252 | def test_add_audio(self, is_training, num_test_clips, sample_ops,
253 | postprocess_ops):
254 | modalities.add_audio(
255 | self._parser_builder, # `parser_builder`
256 | self._sampler_builder, # `sampler_builder`
257 | self._postprocessor_builder, # `postprocessor_builder`
258 | 'WAVEFORM/feature/floats', # `input_feature_name`
259 | builders.AUDIO_FEATURE_NAME, # `output_feature_name`
260 | is_training, # `is_training`
261 | 30720, # `num_samples`
262 | 1, # `stride`
263 | num_test_clips) # `num_test_clips`
264 | output = self._process_examples()
265 |
266 | self.assertAllEqual(
267 | [fd.fn_name for fd in self._sampler_builder.get_summary()],
268 | sample_ops)
269 | self.assertAllEqual(
270 | [fd.fn_name for fd in self._postprocessor_builder.get_summary()],
271 | postprocess_ops)
272 |
273 | # Assert static shape.
274 | self.assertNotIn(
275 | None, output[builders.AUDIO_FEATURE_NAME].shape.as_list())
276 | self.assertSetEqual(set(output.keys()),
277 | set([builders.AUDIO_FEATURE_NAME]))
278 | self.assertAllEqual(output[builders.AUDIO_FEATURE_NAME].shape,
279 | (8 * num_test_clips, 30720))
280 |
281 | def test_all_modalities(self):
282 | # Add RGB image.
283 | modalities.add_image(self._parser_builder, self._sampler_builder,
284 | self._decoder_builder, self._preprocessor_builder,
285 | self._postprocessor_builder)
286 | # Add flow image. Note that in this test this will read from a RGB
287 | # flow/encoded since we store flow on disk as RGB images where only the two
288 | # first channels (RG) corresponds to the relevant horizontal and vertical
289 | # displacement vector.
290 | modalities.add_image(
291 | self._parser_builder,
292 | self._sampler_builder,
293 | self._decoder_builder,
294 | self._preprocessor_builder,
295 | self._postprocessor_builder,
296 | input_feature_name='flow/encoded',
297 | output_feature_name=builders.FLOW_FEATURE_NAME,
298 | is_rgb=None,
299 | zero_centering_image=True,
300 | is_flow=True)
301 | modalities.add_label(
302 | self._parser_builder,
303 | self._decoder_builder,
304 | self._preprocessor_builder,
305 | num_classes=50)
306 | tokenizer = tokenizers.WordTokenizer(
307 | _VOCAB_PATH) # OSS: removed internal filename loading.
308 | tokenizer.initialize()
309 | modalities.add_text(
310 | self._parser_builder,
311 | self._decoder_builder,
312 | self._preprocessor_builder,
313 | tokenizer=tokenizer)
314 | modalities.add_audio(self._parser_builder, self._sampler_builder,
315 | self._postprocessor_builder)
316 | output = self._process_examples()
317 |
318 | self.assertSetEqual(
319 | set(output.keys()),
320 | set([
321 | builders.IMAGE_FEATURE_NAME, builders.FLOW_FEATURE_NAME,
322 | builders.LABEL_INDEX_FEATURE_NAME,
323 | builders.TEXT_INDICES_FEATURE_NAME, builders.AUDIO_FEATURE_NAME
324 | ]))
325 | self.assertAllEqual(output[builders.IMAGE_FEATURE_NAME].shape,
326 | (8, 32, 200, 200, 3))
327 | self.assertAllEqual(output[builders.FLOW_FEATURE_NAME].shape,
328 | (8, 32, 200, 200, 2))
329 | self.assertAllEqual(output[builders.LABEL_INDEX_FEATURE_NAME],
330 | [[0] * 42 + [1] + [0] * 7] * 8)
331 | self.assertAllEqual(output[builders.TEXT_INDICES_FEATURE_NAME],
332 | [[[4, 5] + [0] * 14]] * 8)
333 | self.assertAllEqual(output[builders.AUDIO_FEATURE_NAME].shape, (8, 30720))
334 |
335 |
336 | if __name__ == '__main__':
337 | tf.test.main()
338 |
--------------------------------------------------------------------------------
/dmvr/processors_test.py:
--------------------------------------------------------------------------------
1 | # Copyright 2021 DeepMind Technologies Limited.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # https://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | """Tests for processors."""
16 |
17 | import itertools
18 | import os
19 |
20 | from absl.testing import parameterized
21 | from dmvr import processors
22 | from dmvr import tokenizers
23 | import numpy as np
24 | import tensorflow as tf
25 |
26 | # Removed: Internal pyglib dependencies
27 |
28 | _TESTDATA_DIR = os.path.join(os.path.dirname(__file__), 'testdata')
29 | _SAMPLE_IMAGE_PATH = os.path.join(_TESTDATA_DIR, 'sample.jpeg')
30 | _VOCAB_PATH = os.path.join(_TESTDATA_DIR, 'tokenizers', 'word_vocab.txt')
31 |
32 |
33 | class SampleTest(tf.test.TestCase, parameterized.TestCase):
34 |
35 | def setUp(self):
36 | super().setUp()
37 | self._sequence = tf.range(100)
38 |
39 | def test_sample_linspace_sequence(self):
40 | sampled_seq_1 = processors.sample_linspace_sequence(self._sequence, 10, 10)
41 | sampled_seq_2 = processors.sample_linspace_sequence(self._sequence, 7, 10)
42 | sampled_seq_3 = processors.sample_linspace_sequence(self._sequence, 7, 5, 2)
43 | sampled_seq_4 = processors.sample_linspace_sequence(self._sequence, 101, 1)
44 | self.assertAllEqual(sampled_seq_1, range(100))
45 | # [0, 1, 2, 3, 4, ..., 8, 9, 15, 16, ..., 97, 98, 99]
46 | self.assertAllEqual(
47 | sampled_seq_2,
48 | [15 * i + j for i, j in itertools.product(range(7), range(10))])
49 | # [0, 2, 4, 6, 8, 15, 17, 19, ..., 96, 98]
50 | self.assertAllEqual(
51 | sampled_seq_3,
52 | [15 * i + 2 * j for i, j in itertools.product(range(7), range(5))])
53 | self.assertAllEqual(sampled_seq_4, [0] + list(range(100)))
54 |
55 | def test_sample_sequence(self):
56 | sampled_seq_1 = processors.sample_sequence(self._sequence, 10, False)
57 | sampled_seq_2 = processors.sample_sequence(self._sequence, 10, False, 2)
58 | sampled_seq_3 = processors.sample_sequence(self._sequence, 10, True)
59 |
60 | self.assertAllEqual(sampled_seq_1, range(45, 55))
61 | self.assertAllEqual(sampled_seq_2, range(40, 60, 2))
62 |
63 | offset_3 = sampled_seq_3[0]
64 | self.assertBetween(offset_3, 0, 99)
65 | self.assertAllEqual(sampled_seq_3, range(offset_3, offset_3 + 10))
66 |
67 | def test_sample_sequence_with_state(self):
68 | state = {}
69 | sampled_seq_1 = processors.sample_sequence(
70 | self._sequence, 10, True, state=state)
71 | sampled_seq_2 = processors.sample_sequence(
72 | self._sequence, 10, True, state=state)
73 |
74 | self.assertAllEqual(sampled_seq_1, sampled_seq_2)
75 |
76 | def test_sample_or_pad_non_sorted_sequence(self):
77 | sampled_seq_1 = processors.sample_or_pad_non_sorted_sequence(
78 | self._sequence, 10, 0, False)
79 | sampled_seq_2 = processors.sample_or_pad_non_sorted_sequence(
80 | self._sequence, 110, 0, False)
81 |
82 | self.assertAllEqual(sampled_seq_1, range(10))
83 | self.assertAllEqual(sampled_seq_2, list(range(100)) + [0] * 10)
84 |
85 | def test_sample_or_pad_non_sorted_sequence_with_state(self):
86 | state = {}
87 | sampled_seq_1 = processors.sample_or_pad_non_sorted_sequence(
88 | self._sequence, 10, 0, True, state=state)
89 | sampled_seq_2 = processors.sample_or_pad_non_sorted_sequence(
90 | self._sequence, 10, 0, True, state=state)
91 |
92 | self.assertAllEqual(sampled_seq_1, sampled_seq_2)
93 | self.assertRaises(
94 | tf.errors.InvalidArgumentError,
95 | processors.sample_or_pad_non_sorted_sequence,
96 | self._sequence[:10], 10, 0, True, state=state)
97 |
98 | def test_sample_or_pad_non_sorted_sequence_multidim_with_state(self):
99 | state = {}
100 | sampled_seq_1 = processors.sample_or_pad_non_sorted_sequence(
101 | self._sequence, 10, 0, True, state=state)
102 | multi_dim_sequence = tf.tile(self._sequence[:, None], (1, 10))
103 | sampled_seq_2 = processors.sample_or_pad_non_sorted_sequence(
104 | multi_dim_sequence, 10, 0, True, state=state)
105 | self.assertAllEqual(sampled_seq_1, sampled_seq_2[:, 0])
106 |
107 | @parameterized.named_parameters(
108 | {
109 | 'testcase_name': 'len(seq) < num_steps',
110 | 'sequence': np.array([1, 2, 3]),
111 | 'num_steps': 5,
112 | 'expected_sequence': np.array([1, 2, 3, 1, 2])
113 | },
114 | {
115 | 'testcase_name': 'len(seq) == num_steps',
116 | 'sequence': np.array([1, 2, 3]),
117 | 'num_steps': 3,
118 | 'expected_sequence': np.array([1, 2, 3])
119 | },
120 | {
121 | 'testcase_name': 'len(seq) < num_steps with stride',
122 | 'sequence': np.array([1, 2, 3]),
123 | 'num_steps': 5,
124 | 'expected_sequence': np.array([1, 3, 2, 1, 3]),
125 | 'stride': 2
126 | },
127 | {
128 | 'testcase_name': 'len(seq) == num_steps with stride',
129 | 'sequence': np.array([1, 2, 3]),
130 | 'num_steps': 3,
131 | 'expected_sequence': np.array([1, 1, 1]),
132 | 'stride': 3
133 | },
134 | )
135 | def test_sample_sequence_fixed_offset(self,
136 | sequence: np.ndarray,
137 | num_steps: int,
138 | expected_sequence: np.ndarray,
139 | stride: int = 1):
140 | """Tests that offset is always 0."""
141 | for seed in range(5):
142 | actual_sequence = processors.sample_sequence(
143 | sequence, num_steps=num_steps, random=True, stride=stride, seed=seed)
144 | np.testing.assert_array_equal(actual_sequence, expected_sequence)
145 |
146 |
147 | class DecodeTest(tf.test.TestCase):
148 |
149 | def test_decode_jpeg(self):
150 | with open(_SAMPLE_IMAGE_PATH, 'rb') as f: raw_image_bytes = f.read()
151 | raw_image = tf.constant([raw_image_bytes, raw_image_bytes])
152 | decoded_image = processors.decode_jpeg(raw_image)
153 | decoded_image_with_static_channels = processors.decode_jpeg(raw_image, 3)
154 | self.assertEqual(decoded_image_with_static_channels.shape.as_list()[3], 3)
155 | self.assertAllEqual(decoded_image.shape, (2, 263, 320, 3))
156 | self.assertAllEqual(decoded_image_with_static_channels.shape,
157 | (2, 263, 320, 3))
158 |
159 |
160 | class PreprocessTest(tf.test.TestCase):
161 |
162 | def setUp(self):
163 | super().setUp()
164 | # [[0, 1, ..., 119], [1, 2, ..., 120], ..., [119, 120, ..., 218]].
165 | self._frames = tf.stack([tf.range(i, i + 120) for i in range(90)])
166 | self._frames = tf.cast(self._frames, tf.uint8)
167 | self._frames = self._frames[tf.newaxis, :, :, tf.newaxis]
168 | self._frames = tf.broadcast_to(self._frames, (6, 90, 120, 3))
169 |
170 | # Create an equivalent numpy array for assertions.
171 | self._np_frames = np.array([range(i, i + 120) for i in range(90)])
172 | self._np_frames = self._np_frames[np.newaxis, :, :, np.newaxis]
173 | self._np_frames = np.broadcast_to(self._np_frames, (6, 90, 120, 3))
174 |
175 | def test_set_shape(self):
176 | with open(_SAMPLE_IMAGE_PATH, 'rb') as f: raw_image = f.read()
177 | raw_image = tf.constant([raw_image])
178 | decoded_image = processors.decode_jpeg(raw_image)
179 | decoded_image = processors.set_shape(decoded_image, (1, 263, 320, 3))
180 | self.assertAllEqual(decoded_image.shape.as_list(), (1, 263, 320, 3))
181 |
182 | def test_crop_image(self):
183 | cropped_image_1 = processors.crop_image(self._frames, 50, 70)
184 | cropped_image_2 = processors.crop_image(self._frames, 200, 200)
185 | cropped_image_3 = processors.crop_image(self._frames, 50, 70, True)
186 |
187 | self.assertAllEqual(cropped_image_1.shape, (6, 50, 70, 3))
188 | self.assertAllEqual(cropped_image_1, self._np_frames[:, 20:70, 25:95, :])
189 |
190 | self.assertAllEqual(cropped_image_2.shape, (6, 200, 200, 3))
191 | expected = np.pad(
192 | self._np_frames, ((0, 0), (55, 55), (40, 40), (0, 0)), 'constant')
193 | self.assertAllEqual(cropped_image_2, expected)
194 |
195 | self.assertAllEqual(cropped_image_3.shape, (6, 50, 70, 3))
196 | offset = cropped_image_3[0, 0, 0, 0]
197 | expected = np.array([range(i, i + 70) for i in range(offset, offset + 50)])
198 | expected = expected[np.newaxis, :, :, np.newaxis]
199 | expected = np.broadcast_to(expected, (6, 50, 70, 3))
200 | self.assertAllEqual(cropped_image_3, expected)
201 |
202 | def test_crop_image_with_state(self):
203 | state = {}
204 | cropped_image_1 = processors.crop_image(self._frames, 50, 70, state=state)
205 | cropped_image_2 = processors.crop_image(self._frames, 50, 70, state=state)
206 |
207 | self.assertAllEqual(cropped_image_1, cropped_image_2)
208 |
209 | def test_resize_smallest(self):
210 | resized_frames_1 = processors.resize_smallest(self._frames, 180)
211 | resized_frames_2 = processors.resize_smallest(self._frames, 45)
212 | resized_frames_3 = processors.resize_smallest(self._frames, 90)
213 | resized_frames_4 = processors.resize_smallest(
214 | tf.transpose(a=self._frames, perm=(0, 2, 1, 3)), 45)
215 | self.assertAllEqual(resized_frames_1.shape, (6, 180, 240, 3))
216 | self.assertAllEqual(resized_frames_2.shape, (6, 45, 60, 3))
217 | self.assertAllEqual(resized_frames_3.shape, (6, 90, 120, 3))
218 | self.assertAllEqual(resized_frames_4.shape, (6, 60, 45, 3))
219 |
220 | def test_resize_smallest_with_flow(self):
221 | flows = tf.cast(self._frames, tf.float32)
222 | resized_flows = processors.resize_smallest(flows, 180, True)
223 | resized_flows_expected = 2.0 * processors.resize_smallest(flows, 180, False)
224 |
225 | self.assertAllEqual(resized_flows, resized_flows_expected)
226 |
227 | def test_random_flip_left_right(self):
228 | flipped_frames = processors.random_flip_left_right(self._frames)
229 | flipped = np.fliplr(self._np_frames[0, :, :, 0])
230 | flipped = flipped[np.newaxis, :, :, np.newaxis]
231 | flipped = np.broadcast_to(flipped, (6, 90, 120, 3))
232 | self.assertTrue((flipped_frames == self._np_frames).numpy().all() or (
233 | flipped_frames == flipped).numpy().all())
234 |
235 | def test_random_flip_left_right_with_flow(self):
236 | flows = tf.cast(self._frames, tf.float32)
237 | flipped_flows = processors.random_flip_left_right(flows, is_flow=True)
238 | flipped = np.fliplr(self._np_frames[0, :, :, 0])
239 | flipped = flipped[np.newaxis, :, :, np.newaxis]
240 | flipped = np.broadcast_to(flipped, (6, 90, 120, 3))
241 | flipped_flow = flipped.astype(np.float32)
242 | flipped_flow[:, :, :, 0] *= -1.0
243 | self.assertTrue(
244 | (flipped_flows == self._np_frames.astype(np.float32)).numpy().all() or (
245 | flipped_flows == flipped_flow).numpy().all())
246 |
247 | def test_random_flip_left_right_with_state(self):
248 | state = {}
249 | flipped_frames_1 = processors.random_flip_left_right(
250 | self._frames, state=state)
251 | flipped_frames_2 = processors.random_flip_left_right(
252 | self._frames, state=state)
253 | self.assertAllEqual(flipped_frames_1, flipped_frames_2)
254 |
255 | def test_normalize_image(self):
256 | normalized_images_1 = processors.normalize_image(
257 | self._frames, False, tf.float32)
258 | normalized_images_2 = processors.normalize_image(
259 | self._frames, True, tf.float32)
260 | self.assertAllClose(normalized_images_1, self._np_frames / 255)
261 | self.assertAllClose(normalized_images_2, self._np_frames * 2 / 255 - 1.0)
262 |
263 | def test_scale_jitter_augm(self):
264 | no_jitter_images = processors.scale_jitter_augm(self._frames, 0.8, 1.0, 0.0)
265 | jitter_images = processors.scale_jitter_augm(
266 | self._frames, 2.0, 2.00001, 1.0)
267 | self.assertAllEqual(no_jitter_images.shape, (6, 90, 120, 3))
268 | self.assertAllEqual(jitter_images.shape, (6, 180, 240, 3))
269 |
270 | def test_scale_jitter_augm_with_state(self):
271 | state = {}
272 | jitter_image_1 = processors.scale_jitter_augm(
273 | self._frames, 0.8, 1.2, 1.0, state=state)
274 | jitter_image_2 = processors.scale_jitter_augm(
275 | self._frames, 0.8, 1.2, 1.0, state=state)
276 | self.assertAllEqual(jitter_image_1, jitter_image_2)
277 |
278 | def test_scale_jitter_augm_with_flow(self):
279 | state = {}
280 | flows = tf.cast(self._frames, tf.float32)
281 | jitter_flows = processors.scale_jitter_augm(
282 | flows, 0.8, 1.2, 1.0, state=state, is_flow=True)
283 | jitter_flows_expected = processors.scale_jitter_augm(
284 | flows, 0.8, 1.2, 1.0, state=state)
285 | h_s, w_s, _ = state['scale_jitter_augm_info']
286 | jitter_flows_expected *= tf.stack([h_s, w_s, 1.0])[None, None, None, :]
287 | self.assertAllClose(jitter_flows, jitter_flows_expected)
288 |
289 | def test_color_default_augment(self):
290 | normalized_images = processors.normalize_image(
291 | self._frames, False, tf.float32)
292 | no_augmented_images = processors.color_default_augm(
293 | normalized_images, False, 0.0, 0.0)
294 | color_augmented_images = processors.color_default_augm(
295 | normalized_images, False, 1.0, 0.0)
296 | color_dropped_images = processors.color_default_augm(
297 | normalized_images, False, 0.0, 1.0)
298 | self.assertAllEqual(no_augmented_images.shape, normalized_images.shape)
299 | self.assertAllEqual(color_augmented_images.shape, normalized_images.shape)
300 | self.assertAllEqual(color_dropped_images.shape, normalized_images.shape)
301 |
302 | self.assertAllEqual(normalized_images, no_augmented_images)
303 | self.assertNotAllEqual(normalized_images, color_augmented_images)
304 | self.assertNotAllEqual(normalized_images, color_dropped_images)
305 |
306 | self.assertAllEqual(color_dropped_images[:, :, :, 0],
307 | color_dropped_images[:, :, :, 1])
308 | self.assertAllEqual(color_dropped_images[:, :, :, 0],
309 | color_dropped_images[:, :, :, 2])
310 |
311 | def test_space_to_depth(self):
312 | output_frames_1 = processors.space_to_depth(self._frames, 2, 3)
313 | output_frames_2 = processors.space_to_depth(self._frames, 3, 2)
314 | output_frames_3 = processors.space_to_depth(
315 | self._frames, spatial_block_size=2)
316 | self.assertAllEqual(output_frames_1.shape, (3, 30, 40, 54))
317 | self.assertAllEqual(output_frames_2.shape, (2, 45, 60, 36))
318 | self.assertAllEqual(output_frames_3.shape, (6, 45, 60, 12))
319 |
320 | def test_crop_or_pad_words(self):
321 | input_words_indices = tf.expand_dims(tf.range(10, dtype=tf.int32), axis=0)
322 |
323 | output_words_indices_1 = processors.crop_or_pad_words(
324 | input_words_indices, 5)
325 | output_words_indices_2 = processors.crop_or_pad_words(
326 | input_words_indices, 15)
327 | self.assertAllEqual(output_words_indices_1, [list(range(5))])
328 | self.assertAllEqual(output_words_indices_2,
329 | [[i for i in range(10)] + [0] * 5])
330 |
331 | def test_tokenize(self):
332 | tokenizer = tokenizers.WordTokenizer(
333 | _VOCAB_PATH) # OSS: removed internal filename loading.
334 | tokenizer.initialize()
335 | input_features = {'text': tf.constant(['hello world', 'hello', 'world'])}
336 |
337 | output_features = processors.tokenize(input_features, tokenizer, 'text',
338 | 'indices', False, False, 4, True)
339 | self.assertAllEqual(output_features['text'],
340 | ['hello world', 'hello', 'world'])
341 | self.assertAllEqual(output_features['indices'],
342 | [[4, 5, 0, 0], [4, 0, 0, 0], [5, 0, 0, 0]])
343 |
344 |
345 | class PostprocessTest(tf.test.TestCase):
346 |
347 | def test_batched_video_transpose(self):
348 | input_tensor = tf.constant([[[1, 2], [3, 4], [5, 6]]])
349 | output_tensor = processors.batched_video_transpose(input_tensor, (0, 2, 1))
350 |
351 | self.assertAllEqual(output_tensor, [[[1, 3, 5], [2, 4, 6]]])
352 |
353 | def test_batched_space_to_depth(self):
354 | input_frames = tf.zeros((8, 30, 150, 210, 3))
355 |
356 | output_frames_1 = processors.batched_space_to_depth(input_frames, 2, 3)
357 | output_frames_2 = processors.batched_space_to_depth(input_frames, 3, 2)
358 | output_frames_3 = processors.batched_space_to_depth(
359 | input_frames, spatial_block_size=2)
360 |
361 | self.assertAllEqual(output_frames_1.shape, (8, 15, 50, 70, 54))
362 | self.assertAllEqual(output_frames_2.shape, (8, 10, 75, 105, 36))
363 | self.assertAllEqual(output_frames_3.shape, (8, 30, 75, 105, 12))
364 |
365 |
366 | if __name__ == '__main__':
367 | tf.test.main()
368 |
--------------------------------------------------------------------------------
/dmvr/sources.py:
--------------------------------------------------------------------------------
1 | # Copyright 2021 DeepMind Technologies Limited.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # https://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | """Sources for reading and decoding raw binary data files."""
16 |
17 | import abc
18 | from typing import Optional, Union
19 |
20 | import tensorflow as tf
21 |
22 |
23 | class Source(abc.ABC):
24 | """Base class for sources.
25 |
26 | Sources are objects reading from binary files and generating an initial
27 | `tf.data.Dataset` with the serialized examples. Deserializing the examples is
28 | not responsibility of the `Source` (it should be done by the parser).
29 |
30 | For each different type of storage (e.g. TFRecords, image files, text files),
31 | a subclass can be implemented.
32 | """
33 |
34 | @abc.abstractmethod
35 | def load_and_decode_shard(
36 | self,
37 | shard: Union[str, tf.Tensor] # Shape () and type `tf.string`.
38 | ) -> tf.data.Dataset:
39 | """Decodes a single raw input file into a `tf.data.Dataset`.
40 |
41 | Args:
42 | shard: Path to a single file with encoded data.
43 |
44 | Returns:
45 | A `tf.data.Dataset` object containing a key (this can be a file name,
46 | index, empty or any other useful bits) and a raw example (both encoded as
47 | bytes). Current supported types of examples are `tf.train.Example` and
48 | `tf.train.SequenceExample` (see `builders.BaseParserBuilder`).
49 | """
50 |
51 |
52 | class TFRecordsSource(Source):
53 | """Source for TFRecords data format."""
54 |
55 | def __init__(self, compression_type: Optional[str] = None):
56 | self._compression_type = compression_type
57 |
58 | def load_and_decode_shard(
59 | self,
60 | shard: Union[str, tf.Tensor] # Shape () and type `tf.string`.
61 | ) -> tf.data.Dataset:
62 | ds = tf.data.TFRecordDataset(shard, compression_type=self._compression_type)
63 | # TFRecords do not provide an index or key per example. Use shard path as
64 | # key, since it can be useful later for retrieval.
65 | key = shard.encode('utf-8') if isinstance(shard, str) else shard
66 | ds = ds.map(lambda example: (key, example))
67 | return ds
68 |
--------------------------------------------------------------------------------
/dmvr/sources_test.py:
--------------------------------------------------------------------------------
1 | # Copyright 2021 DeepMind Technologies Limited.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # https://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | """Tests for sources."""
16 |
17 | import os
18 |
19 | from dmvr import sources
20 | import tensorflow as tf
21 |
22 |
23 | class TFRecordsSourceTest(tf.test.TestCase):
24 |
25 | def setUp(self):
26 | super().setUp()
27 |
28 | self._shard = os.path.join(self.get_temp_dir(), 'shard')
29 | # Generate a TFRecord single shard with one serialized SequenceExample
30 | # in the format ('sequence', [[0], [1], ..., [99]]).
31 | with tf.io.TFRecordWriter(self._shard) as builder:
32 | self._seq_example = tf.train.SequenceExample()
33 | for i in range(100):
34 | self._seq_example.feature_lists.feature_list.get_or_create(
35 | 'sequence').feature.add().int64_list.value[:] = [i]
36 | builder.write(self._seq_example.SerializeToString())
37 |
38 | def test_load_and_decode(self):
39 | source = sources.TFRecordsSource()
40 | ds = source.load_and_decode_shard(self._shard)
41 | it = iter(ds)
42 |
43 | data = next(it)
44 | self.assertEqual(data[0], self._shard.encode('utf-8'))
45 | self.assertEqual(data[1], self._seq_example.SerializeToString())
46 |
47 | with self.assertRaises(StopIteration) as _:
48 | data = next(it)
49 |
50 | def test_input_as_tensor(self):
51 | source = sources.TFRecordsSource()
52 | ds = source.load_and_decode_shard(tf.constant(self._shard))
53 | it = iter(ds)
54 |
55 | data = next(it)
56 | self.assertEqual(data[0], self._shard.encode('utf-8'))
57 | self.assertEqual(data[1], self._seq_example.SerializeToString())
58 |
59 |
60 | if __name__ == '__main__':
61 | tf.test.main()
62 |
--------------------------------------------------------------------------------
/dmvr/testdata/sample.jpeg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/google-deepmind/dmvr/77ccedaa084d29239eaeafddb0b2e83843b613a1/dmvr/testdata/sample.jpeg
--------------------------------------------------------------------------------
/dmvr/testdata/tokenizers/bert_word_vocab.txt:
--------------------------------------------------------------------------------
1 | [CLS]
2 | [PAD]
3 | [UNK]
4 | [SEP]
5 | hello
6 | world
7 |
--------------------------------------------------------------------------------
/dmvr/testdata/tokenizers/spiece.model.1000.model:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/google-deepmind/dmvr/77ccedaa084d29239eaeafddb0b2e83843b613a1/dmvr/testdata/tokenizers/spiece.model.1000.model
--------------------------------------------------------------------------------
/dmvr/testdata/tokenizers/word_vocab.txt:
--------------------------------------------------------------------------------
1 | 0
2 | 1
3 | 2
4 | 3
5 | 4 hello
6 | 5 world
7 |
--------------------------------------------------------------------------------
/dmvr/tokenizers.py:
--------------------------------------------------------------------------------
1 | # Copyright 2021 DeepMind Technologies Limited.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # https://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | """A simple tokenizer interface with basic implementations."""
15 |
16 | import abc
17 | from typing import Optional, Sequence, Union
18 |
19 | import clip.simple_tokenizer
20 | import tensorflow as tf
21 | import tensorflow_text
22 |
23 | import sentencepiece as spm
24 |
25 |
26 | class TextTokenizer(abc.ABC):
27 | """Base class for text tokenizers."""
28 |
29 | def initialize(self):
30 | """Initializes tensorflow tables and models."""
31 | return
32 |
33 | @abc.abstractmethod
34 | def string_tensor_to_indices(self,
35 | string_tensor: Union[tf.Tensor, Sequence[str]],
36 | prepend_bos: bool = False,
37 | append_eos: bool = False,
38 | max_num_tokens: Optional[int] = 32) -> tf.Tensor:
39 | """Tokenizes input text, mapping a tensor of strings to a tensor of ints.
40 |
41 | Args:
42 | string_tensor: Input string tensor of shape [num_texts].
43 | prepend_bos: Whether to prepend the BOS (beginning of sentence) token to
44 | the output tokens.
45 | append_eos: Whether to append the EOS (end of sentence) token to the
46 | output tokens.
47 | max_num_tokens: Maximum number of tokens to return per caption. If
48 | provided, the tokens will be padded / cut at the given size. If not, a
49 | tensor of unknown size will be returned.
50 |
51 | Returns:
52 | A `tf.int32` tensor of shape [num_texts, `max_num_tokens`] if
53 | `max_num_tokens` is provided or [num_texts, max_num_tokens_in_batch]
54 | otherwise.
55 | """
56 |
57 | @abc.abstractmethod
58 | def indices_to_string(self, indices: Sequence[int]) -> str:
59 | """Detokenizes, mapping a python sequence of indices to a string."""
60 |
61 | @property
62 | @abc.abstractmethod
63 | def vocab_size(self) -> int:
64 | """Returns the vocabulary size."""
65 |
66 | @property
67 | @abc.abstractmethod
68 | def pad_token(self) -> int:
69 | """Returns index of the PAD token."""
70 |
71 | @property
72 | @abc.abstractmethod
73 | def bos_token(self) -> int:
74 | """Returns index of the BOS token."""
75 |
76 | @property
77 | @abc.abstractmethod
78 | def eos_token(self) -> int:
79 | """Returns index of the EOS token."""
80 |
81 | @property
82 | @abc.abstractmethod
83 | def unk_token(self) -> int:
84 | """Returns index of the UNK token."""
85 |
86 |
87 | class SentencePieceTokenizer(TextTokenizer):
88 | """SentencePiece tokenizer from a pre-trained SentencePiece model.
89 |
90 | Pre-trained models are provided in multiple repositories around the web. See
91 | https://github.com/google/sentencepiece for info on how to train new models on
92 | specific corpus.
93 | """
94 |
95 | def __init__(self, model_path: str):
96 | """Initializes the `SentencePieceTokenizer`.
97 |
98 | Args:
99 | model_path: Path to the '.model' file.
100 | """
101 | self._model_path = model_path
102 | self._sp_model = spm.SentencePieceProcessor()
103 | self._sp_model.Load(model_path)
104 |
105 | self._vocab_size = self._sp_model.GetPieceSize()
106 | self._bos_token = self._sp_model.bos_id()
107 | self._eos_token = self._sp_model.eos_id()
108 | self._pad_token = self._sp_model.pad_id()
109 | self._unk_token = self._sp_model.unk_id()
110 |
111 | self._tf_sp_model = None
112 |
113 | def initialize(self):
114 | with tf.io.gfile.GFile(self._model_path, 'rb') as f:
115 | self._tf_sp_model = tensorflow_text.SentencepieceTokenizer(
116 | model=f.read(), out_type=tf.int32, add_bos=True, add_eos=True)
117 |
118 | def string_tensor_to_indices(self,
119 | string_tensor: Union[tf.Tensor, Sequence[str]],
120 | prepend_bos: bool = False,
121 | append_eos: bool = False,
122 | max_num_tokens: Optional[int] = 32) -> tf.Tensor:
123 | if self._tf_sp_model is None:
124 | raise RuntimeError('Model was not initialized. Call `initialize` method.')
125 |
126 | tokenized = self._tf_sp_model.tokenize(string_tensor)
127 | tokenized = tokenized if prepend_bos else tokenized[..., 1:]
128 | tokenized = tokenized if append_eos else tokenized[..., :-1]
129 |
130 | # Pad to `max_num_tokens`.
131 | shape = None if max_num_tokens is None else [None, max_num_tokens]
132 | tokenized = tokenized.to_tensor(default_value=self._pad_token, shape=shape)
133 | return tokenized
134 |
135 | def indices_to_string(self, indices: Sequence[int]) -> str:
136 | return self._sp_model.DecodeIds(indices)
137 |
138 | def string_to_indices(self,
139 | string: str,
140 | prepend_bos: bool = False,
141 | append_eos: bool = False,
142 | max_num_tokens: Optional[int] = 32) -> Sequence[int]:
143 | """Tokenizes, mapping a python string to a sequence of indices."""
144 | tokenized = self._sp_model.EncodeAsIds(string)
145 | tokenized = [self._bos_token] * prepend_bos + tokenized
146 | tokenized += [self._eos_token] * append_eos
147 | if max_num_tokens:
148 | tokenized = tokenized[:max_num_tokens]
149 | num_tokens = len(tokenized)
150 | tokenized = tokenized + [self._pad_token] * (max_num_tokens - num_tokens)
151 | return tokenized
152 |
153 | @property
154 | def vocab_size(self):
155 | return self._vocab_size
156 |
157 | @property
158 | def pad_token(self):
159 | return self._pad_token
160 |
161 | @property
162 | def bos_token(self):
163 | return self._bos_token
164 |
165 | @property
166 | def eos_token(self):
167 | return self._eos_token
168 |
169 | @property
170 | def unk_token(self):
171 | return self._unk_token
172 |
173 |
174 | class WordTokenizer(TextTokenizer):
175 | """Vocabulary based word tokenizer."""
176 |
177 | PAD = ''
178 | BOS = ''
179 | EOS = ''
180 | UNK = ''
181 |
182 | def __init__(self, vocabulary_path: str):
183 | """Initializes the `WordTokenizer`.
184 |
185 | Args:
186 | vocabulary_path: A path to a vocabulary file. The vocabulary is a simple
187 | text file where each line is of the form: 'idx_word word' or simply
188 | 'word' (the line index will be used). The vocabulary should at least
189 | contain the words: '', '', '' and ''.
190 | """
191 | # Parse the vocabulary. The expected format is either one word per line (and
192 | # the index for that word will be the line index) or an index and a word,
193 | # split by space.
194 | idx2word = {}
195 | with tf.io.gfile.GFile(vocabulary_path) as f:
196 | for line_idx, line in enumerate(f):
197 | line = line.strip().split(' ')
198 |
199 | if len(line) not in [1, 2]:
200 | raise ValueError(f'Line {line_idx} of vocabulary file, with contents '
201 | f'\'{line}\' is malformed')
202 |
203 | idx, word = line if len(line) == 2 else (line_idx, line[0])
204 | idx = int(idx)
205 |
206 | if idx in idx2word:
207 | raise ValueError(
208 | f'Vocabulary contains two words with same index {idx}.')
209 | if word != word.lower():
210 | raise ValueError(f'Word {word} with index {idx} is not lower case.')
211 |
212 | idx2word[idx] = word
213 |
214 | # Validate.
215 | if len(idx2word) != len(set(idx2word.values())):
216 | raise ValueError('Words in vocabulary are not unique.')
217 | basic_tokens = {self.PAD, self.BOS, self.EOS, self.UNK}
218 | if basic_tokens & set(idx2word.values()) != basic_tokens:
219 | raise ValueError(
220 | f'Vocabulary does not contain all basic tokens {basic_tokens}.')
221 |
222 | self._idx2word = idx2word
223 | self._word2idx = {v: k for k, v in idx2word.items()}
224 |
225 | self._vocab_size = len(idx2word)
226 | self._pad_token = self._word2idx[self.PAD]
227 | self._bos_token = self._word2idx[self.BOS]
228 | self._eos_token = self._word2idx[self.EOS]
229 | self._unk_token = self._word2idx[self.UNK]
230 |
231 | self._tf_word2idx = None
232 | self._tf_whitespace_tokenizer = None
233 |
234 | def initialize(self):
235 | ids_tensor = tf.constant([i for w, i in self._word2idx.items()],
236 | dtype=tf.int32)
237 | words_tensor = tf.constant([w for w, i in self._word2idx.items()],
238 | dtype=tf.string)
239 | self._tf_whitespace_tokenizer = tensorflow_text.WhitespaceTokenizer()
240 | self._tf_word2idx = tf.lookup.StaticHashTable(
241 | tf.lookup.KeyValueTensorInitializer(words_tensor, ids_tensor),
242 | self._unk_token)
243 |
244 | def string_tensor_to_indices(self,
245 | string_tensor: Union[tf.Tensor, Sequence[str]],
246 | prepend_bos: bool = False,
247 | append_eos: bool = False,
248 | max_num_tokens: Optional[int] = 32) -> tf.Tensor:
249 | if self._tf_word2idx is None or self._tf_whitespace_tokenizer is None:
250 | raise RuntimeError('Model was not initialized. Call `initialize` method.')
251 |
252 | # Remove punctuation.
253 | string_tensor = tf.strings.regex_replace(string_tensor, '[[:punct:]]', '')
254 | # Lower case.
255 | string_tensor = tf.strings.lower(string_tensor)
256 | if prepend_bos:
257 | string_tensor = self.BOS.encode('utf-8') + b' ' + string_tensor
258 | if append_eos:
259 | string_tensor += b' ' + self.EOS.encode('utf-8')
260 |
261 | # Separate words by whitespace.
262 | tokenized = self._tf_whitespace_tokenizer.tokenize(string_tensor)
263 | # Map word to indices.
264 | tokenized = self._tf_word2idx.lookup(tokenized)
265 | # Pad to `max_num_tokens`.
266 | shape = None if max_num_tokens is None else [None, max_num_tokens]
267 | tokenized = tokenized.to_tensor(default_value=self._pad_token, shape=shape)
268 | return tokenized
269 |
270 | def indices_to_string(self, indices: Sequence[int]) -> str:
271 | # Cut at `EOS` or `PAD`.
272 | idx_list_cut = []
273 | for token_id in indices:
274 | if token_id in [self._pad_token, self._eos_token]:
275 | break
276 | idx_list_cut.append(token_id)
277 |
278 | # Decode back to string.
279 | words_list = [self._idx2word[idx] for idx in idx_list_cut]
280 | return ' '.join(words_list)
281 |
282 | def string_to_indices(self,
283 | string: str,
284 | prepend_bos: bool = False,
285 | append_eos: bool = False,
286 | max_num_tokens: Optional[int] = 32) -> Sequence[int]:
287 | """Tokenizes, mapping a python string to a sequence of indices."""
288 | string = string.translate(
289 | str.maketrans('', '', '!"#$%&\'()*+,-./:;<=>?@[\\]^_`{|}~'))
290 | string = string.lower()
291 | words = string.split(' ')
292 | tokenized = [self._word2idx.get(w, self._unk_token) for w in words]
293 | tokenized = [self._bos_token] * prepend_bos + tokenized
294 | tokenized += [self._eos_token] * append_eos
295 | if max_num_tokens:
296 | tokenized = tokenized[:max_num_tokens]
297 | num_tokens = len(tokenized)
298 | tokenized = tokenized + [self._pad_token] * (max_num_tokens - num_tokens)
299 | return tokenized
300 |
301 | @property
302 | def vocab_size(self):
303 | return self._vocab_size
304 |
305 | @property
306 | def pad_token(self):
307 | return self._pad_token
308 |
309 | @property
310 | def bos_token(self):
311 | return self._bos_token
312 |
313 | @property
314 | def eos_token(self):
315 | return self._eos_token
316 |
317 | @property
318 | def unk_token(self):
319 | return self._unk_token
320 |
321 |
322 | class BertTokenizer(TextTokenizer):
323 | """BERT tokenizer.
324 |
325 | Standard BERT vocabularies can be found in tf hub.
326 | """
327 |
328 | PAD = '[PAD]'
329 | CLS = '[CLS]'
330 | SEP = '[SEP]'
331 | BOS = CLS
332 | EOS = SEP
333 | UNK = '[UNK]'
334 |
335 | def __init__(self, vocabulary_path: str):
336 | """Initializes the `BertTokenizer`.
337 |
338 | Args:
339 | vocabulary_path: A path to a vocabulary file. The vocabulary is a simple
340 | text file where each line is of the form: 'token'. The vocabulary should
341 | at least contain the words: '[PAD]', '[CLS]', '[SEP]' and '[UNK]'.
342 | """
343 | # Parse the vocabulary.
344 | idx2word = {}
345 | self._vocabulary_path = vocabulary_path
346 | with tf.io.gfile.GFile(vocabulary_path) as f:
347 | for idx, line in enumerate(f):
348 | word = line.strip()
349 | idx2word[idx] = word
350 |
351 | # Validate.
352 | if len(idx2word) != len(set(idx2word.values())):
353 | raise ValueError('Words in vocabulary are not unique.')
354 | basic_tokens = {self.PAD, self.BOS, self.EOS, self.UNK}
355 | if basic_tokens & set(idx2word.values()) != basic_tokens:
356 | raise ValueError(
357 | f'Vocabulary does not contain all basic tokens {basic_tokens}.')
358 |
359 | self._idx2word = idx2word
360 | self._word2idx = {v: k for k, v in idx2word.items()}
361 |
362 | self._vocab_size = len(idx2word)
363 | self._pad_token = self._word2idx[self.PAD]
364 | self._bos_token = self._word2idx[self.BOS]
365 | self._eos_token = self._word2idx[self.EOS]
366 | self._unk_token = self._word2idx[self.UNK]
367 |
368 | self._tf_tokenizer = None
369 |
370 | def initialize(self):
371 | self._tf_tokenizer = tensorflow_text.BertTokenizer(
372 | self._vocabulary_path,
373 | token_out_type=tf.int32,
374 | unknown_token=self.UNK,
375 | lower_case=True)
376 |
377 | def string_tensor_to_indices(self,
378 | string_tensor: Union[tf.Tensor, Sequence[str]],
379 | prepend_bos: bool = False,
380 | append_eos: bool = False,
381 | max_num_tokens: Optional[int] = 32) -> tf.Tensor:
382 | if self._tf_tokenizer is None:
383 | raise RuntimeError('Model was not initialized. Call `initialize` method.')
384 |
385 | batch_size = tf.shape(input=string_tensor)[0]
386 | tokenized = self._tf_tokenizer.tokenize(string_tensor)
387 | tokenized = tokenized.merge_dims(-2, -1)
388 |
389 | if append_eos:
390 | eos_tensor = tf.ragged.constant([self._eos_token])
391 | eos_tensor = tf.tile(eos_tensor, [batch_size])
392 | eos_tensor = tf.expand_dims(eos_tensor, axis=1)
393 | tokenized = tf.concat([tokenized, eos_tensor], axis=1)
394 | if prepend_bos:
395 | bos_tensor = tf.ragged.constant([self._bos_token])
396 | bos_tensor = tf.tile(bos_tensor, [batch_size])
397 | bos_tensor = tf.expand_dims(bos_tensor, axis=1)
398 | tokenized = tf.concat([bos_tensor, tokenized], axis=1)
399 |
400 | # Pad to `max_num_tokens`.
401 | shape = None if max_num_tokens is None else [None, max_num_tokens]
402 | tokenized = tokenized.to_tensor(default_value=self._pad_token, shape=shape)
403 | return tokenized
404 |
405 | def indices_to_string(self, indices: Sequence[int]) -> str:
406 | # Cut at `EOS` or `PAD`.
407 | idx_list_cut = []
408 | for token_id in indices:
409 | if token_id in [self._pad_token, self._eos_token]:
410 | break
411 | idx_list_cut.append(token_id)
412 |
413 | # Decode back to string.
414 | word_iter = (self._idx2word[idx] for idx in idx_list_cut)
415 | return ' '.join(word_iter).replace(' ##', '')
416 |
417 | @property
418 | def vocab_size(self):
419 | return self._vocab_size
420 |
421 | @property
422 | def pad_token(self):
423 | return self._pad_token
424 |
425 | @property
426 | def bos_token(self):
427 | return self._bos_token
428 |
429 | @property
430 | def eos_token(self):
431 | return self._eos_token
432 |
433 | @property
434 | def unk_token(self):
435 | return self._unk_token
436 |
437 | @property
438 | def cls_token(self):
439 | return self._bos_token
440 |
441 | @property
442 | def sep_token(self):
443 | return self._eos_token
444 |
445 |
446 | class ClipTokenizer(TextTokenizer):
447 | """CLIP tokenizer."""
448 |
449 | BOS = '<|startoftext|>'
450 | EOS = '<|endoftext|>'
451 | UNK = EOS
452 |
453 | def __init__(
454 | self,
455 | vocabulary_path: Optional[str] = None,
456 | ) -> None:
457 | """Initializes the `ClipTokenizer`.
458 |
459 | Args:
460 | vocabulary_path: A path to a CLIP-style vocabulary file.
461 | """
462 | self._tokenizer = clip.simple_tokenizer.SimpleTokenizer(vocabulary_path)
463 |
464 | self._vocab_size = len(self._tokenizer.encoder)
465 | self._pad_token = 0
466 | self._bos_token = self._tokenizer.encoder[self.BOS]
467 | self._eos_token = self._tokenizer.encoder[self.EOS]
468 | self._unk_token = self._tokenizer.encoder[self.UNK]
469 |
470 | self._initialized = False
471 |
472 | def initialize(self) -> None:
473 | self._initialized = True
474 |
475 | def _clip_tokenize(self, texts: Union[tf.Tensor,
476 | Sequence[str]]) -> tf.RaggedTensor:
477 | if isinstance(texts, tf.Tensor):
478 | texts = [text.decode('utf-8') for text in texts._numpy().tolist()] # pylint: disable=protected-access
479 | return tf.ragged.constant([self._tokenizer.encode(text) for text in texts],
480 | dtype=tf.int32)
481 |
482 | def string_tensor_to_indices(self,
483 | string_tensor: Union[tf.Tensor, Sequence[str]],
484 | prepend_bos: bool = False,
485 | append_eos: bool = False,
486 | max_num_tokens: Optional[int] = 77) -> tf.Tensor:
487 | if not self._initialized: # To satisfy the tests.
488 | raise RuntimeError('Model was not initialized. Call `initialize` method.')
489 |
490 | batch_size = tf.shape(input=string_tensor)[0]
491 |
492 | tokenized = tf.py_function(
493 | func=self._clip_tokenize,
494 | inp=[string_tensor],
495 | Tout=tf.RaggedTensorSpec([None, None], dtype=tf.int32))
496 |
497 | if append_eos:
498 | eos_tensor = tf.ragged.constant([self._eos_token])
499 | eos_tensor = tf.tile(eos_tensor, [batch_size])
500 | eos_tensor = tf.expand_dims(eos_tensor, axis=1)
501 | tokenized = tf.concat([tokenized, eos_tensor], axis=1)
502 | if prepend_bos:
503 | bos_tensor = tf.ragged.constant([self._bos_token])
504 | bos_tensor = tf.tile(bos_tensor, [batch_size])
505 | bos_tensor = tf.expand_dims(bos_tensor, axis=1)
506 | tokenized = tf.concat([bos_tensor, tokenized], axis=1)
507 |
508 | # Pad to `max_num_tokens`.
509 | shape = None if max_num_tokens is None else [None, max_num_tokens]
510 | return tokenized.to_tensor(default_value=self._pad_token, shape=shape)
511 |
512 | def indices_to_string(self, indices: Sequence[int]) -> str:
513 | text = self._tokenizer.decode(i for i in indices if i != self._pad_token)
514 | start_pos = len(self.BOS) if text.startswith(self.BOS) else 0
515 | end_pos = text.index(self.EOS) if self.EOS in text else None
516 | return text[start_pos:end_pos].strip()
517 |
518 | @property
519 | def vocab_size(self) -> int:
520 | return self._vocab_size
521 |
522 | @property
523 | def pad_token(self) -> int:
524 | return self._pad_token
525 |
526 | @property
527 | def bos_token(self) -> int:
528 | return self._bos_token
529 |
530 | @property
531 | def eos_token(self) -> int:
532 | return self._eos_token
533 |
534 | @property
535 | def unk_token(self) -> int:
536 | return self._unk_token
537 |
--------------------------------------------------------------------------------
/dmvr/tokenizers_test.py:
--------------------------------------------------------------------------------
1 | # Copyright 2021 DeepMind Technologies Limited.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # https://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | """Tests for tokenizers."""
16 |
17 | from __future__ import annotations
18 |
19 | from collections.abc import Sequence
20 | import os
21 | from typing import Type, TypeVar
22 |
23 | import clip.simple_tokenizer
24 | from dmvr import tokenizers
25 | from parameterized import parameterized
26 | import tensorflow as tf
27 |
28 | # Removed: Internal pyglib dependencies
29 |
30 | _TESTDATA_DIR = os.path.join(os.path.dirname(__file__), 'testdata')
31 | _MOCK_DATA = os.path.join(_TESTDATA_DIR, 'tokenizers')
32 |
33 | _FILENAMES = {
34 | tokenizers.SentencePieceTokenizer: 'spiece.model.1000.model',
35 | tokenizers.WordTokenizer: 'word_vocab.txt',
36 | tokenizers.BertTokenizer: 'bert_word_vocab.txt',
37 | tokenizers.ClipTokenizer: clip.simple_tokenizer.default_bpe(),
38 | }
39 |
40 | T = TypeVar('T', bound=tokenizers.TextTokenizer)
41 |
42 |
43 | def _get_tokenizer(cls: Type[T]) -> T:
44 | filename = _FILENAMES[cls]
45 | path = os.path.join(_MOCK_DATA, filename) # OSS: removed internal filename loading.
46 | return cls(path)
47 |
48 |
49 | def _tokenize_with_original_clip(
50 | texts: str | Sequence[str],
51 | context_length: int = 77) -> Sequence[Sequence[int]]:
52 | # Code adapted from `clip.tokenize` because it's not importable (only
53 | # `clip.simple_tokenizer` is).
54 |
55 | if isinstance(texts, str):
56 | texts = [texts]
57 |
58 | tokenizer = clip.simple_tokenizer.SimpleTokenizer()
59 | sot_token = tokenizer.encoder['<|startoftext|>']
60 | eot_token = tokenizer.encoder['<|endoftext|>']
61 | all_tokens = [[sot_token] + tokenizer.encode(text) + [eot_token]
62 | for text in texts]
63 | result = []
64 |
65 | for i, tokens in enumerate(all_tokens):
66 | if len(tokens) > context_length:
67 | raise RuntimeError(f'Input {texts[i]} is too long for context length'
68 | f' {context_length}')
69 | result.append(tokens + [0] * (context_length - len(tokens)))
70 |
71 | return result
72 |
73 |
74 | def _decode_with_original_clip(tokens_ids: Sequence[int]) -> str:
75 | tokenizer = clip.simple_tokenizer.SimpleTokenizer()
76 | text = tokenizer.decode(tokens_ids)
77 |
78 | eos = '<|endoftext|>'
79 | return text[:text.index(eos) + len(eos)]
80 |
81 |
82 | class TokenizerTest(tf.test.TestCase):
83 |
84 | @parameterized.expand(
85 | ((tokenizers.WordTokenizer,), (tokenizers.SentencePieceTokenizer,),
86 | (tokenizers.BertTokenizer,), (tokenizers.ClipTokenizer,)))
87 | def test_tokenizer(self, cls):
88 | tokenizer = _get_tokenizer(cls)
89 | tokenizer.initialize()
90 | input_string = ['hello world']
91 |
92 | tokenized = tokenizer.string_tensor_to_indices(
93 | input_string, max_num_tokens=42)
94 | self.assertEqual(tokenized.dtype, tf.int32)
95 |
96 | tokenized = tokenized.numpy().tolist()[0]
97 | self.assertLen(tokenized, 42)
98 | self.assertEqual(tokenized[-1], tokenizer.pad_token)
99 |
100 | detokenized = tokenizer.indices_to_string(tokenized)
101 | self.assertEqual(detokenized, 'hello world')
102 |
103 | @parameterized.expand(
104 | ((tokenizers.WordTokenizer,), (tokenizers.SentencePieceTokenizer,),
105 | (tokenizers.BertTokenizer,), (tokenizers.ClipTokenizer,)))
106 | def test_bos_eos(self, cls):
107 | tokenizer = _get_tokenizer(cls)
108 | tokenizer.initialize()
109 | input_string = ['hello world']
110 |
111 | tokenized = tokenizer.string_tensor_to_indices(
112 | input_string, prepend_bos=True, append_eos=True)
113 | tokenized = tokenized.numpy().tolist()[0]
114 | self.assertEqual(tokenized[0], tokenizer.bos_token)
115 |
116 | if tokenizer.pad_token != tokenizer.eos_token:
117 | tokenized = [t for t in tokenized if t != tokenizer.pad_token]
118 | self.assertEqual(tokenized[-1], tokenizer.eos_token)
119 |
120 | @parameterized.expand(
121 | ((tokenizers.WordTokenizer,), (tokenizers.SentencePieceTokenizer,),
122 | (tokenizers.BertTokenizer,), (tokenizers.ClipTokenizer,)))
123 | def test_not_initialized(self, cls):
124 | tokenizer = _get_tokenizer(cls)
125 | input_string = ['hello world']
126 |
127 | with self.assertRaises(RuntimeError):
128 | tokenizer.string_tensor_to_indices(input_string)
129 |
130 | @parameterized.expand((
131 | (tokenizers.WordTokenizer,),
132 | (tokenizers.SentencePieceTokenizer,),
133 | ))
134 | def test_string_to_indices(self, cls):
135 | tokenizer = _get_tokenizer(cls)
136 | tokenizer.initialize()
137 | input_string = 'hello world'
138 | tokenized = tokenizer.string_to_indices(
139 | input_string, prepend_bos=True, append_eos=True, max_num_tokens=42)
140 | self.assertEqual(type(tokenized), list)
141 | self.assertEqual(tokenized[0], tokenizer.bos_token)
142 | tokenized = [t for t in tokenized if t != tokenizer.pad_token]
143 | self.assertEqual(tokenized[-1], tokenizer.eos_token)
144 |
145 | detokenized = tokenizer.indices_to_string(tokenized[1:-1])
146 | self.assertEqual(detokenized, 'hello world')
147 |
148 | def test_clip_tokenizer(self):
149 | tokenizer = _get_tokenizer(tokenizers.ClipTokenizer)
150 | tokenizer.initialize()
151 | input_string = ['This is a test.', 'pushups']
152 | actual_tokenized_tf = tokenizer.string_tensor_to_indices(
153 | input_string, prepend_bos=True, append_eos=True, max_num_tokens=77)
154 |
155 | expected_tokenized = _tokenize_with_original_clip(input_string)
156 |
157 | actual_tokenized1 = actual_tokenized_tf.numpy().tolist()[0]
158 | expected_tokenized1 = expected_tokenized[0]
159 | self.assertEqual(actual_tokenized1, expected_tokenized1)
160 |
161 | actual_decoded = tokenizer.indices_to_string(actual_tokenized1)
162 | self.assertEqual(actual_decoded, 'this is a test .')
163 |
164 | actual_tokenized2 = actual_tokenized_tf.numpy().tolist()[1]
165 | expected_tokenized2 = expected_tokenized[1]
166 | self.assertEqual(actual_tokenized2, expected_tokenized2)
167 |
168 | actual_decoded = tokenizer.indices_to_string(actual_tokenized2)
169 | self.assertEqual(actual_decoded, input_string[1])
170 |
171 |
172 | if __name__ == '__main__':
173 | tf.test.main()
174 |
--------------------------------------------------------------------------------
/dmvr/utils.py:
--------------------------------------------------------------------------------
1 | # Copyright 2021 DeepMind Technologies Limited.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # https://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | """Utils."""
16 |
17 | from typing import Optional, Sequence
18 |
19 | import tensorflow as tf
20 |
21 |
22 | # ----------------------------------------------------------------------
23 | # ------------------------ Experimental utils. -------------------------
24 | # ----------------------------------------------------------------------
25 |
26 |
27 | def combine_datasets(datasets: Sequence[tf.data.Dataset],
28 | batch_size: int = 1,
29 | weights: Optional[Sequence[float]] = None,
30 | seed: Optional[int] = None) -> tf.data.Dataset:
31 | """Combines multiple datasets into a single one.
32 |
33 | THIS IS AN EXPERIMENTAL FEATURE AND MIGHT BE REMOVED AT ANY TIME.
34 |
35 | This function combines multiple datasets into a single one by sampling
36 | elements from each one with the given probabilities. All input datasets must
37 | have the same structure and Tensor shapes.
38 |
39 | Args:
40 | datasets: A list of batched datasets. All datasets should have the same
41 | structure and Tensor shapes.
42 | batch_size: Batch size of the resulting dataset.
43 | weights: A list of the same length as datasets of floats where `weights[i]`
44 | represents the probability with which an element should be sampled from
45 | `datasets[i]`. If `None`, defaults to a uniform distribution across
46 | datasets.
47 | seed: A deterministic seed to use when sampling.
48 |
49 | Returns:
50 | A dataset that interleaves elements from datasets at random, according to
51 | weights if provided, otherwise with uniform probability. The resulting
52 | dataset is batched.
53 | """
54 | datasets = [ds.unbatch() for ds in datasets]
55 | combined_ds = tf.data.experimental.sample_from_datasets(
56 | datasets, weights, seed)
57 | return combined_ds.batch(batch_size, drop_remainder=True)
58 |
--------------------------------------------------------------------------------
/dmvr/utils_test.py:
--------------------------------------------------------------------------------
1 | # Copyright 2021 DeepMind Technologies Limited.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # https://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | """Tests for utils."""
16 |
17 | from dmvr import utils
18 | import tensorflow as tf
19 |
20 |
21 | class UtilsTest(tf.test.TestCase):
22 |
23 | def test_combine_datasets(self):
24 | ds_0 = tf.data.Dataset.from_tensor_slices({
25 | 'feature_0': [[[0] * 10] * 10] * 5,
26 | 'feature_1': [[0] * 10] * 5,
27 | })
28 | ds_1 = tf.data.Dataset.from_tensor_slices({
29 | 'feature_0': [[[1] * 10] * 10] * 5,
30 | 'feature_1': [[1] * 10] * 5,
31 | })
32 | ds_2 = tf.data.Dataset.from_tensor_slices({
33 | 'feature_0': [[[2] * 10] * 10] * 5,
34 | 'feature_1': [[2] * 10] * 5,
35 | })
36 |
37 | # Dataset uniformly sampling from all 3 datasets.
38 | ds_uniform = utils.combine_datasets([ds_0, ds_1, ds_2], 7)
39 | data_uniform = next(iter(ds_uniform))
40 |
41 | # Dataset sampling from ds_1 and ds_2.
42 | ds_no_1 = utils.combine_datasets([ds_0, ds_1, ds_2], 7, [0.5, 0, 0.5])
43 | data_no_1 = next(iter(ds_no_1))
44 |
45 | self.assertSetEqual(set(data_uniform.keys()),
46 | set(['feature_0', 'feature_1']))
47 | self.assertAllEqual(data_uniform['feature_0'].shape, (7, 10))
48 | self.assertAllEqual(data_uniform['feature_1'].shape, (7,))
49 |
50 | self.assertSetEqual(set(data_no_1.keys()),
51 | set(['feature_0', 'feature_1']))
52 | self.assertAllEqual(data_no_1['feature_0'].shape, (7, 10))
53 | self.assertAllEqual(data_no_1['feature_1'].shape, (7,))
54 |
55 | self.assertAllInSet(data_uniform['feature_0'], (0, 1, 2))
56 | self.assertAllInSet(data_uniform['feature_1'], (0, 1, 2))
57 | self.assertAllInSet(data_no_1['feature_0'], (0, 2))
58 | self.assertAllInSet(data_no_1['feature_1'], (0, 2))
59 |
60 |
61 | if __name__ == '__main__':
62 | tf.test.main()
63 |
--------------------------------------------------------------------------------
/dmvr/video_dataset.py:
--------------------------------------------------------------------------------
1 | # Copyright 2021 DeepMind Technologies Limited.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # https://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | """Basic constructors for video datasets."""
15 |
16 | import abc
17 | from typing import Any, List, Optional, Type, TypeVar
18 |
19 | from absl import logging
20 | from dmvr import builders
21 | from dmvr import sources
22 | import tensorflow as tf
23 |
24 | # Types.
25 | T = TypeVar('T', bound=builders.BaseParserBuilder)
26 | NestedStructure = Any
27 |
28 |
29 | class BaseVideoDatasetFactory(abc.ABC):
30 | """Base class to build final `tf.data.Dataset` objects from files.
31 |
32 | Glossary:
33 |
34 | - A source is an object reading binary files in disk (e.g. TFRecords, image
35 | files) and outputting serialized examples (e.g. `tf.train.SequenceExample`).
36 | - A parser is an object reading serialized examples (e.g.
37 | `tf.train.SequenceExample`) and outputting a `builders.FeaturesDict`.
38 | - A processor is an object transforming features dictionary.
39 | - The data processing pipeline is organised in phases. A phase is an unit of
40 | the data processing graph and will have one parser or processor.
41 | - Builders are helpers designed to allow the user to easily customize the data
42 | processing graph by adding functions to each phase.
43 |
44 | Principle:
45 |
46 | All datasets created with this factory follow the same abstraction:
47 | a `parse_fn`, a `sample_fn`, a `decode_fn`, a `preprocess_fn` and a
48 | `postprocess_fn` are used to control the flow of dataset creation besides
49 | normal dataset operations. These functions are created from builders, allowing
50 | the user to build a graph of data processing operations. In details, the
51 | following steps are followed when creating a dataset:
52 | - Read shards from file system using the given `source.Source`.
53 | - Apply `parse_fn` to output values of the `source` (as bytes) to build a
54 | dictionary of raw features. The parse function should only parse the useful
55 | bytes of the serialized input example (e.g. `tf.train.SequenceExample`) and
56 | put the features in a `builders.FeaturesDict` format. `parser_builder` can
57 | be used to easily add more features / modalities.
58 | - Apply `sample_fn` to sequence features contained in the dictionary in
59 | order to select the desired elements of the sequence, e.g. sample a subset
60 | of frames from the entire stored video. `sampler_builder` can be used to
61 | modify or add sampling options.
62 | - Apply `decode_fn` to convert raw formats to the final format. E.g. decode
63 | JPEG string `tf.Tensor` to a `tf.Tensor` of `uint8`. `decoder_builder` can
64 | be used.
65 | - Apply `preprocess_fn`. E.g. crop images, process audio and text.
66 | `preprocessor_builder` can be used.
67 | - Batch, shuffle, prefetch and do other basic operations with the dataset.
68 | - Apply `postprocess_fn` to batched examples. E.g. transpose batches.
69 | `postprocessor_builder` can be used.
70 |
71 | After each one of the data processing functions, a filter is applied in order
72 | to keep only desirable elements in the dataset. These filters can be
73 | customized by using the `filter_builder`.
74 |
75 | A conventional use of this factory consists of implementing a subclass for a
76 | specific dataset, overriding the `_build` method where all common processing
77 | of the specific dataset can be added using the builders.
78 |
79 | The client of the dataset is able to create a factory, configure it, possibly
80 | add custom extra processing steps and use it to make a dataset.
81 |
82 | Usage:
83 |
84 | ```python
85 | class KineticsFactory(BaseVideoDatasetFactory):
86 |
87 | def __init__(self, subset: str):
88 | shards = ['path/to/kinetics/tfrecords/records-00001-of-00500.tfrecord',
89 | ...]
90 | shards = filter_by_subset(shards, subset)
91 | super().__init__(shards)
92 |
93 | def _build(self, frame_height: int, frame_width: int, frame_count: int):
94 | self.parser_builder.parse_feature(
95 | image_seq_example_feature_name,
96 | tf.io.FixedLenSequenceFeature((), dtype=tf.string),
97 | builders.IMAGE_FEATURE_NAME)
98 | self.sampler_builder.add_fn(
99 | lambda x: sample_sequence_fn(x, frame_count),
100 | builders.IMAGE_FEATURE_NAME)
101 | self.decoder_builder.add_fn(decode_frames_fn, builders.IMAGE_FEATURE_NAME)
102 | self.preprocessor_builder.add_fn(
103 | lambda x: resize_frames(x, frame_height, frame_width),
104 | builders.IMAGE_FEATURE_NAME)
105 | # Other processing functions adding text and label.
106 |
107 | # Dataset client code:
108 | factory = KineticsFactory(subset='test').configure(
109 | frame_height=224, frame_width=224, frame_count=8)
110 |
111 | # Add extra custom preprocess functions:
112 | def my_custom_text_tokenizer(text: tf.Tensor) -> tf.Tensor:
113 | # Tokenize text string.
114 | return tokenized_tensor
115 |
116 | def my_custom_add_word_indices(
117 | features_dict: builders.FeaturesDict) -> builders.FeaturesDict:
118 | tokenized_text = features_dict[builders.TEXT_FEATURE_NAME]
119 | features_dict[builders.TEXT_INDICES_FEATURE_NAME] = text_to_indices(
120 | tokenized_text)
121 | return features_dict
122 |
123 | (factory.preprocess_builder.add_fn(my_custom_tokenizer,
124 | builders.TEXT_FEATURE_NAME)
125 | .add_fn(my_custom_add_word_indices))
126 |
127 | # Add filter:
128 | def keep_only_label_zero(fetures_dict: builders.FeaturesDict) -> tf.Tensor:
129 | return tf.equal(features_dict[builders.LABEL_INDEX_FEATURE_NAME], 0)
130 | factory.filter_builder.add_filter_fn(
131 | keep_only_label_zero, builders.Phase.PARSE)
132 |
133 | # Create dataset:
134 | ds = factory.make_dataset(batch_size=16)
135 | ```
136 |
137 | The factory exposes the process functions builders to the client, allowing
138 | simple modifications to the functions. Common process functions, as crop,
139 | resize, etc. should be implemented in common modules.
140 |
141 | See builders documentation for more details.
142 | """
143 |
144 | def __init__(self,
145 | shards: List[str],
146 | parser_builder_class: Type[T] = builders
147 | .SequenceExampleParserBuilder,
148 | source: sources.Source = sources.TFRecordsSource()):
149 | """Initializes the `BaseVideoDatasetFactory`.
150 |
151 | Args:
152 | shards: List of paths to shards containing the data files. Each one of the
153 | paths will be passed to the `source`, that will read the data and output
154 | examples (that will be fed into the parse function generated by the
155 | `parser_builder_class`). Therefore, `shards`, `parser_builder_class` and
156 | `source` have to be consistent.
157 | parser_builder_class: A parser builder class able to parse examples of the
158 | types contained in `shards` files.
159 | source: Source to be used to load raw binary files and decoding it into
160 | examples (encoded as bytes).
161 | """
162 |
163 | self._shards = shards
164 | self._source = source
165 |
166 | # Initialize all function builders.
167 | self.parser_builder = parser_builder_class()
168 | self.sampler_builder = builders.SamplerBuilder()
169 | self.decoder_builder = builders.DecoderBuilder()
170 | self.preprocessor_builder = builders.PreprocessorBuilder()
171 | self.postprocessor_builder = builders.PostprocessorBuilder()
172 |
173 | # Initialize filters.
174 | self.filter_builder = builders.FilterBuilder()
175 |
176 | # Default tune parameters.
177 | self._shuffle_buffer = 256
178 | self._num_parser_threads = 16
179 | self._num_process_threads = tf.data.experimental.AUTOTUNE
180 | self._num_postprocess_threads = 4
181 | self._parser_buffer_size = 64
182 | self._postprocess_buffer_size = 1
183 | self._prefetch_buffer_size = 8
184 | self._cycle_length = None
185 | self._num_parallel_calls_interleave = tf.data.experimental.AUTOTUNE
186 | self._block_length = None
187 | self._seed = None
188 | self._duplicate_proto = None
189 |
190 | self._is_configured = False
191 |
192 | def configure(self, *args, **kwargs) -> 'BaseVideoDatasetFactory':
193 | """Configures all parse and process functions of this factory.
194 |
195 | This function should be called exactly once per factory instance and will
196 | delegate builders configuration to `_build` method.
197 |
198 | Args:
199 | *args: Positional arguments passed to `_build` function.
200 | **kwargs: Non positional arguments passed to `_build` function.
201 |
202 | Returns:
203 | This instance of the factory.
204 |
205 | Raises:
206 | ValueError: Method has already been called.
207 | """
208 | if self._is_configured:
209 | raise ValueError(
210 | '`configure` has already been called. The method should be called '
211 | 'only once to avoid duplicated process functions.')
212 | self._is_configured = True
213 | self._build(*args, **kwargs)
214 | return self
215 |
216 | def tune(self,
217 | shuffle_buffer: Optional[int] = None,
218 | num_parser_threads: Optional[int] = None,
219 | num_process_threads: Optional[int] = None,
220 | num_postprocess_threads: Optional[int] = None,
221 | parser_buffer_size: Optional[int] = None,
222 | postprocess_buffer_size: Optional[int] = None,
223 | prefetch_buffer_size: Optional[int] = None,
224 | cycle_length: Optional[int] = None,
225 | num_parallel_calls_interleave: Optional[int] = None,
226 | block_length: Optional[int] = None,
227 | seed: Optional[int] = None,
228 | duplicate_proto: Optional[int] = None):
229 | """Changes the dataset creation parameters.
230 |
231 | This method should be used to change the default parameters used to create
232 | the dataset in order to improve speed, memory or other. Only given
233 | parameters will be changed, the others will remain the same.
234 |
235 | Args:
236 | shuffle_buffer: The buffer size for shuffle operation. This affects the
237 | randomness of the output. It must be specified if `shuffle` is `True`.
238 | num_parser_threads: Number of threads to use for the parsing operation.
239 | `tf.data.experimental.AUTOTUNE` can be used to auto-tune.
240 | num_process_threads: Number of threads to use for map operations in
241 | sample, decode and preprocess. `tf.data.experimental.AUTOTUNE` can be
242 | used to auto-tune.
243 | num_postprocess_threads: Number of threads to use for map operations in
244 | postprocess. `tf.data.experimental.AUTOTUNE` can be used to auto-tune.
245 | parser_buffer_size: Buffer size of the sample, decode and preprocess
246 | operation.
247 | postprocess_buffer_size: Buffer size of the postprocess operation.
248 | prefetch_buffer_size: Size of the final prefetch buffer.
249 | cycle_length: The number of shards that will be processed concurrently.
250 | `tf.data.experimental.AUTOTUNE` can be used to auto-tune.
251 | num_parallel_calls_interleave: The number of parallel calls to the
252 | interleave method. `tf.data.experimental.AUTOTUNE` can be used to
253 | auto-tune.
254 | block_length: The number of consecutive elements to produce from each
255 | shard.
256 | seed: Random seed of the shuffle operations.
257 | duplicate_proto: Number of duplicates to make for each loaded proto.
258 | Typically different augmentations will be applied for each copy, so
259 | this can reduce disk reads without harming training performance.
260 | This is applied after the post read function, but before the shuffle
261 | buffer.
262 |
263 | Returns:
264 | This instance of the factory.
265 | """
266 | self._shuffle_buffer = shuffle_buffer or self._shuffle_buffer
267 | self._num_parser_threads = num_parser_threads or self._num_parser_threads
268 | self._num_process_threads = num_process_threads or self._num_process_threads
269 | self._num_postprocess_threads = (
270 | num_postprocess_threads or self._num_postprocess_threads)
271 | self._parser_buffer_size = parser_buffer_size or self._parser_buffer_size
272 | self._postprocess_buffer_size = (
273 | postprocess_buffer_size or self._postprocess_buffer_size)
274 | self._prefetch_buffer_size = (
275 | prefetch_buffer_size or self._prefetch_buffer_size)
276 | self._cycle_length = cycle_length or self._cycle_length
277 | self._num_parallel_calls_interleave = (
278 | num_parallel_calls_interleave or self._num_parallel_calls_interleave)
279 | self._block_length = block_length or self._block_length
280 | self._seed = seed or self._seed
281 | self._duplicate_proto = duplicate_proto or self._duplicate_proto
282 |
283 | return self
284 |
285 | # ----------------------------------------------------------------------
286 | # ---------- Methods that must be implemented by child class. ----------
287 | # ----------------------------------------------------------------------
288 |
289 | @abc.abstractmethod
290 | def _build(self, *args, **kwargs) -> None:
291 | """Builds the data processing graph."""
292 |
293 | # ----------------------------------------------------------------------
294 | # -------- Methods that should only be overridden if necessary. --------
295 | # ----------------------------------------------------------------------
296 |
297 | def make_dataset(
298 | self,
299 | shuffle: bool = True,
300 | num_epochs: Optional[int] = None,
301 | batch_size: Optional[int] = 16,
302 | padded_batch: bool = False,
303 | padded_batch_shapes: NestedStructure = None,
304 | drop_remainder: bool = True,
305 | keep_key: bool = False,
306 | cache: bool = False,
307 | override_preprocess_fn: Optional[builders.Processor] = None,
308 | **experimental_kwargs
309 | ) -> tf.data.Dataset:
310 | """Creates a `tf.data.Dataset` instance of the given dataset.
311 |
312 | Args:
313 | shuffle: Whether output data is shuffled.
314 | num_epochs: Number of epochs to cycle through before stopping. If `None`,
315 | this will read samples indefinitely.
316 | batch_size: If an int, an extra leading batch dimension will be present
317 | for all features. If `None`, then no batching is done and no extra batch
318 | dimension is added.
319 | padded_batch: Whether to use `padded_batch` instead of `batch` method.
320 | Padded batch pads a batch of examples to a given output shape. It pads
321 | all examples to the longest one in that batch. This could be used for
322 | sequence data.
323 | padded_batch_shapes: `padded_shapes` to be passed to `padded_batch`.
324 | drop_remainder: Whether to drop any remainder after the last full-size
325 | batch. If `True`, the batch dimension of the resulting op is known;
326 | otherwise, the batch dimension may be `None` in cases where `num_epochs`
327 | is finite and `batch_size` > 1, since the final remainder batch may be
328 | smaller than the usual batch size.
329 | keep_key: Whether to keep the `builders.Source` key as a feature in the
330 | final dictionary. The key for the key in the dictionary is
331 | `builders.KEY_FEATURE_NAME`.
332 | cache: Whether to cache the dataset in RAM. Note that this should only
333 | be used if the dataset can fit in RAM as otherwise it will lead to
334 | out of memory error.
335 | override_preprocess_fn: Function to use instead of built preprocess_fn.
336 | **experimental_kwargs: Other arguments used for experimental features.
337 | These can be removed at any time without prior notice.
338 |
339 | Returns:
340 | An instance of the dataset.
341 |
342 | Raises:
343 | ValueError: Factory has not been configured.
344 | ValueError: `shuffle_buffer` is `None` when dataset is shuffled.
345 | ValueError: `batch_size` is not `None`, `padded_batch` is `False` and
346 | `padded_batch_shapes` is not `None`.
347 | """
348 |
349 | if not self._is_configured:
350 | raise ValueError('Factory has not been configured. Call `configure` '
351 | 'method before `make_dataset`.')
352 |
353 | # Build functions or use its overrides.
354 | parse_fn = self.parser_builder.build()
355 | sample_fn = self.sampler_builder.build()
356 | decode_fn = self.decoder_builder.build()
357 | preprocess_fn = override_preprocess_fn or self.preprocessor_builder.build()
358 | postprocess_fn = self.postprocessor_builder.build()
359 |
360 | # Filter functions.
361 | filter_fn_post_read = self.filter_builder.build(builders.Phase.READ)
362 | filter_fn_post_parse = self.filter_builder.build(builders.Phase.PARSE)
363 | filter_fn_post_sample = self.filter_builder.build(builders.Phase.SAMPLE)
364 | filter_fn_post_decode = self.filter_builder.build(builders.Phase.DECODE)
365 | filter_fn_post_preprocess = self.filter_builder.build(
366 | builders.Phase.PREPROCESS)
367 | filter_fn_post_postprocess = self.filter_builder.build(
368 | builders.Phase.POSTPROCESS)
369 |
370 | if shuffle and self._shuffle_buffer is None:
371 | raise ValueError(
372 | '`shuffle_buffer` cannot be `None` if dataset is shuffled.')
373 |
374 | def parse_example(key: tf.Tensor,
375 | raw_example: tf.Tensor) -> builders.FeaturesDict:
376 | """Decodes bytes of example and parse it into a features dictionary."""
377 | output = parse_fn(raw_example)
378 | # Potentially parse the key.
379 | if keep_key:
380 | output[builders.KEY_FEATURE_NAME] = key
381 | return output
382 |
383 | ds = tf.data.Dataset.from_tensor_slices(self._shards)
384 | if shuffle:
385 | # Shuffling the shards and not only the examples later is important.
386 | ds = ds.shuffle(len(self._shards), seed=self._seed)
387 |
388 | ds = ds.interleave(
389 | self._source.load_and_decode_shard,
390 | cycle_length=self._cycle_length,
391 | block_length=self._block_length,
392 | num_parallel_calls=self._num_parallel_calls_interleave,
393 | deterministic=not shuffle)
394 |
395 | # At this point, the features dictionary is not yet created. We artificially
396 | # create one with the key only to make the interface uniform.
397 | ds = ds.filter(
398 | lambda key, _: filter_fn_post_read({builders.KEY_FEATURE_NAME: key}))
399 |
400 | if self._duplicate_proto is not None:
401 |
402 | def duplicate_fn(x, y):
403 | return (tf.stack([x] * self._duplicate_proto),
404 | tf.stack([y] * self._duplicate_proto))
405 |
406 | ds = ds.map(duplicate_fn)
407 | ds = ds.unbatch()
408 |
409 | if not cache:
410 | ds = ds.repeat(num_epochs)
411 | if shuffle:
412 | ds = ds.shuffle(self._shuffle_buffer, seed=self._seed)
413 |
414 | # Parse.
415 | ds = ds.map(
416 | parse_example,
417 | num_parallel_calls=self._num_parser_threads,
418 | deterministic=not shuffle)
419 | ds = ds.filter(filter_fn_post_parse)
420 |
421 | if cache:
422 | # We cache the dataset after the parsing operation. This means that we
423 | # cache the raw protos before any random operations happen. This can avoid
424 | # IO issues when the dataset fits in RAM. Note that this is the optimal
425 | # place to cache the data (caching before would have no effect as that
426 | # would only be caching a list of files, caching after would be not
427 | # possible due to the random operations that needs to happen after the
428 | # `ds.repeat` operation, making it impossible to cache as the dataset
429 | # would be unbounded).
430 | ds = ds.cache()
431 | ds = ds.repeat(num_epochs)
432 | if shuffle:
433 | ds = ds.shuffle(self._shuffle_buffer, seed=self._seed)
434 | else:
435 | ds = ds.prefetch(self._parser_buffer_size)
436 |
437 | # Sample.
438 | ds = ds.map(
439 | sample_fn,
440 | num_parallel_calls=self._num_process_threads,
441 | deterministic=not shuffle)
442 | ds = ds.filter(filter_fn_post_sample)
443 |
444 | # Decode.
445 | ds = ds.map(
446 | decode_fn,
447 | num_parallel_calls=self._num_process_threads,
448 | deterministic=not shuffle)
449 | ds = ds.filter(filter_fn_post_decode)
450 |
451 | # Preprocess.
452 | ds = ds.map(
453 | preprocess_fn,
454 | num_parallel_calls=self._num_process_threads,
455 | deterministic=not shuffle)
456 | ds = ds.filter(filter_fn_post_preprocess)
457 |
458 | if experimental_kwargs.get('unbatch_after_preprocessing', False):
459 | ds = ds.unbatch()
460 |
461 | if experimental_kwargs.get('ignore_processing_errors', False):
462 | ds = ds.apply(tf.data.experimental.ignore_errors())
463 |
464 | if batch_size is not None:
465 | if padded_batch:
466 | ds = ds.padded_batch(
467 | batch_size=batch_size,
468 | padded_shapes=padded_batch_shapes,
469 | drop_remainder=drop_remainder)
470 | else:
471 | if padded_batch_shapes is not None:
472 | raise ValueError(
473 | '`padded_batch` is `False`, `padded_batch_shapes` must be `None`,'
474 | f'but is {padded_batch_shapes}.')
475 | ds = ds.batch(batch_size, drop_remainder=drop_remainder)
476 |
477 | # Postprocess.
478 | ds = ds.prefetch(self._postprocess_buffer_size)
479 | ds = ds.map(
480 | postprocess_fn,
481 | num_parallel_calls=self._num_postprocess_threads,
482 | deterministic=not shuffle)
483 | ds = ds.filter(filter_fn_post_postprocess)
484 |
485 | ds = ds.prefetch(self._prefetch_buffer_size)
486 |
487 | logging.info('Dataset created successfully')
488 |
489 | return ds
490 |
--------------------------------------------------------------------------------
/dmvr/video_dataset_test.py:
--------------------------------------------------------------------------------
1 | # Copyright 2021 DeepMind Technologies Limited.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # https://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | """Tests for video_dataset."""
16 |
17 | import os
18 | from typing import List, Union
19 |
20 | from dmvr import builders
21 | from dmvr import sources
22 | from dmvr import video_dataset
23 | from parameterized import parameterized
24 | import tensorflow as tf
25 |
26 |
27 | class _TestTFRecordsSource(sources.Source):
28 |
29 | def load_and_decode_shard(self,
30 | shard: Union[str, tf.Tensor]) -> tf.data.Dataset:
31 | ds = tf.data.TFRecordDataset(shard)
32 | ds = ds.map(lambda example: (b'test_key', example))
33 | return ds
34 |
35 |
36 | class _TestVideoDatasetFactory(
37 | video_dataset.BaseVideoDatasetFactory):
38 |
39 | def __init__(self, shards: List[str]):
40 | super().__init__(shards, builders.SequenceExampleParserBuilder,
41 | _TestTFRecordsSource())
42 |
43 | def _build(self,
44 | sample_offset: int = 0,
45 | multiply_by_2: bool = False,
46 | reduce_max: bool = False,
47 | keep_idx: bool = False):
48 | self.parser_builder.parse_feature(
49 | 'sequence', tf.io.FixedLenSequenceFeature((), dtype=tf.int64))
50 | if keep_idx:
51 | self.parser_builder.parse_feature(
52 | 'idx', tf.io.FixedLenFeature((), dtype=tf.int64), is_context=True)
53 |
54 | self.sampler_builder.add_fn(
55 | lambda x: x[sample_offset:(sample_offset + 50)], 'sequence')
56 |
57 | self.decoder_builder.add_fn(
58 | lambda x: tf.cast(x, tf.uint8), 'sequence')
59 |
60 | if multiply_by_2:
61 | self.preprocessor_builder.add_fn(lambda x: 2 * x, 'sequence')
62 |
63 | if reduce_max:
64 | self.postprocessor_builder.add_fn(
65 | lambda x: tf.reduce_max(input_tensor=x, axis=1), 'sequence')
66 |
67 |
68 | class BaseVideoDatasetFactoryTest(tf.test.TestCase):
69 |
70 | def setUp(self):
71 | super().setUp()
72 |
73 | shards = []
74 | tmp_dir = self.get_temp_dir()
75 | # Generate TFRecords of 5 shards with serialized SequenceExamples in the
76 | # format ('sequence', [[0], [1], ..., [99]]) plus the shard and element
77 | # indices.
78 | for shard_idx in range(5):
79 | shard = os.path.join(tmp_dir,
80 | 'example-{:05}-of-00005.tfrecord'.format(shard_idx))
81 | shards.append(shard)
82 |
83 | # Create fake `tf.train.SequenceExample`.
84 | seq_example = tf.train.SequenceExample()
85 | for i in range(100):
86 | seq_example.feature_lists.feature_list.get_or_create(
87 | 'sequence').feature.add().int64_list.value[:] = [i]
88 |
89 | with tf.io.TFRecordWriter(shard) as builder:
90 | for idx in range(10):
91 | seq_example.context.feature.get_or_create(
92 | 'idx').int64_list.value[:] = [shard_idx * 10 + idx]
93 | builder.write(seq_example.SerializeToString())
94 |
95 | self._factory = _TestVideoDatasetFactory(shards)
96 |
97 | def test_basic(self):
98 | ds = self._factory.configure().make_dataset(batch_size=2)
99 |
100 | data = next(iter(ds))
101 | self.assertSetEqual(set(data.keys()), set(['sequence']))
102 | self.assertAllEqual(data['sequence'], [list(range(50))] * 2)
103 |
104 | def test_configure(self):
105 | ds = self._factory.configure(10, True, True).make_dataset(batch_size=2)
106 |
107 | data = next(iter(ds))
108 | self.assertSetEqual(set(data.keys()), set(['sequence']))
109 | self.assertAllEqual(data['sequence'], [59 * 2] * 2)
110 |
111 | def test_configure_exception(self):
112 | with self.assertRaises(ValueError) as _:
113 | self._factory.make_dataset(batch_size=2)
114 |
115 | with self.assertRaises(ValueError) as _:
116 | self._factory.configure().configure()
117 |
118 | def test_keep_key(self):
119 | ds = self._factory.configure().make_dataset(batch_size=2, keep_key=True)
120 |
121 | data = next(iter(ds))
122 | self.assertSetEqual(set(data.keys()),
123 | set(['sequence', builders.KEY_FEATURE_NAME]))
124 | self.assertAllEqual(data[builders.KEY_FEATURE_NAME].shape, (2,))
125 | self.assertEqual(data[builders.KEY_FEATURE_NAME][0].numpy(), b'test_key')
126 | self.assertEqual(data[builders.KEY_FEATURE_NAME][1].numpy(), b'test_key')
127 |
128 | def test_override_preprocess_fn(self):
129 | # Data shouldn't be multiplied by 2.
130 | ds = self._factory.configure(multiply_by_2=True).make_dataset(
131 | batch_size=2, override_preprocess_fn=lambda x: x)
132 |
133 | data = next(iter(ds))
134 | self.assertSetEqual(set(data.keys()), set(['sequence']))
135 | self.assertAllEqual(data['sequence'], [list(range(50))] * 2)
136 |
137 | def test_no_shuffle(self):
138 | # Set block_length to guarantee reading all examples from the first shard.
139 | ds = self._factory.configure(keep_idx=True).tune(
140 | block_length=5).make_dataset(shuffle=False, batch_size=5)
141 |
142 | data = next(iter(ds))
143 | self.assertSetEqual(set(data.keys()), set(['sequence', 'idx']))
144 | self.assertAllEqual(data['idx'], [0, 1, 2, 3, 4])
145 |
146 | def test_filter_read(self):
147 | self._factory.filter_builder.add_filter_fn(
148 | lambda fd: tf.not_equal(fd[builders.KEY_FEATURE_NAME], 'test_key'),
149 | builders.Phase.READ)
150 | ds = self._factory.configure().make_dataset(batch_size=10, keep_key=True)
151 |
152 | with self.assertRaises(StopIteration) as _:
153 | next(iter(ds))
154 |
155 | @parameterized.expand(
156 | ((builders.Phase.PARSE,), (builders.Phase.SAMPLE,),
157 | (builders.Phase.DECODE,), (builders.Phase.PREPROCESS,)))
158 | def test_filter(self, phase):
159 |
160 | def keep_even_idx(features_dict):
161 | idx = features_dict['idx']
162 | return tf.equal(idx % 2, 0)
163 |
164 | self._factory.filter_builder.add_filter_fn(keep_even_idx, phase)
165 | # Set block_length to guarantee reading examples in key order.
166 | ds = self._factory.configure(keep_idx=True).tune(
167 | block_length=10).make_dataset(shuffle=False, batch_size=10)
168 |
169 | data = next(iter(ds))
170 | self.assertSetEqual(set(data.keys()), set(['sequence', 'idx']))
171 | self.assertAllEqual(data['idx'], range(0, 20, 2))
172 |
173 | def test_filter_postprocess(self):
174 | self._factory.filter_builder.add_filter_fn(
175 | lambda fd: tf.not_equal(fd['idx'][0], 0), # Filter first batch.
176 | builders.Phase.POSTPROCESS)
177 | # Set block_length to guarantee reading examples in key order.
178 | ds = self._factory.configure(keep_idx=True).tune(
179 | block_length=10).make_dataset(shuffle=False, batch_size=10)
180 |
181 | data = next(iter(ds))
182 | self.assertSetEqual(set(data.keys()), set(['sequence', 'idx']))
183 | self.assertAllEqual(data['idx'], range(10, 20))
184 |
185 | def test_ignore_processing_errors(self):
186 |
187 | def fail_decode(idx):
188 | # Fail for all odd indices.
189 | error = tf.assert_equal(idx % 2, tf.zeros((), dtype=tf.int64))
190 | with tf.control_dependencies([error]):
191 | return idx
192 |
193 | self._factory.decoder_builder.add_fn(fail_decode, 'idx')
194 | # Set block_length to guarantee reading examples in key order.
195 | ds = self._factory.configure(keep_idx=True).tune(
196 | block_length=10).make_dataset(
197 | shuffle=False, batch_size=10, ignore_processing_errors=True)
198 |
199 | data = next(iter(ds))
200 | self.assertSetEqual(set(data.keys()), set(['sequence', 'idx']))
201 | self.assertAllEqual(data['idx'], range(0, 20, 2))
202 |
203 |
204 | if __name__ == '__main__':
205 | tf.test.main()
206 |
--------------------------------------------------------------------------------
/examples/README.md:
--------------------------------------------------------------------------------
1 | # End-to-end walktrough: from raw data to using the readers for learning
2 |
3 | This document walks you through all the steps that go from raw data (a list of
4 | mp4 files), to a format that is compatible with DMVR, to writing a reader to
5 | finally use it in an ML application.
6 |
7 |
8 | ## Requirements
9 |
10 | To run the code, you will need to install the following dependencies:
11 |
12 | - python3
13 | - numpy
14 | - absl-py
15 | - pandas
16 | - Tensorflow
17 | - [ffmpeg](https://johnvansickle.com/ffmpeg/)
18 | - [ffmpeg-python](https://github.com/kkroening/ffmpeg-python) [Make sure you pip install ffmpeg-python and not python-ffmpeg]
19 | - unrar [Only for the HMDB-51 dataset generation example]
20 | - [scikit-learn](https://scikit-learn.org/) [Only for training linear model on HMDB]
21 |
22 | Please make sure the ffmpeg binaries (downloadable
23 | [here](https://johnvansickle.com/ffmpeg/)) are visible from the *PATH*
24 | environment variable and to install its python-ffmpeg python wrapper (and not
25 | ffmpeg-python which is different). Installing python-ffmpeg with pip can be done
26 | in one line with:
27 |
28 | ```sh
29 | pip install ffmpeg-python
30 | ```
31 |
32 | ## Creating and reading your own DMVR dataset using open-source tools
33 |
34 | First, we will describe how to generate your own DMVR dataset as tfrecord files
35 | from your own videos using open-source tools.
36 |
37 | Finally, we provide a step-by-step example of how to generate the popular
38 | [HMDB-51](https://serre-lab.clps.brown.edu/resource/hmdb-a-large-human-motion-database/)
39 | action recognition video dataset into the DMVR format.
40 |
41 | ### Generating your own tfrecord files
42 |
43 | #### Creating the input CSV for the generation
44 |
45 | To generate a DMVR compatible video dataset using this tool, all you need is to
46 | provide a csv file with the paths of the videos you want to process, together
47 | with additional metadata such as the start/end timestamps, a label or a text
48 | caption. As an example we are going to download two videos (creative common
49 | license):
50 |
51 | ```sh
52 | wget https://cdn.spacetelescope.org/archives/videos/medium_podcast/heic1608c.mp4 \
53 | -O /tmp/heic1608c.mp4
54 | wget https://upload.wikimedia.org/wikipedia/commons/1/18/BRKM_Javeline_Throw.webm \
55 | -O /tmp/BRKM_Javeline_Throw.webm
56 | ```
57 |
58 | We can create the following csv with the downloaded videos to process:
59 |
60 | ```sh
61 | video_path,start,end,label,caption
62 | /tmp/heic1608c.mp4,1.5,6.0,space,the view of the space from a telescope
63 | /tmp/BRKM_Javeline_Throw.webm,0.0,3.0,javeline_throw,someone is throwing a javeline
64 | ```
65 |
66 | where a more precise description of each column is given below:
67 |
68 | | Column name | Description | Optional |
69 | | ----------- | -------------------------------------------------- | -------- |
70 | | video_path | the path of video to process | No |
71 | | start | the clip start time (in second) | No |
72 | | end | the clip end timee (in second) | No |
73 | | label | A label annotated with the clip (i.e. for | Yes |
74 | : : classification) : :
75 | | caption | A free-form text annotated with the clip (i.e. for | Yes |
76 | : : captioning or retrieval) : :
77 |
78 | Run this following line to create the csv:
79 |
80 | ```sh
81 | echo -e "video_path,start,end,label,caption\n/tmp/heic1608c.mp4,1.5,3.0,space,hubble\n/tmp/BRKM_Javeline_Throw.webm,0.0,3.0,javeline_throw,someone is throwing a javeline" > /tmp/input.csv
82 | ```
83 |
84 | #### Generating the tfrecords data using the CSV
85 |
86 | Now that we have created a CSV file with the videos we wish to process, we can
87 | generated the tfrecords using the provided code. This can be done by running the
88 | following commands:
89 |
90 | ```sh
91 | mkdir /tmp/generated_dataset
92 | python generate_from_file.py \
93 | --csv_path=/tmp/input.csv \
94 | --output_path=/tmp/generated_dataset
95 | ```
96 |
97 | where a description of the arguments is given below:
98 |
99 | Arguments | Description
100 | ------------ | -------------------------------------------------------
101 | csv_path | The path of the input CSV with all the video path.
102 | output_path | The generated tfrecords output path
103 | num_shards | The number of tfrecord shards to create (default=1)
104 | decode_audio | Decode and store audio in the tfrecords (default=False)
105 | shuffle_csv | Whether or not to shuffle the input csv (default=False)
106 |
107 | Congratulations! You have created a DMVR compatible dataset from your own file!
108 |
109 | ### Example: Step-by-step generation of HMDB-51 in the DMVR format
110 |
111 | As another example, we provide step-by-step instructions for generating the
112 | [HMDB-51](https://serre-lab.clps.brown.edu/resource/hmdb-a-large-human-motion-database/)
113 | video dataset in the DMVR format.
114 |
115 | #### Creating the HMDB-51 input CSV for the generation pipeline
116 |
117 | First, you need to download the original splits from the official
118 | [link](http://serre-lab.clps.brown.edu/wp-content/uploads/2013/10/test_train_splits.rar).
119 |
120 | ```sh
121 | wget http://serre-lab.clps.brown.edu/wp-content/uploads/2013/10/test_train_splits.rar
122 | ```
123 |
124 | If you have not
125 | [installed unrar](https://www.tecmint.com/how-to-open-extract-and-create-rar-files-in-linux/)
126 | yet, please install it to extract the rar, and then run:
127 |
128 | ```sh
129 | unrar x test_train_splits.rar
130 | rm test_train_splits.rar
131 | ```
132 |
133 | Create the HMDB CSV file using our provided script:
134 |
135 | ```sh
136 | mkdir hmdb_csv
137 | python generate_hmdb_csv.py \
138 | --input_path=testTrainMulti_7030_splits \
139 | --output_path=hmdb_csv
140 | ```
141 |
142 | This will generate in the *hmdb_csv* folder, 6 csv files: train_1.csv,
143 | test_1.csv, train_2.csv, test_2.csv, train_3.csv and test_3.csv which are the
144 | three train/test splits.
145 |
146 | #### Generating the tfrecords data from the generated HMDB-51 CSV
147 |
148 | Now that you have generated the HMDB-51 csv, you will need to download and
149 | extract the videos from the official website and store them in a newly created
150 | *hmdb_videos* directory:
151 |
152 | ```sh
153 | mkdir hmdb_videos
154 | wget https://serre-lab.clps.brown.edu/wp-content/uploads/2013/10/hmdb51_org.rar \
155 | -P hmdb_videos
156 | cd hmdb_videos
157 | unrar x hmdb51_org.rar
158 | rm hmdb51_org.rar
159 | for video_dir in *rar; do unrar x $video_dir; done
160 | rm *.rar
161 | cd ..
162 | ```
163 |
164 | You can now run the generation pipeline given any csv split, for example you can
165 | run the generation pipeline on the train set of the first split with the
166 | following command:
167 |
168 | ```sh
169 | python generate_from_file.py \
170 | --csv_path=hmdb_csv/train_1.csv \
171 | --video_root_path=hmdb_videos \
172 | --output_path=/path/to/hmdb_shards
173 | ```
174 |
175 | and this will generate the tfrecords in the DMVR format for the HMDB-51 split 1
176 | train set split into *sqrt(num_clips)* shards, where *num_clips* is the number
177 | of video clips from the HMDB-51 split 1 train set.
178 |
179 | ## Writing a DMVR reader
180 |
181 | See `hmdb.py` for an example reader for the data created above.
182 |
183 | ## Training a linear classifier on top of existing features
184 |
185 | The script `linear_mmv_hmdb.py` provides a script evaluating the linear
186 | performance of the recently introduced
187 | [MMV networks](https://arxiv.org/abs/2006.16228) on HMDB51.
188 |
189 | To run the script simply do:
190 |
191 | ```shell
192 | python linear_mmv_hmdb.py \
193 | --data_path=/path/to/hmdb_shards \
194 | --model_name=s3d \
195 | --hmdb51_split=1
196 | ```
197 |
198 | It supports three different models and the script should reproduce
199 | the following results (as reported in the paper):
200 |
201 | Visual Backbone | Results on Linear HMDB51 (avg over 3 splits)
202 | ------- | --------
203 | [S3D-G](https://tfhub.dev/deepmind/mmv/s3d/1) (`s3d`) | 62.6
204 | [Resnet-50 TSM](https://tfhub.dev/deepmind/mmv/tsm-resnet50/1): (`tsm-resnet50`) | 66.7
205 | [Resnet-50 TSMx2](https://tfhub.dev/deepmind/mmv/tsm-resnet50/1): (`tsm-resnet50x2`) | 67.1
206 |
207 |
208 | ### References
209 |
210 | ```bibtex
211 | @inproceedings{alayrac2020self,
212 | title={{S}elf-{S}upervised {M}ulti{M}odal {V}ersatile {N}etworks},
213 | author={Alayrac, Jean-Baptiste and Recasens, Adri{\`a} and Schneider, Rosalia and Arandjelovi{\'c}, Relja and Ramapuram, Jason and De Fauw, Jeffrey and Smaira, Lucas and Dieleman, Sander and Zisserman, Andrew},
214 | booktitle={NeurIPS},
215 | year={2020}
216 | }
217 | ```
218 |
--------------------------------------------------------------------------------
/examples/generate_from_file.py:
--------------------------------------------------------------------------------
1 | # Copyright 2021 DeepMind Technologies Limited.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # https://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | """Python script to generate TFRecords of SequenceExample from raw videos."""
15 |
16 | import contextlib
17 | import math
18 | import os
19 | from typing import Dict, Optional, Sequence
20 |
21 | from absl import app
22 | from absl import flags
23 | import ffmpeg
24 | import numpy as np
25 | import pandas as pd
26 | import tensorflow as tf
27 |
28 | flags.DEFINE_string("csv_path", None, "Input csv")
29 | flags.DEFINE_string("output_path", None, "Tfrecords output path.")
30 | flags.DEFINE_string("video_root_path", None,
31 | "Root directory containing the raw videos.")
32 | flags.DEFINE_integer(
33 | "num_shards", -1, "Number of shards to output, -1 means"
34 | "it will automatically adapt to the sqrt(num_examples).")
35 | flags.DEFINE_bool("decode_audio", False, "Whether or not to decode the audio")
36 | flags.DEFINE_bool("shuffle_csv", False, "Whether or not to shuffle the csv.")
37 | FLAGS = flags.FLAGS
38 |
39 |
40 | _JPEG_HEADER = b"\xff\xd8"
41 |
42 |
43 | @contextlib.contextmanager
44 | def _close_on_exit(writers):
45 | """Call close on all writers on exit."""
46 | try:
47 | yield writers
48 | finally:
49 | for writer in writers:
50 | writer.close()
51 |
52 |
53 | def add_float_list(key: str, values: Sequence[float],
54 | sequence: tf.train.SequenceExample):
55 | sequence.feature_lists.feature_list[key].feature.add(
56 | ).float_list.value[:] = values
57 |
58 |
59 | def add_bytes_list(key: str, values: Sequence[bytes],
60 | sequence: tf.train.SequenceExample):
61 | sequence.feature_lists.feature_list[key].feature.add(
62 | ).bytes_list.value[:] = values
63 |
64 |
65 | def add_int_list(key: str, values: Sequence[int],
66 | sequence: tf.train.SequenceExample):
67 | sequence.feature_lists.feature_list[key].feature.add(
68 | ).int64_list.value[:] = values
69 |
70 |
71 | def set_context_int_list(key: str, value: Sequence[int],
72 | sequence: tf.train.SequenceExample):
73 | sequence.context.feature[key].int64_list.value[:] = value
74 |
75 |
76 | def set_context_bytes(key: str, value: bytes,
77 | sequence: tf.train.SequenceExample):
78 | sequence.context.feature[key].bytes_list.value[:] = (value,)
79 |
80 |
81 | def set_context_float(key: str, value: float,
82 | sequence: tf.train.SequenceExample):
83 | sequence.context.feature[key].float_list.value[:] = (value,)
84 |
85 |
86 | def set_context_int(key: str, value: int, sequence: tf.train.SequenceExample):
87 | sequence.context.feature[key].int64_list.value[:] = (value,)
88 |
89 |
90 | def extract_frames(video_path: str,
91 | start: float,
92 | end: float,
93 | fps: int = 10,
94 | min_resize: int = 256):
95 | """Extract list of jpeg bytes from video_path using ffmpeg."""
96 | new_width = "(iw/min(iw,ih))*{}".format(min_resize)
97 | cmd = (
98 | ffmpeg
99 | .input(video_path)
100 | .trim(start=start, end=end)
101 | .filter("fps", fps=fps)
102 | .filter("scale", new_width, -1)
103 | .output("pipe:", format="image2pipe")
104 | )
105 | jpeg_bytes, _ = cmd.run(capture_stdout=True, quiet=True)
106 | jpeg_bytes = jpeg_bytes.split(_JPEG_HEADER)[1:]
107 | jpeg_bytes = map(lambda x: _JPEG_HEADER + x, jpeg_bytes)
108 | return list(jpeg_bytes)
109 |
110 |
111 | def extract_audio(video_path: str,
112 | start: float,
113 | end: float,
114 | sampling_rate: int = 48000):
115 | """Extract raw mono audio float list from video_path with ffmpeg."""
116 | cmd = (
117 | ffmpeg
118 | .input(video_path, ss=start, t=end-start)
119 | .output("pipe:", ac=1, ar=sampling_rate, format="s32le")
120 | )
121 | audio, _ = cmd.run(capture_stdout=True, quiet=True)
122 | audio = np.frombuffer(audio, np.float32)
123 | return list(audio)
124 |
125 |
126 | def generate_sequence_example(video_path: str,
127 | start: float,
128 | end: float,
129 | label_name: Optional[str] = None,
130 | caption: Optional[str] = None,
131 | label_map: Optional[Dict[str, int]] = None):
132 | """Generate a sequence example."""
133 | if FLAGS.video_root_path:
134 | video_path = os.path.join(FLAGS.video_root_path, video_path)
135 | imgs_encoded = extract_frames(video_path, start, end)
136 |
137 | # Initiate the sequence example.
138 | seq_example = tf.train.SequenceExample()
139 |
140 | # Add the label list as text and indices.
141 | if label_name:
142 | set_context_int("clip/label/index", label_map[label_name], seq_example)
143 | set_context_bytes("clip/label/text", label_name.encode(), seq_example)
144 | if caption:
145 | set_context_bytes("caption/string", caption.encode(), seq_example)
146 | # Add the frames as one feature per frame.
147 | for img_encoded in imgs_encoded:
148 | add_bytes_list("image/encoded", [img_encoded], seq_example)
149 |
150 | # Add audio.
151 | if FLAGS.decode_audio:
152 | audio = extract_audio(video_path, start, end)
153 | add_float_list("WAVEFORM/feature/floats", audio, seq_example)
154 |
155 | # Add other metadata.
156 | set_context_bytes("video/filename", video_path.encode(), seq_example)
157 | # Add start and time in micro seconds.
158 | set_context_int("clip/start/timestamp", int(1000000 * start), seq_example)
159 | set_context_int("clip/end/timestamp", int(1000000 * end), seq_example)
160 | return seq_example
161 |
162 |
163 | def main(argv):
164 | del argv
165 | # reads the input csv.
166 | input_csv = pd.read_csv(FLAGS.csv_path)
167 | if FLAGS.num_shards == -1:
168 | num_shards = int(math.sqrt(len(input_csv)))
169 | else:
170 | num_shards = FLAGS.num_shards
171 | # Set up the TFRecordWriters.
172 | basename = os.path.splitext(os.path.basename(FLAGS.csv_path))[0]
173 | shard_names = [
174 | os.path.join(FLAGS.output_path, f"{basename}-{i:05d}-of-{num_shards:05d}")
175 | for i in range(num_shards)
176 | ]
177 | writers = [tf.io.TFRecordWriter(shard_name) for shard_name in shard_names]
178 |
179 | if "label" in input_csv:
180 | unique_labels = list(set(input_csv["label"].values))
181 | l_map = {unique_labels[i]: i for i in range(len(unique_labels))}
182 | else:
183 | l_map = None
184 |
185 | if FLAGS.shuffle_csv:
186 | input_csv = input_csv.sample(frac=1)
187 | with _close_on_exit(writers) as writers:
188 | for i in range(len(input_csv)):
189 | print(
190 | "Processing example %d of %d (%d%%) \r" %
191 | (i, len(input_csv), i * 100 / len(input_csv)),
192 | end="")
193 | v = input_csv["video_path"].values[i]
194 | s = input_csv["start"].values[i]
195 | e = input_csv["end"].values[i]
196 | l = input_csv["label"].values[i] if "label" in input_csv else None
197 | c = input_csv["caption"].values[i] if "caption" in input_csv else None
198 | seq_ex = generate_sequence_example(
199 | v, s, e, label_name=l, caption=c, label_map=l_map)
200 | writers[i % len(writers)].write(seq_ex.SerializeToString())
201 |
202 |
203 | if __name__ == "__main__":
204 | app.run(main)
205 |
--------------------------------------------------------------------------------
/examples/generate_hmdb_csv.py:
--------------------------------------------------------------------------------
1 | # Copyright 2021 DeepMind Technologies Limited.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # https://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | r"""Script to generate csvs for HMDB.
15 |
16 | You would need to download the official splits at:
17 |
18 | http://serre-lab.clps.brown.edu/wp-content/uploads/2013/10/test_train_splits.rar
19 |
20 | and unrar the archive on your machine, e.g. /path/to/hmdb/
21 |
22 | Usage:
23 | ```
24 | python generate_hmdb_csv.py \
25 | --input_path=/path/to/hmdb/testTrainMulti_7030_splits \
26 | --output_path=/path/to/hmdb
27 | ```
28 | """
29 |
30 | import collections
31 | import csv
32 | import glob
33 | import os
34 |
35 | from absl import app
36 | from absl import flags
37 |
38 |
39 | flags.DEFINE_string(
40 | 'input_path', None, 'Path containing the metadata from HMDB51.')
41 | flags.DEFINE_string(
42 | 'output_path', None, 'Path containing the metadata from HMDB51.')
43 |
44 | FLAGS = flags.FLAGS
45 |
46 | InputVideo = collections.namedtuple(
47 | 'InputRow',
48 | ('video_id', 'split', 'subset', 'label_name'))
49 |
50 | OutputRow = collections.namedtuple(
51 | 'OutputRow',
52 | ('video_id', 'start_sec', 'end_sec', 'label_name', 'label_id'))
53 |
54 |
55 | def main(argv):
56 | del argv
57 | all_files = glob.glob(os.path.join(FLAGS.input_path, '*txt'))
58 |
59 | all_rows = []
60 | label_names = set()
61 |
62 | # Read the files.
63 | for file_name in all_files:
64 | base_name = os.path.basename(file_name)
65 | base_name_split = base_name.split('_')
66 | label_name = ' '.join(base_name_split[:-2])
67 | label_name = label_name.replace(' ', '_')
68 | label_names.add(label_name)
69 | split = int(base_name[-5])
70 | with open(file_name, 'r') as f:
71 | lines = [x.strip().split(' ') for x in f.readlines()]
72 |
73 | for (video_id, ind) in lines:
74 | if ind == '1':
75 | all_rows.append(
76 | InputVideo(video_id, split, 'train', label_name))
77 | elif ind == '2':
78 | all_rows.append(
79 | InputVideo(video_id, split, 'test', label_name))
80 |
81 | # Sort the label names.
82 | label_names = list(label_names)
83 | label_names.sort()
84 |
85 | all_csvs = {
86 | 'train_1': [],
87 | 'train_2': [],
88 | 'train_3': [],
89 | 'test_1': [],
90 | 'test_2': [],
91 | 'test_3': [],
92 | }
93 |
94 | # Generate the csvs rows.
95 | for row in all_rows:
96 | csv_name = f'{row.subset}_{row.split}'
97 | all_csvs[csv_name].append(OutputRow(
98 | video_id=f'{row.label_name}/{row.video_id}',
99 | start_sec=0,
100 | end_sec=20,
101 | label_name=row.label_name,
102 | label_id=label_names.index(row.label_name)
103 | ))
104 |
105 | # Write the csvs.
106 | for csv_name in all_csvs:
107 | output_path = os.path.join(FLAGS.output_path, f'{csv_name}.csv')
108 | print(f'Writing outputs to CSV file {output_path}')
109 | with open(output_path, 'w') as f:
110 | writer = csv.writer(f, delimiter=',')
111 | writer.writerow(
112 | ['video_path', 'start', 'end', 'label'])
113 |
114 | for row in all_csvs[csv_name]:
115 | writer.writerow([
116 | row.video_id, row.start_sec, row.end_sec, row.label_name])
117 |
118 | if __name__ == '__main__':
119 | app.run(main)
120 |
--------------------------------------------------------------------------------
/examples/hmdb.py:
--------------------------------------------------------------------------------
1 | # Copyright 2021 DeepMind Technologies Limited.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # https://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | """HMDB51 video dataset."""
15 |
16 | import os
17 | from typing import Optional
18 |
19 | from dmvr import modalities
20 | from dmvr import video_dataset
21 |
22 |
23 | class HMDB51Factory(video_dataset.BaseVideoDatasetFactory):
24 | """HMDB51 reader."""
25 |
26 | _SUBSETS = ('train', 'test')
27 | _SPLITS = (1, 2, 3)
28 | _NUM_CLASSES = 51
29 |
30 | _NUM_SHARDS = {'train': 59, 'test': 39}
31 |
32 | def __init__(
33 | self,
34 | base_dir: str,
35 | subset: str = 'train',
36 | split: int = 1):
37 | """Constructor of HMDB51Factory."""
38 |
39 | if subset not in HMDB51Factory._SUBSETS:
40 | raise ValueError('Invalid subset "{}". The available subsets are: {}'
41 | .format(subset, HMDB51Factory._SUBSETS))
42 |
43 | if split not in HMDB51Factory._SPLITS:
44 | raise ValueError('Invalid split "{}". The available splits are: {}'
45 | .format(split, HMDB51Factory._SPLITS))
46 |
47 | num_shards = self._NUM_SHARDS[subset]
48 | shards = [f'{subset}_{split}-{i:05d}-of-{num_shards:05d}'
49 | for i in range(num_shards)]
50 | super().__init__(shards=[os.path.join(base_dir, s) for s in shards])
51 |
52 | def _build(self,
53 | is_training: Optional[bool] = True,
54 | # Video related parameters.
55 | num_frames: int = 32,
56 | stride: int = 1,
57 | num_test_clips: int = 1,
58 | min_resize: int = 256,
59 | crop_size: int = 224,
60 | zero_centering_image: bool = False,
61 | # Label related parameters.
62 | one_hot_label: bool = True,
63 | add_label_name: bool = False):
64 | """Default build for this dataset.
65 |
66 | Args:
67 | is_training: Whether or not in training mode.
68 | num_frames: Number of frames per subclip. For single images, use 1.
69 | stride: Temporal stride to sample frames.
70 | num_test_clips: Number of test clips (1 by default). If more than 1, this
71 | will sample multiple linearly spaced clips within each video at test
72 | time. If 1, then a single clip in the middle of the video is sampled.
73 | The clips are aggreagated in the batch dimension.
74 | min_resize: Frames are resized so that `min(height, width)` is
75 | `min_resize`.
76 | crop_size: Final size of the frame after cropping the resized frames. Both
77 | height and width are the same.
78 | zero_centering_image: If `True`, frames are normalized to values in
79 | [-1, 1]. If `False`, values in [0, 1].
80 | one_hot_label: Return labels as one hot tensors.
81 | add_label_name: Also return the name of the label.
82 | """
83 | modalities.add_image(
84 | parser_builder=self.parser_builder,
85 | sampler_builder=self.sampler_builder,
86 | decoder_builder=self.decoder_builder,
87 | preprocessor_builder=self.preprocessor_builder,
88 | postprocessor_builder=self.postprocessor_builder,
89 | is_training=is_training,
90 | num_frames=num_frames, stride=stride,
91 | num_test_clips=num_test_clips,
92 | min_resize=min_resize, crop_size=crop_size,
93 | zero_centering_image=zero_centering_image)
94 |
95 | modalities.add_label(
96 | parser_builder=self.parser_builder,
97 | decoder_builder=self.decoder_builder,
98 | preprocessor_builder=self.preprocessor_builder,
99 | one_hot_label=one_hot_label,
100 | num_classes=HMDB51Factory._NUM_CLASSES,
101 | add_label_name=add_label_name)
102 |
--------------------------------------------------------------------------------
/examples/linear_mmv_hmdb.py:
--------------------------------------------------------------------------------
1 | # Copyright 2021 DeepMind Technologies Limited.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # https://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | """HMDB51 linear evaluation of MMV models."""
16 |
17 | from absl import app
18 | from absl import flags
19 | from dmvr import builders
20 | import hmdb
21 | import numpy as np
22 | from sklearn import preprocessing
23 | from sklearn import svm
24 | import tensorflow.compat.v2 as tf
25 | import tensorflow_datasets as tfds
26 | import tensorflow_hub as hub
27 |
28 |
29 | flags.DEFINE_enum('model_name', 's3d',
30 | ['s3d', 'tsm-resnet50', 'tsm-resnet50x2'],
31 | 'Which MMV backbone to load.')
32 | flags.DEFINE_string('data_path', '/path/to/hmdb/', 'Path where shards live.')
33 | flags.DEFINE_integer('eval_batch_size', 1,
34 | 'The batch size for evaluation.')
35 | flags.DEFINE_integer('train_batch_size', 16,
36 | 'The batch size for training.')
37 | flags.DEFINE_integer('num_train_epochs', 10,
38 | 'How many epochs to collect features during training.')
39 | flags.DEFINE_integer('num_test_clips', 10,
40 | 'How many clips to average on during test.')
41 | flags.DEFINE_integer('min_resize', 224,
42 | 'Min value to resize images to during preprocessing.')
43 | flags.DEFINE_integer('crop_size', 200,
44 | 'Value to resize images to during preprocessing.')
45 | flags.DEFINE_integer('num_frames', 32, 'Number of video frames.')
46 | flags.DEFINE_integer('stride', 1, 'Stride for video frames.')
47 | flags.DEFINE_integer('hmdb51_split', 1, 'Which split of hmdb51 to use.')
48 |
49 |
50 | FLAGS = flags.FLAGS
51 |
52 | _MODELS2REG = {'s3d': 0.0003,
53 | 'tsm-resnet50': 0.0001,
54 | 'tsm-resnet50x2': 0.0003}
55 |
56 |
57 | def compute_accuracy_metrics(pred: np.ndarray, gt: np.ndarray,
58 | prefix: str = ''):
59 | order_pred = np.argsort(pred, axis=1)
60 | assert len(gt.shape) == len(order_pred.shape) == 2
61 | top1_pred = order_pred[:, -1:]
62 | top5_pred = order_pred[:, -5:]
63 | top1_acc = np.mean(top1_pred == gt)
64 | top5_acc = np.mean(np.max(top5_pred == gt, 1))
65 | return {prefix + 'top1': top1_acc,
66 | prefix + 'top5': top5_acc}
67 |
68 |
69 | def main(argv):
70 | del argv
71 |
72 | # Load the model.
73 | sklearn_reg = _MODELS2REG[FLAGS.model_name]
74 | module = hub.load(f'https://tfhub.dev/deepmind/mmv/{FLAGS.model_name}/1')
75 |
76 | def get_features(input_frames: np.ndarray):
77 | vision_output = module.signatures['video'](
78 | tf.constant(tf.cast(input_frames, dtype=tf.float32)))
79 | return vision_output['before_head'].numpy()
80 |
81 | def collect_features_and_labels(ds: tf.data.Dataset, subset: str):
82 | """Collect features and labels."""
83 | features = []
84 | labels = []
85 | print(f'Computing features on {subset}')
86 | examples = iter(tfds.as_numpy(ds))
87 | num_examples = 0
88 | for ex in examples:
89 | vid_representation = get_features(ex[builders.IMAGE_FEATURE_NAME])
90 | labels.append(ex[builders.LABEL_INDEX_FEATURE_NAME])
91 | features.append(vid_representation)
92 | num_examples += ex[builders.LABEL_INDEX_FEATURE_NAME].shape[0]
93 | if num_examples % 100 == 0:
94 | print(f'Processed {num_examples} examples.')
95 | labels = np.concatenate(labels, axis=0)
96 | features = np.concatenate(features, axis=0)
97 | print(f'Finish collecting {subset} features of shape {features.shape}')
98 | return features, labels
99 |
100 | # Generate the training and testing datasets.
101 | conf_kwargs = dict(
102 | num_frames=FLAGS.num_frames,
103 | stride=FLAGS.stride,
104 | min_resize=FLAGS.min_resize,
105 | crop_size=FLAGS.crop_size,
106 | one_hot_label=False)
107 |
108 | train_ds = hmdb.HMDB51Factory(
109 | FLAGS.data_path, subset='train', split=FLAGS.hmdb51_split).configure(
110 | is_training=True, **conf_kwargs).make_dataset(
111 | shuffle=True,
112 | num_epochs=FLAGS.num_train_epochs,
113 | batch_size=FLAGS.train_batch_size)
114 |
115 | test_ds = hmdb.HMDB51Factory(
116 | FLAGS.data_path, subset='test', split=FLAGS.hmdb51_split).configure(
117 | is_training=False, num_test_clips=FLAGS.num_test_clips,
118 | **conf_kwargs).make_dataset(shuffle=False,
119 | num_epochs=1,
120 | batch_size=FLAGS.eval_batch_size)
121 |
122 | # Collect features and labels.
123 | train_features, train_labels = collect_features_and_labels(train_ds, 'train')
124 | test_features, test_labels = collect_features_and_labels(test_ds, 'test')
125 |
126 | # Train classifier
127 | print('Training linear classifier!')
128 | classifier = svm.LinearSVC(C=sklearn_reg)
129 | scaler = preprocessing.StandardScaler().fit(train_features)
130 | train_features = scaler.transform(train_features)
131 | classifier.fit(train_features, train_labels.ravel())
132 | print('Training done !')
133 |
134 | # Evaluation.
135 | test_features = scaler.transform(test_features)
136 | print('Running inference on train')
137 | pred_train = classifier.decision_function(train_features)
138 | print('Running inference on test')
139 | pred_test = classifier.decision_function(test_features)
140 | if FLAGS.num_test_clips > 1:
141 | pred_test = np.reshape(
142 | pred_test, (test_labels.shape[0], -1, pred_test.shape[1]))
143 | pred_test = pred_test.mean(axis=1)
144 |
145 | # Compute accuracies.
146 | metrics = compute_accuracy_metrics(pred_train, train_labels, prefix='train_')
147 | metrics.update(
148 | compute_accuracy_metrics(pred_test, test_labels, prefix='test_'))
149 | print(metrics)
150 |
151 | if __name__ == '__main__':
152 | app.run(main)
153 |
--------------------------------------------------------------------------------
/requirements-test.txt:
--------------------------------------------------------------------------------
1 | numpy
2 | parameterized
3 | tensorflow>=2.0.0
4 |
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | absl-py
2 | sentencepiece
3 | tensorflow>=2.0.0
4 | tensorflow_text>=2.0.0
5 |
--------------------------------------------------------------------------------
/setup.py:
--------------------------------------------------------------------------------
1 | # Copyright 2021 DeepMind Technologies Limited.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # https://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | """Setup for pip package."""
16 |
17 | import setuptools
18 |
19 | _VERSION = '0.0.1'
20 |
21 |
22 | def _parse_requirements(requirements_txt_path):
23 | parse_line = lambda l: l.split('#')[0].strip()
24 | with open(requirements_txt_path) as f:
25 | return [parse_line(l) for l in f]
26 |
27 |
28 | setuptools.setup(
29 | name='dmvr',
30 | version=_VERSION,
31 | url='https://github.com/deepmind/dmvr',
32 | license='Apache 2.0',
33 | author='DeepMind',
34 | description=(
35 | 'DMVR is a library for reading and processing multimodal datasets.'),
36 | long_description=open('README.md').read(),
37 | long_description_content_type='text/markdown',
38 | author_email='dmvr-dev-os@google.com',
39 | # Contained modules and scripts.
40 | packages=setuptools.find_namespace_packages(exclude=['*_test.py']),
41 | install_requires=_parse_requirements('requirements.txt'),
42 | tests_require=_parse_requirements('requirements-test.txt'),
43 | requires_python='>=3.6',
44 | include_package_data=True,
45 | zip_safe=False,
46 | # PyPI package information.
47 | classifiers=[
48 | 'Intended Audience :: Developers',
49 | 'Intended Audience :: Education',
50 | 'Intended Audience :: Science/Research',
51 | 'License :: OSI Approved :: Apache Software License',
52 | 'Programming Language :: Python :: 3',
53 | 'Programming Language :: Python :: 3.6',
54 | 'Programming Language :: Python :: 3.7',
55 | 'Topic :: Scientific/Engineering :: Mathematics',
56 | 'Topic :: Software Development :: Libraries :: Python Modules',
57 | 'Topic :: Software Development :: Libraries',
58 | ],
59 | )
60 |
--------------------------------------------------------------------------------
/test.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | # Copyright 2021 DeepMind Technologies Limited.
4 | #
5 | # Licensed under the Apache License, Version 2.0 (the "License");
6 | # you may not use this file except in compliance with the License.
7 | # You may obtain a copy of the License at
8 | #
9 | # https://www.apache.org/licenses/LICENSE-2.0
10 | #
11 | # Unless required by applicable law or agreed to in writing, software
12 | # distributed under the License is distributed on an "AS IS" BASIS,
13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 | # See the License for the specific language governing permissions and
15 | # limitations under the License.
16 |
17 | # Pip installs the relevant dependencies and runs DMVR tests.
18 |
19 | set -e
20 | set -x
21 |
22 | python3 -m pip install --upgrade pip
23 | python3 -m pip install virtualenv
24 | virtualenv -p python3 .
25 | source bin/activate
26 | python3 --version
27 |
28 | # Run setup.py, install dependencies first to use pip install.
29 | python3 -m pip install -r requirements.txt
30 | python3 setup.py install
31 |
32 | # Python test dependencies.
33 | python3 -m pip install -r requirements-test.txt
34 |
35 | # Run all tests.
36 | python3 -m unittest discover -s 'dmvr' -p '*_test.py'
37 |
38 | deactivate
39 |
--------------------------------------------------------------------------------