├── .flake8 ├── .gitignore ├── README.md ├── __init__.py ├── callbacks.py ├── datamodules ├── __init__.py ├── cifar10.py ├── fer.py └── mnist.py ├── nn ├── __init__.py ├── conv.py ├── fc.py ├── modules.py ├── prune.py ├── quant.py └── shapes.py ├── requirements-dev.txt ├── requirements.txt └── utils.py /.flake8: -------------------------------------------------------------------------------- 1 | [flake8] 2 | docstring-convention = google 3 | extend-ignore = D1 4 | max-line-length = 102 5 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | __pycache__/ 2 | .ipynb_checkpoints/ 3 | wandb/ 4 | lightning_logs/ 5 | data/ 6 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | These utilities are intended for use with W&B educational materials 2 | and are not guaranteed to have a stable API outside of that context. 3 | 4 | # Structure 5 | 6 | ### `datamodules` 7 | 8 | This module includes `pl.LightningDataModule`s for 9 | a variety of datasets. 10 | 11 | The API is only moderately standardized. 12 | It might be further standardized 13 | if/when we add "dataset callbacks". 14 | 15 | ### `nn` 16 | 17 | `lu.nn` includes `pl.LightningModule`s and `torch.nn.Module`s. 18 | The core module is the `LoggedLitModule`, 19 | which abstracts logging, metrics, and training/validation/test steps 20 | in a manner that's suitable for many basic DNNs. 21 | 22 | ### `callbacks` 23 | 24 | Callbacks includes `pl.Callback`s 25 | that log to Weights & Biases. 26 | The `lu.callbacks.WandbCallback` is designed to work 27 | with any `lu.nn.LoggedLitModule` 28 | and should be included in all educational Colabs. 29 | 30 | Others are specific to particular DNN problems, 31 | like image classification or autoencoding. 32 | 33 | ### `utils` 34 | 35 | This is a grab-bag of utilities, 36 | like a run name generator for use with the W&B YOLOv5 integration. 37 | 38 | # Installation 39 | 40 | These utilities are used in Colab and are "installed" via git. 41 | 42 | The following snippet installs the requirements that are not 43 | included in Colab: 44 | 45 | ```python 46 | %%capture 47 | !pip install pytorch-lightning==1.3.8 torchviz wandb 48 | !git clone https://github.com/wandb/lit_utils 49 | !cd "/content/lit_utils" && git pull 50 | ``` 51 | 52 | For local development, 53 | also invoke `!pip install -r requirements-dev.txt`. 54 | 55 | # Usage 56 | 57 | The library is imported as `lu`: 58 | ```python 59 | import lit_utils as lu 60 | ``` 61 | 62 | And in educational notebooks you should use 63 | ```python 64 | lu.utils.filter_warnings() 65 | ``` 66 | 67 | to filter our `UserWarning`s etc. 68 | -------------------------------------------------------------------------------- /__init__.py: -------------------------------------------------------------------------------- 1 | # flake8: noqa 2 | from . import callbacks, datamodules, nn, utils 3 | -------------------------------------------------------------------------------- /callbacks.py: -------------------------------------------------------------------------------- 1 | """Lightning Callbacks for logging to Weights & Biases.""" 2 | import os 3 | from pathlib import Path 4 | import tempfile 5 | 6 | import numpy as np 7 | import pytorch_lightning as pl 8 | import torch 9 | import wandb 10 | 11 | try: 12 | import torchviz 13 | has_torchviz = True 14 | except ImportError: 15 | has_torchviz = False 16 | 17 | 18 | class WandbCallback(): 19 | """Logs useful config and metric data to Weights & Biases.""" 20 | def __init__(self): 21 | self.cbs = [MetadataLogCallback(), ModelSizeLogCallback(), GraphLogCallback()] 22 | 23 | def __getattr__(self, item): 24 | return lambda *args, **kwargs: [getattr(cb, item)(*args, **kwargs) for cb in self.cbs] 25 | 26 | # the base class's on_save_checkpoint method crashes if called with *args, **kwargs 27 | def on_save_checkpoint(self, trainer, module, checkpoint): 28 | return None 29 | 30 | 31 | class FilterLogCallback(pl.Callback): 32 | """PyTorch Lightning Callback for logging the "filters" of a PyTorch Module. 33 | 34 | Filters are weights that touch input or output, and so are often interpretable. 35 | In particular, these weights are most often interpretable for networks that 36 | consume or produce images, because they can be viewed as images. 37 | 38 | This Logger selects the input and/or output filters (set by log_input and 39 | log_output boolean flags) for logging and sends them to Weights & Biases as 40 | images. 41 | """ 42 | def __init__(self, image_size=None, log_input=False, log_output=False): 43 | super().__init__() 44 | if image_size is not None and len(image_size) == 2: 45 | image_size = [1] + list(image_size) 46 | self.image_size = image_size 47 | self.log_input, self.log_output = log_input, log_output 48 | 49 | def on_train_epoch_end(self, trainer, pl_module): 50 | pl_module.eval() 51 | with torch.no_grad(): 52 | if self.log_input: 53 | input_filters = self.fetch_filters(pl_module, reversed=False, 54 | output_shape=self.image_size) 55 | self.log_filters(input_filters, "filters/input", trainer) 56 | 57 | if self.log_output: 58 | output_filters = self.fetch_filters(pl_module, reversed=True, 59 | output_shape=self.image_size) 60 | self.log_filters(output_filters, "filters/output", trainer) 61 | pl_module.train() 62 | 63 | def log_filters(self, filters, key, trainer): 64 | trainer.logger.experiment.log({ 65 | key: wandb.Image(filters.cpu()), 66 | "global_step": trainer.global_step 67 | }) 68 | 69 | def fetch_filters(self, module, reversed=False, output_shape=None): 70 | weights = get_weights(module) 71 | assert len(weights), "could not find any weights" 72 | 73 | if reversed: 74 | filter_weights = torch.transpose(weights[-1], -2, -1) 75 | else: 76 | filter_weights = weights[0] 77 | 78 | filters = self.extract_filters(filter_weights, output_shape=output_shape) 79 | 80 | return filters 81 | 82 | def extract_filters(self, filter_weights, output_shape=None): 83 | is_convolutional = filter_weights.ndim == 4 84 | if is_convolutional: 85 | channel_count = filter_weights.shape[1] 86 | if channel_count not in [1, 3]: 87 | raise ValueError("convolutional filters must have 1 (L) or 3 (RGB) channels, " + 88 | f"but had {channel_count}") 89 | return filter_weights 90 | else: 91 | if filter_weights.ndim != 2: 92 | raise ValueError("filter_weights must have 2 or 4 dimensions, " + 93 | f"but had {filter_weights.ndim}") 94 | if output_shape is None: 95 | raise ValueError("output_shape must be provided to get filters from linear layer") 96 | filter_weights = self.reshape_linear_weights(filter_weights, output_shape) 97 | return filter_weights 98 | 99 | @staticmethod 100 | def reshape_linear_weights(filter_weights, output_shape): 101 | if len(output_shape) < 2: 102 | raise ValueError("output_shape must be at least H x W") 103 | if np.prod(output_shape) != filter_weights.shape[1]: 104 | raise ValueError("shape of filter_weights did not match output_shape") 105 | return torch.reshape(filter_weights, [-1] + list(output_shape)) 106 | 107 | 108 | class ImageLogCallback(pl.Callback): 109 | """Logs the input and output images produced by a module to Weights & Biases. 110 | 111 | Useful in combination with, e.g., an autoencoder architecture, 112 | a convolutional GAN, or any image-to-image transformation network. 113 | """ 114 | def __init__(self, val_samples, num_samples=32): 115 | super().__init__() 116 | self.val_imgs, _ = val_samples 117 | self.val_imgs = self.val_imgs[:num_samples] 118 | 119 | def on_validation_epoch_end(self, trainer, pl_module): 120 | val_imgs = self.val_imgs.to(device=pl_module.device) 121 | 122 | outs = pl_module(val_imgs) 123 | 124 | mosaics = torch.cat([outs, val_imgs], dim=-2) 125 | caption = "Top: Output, Bottom: Input" 126 | trainer.logger.experiment.log({ 127 | "test/examples": [wandb.Image(mosaic, caption=caption) 128 | for mosaic in mosaics], 129 | "global_step": trainer.global_step 130 | }) 131 | 132 | 133 | class ModelSizeLogCallback(pl.Callback): 134 | """Logs information about model size to Weights & Biases.""" 135 | def __init__(self, count_nonzero=False): 136 | super().__init__() 137 | self.count_nonzero = count_nonzero 138 | 139 | def on_fit_end(self, trainer, module): 140 | summary = {} 141 | summary["size_mb"] = self.get_model_disksize(module) 142 | summary["nparams"] = count_params(module) 143 | if self.count_nonzero: 144 | summary["nonzero_params"] = count_params_nonzero(module) 145 | 146 | trainer.logger.experiment.summary.update(summary) 147 | 148 | @staticmethod 149 | def get_model_disksize(module, print_size=True): 150 | """Temporarily save model file to disk and return (and optionally print) model size in MB.""" 151 | with tempfile.NamedTemporaryFile() as f: 152 | torch.save(module.state_dict(), f) 153 | size_mb = os.path.getsize(f.name) / 1e6 154 | if print_size: 155 | print(f"Final Disk Size: {round(size_mb, 2)} MB") 156 | return size_mb 157 | 158 | def on_fit_start(self, trainer, module): 159 | print(f"Parameter Count: {count_params(module)}") 160 | 161 | 162 | class GraphLogCallback(pl.Callback): 163 | """Logs a compute graph to Weights & Biases.""" 164 | def __init__(self): 165 | super().__init__() 166 | self.graph_logged = False 167 | assert has_torchviz, "GraphLogCallback requires torchviz installation" 168 | 169 | def on_train_batch_end(self, trainer, module, outputs, batch, batch_idx, dataloader_idx): 170 | if not self.graph_logged: 171 | try: 172 | self.log_graph(trainer, module, outputs["y_hats"]) 173 | except KeyError: 174 | pass 175 | self.graph_logged = True 176 | 177 | @staticmethod 178 | def log_graph(trainer, module, outputs): 179 | params_dict = dict(list(module.named_parameters())) 180 | graph = torchviz.make_dot(outputs, params=params_dict) 181 | graph.format = "png" 182 | fname = Path(trainer.logger.experiment.dir) / "graph" 183 | graph.render(fname) 184 | wandb.save(str(fname.with_suffix("." + graph.format)), base_path=fname.parent) 185 | 186 | 187 | class SparsityLogCallback(pl.Callback): 188 | """PyTorch Lightning Callback for logging the sparsity of weight tensors in a PyTorch Module.""" 189 | def on_validation_epoch_end(self, trainer, module): 190 | self.log_sparsities(trainer, module) 191 | 192 | def get_sparsities(self, module): 193 | weights = get_weights(module) 194 | names = [".".join(name.split(".")[:-1]) for name, _ in module.named_parameters() 195 | if "weight" in name.split(".")[-1]] 196 | sparsities = [torch.sum(weight == 0) / weight.numel() for weight in weights] 197 | 198 | return {"sparsity/" + name: sparsity for name, sparsity in zip(names, sparsities)} 199 | 200 | def log_sparsities(self, trainer, module): 201 | sparsities = self.get_sparsities(module) 202 | sparsities["sparsity/total"] = 1 - fraction_nonzero(module) 203 | sparsities["global_step"] = trainer.global_step 204 | trainer.logger.experiment.log(sparsities) 205 | 206 | 207 | class MetadataLogCallback(pl.Callback): 208 | """Attempts to infer metadata about a module and log to Weights & Biases.""" 209 | 210 | def on_fit_start(self, trainer, module): 211 | wandb.run.config["batchnorm"] = self.detect_batchnorm(module) 212 | wandb.run.config["dropout"] = self.detect_dropout(module) 213 | wandb.run.config["loss_fn"] = self.detect_loss_fn(module) 214 | wandb.run.config["optimizer"] = self.detect_optimizer(module) 215 | 216 | def on_train_batch_start(self, trainer, module, batch, batch_idx, dataloader_idx): 217 | if "x_range" not in wandb.run.config.keys(): 218 | wandb.run.config["x_range"] = self.detect_x_range(batch) 219 | 220 | @staticmethod 221 | def detect_batchnorm(module): 222 | for module in module.modules(): 223 | if isinstance(module, (torch.nn.BatchNorm1d, torch.nn.BatchNorm2d)): 224 | return True 225 | return False 226 | 227 | @staticmethod 228 | def detect_dropout(module): 229 | dropout_cfg_dict = {"has_dropout": False} 230 | dropout_ct, dropout2d_ct = 0, 0 231 | for module in module.modules(): 232 | if isinstance(module, torch.nn.Dropout): 233 | if module.p > 0: 234 | dropout_cfg_dict["has_dropout"] = True 235 | dropout_cfg_dict[f"1d.{dropout_ct}.p"] = module.p 236 | dropout_ct += 1 237 | if isinstance(module, torch.nn.Dropout2d): 238 | if module.p > 0: 239 | dropout_cfg_dict["has_dropout"] = True 240 | dropout_cfg_dict[f"2d.{dropout2d_ct}.p"] = module.p 241 | dropout2d_ct += 1 242 | return dropout_cfg_dict 243 | 244 | @staticmethod 245 | def detect_loss_fn(module): 246 | try: 247 | classname = module.loss.__class__.__name__ 248 | if classname in ["method", "function"]: 249 | return "unknown" 250 | else: 251 | return classname 252 | except AttributeError: 253 | return 254 | 255 | @staticmethod 256 | def detect_optimizer(module): 257 | try: 258 | return module.optimizers().__class__.__name__ 259 | except AttributeError: 260 | return 261 | 262 | @staticmethod 263 | def detect_x_range(batch): 264 | with torch.no_grad(): 265 | xs = batch[0] 266 | x_range = [torch.min(xs), torch.max(xs)] 267 | return x_range 268 | 269 | 270 | class ImagePredLogCallback(pl.Callback): 271 | 272 | def __init__(self, max_images_to_log=32, labels=None, on_train=False): 273 | super().__init__() 274 | self.max_images_to_log = min(max(max_images_to_log, 0), 32) 275 | self.labels = labels 276 | self.on_train = on_train 277 | 278 | def on_validation_batch_end(self, trainer, pl_module, outputs, batch, batch_idx, dataloader_idx): 279 | if batch_idx == 0 and dataloader_idx == 0: 280 | images_with_predictions = self.package_images_predictions(outputs, batch) 281 | trainer.logger.experiment.log({"validation/predictions": images_with_predictions}) 282 | 283 | def on_train_batch_end(self, trainer, pl_module, outputs, batch, batch_idx, dataloader_idx): 284 | if self.on_train and batch_idx == 0 and dataloader_idx == 0: 285 | images_with_predictions = self.package_images_predictions(outputs, batch) 286 | trainer.logger.experiment.log({"train/predictions": images_with_predictions}) 287 | 288 | def package_images_predictions(self, outputs, batch): 289 | xs, ys = batch 290 | xs, ys = xs[:self.max_images_to_log], ys[:self.max_images_to_log] 291 | preds = self.preds_from_y_hats(outputs["y_hats"][:self.max_images_to_log]) 292 | 293 | if self.labels is not None: 294 | preds = [self.labels[int(pred)] for pred in preds] 295 | ys = [self.labels[int(y)] for y in ys] 296 | 297 | images_with_predictions = [ 298 | wandb.Image(x, caption=f"Pred: {pred}, Target: {y}") 299 | for x, pred, y in zip(xs, preds, ys) 300 | ] 301 | 302 | return images_with_predictions 303 | 304 | @staticmethod 305 | def preds_from_y_hats(y_hats): 306 | if y_hats.shape[-1] == 1 or len(y_hats.shape) == 1: # handle binary classification case 307 | preds = torch.greater(y_hats, 0.5) 308 | preds = [bool(pred) for pred in preds] 309 | else: # assume we are in the typical one-hot case 310 | preds = torch.argmax(y_hats, 1) 311 | return preds 312 | 313 | 314 | def get_weights(module): 315 | allowed_filter_types = ( 316 | torch.nn.Linear, 317 | torch.nn.Conv2d, 318 | torch.nn.ConvTranspose2d 319 | ) 320 | ls = [m for m in module.modules() if isinstance(m, allowed_filter_types)] 321 | ws = list(map(get_masked_weights, ls)) 322 | return ws 323 | 324 | 325 | def get_masked_weights(layer): 326 | with torch.no_grad(): 327 | for name, parameter in layer.named_parameters(): 328 | if any(key in name.split(".")[-1] for key in ["weight", "weight_orig"]): 329 | break 330 | for name, mask in layer.named_buffers(): 331 | if "weight_mask" in name.split(".")[-1]: 332 | parameter = parameter * mask 333 | return parameter 334 | 335 | 336 | def count_params(module): 337 | return sum(p.numel() for p in module.parameters()) 338 | 339 | 340 | def count_params_nonzero(module): 341 | """Counts the total number of non-zero parameters in a module. 342 | 343 | For compatibility with networks with active torch.nn.utils.prune methods, 344 | checks for _mask tensors, which are applied during forward passes and so 345 | represent the actual sparsity of the networks. 346 | """ 347 | suffix = "_mask" 348 | if module.named_buffers(): 349 | masks = {name[:-len(suffix)]: mask_tensor for name, mask_tensor in module.named_buffers() 350 | if name.endswith(suffix)} 351 | else: 352 | masks = {} 353 | 354 | nparams = 0 355 | with torch.no_grad(): 356 | for name, tensor in module.named_parameters(): 357 | if name[:len(suffix)] in masks.keys(): 358 | nparams += int(torch.sum(tensor != 0)) 359 | 360 | return nparams 361 | 362 | 363 | def fraction_nonzero(module): 364 | """Gives the fraction of parameters that are non-zero in a module.""" 365 | return count_params_nonzero(module) / count_params(module) 366 | -------------------------------------------------------------------------------- /datamodules/__init__.py: -------------------------------------------------------------------------------- 1 | # flake8: noqa 2 | from .cifar10 import * 3 | from .fer import * 4 | from .mnist import * 5 | -------------------------------------------------------------------------------- /datamodules/cifar10.py: -------------------------------------------------------------------------------- 1 | from math import floor 2 | import multiprocessing 3 | from pathlib import Path 4 | 5 | import pytorch_lightning as pl 6 | import torch 7 | from torch.utils.data import DataLoader 8 | import torchvision 9 | from torchvision import transforms 10 | 11 | DEFAULT_NUM_WORKERS = multiprocessing.cpu_count() 12 | 13 | 14 | class CIFAR10DataModule(pl.LightningDataModule): 15 | """Dataloaders and setup for the CIFAR10 dataset. 16 | 17 | Arguments: 18 | batch_size: int. Size of batches in training, validation, and test 19 | train_size: int or float. If int, number of examples in training set, 20 | If float, fraction of examples in training set. 21 | debug: bool. If True, cut dataset size by a factor of 10. 22 | """ 23 | data_root = Path(".") / "data" 24 | seed = 117 25 | classes = ["airplane", "automobile", "bird", "cat", "deer", 26 | "dog", "frog", "horse", "ship", "truck"] 27 | 28 | def __init__(self, batch_size, train_size=0.8): 29 | super().__init__() 30 | 31 | self.train_size, self.batch_size = train_size, batch_size 32 | 33 | self.transform = transforms.Compose( 34 | [transforms.ToTensor(), 35 | transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))]) 36 | 37 | def prepare_data(self): 38 | """Download the dataset.""" 39 | torchvision.datasets.CIFAR10(root=self.data_dir, train=True, 40 | download=True, transform=self.transform) 41 | 42 | torchvision.datasets.CIFAR10(root=self.data_dir, train=False, 43 | download=True, transform=self.transform) 44 | 45 | def setup(self, stage=None): 46 | """Set up training and test data and perform our train/val split.""" 47 | if stage in (None, "fit"): 48 | self.cifar10_fit = torchvision.datasets.CIFAR10( 49 | self.data_dir, train=True, download=False, transform=self.transform) 50 | total_size, *self.dims = self.cifar10_fit.data.shape 51 | split_sizes = self.get_split_sizes(self.train_size, total_size) 52 | 53 | split_generator = torch.Generator().manual_seed(self.seed) 54 | self.training_data, self.validation_data = torch.utils.data.random_split( 55 | self.cifar10_fit, split_sizes, split_generator) 56 | 57 | if stage in (None, "test"): 58 | self.test = torchvision.datasets.CIFAR10(self.data_dir, train=False, 59 | transform=self.transform) 60 | 61 | def train_dataloader(self): 62 | trainloader = DataLoader(self.training_data, batch_size=self.batch_size, 63 | shuffle=True, num_workers=DEFAULT_NUM_WORKERS) 64 | return trainloader 65 | 66 | def val_dataloader(self): 67 | valloader = DataLoader(self.validation_data, batch_size=2 * self.batch_size, 68 | shuffle=False, num_workers=DEFAULT_NUM_WORKERS) 69 | return valloader 70 | 71 | def test_dataloader(self): 72 | testloader = DataLoader(self.test_data, batch_size=2 * self.batch_size, 73 | shuffle=False, num_workers=DEFAULT_NUM_WORKERS) 74 | return testloader 75 | 76 | @staticmethod 77 | def get_split_sizes(train_size, total_size): 78 | if isinstance(train_size, float): 79 | train_size = floor(total_size * train_size) 80 | 81 | val_size = total_size - train_size 82 | 83 | return train_size, val_size 84 | -------------------------------------------------------------------------------- /datamodules/fer.py: -------------------------------------------------------------------------------- 1 | """Lightning DataModules and associated utilities for FER2013 dataset.""" 2 | import multiprocessing 3 | import os 4 | from pathlib import Path 5 | import subprocess 6 | 7 | import numpy as np 8 | import pandas as pd 9 | import pytorch_lightning as pl 10 | import torch 11 | from torch.utils.data import DataLoader 12 | 13 | DEFAULT_NUM_WORKERS = multiprocessing.cpu_count() 14 | 15 | 16 | class FERDataModule(pl.LightningDataModule): 17 | """DataModule for downloading and preparing the FER2013 dataset.""" 18 | 19 | tar_url = "https://www.dropbox.com/s/opuvvdv3uligypx/fer2013.tar" 20 | data_root = Path(".") / "data" 21 | local_path = data_root / "fer2013" 22 | csv_file = local_path / "fer2013.csv" 23 | width, height = 48, 48 24 | classes = ["anger", "disgust", "fear", "happiness", 25 | "sadness", "surprise", "neurality"] 26 | 27 | def __init__(self, batch_size=64, validation_size=0.2, num_workers=DEFAULT_NUM_WORKERS): 28 | super().__init__() 29 | self.batch_size = batch_size 30 | self.val_batch_size = 2 * self.batch_size 31 | self.num_workers = num_workers 32 | self.validation_size = validation_size 33 | 34 | def setup(self): 35 | # download the data from the internet 36 | os.makedirs(self.local_path, exist_ok=True) 37 | self.download_data() 38 | 39 | def prepare_data(self): 40 | # read data from a .csv file 41 | faces, emotions = self.read_data() 42 | 43 | # normalize it 44 | faces = torch.divide(faces, 255.) 45 | 46 | # split it into training and validation 47 | num_validation = int(len(faces) * self.validation_size) 48 | 49 | self.training_data = torch.utils.data.TensorDataset( 50 | faces[:-num_validation], emotions[:-num_validation]) 51 | self.validation_data = torch.utils.data.TensorDataset( 52 | faces[-num_validation:], emotions[-num_validation:]) 53 | 54 | # record metadata 55 | self.num_total, self.num_classes = emotions.shape[0], len(self.classes) 56 | self.num_train = self.num_total - num_validation 57 | self.num_validation = num_validation 58 | 59 | def train_dataloader(self): 60 | """The DataLoaders returned by a DataModule produce data for a model. 61 | This DataLoader is used during training. 62 | """ 63 | return DataLoader(self.training_data, batch_size=self.batch_size, 64 | num_workers=self.num_workers, shuffle=True) 65 | 66 | def val_dataloader(self): 67 | """The DataLoaders returned by a DataModule produce data for a model. 68 | This DataLoader is used during validation, at the end of each epoch. 69 | """ 70 | return DataLoader(self.validation_data, batch_size=self.val_batch_size, 71 | num_workers=self.num_workers) 72 | 73 | def download_data(self): 74 | if not os.path.exists(self.csv_file): 75 | print("Downloading the face emotion dataset...") 76 | subprocess.check_output(f"curl -SL {self.tar_url} | tar xz", shell=True) 77 | subprocess.check_output(f"mv fer2013 -t {self.data_root}", shell=True) 78 | print("...done") 79 | 80 | def read_data(self): 81 | """Read the data from a .csv into torch Tensors.""" 82 | data = pd.read_csv(self.csv_file) 83 | pixels = data["pixels"].tolist() 84 | faces = [] 85 | for pixel_sequence in pixels: 86 | face = np.asarray(pixel_sequence.split( 87 | " "), dtype=np.uint8).reshape(1, self.width, self.height) 88 | faces.append(face.astype("float32")) 89 | 90 | faces = np.asarray(faces) 91 | emotions = data["emotion"].to_numpy() 92 | 93 | return torch.tensor(faces), torch.tensor(emotions) 94 | -------------------------------------------------------------------------------- /datamodules/mnist.py: -------------------------------------------------------------------------------- 1 | """Lightning DataModules and associated utilities for MNIST-style datasets. 2 | 3 | Based on code from torchvision MNIST datasets. 4 | 5 | All datasets are stored in memory as tensors and then converted to PIL 6 | before applying preprocessing. 7 | 8 | MNIST-style datasets that are not the handwritten digits dataset can inherit 9 | from the base class and only need to over-write the mirrors, resources, classes, 10 | and folder attributes/properties. 11 | 12 | Unlike most MNIST implementations but like most PNG images, uses 255 to represent white 13 | background and 0 to represent black foreground/strokes. 14 | """ 15 | from math import floor 16 | import multiprocessing 17 | import os 18 | from pathlib import Path 19 | 20 | import numpy as np 21 | from PIL import Image 22 | import pytorch_lightning as pl 23 | import torch 24 | from torch.utils.data import DataLoader 25 | import torchvision 26 | 27 | DEFAULT_WORKERS = multiprocessing.cpu_count() # cpu_count is a good default worker count 28 | 29 | 30 | class AbstractMNISTDataModule(pl.LightningDataModule): 31 | """Abstract DataModule for MNIST-style datasets. 32 | 33 | Must be made concrete by defining a .setup method which sets attributes 34 | self.training_data and self.validation_data and self.test_data. 35 | """ 36 | data_root = Path(".") / "data" 37 | seed = 117 38 | 39 | def __init__(self, batch_size=64, validation_size=10000, num_workers=DEFAULT_WORKERS, 40 | transform=None, target_transform=None): 41 | super().__init__() 42 | self._dataset = None 43 | 44 | self.batch_size = batch_size 45 | self.validation_size = validation_size 46 | self.num_workers = DEFAULT_WORKERS 47 | 48 | self.transform = transform 49 | self.target_transform = target_transform 50 | 51 | def prepare_data(self): 52 | # download the data from the internet 53 | self.dataset(self.data_root, download=True) 54 | self.dataset(self.data_root, train=False, download=True) 55 | 56 | def setup(self, stage=None): 57 | if stage in (None, "fit"): 58 | self.mnist_fit = self.dataset( 59 | self.data_root, train=True, download=False, 60 | transform=self.transform, target_transform=self.target_transform) 61 | 62 | total_size, *self.dims = self.mnist_fit.data.shape 63 | split_sizes = [total_size - self.validation_size, self.validation_size] 64 | split_generator = torch.Generator().manual_seed(self.seed) 65 | self.training_data, self.validation_data = torch.utils.data.random_split( 66 | self.mnist_fit, split_sizes, split_generator) 67 | 68 | if stage in (None, "test"): 69 | self.mnist_test = self.dataset( 70 | self.data_root, train=False, download=False, 71 | transform=None, target_transform=None) 72 | self.test_data = self.mnist_test 73 | 74 | def train_dataloader(self): 75 | """The DataLoaders returned by a DataModule produce data for a model. 76 | 77 | This DataLoader is used during training. 78 | """ 79 | return DataLoader(self.training_data, batch_size=self.batch_size, 80 | num_workers=DEFAULT_WORKERS) 81 | 82 | def val_dataloader(self): 83 | """The DataLoaders returned by a DataModule produce data for a model. 84 | 85 | This DataLoader is used during validation, at the end of each epoch. 86 | """ 87 | return DataLoader(self.validation_data, batch_size=2 * self.batch_size, 88 | num_workers=DEFAULT_WORKERS, shuffle=False) 89 | 90 | def test_dataloader(self): 91 | """The DataLoaders returned by a DataModule produce data for a model. 92 | 93 | This DataLoader is used during testing, at the end of training. 94 | """ 95 | return DataLoader(self.test_data, batch_size=2 * self.batch_size, 96 | num_workers=DEFAULT_WORKERS, shuffle=False) 97 | 98 | @property 99 | def dataset(self): 100 | if self._dataset is None: 101 | raise NotImplementedError("must provide an MNIST-style dataset class") 102 | return self._dataset 103 | 104 | @dataset.setter 105 | def dataset(self, cls): 106 | if not issubclass(cls, torchvision.datasets.MNIST): 107 | raise ValueError(f"dataset must be subclass of torchvision.datasets.MNIST, was {cls}") 108 | self._dataset = cls 109 | 110 | @staticmethod 111 | def get_split_sizes(train_size, total_size): 112 | if isinstance(train_size, float): 113 | train_size = floor(total_size * train_size) 114 | 115 | val_size = total_size - train_size 116 | 117 | return train_size, val_size 118 | 119 | 120 | class MNISTDataModule(AbstractMNISTDataModule): 121 | """DataModule for the MNIST handwritten digit classification task.""" 122 | 123 | def __init__(self, batch_size=64, validation_size=10000, transform=None, 124 | target_transform=None): 125 | super().__init__(batch_size=batch_size, validation_size=validation_size, 126 | transform=transform, target_transform=target_transform) 127 | self.dataset = ClassificationMNIST 128 | self.classes = self.dataset.classes 129 | 130 | 131 | class AutoEncoderMNISTDataModule(AbstractMNISTDataModule): 132 | """DataModule for an MNIST handwritten digit auto-encoding task.""" 133 | 134 | def __init__(self, batch_size=64, validation_size=10000, transform=None, 135 | target_transform=None): 136 | super().__init__(batch_size=batch_size, validation_size=validation_size, 137 | transform=transform, target_transform=target_transform) 138 | self.dataset = AutoEncoderMNIST 139 | self.classes = self.dataset.classes 140 | 141 | 142 | class FashionMNISTDataModule(AbstractMNISTDataModule): 143 | """DataModule for the MNIST handwritten digit classification task.""" 144 | 145 | def __init__(self, batch_size=64, validation_size=10000, transform=None, 146 | target_transform=None): 147 | super().__init__(batch_size=batch_size, validation_size=validation_size, 148 | transform=transform, target_transform=target_transform) 149 | self.dataset = ClassificationFashionMNIST 150 | self.classes = self.dataset.classes 151 | 152 | 153 | class AutoEncoderFashionMNISTDataModule(AbstractMNISTDataModule): 154 | """DataModule for an MNIST handwritten digit auto-encoding task.""" 155 | 156 | def __init__(self, batch_size=64, validation_size=10000, transform=None, 157 | target_transform=None): 158 | super().__init__(batch_size=batch_size, validation_size=validation_size, 159 | transform=transform, target_transform=target_transform) 160 | self.dataset = AutoEncoderFashionMNIST 161 | self.classes = self.dataset.classes 162 | 163 | 164 | class ClassificationMNIST(torchvision.datasets.MNIST): 165 | """Dataset for the MNIST handwritten digit recognition task. 166 | 167 | Modified from torchvision MNIST Dataset code. 168 | """ 169 | # remove slow mirror from default list 170 | mirrors = [mirror for mirror in torchvision.datasets.MNIST.mirrors 171 | if not mirror.startswith("http://yann.lecun.com")] 172 | 173 | classes = [str(ii) for ii in range(10)] 174 | 175 | default_transform = torchvision.transforms.ToTensor() 176 | default_target_transform = torchvision.transforms.Compose([]) 177 | 178 | def __getitem__(self, index): 179 | """Gets the image and label at index. 180 | 181 | Args: 182 | index (int): Index 183 | 184 | Returns: 185 | tuple: (image, target) where target is the index of the target class. 186 | """ 187 | img, target = self.data[index], int(self.targets[index]) 188 | img = reverse_palette(img) 189 | 190 | # doing this so that it is consistent with all other datasets 191 | # to return a PIL Image 192 | img = Image.fromarray(img.numpy(), mode="L") 193 | 194 | if self.transform is not None: 195 | img = self.transform(img) 196 | else: 197 | img = self.default_transform(img) 198 | 199 | if self.target_transform is not None: 200 | target = self.target_transform(target) 201 | else: 202 | target = self.default_target_transform(target) 203 | 204 | return img, target 205 | 206 | @property 207 | def raw_folder(self) -> str: 208 | return os.path.join(self.root, "MNIST", "raw") 209 | 210 | @property 211 | def processed_folder(self) -> str: 212 | return os.path.join(self.root, "MNIST", "processed") 213 | 214 | 215 | class AutoEncoderMNIST(torchvision.datasets.MNIST): 216 | """Dataset for the MNIST handwritten digit reconstruction task. 217 | 218 | Modified from torchvision MNIST Dataset code. 219 | """ 220 | # remove slow mirror from default list 221 | mirrors = [mirror for mirror in torchvision.datasets.MNIST.mirrors 222 | if not mirror.startswith("http://yann.lecun.com")] 223 | 224 | classes = [str(ii) for ii in range(10)] 225 | 226 | default_transform = torchvision.transforms.ToTensor() 227 | default_target_transform = torchvision.transforms.ToTensor() 228 | 229 | def __getitem__(self, index): 230 | """Gets the image at index. 231 | 232 | Args: 233 | index (int): Index 234 | 235 | Returns: 236 | tuple: (image, image) 237 | """ 238 | _img = self.data[index] 239 | _img = reverse_palette(_img) 240 | 241 | # doing this so that it is consistent with all other datasets 242 | # to return a PIL Image 243 | img = Image.fromarray(_img.numpy(), mode="L") 244 | target = Image.fromarray(_img.numpy(), mode="L") 245 | 246 | if self.transform is not None: 247 | img = self.transform(img) 248 | else: 249 | img = self.default_transform(img) 250 | 251 | if self.target_transform is not None: 252 | target = self.target_transform(target) 253 | else: 254 | target = self.default_target_transform(target) 255 | 256 | return img, target 257 | 258 | @property 259 | def raw_folder(self) -> str: 260 | return os.path.join(self.root, "MNIST", "raw") 261 | 262 | @property 263 | def processed_folder(self) -> str: 264 | return os.path.join(self.root, "MNIST", "processed") 265 | 266 | 267 | class ClassificationFashionMNIST(ClassificationMNIST): 268 | """Dataset for the MNIST fashion item recognition task.""" 269 | 270 | mirrors = [ 271 | "http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/" 272 | ] 273 | 274 | resources = [ 275 | ("train-images-idx3-ubyte.gz", "8d4fb7e6c68d591d4c3dfef9ec88bf0d"), 276 | ("train-labels-idx1-ubyte.gz", "25c81989df183df01b3e8a0aad5dffbe"), 277 | ("t10k-images-idx3-ubyte.gz", "bef4ecab320f06d8554ea6380940ec79"), 278 | ("t10k-labels-idx1-ubyte.gz", "bb300cfdad3c16e7a12a480ee83cd310") 279 | ] 280 | 281 | classes = ["T-shirt/top", "Trouser", "Pullover", "Dress", "Coat", "Sandal", 282 | "Shirt", "Sneaker", "Bag", "Ankle boot"] 283 | 284 | @property 285 | def raw_folder(self) -> str: 286 | return os.path.join(self.root, "FashionMNIST", "raw") 287 | 288 | @property 289 | def processed_folder(self) -> str: 290 | return os.path.join(self.root, "FashionMNIST", "processed") 291 | 292 | 293 | class AutoEncoderFashionMNIST(AutoEncoderMNIST): 294 | """Dataset for the MNIST fashion item reconstruction task.""" 295 | 296 | mirrors = [ 297 | "http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/" 298 | ] 299 | 300 | resources = [ 301 | ("train-images-idx3-ubyte.gz", "8d4fb7e6c68d591d4c3dfef9ec88bf0d"), 302 | ("train-labels-idx1-ubyte.gz", "25c81989df183df01b3e8a0aad5dffbe"), 303 | ("t10k-images-idx3-ubyte.gz", "bef4ecab320f06d8554ea6380940ec79"), 304 | ("t10k-labels-idx1-ubyte.gz", "bb300cfdad3c16e7a12a480ee83cd310") 305 | ] 306 | 307 | classes = ["T-shirt/top", "Trouser", "Pullover", "Dress", "Coat", "Sandal", 308 | "Shirt", "Sneaker", "Bag", "Ankle boot"] 309 | 310 | @property 311 | def raw_folder(self) -> str: 312 | return os.path.join(self.root, "FashionMNIST", "raw") 313 | 314 | @property 315 | def processed_folder(self) -> str: 316 | return os.path.join(self.root, "FashionMNIST", "processed") 317 | 318 | 319 | def reverse_palette(img): 320 | return np.abs(255 - img) 321 | -------------------------------------------------------------------------------- /nn/__init__.py: -------------------------------------------------------------------------------- 1 | # flake8: noqa 2 | from . import fc, conv, modules, prune, quant, shapes 3 | -------------------------------------------------------------------------------- /nn/conv.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | 4 | class Convolution2d(torch.nn.Module): 5 | """Applies 2D convolution to the inputs with torch.nn.Conv2d. 6 | 7 | Also can apply nonlinear activations and other components of convolutional layers. 8 | 9 | Args: 10 | in_channels: int. Number of channels in the input tensor. 11 | out_channels: int. Number of channels in the output tensor. 12 | kernel_size: int or tuple. Shape of convolutional kernel. If int, promoted to (int, int). 13 | activation: Callable or None. (Typically nonlinear) activation function for layer. 14 | If None, the identity function is applied. 15 | batchnorm: String or None. If not None, specifies whether to apply batchnorm 16 | before ("pre") or after ("post") the activation function. 17 | dropout: float or None. If not None, apply dropout with this probability. 18 | kwargs: keyword arguments for torch.nn.Conv2d. 19 | """ 20 | 21 | def __init__(self, in_channels, out_channels, kernel_size, 22 | activation=None, 23 | batchnorm=None, dropout=None, **kwargs): 24 | super().__init__() 25 | self.conv = torch.nn.Conv2d(in_channels, out_channels, kernel_size, **kwargs) 26 | 27 | preactivation = [] 28 | if batchnorm == "pre": 29 | preactivation.append(torch.nn.BatchNorm2d(out_channels)) 30 | 31 | if activation is None: 32 | activation = torch.nn.Identity() 33 | 34 | if not callable(activation): 35 | raise ValueError(f"activation must be callable, was {activation}") 36 | 37 | postactivation = [] 38 | if dropout is not None: 39 | postactivation.append(torch.nn.Dropout2d(dropout)) 40 | if batchnorm == "post": 41 | postactivation.append(torch.nn.BatchNorm2d(out_channels)) 42 | 43 | self.preactivation = torch.nn.Sequential(*preactivation) 44 | self.activation = activation 45 | self.postactivation = torch.nn.Sequential(*postactivation) 46 | 47 | def forward(self, x): 48 | return self.postactivation(self.activation(self.preactivation(self.conv(x)))) 49 | 50 | 51 | class ConvolutionTranspose2d(torch.nn.Module): 52 | """Applies 2D tranposed convolution to the inputs. 53 | 54 | Also can apply nonlinear activations and other components of convolutional layers. 55 | 56 | Args: 57 | in_channels: int. Number of channels in the input tensor. 58 | out_channels: int. Number of channels in the output tensor. 59 | kernel_size: int or tuple. Shape of convolutional kernel. If int, promoted to (int, int). 60 | activation: Callable or None. (Typically nonlinear) activation function for layer. 61 | If None, the identity function is applied. 62 | batchnorm: String or None. If not None, specifies whether to apply batchnorm 63 | before ("pre") or after ("post") the activation function. 64 | dropout: float or None. If not None, apply dropout with this probability. 65 | kwargs: keyword arguments for torch.nn.ConvTranspose2d. 66 | """ 67 | 68 | def __init__(self, in_channels, out_channels, kernel_size, 69 | activation=None, batchnorm=None, dropout=None, **kwargs): 70 | super().__init__() 71 | self.conv = torch.nn.ConvTranspose2d(in_channels, out_channels, kernel_size, **kwargs) 72 | 73 | preactivation = [] 74 | if batchnorm == "pre": 75 | preactivation.append(torch.nn.BatchNorm2d(out_channels)) 76 | 77 | if activation is None: 78 | activation = torch.nn.Identity() 79 | 80 | if not callable(activation): 81 | raise ValueError(f"activation must be callable, was {activation}") 82 | 83 | postactivation = [] 84 | if dropout is not None: 85 | postactivation.append(torch.nn.Dropout2d(dropout)) 86 | if batchnorm == "post": 87 | postactivation.append(torch.nn.BatchNorm2d(out_channels)) 88 | 89 | self.preactivation = torch.nn.Sequential(*preactivation) 90 | self.activation = activation 91 | self.postactivation = torch.nn.Sequential(*postactivation) 92 | 93 | def forward(self, x): 94 | return self.postactivation(self.activation(self.preactivation(self.conv(x)))) 95 | -------------------------------------------------------------------------------- /nn/fc.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | 4 | class FullyConnected(torch.nn.Module): 5 | """Applies a dense matrix to the inputs. 6 | 7 | Also known as a "dense", "linear", or "perceptron" layer. 8 | 9 | Args: 10 | in_features: int. Number of entries in the input feature vector. 11 | out_features: int. Number of entries in the output feature vector. 12 | activation: Callable or None. (Typically nonlinear) activation function for layer. 13 | If None, the identity function is applied. 14 | batchnorm: String or None. If not None, specifies whether to apply batchnorm 15 | before ("pre") or after ("post") the activation function. 16 | dropout: float or None. If not None, adds dropout layer after the activation function 17 | using the provided value as the dropout probability. 18 | """ 19 | 20 | def __init__(self, in_features, out_features, activation=None, 21 | batchnorm=None, dropout=None): 22 | super().__init__() 23 | self.linear = torch.nn.Linear(in_features, out_features) 24 | 25 | preactivation = [] 26 | if batchnorm == "pre": 27 | preactivation.append(torch.nn.BatchNorm1d(out_features)) 28 | 29 | if activation is None: 30 | activation = torch.nn.Identity() 31 | 32 | if not callable(activation): 33 | raise ValueError(f"activation must be callable, was {activation}") 34 | 35 | postactivation = [] 36 | if dropout is not None: 37 | postactivation.append(torch.nn.Dropout(dropout)) 38 | if batchnorm == "post": 39 | postactivation.append(torch.nn.BatchNorm1d(out_features)) 40 | 41 | self.preactivation = torch.nn.Sequential(*preactivation) 42 | self.activation = activation 43 | self.postactivation = torch.nn.Sequential(*postactivation) 44 | 45 | def forward(self, x): 46 | return self.postactivation(self.activation(self.preactivation(self.linear(x)))) 47 | -------------------------------------------------------------------------------- /nn/modules.py: -------------------------------------------------------------------------------- 1 | """Basic Lightning Modules plus Weights & Biases features.""" 2 | import pytorch_lightning as pl 3 | import torch 4 | import torchmetrics 5 | 6 | 7 | class LoggedLitModule(pl.LightningModule): 8 | """LightningModule plus wandb features and simple training/val steps. 9 | 10 | By default, assumes that your training loop involves inputs (xs) 11 | fed to .forward to produce outputs (y_hats) 12 | that are compared to targets (ys) 13 | by self.loss and by metrics, 14 | where each batch == (xs, ys). 15 | This loss is fed to self.optimizer. 16 | 17 | If this is not true, overwrite _train_forward 18 | and optionally _val_forward and _test_forward. 19 | """ 20 | 21 | def __init__(self): 22 | super().__init__() 23 | 24 | self.training_metrics = torch.nn.ModuleList([]) 25 | self.validation_metrics = torch.nn.ModuleList([]) 26 | self.test_metrics = torch.nn.ModuleList([]) 27 | 28 | def training_step(self, xys, idx): 29 | xs, ys = xys 30 | y_hats = self._train_forward(xs) 31 | loss = self.loss(y_hats, ys) 32 | 33 | logging_scalars = {"loss": loss} 34 | for metric in self.training_metrics: 35 | self.log_metric(metric, logging_scalars, y_hats, ys) 36 | 37 | self.do_logging(xs, ys, idx, y_hats, logging_scalars) 38 | 39 | return {"loss": loss, "y_hats": y_hats} 40 | 41 | def validation_step(self, xys, idx): 42 | xs, ys = xys 43 | y_hats = self._val_forward(xs) 44 | loss = self.loss(y_hats, ys) 45 | 46 | logging_scalars = {"loss": loss} 47 | for metric in self.validation_metrics: 48 | self.log_metric(metric, logging_scalars, y_hats, ys) 49 | 50 | self.do_logging(xs, ys, idx, y_hats, logging_scalars, step="validation") 51 | 52 | return {"loss": loss, "y_hats": y_hats} 53 | 54 | def test_step(self, xys, idx): 55 | xs, ys = xys 56 | y_hats = self._test_forward(xs) 57 | loss = self.loss(y_hats, ys) 58 | 59 | logging_scalars = {"loss": loss} 60 | for metric in self.test_metrics: 61 | self.log_metric(metric, logging_scalars, y_hats, ys) 62 | 63 | self.do_logging(xs, ys, idx, y_hats, logging_scalars, step="test") 64 | 65 | return {"loss": loss, "y_hats": y_hats} 66 | 67 | def do_logging(self, xs, ys, idx, y_hats, scalars, step="training"): 68 | self.log_dict( 69 | {step + "/" + name: value for name, value in scalars.items()}) 70 | 71 | def on_pretrain_routine_start(self): 72 | print(self) 73 | 74 | def log_metric(self, metric, logging_scalars, y_hats, ys): 75 | metric_str = metric.__class__.__name__.lower() 76 | value = metric(y_hats, ys) 77 | logging_scalars[metric_str] = value 78 | 79 | def _train_forward(self, xs): 80 | """Overwrite this method when module.forward doesn't produce y_hats.""" 81 | return self.forward(xs) 82 | 83 | def _val_forward(self, xs): 84 | """Overwrite this method when training and val forward differ.""" 85 | return self._train_forward(xs) 86 | 87 | def _test_forward(self, xs): 88 | """Overwrite this method when val and test forward differ.""" 89 | return self._val_forward(xs) 90 | 91 | def configure_optimizers(self): 92 | return self.optimizer(self.parameters(), **self.optimizer_params) 93 | 94 | def optimizer(self, *args, **kwargs): 95 | error_msg = ("To use LoggedLitModule, you must set self.optimizer to a torch-style Optimizer" 96 | + "and set self.optimizer_params to a dictionary of keyword arguments.") 97 | raise NotImplementedError(error_msg) 98 | 99 | 100 | class LoggedImageClassifierModule(LoggedLitModule): 101 | """LightningModule for image classification with Weights and Biases logging.""" 102 | def __init__(self): 103 | 104 | super().__init__() 105 | 106 | self.train_acc = torchmetrics.Accuracy() 107 | self.valid_acc = torchmetrics.Accuracy() 108 | self.test_acc = torchmetrics.Accuracy() 109 | 110 | self.training_metrics.append(self.train_acc) 111 | self.validation_metrics.append(self.valid_acc) 112 | self.test_metrics.append(self.test_acc) 113 | 114 | def log_metric(self, metric, logging_scalars, y_hats, ys): 115 | metric_str = metric.__class__.__name__.lower() 116 | if metric_str == "accuracy": 117 | float_types = [torch.float32, torch.float16, torch.float64] 118 | if ys.dtype in float_types: # handle binary classification with MSELoss case 119 | ys = ys.int() # accuracy expects ints, MSELoss expects floats 120 | value = metric(y_hats, ys) 121 | logging_scalars[metric_str] = value 122 | -------------------------------------------------------------------------------- /nn/prune.py: -------------------------------------------------------------------------------- 1 | """Utilities for working with model pruning.""" 2 | import copy 3 | import math 4 | 5 | import torch 6 | 7 | 8 | def make_prune_config(base_prune_config, network=None, n_epochs=None): 9 | """Builds a config dictionary for a pl.callbacks.ModelPruning callback. 10 | 11 | Aside from the keyword arguments to that class, this dictionary 12 | may contain the keys "target_sparsity" 13 | 14 | target_sparsity is combined with n_epochs to determine the value of the 15 | "amount" keyword argument to ModelPruning, which specifies how much pruning to 16 | do on each epoch. 17 | 18 | The key "parameters" can be None, "conv", or "linear". It is used to fetch the 19 | parameters which are to be pruned from the provided network. See 20 | get_parameters_to_prune for details. Note that None corresponds to pruning 21 | all parameters. 22 | """ 23 | prune_config = copy.copy(base_prune_config) 24 | if "target_sparsity" in prune_config.keys(): 25 | target = prune_config.pop("target_sparsity") 26 | if n_epochs is None: 27 | raise ValueError("when specifying target sparsity, must provide number of epochs") 28 | prune_config["amount"] = compute_iterative_prune(target, n_epochs) 29 | 30 | if "amount" not in prune_config.keys(): 31 | raise ValueError("must specify stepwise pruning amount or target in base_prune_config") 32 | 33 | if "parameters" in prune_config.keys(): 34 | parameters = prune_config.pop("parameters") 35 | if parameters is not None: 36 | if network is None: 37 | raise ValueError("when specifying parameters, must provide network") 38 | prune_config["parameters_to_prune"] = get_parameters_to_prune(parameters, network) 39 | 40 | if "parameters_to_prune" not in prune_config.keys(): 41 | raise ValueError("must specify which parameters_to_prune in base_prune_config, " 42 | "use None for global pruning") 43 | 44 | return prune_config 45 | 46 | 47 | def get_parameters_to_prune(parameters, network): 48 | """Return the weights of network matching the parameters value. 49 | 50 | Parameters must be one of "conv" or "linear", or None, 51 | in which case None is also returned. 52 | """ 53 | if parameters == "conv": 54 | return [(layer, "weight") for layer in network.modules() 55 | if isinstance(layer, torch.nn.Conv2d)] 56 | elif parameters == "linear": 57 | return [(layer, "weight") for layer in network.modules() 58 | if isinstance(layer, torch.nn.Linear)] 59 | elif parameters is None: 60 | return 61 | else: 62 | raise ValueError(f"could not understand parameters value: {parameters}") 63 | 64 | 65 | def compute_iterative_prune(target_sparsity, n_epochs): 66 | return 1 - math.pow(1 - target_sparsity, 1 / n_epochs) 67 | -------------------------------------------------------------------------------- /nn/quant.py: -------------------------------------------------------------------------------- 1 | """Utilities for working with quantized networks.""" 2 | import torch 3 | 4 | 5 | def run_static_quantization(network, xs, qconfig="fbgemm"): 6 | """Return a quantized version of supplied network. 7 | 8 | Runs forward pass of network with xs, so make sure they're on 9 | the same device. Returns a copy of the network, so watch memory consumption. 10 | 11 | Note that this uses torch.quantization, rather than PyTorchLightning. 12 | 13 | Args: 14 | network: torch.nn.Module, network to be quantized. 15 | xs: torch.Tensor, valid inputs for network.forward. 16 | qconfig: string, "fbgemm" to quantize for server/x86, "qnnpack" for mobile/ARM 17 | """ 18 | # set up quantization 19 | network.qconfig = torch.quantization.get_default_qconfig(qconfig) 20 | network.eval() 21 | 22 | # attach methods for collecting activation statistics to set quantization bounds 23 | qnetwork = torch.quantization.prepare(network) 24 | 25 | # run inputs through network, collect stats 26 | qnetwork.forward(xs) 27 | 28 | # convert network to uint8 using quantization statistics 29 | qnetwork = torch.quantization.convert(qnetwork) 30 | 31 | return qnetwork 32 | -------------------------------------------------------------------------------- /nn/shapes.py: -------------------------------------------------------------------------------- 1 | from math import floor 2 | 3 | import torch 4 | 5 | 6 | def sequential_output_shape(self, shape): 7 | """Computes the output shape of a torch.nn.Sequential. 8 | 9 | Optimistically assumes any layer without method does not change shape. 10 | """ 11 | for element in self: 12 | for cls, method in output_shape_methods.items(): 13 | if isinstance(element, cls): 14 | shape = method(element, shape) 15 | break 16 | 17 | return shape 18 | 19 | 20 | def sequential_feature_dim(self): 21 | """Computes the feature dimension of a torch.nn.Sequential. 22 | 23 | Returns None if feature dimension cannot be determined. 24 | """ 25 | feature_dim = None 26 | for element in reversed(self): 27 | for cls, method in feature_dim_methods.items(): 28 | if isinstance(element, cls): 29 | feature_dim = method(element) 30 | if feature_dim is not None: 31 | return feature_dim 32 | 33 | 34 | def conv2d_output_shape(module, h_w): 35 | """Computes the output shape of 2d convolutional operators.""" 36 | # grab operator properties 37 | props = module.kernel_size, module.stride, module.padding, module.dilation 38 | # diagonalize into tuples as needed 39 | props = [tuple((p, p)) if not isinstance(p, tuple) else p for p in props] 40 | # "transpose" operator properties -- list indices are height/width rather than property id 41 | props = list(zip(*props)) 42 | 43 | h = conv1d_output_shape(h_w[0], *props[0]) # calculate h from height parameters of props 44 | w = conv1d_output_shape(h_w[1], *props[1]) # calculate w from width parameters of props 45 | 46 | assert (h > 0) & (w > 0), "Invalid parameters" 47 | 48 | return h, w 49 | 50 | 51 | def conv1d_output_shape(lngth, kernel_size, stride, padding, dilation): 52 | """Computes the change in dimensions for a 1d convolutional operator.""" 53 | return floor( ((lngth + (2 * padding) - (dilation * (kernel_size - 1)) - 1) / stride) + 1) # noqa 54 | 55 | 56 | def convtranspose2d_output_shape(*args, **kwargs): 57 | raise NotImplementedError 58 | 59 | 60 | output_shape_methods = { # order is important here; torch.nn.Module must be last 61 | torch.nn.Sequential: sequential_output_shape, 62 | torch.nn.Conv2d: conv2d_output_shape, 63 | torch.nn.MaxPool2d: conv2d_output_shape, 64 | torch.nn.Linear: lambda module, shape: module.out_features, 65 | torch.nn.AdaptiveAvgPool2d: lambda module, shape: module.output_size, 66 | torch.nn.Module: lambda module, shape: shape, 67 | } 68 | 69 | feature_dim_methods = { 70 | torch.nn.Sequential: sequential_feature_dim, 71 | torch.nn.Conv2d: lambda module: module.out_channels, 72 | torch.nn.ConvTranspose2d: lambda module: module.out_channels, 73 | torch.nn.Linear: lambda module: module.out_features, 74 | } 75 | -------------------------------------------------------------------------------- /requirements-dev.txt: -------------------------------------------------------------------------------- 1 | pandas==1.1.5 2 | pytorch_lightning==1.3.8 3 | torch==1.9.0 4 | torchmetrics==0.4.1 5 | torchvision==0.10.0 6 | wandb 7 | 8 | flake8 9 | flake8-docstrings 10 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | pandas==1.1.5 2 | pytorch_lightning==1.3.8 3 | torch==1.9.0 4 | torchmetrics==0.4.1 5 | torchvision==0.10.0 6 | wandb 7 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | """Assorted utitilies for Lightning plus Weights & Biases.""" 2 | import random 3 | import string 4 | import warnings 5 | 6 | from pytorch_lightning.utilities.warnings import LightningDeprecationWarning 7 | 8 | 9 | try: 10 | from wonderwords import RandomWord 11 | no_wonderwords = False 12 | except ImportError: 13 | no_wonderwords = True 14 | 15 | 16 | if no_wonderwords: 17 | chars = string.ascii_lowercase 18 | 19 | def make_random_name(): 20 | return "".join([random.choice(chars) for ii in range(10)]) 21 | 22 | else: 23 | r = RandomWord() 24 | 25 | def make_random_name(): 26 | name = "-".join( 27 | [r.word(word_min_length=3, word_max_length=7, include_parts_of_speech=["adjective"]), 28 | r.word(word_min_length=5, word_max_length=7, include_parts_of_speech=["noun"])]) 29 | return name 30 | 31 | 32 | def filter_warnings(): 33 | """Filters warnings that students do not need to see.""" 34 | warnings.simplefilter("ignore", category=UserWarning) 35 | warnings.simplefilter("ignore", category=LightningDeprecationWarning) 36 | --------------------------------------------------------------------------------