├── .gitignore ├── README.md ├── docker └── pytorch.Dockerfile ├── pretrain.py ├── utils.py └── prune.py /.gitignore: -------------------------------------------------------------------------------- 1 | data 2 | saved_models 3 | .vscode 4 | __pycache__ -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # PyTorch Pruning 2 | 3 | ## Introduction 4 | 5 | PyTorch pruning example for ResNet. ResNet18 pre-trained on CIFAR-10 dataset maintains the same prediction accuracy with 50x compression after pruning. 6 | 7 | ## Usages 8 | 9 | ### Build Docker Image 10 | 11 | ``` 12 | $ docker build -f docker/pytorch.Dockerfile --no-cache --tag=pytorch:1.13.0 . 13 | ``` 14 | 15 | ### Run Docker Container 16 | 17 | ``` 18 | $ docker run -it --rm --gpus device=0 -v $(pwd):/mnt pytorch:1.13.0 19 | ``` 20 | 21 | ### Run Pre-Training 22 | 23 | ``` 24 | $ python pretrain.py 25 | ``` 26 | 27 | ### Run Pruning 28 | 29 | ``` 30 | $ python prune.py 31 | ``` 32 | 33 | ## References 34 | 35 | * [PyTorch Pruning](https://leimao.github.io/blog/PyTorch-Pruning/) 36 | * [PyTorch Pruning Tutorial](https://pytorch.org/tutorials/intermediate/pruning_tutorial.html) 37 | -------------------------------------------------------------------------------- /docker/pytorch.Dockerfile: -------------------------------------------------------------------------------- 1 | FROM nvcr.io/nvidia/cuda:11.7.1-cudnn8-devel-ubuntu22.04 2 | 3 | ENV DEBIAN_FRONTEND noninteractive 4 | 5 | # Install package dependencies 6 | RUN apt-get update && \ 7 | apt-get install -y --no-install-recommends \ 8 | build-essential \ 9 | autoconf \ 10 | automake \ 11 | libtool \ 12 | pkg-config \ 13 | ca-certificates \ 14 | wget \ 15 | git \ 16 | curl \ 17 | libjpeg-dev \ 18 | libpng-dev \ 19 | language-pack-en \ 20 | locales \ 21 | locales-all \ 22 | python3 \ 23 | python3-dev \ 24 | python3-pip \ 25 | python3-setuptools \ 26 | libprotobuf-dev \ 27 | protobuf-compiler \ 28 | zlib1g-dev \ 29 | swig \ 30 | vim \ 31 | gdb \ 32 | valgrind \ 33 | libsm6 \ 34 | libxext6 \ 35 | libxrender-dev && \ 36 | apt-get clean 37 | 38 | RUN cd /usr/local/bin && \ 39 | ln -s /usr/bin/python3 python && \ 40 | ln -s /usr/bin/pip3 pip && \ 41 | pip install --upgrade pip setuptools wheel 42 | 43 | # System locale 44 | # Important for UTF-8 45 | ENV LC_ALL en_US.UTF-8 46 | ENV LANG en_US.UTF-8 47 | ENV LANGUAGE en_US.UTF-8 48 | 49 | RUN pip install torch==1.13.0 torchvision==0.14.0 torchaudio==0.13.0 50 | RUN pip install scikit-learn==1.1.3 51 | -------------------------------------------------------------------------------- /pretrain.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | from utils import set_random_seeds, create_model, prepare_dataloader, train_model, save_model, load_model, evaluate_model, create_classification_report 4 | 5 | 6 | def main(): 7 | 8 | random_seed = 0 9 | num_classes = 10 10 | l1_regularization_strength = 0 11 | l2_regularization_strength = 1e-4 12 | learning_rate = 1e-1 13 | num_epochs = 200 14 | cuda_device = torch.device("cuda:0") 15 | cpu_device = torch.device("cpu:0") 16 | 17 | model_dir = "saved_models" 18 | model_filename = "resnet18_cifar10.pt" 19 | model_filepath = os.path.join(model_dir, model_filename) 20 | 21 | set_random_seeds(random_seed=random_seed) 22 | 23 | # Create an untrained model. 24 | model = create_model(num_classes=num_classes) 25 | 26 | train_loader, test_loader, classes = prepare_dataloader( 27 | num_workers=8, train_batch_size=128, eval_batch_size=256) 28 | 29 | # Train model. 30 | print("Training Model...") 31 | model = train_model(model=model, 32 | train_loader=train_loader, 33 | test_loader=test_loader, 34 | device=cuda_device, 35 | l1_regularization_strength=l1_regularization_strength, 36 | l2_regularization_strength=l2_regularization_strength, 37 | learning_rate=learning_rate, 38 | num_epochs=num_epochs) 39 | 40 | # Save model. 41 | save_model(model=model, model_dir=model_dir, model_filename=model_filename) 42 | # Load a pretrained model. 43 | model = load_model(model=model, 44 | model_filepath=model_filepath, 45 | device=cuda_device) 46 | 47 | _, eval_accuracy = evaluate_model(model=model, 48 | test_loader=test_loader, 49 | device=cuda_device, 50 | criterion=None) 51 | 52 | classification_report = create_classification_report( 53 | model=model, test_loader=test_loader, device=cuda_device) 54 | 55 | print("Test Accuracy: {:.3f}".format(eval_accuracy)) 56 | print("Classification Report:") 57 | print(classification_report) 58 | 59 | 60 | if __name__ == "__main__": 61 | 62 | main() 63 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import random 3 | 4 | import torch 5 | import torch.nn as nn 6 | import torch.optim as optim 7 | import torchvision 8 | from torchvision import datasets, transforms 9 | 10 | import time 11 | import copy 12 | import numpy as np 13 | 14 | import sklearn.metrics 15 | 16 | 17 | def set_random_seeds(random_seed=0): 18 | 19 | torch.manual_seed(random_seed) 20 | torch.backends.cudnn.deterministic = True 21 | torch.backends.cudnn.benchmark = False 22 | np.random.seed(random_seed) 23 | random.seed(random_seed) 24 | 25 | 26 | def prepare_dataloader(num_workers=8, 27 | train_batch_size=128, 28 | eval_batch_size=256): 29 | 30 | train_transform = transforms.Compose([ 31 | transforms.RandomCrop(32, padding=4), 32 | transforms.RandomHorizontalFlip(), 33 | transforms.ToTensor(), 34 | transforms.Normalize(mean=(0.485, 0.456, 0.406), 35 | std=(0.229, 0.224, 0.225)) 36 | ]) 37 | 38 | test_transform = transforms.Compose([ 39 | transforms.ToTensor(), 40 | transforms.Normalize(mean=(0.485, 0.456, 0.406), 41 | std=(0.229, 0.224, 0.225)) 42 | ]) 43 | 44 | train_set = torchvision.datasets.CIFAR10(root="data", 45 | train=True, 46 | download=True, 47 | transform=train_transform) 48 | 49 | test_set = torchvision.datasets.CIFAR10(root="data", 50 | train=False, 51 | download=True, 52 | transform=test_transform) 53 | 54 | train_sampler = torch.utils.data.RandomSampler(train_set) 55 | test_sampler = torch.utils.data.SequentialSampler(test_set) 56 | 57 | train_loader = torch.utils.data.DataLoader(dataset=train_set, 58 | batch_size=train_batch_size, 59 | sampler=train_sampler, 60 | num_workers=num_workers) 61 | 62 | test_loader = torch.utils.data.DataLoader(dataset=test_set, 63 | batch_size=eval_batch_size, 64 | sampler=test_sampler, 65 | num_workers=num_workers) 66 | 67 | classes = train_set.classes 68 | 69 | return train_loader, test_loader, classes 70 | 71 | 72 | def evaluate_model(model, test_loader, device, criterion=None): 73 | 74 | model.eval() 75 | model.to(device) 76 | 77 | running_loss = 0 78 | running_corrects = 0 79 | 80 | for inputs, labels in test_loader: 81 | 82 | inputs = inputs.to(device) 83 | labels = labels.to(device) 84 | 85 | outputs = model(inputs) 86 | _, preds = torch.max(outputs, 1) 87 | 88 | if criterion is not None: 89 | loss = criterion(outputs, labels).item() 90 | else: 91 | loss = 0 92 | 93 | # statistics 94 | running_loss += loss * inputs.size(0) 95 | running_corrects += torch.sum(preds == labels.data) 96 | 97 | eval_loss = running_loss / len(test_loader.dataset) 98 | eval_accuracy = running_corrects / len(test_loader.dataset) 99 | 100 | return eval_loss, eval_accuracy 101 | 102 | 103 | def create_classification_report(model, device, test_loader): 104 | 105 | model.eval() 106 | model.to(device) 107 | 108 | y_pred = [] 109 | y_true = [] 110 | 111 | with torch.no_grad(): 112 | for data in test_loader: 113 | y_true += data[1].numpy().tolist() 114 | images, _ = data[0].to(device), data[1].to(device) 115 | outputs = model(images) 116 | _, predicted = torch.max(outputs.data, 1) 117 | y_pred += predicted.cpu().numpy().tolist() 118 | 119 | classification_report = sklearn.metrics.classification_report( 120 | y_true=y_true, y_pred=y_pred) 121 | 122 | return classification_report 123 | 124 | 125 | def train_model(model, 126 | train_loader, 127 | test_loader, 128 | device, 129 | l1_regularization_strength=0, 130 | l2_regularization_strength=1e-4, 131 | learning_rate=1e-1, 132 | num_epochs=200): 133 | 134 | # The training configurations were not carefully selected. 135 | 136 | criterion = nn.CrossEntropyLoss() 137 | 138 | model.to(device) 139 | 140 | # It seems that SGD optimizer is better than Adam optimizer for ResNet18 training on CIFAR10. 141 | optimizer = optim.SGD(model.parameters(), 142 | lr=learning_rate, 143 | momentum=0.9, 144 | weight_decay=l2_regularization_strength) 145 | # scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=500) 146 | scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, 147 | milestones=[100, 150], 148 | gamma=0.1, 149 | last_epoch=-1) 150 | # optimizer = optim.Adam(model.parameters(), lr=learning_rate, betas=(0.9, 0.999), eps=1e-08, weight_decay=0, amsgrad=False) 151 | 152 | # Evaluation 153 | model.eval() 154 | eval_loss, eval_accuracy = evaluate_model(model=model, 155 | test_loader=test_loader, 156 | device=device, 157 | criterion=criterion) 158 | print("Epoch: {:03d} Eval Loss: {:.3f} Eval Acc: {:.3f}".format( 159 | 0, eval_loss, eval_accuracy)) 160 | 161 | for epoch in range(num_epochs): 162 | 163 | # Training 164 | model.train() 165 | 166 | running_loss = 0 167 | running_corrects = 0 168 | 169 | for inputs, labels in train_loader: 170 | 171 | inputs = inputs.to(device) 172 | labels = labels.to(device) 173 | 174 | # zero the parameter gradients 175 | optimizer.zero_grad() 176 | 177 | # forward + backward + optimize 178 | outputs = model(inputs) 179 | _, preds = torch.max(outputs, 1) 180 | loss = criterion(outputs, labels) 181 | 182 | l1_reg = torch.tensor(0.).to(device) 183 | for module in model.modules(): 184 | mask = None 185 | weight = None 186 | for name, buffer in module.named_buffers(): 187 | if name == "weight_mask": 188 | mask = buffer 189 | for name, param in module.named_parameters(): 190 | if name == "weight_orig": 191 | weight = param 192 | # We usually only want to introduce sparsity to weights and prune weights. 193 | # Do the same for bias if necessary. 194 | if mask is not None and weight is not None: 195 | l1_reg += torch.norm(mask * weight, 1) 196 | 197 | loss += l1_regularization_strength * l1_reg 198 | 199 | loss.backward() 200 | optimizer.step() 201 | 202 | # statistics 203 | running_loss += loss.item() * inputs.size(0) 204 | running_corrects += torch.sum(preds == labels.data) 205 | 206 | train_loss = running_loss / len(train_loader.dataset) 207 | train_accuracy = running_corrects / len(train_loader.dataset) 208 | 209 | # Evaluation 210 | model.eval() 211 | eval_loss, eval_accuracy = evaluate_model(model=model, 212 | test_loader=test_loader, 213 | device=device, 214 | criterion=criterion) 215 | 216 | # Set learning rate scheduler 217 | scheduler.step() 218 | 219 | print( 220 | "Epoch: {:03d} Train Loss: {:.3f} Train Acc: {:.3f} Eval Loss: {:.3f} Eval Acc: {:.3f}" 221 | .format(epoch + 1, train_loss, train_accuracy, eval_loss, 222 | eval_accuracy)) 223 | 224 | return model 225 | 226 | 227 | def save_model(model, model_dir, model_filename): 228 | 229 | if not os.path.exists(model_dir): 230 | os.makedirs(model_dir) 231 | model_filepath = os.path.join(model_dir, model_filename) 232 | torch.save(model.state_dict(), model_filepath) 233 | 234 | 235 | def load_model(model, model_filepath, device): 236 | 237 | model.load_state_dict(torch.load(model_filepath, map_location=device)) 238 | 239 | return model 240 | 241 | 242 | def create_model(num_classes=10, model_func=torchvision.models.resnet18): 243 | 244 | # The number of channels in ResNet18 is divisible by 8. 245 | # This is required for fast GEMM integer matrix multiplication. 246 | # model = torchvision.models.resnet18(pretrained=False) 247 | model = model_func(num_classes=num_classes, pretrained=False) 248 | 249 | # We would use the pretrained ResNet18 as a feature extractor. 250 | # for param in model.parameters(): 251 | # param.requires_grad = False 252 | 253 | # Modify the last FC layer 254 | # num_features = model.fc.in_features 255 | # model.fc = nn.Linear(num_features, 10) 256 | 257 | return model 258 | -------------------------------------------------------------------------------- /prune.py: -------------------------------------------------------------------------------- 1 | import os 2 | import copy 3 | import torch 4 | import torch.nn.utils.prune as prune 5 | from utils import set_random_seeds, create_model, prepare_dataloader, train_model, save_model, load_model, evaluate_model, create_classification_report 6 | 7 | 8 | def compute_final_pruning_rate(pruning_rate, num_iterations): 9 | """A function to compute the final pruning rate for iterative pruning. 10 | Note that this cannot be applied for global pruning rate if the pruning rate is heterogeneous among different layers. 11 | 12 | Args: 13 | pruning_rate (float): Pruning rate. 14 | num_iterations (int): Number of iterations. 15 | 16 | Returns: 17 | float: Final pruning rate. 18 | """ 19 | 20 | final_pruning_rate = 1 - (1 - pruning_rate)**num_iterations 21 | 22 | return final_pruning_rate 23 | 24 | 25 | def measure_module_sparsity(module, weight=True, bias=False, use_mask=False): 26 | 27 | num_zeros = 0 28 | num_elements = 0 29 | 30 | if use_mask == True: 31 | for buffer_name, buffer in module.named_buffers(): 32 | if "weight_mask" in buffer_name and weight == True: 33 | num_zeros += torch.sum(buffer == 0).item() 34 | num_elements += buffer.nelement() 35 | if "bias_mask" in buffer_name and bias == True: 36 | num_zeros += torch.sum(buffer == 0).item() 37 | num_elements += buffer.nelement() 38 | else: 39 | for param_name, param in module.named_parameters(): 40 | if "weight" in param_name and weight == True: 41 | num_zeros += torch.sum(param == 0).item() 42 | num_elements += param.nelement() 43 | if "bias" in param_name and bias == True: 44 | num_zeros += torch.sum(param == 0).item() 45 | num_elements += param.nelement() 46 | 47 | sparsity = num_zeros / num_elements 48 | 49 | return num_zeros, num_elements, sparsity 50 | 51 | 52 | def measure_global_sparsity(model, 53 | weight=True, 54 | bias=False, 55 | conv2d_use_mask=False, 56 | linear_use_mask=False): 57 | 58 | num_zeros = 0 59 | num_elements = 0 60 | 61 | for module_name, module in model.named_modules(): 62 | 63 | if isinstance(module, torch.nn.Conv2d): 64 | 65 | module_num_zeros, module_num_elements, _ = measure_module_sparsity( 66 | module, weight=weight, bias=bias, use_mask=conv2d_use_mask) 67 | num_zeros += module_num_zeros 68 | num_elements += module_num_elements 69 | 70 | elif isinstance(module, torch.nn.Linear): 71 | 72 | module_num_zeros, module_num_elements, _ = measure_module_sparsity( 73 | module, weight=weight, bias=bias, use_mask=linear_use_mask) 74 | num_zeros += module_num_zeros 75 | num_elements += module_num_elements 76 | 77 | sparsity = num_zeros / num_elements 78 | 79 | return num_zeros, num_elements, sparsity 80 | 81 | 82 | def iterative_pruning_finetuning(model, 83 | train_loader, 84 | test_loader, 85 | device, 86 | learning_rate, 87 | l1_regularization_strength, 88 | l2_regularization_strength, 89 | learning_rate_decay=0.1, 90 | conv2d_prune_amount=0.4, 91 | linear_prune_amount=0.2, 92 | num_iterations=10, 93 | num_epochs_per_iteration=10, 94 | model_filename_prefix="pruned_model", 95 | model_dir="saved_models", 96 | grouped_pruning=False): 97 | 98 | for i in range(num_iterations): 99 | 100 | print("Pruning and Finetuning {}/{}".format(i + 1, num_iterations)) 101 | 102 | print("Pruning...") 103 | 104 | if grouped_pruning == True: 105 | # Global pruning 106 | # I would rather call it grouped pruning. 107 | parameters_to_prune = [] 108 | for module_name, module in model.named_modules(): 109 | if isinstance(module, torch.nn.Conv2d): 110 | parameters_to_prune.append((module, "weight")) 111 | prune.global_unstructured( 112 | parameters_to_prune, 113 | pruning_method=prune.L1Unstructured, 114 | amount=conv2d_prune_amount, 115 | ) 116 | else: 117 | for module_name, module in model.named_modules(): 118 | if isinstance(module, torch.nn.Conv2d): 119 | prune.l1_unstructured(module, 120 | name="weight", 121 | amount=conv2d_prune_amount) 122 | elif isinstance(module, torch.nn.Linear): 123 | prune.l1_unstructured(module, 124 | name="weight", 125 | amount=linear_prune_amount) 126 | 127 | _, eval_accuracy = evaluate_model(model=model, 128 | test_loader=test_loader, 129 | device=device, 130 | criterion=None) 131 | 132 | classification_report = create_classification_report( 133 | model=model, test_loader=test_loader, device=device) 134 | 135 | num_zeros, num_elements, sparsity = measure_global_sparsity( 136 | model, 137 | weight=True, 138 | bias=False, 139 | conv2d_use_mask=True, 140 | linear_use_mask=False) 141 | 142 | print("Test Accuracy: {:.3f}".format(eval_accuracy)) 143 | print("Classification Report:") 144 | print(classification_report) 145 | print("Global Sparsity:") 146 | print("{:.2f}".format(sparsity)) 147 | 148 | # print(model.conv1._forward_pre_hooks) 149 | 150 | print("Fine-tuning...") 151 | 152 | train_model(model=model, 153 | train_loader=train_loader, 154 | test_loader=test_loader, 155 | device=device, 156 | l1_regularization_strength=l1_regularization_strength, 157 | l2_regularization_strength=l2_regularization_strength, 158 | learning_rate=learning_rate * (learning_rate_decay**i), 159 | num_epochs=num_epochs_per_iteration) 160 | 161 | _, eval_accuracy = evaluate_model(model=model, 162 | test_loader=test_loader, 163 | device=device, 164 | criterion=None) 165 | 166 | classification_report = create_classification_report( 167 | model=model, test_loader=test_loader, device=device) 168 | 169 | num_zeros, num_elements, sparsity = measure_global_sparsity( 170 | model, 171 | weight=True, 172 | bias=False, 173 | conv2d_use_mask=True, 174 | linear_use_mask=False) 175 | 176 | print("Test Accuracy: {:.3f}".format(eval_accuracy)) 177 | print("Classification Report:") 178 | print(classification_report) 179 | print("Global Sparsity:") 180 | print("{:.2f}".format(sparsity)) 181 | 182 | model_filename = "{}_{}.pt".format(model_filename_prefix, i + 1) 183 | model_filepath = os.path.join(model_dir, model_filename) 184 | save_model(model=model, 185 | model_dir=model_dir, 186 | model_filename=model_filename) 187 | model = load_model(model=model, 188 | model_filepath=model_filepath, 189 | device=device) 190 | 191 | return model 192 | 193 | 194 | def remove_parameters(model): 195 | 196 | for module_name, module in model.named_modules(): 197 | if isinstance(module, torch.nn.Conv2d): 198 | try: 199 | prune.remove(module, "weight") 200 | except: 201 | pass 202 | try: 203 | prune.remove(module, "bias") 204 | except: 205 | pass 206 | elif isinstance(module, torch.nn.Linear): 207 | try: 208 | prune.remove(module, "weight") 209 | except: 210 | pass 211 | try: 212 | prune.remove(module, "bias") 213 | except: 214 | pass 215 | 216 | return model 217 | 218 | 219 | def main(): 220 | 221 | num_classes = 10 222 | random_seed = 1 223 | l1_regularization_strength = 0 224 | l2_regularization_strength = 1e-4 225 | learning_rate = 1e-3 226 | learning_rate_decay = 1 227 | 228 | cuda_device = torch.device("cuda:0") 229 | cpu_device = torch.device("cpu:0") 230 | 231 | model_dir = "saved_models" 232 | model_filename = "resnet18_cifar10.pt" 233 | model_filename_prefix = "pruned_model" 234 | pruned_model_filename = "resnet18_pruned_cifar10.pt" 235 | model_filepath = os.path.join(model_dir, model_filename) 236 | pruned_model_filepath = os.path.join(model_dir, pruned_model_filename) 237 | 238 | set_random_seeds(random_seed=random_seed) 239 | 240 | # Create an untrained model. 241 | model = create_model(num_classes=num_classes) 242 | 243 | # Load a pretrained model. 244 | model = load_model(model=model, 245 | model_filepath=model_filepath, 246 | device=cuda_device) 247 | 248 | train_loader, test_loader, classes = prepare_dataloader( 249 | num_workers=8, train_batch_size=128, eval_batch_size=256) 250 | 251 | _, eval_accuracy = evaluate_model(model=model, 252 | test_loader=test_loader, 253 | device=cuda_device, 254 | criterion=None) 255 | 256 | classification_report = create_classification_report( 257 | model=model, test_loader=test_loader, device=cuda_device) 258 | 259 | num_zeros, num_elements, sparsity = measure_global_sparsity(model) 260 | 261 | print("Test Accuracy: {:.3f}".format(eval_accuracy)) 262 | print("Classification Report:") 263 | print(classification_report) 264 | print("Global Sparsity:") 265 | print("{:.2f}".format(sparsity)) 266 | 267 | print("Iterative Pruning + Fine-Tuning...") 268 | 269 | pruned_model = copy.deepcopy(model) 270 | 271 | # iterative_pruning_finetuning( 272 | # model=pruned_model, 273 | # train_loader=train_loader, 274 | # test_loader=test_loader, 275 | # device=cuda_device, 276 | # learning_rate=learning_rate, 277 | # learning_rate_decay=learning_rate_decay, 278 | # l1_regularization_strength=l1_regularization_strength, 279 | # l2_regularization_strength=l2_regularization_strength, 280 | # conv2d_prune_amount=0.3, 281 | # linear_prune_amount=0, 282 | # num_iterations=8, 283 | # num_epochs_per_iteration=50, 284 | # model_filename_prefix=model_filename_prefix, 285 | # model_dir=model_dir, 286 | # grouped_pruning=True) 287 | 288 | iterative_pruning_finetuning( 289 | model=pruned_model, 290 | train_loader=train_loader, 291 | test_loader=test_loader, 292 | device=cuda_device, 293 | learning_rate=learning_rate, 294 | learning_rate_decay=learning_rate_decay, 295 | l1_regularization_strength=l1_regularization_strength, 296 | l2_regularization_strength=l2_regularization_strength, 297 | conv2d_prune_amount=0.98, 298 | linear_prune_amount=0, 299 | num_iterations=1, 300 | num_epochs_per_iteration=200, 301 | model_filename_prefix=model_filename_prefix, 302 | model_dir=model_dir, 303 | grouped_pruning=True) 304 | 305 | # Apply mask to the parameters and remove the mask. 306 | remove_parameters(model=pruned_model) 307 | 308 | _, eval_accuracy = evaluate_model(model=pruned_model, 309 | test_loader=test_loader, 310 | device=cuda_device, 311 | criterion=None) 312 | 313 | classification_report = create_classification_report( 314 | model=pruned_model, test_loader=test_loader, device=cuda_device) 315 | 316 | num_zeros, num_elements, sparsity = measure_global_sparsity(pruned_model) 317 | 318 | print("Test Accuracy: {:.3f}".format(eval_accuracy)) 319 | print("Classification Report:") 320 | print(classification_report) 321 | print("Global Sparsity:") 322 | print("{:.2f}".format(sparsity)) 323 | 324 | save_model(model=model, model_dir=model_dir, model_filename=model_filename) 325 | 326 | 327 | if __name__ == "__main__": 328 | 329 | main() 330 | --------------------------------------------------------------------------------