├── .gitignore ├── 01-2_pytorch-fabric.py ├── 01_pytorch-vit.py ├── 02_mixed-precision.py ├── 03_bfloat16.py ├── 04_lower-batchsize.py ├── 05_gradient-accum.py ├── 06_sgd-with-scheduler.py ├── 07_01_init-module.py ├── 07_02_init-module.py ├── 07_03_init-module.py ├── 08-10-vit32 ├── 08_baseline.py ├── 08a_fsdp-defaults.py ├── 08b_fsdp-custom.py ├── 08c_fsdp-size-wrap.py ├── 09_fsdp-act-checkp.py ├── 10_fsdp-with-cpu-offload.py ├── 10b_fsdp-with-cpu-offload-no-act-check.py └── local_utilities.py ├── 08_baseline.py ├── 08a_fsdp-defaults.py ├── 08b_fsdp-custom.py ├── 08c_fsdp-size-wrap.py ├── 09_fsdp-act-checkp.py ├── 10_fsdp-with-cpu-offload.py ├── 10b_fsdp-with-cpu-offload-no-act-check.py ├── 11_delay-allocation.py ├── 12_fsdp-overlap.py ├── LICENSE.txt ├── README.md ├── bonus_bigbird-after.py ├── bonus_bigbird-before.py ├── bonus_distilbert-after.py ├── bonus_distilbert-before.py ├── figures └── overview.png ├── local_utilities.py ├── logs.md └── requirements.txt /.gitignore: -------------------------------------------------------------------------------- 1 | .DS_Store 2 | data/ 3 | 4 | # Byte-compiled / optimized / DLL files 5 | __pycache__/ 6 | *.py[cod] 7 | *$py.class 8 | 9 | # C extensions 10 | *.so 11 | 12 | # Distribution / packaging 13 | .Python 14 | build/ 15 | develop-eggs/ 16 | dist/ 17 | downloads/ 18 | eggs/ 19 | .eggs/ 20 | lib/ 21 | lib64/ 22 | parts/ 23 | sdist/ 24 | var/ 25 | wheels/ 26 | share/python-wheels/ 27 | *.egg-info/ 28 | .installed.cfg 29 | *.egg 30 | MANIFEST 31 | 32 | # PyInstaller 33 | # Usually these files are written by a python script from a template 34 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 35 | *.manifest 36 | *.spec 37 | 38 | # Installer logs 39 | pip-log.txt 40 | pip-delete-this-directory.txt 41 | 42 | # Unit test / coverage reports 43 | htmlcov/ 44 | .tox/ 45 | .nox/ 46 | .coverage 47 | .coverage.* 48 | .cache 49 | nosetests.xml 50 | coverage.xml 51 | *.cover 52 | *.py,cover 53 | .hypothesis/ 54 | .pytest_cache/ 55 | cover/ 56 | 57 | # Translations 58 | *.mo 59 | *.pot 60 | 61 | # Django stuff: 62 | *.log 63 | local_settings.py 64 | db.sqlite3 65 | db.sqlite3-journal 66 | 67 | # Flask stuff: 68 | instance/ 69 | .webassets-cache 70 | 71 | # Scrapy stuff: 72 | .scrapy 73 | 74 | # Sphinx documentation 75 | docs/_build/ 76 | 77 | # PyBuilder 78 | .pybuilder/ 79 | target/ 80 | 81 | # Jupyter Notebook 82 | .ipynb_checkpoints 83 | 84 | # IPython 85 | profile_default/ 86 | ipython_config.py 87 | 88 | # pyenv 89 | # For a library or package, you might want to ignore these files since the code is 90 | # intended to run in multiple environments; otherwise, check them in: 91 | # .python-version 92 | 93 | # pipenv 94 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 95 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 96 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 97 | # install all needed dependencies. 98 | #Pipfile.lock 99 | 100 | # poetry 101 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. 102 | # This is especially recommended for binary packages to ensure reproducibility, and is more 103 | # commonly ignored for libraries. 104 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control 105 | #poetry.lock 106 | 107 | # pdm 108 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. 109 | #pdm.lock 110 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it 111 | # in version control. 112 | # https://pdm.fming.dev/#use-with-ide 113 | .pdm.toml 114 | 115 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm 116 | __pypackages__/ 117 | 118 | # Celery stuff 119 | celerybeat-schedule 120 | celerybeat.pid 121 | 122 | # SageMath parsed files 123 | *.sage.py 124 | 125 | # Environments 126 | .env 127 | .venv 128 | env/ 129 | venv/ 130 | ENV/ 131 | env.bak/ 132 | venv.bak/ 133 | 134 | # Spyder project settings 135 | .spyderproject 136 | .spyproject 137 | 138 | # Rope project settings 139 | .ropeproject 140 | 141 | # mkdocs documentation 142 | /site 143 | 144 | # mypy 145 | .mypy_cache/ 146 | .dmypy.json 147 | dmypy.json 148 | 149 | # Pyre type checker 150 | .pyre/ 151 | 152 | # pytype static type analyzer 153 | .pytype/ 154 | 155 | # Cython debug symbols 156 | cython_debug/ 157 | 158 | # PyCharm 159 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can 160 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore 161 | # and can be added to the global gitignore or merged into this file. For a more nuclear 162 | # option (not recommended) you can uncomment the following to ignore the entire idea folder. 163 | #.idea/ 164 | -------------------------------------------------------------------------------- /01-2_pytorch-fabric.py: -------------------------------------------------------------------------------- 1 | import time 2 | 3 | import lightning as L 4 | from lightning import Fabric 5 | import torch 6 | import torch.nn.functional as F 7 | import torchmetrics 8 | from torchvision import transforms 9 | from torchvision.models import vit_l_16 10 | from torchvision.models import ViT_L_16_Weights 11 | from watermark import watermark 12 | 13 | from local_utilities import get_dataloaders_cifar10 14 | 15 | 16 | def train(num_epochs, model, optimizer, train_loader, val_loader, fabric): 17 | 18 | for epoch in range(num_epochs): 19 | train_acc = torchmetrics.Accuracy(task="multiclass", num_classes=10).to(fabric.device) 20 | 21 | model.train() 22 | for batch_idx, (features, targets) in enumerate(train_loader): 23 | model.train() 24 | 25 | ### FORWARD AND BACK PROP 26 | logits = model(features) 27 | loss = F.cross_entropy(logits, targets) 28 | 29 | optimizer.zero_grad() 30 | fabric.backward(loss) 31 | 32 | ### UPDATE MODEL PARAMETERS 33 | optimizer.step() 34 | 35 | ### LOGGING 36 | if not batch_idx % 300: 37 | print(f"Epoch: {epoch+1:04d}/{num_epochs:04d} | Batch {batch_idx:04d}/{len(train_loader):04d} | Loss: {loss:.4f}") 38 | 39 | model.eval() 40 | with torch.no_grad(): 41 | predicted_labels = torch.argmax(logits, 1) 42 | train_acc.update(predicted_labels, targets) 43 | 44 | ### MORE LOGGING 45 | model.eval() 46 | with torch.no_grad(): 47 | val_acc = torchmetrics.Accuracy(task="multiclass", num_classes=10).to(fabric.device) 48 | 49 | for (features, targets) in val_loader: 50 | outputs = model(features) 51 | predicted_labels = torch.argmax(outputs, 1) 52 | val_acc.update(predicted_labels, targets) 53 | 54 | print(f"Epoch: {epoch+1:04d}/{num_epochs:04d} | Train acc.: {train_acc.compute()*100:.2f}% | Val acc.: {val_acc.compute()*100:.2f}%") 55 | train_acc.reset(), val_acc.reset() 56 | 57 | 58 | if __name__ == "__main__": 59 | 60 | print(watermark(packages="torch,lightning", python=True)) 61 | print("Torch CUDA available?", torch.cuda.is_available()) 62 | 63 | fabric = Fabric(accelerator="cuda", devices=1) 64 | fabric.launch() 65 | 66 | L.seed_everything(123) 67 | 68 | ########################## 69 | ### 1 Loading the Dataset 70 | ########################## 71 | train_transforms = transforms.Compose([transforms.Resize((224, 224)), 72 | #transforms.RandomCrop((224, 224)), 73 | transforms.ToTensor()]) 74 | 75 | test_transforms = transforms.Compose([transforms.Resize((224, 224)), 76 | #transforms.CenterCrop((224, 224)), 77 | transforms.ToTensor()]) 78 | 79 | train_loader, val_loader, test_loader = get_dataloaders_cifar10( 80 | batch_size=64, 81 | num_workers=1, 82 | train_transforms=train_transforms, 83 | test_transforms=test_transforms, 84 | validation_fraction=0.1) 85 | 86 | train_loader, val_loader, test_loader = fabric.setup_dataloaders( 87 | train_loader, val_loader, test_loader) 88 | 89 | 90 | ######################################### 91 | ### 2 Initializing the Model 92 | ######################################### 93 | 94 | model = vit_l_16(weights=ViT_L_16_Weights.IMAGENET1K_V1) 95 | 96 | # replace output layer 97 | model.heads.head = torch.nn.Linear(in_features=1024, out_features=10) 98 | 99 | optimizer = torch.optim.Adam(model.parameters(), lr=5e-5) 100 | model, optimizer = fabric.setup(model, optimizer) 101 | 102 | ######################################### 103 | ### 3 Finetuning 104 | ######################################### 105 | 106 | start = time.time() 107 | train( 108 | num_epochs=1, 109 | model=model, 110 | optimizer=optimizer, 111 | train_loader=train_loader, 112 | val_loader=val_loader, 113 | fabric=fabric 114 | ) 115 | 116 | end = time.time() 117 | elapsed = end-start 118 | print(f"Time elapsed {elapsed/60:.2f} min") 119 | print(f"Memory used: {torch.cuda.max_memory_reserved() / 1e9:.02f} GB") 120 | 121 | ######################################### 122 | ### 4 Evaluation 123 | ######################################### 124 | 125 | with torch.no_grad(): 126 | model.eval() 127 | test_acc = torchmetrics.Accuracy(task="multiclass", num_classes=10).to(fabric.device) 128 | 129 | for (features, targets) in test_loader: 130 | outputs = model(features) 131 | predicted_labels = torch.argmax(outputs, 1) 132 | test_acc.update(predicted_labels, targets) 133 | 134 | print(f"Test accuracy {test_acc.compute()*100:.2f}%") -------------------------------------------------------------------------------- /01_pytorch-vit.py: -------------------------------------------------------------------------------- 1 | import time 2 | 3 | import lightning as L 4 | from lightning import Fabric 5 | import torch 6 | import torch.nn.functional as F 7 | import torchmetrics 8 | from torchvision import transforms 9 | from torchvision.models import vit_l_16 10 | from torchvision.models import ViT_L_16_Weights 11 | from watermark import watermark 12 | 13 | from local_utilities import get_dataloaders_cifar10 14 | 15 | 16 | def train(num_epochs, model, optimizer, train_loader, val_loader, device): 17 | 18 | for epoch in range(num_epochs): 19 | train_acc = torchmetrics.Accuracy(task="multiclass", num_classes=10).to(device) 20 | 21 | model.train() 22 | for batch_idx, (features, targets) in enumerate(train_loader): 23 | model.train() 24 | 25 | features = features.to(device) 26 | targets = targets.to(device) 27 | 28 | ### FORWARD AND BACK PROP 29 | logits = model(features) 30 | loss = F.cross_entropy(logits, targets) 31 | 32 | optimizer.zero_grad() 33 | loss.backward() 34 | 35 | ### UPDATE MODEL PARAMETERS 36 | optimizer.step() 37 | 38 | ### LOGGING 39 | if not batch_idx % 300: 40 | print(f"Epoch: {epoch+1:04d}/{num_epochs:04d} | Batch {batch_idx:04d}/{len(train_loader):04d} | Loss: {loss:.4f}") 41 | 42 | model.eval() 43 | with torch.no_grad(): 44 | predicted_labels = torch.argmax(logits, 1) 45 | train_acc.update(predicted_labels, targets) 46 | 47 | ### MORE LOGGING 48 | model.eval() 49 | with torch.no_grad(): 50 | val_acc = torchmetrics.Accuracy(task="multiclass", num_classes=10).to(device) 51 | 52 | for (features, targets) in val_loader: 53 | features = features.to(device) 54 | targets = targets.to(device) 55 | outputs = model(features) 56 | predicted_labels = torch.argmax(outputs, 1) 57 | val_acc.update(predicted_labels, targets) 58 | 59 | print(f"Epoch: {epoch+1:04d}/{num_epochs:04d} | Train acc.: {train_acc.compute()*100:.2f}% | Val acc.: {val_acc.compute()*100:.2f}%") 60 | train_acc.reset(), val_acc.reset() 61 | 62 | 63 | if __name__ == "__main__": 64 | 65 | print(watermark(packages="torch,lightning", python=True)) 66 | print("Torch CUDA available?", torch.cuda.is_available()) 67 | device = "cuda" if torch.cuda.is_available() else "cpu" 68 | 69 | L.seed_everything(123) 70 | 71 | ########################## 72 | ### 1 Loading the Dataset 73 | ########################## 74 | train_transforms = transforms.Compose([transforms.Resize((224, 224)), 75 | #transforms.RandomCrop((224, 224)), 76 | transforms.ToTensor()]) 77 | 78 | test_transforms = transforms.Compose([transforms.Resize((224, 224)), 79 | #transforms.CenterCrop((224, 224)), 80 | transforms.ToTensor()]) 81 | 82 | train_loader, val_loader, test_loader = get_dataloaders_cifar10( 83 | batch_size=64, 84 | num_workers=3, 85 | train_transforms=train_transforms, 86 | test_transforms=test_transforms, 87 | validation_fraction=0.1) 88 | 89 | 90 | ######################################### 91 | ### 2 Initializing the Model 92 | ######################################### 93 | 94 | model = vit_l_16(weights=ViT_L_16_Weights.IMAGENET1K_V1) 95 | 96 | # replace output layer 97 | model.heads.head = torch.nn.Linear(in_features=1024, out_features=10) 98 | 99 | model.to(device) 100 | optimizer = torch.optim.Adam(model.parameters(), lr=5e-5) 101 | 102 | ######################################### 103 | ### 3 Finetuning 104 | ######################################### 105 | 106 | start = time.time() 107 | train( 108 | num_epochs=1, 109 | model=model, 110 | optimizer=optimizer, 111 | train_loader=train_loader, 112 | val_loader=val_loader, 113 | device=device 114 | ) 115 | 116 | end = time.time() 117 | elapsed = end-start 118 | print(f"Time elapsed {elapsed/60:.2f} min") 119 | print(f"Memory used: {torch.cuda.max_memory_reserved() / 1e9:.02f} GB") 120 | 121 | ######################################### 122 | ### 4 Evaluation 123 | ######################################### 124 | 125 | with torch.no_grad(): 126 | model.eval() 127 | test_acc = torchmetrics.Accuracy(task="multiclass", num_classes=10).to(device) 128 | 129 | for (features, targets) in test_loader: 130 | features = features.to(device) 131 | targets = targets.to(device) 132 | outputs = model(features) 133 | predicted_labels = torch.argmax(outputs, 1) 134 | test_acc.update(predicted_labels, targets) 135 | 136 | print(f"Test accuracy {test_acc.compute()*100:.2f}%") -------------------------------------------------------------------------------- /02_mixed-precision.py: -------------------------------------------------------------------------------- 1 | import time 2 | 3 | import lightning as L 4 | from lightning import Fabric 5 | import torch 6 | import torch.nn.functional as F 7 | import torchmetrics 8 | from torchvision import transforms 9 | from torchvision.models import vit_l_16 10 | from torchvision.models import ViT_L_16_Weights 11 | from watermark import watermark 12 | 13 | from local_utilities import get_dataloaders_cifar10 14 | 15 | 16 | def train(num_epochs, model, optimizer, train_loader, val_loader, fabric): 17 | 18 | for epoch in range(num_epochs): 19 | train_acc = torchmetrics.Accuracy(task="multiclass", num_classes=10).to(fabric.device) 20 | 21 | model.train() 22 | for batch_idx, (features, targets) in enumerate(train_loader): 23 | model.train() 24 | 25 | ### FORWARD AND BACK PROP 26 | logits = model(features) 27 | loss = F.cross_entropy(logits, targets) 28 | 29 | optimizer.zero_grad() 30 | fabric.backward(loss) 31 | 32 | ### UPDATE MODEL PARAMETERS 33 | optimizer.step() 34 | 35 | ### LOGGING 36 | if not batch_idx % 300: 37 | print(f"Epoch: {epoch+1:04d}/{num_epochs:04d} | Batch {batch_idx:04d}/{len(train_loader):04d} | Loss: {loss:.4f}") 38 | 39 | model.eval() 40 | with torch.no_grad(): 41 | predicted_labels = torch.argmax(logits, 1) 42 | train_acc.update(predicted_labels, targets) 43 | 44 | ### MORE LOGGING 45 | model.eval() 46 | with torch.no_grad(): 47 | val_acc = torchmetrics.Accuracy(task="multiclass", num_classes=10).to(fabric.device) 48 | 49 | for (features, targets) in val_loader: 50 | outputs = model(features) 51 | predicted_labels = torch.argmax(outputs, 1) 52 | val_acc.update(predicted_labels, targets) 53 | 54 | print(f"Epoch: {epoch+1:04d}/{num_epochs:04d} | Train acc.: {train_acc.compute()*100:.2f}% | Val acc.: {val_acc.compute()*100:.2f}%") 55 | train_acc.reset(), val_acc.reset() 56 | 57 | 58 | if __name__ == "__main__": 59 | 60 | print(watermark(packages="torch,lightning", python=True)) 61 | print("Torch CUDA available?", torch.cuda.is_available()) 62 | 63 | fabric = Fabric(accelerator="cuda", devices=1, precision="16-mixed") 64 | fabric.launch() 65 | 66 | torch.set_float32_matmul_precision('high') 67 | 68 | L.seed_everything(123) 69 | 70 | ########################## 71 | ### 1 Loading the Dataset 72 | ########################## 73 | train_transforms = transforms.Compose([transforms.Resize((224, 224)), 74 | #transforms.RandomCrop((224, 224)), 75 | transforms.ToTensor()]) 76 | 77 | test_transforms = transforms.Compose([transforms.Resize((224, 224)), 78 | #transforms.CenterCrop((224, 224)), 79 | transforms.ToTensor()]) 80 | 81 | train_loader, val_loader, test_loader = get_dataloaders_cifar10( 82 | batch_size=64, 83 | num_workers=3, 84 | train_transforms=train_transforms, 85 | test_transforms=test_transforms, 86 | validation_fraction=0.1) 87 | 88 | train_loader, val_loader, test_loader = fabric.setup_dataloaders( 89 | train_loader, val_loader, test_loader) 90 | 91 | 92 | ######################################### 93 | ### 2 Initializing the Model 94 | ######################################### 95 | 96 | model = vit_l_16(weights=ViT_L_16_Weights.IMAGENET1K_V1) 97 | 98 | # replace output layer 99 | model.heads.head = torch.nn.Linear(in_features=1024, out_features=10) 100 | 101 | optimizer = torch.optim.Adam(model.parameters(), lr=5e-5) 102 | model, optimizer = fabric.setup(model, optimizer) 103 | 104 | ######################################### 105 | ### 3 Finetuning 106 | ######################################### 107 | 108 | start = time.time() 109 | train( 110 | num_epochs=1, 111 | model=model, 112 | optimizer=optimizer, 113 | train_loader=train_loader, 114 | val_loader=val_loader, 115 | fabric=fabric 116 | ) 117 | 118 | end = time.time() 119 | elapsed = end-start 120 | print(f"Time elapsed {elapsed/60:.2f} min") 121 | print(f"Memory used: {torch.cuda.max_memory_reserved() / 1e9:.02f} GB") 122 | 123 | ######################################### 124 | ### 4 Evaluation 125 | ######################################### 126 | 127 | with torch.no_grad(): 128 | model.eval() 129 | test_acc = torchmetrics.Accuracy(task="multiclass", num_classes=10).to(fabric.device) 130 | 131 | for (features, targets) in test_loader: 132 | outputs = model(features) 133 | predicted_labels = torch.argmax(outputs, 1) 134 | test_acc.update(predicted_labels, targets) 135 | 136 | print(f"Test accuracy {test_acc.compute()*100:.2f}%") -------------------------------------------------------------------------------- /03_bfloat16.py: -------------------------------------------------------------------------------- 1 | import time 2 | 3 | import lightning as L 4 | from lightning import Fabric 5 | import torch 6 | import torch.nn.functional as F 7 | import torchmetrics 8 | from torchvision import transforms 9 | from torchvision.models import vit_l_16 10 | from torchvision.models import ViT_L_16_Weights 11 | from watermark import watermark 12 | 13 | from local_utilities import get_dataloaders_cifar10 14 | 15 | 16 | def train(num_epochs, model, optimizer, train_loader, val_loader, fabric): 17 | 18 | for epoch in range(num_epochs): 19 | train_acc = torchmetrics.Accuracy(task="multiclass", num_classes=10).to(fabric.device) 20 | 21 | model.train() 22 | for batch_idx, (features, targets) in enumerate(train_loader): 23 | model.train() 24 | 25 | ### FORWARD AND BACK PROP 26 | logits = model(features) 27 | loss = F.cross_entropy(logits, targets) 28 | 29 | optimizer.zero_grad() 30 | fabric.backward(loss) 31 | 32 | ### UPDATE MODEL PARAMETERS 33 | optimizer.step() 34 | 35 | ### LOGGING 36 | if not batch_idx % 300: 37 | print(f"Epoch: {epoch+1:04d}/{num_epochs:04d} | Batch {batch_idx:04d}/{len(train_loader):04d} | Loss: {loss:.4f}") 38 | 39 | model.eval() 40 | with torch.no_grad(): 41 | predicted_labels = torch.argmax(logits, 1) 42 | train_acc.update(predicted_labels, targets) 43 | 44 | ### MORE LOGGING 45 | model.eval() 46 | with torch.no_grad(): 47 | val_acc = torchmetrics.Accuracy(task="multiclass", num_classes=10).to(fabric.device) 48 | 49 | for (features, targets) in val_loader: 50 | outputs = model(features) 51 | predicted_labels = torch.argmax(outputs, 1) 52 | val_acc.update(predicted_labels, targets) 53 | 54 | print(f"Epoch: {epoch+1:04d}/{num_epochs:04d} | Train acc.: {train_acc.compute()*100:.2f}% | Val acc.: {val_acc.compute()*100:.2f}%") 55 | train_acc.reset(), val_acc.reset() 56 | 57 | 58 | if __name__ == "__main__": 59 | 60 | print(watermark(packages="torch,lightning", python=True)) 61 | print("Torch CUDA available?", torch.cuda.is_available()) 62 | 63 | fabric = Fabric(accelerator="cuda", devices=1, precision="bf16-true") 64 | fabric.launch() 65 | 66 | torch.set_float32_matmul_precision('high') 67 | 68 | L.seed_everything(123) 69 | 70 | ########################## 71 | ### 1 Loading the Dataset 72 | ########################## 73 | train_transforms = transforms.Compose([transforms.Resize((224, 224)), 74 | #transforms.RandomCrop((224, 224)), 75 | transforms.ToTensor()]) 76 | 77 | test_transforms = transforms.Compose([transforms.Resize((224, 224)), 78 | #transforms.CenterCrop((224, 224)), 79 | transforms.ToTensor()]) 80 | 81 | train_loader, val_loader, test_loader = get_dataloaders_cifar10( 82 | batch_size=64, 83 | num_workers=3, 84 | train_transforms=train_transforms, 85 | test_transforms=test_transforms, 86 | validation_fraction=0.1) 87 | 88 | train_loader, val_loader, test_loader = fabric.setup_dataloaders( 89 | train_loader, val_loader, test_loader) 90 | 91 | 92 | ######################################### 93 | ### 2 Initializing the Model 94 | ######################################### 95 | 96 | model = vit_l_16(weights=ViT_L_16_Weights.IMAGENET1K_V1) 97 | 98 | # replace output layer 99 | model.heads.head = torch.nn.Linear(in_features=1024, out_features=10) 100 | 101 | optimizer = torch.optim.Adam(model.parameters(), lr=5e-5) 102 | model, optimizer = fabric.setup(model, optimizer) 103 | 104 | ######################################### 105 | ### 3 Finetuning 106 | ######################################### 107 | 108 | start = time.time() 109 | train( 110 | num_epochs=1, 111 | model=model, 112 | optimizer=optimizer, 113 | train_loader=train_loader, 114 | val_loader=val_loader, 115 | fabric=fabric 116 | ) 117 | 118 | end = time.time() 119 | elapsed = end-start 120 | print(f"Time elapsed {elapsed/60:.2f} min") 121 | print(f"Memory used: {torch.cuda.max_memory_reserved() / 1e9:.02f} GB") 122 | 123 | ######################################### 124 | ### 4 Evaluation 125 | ######################################### 126 | 127 | with torch.no_grad(): 128 | model.eval() 129 | test_acc = torchmetrics.Accuracy(task="multiclass", num_classes=10).to(fabric.device) 130 | 131 | for (features, targets) in test_loader: 132 | outputs = model(features) 133 | predicted_labels = torch.argmax(outputs, 1) 134 | test_acc.update(predicted_labels, targets) 135 | 136 | print(f"Test accuracy {test_acc.compute()*100:.2f}%") -------------------------------------------------------------------------------- /04_lower-batchsize.py: -------------------------------------------------------------------------------- 1 | import time 2 | 3 | import lightning as L 4 | from lightning import Fabric 5 | import torch 6 | import torch.nn.functional as F 7 | import torchmetrics 8 | from torchvision import transforms 9 | from torchvision.models import vit_l_16 10 | from torchvision.models import ViT_L_16_Weights 11 | from watermark import watermark 12 | 13 | from local_utilities import get_dataloaders_cifar10 14 | 15 | 16 | def train(num_epochs, model, optimizer, train_loader, val_loader, fabric): 17 | 18 | for epoch in range(num_epochs): 19 | train_acc = torchmetrics.Accuracy(task="multiclass", num_classes=10).to(fabric.device) 20 | 21 | model.train() 22 | for batch_idx, (features, targets) in enumerate(train_loader): 23 | model.train() 24 | 25 | ### FORWARD AND BACK PROP 26 | logits = model(features) 27 | loss = F.cross_entropy(logits, targets) 28 | 29 | optimizer.zero_grad() 30 | fabric.backward(loss) 31 | 32 | ### UPDATE MODEL PARAMETERS 33 | optimizer.step() 34 | 35 | ### LOGGING 36 | if not batch_idx % 300: 37 | print(f"Epoch: {epoch+1:04d}/{num_epochs:04d} | Batch {batch_idx:04d}/{len(train_loader):04d} | Loss: {loss:.4f}") 38 | 39 | model.eval() 40 | with torch.no_grad(): 41 | predicted_labels = torch.argmax(logits, 1) 42 | train_acc.update(predicted_labels, targets) 43 | 44 | ### MORE LOGGING 45 | model.eval() 46 | with torch.no_grad(): 47 | val_acc = torchmetrics.Accuracy(task="multiclass", num_classes=10).to(fabric.device) 48 | 49 | for (features, targets) in val_loader: 50 | outputs = model(features) 51 | predicted_labels = torch.argmax(outputs, 1) 52 | val_acc.update(predicted_labels, targets) 53 | 54 | print(f"Epoch: {epoch+1:04d}/{num_epochs:04d} | Train acc.: {train_acc.compute()*100:.2f}% | Val acc.: {val_acc.compute()*100:.2f}%") 55 | train_acc.reset(), val_acc.reset() 56 | 57 | 58 | if __name__ == "__main__": 59 | 60 | print(watermark(packages="torch,lightning", python=True)) 61 | print("Torch CUDA available?", torch.cuda.is_available()) 62 | 63 | fabric = Fabric(accelerator="cuda", devices=1, precision="bf16-true") 64 | fabric.launch() 65 | 66 | torch.set_float32_matmul_precision('high') 67 | 68 | L.seed_everything(123) 69 | 70 | ########################## 71 | ### 1 Loading the Dataset 72 | ########################## 73 | train_transforms = transforms.Compose([transforms.Resize((224, 224)), 74 | #transforms.RandomCrop((224, 224)), 75 | transforms.ToTensor()]) 76 | 77 | test_transforms = transforms.Compose([transforms.Resize((224, 224)), 78 | #transforms.CenterCrop((224, 224)), 79 | transforms.ToTensor()]) 80 | 81 | train_loader, val_loader, test_loader = get_dataloaders_cifar10( 82 | batch_size=16, 83 | num_workers=3, 84 | train_transforms=train_transforms, 85 | test_transforms=test_transforms, 86 | validation_fraction=0.1) 87 | 88 | train_loader, val_loader, test_loader = fabric.setup_dataloaders( 89 | train_loader, val_loader, test_loader) 90 | 91 | 92 | ######################################### 93 | ### 2 Initializing the Model 94 | ######################################### 95 | 96 | model = vit_l_16(weights=ViT_L_16_Weights.IMAGENET1K_V1) 97 | 98 | # replace output layer 99 | model.heads.head = torch.nn.Linear(in_features=1024, out_features=10) 100 | 101 | optimizer = torch.optim.Adam(model.parameters(), lr=5e-5) 102 | model, optimizer = fabric.setup(model, optimizer) 103 | 104 | ######################################### 105 | ### 3 Finetuning 106 | ######################################### 107 | 108 | start = time.time() 109 | train( 110 | num_epochs=1, 111 | model=model, 112 | optimizer=optimizer, 113 | train_loader=train_loader, 114 | val_loader=val_loader, 115 | fabric=fabric 116 | ) 117 | 118 | end = time.time() 119 | elapsed = end-start 120 | print(f"Time elapsed {elapsed/60:.2f} min") 121 | print(f"Memory used: {torch.cuda.max_memory_reserved() / 1e9:.02f} GB") 122 | 123 | ######################################### 124 | ### 4 Evaluation 125 | ######################################### 126 | 127 | with torch.no_grad(): 128 | model.eval() 129 | test_acc = torchmetrics.Accuracy(task="multiclass", num_classes=10).to(fabric.device) 130 | 131 | for (features, targets) in test_loader: 132 | outputs = model(features) 133 | predicted_labels = torch.argmax(outputs, 1) 134 | test_acc.update(predicted_labels, targets) 135 | 136 | print(f"Test accuracy {test_acc.compute()*100:.2f}%") -------------------------------------------------------------------------------- /05_gradient-accum.py: -------------------------------------------------------------------------------- 1 | import time 2 | 3 | import lightning as L 4 | from lightning import Fabric 5 | import torch 6 | import torch.nn.functional as F 7 | import torchmetrics 8 | from torchvision import transforms 9 | from torchvision.models import vit_l_16 10 | from torchvision.models import ViT_L_16_Weights 11 | from watermark import watermark 12 | 13 | from local_utilities import get_dataloaders_cifar10 14 | 15 | 16 | def train(num_epochs, model, optimizer, train_loader, val_loader, fabric, accumulation_steps): 17 | 18 | for epoch in range(num_epochs): 19 | train_acc = torchmetrics.Accuracy(task="multiclass", num_classes=10).to(fabric.device) 20 | 21 | model.train() 22 | for batch_idx, (features, targets) in enumerate(train_loader): 23 | model.train() 24 | 25 | ### FORWARD AND BACK PROP 26 | logits = model(features) 27 | loss = F.cross_entropy(logits, targets) / accumulation_steps # NEW 28 | 29 | fabric.backward(loss) 30 | 31 | 32 | ### UPDATE MODEL PARAMETERS 33 | if batch_idx % accumulation_steps == 0: # NEW 34 | optimizer.step() 35 | optimizer.zero_grad() 36 | 37 | ### LOGGING 38 | if not batch_idx % 300: 39 | print(f"Epoch: {epoch+1:04d}/{num_epochs:04d} | Batch {batch_idx:04d}/{len(train_loader):04d} | Loss: {loss:.4f}") 40 | 41 | model.eval() 42 | with torch.no_grad(): 43 | predicted_labels = torch.argmax(logits, 1) 44 | train_acc.update(predicted_labels, targets) 45 | 46 | ### MORE LOGGING 47 | model.eval() 48 | with torch.no_grad(): 49 | val_acc = torchmetrics.Accuracy(task="multiclass", num_classes=10).to(fabric.device) 50 | 51 | for (features, targets) in val_loader: 52 | outputs = model(features) 53 | predicted_labels = torch.argmax(outputs, 1) 54 | val_acc.update(predicted_labels, targets) 55 | 56 | print(f"Epoch: {epoch+1:04d}/{num_epochs:04d} | Train acc.: {train_acc.compute()*100:.2f}% | Val acc.: {val_acc.compute()*100:.2f}%") 57 | train_acc.reset(), val_acc.reset() 58 | 59 | 60 | if __name__ == "__main__": 61 | 62 | print(watermark(packages="torch,lightning", python=True)) 63 | print("Torch CUDA available?", torch.cuda.is_available()) 64 | 65 | fabric = Fabric(accelerator="cuda", devices=1, precision="bf16-true") 66 | fabric.launch() 67 | 68 | torch.set_float32_matmul_precision('high') 69 | 70 | L.seed_everything(123) 71 | 72 | ########################## 73 | ### 1 Loading the Dataset 74 | ########################## 75 | train_transforms = transforms.Compose([transforms.Resize((224, 224)), 76 | #transforms.RandomCrop((224, 224)), 77 | transforms.ToTensor()]) 78 | 79 | test_transforms = transforms.Compose([transforms.Resize((224, 224)), 80 | #transforms.CenterCrop((224, 224)), 81 | transforms.ToTensor()]) 82 | 83 | BATCHSIZE = 16 84 | ACCUMULATION_STEPS = 4 85 | MICROBATCHSIZE = int(BATCHSIZE / ACCUMULATION_STEPS) 86 | 87 | train_loader, val_loader, test_loader = get_dataloaders_cifar10( 88 | batch_size=MICROBATCHSIZE, 89 | num_workers=3, 90 | train_transforms=train_transforms, 91 | test_transforms=test_transforms, 92 | validation_fraction=0.1) 93 | 94 | train_loader, val_loader, test_loader = fabric.setup_dataloaders( 95 | train_loader, val_loader, test_loader) 96 | 97 | 98 | ######################################### 99 | ### 2 Initializing the Model 100 | ######################################### 101 | 102 | model = vit_l_16(weights=ViT_L_16_Weights.IMAGENET1K_V1) 103 | 104 | # replace output layer 105 | model.heads.head = torch.nn.Linear(in_features=1024, out_features=10) 106 | 107 | optimizer = torch.optim.Adam(model.parameters(), lr=5e-5) 108 | model, optimizer = fabric.setup(model, optimizer) 109 | 110 | ######################################### 111 | ### 3 Finetuning 112 | ######################################### 113 | 114 | start = time.time() 115 | train( 116 | num_epochs=1, 117 | model=model, 118 | optimizer=optimizer, 119 | train_loader=train_loader, 120 | val_loader=val_loader, 121 | fabric=fabric, 122 | accumulation_steps=ACCUMULATION_STEPS 123 | ) 124 | 125 | end = time.time() 126 | elapsed = end-start 127 | print(f"Time elapsed {elapsed/60:.2f} min") 128 | print(f"Memory used: {torch.cuda.max_memory_reserved() / 1e9:.02f} GB") 129 | 130 | ######################################### 131 | ### 4 Evaluation 132 | ######################################### 133 | 134 | with torch.no_grad(): 135 | model.eval() 136 | test_acc = torchmetrics.Accuracy(task="multiclass", num_classes=10).to(fabric.device) 137 | 138 | for (features, targets) in test_loader: 139 | outputs = model(features) 140 | predicted_labels = torch.argmax(outputs, 1) 141 | test_acc.update(predicted_labels, targets) 142 | 143 | print(f"Test accuracy {test_acc.compute()*100:.2f}%") -------------------------------------------------------------------------------- /06_sgd-with-scheduler.py: -------------------------------------------------------------------------------- 1 | import time 2 | 3 | import lightning as L 4 | from lightning import Fabric 5 | import torch 6 | import torch.nn.functional as F 7 | import torchmetrics 8 | from torchvision import transforms 9 | from torchvision.models import vit_l_16 10 | from torchvision.models import ViT_L_16_Weights 11 | from watermark import watermark 12 | 13 | from local_utilities import get_dataloaders_cifar10 14 | 15 | 16 | def train(num_epochs, model, optimizer, scheduler, train_loader, val_loader, fabric, accumulation_steps): 17 | 18 | for epoch in range(num_epochs): 19 | train_acc = torchmetrics.Accuracy(task="multiclass", num_classes=10).to(fabric.device) 20 | 21 | model.train() 22 | for batch_idx, (features, targets) in enumerate(train_loader): 23 | model.train() 24 | 25 | ### FORWARD AND BACK PROP 26 | logits = model(features) 27 | loss = F.cross_entropy(logits, targets) / accumulation_steps # NEW 28 | 29 | fabric.backward(loss) 30 | 31 | 32 | ### UPDATE MODEL PARAMETERS 33 | if batch_idx % accumulation_steps == 0: # NEW 34 | optimizer.step() 35 | scheduler.step() 36 | optimizer.zero_grad() 37 | 38 | ### LOGGING 39 | if not batch_idx % 300: 40 | print(f"Epoch: {epoch+1:04d}/{num_epochs:04d} | Batch {batch_idx:04d}/{len(train_loader):04d} | Loss: {loss:.4f}") 41 | 42 | model.eval() 43 | with torch.no_grad(): 44 | predicted_labels = torch.argmax(logits, 1) 45 | train_acc.update(predicted_labels, targets) 46 | 47 | ### MORE LOGGING 48 | model.eval() 49 | with torch.no_grad(): 50 | val_acc = torchmetrics.Accuracy(task="multiclass", num_classes=10).to(fabric.device) 51 | 52 | for (features, targets) in val_loader: 53 | outputs = model(features) 54 | predicted_labels = torch.argmax(outputs, 1) 55 | val_acc.update(predicted_labels, targets) 56 | 57 | print(f"Epoch: {epoch+1:04d}/{num_epochs:04d} | Train acc.: {train_acc.compute()*100:.2f}% | Val acc.: {val_acc.compute()*100:.2f}%") 58 | train_acc.reset(), val_acc.reset() 59 | 60 | 61 | if __name__ == "__main__": 62 | 63 | print(watermark(packages="torch,lightning", python=True)) 64 | print("Torch CUDA available?", torch.cuda.is_available()) 65 | 66 | fabric = Fabric(accelerator="cuda", devices=1, precision="bf16-true") 67 | fabric.launch() 68 | 69 | torch.set_float32_matmul_precision('high') 70 | 71 | L.seed_everything(123) 72 | 73 | ########################## 74 | ### 1 Loading the Dataset 75 | ########################## 76 | train_transforms = transforms.Compose([transforms.Resize((224, 224)), 77 | #transforms.RandomCrop((224, 224)), 78 | transforms.ToTensor()]) 79 | 80 | test_transforms = transforms.Compose([transforms.Resize((224, 224)), 81 | #transforms.CenterCrop((224, 224)), 82 | transforms.ToTensor()]) 83 | 84 | NUM_EPOCHS = 1 85 | BATCHSIZE = 16 86 | ACCUMULATION_STEPS = 4 87 | MICROBATCHSIZE = int(BATCHSIZE / ACCUMULATION_STEPS) 88 | 89 | train_loader, val_loader, test_loader = get_dataloaders_cifar10( 90 | batch_size=MICROBATCHSIZE, 91 | num_workers=3, 92 | train_transforms=train_transforms, 93 | test_transforms=test_transforms, 94 | validation_fraction=0.1) 95 | 96 | train_loader, val_loader, test_loader = fabric.setup_dataloaders( 97 | train_loader, val_loader, test_loader) 98 | 99 | 100 | ######################################### 101 | ### 2 Initializing the Model 102 | ######################################### 103 | 104 | model = vit_l_16(weights=ViT_L_16_Weights.IMAGENET1K_V1) 105 | 106 | # replace output layer 107 | model.heads.head = torch.nn.Linear(in_features=1024, out_features=10) 108 | 109 | optimizer = torch.optim.SGD(model.parameters(), lr=0.01) 110 | 111 | num_steps = NUM_EPOCHS * len(train_loader) 112 | scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=num_steps) 113 | model, optimizer = fabric.setup(model, optimizer) 114 | 115 | ######################################### 116 | ### 3 Finetuning 117 | ######################################### 118 | 119 | start = time.time() 120 | train( 121 | num_epochs=NUM_EPOCHS, 122 | model=model, 123 | optimizer=optimizer, 124 | scheduler=scheduler, 125 | train_loader=train_loader, 126 | val_loader=val_loader, 127 | fabric=fabric, 128 | accumulation_steps=ACCUMULATION_STEPS 129 | ) 130 | 131 | end = time.time() 132 | elapsed = end-start 133 | print(f"Time elapsed {elapsed/60:.2f} min") 134 | print(f"Memory used: {torch.cuda.max_memory_reserved() / 1e9:.02f} GB") 135 | 136 | ######################################### 137 | ### 4 Evaluation 138 | ######################################### 139 | 140 | with torch.no_grad(): 141 | model.eval() 142 | test_acc = torchmetrics.Accuracy(task="multiclass", num_classes=10).to(fabric.device) 143 | 144 | for (features, targets) in test_loader: 145 | outputs = model(features) 146 | predicted_labels = torch.argmax(outputs, 1) 147 | test_acc.update(predicted_labels, targets) 148 | 149 | print(f"Test accuracy {test_acc.compute()*100:.2f}%") -------------------------------------------------------------------------------- /07_01_init-module.py: -------------------------------------------------------------------------------- 1 | import gc 2 | import psutil 3 | 4 | import torch 5 | from lightning.fabric import Fabric 6 | 7 | from torchvision.models import vit_l_16 8 | from torchvision.models import ViT_L_16_Weights 9 | from watermark import watermark 10 | 11 | 12 | if __name__ == "__main__": 13 | print(watermark(packages="torch,lightning", python=True)) 14 | print("Torch CUDA available?", torch.cuda.is_available()) 15 | 16 | 17 | print("Without Fabric") 18 | 19 | cpu_memory_before = psutil.Process().memory_info().rss / (1024**3) 20 | 21 | model = vit_l_16(weights=ViT_L_16_Weights.IMAGENET1K_V1) 22 | model.to(torch.device("cuda")).to(torch.float16) 23 | 24 | cpu_ram_after = psutil.Process().memory_info().rss / (1024**3) 25 | 26 | print(f"CPU Memory used: {cpu_ram_after - cpu_memory_before / 1e9:.02f} GB") 27 | print(f"GPU Memory used: {torch.cuda.max_memory_reserved() / 1e9:.02f} GB") -------------------------------------------------------------------------------- /07_02_init-module.py: -------------------------------------------------------------------------------- 1 | import gc 2 | import psutil 3 | 4 | import torch 5 | from lightning.fabric import Fabric 6 | 7 | from torchvision.models import vit_l_16 8 | from torchvision.models import ViT_L_16_Weights 9 | from watermark import watermark 10 | 11 | 12 | if __name__ == "__main__": 13 | print(watermark(packages="torch,lightning", python=True)) 14 | print("Torch CUDA available?", torch.cuda.is_available()) 15 | 16 | 17 | print("Without init_module") 18 | 19 | cpu_memory_before = psutil.Process().memory_info().rss / (1024**3) 20 | 21 | fabric = Fabric(accelerator="cuda", devices=1, precision="16-true") 22 | model = vit_l_16(weights=ViT_L_16_Weights.IMAGENET1K_V1) 23 | model = fabric.setup(model) 24 | 25 | cpu_ram_after = psutil.Process().memory_info().rss / (1024**3) 26 | 27 | print(f"CPU Memory used: {cpu_ram_after - cpu_memory_before / 1e9:.02f} GB") 28 | print(f"GPU Memory used: {torch.cuda.max_memory_reserved() / 1e9:.02f} GB") -------------------------------------------------------------------------------- /07_03_init-module.py: -------------------------------------------------------------------------------- 1 | import gc 2 | import psutil 3 | 4 | import torch 5 | from lightning.fabric import Fabric 6 | 7 | from torchvision.models import vit_l_16 8 | from torchvision.models import ViT_L_16_Weights 9 | from watermark import watermark 10 | 11 | 12 | if __name__ == "__main__": 13 | print(watermark(packages="torch,lightning", python=True)) 14 | print("Torch CUDA available?", torch.cuda.is_available()) 15 | 16 | 17 | print("With init_module") 18 | 19 | cpu_memory_before = psutil.Process().memory_info().rss / (1024**3) 20 | 21 | fabric = Fabric(accelerator="cuda", devices=1, precision="16-true") 22 | fabric.launch() 23 | 24 | with fabric.init_module(): 25 | model = vit_l_16(weights=ViT_L_16_Weights.IMAGENET1K_V1) 26 | 27 | cpu_ram_after = psutil.Process().memory_info().rss / (1024**3) 28 | 29 | print(f"CPU Memory used: {cpu_ram_after - cpu_memory_before / 1e9:.02f} GB") 30 | print(f"GPU Memory used: {torch.cuda.max_memory_reserved() / 1e9:.02f} GB") 31 | -------------------------------------------------------------------------------- /08-10-vit32/08_baseline.py: -------------------------------------------------------------------------------- 1 | import time 2 | 3 | import lightning as L 4 | from lightning import Fabric 5 | import torch 6 | import torch.nn.functional as F 7 | import torchmetrics 8 | from torchvision import transforms 9 | from torchvision.models import vit_h_14 10 | from torchvision.models import ViT_H_14_Weights 11 | from watermark import watermark 12 | 13 | from local_utilities import get_dataloaders_cifar10 14 | 15 | 16 | def train(num_epochs, model, optimizer, train_loader, val_loader, fabric): 17 | 18 | for epoch in range(num_epochs): 19 | train_acc = torchmetrics.Accuracy(task="multiclass", num_classes=10).to(fabric.device) 20 | 21 | model.train() 22 | for batch_idx, (features, targets) in enumerate(train_loader): 23 | model.train() 24 | 25 | ### FORWARD AND BACK PROP 26 | logits = model(features) 27 | loss = F.cross_entropy(logits, targets) 28 | 29 | optimizer.zero_grad() 30 | fabric.backward(loss) 31 | 32 | ### UPDATE MODEL PARAMETERS 33 | optimizer.step() 34 | 35 | ### LOGGING 36 | if not batch_idx % 300: 37 | print(f"Epoch: {epoch+1:04d}/{num_epochs:04d} | Batch {batch_idx:04d}/{len(train_loader):04d} | Loss: {loss:.4f}") 38 | 39 | model.eval() 40 | with torch.no_grad(): 41 | predicted_labels = torch.argmax(logits, 1) 42 | train_acc.update(predicted_labels, targets) 43 | 44 | ### MORE LOGGING 45 | model.eval() 46 | with torch.no_grad(): 47 | val_acc = torchmetrics.Accuracy(task="multiclass", num_classes=10).to(fabric.device) 48 | 49 | for (features, targets) in val_loader: 50 | outputs = model(features) 51 | predicted_labels = torch.argmax(outputs, 1) 52 | val_acc.update(predicted_labels, targets) 53 | 54 | print(f"Epoch: {epoch+1:04d}/{num_epochs:04d} | Train acc.: {train_acc.compute()*100:.2f}% | Val acc.: {val_acc.compute()*100:.2f}%") 55 | train_acc.reset(), val_acc.reset() 56 | 57 | 58 | if __name__ == "__main__": 59 | 60 | print(watermark(packages="torch,lightning", python=True)) 61 | print("Torch CUDA available?", torch.cuda.is_available()) 62 | 63 | fabric = Fabric(accelerator="cuda", devices=1, precision="16-mixed") 64 | fabric.launch() 65 | 66 | L.seed_everything(123) 67 | 68 | ########################## 69 | ### 1 Loading the Dataset 70 | ########################## 71 | train_transforms = transforms.Compose([transforms.Resize((518, 518)), 72 | #transforms.RandomCrop((224, 224)), 73 | transforms.ToTensor()]) 74 | 75 | test_transforms = transforms.Compose([transforms.Resize((518, 518)), 76 | #transforms.CenterCrop((224, 224)), 77 | transforms.ToTensor()]) 78 | 79 | train_loader, val_loader, test_loader = get_dataloaders_cifar10( 80 | batch_size=4, 81 | num_workers=1, 82 | train_transforms=train_transforms, 83 | test_transforms=test_transforms, 84 | validation_fraction=0.1) 85 | 86 | train_loader, val_loader, test_loader = fabric.setup_dataloaders( 87 | train_loader, val_loader, test_loader) 88 | 89 | 90 | ######################################### 91 | ### 2 Initializing the Model 92 | ######################################### 93 | 94 | model = vit_h_14(weights=ViT_H_14_Weights.IMAGENET1K_SWAG_E2E_V1) 95 | 96 | # replace output layer 97 | model.heads.head = torch.nn.Linear(in_features=1280, out_features=10) 98 | 99 | optimizer = torch.optim.Adam(model.parameters(), lr=5e-5) 100 | model, optimizer = fabric.setup(model, optimizer) 101 | 102 | ######################################### 103 | ### 3 Finetuning 104 | ######################################### 105 | 106 | start = time.time() 107 | train( 108 | num_epochs=1, 109 | model=model, 110 | optimizer=optimizer, 111 | train_loader=train_loader, 112 | val_loader=val_loader, 113 | fabric=fabric 114 | ) 115 | 116 | end = time.time() 117 | elapsed = end-start 118 | print(f"Time elapsed {elapsed/60:.2f} min") 119 | print(f"Memory used: {torch.cuda.max_memory_reserved() / 1e9:.02f} GB") 120 | 121 | ######################################### 122 | ### 4 Evaluation 123 | ######################################### 124 | 125 | with torch.no_grad(): 126 | model.eval() 127 | test_acc = torchmetrics.Accuracy(task="multiclass", num_classes=10).to(fabric.device) 128 | 129 | for (features, targets) in test_loader: 130 | outputs = model(features) 131 | predicted_labels = torch.argmax(outputs, 1) 132 | test_acc.update(predicted_labels, targets) 133 | 134 | print(f"Test accuracy {test_acc.compute()*100:.2f}%") -------------------------------------------------------------------------------- /08-10-vit32/08a_fsdp-defaults.py: -------------------------------------------------------------------------------- 1 | import time 2 | from functools import partial 3 | 4 | import lightning as L 5 | from lightning import Fabric 6 | from lightning.fabric.strategies import FSDPStrategy 7 | import torch 8 | import torch.nn.functional as F 9 | from torch.distributed.fsdp.wrap import transformer_auto_wrap_policy 10 | import torchmetrics 11 | from torchvision import transforms 12 | from torchvision.models import vit_h_14 13 | from torchvision.models import ViT_H_14_Weights 14 | from torchvision.models.vision_transformer import EncoderBlock 15 | from watermark import watermark 16 | 17 | from local_utilities import get_dataloaders_cifar10 18 | 19 | 20 | def train(num_epochs, model, optimizer, train_loader, val_loader, fabric): 21 | 22 | for epoch in range(num_epochs): 23 | train_acc = torchmetrics.Accuracy(task="multiclass", num_classes=10).to(fabric.device) 24 | 25 | model.train() 26 | for batch_idx, (features, targets) in enumerate(train_loader): 27 | model.train() 28 | 29 | ### FORWARD AND BACK PROP 30 | logits = model(features) 31 | loss = F.cross_entropy(logits, targets) 32 | 33 | optimizer.zero_grad() 34 | fabric.backward(loss) 35 | 36 | ### UPDATE MODEL PARAMETERS 37 | optimizer.step() 38 | 39 | ### LOGGING 40 | if not batch_idx % 50: 41 | fabric.print(f"Epoch: {epoch+1:04d}/{num_epochs:04d} | Batch {batch_idx:04d}/{len(train_loader):04d} | Loss: {loss:.4f}") 42 | 43 | model.eval() 44 | with torch.no_grad(): 45 | predicted_labels = torch.argmax(logits, 1) 46 | train_acc.update(predicted_labels, targets) 47 | 48 | ### MORE LOGGING 49 | model.eval() 50 | with torch.no_grad(): 51 | val_acc = torchmetrics.Accuracy(task="multiclass", num_classes=10).to(fabric.device) 52 | 53 | for (features, targets) in val_loader: 54 | outputs = model(features) 55 | predicted_labels = torch.argmax(outputs, 1) 56 | val_acc.update(predicted_labels, targets) 57 | 58 | fabric.print(f"Epoch: {epoch+1:04d}/{num_epochs:04d} | Train acc.: {train_acc.compute()*100:.2f}% | Val acc.: {val_acc.compute()*100:.2f}%") 59 | train_acc.reset(), val_acc.reset() 60 | 61 | 62 | if __name__ == "__main__": 63 | 64 | fabric = Fabric(accelerator="cuda", devices=4, strategy="fsdp", precision="16-mixed") 65 | fabric.launch() 66 | 67 | L.seed_everything(123) 68 | fabric.print(watermark(packages="torch,lightning", python=True)) 69 | fabric.print("Torch CUDA available?", torch.cuda.is_available()) 70 | 71 | 72 | ########################## 73 | ### 1 Loading the Dataset 74 | ########################## 75 | train_transforms = transforms.Compose([transforms.Resize((518, 518)), 76 | #transforms.RandomCrop((224, 224)), 77 | transforms.ToTensor()]) 78 | 79 | test_transforms = transforms.Compose([transforms.Resize((518, 518)), 80 | #transforms.CenterCrop((224, 224)), 81 | transforms.ToTensor()]) 82 | 83 | with fabric.rank_zero_first(): 84 | # the above prevents race conditions when multiple processes 85 | # try to download and write the dataset to disk. 86 | train_loader, val_loader, test_loader = get_dataloaders_cifar10( 87 | batch_size=4, 88 | num_workers=1, 89 | train_transforms=train_transforms, 90 | test_transforms=test_transforms, 91 | validation_fraction=0.1) 92 | 93 | train_loader, val_loader, test_loader = fabric.setup_dataloaders( 94 | train_loader, val_loader, test_loader) 95 | 96 | 97 | ######################################### 98 | ### 2 Initializing the Model 99 | ######################################### 100 | 101 | model = vit_h_14(weights=ViT_H_14_Weights.IMAGENET1K_SWAG_E2E_V1) 102 | 103 | # replace output layer 104 | model.heads.head = torch.nn.Linear(in_features=1280, out_features=10) 105 | 106 | optimizer = torch.optim.Adam(model.parameters(), lr=5e-5) 107 | model, optimizer = fabric.setup(model, optimizer) 108 | 109 | ######################################### 110 | ### 3 Finetuning 111 | ######################################### 112 | 113 | start = time.time() 114 | train( 115 | num_epochs=1, 116 | model=model, 117 | optimizer=optimizer, 118 | train_loader=train_loader, 119 | val_loader=val_loader, 120 | fabric=fabric 121 | ) 122 | 123 | end = time.time() 124 | elapsed = end-start 125 | fabric.print(f"Time elapsed {elapsed/60:.2f} min") 126 | fabric.print(f"Memory used: {torch.cuda.max_memory_reserved() / 1e9:.02f} GB") 127 | 128 | ######################################### 129 | ### 4 Evaluation 130 | ######################################### 131 | 132 | with torch.no_grad(): 133 | model.eval() 134 | test_acc = torchmetrics.Accuracy(task="multiclass", num_classes=10).to(fabric.device) 135 | 136 | for (features, targets) in test_loader: 137 | outputs = model(features) 138 | predicted_labels = torch.argmax(outputs, 1) 139 | test_acc.update(predicted_labels, targets) 140 | 141 | fabric.print(f"Test accuracy {test_acc.compute()*100:.2f}%") -------------------------------------------------------------------------------- /08-10-vit32/08b_fsdp-custom.py: -------------------------------------------------------------------------------- 1 | import time 2 | from functools import partial 3 | 4 | import lightning as L 5 | from lightning import Fabric 6 | from lightning.fabric.strategies import FSDPStrategy 7 | import torch 8 | import torch.nn.functional as F 9 | from torch.distributed.fsdp.wrap import transformer_auto_wrap_policy 10 | import torchmetrics 11 | from torchvision import transforms 12 | from torchvision.models import vit_h_14 13 | from torchvision.models import ViT_H_14_Weights 14 | from torchvision.models.vision_transformer import EncoderBlock 15 | from watermark import watermark 16 | 17 | from local_utilities import get_dataloaders_cifar10 18 | 19 | 20 | def train(num_epochs, model, optimizer, train_loader, val_loader, fabric): 21 | 22 | for epoch in range(num_epochs): 23 | train_acc = torchmetrics.Accuracy(task="multiclass", num_classes=10).to(fabric.device) 24 | 25 | model.train() 26 | for batch_idx, (features, targets) in enumerate(train_loader): 27 | model.train() 28 | 29 | ### FORWARD AND BACK PROP 30 | logits = model(features) 31 | loss = F.cross_entropy(logits, targets) 32 | 33 | optimizer.zero_grad() 34 | fabric.backward(loss) 35 | 36 | ### UPDATE MODEL PARAMETERS 37 | optimizer.step() 38 | 39 | ### LOGGING 40 | if not batch_idx % 50: 41 | fabric.print(f"Epoch: {epoch+1:04d}/{num_epochs:04d} | Batch {batch_idx:04d}/{len(train_loader):04d} | Loss: {loss:.4f}") 42 | 43 | model.eval() 44 | with torch.no_grad(): 45 | predicted_labels = torch.argmax(logits, 1) 46 | train_acc.update(predicted_labels, targets) 47 | 48 | ### MORE LOGGING 49 | model.eval() 50 | with torch.no_grad(): 51 | val_acc = torchmetrics.Accuracy(task="multiclass", num_classes=10).to(fabric.device) 52 | 53 | for (features, targets) in val_loader: 54 | outputs = model(features) 55 | predicted_labels = torch.argmax(outputs, 1) 56 | val_acc.update(predicted_labels, targets) 57 | 58 | fabric.print(f"Epoch: {epoch+1:04d}/{num_epochs:04d} | Train acc.: {train_acc.compute()*100:.2f}% | Val acc.: {val_acc.compute()*100:.2f}%") 59 | train_acc.reset(), val_acc.reset() 60 | 61 | 62 | if __name__ == "__main__": 63 | 64 | auto_wrap_policy = partial(transformer_auto_wrap_policy, transformer_layer_cls={EncoderBlock}) 65 | strategy = FSDPStrategy(auto_wrap_policy=auto_wrap_policy) 66 | 67 | fabric = Fabric(accelerator="cuda", devices=4, strategy=strategy, precision="16-mixed") 68 | fabric.launch() 69 | 70 | L.seed_everything(123) 71 | fabric.print(watermark(packages="torch,lightning", python=True)) 72 | fabric.print("Torch CUDA available?", torch.cuda.is_available()) 73 | 74 | 75 | ########################## 76 | ### 1 Loading the Dataset 77 | ########################## 78 | train_transforms = transforms.Compose([transforms.Resize((518, 518)), 79 | #transforms.RandomCrop((224, 224)), 80 | transforms.ToTensor()]) 81 | 82 | test_transforms = transforms.Compose([transforms.Resize((518, 518)), 83 | #transforms.CenterCrop((224, 224)), 84 | transforms.ToTensor()]) 85 | with fabric.rank_zero_first(): 86 | # the above prevents race conditions when multiple processes 87 | # try to download and write the dataset to disk. 88 | train_loader, val_loader, test_loader = get_dataloaders_cifar10( 89 | batch_size=4, 90 | num_workers=1, 91 | train_transforms=train_transforms, 92 | test_transforms=test_transforms, 93 | validation_fraction=0.1) 94 | 95 | train_loader, val_loader, test_loader = fabric.setup_dataloaders( 96 | train_loader, val_loader, test_loader) 97 | 98 | 99 | ######################################### 100 | ### 2 Initializing the Model 101 | ######################################### 102 | 103 | model = vit_h_14(weights=ViT_H_14_Weights.IMAGENET1K_SWAG_E2E_V1) 104 | 105 | # replace output layer 106 | model.heads.head = torch.nn.Linear(in_features=1280, out_features=10) 107 | 108 | optimizer = torch.optim.Adam(model.parameters(), lr=5e-5) 109 | model, optimizer = fabric.setup(model, optimizer) 110 | 111 | ######################################### 112 | ### 3 Finetuning 113 | ######################################### 114 | 115 | start = time.time() 116 | train( 117 | num_epochs=1, 118 | model=model, 119 | optimizer=optimizer, 120 | train_loader=train_loader, 121 | val_loader=val_loader, 122 | fabric=fabric 123 | ) 124 | 125 | end = time.time() 126 | elapsed = end-start 127 | fabric.print(f"Time elapsed {elapsed/60:.2f} min") 128 | fabric.print(f"Memory used: {torch.cuda.max_memory_reserved() / 1e9:.02f} GB") 129 | 130 | ######################################### 131 | ### 4 Evaluation 132 | ######################################### 133 | 134 | with torch.no_grad(): 135 | model.eval() 136 | test_acc = torchmetrics.Accuracy(task="multiclass", num_classes=10).to(fabric.device) 137 | 138 | for (features, targets) in test_loader: 139 | outputs = model(features) 140 | predicted_labels = torch.argmax(outputs, 1) 141 | test_acc.update(predicted_labels, targets) 142 | 143 | fabric.print(f"Test accuracy {test_acc.compute()*100:.2f}%") -------------------------------------------------------------------------------- /08-10-vit32/08c_fsdp-size-wrap.py: -------------------------------------------------------------------------------- 1 | import time 2 | from functools import partial 3 | 4 | import lightning as L 5 | from lightning import Fabric 6 | from lightning.fabric.strategies import FSDPStrategy 7 | import torch 8 | import torch.nn.functional as F 9 | from torch.distributed.fsdp.wrap import size_based_auto_wrap_policy 10 | import torchmetrics 11 | from torchvision import transforms 12 | from torchvision.models import vit_h_14 13 | from torchvision.models import ViT_H_14_Weights 14 | from torchvision.models.vision_transformer import EncoderBlock 15 | from watermark import watermark 16 | 17 | from local_utilities import get_dataloaders_cifar10 18 | 19 | 20 | def train(num_epochs, model, optimizer, train_loader, val_loader, fabric): 21 | 22 | for epoch in range(num_epochs): 23 | train_acc = torchmetrics.Accuracy(task="multiclass", num_classes=10).to(fabric.device) 24 | 25 | model.train() 26 | for batch_idx, (features, targets) in enumerate(train_loader): 27 | model.train() 28 | 29 | ### FORWARD AND BACK PROP 30 | logits = model(features) 31 | loss = F.cross_entropy(logits, targets) 32 | 33 | optimizer.zero_grad() 34 | fabric.backward(loss) 35 | 36 | ### UPDATE MODEL PARAMETERS 37 | optimizer.step() 38 | 39 | ### LOGGING 40 | if not batch_idx % 50: 41 | fabric.print(f"Epoch: {epoch+1:04d}/{num_epochs:04d} | Batch {batch_idx:04d}/{len(train_loader):04d} | Loss: {loss:.4f}") 42 | 43 | model.eval() 44 | with torch.no_grad(): 45 | predicted_labels = torch.argmax(logits, 1) 46 | train_acc.update(predicted_labels, targets) 47 | 48 | ### MORE LOGGING 49 | model.eval() 50 | with torch.no_grad(): 51 | val_acc = torchmetrics.Accuracy(task="multiclass", num_classes=10).to(fabric.device) 52 | 53 | for (features, targets) in val_loader: 54 | outputs = model(features) 55 | predicted_labels = torch.argmax(outputs, 1) 56 | val_acc.update(predicted_labels, targets) 57 | 58 | fabric.print(f"Epoch: {epoch+1:04d}/{num_epochs:04d} | Train acc.: {train_acc.compute()*100:.2f}% | Val acc.: {val_acc.compute()*100:.2f}%") 59 | train_acc.reset(), val_acc.reset() 60 | 61 | 62 | if __name__ == "__main__": 63 | 64 | auto_wrap_policy = partial(size_based_auto_wrap_policy, module={EncoderBlock}, min_num_params=2_000_000) 65 | strategy = FSDPStrategy(auto_wrap_policy=auto_wrap_policy) 66 | 67 | fabric = Fabric(accelerator="cuda", devices=4, strategy=strategy, precision="16-mixed") 68 | fabric.launch() 69 | 70 | L.seed_everything(123) 71 | fabric.print(watermark(packages="torch,lightning", python=True)) 72 | fabric.print("Torch CUDA available?", torch.cuda.is_available()) 73 | 74 | 75 | ########################## 76 | ### 1 Loading the Dataset 77 | ########################## 78 | train_transforms = transforms.Compose([transforms.Resize((518, 518)), 79 | #transforms.RandomCrop((224, 224)), 80 | transforms.ToTensor()]) 81 | 82 | test_transforms = transforms.Compose([transforms.Resize((518, 518)), 83 | #transforms.CenterCrop((224, 224)), 84 | transforms.ToTensor()]) 85 | 86 | with fabric.rank_zero_first(): 87 | # the above prevents race conditions when multiple processes 88 | # try to download and write the dataset to disk. 89 | train_loader, val_loader, test_loader = get_dataloaders_cifar10( 90 | batch_size=4, 91 | num_workers=1, 92 | train_transforms=train_transforms, 93 | test_transforms=test_transforms, 94 | validation_fraction=0.1) 95 | 96 | train_loader, val_loader, test_loader = fabric.setup_dataloaders( 97 | train_loader, val_loader, test_loader) 98 | 99 | 100 | ######################################### 101 | ### 2 Initializing the Model 102 | ######################################### 103 | 104 | model = vit_h_14(weights=ViT_H_14_Weights.IMAGENET1K_SWAG_E2E_V1) 105 | 106 | # replace output layer 107 | model.heads.head = torch.nn.Linear(in_features=1280, out_features=10) 108 | 109 | optimizer = torch.optim.Adam(model.parameters(), lr=5e-5) 110 | model, optimizer = fabric.setup(model, optimizer) 111 | 112 | ######################################### 113 | ### 3 Finetuning 114 | ######################################### 115 | 116 | start = time.time() 117 | train( 118 | num_epochs=1, 119 | model=model, 120 | optimizer=optimizer, 121 | train_loader=train_loader, 122 | val_loader=val_loader, 123 | fabric=fabric 124 | ) 125 | 126 | end = time.time() 127 | elapsed = end-start 128 | fabric.print(f"Time elapsed {elapsed/60:.2f} min") 129 | fabric.print(f"Memory used: {torch.cuda.max_memory_reserved() / 1e9:.02f} GB") 130 | 131 | ######################################### 132 | ### 4 Evaluation 133 | ######################################### 134 | 135 | with torch.no_grad(): 136 | model.eval() 137 | test_acc = torchmetrics.Accuracy(task="multiclass", num_classes=10).to(fabric.device) 138 | 139 | for (features, targets) in test_loader: 140 | outputs = model(features) 141 | predicted_labels = torch.argmax(outputs, 1) 142 | test_acc.update(predicted_labels, targets) 143 | 144 | fabric.print(f"Test accuracy {test_acc.compute()*100:.2f}%") 145 | -------------------------------------------------------------------------------- /08-10-vit32/09_fsdp-act-checkp.py: -------------------------------------------------------------------------------- 1 | import time 2 | from functools import partial 3 | 4 | import lightning as L 5 | from lightning import Fabric 6 | from lightning.fabric.strategies import FSDPStrategy 7 | import torch 8 | import torch.nn.functional as F 9 | from torch.distributed.fsdp.wrap import transformer_auto_wrap_policy 10 | import torchmetrics 11 | from torchvision import transforms 12 | from torchvision.models import vit_h_14 13 | from torchvision.models import ViT_H_14_Weights 14 | from torchvision.models.vision_transformer import EncoderBlock 15 | from watermark import watermark 16 | 17 | from local_utilities import get_dataloaders_cifar10 18 | 19 | 20 | def train(num_epochs, model, optimizer, train_loader, val_loader, fabric): 21 | 22 | for epoch in range(num_epochs): 23 | train_acc = torchmetrics.Accuracy(task="multiclass", num_classes=10).to(fabric.device) 24 | 25 | model.train() 26 | for batch_idx, (features, targets) in enumerate(train_loader): 27 | model.train() 28 | 29 | ### FORWARD AND BACK PROP 30 | logits = model(features) 31 | loss = F.cross_entropy(logits, targets) 32 | 33 | optimizer.zero_grad() 34 | fabric.backward(loss) 35 | 36 | ### UPDATE MODEL PARAMETERS 37 | optimizer.step() 38 | 39 | ### LOGGING 40 | if not batch_idx % 50: 41 | fabric.print(f"Epoch: {epoch+1:04d}/{num_epochs:04d} | Batch {batch_idx:04d}/{len(train_loader):04d} | Loss: {loss:.4f}") 42 | 43 | model.eval() 44 | with torch.no_grad(): 45 | predicted_labels = torch.argmax(logits, 1) 46 | train_acc.update(predicted_labels, targets) 47 | 48 | ### MORE LOGGING 49 | model.eval() 50 | with torch.no_grad(): 51 | val_acc = torchmetrics.Accuracy(task="multiclass", num_classes=10).to(fabric.device) 52 | 53 | for (features, targets) in val_loader: 54 | outputs = model(features) 55 | predicted_labels = torch.argmax(outputs, 1) 56 | val_acc.update(predicted_labels, targets) 57 | 58 | fabric.print(f"Epoch: {epoch+1:04d}/{num_epochs:04d} | Train acc.: {train_acc.compute()*100:.2f}% | Val acc.: {val_acc.compute()*100:.2f}%") 59 | train_acc.reset(), val_acc.reset() 60 | 61 | 62 | if __name__ == "__main__": 63 | 64 | auto_wrap_policy = partial(transformer_auto_wrap_policy, transformer_layer_cls={EncoderBlock}) 65 | strategy = FSDPStrategy(auto_wrap_policy=auto_wrap_policy, activation_checkpointing=EncoderBlock) 66 | 67 | fabric = Fabric(accelerator="cuda", devices=4, strategy=strategy, precision="16-mixed") 68 | fabric.launch() 69 | 70 | L.seed_everything(123) 71 | fabric.print(watermark(packages="torch,lightning", python=True)) 72 | fabric.print("Torch CUDA available?", torch.cuda.is_available()) 73 | 74 | 75 | ########################## 76 | ### 1 Loading the Dataset 77 | ########################## 78 | train_transforms = transforms.Compose([transforms.Resize((518, 518)), 79 | #transforms.RandomCrop((224, 224)), 80 | transforms.ToTensor()]) 81 | 82 | test_transforms = transforms.Compose([transforms.Resize((518, 518)), 83 | #transforms.CenterCrop((224, 224)), 84 | transforms.ToTensor()]) 85 | 86 | with fabric.rank_zero_first(): 87 | # the above prevents race conditions when multiple processes 88 | # try to download and write the dataset to disk. 89 | train_loader, val_loader, test_loader = get_dataloaders_cifar10( 90 | batch_size=4, 91 | num_workers=1, 92 | train_transforms=train_transforms, 93 | test_transforms=test_transforms, 94 | validation_fraction=0.1) 95 | 96 | train_loader, val_loader, test_loader = fabric.setup_dataloaders( 97 | train_loader, val_loader, test_loader) 98 | 99 | 100 | ######################################### 101 | ### 2 Initializing the Model 102 | ######################################### 103 | 104 | model = vit_h_14(weights=ViT_H_14_Weights.IMAGENET1K_SWAG_E2E_V1) 105 | 106 | # replace output layer 107 | model.heads.head = torch.nn.Linear(in_features=1280, out_features=10) 108 | 109 | optimizer = torch.optim.Adam(model.parameters(), lr=5e-5) 110 | model, optimizer = fabric.setup(model, optimizer) 111 | 112 | ######################################### 113 | ### 3 Finetuning 114 | ######################################### 115 | 116 | start = time.time() 117 | train( 118 | num_epochs=1, 119 | model=model, 120 | optimizer=optimizer, 121 | train_loader=train_loader, 122 | val_loader=val_loader, 123 | fabric=fabric 124 | ) 125 | 126 | end = time.time() 127 | elapsed = end-start 128 | fabric.print(f"Time elapsed {elapsed/60:.2f} min") 129 | fabric.print(f"Memory used: {torch.cuda.max_memory_reserved() / 1e9:.02f} GB") 130 | 131 | ######################################### 132 | ### 4 Evaluation 133 | ######################################### 134 | 135 | with torch.no_grad(): 136 | model.eval() 137 | test_acc = torchmetrics.Accuracy(task="multiclass", num_classes=10).to(fabric.device) 138 | 139 | for (features, targets) in test_loader: 140 | outputs = model(features) 141 | predicted_labels = torch.argmax(outputs, 1) 142 | test_acc.update(predicted_labels, targets) 143 | 144 | fabric.print(f"Test accuracy {test_acc.compute()*100:.2f}%") -------------------------------------------------------------------------------- /08-10-vit32/10_fsdp-with-cpu-offload.py: -------------------------------------------------------------------------------- 1 | import time 2 | from functools import partial 3 | 4 | import lightning as L 5 | from lightning import Fabric 6 | from lightning.fabric.strategies import FSDPStrategy 7 | import torch 8 | import torch.nn.functional as F 9 | from torch.distributed.fsdp.wrap import transformer_auto_wrap_policy 10 | import torchmetrics 11 | from torchvision import transforms 12 | from torchvision.models import vit_h_14 13 | from torchvision.models import ViT_H_14_Weights 14 | from torchvision.models.vision_transformer import EncoderBlock 15 | from watermark import watermark 16 | 17 | from local_utilities import get_dataloaders_cifar10 18 | 19 | 20 | def train(num_epochs, model, optimizer, train_loader, val_loader, fabric): 21 | 22 | for epoch in range(num_epochs): 23 | train_acc = torchmetrics.Accuracy(task="multiclass", num_classes=10).to(fabric.device) 24 | 25 | model.train() 26 | for batch_idx, (features, targets) in enumerate(train_loader): 27 | model.train() 28 | 29 | ### FORWARD AND BACK PROP 30 | logits = model(features) 31 | loss = F.cross_entropy(logits, targets) 32 | 33 | optimizer.zero_grad() 34 | fabric.backward(loss) 35 | 36 | ### UPDATE MODEL PARAMETERS 37 | optimizer.step() 38 | 39 | ### LOGGING 40 | if not batch_idx % 50: 41 | fabric.print(f"Epoch: {epoch+1:04d}/{num_epochs:04d} | Batch {batch_idx:04d}/{len(train_loader):04d} | Loss: {loss:.4f}") 42 | 43 | model.eval() 44 | with torch.no_grad(): 45 | predicted_labels = torch.argmax(logits, 1) 46 | train_acc.update(predicted_labels, targets) 47 | 48 | ### MORE LOGGING 49 | model.eval() 50 | with torch.no_grad(): 51 | val_acc = torchmetrics.Accuracy(task="multiclass", num_classes=10).to(fabric.device) 52 | 53 | for (features, targets) in val_loader: 54 | outputs = model(features) 55 | predicted_labels = torch.argmax(outputs, 1) 56 | val_acc.update(predicted_labels, targets) 57 | 58 | fabric.print(f"Epoch: {epoch+1:04d}/{num_epochs:04d} | Train acc.: {train_acc.compute()*100:.2f}% | Val acc.: {val_acc.compute()*100:.2f}%") 59 | train_acc.reset(), val_acc.reset() 60 | 61 | 62 | if __name__ == "__main__": 63 | 64 | auto_wrap_policy = partial(transformer_auto_wrap_policy, transformer_layer_cls={EncoderBlock}) 65 | strategy = FSDPStrategy( 66 | auto_wrap_policy=auto_wrap_policy, 67 | activation_checkpointing=EncoderBlock, 68 | cpu_offload=True 69 | ) 70 | 71 | fabric = Fabric(accelerator="cuda", devices=4, strategy=strategy, precision="16-mixed") 72 | fabric.launch() 73 | 74 | L.seed_everything(123) 75 | fabric.print(watermark(packages="torch,lightning", python=True)) 76 | fabric.print("Torch CUDA available?", torch.cuda.is_available()) 77 | 78 | 79 | ########################## 80 | ### 1 Loading the Dataset 81 | ########################## 82 | train_transforms = transforms.Compose([transforms.Resize((518, 518)), 83 | #transforms.RandomCrop((224, 224)), 84 | transforms.ToTensor()]) 85 | 86 | test_transforms = transforms.Compose([transforms.Resize((518, 518)), 87 | #transforms.CenterCrop((224, 224)), 88 | transforms.ToTensor()]) 89 | 90 | with fabric.rank_zero_first(): 91 | # the above prevents race conditions when multiple processes 92 | # try to download and write the dataset to disk. 93 | train_loader, val_loader, test_loader = get_dataloaders_cifar10( 94 | batch_size=4, 95 | num_workers=1, 96 | train_transforms=train_transforms, 97 | test_transforms=test_transforms, 98 | validation_fraction=0.1) 99 | 100 | train_loader, val_loader, test_loader = fabric.setup_dataloaders( 101 | train_loader, val_loader, test_loader) 102 | 103 | 104 | ######################################### 105 | ### 2 Initializing the Model 106 | ######################################### 107 | 108 | model = vit_h_14(weights=ViT_H_14_Weights.IMAGENET1K_SWAG_E2E_V1) 109 | 110 | # replace output layer 111 | model.heads.head = torch.nn.Linear(in_features=1280, out_features=10) 112 | 113 | optimizer = torch.optim.Adam(model.parameters(), lr=5e-5) 114 | model, optimizer = fabric.setup(model, optimizer, move_to_device=False) 115 | 116 | ######################################### 117 | ### 3 Finetuning 118 | ######################################### 119 | 120 | start = time.time() 121 | train( 122 | num_epochs=1, 123 | model=model, 124 | optimizer=optimizer, 125 | train_loader=train_loader, 126 | val_loader=val_loader, 127 | fabric=fabric 128 | ) 129 | 130 | end = time.time() 131 | elapsed = end-start 132 | fabric.print(f"Time elapsed {elapsed/60:.2f} min") 133 | fabric.print(f"Memory used: {torch.cuda.max_memory_reserved() / 1e9:.02f} GB") 134 | 135 | ######################################### 136 | ### 4 Evaluation 137 | ######################################### 138 | 139 | with torch.no_grad(): 140 | model.eval() 141 | test_acc = torchmetrics.Accuracy(task="multiclass", num_classes=10).to(fabric.device) 142 | 143 | for (features, targets) in test_loader: 144 | outputs = model(features) 145 | predicted_labels = torch.argmax(outputs, 1) 146 | test_acc.update(predicted_labels, targets) 147 | 148 | fabric.print(f"Test accuracy {test_acc.compute()*100:.2f}%") -------------------------------------------------------------------------------- /08-10-vit32/10b_fsdp-with-cpu-offload-no-act-check.py: -------------------------------------------------------------------------------- 1 | import time 2 | from functools import partial 3 | 4 | import lightning as L 5 | from lightning import Fabric 6 | from lightning.fabric.strategies import FSDPStrategy 7 | import torch 8 | import torch.nn.functional as F 9 | from torch.distributed.fsdp.wrap import transformer_auto_wrap_policy 10 | import torchmetrics 11 | from torchvision import transforms 12 | from torchvision.models import vit_h_14 13 | from torchvision.models import ViT_H_14_Weights 14 | from torchvision.models.vision_transformer import EncoderBlock 15 | from watermark import watermark 16 | 17 | from local_utilities import get_dataloaders_cifar10 18 | 19 | 20 | def train(num_epochs, model, optimizer, train_loader, val_loader, fabric): 21 | 22 | for epoch in range(num_epochs): 23 | train_acc = torchmetrics.Accuracy(task="multiclass", num_classes=10).to(fabric.device) 24 | 25 | model.train() 26 | for batch_idx, (features, targets) in enumerate(train_loader): 27 | model.train() 28 | 29 | ### FORWARD AND BACK PROP 30 | logits = model(features) 31 | loss = F.cross_entropy(logits, targets) 32 | 33 | optimizer.zero_grad() 34 | fabric.backward(loss) 35 | 36 | ### UPDATE MODEL PARAMETERS 37 | optimizer.step() 38 | 39 | ### LOGGING 40 | if not batch_idx % 50: 41 | fabric.print(f"Epoch: {epoch+1:04d}/{num_epochs:04d} | Batch {batch_idx:04d}/{len(train_loader):04d} | Loss: {loss:.4f}") 42 | 43 | model.eval() 44 | with torch.no_grad(): 45 | predicted_labels = torch.argmax(logits, 1) 46 | train_acc.update(predicted_labels, targets) 47 | 48 | ### MORE LOGGING 49 | model.eval() 50 | with torch.no_grad(): 51 | val_acc = torchmetrics.Accuracy(task="multiclass", num_classes=10).to(fabric.device) 52 | 53 | for (features, targets) in val_loader: 54 | outputs = model(features) 55 | predicted_labels = torch.argmax(outputs, 1) 56 | val_acc.update(predicted_labels, targets) 57 | 58 | fabric.print(f"Epoch: {epoch+1:04d}/{num_epochs:04d} | Train acc.: {train_acc.compute()*100:.2f}% | Val acc.: {val_acc.compute()*100:.2f}%") 59 | train_acc.reset(), val_acc.reset() 60 | 61 | 62 | if __name__ == "__main__": 63 | 64 | auto_wrap_policy = partial(transformer_auto_wrap_policy, transformer_layer_cls={EncoderBlock}) 65 | strategy = FSDPStrategy( 66 | auto_wrap_policy=auto_wrap_policy, 67 | #activation_checkpointing=EncoderBlock, 68 | cpu_offload=True 69 | ) 70 | 71 | fabric = Fabric(accelerator="cuda", devices=4, strategy=strategy, precision="16-mixed") 72 | fabric.launch() 73 | 74 | L.seed_everything(123) 75 | fabric.print(watermark(packages="torch,lightning", python=True)) 76 | fabric.print("Torch CUDA available?", torch.cuda.is_available()) 77 | 78 | 79 | ########################## 80 | ### 1 Loading the Dataset 81 | ########################## 82 | train_transforms = transforms.Compose([transforms.Resize((518, 518)), 83 | #transforms.RandomCrop((224, 224)), 84 | transforms.ToTensor()]) 85 | 86 | test_transforms = transforms.Compose([transforms.Resize((518, 518)), 87 | #transforms.CenterCrop((224, 224)), 88 | transforms.ToTensor()]) 89 | 90 | with fabric.rank_zero_first(): 91 | # the above prevents race conditions when multiple processes 92 | # try to download and write the dataset to disk. 93 | train_loader, val_loader, test_loader = get_dataloaders_cifar10( 94 | batch_size=4, 95 | num_workers=1, 96 | train_transforms=train_transforms, 97 | test_transforms=test_transforms, 98 | validation_fraction=0.1) 99 | 100 | train_loader, val_loader, test_loader = fabric.setup_dataloaders( 101 | train_loader, val_loader, test_loader) 102 | 103 | 104 | ######################################### 105 | ### 2 Initializing the Model 106 | ######################################### 107 | 108 | model = vit_h_14(weights=ViT_H_14_Weights.IMAGENET1K_SWAG_E2E_V1) 109 | 110 | # replace output layer 111 | model.heads.head = torch.nn.Linear(in_features=1280, out_features=10) 112 | 113 | optimizer = torch.optim.Adam(model.parameters(), lr=5e-5) 114 | model, optimizer = fabric.setup(model, optimizer, move_to_device=False) 115 | 116 | ######################################### 117 | ### 3 Finetuning 118 | ######################################### 119 | 120 | start = time.time() 121 | train( 122 | num_epochs=1, 123 | model=model, 124 | optimizer=optimizer, 125 | train_loader=train_loader, 126 | val_loader=val_loader, 127 | fabric=fabric 128 | ) 129 | 130 | end = time.time() 131 | elapsed = end-start 132 | fabric.print(f"Time elapsed {elapsed/60:.2f} min") 133 | fabric.print(f"Memory used: {torch.cuda.max_memory_reserved() / 1e9:.02f} GB") 134 | 135 | ######################################### 136 | ### 4 Evaluation 137 | ######################################### 138 | 139 | with torch.no_grad(): 140 | model.eval() 141 | test_acc = torchmetrics.Accuracy(task="multiclass", num_classes=10).to(fabric.device) 142 | 143 | for (features, targets) in test_loader: 144 | outputs = model(features) 145 | predicted_labels = torch.argmax(outputs, 1) 146 | test_acc.update(predicted_labels, targets) 147 | 148 | fabric.print(f"Test accuracy {test_acc.compute()*100:.2f}%") -------------------------------------------------------------------------------- /08-10-vit32/local_utilities.py: -------------------------------------------------------------------------------- 1 | # Imports for ViT finetuning 2 | 3 | import torch 4 | from torch.utils.data import sampler 5 | from torchvision import datasets 6 | from torch.utils.data import DataLoader 7 | from torch.utils.data import SubsetRandomSampler 8 | from torchvision import transforms 9 | 10 | # Import for LLM finetuning 11 | import os 12 | import sys 13 | import tarfile 14 | import time 15 | 16 | import numpy as np 17 | import pandas as pd 18 | from packaging import version 19 | from torch.utils.data import Dataset 20 | from tqdm import tqdm 21 | import urllib 22 | 23 | ############################ 24 | ##### VIT finetuning dataset 25 | ############################ 26 | 27 | 28 | def get_dataloaders_cifar10(batch_size, num_workers=0, 29 | validation_fraction=None, 30 | train_transforms=None, 31 | test_transforms=None): 32 | 33 | if train_transforms is None: 34 | train_transforms = transforms.ToTensor() 35 | 36 | if test_transforms is None: 37 | test_transforms = transforms.ToTensor() 38 | 39 | train_dataset = datasets.CIFAR10(root='data', 40 | train=True, 41 | transform=train_transforms, 42 | download=True) 43 | 44 | valid_dataset = datasets.CIFAR10(root='data', 45 | train=True, 46 | transform=test_transforms) 47 | 48 | test_dataset = datasets.CIFAR10(root='data', 49 | train=False, 50 | transform=test_transforms) 51 | 52 | if validation_fraction is not None: 53 | num = int(validation_fraction * 50000) 54 | train_indices = torch.arange(0, 50000 - num) 55 | valid_indices = torch.arange(50000 - num, 50000) 56 | 57 | train_sampler = SubsetRandomSampler(train_indices) 58 | valid_sampler = SubsetRandomSampler(valid_indices) 59 | 60 | valid_loader = DataLoader(dataset=valid_dataset, 61 | batch_size=batch_size, 62 | num_workers=num_workers, 63 | sampler=valid_sampler) 64 | 65 | train_loader = DataLoader(dataset=train_dataset, 66 | batch_size=batch_size, 67 | num_workers=num_workers, 68 | drop_last=True, 69 | sampler=train_sampler) 70 | 71 | else: 72 | train_loader = DataLoader(dataset=train_dataset, 73 | batch_size=batch_size, 74 | num_workers=num_workers, 75 | drop_last=True, 76 | shuffle=True) 77 | 78 | test_loader = DataLoader(dataset=test_dataset, 79 | batch_size=batch_size, 80 | num_workers=num_workers, 81 | shuffle=False) 82 | 83 | if validation_fraction is None: 84 | return train_loader, test_loader 85 | else: 86 | return train_loader, valid_loader, test_loader 87 | 88 | ############################ 89 | ##### LLM finetuning dataset 90 | ############################ 91 | 92 | import os 93 | import sys 94 | import tarfile 95 | import time 96 | 97 | import numpy as np 98 | import pandas as pd 99 | from packaging import version 100 | from torch.utils.data import Dataset 101 | from tqdm import tqdm 102 | import urllib 103 | 104 | 105 | def reporthook(count, block_size, total_size): 106 | global start_time 107 | if count == 0: 108 | start_time = time.time() 109 | return 110 | duration = time.time() - start_time 111 | progress_size = int(count * block_size) 112 | speed = progress_size / (1024.0**2 * duration) 113 | percent = count * block_size * 100.0 / total_size 114 | 115 | sys.stdout.write( 116 | f"\r{int(percent)}% | {progress_size / (1024.**2):.2f} MB " 117 | f"| {speed:.2f} MB/s | {duration:.2f} sec elapsed" 118 | ) 119 | sys.stdout.flush() 120 | 121 | 122 | def download_dataset(): 123 | source = "http://ai.stanford.edu/~amaas/data/sentiment/aclImdb_v1.tar.gz" 124 | target = "aclImdb_v1.tar.gz" 125 | 126 | if os.path.exists(target): 127 | os.remove(target) 128 | 129 | if not os.path.isdir("aclImdb") and not os.path.isfile("aclImdb_v1.tar.gz"): 130 | urllib.request.urlretrieve(source, target, reporthook) 131 | 132 | if not os.path.isdir("aclImdb"): 133 | 134 | with tarfile.open(target, "r:gz") as tar: 135 | tar.extractall() 136 | 137 | 138 | def load_dataset_into_to_dataframe(): 139 | basepath = "aclImdb" 140 | 141 | labels = {"pos": 1, "neg": 0} 142 | 143 | df = pd.DataFrame() 144 | 145 | with tqdm(total=50000) as pbar: 146 | for s in ("test", "train"): 147 | for l in ("pos", "neg"): 148 | path = os.path.join(basepath, s, l) 149 | for file in sorted(os.listdir(path)): 150 | with open(os.path.join(path, file), "r", encoding="utf-8") as infile: 151 | txt = infile.read() 152 | 153 | if version.parse(pd.__version__) >= version.parse("1.3.2"): 154 | x = pd.DataFrame( 155 | [[txt, labels[l]]], columns=["review", "sentiment"] 156 | ) 157 | df = pd.concat([df, x], ignore_index=False) 158 | 159 | else: 160 | df = df.append([[txt, labels[l]]], ignore_index=True) 161 | pbar.update() 162 | df.columns = ["text", "label"] 163 | 164 | np.random.seed(0) 165 | df = df.reindex(np.random.permutation(df.index)) 166 | 167 | print("Class distribution:") 168 | np.bincount(df["label"].values) 169 | 170 | return df 171 | 172 | 173 | def partition_dataset(df): 174 | df_shuffled = df.sample(frac=1, random_state=1).reset_index() 175 | 176 | df_train = df_shuffled.iloc[:35_000] 177 | df_val = df_shuffled.iloc[35_000:40_000] 178 | df_test = df_shuffled.iloc[40_000:] 179 | 180 | df_train.to_csv("train.csv", index=False, encoding="utf-8") 181 | df_val.to_csv("val.csv", index=False, encoding="utf-8") 182 | df_test.to_csv("test.csv", index=False, encoding="utf-8") 183 | 184 | 185 | class IMDBDataset(Dataset): 186 | def __init__(self, dataset_dict, partition_key="train"): 187 | self.partition = dataset_dict[partition_key] 188 | 189 | def __getitem__(self, index): 190 | return self.partition[index] 191 | 192 | def __len__(self): 193 | return self.partition.num_rows -------------------------------------------------------------------------------- /08_baseline.py: -------------------------------------------------------------------------------- 1 | import time 2 | 3 | import lightning as L 4 | from lightning import Fabric 5 | import torch 6 | import torch.nn.functional as F 7 | import torchmetrics 8 | from torchvision import transforms 9 | from torchvision.models import vit_h_14 10 | from torchvision.models import ViT_H_14_Weights 11 | from watermark import watermark 12 | 13 | from local_utilities import get_dataloaders_cifar10 14 | 15 | 16 | def train(num_epochs, model, optimizer, train_loader, val_loader, fabric): 17 | 18 | for epoch in range(num_epochs): 19 | train_acc = torchmetrics.Accuracy(task="multiclass", num_classes=10).to(fabric.device) 20 | 21 | model.train() 22 | for batch_idx, (features, targets) in enumerate(train_loader): 23 | model.train() 24 | 25 | ### FORWARD AND BACK PROP 26 | logits = model(features) 27 | loss = F.cross_entropy(logits, targets) 28 | 29 | optimizer.zero_grad() 30 | fabric.backward(loss) 31 | 32 | ### UPDATE MODEL PARAMETERS 33 | optimizer.step() 34 | 35 | ### LOGGING 36 | if not batch_idx % 300: 37 | print(f"Epoch: {epoch+1:04d}/{num_epochs:04d} | Batch {batch_idx:04d}/{len(train_loader):04d} | Loss: {loss:.4f}") 38 | 39 | model.eval() 40 | with torch.no_grad(): 41 | predicted_labels = torch.argmax(logits, 1) 42 | train_acc.update(predicted_labels, targets) 43 | 44 | ### MORE LOGGING 45 | model.eval() 46 | with torch.no_grad(): 47 | val_acc = torchmetrics.Accuracy(task="multiclass", num_classes=10).to(fabric.device) 48 | 49 | for (features, targets) in val_loader: 50 | outputs = model(features) 51 | predicted_labels = torch.argmax(outputs, 1) 52 | val_acc.update(predicted_labels, targets) 53 | 54 | print(f"Epoch: {epoch+1:04d}/{num_epochs:04d} | Train acc.: {train_acc.compute()*100:.2f}% | Val acc.: {val_acc.compute()*100:.2f}%") 55 | train_acc.reset(), val_acc.reset() 56 | 57 | 58 | if __name__ == "__main__": 59 | 60 | print(watermark(packages="torch,lightning", python=True)) 61 | print("Torch CUDA available?", torch.cuda.is_available()) 62 | 63 | fabric = Fabric(accelerator="cuda", devices=1, precision="16-mixed") 64 | fabric.launch() 65 | 66 | L.seed_everything(123) 67 | 68 | ########################## 69 | ### 1 Loading the Dataset 70 | ########################## 71 | train_transforms = transforms.Compose([transforms.Resize((518, 518)), 72 | #transforms.RandomCrop((224, 224)), 73 | transforms.ToTensor()]) 74 | 75 | test_transforms = transforms.Compose([transforms.Resize((518, 518)), 76 | #transforms.CenterCrop((224, 224)), 77 | transforms.ToTensor()]) 78 | 79 | train_loader, val_loader, test_loader = get_dataloaders_cifar10( 80 | batch_size=4, 81 | num_workers=1, 82 | train_transforms=train_transforms, 83 | test_transforms=test_transforms, 84 | validation_fraction=0.1) 85 | 86 | train_loader, val_loader, test_loader = fabric.setup_dataloaders( 87 | train_loader, val_loader, test_loader) 88 | 89 | 90 | ######################################### 91 | ### 2 Initializing the Model 92 | ######################################### 93 | 94 | model = vit_h_14(weights=ViT_H_14_Weights.IMAGENET1K_SWAG_E2E_V1) 95 | 96 | # replace output layer 97 | model.heads.head = torch.nn.Linear(in_features=1280, out_features=10) 98 | 99 | optimizer = torch.optim.Adam(model.parameters(), lr=5e-5) 100 | model, optimizer = fabric.setup(model, optimizer) 101 | 102 | ######################################### 103 | ### 3 Finetuning 104 | ######################################### 105 | 106 | start = time.time() 107 | train( 108 | num_epochs=1, 109 | model=model, 110 | optimizer=optimizer, 111 | train_loader=train_loader, 112 | val_loader=val_loader, 113 | fabric=fabric 114 | ) 115 | 116 | end = time.time() 117 | elapsed = end-start 118 | print(f"Time elapsed {elapsed/60:.2f} min") 119 | print(f"Memory used: {torch.cuda.max_memory_reserved() / 1e9:.02f} GB") 120 | 121 | ######################################### 122 | ### 4 Evaluation 123 | ######################################### 124 | 125 | with torch.no_grad(): 126 | model.eval() 127 | test_acc = torchmetrics.Accuracy(task="multiclass", num_classes=10).to(fabric.device) 128 | 129 | for (features, targets) in test_loader: 130 | outputs = model(features) 131 | predicted_labels = torch.argmax(outputs, 1) 132 | test_acc.update(predicted_labels, targets) 133 | 134 | print(f"Test accuracy {test_acc.compute()*100:.2f}%") -------------------------------------------------------------------------------- /08a_fsdp-defaults.py: -------------------------------------------------------------------------------- 1 | import time 2 | from functools import partial 3 | 4 | import lightning as L 5 | from lightning import Fabric 6 | from lightning.fabric.strategies import FSDPStrategy 7 | import torch 8 | import torch.nn.functional as F 9 | from torch.distributed.fsdp.wrap import transformer_auto_wrap_policy 10 | import torchmetrics 11 | from torchvision import transforms 12 | from torchvision.models import vit_h_14 13 | from torchvision.models import ViT_H_14_Weights 14 | from torchvision.models.vision_transformer import EncoderBlock 15 | from watermark import watermark 16 | 17 | from local_utilities import get_dataloaders_cifar10 18 | 19 | 20 | def train(num_epochs, model, optimizer, train_loader, val_loader, fabric): 21 | 22 | for epoch in range(num_epochs): 23 | train_acc = torchmetrics.Accuracy(task="multiclass", num_classes=10).to(fabric.device) 24 | 25 | model.train() 26 | for batch_idx, (features, targets) in enumerate(train_loader): 27 | model.train() 28 | 29 | ### FORWARD AND BACK PROP 30 | logits = model(features) 31 | loss = F.cross_entropy(logits, targets) 32 | 33 | optimizer.zero_grad() 34 | fabric.backward(loss) 35 | 36 | ### UPDATE MODEL PARAMETERS 37 | optimizer.step() 38 | 39 | ### LOGGING 40 | if not batch_idx % 50: 41 | fabric.print(f"Epoch: {epoch+1:04d}/{num_epochs:04d} | Batch {batch_idx:04d}/{len(train_loader):04d} | Loss: {loss:.4f}") 42 | 43 | model.eval() 44 | with torch.no_grad(): 45 | predicted_labels = torch.argmax(logits, 1) 46 | train_acc.update(predicted_labels, targets) 47 | 48 | ### MORE LOGGING 49 | model.eval() 50 | with torch.no_grad(): 51 | val_acc = torchmetrics.Accuracy(task="multiclass", num_classes=10).to(fabric.device) 52 | 53 | for (features, targets) in val_loader: 54 | outputs = model(features) 55 | predicted_labels = torch.argmax(outputs, 1) 56 | val_acc.update(predicted_labels, targets) 57 | 58 | fabric.print(f"Epoch: {epoch+1:04d}/{num_epochs:04d} | Train acc.: {train_acc.compute()*100:.2f}% | Val acc.: {val_acc.compute()*100:.2f}%") 59 | train_acc.reset(), val_acc.reset() 60 | 61 | 62 | if __name__ == "__main__": 63 | 64 | fabric = Fabric(accelerator="cuda", devices=4, strategy="fsdp", precision="16-mixed") 65 | fabric.launch() 66 | 67 | L.seed_everything(123) 68 | fabric.print(watermark(packages="torch,lightning", python=True)) 69 | fabric.print("Torch CUDA available?", torch.cuda.is_available()) 70 | 71 | 72 | ########################## 73 | ### 1 Loading the Dataset 74 | ########################## 75 | train_transforms = transforms.Compose([transforms.Resize((518, 518)), 76 | #transforms.RandomCrop((224, 224)), 77 | transforms.ToTensor()]) 78 | 79 | test_transforms = transforms.Compose([transforms.Resize((518, 518)), 80 | #transforms.CenterCrop((224, 224)), 81 | transforms.ToTensor()]) 82 | 83 | with fabric.rank_zero_first(): 84 | # the above prevents race conditions when multiple processes 85 | # try to download and write the dataset to disk. 86 | train_loader, val_loader, test_loader = get_dataloaders_cifar10( 87 | batch_size=4, 88 | num_workers=1, 89 | train_transforms=train_transforms, 90 | test_transforms=test_transforms, 91 | validation_fraction=0.1) 92 | 93 | train_loader, val_loader, test_loader = fabric.setup_dataloaders( 94 | train_loader, val_loader, test_loader) 95 | 96 | 97 | ######################################### 98 | ### 2 Initializing the Model 99 | ######################################### 100 | 101 | model = vit_h_14(weights=ViT_H_14_Weights.IMAGENET1K_SWAG_E2E_V1) 102 | 103 | # replace output layer 104 | model.heads.head = torch.nn.Linear(in_features=1280, out_features=10) 105 | 106 | optimizer = torch.optim.Adam(model.parameters(), lr=5e-5) 107 | model, optimizer = fabric.setup(model, optimizer) 108 | 109 | ######################################### 110 | ### 3 Finetuning 111 | ######################################### 112 | 113 | start = time.time() 114 | train( 115 | num_epochs=1, 116 | model=model, 117 | optimizer=optimizer, 118 | train_loader=train_loader, 119 | val_loader=val_loader, 120 | fabric=fabric 121 | ) 122 | 123 | end = time.time() 124 | elapsed = end-start 125 | fabric.print(f"Time elapsed {elapsed/60:.2f} min") 126 | fabric.print(f"Memory used: {torch.cuda.max_memory_reserved() / 1e9:.02f} GB") 127 | 128 | ######################################### 129 | ### 4 Evaluation 130 | ######################################### 131 | 132 | with torch.no_grad(): 133 | model.eval() 134 | test_acc = torchmetrics.Accuracy(task="multiclass", num_classes=10).to(fabric.device) 135 | 136 | for (features, targets) in test_loader: 137 | outputs = model(features) 138 | predicted_labels = torch.argmax(outputs, 1) 139 | test_acc.update(predicted_labels, targets) 140 | 141 | fabric.print(f"Test accuracy {test_acc.compute()*100:.2f}%") -------------------------------------------------------------------------------- /08b_fsdp-custom.py: -------------------------------------------------------------------------------- 1 | import time 2 | from functools import partial 3 | 4 | import lightning as L 5 | from lightning import Fabric 6 | from lightning.fabric.strategies import FSDPStrategy 7 | import torch 8 | import torch.nn.functional as F 9 | from torch.distributed.fsdp.wrap import transformer_auto_wrap_policy 10 | import torchmetrics 11 | from torchvision import transforms 12 | from torchvision.models import vit_h_14 13 | from torchvision.models import ViT_H_14_Weights 14 | from torchvision.models.vision_transformer import EncoderBlock 15 | from watermark import watermark 16 | 17 | from local_utilities import get_dataloaders_cifar10 18 | 19 | 20 | def train(num_epochs, model, optimizer, train_loader, val_loader, fabric): 21 | 22 | for epoch in range(num_epochs): 23 | train_acc = torchmetrics.Accuracy(task="multiclass", num_classes=10).to(fabric.device) 24 | 25 | model.train() 26 | for batch_idx, (features, targets) in enumerate(train_loader): 27 | model.train() 28 | 29 | ### FORWARD AND BACK PROP 30 | logits = model(features) 31 | loss = F.cross_entropy(logits, targets) 32 | 33 | optimizer.zero_grad() 34 | fabric.backward(loss) 35 | 36 | ### UPDATE MODEL PARAMETERS 37 | optimizer.step() 38 | 39 | ### LOGGING 40 | if not batch_idx % 50: 41 | fabric.print(f"Epoch: {epoch+1:04d}/{num_epochs:04d} | Batch {batch_idx:04d}/{len(train_loader):04d} | Loss: {loss:.4f}") 42 | 43 | model.eval() 44 | with torch.no_grad(): 45 | predicted_labels = torch.argmax(logits, 1) 46 | train_acc.update(predicted_labels, targets) 47 | 48 | ### MORE LOGGING 49 | model.eval() 50 | with torch.no_grad(): 51 | val_acc = torchmetrics.Accuracy(task="multiclass", num_classes=10).to(fabric.device) 52 | 53 | for (features, targets) in val_loader: 54 | outputs = model(features) 55 | predicted_labels = torch.argmax(outputs, 1) 56 | val_acc.update(predicted_labels, targets) 57 | 58 | fabric.print(f"Epoch: {epoch+1:04d}/{num_epochs:04d} | Train acc.: {train_acc.compute()*100:.2f}% | Val acc.: {val_acc.compute()*100:.2f}%") 59 | train_acc.reset(), val_acc.reset() 60 | 61 | 62 | if __name__ == "__main__": 63 | 64 | auto_wrap_policy = partial(transformer_auto_wrap_policy, transformer_layer_cls={EncoderBlock}) 65 | strategy = FSDPStrategy(auto_wrap_policy=auto_wrap_policy) 66 | 67 | fabric = Fabric(accelerator="cuda", devices=4, strategy=strategy, precision="16-mixed") 68 | fabric.launch() 69 | 70 | L.seed_everything(123) 71 | fabric.print(watermark(packages="torch,lightning", python=True)) 72 | fabric.print("Torch CUDA available?", torch.cuda.is_available()) 73 | 74 | 75 | ########################## 76 | ### 1 Loading the Dataset 77 | ########################## 78 | train_transforms = transforms.Compose([transforms.Resize((518, 518)), 79 | #transforms.RandomCrop((224, 224)), 80 | transforms.ToTensor()]) 81 | 82 | test_transforms = transforms.Compose([transforms.Resize((518, 518)), 83 | #transforms.CenterCrop((224, 224)), 84 | transforms.ToTensor()]) 85 | with fabric.rank_zero_first(): 86 | # the above prevents race conditions when multiple processes 87 | # try to download and write the dataset to disk. 88 | train_loader, val_loader, test_loader = get_dataloaders_cifar10( 89 | batch_size=4, 90 | num_workers=1, 91 | train_transforms=train_transforms, 92 | test_transforms=test_transforms, 93 | validation_fraction=0.1) 94 | 95 | train_loader, val_loader, test_loader = fabric.setup_dataloaders( 96 | train_loader, val_loader, test_loader) 97 | 98 | 99 | ######################################### 100 | ### 2 Initializing the Model 101 | ######################################### 102 | 103 | model = vit_h_14(weights=ViT_H_14_Weights.IMAGENET1K_SWAG_E2E_V1) 104 | 105 | # replace output layer 106 | model.heads.head = torch.nn.Linear(in_features=1280, out_features=10) 107 | 108 | optimizer = torch.optim.Adam(model.parameters(), lr=5e-5) 109 | model, optimizer = fabric.setup(model, optimizer) 110 | 111 | ######################################### 112 | ### 3 Finetuning 113 | ######################################### 114 | 115 | start = time.time() 116 | train( 117 | num_epochs=1, 118 | model=model, 119 | optimizer=optimizer, 120 | train_loader=train_loader, 121 | val_loader=val_loader, 122 | fabric=fabric 123 | ) 124 | 125 | end = time.time() 126 | elapsed = end-start 127 | fabric.print(f"Time elapsed {elapsed/60:.2f} min") 128 | fabric.print(f"Memory used: {torch.cuda.max_memory_reserved() / 1e9:.02f} GB") 129 | 130 | ######################################### 131 | ### 4 Evaluation 132 | ######################################### 133 | 134 | with torch.no_grad(): 135 | model.eval() 136 | test_acc = torchmetrics.Accuracy(task="multiclass", num_classes=10).to(fabric.device) 137 | 138 | for (features, targets) in test_loader: 139 | outputs = model(features) 140 | predicted_labels = torch.argmax(outputs, 1) 141 | test_acc.update(predicted_labels, targets) 142 | 143 | fabric.print(f"Test accuracy {test_acc.compute()*100:.2f}%") -------------------------------------------------------------------------------- /08c_fsdp-size-wrap.py: -------------------------------------------------------------------------------- 1 | import time 2 | from functools import partial 3 | 4 | import lightning as L 5 | from lightning import Fabric 6 | from lightning.fabric.strategies import FSDPStrategy 7 | import torch 8 | import torch.nn.functional as F 9 | from torch.distributed.fsdp.wrap import size_based_auto_wrap_policy 10 | import torchmetrics 11 | from torchvision import transforms 12 | from torchvision.models import vit_h_14 13 | from torchvision.models import ViT_H_14_Weights 14 | from torchvision.models.vision_transformer import EncoderBlock 15 | from watermark import watermark 16 | 17 | from local_utilities import get_dataloaders_cifar10 18 | 19 | 20 | def train(num_epochs, model, optimizer, train_loader, val_loader, fabric): 21 | 22 | for epoch in range(num_epochs): 23 | train_acc = torchmetrics.Accuracy(task="multiclass", num_classes=10).to(fabric.device) 24 | 25 | model.train() 26 | for batch_idx, (features, targets) in enumerate(train_loader): 27 | model.train() 28 | 29 | ### FORWARD AND BACK PROP 30 | logits = model(features) 31 | loss = F.cross_entropy(logits, targets) 32 | 33 | optimizer.zero_grad() 34 | fabric.backward(loss) 35 | 36 | ### UPDATE MODEL PARAMETERS 37 | optimizer.step() 38 | 39 | ### LOGGING 40 | if not batch_idx % 50: 41 | fabric.print(f"Epoch: {epoch+1:04d}/{num_epochs:04d} | Batch {batch_idx:04d}/{len(train_loader):04d} | Loss: {loss:.4f}") 42 | 43 | model.eval() 44 | with torch.no_grad(): 45 | predicted_labels = torch.argmax(logits, 1) 46 | train_acc.update(predicted_labels, targets) 47 | 48 | ### MORE LOGGING 49 | model.eval() 50 | with torch.no_grad(): 51 | val_acc = torchmetrics.Accuracy(task="multiclass", num_classes=10).to(fabric.device) 52 | 53 | for (features, targets) in val_loader: 54 | outputs = model(features) 55 | predicted_labels = torch.argmax(outputs, 1) 56 | val_acc.update(predicted_labels, targets) 57 | 58 | fabric.print(f"Epoch: {epoch+1:04d}/{num_epochs:04d} | Train acc.: {train_acc.compute()*100:.2f}% | Val acc.: {val_acc.compute()*100:.2f}%") 59 | train_acc.reset(), val_acc.reset() 60 | 61 | 62 | if __name__ == "__main__": 63 | 64 | auto_wrap_policy = partial(size_based_auto_wrap_policy, module={EncoderBlock}, min_num_params=2_000_000) 65 | strategy = FSDPStrategy(auto_wrap_policy=auto_wrap_policy) 66 | 67 | fabric = Fabric(accelerator="cuda", devices=4, strategy=strategy, precision="16-mixed") 68 | fabric.launch() 69 | 70 | L.seed_everything(123) 71 | fabric.print(watermark(packages="torch,lightning", python=True)) 72 | fabric.print("Torch CUDA available?", torch.cuda.is_available()) 73 | 74 | 75 | ########################## 76 | ### 1 Loading the Dataset 77 | ########################## 78 | train_transforms = transforms.Compose([transforms.Resize((518, 518)), 79 | #transforms.RandomCrop((224, 224)), 80 | transforms.ToTensor()]) 81 | 82 | test_transforms = transforms.Compose([transforms.Resize((518, 518)), 83 | #transforms.CenterCrop((224, 224)), 84 | transforms.ToTensor()]) 85 | 86 | with fabric.rank_zero_first(): 87 | # the above prevents race conditions when multiple processes 88 | # try to download and write the dataset to disk. 89 | train_loader, val_loader, test_loader = get_dataloaders_cifar10( 90 | batch_size=4, 91 | num_workers=1, 92 | train_transforms=train_transforms, 93 | test_transforms=test_transforms, 94 | validation_fraction=0.1) 95 | 96 | train_loader, val_loader, test_loader = fabric.setup_dataloaders( 97 | train_loader, val_loader, test_loader) 98 | 99 | 100 | ######################################### 101 | ### 2 Initializing the Model 102 | ######################################### 103 | 104 | model = vit_h_14(weights=ViT_H_14_Weights.IMAGENET1K_SWAG_E2E_V1) 105 | 106 | # replace output layer 107 | model.heads.head = torch.nn.Linear(in_features=1280, out_features=10) 108 | 109 | optimizer = torch.optim.Adam(model.parameters(), lr=5e-5) 110 | model, optimizer = fabric.setup(model, optimizer) 111 | 112 | ######################################### 113 | ### 3 Finetuning 114 | ######################################### 115 | 116 | start = time.time() 117 | train( 118 | num_epochs=1, 119 | model=model, 120 | optimizer=optimizer, 121 | train_loader=train_loader, 122 | val_loader=val_loader, 123 | fabric=fabric 124 | ) 125 | 126 | end = time.time() 127 | elapsed = end-start 128 | fabric.print(f"Time elapsed {elapsed/60:.2f} min") 129 | fabric.print(f"Memory used: {torch.cuda.max_memory_reserved() / 1e9:.02f} GB") 130 | 131 | ######################################### 132 | ### 4 Evaluation 133 | ######################################### 134 | 135 | with torch.no_grad(): 136 | model.eval() 137 | test_acc = torchmetrics.Accuracy(task="multiclass", num_classes=10).to(fabric.device) 138 | 139 | for (features, targets) in test_loader: 140 | outputs = model(features) 141 | predicted_labels = torch.argmax(outputs, 1) 142 | test_acc.update(predicted_labels, targets) 143 | 144 | fabric.print(f"Test accuracy {test_acc.compute()*100:.2f}%") 145 | -------------------------------------------------------------------------------- /09_fsdp-act-checkp.py: -------------------------------------------------------------------------------- 1 | import time 2 | from functools import partial 3 | 4 | import lightning as L 5 | from lightning import Fabric 6 | from lightning.fabric.strategies import FSDPStrategy 7 | import torch 8 | import torch.nn.functional as F 9 | from torch.distributed.fsdp.wrap import transformer_auto_wrap_policy 10 | import torchmetrics 11 | from torchvision import transforms 12 | from torchvision.models import vit_h_14 13 | from torchvision.models import ViT_H_14_Weights 14 | from torchvision.models.vision_transformer import EncoderBlock 15 | from watermark import watermark 16 | 17 | from local_utilities import get_dataloaders_cifar10 18 | 19 | 20 | def train(num_epochs, model, optimizer, train_loader, val_loader, fabric): 21 | 22 | for epoch in range(num_epochs): 23 | train_acc = torchmetrics.Accuracy(task="multiclass", num_classes=10).to(fabric.device) 24 | 25 | model.train() 26 | for batch_idx, (features, targets) in enumerate(train_loader): 27 | model.train() 28 | 29 | ### FORWARD AND BACK PROP 30 | logits = model(features) 31 | loss = F.cross_entropy(logits, targets) 32 | 33 | optimizer.zero_grad() 34 | fabric.backward(loss) 35 | 36 | ### UPDATE MODEL PARAMETERS 37 | optimizer.step() 38 | 39 | ### LOGGING 40 | if not batch_idx % 50: 41 | fabric.print(f"Epoch: {epoch+1:04d}/{num_epochs:04d} | Batch {batch_idx:04d}/{len(train_loader):04d} | Loss: {loss:.4f}") 42 | 43 | model.eval() 44 | with torch.no_grad(): 45 | predicted_labels = torch.argmax(logits, 1) 46 | train_acc.update(predicted_labels, targets) 47 | 48 | ### MORE LOGGING 49 | model.eval() 50 | with torch.no_grad(): 51 | val_acc = torchmetrics.Accuracy(task="multiclass", num_classes=10).to(fabric.device) 52 | 53 | for (features, targets) in val_loader: 54 | outputs = model(features) 55 | predicted_labels = torch.argmax(outputs, 1) 56 | val_acc.update(predicted_labels, targets) 57 | 58 | fabric.print(f"Epoch: {epoch+1:04d}/{num_epochs:04d} | Train acc.: {train_acc.compute()*100:.2f}% | Val acc.: {val_acc.compute()*100:.2f}%") 59 | train_acc.reset(), val_acc.reset() 60 | 61 | 62 | if __name__ == "__main__": 63 | 64 | auto_wrap_policy = partial(transformer_auto_wrap_policy, transformer_layer_cls={EncoderBlock}) 65 | strategy = FSDPStrategy(auto_wrap_policy=auto_wrap_policy, activation_checkpointing=EncoderBlock) 66 | 67 | fabric = Fabric(accelerator="cuda", devices=4, strategy=strategy, precision="16-mixed") 68 | fabric.launch() 69 | 70 | L.seed_everything(123) 71 | fabric.print(watermark(packages="torch,lightning", python=True)) 72 | fabric.print("Torch CUDA available?", torch.cuda.is_available()) 73 | 74 | 75 | ########################## 76 | ### 1 Loading the Dataset 77 | ########################## 78 | train_transforms = transforms.Compose([transforms.Resize((518, 518)), 79 | #transforms.RandomCrop((224, 224)), 80 | transforms.ToTensor()]) 81 | 82 | test_transforms = transforms.Compose([transforms.Resize((518, 518)), 83 | #transforms.CenterCrop((224, 224)), 84 | transforms.ToTensor()]) 85 | 86 | with fabric.rank_zero_first(): 87 | # the above prevents race conditions when multiple processes 88 | # try to download and write the dataset to disk. 89 | train_loader, val_loader, test_loader = get_dataloaders_cifar10( 90 | batch_size=4, 91 | num_workers=1, 92 | train_transforms=train_transforms, 93 | test_transforms=test_transforms, 94 | validation_fraction=0.1) 95 | 96 | train_loader, val_loader, test_loader = fabric.setup_dataloaders( 97 | train_loader, val_loader, test_loader) 98 | 99 | 100 | ######################################### 101 | ### 2 Initializing the Model 102 | ######################################### 103 | 104 | model = vit_h_14(weights=ViT_H_14_Weights.IMAGENET1K_SWAG_E2E_V1) 105 | 106 | # replace output layer 107 | model.heads.head = torch.nn.Linear(in_features=1280, out_features=10) 108 | 109 | optimizer = torch.optim.Adam(model.parameters(), lr=5e-5) 110 | model, optimizer = fabric.setup(model, optimizer) 111 | 112 | ######################################### 113 | ### 3 Finetuning 114 | ######################################### 115 | 116 | start = time.time() 117 | train( 118 | num_epochs=1, 119 | model=model, 120 | optimizer=optimizer, 121 | train_loader=train_loader, 122 | val_loader=val_loader, 123 | fabric=fabric 124 | ) 125 | 126 | end = time.time() 127 | elapsed = end-start 128 | fabric.print(f"Time elapsed {elapsed/60:.2f} min") 129 | fabric.print(f"Memory used: {torch.cuda.max_memory_reserved() / 1e9:.02f} GB") 130 | 131 | ######################################### 132 | ### 4 Evaluation 133 | ######################################### 134 | 135 | with torch.no_grad(): 136 | model.eval() 137 | test_acc = torchmetrics.Accuracy(task="multiclass", num_classes=10).to(fabric.device) 138 | 139 | for (features, targets) in test_loader: 140 | outputs = model(features) 141 | predicted_labels = torch.argmax(outputs, 1) 142 | test_acc.update(predicted_labels, targets) 143 | 144 | fabric.print(f"Test accuracy {test_acc.compute()*100:.2f}%") -------------------------------------------------------------------------------- /10_fsdp-with-cpu-offload.py: -------------------------------------------------------------------------------- 1 | import time 2 | from functools import partial 3 | 4 | import lightning as L 5 | from lightning import Fabric 6 | from lightning.fabric.strategies import FSDPStrategy 7 | import torch 8 | import torch.nn.functional as F 9 | from torch.distributed.fsdp.wrap import transformer_auto_wrap_policy 10 | import torchmetrics 11 | from torchvision import transforms 12 | from torchvision.models import vit_h_14 13 | from torchvision.models import ViT_H_14_Weights 14 | from torchvision.models.vision_transformer import EncoderBlock 15 | from watermark import watermark 16 | 17 | from local_utilities import get_dataloaders_cifar10 18 | 19 | 20 | def train(num_epochs, model, optimizer, train_loader, val_loader, fabric): 21 | 22 | for epoch in range(num_epochs): 23 | train_acc = torchmetrics.Accuracy(task="multiclass", num_classes=10).to(fabric.device) 24 | 25 | model.train() 26 | for batch_idx, (features, targets) in enumerate(train_loader): 27 | model.train() 28 | 29 | ### FORWARD AND BACK PROP 30 | logits = model(features) 31 | loss = F.cross_entropy(logits, targets) 32 | 33 | optimizer.zero_grad() 34 | fabric.backward(loss) 35 | 36 | ### UPDATE MODEL PARAMETERS 37 | optimizer.step() 38 | 39 | ### LOGGING 40 | if not batch_idx % 50: 41 | fabric.print(f"Epoch: {epoch+1:04d}/{num_epochs:04d} | Batch {batch_idx:04d}/{len(train_loader):04d} | Loss: {loss:.4f}") 42 | 43 | model.eval() 44 | with torch.no_grad(): 45 | predicted_labels = torch.argmax(logits, 1) 46 | train_acc.update(predicted_labels, targets) 47 | 48 | ### MORE LOGGING 49 | model.eval() 50 | with torch.no_grad(): 51 | val_acc = torchmetrics.Accuracy(task="multiclass", num_classes=10).to(fabric.device) 52 | 53 | for (features, targets) in val_loader: 54 | outputs = model(features) 55 | predicted_labels = torch.argmax(outputs, 1) 56 | val_acc.update(predicted_labels, targets) 57 | 58 | fabric.print(f"Epoch: {epoch+1:04d}/{num_epochs:04d} | Train acc.: {train_acc.compute()*100:.2f}% | Val acc.: {val_acc.compute()*100:.2f}%") 59 | train_acc.reset(), val_acc.reset() 60 | 61 | 62 | if __name__ == "__main__": 63 | 64 | auto_wrap_policy = partial(transformer_auto_wrap_policy, transformer_layer_cls={EncoderBlock}) 65 | strategy = FSDPStrategy( 66 | auto_wrap_policy=auto_wrap_policy, 67 | activation_checkpointing=EncoderBlock, 68 | cpu_offload=True 69 | ) 70 | 71 | fabric = Fabric(accelerator="cuda", devices=4, strategy=strategy, precision="16-mixed") 72 | fabric.launch() 73 | 74 | L.seed_everything(123) 75 | fabric.print(watermark(packages="torch,lightning", python=True)) 76 | fabric.print("Torch CUDA available?", torch.cuda.is_available()) 77 | 78 | 79 | ########################## 80 | ### 1 Loading the Dataset 81 | ########################## 82 | train_transforms = transforms.Compose([transforms.Resize((518, 518)), 83 | #transforms.RandomCrop((224, 224)), 84 | transforms.ToTensor()]) 85 | 86 | test_transforms = transforms.Compose([transforms.Resize((518, 518)), 87 | #transforms.CenterCrop((224, 224)), 88 | transforms.ToTensor()]) 89 | 90 | with fabric.rank_zero_first(): 91 | # the above prevents race conditions when multiple processes 92 | # try to download and write the dataset to disk. 93 | train_loader, val_loader, test_loader = get_dataloaders_cifar10( 94 | batch_size=4, 95 | num_workers=1, 96 | train_transforms=train_transforms, 97 | test_transforms=test_transforms, 98 | validation_fraction=0.1) 99 | 100 | train_loader, val_loader, test_loader = fabric.setup_dataloaders( 101 | train_loader, val_loader, test_loader) 102 | 103 | 104 | ######################################### 105 | ### 2 Initializing the Model 106 | ######################################### 107 | 108 | model = vit_h_14(weights=ViT_H_14_Weights.IMAGENET1K_SWAG_E2E_V1) 109 | 110 | # replace output layer 111 | model.heads.head = torch.nn.Linear(in_features=1280, out_features=10) 112 | 113 | optimizer = torch.optim.Adam(model.parameters(), lr=5e-5) 114 | model, optimizer = fabric.setup(model, optimizer, move_to_device=False) 115 | 116 | ######################################### 117 | ### 3 Finetuning 118 | ######################################### 119 | 120 | start = time.time() 121 | train( 122 | num_epochs=1, 123 | model=model, 124 | optimizer=optimizer, 125 | train_loader=train_loader, 126 | val_loader=val_loader, 127 | fabric=fabric 128 | ) 129 | 130 | end = time.time() 131 | elapsed = end-start 132 | fabric.print(f"Time elapsed {elapsed/60:.2f} min") 133 | fabric.print(f"Memory used: {torch.cuda.max_memory_reserved() / 1e9:.02f} GB") 134 | 135 | ######################################### 136 | ### 4 Evaluation 137 | ######################################### 138 | 139 | with torch.no_grad(): 140 | model.eval() 141 | test_acc = torchmetrics.Accuracy(task="multiclass", num_classes=10).to(fabric.device) 142 | 143 | for (features, targets) in test_loader: 144 | outputs = model(features) 145 | predicted_labels = torch.argmax(outputs, 1) 146 | test_acc.update(predicted_labels, targets) 147 | 148 | fabric.print(f"Test accuracy {test_acc.compute()*100:.2f}%") -------------------------------------------------------------------------------- /10b_fsdp-with-cpu-offload-no-act-check.py: -------------------------------------------------------------------------------- 1 | import time 2 | from functools import partial 3 | 4 | import lightning as L 5 | from lightning import Fabric 6 | from lightning.fabric.strategies import FSDPStrategy 7 | import torch 8 | import torch.nn.functional as F 9 | from torch.distributed.fsdp.wrap import transformer_auto_wrap_policy 10 | import torchmetrics 11 | from torchvision import transforms 12 | from torchvision.models import vit_h_14 13 | from torchvision.models import ViT_H_14_Weights 14 | from torchvision.models.vision_transformer import EncoderBlock 15 | from watermark import watermark 16 | 17 | from local_utilities import get_dataloaders_cifar10 18 | 19 | 20 | def train(num_epochs, model, optimizer, train_loader, val_loader, fabric): 21 | 22 | for epoch in range(num_epochs): 23 | train_acc = torchmetrics.Accuracy(task="multiclass", num_classes=10).to(fabric.device) 24 | 25 | model.train() 26 | for batch_idx, (features, targets) in enumerate(train_loader): 27 | model.train() 28 | 29 | ### FORWARD AND BACK PROP 30 | logits = model(features) 31 | loss = F.cross_entropy(logits, targets) 32 | 33 | optimizer.zero_grad() 34 | fabric.backward(loss) 35 | 36 | ### UPDATE MODEL PARAMETERS 37 | optimizer.step() 38 | 39 | ### LOGGING 40 | if not batch_idx % 50: 41 | fabric.print(f"Epoch: {epoch+1:04d}/{num_epochs:04d} | Batch {batch_idx:04d}/{len(train_loader):04d} | Loss: {loss:.4f}") 42 | 43 | model.eval() 44 | with torch.no_grad(): 45 | predicted_labels = torch.argmax(logits, 1) 46 | train_acc.update(predicted_labels, targets) 47 | 48 | ### MORE LOGGING 49 | model.eval() 50 | with torch.no_grad(): 51 | val_acc = torchmetrics.Accuracy(task="multiclass", num_classes=10).to(fabric.device) 52 | 53 | for (features, targets) in val_loader: 54 | outputs = model(features) 55 | predicted_labels = torch.argmax(outputs, 1) 56 | val_acc.update(predicted_labels, targets) 57 | 58 | fabric.print(f"Epoch: {epoch+1:04d}/{num_epochs:04d} | Train acc.: {train_acc.compute()*100:.2f}% | Val acc.: {val_acc.compute()*100:.2f}%") 59 | train_acc.reset(), val_acc.reset() 60 | 61 | 62 | if __name__ == "__main__": 63 | 64 | auto_wrap_policy = partial(transformer_auto_wrap_policy, transformer_layer_cls={EncoderBlock}) 65 | strategy = FSDPStrategy( 66 | auto_wrap_policy=auto_wrap_policy, 67 | #activation_checkpointing=EncoderBlock, 68 | cpu_offload=True 69 | ) 70 | 71 | fabric = Fabric(accelerator="cuda", devices=4, strategy=strategy, precision="16-mixed") 72 | fabric.launch() 73 | 74 | L.seed_everything(123) 75 | fabric.print(watermark(packages="torch,lightning", python=True)) 76 | fabric.print("Torch CUDA available?", torch.cuda.is_available()) 77 | 78 | 79 | ########################## 80 | ### 1 Loading the Dataset 81 | ########################## 82 | train_transforms = transforms.Compose([transforms.Resize((518, 518)), 83 | #transforms.RandomCrop((224, 224)), 84 | transforms.ToTensor()]) 85 | 86 | test_transforms = transforms.Compose([transforms.Resize((518, 518)), 87 | #transforms.CenterCrop((224, 224)), 88 | transforms.ToTensor()]) 89 | 90 | with fabric.rank_zero_first(): 91 | # the above prevents race conditions when multiple processes 92 | # try to download and write the dataset to disk. 93 | train_loader, val_loader, test_loader = get_dataloaders_cifar10( 94 | batch_size=4, 95 | num_workers=1, 96 | train_transforms=train_transforms, 97 | test_transforms=test_transforms, 98 | validation_fraction=0.1) 99 | 100 | train_loader, val_loader, test_loader = fabric.setup_dataloaders( 101 | train_loader, val_loader, test_loader) 102 | 103 | 104 | ######################################### 105 | ### 2 Initializing the Model 106 | ######################################### 107 | 108 | model = vit_h_14(weights=ViT_H_14_Weights.IMAGENET1K_SWAG_E2E_V1) 109 | 110 | # replace output layer 111 | model.heads.head = torch.nn.Linear(in_features=1280, out_features=10) 112 | 113 | optimizer = torch.optim.Adam(model.parameters(), lr=5e-5) 114 | model, optimizer = fabric.setup(model, optimizer, move_to_device=False) 115 | 116 | ######################################### 117 | ### 3 Finetuning 118 | ######################################### 119 | 120 | start = time.time() 121 | train( 122 | num_epochs=1, 123 | model=model, 124 | optimizer=optimizer, 125 | train_loader=train_loader, 126 | val_loader=val_loader, 127 | fabric=fabric 128 | ) 129 | 130 | end = time.time() 131 | elapsed = end-start 132 | fabric.print(f"Time elapsed {elapsed/60:.2f} min") 133 | fabric.print(f"Memory used: {torch.cuda.max_memory_reserved() / 1e9:.02f} GB") 134 | 135 | ######################################### 136 | ### 4 Evaluation 137 | ######################################### 138 | 139 | with torch.no_grad(): 140 | model.eval() 141 | test_acc = torchmetrics.Accuracy(task="multiclass", num_classes=10).to(fabric.device) 142 | 143 | for (features, targets) in test_loader: 144 | outputs = model(features) 145 | predicted_labels = torch.argmax(outputs, 1) 146 | test_acc.update(predicted_labels, targets) 147 | 148 | fabric.print(f"Test accuracy {test_acc.compute()*100:.2f}%") -------------------------------------------------------------------------------- /11_delay-allocation.py: -------------------------------------------------------------------------------- 1 | import time 2 | from functools import partial 3 | 4 | import lightning as L 5 | from lightning import Fabric 6 | from lightning.fabric.strategies import FSDPStrategy 7 | import torch 8 | import torch.nn.functional as F 9 | from torch.distributed.fsdp.wrap import transformer_auto_wrap_policy 10 | import torchmetrics 11 | from torchvision import transforms 12 | from torchvision.models import vit_l_16 13 | from torchvision.models import ViT_L_16_Weights 14 | from torchvision.models.vision_transformer import EncoderBlock 15 | from watermark import watermark 16 | 17 | from local_utilities import get_dataloaders_cifar10 18 | 19 | 20 | def train(num_epochs, model, optimizer, train_loader, val_loader, fabric): 21 | 22 | for epoch in range(num_epochs): 23 | train_acc = torchmetrics.Accuracy(task="multiclass", num_classes=10).to(fabric.device) 24 | 25 | model.train() 26 | for batch_idx, (features, targets) in enumerate(train_loader): 27 | model.train() 28 | 29 | ### FORWARD AND BACK PROP 30 | logits = model(features) 31 | loss = F.cross_entropy(logits, targets) 32 | 33 | optimizer.zero_grad() 34 | fabric.backward(loss) 35 | 36 | ### UPDATE MODEL PARAMETERS 37 | optimizer.step() 38 | 39 | ### LOGGING 40 | if not batch_idx % 50: 41 | fabric.print(f"Epoch: {epoch+1:04d}/{num_epochs:04d} | Batch {batch_idx:04d}/{len(train_loader):04d} | Loss: {loss:.4f}") 42 | 43 | model.eval() 44 | with torch.no_grad(): 45 | predicted_labels = torch.argmax(logits, 1) 46 | train_acc.update(predicted_labels, targets) 47 | 48 | ### MORE LOGGING 49 | model.eval() 50 | with torch.no_grad(): 51 | val_acc = torchmetrics.Accuracy(task="multiclass", num_classes=10).to(fabric.device) 52 | 53 | for (features, targets) in val_loader: 54 | outputs = model(features) 55 | predicted_labels = torch.argmax(outputs, 1) 56 | val_acc.update(predicted_labels, targets) 57 | 58 | fabric.print(f"Epoch: {epoch+1:04d}/{num_epochs:04d} | Train acc.: {train_acc.compute()*100:.2f}% | Val acc.: {val_acc.compute()*100:.2f}%") 59 | train_acc.reset(), val_acc.reset() 60 | 61 | 62 | if __name__ == "__main__": 63 | 64 | auto_wrap_policy = partial(transformer_auto_wrap_policy, transformer_layer_cls={EncoderBlock}) 65 | strategy = FSDPStrategy( 66 | auto_wrap_policy=auto_wrap_policy, 67 | activation_checkpointing=EncoderBlock, 68 | cpu_offload=True 69 | ) 70 | 71 | fabric = Fabric(accelerator="cuda", devices=4, strategy=strategy) 72 | fabric.launch() 73 | 74 | L.seed_everything(123) 75 | fabric.print(watermark(packages="torch,lightning", python=True)) 76 | fabric.print("Torch CUDA available?", torch.cuda.is_available()) 77 | 78 | 79 | ########################## 80 | ### 1 Loading the Dataset 81 | ########################## 82 | train_transforms = transforms.Compose([transforms.Resize((224, 224)), 83 | #transforms.RandomCrop((224, 224)), 84 | transforms.ToTensor()]) 85 | 86 | test_transforms = transforms.Compose([transforms.Resize((224, 224)), 87 | #transforms.CenterCrop((224, 224)), 88 | transforms.ToTensor()]) 89 | 90 | train_loader, val_loader, test_loader = get_dataloaders_cifar10( 91 | batch_size=64, 92 | num_workers=1, 93 | train_transforms=train_transforms, 94 | test_transforms=test_transforms, 95 | validation_fraction=0.1) 96 | 97 | train_loader, val_loader, test_loader = fabric.setup_dataloaders( 98 | train_loader, val_loader, test_loader) 99 | 100 | 101 | ######################################### 102 | ### 2 Initializing the Model 103 | ######################################### 104 | 105 | with fabric.init_module(empty_init=False): 106 | model = vit_l_16(weights=ViT_L_16_Weights.IMAGENET1K_V1) 107 | # replace output layer 108 | model.heads.head = torch.nn.Linear(in_features=1024, out_features=10) 109 | 110 | optimizer = torch.optim.Adam(model.parameters(), lr=5e-5) 111 | model, optimizer = fabric.setup(model, optimizer, move_to_device=False) 112 | 113 | ######################################### 114 | ### 3 Finetuning 115 | ######################################### 116 | 117 | start = time.time() 118 | train( 119 | num_epochs=1, 120 | model=model, 121 | optimizer=optimizer, 122 | train_loader=train_loader, 123 | val_loader=val_loader, 124 | fabric=fabric 125 | ) 126 | 127 | end = time.time() 128 | elapsed = end-start 129 | fabric.print(f"Time elapsed {elapsed/60:.2f} min") 130 | fabric.print(f"Memory used: {torch.cuda.max_memory_reserved() / 1e9:.02f} GB") 131 | 132 | ######################################### 133 | ### 4 Evaluation 134 | ######################################### 135 | 136 | with torch.no_grad(): 137 | model.eval() 138 | test_acc = torchmetrics.Accuracy(task="multiclass", num_classes=10).to(fabric.device) 139 | 140 | for (features, targets) in test_loader: 141 | outputs = model(features) 142 | predicted_labels = torch.argmax(outputs, 1) 143 | test_acc.update(predicted_labels, targets) 144 | 145 | fabric.print(f"Test accuracy {test_acc.compute()*100:.2f}%") -------------------------------------------------------------------------------- /12_fsdp-overlap.py: -------------------------------------------------------------------------------- 1 | import time 2 | from functools import partial 3 | 4 | import lightning as L 5 | from lightning import Fabric 6 | from lightning.fabric.strategies import FSDPStrategy 7 | from lightning.fabric.strategies.fsdp import fsdp_overlap_step_with_backward 8 | 9 | import torch 10 | import torch.nn.functional as F 11 | from torch.distributed.fsdp.wrap import transformer_auto_wrap_policy 12 | import torchmetrics 13 | from torchvision import transforms 14 | from torchvision.models import vit_l_16 15 | from torchvision.models import ViT_L_16_Weights 16 | from torchvision.models.vision_transformer import EncoderBlock 17 | from watermark import watermark 18 | 19 | from local_utilities import get_dataloaders_cifar10 20 | 21 | 22 | def train(num_epochs, model, optimizers, train_loader, val_loader, fabric): 23 | 24 | for epoch in range(num_epochs): 25 | train_acc = torchmetrics.Accuracy(task="multiclass", num_classes=10).to(fabric.device) 26 | 27 | model.train() 28 | for batch_idx, (features, targets) in enumerate(train_loader): 29 | model.train() 30 | 31 | ### FORWARD AND BACK PROP 32 | logits = model(features) 33 | loss = F.cross_entropy(logits, targets) 34 | 35 | ### UPDATE MODEL PARAMETERS 36 | with fsdp_overlap_step_with_backward(optimizers, model): 37 | fabric.backward(loss) 38 | 39 | ### LOGGING 40 | if not batch_idx % 50: 41 | fabric.print(f"Epoch: {epoch+1:04d}/{num_epochs:04d} | Batch {batch_idx:04d}/{len(train_loader):04d} | Loss: {loss:.4f}") 42 | 43 | model.eval() 44 | with torch.no_grad(): 45 | predicted_labels = torch.argmax(logits, 1) 46 | train_acc.update(predicted_labels, targets) 47 | 48 | ### MORE LOGGING 49 | model.eval() 50 | with torch.no_grad(): 51 | val_acc = torchmetrics.Accuracy(task="multiclass", num_classes=10).to(fabric.device) 52 | 53 | for (features, targets) in val_loader: 54 | outputs = model(features) 55 | predicted_labels = torch.argmax(outputs, 1) 56 | val_acc.update(predicted_labels, targets) 57 | 58 | fabric.print(f"Epoch: {epoch+1:04d}/{num_epochs:04d} | Train acc.: {train_acc.compute()*100:.2f}% | Val acc.: {val_acc.compute()*100:.2f}%") 59 | train_acc.reset(), val_acc.reset() 60 | 61 | 62 | if __name__ == "__main__": 63 | 64 | auto_wrap_policy = partial(transformer_auto_wrap_policy, transformer_layer_cls={EncoderBlock}) 65 | strategy = FSDPStrategy( 66 | auto_wrap_policy=auto_wrap_policy, 67 | activation_checkpointing=EncoderBlock, 68 | #cpu_offload=True 69 | ) 70 | 71 | fabric = Fabric(accelerator="cuda", devices=4, strategy=strategy) 72 | fabric.launch() 73 | 74 | L.seed_everything(123) 75 | fabric.print(watermark(packages="torch,lightning", python=True)) 76 | fabric.print("Torch CUDA available?", torch.cuda.is_available()) 77 | 78 | 79 | ########################## 80 | ### 1 Loading the Dataset 81 | ########################## 82 | train_transforms = transforms.Compose([transforms.Resize((224, 224)), 83 | #transforms.RandomCrop((224, 224)), 84 | transforms.ToTensor()]) 85 | 86 | test_transforms = transforms.Compose([transforms.Resize((224, 224)), 87 | #transforms.CenterCrop((224, 224)), 88 | transforms.ToTensor()]) 89 | 90 | train_loader, val_loader, test_loader = get_dataloaders_cifar10( 91 | batch_size=64, 92 | num_workers=1, 93 | train_transforms=train_transforms, 94 | test_transforms=test_transforms, 95 | validation_fraction=0.1) 96 | 97 | train_loader, val_loader, test_loader = fabric.setup_dataloaders( 98 | train_loader, val_loader, test_loader) 99 | 100 | 101 | ######################################### 102 | ### 2 Initializing the Model 103 | ######################################### 104 | 105 | with fabric.init_module(empty_init=False): 106 | model = vit_l_16(weights=ViT_L_16_Weights.IMAGENET1K_V1) 107 | # replace output layer 108 | model.heads.head = torch.nn.Linear(in_features=1024, out_features=10) 109 | 110 | optimizers = [torch.optim.Adam([p], lr=5e-5) for p in model.parameters()] 111 | 112 | model = fabric.setup_module(model) 113 | optimizers = fabric.setup_optimizers(*optimizers) 114 | 115 | ######################################### 116 | ### 3 Finetuning 117 | ######################################### 118 | 119 | start = time.time() 120 | train( 121 | num_epochs=1, 122 | model=model, 123 | optimizers=optimizers, 124 | train_loader=train_loader, 125 | val_loader=val_loader, 126 | fabric=fabric 127 | ) 128 | 129 | end = time.time() 130 | elapsed = end-start 131 | fabric.print(f"Time elapsed {elapsed/60:.2f} min") 132 | fabric.print(f"Memory used: {torch.cuda.max_memory_reserved() / 1e9:.02f} GB") 133 | 134 | ######################################### 135 | ### 4 Evaluation 136 | ######################################### 137 | 138 | with torch.no_grad(): 139 | model.eval() 140 | test_acc = torchmetrics.Accuracy(task="multiclass", num_classes=10).to(fabric.device) 141 | 142 | for (features, targets) in test_loader: 143 | outputs = model(features) 144 | predicted_labels = torch.argmax(outputs, 1) 145 | test_acc.update(predicted_labels, targets) 146 | 147 | fabric.print(f"Test accuracy {test_acc.compute()*100:.2f}%") -------------------------------------------------------------------------------- /LICENSE.txt: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright 2023 Sebastian Raschka 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Optimizing PyTorch Memory Usage 2 | 3 | 4 | This code repository contains the code used for my "Optimizing Memory Usage for Training LLMs and Vision Transformers in PyTorch" blog post. 5 | 6 | You can install the dependencies via 7 | 8 | 9 | ```bash 10 | pip install -r requirements.txt 11 | ``` 12 | 13 | The scripts are all standalone scripts and can be run by executing 14 | 15 | ```bash 16 | python 01_pytorch-vit.py 17 | ``` 18 | 19 | and so forth. The only requirement is to have the [`local_utilities.py`](local_utilities.py) in the same folder as the script as it contains some data loading utilities that are reused across all the scripts. 20 | 21 | I tracked the script outputs in the [`logs.md`](logs.md) file. 22 | 23 | 24 | 25 | ![overview](figures/overview.png) 26 | -------------------------------------------------------------------------------- /bonus_bigbird-after.py: -------------------------------------------------------------------------------- 1 | 2 | import os 3 | import os.path as op 4 | import time 5 | from functools import partial 6 | 7 | 8 | from datasets import load_dataset 9 | from lightning import Fabric 10 | from lightning.fabric.strategies import FSDPStrategy 11 | 12 | import torch 13 | from torch.utils.data import DataLoader 14 | 15 | import torchmetrics 16 | from transformers import AutoTokenizer 17 | from transformers import AutoModelForSequenceClassification 18 | from watermark import watermark 19 | 20 | from local_utilities import download_dataset, load_dataset_into_to_dataframe, partition_dataset 21 | from local_utilities import IMDBDataset 22 | 23 | 24 | def tokenize_text(batch): 25 | return tokenizer(batch["text"], truncation=True, padding=True) 26 | 27 | 28 | def train(num_epochs, model, optimizer, train_loader, val_loader, fabric, accumulation_steps): 29 | 30 | for epoch in range(num_epochs): 31 | train_acc = torchmetrics.Accuracy(task="multiclass", num_classes=2).to(fabric.device) 32 | 33 | model.train() 34 | for batch_idx, batch in enumerate(train_loader): 35 | model.train() 36 | 37 | ### FORWARD AND BACK PROP 38 | outputs = model(batch["input_ids"], attention_mask=batch["attention_mask"], labels=batch["label"]) 39 | outputs["loss"] /= accumulation_steps 40 | fabric.backward(outputs["loss"]) 41 | 42 | ### UPDATE MODEL PARAMETERS 43 | if batch_idx % accumulation_steps == 0: # NEW 44 | optimizer.step() 45 | optimizer.zero_grad() 46 | 47 | ### LOGGING 48 | if not batch_idx % 300: 49 | fabric.print(f"Epoch: {epoch+1:04d}/{num_epochs:04d} | Batch {batch_idx:04d}/{len(train_loader):04d} | Loss: {outputs['loss']:.4f}") 50 | 51 | model.eval() 52 | with torch.no_grad(): 53 | predicted_labels = torch.argmax(outputs["logits"], 1) 54 | train_acc.update(predicted_labels, batch["label"]) 55 | 56 | ### MORE LOGGING 57 | model.eval() 58 | with torch.no_grad(): 59 | val_acc = torchmetrics.Accuracy(task="multiclass", num_classes=2).to(fabric.device) 60 | for batch in val_loader: 61 | outputs = model(batch["input_ids"], attention_mask=batch["attention_mask"], labels=batch["label"]) 62 | predicted_labels = torch.argmax(outputs["logits"], 1) 63 | val_acc.update(predicted_labels, batch["label"]) 64 | 65 | fabric.print(f"Epoch: {epoch+1:04d}/{num_epochs:04d} | Train acc.: {train_acc.compute()*100:.2f}% | Val acc.: {val_acc.compute()*100:.2f}%") 66 | train_acc.reset(), val_acc.reset() 67 | 68 | 69 | if __name__ == "__main__": 70 | 71 | print(watermark(packages="torch,lightning,transformers", python=True)) 72 | print("Torch CUDA available?", torch.cuda.is_available()) 73 | device = "cuda" if torch.cuda.is_available() else "cpu" 74 | 75 | torch.manual_seed(123) 76 | 77 | ########################## 78 | ### 1 Loading the Dataset 79 | ########################## 80 | download_dataset() 81 | df = load_dataset_into_to_dataframe() 82 | if not (op.exists("train.csv") and op.exists("val.csv") and op.exists("test.csv")): 83 | partition_dataset(df) 84 | 85 | imdb_dataset = load_dataset( 86 | "csv", 87 | data_files={ 88 | "train": "train.csv", 89 | "validation": "val.csv", 90 | "test": "test.csv", 91 | }, 92 | ) 93 | 94 | ######################################### 95 | ### 2 Tokenization and Numericalization 96 | ######################################### 97 | 98 | tokenizer = AutoTokenizer.from_pretrained("google/bigbird-roberta-base") 99 | print("Tokenizer input max length:", tokenizer.model_max_length, flush=True) 100 | print("Tokenizer vocabulary size:", tokenizer.vocab_size, flush=True) 101 | 102 | print("Tokenizing ...", flush=True) 103 | imdb_tokenized = imdb_dataset.map(tokenize_text, batched=True, batch_size=None) 104 | del imdb_dataset 105 | imdb_tokenized.set_format("torch", columns=["input_ids", "attention_mask", "label"]) 106 | os.environ["TOKENIZERS_PARALLELISM"] = "false" 107 | 108 | ######################################### 109 | ### 3 Set Up DataLoaders 110 | ######################################### 111 | 112 | BATCHSIZE = 12 113 | ACCUMULATION_STEPS = 12 114 | MICROBATCHSIZE = int(BATCHSIZE / ACCUMULATION_STEPS) 115 | 116 | train_dataset = IMDBDataset(imdb_tokenized, partition_key="train") 117 | val_dataset = IMDBDataset(imdb_tokenized, partition_key="validation") 118 | test_dataset = IMDBDataset(imdb_tokenized, partition_key="test") 119 | 120 | train_loader = DataLoader( 121 | dataset=train_dataset, 122 | batch_size=MICROBATCHSIZE, 123 | shuffle=True, 124 | num_workers=1, 125 | drop_last=True, 126 | ) 127 | 128 | val_loader = DataLoader( 129 | dataset=val_dataset, 130 | batch_size=MICROBATCHSIZE, 131 | num_workers=1, 132 | drop_last=True, 133 | ) 134 | 135 | test_loader = DataLoader( 136 | dataset=test_dataset, 137 | batch_size=MICROBATCHSIZE, 138 | num_workers=1, 139 | drop_last=True, 140 | ) 141 | 142 | 143 | ######################################### 144 | ### 4 Initializing the Model 145 | ######################################### 146 | 147 | torch.set_float32_matmul_precision('medium') 148 | 149 | strategy = FSDPStrategy( 150 | cpu_offload=True 151 | ) 152 | 153 | fabric = Fabric(accelerator="cuda", devices=4, strategy=strategy, precision="bf16-true") 154 | fabric.launch() 155 | 156 | with fabric.init_module(empty_init=False): 157 | model = AutoModelForSequenceClassification.from_pretrained( 158 | "google/bigbird-roberta-base", num_labels=2) 159 | 160 | optimizer = torch.optim.Adam(model.parameters(), lr=5e-5) 161 | 162 | model, optimizer = fabric.setup(model, optimizer, move_to_device=False) 163 | train_loader, val_loader, test_loader = fabric.setup_dataloaders(train_loader, val_loader, test_loader) 164 | fabric.barrier() 165 | 166 | ######################################### 167 | ### 5 Finetuning 168 | ######################################### 169 | 170 | start = time.time() 171 | train( 172 | num_epochs=1, 173 | model=model, 174 | optimizer=optimizer, 175 | train_loader=train_loader, 176 | val_loader=val_loader, 177 | fabric=fabric, 178 | accumulation_steps=ACCUMULATION_STEPS 179 | ) 180 | 181 | end = time.time() 182 | elapsed = end-start 183 | 184 | with torch.no_grad(): 185 | model.eval() 186 | test_acc = torchmetrics.Accuracy(task="multiclass", num_classes=2).to(fabric.device) 187 | for batch in test_loader: 188 | outputs = model(batch["input_ids"], attention_mask=batch["attention_mask"], labels=batch["label"]) 189 | predicted_labels = torch.argmax(outputs["logits"], 1) 190 | test_acc.update(predicted_labels, batch["label"]) 191 | 192 | fabric.print(f"Test accuracy: {test_acc.compute()*100:.2f}%") 193 | fabric.print(f"Total training time: {elapsed/60:.2f} min") 194 | fabric.print(f"Memory used: {torch.cuda.max_memory_reserved() / 1e9:.02f} GB") -------------------------------------------------------------------------------- /bonus_bigbird-before.py: -------------------------------------------------------------------------------- 1 | 2 | import os 3 | import os.path as op 4 | import time 5 | 6 | from datasets import load_dataset 7 | from lightning import Fabric 8 | import torch 9 | from torch.utils.data import DataLoader 10 | import torchmetrics 11 | from transformers import AutoTokenizer 12 | from transformers import AutoModelForSequenceClassification 13 | from watermark import watermark 14 | 15 | from local_utilities import download_dataset, load_dataset_into_to_dataframe, partition_dataset 16 | from local_utilities import IMDBDataset 17 | 18 | 19 | def tokenize_text(batch): 20 | return tokenizer(batch["text"], truncation=True, padding=True) 21 | 22 | 23 | def train(num_epochs, model, optimizer, train_loader, val_loader, fabric): 24 | 25 | for epoch in range(num_epochs): 26 | train_acc = torchmetrics.Accuracy(task="multiclass", num_classes=2).to(fabric.device) 27 | 28 | model.train() 29 | for batch_idx, batch in enumerate(train_loader): 30 | model.train() 31 | 32 | ### FORWARD AND BACK PROP 33 | outputs = model(batch["input_ids"], attention_mask=batch["attention_mask"], labels=batch["label"]) 34 | optimizer.zero_grad() 35 | fabric.backward(outputs["loss"]) 36 | 37 | ### UPDATE MODEL PARAMETERS 38 | optimizer.step() 39 | 40 | ### LOGGING 41 | if not batch_idx % 300: 42 | fabric.print(f"Epoch: {epoch+1:04d}/{num_epochs:04d} | Batch {batch_idx:04d}/{len(train_loader):04d} | Loss: {outputs['loss']:.4f}") 43 | 44 | model.eval() 45 | with torch.no_grad(): 46 | predicted_labels = torch.argmax(outputs["logits"], 1) 47 | train_acc.update(predicted_labels, batch["label"]) 48 | 49 | ### MORE LOGGING 50 | model.eval() 51 | with torch.no_grad(): 52 | val_acc = torchmetrics.Accuracy(task="multiclass", num_classes=2).to(fabric.device) 53 | for batch in val_loader: 54 | outputs = model(batch["input_ids"], attention_mask=batch["attention_mask"], labels=batch["label"]) 55 | predicted_labels = torch.argmax(outputs["logits"], 1) 56 | val_acc.update(predicted_labels, batch["label"]) 57 | 58 | fabric.print(f"Epoch: {epoch+1:04d}/{num_epochs:04d} | Train acc.: {train_acc.compute()*100:.2f}% | Val acc.: {val_acc.compute()*100:.2f}%") 59 | train_acc.reset(), val_acc.reset() 60 | 61 | 62 | if __name__ == "__main__": 63 | 64 | print(watermark(packages="torch,lightning,transformers", python=True)) 65 | print("Torch CUDA available?", torch.cuda.is_available()) 66 | device = "cuda" if torch.cuda.is_available() else "cpu" 67 | 68 | torch.manual_seed(123) 69 | 70 | ########################## 71 | ### 1 Loading the Dataset 72 | ########################## 73 | download_dataset() 74 | df = load_dataset_into_to_dataframe() 75 | if not (op.exists("train.csv") and op.exists("val.csv") and op.exists("test.csv")): 76 | partition_dataset(df) 77 | 78 | imdb_dataset = load_dataset( 79 | "csv", 80 | data_files={ 81 | "train": "train.csv", 82 | "validation": "val.csv", 83 | "test": "test.csv", 84 | }, 85 | ) 86 | 87 | ######################################### 88 | ### 2 Tokenization and Numericalization 89 | ######################################### 90 | 91 | tokenizer = AutoTokenizer.from_pretrained("google/bigbird-roberta-base") 92 | print("Tokenizer input max length:", tokenizer.model_max_length, flush=True) 93 | print("Tokenizer vocabulary size:", tokenizer.vocab_size, flush=True) 94 | 95 | print("Tokenizing ...", flush=True) 96 | imdb_tokenized = imdb_dataset.map(tokenize_text, batched=True, batch_size=None) 97 | del imdb_dataset 98 | imdb_tokenized.set_format("torch", columns=["input_ids", "attention_mask", "label"]) 99 | os.environ["TOKENIZERS_PARALLELISM"] = "false" 100 | 101 | ######################################### 102 | ### 3 Set Up DataLoaders 103 | ######################################### 104 | 105 | train_dataset = IMDBDataset(imdb_tokenized, partition_key="train") 106 | val_dataset = IMDBDataset(imdb_tokenized, partition_key="validation") 107 | test_dataset = IMDBDataset(imdb_tokenized, partition_key="test") 108 | 109 | train_loader = DataLoader( 110 | dataset=train_dataset, 111 | batch_size=12, 112 | shuffle=True, 113 | num_workers=4, 114 | drop_last=True, 115 | ) 116 | 117 | val_loader = DataLoader( 118 | dataset=val_dataset, 119 | batch_size=12, 120 | num_workers=4, 121 | drop_last=True, 122 | ) 123 | 124 | test_loader = DataLoader( 125 | dataset=test_dataset, 126 | batch_size=12, 127 | num_workers=2, 128 | drop_last=True, 129 | ) 130 | 131 | 132 | ######################################### 133 | ### 4 Initializing the Model 134 | ######################################### 135 | 136 | torch.set_float32_matmul_precision('high') 137 | 138 | fabric = Fabric(accelerator="cuda", devices=4, strategy="ddp") 139 | fabric.launch() 140 | 141 | model = AutoModelForSequenceClassification.from_pretrained( 142 | "google/bigbird-roberta-base", num_labels=2) 143 | 144 | optimizer = torch.optim.Adam(model.parameters(), lr=5e-5) 145 | 146 | model, optimizer = fabric.setup(model, optimizer) 147 | train_loader, val_loader, test_loader = fabric.setup_dataloaders(train_loader, val_loader, test_loader) 148 | fabric.barrier() 149 | 150 | ######################################### 151 | ### 5 Finetuning 152 | ######################################### 153 | 154 | start = time.time() 155 | train( 156 | num_epochs=1, 157 | model=model, 158 | optimizer=optimizer, 159 | train_loader=train_loader, 160 | val_loader=val_loader, 161 | fabric=fabric 162 | ) 163 | 164 | end = time.time() 165 | elapsed = end-start 166 | 167 | with torch.no_grad(): 168 | model.eval() 169 | test_acc = torchmetrics.Accuracy(task="multiclass", num_classes=2).to(fabric.device) 170 | for batch in test_loader: 171 | outputs = model(batch["input_ids"], attention_mask=batch["attention_mask"], labels=batch["label"]) 172 | predicted_labels = torch.argmax(outputs["logits"], 1) 173 | test_acc.update(predicted_labels, batch["label"]) 174 | 175 | fabric.print(f"Test accuracy: {test_acc.compute()*100:.2f}%") 176 | fabric.print(f"Total training time: {elapsed/60:.2f} min") 177 | fabric.print(f"Memory used: {torch.cuda.max_memory_reserved() / 1e9:.02f} GB") -------------------------------------------------------------------------------- /bonus_distilbert-after.py: -------------------------------------------------------------------------------- 1 | 2 | import os 3 | import os.path as op 4 | import time 5 | from functools import partial 6 | 7 | 8 | from datasets import load_dataset 9 | from lightning import Fabric 10 | from lightning.fabric.strategies import FSDPStrategy 11 | 12 | import torch 13 | from torch.utils.data import DataLoader 14 | from torch.distributed.fsdp.wrap import transformer_auto_wrap_policy 15 | 16 | import torchmetrics 17 | from transformers import AutoTokenizer 18 | from transformers import AutoModelForSequenceClassification 19 | from transformers.models.distilbert.modeling_distilbert import TransformerBlock 20 | from watermark import watermark 21 | 22 | from local_utilities import download_dataset, load_dataset_into_to_dataframe, partition_dataset 23 | from local_utilities import IMDBDataset 24 | 25 | 26 | def tokenize_text(batch): 27 | return tokenizer(batch["text"], truncation=True, padding=True) 28 | 29 | 30 | def train(num_epochs, model, optimizer, train_loader, val_loader, fabric, accumulation_steps): 31 | 32 | for epoch in range(num_epochs): 33 | train_acc = torchmetrics.Accuracy(task="multiclass", num_classes=2).to(fabric.device) 34 | 35 | model.train() 36 | for batch_idx, batch in enumerate(train_loader): 37 | model.train() 38 | 39 | ### FORWARD AND BACK PROP 40 | outputs = model(batch["input_ids"], attention_mask=batch["attention_mask"], labels=batch["label"]) 41 | outputs["loss"] /= accumulation_steps 42 | fabric.backward(outputs["loss"]) 43 | 44 | ### UPDATE MODEL PARAMETERS 45 | if batch_idx % accumulation_steps == 0: # NEW 46 | optimizer.step() 47 | optimizer.zero_grad() 48 | 49 | ### LOGGING 50 | if not batch_idx % 300: 51 | fabric.print(f"Epoch: {epoch+1:04d}/{num_epochs:04d} | Batch {batch_idx:04d}/{len(train_loader):04d} | Loss: {outputs['loss']:.4f}") 52 | 53 | model.eval() 54 | with torch.no_grad(): 55 | predicted_labels = torch.argmax(outputs["logits"], 1) 56 | train_acc.update(predicted_labels, batch["label"]) 57 | 58 | ### MORE LOGGING 59 | model.eval() 60 | with torch.no_grad(): 61 | val_acc = torchmetrics.Accuracy(task="multiclass", num_classes=2).to(fabric.device) 62 | for batch in val_loader: 63 | outputs = model(batch["input_ids"], attention_mask=batch["attention_mask"], labels=batch["label"]) 64 | predicted_labels = torch.argmax(outputs["logits"], 1) 65 | val_acc.update(predicted_labels, batch["label"]) 66 | 67 | fabric.print(f"Epoch: {epoch+1:04d}/{num_epochs:04d} | Train acc.: {train_acc.compute()*100:.2f}% | Val acc.: {val_acc.compute()*100:.2f}%") 68 | train_acc.reset(), val_acc.reset() 69 | 70 | 71 | if __name__ == "__main__": 72 | 73 | print(watermark(packages="torch,lightning,transformers", python=True)) 74 | print("Torch CUDA available?", torch.cuda.is_available()) 75 | device = "cuda" if torch.cuda.is_available() else "cpu" 76 | 77 | torch.manual_seed(123) 78 | 79 | ########################## 80 | ### 1 Loading the Dataset 81 | ########################## 82 | download_dataset() 83 | df = load_dataset_into_to_dataframe() 84 | if not (op.exists("train.csv") and op.exists("val.csv") and op.exists("test.csv")): 85 | partition_dataset(df) 86 | 87 | imdb_dataset = load_dataset( 88 | "csv", 89 | data_files={ 90 | "train": "train.csv", 91 | "validation": "val.csv", 92 | "test": "test.csv", 93 | }, 94 | ) 95 | 96 | ######################################### 97 | ### 2 Tokenization and Numericalization 98 | ######################################### 99 | 100 | tokenizer = AutoTokenizer.from_pretrained("distilbert-base-uncased") 101 | print("Tokenizer input max length:", tokenizer.model_max_length, flush=True) 102 | print("Tokenizer vocabulary size:", tokenizer.vocab_size, flush=True) 103 | 104 | print("Tokenizing ...", flush=True) 105 | imdb_tokenized = imdb_dataset.map(tokenize_text, batched=True, batch_size=None) 106 | del imdb_dataset 107 | imdb_tokenized.set_format("torch", columns=["input_ids", "attention_mask", "label"]) 108 | os.environ["TOKENIZERS_PARALLELISM"] = "false" 109 | 110 | ######################################### 111 | ### 3 Set Up DataLoaders 112 | ######################################### 113 | 114 | BATCHSIZE = 12 115 | ACCUMULATION_STEPS = 4 116 | MICROBATCHSIZE = int(BATCHSIZE / ACCUMULATION_STEPS) 117 | 118 | train_dataset = IMDBDataset(imdb_tokenized, partition_key="train") 119 | val_dataset = IMDBDataset(imdb_tokenized, partition_key="validation") 120 | test_dataset = IMDBDataset(imdb_tokenized, partition_key="test") 121 | 122 | train_loader = DataLoader( 123 | dataset=train_dataset, 124 | batch_size=MICROBATCHSIZE, 125 | shuffle=True, 126 | num_workers=1, 127 | drop_last=True, 128 | ) 129 | 130 | val_loader = DataLoader( 131 | dataset=val_dataset, 132 | batch_size=MICROBATCHSIZE, 133 | num_workers=1, 134 | drop_last=True, 135 | ) 136 | 137 | test_loader = DataLoader( 138 | dataset=test_dataset, 139 | batch_size=MICROBATCHSIZE, 140 | num_workers=1, 141 | drop_last=True, 142 | ) 143 | 144 | 145 | ######################################### 146 | ### 4 Initializing the Model 147 | ######################################### 148 | 149 | torch.set_float32_matmul_precision('medium') 150 | 151 | auto_wrap_policy = partial(transformer_auto_wrap_policy, transformer_layer_cls={TransformerBlock}) 152 | strategy = FSDPStrategy( 153 | auto_wrap_policy=auto_wrap_policy, 154 | cpu_offload=True 155 | ) 156 | 157 | fabric = Fabric(accelerator="cuda", devices=4, strategy=strategy, precision="bf16-true") 158 | fabric.launch() 159 | 160 | with fabric.init_module(empty_init=False): 161 | model = AutoModelForSequenceClassification.from_pretrained( 162 | "distilbert-base-uncased", num_labels=2) 163 | 164 | optimizer = torch.optim.Adam(model.parameters(), lr=5e-5) 165 | 166 | model, optimizer = fabric.setup(model, optimizer, move_to_device=False) 167 | train_loader, val_loader, test_loader = fabric.setup_dataloaders(train_loader, val_loader, test_loader) 168 | fabric.barrier() 169 | 170 | ######################################### 171 | ### 5 Finetuning 172 | ######################################### 173 | 174 | start = time.time() 175 | train( 176 | num_epochs=1, 177 | model=model, 178 | optimizer=optimizer, 179 | train_loader=train_loader, 180 | val_loader=val_loader, 181 | fabric=fabric, 182 | accumulation_steps=ACCUMULATION_STEPS 183 | ) 184 | 185 | end = time.time() 186 | elapsed = end-start 187 | 188 | with torch.no_grad(): 189 | model.eval() 190 | test_acc = torchmetrics.Accuracy(task="multiclass", num_classes=2).to(fabric.device) 191 | for batch in test_loader: 192 | outputs = model(batch["input_ids"], attention_mask=batch["attention_mask"], labels=batch["label"]) 193 | predicted_labels = torch.argmax(outputs["logits"], 1) 194 | test_acc.update(predicted_labels, batch["label"]) 195 | 196 | fabric.print(f"Test accuracy: {test_acc.compute()*100:.2f}%") 197 | fabric.print(f"Total training time: {elapsed/60:.2f} min") 198 | fabric.print(f"Memory used: {torch.cuda.max_memory_reserved() / 1e9:.02f} GB") -------------------------------------------------------------------------------- /bonus_distilbert-before.py: -------------------------------------------------------------------------------- 1 | 2 | import os 3 | import os.path as op 4 | import time 5 | 6 | from datasets import load_dataset 7 | from lightning import Fabric 8 | import torch 9 | from torch.utils.data import DataLoader 10 | import torchmetrics 11 | from transformers import AutoTokenizer 12 | from transformers import AutoModelForSequenceClassification 13 | from watermark import watermark 14 | 15 | from local_utilities import download_dataset, load_dataset_into_to_dataframe, partition_dataset 16 | from local_utilities import IMDBDataset 17 | 18 | 19 | def tokenize_text(batch): 20 | return tokenizer(batch["text"], truncation=True, padding=True) 21 | 22 | 23 | def train(num_epochs, model, optimizer, train_loader, val_loader, fabric): 24 | 25 | for epoch in range(num_epochs): 26 | train_acc = torchmetrics.Accuracy(task="multiclass", num_classes=2).to(fabric.device) 27 | 28 | model.train() 29 | for batch_idx, batch in enumerate(train_loader): 30 | model.train() 31 | 32 | ### FORWARD AND BACK PROP 33 | outputs = model(batch["input_ids"], attention_mask=batch["attention_mask"], labels=batch["label"]) 34 | optimizer.zero_grad() 35 | fabric.backward(outputs["loss"]) 36 | 37 | ### UPDATE MODEL PARAMETERS 38 | optimizer.step() 39 | 40 | ### LOGGING 41 | if not batch_idx % 300: 42 | fabric.print(f"Epoch: {epoch+1:04d}/{num_epochs:04d} | Batch {batch_idx:04d}/{len(train_loader):04d} | Loss: {outputs['loss']:.4f}") 43 | 44 | model.eval() 45 | with torch.no_grad(): 46 | predicted_labels = torch.argmax(outputs["logits"], 1) 47 | train_acc.update(predicted_labels, batch["label"]) 48 | 49 | ### MORE LOGGING 50 | model.eval() 51 | with torch.no_grad(): 52 | val_acc = torchmetrics.Accuracy(task="multiclass", num_classes=2).to(fabric.device) 53 | for batch in val_loader: 54 | outputs = model(batch["input_ids"], attention_mask=batch["attention_mask"], labels=batch["label"]) 55 | predicted_labels = torch.argmax(outputs["logits"], 1) 56 | val_acc.update(predicted_labels, batch["label"]) 57 | 58 | fabric.print(f"Epoch: {epoch+1:04d}/{num_epochs:04d} | Train acc.: {train_acc.compute()*100:.2f}% | Val acc.: {val_acc.compute()*100:.2f}%") 59 | train_acc.reset(), val_acc.reset() 60 | 61 | 62 | if __name__ == "__main__": 63 | 64 | print(watermark(packages="torch,lightning,transformers", python=True)) 65 | print("Torch CUDA available?", torch.cuda.is_available()) 66 | device = "cuda" if torch.cuda.is_available() else "cpu" 67 | 68 | torch.manual_seed(123) 69 | 70 | ########################## 71 | ### 1 Loading the Dataset 72 | ########################## 73 | download_dataset() 74 | df = load_dataset_into_to_dataframe() 75 | if not (op.exists("train.csv") and op.exists("val.csv") and op.exists("test.csv")): 76 | partition_dataset(df) 77 | 78 | imdb_dataset = load_dataset( 79 | "csv", 80 | data_files={ 81 | "train": "train.csv", 82 | "validation": "val.csv", 83 | "test": "test.csv", 84 | }, 85 | ) 86 | 87 | ######################################### 88 | ### 2 Tokenization and Numericalization 89 | ######################################### 90 | 91 | tokenizer = AutoTokenizer.from_pretrained("distilbert-base-uncased") 92 | print("Tokenizer input max length:", tokenizer.model_max_length, flush=True) 93 | print("Tokenizer vocabulary size:", tokenizer.vocab_size, flush=True) 94 | 95 | print("Tokenizing ...", flush=True) 96 | imdb_tokenized = imdb_dataset.map(tokenize_text, batched=True, batch_size=None) 97 | del imdb_dataset 98 | imdb_tokenized.set_format("torch", columns=["input_ids", "attention_mask", "label"]) 99 | os.environ["TOKENIZERS_PARALLELISM"] = "false" 100 | 101 | ######################################### 102 | ### 3 Set Up DataLoaders 103 | ######################################### 104 | 105 | train_dataset = IMDBDataset(imdb_tokenized, partition_key="train") 106 | val_dataset = IMDBDataset(imdb_tokenized, partition_key="validation") 107 | test_dataset = IMDBDataset(imdb_tokenized, partition_key="test") 108 | 109 | train_loader = DataLoader( 110 | dataset=train_dataset, 111 | batch_size=12, 112 | shuffle=True, 113 | num_workers=4, 114 | drop_last=True, 115 | ) 116 | 117 | val_loader = DataLoader( 118 | dataset=val_dataset, 119 | batch_size=12, 120 | num_workers=4, 121 | drop_last=True, 122 | ) 123 | 124 | test_loader = DataLoader( 125 | dataset=test_dataset, 126 | batch_size=12, 127 | num_workers=2, 128 | drop_last=True, 129 | ) 130 | 131 | 132 | ######################################### 133 | ### 4 Initializing the Model 134 | ######################################### 135 | 136 | torch.set_float32_matmul_precision('high') 137 | 138 | fabric = Fabric(accelerator="cuda", devices=4, strategy="deepspeed_stage_2", precision="16-mixed") 139 | fabric.launch() 140 | 141 | model = AutoModelForSequenceClassification.from_pretrained( 142 | "distilbert-base-uncased", num_labels=2) 143 | 144 | optimizer = torch.optim.Adam(model.parameters(), lr=5e-5) 145 | 146 | model, optimizer = fabric.setup(model, optimizer) 147 | train_loader, val_loader, test_loader = fabric.setup_dataloaders(train_loader, val_loader, test_loader) 148 | fabric.barrier() 149 | 150 | ######################################### 151 | ### 5 Finetuning 152 | ######################################### 153 | 154 | start = time.time() 155 | train( 156 | num_epochs=1, 157 | model=model, 158 | optimizer=optimizer, 159 | train_loader=train_loader, 160 | val_loader=val_loader, 161 | fabric=fabric 162 | ) 163 | 164 | end = time.time() 165 | elapsed = end-start 166 | 167 | with torch.no_grad(): 168 | model.eval() 169 | test_acc = torchmetrics.Accuracy(task="multiclass", num_classes=2).to(fabric.device) 170 | for batch in test_loader: 171 | outputs = model(batch["input_ids"], attention_mask=batch["attention_mask"], labels=batch["label"]) 172 | predicted_labels = torch.argmax(outputs["logits"], 1) 173 | test_acc.update(predicted_labels, batch["label"]) 174 | 175 | fabric.print(f"Test accuracy: {test_acc.compute()*100:.2f}%") 176 | fabric.print(f"Total training time: {elapsed/60:.2f} min") 177 | fabric.print(f"Memory used: {torch.cuda.max_memory_reserved() / 1e9:.02f} GB") -------------------------------------------------------------------------------- /figures/overview.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rasbt/pytorch-memory-optim/f6eddcd28c40770a097ffd9f18797a043457a2d0/figures/overview.png -------------------------------------------------------------------------------- /local_utilities.py: -------------------------------------------------------------------------------- 1 | # Imports for ViT finetuning 2 | 3 | import torch 4 | from torch.utils.data import sampler 5 | from torchvision import datasets 6 | from torch.utils.data import DataLoader 7 | from torch.utils.data import SubsetRandomSampler 8 | from torchvision import transforms 9 | 10 | # Import for LLM finetuning 11 | import os 12 | import sys 13 | import tarfile 14 | import time 15 | 16 | import numpy as np 17 | import pandas as pd 18 | from packaging import version 19 | from torch.utils.data import Dataset 20 | from tqdm import tqdm 21 | import urllib 22 | 23 | ############################ 24 | ##### VIT finetuning dataset 25 | ############################ 26 | 27 | 28 | def get_dataloaders_cifar10(batch_size, num_workers=0, 29 | validation_fraction=None, 30 | train_transforms=None, 31 | test_transforms=None): 32 | 33 | if train_transforms is None: 34 | train_transforms = transforms.ToTensor() 35 | 36 | if test_transforms is None: 37 | test_transforms = transforms.ToTensor() 38 | 39 | train_dataset = datasets.CIFAR10(root='data', 40 | train=True, 41 | transform=train_transforms, 42 | download=True) 43 | 44 | valid_dataset = datasets.CIFAR10(root='data', 45 | train=True, 46 | transform=test_transforms) 47 | 48 | test_dataset = datasets.CIFAR10(root='data', 49 | train=False, 50 | transform=test_transforms) 51 | 52 | if validation_fraction is not None: 53 | num = int(validation_fraction * 50000) 54 | train_indices = torch.arange(0, 50000 - num) 55 | valid_indices = torch.arange(50000 - num, 50000) 56 | 57 | train_sampler = SubsetRandomSampler(train_indices) 58 | valid_sampler = SubsetRandomSampler(valid_indices) 59 | 60 | valid_loader = DataLoader(dataset=valid_dataset, 61 | batch_size=batch_size, 62 | num_workers=num_workers, 63 | sampler=valid_sampler) 64 | 65 | train_loader = DataLoader(dataset=train_dataset, 66 | batch_size=batch_size, 67 | num_workers=num_workers, 68 | drop_last=True, 69 | sampler=train_sampler) 70 | 71 | else: 72 | train_loader = DataLoader(dataset=train_dataset, 73 | batch_size=batch_size, 74 | num_workers=num_workers, 75 | drop_last=True, 76 | shuffle=True) 77 | 78 | test_loader = DataLoader(dataset=test_dataset, 79 | batch_size=batch_size, 80 | num_workers=num_workers, 81 | shuffle=False) 82 | 83 | if validation_fraction is None: 84 | return train_loader, test_loader 85 | else: 86 | return train_loader, valid_loader, test_loader 87 | 88 | ############################ 89 | ##### LLM finetuning dataset 90 | ############################ 91 | 92 | import os 93 | import sys 94 | import tarfile 95 | import time 96 | 97 | import numpy as np 98 | import pandas as pd 99 | from packaging import version 100 | from torch.utils.data import Dataset 101 | from tqdm import tqdm 102 | import urllib 103 | 104 | 105 | def reporthook(count, block_size, total_size): 106 | global start_time 107 | if count == 0: 108 | start_time = time.time() 109 | return 110 | duration = time.time() - start_time 111 | progress_size = int(count * block_size) 112 | speed = progress_size / (1024.0**2 * duration) 113 | percent = count * block_size * 100.0 / total_size 114 | 115 | sys.stdout.write( 116 | f"\r{int(percent)}% | {progress_size / (1024.**2):.2f} MB " 117 | f"| {speed:.2f} MB/s | {duration:.2f} sec elapsed" 118 | ) 119 | sys.stdout.flush() 120 | 121 | 122 | def download_dataset(): 123 | source = "http://ai.stanford.edu/~amaas/data/sentiment/aclImdb_v1.tar.gz" 124 | target = "aclImdb_v1.tar.gz" 125 | 126 | if os.path.exists(target): 127 | os.remove(target) 128 | 129 | if not os.path.isdir("aclImdb") and not os.path.isfile("aclImdb_v1.tar.gz"): 130 | urllib.request.urlretrieve(source, target, reporthook) 131 | 132 | if not os.path.isdir("aclImdb"): 133 | 134 | with tarfile.open(target, "r:gz") as tar: 135 | tar.extractall() 136 | 137 | 138 | def load_dataset_into_to_dataframe(): 139 | basepath = "aclImdb" 140 | 141 | labels = {"pos": 1, "neg": 0} 142 | 143 | df = pd.DataFrame() 144 | 145 | with tqdm(total=50000) as pbar: 146 | for s in ("test", "train"): 147 | for l in ("pos", "neg"): 148 | path = os.path.join(basepath, s, l) 149 | for file in sorted(os.listdir(path)): 150 | with open(os.path.join(path, file), "r", encoding="utf-8") as infile: 151 | txt = infile.read() 152 | 153 | if version.parse(pd.__version__) >= version.parse("1.3.2"): 154 | x = pd.DataFrame( 155 | [[txt, labels[l]]], columns=["review", "sentiment"] 156 | ) 157 | df = pd.concat([df, x], ignore_index=False) 158 | 159 | else: 160 | df = df.append([[txt, labels[l]]], ignore_index=True) 161 | pbar.update() 162 | df.columns = ["text", "label"] 163 | 164 | np.random.seed(0) 165 | df = df.reindex(np.random.permutation(df.index)) 166 | 167 | print("Class distribution:") 168 | np.bincount(df["label"].values) 169 | 170 | return df 171 | 172 | 173 | def partition_dataset(df): 174 | df_shuffled = df.sample(frac=1, random_state=1).reset_index() 175 | 176 | df_train = df_shuffled.iloc[:35_000] 177 | df_val = df_shuffled.iloc[35_000:40_000] 178 | df_test = df_shuffled.iloc[40_000:] 179 | 180 | df_train.to_csv("train.csv", index=False, encoding="utf-8") 181 | df_val.to_csv("val.csv", index=False, encoding="utf-8") 182 | df_test.to_csv("test.csv", index=False, encoding="utf-8") 183 | 184 | 185 | class IMDBDataset(Dataset): 186 | def __init__(self, dataset_dict, partition_key="train"): 187 | self.partition = dataset_dict[partition_key] 188 | 189 | def __getitem__(self, index): 190 | return self.partition[index] 191 | 192 | def __len__(self): 193 | return self.partition.num_rows -------------------------------------------------------------------------------- /logs.md: -------------------------------------------------------------------------------- 1 | 2 | # 01_pytorch-vit.py 3 | 4 | - regular PyTorch script 5 | - batch size 32 6 | 7 | ``` 8 | torch : 2.1.0.dev20230623+cu118 9 | lightning: 2.1.0.dev0 10 | 11 | Torch CUDA available? True 12 | Global seed set to 123 13 | Files already downloaded and verified 14 | Epoch: 0001/0001 | Batch 0000/0703 | Loss: 2.4105 15 | Epoch: 0001/0001 | Batch 0300/0703 | Loss: 0.0640 16 | Epoch: 0001/0001 | Batch 0600/0703 | Loss: 0.0680 17 | Epoch: 0001/0001 | Train acc.: 94.44% | Val acc.: 96.00% 18 | Time elapsed 17.94 min 19 | Memory used: 26.79 GB 20 | Test accuracy 95.85% 21 | ``` 22 | 23 | # 01-2_pytorch-fabric.py 24 | 25 | - same as above but using Fabric 26 | - same results as expected 27 | 28 | ``` 29 | Epoch: 0001/0001 | Batch 0000/0703 | Loss: 2.4105 30 | Epoch: 0001/0001 | Batch 0300/0703 | Loss: 0.1754 31 | Epoch: 0001/0001 | Batch 0600/0703 | Loss: 0.2308 32 | Epoch: 0001/0001 | Train acc.: 94.44% | Val acc.: 96.06% 33 | Time elapsed 17.88 min 34 | Memory used: 26.84 GB 35 | Test accuracy 96.06% 36 | ``` 37 | 38 | # 02_mixed-precision.py 39 | 40 | - like above, but enables `"16-mixed"` training. 41 | 42 | ``` 43 | Epoch: 0001/0001 | Batch 0000/0703 | Loss: 2.4105 44 | Epoch: 0001/0001 | Batch 0300/0703 | Loss: 0.1088 45 | Epoch: 0001/0001 | Batch 0600/0703 | Loss: 0.1302 46 | Epoch: 0001/0001 | Train acc.: 94.56% | Val acc.: 96.02% 47 | Time elapsed 3.45 min 48 | Memory used: 18.21 GB 49 | Test accuracy 95.71% 50 | ``` 51 | 52 | # 03_bfloat16.py 53 | 54 | - like above, but enables `"bf16-true"` training. 55 | 56 | ``` 57 | Epoch: 0001/0001 | Batch 0000/0703 | Loss: 2.4101 58 | Epoch: 0001/0001 | Batch 0300/0703 | Loss: 0.1149 59 | Epoch: 0001/0001 | Batch 0600/0703 | Loss: 0.0269 60 | Epoch: 0001/0001 | Train acc.: 95.62% | Val acc.: 96.94% 61 | Time elapsed 2.88 min 62 | Memory used: 13.82 GB 63 | Test accuracy 96.92% 64 | ``` 65 | 66 | # 04_lower-batchsize.py 67 | 68 | - batch size 16 (instead of 64 before) 69 | 70 | ``` 71 | Epoch: 0001/0001 | Batch 0000/2812 | Loss: 2.4567 72 | Epoch: 0001/0001 | Batch 0300/2812 | Loss: 0.2379 73 | Epoch: 0001/0001 | Batch 0600/2812 | Loss: 0.0248 74 | Epoch: 0001/0001 | Batch 0900/2812 | Loss: 0.0716 75 | Epoch: 0001/0001 | Batch 1200/2812 | Loss: 0.0398 76 | Epoch: 0001/0001 | Batch 1500/2812 | Loss: 0.0177 77 | Epoch: 0001/0001 | Batch 1800/2812 | Loss: 0.0273 78 | Epoch: 0001/0001 | Batch 2100/2812 | Loss: 0.1532 79 | Epoch: 0001/0001 | Batch 2400/2812 | Loss: 0.0085 80 | Epoch: 0001/0001 | Batch 2700/2812 | Loss: 0.0031 81 | Epoch: 0001/0001 | Train acc.: 95.21% | Val acc.: 97.20% 82 | Time elapsed 3.96 min 83 | Memory used: 5.69 GB 84 | Test accuracy 97.34% 85 | ``` 86 | 87 | # 05_gradient-accum.py 88 | 89 | ``` 90 | Epoch: 0001/0001 | Batch 0000/11250 | Loss: 0.6012 91 | Epoch: 0001/0001 | Batch 0300/11250 | Loss: 0.0044 92 | Epoch: 0001/0001 | Batch 0600/11250 | Loss: 0.0032 93 | Epoch: 0001/0001 | Batch 0900/11250 | Loss: 0.0155 94 | Epoch: 0001/0001 | Batch 1200/11250 | Loss: 0.0021 95 | Epoch: 0001/0001 | Batch 1500/11250 | Loss: 0.0658 96 | Epoch: 0001/0001 | Batch 1800/11250 | Loss: 0.0016 97 | Epoch: 0001/0001 | Batch 2100/11250 | Loss: 0.0359 98 | Epoch: 0001/0001 | Batch 2400/11250 | Loss: 0.0106 99 | Epoch: 0001/0001 | Batch 2700/11250 | Loss: 0.0100 100 | Epoch: 0001/0001 | Batch 3000/11250 | Loss: 0.2942 101 | Epoch: 0001/0001 | Batch 3300/11250 | Loss: 0.0020 102 | Epoch: 0001/0001 | Batch 3600/11250 | Loss: 0.0222 103 | Epoch: 0001/0001 | Batch 3900/11250 | Loss: 0.0075 104 | Epoch: 0001/0001 | Batch 4200/11250 | Loss: 0.1245 105 | Epoch: 0001/0001 | Batch 4500/11250 | Loss: 0.0032 106 | Epoch: 0001/0001 | Batch 4800/11250 | Loss: 0.0266 107 | Epoch: 0001/0001 | Batch 5100/11250 | Loss: 0.0039 108 | Epoch: 0001/0001 | Batch 5400/11250 | Loss: 0.0014 109 | Epoch: 0001/0001 | Batch 5700/11250 | Loss: 0.0171 110 | Epoch: 0001/0001 | Batch 6000/11250 | Loss: 0.0009 111 | Epoch: 0001/0001 | Batch 6300/11250 | Loss: 0.0021 112 | Epoch: 0001/0001 | Batch 6600/11250 | Loss: 0.0655 113 | Epoch: 0001/0001 | Batch 6900/11250 | Loss: 0.0004 114 | Epoch: 0001/0001 | Batch 7200/11250 | Loss: 0.0003 115 | Epoch: 0001/0001 | Batch 7500/11250 | Loss: 0.0004 116 | Epoch: 0001/0001 | Batch 7800/11250 | Loss: 0.0011 117 | Epoch: 0001/0001 | Batch 8100/11250 | Loss: 0.0106 118 | Epoch: 0001/0001 | Batch 8400/11250 | Loss: 0.0321 119 | Epoch: 0001/0001 | Batch 8700/11250 | Loss: 0.0018 120 | Epoch: 0001/0001 | Batch 9000/11250 | Loss: 0.0004 121 | Epoch: 0001/0001 | Batch 9300/11250 | Loss: 0.0013 122 | Epoch: 0001/0001 | Batch 9600/11250 | Loss: 0.0001 123 | Epoch: 0001/0001 | Batch 9900/11250 | Loss: 0.0003 124 | Epoch: 0001/0001 | Batch 10200/11250 | Loss: 0.1277 125 | Epoch: 0001/0001 | Batch 10500/11250 | Loss: 0.0005 126 | Epoch: 0001/0001 | Batch 10800/11250 | Loss: 0.0007 127 | Epoch: 0001/0001 | Batch 11100/11250 | Loss: 0.0490 128 | Epoch: 0001/0001 | Train acc.: 95.46% | Val acc.: 97.24% 129 | Time elapsed 12.91 min 130 | Memory used: 3.91 GB 131 | Test accuracy 97.27% 132 | ``` 133 | 134 | # 06_sgd-with-scheduler.py 135 | 136 | - like above but with SGD + Cosine Decay Scheduler instead of ADAM 137 | 138 | ``` 139 | Epoch: 0001/0001 | Batch 0000/11250 | Loss: 0.6012 140 | Epoch: 0001/0001 | Batch 0300/11250 | Loss: 0.0472 141 | Epoch: 0001/0001 | Batch 0600/11250 | Loss: 0.0110 142 | Epoch: 0001/0001 | Batch 0900/11250 | Loss: 0.0215 143 | Epoch: 0001/0001 | Batch 1200/11250 | Loss: 0.0152 144 | Epoch: 0001/0001 | Batch 1500/11250 | Loss: 0.0134 145 | Epoch: 0001/0001 | Batch 1800/11250 | Loss: 0.0036 146 | Epoch: 0001/0001 | Batch 2100/11250 | Loss: 0.0223 147 | Epoch: 0001/0001 | Batch 2400/11250 | Loss: 0.0230 148 | Epoch: 0001/0001 | Batch 2700/11250 | Loss: 0.0293 149 | Epoch: 0001/0001 | Batch 3000/11250 | Loss: 0.0567 150 | Epoch: 0001/0001 | Batch 3300/11250 | Loss: 0.0009 151 | Epoch: 0001/0001 | Batch 3600/11250 | Loss: 0.0428 152 | Epoch: 0001/0001 | Batch 3900/11250 | Loss: 0.0081 153 | Epoch: 0001/0001 | Batch 4200/11250 | Loss: 0.0612 154 | Epoch: 0001/0001 | Batch 4500/11250 | Loss: 0.0429 155 | Epoch: 0001/0001 | Batch 4800/11250 | Loss: 0.0118 156 | Epoch: 0001/0001 | Batch 5100/11250 | Loss: 0.0041 157 | Epoch: 0001/0001 | Batch 5400/11250 | Loss: 0.0021 158 | Epoch: 0001/0001 | Batch 5700/11250 | Loss: 0.0076 159 | Epoch: 0001/0001 | Batch 6000/11250 | Loss: 0.0027 160 | Epoch: 0001/0001 | Batch 6300/11250 | Loss: 0.0012 161 | Epoch: 0001/0001 | Batch 6600/11250 | Loss: 0.0752 162 | Epoch: 0001/0001 | Batch 6900/11250 | Loss: 0.0061 163 | Epoch: 0001/0001 | Batch 7200/11250 | Loss: 0.0087 164 | Epoch: 0001/0001 | Batch 7500/11250 | Loss: 0.0004 165 | Epoch: 0001/0001 | Batch 7800/11250 | Loss: 0.0017 166 | Epoch: 0001/0001 | Batch 8100/11250 | Loss: 0.0259 167 | Epoch: 0001/0001 | Batch 8400/11250 | Loss: 0.0844 168 | Epoch: 0001/0001 | Batch 8700/11250 | Loss: 0.0777 169 | Epoch: 0001/0001 | Batch 9000/11250 | Loss: 0.0071 170 | Epoch: 0001/0001 | Batch 9300/11250 | Loss: 0.0035 171 | Epoch: 0001/0001 | Batch 9600/11250 | Loss: 0.0002 172 | Epoch: 0001/0001 | Batch 9900/11250 | Loss: 0.0017 173 | Epoch: 0001/0001 | Batch 10200/11250 | Loss: 0.2672 174 | Epoch: 0001/0001 | Batch 10500/11250 | Loss: 0.0072 175 | Epoch: 0001/0001 | Batch 10800/11250 | Loss: 0.0042 176 | Epoch: 0001/0001 | Batch 11100/11250 | Loss: 0.0061 177 | Epoch: 0001/0001 | Train acc.: 95.88% | Val acc.: 97.12% 178 | Time elapsed 12.89 min 179 | Memory used: 2.02 GB 180 | Test accuracy 96.91% 181 | ``` 182 | 183 | # 07_xx_init-module.py 184 | 185 | 07-01_init_module.py: 186 | 187 | ``` 188 | Without Fabric 189 | CPU Memory used: 0.60 GB 190 | GPU Memory used: 1.24 GB 191 | ``` 192 | 193 | 07-03_init_module.py: 194 | 195 | ``` 196 | With init_module 197 | CPU Memory used: 0.70 GB 198 | GPU Memory used: 0.65 GB 199 | ``` 200 | 201 | 202 | # 08_fsdp-with-01-2.py 203 | 204 | - Fully Sharded Data Parallelism on 4 GPUs 205 | - compare to the `01-2_pytorch-fabric.py` baseline 206 | 207 | 208 | ``` 209 | Epoch: 0001/0001 | Batch 0000/0175 | Loss: 2.4957 210 | Epoch: 0001/0001 | Batch 0050/0175 | Loss: 0.1717 211 | Epoch: 0001/0001 | Batch 0100/0175 | Loss: 0.0793 212 | Epoch: 0001/0001 | Batch 0150/0175 | Loss: 0.1426 213 | Epoch: 0001/0001 | Train acc.: 94.74% | Val acc.: 97.28% 214 | Time elapsed 5.53 min 215 | Memory used: 6.59 GB 216 | Test accuracy 97.13% 217 | ``` 218 | 219 | 220 | 221 | 222 | # 09 FSDP CPU Offload 223 | 224 | - Similar to above but with CPU offloading. 225 | 226 | ``` 227 | Epoch: 0001/0001 | Batch 0000/0175 | Loss: 2.4957 228 | Epoch: 0001/0001 | Batch 0050/0175 | Loss: 0.1717 229 | Epoch: 0001/0001 | Batch 0100/0175 | Loss: 0.0794 230 | Epoch: 0001/0001 | Batch 0150/0175 | Loss: 0.1454 231 | Epoch: 0001/0001 | Train acc.: 94.75% | Val acc.: 97.24% 232 | Time elapsed 8.34 min 233 | Memory used: 6.03 GB 234 | Test accuracy 97.23% 235 | ``` 236 | 237 | ## 10_delay-allocation.py 238 | 239 | - like above but with `init_module` context 240 | 241 | ``` 242 | Epoch: 0001/0001 | Batch 0000/0175 | Loss: 2.4957 243 | Epoch: 0001/0001 | Batch 0050/0175 | Loss: 0.1717 244 | Epoch: 0001/0001 | Batch 0100/0175 | Loss: 0.0803 245 | Epoch: 0001/0001 | Batch 0150/0175 | Loss: 0.1496 246 | Epoch: 0001/0001 | Train acc.: 94.74% | Val acc.: 97.38% 247 | Time elapsed 8.48 min 248 | Memory used: 6.03 GB 249 | Test accuracy 97.22% 250 | ``` 251 | 252 | ## bonus: distilbert before 253 | 254 | ``` 255 | Epoch: 0001/0001 | Batch 0000/0729 | Loss: 0.7085 256 | Epoch: 0001/0001 | Batch 0300/0729 | Loss: 0.4883 257 | Epoch: 0001/0001 | Batch 0600/0729 | Loss: 0.3101 258 | Epoch: 0001/0001 | Train acc.: 90.55% | Val acc.: 92.99% 259 | Test accuracy: 92.95% 260 | Total training time: 0.71 min 261 | Memory used: 3.99 GB 262 | ``` 263 | 264 | # bonus distilbert after 265 | 266 | ``` 267 | Epoch: 0001/0001 | Batch 0000/2916 | Loss: 0.1748 268 | Epoch: 0001/0001 | Batch 0300/2916 | Loss: 0.1943 269 | Epoch: 0001/0001 | Batch 0600/2916 | Loss: 0.0161 270 | Epoch: 0001/0001 | Batch 0900/2916 | Loss: 0.0294 271 | Epoch: 0001/0001 | Batch 1200/2916 | Loss: 0.0195 272 | Epoch: 0001/0001 | Batch 1500/2916 | Loss: 0.1030 273 | Epoch: 0001/0001 | Batch 1800/2916 | Loss: 0.0184 274 | Epoch: 0001/0001 | Batch 2100/2916 | Loss: 0.0093 275 | Epoch: 0001/0001 | Batch 2400/2916 | Loss: 0.0197 276 | Epoch: 0001/0001 | Batch 2700/2916 | Loss: 0.0417 277 | Epoch: 0001/0001 | Train acc.: 85.37% | Val acc.: 91.43% 278 | Test accuracy: 90.43% 279 | Total training time: 5.55 min 280 | Memory used: 1.15 GB 281 | ``` 282 | 283 | # bonus bigbird before 284 | 285 | N/A 286 | 287 | # bonus bigbird after 288 | 289 | 290 | ``` 291 | Epoch: 0001/0001 | Batch 0000/8750 | Loss: 0.0697 292 | Epoch: 0001/0001 | Batch 0300/8750 | Loss: 0.0469 293 | Epoch: 0001/0001 | Batch 0600/8750 | Loss: 0.0557 294 | Epoch: 0001/0001 | Batch 0900/8750 | Loss: 0.0181 295 | Epoch: 0001/0001 | Batch 1200/8750 | Loss: 0.0254 296 | Epoch: 0001/0001 | Batch 1500/8750 | Loss: 0.0125 297 | Epoch: 0001/0001 | Batch 1800/8750 | Loss: 0.0138 298 | Epoch: 0001/0001 | Batch 2100/8750 | Loss: 0.0050 299 | Epoch: 0001/0001 | Batch 2400/8750 | Loss: 0.0120 300 | Epoch: 0001/0001 | Batch 2700/8750 | Loss: 0.0048 301 | Epoch: 0001/0001 | Batch 3000/8750 | Loss: 0.0114 302 | Epoch: 0001/0001 | Batch 3300/8750 | Loss: 0.0221 303 | Epoch: 0001/0001 | Batch 3600/8750 | Loss: 0.0027 304 | Epoch: 0001/0001 | Batch 3900/8750 | Loss: 0.0035 305 | Epoch: 0001/0001 | Batch 4200/8750 | Loss: 0.0015 306 | Epoch: 0001/0001 | Batch 4500/8750 | Loss: 0.0014 307 | Epoch: 0001/0001 | Batch 4800/8750 | Loss: 0.0082 308 | Epoch: 0001/0001 | Batch 5100/8750 | Loss: 0.0041 309 | Epoch: 0001/0001 | Batch 5400/8750 | Loss: 0.0023 310 | Epoch: 0001/0001 | Batch 5700/8750 | Loss: 0.0059 311 | Epoch: 0001/0001 | Batch 6000/8750 | Loss: 0.0036 312 | Epoch: 0001/0001 | Batch 6300/8750 | Loss: 0.0023 313 | Epoch: 0001/0001 | Batch 6600/8750 | Loss: 0.0011 314 | Epoch: 0001/0001 | Batch 6900/8750 | Loss: 0.1048 315 | Epoch: 0001/0001 | Batch 7200/8750 | Loss: 0.0050 316 | Epoch: 0001/0001 | Batch 7500/8750 | Loss: 0.0027 317 | Epoch: 0001/0001 | Batch 7800/8750 | Loss: 0.0042 318 | Epoch: 0001/0001 | Batch 8100/8750 | Loss: 0.0016 319 | Epoch: 0001/0001 | Batch 8400/8750 | Loss: 0.0022 320 | Epoch: 0001/0001 | Batch 8700/8750 | Loss: 0.0016 321 | Epoch: 0001/0001 | Train acc.: 88.69% | Val acc.: 93.28% 322 | Test accuracy: 93.10% 323 | Total training time: 75.94 min 324 | Memory used: 4.03 GB 325 | ``` -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | numpy >= 1.24.3 2 | scipy >= 1.10.1 3 | pandas >= 2.0.2 4 | watermark >= 2.4.2 5 | torch >= 2.0.1 6 | torchvision >= 0.15.2 7 | torchmetrics >= 0.11.4 8 | transformers >= 4.30.2 9 | datasets >= 2.11.0 10 | lightning @ git+https://github.com/Lightning-AI/lightning@master 11 | --------------------------------------------------------------------------------