├── yamle ├── utils │ ├── __init__.py │ ├── specific │ │ ├── __init__.py │ │ └── mimo_experiments │ │ │ └── __init__.py │ ├── regularizer_utils.py │ ├── export_utils.py │ ├── pruning_utils.py │ ├── tracing_utils.py │ ├── cli_utils.py │ └── running_utils.py ├── third_party │ ├── __init__.py │ ├── imagenet_c │ │ ├── frost │ │ │ ├── frost1.png │ │ │ ├── frost2.png │ │ │ ├── frost3.png │ │ │ ├── frost4.jpg │ │ │ ├── frost5.jpg │ │ │ └── frost6.jpg │ │ ├── tabular_corruptions.py │ │ └── extra.py │ ├── medmnist │ │ └── __init__.py │ └── tinyimagenet │ │ └── __init__.py ├── evaluation │ └── metrics │ │ ├── __init__.py │ │ └── algorithmic │ │ └── segmentation.py ├── pruning │ ├── unstructured │ │ ├── __init__.py │ │ └── magnitude.py │ ├── __init__.py │ └── pruner.py ├── quantization │ ├── models │ │ ├── __init__.py │ │ ├── specific │ │ │ ├── __init__.py │ │ │ └── mcdropout.py │ │ └── operations.py │ ├── __init__.py │ ├── static │ │ └── __init__.py │ └── qat │ │ └── __init__.py ├── __init__.py ├── trainers │ ├── __init__.py │ ├── calibration.py │ └── ensemble.py ├── losses │ ├── __init__.py │ ├── evidential_regression.py │ └── contrastive.py ├── models │ ├── __init__.py │ ├── specific │ │ ├── temperature_scaling.py │ │ ├── laplace.py │ │ ├── evidential_regression.py │ │ ├── ensemble.py │ │ ├── mimmo.py │ │ └── sgld.py │ ├── model.py │ └── gp.py ├── regularizers │ ├── __init__.py │ ├── gradient.py │ ├── weight.py │ ├── regularizer.py │ ├── model.py │ └── feature.py ├── cli │ ├── analyze_tune.py │ ├── retest.py │ └── train.py ├── methods │ ├── temperature_scaling.py │ ├── __init__.py │ ├── sngp.py │ ├── sgld.py │ ├── be.py │ └── delta_uq.py ├── defaults.py └── data │ ├── dataset_wrappers.py │ └── __init__.py ├── docs ├── requirements.txt ├── index.rst ├── Makefile ├── conf.py ├── extending_yamle │ ├── index.rst │ ├── model.rst │ └── datamodule.rst └── getting_started │ └── index.rst ├── .readthedocs.yaml ├── requirements.txt ├── .github └── ISSUE_TEMPLATE │ ├── feature_request.md │ └── bug_report.md ├── setup.py ├── .gitignore └── CODE_OF_CONDUCT.md /yamle/utils/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /yamle/third_party/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /yamle/utils/specific/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /yamle/evaluation/metrics/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /yamle/pruning/unstructured/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /yamle/quantization/models/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /yamle/__init__.py: -------------------------------------------------------------------------------- 1 | __version__ = "0.0.1" 2 | -------------------------------------------------------------------------------- /yamle/quantization/models/specific/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /yamle/utils/specific/mimo_experiments/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /yamle/third_party/imagenet_c/frost/frost1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/martinferianc/yamle/HEAD/yamle/third_party/imagenet_c/frost/frost1.png -------------------------------------------------------------------------------- /yamle/third_party/imagenet_c/frost/frost2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/martinferianc/yamle/HEAD/yamle/third_party/imagenet_c/frost/frost2.png -------------------------------------------------------------------------------- /yamle/third_party/imagenet_c/frost/frost3.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/martinferianc/yamle/HEAD/yamle/third_party/imagenet_c/frost/frost3.png -------------------------------------------------------------------------------- /yamle/third_party/imagenet_c/frost/frost4.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/martinferianc/yamle/HEAD/yamle/third_party/imagenet_c/frost/frost4.jpg -------------------------------------------------------------------------------- /yamle/third_party/imagenet_c/frost/frost5.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/martinferianc/yamle/HEAD/yamle/third_party/imagenet_c/frost/frost5.jpg -------------------------------------------------------------------------------- /yamle/third_party/imagenet_c/frost/frost6.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/martinferianc/yamle/HEAD/yamle/third_party/imagenet_c/frost/frost6.jpg -------------------------------------------------------------------------------- /docs/requirements.txt: -------------------------------------------------------------------------------- 1 | docutils==0.20.1 2 | Sphinx==7.2.6 3 | sphinx-copybutton==0.5.1 4 | sphinx-hoverxref==1.3.0 5 | sphinxext-opengraph==0.8.2 6 | pydata-sphinx-theme==0.14.1 7 | sphinx-autodoc-typehints==1.24.0 8 | sphinx-paramlinks==0.6.0 -------------------------------------------------------------------------------- /.readthedocs.yaml: -------------------------------------------------------------------------------- 1 | version: 2 2 | 3 | build: 4 | os: ubuntu-20.04 5 | tools: 6 | python: "3.9" 7 | 8 | formats: 9 | - htmlzip 10 | 11 | sphinx: 12 | builder: html 13 | configuration: docs/conf.py 14 | fail_on_warning: true 15 | 16 | python: 17 | install: 18 | - requirements: docs/requirements.txt 19 | - requirements: requirements.txt -------------------------------------------------------------------------------- /docs/index.rst: -------------------------------------------------------------------------------- 1 | .. include:: ../README.rst 2 | 3 | .. toctree:: 4 | :maxdepth: 1 5 | :hidden: 6 | 7 | Getting Started 8 | Extending YAMLE 9 | API Reference <_apidoc/modules> 10 | 11 | Indices and tables 12 | __________________ 13 | 14 | * :ref:`genindex` 15 | * :ref:`modindex` 16 | * :ref:`search` -------------------------------------------------------------------------------- /yamle/quantization/models/operations.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | 5 | class QuantizableAdd(nn.Module): 6 | """A simple class implementing residual addition but with a FloatFunctional object.""" 7 | 8 | def __init__(self) -> None: 9 | super(QuantizableAdd, self).__init__() 10 | self._add = torch.ao.nn.quantized.FloatFunctional() 11 | 12 | def forward(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: 13 | """The forward function of the residual addition.""" 14 | return self._add.add(x, y) 15 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | torch==2.0.0 2 | autograd==1.5 3 | pytorch-lightning==2.0.8 4 | pandas==2.1.0 5 | scienceplots>=2.1.0 6 | torchdata==0.6.0 7 | torchvision==0.15.1 8 | torchtext==0.15.1 9 | torchmetrics==1.0.0 10 | scikit-learn==1.1.3 11 | scikit-optimize==0.9.0 12 | scikit-image==0.20.0 13 | opencv-python==4.7.0.72 14 | medmnist==2.2.3 15 | h5py==3.7.0 16 | ptflops==0.7.1.2 17 | torchinfo>=1.8.0 18 | tensorboard>=2.17.0 19 | einops>=0.6.1 20 | fvcore>=0.1.5.post20221221 21 | rich>=13.5.2 22 | onnx==1.15.0 23 | backpack-for-pytorch==1.6.0 24 | paretoset==1.2.3 25 | natsort==8.4.0 26 | backpack-for-pytorch==1.6.0 27 | paretoset==1.2.3 28 | natsort==8.4.0 29 | gpytorch==1.11.0 30 | syne-tune[basic]==0.10.0 -------------------------------------------------------------------------------- /.github/ISSUE_TEMPLATE/feature_request.md: -------------------------------------------------------------------------------- 1 | --- 2 | name: Feature request 3 | about: Suggest an idea for this project 4 | title: '' 5 | labels: '' 6 | assignees: '' 7 | 8 | --- 9 | 10 | **Is your feature request related to a problem? Please describe.** 11 | A clear and concise description of what the problem is. Ex. I'm always frustrated when [...] 12 | 13 | **Describe the solution you'd like** 14 | A clear and concise description of what you want to happen. 15 | 16 | **Describe alternatives you've considered** 17 | A clear and concise description of any alternative solutions or features you've considered. 18 | 19 | **Additional context** 20 | Add any other context or screenshots about the feature request here. 21 | -------------------------------------------------------------------------------- /docs/Makefile: -------------------------------------------------------------------------------- 1 | # Minimal makefile for Sphinx documentation 2 | # 3 | 4 | # You can set these variables from the command line, and also 5 | # from the environment for the first two. 6 | SPHINXOPTS ?= 7 | SPHINXBUILD ?= sphinx-build 8 | SOURCEDIR = . 9 | BUILDDIR = build 10 | 11 | # Put it first so that "make" without argument is like "make help". 12 | help: 13 | @$(SPHINXBUILD) -M help "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) 14 | 15 | .PHONY: help Makefile 16 | 17 | # Catch-all target: route all unknown targets to Sphinx using the new 18 | # "make mode" option. $(O) is meant as a shortcut for $(SPHINXOPTS). 19 | %: Makefile 20 | @$(SPHINXBUILD) -M $@ "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) 21 | -------------------------------------------------------------------------------- /yamle/pruning/__init__.py: -------------------------------------------------------------------------------- 1 | from typing import Optional, Type, Callable 2 | 3 | from yamle.pruning.unstructured.magnitude import UnstructuredMagnitudePruner 4 | from yamle.pruning.pruner import DummyPruner 5 | 6 | AVAILABLE_PRUNERS = { 7 | "unstructured_magnitude": UnstructuredMagnitudePruner, 8 | None: DummyPruner, 9 | "dummy": DummyPruner, 10 | "none": DummyPruner, 11 | } 12 | 13 | 14 | def pruner_factory(pruner_type: Optional[str] = None) -> Type[Callable]: 15 | """This function is used to create a pruner instance based on the pruner type. 16 | 17 | Args: 18 | pruner_type (str): The type of pruner to create. 19 | """ 20 | if pruner_type not in AVAILABLE_PRUNERS: 21 | raise ValueError(f"Unknown pruner type {pruner_type}.") 22 | return AVAILABLE_PRUNERS[pruner_type] 23 | -------------------------------------------------------------------------------- /yamle/trainers/__init__.py: -------------------------------------------------------------------------------- 1 | from typing import Type 2 | 3 | 4 | from yamle.trainers.trainer import BaseTrainer 5 | from yamle.trainers.ensemble import EnsembleTrainer, BaggingTrainer 6 | from yamle.trainers.calibration import CalibrationTrainer 7 | 8 | AVAILABLE_TRAINERS = { 9 | "base": BaseTrainer, 10 | "ensemble": EnsembleTrainer, 11 | "bagging": BaggingTrainer, 12 | "calibration": CalibrationTrainer, 13 | } 14 | 15 | 16 | def trainer_factory(trainer_type: str) -> Type[BaseTrainer]: 17 | """This function is used to create a trainer instance based on the trainer type. 18 | 19 | Args: 20 | trainer_type (str): The type of trainer to create. 21 | """ 22 | if trainer_type not in AVAILABLE_TRAINERS: 23 | raise ValueError(f"Unknown trainer type {trainer_type}.") 24 | return AVAILABLE_TRAINERS[trainer_type] 25 | -------------------------------------------------------------------------------- /yamle/quantization/__init__.py: -------------------------------------------------------------------------------- 1 | from typing import Optional, Type, Callable 2 | 3 | from yamle.quantization.quantizer import DummyQuantizer 4 | from yamle.quantization.static import StaticQuantizer 5 | from yamle.quantization.qat import QATQuantizer 6 | 7 | AVAILABLE_QUANTIZERS = { 8 | "static": StaticQuantizer, 9 | None: DummyQuantizer, 10 | "dummy": DummyQuantizer, 11 | "none": DummyQuantizer, 12 | "qat": QATQuantizer, 13 | } 14 | 15 | 16 | def quantizer_factory(quantizer_type: Optional[str] = None) -> Type[Callable]: 17 | """This function is used to create a quantizer instance based on the quantizer type. 18 | 19 | Args: 20 | quantizer_type (str): The type of pruner to create. 21 | """ 22 | if quantizer_type not in AVAILABLE_QUANTIZERS: 23 | raise ValueError(f"Unknown pruner type {quantizer_type}.") 24 | return AVAILABLE_QUANTIZERS[quantizer_type] 25 | -------------------------------------------------------------------------------- /.github/ISSUE_TEMPLATE/bug_report.md: -------------------------------------------------------------------------------- 1 | --- 2 | name: Bug report 3 | about: Create a report to help us improve 4 | title: '' 5 | labels: '' 6 | assignees: '' 7 | 8 | --- 9 | 10 | **Describe the bug** 11 | A clear and concise description of what the bug is. 12 | 13 | **To Reproduce** 14 | Steps to reproduce the behavior: 15 | 1. Go to '...' 16 | 2. Click on '....' 17 | 3. Scroll down to '....' 18 | 4. See error 19 | 20 | **Expected behavior** 21 | A clear and concise description of what you expected to happen. 22 | 23 | **Screenshots** 24 | If applicable, add screenshots to help explain your problem. 25 | 26 | **Desktop (please complete the following information):** 27 | - OS: [e.g. iOS] 28 | - Browser [e.g. chrome, safari] 29 | - Version [e.g. 22] 30 | 31 | **Smartphone (please complete the following information):** 32 | - Device: [e.g. iPhone6] 33 | - OS: [e.g. iOS8.1] 34 | - Browser [e.g. stock browser, safari] 35 | - Version [e.g. 22] 36 | 37 | **Additional context** 38 | Add any other context about the problem here. 39 | -------------------------------------------------------------------------------- /yamle/utils/regularizer_utils.py: -------------------------------------------------------------------------------- 1 | from typing import Union, List 2 | 3 | import torch.nn as nn 4 | from yamle.defaults import DISABLED_REGULARIZER_KEY 5 | 6 | 7 | def disable_regularizer(parameters: Union[List[nn.Parameter], nn.Parameter]) -> None: 8 | """This method is used to disable weight decay for the given parameters.""" 9 | if isinstance(parameters, nn.Parameter): 10 | setattr(parameters, DISABLED_REGULARIZER_KEY, True) 11 | elif isinstance(parameters, list): 12 | for param in parameters: 13 | setattr(param, DISABLED_REGULARIZER_KEY, True) 14 | else: 15 | raise ValueError( 16 | f"The parameters should be either a list of parameters or a single parameter. Got {type(parameters)}." 17 | ) 18 | 19 | 20 | def enable_regularizer(parameters: Union[List[nn.Parameter], nn.Parameter]) -> None: 21 | """This method is used to enable weight decay for the given parameters.""" 22 | if isinstance(parameters, nn.Parameter): 23 | setattr(parameters, DISABLED_REGULARIZER_KEY, False) 24 | elif isinstance(parameters, list): 25 | for param in parameters: 26 | setattr(param, DISABLED_REGULARIZER_KEY, False) 27 | else: 28 | raise ValueError( 29 | f"The parameters should be either a list of parameters or a single parameter. Got {type(parameters)}." 30 | ) 31 | -------------------------------------------------------------------------------- /yamle/losses/__init__.py: -------------------------------------------------------------------------------- 1 | from typing import Type, Callable, Optional 2 | 3 | from yamle.losses.loss import DummyLoss 4 | from yamle.losses.classification import CrossEntropyLoss, TextCrossEntropyLoss 5 | from yamle.losses.contrastive import NoiseContrastiveEstimatorLoss 6 | from yamle.losses.regression import ( 7 | GaussianNegativeLogLikelihoodLoss, 8 | MeanSquaredError, 9 | QuantileRegressionLoss, 10 | ) 11 | from yamle.losses.segmentation import FocalLoss, SoftIntersectionOverUnionLoss 12 | from yamle.losses.evidential_regression import EvidentialRegressionLoss 13 | 14 | AVAILABLE_LOSSES = { 15 | "crossentropy": CrossEntropyLoss, 16 | "nce": NoiseContrastiveEstimatorLoss, 17 | "textcrossentropy": TextCrossEntropyLoss, 18 | "gaussiannll": GaussianNegativeLogLikelihoodLoss, 19 | "mse": MeanSquaredError, 20 | "quantile": QuantileRegressionLoss, 21 | "focal": FocalLoss, 22 | "softiou": SoftIntersectionOverUnionLoss, 23 | "evidentialregression": EvidentialRegressionLoss, 24 | None: DummyLoss, 25 | "dummy": DummyLoss, 26 | } 27 | 28 | 29 | def loss_factory(loss_type: Optional[str] = None) -> Type[Callable]: 30 | """This function is used to create a loss instance based on the loss type. 31 | 32 | Args: 33 | loss_type (str): The type of loss to create. 34 | """ 35 | if loss_type not in AVAILABLE_LOSSES: 36 | raise ValueError(f"Unknown loss type {loss_type}.") 37 | return AVAILABLE_LOSSES[loss_type] 38 | -------------------------------------------------------------------------------- /yamle/models/__init__.py: -------------------------------------------------------------------------------- 1 | from typing import Type, Optional 2 | 3 | import torch.nn as nn 4 | 5 | from yamle.models.fc import FCModel, ResidualFCModel 6 | from yamle.models.convnet import ConvNetModel, ResidualConvNetModel 7 | from yamle.models.resnet import ResNetModel 8 | from yamle.models.densenet import DenseNetModel 9 | from yamle.models.unet import UNetModel 10 | from yamle.models.transformer import TransformerModel 11 | from yamle.models.visual_transformer import VisualTransformerModel 12 | from yamle.models.rnn import RNNModel, RNNAutoEncoderModel 13 | from yamle.models.mixer import MixerModel 14 | from yamle.models.vgg import VGGModel 15 | 16 | AVAILABLE_MODELS = { 17 | "fc": FCModel, 18 | "convnet": ConvNetModel, 19 | "residualconvnet": ResidualConvNetModel, 20 | "residualfc": ResidualFCModel, 21 | "resnet": ResNetModel, 22 | "densenet": DenseNetModel, 23 | "vgg": VGGModel, 24 | "unet": UNetModel, 25 | "transformer": TransformerModel, 26 | "visualtransformer": VisualTransformerModel, 27 | "mixer": MixerModel, 28 | "rnn": RNNModel, 29 | "rnnautoencoder": RNNAutoEncoderModel, 30 | None: nn.Identity, 31 | } 32 | 33 | 34 | def model_factory(model_type: Optional[str] = None) -> Type[nn.Module]: 35 | """This function is used to return a model instance based on the model type. 36 | 37 | Args: 38 | model_type (str): The type of model to create. 39 | """ 40 | if model_type not in AVAILABLE_MODELS: 41 | raise ValueError(f"Unknown model type {model_type}.") 42 | return AVAILABLE_MODELS[model_type] 43 | -------------------------------------------------------------------------------- /yamle/utils/export_utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from yamle.utils.tracing_utils import get_input_shape_from_model 4 | import onnx 5 | 6 | import logging 7 | 8 | logging = logging.getLogger("pytorch_lightning") 9 | 10 | 11 | def export_onnx(model: nn.Module, path: str) -> None: 12 | """This method is used to export the model to ONNX. 13 | 14 | Args: 15 | model (nn.Module): The model to export. 16 | path (str): The path to save the model. 17 | """ 18 | model.eval() 19 | input_shape = list(get_input_shape_from_model(model)[0]) 20 | input_shape[0] = 1 21 | x = torch.randn(*input_shape).to(next(model.parameters()).device) 22 | 23 | # Make a pass to count the number of outputs 24 | outputs = model(x) 25 | num_outputs = len(outputs) if isinstance(outputs, (list, tuple)) else 1 26 | 27 | output_names = [f"output_{i}" for i in range(num_outputs)] 28 | dynamic_axes = { 29 | "input": {0: "batch_size"}, 30 | } 31 | for i in range(num_outputs): 32 | dynamic_axes[output_names[i]] = {0: "batch_size"} 33 | 34 | logging.info("Exporting model to ONNX.") 35 | torch.onnx.export( 36 | model, 37 | x, 38 | path, 39 | export_params=True, 40 | opset_version=10, 41 | do_constant_folding=True, 42 | input_names=["input"], 43 | output_names=output_names, 44 | dynamic_axes=dynamic_axes, 45 | ) 46 | 47 | # Perform sanity check 48 | logging.info("Checking ONNX model.") 49 | onnx_model = onnx.load(path) 50 | onnx.checker.check_model(onnx_model) 51 | logging.info("ONNX model is valid.") 52 | -------------------------------------------------------------------------------- /yamle/models/specific/temperature_scaling.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | from yamle.defaults import TINY_EPSILON 5 | 6 | 7 | class TemperatureScaler(nn.Module): 8 | """This is a simple temperature scaling layer that is applied to the logits of a model. 9 | 10 | Args: 11 | temperature (float): The initial temperature. Default: 1.0. 12 | mode (str): When to turn it on and off. Default: 'both'. Can select from 'train', 'eval', 'both'. 13 | train (bool): Whether the temperature is trainable or not. Default: False. 14 | """ 15 | 16 | def __init__( 17 | self, temperature: float = 1.0, mode: str = "both", train: bool = False 18 | ) -> None: 19 | super().__init__() 20 | # At first the gradient is not computed for the temperature parameter. 21 | assert mode in [ 22 | "train", 23 | "eval", 24 | "both", 25 | ], "The mode must be one of 'train', 'eval', 'both'." 26 | self._mode = mode 27 | self._T = nn.Parameter(torch.tensor(temperature), requires_grad=train) 28 | 29 | def forward(self, x: torch.Tensor) -> torch.Tensor: 30 | """This method is used to perform a forward pass through the layer.""" 31 | if self._mode == "train" and self.training: 32 | return x / torch.clamp(self._T, min=TINY_EPSILON, max=1e8) 33 | elif self._mode == "eval" and not self.training: 34 | return x / torch.clamp(self._T, min=TINY_EPSILON, max=1e8) 35 | elif self._mode == "both": 36 | return x / torch.clamp(self._T, min=TINY_EPSILON, max=1e8) 37 | return x 38 | 39 | def extra_repr(self) -> str: 40 | return super().extra_repr() + f"T={self._T.item():.2f}, mode={self._mode}" 41 | -------------------------------------------------------------------------------- /yamle/regularizers/__init__.py: -------------------------------------------------------------------------------- 1 | from typing import Type, Callable, Optional 2 | 3 | 4 | from yamle.regularizers.regularizer import DummyRegularizer 5 | from yamle.regularizers.feature import ( 6 | L1FeatureRegularizer, 7 | L2FeatureRegularizer, 8 | InnerProductFeatureRegularizer, 9 | CorrelationFeatureRegularizer, 10 | CosineSimilarityFeatureRegularizer, 11 | ) 12 | from yamle.regularizers.weight import ( 13 | L1Regularizer, 14 | L2Regularizer, 15 | L1L2Regularizer, 16 | WeightDecayRegularizer, 17 | ) 18 | from yamle.regularizers.gradient import GradientNoiseRegularizer 19 | from yamle.regularizers.model import ShrinkAndPerturbRegularizer 20 | 21 | AVAILABLE_REGULARIZERS = { 22 | "l1": L1Regularizer, 23 | "l2": L2Regularizer, 24 | "weight_decay": WeightDecayRegularizer, 25 | "l1l2": L1L2Regularizer, 26 | "l1_feature": L1FeatureRegularizer, 27 | "l2_feature": L2FeatureRegularizer, 28 | "inner_product_feature": InnerProductFeatureRegularizer, 29 | "correlation_feature": CorrelationFeatureRegularizer, 30 | "cosine_similarity_feature": CosineSimilarityFeatureRegularizer, 31 | "gradient_noise": GradientNoiseRegularizer, 32 | "shrink_and_perturb": ShrinkAndPerturbRegularizer, 33 | None: DummyRegularizer, 34 | "none": DummyRegularizer, 35 | "dummy": DummyRegularizer, 36 | } 37 | 38 | 39 | def regularizer_factory(regularizer_type: Optional[str] = None) -> Type[Callable]: 40 | """This function is used to create a regularizer instance based on the regularizer type. 41 | 42 | Args: 43 | regularizer_type (str): The type of regularizer to create. 44 | """ 45 | if regularizer_type not in AVAILABLE_REGULARIZERS: 46 | raise ValueError(f"Unknown regularizer type {regularizer_type}.") 47 | return AVAILABLE_REGULARIZERS[regularizer_type] 48 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | """This file installs the yamle package.""" 2 | 3 | from setuptools import find_packages, setup 4 | import os 5 | 6 | 7 | def read(fname): 8 | return open(os.path.join(os.path.dirname(__file__), fname)).read() 9 | 10 | 11 | # Package meta-data. 12 | NAME = "yamle" 13 | DESCRIPTION = "Yet Another Machine Learning Environment" 14 | URL = "https://github.com/martinferianc/yamle" 15 | EMAIL = "ferianc.martin@gmail.com" 16 | AUTHOR = "Martin Ferianc" 17 | REQUIRES_PYTHON = ">=3.9.0" 18 | VERSION = "0.0.1" 19 | LICENSE = "GNU GPLv3+" 20 | KEYWORDS = "machine learning, deep learning, python3, open source software" 21 | 22 | # What packages are required for this module to be executed? 23 | REQUIRED = [ 24 | "torch==2.0.0", 25 | "autograd==1.5", 26 | "pytorch-lightning==2.0.8", 27 | "pandas==2.1.0", 28 | "scienceplots>=2.1.0", 29 | "torchdata==0.6.0", 30 | "torchvision==0.15.1", 31 | "torchtext==0.15.1", 32 | "torchmetrics==1.0.0", 33 | "scikit-learn==1.1.3", 34 | "scikit-optimize==0.9.0", 35 | "scikit-image==0.20.0", 36 | "opencv-python==4.7.0.72", 37 | "medmnist==2.2.3", 38 | "h5py==3.7.0", 39 | "ptflops==0.7.1.2", 40 | "torchinfo>=1.8.0", 41 | "einops>=0.6.1", 42 | "fvcore>=0.1.5.post20221221", 43 | "rich>=13.5.2", 44 | "onnx==1.15.0", 45 | "backpack-for-pytorch==1.6.0", 46 | "paretoset==1.2.3", 47 | "natsort==8.4.0", 48 | "gpytorch==1.11.0", 49 | "syne-tune[basic]==0.10.0", 50 | "tensorboard>=2.17.0", 51 | ] 52 | EXTRAS = {} 53 | DEV = [] 54 | TEST = [] 55 | DOCS = [] 56 | EXAMPLES = [] 57 | BENCHMARKS = [] 58 | 59 | setup( 60 | name=NAME, 61 | version=VERSION, 62 | description=DESCRIPTION, 63 | author=AUTHOR, 64 | author_email=EMAIL, 65 | python_requires=REQUIRES_PYTHON, 66 | url=URL, 67 | keywords=KEYWORDS, 68 | packages=find_packages(include=("yamle", "yamle.*")), 69 | long_description=read("README.rst"), 70 | install_requires=REQUIRED, 71 | extras_require=EXTRAS, 72 | include_package_data=True, 73 | license=LICENSE, 74 | ) 75 | -------------------------------------------------------------------------------- /yamle/third_party/medmnist/__init__.py: -------------------------------------------------------------------------------- 1 | from typing import Tuple, Optional 2 | 3 | import torchvision 4 | from PIL import Image 5 | from torch.utils.data import Dataset 6 | 7 | 8 | class MedMNISTDatasetWrapper(Dataset): 9 | """This is a wrapper class for MedMNIST dataset. 10 | 11 | It enables the padding of the images to 32x32, which is required by the 12 | corruptions and might be required by the models. It also does unsqueezing 13 | of the target since they are in a numpy array of shape (1,). 14 | 15 | Args: 16 | dataset (Dataset): The MedMNIST dataset. 17 | pad_to_32 (bool): Whether to pad the images to 32x32. Defaults to True. 18 | target_normalization (bool): Whether to normalize the target to [0, 1]. Defaults to False. 19 | This is used for the ordinal regression task. 20 | """ 21 | 22 | def __init__( 23 | self, 24 | dataset: Dataset, 25 | pad_to_32: bool = True, 26 | target_normalization: bool = False, 27 | ): 28 | self._dataset = dataset 29 | self._pad_to_32 = pad_to_32 30 | self._target_normalization = target_normalization 31 | self._min_max: Optional[Tuple[float, float]] = None 32 | if self._target_normalization: 33 | self._min_max = self._get_min_max() 34 | 35 | def __len__(self) -> int: 36 | return len(self._dataset) 37 | 38 | def _get_min_max(self) -> Tuple[float, float]: 39 | """This is a helper function to get the min and max of the target across the dataset.""" 40 | m = float("inf") 41 | M = float("-inf") 42 | for _, target in self._dataset: 43 | m = min(m, target[0]) 44 | M = max(M, target[0]) 45 | return m, M 46 | 47 | def __getitem__(self, idx: int) -> Tuple[Image.Image, int]: 48 | img, target = self._dataset[idx] 49 | if self._pad_to_32: 50 | img = torchvision.transforms.functional.pad(img, 2) # 28 -> 32 51 | target = target[0] # (1,) -> () 52 | if self._target_normalization: 53 | target = (target - self._min_max[0]) / (self._min_max[1] - self._min_max[0]) 54 | return img, target 55 | -------------------------------------------------------------------------------- /yamle/third_party/imagenet_c/tabular_corruptions.py: -------------------------------------------------------------------------------- 1 | """ 2 | This file is an adaptation to the original ImageNet-C corruptions but for 3 | tabular data. 4 | """ 5 | 6 | import numpy as np 7 | 8 | 9 | def additive_gaussian_noise(x: np.ndarray, severity: int = 1) -> np.ndarray: 10 | """Adds gaussian noise with mean 0 and std deviation c. 11 | 12 | The standard deviation c is multiplied with respect to the features 13 | range to get the final standard deviation. 14 | 15 | Args: 16 | x (np.ndarray): The input data of shape (N, D). 17 | severity (int): The severity of the corruption. Must be in [1, 5]. 18 | """ 19 | c = [0.1, 0.2, 0.3, 0.4, 0.5][severity - 1] 20 | # Sample noise in the same shape as the input 21 | eta = np.random.normal(loc=0.0, scale=c, size=x.shape) 22 | eta = eta * x 23 | return x + eta 24 | 25 | 26 | def multiplicative_gaussian_noise(x: np.ndarray, severity: int = 1) -> np.ndarray: 27 | """Multiplies the input by a gaussian noise with mean 1 and std deviation c. 28 | 29 | The standard deviation c is multiplied with respect to the feature range 30 | """ 31 | c = [0.05, 0.1, 0.15, 0.2, 0.3][severity - 1] 32 | return x * np.random.normal(loc=1.0, scale=c, size=x.shape) 33 | 34 | 35 | def additive_uniform_noise(x: np.ndarray, severity: int = 1) -> np.ndarray: 36 | """Adds uniform noise with range [-c, c]. 37 | 38 | The range is multiplied with respect to the feature range to get the final range. 39 | """ 40 | c = [0.1, 0.2, 0.3, 0.4, 0.5][severity - 1] 41 | # Sample noise in the same shape as the input 42 | eta = np.random.uniform(low=-c, high=c, size=x.shape) 43 | eta = eta * x 44 | return x + eta 45 | 46 | 47 | def multiplicative_uniform_noise(x: np.ndarray, severity: int = 1) -> np.ndarray: 48 | """Multiplies the input by a uniform noise with range [1-c, 1+c].""" 49 | c = [0.05, 0.1, 0.15, 0.2, 0.3][severity - 1] 50 | return x * np.random.uniform(low=1.0 - c, high=1.0 + c, size=x.shape) 51 | 52 | 53 | def multiplicative_bernoulli_noise(x: np.ndarray, severity: int = 1) -> np.ndarray: 54 | """Multiplies the input by a bernoulli noise which drops features with probability c.""" 55 | c = [0.05, 0.1, 0.15, 0.2, 0.3][severity - 1] 56 | return x * np.random.binomial(n=1, p=1.0 - c, size=x.shape) 57 | -------------------------------------------------------------------------------- /yamle/models/specific/laplace.py: -------------------------------------------------------------------------------- 1 | from typing import Optional 2 | 3 | import torch 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | from torch.distributions import MultivariateNormal 7 | 8 | 9 | class LaplaceLinear(nn.Module): 10 | """This class implements the Laplace linear layer. 11 | 12 | Once the Hessian is computed it uses the K-FAC approximation to compute the approximation to the Hessian. 13 | 14 | Args: 15 | U (torch.Tensor): The first matrix to compute the factorised Hessian. 16 | V (torch.Tensor): The second matrix to compute the factorised Hessian. 17 | weight (torch.Tensor): The weight matrix of the linear layer. 18 | bias (Optional[torch.Tensor]): The bias vector of the linear layer. 19 | """ 20 | 21 | def __init__( 22 | self, 23 | U: torch.Tensor, 24 | V: torch.Tensor, 25 | weight: torch.Tensor, 26 | bias: Optional[torch.Tensor] = None, 27 | ) -> None: 28 | super().__init__() 29 | # Initialize the buffers for the factorised Hessian 30 | self.register_buffer("_U", U) 31 | self.register_buffer("_V", V) 32 | # Initialize the weight and bias parameters, the weights are not trainable 33 | self.weight = nn.Parameter(weight, requires_grad=False) 34 | self.bias = ( 35 | nn.Parameter(bias, requires_grad=False) if bias is not None else None 36 | ) 37 | 38 | def forward(self, x: torch.Tensor) -> torch.Tensor: 39 | """This function computes the forward pass of the Laplace linear layer. 40 | 41 | Args: 42 | x (torch.Tensor): The input tensor. 43 | """ 44 | # Compute the forward pass 45 | mean = F.linear(x, self.weight, self.bias) 46 | if self.training: 47 | return mean 48 | else: 49 | # Compute the variance prediction 50 | variance = ( 51 | torch.mm(torch.mm(x, self._V), x.T).diag().reshape(-1, 1, 1) * self._U 52 | ) 53 | distribution = MultivariateNormal(mean, variance) 54 | return distribution.sample() 55 | 56 | def extra_repr(self) -> str: 57 | """This function returns the extra representation of the layer.""" 58 | return f"in_features={self.weight.shape[1]}, out_features={self.weight.shape[0]}, bias={self.bias is not None}" 59 | -------------------------------------------------------------------------------- /yamle/trainers/calibration.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | from typing import Any 3 | 4 | from torch.utils.data import DataLoader 5 | 6 | from yamle.trainers.trainer import BaseTrainer 7 | 8 | 9 | class CalibrationTrainer(BaseTrainer): 10 | """This class defines a temperature trainer which first trains the model and then calibrates it. 11 | 12 | The training is on the training set and the calibration is on the calibration set. 13 | 14 | Args: 15 | calibration_epochs (int): The number of epochs to calibrate the model. 16 | """ 17 | 18 | def __init__(self, calibration_epochs: int, *args: Any, **kwargs: Any) -> None: 19 | super().__init__(*args, **kwargs) 20 | self._calibration_epochs = calibration_epochs 21 | 22 | def fit(self, train_dataloader: DataLoader, validation_dataloader: DataLoader) -> float: 23 | """This method trains the method and then does the calibration. 24 | 25 | Args: 26 | train_dataloader (DataLoader): The dataloader to be used for training. 27 | validation_dataloader (DataLoader): The dataloader to be used for validation. 28 | """ 29 | training_time = super().fit(train_dataloader, validation_dataloader) 30 | calibration_dataloader = self._datamodule.calibration_dataloader() 31 | if not hasattr(self._method, "calibrate"): 32 | raise ValueError("Make sure that the method has a calibrate method.") 33 | self._method.calibrate() 34 | self._initialize_trainer(epochs=self._calibration_epochs) 35 | calibration_time = super().fit(calibration_dataloader, validation_dataloader) 36 | return training_time + calibration_time 37 | 38 | @staticmethod 39 | def add_specific_args( 40 | parent_parser: argparse.ArgumentParser, 41 | ) -> argparse.ArgumentParser: 42 | """This method adds trainer arguments to the given parser. 43 | 44 | Args: 45 | parent_parser (ArgumentParser): The parser to which the arguments should be added. 46 | """ 47 | parser = super(CalibrationTrainer, CalibrationTrainer).add_specific_args( 48 | parent_parser 49 | ) 50 | parser.add_argument( 51 | "--trainer_calibration_epochs", 52 | type=int, 53 | default=10, 54 | help="The number of epochs to be used for training.", 55 | ) 56 | return parser 57 | -------------------------------------------------------------------------------- /yamle/cli/analyze_tune.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import os 3 | from argparse import ArgumentParser 4 | 5 | import yamle.utils.file_utils as utils 6 | from yamle.utils.tuning_utils import ( 7 | plot_different_runs_and_metrics, 8 | plot_different_runs_and_metric_config_combinations, 9 | plot_different_metrics_and_trial_id, 10 | ) 11 | 12 | logging = logging.getLogger("pytorch_lightning") 13 | 14 | 15 | def plot_analyze_tune(args: ArgumentParser) -> None: 16 | """This is a helper function which loads in the dataframe for a tuning experiment and plots the results 17 | and different metrics and statistics.""" 18 | 19 | # Create experiment structure 20 | experiment_name = f"{utils.current_time()}-analyze-tune" 21 | experiment_name += f"-{args.label}" if args.label is not None else "" 22 | 23 | # Create experiment directory 24 | save_path = os.path.join(args.save_path, experiment_name) 25 | save_path = utils.create_experiment_folder(save_path, "./src", cache_scripts=False) 26 | 27 | # Set the logger 28 | utils.config_logger(save_path) 29 | logging.info("Beginning Analyze Tune: %s", experiment_name) 30 | logging.info("Arguments: %s", args) 31 | logging.info("Command arguments to reproduce: %s", utils.argparse_to_command(args)) 32 | 33 | # Save the arguments 34 | utils.save_args(args, save_path) 35 | utils.save_args_dictionary(args, save_path) 36 | 37 | # Load in the dataframe 38 | df = utils.load_tuning_results(args.experiment) 39 | 40 | # Plot the different statistics and analysis 41 | plot_different_runs_and_metrics(df, save_path) 42 | plot_different_runs_and_metric_config_combinations(df, save_path) 43 | plot_different_metrics_and_trial_id(df, save_path) 44 | 45 | 46 | if __name__ == "__main__": 47 | parser = ArgumentParser() 48 | parser.add_argument( 49 | "--label", 50 | type=str, 51 | default=None, 52 | help="An optional label to be added to the experiment name.", 53 | ) 54 | parser.add_argument( 55 | "--save_path", 56 | type=str, 57 | default="experiments", 58 | help="The directory where the experiment results are stored.", 59 | ) 60 | parser.add_argument( 61 | "--experiment", 62 | type=str, 63 | default=None, 64 | help="The name of the experiment to be analyzed.", 65 | ) 66 | args = parser.parse_args() 67 | 68 | plot_analyze_tune(args) 69 | -------------------------------------------------------------------------------- /yamle/methods/temperature_scaling.py: -------------------------------------------------------------------------------- 1 | from typing import Any, Tuple, List 2 | 3 | import torch 4 | import torch.nn as nn 5 | import argparse 6 | 7 | from yamle.methods.method import BaseMethod 8 | from yamle.models.specific.temperature_scaling import TemperatureScaler 9 | from yamle.defaults import CLASSIFICATION_KEY, SEGMENTATION_KEY 10 | 11 | 12 | class TemperatureMethod(BaseMethod): 13 | """This class is the extension of the base method for temperature scaling. 14 | 15 | Args: 16 | calibration_learning_rate (float): The learning rate for the calibration. 17 | """ 18 | 19 | tasks = [CLASSIFICATION_KEY, SEGMENTATION_KEY] 20 | 21 | def __init__(self, calibration_learning_rate: float, *args: Any, **kwargs: Any) -> None: 22 | super().__init__(*args, **kwargs) 23 | self.model._output = nn.Sequential(self.model._output, TemperatureScaler()) 24 | 25 | self.hparams.calibration_learning_rate = calibration_learning_rate 26 | 27 | self.calibration = False 28 | 29 | def calibrate(self) -> None: 30 | """This method is used to trigger the calibration.""" 31 | self.calibration = True 32 | self.model._output[1]._T.requires_grad = True 33 | 34 | def configure_optimizers(self) -> Tuple[List[torch.optim.Optimizer], List[Any]]: 35 | """This method is used to configure the optimizers for the model. 36 | if the model is not calibrated, then the temperature parameter is not updated. 37 | 38 | if `self.calibration` is True, then only the temperature parameter is updated. 39 | """ 40 | if not self.calibration: 41 | return super().configure_optimizers() 42 | else: 43 | return [ 44 | torch.optim.LBFGS( 45 | [self.model._output[1]._T], 46 | lr=self.hparams.calibration_learning_rate, 47 | line_search_fn="strong_wolfe", 48 | ) 49 | ], [] 50 | 51 | @staticmethod 52 | def add_specific_args( 53 | parent_parser: argparse.ArgumentParser, 54 | ) -> argparse.ArgumentParser: 55 | """This method is used to add the specific arguments for the class.""" 56 | parser = super(TemperatureMethod, TemperatureMethod).add_specific_args( 57 | parent_parser 58 | ) 59 | parser.add_argument( 60 | "--method_calibration_learning_rate", 61 | help="The learning rate for the calibration.", 62 | type=float, 63 | default=0.001, 64 | ) 65 | return parser 66 | -------------------------------------------------------------------------------- /yamle/pruning/unstructured/magnitude.py: -------------------------------------------------------------------------------- 1 | from typing import Any 2 | import torch 3 | import torch.nn as nn 4 | import argparse 5 | 6 | from yamle.pruning.pruner import BasePruner 7 | from yamle.utils.pruning_utils import ( 8 | get_all_prunable_weights, 9 | is_layer_prunable, 10 | is_parameter_prunable, 11 | ) 12 | 13 | 14 | class UnstructuredMagnitudePruner(BasePruner): 15 | """This is the base class for unstructured magnitude-based pruning. 16 | 17 | It will prune the weights with the lowest absolute magnitude. The threshold is determined 18 | by the pruning percentage. The pruning percentage is the percentageage of weights to prune. 19 | 20 | Args: 21 | pruning_percentage (float): The percentageage of weights to prune. 22 | """ 23 | 24 | def __init__(self, percentage: float, *args: Any, **kwargs: Any) -> None: 25 | super().__init__(*args, **kwargs) 26 | assert 0.0 <= percentage <= 1.0, "Pruning percentage must be between 0 and 1." 27 | self._percentage = percentage 28 | 29 | def __call__(self, m: nn.Module) -> float: 30 | """This method is used to prune the model.""" 31 | # Get all the weights in the model 32 | weights = get_all_prunable_weights(m) 33 | 34 | # Find the magnitude of the weight at a given percentile 35 | threshold = torch.abs(weights).kthvalue(int(self._percentage * len(weights)))[0] 36 | 37 | # Prune the weights 38 | for module in m.modules(): 39 | if is_layer_prunable(module): 40 | for p in module.parameters(): 41 | if is_parameter_prunable(p): 42 | # Create a mask to prune the weights, `True` means prune 43 | mask = torch.abs(p.data) < threshold 44 | self.prune_parameter(p, mask) 45 | 46 | return threshold.item() 47 | 48 | def __repr__(self) -> str: 49 | return f"{self.__class__.__name__}(percentage={self._percentage})" 50 | 51 | @staticmethod 52 | def add_specific_args( 53 | parent_parser: argparse.ArgumentParser, 54 | ) -> argparse.ArgumentParser: 55 | """This method is used to add the pruner specific arguments to the parent parser.""" 56 | parser = super( 57 | UnstructuredMagnitudePruner, UnstructuredMagnitudePruner 58 | ).add_specific_args(parent_parser) 59 | parser.add_argument( 60 | "--pruner_percentage", 61 | type=float, 62 | default=0.5, 63 | help="The percentageage of weights to prune.", 64 | ) 65 | return parser 66 | -------------------------------------------------------------------------------- /yamle/losses/evidential_regression.py: -------------------------------------------------------------------------------- 1 | """ 2 | Adopted from: https://github.com/aamini/evidential-deep-learning/ 3 | """ 4 | 5 | from typing import Any, Optional, Tuple 6 | 7 | import torch 8 | import torch.nn as nn 9 | import numpy as np 10 | 11 | from yamle.losses.loss import BaseLoss 12 | from yamle.defaults import TINY_EPSILON, REGRESSION_KEY, DEPTH_ESTIMATION_KEY, RECONSTRUCTION_KEY 13 | 14 | 15 | class NIG_NLL(nn.Module): 16 | """Negative log-likelihood loss for Normal Inverse Gamma (NIG) distribution.""" 17 | 18 | def forward( 19 | self, 20 | y: torch.Tensor, 21 | gamma: torch.Tensor, 22 | v: torch.Tensor, 23 | alpha: torch.Tensor, 24 | beta: torch.Tensor, 25 | ) -> torch.Tensor: 26 | """Compute the loss function.""" 27 | twoBlambda = 2 * beta * (1 + v) 28 | 29 | nll = ( 30 | 0.5 * torch.log(torch.tensor(np.pi) / (v + TINY_EPSILON) + TINY_EPSILON) 31 | - alpha * torch.log(twoBlambda + TINY_EPSILON) 32 | + (alpha + 0.5) 33 | * torch.log(v * (y - gamma) ** 2 + twoBlambda + TINY_EPSILON) 34 | + torch.lgamma(alpha) 35 | - torch.lgamma(alpha + 0.5) 36 | ) 37 | 38 | return torch.mean(nll) 39 | 40 | 41 | class NIG_Reg(nn.Module): 42 | """Regularization loss for Normal Inverse Gamma (NIG) distribution.""" 43 | 44 | def forward( 45 | self, 46 | y: torch.Tensor, 47 | gamma: torch.Tensor, 48 | v: torch.Tensor, 49 | alpha: torch.Tensor, 50 | beta: torch.Tensor, 51 | ) -> torch.Tensor: 52 | """Compute the loss function.""" 53 | error = torch.abs(y - gamma) 54 | evi = 2 * v + (alpha) 55 | reg = error * evi 56 | return torch.mean(reg) 57 | 58 | 59 | class EvidentialRegressionLoss(BaseLoss): 60 | """Evidential regression loss for probabilistic regression.""" 61 | 62 | tasks = [REGRESSION_KEY, DEPTH_ESTIMATION_KEY, RECONSTRUCTION_KEY] 63 | 64 | def __init__(self, *args: Any, **kwargs: Any) -> None: 65 | super().__init__(*args, **kwargs) 66 | self._nll = NIG_NLL() 67 | self._reg = NIG_Reg() 68 | 69 | def __call__( 70 | self, 71 | y_hat: torch.Tensor, 72 | y: torch.Tensor, 73 | weights: Optional[torch.Tensor] = None, 74 | ) -> Tuple[torch.Tensor, torch.Tensor]: 75 | gamma, v, alpha, beta = torch.split(y_hat, 1, dim=-1) 76 | loss_nll = self._nll(y, gamma, v, alpha, beta) 77 | loss_reg = self._reg(y, gamma, v, alpha, beta) 78 | return loss_nll, loss_reg 79 | 80 | def __repr__(self) -> str: 81 | return "EvidentialRegression()" 82 | -------------------------------------------------------------------------------- /yamle/regularizers/gradient.py: -------------------------------------------------------------------------------- 1 | from typing import Any 2 | 3 | import torch.nn as nn 4 | from yamle.defaults import TINY_EPSILON 5 | from yamle.regularizers.regularizer import BaseRegularizer 6 | 7 | import torch 8 | import argparse 9 | 10 | 11 | class GradientNoiseRegularizer(BaseRegularizer): 12 | """This is a class for a gradient noise regularization. 13 | 14 | It adds a noise sampled from a normal distribution with mean 0 and standard deviation `std` to the gradient. 15 | 16 | It follows the paper: https://arxiv.org/pdf/1511.06807.pdf 17 | 18 | Args: 19 | eta (float): The standard deviation of the normal distribution from which the noise is sampled. 20 | gamma (float): The factor by which the noise is multiplied. 21 | """ 22 | 23 | def __init__(self, eta: float, gamma: float, *args: Any, **kwargs: Any) -> None: 24 | super().__init__(*args, **kwargs) 25 | assert ( 26 | eta >= 0 27 | ), "The standard deviation of the normal distribution must be non-negative." 28 | assert ( 29 | eta > 0 30 | ), "The standard deviation of the normal distribution must be positive." 31 | assert 0 <= gamma <= 1, f"The factor must be between 0 and 1, but got {gamma}." 32 | self._eta = eta 33 | self._gamma = gamma 34 | 35 | def _var(self, epoch: int) -> float: 36 | """Return the variance of the noise at a given epoch.""" 37 | return self._eta / ((1 + epoch) ** self._gamma + TINY_EPSILON) 38 | 39 | def on_after_backward( 40 | self, model: nn.Module, epoch: int, *args: Any, **kwargs: Any 41 | ) -> None: 42 | """Add noise to the gradients after the backward pass.""" 43 | var = self._var(epoch) 44 | for param in model.parameters(): 45 | if param.grad is not None: 46 | param.grad += torch.randn_like(param.grad) * var 47 | 48 | @staticmethod 49 | def add_specific_args(parser: argparse.ArgumentParser) -> argparse.ArgumentParser: 50 | """This method is used to add specific arguments to the parser.""" 51 | parser = super( 52 | GradientNoiseRegularizer, GradientNoiseRegularizer 53 | ).add_specific_args(parser) 54 | parser.add_argument( 55 | "--regularizer_eta", 56 | type=float, 57 | default=0.1, 58 | help="The standard deviation of the normal distribution from which the noise is sampled.", 59 | ) 60 | parser.add_argument( 61 | "--regularizer_gamma", 62 | type=float, 63 | default=0.55, 64 | help="The factor by which the noise is multiplied.", 65 | ) 66 | return parser 67 | 68 | def __repr__(self) -> str: 69 | return f"GradientNoiseRegularizer(eta={self._eta}, gamma={self._gamma})" 70 | -------------------------------------------------------------------------------- /yamle/defaults.py: -------------------------------------------------------------------------------- 1 | TINY_EPSILON = 1e-8 2 | POSITIVE_INFINITY = 2147483647 3 | NEGATIVE_INFINITY = -POSITIVE_INFINITY 4 | MEMBERS_DIM = 1 5 | TARGET_KEY = "y" 6 | INPUT_KEY = "x" 7 | TARGET_PER_MEMBER_KEY = "y_permember" 8 | PREDICTION_KEY = "y_hat" 9 | MEAN_PREDICTION_KEY = "y_hat_mean" 10 | PREDICTION_PER_MEMBER_KEY = "y_hat_permember" 11 | AVERAGE_WEIGHTS_KEY = "average_weights" 12 | 13 | LOSS_KEY = "loss" 14 | LOSS_REGULARIZER_KEY = f"{LOSS_KEY}_regularizer" 15 | LOSS_KL_KEY = f"{LOSS_KEY}_kl" 16 | 17 | REGRESSION_KEY = "regression" 18 | CLASSIFICATION_KEY = "classification" 19 | TEXT_CLASSIFICATION_KEY = "text_classification" 20 | SEGMENTATION_KEY = "segmentation" 21 | RECONSTRUCTION_KEY = "reconstruction" 22 | DEPTH_ESTIMATION_KEY = "depth_estimation" 23 | PRE_TRAINING_KEY = "pre_training" 24 | SUPPORTED_TASKS = [ 25 | REGRESSION_KEY, 26 | CLASSIFICATION_KEY, 27 | TEXT_CLASSIFICATION_KEY, 28 | SEGMENTATION_KEY, 29 | DEPTH_ESTIMATION_KEY, 30 | PRE_TRAINING_KEY, 31 | RECONSTRUCTION_KEY, 32 | ] 33 | 34 | 35 | # Define arguments that can be different when considering averaging the resutls 36 | ARGS_CAN_BE_DIFFERENT = [ 37 | "save_path", 38 | "seed", 39 | "trainer_accelerator", 40 | "trainer_mode", 41 | "st_checkpoint_dir", 42 | "trainer_devices", 43 | "load_path", 44 | "label", 45 | "datamodule_pin_memory", 46 | "trainer_no_validation_saving", 47 | ] 48 | 49 | TRAIN_KEY = "train" 50 | VALIDATION_KEY = "validation" 51 | TEST_KEY = "test" 52 | ALL_DATASETS_KEY = "all" 53 | CALIBRATION_KEY = "calibration" 54 | 55 | MIN_TENDENCY = "min" 56 | MAX_TENDENCY = "max" 57 | 58 | TRAIN_DATA_SPLIT_KEY = "split_" 59 | DISABLED_REGULARIZER_KEY = "_no_regularizer" 60 | 61 | # This is used to set the optimzier id if there are multiple optimizers 62 | # for which the parameters shoud be split 63 | OPTIMIZER_ID_KEY = "_optimizer_id" 64 | 65 | FROZEN_MASK_KEY = "_frozen_mask" 66 | FROZEN_DATA_KEY = "_frozen_data" 67 | 68 | DISABLED_OPTIMIZATION_KEY = "_disabled_optimization" 69 | DISABLED_DROPOUT_KEY = "_disabled_dropout" 70 | DISABLED_VI_KEY = "_disabled_vi" 71 | DISABLED_PRUNING_KEY = "_disabled_pruning" 72 | FORMER_DATA_PRUNING_KEY = "_former_data" 73 | MASK_PRUNING_KEY = "_pruning_mask" 74 | 75 | MODULE_INPUT_SHAPE_KEY = "_input_shape" 76 | MODULE_OUTPUT_SHAPE_KEY = "_output_shape" 77 | MODULE_HARDWARE_PROPERTIES_KEY = "_hardware_properties" 78 | MODULE_FLOPS_KEY = "__flops__" 79 | MODULE_PARAMS_KEY = "__params__" 80 | MODULE_CUMULATIVE_FLOPS_KEY = "__cumulative_flops__" # These are flops which are accumulated from all the previous layers 81 | MODULE_CUMULATIVE_PARAMS_KEY = "__cumulative_params__" # These are params which are accumulated from all the previous layers 82 | MODULE_NAME_KEY = "_name" 83 | 84 | 85 | QUANTIZED_MODEL_KEY = "_quantized_model" 86 | FLOAT_MODEL_KEY = "_float_model" 87 | QUANTIZED_KEY = "_quantized" 88 | 89 | FIT_TIME_KEY = "fit_time" 90 | TEST_TIME_KEY = "test_time" 91 | PROFILING_TIME_KEY = "profiling_time" 92 | -------------------------------------------------------------------------------- /yamle/quantization/models/specific/mcdropout.py: -------------------------------------------------------------------------------- 1 | from typing import Any 2 | import torch 3 | import torch.nn.functional as F 4 | from torch.ao.nn.quantized import FloatFunctional 5 | from torch.ao.quantization import QuantStub 6 | 7 | from yamle.models.specific.mcdropout import Dropout1d, Dropout2d, Dropout3d 8 | 9 | 10 | class QuantisedDropout1d(Dropout1d): 11 | """This is the dropout class but the probability is remebered in a `nn.Parameter`. 12 | 13 | Args: 14 | p (float): The probability of an element to be zeroed. 15 | inplace (bool): If set to `True`, will do this operation in-place. 16 | """ 17 | def __init__(self, *args: Any, **kwargs: Any) -> None: 18 | super().__init__(*args, **kwargs) 19 | self.quant = FloatFunctional() 20 | self.quant_stub = QuantStub() 21 | 22 | def forward(self, x: torch.Tensor) -> torch.Tensor: 23 | """This method is used to perform a forward pass through the dropout layer.""" 24 | mask = F.dropout(torch.ones_like(x), p=self._p, training=True, inplace=self.inplace) 25 | mask = self.quant_stub(mask) 26 | return self.quant.mul(x, mask) 27 | 28 | 29 | class QuantisedDropout2d(Dropout2d): 30 | """This is the dropout class but the probability is remebered in a `nn.Parameter`. 31 | 32 | Args: 33 | p (float): The probability of an element to be zeroed. 34 | inplace (bool): If set to `True`, will do this operation in-place. 35 | """ 36 | def __init__(self, *args: Any, **kwargs: Any) -> None: 37 | super().__init__(*args, **kwargs) 38 | self.quant = FloatFunctional() 39 | self.quant_stub = QuantStub() 40 | 41 | def forward(self, x: torch.Tensor) -> torch.Tensor: 42 | """This method is used to perform a forward pass through the dropout layer.""" 43 | # Create a mask where filters will be completely zeroed out 44 | mask = F.dropout2d(torch.ones_like(x), p=self._p, training=True, inplace=self.inplace) 45 | mask = self.quant_stub(mask) 46 | return self.quant.mul(x, mask) 47 | 48 | 49 | class QuantisedDropout3d(Dropout3d): 50 | """This is the dropout class but the probability is remebered in a `nn.Parameter`. 51 | 52 | Args: 53 | p (float): The probability of an element to be zeroed. 54 | inplace (bool): If set to `True`, will do this operation in-place. 55 | """ 56 | def __init__(self, *args: Any, **kwargs: Any) -> None: 57 | super().__init__(*args, **kwargs) 58 | self.quant = FloatFunctional() 59 | self.quant_stub = QuantStub() 60 | 61 | def forward(self, x: torch.Tensor) -> torch.Tensor: 62 | """This method is used to perform a forward pass through the dropout layer.""" 63 | # Create a mask where filters will be completely zeroed out 64 | mask = F.dropout3d(torch.ones_like(x), p=self._p, training=True, inplace=self.inplace) 65 | mask = self.quant_stub(mask) 66 | return self.quant.mul(x, mask) 67 | -------------------------------------------------------------------------------- /yamle/cli/retest.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import os 3 | import argparse 4 | from yamle.cli.test import evaluate 5 | 6 | import yamle.utils.file_utils as utils 7 | import os 8 | import json 9 | import shutil 10 | 11 | logging = logging.getLogger("pytorch_lightning") 12 | 13 | 14 | def reevaluate( 15 | args: argparse.Namespace, 16 | ) -> None: 17 | """This function reruns an experiment with exactly the same arguments as the original experiment.""" 18 | experiment_args = utils.load_args(args.load_experiment) 19 | if args.new_path is not None: 20 | experiment_name = None 21 | experiment_args.save_path = args.new_path 22 | else: 23 | experiment_name = os.path.basename(os.path.normpath(args.load_experiment)) 24 | utils.config_logger(args.load_experiment) 25 | 26 | # Create a folder where to put the backup of the arguments and the results 27 | os.makedirs(os.path.join(experiment_args.save_path, "backup"), exist_ok=True) 28 | 29 | # Save the previous results 30 | shutil.copy( 31 | utils.results_file(args.load_experiment), 32 | utils.results_file(os.path.join(experiment_args.save_path, "backup")), 33 | ) 34 | # Save the previous log file 35 | shutil.copy( 36 | utils.log_file(args.load_experiment), 37 | utils.log_file(os.path.join(experiment_args.save_path, "backup")), 38 | ) 39 | 40 | logging.info("Rerunning experiment: %s", experiment_name) 41 | logging.info("Experiment arguments: %s", experiment_args) 42 | logging.info("Arguments: %s", args) 43 | logging.info("Command arguments to reproduce: %s", utils.argparse_to_command(args)) 44 | 45 | # Update the arguments with the new arguments 46 | # Create a an args object from the new arguments 47 | new_args = argparse.Namespace(**args.new_args) 48 | new_args = utils.parse_args(new_args) 49 | for key, value in vars(new_args).items(): 50 | if hasattr(experiment_args, key): 51 | logging.info("Updating argument %s to %s", key, value) 52 | setattr(experiment_args, key, value) 53 | else: 54 | raise ValueError( 55 | f"Argument {key} does not exist in the experiment arguments." 56 | ) 57 | 58 | experiment_args.no_saving = False # I have changed the name of the argument 59 | 60 | evaluate(experiment_args, experiment_name, overwrite=False, overwrite_results=True) 61 | 62 | 63 | if __name__ == "__main__": 64 | parser = argparse.ArgumentParser() 65 | parser.add_argument( 66 | "--load_experiment", 67 | type=str, 68 | default=None, 69 | help="The directory where the experiment results are stored and loaded from.", 70 | ) 71 | parser.add_argument( 72 | "--new_args", 73 | type=json.loads, 74 | default={}, 75 | help="The new arguments to be used for the experiment.", 76 | ) 77 | parser.add_argument( 78 | "--new_path", 79 | type=str, 80 | default=None, 81 | help="The new path to store the experiment results.", 82 | ) 83 | args = parser.parse_args() 84 | reevaluate(args) 85 | -------------------------------------------------------------------------------- /docs/conf.py: -------------------------------------------------------------------------------- 1 | import datetime 2 | import os 3 | import shutil 4 | import sys 5 | 6 | sys.path.insert(0, os.path.abspath("..")) 7 | 8 | import yamle 9 | 10 | 11 | def run_apidoc(app): 12 | """Generate doc stubs using sphinx-apidoc.""" 13 | module_dir = os.path.join(app.srcdir, "../") 14 | output_dir = os.path.join(app.srcdir, "_apidoc") 15 | excludes = ["../setup.py"] 16 | 17 | # Ensure that any stale apidoc files are cleaned up first. 18 | if os.path.exists(output_dir): 19 | shutil.rmtree(output_dir) 20 | 21 | cmd = [ 22 | "--separate", 23 | "--module-first", 24 | "--doc-project=API Reference", 25 | "-o", 26 | output_dir, 27 | module_dir, 28 | ] 29 | cmd.extend(excludes) 30 | 31 | try: 32 | from sphinx.ext import apidoc # Sphinx >= 1.7 33 | 34 | apidoc.main(cmd) 35 | except ImportError: 36 | from sphinx import apidoc # Sphinx < 1.7 37 | 38 | cmd.insert(0, apidoc.__file__) 39 | apidoc.main(cmd) 40 | 41 | 42 | def setup(app): 43 | """Register our sphinx-apidoc hook.""" 44 | app.connect("builder-inited", run_apidoc) 45 | 46 | 47 | # Sphinx configuration below. 48 | project = "yamle" 49 | version = yamle.__version__ 50 | release = yamle.__version__ 51 | athor = "Martin Ferianc" 52 | copyright = f"2023-{datetime.datetime.now().year}, Martin" 53 | 54 | 55 | extensions = [ 56 | "sphinx.ext.autosectionlabel", 57 | "sphinx.ext.napoleon", 58 | "sphinx.ext.autodoc", 59 | "sphinx_autodoc_typehints", 60 | "sphinx.ext.doctest", 61 | "sphinx.ext.intersphinx", 62 | "sphinx.ext.todo", 63 | "sphinx.ext.viewcode", 64 | "sphinx.ext.coverage", 65 | "hoverxref.extension", 66 | "sphinx_copybutton", 67 | "sphinxext.opengraph", 68 | "sphinx_paramlinks", 69 | ] 70 | coverage_show_missing_items = True 71 | 72 | autosectionlabel_prefix_document = True 73 | 74 | hoverxref_auto_ref = True 75 | hoverxref_role_types = {"ref": "tooltip"} 76 | 77 | source_suffix = [".rst", ".md"] 78 | 79 | master_doc = "index" 80 | 81 | autoclass_content = "class" 82 | autodoc_member_order = "bysource" 83 | default_role = "py:obj" 84 | 85 | html_theme = "pydata_sphinx_theme" 86 | html_sidebars = {"**": ["sidebar-nav-bs"]} 87 | html_theme_options = { 88 | "primary_sidebar_end": [], 89 | "footer_start": ["copyright"], 90 | "footer_end": [], 91 | "icon_links": [ 92 | { 93 | "name": "GitHub", 94 | "url": "https://github.com/martinferianc/yamle", 95 | "icon": "fa-brands fa-square-github", 96 | "type": "fontawesome", 97 | } 98 | ], 99 | "use_edit_page_button": True, 100 | "collapse_navigation": True, 101 | } 102 | html_context = { 103 | "github_user": "martinferianc", 104 | "github_repo": "yamle", 105 | "github_version": "main", 106 | "doc_path": "docs", 107 | "default_mode": "light", 108 | } 109 | 110 | htmlhelp_basename = "{}doc".format(project) 111 | 112 | napoleon_use_rtype = False 113 | 114 | rst_prolog = """ 115 | .. role:: python(code) 116 | :language: python 117 | :class: highlight 118 | """ -------------------------------------------------------------------------------- /yamle/evaluation/metrics/algorithmic/segmentation.py: -------------------------------------------------------------------------------- 1 | from typing import List, Optional, Any 2 | 3 | import torchmetrics 4 | import torch 5 | from yamle.defaults import TINY_EPSILON 6 | 7 | 8 | class IntersectionOverUnion(torchmetrics.Metric): 9 | """Calculate the intersection over union (IoU) of two tensors. 10 | 11 | The input is assumed to be probabilities, so the output of a softmax layer. 12 | 13 | Args: 14 | num_classes (int): Number of classes in the dataset. Defaults to 2. 15 | flatten (bool): Whether to flatten the input. Defaults to False. 16 | ignore_indices (Optional[List[int]]): List of indices to ignore. Defaults to None. 17 | """ 18 | 19 | is_differentiable = False 20 | higher_is_better = False 21 | full_state_update = True 22 | 23 | def __init__( 24 | self, 25 | num_classes: int = 10, 26 | flatten: bool = False, 27 | ignore_indices: Optional[List[int]] = None, 28 | *args: Any, 29 | **kwargs: Any 30 | ) -> None: 31 | super().__init__(*args, **kwargs) 32 | self._num_classes = num_classes 33 | self._flatten = flatten 34 | self._ignore_indices = ignore_indices 35 | 36 | self.add_state( 37 | "confusion_matrix", 38 | default=torch.zeros((self._num_classes, self._num_classes)), 39 | dist_reduce_fx="sum", 40 | ) 41 | 42 | def _confusion_matrix(self, y_hat: torch.Tensor, y: torch.Tensor) -> torch.Tensor: 43 | """A helper function to compute the confusion matrix.""" 44 | mask = (y >= 0) & (y < self._num_classes) 45 | cm = torch.bincount( 46 | self._num_classes * y[mask].long() + y_hat[mask], 47 | minlength=self._num_classes**2, 48 | ).reshape(self._num_classes, self._num_classes) 49 | if self._ignore_indices is not None: 50 | cm[:, self._ignore_indices] = 0 51 | cm[self._ignore_indices, :] = 0 52 | return cm 53 | 54 | def update(self, preds: torch.Tensor, target: torch.Tensor) -> None: 55 | """Update metric states with predictions and targets. 56 | 57 | Args: 58 | preds (torch.Tensor): Predictions from model (probabilities). 59 | target (torch.Tensor): Ground truth values. 60 | """ 61 | if self._flatten: 62 | preds = preds.permute(0, *(tuple(range(2, preds.ndim)) + (1,))).reshape( 63 | -1, preds.shape[1] 64 | ) 65 | target = target.flatten() 66 | preds = preds.argmax(dim=1) 67 | 68 | self.confusion_matrix += self._confusion_matrix(preds, target) 69 | 70 | def compute(self) -> torch.Tensor: 71 | """Compute the intersection over union.""" 72 | # The tiny epsilon is added to avoid division by zero. 73 | iou = torch.diag(self.confusion_matrix) / ( 74 | self.confusion_matrix.sum(dim=1) 75 | + self.confusion_matrix.sum(dim=0) 76 | - torch.diag(self.confusion_matrix) 77 | + TINY_EPSILON 78 | ) 79 | ignored_indices = ( 80 | 0 if self._ignore_indices is None else len(self._ignore_indices) 81 | ) 82 | return torch.sum(iou) / (self._num_classes - ignored_indices) 83 | -------------------------------------------------------------------------------- /yamle/models/specific/evidential_regression.py: -------------------------------------------------------------------------------- 1 | from typing import Union, Tuple 2 | 3 | import torch 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | 7 | 8 | class NormalGammaLinear(nn.Linear): 9 | """This class defines the normal-gamma linear layer. 10 | 11 | It has 4 output features: mean, variance, alpha, and beta. 12 | 13 | Args: 14 | in_features (int): The number of input features. 15 | bias (bool): Whether to use bias or not. 16 | """ 17 | 18 | def __init__(self, in_features: int, bias: bool = True) -> None: 19 | super().__init__(in_features, 4, bias) 20 | 21 | def _evidence(self, x: torch.Tensor) -> torch.Tensor: 22 | """Apply the softplus function to calculate evidence.""" 23 | return F.softplus(x) 24 | 25 | def forward(self, x: torch.Tensor) -> torch.Tensor: 26 | output = super().forward(x) 27 | mu, logv, logalpha, logbeta = torch.split(output, 1, dim=-1) 28 | v = self._evidence(logv) 29 | alpha = self._evidence(logalpha) + 1.0 30 | beta = self._evidence(logbeta) 31 | return torch.cat([mu, v, alpha, beta], dim=-1) 32 | 33 | 34 | class NormalGammaConv2d(nn.Conv2d): 35 | """This class defines the normal-gamma convolutional layer. 36 | 37 | It has 4 output features: mean, variance, alpha, and beta. 38 | 39 | Args: 40 | in_channels (int): Number of input channels. 41 | kernel_size (Union[int, Tuple[int, int]]): Size of the convolutional kernel. 42 | stride (Union[int, Tuple[int, int]], optional): Stride for the convolution operation. Default: 1. 43 | padding (Union[int, Tuple[int, int]], optional): Zero-padding added to both sides of the input. Default: 0. 44 | dilation (Union[int, Tuple[int, int]], optional): Spacing between kernel elements. Default: 1. 45 | groups (int, optional): Number of blocked connections from input channels to output channels. Default: 1. 46 | bias (bool, optional): Whether to use bias or not. Default: True. 47 | """ 48 | 49 | def __init__( 50 | self, 51 | in_channels: int, 52 | kernel_size: Union[int, Tuple[int, int]], 53 | stride: Union[int, Tuple[int, int]] = 1, 54 | padding: Union[int, Tuple[int, int]] = 0, 55 | dilation: Union[int, Tuple[int, int]] = 1, 56 | groups: int = 1, 57 | bias: bool = True, 58 | ) -> None: 59 | super(NormalGammaConv2d, self).__init__( 60 | in_channels, 61 | 4, 62 | kernel_size, 63 | stride=stride, 64 | padding=padding, 65 | dilation=dilation, 66 | groups=groups, 67 | bias=bias, 68 | ) 69 | 70 | def _evidence(self, x: torch.Tensor) -> torch.Tensor: 71 | """Apply the softplus function to calculate evidence.""" 72 | return F.softplus(x) 73 | 74 | def forward(self, x: torch.Tensor) -> torch.Tensor: 75 | output = super(NormalGammaConv2d, self).forward(x) 76 | mu, logv, logalpha, logbeta = torch.split(output, self.out_channels, dim=1) 77 | v = self._evidence(logv) 78 | alpha = self._evidence(logalpha) + 1.0 79 | beta = self._evidence(logbeta) 80 | return torch.cat([mu, v, alpha, beta], dim=1) 81 | -------------------------------------------------------------------------------- /yamle/models/model.py: -------------------------------------------------------------------------------- 1 | from typing import Any, Union, Dict, Tuple 2 | import abc 3 | import torch 4 | import torch.nn as nn 5 | import argparse 6 | 7 | from yamle.defaults import SUPPORTED_TASKS 8 | 9 | 10 | class BaseModel(nn.Module, abc.ABC): 11 | """This is the base class for all the models. 12 | 13 | By default it should have an input and output layer in `_input` and `_output` respectively. 14 | All the intermediate layers should be in `_layers`. 15 | The depth of the model should be in `_depth`. 16 | 17 | Args: 18 | inputs_dim (Tuple[int,...]): The input dimensions. 19 | outputs_dim (int): The output dimension. 20 | task (str): The task to perform. 21 | """ 22 | 23 | tasks = SUPPORTED_TASKS 24 | 25 | def __init__( 26 | self, 27 | inputs_dim: Tuple[int, ...], 28 | outputs_dim: int, 29 | task: str, 30 | seed: int, 31 | *args: Any, 32 | **kwargs: Any, 33 | ) -> None: 34 | super().__init__() 35 | self._inputs_dim = inputs_dim 36 | self._outputs_dim = outputs_dim 37 | assert ( 38 | task in self.tasks 39 | ), f"The task {task} is not supported. Supported tasks are {self.tasks}." 40 | self._task = task 41 | self._output: nn.Module = None 42 | self._input: nn.Module = None 43 | self._output_activation: nn.Module = None 44 | self._layers: Union[nn.ModuleList, nn.Sequential] = None 45 | 46 | self._added_method_specific_layers = False 47 | self._method: str = None 48 | self._method_kwargs: Dict[str, Any] = None 49 | self._depth: int = None 50 | self._seed = seed 51 | 52 | @abc.abstractmethod 53 | def forward(self, x: torch.Tensor) -> torch.Tensor: 54 | """This method is used to perform a forward pass of the model.""" 55 | raise NotImplementedError("The forward method must be implemented.") 56 | 57 | @abc.abstractmethod 58 | def final_layer(self, x: torch.Tensor, **output_kwargs: Any) -> torch.Tensor: 59 | """This function is used to get the final layer output.""" 60 | raise NotImplementedError("The final_layer method must be implemented.") 61 | 62 | @classmethod 63 | def add_specific_args( 64 | cls, parent_parser: argparse.ArgumentParser 65 | ) -> argparse.ArgumentParser: 66 | """This method adds model arguments to the given parser.""" 67 | return argparse.ArgumentParser(parents=[parent_parser], add_help=False) 68 | 69 | def reset(self) -> None: 70 | """This method is used to reset the model e.g. at the start of a new epoch.""" 71 | pass 72 | 73 | def replace_layers_for_quantization(self) -> None: 74 | """Fuses all the operations in the network. 75 | 76 | In this function we only need to fuse layers that are not in the blocks. 77 | e.g. the reshaping layers added by the method. 78 | """ 79 | pass 80 | 81 | def add_method_specific_layers(self, method: str, **kwargs: Any) -> None: 82 | """This method is used to add method specific layers to the model. 83 | 84 | Args: 85 | method (str): The method to use. 86 | """ 87 | self._added_method_specific_layers = True 88 | self._method = method 89 | self._method_kwargs = kwargs 90 | -------------------------------------------------------------------------------- /yamle/third_party/tinyimagenet/__init__.py: -------------------------------------------------------------------------------- 1 | from typing import Any 2 | import os 3 | import shutil 4 | 5 | from torchvision.datasets import ImageFolder 6 | from torchvision.datasets.utils import download_and_extract_archive, verify_str_arg 7 | 8 | 9 | def normalize_tin_validation_folder_structure( 10 | path: str, 11 | images_folder: str = "images", 12 | annotations_file: str = "val_annotations.txt", 13 | ) -> None: 14 | # Check if files/annotations are still there to see 15 | # if we already run reorganize the folder structure. 16 | images_folder = os.path.join(path, images_folder) 17 | annotations_file = os.path.join(path, annotations_file) 18 | 19 | # Exists 20 | if not os.path.exists(images_folder) and not os.path.exists(annotations_file): 21 | if not os.listdir(path): 22 | raise RuntimeError("Validation folder is empty.") 23 | return 24 | 25 | # Parse the annotations 26 | with open(annotations_file) as f: 27 | for line in f: 28 | values = line.split() 29 | img = values[0] 30 | label = values[1] 31 | img_file = os.path.join(images_folder, values[0]) 32 | label_folder = os.path.join(path, label) 33 | os.makedirs(label_folder, exist_ok=True) 34 | try: 35 | shutil.move(img_file, os.path.join(label_folder, img)) 36 | except FileNotFoundError: 37 | continue 38 | 39 | os.sync() 40 | assert not os.listdir(images_folder) 41 | shutil.rmtree(images_folder) 42 | os.remove(annotations_file) 43 | os.sync() 44 | 45 | 46 | class TinyImageNet(ImageFolder): 47 | """Dataset for TinyImageNet-200 48 | 49 | Taken from: https://gist.github.com/lromor/bcfc69dcf31b2f3244358aea10b7a11b 50 | """ 51 | 52 | base_folder = "tiny-imagenet-200" 53 | zip_md5 = "90528d7ca1a48142e341f4ef8d21d0de" 54 | splits = ("train", "val") 55 | filename = "tiny-imagenet-200.zip" 56 | url = "http://cs231n.stanford.edu/tiny-imagenet-200.zip" 57 | 58 | def __init__( 59 | self, root: str, split: str = "train", download: bool = False, **kwargs: Any 60 | ) -> None: 61 | self.data_root = os.path.expanduser(root) 62 | self.split = verify_str_arg(split, "split", self.splits) 63 | 64 | if download: 65 | self.download() 66 | 67 | if not self._check_exists(): 68 | raise RuntimeError( 69 | "Dataset not found." + " You can use download=True to download it" 70 | ) 71 | super().__init__(self.split_folder, **kwargs) 72 | 73 | @property 74 | def dataset_folder(self) -> str: 75 | return os.path.join(self.data_root, self.base_folder) 76 | 77 | @property 78 | def split_folder(self) -> str: 79 | return os.path.join(self.dataset_folder, self.split) 80 | 81 | def _check_exists(self) -> bool: 82 | return os.path.exists(self.split_folder) 83 | 84 | def extra_repr(self): 85 | return "Split: {split}".format(**self.__dict__) 86 | 87 | def download(self) -> None: 88 | if self._check_exists(): 89 | return 90 | download_and_extract_archive( 91 | self.url, 92 | self.data_root, 93 | filename=self.filename, 94 | remove_finished=True, 95 | md5=self.zip_md5, 96 | ) 97 | assert "val" in self.splits 98 | normalize_tin_validation_folder_structure(os.path.join(self.dataset_folder, "val")) 99 | -------------------------------------------------------------------------------- /yamle/regularizers/weight.py: -------------------------------------------------------------------------------- 1 | from typing import Any 2 | 3 | from yamle.regularizers.regularizer import BaseRegularizer 4 | 5 | import torch 6 | import torch.nn as nn 7 | 8 | 9 | class L1Regularizer(BaseRegularizer): 10 | """This is a class for L1 regularization.""" 11 | 12 | def __call__(self, model: nn.Module) -> torch.Tensor: 13 | """This method is used to calculate the regularization loss.""" 14 | loss = torch.tensor(0.0, device=next(model.parameters()).device) 15 | for param in self.get_parameters(model): 16 | loss += torch.sum(torch.abs(param)) 17 | return loss 18 | 19 | def __repr__(self) -> str: 20 | return f"L1()" 21 | 22 | 23 | class L2Regularizer(BaseRegularizer): 24 | """This is a class for L2 regularization.""" 25 | 26 | def __call__(self, model: nn.Module) -> torch.Tensor: 27 | """This method is used to calculate the regularization loss.""" 28 | loss = torch.tensor(0.0, device=next(model.parameters()).device) 29 | for param in self.get_parameters(model): 30 | loss += torch.sum(param**2) 31 | return loss * 0.5 32 | 33 | def __repr__(self) -> str: 34 | return f"L2()" 35 | 36 | 37 | class L1L2Regularizer(BaseRegularizer): 38 | """This is a class for combined L1 and L2 regularization.""" 39 | 40 | def __call__(self, model: nn.Module) -> torch.Tensor: 41 | """This method is used to calculate the regularization loss.""" 42 | loss = torch.tensor(0.0, device=next(model.parameters()).device) 43 | for param in self.get_parameters(model): 44 | loss += 0.5 * torch.sum(param**2) 45 | loss += torch.sum(torch.abs(param)) 46 | return loss 47 | 48 | def __repr__(self) -> str: 49 | return f"L1L2()" 50 | 51 | 52 | class WeightDecayRegularizer(BaseRegularizer): 53 | """This is a class for weight decay regularization. 54 | 55 | It is implemented in an inefficient manner to be compatible with any optimizer. 56 | 57 | During the ``__call__`` method, the weights at time ``t`` are cached. 58 | Then, during the `update_on_step` method, the weights, which were already updated by the optimizer, are further updated by weight decay. 59 | 60 | The weight decay is applied as follows: 61 | 62 | w_{t+1} = (1 - weight) * w_{t} - \eta * ∇L(w_{t+1}) 63 | 64 | Hence, after the optimization step, assuming that only ``w_{t+1} = w_{t} - \eta * ∇L(w_{t+1})`` was applied, 65 | we need to apply the ``-weight * w_{t}`` term. The weight is scaled by the learning rate. 66 | """ 67 | 68 | def __call__(self, model: nn.Module) -> torch.Tensor: 69 | """This method is used to cache all the weight values *before* the optimization step. 70 | 71 | This is done such that the weights can then be updated at the very end of the training batch. 72 | """ 73 | for param in self.get_parameters(model): 74 | param._cached_weight = param.data.clone().detach() 75 | return torch.tensor(0.0, device=next(model.parameters()).device) 76 | 77 | def on_after_training_step( 78 | self, model: nn.Module, weight: float, lr: float, *args: Any, **kwargs: Any 79 | ) -> None: 80 | """This method is used to update the model on a given step.""" 81 | for param in self.get_parameters(model): 82 | param.data.add_(param._cached_weight, alpha=-weight * lr) 83 | # Reset the cached weight 84 | del param._cached_weight 85 | 86 | def __repr__(self) -> str: 87 | return f"WeightDecay()" 88 | -------------------------------------------------------------------------------- /yamle/quantization/static/__init__.py: -------------------------------------------------------------------------------- 1 | import logging 2 | from typing import Any 3 | 4 | import torch 5 | from pytorch_lightning import LightningModule, Trainer 6 | from torch.ao.quantization.fake_quantize import FakeQuantize 7 | from torch.ao.quantization.observer import HistogramObserver 8 | 9 | from yamle.defaults import QUANTIZED_KEY 10 | from yamle.quantization.quantizer import BaseQuantizer 11 | 12 | logging = logging.getLogger("pytorch_lightning") 13 | 14 | 15 | class StaticQuantizer(BaseQuantizer): 16 | """This is the static quantizer class. 17 | 18 | It performs static post-training quantization on the model. 19 | It does it with respect to a specific number of bits for the activation and weight. 20 | The quantization is simulated and the model is not actually quantized. 21 | 22 | """ 23 | 24 | def __call__(self, trainer: Trainer, method: LightningModule) -> None: 25 | """This method is used to quantize the model. 26 | 27 | A copy of the model is saved before quantization, just in case. 28 | First the model is prepared for quantization. 29 | Then the trainer is queried for the dataloader - this can be used to calibrate the model or fine-tune it. 30 | Then the the fake quantization is applied to the model and the observer is disabled to simulate quantization. 31 | The original model is kept such that it can be recovered. 32 | """ 33 | self.save_original_model(method) 34 | self.prepare(trainer, method) 35 | 36 | trainer.calibrate() 37 | 38 | method.model.apply(torch.ao.quantization.enable_fake_quant) 39 | method.model.apply(torch.ao.quantization.disable_observer) 40 | 41 | self.save_quantized_model(method) 42 | logging.info("Model quantized.") 43 | logging.info(method.model) 44 | setattr(method, QUANTIZED_KEY, True) 45 | 46 | def prepare(self, trainer: Trainer, method: LightningModule) -> None: 47 | """This method is used to prepare the model for quantization.""" 48 | method.model.eval() 49 | self.replace_layers_for_quantization(method.model) 50 | method.model.qconfig = self.get_qconfig() 51 | method.model.train() 52 | torch.quantization.prepare_qat(method.model, inplace=True) 53 | method.model.apply(torch.ao.quantization.disable_fake_quant) 54 | logging.info("Model prepared for quantization.") 55 | logging.info(method.model) 56 | 57 | def get_qconfig(self) -> Any: 58 | """This method is used to get the quantization configuration. 59 | 60 | We use the number of activation and weight bits to create the quantization configuration. 61 | """ 62 | # Else specify the qconfig manually based on the activation and weight bits 63 | activation_bits = self._activation_bits 64 | weight_bits = self._weight_bits 65 | 66 | activation_fq = FakeQuantize.with_args( 67 | observer=HistogramObserver, 68 | quant_min=0, 69 | quant_max=int(2**activation_bits - 1), 70 | dtype=torch.quint8, 71 | qscheme=torch.per_tensor_affine, 72 | reduce_range=False, # Since this is in simulation, we don't want to reduce the range 73 | ) 74 | weight_fq = FakeQuantize.with_args( 75 | observer=HistogramObserver, 76 | quant_min=-int((2**weight_bits) / 2), 77 | quant_max=int((2**weight_bits) / 2 - 1), 78 | dtype=torch.qint8, 79 | qscheme=torch.per_tensor_symmetric, 80 | reduce_range=False, # Since this is in simulation, we don't want to reduce the range 81 | ) 82 | 83 | qconfig = torch.quantization.QConfig(activation=activation_fq, weight=weight_fq) 84 | return qconfig 85 | -------------------------------------------------------------------------------- /yamle/utils/pruning_utils.py: -------------------------------------------------------------------------------- 1 | from typing import List, Union 2 | 3 | import torch 4 | import torch.nn as nn 5 | 6 | from yamle.defaults import DISABLED_PRUNING_KEY 7 | 8 | 9 | def enable_pruning(m: Union[nn.Parameter, List[nn.Parameter]]) -> None: 10 | """Enable pruning for the given parameters. 11 | 12 | Args: 13 | m (Union[nn.Parameter, List[nn.Parameter]]): The parameters to enable pruning for. 14 | """ 15 | if isinstance(m, nn.Parameter): 16 | setattr(m, DISABLED_PRUNING_KEY, False) 17 | elif isinstance(m, list): 18 | for param in m: 19 | setattr(param, DISABLED_PRUNING_KEY, False) 20 | else: 21 | raise ValueError( 22 | f"The parameters should be either a list of parameters or a single parameter. Got {type(m)}." 23 | ) 24 | 25 | 26 | def disable_pruning(m: Union[nn.Parameter, List[nn.Parameter]]) -> None: 27 | """Disable pruning for the given parameters. 28 | 29 | Args: 30 | m (Union[nn.Parameter, List[nn.Parameter]]): The parameters to disable pruning for. 31 | """ 32 | if isinstance(m, nn.Parameter): 33 | setattr(m, DISABLED_PRUNING_KEY, True) 34 | elif isinstance(m, list): 35 | for param in m: 36 | setattr(param, DISABLED_PRUNING_KEY, True) 37 | else: 38 | raise ValueError( 39 | f"The parameters should be either a list of parameters or a single parameter. Got {type(m)}." 40 | ) 41 | 42 | 43 | def is_layer_prunable(layer: nn.Module) -> bool: 44 | """Check if a layer is prunable. 45 | 46 | Args: 47 | layer (nn.Module): The layer to check. 48 | """ 49 | return ( 50 | isinstance(layer, nn.Linear) 51 | or issubclass(type(layer), nn.Linear) 52 | or isinstance(layer, nn.Conv2d) 53 | or issubclass(type(layer), nn.Conv2d) 54 | ) 55 | 56 | 57 | def is_parameter_prunable(param: Union[nn.Parameter, List[nn.Parameter]]) -> bool: 58 | """Check if a parameter is prunable. 59 | 60 | Args: 61 | param (Union[nn.Parameter, torch.Tensor]): The parameter to check. 62 | """ 63 | if isinstance(param, nn.Parameter): 64 | if not hasattr(param, DISABLED_PRUNING_KEY): 65 | return True 66 | return hasattr(param, DISABLED_PRUNING_KEY) and not getattr( 67 | param, DISABLED_PRUNING_KEY 68 | ) 69 | elif isinstance(param, list): 70 | for p in param: 71 | if not is_parameter_prunable(p): 72 | return False 73 | return True 74 | else: 75 | raise ValueError( 76 | f"The parameters should be either a list of parameters or a single parameter. Got {type(param)}." 77 | ) 78 | 79 | 80 | def get_all_prunable_weights(module: nn.Module) -> torch.Tensor: 81 | """Get all the prunable weights in the model. 82 | 83 | All the parameters of the prunable layers will be flattened into a single vector. 84 | These weights will be returned in a single Tensor. 85 | 86 | Args: 87 | m (nn.Module): The model to get the weights from. 88 | """ 89 | weights = [] 90 | for m in module.modules(): 91 | if is_layer_prunable(m): 92 | for p in m.parameters(): 93 | if is_parameter_prunable(p): 94 | weights.append(p.data.view(-1)) 95 | return torch.cat(weights) 96 | 97 | 98 | def get_all_prunable_modules(module: nn.Module) -> List[nn.Module]: 99 | """Get all the prunable layers in the model. 100 | 101 | Args: 102 | m (nn.Module): The model to get the layers from. 103 | """ 104 | layers = [] 105 | for m in module.modules(): 106 | if is_layer_prunable(m): 107 | layers.append(m) 108 | return layers 109 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | share/python-wheels/ 24 | *.egg-info/ 25 | .installed.cfg 26 | *.egg 27 | MANIFEST 28 | 29 | # PyInstaller 30 | # Usually these files are written by a python script from a template 31 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 32 | *.manifest 33 | *.spec 34 | 35 | # Installer logs 36 | pip-log.txt 37 | pip-delete-this-directory.txt 38 | 39 | # Unit test / coverage reports 40 | htmlcov/ 41 | .tox/ 42 | .nox/ 43 | .coverage 44 | .coverage.* 45 | .cache 46 | nosetests.xml 47 | coverage.xml 48 | *.cover 49 | *.py,cover 50 | .hypothesis/ 51 | .pytest_cache/ 52 | cover/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | .pybuilder/ 76 | target/ 77 | 78 | # Jupyter Notebook 79 | .ipynb_checkpoints 80 | 81 | # IPython 82 | profile_default/ 83 | ipython_config.py 84 | 85 | # pyenv 86 | # For a library or package, you might want to ignore these files since the code is 87 | # intended to run in multiple environments; otherwise, check them in: 88 | # .python-version 89 | 90 | # pipenv 91 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 92 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 93 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 94 | # install all needed dependencies. 95 | #Pipfile.lock 96 | 97 | # poetry 98 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. 99 | # This is especially recommended for binary packages to ensure reproducibility, and is more 100 | # commonly ignored for libraries. 101 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control 102 | #poetry.lock 103 | 104 | # pdm 105 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. 106 | #pdm.lock 107 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it 108 | # in version control. 109 | # https://pdm.fming.dev/#use-with-ide 110 | .pdm.toml 111 | 112 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm 113 | __pypackages__/ 114 | 115 | # Celery stuff 116 | celerybeat-schedule 117 | celerybeat.pid 118 | 119 | # SageMath parsed files 120 | *.sage.py 121 | 122 | # Environments 123 | .env 124 | .venv 125 | env/ 126 | venv/ 127 | ENV/ 128 | env.bak/ 129 | venv.bak/ 130 | 131 | # Spyder project settings 132 | .spyderproject 133 | .spyproject 134 | 135 | # Rope project settings 136 | .ropeproject 137 | 138 | # mkdocs documentation 139 | /site 140 | 141 | # mypy 142 | .mypy_cache/ 143 | .dmypy.json 144 | dmypy.json 145 | 146 | # Pyre type checker 147 | .pyre/ 148 | 149 | # pytype static type analyzer 150 | .pytype/ 151 | 152 | # Cython debug symbols 153 | cython_debug/ 154 | 155 | # PyCharm 156 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can 157 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore 158 | # and can be added to the global gitignore or merged into this file. For a more nuclear 159 | # option (not recommended) you can uncomment the following to ignore the entire idea folder. 160 | #.idea/ 161 | 162 | .vscode/ -------------------------------------------------------------------------------- /yamle/regularizers/regularizer.py: -------------------------------------------------------------------------------- 1 | from typing import List, Union, Any, Tuple 2 | from abc import ABC 3 | import torch 4 | import torch.nn as nn 5 | import argparse 6 | 7 | from yamle.defaults import DISABLED_REGULARIZER_KEY 8 | 9 | 10 | class BaseRegularizer(ABC): 11 | """This is a general class for regularizers applied to the model (L1, L2, etc.).""" 12 | 13 | def __call__(self, model: Union[nn.Module, torch.Tensor]) -> torch.Tensor: 14 | """This method is used to calculate the regularization loss.""" 15 | return torch.tensor(0.0, device=next(model.parameters()).device) 16 | 17 | def get_parameters(self, model: nn.Module) -> List[nn.Parameter]: 18 | """This method is used to get the parameters of the model that should be regularized.""" 19 | params = [] 20 | for param in model.parameters(): 21 | if param.requires_grad: 22 | if hasattr(param, DISABLED_REGULARIZER_KEY) and getattr( 23 | param, DISABLED_REGULARIZER_KEY 24 | ): 25 | continue 26 | params.append(param) 27 | return params 28 | 29 | def get_names(self, model: nn.Module) -> Tuple[List[str], List[str]]: 30 | """This method is used to get the names of the parameters of the model that should and should not be regularized.""" 31 | regularized = [] 32 | not_regularized = [] 33 | for name, param in model.named_parameters(): 34 | if param.requires_grad: 35 | if hasattr(param, DISABLED_REGULARIZER_KEY) and getattr( 36 | param, DISABLED_REGULARIZER_KEY 37 | ): 38 | not_regularized.append(name) 39 | else: 40 | regularized.append(name) 41 | else: 42 | not_regularized.append(name) 43 | return regularized, not_regularized 44 | 45 | def on_after_training_step( 46 | self, model: nn.Module, *args: Any, **kwargs: Any 47 | ) -> None: 48 | """This method is used to update the model after a given training step. 49 | 50 | It can be used to implement a weight decay strategy, e.g. update the weights after each training 51 | batch by decaying them with a given factor multiplied by the learning rate. 52 | """ 53 | pass 54 | 55 | def on_after_backward(self, model: nn.Module, *args: Any, **kwargs: Any) -> None: 56 | """This method is used to update the model after the backward pass. 57 | 58 | It can be used to update the model after the backward pass, e.g. add noise to the gradients. 59 | """ 60 | pass 61 | 62 | def on_after_train_epoch(self, model: nn.Module, *args: Any, **kwargs: Any) -> None: 63 | """This method is used to update the model after a given training epoch. 64 | 65 | It can be used to add noise to the model after each training epoch. 66 | """ 67 | pass 68 | 69 | @staticmethod 70 | def add_specific_args( 71 | parent_parser: argparse.ArgumentParser, 72 | ) -> argparse.ArgumentParser: 73 | """This method is used to add specific arguments to the parser.""" 74 | parser = argparse.ArgumentParser(parents=[parent_parser], add_help=False) 75 | return parser 76 | 77 | def __repr__(self) -> str: 78 | return f"Regularizer()" 79 | 80 | 81 | class DummyRegularizer(BaseRegularizer): 82 | """This is a class for a dummy regularizer that does nothing.""" 83 | 84 | def __call__(self, *args: Any, **kwargs: Any) -> torch.Tensor: 85 | """This method is used to calculate the regularization loss.""" 86 | for arg in args: 87 | if isinstance(arg, torch.Tensor): 88 | device = arg.device 89 | break 90 | elif isinstance(arg, nn.Module): 91 | device = next(arg.parameters()).device 92 | break 93 | return torch.tensor(0.0, device=device) 94 | 95 | def __repr__(self) -> str: 96 | return "DummyRegularizer()" 97 | -------------------------------------------------------------------------------- /yamle/methods/__init__.py: -------------------------------------------------------------------------------- 1 | from typing import Type 2 | 3 | 4 | from yamle.methods.method import BaseMethod 5 | from yamle.methods.augmentation_classification import ( 6 | CutOutImageClassificationMethod, 7 | CutMixImageClassificationMethod, 8 | MixUpImageClassificationMethod, 9 | RandomErasingImageClassificationMethod, 10 | ) 11 | from yamle.methods.contrastive import SimCLRVisionMethod 12 | from yamle.methods.mimo import ( 13 | MIMOMethod, 14 | MixMoMethod, 15 | DataMUXMethod, 16 | UnMixMoMethod, 17 | MIMMOMethod, 18 | SAEMethod, 19 | MixVitMethod, 20 | ) 21 | from yamle.methods.pe import PEMethod 22 | from yamle.methods.mcdropout import ( 23 | MCDropoutMethod, 24 | MCDropConnectMethod, 25 | MCStandOutMethod, 26 | MCDropBlockMethod, 27 | MCStochasticDepthMethod, 28 | ) 29 | from yamle.methods.ensemble import ( 30 | EnsembleMethod, 31 | SnapsotEnsembleMethod, 32 | GradientBoostingEnsembleMethod, 33 | ) 34 | from yamle.methods.moe import ( 35 | MultiHeadEnsembleMethod, 36 | MixtureOfExpertsMethod, 37 | ) 38 | from yamle.methods.dun import DUNMethod 39 | from yamle.methods.early_exit import EarlyExitMethod 40 | from yamle.methods.sngp import SNGPMethod 41 | from yamle.methods.be import BEMethod 42 | from yamle.methods.temperature_scaling import TemperatureMethod 43 | from yamle.methods.rbnn import RBNNMethod 44 | from yamle.methods.svi import ( 45 | SVIRTMethod, 46 | SVILRTMethod, 47 | SVIFlipOutRTMethod, 48 | SVIFlipOutDropConnectMethod, 49 | SVILRTVDMethod, 50 | ) 51 | from yamle.methods.delta_uq import DeltaUQMethod 52 | from yamle.methods.gp import GPMethod 53 | from yamle.methods.evidential_regression import ( 54 | EvidentialRegressionMethod, 55 | ) 56 | from yamle.methods.sgld import SGLDMethod 57 | from yamle.methods.laplace import LaplaceMethod 58 | from yamle.methods.swag import SWAGMethod 59 | 60 | AVAILABLE_METHODS = { 61 | "base": BaseMethod, 62 | "simclrvision": SimCLRVisionMethod, 63 | "cutout": CutOutImageClassificationMethod, 64 | "cutmix": CutMixImageClassificationMethod, 65 | "mixup": MixUpImageClassificationMethod, 66 | "random_erasing": RandomErasingImageClassificationMethod, 67 | "mimo": MIMOMethod, 68 | "mimmo": MIMMOMethod, 69 | "sae": SAEMethod, 70 | "mixmo": MixMoMethod, 71 | "mixvit": MixVitMethod, 72 | "unmixmo": UnMixMoMethod, 73 | "datamux": DataMUXMethod, 74 | "pe": PEMethod, 75 | "svirt": SVIRTMethod, 76 | "svilrt": SVILRTMethod, 77 | "svilrtvd": SVILRTVDMethod, 78 | "sviflipout_gaussian": SVIFlipOutRTMethod, 79 | "sviflipout_dropconnect": SVIFlipOutDropConnectMethod, 80 | "mcdropout": MCDropoutMethod, 81 | "mcdropconnect": MCDropConnectMethod, 82 | "mcstandout": MCStandOutMethod, 83 | "mcdropblock": MCDropBlockMethod, 84 | "mcstochasticdepth": MCStochasticDepthMethod, 85 | "ensemble": EnsembleMethod, 86 | "snapshot_ensemble": SnapsotEnsembleMethod, 87 | "multi_head_ensemble": MultiHeadEnsembleMethod, 88 | "mixture_of_experts": MixtureOfExpertsMethod, 89 | "gradient_boosting_ensemble": GradientBoostingEnsembleMethod, 90 | "sgld": SGLDMethod, 91 | "dun": DUNMethod, 92 | "swag": SWAGMethod, 93 | "early_exit": EarlyExitMethod, 94 | "sngp": SNGPMethod, 95 | "be": BEMethod, 96 | "temperature": TemperatureMethod, 97 | "rbnn": RBNNMethod, 98 | "delta_uq": DeltaUQMethod, 99 | "gp": GPMethod, 100 | "laplace": LaplaceMethod, 101 | "evidential_regression": EvidentialRegressionMethod, 102 | } 103 | 104 | 105 | def method_factory(method_type: str) -> Type[BaseMethod]: 106 | """This function is used to create a method instance based on the method type. 107 | 108 | Args: 109 | method_type (str): The type of method to create. 110 | """ 111 | if method_type not in AVAILABLE_METHODS: 112 | raise ValueError(f"Unknown method type {method_type}.") 113 | return AVAILABLE_METHODS[method_type] 114 | -------------------------------------------------------------------------------- /yamle/data/dataset_wrappers.py: -------------------------------------------------------------------------------- 1 | from typing import Any, Callable, Optional 2 | 3 | import torch 4 | import torchvision.transforms.functional as F 5 | from PIL import Image 6 | from torch.utils.data import Dataset 7 | 8 | 9 | class SurrogateDataset(Dataset): 10 | """This class is a dataset wrapper, ensuring that the transforms are applied to the data and targets 11 | after splitting the dataset into training and validation. 12 | 13 | Args: 14 | dataset (Dataset): Dataset to wrap. 15 | transform (Optional[Callable]): Transformations to apply to the data. 16 | target_transform (Optional[Callable]): Transformations to apply to the targets. 17 | joint_transform (Optional[Callable]): Transformations to apply to the input as well as the target. 18 | """ 19 | 20 | def __init__( 21 | self, 22 | dataset: Dataset, 23 | transform: Optional[Callable] = None, 24 | target_transform: Optional[Callable] = None, 25 | joint_transform: Optional[Callable] = None, 26 | ) -> None: 27 | self._dataset = dataset 28 | self._transform = transform 29 | self._target_transform = target_transform 30 | self._joint_transform = joint_transform 31 | 32 | def __getitem__(self, index: int) -> Any: 33 | data, target = self._dataset[index] 34 | if self._joint_transform is not None: 35 | data, target = self._joint_transform(data, target) 36 | if self._transform is not None: 37 | data = self._transform(data) 38 | if self._target_transform is not None: 39 | target = self._target_transform(target) 40 | return data, target 41 | 42 | def __len__(self) -> int: 43 | return len(self._dataset) 44 | 45 | 46 | class InputImagePaddingDataset(Dataset): 47 | """This class is a dataset wrapper, which can pad the input image to a given size. 48 | 49 | Args: 50 | dataset (Dataset): Dataset to wrap. 51 | padding (int): Padding to apply to the input image on all sides. 52 | """ 53 | 54 | def __init__(self, dataset: Dataset, padding: int) -> None: 55 | self._dataset = dataset 56 | self._padding = padding 57 | 58 | def __getitem__(self, index: int) -> Any: 59 | data, target = self._dataset[index] 60 | assert isinstance(data, Image.Image), f"Data type {type(data)} is not supported" 61 | data = F.pad(data, self._padding) 62 | return data, target 63 | 64 | def __len__(self) -> int: 65 | return len(self._dataset) 66 | 67 | 68 | class ImageRotationDataset(Dataset): 69 | """This class is a dataset wrapper for image rotation. 70 | 71 | It discards the target and replaces it with the rotation angle which should be predicted. 72 | This changes the task from anything to regression. 73 | 74 | Args: 75 | dataset (Dataset): Dataset to wrap. 76 | max_angle (float): Maximum angle to rotate the image by. Defaults to 90 degrees. 77 | min_angle (float): Minimum angle to rotate the image by. Defaults to 0 degrees. 78 | seed (int): Seed for the random number generator. Defaults to 42. 79 | """ 80 | 81 | def __init__( 82 | self, 83 | dataset: Dataset, 84 | max_angle: float = 90, 85 | min_angle: float = 0, 86 | seed: int = 42, 87 | ) -> None: 88 | self._dataset = dataset 89 | self._max_angle = max_angle 90 | self._min_angle = min_angle 91 | self._seed = seed 92 | self._generator = torch.Generator().manual_seed(self._seed) 93 | 94 | def __getitem__(self, index: int) -> Any: 95 | data, _ = self._dataset[index] 96 | angle = torch.randint( 97 | self._min_angle, self._max_angle, (1,), generator=self._generator 98 | ) 99 | # Scale the angle to between 0 and 1 through min-max scaling 100 | scaled_angle = (angle - self._min_angle) / (self._max_angle - self._min_angle) 101 | data = F.rotate(data, angle.item()) 102 | return data, scaled_angle 103 | 104 | def __len__(self) -> int: 105 | return len(self._dataset) 106 | -------------------------------------------------------------------------------- /docs/extending_yamle/index.rst: -------------------------------------------------------------------------------- 1 | .. _extending_yamle: 2 | 3 | ******************* 4 | Extending YAMLE 5 | ******************* 6 | 7 | This section covers the extension of YAMLE for :py:mod:`BaseModel `, :py:mod:`BaseMethod ` and :py:mod:`BaseDataModule `. 8 | 9 | The :py:mod:`BaseModel ` is the base class for all models in YAMLE. It provides the architecture of the model and its forward pass. It defines several components, such as the input and output layers, which can be modified by the :py:mod:`BaseMethod `. The goal is to write general and configurable implementations of a model that can be used across different datasets and tasks. For example, if defining a multi-layer perceptron, the :py:mod:`BaseModel ` should be configurable to different widths, depths, and activation functions. 10 | 11 | Please see the :ref:`extending_model` section for more details on how to extend YAMLE for new models. 12 | 13 | The :py:mod:`BaseDataModule ` is the base class for all data modules in YAMLE. It is responsible for downloading, loading, and preprocessing data. It defines the task, e.g., classification or regression, to be solved by the :py:mod:`BaseMethod ` and :py:mod:`BaseModel ` and handles data splitting into training, validation, and test sets. It also defines the data input and output dimensions, which can be used to modify the :py:mod:`BaseModel ` by the :py:mod:`BaseMethod `. 14 | 15 | Please see the :ref:`extending_datamodule` section for more details on how to extend YAMLE for new data modules. 16 | 17 | The :py:mod:`BaseMethod ` is the base class for all methods in YAMLE. It defines the interface that can optionally change the model and specifies the training, validation, and test steps by reusing PyTorch Lightning's functionality. For instance, it can be used to implement a new training algorithm by overloading the :py:meth:`_training_step `, :py:meth:`_validation_step `, and :py:meth:`_test_step ` methods. Depending on the provided :py:mod:`BaseDataModule ` it decides automatically which are relevant algorithmic metrics to log and automatically logs them through the use of callbacks provided by PyTorch Lightning at the end of each epoch. The validation metrics are automatically passed to syne-tyne for hyperparameter optimization if it is desired. The :py:mod:`BaseMethod ` also considers the loss function, optimizer, and regularization during training; and can incorporate :py:mod:`BasePruner ` and :py:mod:`BaseQuantizer ` during evaluation. 18 | 19 | Please see the :ref:`extending_method` section for more details on how to extend YAMLE for new methods. 20 | 21 | All the components—:py:mod:`BaseDataModule `, :py:mod:`BaseModel `, and :py:mod:`BaseMethod `—enable customization through defining their own arguments that can be triggered via argparse. These components are orchestrated by the :py:mod:`Trainer/Tester `, responsible for querying the :py:mod:`BaseDataModule `, :py:mod:`BaseModel `, and :py:mod:`BaseMethod `, and executing training and evaluation loops through the step methods, as well as running on a specific device platform. These classes are connected and facilitate end-to-end experiments from data preprocessing to model training and evaluation. It only requires subclassing the appropriate classes, registering them in the framework for selection via argparse, and executing training or evaluation using the methods defined in :py:mod:`yamle.cli `. 22 | 23 | 24 | .. toctree:: 25 | :maxdepth: 2 26 | 27 | datamodule 28 | method 29 | model 30 | -------------------------------------------------------------------------------- /yamle/models/specific/ensemble.py: -------------------------------------------------------------------------------- 1 | from typing import Optional, Iterator, Tuple 2 | import torch 3 | import torch.nn as nn 4 | import copy 5 | import logging 6 | 7 | logging = logging.getLogger("pytorch_lightning") 8 | 9 | 10 | class Ensemble(nn.Module): 11 | """This class defines an ensemble of any models. 12 | 13 | During training, only the model with the `current_trained_member` is trained. 14 | 15 | The copying of the model is done via `copy.deepcopy`. 16 | The initialization of the model ensemble is done through the `reset_parameters` method. 17 | 18 | Args: 19 | model (nn.Module): The model to initially copy for the ensemble. 20 | num_members (int): The number of members in the ensemble. 21 | """ 22 | 23 | def __init__(self, model: nn.Module, num_members: int) -> None: 24 | super(Ensemble, self).__init__() 25 | assert ( 26 | num_members > 0 27 | ), "The number of members in the ensemble must be positive." 28 | self._models = nn.ModuleList([copy.deepcopy(model) for _ in range(num_members)]) 29 | for m in self._models: 30 | for l in m.modules(): 31 | if hasattr(l, "reset_parameters"): 32 | l.reset_parameters() 33 | else: 34 | logging.warn( 35 | f"Module {l} does not have a `reset_parameters` method." 36 | ) 37 | self._num_members = num_members 38 | self.register_buffer("currently_trained_member", torch.tensor(0)) 39 | 40 | def forward( 41 | self, x: torch.Tensor, current_member: Optional[int] = None 42 | ) -> torch.Tensor: 43 | """This method is used to perform a forward pass through the current model.""" 44 | if current_member is not None: 45 | assert ( 46 | current_member >= 0 and current_member < self._num_members 47 | ), "The current member index is out of bounds." 48 | return self._models[current_member](x) 49 | return self._models[self.currently_trained_member.item()](x) 50 | 51 | def parameters( 52 | self, recurse: bool = True, index: Optional[int] = None 53 | ) -> Iterator[nn.Parameter]: 54 | """This method is used to get the parameters of the current model or all models.""" 55 | if self.training: 56 | if index is not None: 57 | assert ( 58 | index >= 0 and index < self._num_members 59 | ), f"The index is out of bounds. It must be between 0 and {self._num_members - 1}." 60 | return self[index].parameters(recurse=recurse) 61 | return self[self.currently_trained_member.item()].parameters( 62 | recurse=recurse 63 | ) 64 | else: 65 | return self._models.parameters(recurse=recurse) 66 | 67 | def named_parameters( 68 | self, prefix: str = "", recurse: bool = True, index: Optional[int] = None 69 | ) -> Iterator[Tuple[str, nn.Parameter]]: 70 | """This method is used to get the named parameters of the current model or all models.""" 71 | if self.training: 72 | if index is not None: 73 | assert ( 74 | index >= 0 and index < self._num_members 75 | ), f"The index is out of bounds. It must be between 0 and {self._num_members - 1}." 76 | return self[index].named_parameters(prefix=prefix, recurse=recurse) 77 | return self[self.currently_trained_member.item()].named_parameters( 78 | prefix=prefix, recurse=recurse 79 | ) 80 | else: 81 | return self._models.named_parameters(prefix=prefix, recurse=recurse) 82 | 83 | def increment_current_member(self) -> None: 84 | """This method is used to increment the current member index.""" 85 | self.currently_trained_member.data.add_(1) 86 | 87 | def reset(self) -> None: 88 | """This method is used to reset a model after an epoch.""" 89 | pass 90 | 91 | def __getitem__(self, index: int) -> nn.Module: 92 | """This method is used to get the model at the given index.""" 93 | return self._models[index] 94 | 95 | def __len__(self) -> int: 96 | """This method is used to get the number of models in the ensemble.""" 97 | return len(self._models) 98 | -------------------------------------------------------------------------------- /docs/extending_yamle/model.rst: -------------------------------------------------------------------------------- 1 | .. _extending_model: 2 | 3 | ********************** 4 | Extending Model 5 | ********************** 6 | 7 | In this Tutorial we will demonstrate how to extend the :py:mod: `BaseModel ` class to create a new model. 8 | 9 | .. literalinclude:: ../../yamle/models/model.py 10 | :language: python 11 | :pyobject: BaseModel 12 | 13 | Each model which is added to YAMLE needs to inherit from the :py:mod: `BaseModel ` class. The :py:mod:`BaseModel ` class provides a number of methods which are used to cross-interact the model with a method `BaseMethod ` and a datamodule `BaseDataModule `. 14 | 15 | Note that each model needs to be able to accept the :code:`inputs_dim`, :code:`outputs_dim` and :code:`task` which automatically decides the number of inputs and outputs for the model. The :code:`task` is a string which is used to determine the type of task the model is being used for. The task usually determines the output activation, for example softmax for classification and exponential applied to one of the outputs in regression to model the variance. 16 | 17 | It is expected that the very first learnable layer will be in the :py:attr:`_input ` attribute and the very last learnable layer will be in the :py:attr:`_output ` attribute. The output activation is expected to be in the :py:attr:`_output_activation ` attribute. This is such that it is possible to easily extract the model's input and output layers and the output activation if needed by the underlying `BaseMethod `. 18 | 19 | There are also other functions which can be used to define the exact behaviour when quantising the model, reset the model each training epoch or to add method-specific layer to the model while keeping the backbone of the model the same. These are all optional and can be overridden if needed. 20 | 21 | The most important methods are the :py:meth:`forward ` or :py:meth:`final_layer ` which specify the forward pass of the model or the processing of the last hidden features with respect to the output layer and the output activation. 22 | 23 | A concrete example is a fully connected network with multiple hidden layers :py:mod:`FC `. 24 | 25 | .. literalinclude:: ../../yamle/models/fc.py 26 | :language: python 27 | :pyobject: FCModel 28 | 29 | Notice the implementation of the :code:`_input` and :code:`_output` layers which automatically take into the account the input and output dimensions passed down by the datamodule that has been chosen to run the experiment with. The :code:`_output_activation` is also automatically chosen based on the task. 30 | 31 | Notice that the :py:meth:`forward ` method takes in extra keyword arguments e.g. to output the hidden representation by each hidden layer, this is used by certain specific methods along with the function to add extra layers for some specfic methods. 32 | 33 | To specify the arguments of the model there is the fucntion :py:meth:`add_specific_args ` which is used to add the arguments of the model to the :py:mod: `ArgumentParser ` of the experiment in the command line. This is used to specify the number of hidden layers, the activation or the width of the network. 34 | 35 | The model also uses some general layers such as :py:mod:`LinearNormActivation ` which is a linear layer followed by a normalisation layer and an activation layer. This class is used also in other models since it is quite general. If you feel that you will be using/implementing a general layer, place it in the :py:mod:`operations ` module. For a method-specific layer, place it in the :py:mod:`specific ` folder. 36 | 37 | The last step is to register the new model in the :py:mod:`__init__ ` file of the :py:mod:`models ` module. This is done by adding the model to the following list: 38 | 39 | .. literalinclude:: ../../yamle/models/__init__.py 40 | :language: python 41 | -------------------------------------------------------------------------------- /yamle/methods/sngp.py: -------------------------------------------------------------------------------- 1 | from typing import Any 2 | import torch 3 | import argparse 4 | 5 | from yamle.methods.method import BaseMethod 6 | from yamle.models.specific.sngp import RFF, spectral_norm 7 | from yamle.defaults import CLASSIFICATION_KEY, SEGMENTATION_KEY 8 | 9 | 10 | def enable_spectral_normalization(model: torch.nn.Module, coeff: float) -> None: 11 | """Replace all the layers in the model with spectral normalized layers. 12 | 13 | Args: 14 | model (torch.nn.Module): The model to enable spectral normalization for. 15 | """ 16 | for name, child in model.named_children(): 17 | if len(list(child.children())) > 0: 18 | enable_spectral_normalization(child, coeff) 19 | else: 20 | setattr(model, name, spectral_norm(child, coeff)) 21 | 22 | 23 | class SNGPMethod(BaseMethod): 24 | """This class is the extension of the base method for which the prediciton is performed through the method of: 25 | Simple and Principled Uncertainty Estimation with Deterministic Deep Learning via Distance Awareness 26 | 27 | 28 | The core of the method is to 1. enable spectral normalization for all `._residual` layers in the model 29 | and replace the `._output` layer with a `._output` layer with a Gaussian process 30 | 31 | Args: 32 | m (float): The gamma for exponential moving average for updating the precision matrix. 33 | random_features (int): The number of random features to use in the RFF layer. 34 | mean_field_factor (float): The factor to use for the mean field approximation. 35 | coeff (float): The coefficient for the spectral normalization. 36 | """ 37 | 38 | tasks = [CLASSIFICATION_KEY, SEGMENTATION_KEY] 39 | 40 | def __init__( 41 | self, 42 | m: float = 0.99, 43 | random_features: int = 512, 44 | mean_field_factor: float = 1.0, 45 | coeff: float = 1.0, 46 | **kwargs: Any 47 | ) -> None: 48 | super().__init__(**kwargs) 49 | enable_spectral_normalization(self.model, coeff=coeff) 50 | assert isinstance( 51 | self.model._output, torch.nn.Linear 52 | ), "The output layer must be a linear layer" 53 | self.model._output = RFF( 54 | self.model._output.in_features, 55 | self.model._output.out_features, 56 | random_features, 57 | mean_field_factor, 58 | m, 59 | ) 60 | 61 | def on_train_epoch_start(self) -> None: 62 | """In the final epoch we need to update the precision matrix. The update is triggered by the `_final_epoch` flag 63 | set to `True`.""" 64 | if self.current_epoch == self.trainer.max_epochs - 1: 65 | self.model._output._final_epoch = True 66 | return super().on_train_epoch_start() 67 | 68 | def on_train_epoch_end(self) -> None: 69 | if self.model._output._final_epoch: 70 | self.model._output._final_epoch = False 71 | self.model._output.compute_covariance() 72 | return super().on_train_epoch_end() 73 | 74 | @staticmethod 75 | def add_specific_args( 76 | parent_parser: argparse.ArgumentParser, 77 | ) -> argparse.ArgumentParser: 78 | """This method is used to add the specific arguments for the DUN method.""" 79 | parser = super(SNGPMethod, SNGPMethod).add_specific_args(parent_parser) 80 | parser.add_argument( 81 | "--method_m", 82 | type=float, 83 | default=0.99, 84 | help="The gamma for exponential moving average for updating the precision matrix.", 85 | ) 86 | parser.add_argument( 87 | "--method_random_features", 88 | type=int, 89 | default=512, 90 | help="The number of random features to use in the RFF layer.", 91 | ) 92 | parser.add_argument( 93 | "--method_mean_field_factor", 94 | type=float, 95 | default=1.0, 96 | help="The factor to use for the mean field approximation.", 97 | ) 98 | parser.add_argument( 99 | "--method_coeff", 100 | type=float, 101 | default=1.0, 102 | help="The coefficient for the spectral normalization.", 103 | ) 104 | return parser 105 | -------------------------------------------------------------------------------- /yamle/utils/tracing_utils.py: -------------------------------------------------------------------------------- 1 | from typing import Tuple, Optional 2 | import torch 3 | from pytorch_lightning import LightningModule 4 | from yamle.defaults import ( 5 | MODULE_INPUT_SHAPE_KEY, 6 | MODULE_OUTPUT_SHAPE_KEY, 7 | MODULE_NAME_KEY, 8 | ) 9 | 10 | 11 | def forward_shape_hook( 12 | module: torch.nn.Module, input: torch.Tensor, output: torch.Tensor 13 | ) -> None: 14 | """This function is used to cache the input and output shapes of a module. 15 | 16 | The shapes will be stored in the module as `MODULE_INPUT_SHAPE_KEY` and `MODULE_OUTPUT_SHAPE_KEY`. 17 | 18 | Args: 19 | module (torch.nn.Module): The module to cache the input and output shapes of. 20 | input (torch.Tensor): The input to the module. 21 | output (torch.Tensor): The output of the module. 22 | 23 | """ 24 | setattr( 25 | module, 26 | MODULE_INPUT_SHAPE_KEY, 27 | [x.shape if isinstance(x, torch.Tensor) else None for x in input], 28 | ) 29 | if isinstance(output, torch.Tensor): 30 | setattr(module, MODULE_OUTPUT_SHAPE_KEY, [output.shape]) 31 | elif isinstance(output, (tuple, list)): 32 | setattr( 33 | module, 34 | MODULE_OUTPUT_SHAPE_KEY, 35 | [ 36 | x.shape if x is not None and isinstance(x, torch.Tensor) else None 37 | for x in output 38 | ], 39 | ) 40 | else: 41 | setattr(module, MODULE_OUTPUT_SHAPE_KEY, [None]) 42 | 43 | 44 | @torch.no_grad() 45 | def get_sample_input_and_target( 46 | method: LightningModule, batch_size: Optional[int] = None 47 | ) -> Tuple[torch.Tensor, torch.Tensor]: 48 | """This method is used to get the sample input and target of the model.""" 49 | 50 | input_shape = method._inputs_dim 51 | if batch_size is not None: 52 | input_shape = (batch_size, *input_shape[1:]) 53 | 54 | if method._inputs_dtype == torch.float: 55 | x = torch.ones(input_shape).to(next(method.model.parameters()).device) 56 | elif method._inputs_dtype == torch.long: 57 | x = torch.randint(0, 1, input_shape).to(next(method.model.parameters()).device) 58 | else: 59 | raise ValueError(f"Input dtype {method._inputs_dtype} is not supported.") 60 | 61 | output_shape = method._targets_dim 62 | batch_size = method._inputs_dim[0] if batch_size is None else batch_size 63 | output_shape = ( 64 | (batch_size, *output_shape) 65 | if isinstance(output_shape, (tuple, list)) 66 | else (batch_size, output_shape) 67 | ) 68 | if method._outputs_dtype == torch.float: 69 | y = torch.randn(output_shape).to(next(method.model.parameters()).device) 70 | elif method._outputs_dtype == torch.long: 71 | y = torch.randint(0, 1, output_shape).to(next(method.model.parameters()).device) 72 | if method._targets_dim == 1: 73 | y = y.view(-1) 74 | else: 75 | raise ValueError(f"Output dtype {method._outputs_dtype} is not supported.") 76 | return x, y 77 | 78 | 79 | def get_input_shape_from_model(model: torch.nn.Module) -> Tuple[int, ...]: 80 | """This method is used to get the input shape of the model.""" 81 | return getattr(model, MODULE_INPUT_SHAPE_KEY, None) 82 | 83 | 84 | def get_output_shape_from_model(model: torch.nn.Module) -> Tuple[int, ...]: 85 | """This method is used to get the output shape of the model.""" 86 | return getattr(model, MODULE_OUTPUT_SHAPE_KEY, None) 87 | 88 | 89 | @torch.no_grad() 90 | def trace_input_output_shapes(method: LightningModule) -> None: 91 | """This method is used to trace the input and output shapes of the model. 92 | 93 | Additionally, it will name all the modules in the model. 94 | """ 95 | method.eval() 96 | hooks = [] 97 | for m in method.model.modules(): 98 | hooks.append(m.register_forward_hook(forward_shape_hook)) 99 | 100 | batch = get_sample_input_and_target(method) 101 | 102 | method.test_step(batch, batch_idx=0) 103 | for hook in hooks: 104 | hook.remove() 105 | method.train() 106 | 107 | name_all_modules(method.model) 108 | 109 | 110 | def name_all_modules(model: torch.nn.Module) -> None: 111 | """This method is used to name all the modules in the model.""" 112 | for name, module in model.named_modules(): 113 | setattr(module, MODULE_NAME_KEY, name) 114 | -------------------------------------------------------------------------------- /yamle/utils/cli_utils.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | 3 | from yamle.data import AVAILABLE_DATAMODULES 4 | from yamle.methods import AVAILABLE_METHODS 5 | from yamle.losses import AVAILABLE_LOSSES 6 | from yamle.models import AVAILABLE_MODELS 7 | from yamle.pruning import AVAILABLE_PRUNERS 8 | from yamle.quantization import AVAILABLE_QUANTIZERS 9 | from yamle.regularizers import AVAILABLE_REGULARIZERS 10 | from yamle.trainers import AVAILABLE_TRAINERS 11 | 12 | 13 | def add_shared_args(parser: argparse.ArgumentParser) -> argparse.ArgumentParser: 14 | """This function adds the shared arguments between training and evaluation to the given parser.""" 15 | parser.add_argument( 16 | "--label", 17 | type=str, 18 | default=None, 19 | help="An optional label to be added to the experiment name.", 20 | ) 21 | 22 | parser.add_argument( 23 | "--model", 24 | type=str, 25 | default="fc", 26 | choices=AVAILABLE_MODELS.keys(), 27 | help="The model to be used for training.", 28 | ) 29 | parser.add_argument( 30 | "--method", 31 | type=str, 32 | default="base", 33 | choices=AVAILABLE_METHODS.keys(), 34 | help="The method to be used for testing.", 35 | ) 36 | parser.add_argument( 37 | "--loss", 38 | type=str, 39 | default=None, 40 | choices=AVAILABLE_LOSSES.keys(), 41 | help="The loss to be used for training.", 42 | ) 43 | parser.add_argument( 44 | "--regularizer", 45 | type=str, 46 | default=None, 47 | choices=AVAILABLE_REGULARIZERS.keys(), 48 | help="The regularizer to be used for training.", 49 | ) 50 | parser.add_argument( 51 | "--datamodule", 52 | type=str, 53 | default="mnist", 54 | choices=AVAILABLE_DATAMODULES.keys(), 55 | help="The data to be used for training.", 56 | ) 57 | parser.add_argument( 58 | "--trainer", 59 | type=str, 60 | default="base", 61 | choices=AVAILABLE_TRAINERS.keys(), 62 | help="The trainer to be used for training.", 63 | ) 64 | parser.add_argument( 65 | "--pruner", 66 | type=str, 67 | default="none", 68 | choices=AVAILABLE_PRUNERS.keys(), 69 | help="The pruner to be used for evaluation.", 70 | ) 71 | parser.add_argument( 72 | "--quantizer", 73 | type=str, 74 | default="none", 75 | choices=AVAILABLE_QUANTIZERS.keys(), 76 | help="The quantizer to be used for evaluation.", 77 | ) 78 | parser.add_argument_group("Experiment") 79 | parser.add_argument( 80 | "--seed", type=int, default=42, help="The seed to be used for training." 81 | ) 82 | parser.add_argument( 83 | "--save_path", 84 | type=str, 85 | default="experiments", 86 | help="The directory where the experiment results are stored.", 87 | ) 88 | parser.add_argument( 89 | "--load_path", 90 | type=str, 91 | default=None, 92 | help="The directory where the experiment results are stored and loaded from.", 93 | ) 94 | parser.add_argument( 95 | "--no_saving", 96 | type=int, 97 | default=0, 98 | choices=[0, 1], 99 | help="Whether to skip the saving step.", 100 | ) 101 | parser.add_argument( 102 | "--st_checkpoint_dir", 103 | type=str, 104 | default=None, 105 | help="The directory where the Syne Tune checkpoint is stored.", 106 | ) 107 | 108 | parser.add_argument( 109 | "--onnx_export", 110 | type=int, 111 | default=0, 112 | choices=[0, 1], 113 | help="Whether to export the model to ONNX.", 114 | ) 115 | 116 | return parser 117 | 118 | 119 | def add_train_args(parser: argparse.ArgumentParser) -> argparse.ArgumentParser: 120 | """This function adds the training arguments to the given parser.""" 121 | parser.add_argument( 122 | "--no_evaluation", 123 | type=int, 124 | default=0, 125 | choices=[0, 1], 126 | help="Whether to skip the evaluation step.", 127 | ) 128 | return parser 129 | 130 | 131 | def add_test_args(parser: argparse.ArgumentParser) -> argparse.ArgumentParser: 132 | """This function adds the testing arguments to the given parser.""" 133 | return parser 134 | -------------------------------------------------------------------------------- /yamle/third_party/imagenet_c/extra.py: -------------------------------------------------------------------------------- 1 | from typing import Tuple 2 | 3 | from PIL import Image 4 | import numpy as np 5 | import torch 6 | 7 | 8 | class RandomImageNoise: 9 | """This class creates an image where each pixel is uniformly distributed 10 | between 0 and 255. 11 | 12 | Args: 13 | size (Tuple[int, int, int]): The size of the image. The shape is `(channels, height, width)`. 14 | minimum (torch.Tensor): The minimum value of each pixel per channel. 15 | maximum (torch.Tensor): The maximum value of each pixel per channel. 16 | mean (torch.Tensor): The mean value of each pixel per channel. 17 | std (torch.Tensor): The standard deviation of each pixel per channel. 18 | noise (str): The type of noise to use. Can be one of `uniform` or `gaussian`. Default: `uniform`. 19 | """ 20 | 21 | def __init__( 22 | self, 23 | size: Tuple[int, int, int], 24 | minimum: torch.Tensor, 25 | maximum: torch.Tensor, 26 | mean: torch.Tensor, 27 | std: torch.Tensor, 28 | noise: str = "uniform", 29 | ) -> None: 30 | self._size = size 31 | # Broadcast all tensors to the same shape 32 | self._minimum = minimum.reshape(-1, 1, 1) 33 | self._maximum = maximum.reshape(-1, 1, 1) 34 | self._mean = mean.reshape(-1, 1, 1) 35 | self._std = std.reshape(-1, 1, 1) 36 | self._noise = noise 37 | 38 | def __call__(self, x: Image.Image) -> Image.Image: 39 | """Creates an image where each pixel is uniformly distributed between 0 and 255. 40 | 41 | Args: 42 | x (Image.Image): The input image. 43 | 44 | Returns: 45 | Image.Image: The output image. 46 | """ 47 | noise = None 48 | if self._noise == "uniform": 49 | noise = np.random.uniform( 50 | self._minimum, self._maximum, size=self._size 51 | ).astype(np.uint8) 52 | elif self._noise == "gaussian": 53 | noise = np.random.normal(self._mean, self._std, size=self._size).astype( 54 | np.uint8 55 | ) 56 | 57 | # Handle grayscale images (single channel) 58 | if self._size[0] == 1: 59 | noise = noise.squeeze(0) # Remove channel dimension if grayscale 60 | 61 | # Handle color images 62 | if self._size[0] == 3: 63 | noise = noise.transpose(1, 2, 0) # Transpose to (H, W, C) 64 | 65 | return Image.fromarray(noise) 66 | 67 | class RandomTabularNoise: 68 | """This class creates a tabular noise where each feature is uniformly sampled between min and max. 69 | 70 | Args: 71 | size (Tuple[..., int]): The size of the tabular noise. The shape is `(features)`. 72 | minimum (torch.Tensor): The minimum value of each feature. 73 | maximum (torch.Tensor): The maximum value of each feature. 74 | mean (torch.Tensor): The mean value of each feature. 75 | std (torch.Tensor): The standard deviation of each feature. 76 | noise (str): The type of noise to use. Can be one of `uniform` or `gaussian`. Default: `uniform`. 77 | """ 78 | 79 | def __init__( 80 | self, 81 | size: Tuple[int, ...], 82 | minimum: torch.Tensor, 83 | maximum: torch.Tensor, 84 | mean: torch.Tensor, 85 | std: torch.Tensor, 86 | noise: str = "uniform", 87 | ) -> None: 88 | self._size = size 89 | self._minimum = minimum 90 | self._maximum = maximum 91 | self._mean = mean 92 | self._std = std 93 | 94 | assert noise in [ 95 | "uniform", 96 | "gaussian", 97 | ], f"Unknown noise type {noise}. Must be one of `uniform` or `gaussian`." 98 | 99 | self._noise = noise 100 | 101 | def __call__(self, x: torch.Tensor) -> torch.Tensor: 102 | """Creates a tabular noise where each feature is uniformly sampled between min and max. 103 | 104 | Args: 105 | x (np.ndarray): The input tabular data. 106 | 107 | Returns: 108 | np.ndarray: The output tabular noise. 109 | """ 110 | if self._noise == "uniform": 111 | noise = torch.rand(self._size, device=x.device) 112 | return self._minimum + noise * (self._maximum - self._minimum) 113 | 114 | elif self._noise == "gaussian": 115 | return torch.randn(self._size, device=x.device) * self._std + self._mean 116 | -------------------------------------------------------------------------------- /yamle/methods/sgld.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | from typing import Any, Dict, List, Optional 3 | import torch 4 | from yamle.defaults import LOSS_KEY, TRAIN_KEY 5 | from yamle.methods.ensemble import EnsembleMethod 6 | 7 | 8 | class SGLDMethod(EnsembleMethod): 9 | """This is a Method class for the Stochastic Gradient Langevin Dynamics Optimizer. 10 | 11 | It uses the `Ensemble` model to wrap around the original model and then uses the base method to train the network 12 | via the SGLD optimizer. 13 | 14 | At predefined epoch intervals, the weights of the main model 15 | are copied into the next member in the ensemble. This represents the posterior distribution of the weights. 16 | The last sample is always at the last epoch. 17 | 18 | Args: 19 | sampling_epochs (List[int], optional): Epochs at which to sample (default: [0, 1, 2, 3, 4, 5, 10, 20, 50, 100]). 20 | """ 21 | 22 | def __init__( 23 | self, 24 | sampling_epochs: List[int] = [0, 1, 2, 3, 4, 5, 10, 20, 50, 100], 25 | *args: Any, 26 | **kwargs: Any, 27 | ) -> None: 28 | super().__init__(*args, **kwargs) 29 | self._sampling_epochs = sampling_epochs 30 | assert self._num_members == len( 31 | self._sampling_epochs 32 | ), f"Number of sampling epochs ({len(self._sampling_epochs)}) must match number of ensemble members ({self._num_members})." 33 | assert self.hparams.optimizer in [ 34 | "sgld", 35 | "psgld", 36 | ], f"Optimizer must be 'sgld' or 'psgld', not {self.hparams.optimizer}." 37 | 38 | def _predict(self, x: torch.Tensor, **forward_kwargs: Any) -> torch.Tensor: 39 | """This method is used to perform a forward pass of the model. 40 | 41 | In validation it is done with respect to all models that have been trained. 42 | In training only the first member is used. 43 | """ 44 | if self.training: 45 | return super(EnsembleMethod, self)._predict( 46 | x, current_member=0, **forward_kwargs 47 | ) 48 | else: 49 | return super()._predict(x, **forward_kwargs) 50 | 51 | def get_parameters(self, recurse: bool = True) -> List[torch.nn.Parameter]: 52 | """A helper function to get the parameters of a single ensemble member. 53 | 54 | In this case, get always the first one. 55 | """ 56 | return list(self.model.parameters(index=0, recurse=recurse)) 57 | 58 | def _step( 59 | self, 60 | batch: List[torch.Tensor], 61 | batch_idx: int, 62 | optimizer_idx: Optional[int] = None, 63 | phase: str = TRAIN_KEY, 64 | ) -> Dict[str, Any]: 65 | output = super().step(batch, batch_idx, optimizer_idx, phase) 66 | # A batch is an approximation of the whole dataset, so we need to scale the loss 67 | # by the size of the dataset (training set size) 68 | output[LOSS_KEY] = output[LOSS_KEY] * self._datamodule.train_dataset_size() 69 | return output 70 | 71 | def on_train_epoch_end(self) -> None: 72 | """This method is called at the end of each training epoch. 73 | 74 | In this case, if the current epoch can be found in the sampling epochs, the current member is incremented. 75 | """ 76 | assert ( 77 | self._sampling_epochs[-1] == self.trainer.max_epochs - 1 78 | ), f"Last sampling epoch ({self._sampling_epochs[-1]}) must match last epoch ({self.trainer.max_epochs - 1})." 79 | # The -1 is the default model at index 0, we don't need to do anything 80 | if self.current_epoch in self._sampling_epochs[:-1]: 81 | self.increment_current_member() 82 | self.model[self.model.currently_trained_member.item()].load_state_dict( 83 | self.model[0].state_dict() 84 | ) 85 | super().on_train_epoch_end() 86 | 87 | @staticmethod 88 | def add_specific_args( 89 | parent_parser: argparse.ArgumentParser, 90 | ) -> argparse.ArgumentParser: 91 | """This method adds arguments specific to this method to the parser.""" 92 | parser = super(SGLDMethod, SGLDMethod).add_specific_args(parent_parser) 93 | parser.add_argument( 94 | "--method_sampling_epochs", 95 | type=str, 96 | default="[0,1,2,3,4,5,10,20,50,100]", 97 | help="Epochs at which to sample (default: [0,1,2,3,4,5,10,20,50,100]).", 98 | ) 99 | return parser 100 | -------------------------------------------------------------------------------- /yamle/models/specific/mimmo.py: -------------------------------------------------------------------------------- 1 | from typing import Any, Callable 2 | 3 | import torch 4 | import torch.nn as nn 5 | 6 | from yamle.models.model import BaseModel 7 | 8 | import logging 9 | 10 | logging = logging.getLogger("pytorch_lightning") 11 | 12 | 13 | class MIMMMOWrapper(nn.Module): 14 | """This is a wrapper for a MIMMO module which makes the predictions from any ``BaseModel``. 15 | 16 | This is to wrap the forward method which should return all the predictions from the model. 17 | 18 | Args: 19 | model (BaseModel): The model to wrap. 20 | evaluation_depth_weights_function (Callable): The function to use to compute the depth weights. 21 | """ 22 | 23 | def __init__( 24 | self, model: BaseModel, evaluation_depth_weights_function: Callable 25 | ) -> None: 26 | super().__init__() 27 | self.model = model 28 | self._evaluation_depth_weights_function = evaluation_depth_weights_function 29 | 30 | def forward(self, x: torch.Tensor, **forward_kwargs: Any) -> torch.Tensor: 31 | """This method is used to perform a forward pass of the model.""" 32 | last_layer, stages = self.model(x, staged_output=True, **forward_kwargs) 33 | 34 | # Since the last layer uses the last hidden layer 35 | # we can remove it 36 | stages = stages[:-1] 37 | outputs = [] 38 | offset = 0 39 | for i, h in enumerate(stages): 40 | if not self.model._available_heads[i]: 41 | offset += 1 42 | continue 43 | h = self.model._reshaping_layers[i - offset](h) 44 | h = ( 45 | self.model.final_layer(h) 46 | if not self.model._additional_heads 47 | else self.model._output_activation(self.model._heads[i - offset](h)) 48 | ) 49 | # A single output has shape `(batch_size, num_members, predictions)` 50 | outputs.append(h) 51 | 52 | outputs.append(last_layer) 53 | # Note that the output shape is `(batch_size, depth, num_members, predictions)` 54 | return torch.stack(outputs, dim=1) 55 | 56 | @property 57 | def _input(self) -> nn.Module: 58 | """This property is used to get the input layer of the model.""" 59 | return self.model._input 60 | 61 | @property 62 | def _output(self) -> nn.Module: 63 | """This property is used to get the output layer of the model.""" 64 | return self.model._output 65 | 66 | @property 67 | def _heads(self) -> nn.ModuleList: 68 | """This property is used to get the heads of the model.""" 69 | return self.model._heads 70 | 71 | @property 72 | def _reshaping_layers(self) -> nn.ModuleList: 73 | """This property is used to get the reshaping layers of the model.""" 74 | return self.model._reshaping_layers 75 | 76 | @property 77 | def _prior_depth_weights(self) -> torch.Tensor: 78 | """This property is used to get the prior depth weights of the model.""" 79 | return self.model._prior_depth_weights 80 | 81 | @property 82 | def _depth_weights(self) -> torch.Tensor: 83 | """This property is used to get the depth weights of the model.""" 84 | return self.model._depth_weights 85 | 86 | @property 87 | def _available_heads(self) -> torch.Tensor: 88 | """This property is used to get the available heads of the model.""" 89 | return self.model._available_heads 90 | 91 | @property 92 | def _additional_heads(self) -> bool: 93 | """This property is used to get whether the model has additional heads.""" 94 | return self.model._additional_heads 95 | 96 | @property 97 | def _output_activation(self) -> nn.Module: 98 | """This property is used to get the output activation of the model.""" 99 | return self.model._output_activation 100 | 101 | @property 102 | def _depth(self) -> int: 103 | """This property is used to get the depth of the model.""" 104 | return self.model._depth 105 | 106 | def final_layer(self, x: torch.Tensor, **output_kwargs: Any) -> torch.Tensor: 107 | """This function is used to get the final layer output.""" 108 | return self.model.final_layer(x, **output_kwargs) 109 | 110 | def reset(self) -> None: 111 | """This method is used to reset the model e.g. at the start of a new epoch.""" 112 | self.model.reset() 113 | 114 | def replace_layers_for_quantization(self) -> None: 115 | """Fuses all the operations in the network. 116 | 117 | In this function we only need to fuse layers that are not in the blocks. 118 | e.g. the reshaping layers added by the method. 119 | """ 120 | self.model.replace_layers_for_quantization() -------------------------------------------------------------------------------- /yamle/trainers/ensemble.py: -------------------------------------------------------------------------------- 1 | from typing import Dict, Any 2 | import argparse 3 | import logging 4 | from typing import Any 5 | import time 6 | 7 | from yamle.defaults import TRAIN_DATA_SPLIT_KEY, FIT_TIME_KEY, ALL_DATASETS_KEY 8 | from yamle.trainers.trainer import BaseTrainer 9 | 10 | logging = logging.getLogger("pytorch_lightning") 11 | 12 | 13 | class BaggingTrainer(BaseTrainer): 14 | """This class defines a bagging trainer which given a method and multiple data splits performs training and evaluation. 15 | 16 | The difference is that the data is split across training splits and the training is performed 17 | in parallel for each training split. 18 | """ 19 | 20 | def __init__(self, *args: Any, **kwargs: Any) -> None: 21 | super().__init__(*args, **kwargs) 22 | self._train_splits = self._datamodule._train_splits 23 | 24 | def fit(self) -> None: 25 | """This method trains the method and the embedded model.""" 26 | train_dataloaders = {} 27 | validation_dataloader = self._datamodule.validation_dataloader() 28 | for i in range(self._train_splits): 29 | train_dataloaders[ 30 | f"{TRAIN_DATA_SPLIT_KEY}{i}" 31 | ] = self._datamodule.train_dataloader( 32 | split=None if self._train_splits is None else i 33 | ) 34 | self._trainer.fit(self._method, train_dataloaders, validation_dataloader) 35 | 36 | 37 | class EnsembleTrainer(BaggingTrainer): 38 | """This class defines an ensemble trainer which given a method and data loaders performs training and evaluation. 39 | 40 | The difference is that the training is performed repeatedly for all ensemble members. 41 | If training data splits are provided, the training is performed with respect to the data splits 42 | which are separate for each ensemble member. 43 | 44 | If the training is defines as parallel, this trainer returns multiple train loaders to the method. 45 | Then it is the method's responsibility to train the ensemble members in parallel. 46 | 47 | Args: 48 | parallel (bool): Whether to train the ensemble members in parallel. 49 | """ 50 | 51 | def __init__(self, parallel: bool = False, *args: Any, **kwargs: Any) -> None: 52 | super().__init__(*args, **kwargs) 53 | self._parallel = parallel 54 | 55 | def fit(self, results: Dict[str, Any]) -> float: 56 | """This method trains the method and the embedded model. 57 | 58 | It also measures the time taken for training and returns it. 59 | """ 60 | assert hasattr( 61 | self._method, "_num_members" 62 | ), "The method does not have a `num_members` attribute. Maybe not an ensemble model?" 63 | if self._train_splits is not None: 64 | assert ( 65 | self._train_splits == self._method._num_members 66 | ), f"The number of training splits ({len(self._train_splits)}) does not match the number of ensemble members ({self._method._num_members})." 67 | total_time = 0 68 | if not self._parallel: 69 | assert hasattr( 70 | self._method, "increment_current_member" 71 | ), "The method does not have a `increment_current_member` method. Maybe not an ensemble model?" 72 | validation_dataloader = self._datamodule.validation_dataloader() 73 | for i in range(self._method._num_members): 74 | logging.info(f"Training member {i+1} of {self._method._num_members}.") 75 | train_dataloader = self._datamodule.train_dataloader( 76 | split=None if self._train_splits is None else i 77 | ) 78 | start_time = time.time() 79 | self._trainer.fit(self._method, train_dataloader, validation_dataloader) 80 | end_time = time.time() 81 | # The trainer needs to be reinitialized after each training round. 82 | if i < self._method._num_members - 1: 83 | self._method.increment_current_member() 84 | self._initialize_trainer() 85 | total_time += end_time - start_time 86 | if results is not None: 87 | results[ALL_DATASETS_KEY][FIT_TIME_KEY] = total_time 88 | else: 89 | total_time = super().fit(results) 90 | return total_time 91 | 92 | @staticmethod 93 | def add_specific_args( 94 | parent_parser: argparse.ArgumentParser, 95 | ) -> argparse.ArgumentParser: 96 | """This method adds the specific arguments for the ensemble trainer.""" 97 | parser = super(EnsembleTrainer, EnsembleTrainer).add_specific_args( 98 | parent_parser 99 | ) 100 | parser.add_argument( 101 | "--trainer_parallel", 102 | type=int, 103 | choices=[0, 1], 104 | default=0, 105 | help="Whether to train the ensemble members in parallel.", 106 | ) 107 | return parser 108 | -------------------------------------------------------------------------------- /yamle/utils/running_utils.py: -------------------------------------------------------------------------------- 1 | from typing import Any, Dict 2 | import argparse 3 | 4 | from yamle.data.datamodule import BaseDataModule 5 | 6 | 7 | def extract_kwargs(args: argparse.Namespace, prefix: str) -> Dict[str, Any]: 8 | """This method is used to extract the kwargs from the args. 9 | 10 | Args: 11 | args (argparse.Namespace): The arguments. 12 | prefix (str): The prefix to be used for the arguments. 13 | 14 | Returns: 15 | Dict[str, Any]: The extracted arguments. 16 | """ 17 | kwargs = {} 18 | for key, value in vars(args).items(): 19 | if key.startswith(prefix): 20 | kwargs[key[len(prefix) :]] = value 21 | return kwargs 22 | 23 | 24 | def prepare_datamodule_kwargs(args: argparse.Namespace) -> Dict[str, Any]: 25 | """This method is used to prepare the datamodule kwargs.""" 26 | datamodule_kwargs = extract_kwargs(args, "datamodule_") 27 | datamodule_kwargs["seed"] = args.seed 28 | return datamodule_kwargs 29 | 30 | 31 | def prepare_model_kwargs( 32 | args: argparse.Namespace, datamodule: BaseDataModule 33 | ) -> Dict[str, Any]: 34 | """This method is used to prepare the model kwargs.""" 35 | model_kwargs = extract_kwargs(args, "model_") 36 | model_kwargs["task"] = datamodule.task 37 | model_kwargs["outputs_dim"] = datamodule.outputs_dim 38 | model_kwargs["inputs_dim"] = datamodule.inputs_dim 39 | model_kwargs["seed"] = args.seed 40 | return model_kwargs 41 | 42 | 43 | def prepare_loss_kwargs( 44 | args: argparse.Namespace, datamodule: BaseDataModule 45 | ) -> Dict[str, Any]: 46 | """This method is used to prepare the loss kwargs.""" 47 | loss_kwargs = extract_kwargs(args, "loss_") 48 | loss_kwargs["task"] = datamodule.task 49 | return loss_kwargs 50 | 51 | 52 | def prepare_regularizer_kwargs(args: argparse.Namespace) -> Dict[str, Any]: 53 | """This method is used to prepare the regularizer kwargs.""" 54 | regularizer_kwargs = extract_kwargs(args, "regularizer_") 55 | return regularizer_kwargs 56 | 57 | 58 | def prepare_metrics_kwargs( 59 | args: argparse.Namespace, datamodule: BaseDataModule 60 | ) -> Dict[str, Any]: 61 | """This method is used to prepare the metrics kwargs.""" 62 | metrics_kwargs = {} 63 | metrics_kwargs["task"] = datamodule.task 64 | metrics_kwargs["outputs_dim"] = datamodule.outputs_dim 65 | metrics_kwargs["ignore_indices"] = datamodule.ignore_indices 66 | metrics_kwargs["num_members"] = ( 67 | 1 if not hasattr(args, "method_num_members") else args.method_num_members 68 | ) 69 | metrics_kwargs["metrics"] = args.method_metrics 70 | return metrics_kwargs 71 | 72 | 73 | def prepare_method_kwargs( 74 | args: argparse.Namespace, datamodule: BaseDataModule 75 | ) -> Dict[str, Any]: 76 | """This method is used to prepare the method kwargs.""" 77 | method_kwargs = extract_kwargs(args, "method_") 78 | method_kwargs["seed"] = args.seed 79 | method_kwargs["task"] = datamodule.task 80 | method_kwargs["outputs_dim"] = datamodule.outputs_dim 81 | method_kwargs["targets_dim"] = datamodule.targets_dim 82 | method_kwargs["outputs_dtype"] = datamodule.outputs_dtype 83 | method_kwargs["inputs_dim"] = (args.datamodule_batch_size, *datamodule.inputs_dim) 84 | method_kwargs["inputs_dtype"] = datamodule.inputs_dtype 85 | method_kwargs["datamodule"] = datamodule 86 | method_kwargs["save_path"] = args.save_path 87 | method_kwargs["metrics_kwargs"] = prepare_metrics_kwargs(args, datamodule) 88 | method_kwargs["model_kwargs"] = prepare_model_kwargs(args, datamodule) 89 | return method_kwargs 90 | 91 | 92 | def prepare_trainer_kwargs( 93 | args: argparse.Namespace, datamodule: BaseDataModule 94 | ) -> Dict[str, Any]: 95 | """This method is used to prepare the trainer kwargs.""" 96 | trainer_kwargs = extract_kwargs(args, "trainer_") 97 | trainer_kwargs["save_path"] = args.save_path 98 | trainer_kwargs["st_checkpoint_dir"] = args.st_checkpoint_dir 99 | trainer_kwargs["datamodule"] = datamodule 100 | trainer_kwargs["task"] = datamodule.task 101 | trainer_kwargs["no_saving"] = args.no_saving 102 | return trainer_kwargs 103 | 104 | 105 | def prepare_test_trainer_kwargs( 106 | args: argparse.Namespace, datamodule: BaseDataModule 107 | ) -> Dict[str, Any]: 108 | """This method is used to prepare the trainer kwargs.""" 109 | test_trainer_kwargs = prepare_trainer_kwargs(args, datamodule) 110 | test_trainer_kwargs["precision"] = 32 111 | return test_trainer_kwargs 112 | 113 | 114 | def prepare_pruner_kwargs(args: argparse.Namespace) -> Dict[str, Any]: 115 | """This method is used to prepare the pruner kwargs.""" 116 | pruner_kwargs = extract_kwargs(args, "pruner_") 117 | return pruner_kwargs 118 | 119 | 120 | def prepare_quantizer_kwargs(args: argparse.Namespace) -> Dict[str, Any]: 121 | """This method is used to prepare the quantizer kwargs.""" 122 | quantizer_kwargs = extract_kwargs(args, "quantizer_") 123 | return quantizer_kwargs 124 | -------------------------------------------------------------------------------- /yamle/regularizers/model.py: -------------------------------------------------------------------------------- 1 | from typing import Any 2 | 3 | import torch.nn as nn 4 | from yamle.regularizers.regularizer import BaseRegularizer 5 | 6 | import torch 7 | import argparse 8 | 9 | 10 | class ShrinkAndPerturbRegularizer(BaseRegularizer): 11 | """This is a class for a shrink and perturb regularization. 12 | 13 | It shrinks the weights by a factor of `l` and adds a noise sampled from a 14 | normal distribution with mean 0 and standard deviation `std` to the weights at 15 | a certain epoch frequency. 16 | 17 | There is also a second argument which limits the starting epoch and the ending epoch 18 | within which the shrink and perturb regularization is applied. 19 | 20 | It follows the paper: https://arxiv.org/pdf/1910.08475.pdf 21 | 22 | Args: 23 | l (float): The factor by which the weights are shrunk. 24 | std (float): The standard deviation of the normal distribution from which the noise is sampled. 25 | start_epoch (int): The epoch at which the shrink and perturb regularization starts. Default is 0, which means that the regularization is applied from the beginning of the training. 26 | end_epoch (int): The epoch at which the shrink and perturb regularization ends. Default is -1, which means that the regularization is applied until the end of the training. 27 | epoch_frequency (int): The frequency at which the shrink and perturb regularization is applied. 28 | """ 29 | 30 | def __init__( 31 | self, 32 | l: float, 33 | std: float, 34 | start_epoch: int, 35 | end_epoch: int, 36 | epoch_frequency: int, 37 | *args: Any, 38 | **kwargs: Any, 39 | ) -> None: 40 | super().__init__(*args, **kwargs) 41 | assert 0 <= l <= 1, f"The shrink factor must be between 0 and 1, but got {l}." 42 | 43 | assert ( 44 | std >= 0 45 | ), f"The standard deviation of the normal distribution must be non-negative, but got {std}." 46 | 47 | assert ( 48 | start_epoch >= 0 49 | ), f"The start epoch must be non-negative, but got {start_epoch}." 50 | 51 | assert ( 52 | end_epoch == -1 or end_epoch >= start_epoch 53 | ), f"The end epoch must be greater than or equal to the start epoch, but got {end_epoch} and {start_epoch}." 54 | 55 | assert ( 56 | epoch_frequency > 0 57 | ), f"The epoch frequency must be greater than 0, but got {epoch_frequency}." 58 | 59 | self._l = l 60 | self._std = std 61 | self._start_epoch = start_epoch 62 | self._end_epoch = end_epoch 63 | self._epoch_frequency = epoch_frequency 64 | 65 | def on_after_train_epoch( 66 | self, model: nn.Module, epoch: int, *args: Any, **kwargs: Any 67 | ) -> None: 68 | """Add noise to the weights after a given training epoch. 69 | 70 | For all parameters that require gradients, the weights are shrunk by a factor of `l` and a noise sampled from a 71 | normal distribution with mean 0 and standard deviation `std` is added to the weights. 72 | """ 73 | if ( 74 | epoch >= self._start_epoch 75 | and (epoch <= self._end_epoch or self._end_epoch == -1) 76 | and epoch % self._epoch_frequency == 0 77 | and epoch != 0 78 | ): 79 | for param in model.parameters(): 80 | if param.requires_grad: 81 | param.data = ( 82 | param.data * self._l + torch.randn_like(param.data) * self._std 83 | ) 84 | 85 | @staticmethod 86 | def add_specific_args(parser: argparse.ArgumentParser) -> argparse.ArgumentParser: 87 | """This method is used to add specific arguments to the parser.""" 88 | parser = super( 89 | ShrinkAndPerturbRegularizer, ShrinkAndPerturbRegularizer 90 | ).add_specific_args(parser) 91 | parser.add_argument( 92 | "--regularizer_l", 93 | type=float, 94 | default=0.1, 95 | help="The factor by which the weights are shrunk.", 96 | ) 97 | parser.add_argument( 98 | "--regularizer_std", 99 | type=float, 100 | default=0.1, 101 | help="The standard deviation of the normal distribution from which the noise is sampled.", 102 | ) 103 | parser.add_argument( 104 | "--regularizer_start_epoch", 105 | type=int, 106 | default=0, 107 | help="The epoch at which the shrink and perturb regularization starts.", 108 | ) 109 | parser.add_argument( 110 | "--regularizer_end_epoch", 111 | type=int, 112 | default=-1, 113 | help="The epoch at which the shrink and perturb regularization ends.", 114 | ) 115 | parser.add_argument( 116 | "--regularizer_epoch_frequency", 117 | type=int, 118 | default=1, 119 | help="The frequency at which the shrink and perturb regularization is applied.", 120 | ) 121 | return parser 122 | -------------------------------------------------------------------------------- /yamle/regularizers/feature.py: -------------------------------------------------------------------------------- 1 | from typing import Any, List 2 | from yamle.regularizers.regularizer import BaseRegularizer 3 | 4 | import torch 5 | import argparse 6 | 7 | 8 | class L1FeatureRegularizer(BaseRegularizer): 9 | """This is a class for L1 regularization for the output features.""" 10 | 11 | def __call__(self, x: torch.Tensor) -> torch.Tensor: 12 | """This method is used to calculate the regularization loss.""" 13 | batch_size = x.shape[0] 14 | return torch.abs(x).sum() / batch_size 15 | 16 | def __repr__(self) -> str: 17 | return f"L1FeatureRegularizer()" 18 | 19 | 20 | class L2FeatureRegularizer(BaseRegularizer): 21 | """This is a class for L2 regularization for the output features.""" 22 | 23 | def __call__(self, x: torch.Tensor) -> torch.Tensor: 24 | """This method is used to calculate the regularization loss.""" 25 | batch_size = x.shape[0] 26 | return (torch.sum(x**2) * 0.5) / batch_size 27 | 28 | def __repr__(self) -> str: 29 | return f"L2Feature()" 30 | 31 | 32 | class InnerProductFeatureRegularizer(BaseRegularizer): 33 | """This is a class for inner product regularization. 34 | 35 | Given a tensor `x` which can be split in dimension `dim` into `n` tensors `x_1, ..., x_n`, the regularization loss is calculated as: 36 | 37 | `loss = sum_{i=1}^{n} sum_{j=i+1}^{n} x_i * x_j` 38 | `loss = loss / (n*(n-1)/2)` 39 | 40 | Args: 41 | dim (int): The dimension over which split the tensor to then calculate the inner product as a cartesian product. 42 | """ 43 | 44 | def __init__(self, dim: int = 1, *args: Any, **kwargs: Any) -> None: 45 | super().__init__(*args, **kwargs) 46 | self._dim = dim 47 | 48 | def _split_and_reshape_tensor_on_dim(self, x: torch.Tensor) -> List[torch.Tensor]: 49 | """This method is used to split the tensor on the given dimension and then reshape it.""" 50 | batch_size = x.shape[0] 51 | x = torch.split(x, 1, dim=self._dim) 52 | x = [x_.squeeze(dim=self._dim).view(batch_size, -1) for x_ in x] 53 | return x 54 | 55 | def __call__(self, x: torch.Tensor) -> torch.Tensor: 56 | """This method is used to calculate the regularization loss.""" 57 | loss = 0.0 58 | x = self._split_and_reshape_tensor_on_dim(x) 59 | for i in range(len(x)): 60 | for j in range(i + 1, len(x)): 61 | loss += torch.sum(x[i] * x[j], dim=1).mean() 62 | return loss / (len(x) * (len(x) - 1) / 2) 63 | 64 | @staticmethod 65 | def add_specific_args(parser: argparse.ArgumentParser) -> argparse.ArgumentParser: 66 | """This method is used to add specific arguments to the parser.""" 67 | parser = super( 68 | InnerProductFeatureRegularizer, InnerProductFeatureRegularizer 69 | ).add_specific_args(parser) 70 | parser.add_argument( 71 | "--regularizer_dim", 72 | type=int, 73 | default=1, 74 | help="The dimension over which split the tensor to then calculate the inner product as a cartesian product.", 75 | ) 76 | return parser 77 | 78 | def __repr__(self) -> str: 79 | return f"InnerProductFeatureRegularizer(dim={self._dim})" 80 | 81 | 82 | class CosineSimilarityFeatureRegularizer(InnerProductFeatureRegularizer): 83 | """This is a class for cosine similarity regularization. 84 | 85 | Given a tensor `x` which can be split in dimension `dim` into `n` tensors `x_1, ..., x_n`, the regularization loss is calculated as: 86 | 87 | `loss = sum_{i=1}^{n} sum_{j=i+1}^{n} cos(x_i, x_j)` 88 | `loss = loss / (n*(n-1)/2)` 89 | 90 | The `cos` function is the cosine similarity between `x_i` and `x_j`. 91 | `cos(x_i, x_j) = x_i * x_j / (||x_i|| * ||x_j||)` 92 | """ 93 | 94 | def __call__(self, x: torch.Tensor) -> torch.Tensor: 95 | """This method is used to calculate the regularization loss.""" 96 | loss = 0.0 97 | x = self._split_and_reshape_tensor_on_dim(x) 98 | for i in range(len(x)): 99 | for j in range(i + 1, len(x)): 100 | loss += torch.cosine_similarity(x[i], x[j], dim=1).mean() 101 | return loss / (len(x) * (len(x) - 1) / 2) 102 | 103 | def __repr__(self) -> str: 104 | return f"CosineSimilarityFeatureRegularizer(dim={self._dim})" 105 | 106 | 107 | class CorrelationFeatureRegularizer(CosineSimilarityFeatureRegularizer): 108 | """This is a class for correlation regularization. 109 | 110 | Correlation is the cosine similarity between centered versions of x and y. 111 | Unlike the cosine, the correlation is invariant to both scale and location changes of x and y. 112 | """ 113 | 114 | def __call__(self, x: torch.Tensor) -> torch.Tensor: 115 | """This method is used to calculate the regularization loss.""" 116 | x = self._split_and_reshape_tensor_on_dim(x) 117 | for i in range(len(x)): 118 | x[i] = x[i] - x[i].mean(dim=1, keepdim=True) 119 | x = torch.stack(x, dim=self._dim) 120 | return super().__call__(x) 121 | 122 | def __repr__(self) -> str: 123 | return f"CorrelationFeatureRegularizer(dim={self._dim})" 124 | -------------------------------------------------------------------------------- /docs/getting_started/index.rst: -------------------------------------------------------------------------------- 1 | *************** 2 | Getting Started 3 | *************** 4 | 5 | This section covers the usage of YAMLE from the installation to the training of a model. 6 | 7 | YAMLE is not available as a package on PyPI yet. However, it is possible to install it directly from the git repository. 8 | It is recommended to install YAMLE in a virtual environment. 9 | For example, follow 10 | `these instructions `_ to use :code:`venv`. 11 | 12 | .. code-block:: bash 13 | 14 | git clone https://github.com/martinferianc/yamle.git 15 | cd yamle 16 | pip install -e . 17 | 18 | Afterwards you can try to run the example script: 19 | 20 | .. code-block:: bash 21 | 22 | python yamle/cli/train.py --method base --trainer_devices "[0]" --datamodule mnist --datamodule_batch_size 256 --method_optimizer adam --method_learning_rate 3e-4 --regularizer l2 --method_regularizer_weight 1e-5 --loss crossentropy --save_path ./experiments --trainer_epochs 3 --model_hidden_dim 32 --model_depth 3 --datamodule_validation_portion 0.1 --save_path ./experiments --model fc --datamodule_pad_to_32 1 23 | 24 | This script trains a simple fully connected network :py:mod:`FC ` on the :py:mod:`MNIST ` dataset. It uses L2 regularization defined by :py:mod:`L2 ` and cross-entropy loss defined by :py:mod:`CrossEntropyLoss `. The model is trained for 3 epochs and the validation set is 10% of the training set. The model is saved to the :code:`./experiments` directory. All of this is grouped together through a trainer class :py:mod:`Trainer ` which executes the training, validation or testing loops. The metrics are logged automatically and the base algorithmic metrics are supplied by the function: :py:meth:`metrics_factory `. The logging is done through a PyTorch Lightning callback in the :py:mod:`LoggingCallback `. 25 | 26 | In general, YAMLE operates through the CLI where the user specifies the configuration of the experiment. The configuration is then parsed and the experiment is run. The configuration is specified through the command line arguments. The arguments are grouped into several categories. The most important ones are: 27 | 28 | * :code:`--method` which specified the method and its parameters 29 | * :code:`--model` which specifies the model and its parameters 30 | * :code:`--loss` which specifies the loss and its parameters 31 | * :code:`--regularizer` which specifies the regularizer and its parameters 32 | * :code:`--datamodule` which specifies the datamodule and its parameters 33 | * :code:`--trainer` which specifies the trainer and its parameters 34 | 35 | When adding a new method, datamodule, model, regularizer etc. you will be able to define your own arguments. 36 | 37 | When a model was trained we can evaluate it using: 38 | 39 | .. code-block:: bash 40 | 41 | python yamle/cli/evaluate.py --method base --trainer_devices "[0]" --datamodule mnist --datamodule_batch_size 256 --loss crossentropy --save_path ./experiments --model_hidden_dim 32 --model_depth 3 --datamodule_validation_portion 0.1 --save_path ./experiments --model fc --datamodule_pad_to_32 1 --load_path ./experiments/2023-10-23-13-11-33-546652-train-fc-mnist-base 42 | 43 | This script evaluates the model trained in the previous step. The evaluation is done on any data split specified by the datamodule. 44 | 45 | The last main feature of YAMLE is hyperparameter-optimisation. It is done through the syne-tune library. The hyperparameters and their range are specified in a config file e.g.: 46 | 47 | .. code-block:: python 48 | 49 | # config.py 50 | from syne_tune.config_space import randint, rand 51 | 52 | def config_space() -> Dict[str, Any]: 53 | return { 54 | "model_hidden_dim": randint(16, 128), 55 | "model_depth": randint(1, 5), 56 | "method": "base", 57 | "method_learning_rate": 3e-4, 58 | "method_optimizer": "adam", 59 | "method_regularizer_weight": 1e-5, 60 | "regularizer": "l2", 61 | "loss": "crossentropy", 62 | "datamodule": "mnist", 63 | "datamodule_batch_size": 256, 64 | "datamodule_validation_portion": 0.1, 65 | "datamodule_pad_to_32": 1, 66 | "trainer_epochs": 3, 67 | "save_path": "./experiments", 68 | } 69 | 70 | The config file is then passed to the hyperparameter optimisation script: 71 | 72 | .. code-block:: bash 73 | 74 | python yamle/cli/tune.py --config_file config.py --optimizer "Grid Search" --save_path ./experiments --max_wallclock_time 420 --optimization_metric "validation_accuracy" 75 | 76 | The script will run the hyperparameter optimisation and save the best model to the :code:`./experiments` directory. We encourage you to look into the tune script to see how the hyperparameter optimisation is done. 77 | 78 | In order to generate documentation from the docstrings, run: 79 | 80 | .. code-block:: bash 81 | 82 | cd docs 83 | make html 84 | 85 | The documentation will be generated in the :code:`docs/build/html` directory. 86 | 87 | -------------------------------------------------------------------------------- /yamle/data/__init__.py: -------------------------------------------------------------------------------- 1 | from typing import Type 2 | from yamle.data.datamodule import BaseDataModule 3 | from yamle.data.classification import ( 4 | ToyTwoMoonsClassificationDataModule, 5 | ToyTwoCirclesClassificationDataModule, 6 | TorchvisionClassificationDataModuleMNIST, 7 | TorchvisionClassificationDataModuleCIFAR10, 8 | TorchvisionClassificationDataModuleCIFAR5, 9 | TorchvisionClassificationDataModuleCIFAR3, 10 | TorchvisionClassificationDataModuleCIFAR100, 11 | TorchvisionClassificationDataModuleFashionMNIST, 12 | TinyImageNetClassificationDataModule, 13 | TorchvisionClassificationDataModuleSVHN, 14 | BreastCancerUCIClassificationDataModule, 15 | AdultIncomeUCIClassificationDataModule, 16 | CarEvaluationUCIClassificationDataModule, 17 | CreditUCIClassificationDataModule, 18 | DermatologyUCIClassificationDataModule, 19 | PneumoniaMNISTClassificationDataModule, 20 | DermaMNISTClassificationDataModule, 21 | BreastMNISTClassificationDataModule, 22 | BloodMNISTClassificationDataModule, 23 | ECG5000ClassificationDataModule, 24 | ) 25 | from yamle.data.regression import ( 26 | ToyRegressionDataModule, 27 | ConcreteUCIRegressionDataModule, 28 | EnergyUCIRegressionDataModule, 29 | BostonUCIRegressionDataModule, 30 | TemperatureTimeSeriesDataModule, 31 | WineQualityUCIRegressionDataModule, 32 | YachtUCIRegressionDataModule, 33 | AbaloneUCIRegressionDataModule, 34 | TelemonitoringUCIRegressionDataModule, 35 | RetinaMNISTDataModule, 36 | WikiFaceRegressionDataModule, 37 | TorchvisionRotationRegressionDataModuleMNIST, 38 | TorchvisionRotationRegressionDataModuleCIFAR10, 39 | TorchvisionRotationRegressionDataModuleFashionMNIST, 40 | TorchvisionRotationRegressionDataModuleSVHN, 41 | TorchvisionRotationRegressionDataModuleCIFAR100, 42 | TinyImageNetRotationRegressionDataModule, 43 | ) 44 | from yamle.data.segmentation import TorchvisionSegmentationDataModuleCityscapes 45 | from yamle.data.text import ( 46 | TorchtextClassificationModelWikiText2, 47 | TorchtextClassificationModelWikiText103, 48 | TorchtextClassificationModelIMDB, 49 | Shakespeare, 50 | ) 51 | from yamle.data.depth import NYUv2DataModule 52 | from yamle.data.reconstruction import ECG5000ReconstructionDataModule 53 | 54 | AVAILABLE_DATAMODULES = { 55 | "mnist": TorchvisionClassificationDataModuleMNIST, 56 | "cifar3": TorchvisionClassificationDataModuleCIFAR3, 57 | "cifar5": TorchvisionClassificationDataModuleCIFAR5, 58 | "cifar10": TorchvisionClassificationDataModuleCIFAR10, 59 | "cifar100": TorchvisionClassificationDataModuleCIFAR100, 60 | "svhn": TorchvisionClassificationDataModuleSVHN, 61 | "fashionmnist": TorchvisionClassificationDataModuleFashionMNIST, 62 | "tinyimagenet": TinyImageNetClassificationDataModule, 63 | "wiki_face": WikiFaceRegressionDataModule, 64 | "pneumoniamnist": PneumoniaMNISTClassificationDataModule, 65 | "breastmnist": BreastMNISTClassificationDataModule, 66 | "retinamnist": RetinaMNISTDataModule, 67 | "dermamnist": DermaMNISTClassificationDataModule, 68 | "bloodmnist": BloodMNISTClassificationDataModule, 69 | "toyregression": ToyRegressionDataModule, 70 | "toymoons": ToyTwoMoonsClassificationDataModule, 71 | "toycircles": ToyTwoCirclesClassificationDataModule, 72 | "ecg5000classification": ECG5000ClassificationDataModule, 73 | "ecg5000reconstruction": ECG5000ReconstructionDataModule, 74 | "cityscapes": TorchvisionSegmentationDataModuleCityscapes, 75 | "wikitext2": TorchtextClassificationModelWikiText2, 76 | "wikitext103": TorchtextClassificationModelWikiText103, 77 | "imdb": TorchtextClassificationModelIMDB, 78 | "shakespeare": Shakespeare, 79 | "concrete": ConcreteUCIRegressionDataModule, 80 | "energy": EnergyUCIRegressionDataModule, 81 | "boston": BostonUCIRegressionDataModule, 82 | "wine": WineQualityUCIRegressionDataModule, 83 | "yacht": YachtUCIRegressionDataModule, 84 | "abalone": AbaloneUCIRegressionDataModule, 85 | "telemonitoring": TelemonitoringUCIRegressionDataModule, 86 | "breastcancer": BreastCancerUCIClassificationDataModule, 87 | "adultincome": AdultIncomeUCIClassificationDataModule, 88 | "carevaluation": CarEvaluationUCIClassificationDataModule, 89 | "credit": CreditUCIClassificationDataModule, 90 | "dermatology": DermatologyUCIClassificationDataModule, 91 | "temperature": TemperatureTimeSeriesDataModule, 92 | "nyuv2": NYUv2DataModule, 93 | "rotation_mnist": TorchvisionRotationRegressionDataModuleMNIST, 94 | "rotation_cifar10": TorchvisionRotationRegressionDataModuleCIFAR10, 95 | "rotation_fashionmnist": TorchvisionRotationRegressionDataModuleFashionMNIST, 96 | "rotation_svhn": TorchvisionRotationRegressionDataModuleSVHN, 97 | "rotation_cifar100": TorchvisionRotationRegressionDataModuleCIFAR100, 98 | "rotation_tinyimagenet": TinyImageNetRotationRegressionDataModule, 99 | } 100 | 101 | 102 | def data_factory(data_type: str) -> Type[BaseDataModule]: 103 | """This function is used to create a data module instance based on the data type.""" 104 | if data_type not in AVAILABLE_DATAMODULES: 105 | raise ValueError(f"Unknown data type {data_type}.") 106 | return AVAILABLE_DATAMODULES[data_type] 107 | -------------------------------------------------------------------------------- /yamle/methods/be.py: -------------------------------------------------------------------------------- 1 | from typing import List, Dict, Any 2 | import torch 3 | import torch.nn as nn 4 | 5 | from yamle.methods.uncertain_method import MemberMethod 6 | from yamle.models.specific.be import LinearBE, Conv2dBE 7 | from yamle.utils.operation_utils import average_predictions, repeat_inputs 8 | from yamle.defaults import LOSS_KEY, TARGET_KEY, PREDICTION_KEY, MEAN_PREDICTION_KEY, TARGET_PER_MEMBER_KEY, PREDICTION_PER_MEMBER_KEY, AVERAGE_WEIGHTS_KEY 9 | 10 | 11 | def replace_with_be(model: nn.Module, num_members: int) -> None: 12 | """This method is used to replace all the `nn.Linear`, `nn.Conv2d` layers 13 | with a `LinearBE`, `Conv2dBE` respectively. 14 | 15 | Args: 16 | model (nn.Module): The model to replace the layers in. 17 | num_members (int): The number of members in the ensemble. 18 | """ 19 | for name, child in model.named_children(): 20 | if isinstance(child, nn.Linear): 21 | setattr( 22 | model, 23 | name, 24 | LinearBE( 25 | in_features=child.in_features, 26 | out_features=child.out_features, 27 | bias=child.bias, 28 | num_members=num_members, 29 | weight=child.weight, 30 | ), 31 | ) 32 | elif isinstance(child, nn.Conv2d): 33 | setattr( 34 | model, 35 | name, 36 | Conv2dBE( 37 | in_channels=child.in_channels, 38 | out_channels=child.out_channels, 39 | kernel_size=child.kernel_size, 40 | stride=child.stride, 41 | padding=child.padding, 42 | dilation=child.dilation, 43 | groups=child.groups, 44 | num_members=num_members, 45 | weight=child.weight, 46 | bias=child.bias, 47 | ), 48 | ) 49 | else: 50 | replace_with_be(child, num_members) 51 | 52 | 53 | class BEMethod(MemberMethod): 54 | """This class is the extension of the base method for BatchEnsemble models. 55 | 56 | The difference is in having to change the prediction to concatenate the `num_members` dimension. 57 | into the batch dimension during validation and testing. 58 | 59 | Note that only Linear and Conv2d layers are supported, not the batch norm layers. 60 | In practice this is not a problem https://github.com/google/edward2/blob/main/edward2/tensorflow/layers/normalization.py#L111 61 | """ 62 | 63 | def __init__(self, *args: Any, **kwargs: Any) -> None: 64 | super().__init__(*args, **kwargs) 65 | replace_with_be(self.model, self._num_members) 66 | 67 | def _predict( 68 | self, x: torch.Tensor, unsqueeze: bool = True, **forward_kwargs: Any 69 | ) -> torch.Tensor: 70 | """This method is used to perform a forward pass of the model. 71 | 72 | If the model is in evaluation mode it replicates the inputs `num_members` times and 73 | concatenates them into the batch dimension. 74 | 75 | Args: 76 | x (torch.Tensor): The input to the model. 77 | **forward_kwargs (Any): The keyword arguments to be passed to the forward pass of the model. 78 | """ 79 | if self.evaluation: 80 | x = repeat_inputs(x, self._num_members) 81 | 82 | output = self.model(x, **forward_kwargs) 83 | 84 | if self.evaluation: 85 | output = output.reshape(-1, self._num_members, *output.shape[1:]) 86 | elif unsqueeze: 87 | output = output.unsqueeze(1) 88 | return output 89 | 90 | def _validation_test_step( 91 | self, batch: List[torch.Tensor], batch_idx: int 92 | ) -> Dict[str, Any]: 93 | """This method is used to perform a single validation/test step. 94 | 95 | It assumes that the batch has a shape `(batch_size, num_features)`. 96 | It assumes that the output of the model has a shape `(batch_size, n_samples, num_classes)`. 97 | """ 98 | x, y = batch 99 | y_hat_permember = self._predict(x) 100 | # Repeat the labels num_members times 101 | y_permember = torch.stack([y] * self._num_members, dim=1) 102 | loss = self._loss_per_member(y_hat_permember, y_permember) 103 | y_hat = average_predictions(y_hat_permember, self._task) 104 | output = { 105 | LOSS_KEY: loss, 106 | TARGET_KEY: y.detach(), 107 | PREDICTION_KEY: y_hat_permember.detach(), 108 | MEAN_PREDICTION_KEY: y_hat.detach(), 109 | TARGET_PER_MEMBER_KEY: y_permember.detach(), 110 | PREDICTION_PER_MEMBER_KEY: y_hat_permember.detach(), 111 | AVERAGE_WEIGHTS_KEY: None, 112 | } 113 | return output 114 | 115 | def _validation_step( 116 | self, batch: List[torch.Tensor], batch_idx: int 117 | ) -> Dict[str, Any]: 118 | """This method is used to perform a single validation step.""" 119 | return self._validation_test_step(batch, batch_idx) 120 | 121 | def _test_step(self, batch: List[torch.Tensor], batch_idx: int) -> Dict[str, Any]: 122 | """This method is used to perform a single test step.""" 123 | return self._validation_test_step(batch, batch_idx) 124 | -------------------------------------------------------------------------------- /yamle/models/gp.py: -------------------------------------------------------------------------------- 1 | from typing import Tuple, List, Union, Dict, Any 2 | 3 | import argparse 4 | import torch 5 | from gpytorch.models import ApproximateGP 6 | from gpytorch.variational import ( 7 | CholeskyVariationalDistribution, 8 | VariationalStrategy, 9 | LMCVariationalStrategy, 10 | ) 11 | from gpytorch.means import ZeroMean, ConstantMean 12 | from gpytorch.kernels import ScaleKernel, RBFKernel, MaternKernel 13 | from gpytorch.distributions import MultivariateNormal 14 | 15 | from yamle.defaults import CLASSIFICATION_KEY 16 | 17 | 18 | class GPModel(ApproximateGP): 19 | """This class is used to create a Gaussian Process model with the given parameters. 20 | 21 | Args: 22 | prior_mean (str): The prior mean function. Can be 'zero' or 'constant'. 23 | prior_covariance (str): The prior covariance function. Can be 'rbf', 'matern32', 'matern52'. 24 | inducing_points (torch.Tensor): The inducing points. 25 | num_latent (int): The latent dimension. 26 | num_outputs (int): The number of outputs. 27 | task (str): The task to perform. Either 'classification' or 'regression'. 28 | The task determined is `softmax` is used for the output layer. 29 | """ 30 | 31 | def __init__( 32 | self, 33 | prior_mean: str, 34 | prior_covariance: str, 35 | inducing_points: torch.Tensor, 36 | num_latent: int, 37 | num_outputs: int, 38 | task: str, 39 | ) -> None: 40 | variational_distribution = CholeskyVariationalDistribution( 41 | num_inducing_points=inducing_points.size(0), 42 | batch_shape=torch.Size([num_latent]), 43 | ) 44 | variational_strategy = VariationalStrategy( 45 | self, 46 | inducing_points, 47 | variational_distribution, 48 | learn_inducing_locations=True, 49 | ) 50 | if task == CLASSIFICATION_KEY: 51 | variational_strategy = LMCVariationalStrategy( 52 | variational_strategy, 53 | num_tasks=num_outputs, 54 | num_latents=num_latent, 55 | latent_dim=-1, 56 | ) 57 | 58 | super().__init__(variational_strategy) 59 | assert prior_mean in ["zero", "constant"] 60 | assert prior_covariance in ["matern32", "matern52", "rbf"] 61 | 62 | if prior_mean == "zero": 63 | self._prior_mean = ZeroMean(batch_shape=torch.Size([num_latent])) 64 | elif prior_mean == "constant": 65 | self._prior_mean = ConstantMean(batch_shape=torch.Size([num_latent])) 66 | else: 67 | raise ValueError(f"The prior mean function {prior_mean} is not supported.") 68 | 69 | if prior_covariance == "rbf": 70 | self._prior_covariance = ScaleKernel( 71 | RBFKernel(batch_shape=torch.Size([num_latent])), 72 | batch_shape=torch.Size([num_latent]), 73 | ) 74 | elif prior_covariance == "matern32": 75 | self._prior_covariance = ScaleKernel( 76 | MaternKernel(nu=1.5, batch_shape=torch.Size([num_latent])), 77 | batch_shape=torch.Size([num_latent]), 78 | ) 79 | elif prior_covariance == "matern52": 80 | self._prior_covariance = ScaleKernel( 81 | MaternKernel(nu=2.5, batch_shape=torch.Size([num_latent])), 82 | batch_shape=torch.Size([num_latent]), 83 | ) 84 | else: 85 | raise ValueError( 86 | f"The prior covariance function {prior_covariance} is not supported." 87 | ) 88 | 89 | self._task = task 90 | 91 | def forward( 92 | self, 93 | x: torch.Tensor, 94 | staged_output: bool = False, 95 | input_kwargs: Dict[str, Any] = {}, 96 | output_kwargs: Dict[str, Any] = {}, 97 | ) -> Union[torch.Tensor, Tuple[torch.Tensor, List[torch.Tensor]]]: 98 | """This function is used to perform the forward pass of the model. 99 | 100 | Args: 101 | x (torch.Tensor): The input tensor. 102 | staged_output (bool): Whether to return the intermediate outputs. Not used in this model. 103 | input_kwargs (Dict[str, Any]): The kwargs for the input layer. 104 | output_kwargs (Dict[str, Any]): The kwargs for the output layer. 105 | """ 106 | assert not staged_output, "The staged output is not supported for this model." 107 | mean = self._prior_mean(x) 108 | covariance = self._prior_covariance(x) 109 | return MultivariateNormal(mean, covariance) 110 | 111 | def final_layer(self, x: torch.Tensor) -> torch.Tensor: 112 | """This function is used to get the final layer output.""" 113 | pass 114 | 115 | def add_method_specific_layers(self, method: str) -> None: 116 | """This method is used to add method specific layers to the model. 117 | 118 | Args: 119 | method (str): The method to use. 120 | """ 121 | pass 122 | 123 | @staticmethod 124 | def add_specific_args( 125 | parent_parser: argparse.ArgumentParser, 126 | ) -> argparse.ArgumentParser: 127 | """This method is used to add the model specific arguments to the parent parser.""" 128 | pass 129 | 130 | def reset(self) -> None: 131 | """This function is used to reset the model after each epoch.""" 132 | pass 133 | -------------------------------------------------------------------------------- /yamle/losses/contrastive.py: -------------------------------------------------------------------------------- 1 | from typing import Any, Optional 2 | 3 | import torch 4 | import argparse 5 | import torch.nn.functional as F 6 | from yamle.losses.loss import BaseLoss 7 | from yamle.defaults import TINY_EPSILON 8 | 9 | 10 | class NoiseContrastiveEstimatorLoss(BaseLoss): 11 | """This defines the noise contrastive estimation (NCE) loss. 12 | 13 | It assumes that the input shape is `(batch_size, num_members, num_classes)`. 14 | No matter what the reduction it is always averaged over the `num_members`. 15 | 16 | Args: 17 | temperature (float): The temperature to use for the softmax. Defaults to 1.0. 18 | similarity (str): The similarity function to use. Defaults to `cosine`. Choices are `cosine` and `dot`. 19 | """ 20 | 21 | def __init__( 22 | self, 23 | temperature: float = 1.0, 24 | similarity: str = "cosine", 25 | *args: Any, 26 | **kwargs: Any, 27 | ) -> None: 28 | super().__init__(*args, **kwargs) 29 | assert similarity in [ 30 | "cosine", 31 | "dot", 32 | ], f"Similarity function must be either `cosine` or `dot`. Got {similarity}." 33 | assert ( 34 | temperature > 0 35 | ), f"Temperature must be greater than 0. Got {temperature}." 36 | self._similarity = similarity 37 | self._temperature = temperature 38 | 39 | def _cosine_similarity(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: 40 | """Computes the cosine similarity between two tensors.""" 41 | return F.cosine_similarity(x, y, dim=-1, eps=TINY_EPSILON) 42 | 43 | def _dot_similarity(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: 44 | """Computes the dot similarity between two tensors.""" 45 | return torch.matmul(x, y.transpose(-1, -2)) 46 | 47 | def _loss( 48 | self, 49 | y_hat: torch.Tensor, 50 | y_hat_positive: torch.Tensor, 51 | y_hat_negative: torch.Tensor, 52 | ) -> torch.Tensor: 53 | """Computes the NCE loss. 54 | 55 | The `y_hat` tensor contains the default predictions for `x` samples. The shape is `(batch_size, num_classes)`. 56 | The `y_hat_positive` tensor contains the predictions for the positive samples. The shape is `(batch_size, num_classes)`. 57 | The `y_hat_negative` tensor contains the predictions for the negative samples. The shape is `(batch_size, K, num_classes)`. 58 | """ 59 | assert ( 60 | y_hat.shape == y_hat_positive.shape 61 | ), f"The shapes of the predictions do not match. Got {y_hat.shape}, {y_hat_positive.shape}." 62 | assert ( 63 | y_hat.shape[0] == y_hat_negative.shape[0] 64 | ), f"The batch sizes of the predictions do not match. Got {y_hat.shape[0]}, {y_hat_negative.shape[0]}." 65 | 66 | if self._similarity == "cosine": 67 | similarity_fn = self._cosine_similarity 68 | elif self._similarity == "dot": 69 | similarity_fn = self._dot_similarity 70 | else: 71 | raise NotImplementedError( 72 | f"Similarity function {self._similarity} is not implemented." 73 | ) 74 | 75 | similarity_positive = ( 76 | similarity_fn(y_hat, y_hat_positive) / self._temperature 77 | ).exp() 78 | similarity_negative = ( 79 | similarity_fn( 80 | y_hat.unsqueeze(1).repeat(1, y_hat_negative.shape[1], 1), y_hat_negative 81 | ) 82 | / self._temperature 83 | ).exp() 84 | 85 | loss = -torch.log( 86 | similarity_positive 87 | / ( 88 | similarity_positive 89 | + torch.sum(similarity_negative, dim=1) 90 | + TINY_EPSILON 91 | ) 92 | + TINY_EPSILON 93 | ) 94 | return loss 95 | 96 | def __call__( 97 | self, 98 | y_hat: torch.Tensor, 99 | y_hat_positive: torch.Tensor, 100 | y_hat_negative: torch.Tensor, 101 | weights: Optional[torch.Tensor] = None, 102 | ) -> torch.Tensor: 103 | """This method is used to compute the NCE loss.""" 104 | num_members = y_hat_positive.shape[1] 105 | loss = 0.0 106 | for i in range(num_members): 107 | sample_loss = self._loss( 108 | y_hat[:, i], y_hat_positive[:, i], y_hat_negative[:, i] 109 | ) 110 | loss += self._process_sample_loss(sample_loss, i, weights) 111 | 112 | return self._process_member_loss(loss, num_members) 113 | 114 | def __repr__(self) -> str: 115 | return f"NoiseContrastiveEstimatorLoss(reduction_per_sample={self._reduction_per_sample}, reduction_per_member={self._reduction_per_member}, reduction_per_feature={self._reduction_per_feature}, similarity={self._similarity})" 116 | 117 | @staticmethod 118 | def add_specific_args( 119 | parent_parser: argparse.ArgumentParser, 120 | ) -> argparse.ArgumentParser: 121 | parser = super( 122 | NoiseContrastiveEstimatorLoss, NoiseContrastiveEstimatorLoss 123 | ).add_specific_args(parent_parser) 124 | parser.add_argument( 125 | "--loss_temperature", 126 | type=float, 127 | default=1.0, 128 | help="The temperature to use for the softmax.", 129 | ) 130 | parser.add_argument( 131 | "--loss_similarity", 132 | type=str, 133 | choices=["cosine", "dot"], 134 | default="cosine", 135 | help="The similarity function to use.", 136 | ) 137 | return parser 138 | -------------------------------------------------------------------------------- /yamle/cli/train.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import os 3 | from argparse import ArgumentParser 4 | 5 | import pytorch_lightning as pl 6 | 7 | import yamle.utils.file_utils as utils 8 | from yamle.cli.test import evaluate 9 | from yamle.data import data_factory 10 | from yamle.evaluation.metrics.hardware import model_complexity 11 | from yamle.methods import method_factory 12 | from yamle.losses import loss_factory 13 | from yamle.models import model_factory 14 | from yamle.pruning import pruner_factory 15 | from yamle.quantization import quantizer_factory 16 | from yamle.regularizers import regularizer_factory 17 | from yamle.trainers import trainer_factory 18 | from yamle.utils.running_utils import * 19 | from yamle.utils.cli_utils import add_shared_args, add_train_args 20 | from yamle.utils.tracing_utils import trace_input_output_shapes 21 | 22 | logging = logging.getLogger("pytorch_lightning") 23 | 24 | 25 | def train(args: ArgumentParser) -> None: 26 | # Set seed 27 | pl.seed_everything(args.seed, workers=True) 28 | 29 | # Create experiment structure 30 | experiment_name = utils.get_experiment_name(args, mode="train") 31 | 32 | # Create experiment directory 33 | save_path = os.path.join(args.save_path, experiment_name) 34 | save_path = utils.create_experiment_folder(save_path, "./yamle") 35 | args.save_path = save_path 36 | 37 | # Set the logger 38 | utils.config_logger(args.save_path) 39 | logging.info("Beginning experiment: %s", experiment_name) 40 | logging.info("Arguments: %s", args) 41 | logging.info("Command arguments to reproduce: %s", utils.argparse_to_command(args)) 42 | 43 | # Prepare the datamodule and its arguments 44 | datamodule_kwargs = prepare_datamodule_kwargs(args) 45 | datamodule = data_factory(args.datamodule)(**datamodule_kwargs) 46 | datamodule.prepare_data() 47 | datamodule.setup() 48 | 49 | # Prepare the model and its arguments 50 | model_kwargs = prepare_model_kwargs(args, datamodule) 51 | model = model_factory(args.model)(**model_kwargs) 52 | logging.info("Model: %s", model) 53 | 54 | # Prepare the loss and its arguments 55 | loss_kwargs = prepare_loss_kwargs(args, datamodule) 56 | loss = loss_factory(args.loss)(**loss_kwargs) 57 | logging.info("Loss: %s", loss) 58 | 59 | # Prepare the regularizer and its arguments 60 | regularizer_kwargs = prepare_regularizer_kwargs(args) 61 | regularizer = regularizer_factory(args.regularizer)(**regularizer_kwargs) 62 | logging.info("Regularizer: %s", regularizer) 63 | 64 | # Prepare the method and its arguments 65 | method_kwargs = prepare_method_kwargs(args, datamodule) 66 | method = method_factory(args.method)( 67 | model=model, loss=loss, regularizer=regularizer, **method_kwargs 68 | ) 69 | logging.info("Method: %s", method) 70 | logging.info("Tracing input and output shapes") 71 | trace_input_output_shapes(method) 72 | 73 | # Create trainer 74 | trainer_kwargs = prepare_trainer_kwargs(args, datamodule) 75 | trainer = trainer_factory(args.trainer)(**trainer_kwargs, method=method) 76 | logging.info("Trainer: %s", trainer) 77 | 78 | # Train model 79 | results = {} 80 | if args.load_path is not None: 81 | logging.info( 82 | f"Loading model from {args.load_path}. Note that, only the model is loaded, not the method." 83 | ) 84 | method.on_before_model_load() 85 | method.model = utils.load_model(method.model, utils.model_file(args.load_path)) 86 | method.on_after_model_load() 87 | 88 | # Get the model complexity 89 | model_complexity(method.model, method, trainer.devices, results=results) 90 | 91 | trainer.fit(results) 92 | 93 | # If trainer has been interrupted e.g. by Ctrl+C terminate the experiment 94 | if trainer.interrupted: 95 | logging.info("Training interrupted. Terminating experiment.") 96 | exit(0) 97 | 98 | # Save all the data 99 | # The model is saved with respect to the method which may have done some 100 | # processing on it 101 | if not args.no_saving: 102 | utils.save_experiment(save_path, args, method, results, overwrite=True) 103 | 104 | # Evaluate model on the default dataset 105 | if not args.no_evaluation: 106 | if args.no_saving: 107 | raise ValueError( 108 | "Cannot evaluate a model without saving it first. --no_saving must be set to False." 109 | ) 110 | args.load_path = args.save_path 111 | evaluate(args, experiment_name=experiment_name) 112 | 113 | 114 | if __name__ == "__main__": 115 | parser = ArgumentParser() 116 | 117 | parser = add_shared_args(parser) 118 | parser = add_train_args(parser) 119 | 120 | args = parser.parse_known_args()[0] 121 | 122 | method = method_factory(args.method) 123 | model = model_factory(args.model) 124 | loss = loss_factory(args.loss) 125 | regularizer = regularizer_factory(args.regularizer) 126 | datamodule = data_factory(args.datamodule) 127 | trainer = trainer_factory(args.trainer) 128 | pruner = pruner_factory(args.pruner) 129 | quantizer = quantizer_factory(args.quantizer) 130 | 131 | parser = method.add_specific_args(parser) 132 | parser = model.add_specific_args(parser) 133 | parser = loss.add_specific_args(parser) 134 | parser = regularizer.add_specific_args(parser) 135 | parser = datamodule.add_specific_args(parser) 136 | parser = trainer.add_specific_args(parser) 137 | parser = pruner.add_specific_args(parser) 138 | parser = quantizer.add_specific_args(parser) 139 | 140 | args = parser.parse_args() 141 | args = utils.parse_args(args) 142 | train(args) 143 | -------------------------------------------------------------------------------- /CODE_OF_CONDUCT.md: -------------------------------------------------------------------------------- 1 | # Contributor Covenant Code of Conduct 2 | 3 | ## Our Pledge 4 | 5 | We as members, contributors, and leaders pledge to make participation in our 6 | community a harassment-free experience for everyone, regardless of age, body 7 | size, visible or invisible disability, ethnicity, sex characteristics, gender 8 | identity and expression, level of experience, education, socio-economic status, 9 | nationality, personal appearance, race, religion, or sexual identity 10 | and orientation. 11 | 12 | We pledge to act and interact in ways that contribute to an open, welcoming, 13 | diverse, inclusive, and healthy community. 14 | 15 | ## Our Standards 16 | 17 | Examples of behavior that contributes to a positive environment for our 18 | community include: 19 | 20 | * Demonstrating empathy and kindness toward other people 21 | * Being respectful of differing opinions, viewpoints, and experiences 22 | * Giving and gracefully accepting constructive feedback 23 | * Accepting responsibility and apologizing to those affected by our mistakes, 24 | and learning from the experience 25 | * Focusing on what is best not just for us as individuals, but for the 26 | overall community 27 | 28 | Examples of unacceptable behavior include: 29 | 30 | * The use of sexualized language or imagery, and sexual attention or 31 | advances of any kind 32 | * Trolling, insulting or derogatory comments, and personal or political attacks 33 | * Public or private harassment 34 | * Publishing others' private information, such as a physical or email 35 | address, without their explicit permission 36 | * Other conduct which could reasonably be considered inappropriate in a 37 | professional setting 38 | 39 | ## Enforcement Responsibilities 40 | 41 | Community leaders are responsible for clarifying and enforcing our standards of 42 | acceptable behavior and will take appropriate and fair corrective action in 43 | response to any behavior that they deem inappropriate, threatening, offensive, 44 | or harmful. 45 | 46 | Community leaders have the right and responsibility to remove, edit, or reject 47 | comments, commits, code, wiki edits, issues, and other contributions that are 48 | not aligned to this Code of Conduct, and will communicate reasons for moderation 49 | decisions when appropriate. 50 | 51 | ## Scope 52 | 53 | This Code of Conduct applies within all community spaces, and also applies when 54 | an individual is officially representing the community in public spaces. 55 | Examples of representing our community include using an official e-mail address, 56 | posting via an official social media account, or acting as an appointed 57 | representative at an online or offline event. 58 | 59 | ## Enforcement 60 | 61 | Instances of abusive, harassing, or otherwise unacceptable behavior may be 62 | reported to the community leaders responsible for enforcement at 63 | ferianc.martin@gmail.com. 64 | All complaints will be reviewed and investigated promptly and fairly. 65 | 66 | All community leaders are obligated to respect the privacy and security of the 67 | reporter of any incident. 68 | 69 | ## Enforcement Guidelines 70 | 71 | Community leaders will follow these Community Impact Guidelines in determining 72 | the consequences for any action they deem in violation of this Code of Conduct: 73 | 74 | ### 1. Correction 75 | 76 | **Community Impact**: Use of inappropriate language or other behavior deemed 77 | unprofessional or unwelcome in the community. 78 | 79 | **Consequence**: A private, written warning from community leaders, providing 80 | clarity around the nature of the violation and an explanation of why the 81 | behavior was inappropriate. A public apology may be requested. 82 | 83 | ### 2. Warning 84 | 85 | **Community Impact**: A violation through a single incident or series 86 | of actions. 87 | 88 | **Consequence**: A warning with consequences for continued behavior. No 89 | interaction with the people involved, including unsolicited interaction with 90 | those enforcing the Code of Conduct, for a specified period of time. This 91 | includes avoiding interactions in community spaces as well as external channels 92 | like social media. Violating these terms may lead to a temporary or 93 | permanent ban. 94 | 95 | ### 3. Temporary Ban 96 | 97 | **Community Impact**: A serious violation of community standards, including 98 | sustained inappropriate behavior. 99 | 100 | **Consequence**: A temporary ban from any sort of interaction or public 101 | communication with the community for a specified period of time. No public or 102 | private interaction with the people involved, including unsolicited interaction 103 | with those enforcing the Code of Conduct, is allowed during this period. 104 | Violating these terms may lead to a permanent ban. 105 | 106 | ### 4. Permanent Ban 107 | 108 | **Community Impact**: Demonstrating a pattern of violation of community 109 | standards, including sustained inappropriate behavior, harassment of an 110 | individual, or aggression toward or disparagement of classes of individuals. 111 | 112 | **Consequence**: A permanent ban from any sort of public interaction within 113 | the community. 114 | 115 | ## Attribution 116 | 117 | This Code of Conduct is adapted from the [Contributor Covenant][homepage], 118 | version 2.0, available at 119 | https://www.contributor-covenant.org/version/2/0/code_of_conduct.html. 120 | 121 | Community Impact Guidelines were inspired by [Mozilla's code of conduct 122 | enforcement ladder](https://github.com/mozilla/diversity). 123 | 124 | [homepage]: https://www.contributor-covenant.org 125 | 126 | For answers to common questions about this code of conduct, see the FAQ at 127 | https://www.contributor-covenant.org/faq. Translations are available at 128 | https://www.contributor-covenant.org/translations. 129 | -------------------------------------------------------------------------------- /docs/extending_yamle/datamodule.rst: -------------------------------------------------------------------------------- 1 | .. _extending_datamodule: 2 | 3 | ************************ 4 | Extending DataModule 5 | ************************ 6 | 7 | In this Tutorial we will demonstrate how to extend the :py:mod:`BaseDataModule ` class to create a custom DataModule. 8 | 9 | We will be adding or looking at how to add the MNIST dataset to YAMLE through a custom DataModule. MNIST is a dataset of handwritten digits, which is a popular dataset for testing image classification models. The dataset is available through the `torchvision `_ package. 10 | 11 | To start an implementation of any datamodule we recommend to look at the :py:mod:`BaseDataModule ` class. It has many arguments which can be used to customize the datamodule. 12 | 13 | .. literalinclude:: ../../yamle/data/datamodule.py 14 | :language: python 15 | :lines: 36-60 16 | 17 | This class also does already cointain a lot of useful functionality e.g. to do automatic splitting of the dataset to training, validation and calibration portions e.g. through the :py:meth:`setup ` method. 18 | 19 | .. literalinclude:: ../../yamle/data/datamodule.py 20 | :language: python 21 | :pyobject: BaseDataModule.setup 22 | 23 | Note that the :py:meth:`setup ` method wraps the datasets into a `SurrogateDataset ` which is a wrapper around the `torch.utils.data.Dataset `_ class. This wrapper allows to manually control the data or the target transformations. 24 | 25 | The transformations are generally managed through a :py:meth:`get_transform ` method which is being called for each dataset split: training, validation, calibration and testing. 26 | 27 | Then there is the :py:meth:`prepare_data ` method which is used to download the dataset. This method is only called once per machine and not per GPU. This is important to know if you want to download the dataset multiple times. The :py:meth:`prepare_data ` method is called before the :py:meth:`setup ` method. 28 | 29 | Now let's start with the implementation of the MNIST datamodule. In fact, many of the torchvision datasets can be processed in a similar way hence we will create two classes. One for general torchvision classification datasets and one concretely for MNIST. 30 | 31 | The torchvision classification datamodule is implemented in :py:mod:`TorchvisionClassificationDataModule `. 32 | 33 | .. literalinclude:: ../../yamle/data/classification.py 34 | :language: python 35 | :pyobject: TorchvisionClassificationDataModule 36 | 37 | It inherits from a :py:mod:`VisionClassificationDataModule ` which implements useful methods for debugging and plotting of the predictions or the applied augmentations. 38 | 39 | Any datamodule also allows specification of custom arguments e.g. the :code:`datamodule_pad_to_32` argument through :py:meth:`add_specific_args `. 40 | 41 | .. literalinclude:: ../../yamle/data/classification.py 42 | :language: python 43 | :pyobject: TorchvisionClassificationDataModule.add_specific_args 44 | 45 | Note the :code:`datamodule_` prefix which is used to avoid name clashes with other arguments and separate the datamodule arguments from any other arguments. 46 | 47 | The module can accept custom arguments such as :code:`pad_to_32` which can pad the image to a size of 32x32 pixels. This is useful if you want to use a model which requires a certain input size or to be used to apply out-ouf-distribution augmentations common in the field of out-of-distribution detection. Notice that, in practice the user only needs to fill in the :py:meth:`prepare_data ` method which downloads the training or the test datasets and places them at the :py:attr:`_data_dir ` location. The :py:meth:`setup ` method is then used to wrap the datasets into a :py:class:`SurrogateDataset ` and to split the training dataset into training, validation and calibration portions. 48 | 49 | Finally we create a concrete MNIST datamodule :py:mod:`TorchvisionClassificationDataModuleMNIST ` which inherits from the :py:mod:`TorchvisionClassificationDataModule ` 50 | 51 | .. literalinclude:: ../../yamle/data/classification.py 52 | :language: python 53 | :pyobject: TorchvisionClassificationDataModuleMNIST 54 | 55 | Note that each end datamodule which implements a concrete dataset needs to specify the :py:attr:`inputs_dim `, :py:attr:`outputs_dim `, :py:attr:`targets_dim ` and optionally :py:attr:`mean ` and :py:attr:`std ` attributes. These attributes are used to normalize the data and to calculate the input and output dimensions of the model. 56 | 57 | The last step is to register the new datamodule in the :py:mod:`__init__ ` module along all the other available datamodules. 58 | 59 | .. literalinclude:: ../../yamle/data/__init__.py 60 | :language: python 61 | -------------------------------------------------------------------------------- /yamle/methods/delta_uq.py: -------------------------------------------------------------------------------- 1 | from typing import Any, Dict, List 2 | import torch 3 | import argparse 4 | 5 | from yamle.methods.uncertain_method import MCSamplingMethod 6 | 7 | 8 | class DeltaUQMethod(MCSamplingMethod): 9 | """This class is the extension of the base method for delta-UQ method. 10 | 11 | The core of the method is in applying anchors during training and inference, 12 | the anchors are the samples themselves where the input of the network is: 13 | `[x-anchor, anchor]` where anchor is reshufled `x`. 14 | 15 | Note that, this method requires a special treatment in case of data augmentation. 16 | The anchors during test should be from a training set and not from the test set. 17 | Hence we will cache some of the training samples. 18 | """ 19 | 20 | def __init__(self, **kwargs: Any) -> None: 21 | super().__init__(**kwargs) 22 | self._replace_input_layer() 23 | self._anchors_cache: torch.Tensor = None 24 | self._anchors_cache_max_size = 1000 25 | 26 | def _replace_input_layer(self) -> None: 27 | """Replace the first layer with one where the input dimension is multiplied exactly by 2.""" 28 | if isinstance(self.model._input, torch.nn.Linear): 29 | self.model._input = torch.nn.Linear( 30 | in_features=self.model._input.in_features * 2, 31 | out_features=self.model._input.out_features, 32 | bias=self.model._input.bias is not None, 33 | ) 34 | elif isinstance(self.model._input, torch.nn.Conv2d): 35 | self.model._input = torch.nn.Conv2d( 36 | in_channels=self.model._input.in_channels * 2, 37 | out_channels=self.model._input.out_channels, 38 | kernel_size=self.model._input.kernel_size, 39 | stride=self.model._input.stride, 40 | padding=self.model._input.padding, 41 | dilation=self.model._input.dilation, 42 | groups=self.model._input.groups, 43 | bias=self.model._input.bias is not None, 44 | ) 45 | else: 46 | raise ValueError( 47 | "The first layer of the model should be either a `torch.nn.Linear` or a " 48 | "`torch.nn.Conv2d`." 49 | ) 50 | 51 | def state_dict(self) -> Dict[str, Any]: 52 | """This method is used to get the state of the method.""" 53 | state_dict = super().state_dict() 54 | state_dict["anchors_cache"] = self._anchors_cache.cpu() 55 | state_dict["anchors_cache_max_size"] = self._anchors_cache_max_size 56 | return state_dict 57 | 58 | def load_state_dict(self, state_dict: Dict[str, Any]) -> None: 59 | """This method is used to load the state of the method.""" 60 | super().load_state_dict(state_dict) 61 | self._anchors_cache = state_dict["anchors_cache"] 62 | self._anchors_cache_max_size = state_dict["anchors_cache_max_size"] 63 | 64 | def _predict(self, x: torch.Tensor, **forward_kwargs: Any) -> torch.Tensor: 65 | """This method is used to perform a forward pass of the model. 66 | 67 | It is done with respect to the number of samples specified in the constructor. 68 | It applies the reshuffling of the anchors and the subtraction of the anchors 69 | from the inputs. 70 | """ 71 | outputs = [] 72 | num_members = self.training_num_members if self.training else self._num_members 73 | for _ in range(num_members): 74 | anchors = ( 75 | x[torch.randperm(x.size(0))] 76 | if self._anchors_cache is None or self.training 77 | else self._sample_anchors_from_cache(x.size(0)).to(x.device) 78 | ) 79 | new_x = torch.cat([x - anchors, anchors], dim=1) 80 | outputs.append( 81 | super(MCSamplingMethod, self)._predict(new_x, **forward_kwargs) 82 | ) 83 | return torch.cat(outputs, dim=1) 84 | 85 | def _sample_anchors_from_cache(self, N: int) -> torch.Tensor: 86 | """A helper function to sample from the cache, potentially with repeition.""" 87 | if self._anchors_cache is None: 88 | raise ValueError( 89 | "The anchors cache is not initialized. Please run the training first." 90 | ) 91 | return self._anchors_cache[ 92 | torch.randint(low=0, high=len(self._anchors_cache), size=(N,)) 93 | ] 94 | 95 | def _training_step(self, batch: List[torch.Tensor], batch_idx: int) -> torch.Tensor: 96 | """This method is used to perform a training step. 97 | 98 | In this instance we will randomly cache some of the training samples. 99 | If the cache exists, in each training step randomly replace some indices. 100 | """ 101 | x, _ = batch 102 | if self._anchors_cache is None: 103 | self._anchors_cache = x.clone().detach() 104 | elif ( 105 | len(self._anchors_cache) < self._anchors_cache_max_size 106 | and self.current_epoch == 0 107 | ): 108 | self._anchors_cache = torch.cat([self._anchors_cache, x.clone().detach()]) 109 | else: 110 | self._anchors_cache[ 111 | torch.randperm(self._anchors_cache.size(0))[: x.size(0)] 112 | ] = x.clone().detach() 113 | return super()._training_step(batch, batch_idx) 114 | 115 | @staticmethod 116 | def add_specific_args( 117 | parent_parser: argparse.ArgumentParser, 118 | ) -> argparse.ArgumentParser: 119 | """This method is used to add the specific arguments for the class.""" 120 | parser = super(DeltaUQMethod, DeltaUQMethod).add_specific_args(parent_parser) 121 | return parser 122 | -------------------------------------------------------------------------------- /yamle/quantization/qat/__init__.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import logging 3 | from typing import Any 4 | 5 | import torch 6 | from pytorch_lightning import LightningModule, Trainer 7 | from torch.ao.quantization.fake_quantize import FakeQuantize 8 | from torch.ao.quantization.observer import MovingAverageMinMaxObserver 9 | 10 | 11 | from yamle.defaults import QUANTIZED_KEY 12 | from yamle.quantization.quantizer import BaseQuantizer 13 | 14 | logging = logging.getLogger("pytorch_lightning") 15 | 16 | 17 | class QATQuantizer(BaseQuantizer): 18 | """This is the quantization-aware training quantizer class. 19 | 20 | It performs quantization-aware training on the model. 21 | In contrast to the static quantizer, the quantizer uses the calibration 22 | (validation) dataset to fine-tune the model. 23 | It uses the same optimiser as the one used for training the model. 24 | 25 | Args: 26 | learning_rate (float): The learning rate to use for the fine-tuning. 27 | epochs (int): The number of epochs to use for the fine-tuning. 28 | """ 29 | 30 | def __init__( 31 | self, learning_rate: float, epochs: int, *args: Any, **kwargs: Any 32 | ) -> None: 33 | super().__init__(*args, **kwargs) 34 | self._learning_rate = learning_rate 35 | self._epochs = epochs 36 | 37 | def __call__(self, trainer: Trainer, method: LightningModule) -> None: 38 | """This method is used to quantize the model. 39 | 40 | A copy of the model is saved before quantization. 41 | First the model is prepared for quantization. 42 | Then the trainer is queried to fine-tune the model. 43 | Then the model is quantized. 44 | The original model is kept such that it can be recovered. 45 | """ 46 | self.save_original_model(method) 47 | self.prepare(trainer, method) 48 | trainer.fine_tune(self._epochs) 49 | 50 | method.model.apply(torch.ao.quantization.disable_observer) 51 | 52 | self.save_quantized_model(method) 53 | logging.info("Model quantized.") 54 | logging.info(method.model) 55 | setattr(method, QUANTIZED_KEY, True) 56 | self.cleanup(method, trainer) 57 | 58 | def prepare(self, trainer: Trainer, method: LightningModule) -> None: 59 | """This method is used to prepare the model for quantization. 60 | 61 | It caches the original hyperparameters for the optimisation and replaces the hyperparameters 62 | with the ones for the fine-tuning. 63 | """ 64 | method.model.eval() 65 | self.replace_layers_for_quantization(method.model) 66 | method.model.qconfig = self.get_qconfig() 67 | method.model.train() 68 | torch.quantization.prepare_qat(method.model, inplace=True) 69 | logging.info("Model prepared for quantization.") 70 | logging.info(method.model) 71 | 72 | # Cache the original hyperparameters 73 | self._original_hyperparameters = { 74 | "learning_rate": method.hparams.learning_rate, 75 | "epochs": trainer._epochs, 76 | } 77 | 78 | # Replace the hyperparameters 79 | method.hparams.learning_rate = self._learning_rate 80 | trainer._epochs = self._epochs 81 | 82 | def cleanup( 83 | self, method: LightningModule, trainer: Trainer, *args: Any, **kwargs: Any 84 | ) -> None: 85 | """This method is used to clean up the model after quantization.""" 86 | super().cleanup(*args, **kwargs) 87 | 88 | # Recover the original hyperparameters 89 | method.hparams.learning_rate = self._original_hyperparameters["learning_rate"] 90 | trainer._epochs = self._original_hyperparameters["epochs"] 91 | 92 | del self._original_hyperparameters 93 | 94 | def get_qconfig(self) -> Any: 95 | """This method is used to get the quantization configuration. 96 | 97 | We use the number of activation and weight bits to create the quantization configuration. 98 | """ 99 | # Else specify the qconfig manually based on the activation and weight bits 100 | activation_bits = self._activation_bits 101 | weight_bits = self._weight_bits 102 | 103 | activation_fq = FakeQuantize.with_args( 104 | observer=MovingAverageMinMaxObserver, 105 | quant_min=0, 106 | quant_max=int(2**activation_bits - 1), 107 | dtype=torch.quint8, 108 | qscheme=torch.per_tensor_affine, 109 | reduce_range=False, # Since this is in simulation, we don't want to reduce the range 110 | ) 111 | weight_fq = FakeQuantize.with_args( 112 | observer=MovingAverageMinMaxObserver, 113 | quant_min=-int((2**weight_bits) / 2), 114 | quant_max=int((2**weight_bits) / 2 - 1), 115 | dtype=torch.qint8, 116 | qscheme=torch.per_tensor_symmetric, 117 | reduce_range=False, # Since this is in simulation, we don't want to reduce the range 118 | ) 119 | return torch.quantization.QConfig(activation=activation_fq, weight=weight_fq) 120 | 121 | @staticmethod 122 | def add_specific_args( 123 | parent_parser: argparse.ArgumentParser, 124 | ) -> argparse.ArgumentParser: 125 | """This method is used to add specific arguments to the parser.""" 126 | parser = super(QATQuantizer, QATQuantizer).add_specific_args(parent_parser) 127 | parser.add_argument( 128 | "--quantizer_learning_rate", 129 | type=float, 130 | default=1e-3, 131 | help="The learning rate to use for the fine-tuning.", 132 | ) 133 | parser.add_argument( 134 | "--quantizer_epochs", 135 | type=int, 136 | default=1, 137 | help="The number of epochs to use for the fine-tuning.", 138 | ) 139 | return parser 140 | -------------------------------------------------------------------------------- /yamle/pruning/pruner.py: -------------------------------------------------------------------------------- 1 | from typing import Any 2 | from abc import ABC, abstractmethod 3 | import argparse 4 | import torch.nn as nn 5 | import torch 6 | from tabulate import tabulate 7 | 8 | from yamle.defaults import MASK_PRUNING_KEY, FORMER_DATA_PRUNING_KEY 9 | from yamle.utils.pruning_utils import is_layer_prunable, is_parameter_prunable 10 | 11 | import logging 12 | 13 | logging = logging.getLogger("pytorch_lightning") 14 | 15 | 16 | class BasePruner(ABC): 17 | """This is the base class for all prune methods. 18 | 19 | The pruner's call method will be used to prune the model. 20 | Additionally there is a function to analyse the model and print a summary of the pruning 21 | per each named parameter. 22 | """ 23 | 24 | @abstractmethod 25 | def __call__(self, m: nn.Module) -> Any: 26 | """This method is used to prune the model.""" 27 | pass 28 | 29 | @staticmethod 30 | def prune_parameter(p: nn.Parameter, mask: torch.Tensor) -> None: 31 | """This function is used to prune a parameter based on a mask. 32 | 33 | The mask's `True` values will be pruned. `False` values will be kept. 34 | 35 | This function and this function only should be used to prune parameters. 36 | """ 37 | assert ( 38 | p.shape == mask.shape 39 | ), f"Parameter shape {p.shape} and mask shape {mask.shape} do not match." 40 | if hasattr(p, MASK_PRUNING_KEY): 41 | logging.warning( 42 | "The parameter already has a mask. The new mask will be used." 43 | ) 44 | 45 | setattr(p, MASK_PRUNING_KEY, mask.detach().cpu().clone()) 46 | setattr(p, FORMER_DATA_PRUNING_KEY, p.data.detach().cpu().clone()) 47 | p.data[mask] = 0.0 48 | 49 | @staticmethod 50 | def recover_parameter(p: nn.Parameter) -> None: 51 | """This function is used to recover a parameter from a mask.""" 52 | if hasattr(p, FORMER_DATA_PRUNING_KEY): 53 | p.data = getattr(p, FORMER_DATA_PRUNING_KEY).to(p.device) 54 | # Delete the data but keep the mask such that there is a record of the pruning 55 | delattr(p, FORMER_DATA_PRUNING_KEY) 56 | else: 57 | logging.warning("No former data was found for this parameter.") 58 | 59 | def recover(self, m: nn.Module) -> None: 60 | """This method is used to recover the model from a mask.""" 61 | for module in m.modules(): 62 | if is_layer_prunable(module): 63 | for p in module.parameters(): 64 | if is_parameter_prunable(p): 65 | self.recover_parameter(p) 66 | 67 | @staticmethod 68 | def add_specific_args( 69 | parent_parser: argparse.ArgumentParser, 70 | ) -> argparse.ArgumentParser: 71 | """This method is used to add the pruner specific arguments to the parent parser.""" 72 | return argparse.ArgumentParser(parents=[parent_parser], add_help=False) 73 | 74 | def summary(self, module: nn.Module) -> int: 75 | """This method is used to print a summary of the pruning per parameter. 76 | 77 | The pruned weights are recognised as the weights with a value of 0. 78 | 79 | The format of the is: the first column is the name of the parameter, 80 | the second column is the total number of weights in the parameter, the third column 81 | is the number of pruned weights in the parameter, the fourth column is the percentage 82 | of pruned weights in the parameter. The last row is the total number of pruned weights 83 | in the model, the total number of weights in the model and the percentage of pruned weights. 84 | Returns the total number of pruned weights. 85 | """ 86 | table = [] 87 | total_pruned = 0 88 | total_weights = 0 89 | non_pruned_parameters = [] 90 | for name, m in module.named_modules(): 91 | if is_layer_prunable(m): 92 | for parameter_name, p in m.named_parameters(): 93 | if is_parameter_prunable(p): 94 | total_weights += p.numel() 95 | pruned = torch.sum(p.data == 0).item() 96 | total_pruned += pruned 97 | table.append( 98 | [ 99 | f"{name}.{parameter_name}", 100 | p.numel(), 101 | pruned, 102 | pruned / p.numel(), 103 | ] 104 | ) 105 | else: 106 | non_pruned_parameters.append(f"{name}.{parameter_name}") 107 | else: 108 | # If the module is a leaf module add the name to the non-pruned parameters. 109 | if len(list(m.children())) == 0: 110 | non_pruned_parameters.append(name) 111 | table.append( 112 | ["Total", total_weights, total_pruned, total_pruned / total_weights] 113 | ) 114 | logging.info( 115 | tabulate( 116 | table, 117 | headers=["Parameter", "Total", "Pruned", "Pruned [%]"], 118 | tablefmt="github", 119 | ) 120 | ) 121 | logging.info( 122 | f"Non-pruned parameters/layers: {', '.join(non_pruned_parameters)}" 123 | ) 124 | return total_pruned 125 | 126 | def __repr__(self) -> str: 127 | return f"{self.__class__.__name__}()" 128 | 129 | 130 | class DummyPruner(BasePruner): 131 | """This is a dummy pruner class which does not do any pruning.""" 132 | 133 | def __call__(self, *args: Any, **kwargs: Any) -> None: 134 | pass 135 | 136 | def summary(self, m: nn.Module) -> int: 137 | logging.info("No pruning was performed.") 138 | return 0 139 | 140 | def __repr__(self) -> str: 141 | return f"{self.__class__.__name__}()" 142 | -------------------------------------------------------------------------------- /yamle/models/specific/sgld.py: -------------------------------------------------------------------------------- 1 | from typing import Optional, Iterable, Callable 2 | import torch 3 | import torch.nn as nn 4 | from torch.optim import Optimizer 5 | import math 6 | from yamle.defaults import TINY_EPSILON 7 | 8 | 9 | class SGLD(Optimizer): 10 | """Stochastic Gradient Langevin Dynamics optimizer. 11 | 12 | Adopted from: https://github.com/JavierAntoran/Bayesian-Neural-Networks/blob/master/src/Stochastic_Gradient_Langevin_Dynamics/optimizers.py 13 | 14 | We added the options for momentum and nestrov acceleration. 15 | 16 | Args: 17 | params (Iterable[nn.Parameter]): Parameters to optimize. 18 | lr (float, optional): Learning rate (default: 1e-2). 19 | momentum (float, optional): Momentum factor (default: 0.9). 20 | nestrov (bool, optional): Use nestrov acceleration (default: False). 21 | """ 22 | 23 | def __init__( 24 | self, 25 | params: Iterable[nn.Parameter], 26 | lr: float = 1e-2, 27 | momentum: float = 0.9, 28 | nestrov: bool = False, 29 | ) -> None: 30 | if lr < 0.0: 31 | raise ValueError(f"Invalid learning rate: {lr}") 32 | 33 | if momentum < 0.0 or not 0.0 <= momentum <= 1.0: 34 | raise ValueError(f"Invalid momentum value: {momentum}") 35 | 36 | defaults = dict(lr=lr, momentum=momentum, nestrov=nestrov) 37 | super(SGLD, self).__init__(params, defaults) 38 | 39 | def _step(self, closure: Optional[Callable] = None) -> Optional[float]: 40 | loss = None 41 | if closure is not None: 42 | with torch.enable_grad(): 43 | loss = closure() 44 | 45 | for group in self.param_groups: 46 | momentum = group["momentum"] 47 | nestrov = group["nestrov"] 48 | for p in group["params"]: 49 | if p.grad is None: 50 | continue 51 | 52 | d_p = p.grad.data 53 | 54 | if momentum != 0: 55 | param_state = self.state[p] 56 | if "momentum_buffer" not in param_state: 57 | param_state["momentum_buffer"] = d_p.clone().detach() 58 | buf = param_state["momentum_buffer"] 59 | else: 60 | buf = param_state["momentum_buffer"] 61 | buf.mul_(momentum).add_(d_p, alpha=1) 62 | 63 | if nestrov: 64 | d_p = d_p.add(buf, alpha=momentum) 65 | else: 66 | d_p = buf 67 | # Generage noise according to eq. 4 in the paper 68 | noise = p.data.new(p.data.size()).normal_( 69 | mean=0, std=math.sqrt(group["lr"] + TINY_EPSILON) 70 | ) 71 | # Update the parameters according to eq. 4 in the paper 72 | p.data.add_(d_p, alpha=-group["lr"] * 0.5) 73 | p.data.add_(noise, alpha=-1) 74 | 75 | return loss 76 | 77 | 78 | class pSGLD(Optimizer): 79 | """Preconditioned Stochastic Gradient Langevin Dynamics optimizer. 80 | 81 | Adopted from: https://github.com/JavierAntoran/Bayesian-Neural-Networks/blob/master/src/Stochastic_Gradient_Langevin_Dynamics/optimizers.py 82 | 83 | Args: 84 | params (Iterable[nn.Parameter]): Parameters to optimize. 85 | lr (float, optional): Learning rate (default: 1e-2). 86 | alpha (float, optional): Alpha value for preconditioner (default: 0.99). 87 | centered (bool, optional): Use centered version of pSGLD (default: True). 88 | """ 89 | 90 | def __init__( 91 | self, 92 | params: Iterable[nn.Parameter], 93 | lr: float = 1e-2, 94 | alpha: float = 0.99, 95 | centered: bool = True, 96 | ) -> None: 97 | if lr < 0.0: 98 | raise ValueError(f"Invalid learning rate: {lr}") 99 | 100 | if alpha < 0.0 or not 0.0 <= alpha <= 1.0: 101 | raise ValueError(f"Invalid alpha value: {alpha}") 102 | 103 | defaults = dict( 104 | lr=lr, 105 | alpha=alpha, 106 | centered=centered, 107 | ) 108 | super(pSGLD, self).__init__(params, defaults) 109 | 110 | def step(self, closure: Optional[Callable] = None) -> Optional[float]: 111 | loss = None 112 | if closure is not None: 113 | with torch.enable_grad(): 114 | loss = closure() 115 | 116 | for group in self.param_groups: 117 | alpha = group["alpha"] 118 | centered = group["centered"] 119 | for p in group["params"]: 120 | if p.grad is None: 121 | continue 122 | 123 | d_p = p.grad.data 124 | state = self.state[p] 125 | if "square_avg" not in state: 126 | state["square_avg"] = torch.zeros_like(p.data) 127 | square_avg = state["square_avg"] 128 | square_avg.mul_(alpha).addcmul_(d_p, d_p, value=1 - alpha) 129 | if centered: 130 | if "grad_avg" not in state: 131 | state["grad_avg"] = torch.zeros_like(p.data) 132 | grad_avg = state["grad_avg"] 133 | grad_avg.mul_(alpha).add_(d_p, alpha=1 - alpha) 134 | avg = ( 135 | square_avg.addcmul(grad_avg, grad_avg, value=-1) 136 | .add_(TINY_EPSILON) 137 | .sqrt_() 138 | ) 139 | else: 140 | avg = square_avg.sqrt().add_(TINY_EPSILON) 141 | 142 | noise = p.data.new(p.data.size()).normal_(mean=0, std=1.0) / math.sqrt( 143 | group["lr"] + TINY_EPSILON 144 | ) 145 | p.data.add_( 146 | 0.5 * d_p.div_(avg) + noise / torch.sqrt(avg), alpha=-group["lr"] 147 | ) 148 | 149 | return loss 150 | --------------------------------------------------------------------------------