├── test ├── __init__.py ├── images │ ├── 1.jpeg │ ├── 2.jpeg │ ├── 3.jpeg │ ├── 4.jpeg │ └── 5.jpeg ├── gif │ └── rotate_earth.gif ├── example.py ├── test_img_seq.py ├── test_h5record.py ├── test_swrm.py ├── test_modality.py ├── test_schema.py └── test_to_mem.py ├── h5record ├── __init__.py ├── dataset.py └── attributes.py ├── LICENSE ├── NOTES.md ├── setup.py ├── .gitignore └── README.md /test/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /h5record/__init__.py: -------------------------------------------------------------------------------- 1 | from .dataset import * 2 | from .attributes import * -------------------------------------------------------------------------------- /test/images/1.jpeg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/theblackcat102/H5Record/HEAD/test/images/1.jpeg -------------------------------------------------------------------------------- /test/images/2.jpeg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/theblackcat102/H5Record/HEAD/test/images/2.jpeg -------------------------------------------------------------------------------- /test/images/3.jpeg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/theblackcat102/H5Record/HEAD/test/images/3.jpeg -------------------------------------------------------------------------------- /test/images/4.jpeg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/theblackcat102/H5Record/HEAD/test/images/4.jpeg -------------------------------------------------------------------------------- /test/images/5.jpeg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/theblackcat102/H5Record/HEAD/test/images/5.jpeg -------------------------------------------------------------------------------- /test/gif/rotate_earth.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/theblackcat102/H5Record/HEAD/test/gif/rotate_earth.gif -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2021 theblackcat102 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /test/example.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torchvision.datasets import MNIST 3 | import numpy as np 4 | from tqdm import tqdm # for progress tracking 5 | 6 | 7 | from h5record import H5Dataset, Integer, Image 8 | 9 | def mnist_generator(): 10 | dataset = MNIST(root='mnist', download=True) 11 | for data in tqdm(dataset): 12 | image, label = data 13 | np_img = np.array(image).reshape(1, 28, 28) 14 | # NOTE: the key must be same as schema name 15 | yield {'img': np_img, 'label': label } 16 | 17 | 18 | schema = ( 19 | Integer('label'), 20 | Image(name='img', h=28, w=28) 21 | ) 22 | 23 | # it's recommended to provide dataset size 24 | # this would provide faster index access 25 | dataset = H5Dataset(schema, 'mnist.h5', 26 | mnist_generator(), 27 | data_length=60000, chunk_size=300, 28 | multiprocess=True) 29 | 30 | 31 | print('Data size ', len(dataset)) 32 | from torch.utils.data import DataLoader 33 | 34 | dataloader = DataLoader(dataset, batch_size=128, 35 | shuffle=True, num_workers=4) 36 | 37 | for batch in tqdm(dataloader): 38 | imgs = batch['img'] 39 | labels = batch['label'] 40 | print(imgs.shape, labels.shape) -------------------------------------------------------------------------------- /NOTES.md: -------------------------------------------------------------------------------- 1 | # Note 2 | 3 | ### Materials 4 | 5 | Useful materials about HDF5 6 | 7 | [HDF5 tech note](https://support.hdfgroup.org/ftp/HDF5/documentation/doc1.8/TechNotes/TechNote-HDF5-ImprovingIOPerformanceCompressedDatasets.pdf) 8 | 9 | [Compression benchmark ](https://www.hdfgroup.org/2018/06/hdf5-or-how-i-learned-to-love-data-compression-and-partial-i-o/) 10 | 11 | 12 | Other materials and discussion 13 | 14 | 1. [A discussion industrial large dataset solution in Pytorch](https://github.com/pytorch/pytorch/issues/20822) 15 | 16 | 17 | 2. [H5Record reddit post](https://www.reddit.com/r/MachineLearning/comments/nsq3ai/p_h5records_store_large_datasets_in_one_single/) 18 | 19 | 20 | 21 | ## Comparison between LMDB and HDF5 22 | 23 | Data obtain from [w86763777 script](https://github.com/w86763777/LMDBvsHDF5) 24 | 25 | | Compression Type | Write | Read | Size | 26 | |---|---|---|---| 27 | | HDF5 | 4.32 secs | 1.20 secs | 496K | 28 | | LMDB | 1.68 secs | 0.10 secs | 224M | 29 | 30 | * Benchmarked on 103 images, total size of 5.4M, image resized on LMDB file 31 | 32 | Overall LMDB provide a 2.6x improvement on write and 12x on read speed (results are averaged on 10 reads/writes session, benchmark on macbook 2017 Intel Core i5 ). 33 | 34 | Maybe H5record should include additional backend choice for LMDB since it supports significant fast load of binary file. 35 | 36 | 37 | ### TODO 38 | 39 | - [ ] Test combinations of different data modalities 40 | 41 | - [ ] Do more tuning and experiments on different driver settings 42 | 43 | - [ ] Performance benchmark: 44 | 45 | - [ ] Performance comparison between zip in multiple workers ( I suspect there's some improvement to be done here ) 46 | 47 | - [ ] In memory (dataset[:]) access vs no compression 48 | 49 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | ''' 2 | setup.py - a setup script 3 | Licensed under the Apache License, Version 2.0 (the "License"); 4 | you may not use this file except in compliance with the License. 5 | You may obtain a copy of the License at 6 | http://www.apache.org/licenses/LICENSE-2.0 7 | Unless required by applicable law or agreed to in writing, software 8 | distributed under the License is distributed on an "AS IS" BASIS, 9 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 10 | See the License for the specific language governing permissions and 11 | limitations under the License. 12 | Authors: 13 | @theblackcat 14 | ''' 15 | from setuptools import setup, Extension 16 | 17 | try: 18 | from setuptools import setup 19 | except ImportError: 20 | from distutils.core import setup 21 | 22 | with open('LICENSE', 'r') as f: 23 | license_ = f.read() 24 | 25 | with open('README.md', 'r') as f: 26 | readme = f.read() 27 | 28 | 29 | test_requirements = [ 30 | 'Pillow', 31 | ] 32 | 33 | setup( 34 | name='h5record', 35 | version='1.0.4', 36 | description='Large data storage for pytorch', 37 | long_description=readme, 38 | long_description_content_type="text/markdown", 39 | author='theblackcat', 40 | author_email='zhirui09400@gmail.com', 41 | url='https://github.com/theblackcat102/h5record', 42 | keywords='data processing', 43 | packages=['h5record'], 44 | install_requires=[ 45 | 'torch', 46 | 'h5py', 47 | 'numpy' 48 | ], 49 | tests_require=test_requirements, 50 | license='MIT License', 51 | classifiers=[ 52 | 'Development Status :: 5 - Production/Stable', 53 | 'Intended Audience :: Developers', 54 | 'License :: OSI Approved :: Apache Software License', 55 | 'Operating System :: OS Independent', 56 | 'Programming Language :: Python :: 3', 57 | 'Programming Language :: Python :: 3.6', 58 | 'Programming Language :: Python :: 3.7', 59 | 'Programming Language :: Python :: 3.8', 60 | ] 61 | ) 62 | -------------------------------------------------------------------------------- /test/test_img_seq.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | import os 3 | import numpy as np 4 | import torch 5 | 6 | class TestDataModality(unittest.TestCase): 7 | 8 | 9 | # def test_pow(self): 10 | # import h5py 11 | # import numpy as np 12 | # if os.path.exists('temp.h5'): 13 | # os.remove('temp.h5') 14 | # f = h5py.File('temp.h5','w') 15 | # float32_t = h5py.special_dtype(vlen=np.dtype('float32')) 16 | # evolutionary_ = f.create_dataset('evolutionary', shape=(1, 3, 3, ), maxshape=(None, 3,3, ), dtype=float32_t) 17 | # a = np.random.randn(1, 3, 3, 4) 18 | # b = np.random.randn(1, 3, 3, 6) 19 | 20 | # evolutionary_[0] = a 21 | 22 | # evolutionary_.resize(3, axis=0) 23 | # evolutionary_[1] = b 24 | 25 | # f = h5py.File('temp.h5','r') 26 | # print( np.vstack(f['evolutionary'][0]).shape ) 27 | 28 | # assert np.stack(f['evolutionary'][0], axis=0) == (3, 32, 32, 4) 29 | # assert f['evolutionary'][1] == (3, 32, 32, 6) 30 | 31 | 32 | def test_gif_based_schema(self): 33 | from h5record.dataset import H5Dataset 34 | from h5record.attributes import ImageSequence 35 | gif_attr = ImageSequence(name='gif', h=32, w=32) 36 | schema = [ 37 | gif_attr 38 | ] 39 | data_size = 1 40 | 41 | gif_paths = ['test/gif/rotate_earth.gif']*data_size 42 | 43 | def pair_iter(): 44 | for (gif_path) in gif_paths: 45 | 46 | yield { 47 | 'gif': gif_attr.read_gif(gif_path), 48 | } 49 | if os.path.exists('gif_dataset.h5'): 50 | os.remove('gif_dataset.h5') 51 | 52 | dataset = H5Dataset(schema, './gif_dataset.h5', pair_iter()) 53 | for idx in range(data_size): 54 | # Currently this returns matrix of 3 x 32 x 32 x np.array 55 | # which is treated as numpy.object_ 56 | 57 | # solution include : 58 | # 1. costly reshape 59 | # 2. flatten the 4D matrix as one large 1D matrix which is suitable variable array 60 | gif = dataset[idx]['gif'] 61 | assert gif.shape == (3, 32, 32, 44) 62 | 63 | assert len(dataset) == data_size 64 | 65 | os.remove('gif_dataset.h5') 66 | 67 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | .vscode 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | pip-wheel-metadata/ 24 | share/python-wheels/ 25 | *.egg-info/ 26 | .installed.cfg 27 | *.egg 28 | MANIFEST 29 | 30 | # PyInstaller 31 | # Usually these files are written by a python script from a template 32 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 33 | *.manifest 34 | *.spec 35 | 36 | # Installer logs 37 | pip-log.txt 38 | pip-delete-this-directory.txt 39 | 40 | # Unit test / coverage reports 41 | htmlcov/ 42 | .tox/ 43 | .nox/ 44 | .coverage 45 | .coverage.* 46 | .cache 47 | nosetests.xml 48 | coverage.xml 49 | *.cover 50 | *.py,cover 51 | .hypothesis/ 52 | .pytest_cache/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | target/ 76 | 77 | # Jupyter Notebook 78 | .ipynb_checkpoints 79 | 80 | # IPython 81 | profile_default/ 82 | ipython_config.py 83 | 84 | # pyenv 85 | .python-version 86 | 87 | # pipenv 88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 91 | # install all needed dependencies. 92 | #Pipfile.lock 93 | 94 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 95 | __pypackages__/ 96 | 97 | # Celery stuff 98 | celerybeat-schedule 99 | celerybeat.pid 100 | 101 | # SageMath parsed files 102 | *.sage.py 103 | 104 | # Environments 105 | .env 106 | .venv 107 | env/ 108 | venv/ 109 | ENV/ 110 | env.bak/ 111 | venv.bak/ 112 | 113 | # Spyder project settings 114 | .spyderproject 115 | .spyproject 116 | 117 | # Rope project settings 118 | .ropeproject 119 | 120 | # mkdocs documentation 121 | /site 122 | 123 | # mypy 124 | .mypy_cache/ 125 | .dmypy.json 126 | dmypy.json 127 | 128 | # Pyre type checker 129 | .pyre/ 130 | -------------------------------------------------------------------------------- /test/test_h5record.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | import os 3 | 4 | class TestDataModality(unittest.TestCase): 5 | 6 | 7 | def test_string_pair(self): 8 | from h5record.dataset import H5Dataset 9 | from h5record.attributes import String, Integer 10 | schema = { 11 | 'sentence1': String(name='sentence1'), 12 | 'sentence2': String(name='sentence2'), 13 | 'label': Integer(name='label') 14 | } 15 | 16 | data = [ 17 | ['HDF5 supports chunking and compression.', 'You may want to experiment', 0], 18 | ['Starting to load file into an HDF5 file with chunk size','and compression is gzip', 1 ], 19 | ['Reading from an HDF5 file which you will probably be','about to overwrite! Override this error only if you know what youre doing ', 1], 20 | ['Lorem ipsum dolor sit amet, consectetur adipiscing elit','sed do eiusmod tempor incididunt ut labore et dolore magna aliqua.', 1], 21 | ['Ut enim ad minim veniam, quis nostrud exercitation ullamco laboris','nisi ut aliquip ex ea commodo consequat.', 0], 22 | ['Duis aute irure dolor in reprehenderit in voluptate velit','esse cillum dolore eu fugiat nulla pariatur.', 0], 23 | ['Excepteur sint occaecat cupidatat non proident, ','sunt in culpa qui officia deserunt mollit anim id est laborum.', 0], 24 | 25 | ] 26 | 27 | def pair_iter(): 28 | for row in data: 29 | yield { 30 | 'sentence1': row[0], 31 | 'sentence2': row[1], 32 | 'label': row[2] 33 | } 34 | if os.path.exists('question_pair.h5'): 35 | os.remove('question_pair.h5') 36 | 37 | dataset = H5Dataset(schema, './question_pair.h5', pair_iter()) 38 | 39 | assert len(dataset) == len(data) 40 | 41 | for idx in range(len(data)): 42 | row = dataset[idx] 43 | sent1, sent2, label = data[idx] 44 | assert sent1 == row['sentence1'] 45 | assert sent2 == row['sentence2'] 46 | assert label == row['label'] 47 | 48 | os.remove('question_pair.h5') 49 | 50 | dataset = H5Dataset(schema, './question_pair.h5', pair_iter(), 51 | data_length=len(data), chunk_size=4) 52 | 53 | for idx in range(len(data)): 54 | row = dataset[idx] 55 | sent1, sent2, label = data[idx] 56 | assert sent1 == row['sentence1'] 57 | assert sent2 == row['sentence2'] 58 | assert label == row['label'] 59 | 60 | os.remove('question_pair.h5') 61 | 62 | -------------------------------------------------------------------------------- /test/test_swrm.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | import os 3 | import numpy as np 4 | import torch 5 | 6 | class TestSWRM(unittest.TestCase): 7 | 8 | 9 | # def test_pow(self): 10 | # import h5py 11 | # import numpy as np 12 | # if os.path.exists('temp.h5'): 13 | # os.remove('temp.h5') 14 | # f = h5py.File('temp.h5','w') 15 | # float32_t = h5py.special_dtype(vlen=np.dtype('float32')) 16 | # evolutionary_ = f.create_dataset('evolutionary', shape=(1, 3, 3, ), maxshape=(None, 3,3, ), dtype=float32_t) 17 | # a = np.random.randn(1, 3, 3, 4) 18 | # b = np.random.randn(1, 3, 3, 6) 19 | 20 | # evolutionary_[0] = a 21 | 22 | # evolutionary_.resize(3, axis=0) 23 | # evolutionary_[1] = b 24 | 25 | # f = h5py.File('temp.h5','r') 26 | # print( np.vstack(f['evolutionary'][0]).shape ) 27 | 28 | # assert np.stack(f['evolutionary'][0], axis=0) == (3, 32, 32, 4) 29 | # assert f['evolutionary'][1] == (3, 32, 32, 6) 30 | 31 | 32 | def test_multiprocessing_flag(self): 33 | from h5record.dataset import H5Dataset 34 | from h5record.attributes import ImageSequence 35 | gif_attr = ImageSequence(name='gif', h=32, w=32) 36 | schema = [ 37 | gif_attr 38 | ] 39 | data_size = 1 40 | 41 | gif_paths = ['test/gif/rotate_earth.gif']*data_size 42 | 43 | def pair_iter(): 44 | for (gif_path) in gif_paths: 45 | 46 | yield { 47 | 'gif': gif_attr.read_gif(gif_path), 48 | } 49 | if os.path.exists('gif_dataset.h5'): 50 | os.remove('gif_dataset.h5') 51 | 52 | dataset = H5Dataset(schema, './gif_dataset.h5', pair_iter()) 53 | for idx in range(data_size): 54 | # Currently this returns matrix of 3 x 32 x 32 x np.array 55 | # which is treated as numpy.object_ 56 | 57 | # solution include : 58 | # 1. costly reshape 59 | # 2. flatten the 4D matrix as one large 1D matrix which is suitable variable array 60 | gif = dataset[idx]['gif'] 61 | assert gif.shape == (3, 32, 32, 44) 62 | 63 | 64 | 65 | dataset = H5Dataset(schema, './gif_dataset.h5', pair_iter(), multiprocess=True) 66 | for idx in range(data_size): 67 | # Currently this returns matrix of 3 x 32 x 32 x np.array 68 | # which is treated as numpy.object_ 69 | 70 | # solution include : 71 | # 1. costly reshape 72 | # 2. flatten the 4D matrix as one large 1D matrix which is suitable variable array 73 | gif = dataset[idx]['gif'] 74 | assert gif.shape == (3, 32, 32, 44) 75 | 76 | assert len(dataset) == data_size 77 | 78 | os.remove('gif_dataset.h5') 79 | 80 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # H5Record 2 | [![codecov](https://codecov.io/gh/theblackcat102/H5Record/branch/main/graph/badge.svg?token=WJQ0DEPL38)](https://codecov.io/gh/theblackcat102/H5Record) [![PyPI version](https://badge.fury.io/py/h5record.svg)](https://badge.fury.io/py/h5record) 3 | 4 | Large dataset ( > 100G, <= 1T) storage format for Pytorch (wip) 5 | 6 | Support python 3 7 | 8 | ``` 9 | pip install h5record 10 | ``` 11 | 12 | 13 | ## Why? 14 | 15 | * Writing large dataset is still a wild west in pytorch. Approaches seen in the wild include: 16 | 17 | - large directory with lots of small files : slow IO when complex file is fetched, deserialized frequently 18 | - database approach : depend on what kind of database engine used, usually multi-process read is not supported 19 | - the above method scale non linear in terms of data - storage size 20 | 21 | * TFRecord solved the above problems well ( multiprocess fetch, (de)compression ), fast serialization ( protobuf ) 22 | 23 | * However TFRecord port does not support data size evaluation (used frequently by Dataloader ), no index level access available ( important for data evaluation or verification ) 24 | 25 | H5Record aim to tackle TFRecord problems by compressing the dataset into [HDF5](https://support.hdfgroup.org/HDF5/doc/TechNotes/BigDataSmMach.html) file with an easy to use interface through predefined interfaces ( String, Image, Sequences, Integer). 26 | 27 | Some advantage of using H5Record 28 | 29 | * Support multi-process read 30 | 31 | * Relatively simple to use and low technical debt 32 | 33 | * Support compression/de-compression on the fly 34 | 35 | * Quick load to memory if required 36 | 37 | ### Simple usage 38 | 39 | ``` 40 | pip install h5record 41 | ``` 42 | 43 | 44 | 1. Sentence Similarity 45 | 46 | ```python 47 | from h5record import H5Dataset, Float, String 48 | 49 | schema = ( 50 | String(name='sentence1'), 51 | String(name='sentence2'), 52 | Float(name='label') 53 | ) 54 | data = [ 55 | ['Sent 1.', 'Sent 2', 0.1], 56 | ['Sent 3', 'Sent 4', 0.2], 57 | ] 58 | 59 | def pair_iter(): 60 | for row in data: 61 | yield { 62 | 'sentence1': row[0], 63 | 'sentence2': row[1], 64 | 'label': row[2] 65 | } 66 | 67 | dataset = H5Dataset(schema, './question_pair.h5', pair_iter()) 68 | for idx in range(len(dataset)): 69 | print(dataset[idx]) 70 | 71 | ``` 72 | 73 | 74 | ## Note 75 | 76 | Due to in progress development, this package should be use in care in storage with FAT, FAT-32 format 77 | 78 | ## Comparison between different compression algorithm 79 | 80 | No chunking is used 81 | 82 | | Compression Type | File size | Read speed row/second | 83 | |---|---|---| 84 | | no compression | 2.0G | 2084.55 it/s | 85 | | lzf | 1.7G | 1496.14 it/s | 86 | | gzip | 1.1G | 843.78 it/s | 87 | 88 | benchmarked in i7-9700, 1TB NVMe SSD 89 | 90 | 91 | 92 | If you are interested to learn more feel free to checkout the [note](NOTES.md) as well! 93 | 94 | 95 | -------------------------------------------------------------------------------- /h5record/dataset.py: -------------------------------------------------------------------------------- 1 | import h5py as h5 2 | import numpy as np 3 | from torch.utils.data.dataset import Dataset 4 | import os 5 | from .attributes import ( 6 | String, ImageSequence 7 | ) 8 | 9 | class AtomicFile: 10 | ''' 11 | Wrapper file for h5 in case multiprocess writes is needed 12 | ''' 13 | def __init__(self, path): 14 | self.fd = os.open(path, os.O_RDONLY) 15 | self.pos = 0 16 | 17 | def seek(self, pos, whence=0): 18 | if whence == 0: 19 | self.pos = pos 20 | elif whence == 1: 21 | self.pos += pos 22 | else: 23 | self.pos = os.lseek(self.fd, pos, whence) 24 | return self.pos 25 | 26 | def tell(self): 27 | return self.pos 28 | 29 | def read(self, size): 30 | b = os.pread(self.fd, size, self.pos) 31 | self.pos += len(b) 32 | return b 33 | 34 | class H5Dataset(Dataset): 35 | 36 | def __init__(self, schema, save_filename, data_iter=None, 37 | data_length=None, chunk_size=300, compression=None, 38 | transform=None, append_mode=False, verbose=0, 39 | to_memory=False, multiprocess=False): 40 | 41 | ''' 42 | Note: 43 | * data length must be known value otherwise chunk size will not be enabled 44 | * chunk size affects reading speed, usually a size of 100-500 is suitable value 45 | * compression algorithm affects reading speed, so if storage is not your concern is recommended not to enable 46 | multiprocess: 47 | if such error occur "OSError: Can't read data (address of object past end of allocation)" 48 | set it to True 49 | ''' 50 | 51 | # normalized schema design to dictionary 52 | if isinstance(schema, list) or isinstance(schema, tuple): 53 | schema = { s.name: s for s in schema } 54 | self.schema = schema 55 | self.save_filename = save_filename 56 | self.data_length = data_length # dataset maximum size 57 | self.transform = transform # transform function before returned by index access 58 | self.append_mode = append_mode # force append ? 59 | 60 | self.chunk_size = None if data_length is None else chunk_size 61 | assert compression in [None, 'lzf', 'gzip', 'szip'] 62 | self.compression = compression 63 | if not os.path.exists(self.save_filename): 64 | self.preprocess(data_iter) 65 | 66 | if multiprocess: # this is a backup method to ensure multiprocessing support on old file 67 | self.reader = h5.File(AtomicFile(self.save_filename), 'r') 68 | else: 69 | self.reader = h5.File(self.save_filename, 'r', swmr=True) 70 | 71 | first_key = list(self.schema.keys())[0] 72 | self.num_entries = self.reader[first_key].shape[0] 73 | 74 | 75 | if to_memory: 76 | # warning this may use all your memory 77 | temp = {} 78 | for key in self.schema.keys(): 79 | temp[key] = self.reader[key][:] 80 | self.reader = temp 81 | 82 | def preprocess(self, data_iter): 83 | idx = 0 84 | for data in data_iter: 85 | if idx == 0: 86 | with h5.File(self.save_filename, 'w', libver='latest') as fout: 87 | fout.swmr_mode = True 88 | for key, value in data.items(): 89 | attribute = self.schema[key] 90 | attribute.init_attributes(fout, value, 91 | self.compression, self.data_length) 92 | else: 93 | with h5.File(self.save_filename, 'a', libver='latest') as fout: 94 | fout.swmr_mode = True 95 | for key, value in data.items(): 96 | attribute = self.schema[key] 97 | value = attribute.transform(value) 98 | attribute.append(fout, value) 99 | 100 | idx += 1 101 | 102 | def __len__(self): 103 | return self.num_entries 104 | 105 | 106 | def __getitem__(self, idx): 107 | data = {} 108 | for key in self.schema.keys(): 109 | raw_output = self.reader[key][idx] 110 | attribute = self.schema[key] 111 | if isinstance(attribute, String): 112 | data[key] = raw_output[0].decode(attribute.encoding ) 113 | elif isinstance(attribute, ImageSequence): 114 | # heavy reshaping is needed as variable length dimension (last dimension) 115 | # is always treated as np.array 116 | # rendering high dimension shape becomes a np.object matrix 117 | # 118 | data[key] = raw_output[0].reshape( 119 | attribute.img_channel, attribute.w, attribute.h, -1 ) 120 | else: 121 | data[key] = raw_output 122 | 123 | if self.transform is not None: 124 | return self.transform(data) 125 | 126 | return data 127 | 128 | 129 | -------------------------------------------------------------------------------- /test/test_modality.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | import os 3 | 4 | class TestDataModality(unittest.TestCase): 5 | 6 | 7 | def test_string_pair(self): 8 | from h5record.dataset import H5Dataset 9 | from h5record.attributes import String, Integer 10 | schema = { 11 | 'sentence1': String(name='sentence1'), 12 | 'sentence2': String(name='sentence2'), 13 | 'label': Integer(name='label') 14 | } 15 | 16 | data = [ 17 | ['HDF5 supports chunking and compression.', 'You may want to experiment', 0], 18 | ['Starting to load file into an HDF5 file with chunk size','and compression is gzip', 1 ], 19 | ['Reading from an HDF5 file which you will probably be','about to overwrite! Override this error only if you know what youre doing ', 1], 20 | ] 21 | 22 | def pair_iter(): 23 | for row in data: 24 | yield { 25 | 'sentence1': row[0], 26 | 'sentence2': row[1], 27 | 'label': row[2] 28 | } 29 | if os.path.exists('question_pair.h5'): 30 | os.remove('question_pair.h5') 31 | 32 | dataset = H5Dataset(schema, './question_pair.h5', pair_iter()) 33 | 34 | assert len(dataset) == len(data) 35 | 36 | for idx in range(len(data)): 37 | row = dataset[idx] 38 | sent1, sent2, label = data[idx] 39 | assert sent1 == row['sentence1'] 40 | assert sent2 == row['sentence2'] 41 | assert label == row['label'] 42 | 43 | os.remove('question_pair.h5') 44 | 45 | dataset = H5Dataset(schema, './question_pair.h5', pair_iter(), 46 | data_length=len(data), chunk_size=4) 47 | assert len(dataset) == len(data) 48 | 49 | for idx in range(len(data)): 50 | row = dataset[idx] 51 | sent1, sent2, label = data[idx] 52 | assert sent1 == row['sentence1'] 53 | assert sent2 == row['sentence2'] 54 | assert label == row['label'] 55 | 56 | os.remove('question_pair.h5') 57 | 58 | 59 | def test_unicode(self): 60 | from h5record.dataset import H5Dataset 61 | from h5record.attributes import String, Integer 62 | schema = { 63 | 'sentence1': String(name='sentence1'), 64 | 'sentence2': String(name='sentence2'), 65 | 'label': Integer(name='label') 66 | } 67 | 68 | data = [ 69 | [ '小明去學校', '結果已經下課了', 0] 70 | ] 71 | 72 | def pair_iter(): 73 | for row in data: 74 | yield { 75 | 'sentence1': row[0], 76 | 'sentence2': row[1], 77 | 'label': row[2] 78 | } 79 | if os.path.exists('question_pair.h5'): 80 | os.remove('question_pair.h5') 81 | 82 | dataset = H5Dataset(schema, './question_pair.h5', pair_iter()) 83 | assert len(dataset) == len(data) 84 | 85 | for idx in range(len(data)): 86 | row = dataset[idx] 87 | sent1, sent2, label = data[idx] 88 | 89 | assert sent1 == row['sentence1'] 90 | assert sent2 == row['sentence2'] 91 | assert label == row['label'] 92 | 93 | os.remove('question_pair.h5') 94 | 95 | 96 | def test_image_string_pair(self): 97 | from h5record.dataset import H5Dataset 98 | from h5record.attributes import String, Image 99 | 100 | schema = { 101 | 'image': Image(name='image', h=32, w=32), 102 | 'caption': String(name='caption') 103 | } 104 | 105 | captions = [ 106 | 'Lenna profile', 107 | 'Lenna back patch', 108 | 'Lenna lower patch', 109 | 'meme image', 110 | 'greyscale image' 111 | ] 112 | 113 | image_paths = [ 114 | 'test/images/1.jpeg', 115 | 'test/images/2.jpeg', 116 | 'test/images/3.jpeg', 117 | 'test/images/4.jpeg', 118 | 'test/images/5.jpeg', 119 | ] 120 | def pair_iter(): 121 | for (caption, img_path) in zip(captions, image_paths): 122 | yield { 123 | 'image': schema['image'].read_image(img_path), 124 | 'caption': caption 125 | } 126 | if os.path.exists('img_caption.h5'): 127 | os.remove('img_caption.h5') 128 | 129 | dataset = H5Dataset(schema, './img_caption.h5', pair_iter(), 130 | data_length=len(image_paths), chunk_size=4) 131 | assert len(dataset) == len(captions) 132 | 133 | for idx in range(len(image_paths)): 134 | data = dataset[idx] 135 | assert data['caption'] == captions[idx] 136 | 137 | os.remove('img_caption.h5') 138 | 139 | 140 | dataset = H5Dataset(schema, './img_caption.h5', pair_iter(), 141 | compression='lzf') 142 | 143 | for idx in range(len(image_paths)): 144 | data = dataset[idx] 145 | assert data['caption'] == captions[idx] 146 | 147 | os.remove('img_caption.h5') 148 | -------------------------------------------------------------------------------- /test/test_schema.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | import os 3 | from h5record.dataset import H5Dataset 4 | 5 | class TestDataModality(unittest.TestCase): 6 | 7 | 8 | def test_list_based_schema(self): 9 | from h5record.attributes import String, Integer, Image 10 | 11 | img_attr = Image(name='image', h=32, w=32) 12 | 13 | schema = [ 14 | img_attr, 15 | Integer(name='label'), 16 | String(name='sentence1'), 17 | String(name='sentence2'), 18 | ] 19 | image_paths = [ 20 | 'test/images/1.jpeg', 21 | 'test/images/2.jpeg', 22 | 'test/images/3.jpeg', 23 | ] 24 | 25 | data = [ 26 | ['HDF5 supports chunking and compression.', 'You may want to experiment', 0 ], 27 | ['Starting to load file into an HDF5 file with chunk size','and compression is gzip', 1 ], 28 | ['Reading from an HDF5 file which you will probably be','about to overwrite! Override this error only if you know what youre doing ', 1], 29 | ] 30 | 31 | def pair_iter(): 32 | for (row, image_path) in zip(data, image_paths): 33 | yield { 34 | 'sentence1': row[0], 35 | 'sentence2': row[1], 36 | 'image': img_attr.read_image(image_path), 37 | 'label': row[2] 38 | } 39 | if os.path.exists('question_image_pair.h5'): 40 | os.remove('question_image_pair.h5') 41 | 42 | dataset = H5Dataset(schema, './question_image_pair.h5', pair_iter()) 43 | for idx in range(len(data)): 44 | row = dataset[idx] 45 | sent1, sent2, label = data[idx] 46 | assert sent1 == row['sentence1'] 47 | assert sent2 == row['sentence2'] 48 | assert label == row['label'] 49 | 50 | os.remove('question_image_pair.h5') 51 | 52 | dataset = H5Dataset(schema, './question_image_pair.h5', pair_iter(), 53 | data_length=len(data), chunk_size=4) 54 | 55 | for idx in range(len(data)): 56 | row = dataset[idx] 57 | sent1, sent2, label = data[idx] 58 | assert sent1 == row['sentence1'] 59 | assert sent2 == row['sentence2'] 60 | assert label == row['label'] 61 | 62 | os.remove('question_image_pair.h5') 63 | 64 | def test_tuple_based_schema(self): 65 | from h5record.attributes import String, Integer, Image 66 | 67 | img_attr = Image(name='image', h=32, w=32) 68 | 69 | schema = ( 70 | String(name='sentence1'), 71 | String(name='sentence2'), 72 | img_attr, 73 | Integer(name='label') 74 | ) 75 | image_paths = [ 76 | 'test/images/1.jpeg', 77 | 'test/images/2.jpeg', 78 | 'test/images/3.jpeg', 79 | ] 80 | 81 | data = [ 82 | ['HDF5 supports chunking and compression.', 'You may want to experiment', 0 ], 83 | ['Starting to load file into an HDF5 file with chunk size','and compression is gzip', 1 ], 84 | ['Reading from an HDF5 file which you will probably be','about to overwrite! Override this error only if you know what youre doing ', 1], 85 | ] 86 | 87 | def pair_iter(): 88 | for (row, image_path) in zip(data, image_paths): 89 | yield { 90 | 'sentence1': row[0], 91 | 'sentence2': row[1], 92 | 'image': img_attr.read_image(image_path), 93 | 'label': row[2] 94 | } 95 | if os.path.exists('question_image_pair.h5'): 96 | os.remove('question_image_pair.h5') 97 | 98 | dataset = H5Dataset(schema, './question_image_pair.h5', pair_iter()) 99 | for idx in range(len(data)): 100 | row = dataset[idx] 101 | sent1, sent2, label = data[idx] 102 | assert sent1 == row['sentence1'] 103 | assert sent2 == row['sentence2'] 104 | assert label == row['label'] 105 | 106 | os.remove('question_image_pair.h5') 107 | 108 | dataset = H5Dataset(schema, './question_image_pair.h5', pair_iter(), 109 | data_length=len(data), chunk_size=4) 110 | 111 | for idx in range(len(data)): 112 | row = dataset[idx] 113 | sent1, sent2, label = data[idx] 114 | assert sent1 == row['sentence1'] 115 | assert sent2 == row['sentence2'] 116 | assert label == row['label'] 117 | 118 | os.remove('question_image_pair.h5') 119 | 120 | 121 | def test_dict_based_schema(self): 122 | from h5record.attributes import String, Image 123 | 124 | schema = { 125 | 'image': Image(name='image', h=32, w=32), 126 | 'caption': String(name='caption') 127 | } 128 | 129 | captions = [ 130 | 'Lenna profile', 131 | 'Lenna back patch', 132 | 'Lenna lower patch', 133 | 'meme image', 134 | 'greyscale image' 135 | ] 136 | image_path = [ 137 | 'test/images/1.jpeg', 138 | 'test/images/2.jpeg', 139 | 'test/images/3.jpeg', 140 | 'test/images/4.jpeg', 141 | 'test/images/5.jpeg', 142 | ] 143 | def pair_iter(): 144 | for (caption, img_path) in zip(captions, image_path): 145 | yield { 146 | 'image': schema['image'].read_image(image_path[0]), 147 | 'caption': caption 148 | } 149 | if os.path.exists('img_caption.h5'): 150 | os.remove('img_caption.h5') 151 | 152 | dataset = H5Dataset(schema, './img_caption.h5', pair_iter(), 153 | data_length=len(image_path), chunk_size=4) 154 | 155 | for idx in range(len(image_path)): 156 | data = dataset[idx] 157 | assert data['caption'] == captions[idx] 158 | 159 | os.remove('img_caption.h5') 160 | 161 | 162 | dataset = H5Dataset(schema, './img_caption.h5', pair_iter(), 163 | compression='lzf') 164 | 165 | for idx in range(len(image_path)): 166 | data = dataset[idx] 167 | assert data['caption'] == captions[idx] 168 | 169 | os.remove('img_caption.h5') 170 | 171 | 172 | 173 | def test_seq_schema(self): 174 | ''' 175 | Suitable for tokenized sequences 176 | ''' 177 | from h5record.attributes import Sequence, FloatSequence, Float16Sequence 178 | import numpy as np 179 | 180 | schema = ( 181 | Sequence(name='seq1'), 182 | Sequence(name='seq2') 183 | ) 184 | 185 | data = [ 186 | [ np.array([ 0, 1, 2, 3 ]), np.array([ 1, 2, 3 ]) ], 187 | [ np.array([ 0, 1, 2, 3, 4, 5 ]), np.array([ 1, 2, 3, 3 ]) ], 188 | [ np.array([ 0, 1, 2 ]), np.array([ 1, 2, -1 ]) ], 189 | ] 190 | def pair_iter(): 191 | for (seq1, seq2) in data: 192 | yield { 193 | 'seq1': seq1, 194 | 'seq2': seq2 195 | } 196 | if os.path.exists('tokens.h5'): 197 | os.remove('tokens.h5') 198 | 199 | dataset = H5Dataset(schema, './tokens.h5', pair_iter(), 200 | data_length=len(data), chunk_size=4) 201 | 202 | for idx in range(len(data)): 203 | row = dataset[idx] 204 | assert (row['seq1'][0] == data[idx][0]).all() 205 | 206 | os.remove('tokens.h5') 207 | 208 | data = [ 209 | [ np.array([ 0.1, 1, 2, 3 ]), np.array([ 1, 2.13, 3 ]) ], 210 | [ np.array([ 0.11, 1, 2, 3, 4, 5 ]), np.array([ 1, 2, 3.33333333, 3 ]) ], 211 | [ np.array([ 3.14159, 1, 2 ]), np.array([ 1.988, 2, -1 ]) ], 212 | ] 213 | schema = ( 214 | FloatSequence(name='seq1'), 215 | FloatSequence(name='seq2') 216 | ) 217 | if os.path.exists('tokens.h5'): 218 | os.remove('tokens.h5') 219 | 220 | dataset = H5Dataset(schema, './tokens.h5', pair_iter(), 221 | data_length=len(data), chunk_size=4) 222 | 223 | for idx in range(len(data)): 224 | row = dataset[idx] 225 | assert (row['seq1'][0] - data[idx][0]).sum() < 1e-6 226 | assert (row['seq2'][0] - data[idx][1]).sum() < 1e-6 227 | 228 | os.remove('tokens.h5') 229 | 230 | schema = ( 231 | Float16Sequence(name='seq1'), 232 | Float16Sequence(name='seq2') 233 | ) 234 | if os.path.exists('tokens.h5'): 235 | os.remove('tokens.h5') 236 | 237 | dataset = H5Dataset(schema, './tokens.h5', pair_iter(), 238 | data_length=len(data), chunk_size=4) 239 | 240 | for idx in range(len(data)): 241 | row = dataset[idx] 242 | assert (row['seq1'][0] - data[idx][0]).mean() < 1e-3 243 | assert (row['seq2'][0] - data[idx][1]).mean() < 1e-3 244 | 245 | os.remove('tokens.h5') 246 | 247 | -------------------------------------------------------------------------------- /h5record/attributes.py: -------------------------------------------------------------------------------- 1 | import h5py as h5 2 | import numpy as np 3 | 4 | try: 5 | from PIL import Image as PImage 6 | from PIL import ImageSequence as PImageSequence 7 | except ImportError as e: 8 | pass 9 | 10 | class Attribute(): 11 | 12 | def append(self, h5, data): 13 | raise NotImplementedError 14 | 15 | def transform(self, data): 16 | raise NotImplementedError 17 | 18 | 19 | def init_attributes(self, fout, value, compression, data_length): 20 | value = self.transform(value) 21 | max_shape = list(self.max_shape) 22 | max_shape[0] = data_length 23 | max_shape = tuple(max_shape) 24 | 25 | shape = value.shape 26 | fout.create_dataset(self.name, data=value, shape=shape, 27 | maxshape=max_shape, 28 | dtype=self.dtype, 29 | compression=compression ) 30 | 31 | 32 | class Integer(Attribute): 33 | ''' 34 | One dimensional data shape 35 | ''' 36 | dtype = 'int64' 37 | def __init__(self, name='label'): 38 | self.name = name 39 | self.shape = (None, ) 40 | self.max_shape = (None, ) 41 | 42 | def append(self, h5, data): 43 | h5[self.name].resize( h5[self.name].shape[0]+data.shape[0], axis=0) 44 | h5[self.name][-data.shape[0]:] = data 45 | return h5 46 | 47 | def transform(self, data): 48 | return np.array([data]) 49 | 50 | class Float(Attribute): 51 | ''' 52 | One dimensional data shape 53 | ''' 54 | dtype = 'float32' 55 | def __init__(self, name='label'): 56 | self.name = name 57 | self.shape = (None, ) 58 | self.max_shape = (None, ) 59 | 60 | def append(self, h5, data): 61 | h5[self.name].resize( h5[self.name].shape[0]+data.shape[0], axis=0) 62 | h5[self.name][-data.shape[0]:] = data 63 | return h5 64 | 65 | def transform(self, data): 66 | return np.array([data]) 67 | 68 | class Float16(Float): 69 | ''' 70 | One dimensional data shape 71 | ''' 72 | dtype = 'float16' 73 | 74 | 75 | class Image(Attribute): 76 | 77 | dtype = 'uint8' 78 | def __init__(self, h, w, c=3, name='image'): 79 | self.c = c 80 | self.h = h 81 | self.w = w 82 | self.name = name 83 | 84 | self.shape = (None, self.c, self.h, self.w) 85 | self.max_shape = (None, self.c, self.h, self.w) 86 | 87 | def read_image(self, filename, resize=True): 88 | # read image by file path and return numpy object 89 | img = PImage.open(filename) 90 | if resize: 91 | img = img.resize((self.w, self.h)) 92 | np_img = np.array(img) 93 | if len(np_img.shape) == 2: 94 | np_img = np.expand_dims(np_img, -1) 95 | np_img = np.repeat(np_img, 3, axis=-1) 96 | 97 | return np.transpose(np_img, (2, 0, 1)) 98 | 99 | def append(self, h5, data): 100 | h5[self.name].resize( h5[self.name].shape[0]+data.shape[0], axis=0) 101 | h5[self.name][-data.shape[0]:] = data 102 | return h5 103 | 104 | def transform(self, data): 105 | data = np.array(data) 106 | if len(data.shape) == 3: 107 | # ensure data shape is B x C x H x W 108 | data = np.expand_dims(data, axis=0) 109 | return data 110 | 111 | class ImageSequence(Attribute): 112 | dtype = h5.special_dtype(vlen=np.dtype('uint8')) 113 | img_channel = 3 114 | 115 | def __init__(self, h, w, c=3, name='img_seq'): 116 | self.c = c 117 | self.h = h 118 | self.w = w 119 | self.name = name 120 | 121 | self.shape = (1, 1, ) 122 | self.max_shape = (None, 1, ) 123 | 124 | def read_gif(self, filename, resize=True): 125 | # read image by file path and return numpy object 126 | # Channel x H x W x Length 127 | gif = PImage.open(filename) 128 | np_images = [] 129 | for frame in PImageSequence.Iterator(gif): 130 | if resize: 131 | frame = frame.resize((self.w, self.h)) 132 | np_img = np.array(frame) 133 | if len(np_img.shape) == 2: 134 | np_img = np.expand_dims(np_img, -1) 135 | np_img = np.repeat(np_img, self.img_channel, axis=-1) 136 | np_images.append(np_img) 137 | gif.close() 138 | np_images = np.stack(np_images) 139 | 140 | return np.transpose(np_images.astype(self.dtype), (3, 1, 2, 0)).flatten() 141 | 142 | 143 | def transform(self, data): 144 | data = np.array(data) 145 | if len(data.shape) == 1: 146 | # make sure its 1 x C x H x W x sequence length 147 | data = np.array([data.flatten() ], dtype=self.dtype) 148 | return data 149 | 150 | def append(self, h5, data): 151 | h5[self.name].resize( h5[self.name].shape[0]+data.shape[0], axis=0) 152 | h5[self.name][-1] = data 153 | return h5 154 | 155 | def init_attributes(self, fout, value, compression, data_length): 156 | max_shape = self.max_shape 157 | max_shape = list(self.max_shape) 158 | max_shape[0] = data_length 159 | max_shape = tuple(max_shape) 160 | 161 | dset = fout.create_dataset(self.name, 162 | shape=self.shape, 163 | maxshape=max_shape, 164 | dtype=self.dtype, 165 | compression=compression ) 166 | dset[0] = value 167 | 168 | 169 | class Sequence(Attribute): 170 | 171 | dtype = h5.special_dtype(vlen=np.dtype('int32')) 172 | 173 | def __init__(self, name='sequence', sub_attributes=None): 174 | self.name = name 175 | self.shape = (1, 1, ) 176 | self.sub_attributes = sub_attributes 177 | self.max_shape = (None, 1, ) 178 | 179 | 180 | def append(self, h5, data): 181 | if isinstance(data, dict): 182 | for key in self.sub_attributes: 183 | np_seq = data[key] 184 | # np_seq = np.array([np_seq], dtype=self.dtype) 185 | h5[self.name + '_'+key].resize( h5[self.name + '_'+key].shape[0]+np_seq.shape[0], axis=0) 186 | h5[self.name + '_'+key][-np_seq.shape[0]:] = np_seq 187 | elif isinstance(data, np.ndarray): 188 | h5[self.name].resize( h5[self.name].shape[0]+data.shape[0], axis=0) 189 | h5[self.name][-data.shape[0]:] = data 190 | else: 191 | raise ValueError("invalid data type: {}".format(type(data))) 192 | 193 | return h5 194 | 195 | 196 | def transform(self, data): 197 | if isinstance(data, dict): 198 | assert self.sub_attributes is not None, "sub attributes not defined" 199 | 200 | for key in self.sub_attributes: 201 | np_seq = data[key] 202 | if len(np_seq.shape) == 2: 203 | np_seq = np.array([np_seq], dtype=self.dtype) 204 | data[key] = np_seq 205 | return data 206 | elif isinstance(data, np.ndarray): 207 | if len(data.shape) == 1: 208 | # make sure its B x 1 x sequence length 209 | data = np.array([data], dtype=self.dtype) 210 | return data 211 | else: 212 | raise ValueError("invalid data type: {}".format(type(data))) 213 | 214 | def init_attributes(self, fout, value, compression, data_length): 215 | max_shape = self.max_shape 216 | max_shape = list(self.max_shape) 217 | max_shape[0] = data_length 218 | max_shape = tuple(max_shape) 219 | 220 | dset = fout.create_dataset(self.name, 221 | shape=self.shape, 222 | maxshape=max_shape, 223 | dtype=self.dtype, 224 | compression=compression ) 225 | dset[0] = value 226 | 227 | # hard to define how small float should be 228 | class FloatSequence(Sequence): 229 | dtype = h5.special_dtype(vlen=np.dtype('float32')) 230 | 231 | class Float16Sequence(Sequence): 232 | dtype = h5.special_dtype(vlen=np.dtype('float16')) 233 | 234 | 235 | class String(Attribute): 236 | 237 | encoding = 'utf-8' 238 | dtype = h5.string_dtype(encoding='utf-8') 239 | 240 | def __init__(self, name='string'): 241 | self.name = name 242 | self.max_shape = (None, 1) 243 | self.shape = None 244 | 245 | def append(self, h5, data): 246 | buf_size = len(data) 247 | h5[self.name].resize((h5[self.name].shape[0]+buf_size), axis=0) 248 | h5[self.name][-buf_size:] = data 249 | return h5 250 | 251 | def transform(self, data): 252 | assert isinstance(data, str) 253 | return [data] 254 | 255 | def init_attributes(self, fout, value, compression, data_length): 256 | value = self.transform(value) 257 | max_shape = list(self.max_shape) 258 | max_shape[0] = data_length 259 | max_shape = tuple(max_shape) 260 | 261 | shape = (len(value), 1) 262 | 263 | fout.create_dataset(self.name, data=value, shape=shape, 264 | maxshape=max_shape, 265 | dtype=self.dtype, 266 | compression=compression ) 267 | -------------------------------------------------------------------------------- /test/test_to_mem.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | import os 3 | from h5record.dataset import H5Dataset 4 | 5 | class TestDataModality(unittest.TestCase): 6 | 7 | 8 | def test_list_based_schema(self): 9 | from h5record.attributes import String, Integer, Image 10 | 11 | img_attr = Image(name='image', h=32, w=32) 12 | 13 | schema = [ 14 | img_attr, 15 | Integer(name='label'), 16 | String(name='sentence1'), 17 | String(name='sentence2'), 18 | ] 19 | image_paths = [ 20 | 'test/images/1.jpeg', 21 | 'test/images/2.jpeg', 22 | 'test/images/3.jpeg', 23 | ] 24 | 25 | data = [ 26 | ['HDF5 supports chunking and compression.', 'You may want to experiment', 0 ], 27 | ['Starting to load file into an HDF5 file with chunk size','and compression is gzip', 1 ], 28 | ['Reading from an HDF5 file which you will probably be','about to overwrite! Override this error only if you know what youre doing ', 1], 29 | ] 30 | 31 | def pair_iter(): 32 | for (row, image_path) in zip(data, image_paths): 33 | yield { 34 | 'sentence1': row[0], 35 | 'sentence2': row[1], 36 | 'image': img_attr.read_image(image_path), 37 | 'label': row[2] 38 | } 39 | if os.path.exists('question_image_pair.h5'): 40 | os.remove('question_image_pair.h5') 41 | 42 | dataset = H5Dataset(schema, './question_image_pair.h5', pair_iter(), 43 | to_memory=True) 44 | for idx in range(len(data)): 45 | row = dataset[idx] 46 | sent1, sent2, label = data[idx] 47 | assert sent1 == row['sentence1'] 48 | assert sent2 == row['sentence2'] 49 | assert label == row['label'] 50 | 51 | os.remove('question_image_pair.h5') 52 | 53 | dataset = H5Dataset(schema, './question_image_pair.h5', pair_iter(), 54 | data_length=len(data), chunk_size=4) 55 | 56 | for idx in range(len(data)): 57 | row = dataset[idx] 58 | sent1, sent2, label = data[idx] 59 | assert sent1 == row['sentence1'] 60 | assert sent2 == row['sentence2'] 61 | assert label == row['label'] 62 | 63 | os.remove('question_image_pair.h5') 64 | 65 | def test_tuple_based_schema(self): 66 | from h5record.attributes import String, Integer, Image 67 | 68 | img_attr = Image(name='image', h=32, w=32) 69 | 70 | schema = ( 71 | String(name='sentence1'), 72 | String(name='sentence2'), 73 | img_attr, 74 | Integer(name='label') 75 | ) 76 | image_paths = [ 77 | 'test/images/1.jpeg', 78 | 'test/images/2.jpeg', 79 | 'test/images/3.jpeg', 80 | ] 81 | 82 | data = [ 83 | ['HDF5 supports chunking and compression.', 'You may want to experiment', 0 ], 84 | ['Starting to load file into an HDF5 file with chunk size','and compression is gzip', 1 ], 85 | ['Reading from an HDF5 file which you will probably be','about to overwrite! Override this error only if you know what youre doing ', 1], 86 | ] 87 | 88 | def pair_iter(): 89 | for (row, image_path) in zip(data, image_paths): 90 | yield { 91 | 'sentence1': row[0], 92 | 'sentence2': row[1], 93 | 'image': img_attr.read_image(image_path), 94 | 'label': row[2] 95 | } 96 | if os.path.exists('question_image_pair.h5'): 97 | os.remove('question_image_pair.h5') 98 | 99 | dataset = H5Dataset(schema, 100 | './question_image_pair.h5', pair_iter(), 101 | to_memory=True) 102 | for idx in range(len(data)): 103 | row = dataset[idx] 104 | sent1, sent2, label = data[idx] 105 | assert sent1 == row['sentence1'] 106 | assert sent2 == row['sentence2'] 107 | assert label == row['label'] 108 | 109 | os.remove('question_image_pair.h5') 110 | 111 | dataset = H5Dataset(schema, './question_image_pair.h5', pair_iter(), 112 | data_length=len(data), chunk_size=4) 113 | 114 | for idx in range(len(data)): 115 | row = dataset[idx] 116 | sent1, sent2, label = data[idx] 117 | assert sent1 == row['sentence1'] 118 | assert sent2 == row['sentence2'] 119 | assert label == row['label'] 120 | 121 | os.remove('question_image_pair.h5') 122 | 123 | 124 | def test_dict_based_schema(self): 125 | from h5record.attributes import String, Image 126 | 127 | schema = { 128 | 'image': Image(name='image', h=32, w=32), 129 | 'caption': String(name='caption') 130 | } 131 | 132 | captions = [ 133 | 'Lenna profile', 134 | 'Lenna back patch', 135 | 'Lenna lower patch', 136 | 'meme image', 137 | 'greyscale image' 138 | ] 139 | image_path = [ 140 | 'test/images/1.jpeg', 141 | 'test/images/2.jpeg', 142 | 'test/images/3.jpeg', 143 | 'test/images/4.jpeg', 144 | 'test/images/5.jpeg', 145 | ] 146 | def pair_iter(): 147 | for (caption, img_path) in zip(captions, image_path): 148 | yield { 149 | 'image': schema['image'].read_image(image_path[0]), 150 | 'caption': caption 151 | } 152 | if os.path.exists('img_caption.h5'): 153 | os.remove('img_caption.h5') 154 | 155 | dataset = H5Dataset(schema, './img_caption.h5', pair_iter(), 156 | data_length=len(image_path), chunk_size=4, to_memory=True) 157 | 158 | for idx in range(len(image_path)): 159 | data = dataset[idx] 160 | assert data['caption'] == captions[idx] 161 | 162 | os.remove('img_caption.h5') 163 | 164 | 165 | dataset = H5Dataset(schema, './img_caption.h5', pair_iter(), 166 | compression='lzf') 167 | 168 | for idx in range(len(image_path)): 169 | data = dataset[idx] 170 | assert data['caption'] == captions[idx] 171 | 172 | os.remove('img_caption.h5') 173 | 174 | 175 | 176 | def test_seq_schema(self): 177 | ''' 178 | Suitable for tokenized sequences 179 | ''' 180 | from h5record.attributes import Sequence, FloatSequence, Float16Sequence 181 | import numpy as np 182 | 183 | schema = ( 184 | Sequence(name='seq1'), 185 | Sequence(name='seq2') 186 | ) 187 | 188 | data = [ 189 | [ np.array([ 0, 1, 2, 3 ]), np.array([ 1, 2, 3 ]) ], 190 | [ np.array([ 0, 1, 2, 3, 4, 5 ]), np.array([ 1, 2, 3, 3 ]) ], 191 | [ np.array([ 0, 1, 2 ]), np.array([ 1, 2, -1 ]) ], 192 | ] 193 | def pair_iter(): 194 | for (seq1, seq2) in data: 195 | yield { 196 | 'seq1': seq1, 197 | 'seq2': seq2 198 | } 199 | if os.path.exists('tokens.h5'): 200 | os.remove('tokens.h5') 201 | 202 | dataset = H5Dataset(schema, './tokens.h5', pair_iter(), 203 | data_length=len(data), chunk_size=4, to_memory=True) 204 | 205 | for idx in range(len(data)): 206 | row = dataset[idx] 207 | assert (row['seq1'][0] == data[idx][0]).all() 208 | 209 | os.remove('tokens.h5') 210 | 211 | data = [ 212 | [ np.array([ 0.1, 1, 2, 3 ]), np.array([ 1, 2.13, 3 ]) ], 213 | [ np.array([ 0.11, 1, 2, 3, 4, 5 ]), np.array([ 1, 2, 3.33333333, 3 ]) ], 214 | [ np.array([ 3.14159, 1, 2 ]), np.array([ 1.988, 2, -1 ]) ], 215 | ] 216 | schema = ( 217 | FloatSequence(name='seq1'), 218 | FloatSequence(name='seq2') 219 | ) 220 | if os.path.exists('tokens.h5'): 221 | os.remove('tokens.h5') 222 | 223 | dataset = H5Dataset(schema, './tokens.h5', pair_iter(), 224 | data_length=len(data), chunk_size=4) 225 | 226 | for idx in range(len(data)): 227 | row = dataset[idx] 228 | assert (row['seq1'][0] - data[idx][0]).sum() < 1e-6 229 | assert (row['seq2'][0] - data[idx][1]).sum() < 1e-6 230 | 231 | os.remove('tokens.h5') 232 | 233 | schema = ( 234 | Float16Sequence(name='seq1'), 235 | Float16Sequence(name='seq2') 236 | ) 237 | if os.path.exists('tokens.h5'): 238 | os.remove('tokens.h5') 239 | 240 | dataset = H5Dataset(schema, './tokens.h5', pair_iter(), 241 | data_length=len(data), chunk_size=4) 242 | 243 | for idx in range(len(data)): 244 | row = dataset[idx] 245 | assert (row['seq1'][0] - data[idx][0]).mean() < 1e-3 246 | assert (row['seq2'][0] - data[idx][1]).mean() < 1e-3 247 | 248 | os.remove('tokens.h5') 249 | 250 | --------------------------------------------------------------------------------