├── write2tar.py ├── README.md └── base_dataset_tar.py /write2tar.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | import tensorflow as tf 3 | import numpy as np 4 | import base64 5 | from absl import logging 6 | import webdataset as wds 7 | 8 | 9 | if __name__ == "__main__": 10 | import sys, glob, os 11 | 12 | record_filename = sys.argv[1] 13 | output_dir = sys.argv[2] 14 | if not os.path.exists(output_dir): 15 | os.mkdir(output_dir) 16 | 17 | 18 | index = 0 19 | num_per_dir = 2000 20 | 21 | pattern = os.path.join(output_dir, f"eng_zh-%06d.tar") 22 | sink = wds.ShardWriter(pattern, maxsize=int(6e9), maxcount=int(num_per_dir)) 23 | 24 | all_lines = [] 25 | with open(record_filename, 'r') as f: 26 | for line in f: 27 | items = line.strip('\n').split('\t') 28 | text = items[0] 29 | embedding = "\t".join(items[1:]) 30 | xkey = "%07d" % index 31 | sample = { 32 | "__key__": xkey, 33 | "text": text, 34 | "embedding": embedding 35 | } 36 | # Write the sample to the sharded tar archives. 37 | sink.write(sample) 38 | index += 1 39 | 40 | sink.close() 41 | 42 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # webdataset封装接口 2 | pytorch大规模数据读取dataset 3 | ## 说明 4 | 对于tensorflow,读取大规模数据可以使用tfrecord进行存储,而pytorch则只能通过读取小文件、或者通过pickle文件打包,IterableDataset进行读取。 5 | 在这类场景下,webdataset是一个比较好的工具。这里我将其进行了封装,可以通过继承,复用已有接口 6 | 7 | ## How to use 8 | - 接口说明: 9 | - Args: 10 | - tar_pattern: 格式path/video-tfr-{000000..000003}.tar 11 | - length: 数据样本长度 12 | - local_rank: reduce操作所在机器 13 | - world_size: 单台机器内GPU卡数 14 | - keys:tar文件内的数据的key的子集,需要用来解析的key 15 | - decode_key_funcs: key对应的数据decode操作,可以为编解码等相关操作 16 | - process_key_funcs: 每个key对应的处理操作 17 | - shuffle_buffer: 打乱数据对应的缓存区空间 18 | 19 | 20 | ### 例子 21 | #### 生成tar文件 22 | 将数据写入tar,见write2tar.py中例子 23 | 24 | #### 继承并实现解析逻辑 25 | 这里想实现一个功能:通过读取text,将word以及对应词向量存储至tar文件中;需要如下步骤: 26 | 1. 构建解析函数 27 | 2. 构建pipe_fn函数,主要用来整理输出格式 28 | 29 | ```python 30 | 31 | class TextPairDistilTarDataset(BaseTarDataset): 32 | def __init__(self, filename, split, ratio, length, local_rank=0, world_size=2): 33 | identity_func = lambda x: x 34 | embedd_process_fn = lambda x: np.array([eval(item) for item in str(x, encoding='utf-8').split("\t")], dtype=np.float32) 35 | 36 | keys = ["text", "embedding"] 37 | decode_key_funcs = [identity_func, embedd_process_fn] 38 | process_key_funcs = [identity_func, identity_func] 39 | 40 | data_pattern = list(braceexpand.braceexpand(filename)) 41 | indexes = list(range(len(data_pattern))) 42 | random.shuffle(indexes) 43 | 44 | selected_indexes = indexes[:int(ratio * len(indexes))] if split == 'train' else indexes[int(ratio * len(indexes)):] 45 | tar_pattern = [data_pattern[i] for i in selected_indexes] 46 | print("[INFO]:{} mode, curr filelist,".format(split), len(tar_pattern)) 47 | super().__init__( 48 | tar_pattern=tar_pattern, 49 | length=length, 50 | local_rank=local_rank, 51 | world_size=world_size, 52 | keys=keys, 53 | decode_key_funcs=decode_key_funcs, 54 | process_key_funcs=process_key_funcs, 55 | shuffle_buffer=2000, 56 | ) 57 | self.length = length 58 | 59 | def __len__(self): 60 | return self.length 61 | 62 | def find_chinese(self, file): 63 | pattern = re.compile(r'[^\u4e00-\u9fa5]') 64 | chinese = re.sub(pattern, '', file) 65 | return chinese 66 | 67 | def pipe_fn(self, data): 68 | for sample in data: 69 | sample = list(sample) 70 | text = str(sample[0], encoding='utf-8') 71 | text = self.find_chinese(text) 72 | if len(text) == 0: 73 | text = "我的" 74 | length = len(text) 75 | if len(text) > 256: 76 | start_index = np.random.randint(0, length - 256) 77 | text = text[start_index: start_index + 256] 78 | embedding = sample[1] 79 | yield {'text': text, 'embedding': embedding} 80 | 81 | ``` 82 | #### 在dataloader中使用 83 | ``` 84 | dataset_class = tar_class_dataset_init() 85 | dataloader = torch.utils.data.DataLoader(dataset_class.dataset, batch_size=xx, num_worker=xx) 86 | ``` 87 | 88 | **TIPS: webdataset属于IterableDataset, 因此需要指定数据length,否则会报错** 89 | -------------------------------------------------------------------------------- /base_dataset_tar.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import sys 3 | import warnings 4 | import webdataset as wds 5 | 6 | show_splits = False 7 | 8 | 9 | def split_by_node(urls, dist_rank, dist_size): 10 | 11 | """Split urls for each node. 12 | This uses the rank and world size. Note that it is invoked in each worker, 13 | so the results need to be consistent between multiple invocations.""" 14 | 15 | if dist_rank >= 0 and dist_size > 0: 16 | result = urls[dist_rank::dist_size] 17 | if show_splits: 18 | print( 19 | f"split_by_node {dist_rank}/{dist_size} len={len(result)}", 20 | file=sys.stderr, 21 | ) 22 | return result 23 | else: 24 | return urls 25 | 26 | 27 | def split_by_worker(urls): 28 | 29 | """Split urls for each worker.""" 30 | 31 | urls = [url for url in urls] 32 | worker_info = torch.utils.data.get_worker_info() 33 | if worker_info is not None: 34 | wid = worker_info.id 35 | num_workers = worker_info.num_workers 36 | if wid == 0 and len(urls) < num_workers: 37 | warnings.warn(f"num_workers {num_workers} > num_shards {len(urls)}") 38 | result = urls[wid::num_workers] 39 | if show_splits: 40 | print( 41 | f"split_by_worker {wid}/{num_workers} len={len(result)}", 42 | file=sys.stderr, 43 | ) 44 | return result 45 | else: 46 | return urls 47 | 48 | 49 | class BaseTarDataset(object): 50 | def __init__( 51 | self, 52 | tar_pattern, 53 | length=None, 54 | local_rank=0, 55 | world_size=2, 56 | keys=[], 57 | decode_key_funcs=[], 58 | process_key_funcs=[], 59 | shuffle_buffer=2000, 60 | **kwargs, 61 | ): 62 | """ 63 | Args: 64 | tar_pattern: 格式path/video-tfr-{000000..000003}.tar 65 | length: 数据样本长度 66 | local_rank: reduce操作所在机器 67 | world_size: 单台机器内GPU卡数 68 | keys:tar文件内的数据的key的子集,需要用来解析的key 69 | decode_key_funcs: key对应的数据decode操作,可以为编解码等相关操作 70 | process_key_funcs: 每个key对应的处理操作 71 | shuffle_buffer: 打乱数据对应的缓存区空间 72 | """ 73 | decode_list = [] 74 | for k, kf in zip(keys, decode_key_funcs): 75 | decode_list.append(wds.handle_extension(".{}".format(k), kf)) 76 | 77 | tuple_str = " ".join(keys) 78 | 79 | self.dataset = ( 80 | wds.WebDataset( 81 | tar_pattern, 82 | length=length, 83 | nodesplitter=lambda x: split_by_node(x, local_rank, world_size), 84 | splitter=split_by_worker, 85 | ) 86 | .shuffle(shuffle_buffer) 87 | .decode(*decode_list) 88 | .to_tuple(tuple_str) 89 | .map_tuple(*process_key_funcs) 90 | .pipe(self.pipe_fn) 91 | ) 92 | 93 | self.keys = keys 94 | 95 | def pipe_fn(self, data): 96 | for sample in data: 97 | sample = list(sample) 98 | sample[0] = str(sample[0], encoding='utf-8') 99 | yield sample 100 | --------------------------------------------------------------------------------