├── .flake8 ├── LICENSE ├── README.md ├── data ├── dataset.py ├── sampler.py ├── server.py └── utils.py └── images ├── 1.jpg ├── 2.jpg ├── 3.jpg ├── 4.jpg └── 5.jpg /.flake8: -------------------------------------------------------------------------------- 1 | # This is an example .flake8 config, used when developing *Black* itself. 2 | # Keep in sync with setup.cfg which is used for source packages. 3 | 4 | [flake8] 5 | # W606: reserved keywords 6 | ignore = E203, E266, W503, F405, F403, W606, E731, C901, E701 7 | max-line-length = 120 8 | max-complexity = 18 9 | select = B,C,E,F,W,T4,B9 10 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | 2 | Apache License 3 | Version 2.0, January 2004 4 | http://www.apache.org/licenses/ 5 | 6 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 7 | 8 | 1. Definitions. 9 | 10 | "License" shall mean the terms and conditions for use, reproduction, 11 | and distribution as defined by Sections 1 through 9 of this document. 12 | 13 | "Licensor" shall mean the copyright owner or entity authorized by 14 | the copyright owner that is granting the License. 15 | 16 | "Legal Entity" shall mean the union of the acting entity and all 17 | other entities that control, are controlled by, or are under common 18 | control with that entity. For the purposes of this definition, 19 | "control" means (i) the power, direct or indirect, to cause the 20 | direction or management of such entity, whether by contract or 21 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 22 | outstanding shares, or (iii) beneficial ownership of such entity. 23 | 24 | "You" (or "Your") shall mean an individual or Legal Entity 25 | exercising permissions granted by this License. 26 | 27 | "Source" form shall mean the preferred form for making modifications, 28 | including but not limited to software source code, documentation 29 | source, and configuration files. 30 | 31 | "Object" form shall mean any form resulting from mechanical 32 | transformation or translation of a Source form, including but 33 | not limited to compiled object code, generated documentation, 34 | and conversions to other media types. 35 | 36 | "Work" shall mean the work of authorship, whether in Source or 37 | Object form, made available under the License, as indicated by a 38 | copyright notice that is included in or attached to the work 39 | (an example is provided in the Appendix below). 40 | 41 | "Derivative Works" shall mean any work, whether in Source or Object 42 | form, that is based on (or derived from) the Work and for which the 43 | editorial revisions, annotations, elaborations, or other modifications 44 | represent, as a whole, an original work of authorship. For the purposes 45 | of this License, Derivative Works shall not include works that remain 46 | separable from, or merely link (or bind by name) to the interfaces of, 47 | the Work and Derivative Works thereof. 48 | 49 | "Contribution" shall mean any work of authorship, including 50 | the original version of the Work and any modifications or additions 51 | to that Work or Derivative Works thereof, that is intentionally 52 | submitted to Licensor for inclusion in the Work by the copyright owner 53 | or by an individual or Legal Entity authorized to submit on behalf of 54 | the copyright owner. For the purposes of this definition, "submitted" 55 | means any form of electronic, verbal, or written communication sent 56 | to the Licensor or its representatives, including but not limited to 57 | communication on electronic mailing lists, source code control systems, 58 | and issue tracking systems that are managed by, or on behalf of, the 59 | Licensor for the purpose of discussing and improving the Work, but 60 | excluding communication that is conspicuously marked or otherwise 61 | designated in writing by the copyright owner as "Not a Contribution." 62 | 63 | "Contributor" shall mean Licensor and any individual or Legal Entity 64 | on behalf of whom a Contribution has been received by Licensor and 65 | subsequently incorporated within the Work. 66 | 67 | 2. Grant of Copyright License. Subject to the terms and conditions of 68 | this License, each Contributor hereby grants to You a perpetual, 69 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 70 | copyright license to reproduce, prepare Derivative Works of, 71 | publicly display, publicly perform, sublicense, and distribute the 72 | Work and such Derivative Works in Source or Object form. 73 | 74 | 3. Grant of Patent License. Subject to the terms and conditions of 75 | this License, each Contributor hereby grants to You a perpetual, 76 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 77 | (except as stated in this section) patent license to make, have made, 78 | use, offer to sell, sell, import, and otherwise transfer the Work, 79 | where such license applies only to those patent claims licensable 80 | by such Contributor that are necessarily infringed by their 81 | Contribution(s) alone or by combination of their Contribution(s) 82 | with the Work to which such Contribution(s) was submitted. If You 83 | institute patent litigation against any entity (including a 84 | cross-claim or counterclaim in a lawsuit) alleging that the Work 85 | or a Contribution incorporated within the Work constitutes direct 86 | or contributory patent infringement, then any patent licenses 87 | granted to You under this License for that Work shall terminate 88 | as of the date such litigation is filed. 89 | 90 | 4. Redistribution. You may reproduce and distribute copies of the 91 | Work or Derivative Works thereof in any medium, with or without 92 | modifications, and in Source or Object form, provided that You 93 | meet the following conditions: 94 | 95 | (a) You must give any other recipients of the Work or 96 | Derivative Works a copy of this License; and 97 | 98 | (b) You must cause any modified files to carry prominent notices 99 | stating that You changed the files; and 100 | 101 | (c) You must retain, in the Source form of any Derivative Works 102 | that You distribute, all copyright, patent, trademark, and 103 | attribution notices from the Source form of the Work, 104 | excluding those notices that do not pertain to any part of 105 | the Derivative Works; and 106 | 107 | (d) If the Work includes a "NOTICE" text file as part of its 108 | distribution, then any Derivative Works that You distribute must 109 | include a readable copy of the attribution notices contained 110 | within such NOTICE file, excluding those notices that do not 111 | pertain to any part of the Derivative Works, in at least one 112 | of the following places: within a NOTICE text file distributed 113 | as part of the Derivative Works; within the Source form or 114 | documentation, if provided along with the Derivative Works; or, 115 | within a display generated by the Derivative Works, if and 116 | wherever such third-party notices normally appear. The contents 117 | of the NOTICE file are for informational purposes only and 118 | do not modify the License. You may add Your own attribution 119 | notices within Derivative Works that You distribute, alongside 120 | or as an addendum to the NOTICE text from the Work, provided 121 | that such additional attribution notices cannot be construed 122 | as modifying the License. 123 | 124 | You may add Your own copyright statement to Your modifications and 125 | may provide additional or different license terms and conditions 126 | for use, reproduction, or distribution of Your modifications, or 127 | for any such Derivative Works as a whole, provided Your use, 128 | reproduction, and distribution of the Work otherwise complies with 129 | the conditions stated in this License. 130 | 131 | 5. Submission of Contributions. Unless You explicitly state otherwise, 132 | any Contribution intentionally submitted for inclusion in the Work 133 | by You to the Licensor shall be under the terms and conditions of 134 | this License, without any additional terms or conditions. 135 | Notwithstanding the above, nothing herein shall supersede or modify 136 | the terms of any separate license agreement you may have executed 137 | with Licensor regarding such Contributions. 138 | 139 | 6. Trademarks. This License does not grant permission to use the trade 140 | names, trademarks, service marks, or product names of the Licensor, 141 | except as required for reasonable and customary use in describing the 142 | origin of the Work and reproducing the content of the NOTICE file. 143 | 144 | 7. Disclaimer of Warranty. Unless required by applicable law or 145 | agreed to in writing, Licensor provides the Work (and each 146 | Contributor provides its Contributions) on an "AS IS" BASIS, 147 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 148 | implied, including, without limitation, any warranties or conditions 149 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 150 | PARTICULAR PURPOSE. You are solely responsible for determining the 151 | appropriateness of using or redistributing the Work and assume any 152 | risks associated with Your exercise of permissions under this License. 153 | 154 | 8. Limitation of Liability. In no event and under no legal theory, 155 | whether in tort (including negligence), contract, or otherwise, 156 | unless required by applicable law (such as deliberate and grossly 157 | negligent acts) or agreed to in writing, shall any Contributor be 158 | liable to You for damages, including any direct, indirect, special, 159 | incidental, or consequential damages of any character arising as a 160 | result of this License or out of the use or inability to use the 161 | Work (including but not limited to damages for loss of goodwill, 162 | work stoppage, computer failure or malfunction, or any and all 163 | other commercial damages or losses), even if such Contributor 164 | has been advised of the possibility of such damages. 165 | 166 | 9. Accepting Warranty or Additional Liability. While redistributing 167 | the Work or Derivative Works thereof, You may choose to offer, 168 | and charge a fee for, acceptance of support, warranty, indemnity, 169 | or other liability obligations and/or rights consistent with this 170 | License. However, in accepting such obligations, You may act only 171 | on Your own behalf and on Your sole responsibility, not on behalf 172 | of any other Contributor, and only if You agree to indemnify, 173 | defend, and hold each Contributor harmless for any liability 174 | incurred by, or claims asserted against, such Contributor by reason 175 | of your accepting any such warranty or additional liability. 176 | 177 | END OF TERMS AND CONDITIONS 178 | 179 | APPENDIX: How to apply the Apache License to your work. 180 | 181 | To apply the Apache License to your work, attach the following 182 | boilerplate notice, with the fields enclosed by brackets "[]" 183 | replaced with your own identifying information. (Don't include 184 | the brackets!) The text should be enclosed in the appropriate 185 | comment syntax for the file format. We also recommend that a 186 | file or class name and description of purpose be included on the 187 | same "printed page" as the copyright notice for easier 188 | identification within third-party archives. 189 | 190 | Copyright [yyyy] [name of copyright owner] 191 | 192 | Licensed under the Apache License, Version 2.0 (the "License"); 193 | you may not use this file except in compliance with the License. 194 | You may obtain a copy of the License at 195 | 196 | http://www.apache.org/licenses/LICENSE-2.0 197 | 198 | Unless required by applicable law or agreed to in writing, software 199 | distributed under the License is distributed on an "AS IS" BASIS, 200 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 201 | See the License for the specific language governing permissions and 202 | limitations under the License. -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # PyTorch 大规模数据集加载 2 | 3 | ## [知乎链接](https://zhuanlan.zhihu.com/p/357809861) 4 | 5 | ## 问题阐述 6 | 对于数据量比较小的数据集,一般来说我们直接加载到内存里即可,不需要考虑内存是否够用的情况。对于大规模数据集(千万级别以上)我们普通的加载方式已经没法满足我们的需求,内存问题已经成为瓶颈之一,因此针对此我们需要作出一些针对性的优化。 7 | 8 | ### PyTorch的数据集加载背景简单介绍 9 | 10 | 一般来说我们只需要关注这以下几个部分,dataset 是我们实现的自己数据集的具体原型,继承自torch.dataset, 其getitem 函数负责根据index去拿出我们对应的meta信息,sampler 负责提供想要拿出的index顺序。详细的PyTorch 数据集加载可以阅读官网或者可以看这篇。[Dataset](https://zhuanlan.zhihu.com/p/337850513) 11 | 12 | ![1.jpg](./images/1.jpg) 13 | 14 | ### 具体实现 15 | 16 | * 背景介绍 17 | 18 | 以下为一个最简单的dataset 实现,对于8卡任务因为是多进程,所以实际的进程数量为8, 也就是会有8 * metas 需要在内存里存放(实际考虑到dataloader 的worker 数量,这个实际占用量会更大),当我们的metas信息比较大的时候,我们的内存就可能会出现溢出问题。 19 | 20 | * 普通样例 21 | 22 | ```python 23 | class BaseDataset(Dataset): 24 | def __init__(self, meta_file): 25 | super(BaseDataset, self).__init__() 26 | self.metas = self.parse(meta_file) 27 | 28 | def parse(self, meta_file): 29 | metas = [] 30 | with open(meta_file) as f: 31 | for line in f.readlines(): 32 | metas.append(line.strip()) 33 | return metas 34 | 35 | def __getitem__(self, idx): 36 | return self.metas[idx] 37 | 38 | ``` 39 | 40 | * meta_file 格式 41 | ```shell 42 | #filename label (分类任务) 43 | image1.jpg 1 44 | image2.jpg 0 45 | image3.jpg 3 46 | ``` 47 | 48 | * 训练流程 49 | 50 | 训练数据的流程可以表示如下: 51 | 52 | ```python 53 | dataset = BaseDataset("/path/to/meta") 54 | sampler = DistributedSampler(datset) 55 | dataloader = DataLoader( 56 | dataset=dataset, 57 | batch_size=32, 58 | shuffle=False, 59 | num_workers=4, 60 | sampler=sampler 61 | ) 62 | model = build_model() 63 | for index, batch in enumerate(dataloader): 64 | image, label = batch 65 | output = model(image) 66 | loss = criterion(output, label) 67 | loss.backward() 68 | 69 | ``` 70 | 71 | #### 解决方案一 72 | 73 | 将metas 信息中心化,放到一个中心化的地方进行存储,只保留一份,这样可以存储非常大的metas。然后dataset 从中心化的地方去获取meta信息 74 | 75 | ![2.jpg](./images/2.jpg) 76 | 77 | * example 78 | 79 | ```python 80 | class ServerDataset(BaseDataset): 81 | def __init__(self, meta_file, server_ip, server_port, timeout=1000): 82 | super(ServerDataset, self).__init__(meta_file) 83 | self.server_ip = server_ip 84 | self.server_port = server_port 85 | self.timeout = timeout 86 | self.meta_num = self.get_meta_num() 87 | 88 | @retry(stop_max_delay=10, stop_max_attempt_number=10) 89 | def get_meta_num(self): 90 | meta_num = requests.get('http://{}:{}/get_len'.format( 91 | self.server_ip, self.server_port), timeout=self.timeout).json() 92 | return int(meta_num) 93 | 94 | @retry(stop_max_delay=10, stop_max_attempt_number=10) 95 | def get_meta(self, idx): 96 | meta = requests.get('http://{}:{}/get/{}'.format( 97 | self.server_ip, self.server_port, idx), timeout=self.timeout).json() 98 | return meta 99 | ``` 100 | 101 | * 训练流程 102 | 103 | **启动server** 104 | 105 | ```shell 106 | python server.py --meta_file="/path/to/meta" --port="10080" 107 | 108 | ``` 109 | 110 | **启动训练** 111 | 112 | ```python 113 | dataset = ServerDataset("/path/to/meta", server_ip="10.10.10.10", server_port="10080") 114 | sampler = DistributedSampler(datset) 115 | dataloader = DataLoader( 116 | dataset=dataset, 117 | batch_size=32, 118 | shuffle=False, 119 | num_workers=4, 120 | sampler=sampler 121 | ) 122 | model = build_model() 123 | for index, batch in enumerate(dataloader): 124 | image, label = batch 125 | output = model(image) 126 | loss = criterion(output, label) 127 | loss.backward() 128 | 129 | ``` 130 | 131 | * 弊端 132 | 133 | 这种做法对于qps 在1k以下还比较实用, 但是当我们的训练的总batchsize 特别大的时候这种做法会有明显的瓶颈问题,受限于中心化的读取上限问题,因此此方法具有一定的局限性。 134 | 135 | 136 | #### 解决方案二 137 | * 背景知识 138 | 139 | 从原理出发,在分布式训练的过程中,其实每张卡实际使用的数据量为 len(metas) // world_size, 在一般的训练过程中我们为了访问方便,采用sampler 去划分不同的卡读取的index,每块卡还是会保留所有的meta信息,因此这样会导致前面的内存问题。 140 | 141 | * 具体方案 142 | 143 | 我们的方案具体为 分rank + 切分数据集进一步的动态的去加载我们的数据集。如下图所示,在初始化的时候,每块卡只加载其对应的meta信息,这样总体的内存占用率可减少 world_size 倍。为了进一步的减少内存,我们还可以进一步将数据集进行切分,分成 mini_epoch 进行分组读取。两者配合使用,总体的内存减少量可达 world_size * mini_epoch 倍,基本上可以达到我们的需求。 144 | 145 | *实际的流程图* 146 | ![3.jpg](./images/3.jpg) 147 | 148 | * 切分流程 149 | ```python 150 | ''' 151 | Metas 切分过程, mini_epoch = 2, world_size = 8 152 | 153 | mini_epoch_idx = 0 mini_epoch_idx = 1 154 | ---- ---- ---- ---- ---- ---- ---- ---- | ---- ---- ---- ---- ---- ---- ---- ---- 155 | rk0 rk1 rk2 rk3 rk4 rk5 rk6 rk7 | rk0 rk1 rk2 rk3 rk4 rk5 rk6 rk7 156 | 157 | 每次只加载 len(metas) // (world_size * mini_epoch) 这样我们的内存占用就会可以人为的进行调整 158 | 159 | ''' 160 | ``` 161 | * 注意 162 | 163 | 对于普通的dataloader,随机性一般由sampler进行控制,我们这里由于已经分rank进行加载我们的meta 信息,因此每隔一个epoch我们需要重新分配一次我们每个 rank 的 meta 信息,为了保证随机性,在分配rank的meta信息时,我们就要引入随机性, 以下是从本地读取的样例。 164 | 165 | * 本地读取样例 166 | 167 | ![4.jpg](./images/4.jpg) 168 | 169 | * 训练流程 170 | 171 | ```python 172 | for epoch_num in range(epoch_num): 173 | reload_cfg = {"mini_epoch": 1, "seed": epoch_num, "mini_epoch_idx": 0, "group": 1} 174 | dataset = RankDataset("/path/to/meta", is_test=False, reload_cfg) 175 | sampler = RandomSampler(datset) 176 | dataloader = DataLoader( 177 | dataset=dataset, 178 | batch_size=32, 179 | shuffle=False, 180 | num_workers=4, 181 | sampler=sampler 182 | ) 183 | ``` 184 | 185 | 186 | 187 | * server 读取样例 188 | 189 | 本地读取常常会受限于文件系统的读取效率,在我们的文件系统读取速度比较差的时候整个加载会比较慢,因此提供一个中心化读取方案,适用于网络较快的情况。 190 | 191 | ![5.jpg](./images/5.jpg) 192 | 193 | **启动server** 194 | 195 | ```shell 196 | python server.py --meta_file="/path/to/meta" --port="10080" 197 | 198 | ``` 199 | 200 | **启动训练** 201 | 202 | ```python 203 | for epoch_num in range(epoch_num): 204 | reload_cfg = {"mini_epoch": 1, "seed": epoch_num, "mini_epoch_idx": 0, "group": 1} 205 | dataset = RankServerDataset("/path/to/meta", server_ip="10.10.10.10", server_port="10080", is_test=False, reload_cfg) 206 | sampler = RandomSampler(datset) 207 | dataloader = DataLoader( 208 | dataset=dataset, 209 | batch_size=32, 210 | shuffle=False, 211 | num_workers=4, 212 | sampler=sampler 213 | ) 214 | ``` 215 | 216 | 217 | **需要注意** 218 | 当我们需要切分mini_epoch 的时候,每个mini_epoch 都需要进行重新构建dataloader 219 | 220 | * Sampler 221 | 222 | 这是切分rank 之后的sampler,这里就不再需要区分rank了,meta 已经根据rank进行区分 223 | 224 | ```python 225 | class RandomSampler(Sampler): 226 | r"""Samples elements randomly, without replacement. 227 | 228 | Arguments: 229 | data_source (Dataset): dataset to sample from 230 | """ 231 | 232 | def __init__(self, dataset): 233 | self.dataset = dataset 234 | 235 | def __iter__(self): 236 | return iter(torch.randperm(len(self.dataset)).tolist()) 237 | 238 | def __len__(self): 239 | return len(self.dataset) 240 | 241 | ``` 242 | 243 | **以下是普通的分布式的sampler** 244 | 245 | ```python 246 | class DistributedSampler(Sampler): 247 | def __init__(self, dataset, world_size=None, rank=None): 248 | if world_size is None: 249 | world_size = get_world_size() 250 | if rank is None: 251 | rank = get_rank() 252 | 253 | self.dataset = dataset 254 | self.world_size = world_size 255 | self.rank = rank 256 | self.num_samples = int( 257 | math.ceil(len(self.dataset) * 1.0 / self.world_size)) 258 | self.total_size = self.num_samples * self.world_size 259 | 260 | def __iter__(self): 261 | # deterministically shuffle based on epoch 262 | g = torch.Generator() 263 | g.manual_seed(self.epoch) 264 | indices = list(torch.randperm(len(self.dataset), generator=g)) 265 | 266 | # add extra samples to make it evenly divisible 267 | indices += indices[:(self.total_size - len(indices))] 268 | assert len(indices) == self.total_size 269 | 270 | # subsample 271 | offset = self.num_samples * self.rank 272 | indices = indices[offset:offset + self.num_samples] 273 | assert len(indices) == self.num_samples 274 | 275 | return iter(indices) 276 | 277 | def __len__(self): 278 | return self.num_samples 279 | ``` -------------------------------------------------------------------------------- /data/dataset.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 Sensetime Yongqiang Yao, Tianzi Xiao 2 | from torch.utils.data import Dataset 3 | import requests 4 | from retrying import retry 5 | import numpy as np 6 | from .utils import get_rank, get_world_size, barrier 7 | import math 8 | import json 9 | 10 | 11 | class BaseDataset(Dataset): 12 | def __init__(self, meta_file): 13 | super(BaseDataset, self).__init__() 14 | self.metas = self.parse(meta_file) 15 | 16 | def parse(self, meta_file): 17 | metas = [] 18 | with open(meta_file) as f: 19 | for line in f.readlines(): 20 | metas.append(line.strip()) 21 | return metas 22 | 23 | def __getitem__(self, idx): 24 | return self.metas[idx] 25 | 26 | 27 | class ServerDataset(BaseDataset): 28 | def __init__(self, meta_file, server_ip, server_port, timeout=1000): 29 | super(ServerDataset, self).__init__(meta_file) 30 | self.server_ip = server_ip 31 | self.server_port = server_port 32 | self.timeout = timeout 33 | self.meta_num = self.get_meta_num() 34 | 35 | @retry(stop_max_delay=10, stop_max_attempt_number=10) 36 | def get_meta_num(self): 37 | meta_num = requests.get('http://{}:{}/get_len'.format( 38 | self.server_ip, self.server_port), timeout=self.timeout).json() 39 | return int(meta_num) 40 | 41 | @retry(stop_max_delay=10, stop_max_attempt_number=10) 42 | def get_meta(self, idx): 43 | meta = requests.get('http://{}:{}/get/{}'.format( 44 | self.server_ip, self.server_port, idx), timeout=self.timeout).json() 45 | return meta 46 | 47 | 48 | class RankDataset(BaseDataset): 49 | ''' 50 | 实际流程 51 | 获取rank和world_size 信息 -> 获取dataset长度 -> 根据dataset长度产生随机indices -> 52 | 给不同的rank 分配indices -> 根据这些indices产生metas 53 | ''' 54 | 55 | def __init__(self, meta_file, is_test=False, reload_cfg=None): 56 | self.world_size = get_world_size() 57 | self.rank = get_rank() 58 | if reload_cfg is None: 59 | reload_cfg = {} 60 | self.mini_epoch = reload_cfg.get('mini_epoch', 1) 61 | self.seed = reload_cfg.get('seed', 0) 62 | self.mini_epoch_idx = reload_cfg.get('mini_epoch_idx', 0) 63 | self.group = reload_cfg.get('group', 1) 64 | self.is_test = is_test 65 | super(RankDataset, self).__init__(meta_file) 66 | 67 | def count_dataset_size(self, file_name): 68 | from itertools import (takewhile, repeat) 69 | buffer = 1024 * 1024 * 8 70 | with open(file_name) as f: 71 | buf_gen = takewhile(lambda x: x, (f.read(buffer) 72 | for _ in repeat(None))) 73 | return sum(buf.count('\n') for buf in buf_gen) 74 | 75 | def get_rank_indices(self, meta_file): 76 | dataset_size = self.count_dataset_size(meta_file) 77 | if self.is_test: 78 | return list(range(0, dataset_size)), dataset_size 79 | indices = self.get_group_random_indices(dataset_size) 80 | rank_num_samples = int( 81 | math.ceil(dataset_size * 1.0 / (self.world_size * self.mini_epoch))) 82 | total_size = rank_num_samples * self.world_size * self.mini_epoch 83 | indices += indices[:(total_size - len(indices))] 84 | offset = rank_num_samples * self.rank 85 | mini_epoch_offset_begin = self.mini_epoch_idx * rank_num_samples * self.world_size 86 | mini_epoch_offset_end = (self.mini_epoch_idx + 1) * \ 87 | rank_num_samples * self.world_size 88 | rank_indices = indices[mini_epoch_offset_begin: 89 | mini_epoch_offset_end][offset:offset + rank_num_samples] 90 | assert len(rank_indices) == rank_num_samples 91 | return rank_indices, rank_num_samples 92 | 93 | def get_group_random_indices(self, dataset_size): 94 | ''' 95 | 分组产生随机数,避免一次性产生带来的内存溢出问题(5亿会占用很大的内存) 96 | 1. 先切分成group组,每组设置不同的随机种子 97 | 2. 对不同组进行二次随机,进一步近似全局随机 98 | ''' 99 | 100 | indices = [] 101 | temp_indices = [] 102 | mini_dataset_size = int(math.ceil(dataset_size * 1.0 / self.group)) 103 | group = math.ceil(dataset_size * 1.0 / mini_dataset_size) 104 | last_size = dataset_size - mini_dataset_size * (group - 1) 105 | for i in range(group): 106 | if i <= group - 2: 107 | cur_size = mini_dataset_size 108 | else: 109 | cur_size = last_size 110 | np.random.seed(self.seed + i + 10000) 111 | _indices = np.random.permutation(cur_size).astype(np.int32) 112 | _indices += i * mini_dataset_size 113 | temp_indices.append(_indices.tolist()) 114 | np.random.seed(self.seed + 10000) 115 | for i in np.random.permutation(group): 116 | indices.extend(temp_indices[i]) 117 | return indices 118 | 119 | def read_from_buf(self, fileObj, lineSign): 120 | ''' 121 | 避免一次性读入所有文件,每次读入一个buf,减少内存占用 122 | ''' 123 | buf = "" 124 | while True: 125 | lines = buf.split(lineSign) 126 | for line in lines[0:-1]: 127 | yield line 128 | buf_size = 1024 * 1024 * 8 129 | chunk = fileObj.readline(buf_size) 130 | if not chunk: 131 | break 132 | buf = chunk 133 | 134 | def _read(self, meta_file): 135 | data_lst = [] 136 | rank_indices, rank_num_samples = self.get_rank_indices(meta_file) 137 | rank_indices = set(rank_indices) 138 | idx = 0 139 | with open(meta_file) as f: 140 | for line in self.read_from_buf(f, "\n"): 141 | if idx in rank_indices: 142 | filename = line.rstrip()[:-2] 143 | label = line.rstrip()[-1] 144 | data_lst.append([filename, label, idx]) 145 | idx += 1 146 | if len(rank_indices) != rank_num_samples: 147 | data_lst += data_lst[:(rank_num_samples - len(rank_indices))] 148 | return data_lst 149 | 150 | def parse(self, meta_file): 151 | ''' 152 | parse meta_file 153 | return: metas 154 | ''' 155 | metas = self._read(meta_file) 156 | return metas 157 | 158 | 159 | class RankServerDataset(BaseDataset): 160 | ''' 161 | server 中心化读取流程 162 | 获取rank, world_size 信息 -> 获取reload_cfg(重新生成dataloader需要) -> 163 | 获取dataset 长度 -> 向server端发送生成indices请求 -> 从server 端获取rank的dataset size -> 164 | 分组从server 端获取metas 165 | 166 | ''' 167 | 168 | def __init__(self, 169 | meta_file, 170 | server_ip, 171 | server_port, 172 | is_test=False, 173 | reload_cfg={}, 174 | timeout=1000): 175 | self.world_size = get_world_size() 176 | self.rank = get_rank() 177 | self.timeout = timeout 178 | self.mini_epoch = reload_cfg.get('mini_epoch', 1) 179 | self.seed = reload_cfg.get('seed', 0) 180 | self.mini_epoch_idx = reload_cfg.get('mini_epoch_idx', 0) 181 | self.group = reload_cfg.get('group', 10) 182 | self.cache_size = reload_cfg.get('cache_size', 100000) 183 | self.is_test = is_test 184 | self.server_ip = server_ip 185 | self.server_port = server_port 186 | super(RankServerDataset, self).__init__(meta_file) 187 | 188 | @retry(stop_max_delay=10, stop_max_attempt_number=10) 189 | def prepare_rank_indices(self): 190 | key_info = { 191 | "mini_epoch": self.mini_epoch, 192 | "mini_epoch_idx": self.mini_epoch_idx, 193 | "world_size": self.world_size, 194 | "seed": self.seed, 195 | "group": self.group 196 | } 197 | status = requests.get('http://{}:{}/set_rank_indices/{}'.format( 198 | self.server_ip, self.server_port, json.dumps(key_info)), timeout=self.timeout).json() 199 | assert status == 1 200 | 201 | @retry(stop_max_delay=10, stop_max_attempt_number=10) 202 | def get_rank_metas(self, cur_idx): 203 | key_info = { 204 | "begin": cur_idx[0], 205 | "end": cur_idx[1], 206 | "rank": self.rank, 207 | "mini_epoch_idx": self.mini_epoch_idx, 208 | "world_size": self.world_size, 209 | } 210 | metas = requests.get('http://{}:{}/get_rank_metas/{}'.format( 211 | self.server_ip, self.server_port, json.dumps(key_info)), timeout=self.timeout).json() 212 | return metas 213 | 214 | @retry(stop_max_delay=10, stop_max_attempt_number=10) 215 | def get_dataset_size(self): 216 | size = requests.get( 217 | 'http://{}:{}/get_len'.format(self.server_ip, self.server_port), timeout=self.timeout).json() 218 | return int(size) 219 | 220 | @retry(stop_max_delay=10, stop_max_attempt_number=10) 221 | def get_rank_size(self): 222 | size = requests.get('http://{}:{}/get_rank_size/{}'.format( 223 | self.server_port, self.server_ip, self.rank), timeout=self.timeout).json() 224 | return int(size) 225 | 226 | def _read(self, meta_file): 227 | data_lst = [] 228 | # dataset_size = self.get_dataset_size() 229 | if self.rank == 0: 230 | self.prepare_rank_indices() 231 | barrier() 232 | rank_size = self.get_rank_size() 233 | idx = 0 234 | num = rank_size // self.cache_size + 1 235 | for i in range(num): 236 | cur_idx = [i * self.cache_size, (i + 1) * self.cache_size] 237 | lines = self.get_rank_metas(cur_idx) 238 | for line in lines: 239 | filename = line.rstrip()[:-2] 240 | label = line.rstrip()[-1] 241 | data_lst.append([filename, label, idx]) 242 | idx += 1 243 | return data_lst 244 | 245 | def parse(self, meta_file): 246 | ''' 247 | parse meta_file 248 | return: metas 249 | ''' 250 | metas = self._read(meta_file) 251 | return metas 252 | -------------------------------------------------------------------------------- /data/sampler.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 Sensetime Yongqiang Yao, Tianzi Xiao 2 | from .utils import get_rank, get_world_size 3 | import torch 4 | import math 5 | from torch.utils.data.sampler import Sampler 6 | 7 | 8 | class DistributedSampler(Sampler): 9 | def __init__(self, dataset, world_size=None, rank=None): 10 | if world_size is None: 11 | world_size = get_world_size() 12 | if rank is None: 13 | rank = get_rank() 14 | 15 | self.dataset = dataset 16 | self.world_size = world_size 17 | self.rank = rank 18 | self.num_samples = int( 19 | math.ceil(len(self.dataset) * 1.0 / self.world_size)) 20 | self.total_size = self.num_samples * self.world_size 21 | 22 | def __iter__(self): 23 | # deterministically shuffle based on epoch 24 | g = torch.Generator() 25 | g.manual_seed(self.epoch) 26 | indices = list(torch.randperm(len(self.dataset), generator=g)) 27 | 28 | # add extra samples to make it evenly divisible 29 | indices += indices[:(self.total_size - len(indices))] 30 | assert len(indices) == self.total_size 31 | 32 | # subsample 33 | offset = self.num_samples * self.rank 34 | indices = indices[offset:offset + self.num_samples] 35 | assert len(indices) == self.num_samples 36 | 37 | return iter(indices) 38 | 39 | def __len__(self): 40 | return self.num_samples 41 | 42 | 43 | class RandomSampler(Sampler): 44 | r"""Samples elements randomly, without replacement. 45 | 46 | Arguments: 47 | data_source (Dataset): dataset to sample from 48 | """ 49 | 50 | def __init__(self, dataset): 51 | self.dataset = dataset 52 | 53 | def __iter__(self): 54 | return iter(torch.randperm(len(self.dataset)).tolist()) 55 | 56 | def __len__(self): 57 | return len(self.dataset) 58 | -------------------------------------------------------------------------------- /data/server.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 Sensetime Yongqiang Yao, Tianzi Xiao 2 | from __future__ import division 3 | 4 | # Standard Library 5 | import argparse 6 | import json # noqa F402 7 | import os 8 | import random # noqa F402 9 | import numpy as np 10 | 11 | # Import from third library 12 | from flask import Flask # noqa F402 13 | import math 14 | 15 | 16 | # from gevent import monkey 17 | # monkey.patch_all() 18 | 19 | parser = argparse.ArgumentParser(description='Server') 20 | 21 | parser.add_argument( 22 | '--meta_file', 23 | dest='meta_file', 24 | default='', 25 | help='meta file for training') 26 | parser.add_argument( 27 | '--port', 28 | default=28889, 29 | type=int, 30 | help='server port') 31 | 32 | 33 | app = Flask(__name__) 34 | 35 | global data_lst 36 | data_lst = [] 37 | global indices 38 | indices = [] 39 | global rank_num_samples 40 | rank_num_samples = 0 41 | global is_init 42 | is_init = False 43 | 44 | 45 | @app.route('/get/') 46 | def get(key): 47 | meta = data_lst[int(key)] 48 | return json.dumps(meta) 49 | 50 | 51 | @app.route('/set_rank_indices/') 52 | def set_rank_indices(key): 53 | decode_key = json.loads(key) 54 | seed = decode_key.get('seed', 0) 55 | group = decode_key.get('group', 10) 56 | world_size = decode_key.get("world_size", 8) 57 | mini_epoch = decode_key.get('mini_epoch', 1) 58 | mini_epoch_idx = decode_key.get('mini_epoch_idx', 0) 59 | global indices 60 | if len(indices) > 0 and mini_epoch_idx > 0: 61 | return json.dumps(1) 62 | get_group_random_indices(len(data_lst), group=group, seed=seed) 63 | prepare_all_rank_indices(world_size, mini_epoch, mini_epoch_idx) 64 | global is_init 65 | is_init = True 66 | return json.dumps(1) 67 | 68 | 69 | @app.route('/get_rank_size/') 70 | def get_rank_size(key): 71 | global rank_num_samples 72 | return json.dumps(rank_num_samples) 73 | 74 | 75 | @app.route('/get_rank_metas/') 76 | def get_rank_metas(key): 77 | global is_init, indices 78 | decode_key = json.loads(key) 79 | rank = decode_key.get('rank', 0) 80 | begin = decode_key.get('begin', 0) 81 | end = decode_key.get('end', 0) 82 | mini_epoch_idx = decode_key.get('mini_epoch_idx', 0) 83 | world_size = decode_key.get("world_size", 8) 84 | if not is_init: 85 | world_size = decode_key.get("world_size", 8) 86 | mini_epoch = decode_key.get('mini_epoch', 1) 87 | mini_epoch_idx = decode_key.get('mini_epoch_idx', 0) 88 | prepare_all_rank_indices(world_size, mini_epoch, mini_epoch_idx) 89 | metas = [] 90 | begin_idx, end_idx = get_cur_index( 91 | rank, begin, end, mini_epoch_idx, world_size) 92 | for i in range(begin_idx, end_idx): 93 | metas.append(data_lst[indices[i]]) 94 | return json.dumps(metas) 95 | 96 | 97 | @app.route('/get_len') 98 | def get_len(): 99 | return json.dumps(len(data_lst)) 100 | 101 | 102 | def get_cur_index(rank, begin, end, mini_epoch_idx, world_size): 103 | global rank_num_samples 104 | mini_epoch_index = mini_epoch_idx * rank_num_samples * world_size 105 | rank_index = rank_num_samples * rank 106 | cur_begin_idx = max(0, begin + mini_epoch_index + rank_index) 107 | cur_end_idx = min(end + mini_epoch_index + rank_index, 108 | mini_epoch_index + rank_index + rank_num_samples) 109 | return cur_begin_idx, cur_end_idx 110 | 111 | 112 | def get_group_random_indices(dataset_size, group=10, seed=1): 113 | ''' 114 | 1. random dataset_size / group for each sub group 115 | 2. random group 116 | 3. extend final indices 117 | ''' 118 | mini_dataset_size = int(math.ceil(dataset_size * 1.0 / group)) 119 | group = math.ceil(dataset_size * 1.0 / mini_dataset_size) 120 | last_size = dataset_size - mini_dataset_size * (group - 1) 121 | global indices 122 | indices = [] 123 | temp_indices = [] 124 | for i in range(group): 125 | if i <= group - 2: 126 | cur_size = mini_dataset_size 127 | else: 128 | cur_size = last_size 129 | np.random.seed(i + int(seed) + 10000) 130 | _indices = np.random.permutation(cur_size).astype(np.int32) 131 | _indices += i * mini_dataset_size 132 | temp_indices.append(_indices.tolist()) 133 | np.random.seed(int(seed) + 10000) 134 | for i in np.random.permutation(group): 135 | indices.extend(temp_indices[i]) 136 | return indices 137 | 138 | 139 | def prepare_all_rank_indices(world_size, mini_epoch, mini_epoch_idx): 140 | global indices, rank_num_samples 141 | dataset_size = len(data_lst) 142 | rank_num_samples = int( 143 | math.ceil(dataset_size * 1.0 / (world_size * mini_epoch))) 144 | total_size = rank_num_samples * world_size * mini_epoch 145 | indices = indices + indices[:(total_size - len(indices))] 146 | 147 | 148 | def get_meta(meta_file): 149 | with open(meta_file) as f: 150 | for line in f.readlines(): 151 | data_lst.append(line.strip()) 152 | 153 | 154 | if __name__ == '__main__': 155 | args = parser.parse_args() 156 | get_meta(args.meta_file) 157 | os.system('ifconfig') 158 | port = args.port 159 | app.run('0.0.0.0', port, threaded=True) 160 | -------------------------------------------------------------------------------- /data/utils.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 Sensetime Yongqiang Yao, Tianzi Xiao 2 | import torch 3 | import torch.distributed as dist 4 | 5 | 6 | def get_rank(): 7 | if not dist.is_available(): 8 | return 0 9 | if not dist.is_initialized(): 10 | return 0 11 | return dist.get_rank() 12 | 13 | 14 | def get_world_size(): 15 | if not dist.is_available(): 16 | return 1 17 | if not dist.is_initialized(): 18 | return 1 19 | return dist.get_world_size() 20 | 21 | 22 | def barrier(): 23 | if get_world_size() > 1: 24 | x = torch.cuda.IntTensor([1]) 25 | dist.all_reduce(x) 26 | x.cpu() 27 | -------------------------------------------------------------------------------- /images/1.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ModelTC/rank_dataset/5ef08382e14d255174892619ce686da5e7ccbd68/images/1.jpg -------------------------------------------------------------------------------- /images/2.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ModelTC/rank_dataset/5ef08382e14d255174892619ce686da5e7ccbd68/images/2.jpg -------------------------------------------------------------------------------- /images/3.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ModelTC/rank_dataset/5ef08382e14d255174892619ce686da5e7ccbd68/images/3.jpg -------------------------------------------------------------------------------- /images/4.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ModelTC/rank_dataset/5ef08382e14d255174892619ce686da5e7ccbd68/images/4.jpg -------------------------------------------------------------------------------- /images/5.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ModelTC/rank_dataset/5ef08382e14d255174892619ce686da5e7ccbd68/images/5.jpg --------------------------------------------------------------------------------