├── .gitignore ├── README.md ├── fastai_v1 ├── OOMLog.txt ├── SingleCoreLog.txt ├── tpu_distributed_fastai.py └── tpu_single_core_fastai.py └── fastai_v2 ├── BENCHMARKS.md ├── fastai2_GPU_Food101.ipynb ├── fastai2_gpu.py ├── fastai_multiprocessing_dl.py ├── run_tpu.sh ├── test_tpu_distributed_dl.py ├── tpu_distributed_dl.py └── tpu_distributed_fastai2.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | pip-wheel-metadata/ 24 | share/python-wheels/ 25 | *.egg-info/ 26 | .installed.cfg 27 | *.egg 28 | MANIFEST 29 | 30 | # PyInstaller 31 | # Usually these files are written by a python script from a template 32 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 33 | *.manifest 34 | *.spec 35 | 36 | # Installer logs 37 | pip-log.txt 38 | pip-delete-this-directory.txt 39 | 40 | # Unit test / coverage reports 41 | htmlcov/ 42 | .tox/ 43 | .nox/ 44 | .coverage 45 | .coverage.* 46 | .cache 47 | nosetests.xml 48 | coverage.xml 49 | *.cover 50 | *.py,cover 51 | .hypothesis/ 52 | .pytest_cache/ 53 | cover/ 54 | 55 | # Translations 56 | *.mo 57 | *.pot 58 | 59 | # Django stuff: 60 | *.log 61 | local_settings.py 62 | db.sqlite3 63 | db.sqlite3-journal 64 | 65 | # Flask stuff: 66 | instance/ 67 | .webassets-cache 68 | 69 | # Scrapy stuff: 70 | .scrapy 71 | 72 | # Sphinx documentation 73 | docs/_build/ 74 | 75 | # PyBuilder 76 | target/ 77 | 78 | # Jupyter Notebook 79 | .ipynb_checkpoints 80 | 81 | # IPython 82 | profile_default/ 83 | ipython_config.py 84 | 85 | # pyenv 86 | # For a library or package, you might want to ignore these files since the code is 87 | # intended to run in multiple environments; otherwise, check them in: 88 | # .python-version 89 | 90 | # pipenv 91 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 92 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 93 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 94 | # install all needed dependencies. 95 | #Pipfile.lock 96 | 97 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 98 | __pypackages__/ 99 | 100 | # Celery stuff 101 | celerybeat-schedule 102 | celerybeat.pid 103 | 104 | # SageMath parsed files 105 | *.sage.py 106 | 107 | # Environments 108 | .env 109 | .venv 110 | env/ 111 | venv/ 112 | ENV/ 113 | env.bak/ 114 | venv.bak/ 115 | 116 | # Spyder project settings 117 | .spyderproject 118 | .spyproject 119 | 120 | # Rope project settings 121 | .ropeproject 122 | 123 | # mkdocs documentation 124 | /site 125 | 126 | # mypy 127 | .mypy_cache/ 128 | .dmypy.json 129 | dmypy.json 130 | 131 | # Pyre type checker 132 | .pyre/ 133 | 134 | # pytype static type analyzer 135 | .pytype/ 136 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # fastai_tpu 2 | TPU support for the fastai library 3 | 4 | See [this fastai forum thread](https://forums.fast.ai/t/fastai-v2-tpu-support/75421) for more information. 5 | -------------------------------------------------------------------------------- /fastai_v1/SingleCoreLog.txt: -------------------------------------------------------------------------------- 1 | 2019-12-05 01:45:02.609962: I 1890 torch_xla/csrc/tensor_util.cpp:28] Using BF16 data type for floating point values 2 | hello 3 | epoch train_loss valid_loss accuracy time 4 | -------------------------------------------------------------------------------- /fastai_v1/tpu_distributed_fastai.py: -------------------------------------------------------------------------------- 1 | import warnings 2 | warnings.filterwarnings('ignore') 3 | 4 | import torch_xla 5 | import torch_xla.distributed.data_parallel as dp 6 | import torch_xla.utils.utils as xu 7 | import torch_xla.core.xla_model as xm 8 | import torch_xla.distributed.parallel_loader as pl 9 | import torch_xla.distributed.xla_multiprocessing as xmp 10 | import torch 11 | 12 | import fastai 13 | from fastai import * 14 | from fastai.core import * 15 | from fastai.torch_core import * 16 | from fastai.vision import * 17 | from fastai.basic_train import * 18 | 19 | def len_parallelloader(self): 20 | return len(self._loader._loader) 21 | pl.PerDeviceLoader.__len__ = len_parallelloader 22 | 23 | 24 | class TPUDistributed(LearnerCallback): 25 | def __init__(self, learn:Learner): 26 | super().__init__(learn) 27 | self.device = xm.xla_device() 28 | 29 | def _change_dl(self,dl, shuffle): 30 | old_dl = dl 31 | sampler = torch.utils.data.distributed.DistributedSampler( 32 | dl.dataset, 33 | num_replicas=xm.xrt_world_size(), 34 | rank=xm.get_ordinal(), 35 | shuffle=shuffle 36 | ) 37 | new_dl = dl.new(shuffle=False, sampler=sampler) 38 | return old_dl,new_dl,sampler 39 | 40 | 41 | def on_train_begin(self, **kwargs:Any)->None: 42 | self.learn.model = self.learn.model.to(self.device) 43 | self.learn.opt.lr = self.learn.opt.lr*xm.xrt_world_size() 44 | 45 | shuffle = self.data.train_dl.init_kwargs['shuffle'] if hasattr(self.data.train_dl, 'init_kwargs') else True 46 | self.old_sampler_train_dl,self.data.train_dl,self.train_sampler = self._change_dl(self.data.train_dl, shuffle) 47 | if hasattr(self.data, 'valid_dl') and self.data.valid_dl is not None: 48 | self.old_sampler_valid_dl,self.data.valid_dl,self.valid_sampler = self._change_dl(self.data.valid_dl, shuffle) 49 | def on_epoch_begin(self,**kwargs:Any)->None: 50 | self.old_train_dl = self.data.train_dl 51 | self.learn.data.train_dl = pl.ParallelLoader(self.old_train_dl, [self.device]).per_device_loader(self.device) 52 | self.learn.data.train_dl.dataset = None #self.old_train_dl.dataset 53 | if hasattr(self.data, 'valid_dl') and self.data.valid_dl is not None: 54 | self.old_valid_dl = self.learn.data.valid_dl 55 | self.learn.data.valid_dl = pl.ParallelLoader(self.old_valid_dl, [self.device]).per_device_loader(self.device) 56 | 57 | self.learn.data.valid_dl.dataset = self.old_valid_dl.dataset 58 | self.learn.data.valid_dl.dl = self.learn.data.valid_dl._loader._loader 59 | 60 | def on_backward_end(self, **kwargs:Any)->None: 61 | xm.optimizer_step(self.learn.opt) 62 | return {'skip_step': True} 63 | 64 | def on_epoch_end(self,**kwargs:Any)->None: 65 | self.learn.data.train_dl = self.old_train_dl 66 | self.learn.data.valid_dl = self.old_valid_dl 67 | 68 | def on_train_end(self,**kwargs:Any)->None: 69 | self.learn.data.train_dl = self.old_sampler_train_dl 70 | self.learn.data.valid_dl = self.old_sampler_valid_dl 71 | 72 | 73 | def _to_tpu_distributed(learn:Learner) -> Learner: 74 | #Learner.fit = _fit_tpu 75 | learn.callback_fns.append(TPUDistributed) 76 | return learn 77 | 78 | 79 | Learner.to_tpu_distributed = _to_tpu_distributed 80 | 81 | 82 | path = untar_data(URLs.FOOD) 83 | def filelist2df(path): 84 | df = pd.read_csv(path, delimiter='/', header=None, names=['label', 'name']) 85 | df['name'] = df['label'].astype(str) + "/" + df['name'].astype(str) + ".jpg" 86 | return df 87 | 88 | train_path = path/'train.txt' 89 | test_path = path/'test.txt' 90 | 91 | def train_loop(index): 92 | train_df = filelist2df(train_path) 93 | test_df = filelist2df(test_path) 94 | 95 | 96 | data = (ImageList.from_df(df=train_df, path=path/'images', cols=1) 97 | .random_split_by_pct(0.2) 98 | .label_from_df(cols=0) 99 | .transform(get_transforms(),size=224) 100 | .databunch(bs=256, num_workers=4) 101 | .normalize(imagenet_stats)) 102 | learn = cnn_learner(data, models.resnet152, metrics=accuracy).to_tpu_distributed() 103 | print('hello') 104 | learn.fit(4) 105 | 106 | if __name__ == "__main__": 107 | xmp.spawn(train_loop,args=()) 108 | 109 | -------------------------------------------------------------------------------- /fastai_v1/tpu_single_core_fastai.py: -------------------------------------------------------------------------------- 1 | import warnings 2 | warnings.filterwarnings('ignore') 3 | 4 | import torch_xla 5 | import torch_xla.distributed.data_parallel as dp 6 | import torch_xla.utils.utils as xu 7 | import torch_xla.core.xla_model as xm 8 | import torch_xla.distributed.parallel_loader as pl 9 | import torch_xla.distributed.xla_multiprocessing as xmp 10 | import torch 11 | 12 | import fastai 13 | from fastai import * 14 | from fastai.core import * 15 | from fastai.torch_core import * 16 | from fastai.vision import * 17 | from fastai.basic_train import * 18 | 19 | class TPUSingleCore(LearnerCallback): 20 | def __init__(self, learn:Learner): 21 | super().__init__(learn) 22 | self.device = xm.xla_device() 23 | 24 | def on_train_begin(self, **kwargs:Any)->None: 25 | self.learn.model = self.learn.model.to(self.device) 26 | 27 | def on_batch_begin(self, last_input, last_target, train, **kwargs): 28 | return {'last_input': last_input.to(self.device), 'last_target': last_target.to(self.device)} 29 | 30 | def on_backward_end(self, **kwargs:Any)->None: 31 | xm.optimizer_step(self.learn.opt) 32 | return {'skip_step': True} 33 | 34 | 35 | def _to_tpu_single(learn:Learner) -> Learner: 36 | #Learner.fit = _fit_tpu 37 | learn.callback_fns.append(TPUSingleCore) 38 | return learn 39 | 40 | 41 | Learner.to_tpu_distributed = _to_tpu_single 42 | 43 | 44 | path = untar_data(URLs.FOOD) 45 | def filelist2df(path): 46 | df = pd.read_csv(path, delimiter='/', header=None, names=['label', 'name']) 47 | df['name'] = df['label'].astype(str) + "/" + df['name'].astype(str) + ".jpg" 48 | return df 49 | 50 | train_path = path/'train.txt' 51 | test_path = path/'test.txt' 52 | 53 | def train_loop(): 54 | train_df = filelist2df(train_path) 55 | test_df = filelist2df(test_path) 56 | 57 | 58 | data = (ImageList.from_df(df=train_df, path=path/'images', cols=1) 59 | .random_split_by_pct(0.2) 60 | .label_from_df(cols=0) 61 | .transform(get_transforms(), size=224) 62 | .databunch(bs=512, num_workers=16) 63 | .normalize(imagenet_stats)) 64 | learn = cnn_learner(data, models.resnet152, metrics=accuracy).to_tpu_distributed() 65 | print('hello') 66 | learn.fit(3) 67 | 68 | train_loop() 69 | 70 | -------------------------------------------------------------------------------- /fastai_v2/BENCHMARKS.md: -------------------------------------------------------------------------------- 1 | # Benchmarks 2 | 3 | ## GPU benchmarks 4 | 5 | | GPU type | number of GPUs | CPU type | batch size | `num_threads`| Time 6 | --- | --- | --- | --- | --- | --- | 7 | | P100 | 1 | n1-standard-16 | 64 |None?|10:05| 8 | 9 | ## TPU v3-8 benchmarks 10 | 11 | | CPU type | `num_threads`| Time 12 | | --- | --- | --- | 13 | | n1-standard-16 | | 14 | -------------------------------------------------------------------------------- /fastai_v2/fastai2_gpu.py: -------------------------------------------------------------------------------- 1 | from fastai.basics import * 2 | from fastai.callback.all import * 3 | from fastai.vision.all import * 4 | from torchvision import models 5 | 6 | def filelist2df(path): 7 | df = pd.read_csv(path, delimiter='/', header=None, names=['label', 'name']) 8 | df['name'] = df['label'].astype(str) + "/" + df['name'].astype(str) + ".jpg" 9 | return df 10 | 11 | path = untar_data(URLs.FOOD) 12 | train_path = path/'train.txt' 13 | test_path = path/'test.txt' 14 | 15 | def train_loop(): 16 | train_df = filelist2df(train_path) 17 | test_df = filelist2df(test_path) 18 | food = DataBlock(blocks=(ImageBlock, CategoryBlock), 19 | get_x = ColReader(1,pref=path/'images'), 20 | splitter = RandomSplitter(), 21 | get_y = ColReader(cols=0), 22 | item_tfms=Resize(224), 23 | batch_tfms=aug_transforms(flip_vert=True, max_lighting=0.1, max_zoom=1.05, max_warp=0.) 24 | ) 25 | dls = food.dataloaders(train_df.values,bs=64) 26 | learn = cnn_learner(dls, models.resnet152, metrics=accuracy) 27 | learn.fit(3) 28 | 29 | if __name__ == "__main__": 30 | train_loop() -------------------------------------------------------------------------------- /fastai_v2/fastai_multiprocessing_dl.py: -------------------------------------------------------------------------------- 1 | from torch.multiprocessing import Pool, set_start_method 2 | from fastai.vision.all import * 3 | 4 | 5 | def filelist2df(path): 6 | df = pd.read_csv(path, delimiter='/', header=None, names=['label', 'name']) 7 | df['name'] = df['label'].astype(str) + "/" + df['name'].astype(str) + ".jpg" 8 | return df 9 | 10 | path = untar_data(URLs.FOOD) 11 | train_path = path/'train.txt' 12 | test_path = path/'test.txt' 13 | 14 | def load_data(index): 15 | train_df = filelist2df(train_path) 16 | test_df = filelist2df(test_path) 17 | food = DataBlock(blocks=(ImageBlock, CategoryBlock), get_x = ColReader(1,pref=path/'images'), splitter = RandomSplitter(), get_y = ColReader(cols=0), item_tfms=Resize(224)) 18 | dls = food.dataloaders(train_df.values,bs=64) 19 | 20 | 21 | if __name__ == '__main__': 22 | set_start_method('spawn', force=True) 23 | try: 24 | pool = Pool(8) 25 | pool.map(load_data, [1,2,3,4,5,6,7,8]) 26 | except KeyboardInterrupt: 27 | exit() 28 | finally: 29 | pool.terminate() 30 | pool.join() 31 | -------------------------------------------------------------------------------- /fastai_v2/run_tpu.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | # Setup environment and run fastai training code on TPU 3 | TPU_IP_ADDRESS="10.9.54.194" 4 | export XRT_TPU_CONFIG="tpu_worker;0;$TPU_IP_ADDRESS:8470" 5 | python tpu_distributed_fastai2.py 6 | -------------------------------------------------------------------------------- /fastai_v2/test_tpu_distributed_dl.py: -------------------------------------------------------------------------------- 1 | import warnings 2 | warnings.filterwarnings('ignore') 3 | 4 | import torch_xla 5 | import torch_xla.distributed.data_parallel as dp 6 | import torch_xla.utils.utils as xu 7 | import torch_xla.core.xla_model as xm 8 | import torch_xla.distributed.parallel_loader as pl 9 | import torch_xla.distributed.xla_multiprocessing as xmp 10 | import torch 11 | 12 | import fastai 13 | from fastai.callback.all import * 14 | from fastai.vision.all import * 15 | from fastai.distributed import * 16 | from fastai.data.load import _FakeLoader 17 | 18 | def _fa_rebuild_tensor (cls, *args, **kwargs): return cls(torch._utils._rebuild_tensor_v2 (*args, **kwargs)) 19 | def _fa_rebuild_qtensor (cls, *args, **kwargs): return cls(torch._utils._rebuild_qtensor (*args, **kwargs)) 20 | def _fa_rebuild_xla_tensor(cls, *args, **kwargs): return cls(torch._utils._rebuild_xla_tensor(*args, **kwargs)) 21 | 22 | @patch 23 | def __reduce_ex__(self:TensorBase, proto): 24 | torch.utils.hooks.warn_if_has_hooks(self) 25 | if self.device.type == 'xla': 26 | args = (type(self), self.cpu().numpy(), self.dtype, str(self.device), self.requires_grad) 27 | return (_fa_rebuild_xla_tensor, args) 28 | 29 | args = (type(self), self.storage(), self.storage_offset(), tuple(self.size()), self.stride()) 30 | if self.is_quantized: args = args + (self.q_scale(), self.q_zero_point()) 31 | f = _fa_rebuild_qtensor if self.is_quantized else _fa_rebuild_tensor 32 | return (f, args + (self.requires_grad, OrderedDict())) 33 | 34 | def _round_to_multiple(number,multiple): return int(math.ceil(number/multiple)*multiple) 35 | 36 | class TPUDistributedDL(TfmdDL): 37 | "A `TfmdDL` which splits a batch into equal size pieces for each TPU core" 38 | def __init__(self,dl,rank,world_size): 39 | store_attr() 40 | self.bs,self.device,self.num_workers,self.drop_last,self.dataset,self.offs,fake = \ 41 | attrgetter('bs','device','num_workers','drop_last','dataset','offs','fake_l')(dl) 42 | self.fake_l = _FakeLoader(self, fake.pin_memory, fake.num_workers, fake.timeout, persistent_workers=fake.persistent_workers) 43 | self.SERIAL_EXEC = xmp.MpSerialExecutor() 44 | 45 | def _to_detach(self,b,cpu=True,gather=True): return to_detach(b,cpu,gather) # member func so we can override for test 46 | def __len__(self): return _round_to_multiple(len(self.dl),self.world_size)//self.world_size 47 | def get_idxs(self): 48 | idxs = self.SERIAL_EXEC.run(self.dl.get_idxs) # compute get_idxs in all ranks (we'll only use rank 0 but size must be consistent) 49 | self.n = len(idxs) # we assumed n was dl.n but we really care about number of idxs 50 | # add extra samples to make it evenly divisible 51 | self.n_padded = _round_to_multiple(self.n,self.world_size) 52 | idxs += (idxs * (self.n_padded//self.n))[:self.n_padded-self.n] # idx needs to be repeated when n_padded>>n 53 | # slice padded idxs so that each rank gets self.n_padded//self.world_size tensors 54 | return idxs[self.rank*self.n_padded//self.world_size:(self.rank+1)*self.n_padded//self.world_size] 55 | 56 | def before_iter(self): 57 | self.i = 0 58 | self.dl.before_iter() 59 | 60 | def randomize(self): self.dl.randomize() 61 | def after_batch(self,b): 62 | self.i += find_bs(b) 63 | return self.dl.after_batch(b) 64 | 65 | def after_iter(self): self.dl.after_iter() 66 | def create_batches(self,samps): return self.dl.create_batches(samps) 67 | def to_detach(self,b, cpu=True, gather=True): 68 | b = self._to_detach(b, cpu, gather) 69 | def _inner(b): 70 | if b.ndim>0: 71 | # for each rank, compute overflow of read idxs vs self.n and accumulate them to unpad totals after gathering 72 | n = sum([min(0,max(-len(b)//self.world_size, 73 | self.n-(self.i+r*self.n_padded//self.world_size))) for r in range(self.world_size)]) 74 | b = b[:n or None] 75 | return b 76 | return apply(_inner,b) if gather and all(hasattr(self,o) for o in ('i','n','n_padded')) else b 77 | 78 | 79 | 80 | 81 | 82 | def train_loop(index): 83 | dl = TfmdDL(list(range(50)), bs=12, num_workers=2) 84 | distributed_dl = pl.ParallelLoader(TPUDistributedDL(dl, xm.get_ordinal(), xm.xrt_world_size()), [self.device]).per_device_loader(self.device) 85 | print(xm.get_ordinal(), next(iter(distributed_dl)) 86 | print(xm.get_ordinal(), list(distributed_dl)) 87 | 88 | 89 | if __name__ == "__main__": 90 | xmp.spawn(train_loop,nprocs=8,args=()) 91 | -------------------------------------------------------------------------------- /fastai_v2/tpu_distributed_dl.py: -------------------------------------------------------------------------------- 1 | import warnings 2 | warnings.filterwarnings('ignore') 3 | 4 | import torch_xla 5 | import torch_xla.distributed.data_parallel as dp 6 | import torch_xla.utils.utils as xu 7 | import torch_xla.core.xla_model as xm 8 | import torch_xla.distributed.parallel_loader as pl 9 | import torch_xla.distributed.xla_multiprocessing as xmp 10 | import torch 11 | 12 | import fastai 13 | from fastai.callback.all import * 14 | from fastai.vision.all import * 15 | from fastai.distributed import * 16 | from fastai.data.load import _FakeLoader 17 | 18 | def _round_to_multiple(number,multiple): return int(math.ceil(number/multiple)*multiple) 19 | 20 | class TPUDistributedDL(TfmdDL): 21 | "A `TfmdDL` which splits a batch into equal size pieces for each TPU core" 22 | def __init__(self,dl,rank,world_size): 23 | store_attr() 24 | self.bs,self.device,self.num_workers,self.drop_last,self.dataset,self.offs,fake = \ 25 | attrgetter('bs','device','num_workers','drop_last','dataset','offs','fake_l')(dl) 26 | self.fake_l = _FakeLoader(self, fake.pin_memory, fake.num_workers, fake.timeout, persistent_workers=fake.persistent_workers) 27 | self.SERIAL_EXEC = xmp.MpSerialExecutor() 28 | 29 | def _to_detach(self,b,cpu=True,gather=True): return to_detach(b,cpu,gather) # member func so we can override for test 30 | def __len__(self): return _round_to_multiple(len(self.dl),self.world_size)//self.world_size 31 | def get_idxs(self): 32 | idxs = self.SERIAL_EXEC.run(self.dl.get_idxs) # compute get_idxs in all ranks (we'll only use rank 0 but size must be consistent) 33 | self.n = len(idxs) # we assumed n was dl.n but we really care about number of idxs 34 | # add extra samples to make it evenly divisible 35 | self.n_padded = _round_to_multiple(self.n,self.world_size) 36 | idxs += (idxs * (self.n_padded//self.n))[:self.n_padded-self.n] # idx needs to be repeated when n_padded>>n 37 | # slice padded idxs so that each rank gets self.n_padded//self.world_size tensors 38 | return idxs[self.rank*self.n_padded//self.world_size:(self.rank+1)*self.n_padded//self.world_size] 39 | 40 | def before_iter(self): 41 | self.i = 0 42 | self.dl.before_iter() 43 | 44 | def randomize(self): self.dl.randomize() 45 | def after_batch(self,b): 46 | self.i += find_bs(b) 47 | return self.dl.after_batch(b) 48 | 49 | def after_iter(self): self.dl.after_iter() 50 | def create_batches(self,samps): return self.dl.create_batches(samps) 51 | def to_detach(self,b, cpu=True, gather=True): 52 | b = self._to_detach(b, cpu, gather) 53 | def _inner(b): 54 | if b.ndim>0: 55 | # for each rank, compute overflow of read idxs vs self.n and accumulate them to unpad totals after gathering 56 | n = sum([min(0,max(-len(b)//self.world_size, 57 | self.n-(self.i+r*self.n_padded//self.world_size))) for r in range(self.world_size)]) 58 | b = b[:n or None] 59 | return b 60 | return apply(_inner,b) if gather and all(hasattr(self,o) for o in ('i','n','n_padded')) else b 61 | -------------------------------------------------------------------------------- /fastai_v2/tpu_distributed_fastai2.py: -------------------------------------------------------------------------------- 1 | import warnings 2 | warnings.filterwarnings('ignore') 3 | 4 | import torch_xla 5 | import torch_xla.distributed.data_parallel as dp 6 | import torch_xla.utils.utils as xu 7 | import torch_xla.core.xla_model as xm 8 | import torch_xla.distributed.parallel_loader as pl 9 | import torch_xla.distributed.xla_multiprocessing as xmp 10 | import torch 11 | 12 | import fastai 13 | from fastai.callback.all import * 14 | from fastai.vision.all import * 15 | from fastai.distributed import * 16 | from fastai.data.load import _FakeLoader 17 | 18 | def _fa_rebuild_tensor (cls, *args, **kwargs): return cls(torch._utils._rebuild_tensor_v2 (*args, **kwargs)) 19 | def _fa_rebuild_qtensor (cls, *args, **kwargs): return cls(torch._utils._rebuild_qtensor (*args, **kwargs)) 20 | def _fa_rebuild_xla_tensor(cls, *args, **kwargs): return cls(torch._utils._rebuild_xla_tensor(*args, **kwargs)) 21 | 22 | @patch 23 | def __reduce_ex__(self:TensorBase, proto): 24 | torch.utils.hooks.warn_if_has_hooks(self) 25 | if self.device.type == 'xla': 26 | args = (type(self), self.cpu().numpy(), self.dtype, str(self.device), self.requires_grad) 27 | return (_fa_rebuild_xla_tensor, args) 28 | 29 | args = (type(self), self.storage(), self.storage_offset(), tuple(self.size()), self.stride()) 30 | if self.is_quantized: args = args + (self.q_scale(), self.q_zero_point()) 31 | f = _fa_rebuild_qtensor if self.is_quantized else _fa_rebuild_tensor 32 | return (f, args + (self.requires_grad, OrderedDict())) 33 | 34 | @patch 35 | def __getstate__(self: Optimizer): 36 | optim_dict = self.__dict__.copy() 37 | modified_dict = {**optim_dict, 'param_groups': self.param_groups} #this change needed since PyTorch XLA wants it! 38 | return modified_dict 39 | 40 | @patch 41 | def __setstate__(self: Optimizer,state): 42 | print('setstate Optimizer dict: ', self.__dict__.keys()) 43 | del state['param_groups'] 44 | self.__dict__.update(state) 45 | 46 | @patch 47 | def set_epoch(self: pl.PerDeviceLoader,epoch): 48 | self._loader._loader.set_epoch(epoch) 49 | 50 | def _round_to_multiple(number,multiple): return int(math.ceil(number/multiple)*multiple) 51 | 52 | class TPUDistributedDL(TfmdDL): 53 | "A `TfmdDL` which splits a batch into equal size pieces for each TPU core" 54 | def __init__(self,dl,rank,world_size): 55 | store_attr() 56 | self.bs,self.device,self.num_workers,self.drop_last,self.dataset,self.offs,fake = \ 57 | attrgetter('bs','device','num_workers','drop_last','dataset','offs','fake_l')(dl) 58 | self.fake_l = _FakeLoader(self, fake.pin_memory, fake.num_workers, fake.timeout, persistent_workers=fake.persistent_workers) 59 | self.SERIAL_EXEC = xmp.MpSerialExecutor() 60 | 61 | def _to_detach(self,b,cpu=True,gather=True): return to_detach(b,cpu,gather) # member func so we can override for test 62 | def __len__(self): return _round_to_multiple(len(self.dl),self.world_size)//self.world_size 63 | def get_idxs(self): 64 | idxs = self.SERIAL_EXEC.run(self.dl.get_idxs) # compute get_idxs in all ranks (we'll only use rank 0 but size must be consistent) 65 | self.n = len(idxs) # we assumed n was dl.n but we really care about number of idxs 66 | # add extra samples to make it evenly divisible 67 | self.n_padded = _round_to_multiple(self.n,self.world_size) 68 | idxs += (idxs * (self.n_padded//self.n))[:self.n_padded-self.n] # idx needs to be repeated when n_padded>>n 69 | # slice padded idxs so that each rank gets self.n_padded//self.world_size tensors 70 | return idxs[self.rank*self.n_padded//self.world_size:(self.rank+1)*self.n_padded//self.world_size] 71 | 72 | def before_iter(self): 73 | self.i = 0 74 | self.dl.before_iter() 75 | 76 | def randomize(self): self.dl.randomize() 77 | def after_batch(self,b): 78 | self.i += find_bs(b) 79 | return self.dl.after_batch(b) 80 | 81 | def after_iter(self): self.dl.after_iter() 82 | def create_batches(self,samps): return self.dl.create_batches(samps) 83 | def to_detach(self,b, cpu=True, gather=True): 84 | b = self._to_detach(b, cpu, gather) 85 | def _inner(b): 86 | if b.ndim>0: 87 | # for each rank, compute overflow of read idxs vs self.n and accumulate them to unpad totals after gathering 88 | n = sum([min(0,max(-len(b)//self.world_size, 89 | self.n-(self.i+r*self.n_padded//self.world_size))) for r in range(self.world_size)]) 90 | b = b[:n or None] 91 | return b 92 | return apply(_inner,b) if gather and all(hasattr(self,o) for o in ('i','n','n_padded')) else b 93 | 94 | 95 | 96 | # Much of the below code is inspired by the GPU distributed callback 97 | class TPUDistributed(Callback): 98 | def __init__(self): 99 | self.device = xm.xla_device() 100 | 101 | def _wrap_dl(self, dl): 102 | if isinstance(dl, pl.PerDeviceLoader): 103 | return dl 104 | else: 105 | #dl = dl.to(self.device) 106 | dl.fake_l.num_workers=0 # For some reason, needed for it to work (something on fastai's end). Need to investigate further 107 | distributed_dl = TPUDistributedDL(dl, xm.get_ordinal(), xm.xrt_world_size()) # Use existing distributed functionality 108 | return pl.ParallelLoader(distributed_dl, [self.device]).per_device_loader(self.device) 109 | 110 | def before_fit(self): 111 | xm.master_print('begin fit') 112 | self.learn.model = self.learn.model.to(self.device) 113 | for h in self.opt.hypers: h['lr'] *= xm.xrt_world_size() 114 | self.old_dls = list(self.dls) 115 | print('wrapping dls') 116 | self.learn.dls.loaders = [self._wrap_dl(dl) for dl in self.dls] 117 | # for dl in self.dls: dl.set_epoch(self.epoch) 118 | 119 | #def before_epoch(self): 120 | # for dl in self.dls: dl.set_epoch(self.epoch) 121 | 122 | def before_train(self): 123 | self.learn.dl = self._wrap_dl(self.learn.dl) 124 | 125 | def before_batch(self): 126 | self.learn.xb = [xb_item.to(self.device) for xb_item in self.xb] 127 | self.learn.yb = [yb_item.to(self.device) for yb_item in self.yb] 128 | def after_backward(self): 129 | xm.optimizer_step(self.learn.opt) 130 | self.learn.opt.zero_grad() 131 | return CancelBatchException 132 | 133 | 134 | def before_validate(self): self.learn.dl = self._wrap_dl(self.learn.dl) 135 | 136 | def after_fit(self): 137 | self.learn.dls.loaders = self.old_dls 138 | 139 | @patch 140 | def to_tpu_distributed(self:Learner): 141 | self.add_cbs([TPUDistributed()]) 142 | return self 143 | 144 | 145 | 146 | def filelist2df(path): 147 | df = pd.read_csv(path, delimiter='/', header=None, names=['label', 'name']) 148 | df['name'] = df['label'].astype(str) + "/" + df['name'].astype(str) + ".jpg" 149 | return df 150 | 151 | path = untar_data(URLs.FOOD) 152 | train_path = path/'train.txt' 153 | test_path = path/'test.txt' 154 | 155 | 156 | def train_loop(index): 157 | print('index: ',index) 158 | train_df = filelist2df(train_path) 159 | test_df = filelist2df(test_path) 160 | food = DataBlock(blocks=(ImageBlock, CategoryBlock), 161 | get_x = ColReader(1,pref=path/'images'), 162 | splitter = RandomSplitter(), 163 | get_y = ColReader(cols=0), 164 | item_tfms=Resize(224) 165 | # batch_tfms=aug_transforms(flip_vert=True, max_lighting=0.1, max_zoom=1.05, max_warp=0.) <-- ignore batch (on-TPU) tfms for now 166 | ) 167 | dls = food.dataloaders(train_df.values,bs=256) 168 | learn = cnn_learner(dls, resnet152, metrics=accuracy).to_tpu_distributed() 169 | learn.fit(3) 170 | 171 | if __name__ == "__main__": 172 | xmp.spawn(train_loop,nprocs=8,args=()) 173 | --------------------------------------------------------------------------------