├── rtutils ├── __init__.py ├── misc.py ├── drive2.py ├── coco.py ├── drive.py ├── DataLoader2.py ├── sampler.py ├── DataLoader.py └── pl_patch.py ├── .gitignore ├── jupyter_connect_cmd ├── requirements.txt ├── rtutils_cli └── gdrive_wrapper.py ├── ipy ├── setup.py ├── LICENSE ├── .github └── workflows │ ├── python-publish.yml │ └── python-package.yml ├── README.md ├── tests └── test_sampler.py └── examples └── pl_deterministic_sampler.py /rtutils/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | *pyc 2 | __pycache__ 3 | build 4 | dist 5 | *egg* 6 | -------------------------------------------------------------------------------- /jupyter_connect_cmd: -------------------------------------------------------------------------------- 1 | ssh -N -L 8880:128.135.8.27:18661 rluo@slurm.ttic.edu 2 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | oauth2client 2 | PyDrive 3 | torch==1.8.1 4 | mmcv 5 | 6 | -------------------------------------------------------------------------------- /rtutils/misc.py: -------------------------------------------------------------------------------- 1 | import os 2 | import pylab 3 | 4 | def save_current_fig(fn): 5 | save_path = os.path.join('/home-nfs/rluo/rluo/figures', fn) 6 | fig = pylab.gcf() 7 | if not os.path.exists(os.path.dirname(save_path)): 8 | os.makedirs(os.path.dirname(save_path)) 9 | fig.savefig(save_path) 10 | -------------------------------------------------------------------------------- /rtutils_cli/gdrive_wrapper.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import sys 3 | import os 4 | import re 5 | 6 | def simplify(url): 7 | url = re.split("[/\\\]", url) 8 | return [_ for _ in url if len(_) == 33][0] 9 | 10 | 11 | def main(): 12 | args = sys.argv[1:] 13 | args = [simplify(x) if 'http' in x else x for x in args] 14 | print('gdrive '+' '.join(args)) 15 | # os.system('gdrive '+' '.join(args)) 16 | -------------------------------------------------------------------------------- /ipy: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | unset XDG_RUNTIME_DIR 4 | ip=$(/sbin/ip route get 8.8.8.8 | awk '{print $NF;exit}') 5 | port=$((10000+ $RANDOM % 20000)) 6 | 7 | echo Use the following ssh command in your laptop shell to tunnel the notebook server: 8 | 9 | if [ -z "$1" ] 10 | then 11 | toport=8880 12 | else 13 | toport=$1 14 | fi 15 | 16 | CMD="ssh -N -L $toport:$ip:$port $USER@$HOSTNAME" 17 | echo "$CMD" 18 | echo "$CMD" > $HOME/utils/jupyter_connect_cmd 19 | 20 | #upyter-lab --no-browser --ip=$ip --port=$port --log-level='ERROR' 21 | jupyter-notebook --no-browser --ip=$ip --port=$port --log-level='ERROR' 22 | -------------------------------------------------------------------------------- /rtutils/drive2.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | 3 | from googleapiclient import discovery 4 | from httplib2 import Http 5 | from oauth2client import file, client 6 | from oauth2client.tools import argparser, run_flow 7 | import argparse 8 | import os 9 | 10 | SCOPES = 'https://www.googleapis.com/auth/drive.readonly.metadata' 11 | store = file.Storage(os.path.expanduser("~") + '/.gdrive_credential.json') 12 | creds = store.get() 13 | if not creds or creds.invalid: 14 | flow = client.flow_from_clientsecrets(os.path.expanduser("~") + '/.rtutils_credentials.json', SCOPES) 15 | args = argparser.parse_args() 16 | args.noauth_local_webserver = True 17 | creds = run_flow(flow, store, args) 18 | DRIVE = discovery.build('drive', 'v3', http=creds.authorize(Http())) 19 | 20 | files = DRIVE.files().list().execute().get('files', []) 21 | for f in files: 22 | print(f['name'], f['mimeType']) -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | import setuptools 2 | 3 | #with open("README.md", "r") as fh: 4 | # long_description = fh.read() 5 | 6 | setuptools.setup( 7 | name="rtutils", # Replace with your own username 8 | version="0.3", 9 | author="Ruotian Luo", 10 | author_email="rluo@ttic.edu", 11 | entry_points={ 12 | 'console_scripts': [ 13 | 'rtdrive=rtutils.drive:main', 14 | 'gd=rtutils_cli.gdrive_wrapper:main' 15 | ] 16 | }, 17 | # description="A small example package", 18 | # long_description=long_description, 19 | # long_description_content_type="text/markdown", 20 | # url="https://github.com/pypa/sampleproject", 21 | packages=setuptools.find_packages(), 22 | # classifiers=[ 23 | # "Programming Language :: Python :: 3", 24 | # "License :: OSI Approved :: MIT License", 25 | # "Operating System :: OS Independent", 26 | # ], 27 | # python_requires='>=3.6', 28 | ) 29 | -------------------------------------------------------------------------------- /rtutils/coco.py: -------------------------------------------------------------------------------- 1 | import os 2 | import mmcv 3 | 4 | def get_coco_filename(cocoid): 5 | cocoid = int(cocoid) 6 | if os.path.isfile('/share/data/vdata/coco/images/train2017/%012d.jpg' %(cocoid)): 7 | return '/share/data/vdata/coco/images/train2017/%012d.jpg' %(cocoid) 8 | else: 9 | return '/share/data/vdata/coco/images/val2017/%012d.jpg' %(cocoid) 10 | 11 | coco_image_infos = None 12 | 13 | def get_coco_image_infos(): 14 | coco_image_infos = mmcv.load('/share/data/vdata/coco/annotations/captions_train2017.json')['images'] + mmcv.load('/share/data/vdata/coco/annotations/captions_val2017.json')['images'] 15 | coco_image_infos = {_['id']: _ for _ in coco_image_infos} 16 | return coco_image_infos 17 | 18 | 19 | def get_coco_url(cocoid): 20 | global coco_image_infos 21 | if coco_image_infos is None: 22 | coco_image_infos = get_coco_image_infos() 23 | cocoid = int(cocoid) 24 | return coco_image_infos[cocoid]['coco_url'] -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Copyright (c) 2018 The Python Packaging Authority 2 | 3 | Permission is hereby granted, free of charge, to any person obtaining a copy 4 | of this software and associated documentation files (the "Software"), to deal 5 | in the Software without restriction, including without limitation the rights 6 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 7 | copies of the Software, and to permit persons to whom the Software is 8 | furnished to do so, subject to the following conditions: 9 | 10 | The above copyright notice and this permission notice shall be included in all 11 | copies or substantial portions of the Software. 12 | 13 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 14 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 15 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 16 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 17 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 18 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 19 | SOFTWARE. 20 | -------------------------------------------------------------------------------- /.github/workflows/python-publish.yml: -------------------------------------------------------------------------------- 1 | # This workflow will upload a Python Package using Twine when a release is created 2 | # For more information see: https://help.github.com/en/actions/language-and-framework-guides/using-python-with-github-actions#publishing-to-package-registries 3 | 4 | # This workflow uses actions that are not certified by GitHub. 5 | # They are provided by a third-party and are governed by 6 | # separate terms of service, privacy policy, and support 7 | # documentation. 8 | 9 | name: Upload Python Package 10 | 11 | on: 12 | release: 13 | types: [published] 14 | 15 | jobs: 16 | deploy: 17 | 18 | runs-on: ubuntu-latest 19 | 20 | steps: 21 | - uses: actions/checkout@v2 22 | - name: Set up Python 23 | uses: actions/setup-python@v2 24 | with: 25 | python-version: '3.x' 26 | - name: Install dependencies 27 | run: | 28 | python -m pip install --upgrade pip 29 | pip install build 30 | - name: Build package 31 | run: python -m build 32 | - name: Publish package 33 | uses: pypa/gh-action-pypi-publish@27b31702a0e7fc50959f5ad993c78deac1bdfc29 34 | with: 35 | user: __token__ 36 | password: ${{ secrets.PYPI_API_TOKEN }} 37 | -------------------------------------------------------------------------------- /.github/workflows/python-package.yml: -------------------------------------------------------------------------------- 1 | # This workflow will install Python dependencies, run tests and lint with a variety of Python versions 2 | # For more information see: https://help.github.com/actions/language-and-framework-guides/using-python-with-github-actions 3 | 4 | name: Python package 5 | 6 | on: 7 | push: 8 | branches: [ main ] 9 | pull_request: 10 | branches: [ main ] 11 | 12 | jobs: 13 | build: 14 | 15 | runs-on: ubuntu-latest 16 | strategy: 17 | fail-fast: false 18 | matrix: 19 | python-version: [3.7, 3.8, 3.9] 20 | 21 | steps: 22 | - uses: actions/checkout@v2 23 | - name: Set up Python ${{ matrix.python-version }} 24 | uses: actions/setup-python@v2 25 | with: 26 | python-version: ${{ matrix.python-version }} 27 | - name: Install dependencies 28 | run: | 29 | python -m pip install --upgrade pip 30 | python -m pip install flake8 pytest 31 | if [ -f requirements.txt ]; then pip install -r requirements.txt; fi 32 | python -m pip install ./ 33 | - name: Lint with flake8 34 | run: | 35 | # stop the build if there are Python syntax errors or undefined names 36 | flake8 . --count --select=E9,F63,F7,F82 --show-source --statistics 37 | # exit-zero treats all errors as warnings. The GitHub editor is 127 chars wide 38 | flake8 . --count --exit-zero --max-complexity=10 --max-line-length=127 --statistics 39 | - name: Test with pytest 40 | run: | 41 | pytest 42 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # RT's utility functions 2 | 3 | To install: `pip install rtutils` 4 | 5 | Note that: some functions are designed to only work for me and my machines. But those that are useful, I make them as general as possible. 6 | ## Stateless resume from checkpoint saved in the middle of one epoch 7 | 8 | The change is essencially one line for import and one line for use. 9 | 10 | Import function implemented in our repo: 11 | ``` 12 | import rtutils.pl_patch 13 | ``` 14 | 15 | Use the function imported to modify the pytorch_lightning trainer: 16 | ``` 17 | trainer = pl.Trainer(...) 18 | # after defining the trainer 19 | rtutils.pl_patch.patch_everything(trainer) 20 | ``` 21 | 22 | And, see the file examples/pl_deterministic_sampler.py to see what other changes need to made. 23 | 24 | Some requirements: 25 | - pytorch_lightning version: >= 1.4.0 for rtutils >= 0.3; pytorch_lightning version: >=1.3 and < 1.4 for rtutils < 0.3 26 | - Using ModelCheckpoint 27 | - Using ddp training 28 | - replace_ddp_sampler is True (default to be True); and you are not manually setting any train loader sampler. 29 | - Fixed batch_size and fixed dataset. (It would also work if batch_size or dataset changes, but it just will not be a "correct" resume.) 30 | - It won't work with your existing checkpoints because the existing checkpoint doesn't have total_batch_idx saved. 31 | 32 | Expected behavior: 33 | - If you ctrl-C the training, the trainer will save a "last" checkpoint at current iteration (you can try this with in examples/pl_deterministic_sampler.py now); or you can write your own callback that saves checkpoint at middle of an epoch (e.g. every k iterations). 34 | - Next time when you resume from "last" checkpoint, it will resume from the exact iteration. 35 | 36 | To what level of resume: 37 | - Only dataset indices are resumed. It guarantee that you will not see the same data twice in one epoch. 38 | - RNG states will not be resumed. So the dropout or the random augmentation of data in the worker will be different, comparing resume from keep training. (If you don't understand, it's fine. I don't really think it matters.) 39 | 40 | ## To use drive 41 | visit [link](https://developers.google.com/drive/api/v3/quickstart/python) 42 | 43 | Follow this to get the credentials.json. 44 | 45 | copy to home directory and rename to `.rtutils_credentials.json`. 46 | -------------------------------------------------------------------------------- /tests/test_sampler.py: -------------------------------------------------------------------------------- 1 | from rtutils.sampler import DeterministicDistributedSampler, InfiniteDistributedSampler 2 | from torch.utils.data import DataLoader 3 | import pytest 4 | 5 | 6 | @pytest.mark.parametrize("seed", 7 | [0, 100] 8 | ) 9 | @pytest.mark.parametrize("global_batch_size, batch_size", 10 | [[0, 25], [100, 0]] 11 | ) 12 | @pytest.mark.parametrize("num_replicas", 13 | [4] 14 | ) 15 | @pytest.mark.parametrize("rank", 16 | [0,1,2,3] 17 | ) 18 | @pytest.mark.parametrize("dataset_size", 19 | [1000, 1050] 20 | ) 21 | def test_deterministic_sampler(seed, global_batch_size, batch_size, num_replicas, rank, dataset_size): 22 | dataset = list(range(dataset_size)) 23 | x = DeterministicDistributedSampler(dataset, num_replicas, rank, shuffle=True, global_batch_size=global_batch_size, batch_size=batch_size, seed=seed) 24 | dl = DataLoader(dataset, batch_size=batch_size or global_batch_size // num_replicas, sampler=x, collate_fn=lambda x: x) 25 | 26 | tmp = [] 27 | for epoch in range(10): 28 | for x in dl: 29 | tmp.append(x) 30 | 31 | # resume_froms = [5,15,30,44,50] 32 | resume_froms = [0,10] 33 | 34 | for resume_from in resume_froms: 35 | x = DeterministicDistributedSampler(dataset, num_replicas, rank, shuffle=True, global_batch_size=global_batch_size, batch_size=batch_size, seed=seed) 36 | x.set_epoch(resume_from) 37 | print(resume_from, len(x)) 38 | dl = DataLoader(dataset, batch_size=batch_size or global_batch_size // num_replicas, sampler=x, collate_fn=lambda x: x) 39 | tmp1 = [] 40 | for epoch in range(10): 41 | cnt = 0 42 | for x in dl: 43 | cnt += 1 44 | tmp1.append(x) 45 | if epoch == 0: 46 | print(cnt) 47 | for x,y in zip(tmp[resume_from:], tmp1): 48 | assert x == y 49 | 50 | 51 | def test_inifinte(): 52 | dataset = list(range(100)) 53 | x = InfiniteDistributedSampler(dataset, 4, 0, shuffle=False) 54 | dl = DataLoader(dataset, batch_size=22, sampler=x, collate_fn=lambda x: x) 55 | for x in dl: 56 | print(x) 57 | break 58 | 59 | 60 | @pytest.mark.parametrize("global_batch_size, batch_size", 61 | [[0, 4], [16, 0]] 62 | ) 63 | def test_inifinte_deterministic(global_batch_size, batch_size): 64 | dataset = list(range(1000)) 65 | x = InfiniteDistributedSampler(dataset, 4, 0, shuffle=True, deterministic=True, global_batch_size=global_batch_size, batch_size=batch_size) 66 | dl = DataLoader(dataset, batch_size=4, sampler=x, collate_fn=lambda x: x) 67 | 68 | tmp = [] 69 | for x,_ in zip(dl, range(10)): 70 | print(x) 71 | tmp.append(x) 72 | 73 | print('-'*100) 74 | x = InfiniteDistributedSampler(dataset, 4, 0, shuffle=True, deterministic=True, global_batch_size=global_batch_size, batch_size=batch_size) 75 | x.set_epoch(2) 76 | dl = DataLoader(dataset, batch_size=4, sampler=x, collate_fn=lambda x: x) 77 | for x,_ in zip(dl, range(8)): 78 | print(x) 79 | assert x == tmp[_+2] 80 | 81 | -------------------------------------------------------------------------------- /examples/pl_deterministic_sampler.py: -------------------------------------------------------------------------------- 1 | # Example of pl_deterministic_sampler. 2 | 3 | import os 4 | 5 | import torch 6 | from torch.utils.data import DataLoader, Dataset 7 | 8 | from pytorch_lightning import LightningModule, Trainer 9 | 10 | # Note: 11 | # change are unnecessary changes, just for demoonstration in this file 12 | # CHANGE are necessar changes. 13 | 14 | 15 | # This file is modified from https://github.com/PyTorchLightning/pytorch-lightning/blob/master/pl_examples/bug_report_model.py 16 | 17 | ################################################################ 18 | # To use my deterministic sampler and pytorch lightning integration, 19 | # there is a few things you need to know. 20 | # First, you have to use my fork of the pytorch-lightning which includes some changes. 21 | # there are no breaking changes and should work as the original version in other cases. 22 | # Second, this sampler assumes fixed batch_size and fixed number of iterations in one epoch. 23 | ################################################################ 24 | 25 | import rtutils.pl_patch # CHANGE 26 | import pytorch_lightning as pl # change 27 | 28 | class RandomDataset(Dataset): 29 | 30 | def __init__(self, size, length): 31 | self.len = length 32 | self.data = torch.randn(length, size) 33 | 34 | def __getitem__(self, index): 35 | return self.data[index] 36 | 37 | def __len__(self): 38 | return self.len 39 | 40 | 41 | class BoringModel(LightningModule): 42 | 43 | def __init__(self): 44 | super().__init__() 45 | self.layer = torch.nn.Linear(32, 2) 46 | 47 | def forward(self, x): 48 | return self.layer(x) 49 | 50 | def training_step(self, batch, batch_idx): 51 | loss = self(batch).sum() 52 | self.log("train_loss", loss) 53 | # change 54 | # sleep to make training slower for you have time to interrupt, do not follow this in practice. 55 | import time; time.sleep(1) 56 | # change end 57 | return {"loss": loss} 58 | 59 | def validation_step(self, batch, batch_idx): 60 | loss = self(batch).sum() 61 | # change 62 | # sleep to make training slower for you have time to interrupt, do not follow this in practice. 63 | import time; time.sleep(0.5) 64 | # change end 65 | self.log("valid_loss", loss) 66 | 67 | def test_step(self, batch, batch_idx): 68 | loss = self(batch).sum() 69 | self.log("test_loss", loss) 70 | 71 | def configure_optimizers(self): 72 | return torch.optim.SGD(self.layer.parameters(), lr=0.1) 73 | 74 | 75 | def run(): 76 | train_data = DataLoader(RandomDataset(32, 64), batch_size=2) 77 | val_data = DataLoader(RandomDataset(32, 64), batch_size=2) 78 | test_data = DataLoader(RandomDataset(32, 64), batch_size=2) 79 | 80 | # change 81 | # You would have to have some model_checkpoint to save checkpoints. (If you don't save, what to resume right?) 82 | # otherwise, nothing to change. 83 | callbacks = [ 84 | pl.callbacks.ModelCheckpoint( 85 | dirpath=os.getcwd(), 86 | save_last=True, 87 | ), 88 | ] 89 | # change 90 | 91 | model = BoringModel() 92 | trainer = Trainer( 93 | callbacks = callbacks, 94 | default_root_dir=os.getcwd(), 95 | accelerator='ddp', # CHANGE; the deterministic sampler only work for ddp training. 96 | gpus=1, # change 97 | resume_from_checkpoint='last.ckpt' if os.path.exists('last.ckpt') else None, # CHANGE; resume from middle of checkpoint 98 | # limit_train_batches=1, # change 99 | # limit_val_batches=1, # change 100 | num_sanity_val_steps=0, 101 | max_epochs=2, 102 | weights_summary=None, 103 | ) 104 | rtutils.pl_patch.patch_everything(trainer) # CHANGE; patch multiple things. If you want to know more. Check out the source code for this function. 105 | trainer.fit(model, train_dataloader=train_data, val_dataloaders=val_data) 106 | # trainer.test(model, test_dataloaders=test_data) # change 107 | 108 | 109 | if __name__ == '__main__': 110 | run() -------------------------------------------------------------------------------- /rtutils/drive.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | 3 | from pydrive.auth import GoogleAuth 4 | from pydrive.drive import GoogleDrive 5 | from oauth2client.client import GoogleCredentials 6 | import os 7 | 8 | 9 | class MyDrive: 10 | def __init__(self): 11 | gauth = GoogleAuth() 12 | cred_file = os.path.expanduser("~") + '/.gdrive_credential.json' 13 | if os.path.exists(cred_file): 14 | gauth.LoadCredentialsFile(cred_file) 15 | else: 16 | gauth.LoadClientConfigFile(os.path.expanduser("~") + '/.rtutils_credentials.json') 17 | gauth.CommandLineAuth() 18 | # gauth.credentials = '~/.rtutils_client.json' #GoogleCredentials.get_application_default() 19 | gauth.SaveCredentialsFile(cred_file) 20 | 21 | self.drive = GoogleDrive(gauth) 22 | 23 | def create_folder(self, folder, remote_folder_id): 24 | newFolder = self.drive.CreateFile({'title': folder, "parents": [{"kind": "drive#fileLink", "id": remote_folder_id}], 25 | "mimeType": "application/vnd.google-apps.folder"}) 26 | newFolder.Upload() 27 | return newFolder['id'] 28 | 29 | def upload_folder(self, root, folder, remote_folder_id): 30 | # Create folder 31 | newFolder_id = self.create_folder(folder, remote_folder_id) 32 | new_root = os.path.join(root, folder) 33 | for item in os.listdir(new_root): 34 | if os.path.isfile(os.path.join(new_root, item)): 35 | # Upload files in this folder 36 | self.upload_file(new_root, item, newFolder_id) 37 | else: 38 | # Upload folder in this folder 39 | self.upload_folder(new_root, item, newFolder_id) 40 | return newFolder_id 41 | 42 | def upload_file(self, root, filename, remote_folder_id=None): 43 | if remote_folder_id == None: 44 | newFile = self.drive.CreateFile({"title": filename}) 45 | else: 46 | newFile = self.drive.CreateFile({"title": filename, 47 | "parents": [{"kind": "drive#fileLink", "id": remote_folder_id}]}) 48 | newFile.SetContentFile(os.path.join(root, filename)) 49 | newFile.Upload() 50 | 51 | def download_file(self, download_id, download_dir): 52 | newFile = self.drive.CreateFile({'id': download_id}) 53 | newFile.FetchMetadata(fetch_all=True) 54 | newFile.GetContentFile(os.path.join(download_dir, newFile['title'])) 55 | return newFile 56 | 57 | def main(): 58 | drive = MyDrive() 59 | parser = argparse.ArgumentParser(description='Drive') 60 | subparsers = parser.add_subparsers(dest="command", required=True) 61 | parser_upload = subparsers.add_parser('upload', help='upload file or folder') 62 | parser_download = subparsers.add_parser('download', help='download file') 63 | 64 | parser_upload.add_argument('upload_path', type=str, help='') 65 | parser_upload.add_argument('remote_folder_id', type=str, nargs='?', default=None, help='') 66 | 67 | parser_download.add_argument('download_id', type=str, help='') 68 | parser_download.add_argument('download_dir', type=str, nargs='?', default='./', help='') 69 | 70 | args = parser.parse_args() 71 | # print(args) 72 | 73 | if args.command == 'upload': 74 | path = args.upload_path 75 | root = os.path.dirname(path) 76 | name = os.path.basename(path) 77 | if os.path.isdir(path): 78 | drive.upload_folder(root, name, args.remote_folder_id) 79 | else: 80 | drive.upload_file(root, name, args.remote_folder_id) 81 | print('Upload done') 82 | elif args.command == 'download': 83 | file = drive.download_file(args.download_id, args.download_dir) 84 | print('Download finished %s' %os.path.join(args.download_dir, file["title"])) 85 | 86 | # file_name = args.file 87 | # action = args.action 88 | # dummy = args.dummy 89 | # partition = args.partition 90 | # length = args.length 91 | # num_cores = args.num_cores 92 | # features = args.feature_constraints 93 | # print("Using partition {}".format(partition)) 94 | # if dummy: 95 | # print("Under dummy mode") 96 | # if action not in allowed_actions: 97 | # raise ValueError( 98 | # "action must be one of {}, but given: {}".format(allowed_actions, action)) 99 | 100 | # with open(file_name) as f: 101 | # task_dir_list = yaml.load(f) 102 | # for task_dir in task_dir_list: 103 | # if not path.isdir(task_dir): 104 | # raise ValueError("{} is not a valid directory".format(task_dir)) 105 | # else: 106 | # task_execute( 107 | # task_dir, action, length, dummy, partition, num_cores, features) 108 | 109 | if __name__ == '__main__': 110 | drive = MyDrive() 111 | main() 112 | # drive.upload_folder('../content/', 'weekly_2020-01-10', '1DagXOiUK-oqBQ7lN734X1ZdxAiRnsFC1') -------------------------------------------------------------------------------- /rtutils/DataLoader2.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import itertools 3 | import time 4 | import torch.distributed as dist 5 | 6 | class ResumableDataLoader(torch.utils.data.DataLoader): 7 | def __init__(self, *args, **kwargs): 8 | super().__init__(*args, **kwargs) 9 | self._need_to_resume = False 10 | self._rng_state = None 11 | self._iter = None 12 | 13 | def get_sampler_iter(self): 14 | tmp_num_workers = self.num_workers 15 | self.num_workers = 0 16 | dataiter = super().__iter__() 17 | __sampler_iter, _sampler_iter = itertools.tee(dataiter._sampler_iter) 18 | next(__sampler_iter) # Make sure the _sampler_iter is instantiated 19 | self.num_workers = tmp_num_workers 20 | return _sampler_iter 21 | 22 | def __iter__(self): 23 | 24 | if self._need_to_resume: 25 | # Run through the saved_iter, make the sampler at the state of halt 26 | self._sampler_iter, _sampler_iter = itertools.tee(iter(self._saved_sampler)) 27 | for i in range(self._saved_iter): 28 | next(_sampler_iter) 29 | else: 30 | _sampler_iter = self.get_sampler_iter() 31 | self._sampler_iter, _sampler_iter = itertools.tee(_sampler_iter) # Save the full iter 32 | 33 | self.get_sampler_to_save() # for state_dict 34 | 35 | dataiter = super().__iter__() 36 | dataiter._sampler_iter = _sampler_iter 37 | if self.num_workers > 0: 38 | # Wait for prefetching 39 | time.sleep(5) 40 | # drain the prefetched batches 41 | for i in range(dataiter._send_idx - dataiter._rcvd_idx): 42 | next(dataiter) 43 | 44 | self._iter = dataiter # Save it in case 45 | self._need_to_resume = False # Only resume at the first iter(loader) after load_state_dict 46 | return self._iter 47 | 48 | def get_sampler_to_save(self): 49 | # All the indices: 50 | self._sampler_to_save = [] 51 | for idxs in self._sampler_iter: 52 | if dist.is_available() and type(self.sampler) is torch.utils.data.distributed.DistributedSampler: 53 | idxs_list = [torch.zeros(len(idxs)).long().to(torch.cuda.current_device()) for k in range(dist.get_world_size())] 54 | dist.all_gather(idxs_list, torch.tensor(idxs).long().to(torch.cuda.current_device()), async_op=False) # can also use detectron.utils.comm.all_gather 55 | self._sampler_to_save.append([_.tolist() for _ in idxs_list]) 56 | else: 57 | self._sampler_to_save.append(idxs) 58 | 59 | def state_dict(self): 60 | if self._iter is None: 61 | return None 62 | # The number of batches rest in this epoch 63 | num_rest = 0 64 | self._iter._sampler_iter, _sampler_iter = itertools.tee(self._iter._sampler_iter) 65 | for idxs in _sampler_iter: 66 | num_rest += 1 67 | 68 | # the number of batched prefetched by the loader 69 | if self.num_workers > 0: 70 | num_prefetched = self._iter._send_idx - self._iter._rcvd_idx 71 | else: 72 | num_prefetched = 0 73 | self._saved_iter = len(self) - num_rest - num_prefetched 74 | 75 | return {'saved_iter': self._saved_iter, 76 | 'sampler': self._sampler_to_save} 77 | 78 | def load_state_dict(self, state_dict): 79 | if state_dict is None: 80 | return 81 | self._need_to_resume = True 82 | self._saved_iter = state_dict['saved_iter'] 83 | self._saved_sampler = state_dict['sampler'] 84 | if dist.is_available() and type(self.sampler) is torch.utils.data.distributed.DistributedSampler: 85 | self._saved_sampler = [_[dist.get_rank()] for _ in self._saved_sampler] 86 | 87 | 88 | 89 | def test1(): 90 | train_loader = ResumableDataLoader(list(range(1000)), #testDataset(), 91 | batch_size=10, 92 | shuffle=True, 93 | num_workers=2) 94 | 95 | 96 | dataloader_iter = iter(train_loader) 97 | 98 | batches = [] 99 | for i, data in enumerate(dataloader_iter): 100 | print(i, data) 101 | batches.append(data) 102 | if i == 3: 103 | print('Save the state_dict after the third iteration') 104 | state_dict = train_loader.state_dict() 105 | if i == 20: 106 | print('End up here') 107 | break 108 | 109 | train_loader.load_state_dict(state_dict) 110 | print('-----') 111 | print('Resume from the fourth iteration') 112 | for i, data in enumerate(train_loader): 113 | print(i+4, data) 114 | if i+4 >= len(batches): 115 | print(i+4) 116 | break 117 | assert (batches[i+4] == data).all() 118 | 119 | print('--------------') 120 | 121 | def test2(): 122 | 123 | # test 2: edge case, when dataiter has gone to the end. Make sure no error 124 | train_loader = ResumableDataLoader(list(range(10)), #testDataset(), 125 | batch_size=10, 126 | shuffle=True, 127 | num_workers=2) 128 | 129 | dataloader_iter = iter(train_loader) 130 | for i, data in enumerate(dataloader_iter): 131 | print(data) 132 | state_dict = train_loader.state_dict() 133 | train_loader.load_state_dict(state_dict) 134 | for i, data in enumerate(train_loader): 135 | print(data) 136 | 137 | def main_worker(rank, ngpus, port): 138 | 139 | dist.init_process_group( 140 | world_size=ngpus, rank=rank, 141 | backend='nccl', init_method='tcp://127.0.0.1:%d' %port, 142 | ) 143 | torch.cuda.set_device(rank) 144 | 145 | train_dataset = list(range(10000)) 146 | train_sampler = torch.utils.data.distributed.DistributedSampler(train_dataset) 147 | 148 | train_loader = ResumableDataLoader( 149 | train_dataset, batch_size=10, shuffle=(train_sampler is None), 150 | num_workers=4, pin_memory=True, sampler=train_sampler) 151 | 152 | dataloader_iter = iter(train_loader) 153 | 154 | batches = [] 155 | for i, data in enumerate(dataloader_iter): 156 | print(rank, i, data) 157 | batches.append(data) 158 | if i == 3: 159 | print('Save the state_dict after the third iteration') 160 | state_dict = train_loader.state_dict() 161 | if i == 20: 162 | print('End up here') 163 | break 164 | 165 | train_loader.load_state_dict(state_dict) 166 | print('-----') 167 | print('Resume from the fourth iteration') 168 | for i, data in enumerate(train_loader): 169 | print(rank, i+4, data) 170 | if i+4 >= len(batches): 171 | print(i+4) 172 | break 173 | assert (batches[i+4] == data).all() 174 | 175 | def test3(): 176 | # distributed: 177 | import torch.distributed as dist 178 | import torch.utils.data.distributed 179 | import torch.multiprocessing as mp 180 | 181 | port = 11000 182 | mp.spawn(main_worker, nprocs=2, args=(2, port)) 183 | 184 | 185 | if __name__ == '__main__': 186 | test3() 187 | 188 | 189 | 190 | 191 | 192 | -------------------------------------------------------------------------------- /rtutils/sampler.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.utils.data.distributed import DistributedSampler 3 | import math 4 | import warnings 5 | import itertools 6 | 7 | 8 | class DeterministicDistributedSampler(DistributedSampler): 9 | 10 | def __init__(self, *args, **kwargs): 11 | """ 12 | we always start the generator with seed 0, 13 | and then to restart from certain iteration, we just 14 | manually drain the iterator. This can make dataloader 15 | properly resumed. 16 | 17 | Big note: we assume set_epoch takes in the iteration!!! 18 | 19 | Args: 20 | global_batch_size: We need global_batch_size because we 21 | we need to know how many interations each epoch so 22 | that we drain the iterator correctly. 23 | and we also manually apply drop_last. 24 | batch_size: local batch size. since we know world size, 25 | we can drive the global_batch_size 26 | """ 27 | self.global_batch_size = kwargs.pop('global_batch_size', 0) 28 | self.batch_size = kwargs.pop('batch_size', 0) 29 | 30 | super().__init__(*args, **kwargs) 31 | 32 | warnings.warn( 33 | 'You are using a customized distributed sampler. Make sure you are feeding ' 34 | 'iteration number to the set_epoch function; In addition, batch_size has to be fixed!' 35 | ) 36 | 37 | # intialize batch_size and global_batch_size. 38 | if self.batch_size != 0: 39 | assert self.global_batch_size == 0, 'don\'t set both batch_size and global_batch_size' 40 | self.global_batch_size = self.batch_size * self.num_replicas 41 | elif self.global_batch_size != 0: 42 | assert self.batch_size == 0, 'don\'t set both batch_size and global_batch_size' 43 | self.batch_size = self.global_batch_size // self.num_replicas 44 | else: 45 | assert False, 'batch_size or global_batch_size should be specified.' 46 | 47 | # define the generator. There will be one generator form the beginning to the end. 48 | self.generator = torch.Generator() 49 | self.generator.manual_seed(self.seed) 50 | 51 | self._drained = False # if False, we will drain the iterator according to current self.epoch 52 | 53 | def __len__(self): 54 | if not self._drained: 55 | iterations_in_one_epoch = self.num_samples // self.batch_size 56 | done_iteration_in_this_epoch = self.epoch % iterations_in_one_epoch 57 | return (iterations_in_one_epoch - done_iteration_in_this_epoch) * self.batch_size 58 | else: 59 | return self.num_samples // self.batch_size * self.batch_size 60 | 61 | # Modified from https://github.com/facebookresearch/detectron2/blob/master/detectron2/data/samplers/distributed_sampler.py#L12 62 | def __iter__(self): 63 | start = self.rank 64 | # drain the sampler 65 | if not self._drained: 66 | _i = 0 67 | _iter = iter(self._indices()) 68 | while _i < self.epoch * self.batch_size: 69 | try: 70 | next(_iter) 71 | except: 72 | _iter = iter(self._indices()) 73 | next(_iter) 74 | self._drained = True 75 | return _iter 76 | else: 77 | return iter(self._indices()) 78 | 79 | def __iter__(self): 80 | start = self.rank 81 | indices = self._indices() 82 | # drain the sampler 83 | if not self._drained: 84 | done_epochs = (self.epoch * self.batch_size) // len(indices) 85 | for i in range(done_epochs): 86 | indices = self._indices() 87 | done_iteration_in_this_epoch = (self.epoch * self.batch_size) % len(indices) 88 | self._drained = True 89 | return iter(indices[done_iteration_in_this_epoch:]) 90 | else: 91 | return iter(indices) 92 | 93 | def _indices(self): 94 | """ 95 | This function almost copy from original pytorch implementation; 96 | first change: use one generator 97 | second change: directly apply drop_last here. Because we want to make sure 98 | """ 99 | if self.shuffle: 100 | g = self.generator 101 | indices = torch.randperm(len(self.dataset), generator=g).tolist() 102 | else: 103 | indices = list(range(len(self.dataset))) 104 | 105 | if not self.drop_last: 106 | # add extra samples to make it evenly divisible 107 | padding_size = self.total_size - len(indices) 108 | if padding_size <= len(indices): 109 | indices += indices[:padding_size] 110 | else: 111 | indices += (indices * math.ceil(padding_size / len(indices)))[:padding_size] 112 | else: 113 | # remove tail of data to make it evenly divisible. 114 | indices = indices[:self.total_size] 115 | assert len(indices) == self.total_size 116 | 117 | # subsample 118 | indices = indices[self.rank:self.total_size:self.num_replicas] 119 | assert len(indices) == self.num_samples 120 | 121 | # manual drop_last 122 | indices = indices[:self.num_samples // self.batch_size * self.batch_size] 123 | 124 | return indices 125 | 126 | 127 | class InfiniteDistributedSampler(DistributedSampler): 128 | 129 | def __init__(self, *args, **kwargs): 130 | """ 131 | Args: 132 | global_batch_size: since infinite indices will wrap, 133 | so it is possible that same images in one batch. 134 | We apply drop_last here in the sampler. 135 | determistic: we always start the generator with seed 0, 136 | and then to restart from certain iteration, we just 137 | manually drain the iterator. This can make dataloader 138 | properly resumed. 139 | """ 140 | self.global_batch_size = kwargs.pop('global_batch_size', 0) 141 | self.batch_size = kwargs.pop('batch_size', 0) 142 | 143 | self.deterministic = kwargs.pop('deterministic', False) 144 | 145 | super().__init__(*args, **kwargs) 146 | 147 | # intialize batch_size and global_batch_size. 148 | if self.batch_size != 0: 149 | assert self.global_batch_size == 0, 'don\'t set both batch_size and global_batch_size' 150 | self.global_batch_size = self.batch_size * self.num_replicas 151 | elif self.global_batch_size != 0: 152 | assert self.batch_size == 0, 'don\'t set both batch_size and global_batch_size' 153 | self.batch_size = self.global_batch_size // self.num_replicas 154 | else: 155 | assert not self.deterministic, 'You have to specify batch_size or global_batch_size if determinstic is True' 156 | 157 | # Modified from https://github.com/facebookresearch/detectron2/blob/master/detectron2/data/samplers/distributed_sampler.py#L12 158 | def __iter__(self): 159 | start = self.rank 160 | if self.deterministic: 161 | for _i, idx in enumerate(itertools.islice(self._infinite_indices(), start, None, self.num_replicas)): 162 | if _i >= self.epoch * self.batch_size: # Make sure the epoch is actually iteration 163 | yield idx 164 | else: 165 | yield from itertools.islice(self._infinite_indices(), start, None, self.num_replicas) 166 | 167 | def _infinite_indices(self): 168 | g = torch.Generator() 169 | if self.deterministic: 170 | g.manual_seed(0) 171 | else: 172 | g.manual_seed(self.epoch) 173 | while True: 174 | if self.shuffle: 175 | indices = torch.randperm(len(self.dataset), generator=g).tolist() 176 | else: 177 | indices = list(range(len(self.dataset))) 178 | 179 | 180 | # add extra samples to make it evenly divisible 181 | indices += indices[:(self.total_size - len(indices))] 182 | assert len(indices) == self.total_size 183 | 184 | if self.global_batch_size != 0: 185 | # do what drop_last do, make it devisible by global_batch_size 186 | indices = indices[:self.total_size // self.global_batch_size * self.global_batch_size] 187 | 188 | yield from indices -------------------------------------------------------------------------------- /rtutils/DataLoader.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import itertools 3 | import torch.distributed as dist 4 | 5 | class ResumableDataLoader(torch.utils.data.DataLoader): 6 | def __init__(self, *args, **kwargs): 7 | super().__init__(*args, **kwargs) 8 | self._need_to_resume = False 9 | self._rng_state = None 10 | self._iter = None 11 | 12 | # def __iter__(self): 13 | # if self._need_to_resume: 14 | # torch.set_rng_state(self._rng_state) 15 | # tmp_num_workers = self.num_workers 16 | # self.num_workers = 0 17 | # dataiter = super().__iter__() 18 | # import itertools 19 | # __sampler_iter, _sampler_iter = itertools.tee(dataiter._sampler_iter) 20 | # next(__sampler_iter) # Make sure the _sampler_iter is instantiated 21 | 22 | # # Run through the saved_iter, make the sampler at the state of halt 23 | # for i in range(self._saved_iter): 24 | # next(_sampler_iter) 25 | 26 | # self.num_workers = tmp_num_workers 27 | # dataiter = super().__iter__() 28 | # dataiter._sampler_iter = _sampler_iter 29 | # import time 30 | # time.sleep(5) 31 | # # drain the prefetched batches 32 | # for i in range(dataiter._send_idx - dataiter._rcvd_idx): 33 | # next(dataiter) 34 | 35 | # self._iter = dataiter # Save it in case 36 | # self._need_to_resume = False # Only resume at the first iter(loader) after load_state_dict 37 | # return self._iter 38 | # else: 39 | # self._rng_state = torch.get_rng_state() 40 | # self._iter = super().__iter__() 41 | # return self._iter 42 | 43 | # # save sampler_iter and rng-state 44 | # def __iter__(self): 45 | # if self._need_to_resume: 46 | # torch.set_rng_state(self._rng_state) 47 | # else: 48 | # self._rng_state = torch.get_rng_state() 49 | 50 | # tmp_num_workers = self.num_workers 51 | # self.num_workers = 0 52 | # dataiter = super().__iter__() 53 | # import itertools 54 | # __sampler_iter, _sampler_iter = itertools.tee(dataiter._sampler_iter) 55 | # next(__sampler_iter) # Make sure the _sampler_iter is instantiated 56 | 57 | # self._sampler_iter, _sampler_iter = itertools.tee(_sampler_iter) # Save the full iter 58 | 59 | # if self._need_to_resume: 60 | # # Run through the saved_iter, make the sampler at the state of halt 61 | # for i in range(self._saved_iter): 62 | # next(_sampler_iter) 63 | 64 | # if self._need_to_resume: 65 | # # Assertion 66 | # # assert _sampler_iter the same as saved_sampler: 67 | # tmp_sampler_iter, _ = itertools.tee(self._sampler_iter) 68 | # for idxs1, idxs2 in zip(tmp_sampler_iter, self._saved_sampler): 69 | # assert idxs1 == idxs2 70 | 71 | # self.num_workers = tmp_num_workers 72 | # dataiter = super().__iter__() 73 | # dataiter._sampler_iter = _sampler_iter 74 | # import time 75 | # time.sleep(5) 76 | # # drain the prefetched batches 77 | # for i in range(dataiter._send_idx - dataiter._rcvd_idx): 78 | # next(dataiter) 79 | 80 | # self._iter = dataiter # Save it in case 81 | # self._need_to_resume = False # Only resume at the first iter(loader) after load_state_dict 82 | # return self._iter 83 | 84 | 85 | def get_sampler_iter(self): 86 | tmp_num_workers = self.num_workers 87 | self.num_workers = 0 88 | dataiter = super().__iter__() 89 | __sampler_iter, _sampler_iter = itertools.tee(dataiter._sampler_iter) 90 | next(__sampler_iter) # Make sure the _sampler_iter is instantiated 91 | self.num_workers = tmp_num_workers 92 | return _sampler_iter 93 | 94 | 95 | def __iter__(self): 96 | 97 | if self._need_to_resume: 98 | # Run through the saved_iter, make the sampler at the state of halt 99 | self._sampler_iter, _sampler_iter = itertools.tee(iter(self._saved_sampler)) 100 | for i in range(self._saved_iter): 101 | next(_sampler_iter) 102 | else: 103 | _sampler_iter = self.get_sampler_iter() 104 | self._sampler_iter, _sampler_iter = itertools.tee(_sampler_iter) # Save the full iter 105 | 106 | self.get_sampler_to_save() # for state_dict 107 | 108 | 109 | dataiter = super().__iter__() 110 | dataiter._sampler_iter = _sampler_iter 111 | if self.num_workers > 0: 112 | # Wait for prefetching 113 | import time 114 | time.sleep(5) 115 | # drain the prefetched batches 116 | for i in range(dataiter._send_idx - dataiter._rcvd_idx): 117 | next(dataiter) 118 | 119 | self._iter = dataiter # Save it in case 120 | self._need_to_resume = False # Only resume at the first iter(loader) after load_state_dict 121 | return self._iter 122 | 123 | def get_sampler_to_save(self): 124 | # All the indices: 125 | self._sampler_to_save = [] 126 | for idxs in self._sampler_iter: 127 | if dist.is_available() and type(self.sampler) is torch.utils.data.distributed.DistributedSampler: 128 | idxs_list = [torch.zeros(len(idxs)) for k in range(dist.get_world_size())] 129 | dist.all_gather(idxs_list, torch.tensor(idxs), async_op=False) 130 | self._sampler_to_save.append([_.tolist() for _ in idxs_list]) 131 | else: 132 | self._sampler_to_save.append(idxs) 133 | 134 | 135 | def state_dict(self): 136 | if self._iter is None: 137 | return None 138 | # The number of batches rest in this epoch 139 | num_rest = 0 140 | self._iter._sampler_iter, _sampler_iter = itertools.tee(self._iter._sampler_iter) 141 | for idxs in _sampler_iter: 142 | num_rest += 1 143 | 144 | # the number of batched prefetched by the loader 145 | if self.num_workers > 0: 146 | num_prefetched = self._iter._send_idx - self._iter._rcvd_idx 147 | else: 148 | num_prefetched = 0 149 | self._saved_iter = len(self) - num_rest - num_prefetched 150 | 151 | return {'saved_iter': self._saved_iter, 152 | 'sampler': self._sampler_to_save} 153 | 154 | # def state_dict(self): 155 | # # The number of batches rest in this epoch 156 | # num_rest = 0 157 | # self._rest_indices = [] 158 | # for idxs in self._iter._sampler_iter: 159 | # num_rest += 1 160 | # self._rest_indices.append(idxs) 161 | 162 | # # the number of batched prefetched by the loader 163 | # num_prefetched = self._iter._send_idx - self._iter._rcvd_idx 164 | # self._saved_iter = len(self) - num_rest - num_prefetched 165 | 166 | # return {'saved_iter': self._saved_iter, 167 | # 'rng_state': self._rng_state, 168 | # 'rest_indices': self._rest_indices} 169 | 170 | def load_state_dict(self, state_dict): 171 | if state_dict is None: 172 | return 173 | self._need_to_resume = True 174 | self._saved_iter = state_dict['saved_iter'] 175 | self._saved_sampler = state_dict['sampler'] 176 | if dist.is_available() and type(self.sampler) is torch.utils.data.distributed.DistributedSampler: 177 | self._saved_sampler = [_[dist.get_rank()] for _ in self._saved_sampler] 178 | 179 | 180 | 181 | if __name__ == '__main__': 182 | 183 | train_loader = ResumableDataLoader(list(range(1000)), #testDataset(), 184 | batch_size=10, 185 | shuffle=True, 186 | num_workers=2) 187 | 188 | 189 | dataloader_iter = iter(train_loader) 190 | 191 | batches = [] 192 | for i, data in enumerate(dataloader_iter): 193 | print(i, data) 194 | batches.append(data) 195 | if i == 3: 196 | print('Save the state_dict after the third iteration') 197 | state_dict = train_loader.state_dict() 198 | if i == 20: 199 | print('End up here') 200 | break 201 | 202 | train_loader.load_state_dict(state_dict) 203 | print('-----') 204 | print('Resume from the fourth iteration') 205 | for i, data in enumerate(train_loader): 206 | print(i+4, data) 207 | if i+4 >= len(batches): 208 | print(i+4) 209 | break 210 | assert (batches[i+4] == data).all() 211 | 212 | print('--------------') 213 | 214 | 215 | # test 2: edge case, when dataiter has gone to the end. Make sure no error 216 | 217 | train_loader = ResumableDataLoader(list(range(10)), #testDataset(), 218 | batch_size=10, 219 | shuffle=True, 220 | num_workers=2) 221 | 222 | dataloader_iter = iter(train_loader) 223 | for i, data in enumerate(dataloader_iter): 224 | print(data) 225 | state_dict = train_loader.state_dict() 226 | train_loader.load_state_dict(state_dict) 227 | for i, data in enumerate(train_loader): 228 | print(data) 229 | 230 | 231 | 232 | 233 | 234 | -------------------------------------------------------------------------------- /rtutils/pl_patch.py: -------------------------------------------------------------------------------- 1 | import types 2 | import os 3 | from rtutils.sampler import DeterministicDistributedSampler 4 | import warnings 5 | 6 | import pytorch_lightning as pl 7 | from packaging.version import Version 8 | assert Version(pl.__version__) > Version('1.3.8'), 'v0.3 only work for pytorch-lightning>=1.4.0, either downgrade to v0.2.2 or update pytorch-lightning' 9 | 10 | 11 | def _get_distributed_sampler(self, dataloader, shuffle, mode): 12 | # modified from https://github.com/PyTorchLightning/pytorch-lightning/blob/HEAD/pytorch_lightning/trainer/data_loading.py?q=replace_ddp_sampler#L217 13 | from pytorch_lightning.overrides.distributed import UnrepeatedDistributedSampler 14 | from pytorch_lightning.trainer.states import RunningStage 15 | kwargs = self.distributed_sampler_kwargs 16 | kwargs["shuffle"] = shuffle and not self.overfit_batches 17 | kwargs.setdefault("seed", int(os.getenv("PL_GLOBAL_SEED", 0))) 18 | cls = UnrepeatedDistributedSampler if mode == RunningStage.PREDICTING else DeterministicDistributedSampler 19 | if cls == DeterministicDistributedSampler: 20 | kwargs["batch_size"] = dataloader.batch_size 21 | sampler = cls(dataloader.dataset, **kwargs) 22 | 23 | # rt note here: we set_epoch before, because we want the progress bart start from 0/{remaining batches}, instead of 0/{full batches}. However, we have now a better way to handle progressbar. 24 | # I just leave a note here to remind myself 25 | # # If we want to know if it is the end of an epoch, we cannot set_epoch here. because we rely on fixed number of batches each epoch. 26 | # # sampler.set_epoch(self.total_batch_idx) # we set_epoch here so that the progress bar will know the correct size. In fact pl will set_epoch later and override this. This is just a spoiler for the trainer to know how much batches left for this epoch. 27 | return sampler 28 | 29 | 30 | def patch_pl_trainer_with_deterministic_sampler(trainer): 31 | """ 32 | When pytorch lightning is replacing distributed sampler, replace with my dterministic sampler. 33 | """ 34 | assert trainer.accelerator_connector.is_distributed, 'Your trainer is not distributed. Cannot replace.' 35 | assert trainer.accelerator_connector.replace_sampler_ddp, 'Make sure you set replace_sampler_ddp to be True' 36 | # TODO, we also need to make sure the dataloader is not set some DistributedSampler, otherwise the replace_sampler will not be called either. 37 | trainer._get_distributed_sampler = types.MethodType(_get_distributed_sampler, trainer) 38 | 39 | 40 | class ProgressBarPatch(pl.callbacks.ProgressBar): 41 | def on_train_epoch_start(self, trainer, pl_module): 42 | super().on_train_epoch_start(trainer, pl_module) 43 | # This manually set the progress bar to the middle according to the total_batch_idx. 44 | # asssume fix num batches in each epoch. 45 | self.main_progress_bar.update(trainer.fit_loop.total_batch_idx % trainer.num_training_batches) 46 | 47 | 48 | def patch_progressbar(trainer): 49 | """ 50 | Make progressbar start from middle if resuming from middle-epoch checkpoint. 51 | """ 52 | # This on_train_epoch_start will run after progress_bar on_train_epoch_start 53 | # this manually set the progress bar to the middle according to the total_batch_idx. 54 | # asssume fix num batches in each epoch. 55 | for callback in trainer.callbacks: 56 | if isinstance(callback, pl.callbacks.progress.ProgressBar): 57 | old_on_train_epoch_start = callback.on_train_epoch_start 58 | def on_train_epoch_start(self, trainer, pl_module): 59 | old_on_train_epoch_start(trainer, pl_module) 60 | self.main_progress_bar.update(trainer.fit_loop.total_batch_idx % trainer.num_training_batches) 61 | callback.on_train_epoch_start = types.MethodType(on_train_epoch_start, callback) 62 | 63 | 64 | class KeyboardInterruptModelCheckpoint(pl.callbacks.ModelCheckpoint): 65 | def on_keyboard_interrupt(self, trainer, pl_module): 66 | # Save model when keyboard interrupt 67 | filepath = os.path.join(self.dirpath, self.CHECKPOINT_NAME_LAST+'.ckpt') 68 | self._save_model(trainer, filepath=filepath) 69 | 70 | 71 | def patch_model_checkpoint(trainer): 72 | """ 73 | Save last checkpoint when key interrupt occurs. 74 | """ 75 | for callback in trainer.callbacks: 76 | if isinstance(callback, pl.callbacks.ModelCheckpoint): 77 | old_on_keyboard_interrupt = callback.on_keyboard_interrupt 78 | def on_keyboard_interrupt(self, trainer, pl_module): 79 | old_on_keyboard_interrupt(trainer, pl_module) 80 | KeyboardInterruptModelCheckpoint.on_keyboard_interrupt(self, trainer, pl_module) 81 | callback.on_keyboard_interrupt = types.MethodType(on_keyboard_interrupt, callback) 82 | 83 | 84 | def patch_on_save_checkpoint_every(trainer, checkpoint_every): 85 | """ 86 | Save last checkpoint every {checkpoint_every} iteration. 87 | """ 88 | for callback in trainer.callbacks: 89 | if isinstance(callback, pl.callbacks.ModelCheckpoint): 90 | old_on_train_batch_end = callback.on_train_batch_end 91 | def on_train_batch_end(self, trainer, pl_module, outputs, batch, batch_idx, dataloader_idx): 92 | old_on_train_batch_end(trainer, pl_module, outputs, batch, batch_idx, dataloader_idx) 93 | if (trainer.global_step + 1) % checkpoint_every == 0: 94 | KeyboardInterruptModelCheckpoint.on_keyboard_interrupt(self, trainer, pl_module) 95 | callback.on_train_batch_end = types.MethodType(on_train_batch_end, callback) 96 | 97 | 98 | def patch_checkpoint_connector(trainer): 99 | # save and load total_batch_idx 100 | old_restore_training_state = trainer.checkpoint_connector.restore_training_state 101 | def restore_training_state(self): 102 | old_restore_training_state() 103 | if self._loaded_checkpoint: 104 | self.trainer.fit_loop.epoch_loop.total_batch_idx = self._loaded_checkpoint.get('total_batch_idx', 0) 105 | old_dump_checkpoint = trainer.checkpoint_connector.dump_checkpoint 106 | def dump_checkpoint(self, weights_only: bool = False) -> dict: 107 | checkpoint = old_dump_checkpoint(weights_only) 108 | checkpoint['total_batch_idx'] = self.trainer.fit_loop.total_batch_idx 109 | # note: this won't work when limit_train_batches is set. 110 | if not self.trainer.fit_loop.total_batch_idx % self.trainer.num_training_batches == 0: 111 | # end in the middle of a epoch 112 | checkpoint['epoch'] -= 1 # so that when resuming, the current_epoch will be the same as when saving. 113 | return checkpoint 114 | trainer.checkpoint_connector.restore_training_state = types.MethodType(restore_training_state, trainer.checkpoint_connector) 115 | trainer.checkpoint_connector.dump_checkpoint = types.MethodType(dump_checkpoint, trainer.checkpoint_connector) 116 | 117 | 118 | class SetEpochCallback(pl.callbacks.Callback): 119 | def on_train_epoch_start(self, trainer, pl_module): 120 | trainer.train_dataloader.sampler.set_epoch(trainer.fit_loop.total_batch_idx) 121 | 122 | 123 | def patch_set_epoch(trainer): 124 | """ 125 | Instead of set_epoch according to current_epoch, set_epoch with total_batch_idx. 126 | This is intended to work with my determnistic_distributed_sampler. 127 | """ 128 | old_on_advance_start = trainer.fit_loop.on_advance_start 129 | def on_advance_start(self): 130 | old_on_advance_start() 131 | self.trainer.train_dataloader.sampler.set_epoch(self.trainer.fit_loop.total_batch_idx) 132 | trainer.fit_loop.on_advance_start = types.MethodType(on_advance_start, trainer.fit_loop) 133 | 134 | 135 | def patch_data_connector(trainer): 136 | """ 137 | The the initial_batch_idx from the train_loader enumerator respect middle-epoch resuming. 138 | So that the check_val_fx can work properly. 139 | """ 140 | # set enumerate initial batch_idx according to the loader size. 141 | # copy and rewrite the function 142 | # def get_profiled_train_dataloader(self, train_dataloader): 143 | # # We feed batch_idx because the model may resume of middle-of-epoch checkpoint. 144 | # # the length of train_dataloader has been modified by set_epoch at this point. 145 | # from pytorch_lightning.trainer.supporters import prefetch_iterator 146 | # start_batch_idx = self.trainer.num_training_batches - len(train_dataloader) 147 | # profiled_dl = self.trainer.profiler.profile_iterable( 148 | # enumerate(prefetch_iterator(train_dataloader), start_batch_idx), "get_train_batch" 149 | # ) 150 | # return profiled_dl 151 | old_get_profiled_train_dataloader = trainer.data_connector.get_profiled_train_dataloader 152 | # reuse the old function 153 | def get_profiled_train_dataloader(self, train_dataloader): 154 | # We feed batch_idx because the model may resume of middle-of-epoch checkpoint. 155 | # the length of train_dataloader has been modified by set_epoch at this point. 156 | start_batch_idx = self.trainer.fit_loop.total_batch_idx % self.trainer.num_training_batches 157 | old_profiled_dl = old_get_profiled_train_dataloader(train_dataloader) 158 | def profiled_dl(): 159 | # we discard the old batch_idx, and use the new one. 160 | for batch_idx, (_, batch) in enumerate(old_profiled_dl, start_batch_idx): 161 | yield batch_idx, batch 162 | return profiled_dl() 163 | trainer.data_connector.get_profiled_train_dataloader = types.MethodType(get_profiled_train_dataloader, trainer.data_connector) 164 | 165 | 166 | def patch_everything(trainer): 167 | warnings.warn('Patch everything will replace trainer\'s (and its members\') functions with something else. In general it should be fine. But make sure you are confortable with this.') 168 | patch_model_checkpoint(trainer) 169 | patch_set_epoch(trainer) 170 | patch_progressbar(trainer) 171 | patch_pl_trainer_with_deterministic_sampler(trainer) 172 | patch_checkpoint_connector(trainer) 173 | patch_data_connector(trainer) 174 | 175 | 176 | def patch_everything_safer(trainer): 177 | """ 178 | Compare to patch_everything, we remove the patches that can be implemented by callbacks. 179 | This need to work with all_callbacks() and KeyboardInterruptModelCheckpoint to work properly. 180 | 181 | For examples: 182 | callbacks = [ 183 | KeyboardInterruptModelCheckpoint(...), 184 | *rtutils.pl_patch.all_callbacks(), 185 | ] 186 | trainer = pl.Trainer( 187 | ..., 188 | callbacks = callbacks, 189 | ... 190 | ) 191 | rtutils.pl_patch.patch_everything_safer(trainer) 192 | """ 193 | warnings.warn('Patch everything will replace trainer\'s (and its members\') functions with something else. In general it should be fine. But make sure you are confortable with this.') 194 | warnings.warn('To fully function, you would also want to include all the callbacks in pl_path.all_backs and use KeyboardInterruptModelCheckpoint to create ModelCheckpoint Instance') 195 | patch_pl_trainer_with_deterministic_sampler(trainer) 196 | patch_checkpoint_connector(trainer) 197 | patch_data_connector(trainer) 198 | 199 | 200 | def all_callbacks(): 201 | return [ 202 | ProgressBarPatch(), 203 | SetEpochCallback(), 204 | ] --------------------------------------------------------------------------------