├── .gitignore ├── LICENSE ├── README.md ├── train_1gpu.py ├── train_ddp.py ├── train_ddp_mixed_presicion.py ├── train_dp.py └── utils.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | data/* 6 | data*/ 7 | cifar* 8 | # C extensions 9 | *.so 10 | 11 | # Distribution / packaging 12 | .Python 13 | build/ 14 | develop-eggs/ 15 | dist/ 16 | downloads/ 17 | eggs/ 18 | .eggs/ 19 | lib/ 20 | lib64/ 21 | parts/ 22 | sdist/ 23 | var/ 24 | wheels/ 25 | pip-wheel-metadata/ 26 | share/python-wheels/ 27 | *.egg-info/ 28 | .installed.cfg 29 | *.egg 30 | MANIFEST 31 | 32 | # PyInstaller 33 | # Usually these files are written by a python script from a template 34 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 35 | *.manifest 36 | *.spec 37 | 38 | # Installer logs 39 | pip-log.txt 40 | pip-delete-this-directory.txt 41 | 42 | # Unit test / coverage reports 43 | htmlcov/ 44 | .tox/ 45 | .nox/ 46 | .coverage 47 | .coverage.* 48 | .cache 49 | nosetests.xml 50 | coverage.xml 51 | *.cover 52 | *.py,cover 53 | .hypothesis/ 54 | .pytest_cache/ 55 | 56 | # Translations 57 | *.mo 58 | *.pot 59 | 60 | # Django stuff: 61 | *.log 62 | local_settings.py 63 | db.sqlite3 64 | db.sqlite3-journal 65 | 66 | # Flask stuff: 67 | instance/ 68 | .webassets-cache 69 | 70 | # Scrapy stuff: 71 | .scrapy 72 | 73 | # Sphinx documentation 74 | docs/_build/ 75 | 76 | # PyBuilder 77 | target/ 78 | 79 | # Jupyter Notebook 80 | .ipynb_checkpoints 81 | 82 | # IPython 83 | profile_default/ 84 | ipython_config.py 85 | 86 | # pyenv 87 | .python-version 88 | 89 | # pipenv 90 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 91 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 92 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 93 | # install all needed dependencies. 94 | #Pipfile.lock 95 | 96 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 97 | __pypackages__/ 98 | 99 | # Celery stuff 100 | celerybeat-schedule 101 | celerybeat.pid 102 | 103 | # SageMath parsed files 104 | *.sage.py 105 | 106 | # Environments 107 | .env 108 | .venv 109 | env/ 110 | venv/ 111 | ENV/ 112 | env.bak/ 113 | venv.bak/ 114 | 115 | # Spyder project settings 116 | .spyderproject 117 | .spyproject 118 | 119 | # Rope project settings 120 | .ropeproject 121 | 122 | # mkdocs documentation 123 | /site 124 | 125 | # mypy 126 | .mypy_cache/ 127 | .dmypy.json 128 | dmypy.json 129 | 130 | # Pyre type checker 131 | .pyre/ 132 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2022 AI Summer 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # pytorch-ddp 2 | code for the ddp tutorial 3 | 4 | 5 | # 1 gpu 6 | Accuracy of the network on the 10000 test images: 27 % 7 | Total elapsed time: 69.03 seconds, Train 1 epoch 13.08 seconds 8 | 9 | 10 | # DDP 11 | ``` 12 | python -m torch.distributed.launch --nproc_per_node=4 train_ddp.py 13 | ``` 14 | ## ddp 4gpus 15 | Accuracy of the network on the 10000 test images: 14 % 16 | Total elapsed time: 70.23 seconds, Train 1 epoch 6.11 seconds 17 | 18 | ## ddp 2gpus 19 | Accuracy of the network on the 10000 test images: 19 % 20 | Total elapsed time: 97.03 seconds, Train 1 epoch 9.79 seconds 21 | 22 | 23 | ## mixed precision ddp 4gpus 24 | Accuracy of the network on the 10000 test images: 15 % 25 | Total elapsed time: 70.61 seconds, Train 1 epoch 6.52 seconds 26 | -------------------------------------------------------------------------------- /train_1gpu.py: -------------------------------------------------------------------------------- 1 | """ 2 | Mostly based on the official pytorch tutorial 3 | Link: https://pytorch.org/tutorials/beginner/blitz/cifar10_tutorial.html 4 | Modified for educational purposes. 5 | Nikolas, AI Summer 20222 6 | """ 7 | import torch 8 | import torchvision 9 | import torchvision.transforms as transforms 10 | import torch.nn as nn 11 | import torch.nn.functional as F 12 | import torch.optim as optim 13 | import time 14 | 15 | 16 | def create_data_loader_cifar10(): 17 | transform = transforms.Compose( 18 | [ 19 | transforms.RandomCrop(32), 20 | transforms.RandomHorizontalFlip(), 21 | transforms.ToTensor(), 22 | transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))]) 23 | 24 | batch_size = 256 25 | 26 | trainset = torchvision.datasets.CIFAR10(root='./data', train=True, 27 | download=True, transform=transform) 28 | trainloader = torch.utils.data.DataLoader(trainset, batch_size=batch_size, 29 | shuffle=True, num_workers=10, pin_memory=True) 30 | 31 | testset = torchvision.datasets.CIFAR10(root='./data', train=False, 32 | download=True, transform=transform) 33 | testloader = torch.utils.data.DataLoader(testset, batch_size=batch_size, 34 | shuffle=False, num_workers=10) 35 | return trainloader, testloader 36 | 37 | 38 | def train(net, trainloader): 39 | print("Start training...") 40 | criterion = nn.CrossEntropyLoss() 41 | optimizer = optim.SGD(net.parameters(), lr=0.001, momentum=0.9) 42 | epochs = 1 43 | num_of_batches = len(trainloader) 44 | for epoch in range(epochs): # loop over the dataset multiple times 45 | 46 | running_loss = 0.0 47 | for i, data in enumerate(trainloader, 0): 48 | # get the inputs; data is a list of [inputs, labels] 49 | inputs, labels = data 50 | 51 | images, labels = inputs.cuda(), labels.cuda() 52 | 53 | # zero the parameter gradients 54 | optimizer.zero_grad() 55 | 56 | # forward + backward + optimize 57 | outputs = net(images) 58 | loss = criterion(outputs, labels) 59 | loss.backward() 60 | optimizer.step() 61 | 62 | # print statistics 63 | running_loss += loss.item() 64 | 65 | print(f'[Epoch {epoch + 1}/{epochs}] loss: {running_loss / num_of_batches:.3f}') 66 | 67 | print('Finished Training') 68 | 69 | 70 | def test(net, PATH, testloader): 71 | net.load_state_dict(torch.load(PATH)) 72 | 73 | correct = 0 74 | total = 0 75 | # since we're not training, we don't need to calculate the gradients for our outputs 76 | with torch.no_grad(): 77 | for data in testloader: 78 | images, labels = data 79 | 80 | images, labels = images.cuda(), labels.cuda() 81 | # calculate outputs by running images through the network 82 | outputs = net(images) 83 | # the class with the highest energy is what we choose as prediction 84 | _, predicted = torch.max(outputs.data, 1) 85 | total += labels.size(0) 86 | correct += (predicted == labels).sum().item() 87 | acc = 100 * correct // total 88 | print(f'Accuracy of the network on the 10000 test images: {acc} %') 89 | 90 | 91 | if __name__ == '__main__': 92 | start = time.time() 93 | 94 | import torchvision 95 | 96 | PATH = './cifar_net.pth' 97 | trainloader, testloader = create_data_loader_cifar10() 98 | net = torchvision.models.resnet50(False).cuda() 99 | start_train = time.time() 100 | train(net, trainloader) 101 | end_train = time.time() 102 | # save 103 | torch.save(net.state_dict(), PATH) 104 | # test 105 | test(net, PATH, testloader) 106 | 107 | end = time.time() 108 | seconds = (end - start) 109 | seconds_train = (end_train - start_train) 110 | print(f"Total elapsed time: {seconds:.2f} seconds, \ 111 | Train 1 epoch {seconds_train:.2f} seconds") 112 | 113 | 114 | 115 | -------------------------------------------------------------------------------- /train_ddp.py: -------------------------------------------------------------------------------- 1 | """ 2 | Mostly based on the official pytorch tutorial 3 | Link: https://pytorch.org/tutorials/beginner/blitz/cifar10_tutorial.html 4 | Modified for educational purposes. 5 | Nikolas, AI Summer 6 | """ 7 | import os 8 | gpu_list = "0,1,2,3" 9 | os.environ['CUDA_VISIBLE_DEVICES'] = gpu_list 10 | 11 | import torch 12 | import torchvision 13 | import torchvision.transforms as transforms 14 | import torch.nn as nn 15 | import torch.nn.functional as F 16 | import torch.optim as optim 17 | import torch.distributed as dist 18 | import time 19 | import torchvision 20 | 21 | from utils import setup_for_distributed, save_on_master, is_main_process 22 | 23 | 24 | def create_data_loader_cifar10(): 25 | transform = transforms.Compose( 26 | [ 27 | transforms.RandomCrop(32), 28 | transforms.RandomHorizontalFlip(), 29 | transforms.ToTensor(), 30 | transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))]) 31 | 32 | batch_size = 256 33 | 34 | trainset = torchvision.datasets.CIFAR10(root='./data', train=True, 35 | download=True, transform=transform) 36 | train_sampler = torch.utils.data.distributed.DistributedSampler(dataset=trainset, shuffle=True) 37 | trainloader = torch.utils.data.DataLoader(trainset, batch_size=batch_size, 38 | sampler=train_sampler, num_workers=16, pin_memory=True) 39 | 40 | testset = torchvision.datasets.CIFAR10(root='./data', train=False, 41 | download=True, transform=transform) 42 | test_sampler = torch.utils.data.distributed.DistributedSampler(dataset=testset, shuffle=True) 43 | testloader = torch.utils.data.DataLoader(testset, batch_size=batch_size, 44 | shuffle=False, sampler=test_sampler, num_workers=16) 45 | return trainloader, testloader 46 | 47 | 48 | def train(net, trainloader): 49 | print("Start training...") 50 | criterion = nn.CrossEntropyLoss() 51 | optimizer = optim.SGD(net.parameters(), lr=0.001, momentum=0.9) 52 | epochs = 1 53 | num_of_batches = len(trainloader) 54 | for epoch in range(epochs): # loop over the dataset multiple times 55 | trainloader.sampler.set_epoch(epoch) 56 | running_loss = 0.0 57 | for i, data in enumerate(trainloader, 0): 58 | # get the inputs; data is a list of [inputs, labels] 59 | inputs, labels = data 60 | 61 | images, labels = inputs.cuda(), labels.cuda() 62 | 63 | # zero the parameter gradients 64 | optimizer.zero_grad() 65 | 66 | # forward + backward + optimize 67 | outputs = net(images) 68 | loss = criterion(outputs, labels) 69 | loss.backward() 70 | optimizer.step() 71 | 72 | # print statistics 73 | running_loss += loss.item() 74 | 75 | print(f'[Epoch {epoch + 1}/{epochs}] loss: {running_loss / num_of_batches:.3f}') 76 | 77 | print('Finished Training') 78 | 79 | 80 | def test(net, PATH, testloader): 81 | # if is_main_process: 82 | # net.load_state_dict(torch.load(PATH)) 83 | # dist.barrier() 84 | 85 | correct = 0 86 | total = 0 87 | # since we're not training, we don't need to calculate the gradients for our outputs 88 | with torch.no_grad(): 89 | for data in testloader: 90 | images, labels = data 91 | 92 | images, labels = images.cuda(), labels.cuda() 93 | # calculate outputs by running images through the network 94 | outputs = net(images) 95 | # the class with the highest energy is what we choose as prediction 96 | _, predicted = torch.max(outputs.data, 1) 97 | total += labels.size(0) 98 | correct += (predicted == labels).sum().item() 99 | acc = 100 * correct // total 100 | print(f'Accuracy of the network on the 10000 test images: {acc} %') 101 | 102 | def init_distributed(): 103 | 104 | # Initializes the distributed backend which will take care of sychronizing nodes/GPUs 105 | dist_url = "env://" # default 106 | 107 | # only works with torch.distributed.launch // torch.run 108 | rank = int(os.environ["RANK"]) 109 | world_size = int(os.environ['WORLD_SIZE']) 110 | local_rank = int(os.environ['LOCAL_RANK']) 111 | 112 | dist.init_process_group( 113 | backend="nccl", 114 | init_method=dist_url, 115 | world_size=world_size, 116 | rank=rank) 117 | 118 | # this will make all .cuda() calls work properly 119 | torch.cuda.set_device(local_rank) 120 | # synchronizes all the threads to reach this point before moving on 121 | dist.barrier() 122 | setup_for_distributed(rank == 0) 123 | 124 | 125 | if __name__ == '__main__': 126 | start = time.time() 127 | 128 | init_distributed() 129 | 130 | PATH = './cifar_net.pth' 131 | trainloader, testloader = create_data_loader_cifar10() 132 | net = torchvision.models.resnet50(False).cuda() 133 | 134 | # Convert BatchNorm to SyncBatchNorm. 135 | net = nn.SyncBatchNorm.convert_sync_batchnorm(net) 136 | 137 | local_rank = int(os.environ['LOCAL_RANK']) 138 | net = nn.parallel.DistributedDataParallel(net, device_ids=[local_rank]) 139 | 140 | start_train = time.time() 141 | train(net, trainloader) 142 | end_train = time.time() 143 | # save 144 | if is_main_process: 145 | save_on_master(net.state_dict(), PATH) 146 | dist.barrier() 147 | 148 | # test 149 | test(net, PATH, testloader) 150 | 151 | end = time.time() 152 | seconds = (end - start) 153 | seconds_train = (end_train - start_train) 154 | print(f"Total elapsed time: {seconds:.2f} seconds, \ 155 | Train 1 epoch {seconds_train:.2f} seconds") 156 | 157 | 158 | 159 | -------------------------------------------------------------------------------- /train_ddp_mixed_presicion.py: -------------------------------------------------------------------------------- 1 | """ 2 | Mostly based on the official pytorch tutorial 3 | Link: https://pytorch.org/tutorials/beginner/blitz/cifar10_tutorial.html 4 | Modified for educational purposes. 5 | Nikolas, AI Summer 6 | """ 7 | import os 8 | gpu_list = "0,1,2,3" 9 | os.environ['CUDA_VISIBLE_DEVICES'] = gpu_list 10 | 11 | import torch 12 | import torchvision 13 | import torchvision.transforms as transforms 14 | import torch.nn as nn 15 | import torch.nn.functional as F 16 | import torch.optim as optim 17 | import torch.distributed as dist 18 | import time 19 | import torchvision 20 | 21 | from utils import setup_for_distributed, save_on_master, is_main_process 22 | 23 | 24 | def create_data_loader_cifar10(): 25 | transform = transforms.Compose( 26 | [ 27 | transforms.RandomCrop(32), 28 | transforms.RandomHorizontalFlip(), 29 | transforms.ToTensor(), 30 | transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))]) 31 | 32 | batch_size = 256 33 | 34 | trainset = torchvision.datasets.CIFAR10(root='./data', train=True, 35 | download=True, transform=transform) 36 | train_sampler = torch.utils.data.distributed.DistributedSampler(dataset=trainset, shuffle=True) 37 | trainloader = torch.utils.data.DataLoader(trainset, batch_size=batch_size, 38 | sampler=train_sampler, num_workers=10, pin_memory=True) 39 | 40 | testset = torchvision.datasets.CIFAR10(root='./data', train=False, 41 | download=True, transform=transform) 42 | test_sampler = torch.utils.data.distributed.DistributedSampler(dataset=testset, shuffle=True) 43 | testloader = torch.utils.data.DataLoader(testset, batch_size=batch_size, 44 | shuffle=False, sampler=test_sampler, num_workers=10) 45 | return trainloader, testloader 46 | 47 | 48 | def train(net, trainloader): 49 | print("Start training...") 50 | criterion = nn.CrossEntropyLoss() 51 | optimizer = optim.SGD(net.parameters(), lr=0.001, momentum=0.9) 52 | fp16_scaler = torch.cuda.amp.GradScaler(enabled=True) 53 | epochs = 1 54 | num_of_batches = len(trainloader) 55 | for epoch in range(epochs): # loop over the dataset multiple times 56 | trainloader.sampler.set_epoch(epoch) 57 | running_loss = 0.0 58 | for i, data in enumerate(trainloader, 0): 59 | # get the inputs; data is a list of [inputs, labels] 60 | inputs, labels = data 61 | 62 | 63 | # zero the parameter gradients 64 | optimizer.zero_grad() 65 | 66 | # forward 67 | with torch.cuda.amp.autocast(): 68 | images, labels = inputs.cuda(), labels.cuda() 69 | outputs = net(images) 70 | loss = criterion(outputs, labels) 71 | 72 | # mixed precision training 73 | # backward + optimizer step 74 | fp16_scaler.scale(loss).backward() 75 | fp16_scaler.step(optimizer) 76 | fp16_scaler.update() 77 | 78 | # print statistics 79 | running_loss += loss.item() 80 | 81 | print(f'[Epoch {epoch + 1}/{epochs}] loss: {running_loss / num_of_batches:.3f}') 82 | print('Finished Training') 83 | 84 | 85 | def test(net, PATH, testloader): 86 | if is_main_process: 87 | net.load_state_dict(torch.load(PATH)) 88 | dist.barrier() 89 | 90 | correct = 0 91 | total = 0 92 | # since we're not training, we don't need to calculate the gradients for our outputs 93 | with torch.no_grad(): 94 | for data in testloader: 95 | images, labels = data 96 | images, labels = images.cuda(), labels.cuda() 97 | # calculate outputs by running images through the network 98 | outputs = net(images) 99 | # the class with the highest energy is what we choose as prediction 100 | _, predicted = torch.max(outputs.data, 1) 101 | total += labels.size(0) 102 | correct += (predicted == labels).sum().item() 103 | acc = 100 * correct // total 104 | print(f'Accuracy of the network on the 10000 test images: {acc} %') 105 | 106 | def init_distributed(): 107 | 108 | # Initializes the distributed backend which will take care of sychronizing nodes/GPUs 109 | dist_url = "env://" # default 110 | # only works with torch.distributed.launch // torch.run 111 | rank = int(os.environ["RANK"]) 112 | world_size = int(os.environ['WORLD_SIZE']) 113 | local_rank = int(os.environ['LOCAL_RANK']) 114 | 115 | dist.init_process_group( 116 | backend="nccl", 117 | init_method=dist_url, 118 | world_size=world_size, 119 | rank=rank) 120 | 121 | # this will make all .cuda() calls work properly 122 | torch.cuda.set_device(local_rank) 123 | # synchronizes all the threads to reach this point before moving on 124 | dist.barrier() 125 | setup_for_distributed(rank == 0) 126 | 127 | 128 | if __name__ == '__main__': 129 | start = time.time() 130 | 131 | init_distributed() 132 | PATH = './cifar_net.pth' 133 | trainloader, testloader = create_data_loader_cifar10() 134 | net = torchvision.models.resnet50(False).cuda() 135 | 136 | # Convert BatchNorm to SyncBatchNorm. 137 | net = nn.SyncBatchNorm.convert_sync_batchnorm(net) 138 | 139 | local_rank = int(os.environ['LOCAL_RANK']) 140 | net = nn.parallel.DistributedDataParallel(net, device_ids=[local_rank]) 141 | 142 | start_train = time.time() 143 | train(net, trainloader) 144 | end_train = time.time() 145 | # save 146 | if is_main_process: 147 | save_on_master(net.state_dict(), PATH) 148 | dist.barrier() 149 | 150 | # test 151 | test(net, PATH, testloader) 152 | 153 | end = time.time() 154 | seconds = (end - start) 155 | seconds_train = (end_train - start_train) 156 | print(f"Total elapsed time: {seconds:.2f} seconds, \ 157 | Train 1 epoch {seconds_train:.2f} seconds") -------------------------------------------------------------------------------- /train_dp.py: -------------------------------------------------------------------------------- 1 | """ 2 | Mostly based on the official pytorch tutorial 3 | Link: https://pytorch.org/tutorials/beginner/blitz/cifar10_tutorial.html 4 | Modified for educational purposes. 5 | Nikolas, AI Summer 6 | """ 7 | import os 8 | os.environ['CUDA_VISIBLE_DEVICES'] = "0,1,2,3" 9 | import torch 10 | import torchvision 11 | import torchvision.transforms as transforms 12 | import torch.nn as nn 13 | import torch.nn.functional as F 14 | import torch.optim as optim 15 | import time 16 | import torchvision 17 | 18 | def create_data_loader_cifar10(): 19 | transform = transforms.Compose( 20 | [ 21 | transforms.RandomCrop(32), 22 | transforms.RandomHorizontalFlip(), 23 | transforms.ToTensor(), 24 | transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))]) 25 | 26 | batch_size = 256*4 27 | 28 | trainset = torchvision.datasets.CIFAR10(root='./data', train=True, 29 | download=True, transform=transform) 30 | trainloader = torch.utils.data.DataLoader(trainset, batch_size=batch_size, 31 | shuffle=True, num_workers=20, pin_memory=True) 32 | 33 | testset = torchvision.datasets.CIFAR10(root='./data', train=False, 34 | download=True, transform=transform) 35 | testloader = torch.utils.data.DataLoader(testset, batch_size=batch_size, 36 | shuffle=False, num_workers=20) 37 | return trainloader, testloader 38 | 39 | 40 | 41 | def train(net, trainloader): 42 | print("Start training...") 43 | criterion = nn.CrossEntropyLoss() 44 | optimizer = optim.SGD(net.parameters(), lr=0.001, momentum=0.9) 45 | epochs = 1 46 | num_of_batches = len(trainloader) 47 | for epoch in range(epochs): # loop over the dataset multiple times 48 | 49 | running_loss = 0.0 50 | for i, data in enumerate(trainloader, 0): 51 | # get the inputs; data is a list of [inputs, labels] 52 | images, labels = data 53 | 54 | labels = labels.cuda() 55 | 56 | # zero the parameter gradients 57 | optimizer.zero_grad() 58 | 59 | # forward + backward + optimize 60 | outputs = net(images) 61 | loss = criterion(outputs, labels) 62 | loss.backward() 63 | optimizer.step() 64 | 65 | # print statistics 66 | running_loss += loss.item() 67 | 68 | print(f'[Epoch {epoch + 1}/{epochs}] loss: {running_loss / num_of_batches:.3f}') 69 | 70 | print('Finished Training') 71 | 72 | 73 | def test(net, PATH, testloader): 74 | net.load_state_dict(torch.load(PATH)) 75 | 76 | correct = 0 77 | total = 0 78 | # since we're not training, we don't need to calculate the gradients for our outputs 79 | with torch.no_grad(): 80 | for data in testloader: 81 | images, labels = data 82 | 83 | labels = labels.cuda() 84 | 85 | # calculate outputs by running images through the network 86 | outputs = net(images) 87 | # the class with the highest energy is what we choose as prediction 88 | _, predicted = torch.max(outputs.data, 1) 89 | total += labels.size(0) 90 | correct += (predicted == labels).sum().item() 91 | acc = 100 * correct // total 92 | print(f'Accuracy of the network on the 10000 test images: {acc} %') 93 | 94 | 95 | if __name__ == '__main__': 96 | start = time.time() 97 | 98 | 99 | 100 | PATH = './cifar_net.pth' 101 | trainloader, testloader = create_data_loader_cifar10() 102 | 103 | net = torchvision.models.resnet50(False) 104 | if torch.cuda.device_count() > 1: 105 | print("Let's use", torch.cuda.device_count(), "GPUs!") 106 | # Batch size should be divisible by number of GPUs 107 | net = nn.DataParallel(net) 108 | 109 | net.cuda() 110 | 111 | start_train = time.time() 112 | train(net, trainloader) 113 | end_train = time.time() 114 | # save 115 | torch.save(net.state_dict(), PATH) 116 | # test 117 | test(net, PATH, testloader) 118 | 119 | end = time.time() 120 | seconds = (end - start) 121 | seconds_train = (end_train - start_train) 122 | print(f"Total elapsed time: {seconds:.2f} seconds, \ 123 | Train 1 epoch {seconds_train:.2f} seconds") 124 | 125 | 126 | 127 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | """ 2 | Mostly copy-paste from torchvision references or other public repos like DETR: 3 | https://github.com/facebookresearch/detr/blob/master/util/misc.py 4 | """ 5 | 6 | import torch.distributed as dist 7 | import torch 8 | 9 | def is_dist_avail_and_initialized(): 10 | if not dist.is_available(): 11 | return False 12 | if not dist.is_initialized(): 13 | return False 14 | return True 15 | 16 | 17 | def get_world_size(): 18 | if not is_dist_avail_and_initialized(): 19 | return 1 20 | return dist.get_world_size() 21 | 22 | 23 | def get_rank(): 24 | if not is_dist_avail_and_initialized(): 25 | return 0 26 | return dist.get_rank() 27 | 28 | 29 | def is_main_process(): 30 | return get_rank() == 0 31 | 32 | 33 | def save_on_master(*args, **kwargs): 34 | if is_main_process(): 35 | torch.save(*args, **kwargs) 36 | 37 | 38 | def setup_for_distributed(is_master): 39 | """ 40 | This function disables printing when not in master process 41 | """ 42 | import builtins as __builtin__ 43 | builtin_print = __builtin__.print 44 | 45 | def print(*args, **kwargs): 46 | force = kwargs.pop('force', False) 47 | if is_master or force: 48 | builtin_print(*args, **kwargs) 49 | 50 | __builtin__.print = print --------------------------------------------------------------------------------