├── .gitignore ├── README.md ├── create_lmdb.py ├── dataset.py ├── dataset_test.py └── proto ├── tensor.proto ├── tensor_pb2.py └── utils.py /.gitignore: -------------------------------------------------------------------------------- 1 | __pycache__/ 2 | *.pyc 3 | .DS_Store 4 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Pytorch Lmdb Dataloader 2 | Use lmdb with protobuf to efficiently read big data for pytorch training. 3 | 4 | ## Getting Started 5 | 1. Install python and protobuf. It's convinient to get protoc in grpc_tools. 6 | ```shell 7 | pip install grpcio grpcio-tools 8 | ``` 9 | 2. Generate proto. 10 | ```shell 11 | python -m grpc_tools.protoc -I./proto --python_out=./proto ./proto/tensor.proto 12 | ``` 13 | 3. Create dummy training data. 14 | ```shell 15 | python create_lmdb.py --output_file train_lmdb 16 | ``` 17 | 4. Run the unit testing. 18 | ```shell 19 | python dataset_test.py 20 | ``` 21 | 22 | ## Reference 23 | * https://discuss.pytorch.org/t/whats-the-best-way-to-load-large-data/2977 24 | * https://github.com/pytorch/pytorch/blob/master/caffe2/python/examples/lmdb_create_example.py 25 | -------------------------------------------------------------------------------- /create_lmdb.py: -------------------------------------------------------------------------------- 1 | # Copyright 2019 ASLP@NPU. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # 15 | # Author: npuichigo@gmail.com (zhangyuchao) 16 | 17 | from __future__ import absolute_import 18 | from __future__ import division 19 | from __future__ import print_function 20 | from __future__ import unicode_literals 21 | 22 | import argparse 23 | import lmdb 24 | import numpy as np 25 | 26 | from proto import utils 27 | from proto import tensor_pb2 28 | 29 | 30 | def create_db(output_file): 31 | print(">>> Write database...") 32 | LMDB_MAP_SIZE = 1 << 40 # MODIFY 33 | print(LMDB_MAP_SIZE) 34 | env = lmdb.open(output_file, map_size=LMDB_MAP_SIZE) 35 | 36 | checksum = 0 37 | with env.begin(write=True) as txn: 38 | for j in range(0, 1024): 39 | # MODIFY: add your own data reader / creator 40 | width = 64 41 | height = 32 42 | img_data = np.random.rand(3, width, height).astype(np.float32) 43 | label = np.asarray(j % 10) 44 | 45 | # Create TensorProtos 46 | tensor_protos = tensor_pb2.TensorProtos() 47 | img_tensor = utils.numpy_array_to_tensor(img_data) 48 | tensor_protos.protos.extend([img_tensor]) 49 | 50 | label_tensor = utils.numpy_array_to_tensor(label) 51 | tensor_protos.protos.extend([label_tensor]) 52 | txn.put( 53 | '{}'.format(j).encode('ascii'), 54 | tensor_protos.SerializeToString() 55 | ) 56 | 57 | if (j % 16 == 0): 58 | print("Inserted {} rows".format(j)) 59 | 60 | 61 | def main(): 62 | parser = argparse.ArgumentParser( 63 | description="LMDB creation" 64 | ) 65 | parser.add_argument("--output_file", type=str, default=None, 66 | help="Path to write the database to", 67 | required=True) 68 | args = parser.parse_args() 69 | 70 | create_db(args.output_file) 71 | 72 | 73 | if __name__ == '__main__': 74 | main() 75 | -------------------------------------------------------------------------------- /dataset.py: -------------------------------------------------------------------------------- 1 | # Copyright 2019 ASLP@NPU. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # 15 | # Author: npuichigo@gmail.com (zhangyuchao) 16 | 17 | from __future__ import absolute_import 18 | from __future__ import division 19 | from __future__ import print_function 20 | 21 | import numpy as np 22 | from torch.utils.data import Dataset 23 | 24 | from proto import tensor_pb2 25 | from proto import utils 26 | 27 | 28 | class LmdbDataset(Dataset): 29 | """Lmdb dataset.""" 30 | 31 | def __init__(self, lmdb_path): 32 | super(LmdbDataset, self).__init__() 33 | import lmdb 34 | self.env = lmdb.open(lmdb_path, max_readers=1, readonly=True, lock=False, 35 | readahead=False, meminit=False) 36 | with self.env.begin(write=False) as txn: 37 | self.length = txn.stat()['entries'] 38 | self.keys = [key for key, _ in txn.cursor()] 39 | 40 | def __getitem__(self, index): 41 | with self.env.begin(write=False) as txn: 42 | serialized_str = txn.get(self.keys[index]) 43 | tensor_protos = tensor_pb2.TensorProtos() 44 | tensor_protos.ParseFromString(serialized_str) 45 | img = utils.tensor_to_numpy_array(tensor_protos.protos[0]) 46 | label = utils.tensor_to_numpy_array(tensor_protos.protos[1]) 47 | return img, label 48 | 49 | def __len__(self): 50 | return self.length 51 | -------------------------------------------------------------------------------- /dataset_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2019 ASLP@NPU. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # 15 | # Author: npuichigo@gmail.com (zhangyuchao) 16 | 17 | from __future__ import absolute_import 18 | from __future__ import division 19 | from __future__ import print_function 20 | 21 | import unittest 22 | 23 | from torch.utils.data import DataLoader 24 | 25 | from dataset import LmdbDataset 26 | 27 | 28 | class LmdbDatasetTest(unittest.TestCase): 29 | def setUp(self): 30 | self.dataset = LmdbDataset("train_lmdb") 31 | self.dataloader = DataLoader(self.dataset, batch_size=32, shuffle=True, 32 | num_workers=4) 33 | 34 | def testRead(self): 35 | for i, data in enumerate(self.dataloader): 36 | img, label = data 37 | print(i, img.shape, label.shape) 38 | 39 | if __name__ == '__main__': 40 | unittest.main() 41 | -------------------------------------------------------------------------------- /proto/tensor.proto: -------------------------------------------------------------------------------- 1 | syntax = "proto3"; 2 | 3 | // TensorProtos stores multiple TensorProto objects in one single proto. 4 | message TensorProtos { 5 | repeated TensorProto protos = 1; 6 | } 7 | 8 | // TensorProto stores serialized Tensor objects. 9 | message TensorProto { 10 | // The dimensions in the tensor. 11 | repeated int64 dims = 1; 12 | 13 | // Data type 14 | enum DataType { 15 | UNDEFINED = 0; 16 | 17 | // Basic types 18 | FLOAT = 1; // float 19 | INT32 = 2; // int 20 | BYTE = 3; // byte, when deserialized, is going to be restored as uint8 21 | STRING = 4; // string 22 | 23 | // Less-commonly used data types 24 | BOOL = 5; // bool 25 | UINT8 = 6; // uint8_t 26 | INT8 = 7; // int8_t 27 | UINT16 = 8; // uint16_t 28 | INT16 = 9; // int16_t 29 | INT64 = 10; // int64_t 30 | FLOAT16 = 12; // at::Half 31 | DOUBLE = 13; // double 32 | } 33 | DataType data_type = 2; 34 | 35 | // For float 36 | repeated float float_data = 3; 37 | // For int32, uint8, int8, uint16, int16, bool, and float16 38 | // Note about float16: in storage we will basically convert float16 byte-wise 39 | // to unsigned short and then store them in the int32_data field. 40 | repeated int32 int32_data = 4; 41 | // For bytes 42 | bytes byte_data = 5; 43 | // For strings 44 | repeated bytes string_data = 6; 45 | // For double 46 | repeated double double_data = 7; 47 | // For int64 48 | repeated int64 int64_data = 8; 49 | } 50 | -------------------------------------------------------------------------------- /proto/tensor_pb2.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # Generated by the protocol buffer compiler. DO NOT EDIT! 3 | # source: tensor.proto 4 | 5 | import sys 6 | _b=sys.version_info[0]<3 and (lambda x:x) or (lambda x:x.encode('latin1')) 7 | from google.protobuf import descriptor as _descriptor 8 | from google.protobuf import message as _message 9 | from google.protobuf import reflection as _reflection 10 | from google.protobuf import symbol_database as _symbol_database 11 | # @@protoc_insertion_point(imports) 12 | 13 | _sym_db = _symbol_database.Default() 14 | 15 | 16 | 17 | 18 | DESCRIPTOR = _descriptor.FileDescriptor( 19 | name='tensor.proto', 20 | package='', 21 | syntax='proto3', 22 | serialized_options=None, 23 | serialized_pb=_b('\n\x0ctensor.proto\",\n\x0cTensorProtos\x12\x1c\n\x06protos\x18\x01 \x03(\x0b\x32\x0c.TensorProto\"\x94\x01\n\x06Tensor\x12\x0c\n\x04\x64ims\x18\x01 \x03(\x03\x12#\n\tdata_type\x18\x02 \x01(\x0e\x32\x10.Tensor.DataType\x12\x12\n\nfloat_data\x18\x03 \x03(\x02\x12\x12\n\nint32_data\x18\x04 \x03(\x11\"/\n\x08\x44\x61taType\x12\r\n\tUNDEFINED\x10\x00\x12\t\n\x05\x46LOAT\x10\x01\x12\t\n\x05INT32\x10\x02\"\xe0\x02\n\x0bTensorProto\x12\x0c\n\x04\x64ims\x18\x01 \x03(\x03\x12(\n\tdata_type\x18\x02 \x01(\x0e\x32\x15.TensorProto.DataType\x12\x12\n\nfloat_data\x18\x03 \x03(\x02\x12\x12\n\nint32_data\x18\x04 \x03(\x05\x12\x11\n\tbyte_data\x18\x05 \x01(\x0c\x12\x13\n\x0bstring_data\x18\x06 \x03(\x0c\x12\x13\n\x0b\x64ouble_data\x18\x07 \x03(\x01\x12\x12\n\nint64_data\x18\x08 \x03(\x03\"\x9f\x01\n\x08\x44\x61taType\x12\r\n\tUNDEFINED\x10\x00\x12\t\n\x05\x46LOAT\x10\x01\x12\t\n\x05INT32\x10\x02\x12\x08\n\x04\x42YTE\x10\x03\x12\n\n\x06STRING\x10\x04\x12\x08\n\x04\x42OOL\x10\x05\x12\t\n\x05UINT8\x10\x06\x12\x08\n\x04INT8\x10\x07\x12\n\n\x06UINT16\x10\x08\x12\t\n\x05INT16\x10\t\x12\t\n\x05INT64\x10\n\x12\x0b\n\x07\x46LOAT16\x10\x0c\x12\n\n\x06\x44OUBLE\x10\rb\x06proto3') 24 | ) 25 | 26 | 27 | 28 | _TENSOR_DATATYPE = _descriptor.EnumDescriptor( 29 | name='DataType', 30 | full_name='Tensor.DataType', 31 | filename=None, 32 | file=DESCRIPTOR, 33 | values=[ 34 | _descriptor.EnumValueDescriptor( 35 | name='UNDEFINED', index=0, number=0, 36 | serialized_options=None, 37 | type=None), 38 | _descriptor.EnumValueDescriptor( 39 | name='FLOAT', index=1, number=1, 40 | serialized_options=None, 41 | type=None), 42 | _descriptor.EnumValueDescriptor( 43 | name='INT32', index=2, number=2, 44 | serialized_options=None, 45 | type=None), 46 | ], 47 | containing_type=None, 48 | serialized_options=None, 49 | serialized_start=164, 50 | serialized_end=211, 51 | ) 52 | _sym_db.RegisterEnumDescriptor(_TENSOR_DATATYPE) 53 | 54 | _TENSORPROTO_DATATYPE = _descriptor.EnumDescriptor( 55 | name='DataType', 56 | full_name='TensorProto.DataType', 57 | filename=None, 58 | file=DESCRIPTOR, 59 | values=[ 60 | _descriptor.EnumValueDescriptor( 61 | name='UNDEFINED', index=0, number=0, 62 | serialized_options=None, 63 | type=None), 64 | _descriptor.EnumValueDescriptor( 65 | name='FLOAT', index=1, number=1, 66 | serialized_options=None, 67 | type=None), 68 | _descriptor.EnumValueDescriptor( 69 | name='INT32', index=2, number=2, 70 | serialized_options=None, 71 | type=None), 72 | _descriptor.EnumValueDescriptor( 73 | name='BYTE', index=3, number=3, 74 | serialized_options=None, 75 | type=None), 76 | _descriptor.EnumValueDescriptor( 77 | name='STRING', index=4, number=4, 78 | serialized_options=None, 79 | type=None), 80 | _descriptor.EnumValueDescriptor( 81 | name='BOOL', index=5, number=5, 82 | serialized_options=None, 83 | type=None), 84 | _descriptor.EnumValueDescriptor( 85 | name='UINT8', index=6, number=6, 86 | serialized_options=None, 87 | type=None), 88 | _descriptor.EnumValueDescriptor( 89 | name='INT8', index=7, number=7, 90 | serialized_options=None, 91 | type=None), 92 | _descriptor.EnumValueDescriptor( 93 | name='UINT16', index=8, number=8, 94 | serialized_options=None, 95 | type=None), 96 | _descriptor.EnumValueDescriptor( 97 | name='INT16', index=9, number=9, 98 | serialized_options=None, 99 | type=None), 100 | _descriptor.EnumValueDescriptor( 101 | name='INT64', index=10, number=10, 102 | serialized_options=None, 103 | type=None), 104 | _descriptor.EnumValueDescriptor( 105 | name='FLOAT16', index=11, number=12, 106 | serialized_options=None, 107 | type=None), 108 | _descriptor.EnumValueDescriptor( 109 | name='DOUBLE', index=12, number=13, 110 | serialized_options=None, 111 | type=None), 112 | ], 113 | containing_type=None, 114 | serialized_options=None, 115 | serialized_start=407, 116 | serialized_end=566, 117 | ) 118 | _sym_db.RegisterEnumDescriptor(_TENSORPROTO_DATATYPE) 119 | 120 | 121 | _TENSORPROTOS = _descriptor.Descriptor( 122 | name='TensorProtos', 123 | full_name='TensorProtos', 124 | filename=None, 125 | file=DESCRIPTOR, 126 | containing_type=None, 127 | fields=[ 128 | _descriptor.FieldDescriptor( 129 | name='protos', full_name='TensorProtos.protos', index=0, 130 | number=1, type=11, cpp_type=10, label=3, 131 | has_default_value=False, default_value=[], 132 | message_type=None, enum_type=None, containing_type=None, 133 | is_extension=False, extension_scope=None, 134 | serialized_options=None, file=DESCRIPTOR), 135 | ], 136 | extensions=[ 137 | ], 138 | nested_types=[], 139 | enum_types=[ 140 | ], 141 | serialized_options=None, 142 | is_extendable=False, 143 | syntax='proto3', 144 | extension_ranges=[], 145 | oneofs=[ 146 | ], 147 | serialized_start=16, 148 | serialized_end=60, 149 | ) 150 | 151 | 152 | _TENSOR = _descriptor.Descriptor( 153 | name='Tensor', 154 | full_name='Tensor', 155 | filename=None, 156 | file=DESCRIPTOR, 157 | containing_type=None, 158 | fields=[ 159 | _descriptor.FieldDescriptor( 160 | name='dims', full_name='Tensor.dims', index=0, 161 | number=1, type=3, cpp_type=2, label=3, 162 | has_default_value=False, default_value=[], 163 | message_type=None, enum_type=None, containing_type=None, 164 | is_extension=False, extension_scope=None, 165 | serialized_options=None, file=DESCRIPTOR), 166 | _descriptor.FieldDescriptor( 167 | name='data_type', full_name='Tensor.data_type', index=1, 168 | number=2, type=14, cpp_type=8, label=1, 169 | has_default_value=False, default_value=0, 170 | message_type=None, enum_type=None, containing_type=None, 171 | is_extension=False, extension_scope=None, 172 | serialized_options=None, file=DESCRIPTOR), 173 | _descriptor.FieldDescriptor( 174 | name='float_data', full_name='Tensor.float_data', index=2, 175 | number=3, type=2, cpp_type=6, label=3, 176 | has_default_value=False, default_value=[], 177 | message_type=None, enum_type=None, containing_type=None, 178 | is_extension=False, extension_scope=None, 179 | serialized_options=None, file=DESCRIPTOR), 180 | _descriptor.FieldDescriptor( 181 | name='int32_data', full_name='Tensor.int32_data', index=3, 182 | number=4, type=17, cpp_type=1, label=3, 183 | has_default_value=False, default_value=[], 184 | message_type=None, enum_type=None, containing_type=None, 185 | is_extension=False, extension_scope=None, 186 | serialized_options=None, file=DESCRIPTOR), 187 | ], 188 | extensions=[ 189 | ], 190 | nested_types=[], 191 | enum_types=[ 192 | _TENSOR_DATATYPE, 193 | ], 194 | serialized_options=None, 195 | is_extendable=False, 196 | syntax='proto3', 197 | extension_ranges=[], 198 | oneofs=[ 199 | ], 200 | serialized_start=63, 201 | serialized_end=211, 202 | ) 203 | 204 | 205 | _TENSORPROTO = _descriptor.Descriptor( 206 | name='TensorProto', 207 | full_name='TensorProto', 208 | filename=None, 209 | file=DESCRIPTOR, 210 | containing_type=None, 211 | fields=[ 212 | _descriptor.FieldDescriptor( 213 | name='dims', full_name='TensorProto.dims', index=0, 214 | number=1, type=3, cpp_type=2, label=3, 215 | has_default_value=False, default_value=[], 216 | message_type=None, enum_type=None, containing_type=None, 217 | is_extension=False, extension_scope=None, 218 | serialized_options=None, file=DESCRIPTOR), 219 | _descriptor.FieldDescriptor( 220 | name='data_type', full_name='TensorProto.data_type', index=1, 221 | number=2, type=14, cpp_type=8, label=1, 222 | has_default_value=False, default_value=0, 223 | message_type=None, enum_type=None, containing_type=None, 224 | is_extension=False, extension_scope=None, 225 | serialized_options=None, file=DESCRIPTOR), 226 | _descriptor.FieldDescriptor( 227 | name='float_data', full_name='TensorProto.float_data', index=2, 228 | number=3, type=2, cpp_type=6, label=3, 229 | has_default_value=False, default_value=[], 230 | message_type=None, enum_type=None, containing_type=None, 231 | is_extension=False, extension_scope=None, 232 | serialized_options=None, file=DESCRIPTOR), 233 | _descriptor.FieldDescriptor( 234 | name='int32_data', full_name='TensorProto.int32_data', index=3, 235 | number=4, type=5, cpp_type=1, label=3, 236 | has_default_value=False, default_value=[], 237 | message_type=None, enum_type=None, containing_type=None, 238 | is_extension=False, extension_scope=None, 239 | serialized_options=None, file=DESCRIPTOR), 240 | _descriptor.FieldDescriptor( 241 | name='byte_data', full_name='TensorProto.byte_data', index=4, 242 | number=5, type=12, cpp_type=9, label=1, 243 | has_default_value=False, default_value=_b(""), 244 | message_type=None, enum_type=None, containing_type=None, 245 | is_extension=False, extension_scope=None, 246 | serialized_options=None, file=DESCRIPTOR), 247 | _descriptor.FieldDescriptor( 248 | name='string_data', full_name='TensorProto.string_data', index=5, 249 | number=6, type=12, cpp_type=9, label=3, 250 | has_default_value=False, default_value=[], 251 | message_type=None, enum_type=None, containing_type=None, 252 | is_extension=False, extension_scope=None, 253 | serialized_options=None, file=DESCRIPTOR), 254 | _descriptor.FieldDescriptor( 255 | name='double_data', full_name='TensorProto.double_data', index=6, 256 | number=7, type=1, cpp_type=5, label=3, 257 | has_default_value=False, default_value=[], 258 | message_type=None, enum_type=None, containing_type=None, 259 | is_extension=False, extension_scope=None, 260 | serialized_options=None, file=DESCRIPTOR), 261 | _descriptor.FieldDescriptor( 262 | name='int64_data', full_name='TensorProto.int64_data', index=7, 263 | number=8, type=3, cpp_type=2, label=3, 264 | has_default_value=False, default_value=[], 265 | message_type=None, enum_type=None, containing_type=None, 266 | is_extension=False, extension_scope=None, 267 | serialized_options=None, file=DESCRIPTOR), 268 | ], 269 | extensions=[ 270 | ], 271 | nested_types=[], 272 | enum_types=[ 273 | _TENSORPROTO_DATATYPE, 274 | ], 275 | serialized_options=None, 276 | is_extendable=False, 277 | syntax='proto3', 278 | extension_ranges=[], 279 | oneofs=[ 280 | ], 281 | serialized_start=214, 282 | serialized_end=566, 283 | ) 284 | 285 | _TENSORPROTOS.fields_by_name['protos'].message_type = _TENSORPROTO 286 | _TENSOR.fields_by_name['data_type'].enum_type = _TENSOR_DATATYPE 287 | _TENSOR_DATATYPE.containing_type = _TENSOR 288 | _TENSORPROTO.fields_by_name['data_type'].enum_type = _TENSORPROTO_DATATYPE 289 | _TENSORPROTO_DATATYPE.containing_type = _TENSORPROTO 290 | DESCRIPTOR.message_types_by_name['TensorProtos'] = _TENSORPROTOS 291 | DESCRIPTOR.message_types_by_name['Tensor'] = _TENSOR 292 | DESCRIPTOR.message_types_by_name['TensorProto'] = _TENSORPROTO 293 | _sym_db.RegisterFileDescriptor(DESCRIPTOR) 294 | 295 | TensorProtos = _reflection.GeneratedProtocolMessageType('TensorProtos', (_message.Message,), dict( 296 | DESCRIPTOR = _TENSORPROTOS, 297 | __module__ = 'tensor_pb2' 298 | # @@protoc_insertion_point(class_scope:TensorProtos) 299 | )) 300 | _sym_db.RegisterMessage(TensorProtos) 301 | 302 | Tensor = _reflection.GeneratedProtocolMessageType('Tensor', (_message.Message,), dict( 303 | DESCRIPTOR = _TENSOR, 304 | __module__ = 'tensor_pb2' 305 | # @@protoc_insertion_point(class_scope:Tensor) 306 | )) 307 | _sym_db.RegisterMessage(Tensor) 308 | 309 | TensorProto = _reflection.GeneratedProtocolMessageType('TensorProto', (_message.Message,), dict( 310 | DESCRIPTOR = _TENSORPROTO, 311 | __module__ = 'tensor_pb2' 312 | # @@protoc_insertion_point(class_scope:TensorProto) 313 | )) 314 | _sym_db.RegisterMessage(TensorProto) 315 | 316 | 317 | # @@protoc_insertion_point(module_scope) 318 | -------------------------------------------------------------------------------- /proto/utils.py: -------------------------------------------------------------------------------- 1 | # Copyright 2019 ASLP@NPU. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # 15 | # Author: npuichigo@gmail.com (zhangyuchao) 16 | 17 | from __future__ import absolute_import 18 | from __future__ import division 19 | from __future__ import print_function 20 | 21 | import numpy as np 22 | 23 | from proto import tensor_pb2 24 | 25 | 26 | def tensor_to_numpy_array(tensor): 27 | if tensor.data_type == tensor_pb2.TensorProto.FLOAT: 28 | return np.asarray( 29 | tensor.float_data, dtype=np.float32).reshape(tensor.dims) 30 | elif tensor.data_type == tensor_pb2.TensorProto.DOUBLE: 31 | return np.asarray( 32 | tensor.double_data, dtype=np.float64).reshape(tensor.dims) 33 | elif tensor.data_type == tensor_pb2.TensorProto.INT32: 34 | return np.asarray( 35 | tensor.int32_data, dtype=np.int).reshape(tensor.dims) # pb.INT32=>np.int use int32_data 36 | elif tensor.data_type == tensor_pb2.TensorProto.INT16: 37 | return np.asarray( 38 | tensor.int32_data, dtype=np.int16).reshape(tensor.dims) # pb.INT16=>np.int16 use int32_data 39 | elif tensor.data_type == tensor_pb2.TensorProto.UINT16: 40 | return np.asarray( 41 | tensor.int32_data, dtype=np.uint16).reshape(tensor.dims) # pb.UINT16=>np.uint16 use int32_data 42 | elif tensor.data_type == tensor_pb2.TensorProto.INT8: 43 | return np.asarray( 44 | tensor.int32_data, dtype=np.int8).reshape(tensor.dims) # pb.INT8=>np.int8 use int32_data 45 | elif tensor.data_type == tensor_pb2.TensorProto.UINT8: 46 | return np.asarray( 47 | tensor.int32_data, dtype=np.uint8).reshape(tensor.dims) # pb.UINT8=>np.uint8 use int32_data 48 | else: 49 | # TODO: complete the data type: bool, float16, byte, int64, string 50 | raise RuntimeError( 51 | "Tensor data type not supported yet: " + str(tensor.data_type)) 52 | 53 | 54 | def numpy_array_to_tensor(arr): 55 | tensor = tensor_pb2.TensorProto() 56 | tensor.dims.extend(arr.shape) 57 | if arr.dtype == np.float32: 58 | tensor.data_type = tensor_pb2.TensorProto.FLOAT 59 | tensor.float_data.extend(list(arr.flatten().astype(float))) 60 | elif arr.dtype == np.float64: 61 | tensor.data_type = tensor_pb2.TensorProto.DOUBLE 62 | tensor.double_data.extend(list(arr.flatten().astype(np.float64))) 63 | elif arr.dtype == np.int or arr.dtype == np.int32: 64 | tensor.data_type = tensor_pb2.TensorProto.INT32 65 | tensor.int32_data.extend(arr.flatten().astype(np.int).tolist()) 66 | elif arr.dtype == np.int16: 67 | tensor.data_type = tensor_pb2.TensorProto.INT16 68 | tensor.int32_data.extend(list(arr.flatten().astype(np.int16))) # np.int16=>pb.INT16 use int32_data 69 | elif arr.dtype == np.uint16: 70 | tensor.data_type = tensor_pb2.TensorProto.UINT16 71 | tensor.int32_data.extend(list(arr.flatten().astype(np.uint16))) # np.uint16=>pb.UNIT16 use int32_data 72 | elif arr.dtype == np.int8: 73 | tensor.data_type = tensor_pb2.TensorProto.INT8 74 | tensor.int32_data.extend(list(arr.flatten().astype(np.int8))) # np.int8=>pb.INT8 use int32_data 75 | elif arr.dtype == np.uint8: 76 | tensor.data_type = tensor_pb2.TensorProto.UINT8 77 | tensor.int32_data.extend(list(arr.flatten().astype(np.uint8))) # np.uint8=>pb.UNIT8 use int32_data 78 | else: 79 | # TODO: complete the data type: bool, float16, byte, int64, string 80 | raise RuntimeError( 81 | "Numpy data type not supported yet: " + str(arr.dtype)) 82 | return tensor 83 | --------------------------------------------------------------------------------