├── remote_dataloader ├── __init__.py ├── common.py ├── worker.py └── loader.py ├── .envrc ├── requirements.txt ├── README.md ├── example.py ├── LICENSE └── .gitignore /remote_dataloader/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /.envrc: -------------------------------------------------------------------------------- 1 | source ~/.bash_profile 2 | pyenv activate remote-dl -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | pyzmq==18.1.0 2 | argparse==1.4.0 3 | tqdm 4 | dill==0.3.1.1 -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # remote-dataloader 2 | 3 | DataLoader processed in multiple remote computation machines for heavy data processing. 4 | 5 | ## Architecture 6 | 7 | ## Usage 8 | 9 | ### RemoteDataLoader 10 | 11 | ```python 12 | total_trainset = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=transform_train) 13 | loader = RemoteDataLoader(total_trainset, batch_size=32, timeout=5) 14 | ```` 15 | 16 | ### Example.py 17 | 18 | ```example.py``` contains a simple example to process cifar10 images using remote nodes. 19 | 20 | ```bash 21 | $ python example.py # run server(dataloader) 22 | $ python remote_dataloader/worker.py --server {master_ip}:1958 # run multiple workers 23 | $ python remote_dataloader/worker.py --server {master_ip}:1958 24 | $ ... 25 | ``` 26 | -------------------------------------------------------------------------------- /example.py: -------------------------------------------------------------------------------- 1 | import torchvision as torchvision 2 | from torchvision.transforms import transforms 3 | from tqdm import tqdm 4 | 5 | from remote_dataloader.loader import RemoteDataLoader 6 | 7 | if __name__ == '__main__': 8 | transform_train = transforms.Compose([ 9 | transforms.RandomCrop(32, padding=4), 10 | transforms.RandomHorizontalFlip(), 11 | transforms.ToTensor(), 12 | transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)), 13 | ]) 14 | 15 | total_trainset = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=transform_train) 16 | loader = RemoteDataLoader(total_trainset, batch_size=32, timeout=5) 17 | 18 | for epoch in range(5): 19 | for img, lb in tqdm(loader): 20 | pass 21 | -------------------------------------------------------------------------------- /remote_dataloader/common.py: -------------------------------------------------------------------------------- 1 | import json 2 | import dill as pickle 3 | import random 4 | import string 5 | 6 | import logging 7 | 8 | formatter = logging.Formatter('[%(asctime)s] [%(name)s] [%(levelname)s] %(message)s') 9 | 10 | 11 | def get_logger(name, level=logging.DEBUG): 12 | logger = logging.getLogger(name) 13 | logger.handlers.clear() 14 | logger.setLevel(level) 15 | ch = logging.StreamHandler() 16 | ch.setLevel(level) 17 | ch.setFormatter(formatter) 18 | logger.addHandler(ch) 19 | return logger 20 | 21 | 22 | CODE_INIT = 'init' 23 | CODE_POLL = 'poll' 24 | 25 | 26 | def random_string(string_length=10): 27 | """ 28 | Generate a random string of fixed length 29 | ref : https://pynative.com/python-generate-random-string/ 30 | """ 31 | letters = string.ascii_lowercase 32 | return ''.join(random.choice(letters) for i in range(string_length)) 33 | 34 | 35 | def byte_message(myid, code, message): 36 | return pickle.dumps({ 37 | 'myid': myid, 38 | 'code': code, 39 | 'message': message 40 | }, protocol=pickle.HIGHEST_PROTOCOL) 41 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2019 Ildoo Kim 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /remote_dataloader/worker.py: -------------------------------------------------------------------------------- 1 | import ast 2 | import json 3 | import dill as pickle 4 | import time 5 | 6 | import zmq 7 | import argparse 8 | import socket 9 | 10 | from remote_dataloader.common import random_string, byte_message, CODE_INIT, CODE_POLL 11 | 12 | 13 | if __name__ == '__main__': 14 | parser = argparse.ArgumentParser(description='Process some integers.') 15 | parser.add_argument('--server', type=str, help='server id:port. (eg. 0.0.0.0:1958)', required=True) 16 | args = parser.parse_args() 17 | 18 | # connect 19 | context = zmq.Context() 20 | myid = '%s-%s' % (socket.gethostname(), random_string()) 21 | print("Connecting to server=%s myid=%s" % (args.server, myid)) 22 | socket = context.socket(zmq.REQ) 23 | socket.connect("tcp://%s" % args.server) 24 | 25 | # request to initialization 26 | print("Request to initialize.") 27 | socket.send(byte_message(myid, CODE_INIT, ''), copy=False) 28 | fetcher = pickle.loads(socket.recv()) 29 | 30 | print("Initialized.") 31 | 32 | jobid = data = None 33 | while True: 34 | socket.send(byte_message(myid, CODE_POLL, (jobid, data)), copy=False) 35 | 36 | msg = pickle.loads(socket.recv()) 37 | jobid = msg['message'] # list converted to string(eg. "[id1, id2, ...]") 38 | if jobid is None: 39 | jobid = data = None 40 | time.sleep(1) 41 | else: 42 | data = fetcher.fetch(ast.literal_eval(jobid)) 43 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | *.egg-info/ 24 | .installed.cfg 25 | *.egg 26 | MANIFEST 27 | 28 | # PyInstaller 29 | # Usually these files are written by a python script from a template 30 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 31 | *.manifest 32 | *.spec 33 | 34 | # Installer logs 35 | pip-log.txt 36 | pip-delete-this-directory.txt 37 | 38 | # Unit test / coverage reports 39 | htmlcov/ 40 | .tox/ 41 | .coverage 42 | .coverage.* 43 | .cache 44 | nosetests.xml 45 | coverage.xml 46 | *.cover 47 | .hypothesis/ 48 | .pytest_cache/ 49 | 50 | # Translations 51 | *.mo 52 | *.pot 53 | 54 | # Django stuff: 55 | *.log 56 | local_settings.py 57 | db.sqlite3 58 | 59 | # Flask stuff: 60 | instance/ 61 | .webassets-cache 62 | 63 | # Scrapy stuff: 64 | .scrapy 65 | 66 | # Sphinx documentation 67 | docs/_build/ 68 | 69 | # PyBuilder 70 | target/ 71 | 72 | # Jupyter Notebook 73 | .ipynb_checkpoints 74 | 75 | # pyenv 76 | .python-version 77 | 78 | # celery beat schedule file 79 | celerybeat-schedule 80 | 81 | # SageMath parsed files 82 | *.sage.py 83 | 84 | # Environments 85 | .env 86 | .venv 87 | env/ 88 | venv/ 89 | ENV/ 90 | env.bak/ 91 | venv.bak/ 92 | 93 | # Spyder project settings 94 | .spyderproject 95 | .spyproject 96 | 97 | # Rope project settings 98 | .ropeproject 99 | 100 | # mkdocs documentation 101 | /site 102 | 103 | # mypy 104 | .mypy_cache/ 105 | data 106 | -------------------------------------------------------------------------------- /remote_dataloader/loader.py: -------------------------------------------------------------------------------- 1 | import json 2 | import logging 3 | import dill as pickle 4 | import time 5 | 6 | import zmq 7 | from torch.utils.data.dataloader import _BaseDataLoaderIter, _DatasetKind, DataLoader 8 | from zmq.error import ZMQError 9 | 10 | from remote_dataloader.common import CODE_INIT, byte_message, CODE_POLL, get_logger 11 | 12 | _logger = get_logger('RemoteDataLoader', level=logging.WARNING) 13 | 14 | 15 | class _ZmqDataLoaderIter(_BaseDataLoaderIter): 16 | def __init__(self, loader): 17 | super(_ZmqDataLoaderIter, self).__init__(loader) 18 | self.dataset_fetcher = _DatasetKind.create_fetcher(self.dataset_kind, self.dataset, self.auto_collation, self.collate_fn, self.drop_last) 19 | self.pickled_fetcher = pickle.dumps(self.dataset_fetcher, protocol=pickle.HIGHEST_PROTOCOL) 20 | 21 | self.listen = loader.listen 22 | self.socket = loader.socket 23 | self.timeout = loader.timeout 24 | self.requested_queue = [] 25 | self.received_result = {} 26 | self.more_jobs = True 27 | self.return_cnt = 0 28 | 29 | def __iter__(self): 30 | self.requested_queue = [] 31 | self.received_result = {} 32 | self.more_jobs = True 33 | self.return_cnt = 0 34 | return super(_ZmqDataLoaderIter, self).__iter__(self) 35 | 36 | def __next__(self): 37 | while self.more_jobs or len(self.requested_queue) > 0: 38 | # Wait for next request from client 39 | try: 40 | message = self.socket.recv(zmq.NOBLOCK) 41 | except ZMQError: 42 | # check queue 43 | if len(self.requested_queue) > 0: 44 | jobid, request_t = self.requested_queue[0] 45 | if jobid in self.received_result: 46 | self.requested_queue.pop(0) 47 | data = self.received_result.pop(jobid) 48 | self.return_cnt += 1 49 | return data 50 | continue 51 | 52 | # process client's message 53 | msg = pickle.loads(message) 54 | if msg['code'] == CODE_INIT: 55 | cmd = self.pickled_fetcher 56 | elif msg['code'] == CODE_POLL: 57 | jobid, data = msg['message'] 58 | if jobid is not None: 59 | self.received_result[jobid] = data 60 | 61 | try: 62 | request_id, request_t = self.requested_queue[0] if len(self.requested_queue) > 0 else (-1, -1) 63 | if request_t > 0 and 0 < self.timeout < time.time() - request_t and request_id not in self.received_result: 64 | jobid, request_t = self.requested_queue[0] 65 | _logger.warning('task timeout, retry. socket=%s' % self.listen) 66 | self.requested_queue[0][1] = time.time() # override current time 67 | elif self.more_jobs: 68 | jobid = str(self._next_index()) 69 | self.requested_queue.append([jobid, time.time()]) 70 | else: 71 | jobid = None 72 | cmd = byte_message('server', CODE_POLL, jobid) 73 | except StopIteration: 74 | cmd = byte_message('server', CODE_POLL, None) 75 | self.more_jobs = False 76 | else: 77 | raise ValueError('cannot process message: %s' % msg) 78 | 79 | # Send reply back to client 80 | self.socket.send(cmd) 81 | raise StopIteration 82 | 83 | 84 | class RemoteDataLoader(DataLoader): 85 | def __init__(self, dataset, batch_size=1, shuffle=False, sampler=None, 86 | batch_sampler=None, collate_fn=None, pin_memory=False, drop_last=False, timeout=0, 87 | listen='*:1958'): 88 | super(RemoteDataLoader, self).__init__(dataset, batch_size, shuffle=shuffle, sampler=sampler, 89 | batch_sampler=batch_sampler, collate_fn=collate_fn, pin_memory=pin_memory, 90 | drop_last=drop_last) 91 | _logger.info('RemoteDataLoader listen... %s' % listen) 92 | 93 | # socket 94 | self.timeout = timeout 95 | self.context = zmq.Context() 96 | self.socket = self.context.socket(zmq.REP) 97 | self.socket.bind("tcp://%s" % listen) 98 | self.listen = listen 99 | 100 | def __iter__(self): 101 | return _ZmqDataLoaderIter(self) 102 | --------------------------------------------------------------------------------