├── README.md ├── mtl-data-loading ├── basic_dataloader_example.py ├── balanced_batch_scheduler_dataloader_example.py ├── batch_scheduler_dataloader_example.py ├── basic_dataset_example.py ├── multi_task_batch_scheduler.py └── balanced_sampler.py └── .gitignore /README.md: -------------------------------------------------------------------------------- 1 | # code-for-posts 2 | ### code scripts for blog posts I published: 3 | 4 | 1. **Unbalanced data loading for multi-task learning in PyTorch:** 5 | Code: [mtl-data-loading](https://github.com/bomri/code-for-posts/tree/master/mtl-data-loading) 6 | Link: [Towards Data Science Post](https://medium.com/p/unbalanced-data-loading-for-multi-task-learning-in-pytorch-e030ad5033b?source=email-486b68bc632a--writer.postDistributed&sk=1b6abef7a845cb72faa8304aaccfc281) 7 | -------------------------------------------------------------------------------- /mtl-data-loading/basic_dataloader_example.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.utils.data.dataset import ConcatDataset 3 | from basic_dataset_example import MyFirstDataset, MySecondDataset 4 | 5 | 6 | first_dataset = MyFirstDataset() 7 | second_dataset = MySecondDataset() 8 | concat_dataset = ConcatDataset([first_dataset, second_dataset]) 9 | 10 | batch_size = 8 11 | 12 | # basic dataloader 13 | dataloader = torch.utils.data.DataLoader(dataset=concat_dataset, 14 | batch_size=batch_size, 15 | shuffle=True) 16 | 17 | for inputs in dataloader: 18 | print(inputs) 19 | -------------------------------------------------------------------------------- /mtl-data-loading/balanced_batch_scheduler_dataloader_example.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.utils.data.dataset import ConcatDataset 3 | from balanced_sampler import BalancedBatchSchedulerSampler 4 | from basic_dataset_example import MyFirstDataset, MySecondDataset 5 | 6 | first_dataset = MyFirstDataset() 7 | second_dataset = MySecondDataset() 8 | concat_dataset = ConcatDataset([first_dataset, second_dataset]) 9 | 10 | batch_size = 8 11 | 12 | # dataloader with BalancedBatchSchedulerSampler 13 | dataloader = torch.utils.data.DataLoader(dataset=concat_dataset, 14 | sampler=BalancedBatchSchedulerSampler(dataset=concat_dataset, 15 | batch_size=batch_size), 16 | batch_size=batch_size, 17 | shuffle=False) 18 | 19 | for inputs in dataloader: 20 | print(inputs) 21 | -------------------------------------------------------------------------------- /mtl-data-loading/batch_scheduler_dataloader_example.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.utils.data.dataset import ConcatDataset 3 | from basic_dataset_example import MyFirstDataset, MySecondDataset, MyThirdDataset 4 | from multi_task_batch_scheduler import BatchSchedulerSampler 5 | 6 | first_dataset = MyFirstDataset() 7 | second_dataset = MySecondDataset() 8 | third_dataset = MyThirdDataset() 9 | concat_dataset = ConcatDataset([first_dataset, second_dataset, third_dataset]) 10 | 11 | batch_size = 8 12 | 13 | # dataloader with BatchSchedulerSampler 14 | dataloader = torch.utils.data.DataLoader(dataset=concat_dataset, 15 | sampler=BatchSchedulerSampler(dataset=concat_dataset, 16 | batch_size=batch_size), 17 | batch_size=batch_size, 18 | shuffle=False) 19 | 20 | for inputs in dataloader: 21 | print(inputs) 22 | -------------------------------------------------------------------------------- /mtl-data-loading/basic_dataset_example.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.utils.data.dataset import ConcatDataset 3 | 4 | 5 | class MyFirstDataset(torch.utils.data.Dataset): 6 | def __init__(self): 7 | # dummy dataset 8 | self.samples = torch.cat((-torch.ones(5), torch.ones(5))) 9 | 10 | def __getitem__(self, index): 11 | # change this to your samples fetching logic 12 | return self.samples[index] 13 | 14 | def __len__(self): 15 | # change this to return number of samples in your dataset 16 | return self.samples.shape[0] 17 | 18 | 19 | class MySecondDataset(torch.utils.data.Dataset): 20 | def __init__(self): 21 | # dummy dataset 22 | self.samples = torch.cat((torch.ones(50) * 5, torch.ones(5) * -5)) 23 | 24 | def __getitem__(self, index): 25 | # change this to your samples fetching logic 26 | return self.samples[index] 27 | 28 | def __len__(self): 29 | # change this to return number of samples in your dataset 30 | return self.samples.shape[0] 31 | 32 | 33 | class MyThirdDataset(torch.utils.data.Dataset): 34 | def __init__(self): 35 | # dummy dataset 36 | self.samples = torch.cat((torch.ones(20) * 10, torch.ones(10) * -10)) 37 | 38 | def __getitem__(self, index): 39 | # change this to your samples fetching logic 40 | return self.samples[index] 41 | 42 | def __len__(self): 43 | # change this to return number of samples in your dataset 44 | return self.samples.shape[0] 45 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | #python files 2 | sampler.py 3 | 4 | *.idea 5 | 6 | # Byte-compiled / optimized / DLL files 7 | __pycache__/ 8 | *.py[cod] 9 | *$py.class 10 | 11 | # C extensions 12 | *.so 13 | 14 | # Distribution / packaging 15 | .Python 16 | build/ 17 | develop-eggs/ 18 | dist/ 19 | downloads/ 20 | eggs/ 21 | .eggs/ 22 | lib/ 23 | lib64/ 24 | parts/ 25 | sdist/ 26 | var/ 27 | wheels/ 28 | pip-wheel-metadata/ 29 | share/python-wheels/ 30 | *.egg-info/ 31 | .installed.cfg 32 | *.egg 33 | MANIFEST 34 | 35 | # PyInstaller 36 | # Usually these files are written by a python script from a template 37 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 38 | *.manifest 39 | *.spec 40 | 41 | # Installer logs 42 | pip-log.txt 43 | pip-delete-this-directory.txt 44 | 45 | # Unit test / coverage reports 46 | htmlcov/ 47 | .tox/ 48 | .nox/ 49 | .coverage 50 | .coverage.* 51 | .cache 52 | nosetests.xml 53 | coverage.xml 54 | *.cover 55 | *.py,cover 56 | .hypothesis/ 57 | .pytest_cache/ 58 | 59 | # Translations 60 | *.mo 61 | *.pot 62 | 63 | # Django stuff: 64 | *.log 65 | local_settings.py 66 | db.sqlite3 67 | db.sqlite3-journal 68 | 69 | # Flask stuff: 70 | instance/ 71 | .webassets-cache 72 | 73 | # Scrapy stuff: 74 | .scrapy 75 | 76 | # Sphinx documentation 77 | docs/_build/ 78 | 79 | # PyBuilder 80 | target/ 81 | 82 | # Jupyter Notebook 83 | .ipynb_checkpoints 84 | 85 | # IPython 86 | profile_default/ 87 | ipython_config.py 88 | 89 | # pyenv 90 | .python-version 91 | 92 | # pipenv 93 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 94 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 95 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 96 | # install all needed dependencies. 97 | #Pipfile.lock 98 | 99 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 100 | __pypackages__/ 101 | 102 | # Celery stuff 103 | celerybeat-schedule 104 | celerybeat.pid 105 | 106 | # SageMath parsed files 107 | *.sage.py 108 | 109 | # Environments 110 | .env 111 | .venv 112 | env/ 113 | venv/ 114 | ENV/ 115 | env.bak/ 116 | venv.bak/ 117 | 118 | # Spyder project settings 119 | .spyderproject 120 | .spyproject 121 | 122 | # Rope project settings 123 | .ropeproject 124 | 125 | # mkdocs documentation 126 | /site 127 | 128 | # mypy 129 | .mypy_cache/ 130 | .dmypy.json 131 | dmypy.json 132 | 133 | # Pyre type checker 134 | .pyre/ 135 | -------------------------------------------------------------------------------- /mtl-data-loading/multi_task_batch_scheduler.py: -------------------------------------------------------------------------------- 1 | import math 2 | import torch 3 | from torch.utils.data.sampler import RandomSampler 4 | 5 | 6 | class BatchSchedulerSampler(torch.utils.data.sampler.Sampler): 7 | """ 8 | iterate over tasks and provide a random batch per task in each mini-batch 9 | """ 10 | def __init__(self, dataset, batch_size): 11 | self.dataset = dataset 12 | self.batch_size = batch_size 13 | self.number_of_datasets = len(dataset.datasets) 14 | self.largest_dataset_size = max([len(cur_dataset.samples) for cur_dataset in dataset.datasets]) 15 | 16 | def __len__(self): 17 | return self.batch_size * math.ceil(self.largest_dataset_size / self.batch_size) * len(self.dataset.datasets) 18 | 19 | def __iter__(self): 20 | samplers_list = [] 21 | sampler_iterators = [] 22 | for dataset_idx in range(self.number_of_datasets): 23 | cur_dataset = self.dataset.datasets[dataset_idx] 24 | sampler = RandomSampler(cur_dataset) 25 | samplers_list.append(sampler) 26 | cur_sampler_iterator = sampler.__iter__() 27 | sampler_iterators.append(cur_sampler_iterator) 28 | 29 | push_index_val = [0] + self.dataset.cumulative_sizes[:-1] 30 | step = self.batch_size * self.number_of_datasets 31 | samples_to_grab = self.batch_size 32 | # for this case we want to get all samples in dataset, this force us to resample from the smaller datasets 33 | epoch_samples = self.largest_dataset_size * self.number_of_datasets 34 | 35 | final_samples_list = [] # this is a list of indexes from the combined dataset 36 | for _ in range(0, epoch_samples, step): 37 | for i in range(self.number_of_datasets): 38 | cur_batch_sampler = sampler_iterators[i] 39 | cur_samples = [] 40 | for _ in range(samples_to_grab): 41 | try: 42 | cur_sample_org = cur_batch_sampler.__next__() 43 | cur_sample = cur_sample_org + push_index_val[i] 44 | cur_samples.append(cur_sample) 45 | except StopIteration: 46 | # got to the end of iterator - restart the iterator and continue to get samples 47 | # until reaching "epoch_samples" 48 | sampler_iterators[i] = samplers_list[i].__iter__() 49 | cur_batch_sampler = sampler_iterators[i] 50 | cur_sample_org = cur_batch_sampler.__next__() 51 | cur_sample = cur_sample_org + push_index_val[i] 52 | cur_samples.append(cur_sample) 53 | final_samples_list.extend(cur_samples) 54 | 55 | return iter(final_samples_list) 56 | -------------------------------------------------------------------------------- /mtl-data-loading/balanced_sampler.py: -------------------------------------------------------------------------------- 1 | import math 2 | import torch 3 | from torch.utils.data import RandomSampler 4 | from sampler import ImbalancedDatasetSampler 5 | 6 | 7 | class ExampleImbalancedDatasetSampler(ImbalancedDatasetSampler): 8 | """ 9 | ImbalancedDatasetSampler is taken from: 10 | https://github.com/ufoym/imbalanced-dataset-sampler/blob/master/torchsampler/imbalanced.py 11 | In order to be able to show the usage of ImbalancedDatasetSampler in this example I am editing the _get_label 12 | to fit my datasets 13 | """ 14 | def _get_label(self, dataset, idx): 15 | return dataset.samples[idx].item() 16 | 17 | 18 | class BalancedBatchSchedulerSampler(torch.utils.data.sampler.Sampler): 19 | """ 20 | iterate over tasks and provide a balanced batch per task in each mini-batch 21 | """ 22 | def __init__(self, dataset, batch_size): 23 | self.dataset = dataset 24 | self.batch_size = batch_size 25 | self.number_of_datasets = len(dataset.datasets) 26 | self.largest_dataset_size = max([len(cur_dataset.samples) for cur_dataset in dataset.datasets]) 27 | 28 | def __len__(self): 29 | return self.batch_size * math.ceil(self.largest_dataset_size / self.batch_size) * len(self.dataset.datasets) 30 | 31 | def __iter__(self): 32 | samplers_list = [] 33 | sampler_iterators = [] 34 | for dataset_idx in range(self.number_of_datasets): 35 | cur_dataset = self.dataset.datasets[dataset_idx] 36 | if dataset_idx == 0: 37 | # the first dataset is kept at RandomSampler 38 | sampler = RandomSampler(cur_dataset) 39 | else: 40 | # the second unbalanced dataset is changed 41 | sampler = ExampleImbalancedDatasetSampler(cur_dataset) 42 | samplers_list.append(sampler) 43 | cur_sampler_iterator = sampler.__iter__() 44 | sampler_iterators.append(cur_sampler_iterator) 45 | 46 | push_index_val = [0] + self.dataset.cumulative_sizes[:-1] 47 | step = self.batch_size * self.number_of_datasets 48 | samples_to_grab = self.batch_size 49 | # for this case we want to get all samples in dataset, this force us to resample from the smaller datasets 50 | epoch_samples = self.largest_dataset_size * self.number_of_datasets 51 | 52 | final_samples_list = [] # this is a list of indexes from the combined dataset 53 | for _ in range(0, epoch_samples, step): 54 | for i in range(self.number_of_datasets): 55 | cur_batch_sampler = sampler_iterators[i] 56 | cur_samples = [] 57 | for _ in range(samples_to_grab): 58 | try: 59 | cur_sample_org = cur_batch_sampler.__next__() 60 | cur_sample = cur_sample_org + push_index_val[i] 61 | cur_samples.append(cur_sample) 62 | except StopIteration: 63 | # got to the end of iterator - restart the iterator and continue to get samples 64 | # until reaching "epoch_samples" 65 | sampler_iterators[i] = samplers_list[i].__iter__() 66 | cur_batch_sampler = sampler_iterators[i] 67 | cur_sample_org = cur_batch_sampler.__next__() 68 | cur_sample = cur_sample_org + push_index_val[i] 69 | cur_samples.append(cur_sample) 70 | final_samples_list.extend(cur_samples) 71 | 72 | return iter(final_samples_list) 73 | --------------------------------------------------------------------------------