├── cifar10 ├── __init__.py ├── path_config.py ├── transforms.py └── dataloader.py ├── cifar100 ├── __init__.py ├── path_config.py ├── transforms.py └── dataloader.py ├── models ├── manifolds │ ├── math │ │ ├── __init__.py │ │ ├── variance.py │ │ ├── midpoint.py │ │ ├── linreg.py │ │ ├── diffgeom.py │ │ ├── frechet_mean.py │ │ └── diffgeom_autograd.py │ ├── __init__.py │ └── poincare_disk.py ├── nn │ ├── __init__.py │ └── modules │ │ ├── __init__.py │ │ ├── linear.py │ │ ├── batchnorm.py │ │ └── convolution.py ├── resnets │ ├── __init__.py │ ├── parse_model_from_name.py │ ├── euclidean.py │ ├── euclidean_w_hyp_class.py │ └── hyperbolic.py └── optimizers.py ├── config.py ├── requirements.txt ├── example_config.ini ├── gradcam.sh ├── train.sh ├── adversarial_attacks.sh ├── ood_detection.sh ├── README.md ├── adversarial_attacks.py ├── .gitignore ├── ood_utils ├── svhn_loader.py └── display_results.py ├── train.py ├── gradcam.py ├── ood_detection.py └── LICENSE /cifar10/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /cifar100/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /models/manifolds/math/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /models/nn/__init__.py: -------------------------------------------------------------------------------- 1 | from .modules import * 2 | -------------------------------------------------------------------------------- /config.py: -------------------------------------------------------------------------------- 1 | from configparser import ConfigParser 2 | 3 | config = ConfigParser() 4 | config.read("config.ini") 5 | -------------------------------------------------------------------------------- /models/manifolds/__init__.py: -------------------------------------------------------------------------------- 1 | from .math import * 2 | from .poincare_disk import PoincareBall, poincareball_factory 3 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | torch 2 | torchvision 3 | tensorboard 4 | cleverhans 5 | geoopt 6 | grad-cam 7 | numpy 8 | timm 9 | -------------------------------------------------------------------------------- /models/resnets/__init__.py: -------------------------------------------------------------------------------- 1 | from .euclidean import * 2 | from .hyperbolic import * 3 | from .parse_model_from_name import * 4 | -------------------------------------------------------------------------------- /models/nn/modules/__init__.py: -------------------------------------------------------------------------------- 1 | from .batchnorm import PoincareBatchNorm, PoincareBatchNorm2d 2 | from .convolution import PoincareConvolution2d 3 | from .linear import PoincareLinear 4 | -------------------------------------------------------------------------------- /example_config.ini: -------------------------------------------------------------------------------- 1 | [DATASETS] 2 | Cifar10 = /PATH/TO/CIFAR10 3 | Cifar100 = /PATH/TO/CIFAR100 4 | Places365 = /PATH/TO/PLACES365 5 | SVHN = /PATH/TO/SVHN 6 | Textures = /PATH/TO/TEXTURES 7 | -------------------------------------------------------------------------------- /gradcam.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | source ~/miniconda3/etc/profile.d/conda.sh 3 | conda activate poincare_resnet 4 | 5 | python -m gradcam \ 6 | 8-16-32-resnet-32 \ 7 | cifar10 \ 8 | -e 9 | -------------------------------------------------------------------------------- /train.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | source ~/miniconda3/etc/profile.d/conda.sh 3 | conda activate poincare_resnet 4 | 5 | python -m train \ 6 | euclidean-8-16-32-resnet-20 \ 7 | cifar100 \ 8 | -e 100 \ 9 | -s \ 10 | --opt=adam \ 11 | --lr=0.001 \ 12 | --weight-decay=1e-4 13 | -------------------------------------------------------------------------------- /cifar10/path_config.py: -------------------------------------------------------------------------------- 1 | import configparser 2 | import os 3 | 4 | cfd = os.path.dirname(os.path.realpath(__file__)) 5 | ini_path = os.path.join(os.path.dirname(cfd), "config.ini") 6 | 7 | config = configparser.ConfigParser() 8 | config.read(ini_path) 9 | data_dir = config["DATASETS"]["Cifar10"] 10 | -------------------------------------------------------------------------------- /cifar100/path_config.py: -------------------------------------------------------------------------------- 1 | import configparser 2 | import os 3 | 4 | cfd = os.path.dirname(os.path.realpath(__file__)) 5 | ini_path = os.path.join(os.path.dirname(cfd), "config.ini") 6 | 7 | config = configparser.ConfigParser() 8 | config.read(ini_path) 9 | data_dir = config["DATASETS"]["Cifar100"] 10 | -------------------------------------------------------------------------------- /cifar10/transforms.py: -------------------------------------------------------------------------------- 1 | import torchvision.transforms as transforms 2 | 3 | 4 | def get_standard_transform(train: bool = True) -> transforms.Compose: 5 | transform = [ 6 | transforms.ToTensor(), 7 | transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)), 8 | ] 9 | 10 | if train: 11 | transform.extend( 12 | [ 13 | transforms.RandomCrop(size=32, padding=4), 14 | transforms.RandomHorizontalFlip(), 15 | ] 16 | ) 17 | 18 | return transforms.Compose(transform) 19 | -------------------------------------------------------------------------------- /cifar100/transforms.py: -------------------------------------------------------------------------------- 1 | import torchvision.transforms as transforms 2 | 3 | 4 | def get_standard_transform(train: bool = True) -> transforms.Compose: 5 | transform = [ 6 | transforms.ToTensor(), 7 | transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)), 8 | ] 9 | 10 | if train: 11 | transform.extend( 12 | [ 13 | transforms.RandomCrop(size=32, padding=4), 14 | transforms.RandomHorizontalFlip(), 15 | ] 16 | ) 17 | 18 | return transforms.Compose(transform) 19 | -------------------------------------------------------------------------------- /adversarial_attacks.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | source ~/miniconda3/etc/profile.d/conda.sh 3 | conda activate poincare_resnet 4 | 5 | declare -a models=( 6 | "hyperbolic-8-16-32-resnet-32" 7 | "euclidean-8-16-32-resnet-32" 8 | "euclideanwhypclass-8-16-32-resnet-32" 9 | ) 10 | 11 | declare -a epsilons=( 12 | "0.00314" 13 | "0.00627" 14 | "0.00941" 15 | "0.01255" 16 | ) 17 | 18 | for model in "${models[@]}"; do 19 | echo $model 20 | for epsilon in "${epsilons[@]}"; do 21 | echo $epsilon 22 | python -m adversarial_attacks \ 23 | $model cifar10 \ 24 | -e $epsilon \ 25 | --batch-size 128 26 | done 27 | done 28 | -------------------------------------------------------------------------------- /ood_detection.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | source ~/miniconda3/etc/profile.d/conda.sh 3 | conda activate poincare_resnet 4 | 5 | declare -a datasets=( 6 | "cifar10" 7 | "cifar100" 8 | ) 9 | 10 | declare -a models=( 11 | "hyperbolic-8-16-32-resnet-20" 12 | "euclidean-8-16-32-resnet-20" 13 | "euclideanwhypclass-8-16-32-resnet-20" 14 | "hyperbolic-8-16-32-resnet-32" 15 | "euclidean-8-16-32-resnet-32" 16 | "euclideanwhypclass-8-16-32-resnet-32" 17 | ) 18 | 19 | for dataset in "${datasets[@]}"; do 20 | echo $dataset 21 | for model in "${models[@]}"; do 22 | echo $model 23 | python -m ood_detection \ 24 | --model $model \ 25 | --dataset $dataset \ 26 | --num_to_avg 10 27 | done 28 | done 29 | -------------------------------------------------------------------------------- /models/optimizers.py: -------------------------------------------------------------------------------- 1 | import geoopt 2 | import torch 3 | import torch.nn as nn 4 | 5 | optimizer_dict = { 6 | "sgd": geoopt.optim.RiemannianSGD, 7 | "adam": geoopt.optim.RiemannianAdam, 8 | } 9 | 10 | 11 | allowed_opt_kwargs = { 12 | "sgd": [ 13 | "lr", 14 | "momentum", 15 | "weight_decay", 16 | ], 17 | "adam": [ 18 | "lr", 19 | "weight_decay", 20 | ], 21 | } 22 | 23 | 24 | def parse_optimizer_kwargs(args: dict) -> list: 25 | opt = args["opt"] 26 | opt_kwargs = {} 27 | for key in args.keys(): 28 | if key in allowed_opt_kwargs[opt]: 29 | opt_kwargs[key] = args[key] 30 | 31 | return opt, opt_kwargs 32 | 33 | 34 | def initialize_optimizer( 35 | model: nn.Module, 36 | args, 37 | ) -> list[torch.optim.Optimizer]: 38 | # Parse optimizer specification and configuration 39 | opt, config = parse_optimizer_kwargs(vars(args)) 40 | 41 | # Create optimizer 42 | optimizer = optimizer_dict[opt] 43 | return optimizer(model.parameters(), **config) 44 | -------------------------------------------------------------------------------- /models/resnets/parse_model_from_name.py: -------------------------------------------------------------------------------- 1 | from .euclidean import EuclideanResNet 2 | from .euclidean_w_hyp_class import EuclideanResNetWHypClass 3 | from .hyperbolic import HyperbolicResNet 4 | 5 | 6 | def parse_model_from_name( 7 | model_name: str, 8 | classes: int, 9 | ) -> EuclideanResNet | EuclideanResNetWHypClass | HyperbolicResNet: 10 | keys = model_name.split("-") 11 | model_type = keys[0] 12 | channel_dims = [int(k) for k in keys[1:4]] 13 | # depths explanation: 2 layers outside groups (-2), 2 * 3 = 6 layers added 14 | # when a residual block (2 layers) is added to each group (*3) 15 | depths = 3 * [(int(keys[-1]) - 2) // 6] 16 | 17 | if model_type == "euclidean": 18 | model_class = EuclideanResNet 19 | elif model_type == "euclideanwhypclass": 20 | model_class = EuclideanResNetWHypClass 21 | elif model_type == "hyperbolic": 22 | model_class = HyperbolicResNet 23 | 24 | return model_class( 25 | classes=classes, 26 | channel_dims=channel_dims, 27 | depths=depths, 28 | ) 29 | -------------------------------------------------------------------------------- /models/manifolds/math/variance.py: -------------------------------------------------------------------------------- 1 | from typing import Optional 2 | 3 | import torch 4 | 5 | from .diffgeom import dg_dist 6 | from .diffgeom_autograd import ag_dist 7 | 8 | 9 | def frechet_variance( 10 | x: torch.Tensor, 11 | mu: torch.Tensor, 12 | c: torch.Tensor, 13 | dim: int = -1, 14 | w: Optional[torch.Tensor] = None, 15 | custom_autograd: bool = True, 16 | ) -> torch.Tensor: 17 | """ 18 | Args 19 | ---- 20 | x (tensor): points of shape [..., points, dim] 21 | mu (tensor): mean of shape [..., dim] 22 | w (tensor): weights of shape [..., points] 23 | 24 | where the ... of the three variables line up 25 | 26 | Returns 27 | ------- 28 | tensor of shape [...] 29 | """ 30 | if custom_autograd: 31 | distance: torch.Tensor = ag_dist(x=x, y=mu, c=c, dim=dim) 32 | else: 33 | distance: torch.Tensor = dg_dist(x=x, y=mu, c=c, dim=dim) 34 | distance = distance.pow(2) 35 | 36 | if w is None: 37 | return distance.mean(dim=dim) 38 | else: 39 | return (distance * w).sum(dim=dim) 40 | -------------------------------------------------------------------------------- /cifar10/dataloader.py: -------------------------------------------------------------------------------- 1 | import torchvision 2 | from torch.utils.data import DataLoader 3 | 4 | from .path_config import data_dir 5 | from .transforms import get_standard_transform 6 | 7 | 8 | class Cifar10DataLoaderFactory: 9 | train_transform = get_standard_transform(train=True) 10 | test_transform = get_standard_transform(train=False) 11 | 12 | train_set = torchvision.datasets.CIFAR10( 13 | root=data_dir, 14 | train=True, 15 | download=True, 16 | transform=train_transform, 17 | ) 18 | 19 | test_set = torchvision.datasets.CIFAR10( 20 | root=data_dir, 21 | train=False, 22 | download=True, 23 | transform=test_transform, 24 | ) 25 | 26 | @classmethod 27 | def create_train_loaders(cls, batch_size: int): 28 | train_loader = DataLoader( 29 | dataset=cls.train_set, 30 | batch_size=batch_size, 31 | shuffle=True, 32 | ) 33 | 34 | test_loader = DataLoader( 35 | dataset=cls.test_set, 36 | batch_size=batch_size, 37 | shuffle=False, 38 | ) 39 | 40 | return train_loader, test_loader 41 | -------------------------------------------------------------------------------- /cifar100/dataloader.py: -------------------------------------------------------------------------------- 1 | import torchvision 2 | from torch.utils.data import DataLoader 3 | 4 | from .path_config import data_dir 5 | from .transforms import get_standard_transform 6 | 7 | 8 | class Cifar100DataLoaderFactory: 9 | train_transform = get_standard_transform(train=True) 10 | test_transform = get_standard_transform(train=False) 11 | 12 | train_set = torchvision.datasets.CIFAR100( 13 | root=data_dir, 14 | train=True, 15 | download=True, 16 | transform=train_transform, 17 | ) 18 | 19 | test_set = torchvision.datasets.CIFAR100( 20 | root=data_dir, 21 | train=False, 22 | download=True, 23 | transform=test_transform, 24 | ) 25 | 26 | @classmethod 27 | def create_train_loaders(cls, batch_size: int): 28 | train_loader = DataLoader( 29 | dataset=cls.train_set, 30 | batch_size=batch_size, 31 | shuffle=True, 32 | ) 33 | 34 | test_loader = DataLoader( 35 | dataset=cls.test_set, 36 | batch_size=batch_size, 37 | shuffle=False, 38 | ) 39 | 40 | return train_loader, test_loader 41 | -------------------------------------------------------------------------------- /models/nn/modules/linear.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | from ...manifolds import PoincareBall 5 | 6 | 7 | class PoincareLinear(nn.Module): 8 | """Poincare fully connected linear layer""" 9 | 10 | def __init__( 11 | self, 12 | in_features: int, 13 | out_features: int, 14 | ball: PoincareBall, 15 | bias: bool = True, 16 | id_init: bool = True, 17 | ) -> None: 18 | super(PoincareLinear, self).__init__() 19 | self.in_features = in_features 20 | self.out_features = out_features 21 | self.ball = ball 22 | self.has_bias = bias 23 | self.id_init = id_init 24 | 25 | self.z = nn.Parameter(torch.empty(in_features, out_features)) 26 | if self.has_bias: 27 | self.bias = nn.Parameter(torch.empty(out_features)) 28 | self.reset_parameters() 29 | 30 | def reset_parameters(self) -> None: 31 | if self.id_init: 32 | self.z = nn.Parameter( 33 | 1 / 2 * torch.eye(self.in_features, self.out_features) 34 | ) 35 | else: 36 | nn.init.normal_( 37 | self.z, mean=0, std=(2 * self.in_features * self.out_features) ** -0.5 38 | ) 39 | if self.has_bias: 40 | nn.init.zeros_(self.bias) 41 | 42 | def forward(self, x: torch.Tensor) -> torch.Tensor: 43 | x = self.ball.expmap0(x, dim=-1) 44 | y = self.ball.fully_connected( 45 | x=x, 46 | z=self.z, 47 | bias=self.bias, 48 | ) 49 | return self.ball.logmap0(y, dim=-1) 50 | -------------------------------------------------------------------------------- /models/manifolds/math/midpoint.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | 4 | def poincare_to_klein(x: torch.Tensor, c: torch.Tensor, dim: int = -1) -> torch.Tensor: 5 | return 2 / (1 + c * x.pow(2).sum(dim=dim, keepdim=True)) * x 6 | 7 | 8 | def klein_to_poincare(x: torch.Tensor, c: torch.Tensor, dim: int = -1) -> torch.Tensor: 9 | gamma = 1 / (1 - c * x.pow(2).sum(dim=dim, keepdim=True)).sqrt().clamp_min(1e-15) 10 | return gamma / (1 + gamma) * x 11 | 12 | 13 | def klein_midpoint( 14 | x: torch.Tensor, 15 | c: torch.Tensor, 16 | vec_dim: int = -1, 17 | batch_dim: int = 0, 18 | ) -> torch.Tensor: 19 | gamma = 1 / (1 - c * x.pow(2).sum(dim=vec_dim, keepdim=True)).sqrt().clamp_min( 20 | 1e-15 21 | ) 22 | numerator = (gamma * x).sum(dim=batch_dim, keepdim=True) 23 | denominator = gamma.sum(dim=batch_dim, keepdim=True) 24 | return numerator / denominator 25 | 26 | 27 | def poincare_klein_midpoint( 28 | x: torch.Tensor, 29 | c: torch.Tensor, 30 | vec_dim: int = -1, 31 | batch_dim: int = 0, 32 | ) -> torch.Tensor: 33 | x = poincare_to_klein(x, c, vec_dim) 34 | m = klein_midpoint(x, c, vec_dim, batch_dim) 35 | return klein_to_poincare(m, c, vec_dim) 36 | 37 | 38 | def poincare_midpoint( 39 | x: torch.Tensor, 40 | c: torch.Tensor, 41 | vec_dim: int = -1, 42 | batch_dim: int = 0, 43 | ): 44 | gamma_sq = 1 / (1 - c * x.pow(2).sum(dim=vec_dim, keepdim=True)).clamp_min(1e-15) 45 | numerator = (gamma_sq * x).sum(dim=batch_dim, keepdim=True) 46 | denominator = gamma_sq.sum(dim=batch_dim, keepdim=True) - x.size(batch_dim) / 2 47 | m = numerator / denominator 48 | gamma_m = 1 / (1 - c * m.pow(2).sum(dim=vec_dim, keepdim=True)).sqrt().clamp_min( 49 | 1e-15 50 | ) 51 | return gamma_m / (1 + gamma_m) * m 52 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Poincaré ResNet 2 | Repository containing the code for the [Poincaré ResNet paper](https://arxiv.org/abs/2303.14027). 3 | 4 | # Installation 5 | This repository requires `python >= 3.10`. To install the required packages, run 6 | ``` 7 | pip install -r requirements.txt 8 | ``` 9 | The repository uses CIFAR-10 and CIFAR-100 for training and Places365, SVHN and Textures for out-of-distrbution (OOD) detection. Your root directory (where this README is located) must contain a `config.ini` file containing the paths to these datasets. An example ini file can be found in `example_config.ini`. Instructions for downloading the required datasets can be found at: 10 | 11 | - [CIFAR-10/100](https://www.cs.toronto.edu/~kriz/cifar.html) 12 | - [Places365](http://places2.csail.mit.edu/download.html) 13 | - [SVHN](http://ufldl.stanford.edu/housenumbers/) 14 | - [Textures](https://www.robots.ox.ac.uk/~vgg/data/dtd/) 15 | 16 | Note that the last three of these datasets are only required for OOD detection. If you are only interested in other components of this repository, feel free to ignore the paths to these datasets in the ini file. If you are interested in the OOD detection but do not have access to one or more of these datasets, you can remove the corresponding sections from the `ood_detection.py` file. 17 | 18 | # Training 19 | To train a model, use the CLI tool in `train.py`. The naming convention for the models is as follows: 20 | ``` 21 | ----resnet- 22 | ``` 23 | where depth = 3 * 2 * block_depth + 2. As an example, 24 | ``` 25 | hyperbolic-8-16-32-resnet-32 26 | ``` 27 | leads to a hyperbolic ResNet with channel sizes (8, 16, 32) and with block sizes (5, 5, 5). 28 | 29 | For simplicity, the `train.sh` script contains an example of a call to the train tool with some sensible arguments. 30 | 31 | # Robustness experiments 32 | Each of the robustness experiments has its own CLI tool: `ood_detection.py`, `adversarial_attacks.py`, `gradcam.py` 33 | 34 | These experiments require models to have already been trained and stored by using the train tool mentioned above with the `-s` flag (stores the weights in ./weights directory). 35 | 36 | Examples of how to run these experiments with sensible arguments are shown in the corresponding shell scripts. 37 | -------------------------------------------------------------------------------- /models/manifolds/math/linreg.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | 4 | def poincare_mlr( 5 | x: torch.Tensor, 6 | z: torch.Tensor, 7 | r: torch.Tensor, 8 | c: torch.Tensor, 9 | ) -> torch.Tensor: 10 | """ 11 | The Poincare multinomial logistic regression (MLR) operation. 12 | 13 | Parameters 14 | ---------- 15 | x : tensor 16 | contains input values 17 | z : tensor 18 | contains the hyperbolic vectors describing the hyperplane orientations 19 | r : tensor 20 | contains the hyperplane offsets 21 | c : tensor 22 | curvature of the Poincare disk 23 | 24 | Returns 25 | ------- 26 | tensor 27 | signed distances of input w.r.t. the hyperplanes, denoted by v_k(x) in 28 | the HNN++ paper 29 | """ 30 | # Compute some variables 31 | c_sqrt = c.sqrt() 32 | lam = 2 * (1 - c * x.pow(2).sum(dim=-1, keepdim=True)) 33 | z_norm = z.norm(dim=0).clamp_min(1e-15) 34 | 35 | # Computation can be simplified if there is no offset 36 | if r is not None: 37 | two_csqrt_r = 2.0 * c_sqrt * r 38 | return ( 39 | 2 40 | * z_norm 41 | / c_sqrt 42 | * torch.asinh( 43 | c_sqrt * lam / z_norm * torch.matmul(x, z) * two_csqrt_r.cosh() 44 | - (lam - 1) * two_csqrt_r.sinh() 45 | ) 46 | ) 47 | else: 48 | return ( 49 | 2 50 | * z_norm 51 | / c_sqrt 52 | * torch.asinh(c_sqrt * lam / z_norm * torch.matmul(x, z)) 53 | ) 54 | 55 | 56 | def poincare_fully_connected( 57 | x: torch.Tensor, 58 | z: torch.Tensor, 59 | bias: torch.Tensor, 60 | c: torch.Tensor, 61 | ) -> torch.Tensor: 62 | """ 63 | The Poincare fully connected layer operation. 64 | 65 | Parameters 66 | ---------- 67 | x : tensor 68 | contains the layer inputs 69 | z : tensor 70 | contains the hyperbolic vectors describing the hyperplane orientations 71 | bias : tensor 72 | contains the biases (hyperplane offsets) 73 | c : tensor 74 | curvature of the Poincare disk 75 | 76 | Returns 77 | ------- 78 | tensor 79 | Poincare FC transformed hyperbolic tensor, commonly denoted by y 80 | """ 81 | c_sqrt = c.sqrt() 82 | 83 | # Perform MLR to compute v(x) 84 | x = poincare_mlr(x=x, z=z, r=bias, c=c) 85 | 86 | # Compute the w vector 87 | x = (c_sqrt * x).sinh() / c_sqrt 88 | 89 | # Compute y 90 | return x / (1 + (1 + c * x.pow(2).sum(dim=-1, keepdim=True)).sqrt()) 91 | -------------------------------------------------------------------------------- /models/nn/modules/batchnorm.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | from ...manifolds import PoincareBall 5 | from ...manifolds.math.frechet_mean import frechet_mean 6 | from ...manifolds.math.midpoint import poincare_midpoint 7 | from ...manifolds.math.variance import frechet_variance 8 | 9 | 10 | class PoincareBatchNorm(nn.Module): 11 | """ 12 | Basic implementation of batch normalization in the Poincare ball model. 13 | 14 | Based on: 15 | https://arxiv.org/abs/2003.00335 16 | """ 17 | 18 | def __init__( 19 | self, 20 | features: int, 21 | ball: PoincareBall, 22 | use_midpoint: bool = True, 23 | ) -> None: 24 | super(PoincareBatchNorm, self).__init__() 25 | self.features = features 26 | self.ball = ball 27 | self.use_midpoint = use_midpoint 28 | 29 | self.mean = nn.Parameter(torch.zeros(features)) 30 | self.var = nn.Parameter(torch.tensor(1.0)) 31 | 32 | # statistics 33 | self.register_buffer("running_mean", torch.zeros(1, features)) 34 | self.register_buffer("running_var", torch.tensor(1.0)) 35 | self.updates = 0 36 | 37 | def forward(self, x, momentum=0.9): 38 | x = self.ball.expmap0(x, dim=-1) 39 | mean_on_ball = self.ball.expmap0(self.mean, dim=-1) 40 | if self.use_midpoint: 41 | input_mean = poincare_midpoint(x, self.ball.c, vec_dim=-1, batch_dim=0) 42 | else: 43 | input_mean = frechet_mean(x, self.ball) 44 | input_var = frechet_variance(x, input_mean, self.ball.c, dim=-1) 45 | 46 | input_logm = self.ball.transp( 47 | x=input_mean, 48 | y=mean_on_ball, 49 | v=self.ball.logmap(input_mean, x), 50 | ) 51 | 52 | input_logm = (self.var / (input_var + 1e-6)).sqrt() * input_logm 53 | 54 | output = self.ball.expmap(mean_on_ball.unsqueeze(-2), input_logm) 55 | 56 | self.updates += 1 57 | 58 | if self.ball.logmap0(output, dim=-1).isnan().any(): 59 | print("bug") 60 | if self.ball.logmap0(output, dim=-1).isnan().any(): 61 | print("bn bug") 62 | return self.ball.logmap0(output, dim=-1) 63 | 64 | 65 | class PoincareBatchNorm2d(nn.Module): 66 | """ 67 | 2D implementation of batch normalization in the Poincare ball model. 68 | 69 | Based on: 70 | https://arxiv.org/abs/2003.00335 71 | """ 72 | 73 | def __init__( 74 | self, 75 | features: int, 76 | ball: PoincareBall, 77 | use_midpoint: bool = True, 78 | ) -> None: 79 | super(PoincareBatchNorm2d, self).__init__() 80 | self.features = features 81 | self.ball = ball 82 | self.use_midpoint = use_midpoint 83 | 84 | self.norm = PoincareBatchNorm( 85 | features=features, 86 | ball=ball, 87 | use_midpoint=use_midpoint, 88 | ) 89 | 90 | def forward(self, x: torch.Tensor) -> torch.Tensor: 91 | # Store input dimensions 92 | batch_size, height, width = x.size(0), x.size(2), x.size(3) 93 | 94 | # Swap batch and channel dimensions and flatten everything but channel dimension 95 | x = x.permute(0, 2, 3, 1).flatten(start_dim=0, end_dim=2) 96 | 97 | # Apply batchnorm 98 | x = self.norm(x) 99 | 100 | # Reshape to original dimensions 101 | x = x.reshape(batch_size, height, width, self.features).permute(0, 3, 1, 2) 102 | 103 | return x 104 | -------------------------------------------------------------------------------- /adversarial_attacks.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | 4 | import numpy as np 5 | import torch 6 | import torch.nn as nn 7 | from cleverhans.torch.attacks.fast_gradient_method import fast_gradient_method 8 | from timm import utils 9 | from tqdm import tqdm 10 | 11 | from cifar10.dataloader import Cifar10DataLoaderFactory 12 | from cifar100.dataloader import Cifar100DataLoaderFactory 13 | from models.resnets import parse_model_from_name 14 | 15 | parser = argparse.ArgumentParser(description="PyTorch adversarial attack evaluation") 16 | 17 | parser.add_argument("model", type=str, help="Model name") 18 | parser.add_argument("dataset", type=str, help="Dataset name") 19 | parser.add_argument("-e", "--epsilon", type=float) 20 | parser.add_argument("-b", "--batch-size", type=int, default=128) 21 | 22 | 23 | def create_metrics_dict(): 24 | return { 25 | "losses": utils.AverageMeter(), 26 | "top1": utils.AverageMeter(), 27 | "top5": utils.AverageMeter(), 28 | } 29 | 30 | 31 | def update_metrics_dict(metrics, input, loss, acc1, acc5): 32 | metrics["losses"].update(loss.data.item(), input.size(0)) 33 | metrics["top1"].update(acc1.item(), input.size(0)) 34 | metrics["top5"].update(acc5.item(), input.size(0)) 35 | return metrics 36 | 37 | 38 | def main() -> dict: 39 | args = parser.parse_args() 40 | 41 | if args.dataset == "cifar10": 42 | _, test_loader = Cifar10DataLoaderFactory.create_train_loaders( 43 | batch_size=args.batch_size 44 | ) 45 | classes = 10 46 | elif args.dataset == "cifar100": 47 | _, test_loader = Cifar100DataLoaderFactory.create_train_loaders( 48 | batch_size=args.batch_size 49 | ) 50 | classes = 100 51 | 52 | model = parse_model_from_name(model_name=args.model, classes=classes).cuda() 53 | weights_path = os.path.join( 54 | os.path.dirname(os.path.abspath(__file__)), 55 | "weights", 56 | args.dataset, 57 | args.model, 58 | f"{args.model}_weights.pth", 59 | ) 60 | state_dict = torch.load(weights_path) 61 | model.load_state_dict(state_dict) 62 | model.eval() 63 | 64 | loss_fn = nn.CrossEntropyLoss() 65 | 66 | metrics = {attack: create_metrics_dict() for attack in ["clean", "fgm"]} 67 | 68 | for input, target in tqdm(test_loader): 69 | input, target = input.cuda(), target.cuda() 70 | input_fgm = fast_gradient_method( 71 | model_fn=model, x=input, eps=args.epsilon, norm=np.inf 72 | ) 73 | 74 | output = model(input) 75 | output_fgm = model(input_fgm) 76 | 77 | loss = loss_fn(output, target) 78 | acc1, acc5 = utils.accuracy(output, target, topk=(1, 5)) 79 | metrics["clean"] = update_metrics_dict( 80 | metrics["clean"], input, loss, acc1, acc5 81 | ) 82 | 83 | loss_fgm = loss_fn(output_fgm, target) 84 | acc1_fgm, acc5_fgm = utils.accuracy(output_fgm, target, topk=(1, 5)) 85 | metrics["fgm"] = update_metrics_dict( 86 | metrics["fgm"], input_fgm, loss_fgm, acc1_fgm, acc5_fgm 87 | ) 88 | 89 | del loss, loss_fgm 90 | 91 | for attack in metrics.keys(): 92 | print( 93 | f"Metrics for {attack}: " 94 | f"Loss: {metrics[attack]['losses'].avg:>7.4f} " 95 | f"Acc@1: {metrics[attack]['top1'].avg:>7.4f} " 96 | f"Acc@5: {metrics[attack]['top5'].avg:>7.4f}" 97 | ) 98 | 99 | 100 | if __name__ == "__main__": 101 | main() 102 | -------------------------------------------------------------------------------- /models/manifolds/math/diffgeom.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | 4 | def dg_mobius_add(x: torch.Tensor, y: torch.Tensor, c: torch.Tensor, dim: int = -1): 5 | x2 = x.pow(2).sum(dim=dim, keepdim=True) 6 | y2 = y.pow(2).sum(dim=dim, keepdim=True) 7 | xy = (x * y).sum(dim=dim, keepdim=True) 8 | num = (1 + 2 * c * xy + c * y2) * x + (1 - c * x2) * y 9 | denom = 1 + 2 * c * xy + c**2 * x2 * y2 10 | return num / denom.clamp_min(1e-15) 11 | 12 | 13 | def dg_project(x: torch.Tensor, c: torch.Tensor, dim: int = -1, eps: float = -1.0): 14 | if eps < 0: 15 | if x.dtype == torch.float32: 16 | eps = 4e-3 17 | else: 18 | eps = 1e-5 19 | maxnorm = (1 - eps) / ((c + 1e-15) ** 0.5) 20 | maxnorm = torch.where(c.gt(0), maxnorm, c.new_full((), 1e15)) 21 | norm = x.norm(dim=dim, keepdim=True, p=2).clamp_min(1e-15) 22 | cond = norm > maxnorm 23 | projected = x / norm * maxnorm 24 | return torch.where(cond, projected, x) 25 | 26 | 27 | def dg_expmap0(v: torch.Tensor, c: torch.Tensor, dim: int = -1): 28 | v_norm_c_sqrt = v.norm(dim=dim, keepdim=True).clamp_min(1e-15) * c.sqrt() 29 | return dg_project(torch.tanh(v_norm_c_sqrt) * v / v_norm_c_sqrt, c) 30 | 31 | 32 | def dg_logmap0(y: torch.Tensor, c: torch.Tensor, dim: int = -1): 33 | y_norm_c_sqrt = y.norm(dim=dim, keepdim=True).clamp_min(1e-15) * c.sqrt() 34 | return torch.atanh(y_norm_c_sqrt) * y / y_norm_c_sqrt 35 | 36 | 37 | def dg_expmap(x: torch.Tensor, v: torch.Tensor, c: torch.Tensor, dim: int = -1): 38 | v_norm = v.norm(dim=dim, keepdim=True).clamp_min(1e-15) 39 | lambda_x = 2 / (1 - c * x.pow(2).sum(dim=dim, keepdim=True)).clamp_min(1e-15) 40 | c_sqrt = c.sqrt() 41 | second_term = torch.tanh(c_sqrt * lambda_x * v_norm / 2) * v / (c_sqrt * v_norm) 42 | return dg_project(dg_mobius_add(x, second_term, c, dim=dim), c, dim=dim) 43 | 44 | 45 | def dg_logmap(x: torch.Tensor, y: torch.Tensor, c: torch.Tensor, dim: int = -1): 46 | min_x_y = dg_mobius_add(-x, y, c, dim=dim) 47 | min_x_y_norm = min_x_y.norm(dim=dim, keepdim=True).clamp_min(1e-15) 48 | lambda_x = 2 / (1 - c * x.pow(2).sum(dim=dim, keepdim=True)).clamp_min(1e-15) 49 | c_sqrt = c.sqrt() 50 | return ( 51 | 2 52 | / (c_sqrt * lambda_x) 53 | * torch.atanh(c_sqrt * min_x_y_norm) 54 | * min_x_y 55 | / min_x_y_norm 56 | ) 57 | 58 | 59 | def dg_gyration( 60 | u: torch.Tensor, 61 | v: torch.Tensor, 62 | w: torch.Tensor, 63 | c: torch.Tensor, 64 | dim: int = -1, 65 | ): 66 | u2 = u.pow(2).sum(dim=dim, keepdim=True) 67 | v2 = v.pow(2).sum(dim=dim, keepdim=True) 68 | uv = (u * v).sum(dim=dim, keepdim=True) 69 | uw = (u * w).sum(dim=dim, keepdim=True) 70 | vw = (v * w).sum(dim=dim, keepdim=True) 71 | K2 = c**2 72 | a = -K2 * uw * v2 + c * vw + 2 * K2 * uv * vw 73 | b = -K2 * vw * u2 - c * uw 74 | d = 1 + 2 * c * uv + K2 * u2 * v2 75 | return w + 2 * (a * u + b * v) / d.clamp_min(1e-15) 76 | 77 | 78 | def dg_transp( 79 | x: torch.Tensor, 80 | y: torch.Tensor, 81 | v: torch.Tensor, 82 | c: torch.Tensor, 83 | dim: int = -1, 84 | ): 85 | lambda_x = 2 / (1 - c * x.pow(2).sum(dim=dim, keepdim=True)).clamp_min(1e-15) 86 | lambda_y = 2 / (1 - c * y.pow(2).sum(dim=dim, keepdim=True)).clamp_min(1e-15) 87 | return dg_gyration(y, -x, v, c, dim=dim) * lambda_x / lambda_y 88 | 89 | 90 | def dg_dist( 91 | x: torch.Tensor, 92 | y: torch.Tensor, 93 | c: torch.Tensor, 94 | dim: int = -1, 95 | keepdim: bool = False, 96 | ) -> torch.Tensor: 97 | return ( 98 | 2 99 | / c.sqrt() 100 | * ( 101 | c.sqrt() * dg_mobius_add(-x, y, c, dim=dim).norm(dim=dim, keepdim=keepdim) 102 | ).atanh() 103 | ) 104 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Project specific 2 | /runs 3 | /weights 4 | */data 5 | /gradcam 6 | config.ini 7 | 8 | # Byte-compiled / optimized / DLL files 9 | __pycache__/ 10 | *.py[cod] 11 | *$py.class 12 | 13 | # C extensions 14 | *.so 15 | 16 | # Distribution / packaging 17 | .Python 18 | build/ 19 | develop-eggs/ 20 | dist/ 21 | downloads/ 22 | eggs/ 23 | .eggs/ 24 | lib/ 25 | lib64/ 26 | parts/ 27 | sdist/ 28 | var/ 29 | wheels/ 30 | share/python-wheels/ 31 | *.egg-info/ 32 | .installed.cfg 33 | *.egg 34 | MANIFEST 35 | 36 | # PyInstaller 37 | # Usually these files are written by a python script from a template 38 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 39 | *.manifest 40 | *.spec 41 | 42 | # Installer logs 43 | pip-log.txt 44 | pip-delete-this-directory.txt 45 | 46 | # Unit test / coverage reports 47 | htmlcov/ 48 | .tox/ 49 | .nox/ 50 | .coverage 51 | .coverage.* 52 | .cache 53 | nosetests.xml 54 | coverage.xml 55 | *.cover 56 | *.py,cover 57 | .hypothesis/ 58 | .pytest_cache/ 59 | cover/ 60 | 61 | # Translations 62 | *.mo 63 | *.pot 64 | 65 | # Django stuff: 66 | *.log 67 | local_settings.py 68 | db.sqlite3 69 | db.sqlite3-journal 70 | 71 | # Flask stuff: 72 | instance/ 73 | .webassets-cache 74 | 75 | # Scrapy stuff: 76 | .scrapy 77 | 78 | # Sphinx documentation 79 | docs/_build/ 80 | 81 | # PyBuilder 82 | .pybuilder/ 83 | target/ 84 | 85 | # Jupyter Notebook 86 | .ipynb_checkpoints 87 | 88 | # IPython 89 | profile_default/ 90 | ipython_config.py 91 | 92 | # pyenv 93 | # For a library or package, you might want to ignore these files since the code is 94 | # intended to run in multiple environments; otherwise, check them in: 95 | # .python-version 96 | 97 | # pipenv 98 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 99 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 100 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 101 | # install all needed dependencies. 102 | #Pipfile.lock 103 | 104 | # poetry 105 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. 106 | # This is especially recommended for binary packages to ensure reproducibility, and is more 107 | # commonly ignored for libraries. 108 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control 109 | #poetry.lock 110 | 111 | # pdm 112 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. 113 | #pdm.lock 114 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it 115 | # in version control. 116 | # https://pdm.fming.dev/#use-with-ide 117 | .pdm.toml 118 | 119 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm 120 | __pypackages__/ 121 | 122 | # Celery stuff 123 | celerybeat-schedule 124 | celerybeat.pid 125 | 126 | # SageMath parsed files 127 | *.sage.py 128 | 129 | # Environments 130 | .env 131 | .venv 132 | env/ 133 | venv/ 134 | ENV/ 135 | env.bak/ 136 | venv.bak/ 137 | 138 | # Spyder project settings 139 | .spyderproject 140 | .spyproject 141 | 142 | # Rope project settings 143 | .ropeproject 144 | 145 | # mkdocs documentation 146 | /site 147 | 148 | # mypy 149 | .mypy_cache/ 150 | .dmypy.json 151 | dmypy.json 152 | 153 | # Pyre type checker 154 | .pyre/ 155 | 156 | # pytype static type analyzer 157 | .pytype/ 158 | 159 | # Cython debug symbols 160 | cython_debug/ 161 | 162 | # PyCharm 163 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can 164 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore 165 | # and can be added to the global gitignore or merged into this file. For a more nuclear 166 | # option (not recommended) you can uncomment the following to ignore the entire idea folder. 167 | #.idea/ 168 | -------------------------------------------------------------------------------- /models/nn/modules/convolution.py: -------------------------------------------------------------------------------- 1 | from typing import Tuple 2 | 3 | import torch 4 | import torch.nn as nn 5 | from scipy.special import beta 6 | 7 | from ...manifolds import PoincareBall 8 | 9 | 10 | class PoincareConvolution2d(nn.Module): 11 | """Poincare 2 dimensional convolution layer""" 12 | 13 | def __init__( 14 | self, 15 | in_channels: int, 16 | out_channels: int, 17 | kernel_dims: Tuple[int, int], 18 | ball: PoincareBall, 19 | bias: bool = True, 20 | stride: int = 1, 21 | padding: int = 0, 22 | id_init: bool = True, 23 | ) -> None: 24 | # Store layer parameters 25 | super(PoincareConvolution2d, self).__init__() 26 | self.in_channels = in_channels 27 | self.out_channels = out_channels 28 | self.kernel_dims = kernel_dims 29 | self.kernel_size = kernel_dims[0] * kernel_dims[1] 30 | self.ball = ball 31 | self.stride = stride 32 | self.padding = padding 33 | self.id_init = id_init 34 | 35 | # Unfolding layer 36 | self.unfold = nn.Unfold( 37 | kernel_size=kernel_dims, 38 | padding=padding, 39 | stride=stride, 40 | ) 41 | 42 | # Create weights 43 | self.has_bias = bias 44 | if bias: 45 | self.bias = nn.Parameter(torch.empty(out_channels)) 46 | self.weights = nn.Parameter( 47 | torch.empty(self.kernel_size * in_channels, out_channels) 48 | ) 49 | 50 | # Initialize weights 51 | self.reset_parameters() 52 | 53 | # Create beta's for concatenating receptive field features 54 | self.beta_ni = beta(self.in_channels / 2, 1 / 2) 55 | self.beta_n = beta(self.in_channels * self.kernel_size / 2, 1 / 2) 56 | 57 | def reset_parameters(self): 58 | # Identity initialization (1/2 factor to counter 2 inside the distance formula) 59 | if self.id_init: 60 | self.weights = nn.Parameter( 61 | 1 62 | / 2 63 | * torch.eye(self.kernel_size * self.in_channels, self.out_channels) 64 | ) 65 | else: 66 | nn.init.normal_( 67 | self.weights, 68 | mean=0, 69 | std=(2 * self.in_channels * self.kernel_size * self.out_channels) 70 | ** -0.5, 71 | ) 72 | if self.has_bias: 73 | nn.init.zeros_(self.bias) 74 | 75 | def forward(self, x: torch.Tensor) -> torch.Tensor: 76 | """ 77 | Forward pass of the 2 dimensional convolution layer 78 | 79 | Parameters 80 | ---------- 81 | x : tensor (height, width, batchsize, input channels) 82 | contains the layer inputs 83 | 84 | Returns 85 | ------- 86 | tensor (height, width, batchsize, output channels) 87 | """ 88 | batch_size, height, width = x.size(0), x.size(2), x.size(3) 89 | out_height = ( 90 | height - self.kernel_dims[0] + 1 + 2 * self.padding 91 | ) // self.stride 92 | out_width = (width - self.kernel_dims[1] + 1 + 2 * self.padding) // self.stride 93 | 94 | # Scalar transform for concatenation 95 | x = x * self.beta_n / self.beta_ni 96 | 97 | # Apply sliding window to input to obtain features of each frame 98 | x = self.unfold(x) 99 | x = x.transpose(1, 2) 100 | 101 | # Project the receptive field features back onto the Poincare ball 102 | x = self.ball.expmap0(x, dim=-1) 103 | 104 | # Apply the Poincare fully connected operation 105 | x = self.ball.fully_connected( 106 | x=x, 107 | z=self.weights, 108 | bias=self.bias, 109 | ) 110 | 111 | # Convert y back to the proper shape 112 | x = x.transpose(1, 2).reshape( 113 | batch_size, self.out_channels, out_height, out_width 114 | ) 115 | 116 | # return y 117 | return self.ball.logmap0(x, dim=1) 118 | -------------------------------------------------------------------------------- /models/resnets/euclidean.py: -------------------------------------------------------------------------------- 1 | from typing import Optional 2 | 3 | import torch 4 | import torch.nn as nn 5 | 6 | 7 | def _conv3x3(in_channels: int, out_channels: int, stride: int = 1) -> nn.Conv2d: 8 | return nn.Conv2d( 9 | in_channels, 10 | out_channels, 11 | kernel_size=3, 12 | stride=stride, 13 | padding=1, 14 | bias=True, 15 | ) 16 | 17 | 18 | class ResidualBlock(nn.Module): 19 | """The basic building block of a wide ResNet 20 | 21 | Note that the batch normalization and ReLU appear before the convolution instead of after 22 | the convolution. This makes it different from the original ResNet blocks. 23 | """ 24 | 25 | def __init__( 26 | self, 27 | in_channels: int, 28 | out_channels: int, 29 | stride: int = 1, 30 | downsample: Optional[nn.Sequential] = None, 31 | inplace: bool = True, 32 | ) -> None: 33 | super(ResidualBlock, self).__init__() 34 | self.in_channels = in_channels 35 | self.out_channels = out_channels 36 | self.stride = stride 37 | 38 | self.relu = nn.ReLU(inplace=inplace) 39 | self.conv1 = _conv3x3( 40 | in_channels=in_channels, 41 | out_channels=out_channels, 42 | stride=stride, 43 | ) 44 | self.bn1 = nn.BatchNorm2d(out_channels) 45 | self.conv2 = _conv3x3( 46 | in_channels=out_channels, 47 | out_channels=out_channels, 48 | ) 49 | self.bn2 = nn.BatchNorm2d(out_channels) 50 | self.downsample = downsample 51 | 52 | def forward(self, x: torch.Tensor) -> torch.Tensor: 53 | residual = x 54 | x = self.conv1(x) 55 | x = self.bn1(x) 56 | x = self.relu(x) 57 | x = self.conv2(x) 58 | x = self.bn2(x) 59 | 60 | if self.downsample is not None: 61 | residual = self.downsample(residual) 62 | 63 | x = x + residual 64 | x = self.relu(x) 65 | 66 | return x 67 | 68 | 69 | class EuclideanResNet(nn.Module): 70 | """Residual Networks 71 | 72 | Implementation of Residual Networks as described in: https://arxiv.org/pdf/1512.03385.pdf 73 | """ 74 | 75 | def __init__( 76 | self, 77 | classes: int, 78 | channel_dims: list[int], 79 | depths: list[int], 80 | ) -> None: 81 | super(EuclideanResNet, self).__init__() 82 | self.classes = classes 83 | self.channel_dims = channel_dims 84 | self.depths = depths 85 | 86 | self.relu = nn.ReLU(inplace=True) 87 | self.conv = _conv3x3( 88 | in_channels=3, 89 | out_channels=channel_dims[0], 90 | ) 91 | self.bn = nn.BatchNorm2d(channel_dims[0]) 92 | 93 | self.group1 = self._make_group( 94 | in_channels=channel_dims[0], 95 | out_channels=channel_dims[0], 96 | depth=depths[0], 97 | ) 98 | 99 | self.group2 = self._make_group( 100 | in_channels=channel_dims[0], 101 | out_channels=channel_dims[1], 102 | depth=depths[1], 103 | stride=2, 104 | ) 105 | 106 | self.group3 = self._make_group( 107 | in_channels=channel_dims[1], 108 | out_channels=channel_dims[2], 109 | depth=depths[2], 110 | stride=2, 111 | ) 112 | 113 | self.avg_pool = nn.AvgPool2d(8) 114 | self.fc = nn.Linear(channel_dims[2], classes) 115 | 116 | def forward(self, x: torch.Tensor) -> torch.Tensor: 117 | x = self.conv(x) 118 | x = self.bn(x) 119 | x = self.relu(x) 120 | x = self.group1(x) 121 | x = self.group2(x) 122 | x = self.group3(x) 123 | x = self.avg_pool(x) 124 | x = self.fc(x.squeeze()) 125 | return x 126 | 127 | def _make_group( 128 | self, 129 | in_channels: int, 130 | out_channels: int, 131 | depth: int, 132 | stride: int = 1, 133 | ) -> nn.Sequential: 134 | downsample = None 135 | if stride != 1: 136 | downsample = nn.Conv2d( 137 | in_channels=in_channels, 138 | out_channels=out_channels, 139 | kernel_size=1, 140 | stride=stride, 141 | padding=0, 142 | bias=True, 143 | ) 144 | 145 | layers = [ 146 | ResidualBlock( 147 | in_channels=in_channels, 148 | out_channels=out_channels, 149 | stride=stride, 150 | downsample=downsample, 151 | ) 152 | ] 153 | 154 | for _ in range(1, depth): 155 | layers.append( 156 | ResidualBlock(in_channels=out_channels, out_channels=out_channels) 157 | ) 158 | 159 | return nn.Sequential(*layers) 160 | -------------------------------------------------------------------------------- /models/resnets/euclidean_w_hyp_class.py: -------------------------------------------------------------------------------- 1 | from typing import Optional 2 | 3 | import torch 4 | import torch.nn as nn 5 | 6 | from ..manifolds import poincareball_factory 7 | from ..nn import PoincareLinear 8 | 9 | 10 | def _conv3x3(in_channels: int, out_channels: int, stride: int = 1) -> nn.Conv2d: 11 | return nn.Conv2d( 12 | in_channels, 13 | out_channels, 14 | kernel_size=3, 15 | stride=stride, 16 | padding=1, 17 | bias=True, 18 | ) 19 | 20 | 21 | class ResidualBlock(nn.Module): 22 | """The basic building block of a wide ResNet 23 | 24 | Note that the batch normalization and ReLU appear before the convolution instead of after 25 | the convolution. This makes it different from the original ResNet blocks. 26 | """ 27 | 28 | def __init__( 29 | self, 30 | in_channels: int, 31 | out_channels: int, 32 | stride: int = 1, 33 | downsample: Optional[nn.Sequential] = None, 34 | inplace: bool = True, 35 | ) -> None: 36 | super(ResidualBlock, self).__init__() 37 | self.in_channels = in_channels 38 | self.out_channels = out_channels 39 | self.stride = stride 40 | 41 | self.relu = nn.ReLU(inplace=inplace) 42 | self.conv1 = _conv3x3( 43 | in_channels=in_channels, 44 | out_channels=out_channels, 45 | stride=stride, 46 | ) 47 | self.bn1 = nn.BatchNorm2d(out_channels) 48 | self.conv2 = _conv3x3( 49 | in_channels=out_channels, 50 | out_channels=out_channels, 51 | ) 52 | self.bn2 = nn.BatchNorm2d(out_channels) 53 | self.downsample = downsample 54 | 55 | def forward(self, x: torch.Tensor) -> torch.Tensor: 56 | residual = x 57 | x = self.conv1(x) 58 | x = self.bn1(x) 59 | x = self.relu(x) 60 | x = self.conv2(x) 61 | x = self.bn2(x) 62 | 63 | if self.downsample is not None: 64 | residual = self.downsample(residual) 65 | 66 | x = x + residual 67 | x = self.relu(x) 68 | 69 | return x 70 | 71 | 72 | class EuclideanResNetWHypClass(nn.Module): 73 | """Residual Networks 74 | 75 | Implementation of Residual Networks as described in: https://arxiv.org/pdf/1512.03385.pdf 76 | """ 77 | 78 | def __init__( 79 | self, 80 | classes: int, 81 | channel_dims: list[int], 82 | depths: list[int], 83 | init_c: float = 1, 84 | custom_autograd: bool = True, 85 | learnable: bool = True, 86 | ) -> None: 87 | super(EuclideanResNetWHypClass, self).__init__() 88 | self.classes = classes 89 | self.channel_dims = channel_dims 90 | self.depths = depths 91 | self.ball = poincareball_factory( 92 | c=init_c, custom_autograd=custom_autograd, learnable=learnable 93 | ) 94 | 95 | self.relu = nn.ReLU(inplace=True) 96 | self.conv = _conv3x3( 97 | in_channels=3, 98 | out_channels=channel_dims[0], 99 | ) 100 | self.bn = nn.BatchNorm2d(channel_dims[0]) 101 | 102 | self.group1 = self._make_group( 103 | in_channels=channel_dims[0], 104 | out_channels=channel_dims[0], 105 | depth=depths[0], 106 | ) 107 | 108 | self.group2 = self._make_group( 109 | in_channels=channel_dims[0], 110 | out_channels=channel_dims[1], 111 | depth=depths[1], 112 | stride=2, 113 | ) 114 | 115 | self.group3 = self._make_group( 116 | in_channels=channel_dims[1], 117 | out_channels=channel_dims[2], 118 | depth=depths[2], 119 | stride=2, 120 | ) 121 | 122 | self.avg_pool = nn.AvgPool2d(8) 123 | self.fc = PoincareLinear( 124 | in_features=channel_dims[2], 125 | out_features=self.classes, 126 | ball=self.ball, 127 | ) 128 | 129 | def forward(self, x: torch.Tensor) -> torch.Tensor: 130 | x = self.conv(x) 131 | x = self.bn(x) 132 | x = self.relu(x) 133 | x = self.group1(x) 134 | x = self.group2(x) 135 | x = self.group3(x) 136 | x = self.avg_pool(x) 137 | x = self.fc(x.squeeze()) 138 | return x 139 | 140 | def _make_group( 141 | self, 142 | in_channels: int, 143 | out_channels: int, 144 | depth: int, 145 | stride: int = 1, 146 | ) -> nn.Sequential: 147 | downsample = None 148 | if stride != 1: 149 | downsample = nn.Conv2d( 150 | in_channels=in_channels, 151 | out_channels=out_channels, 152 | kernel_size=1, 153 | stride=stride, 154 | padding=0, 155 | bias=True, 156 | ) 157 | 158 | layers = [ 159 | ResidualBlock( 160 | in_channels=in_channels, 161 | out_channels=out_channels, 162 | stride=stride, 163 | downsample=downsample, 164 | ) 165 | ] 166 | 167 | for _ in range(1, depth): 168 | layers.append( 169 | ResidualBlock(in_channels=out_channels, out_channels=out_channels) 170 | ) 171 | 172 | return nn.Sequential(*layers) 173 | -------------------------------------------------------------------------------- /ood_utils/svhn_loader.py: -------------------------------------------------------------------------------- 1 | import os 2 | import os.path 3 | 4 | import numpy as np 5 | import torch.utils.data as data 6 | from PIL import Image 7 | 8 | 9 | class SVHN(data.Dataset): 10 | url = "" 11 | filename = "" 12 | file_md5 = "" 13 | 14 | split_list = { 15 | "train": [ 16 | "http://ufldl.stanford.edu/housenumbers/train_32x32.mat", 17 | "train_32x32.mat", 18 | "e26dedcc434d2e4c54c9b2d4a06d8373", 19 | ], 20 | "test": [ 21 | "http://ufldl.stanford.edu/housenumbers/test_32x32.mat", 22 | "test_32x32.mat", 23 | "eb5a983be6a315427106f1b164d9cef3", 24 | ], 25 | "extra": [ 26 | "http://ufldl.stanford.edu/housenumbers/extra_32x32.mat", 27 | "extra_32x32.mat", 28 | "a93ce644f1a588dc4d68dda5feec44a7", 29 | ], 30 | "train_and_extra": [ 31 | [ 32 | "http://ufldl.stanford.edu/housenumbers/train_32x32.mat", 33 | "train_32x32.mat", 34 | "e26dedcc434d2e4c54c9b2d4a06d8373", 35 | ], 36 | [ 37 | "http://ufldl.stanford.edu/housenumbers/extra_32x32.mat", 38 | "extra_32x32.mat", 39 | "a93ce644f1a588dc4d68dda5feec44a7", 40 | ], 41 | ], 42 | } 43 | 44 | def __init__( 45 | self, root, split="train", transform=None, target_transform=None, download=False 46 | ): 47 | self.root = root 48 | self.transform = transform 49 | self.target_transform = target_transform 50 | self.split = split # training set or test set or extra set 51 | 52 | if self.split not in self.split_list: 53 | raise ValueError( 54 | 'Wrong split entered! Please use split="train" ' 55 | 'or split="extra" or split="test" ' 56 | 'or split="train_and_extra" ' 57 | ) 58 | 59 | if self.split == "train_and_extra": 60 | self.url = self.split_list[split][0][0] 61 | self.filename = self.split_list[split][0][1] 62 | self.file_md5 = self.split_list[split][0][2] 63 | else: 64 | self.url = self.split_list[split][0] 65 | self.filename = self.split_list[split][1] 66 | self.file_md5 = self.split_list[split][2] 67 | 68 | # import here rather than at top of file because this is 69 | # an optional dependency for torchvision 70 | import scipy.io as sio 71 | 72 | # reading(loading) mat file as array 73 | loaded_mat = sio.loadmat(os.path.join(root, self.filename)) 74 | 75 | if self.split == "test": 76 | self.data = loaded_mat["X"] 77 | self.targets = loaded_mat["y"] 78 | # Note label 10 == 0 so modulo operator required 79 | self.targets = ( 80 | self.targets % 10 81 | ).squeeze() # convert to zero-based indexing 82 | self.data = np.transpose(self.data, (3, 2, 0, 1)) 83 | else: 84 | self.data = loaded_mat["X"] 85 | self.targets = loaded_mat["y"] 86 | 87 | if self.split == "train_and_extra": 88 | extra_filename = self.split_list[split][1][1] 89 | loaded_mat = sio.loadmat(os.path.join(root, extra_filename)) 90 | self.data = np.concatenate([self.data, loaded_mat["X"]], axis=3) 91 | self.targets = np.vstack((self.targets, loaded_mat["y"])) 92 | # Note label 10 == 0 so modulo operator required 93 | self.targets = ( 94 | self.targets % 10 95 | ).squeeze() # convert to zero-based indexing 96 | self.data = np.transpose(self.data, (3, 2, 0, 1)) 97 | 98 | def __getitem__(self, index): 99 | if self.split == "test": 100 | img, target = self.data[index], self.targets[index] 101 | else: 102 | img, target = self.data[index], self.targets[index] 103 | 104 | # doing this so that it is consistent with all other datasets 105 | # to return a PIL Image 106 | img = Image.fromarray(np.transpose(img, (1, 2, 0))) 107 | 108 | if self.transform is not None: 109 | img = self.transform(img) 110 | 111 | if self.target_transform is not None: 112 | target = self.target_transform(target) 113 | 114 | return img, target 115 | 116 | def __len__(self): 117 | if self.split == "test": 118 | return len(self.data) 119 | else: 120 | return len(self.data) 121 | 122 | def _check_integrity(self): 123 | root = self.root 124 | if self.split == "train_and_extra": 125 | md5 = self.split_list[self.split][0][2] 126 | fpath = os.path.join(root, self.filename) 127 | train_integrity = check_integrity(fpath, md5) 128 | extra_filename = self.split_list[self.split][1][1] 129 | md5 = self.split_list[self.split][1][2] 130 | fpath = os.path.join(root, extra_filename) 131 | return check_integrity(fpath, md5) and train_integrity 132 | else: 133 | md5 = self.split_list[self.split][2] 134 | fpath = os.path.join(root, self.filename) 135 | return check_integrity(fpath, md5) 136 | 137 | def download(self): 138 | if self.split == "train_and_extra": 139 | md5 = self.split_list[self.split][0][2] 140 | download_url(self.url, self.root, self.filename, md5) 141 | extra_filename = self.split_list[self.split][1][1] 142 | md5 = self.split_list[self.split][1][2] 143 | download_url(self.url, self.root, extra_filename, md5) 144 | else: 145 | md5 = self.split_list[self.split][2] 146 | download_url(self.url, self.root, self.filename, md5) 147 | -------------------------------------------------------------------------------- /ood_utils/display_results.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import sklearn.metrics as sk 3 | 4 | recall_level_default = 0.95 5 | 6 | 7 | def stable_cumsum(arr, rtol=1e-05, atol=1e-08): 8 | """Use high precision for cumsum and check that final value matches sum 9 | Parameters 10 | ---------- 11 | arr : array-like 12 | To be cumulatively summed as flat 13 | rtol : float 14 | Relative tolerance, see ``np.allclose`` 15 | atol : float 16 | Absolute tolerance, see ``np.allclose`` 17 | """ 18 | out = np.cumsum(arr, dtype=np.float64) 19 | expected = np.sum(arr, dtype=np.float64) 20 | if not np.allclose(out[-1], expected, rtol=rtol, atol=atol): 21 | raise RuntimeError( 22 | "cumsum was found to be unstable: " 23 | "its last element does not correspond to sum" 24 | ) 25 | return out 26 | 27 | 28 | def fpr_and_fdr_at_recall( 29 | y_true, y_score, recall_level=recall_level_default, pos_label=None 30 | ): 31 | classes = np.unique(y_true) 32 | if pos_label is None and not ( 33 | np.array_equal(classes, [0, 1]) 34 | or np.array_equal(classes, [-1, 1]) 35 | or np.array_equal(classes, [0]) 36 | or np.array_equal(classes, [-1]) 37 | or np.array_equal(classes, [1]) 38 | ): 39 | raise ValueError("Data is not binary and pos_label is not specified") 40 | elif pos_label is None: 41 | pos_label = 1.0 42 | 43 | # make y_true a boolean vector 44 | y_true = y_true == pos_label 45 | 46 | # sort scores and corresponding truth values 47 | desc_score_indices = np.argsort(y_score, kind="mergesort")[::-1] 48 | y_score = y_score[desc_score_indices] 49 | y_true = y_true[desc_score_indices] 50 | 51 | # y_score typically has many tied values. Here we extract 52 | # the indices associated with the distinct values. We also 53 | # concatenate a value for the end of the curve. 54 | distinct_value_indices = np.where(np.diff(y_score))[0] 55 | threshold_idxs = np.r_[distinct_value_indices, y_true.size - 1] 56 | 57 | # accumulate the true positives with decreasing threshold 58 | tps = stable_cumsum(y_true)[threshold_idxs] 59 | fps = 1 + threshold_idxs - tps # add one because of zero-based indexing 60 | 61 | thresholds = y_score[threshold_idxs] 62 | 63 | recall = tps / tps[-1] 64 | 65 | last_ind = tps.searchsorted(tps[-1]) 66 | sl = slice(last_ind, None, -1) # [last_ind::-1] 67 | recall, fps, tps, thresholds = ( 68 | np.r_[recall[sl], 1], 69 | np.r_[fps[sl], 0], 70 | np.r_[tps[sl], 0], 71 | thresholds[sl], 72 | ) 73 | 74 | cutoff = np.argmin(np.abs(recall - recall_level)) 75 | 76 | return fps[cutoff] / (np.sum(np.logical_not(y_true))) 77 | 78 | 79 | def get_measures(_pos, _neg, recall_level=recall_level_default): 80 | pos = np.array(_pos[:]).reshape((-1, 1)) 81 | neg = np.array(_neg[:]).reshape((-1, 1)) 82 | examples = np.squeeze(np.vstack((pos, neg))) 83 | labels = np.zeros(len(examples), dtype=np.int32) 84 | labels[: len(pos)] += 1 85 | 86 | auroc = sk.roc_auc_score(labels, examples) 87 | aupr = sk.average_precision_score(labels, examples) 88 | fpr = fpr_and_fdr_at_recall(labels, examples, recall_level) 89 | 90 | return auroc, aupr, fpr 91 | 92 | 93 | def show_performance(pos, neg, method_name="Ours", recall_level=recall_level_default): 94 | """ 95 | :param pos: 1's class, class to detect, outliers, or wrongly predicted 96 | example scores 97 | :param neg: 0's class scores 98 | """ 99 | 100 | auroc, aupr, fpr = get_measures(pos[:], neg[:], recall_level) 101 | 102 | print("\t\t\t" + method_name) 103 | print("FPR{:d}:\t\t\t{:.2f}".format(int(100 * recall_level), 100 * fpr)) 104 | print("AUROC:\t\t\t{:.2f}".format(100 * auroc)) 105 | print("AUPR:\t\t\t{:.2f}".format(100 * aupr)) 106 | 107 | 108 | def print_measures( 109 | auroc, aupr, fpr, method_name="Ours", recall_level=recall_level_default 110 | ): 111 | print("\t\t\t\t" + method_name) 112 | print(" FPR{:d} AUROC AUPR".format(int(100 * recall_level))) 113 | print("& {:.2f} & {:.2f} & {:.2f}".format(100 * fpr, 100 * auroc, 100 * aupr)) 114 | 115 | 116 | def print_measures_with_std( 117 | aurocs, auprs, fprs, method_name="Ours", recall_level=recall_level_default 118 | ): 119 | print("\t\t\t\t" + method_name) 120 | print(" FPR{:d} AUROC AUPR".format(int(100 * recall_level))) 121 | print( 122 | "& {:.2f} & {:.2f} & {:.2f}".format( 123 | 100 * np.mean(fprs), 100 * np.mean(aurocs), 100 * np.mean(auprs) 124 | ) 125 | ) 126 | print( 127 | "& {:.2f} & {:.2f} & {:.2f}".format( 128 | 100 * np.std(fprs), 100 * np.std(aurocs), 100 * np.std(auprs) 129 | ) 130 | ) 131 | 132 | 133 | def show_performance_comparison( 134 | pos_base, 135 | neg_base, 136 | pos_ours, 137 | neg_ours, 138 | baseline_name="Baseline", 139 | method_name="Ours", 140 | recall_level=recall_level_default, 141 | ): 142 | """ 143 | :param pos_base: 1's class, class to detect, outliers, or wrongly predicted 144 | example scores from the baseline 145 | :param neg_base: 0's class scores generated by the baseline 146 | """ 147 | auroc_base, aupr_base, fpr_base = get_measures( 148 | pos_base[:], neg_base[:], recall_level 149 | ) 150 | auroc_ours, aupr_ours, fpr_ours = get_measures( 151 | pos_ours[:], neg_ours[:], recall_level 152 | ) 153 | 154 | print("\t\t\t" + baseline_name + "\t" + method_name) 155 | print( 156 | "FPR{:d}:\t\t\t{:.2f}\t\t{:.2f}".format( 157 | int(100 * recall_level), 100 * fpr_base, 100 * fpr_ours 158 | ) 159 | ) 160 | print("AUROC:\t\t\t{:.2f}\t\t{:.2f}".format(100 * auroc_base, 100 * auroc_ours)) 161 | print("AUPR:\t\t\t{:.2f}\t\t{:.2f}".format(100 * aupr_base, 100 * aupr_ours)) 162 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import json 3 | import os 4 | import time 5 | from datetime import datetime 6 | 7 | import torch 8 | from timm import utils 9 | from torch.utils.tensorboard import SummaryWriter 10 | 11 | from cifar10.dataloader import Cifar10DataLoaderFactory 12 | from cifar100.dataloader import Cifar100DataLoaderFactory 13 | from models.optimizers import initialize_optimizer 14 | from models.resnets import parse_model_from_name 15 | 16 | parser = argparse.ArgumentParser(description="PyTorch CIFAR-10 training") 17 | 18 | parser.add_argument("model", type=str, help="Model name") 19 | parser.add_argument( 20 | "dataset", 21 | type=str, 22 | choices=["cifar10", "cifar100"], 23 | help="Dataset (cifar10, cifar100)", 24 | ) 25 | parser.add_argument( 26 | "-b", 27 | "--batch-size", 28 | type=int, 29 | default=128, 30 | help="Overwrite batch size (default: 128)", 31 | ) 32 | parser.add_argument( 33 | "-e", 34 | "--epochs", 35 | type=int, 36 | default=500, 37 | help="Number of epochs (default: 500)", 38 | ) 39 | parser.add_argument( 40 | "--opt", 41 | type=str, 42 | default="sgd", 43 | help="Optimizer (default: sgd)", 44 | ) 45 | parser.add_argument( 46 | "--lr", 47 | type=float, 48 | default=0.001, 49 | help="Learning rate (default: 0.001)", 50 | ) 51 | parser.add_argument( 52 | "--momentum", 53 | type=float, 54 | default=0.9, 55 | help="Momentum (default: 0.9)", 56 | ) 57 | parser.add_argument( 58 | "--weight-decay", 59 | type=float, 60 | default=1e-4, 61 | help="Weight decay (default: 1e-4)", 62 | ) 63 | parser.add_argument( 64 | "-s", 65 | "--save", 66 | action="store_const", 67 | const=True, 68 | default=False, 69 | help="Save the model weights after training", 70 | ) 71 | parser.add_argument( 72 | "--criterion", 73 | type=str, 74 | default="top1", 75 | choices=["losses", "top1", "top5"], 76 | help="Choose the metric which will determine the best result (default: top1)", 77 | ) 78 | 79 | 80 | def main(): 81 | args = parser.parse_args() 82 | 83 | # Create some strings for file management 84 | now = datetime.now().strftime("%Y-%m-%d_%H-%M-%S") 85 | dir_path = os.path.dirname(os.path.realpath(__file__)) 86 | exp_dir = os.path.join(dir_path, "runs", args.dataset, args.model, now) 87 | os.makedirs(exp_dir) 88 | 89 | # Grab some dataset specific stuff 90 | if args.dataset == "cifar10": 91 | dataset_factory = Cifar10DataLoaderFactory 92 | classes = 10 93 | elif args.dataset == "cifar100": 94 | dataset_factory = Cifar100DataLoaderFactory 95 | classes = 100 96 | 97 | # Create dataloaders 98 | train_loader, test_loader = dataset_factory.create_train_loaders( 99 | batch_size=args.batch_size 100 | ) 101 | 102 | # Create model 103 | model = parse_model_from_name(args.model, classes).cuda() 104 | 105 | # Initialize tensorboard logger 106 | writer = SummaryWriter(exp_dir) 107 | 108 | # Create optimizers 109 | optimizer = initialize_optimizer( 110 | model=model, 111 | args=args, 112 | ) 113 | 114 | print(f"Using optimizer: {optimizer}") 115 | 116 | loss_fn = torch.nn.CrossEntropyLoss() 117 | 118 | best_avg_metrics = {} 119 | 120 | for epoch in range(args.epochs): 121 | epoch_start = time.time() 122 | 123 | model.train() 124 | 125 | for idx, (input, target) in enumerate(train_loader): 126 | input, target = input.cuda(), target.cuda() 127 | output = model(input) 128 | loss = loss_fn(output, target) 129 | 130 | optimizer.zero_grad() 131 | loss.backward() 132 | optimizer.step() 133 | 134 | metrics = { 135 | "losses": utils.AverageMeter(), 136 | "top1": utils.AverageMeter(), 137 | "top5": utils.AverageMeter(), 138 | } 139 | 140 | model.eval() 141 | 142 | with torch.no_grad(): 143 | for input, target in test_loader: 144 | input, target = input.cuda(), target.cuda() 145 | output = model(input) 146 | 147 | loss = loss_fn(output, target) 148 | acc1, acc5 = utils.accuracy(output, target, topk=(1, 5)) 149 | 150 | metrics["losses"].update(loss.data.item(), input.size(0)) 151 | metrics["top1"].update(acc1.item(), output.size(0)) 152 | metrics["top5"].update(acc5.item(), output.size(0)) 153 | 154 | writer.add_scalar(f"{args.criterion}/test", metrics[args.criterion].avg, epoch) 155 | 156 | if ( 157 | not best_avg_metrics 158 | or metrics[args.criterion].avg > best_avg_metrics[args.criterion] 159 | ): 160 | best_avg_metrics = {k: metrics[k].avg for k in metrics} 161 | best_model_state = model.state_dict() 162 | 163 | print( 164 | f"Epoch {epoch}: " 165 | f"Time: {time.time() - epoch_start:.3f} " 166 | f"Loss: {metrics['losses'].avg:>7.4f} " 167 | f"Acc@1: {metrics['top1'].avg:>7.4f} " 168 | f"Acc@5: {metrics['top5'].avg:>7.4f}" 169 | ) 170 | 171 | output_dict = { 172 | "best_model_state": best_model_state, 173 | "best_avg_metrics": best_avg_metrics, 174 | "last_model_state": model.state_dict(), 175 | "last_avg_metrics": {k: metrics[k].avg for k in metrics}, 176 | } 177 | 178 | # Store model weights 179 | if args.save: 180 | torch.save( 181 | output_dict["last_model_state"], 182 | os.path.join(exp_dir, f"{args.model}_weights.pth"), 183 | ) 184 | 185 | weights_dir = os.path.join(dir_path, "weights", args.dataset, args.model) 186 | os.makedirs(weights_dir) 187 | torch.save( 188 | output_dict["last_model_state"], 189 | os.path.join(weights_dir, f"{args.model}_weights.pth"), 190 | ) 191 | 192 | # Store metrics 193 | with open(f"{exp_dir}/metrics.json", "w") as file: 194 | json.dump(output_dict["last_avg_metrics"], file, indent=4) 195 | 196 | 197 | if __name__ == "__main__": 198 | main() 199 | -------------------------------------------------------------------------------- /gradcam.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | 4 | import torch 5 | import torchvision 6 | from PIL import Image 7 | from pytorch_grad_cam import GradCAM 8 | from pytorch_grad_cam.utils.image import show_cam_on_image 9 | from torch.utils.data import DataLoader 10 | from torchvision.transforms import ToTensor 11 | from torchvision.transforms.functional import to_pil_image 12 | 13 | from cifar10.dataloader import Cifar10DataLoaderFactory 14 | from cifar100.dataloader import Cifar100DataLoaderFactory 15 | from config import config 16 | from models.resnets import parse_model_from_name 17 | 18 | parser = argparse.ArgumentParser(description="PyTorch adversarial attack evaluation") 19 | 20 | parser.add_argument("model", type=str, help="Model name") 21 | parser.add_argument("dataset", type=str, help="Dataset name") 22 | parser.add_argument("-b", "--batch-size", type=int, default=32) 23 | parser.add_argument( 24 | "--count", "-c", type=int, default=100, help="Minimum number of images to input" 25 | ) 26 | parser.add_argument( 27 | "--errors-only", 28 | "-e", 29 | action="store_true", 30 | help="Only create visualizatons for wrong predictions", 31 | ) 32 | 33 | cifar10_classes = ( 34 | "plane", 35 | "car", 36 | "bird", 37 | "cat", 38 | "deer", 39 | "dog", 40 | "frog", 41 | "horse", 42 | "ship", 43 | "truck", 44 | ) 45 | 46 | 47 | if __name__ == "__main__": 48 | args = parser.parse_args() 49 | root = os.path.dirname(os.path.abspath(__file__)) 50 | 51 | if args.dataset == "cifar10": 52 | _, test_loader = Cifar10DataLoaderFactory.create_train_loaders( 53 | batch_size=args.batch_size 54 | ) 55 | classes = 10 56 | original_test_data = torchvision.datasets.CIFAR10( 57 | root=config["DATASETS"]["Cifar10"], 58 | train=False, 59 | download=False, 60 | transform=ToTensor(), 61 | ) 62 | elif args.dataset == "cifar100": 63 | _, test_loader = Cifar100DataLoaderFactory.create_train_loaders( 64 | batch_size=args.batch_size 65 | ) 66 | classes = 100 67 | original_test_data = torchvision.datasets.CIFAR10( 68 | root=config["DATASETS"]["Cifar100"], 69 | train=False, 70 | download=False, 71 | transform=ToTensor(), 72 | ) 73 | 74 | original_test_loader = DataLoader( 75 | dataset=original_test_data, batch_size=args.batch_size, shuffle=False 76 | ) 77 | 78 | def create_model_from_name_and_prefix(model_name: str, prefix: str): 79 | model_name = f"{prefix}-{model_name}" 80 | model = parse_model_from_name(model_name=model_name, classes=classes).cuda() 81 | weights_path = os.path.join( 82 | root, "weights", args.dataset, model_name, f"{model_name}_weights.pth" 83 | ) 84 | state_dict = torch.load(weights_path) 85 | model.load_state_dict(state_dict) 86 | model.eval() 87 | target_layers = [model.group3] 88 | return model, target_layers 89 | 90 | hyp_model, hyp_target_layers = create_model_from_name_and_prefix( 91 | model_name=args.model, prefix="hyperbolic" 92 | ) 93 | euc_model, euc_target_layers = create_model_from_name_and_prefix( 94 | model_name=args.model, prefix="euclidean" 95 | ) 96 | 97 | exp_dir = os.path.join( 98 | root, 99 | "gradcam", 100 | args.dataset, 101 | f"{'ERRORS-' if args.errors_only else ''}{args.model}", 102 | ) 103 | os.makedirs(exp_dir, exist_ok=True) 104 | 105 | hyp_cam = GradCAM(model=hyp_model, target_layers=hyp_target_layers, use_cuda=True) 106 | euc_cam = GradCAM(model=euc_model, target_layers=euc_target_layers, use_cuda=True) 107 | 108 | total = 0 109 | hyp_total_correct = 0 110 | euc_total_correct = 0 111 | 112 | for batch_id, ((batch, targets), (originals, _)) in enumerate( 113 | zip(test_loader, original_test_loader) 114 | ): 115 | with torch.no_grad(): 116 | hyp_logits = hyp_model(batch.cuda()) 117 | hyp_preds = torch.argmax(hyp_logits, dim=-1) 118 | hyp_correct = hyp_preds == targets.cuda() 119 | 120 | euc_logits = euc_model(batch.cuda()) 121 | euc_preds = torch.argmax(euc_logits, dim=-1) 122 | euc_correct = euc_preds == targets.cuda() 123 | 124 | total += batch.size(0) 125 | hyp_total_correct += hyp_correct.sum() 126 | euc_total_correct += euc_correct.sum() 127 | 128 | hyp_grayscale_cam = hyp_cam(input_tensor=batch, targets=None, aug_smooth=True) 129 | euc_grayscale_cam = euc_cam(input_tensor=batch, targets=None, aug_smooth=True) 130 | 131 | for im_id in range(args.batch_size): 132 | if args.errors_only and hyp_correct[im_id] and euc_correct[im_id]: 133 | continue 134 | output_img = Image.new("RGB", (32 * 3, 32)) 135 | original_img = originals[im_id, :] 136 | 137 | hyp_visualization = show_cam_on_image( 138 | original_img.movedim([0], [2]).numpy(), 139 | hyp_grayscale_cam[im_id, :], 140 | use_rgb=True, 141 | ) 142 | hyp_grad_img = Image.fromarray(hyp_visualization, "RGB") 143 | 144 | euc_visualization = show_cam_on_image( 145 | original_img.movedim([0], [2]).numpy(), 146 | euc_grayscale_cam[im_id, :], 147 | use_rgb=True, 148 | ) 149 | euc_grad_img = Image.fromarray(euc_visualization, "RGB") 150 | 151 | original_img = to_pil_image(original_img) 152 | 153 | output_img.paste(original_img, (0, 0)) 154 | output_img.paste(euc_grad_img, (32, 0)) 155 | output_img.paste(hyp_grad_img, (64, 0)) 156 | 157 | output_img.save( 158 | os.path.join( 159 | exp_dir, 160 | f"{batch_id * args.batch_size + im_id}_euc_vs_hyp_gradcam" 161 | f"_{cifar10_classes[targets[im_id]]}_{cifar10_classes[euc_preds[im_id]]}" 162 | f"_{cifar10_classes[hyp_preds[im_id]]}.jpeg", 163 | ) 164 | ) 165 | 166 | if (batch_id + 1) * args.batch_size >= args.count: 167 | break 168 | 169 | print("Hyp accuracy:", hyp_total_correct / total) 170 | print("Euc accuracy:", euc_total_correct / total) 171 | -------------------------------------------------------------------------------- /ood_detection.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | 4 | import numpy as np 5 | import torch 6 | import torch.nn.functional as F 7 | import torchvision.datasets as dset 8 | import torchvision.transforms as trn 9 | 10 | import ood_utils.svhn_loader as svhn 11 | from cifar10.dataloader import Cifar10DataLoaderFactory 12 | from cifar100.dataloader import Cifar100DataLoaderFactory 13 | from config import config 14 | from models.resnets import parse_model_from_name 15 | from ood_utils.display_results import ( 16 | get_measures, 17 | print_measures, 18 | print_measures_with_std, 19 | show_performance, 20 | ) 21 | 22 | parser = argparse.ArgumentParser( 23 | description="Evaluates a CIFAR OOD Detector", 24 | formatter_class=argparse.ArgumentDefaultsHelpFormatter, 25 | ) 26 | parser.add_argument("--model", type=str, default="allconv", help="Choose architecture.") 27 | parser.add_argument("--dataset", type=str, default="cifar10", help="Choose dataset.") 28 | parser.add_argument("--batch_size", type=int, default=200) 29 | parser.add_argument( 30 | "--num_to_avg", type=int, default=1, help="Average measures across num_to_avg runs." 31 | ) 32 | parser.add_argument("--ngpu", type=int, default=1, help="0 = CPU.") 33 | parser.add_argument("--prefetch", type=int, default=2, help="Pre-fetching threads.") 34 | parser.add_argument("--T", default=1.0, type=float, help="temperature") 35 | args = parser.parse_args() 36 | print(args) 37 | 38 | mean = [0.5, 0.5, 0.5] 39 | std = [0.5, 0.5, 0.5] 40 | 41 | test_transform = trn.Compose([trn.ToTensor(), trn.Normalize(mean, std)]) 42 | 43 | if args.dataset == "cifar10": 44 | _, test_loader = Cifar10DataLoaderFactory.create_train_loaders( 45 | batch_size=args.batch_size 46 | ) 47 | classes = 10 48 | elif args.dataset == "cifar100": 49 | _, test_loader = Cifar100DataLoaderFactory.create_train_loaders( 50 | batch_size=args.batch_size 51 | ) 52 | classes = 100 53 | 54 | model = parse_model_from_name(model_name=args.model, classes=classes).cuda() 55 | weights_path = os.path.join( 56 | os.path.dirname(os.path.abspath(__file__)), 57 | "weights", 58 | args.dataset, 59 | args.model, 60 | f"{args.model}_weights.pth", 61 | ) 62 | state_dict = torch.load(weights_path) 63 | model.load_state_dict(state_dict) 64 | model.eval() 65 | 66 | # Detection Prelims 67 | ood_num_examples = len(test_loader) * args.batch_size // 5 68 | expected_ap = ood_num_examples / (ood_num_examples + len(test_loader)) 69 | 70 | concat = lambda x: np.concatenate(x, axis=0) 71 | to_np = lambda x: x.data.cpu().numpy() 72 | 73 | 74 | def get_ood_scores(loader, in_dist=False): 75 | _score = [] 76 | _right_score = [] 77 | _wrong_score = [] 78 | 79 | with torch.no_grad(): 80 | for batch_idx, (data, target) in enumerate(loader): 81 | if batch_idx >= ood_num_examples // args.batch_size and in_dist is False: 82 | break 83 | 84 | data = data.cuda() 85 | 86 | output = model(data) 87 | smax = to_np(F.softmax(output, dim=1)) 88 | 89 | _score.append(-to_np((args.T * torch.logsumexp(output / args.T, dim=1)))) 90 | 91 | if in_dist: 92 | preds = np.argmax(smax, axis=1) 93 | targets = target.numpy().squeeze() 94 | right_indices = preds == targets 95 | wrong_indices = np.invert(right_indices) 96 | 97 | _right_score.append(-np.max(smax[right_indices], axis=1)) 98 | _wrong_score.append(-np.max(smax[wrong_indices], axis=1)) 99 | 100 | if in_dist: 101 | return ( 102 | concat(_score).copy(), 103 | concat(_right_score).copy(), 104 | concat(_wrong_score).copy(), 105 | ) 106 | else: 107 | return concat(_score)[:ood_num_examples].copy() 108 | 109 | 110 | in_score, right_score, wrong_score = get_ood_scores(test_loader, in_dist=True) 111 | 112 | num_right = len(right_score) 113 | num_wrong = len(wrong_score) 114 | print("Error Rate {:.2f}".format(100 * num_wrong / (num_wrong + num_right))) 115 | # End Detection Prelims 116 | 117 | print("\nUsing CIFAR-10 as typical data") if classes == 10 else print( 118 | "\nUsing CIFAR-100 as typical data" 119 | ) 120 | 121 | # Error Detection 122 | print("\n\nError Detection") 123 | show_performance(wrong_score, right_score, method_name=args.model) 124 | 125 | # OOD Detection 126 | auroc_list, aupr_list, fpr_list = [], [], [] 127 | 128 | 129 | def get_and_print_results(ood_loader, num_to_avg=args.num_to_avg): 130 | aurocs, auprs, fprs = [], [], [] 131 | 132 | for _ in range(num_to_avg): 133 | out_score = get_ood_scores(ood_loader) 134 | measures = get_measures(-in_score, -out_score) 135 | aurocs.append(measures[0]) 136 | auprs.append(measures[1]) 137 | fprs.append(measures[2]) 138 | 139 | print(in_score[:3], out_score[:3]) 140 | auroc = np.mean(aurocs) 141 | aupr = np.mean(auprs) 142 | fpr = np.mean(fprs) 143 | auroc_list.append(auroc) 144 | aupr_list.append(aupr) 145 | fpr_list.append(fpr) 146 | 147 | if num_to_avg >= 5: 148 | print_measures_with_std(aurocs, auprs, fprs, args.model) 149 | else: 150 | print_measures(auroc, aupr, fpr, args.model) 151 | 152 | 153 | # Textures 154 | ood_data = dset.ImageFolder( 155 | root=config["DATASETS"]["Textures"], 156 | transform=trn.Compose( 157 | [trn.Resize(32), trn.CenterCrop(32), trn.ToTensor(), trn.Normalize(mean, std)] 158 | ), 159 | ) 160 | ood_loader = torch.utils.data.DataLoader( 161 | ood_data, batch_size=args.batch_size, shuffle=True, num_workers=4, pin_memory=True 162 | ) 163 | print("\n\nTexture Detection") 164 | get_and_print_results(ood_loader) 165 | 166 | # SVHN 167 | ood_data = svhn.SVHN( 168 | root=config["DATASETS"]["SVHN"], 169 | split="test", 170 | transform=trn.Compose([trn.ToTensor(), trn.Normalize(mean, std)]), 171 | download=False, 172 | ) 173 | ood_loader = torch.utils.data.DataLoader( 174 | ood_data, batch_size=args.batch_size, shuffle=True, num_workers=2, pin_memory=True 175 | ) 176 | print("\n\nSVHN Detection") 177 | get_and_print_results(ood_loader) 178 | 179 | # Places365 180 | ood_data = dset.ImageFolder( 181 | root=config["DATASETS"]["Places365"], 182 | transform=trn.Compose( 183 | [trn.Resize(32), trn.CenterCrop(32), trn.ToTensor(), trn.Normalize(mean, std)] 184 | ), 185 | ) 186 | ood_loader = torch.utils.data.DataLoader( 187 | ood_data, batch_size=args.batch_size, shuffle=True, num_workers=2, pin_memory=True 188 | ) 189 | print("\n\nPlaces365 Detection") 190 | get_and_print_results(ood_loader) 191 | 192 | # Mean Results 193 | print("\n\nMean Test Results!!!!!") 194 | print_measures( 195 | np.mean(auroc_list), np.mean(aupr_list), np.mean(fpr_list), method_name=args.model 196 | ) 197 | -------------------------------------------------------------------------------- /models/manifolds/poincare_disk.py: -------------------------------------------------------------------------------- 1 | from typing import Optional 2 | 3 | import torch 4 | import torch.nn as nn 5 | 6 | from .math.diffgeom import * 7 | from .math.diffgeom_autograd import * 8 | from .math.linreg import * 9 | from .math.variance import frechet_variance 10 | 11 | 12 | class PoincareBallStdGrad(nn.Module): 13 | """ 14 | Class representing the Poincare ball model of hyperbolic space. 15 | 16 | Implementation based on the geoopt implementation, 17 | but changed to use hyperbolic torch functions. 18 | """ 19 | 20 | def __init__(self, c=1.0, learnable=True): 21 | super().__init__() 22 | c = torch.as_tensor(c, dtype=torch.float32) 23 | self.isp_c = nn.Parameter(c, requires_grad=learnable) 24 | self.learnable = learnable 25 | 26 | @property 27 | def c(self): 28 | return nn.functional.softplus(self.isp_c) 29 | 30 | def mobius_add(self, x: torch.Tensor, y: torch.Tensor, dim: int = -1): 31 | return dg_mobius_add(x=x, y=y, c=self.c, dim=dim) 32 | 33 | def project(self, x: torch.Tensor, dim: int = -1, eps: float = -1.0): 34 | return dg_project(x=x, c=self.c, dim=dim, eps=eps) 35 | 36 | def expmap0(self, v: torch.Tensor, dim: int = -1): 37 | return dg_expmap0(v=v, c=self.c, dim=dim) 38 | 39 | def logmap0(self, y: torch.Tensor, dim: int = -1): 40 | return dg_logmap0(y=y, c=self.c, dim=dim) 41 | 42 | def expmap(self, x: torch.Tensor, v: torch.Tensor, dim: int = -1): 43 | return dg_expmap(x=x, v=v, c=self.c, dim=dim) 44 | 45 | def logmap(self, x: torch.Tensor, y: torch.Tensor, dim: int = -1): 46 | return dg_logmap(x=x, y=y, c=self.c, dim=dim) 47 | 48 | def gyration( 49 | self, 50 | u: torch.Tensor, 51 | v: torch.Tensor, 52 | w: torch.Tensor, 53 | dim: int = -1, 54 | ): 55 | return dg_gyration(u=u, v=v, w=w, c=self.c, dim=dim) 56 | 57 | def transp( 58 | self, 59 | x: torch.Tensor, 60 | y: torch.Tensor, 61 | v: torch.Tensor, 62 | dim: int = -1, 63 | ): 64 | return dg_transp(x=x, y=y, v=v, c=self.c, dim=dim) 65 | 66 | def dist( 67 | self, 68 | x: torch.Tensor, 69 | y: torch.Tensor, 70 | dim: int = -1, 71 | ) -> torch.Tensor: 72 | return dg_dist(x=x, y=y, c=self.c, dim=dim) 73 | 74 | def mlr( 75 | self, 76 | x: torch.Tensor, 77 | z: torch.Tensor, 78 | r: torch.Tensor, 79 | ) -> torch.Tensor: 80 | return poincare_mlr(x=x, z=z, r=r, c=self.c) 81 | 82 | def fully_connected( 83 | self, 84 | x: torch.Tensor, 85 | z: torch.Tensor, 86 | bias: torch.Tensor, 87 | ) -> torch.Tensor: 88 | y = poincare_fully_connected(x=x, z=z, bias=bias, c=self.c) 89 | return self.project(y, dim=-1) 90 | 91 | def frechet_variance( 92 | self, 93 | x: torch.Tensor, 94 | mu: torch.Tensor, 95 | dim: int = -1, 96 | w: Optional[torch.Tensor] = None, 97 | ) -> torch.Tensor: 98 | return frechet_variance( 99 | x=x, mu=mu, c=self.c, dim=dim, w=w, custom_autograd=False 100 | ) 101 | 102 | 103 | class PoincareBallCustomAutograd(nn.Module): 104 | """ 105 | Class representing the Poincare ball model of hyperbolic space. 106 | 107 | Implementation based on the geoopt implementation, 108 | but changed to use custom autograd functions. 109 | """ 110 | 111 | def __init__(self, c=1.0, learnable=True): 112 | super().__init__() 113 | c = torch.as_tensor(c, dtype=torch.float32) 114 | self.isp_c = nn.Parameter(c, requires_grad=learnable) 115 | self.learnable = learnable 116 | 117 | @property 118 | def c(self) -> torch.Tensor: 119 | if self.learnable: 120 | return nn.functional.softplus(self.isp_c) 121 | else: 122 | return self.isp_c 123 | 124 | def mobius_add( 125 | self, x: torch.Tensor, y: torch.Tensor, dim: int = -1 126 | ) -> torch.Tensor: 127 | return ag_MobiusAddition.apply(x, y, self.c, dim) 128 | 129 | def project( 130 | self, x: torch.Tensor, dim: int = -1, eps: float = -1.0 131 | ) -> torch.Tensor: 132 | return ag_Project.apply(x, self.c, dim) 133 | 134 | def expmap0(self, v: torch.Tensor, dim: int = -1) -> torch.Tensor: 135 | return ag_expmap0(v, self.c, dim) 136 | 137 | def logmap0(self, y: torch.Tensor, dim: int = -1) -> torch.Tensor: 138 | return ag_LogMap0.apply(y, self.c, dim) 139 | 140 | def expmap(self, x: torch.Tensor, v: torch.Tensor, dim: int = -1) -> torch.Tensor: 141 | return ag_expmap(x, v, self.c, dim) 142 | 143 | def logmap(self, x: torch.Tensor, y: torch.Tensor, dim: int = -1) -> torch.Tensor: 144 | return ag_logmap(x, y, self.c, dim) 145 | 146 | def gyration( 147 | self, 148 | u: torch.Tensor, 149 | v: torch.Tensor, 150 | w: torch.Tensor, 151 | dim: int = -1, 152 | ) -> torch.Tensor: 153 | return ag_gyration(u, v, w, self.c, dim) 154 | 155 | def transp( 156 | self, 157 | x: torch.Tensor, 158 | y: torch.Tensor, 159 | v: torch.Tensor, 160 | dim: int = -1, 161 | ) -> torch.Tensor: 162 | return ag_transp(x, y, v, self.c, dim) 163 | 164 | def dist( 165 | self, 166 | x: torch.Tensor, 167 | y: torch.Tensor, 168 | dim: int = -1, 169 | ) -> torch.Tensor: 170 | return ( 171 | 2 172 | / self.c.sqrt() 173 | * ( 174 | self.c.sqrt() 175 | * self.mobius_add(-x, y, dim=dim).norm(dim=dim, keepdim=True) 176 | ).atanh() 177 | ) 178 | 179 | def mlr( 180 | self, 181 | x: torch.Tensor, 182 | z: torch.Tensor, 183 | r: torch.Tensor, 184 | ) -> torch.Tensor: 185 | return poincare_mlr(x=x, z=z, r=r, c=self.c) 186 | 187 | def fully_connected( 188 | self, 189 | x: torch.Tensor, 190 | z: torch.Tensor, 191 | bias: torch.Tensor, 192 | ) -> torch.Tensor: 193 | y = poincare_fully_connected(x=x, z=z, bias=bias, c=self.c) 194 | return self.project(y, dim=-1) 195 | 196 | def frechet_variance( 197 | self, 198 | x: torch.Tensor, 199 | mu: torch.Tensor, 200 | dim: int = -1, 201 | w: Optional[torch.Tensor] = None, 202 | ) -> torch.Tensor: 203 | return frechet_variance( 204 | x=x, mu=mu, c=self.c, dim=dim, w=w, custom_autograd=True 205 | ) 206 | 207 | 208 | PoincareBall = PoincareBallStdGrad | PoincareBallCustomAutograd 209 | 210 | 211 | def poincareball_factory( 212 | c: float = 1.0, custom_autograd: bool = True, learnable: bool = True 213 | ) -> PoincareBall: 214 | if custom_autograd: 215 | return PoincareBallCustomAutograd(c=c, learnable=learnable) 216 | else: 217 | return PoincareBallStdGrad(c=c, learnable=learnable) 218 | -------------------------------------------------------------------------------- /models/manifolds/math/frechet_mean.py: -------------------------------------------------------------------------------- 1 | from typing import Optional 2 | 3 | import torch 4 | import torch.nn as nn 5 | from geoopt.manifolds import PoincareBallExact 6 | from geoopt.manifolds.lorentz.math import arcosh 7 | from geoopt.manifolds.stereographic.math import dist 8 | 9 | _TOLEPS = {torch.float32: 1e-6, torch.float64: 1e-12} 10 | 11 | 12 | class FrechetMean(torch.autograd.Function): 13 | """ 14 | This implementation is copied mostly from: 15 | https://github.com/CUVL/Differentiable-Frechet-Mean.git 16 | 17 | which is itself based on the paper: 18 | https://arxiv.org/abs/2003.00335 19 | 20 | Both by Aaron Lou (et al.) 21 | """ 22 | 23 | @staticmethod 24 | def forward(ctx, x, w, K): 25 | mean = frechet_ball_forward( 26 | x, w, K, rtol=_TOLEPS[x.dtype], atol=_TOLEPS[x.dtype] 27 | ) 28 | ctx.save_for_backward(x, mean, w, K) 29 | return mean 30 | 31 | @staticmethod 32 | def backward(ctx, grad_output): 33 | X, mean, w, K = ctx.saved_tensors 34 | dx, dw, dK = frechet_ball_backward(X, mean, grad_output, w, K) 35 | return dx, dw, dK, None 36 | 37 | 38 | def frechet_mean( 39 | x: torch.Tensor, 40 | ball: torch.Tensor, 41 | w: Optional[torch.Tensor] = None, 42 | ) -> torch.Tensor: 43 | if w is None: 44 | w = torch.ones(x.shape[:-1]).to(x) 45 | return FrechetMean.apply(x, w, -ball.c) 46 | 47 | 48 | def frechet_variance( 49 | x: torch.Tensor, 50 | mu: torch.Tensor, 51 | ball: PoincareBallExact, 52 | dim: int = -1, 53 | w: Optional[torch.Tensor] = None, 54 | ) -> torch.Tensor: 55 | """ 56 | Args 57 | ---- 58 | x (tensor): points of shape [..., points, dim] 59 | mu (tensor): mean of shape [..., dim] 60 | w (tensor): weights of shape [..., points] 61 | 62 | where the ... of the three variables line up 63 | 64 | Returns 65 | ------- 66 | tensor of shape [...] 67 | """ 68 | distance: torch.Tensor = dist( 69 | x=x, 70 | y=mu, 71 | k=-ball.c, 72 | ) 73 | distance = distance.pow(2) 74 | 75 | if w is None: 76 | return distance.mean(dim=-1) 77 | else: 78 | return (distance * w).sum(dim=-1) 79 | 80 | 81 | def l_prime(y: torch.Tensor) -> torch.Tensor: 82 | cond = y < 1e-12 83 | val = 4 * torch.ones_like(y) 84 | ret = torch.where(cond, val, 2 * arcosh(1 + 2 * y) / (y.pow(2) + y).sqrt()) 85 | return ret 86 | 87 | 88 | def frechet_ball_forward( 89 | X: torch.Tensor, 90 | w: torch.Tensor, 91 | K: torch.Tensor = torch.Tensor([-1]), 92 | max_iter: int = 1000, 93 | rtol: float = 1e-6, 94 | atol: float = 1e-6, 95 | ) -> torch.Tensor: 96 | """ 97 | Args 98 | ---- 99 | X (tensor): point of shape [..., points, dim] 100 | w (tensor): weights of shape [..., points] 101 | K (float): curvature (must be negative) 102 | 103 | Returns 104 | ------- 105 | frechet mean (tensor): shape [..., dim] 106 | """ 107 | mu = X[..., 0, :].clone() 108 | 109 | x_ss = X.pow(2).sum(dim=-1) 110 | 111 | mu_prev = mu 112 | iters = 0 113 | for _ in range(max_iter): 114 | mu_ss = mu.pow(2).sum(dim=-1) 115 | xmu_ss = (X - mu.unsqueeze(-2)).pow(2).sum(dim=-1) 116 | 117 | alphas = l_prime( 118 | -K * xmu_ss / ((1 + K * x_ss) * (1 + K * mu_ss.unsqueeze(-1))) 119 | ) / (1 + K * x_ss) 120 | 121 | alphas = alphas * w 122 | 123 | c = (alphas * x_ss).sum(dim=-1) 124 | b = (alphas.unsqueeze(-1) * X).sum(dim=-2) 125 | a = alphas.sum(dim=-1) 126 | 127 | b_ss = b.pow(2).sum(dim=-1) 128 | 129 | eta = (a - K * c - ((a - K * c).pow(2) + 4 * K * b_ss).sqrt()) / ( 130 | 2 * (-K) * b_ss 131 | ) 132 | 133 | mu = eta.unsqueeze(-1) * b 134 | 135 | dist = (mu - mu_prev).norm(dim=-1) 136 | prev_dist = mu_prev.norm(dim=-1) 137 | if (dist < atol).all() or (dist / prev_dist < rtol).all(): 138 | break 139 | 140 | mu_prev = mu 141 | iters += 1 142 | 143 | return mu 144 | 145 | 146 | def darcosh(x): 147 | cond = x < 1 + 1e-7 148 | x = torch.where(cond, 2 * torch.ones_like(x), x) 149 | x = torch.where(~cond, 2 * arcosh(x) / torch.sqrt(x**2 - 1), x) 150 | return x 151 | 152 | 153 | def d2arcosh(x): 154 | cond = x < 1 + 1e-7 155 | x = torch.where(cond, -2 / 3 * torch.ones_like(x), x) 156 | x = torch.where( 157 | ~cond, 2 / (x**2 - 1) - 2 * x * arcosh(x) / ((x**2 - 1) ** (3 / 2)), x 158 | ) 159 | return x 160 | 161 | 162 | def grad_var( 163 | X: torch.Tensor, 164 | y: torch.Tensor, 165 | w: torch.Tensor, 166 | K: torch.Tensor, 167 | ) -> torch.Tensor: 168 | """ 169 | Args 170 | ---- 171 | X (tensor): point of shape [..., points, dim] 172 | y (tensor): mean point of shape [..., dim] 173 | w (tensor): weight tensor of shape [..., points] 174 | K (float): curvature (must be negative) 175 | 176 | Returns 177 | ------- 178 | grad (tensor): gradient of variance [..., dim] 179 | """ 180 | yl = y.unsqueeze(-2) 181 | xnorm = 1 + K * X.norm(dim=-1).pow(2) 182 | ynorm = 1 + K * yl.norm(dim=-1).pow(2) 183 | xynorm = (X - yl).norm(dim=-1).pow(2) 184 | 185 | D = xnorm * ynorm 186 | v = 1 - 2 * K * xynorm / D 187 | 188 | Dl = D.unsqueeze(-1) 189 | vl = v.unsqueeze(-1) 190 | 191 | first_term = (X - yl) / Dl 192 | sec_term = K / Dl.pow(2) * yl * xynorm.unsqueeze(-1) * xnorm.unsqueeze(-1) 193 | return -(4 * darcosh(vl) * w.unsqueeze(-1) * (first_term + sec_term)).sum(dim=-2) 194 | 195 | 196 | def inverse_hessian( 197 | X: torch.Tensor, 198 | y: torch.Tensor, 199 | w: torch.Tensor, 200 | K: torch.Tensor, 201 | ) -> torch.Tensor: 202 | """ 203 | Args 204 | ---- 205 | X (tensor): point of shape [..., points, dim] 206 | y (tensor): mean point of shape [..., dim] 207 | w (tensor): weight tensor of shape [..., points] 208 | K (float): curvature (must be negative) 209 | 210 | Returns 211 | ------- 212 | inv_hess (tensor): inverse hessian of [..., points, dim, dim] 213 | """ 214 | yl = y.unsqueeze(-2) 215 | xnorm = 1 + K * X.norm(dim=-1).pow(2) 216 | ynorm = 1 + K * yl.norm(dim=-1).pow(2) 217 | xynorm = (X - yl).norm(dim=-1).pow(2) 218 | 219 | D = xnorm * ynorm 220 | v = 1 - 2 * K * xynorm / D 221 | 222 | Dl = D.unsqueeze(-1) 223 | vl = v.unsqueeze(-1) 224 | vll = vl.unsqueeze(-1) 225 | 226 | """ 227 | \partial T/ \partial y 228 | """ 229 | first_const = -8 * (K**2) * xnorm / D.pow(2) 230 | matrix_val = (first_const.unsqueeze(-1) * yl).unsqueeze(-1) * (X - yl).unsqueeze(-2) 231 | first_term = matrix_val + matrix_val.transpose(-1, -2) 232 | 233 | sec_const = -16 * (K**3) * xnorm.pow(2) / D.pow(3) * xynorm 234 | sec_term = (sec_const.unsqueeze(-1) * yl).unsqueeze(-1) * yl.unsqueeze(-2) 235 | 236 | third_const = -4 * K / D + 4 * (K**2) * xnorm / D.pow(2) * xynorm 237 | third_term = third_const.reshape(*third_const.shape, 1, 1) * torch.eye( 238 | y.shape[-1] 239 | ).to(X).reshape((1,) * len(third_const.shape) + (y.shape[-1], y.shape[-1])) 240 | 241 | Ty = first_term + sec_term + third_term 242 | 243 | """ 244 | T 245 | """ 246 | 247 | first_term = K / Dl * (X - yl) 248 | sec_term = K.pow(2) / Dl.pow(2) * yl * xynorm.unsqueeze(-1) * xnorm.unsqueeze(-1) 249 | T = 4 * (first_term + sec_term) 250 | 251 | """ 252 | inverse of shape [..., points, dim, dim] 253 | """ 254 | first_term = d2arcosh(vll) * T.unsqueeze(-1) * T.unsqueeze(-2) 255 | sec_term = darcosh(vll) * Ty 256 | hessian = ((first_term + sec_term) * w.unsqueeze(-1).unsqueeze(-1)).sum(dim=-3) / -K 257 | inv_hess = torch.inverse(hessian) 258 | return inv_hess 259 | 260 | 261 | def frechet_ball_backward( 262 | X: torch.Tensor, 263 | y: torch.Tensor, 264 | grad: torch.Tensor, 265 | w: torch.Tensor, 266 | K: torch.Tensor, 267 | ) -> tuple[torch.Tensor]: 268 | """ 269 | Args 270 | ---- 271 | X (tensor): point of shape [..., points, dim] 272 | y (tensor): mean point of shape [..., dim] 273 | grad (tensor): gradient 274 | K (float): curvature (must be negative) 275 | 276 | Returns 277 | ------- 278 | gradients (tensor, tensor, tensor): 279 | gradient of X [..., points, dim], weights [..., dim], curvature [] 280 | """ 281 | if not torch.is_tensor(K): 282 | K = torch.tensor(K).to(X) 283 | 284 | with torch.no_grad(): 285 | inv_hess = inverse_hessian(X, y, w=w, K=K) 286 | 287 | with torch.enable_grad(): 288 | # clone variables 289 | X = nn.Parameter(X.detach()) 290 | y = y.detach() 291 | w = nn.Parameter(w.detach()) 292 | K = nn.Parameter(K) 293 | 294 | grad = (inv_hess @ grad.unsqueeze(-1)).squeeze() 295 | gradf = grad_var(X, y, w, K) 296 | dx, dw, dK = torch.autograd.grad(-gradf.squeeze(), (X, w, K), grad) 297 | 298 | return dx, dw, dK 299 | -------------------------------------------------------------------------------- /models/resnets/hyperbolic.py: -------------------------------------------------------------------------------- 1 | from typing import Optional 2 | 3 | import torch 4 | import torch.nn as nn 5 | 6 | from ..manifolds import PoincareBall, poincareball_factory 7 | from ..nn import PoincareBatchNorm2d, PoincareConvolution2d, PoincareLinear 8 | 9 | 10 | def _conv3x3( 11 | in_channels: int, 12 | out_channels: int, 13 | ball: PoincareBall, 14 | stride: int = 1, 15 | ) -> PoincareConvolution2d: 16 | return PoincareConvolution2d( 17 | in_channels=in_channels, 18 | out_channels=out_channels, 19 | kernel_dims=(3, 3), 20 | ball=ball, 21 | bias=True, 22 | stride=stride, 23 | padding=1, 24 | ) 25 | 26 | 27 | class ResidualBlock(nn.Module): 28 | def __init__( 29 | self, 30 | in_channels: int, 31 | out_channels: int, 32 | stride: int = 1, 33 | act_layer: Optional[nn.Module] = None, 34 | downsample: Optional[nn.Sequential] = None, 35 | custom_autograd: bool = True, 36 | learnable: bool = False, 37 | init_c: float = 1, 38 | skip_connection: str = "fr", 39 | bn_midpoint: bool = True, 40 | ) -> None: 41 | super(ResidualBlock, self).__init__() 42 | self.in_channels = in_channels 43 | self.out_channels = out_channels 44 | self.stride = stride 45 | self.act_layer = act_layer 46 | self.learnable = learnable 47 | self.init_c = init_c 48 | self.skip_connection = skip_connection 49 | self.bn_midpoint = bn_midpoint 50 | 51 | self.balls = { 52 | "ball1": poincareball_factory( 53 | c=init_c, custom_autograd=custom_autograd, learnable=learnable 54 | ), 55 | "ball2": poincareball_factory( 56 | c=init_c, custom_autograd=custom_autograd, learnable=learnable 57 | ), 58 | "bn_ball": poincareball_factory( 59 | c=init_c, custom_autograd=custom_autograd, learnable=learnable 60 | ), 61 | "skip_ball": poincareball_factory( 62 | c=init_c, custom_autograd=custom_autograd, learnable=learnable 63 | ), 64 | } 65 | 66 | self.conv1 = _conv3x3( 67 | in_channels=in_channels, 68 | out_channels=out_channels, 69 | ball=self.balls["ball1"], 70 | stride=stride, 71 | ) 72 | self.bn1 = PoincareBatchNorm2d( 73 | out_channels, ball=self.balls["bn_ball"], use_midpoint=self.bn_midpoint 74 | ) 75 | if act_layer is not None: 76 | self.act1 = act_layer() 77 | self.conv2 = _conv3x3( 78 | in_channels=out_channels, 79 | out_channels=out_channels, 80 | ball=self.balls["ball2"], 81 | stride=1, 82 | ) 83 | self.bn2 = PoincareBatchNorm2d( 84 | out_channels, ball=self.balls["bn_ball"], use_midpoint=self.bn_midpoint 85 | ) 86 | if self.act_layer is not None: 87 | self.act2 = act_layer() 88 | self.downsample = downsample 89 | 90 | def forward(self, x: torch.Tensor) -> torch.Tensor: 91 | residual = x 92 | x = self.conv1(x) 93 | x = self.bn1(x) 94 | if self.act_layer is not None: 95 | x = self.act1(x) 96 | x = self.conv2(x) 97 | x = self.bn2(x) 98 | 99 | if self.downsample is not None: 100 | residual = self.downsample(residual) 101 | 102 | # Skip connection with Mobius addition in Poincare ball (fr: f(x) + res, rf: res + f(x)). 103 | x = self.balls["skip_ball"].expmap0(x, dim=-1) 104 | residual = self.balls["skip_ball"].expmap0(residual, dim=-1) 105 | if self.skip_connection == "fr": 106 | x = self.balls["skip_ball"].mobius_add(x, residual) 107 | elif self.skip_connection == "rf": 108 | x = self.balls["skip_ball"].mobius_add(residual, x) 109 | x = self.balls["skip_ball"].logmap0(x, dim=-1) 110 | 111 | if self.act_layer is not None: 112 | x = self.act2(x) 113 | 114 | return x 115 | 116 | 117 | class HyperbolicResNet(nn.Module): 118 | """Hyperbolic Residual Networks 119 | 120 | Implementation of Residual Networks as described in: https://arxiv.org/pdf/1512.03385.pdf 121 | but with hyperbolic operations defined on the Poincare disk instead of Euclidean 122 | """ 123 | 124 | def __init__( 125 | self, 126 | classes: int, 127 | channel_dims: list[int], 128 | depths: list[int], 129 | act_layer: Optional[nn.Module] = nn.ReLU, 130 | init_c: float = 0.1, 131 | custom_autograd: bool = True, 132 | learnable: bool = False, 133 | skip_connection: str = "fr", 134 | bn_midpoint: bool = True, 135 | ) -> None: 136 | super(HyperbolicResNet, self).__init__() 137 | self.classes = classes 138 | self.channel_dims = channel_dims 139 | self.depths = depths 140 | self.act_layer = act_layer 141 | self.init_c = init_c 142 | self.custom_autograd = custom_autograd 143 | self.learnable = learnable 144 | self.skip_connection = skip_connection 145 | self.bn_midpoint = bn_midpoint 146 | 147 | self.conv_ball = poincareball_factory( 148 | c=init_c, custom_autograd=custom_autograd, learnable=learnable 149 | ) 150 | self.bn_ball = poincareball_factory( 151 | c=init_c, custom_autograd=custom_autograd, learnable=learnable 152 | ) 153 | self.linear_ball = poincareball_factory( 154 | c=init_c, custom_autograd=custom_autograd, learnable=learnable 155 | ) 156 | 157 | self.conv = _conv3x3( 158 | in_channels=3, 159 | out_channels=channel_dims[0], 160 | ball=self.conv_ball, 161 | ) 162 | self.bn = PoincareBatchNorm2d( 163 | channel_dims[0], ball=self.bn_ball, use_midpoint=self.bn_midpoint 164 | ) 165 | if act_layer is not None: 166 | self.act = act_layer() 167 | 168 | self.group1 = self._make_group( 169 | in_channels=channel_dims[0], 170 | out_channels=channel_dims[0], 171 | depth=depths[0], 172 | ) 173 | 174 | self.group2 = self._make_group( 175 | in_channels=channel_dims[0], 176 | out_channels=channel_dims[1], 177 | depth=depths[1], 178 | stride=2, 179 | ) 180 | 181 | self.group3 = self._make_group( 182 | in_channels=channel_dims[1], 183 | out_channels=channel_dims[2], 184 | depth=depths[2], 185 | stride=2, 186 | ) 187 | 188 | self.avg_pool = nn.AvgPool2d(8) 189 | 190 | self.fc = PoincareLinear( 191 | in_features=channel_dims[2], 192 | out_features=classes, 193 | ball=self.linear_ball, 194 | ) 195 | 196 | def forward(self, x: torch.Tensor) -> torch.Tensor: 197 | x = self.conv(x) 198 | x = self.bn(x) 199 | if self.act_layer is not None: 200 | x = self.act(x) 201 | x = self.group1(x) 202 | x = self.group2(x) 203 | x = self.group3(x) 204 | x = self.avg_pool(x) 205 | x = self.fc(x.squeeze()) 206 | return x 207 | 208 | def _make_group( 209 | self, 210 | in_channels: int, 211 | out_channels: int, 212 | depth: int, 213 | stride: int = 1, 214 | ) -> nn.Sequential: 215 | downsample = None 216 | if stride != 1: 217 | downsample_ball = poincareball_factory( 218 | c=self.init_c, 219 | custom_autograd=self.custom_autograd, 220 | learnable=self.learnable, 221 | ) 222 | downsample = PoincareConvolution2d( 223 | in_channels=in_channels, 224 | out_channels=out_channels, 225 | kernel_dims=(1, 1), 226 | ball=downsample_ball, 227 | bias=True, 228 | stride=stride, 229 | padding=0, 230 | ) 231 | 232 | layers = [ 233 | ResidualBlock( 234 | in_channels=in_channels, 235 | out_channels=out_channels, 236 | stride=stride, 237 | act_layer=self.act_layer, 238 | downsample=downsample, 239 | custom_autograd=self.custom_autograd, 240 | learnable=self.learnable, 241 | init_c=self.init_c, 242 | skip_connection=self.skip_connection, 243 | bn_midpoint=self.bn_midpoint, 244 | ) 245 | ] 246 | 247 | for _ in range(1, depth): 248 | layers.append( 249 | ResidualBlock( 250 | in_channels=out_channels, 251 | out_channels=out_channels, 252 | act_layer=self.act_layer, 253 | custom_autograd=self.custom_autograd, 254 | learnable=self.learnable, 255 | init_c=self.init_c, 256 | skip_connection=self.skip_connection, 257 | bn_midpoint=self.bn_midpoint, 258 | ) 259 | ) 260 | 261 | return nn.Sequential(*layers) 262 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright [yyyy] [name of copyright owner] 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /models/manifolds/math/diffgeom_autograd.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | 4 | class ag_MobiusAddition(torch.autograd.Function): 5 | @staticmethod 6 | def forward(ctx, x, y, c, dim=-1): 7 | x2 = x.pow(2).sum(dim=dim, keepdim=True) 8 | y2 = y.pow(2).sum(dim=dim, keepdim=True) 9 | xy = (x * y).sum(dim=dim, keepdim=True) 10 | a = 1 + 2 * c * xy + c * y2 11 | b = 1 - c * x2 12 | denom = (1 + 2 * c * xy + c**2 * x2 * y2).clamp_min(1e-15) 13 | 14 | ctx.save_for_backward(x, y, c, x2, y2, xy, a, b, denom) 15 | ctx.dim = dim 16 | 17 | return (a * x + b * y) / denom 18 | 19 | @staticmethod 20 | def backward(ctx, grad_output): 21 | x, y, c, x2, y2, xy, a, b, denom = ctx.saved_tensors 22 | dim = ctx.dim 23 | 24 | denom_pow = (1 / denom).pow(2) 25 | 26 | utx = (grad_output * x).sum(dim=dim, keepdim=True) 27 | uty = (grad_output * y).sum(dim=dim, keepdim=True) 28 | theta = a * utx + b * uty 29 | k = 2 * c / denom 30 | theta_frac = theta / denom 31 | 32 | x_grad = ( 33 | a / denom * grad_output 34 | - k * (theta_frac * c * y2 + uty) * x 35 | - k * (theta_frac - utx) * y 36 | ) 37 | 38 | y_grad = ( 39 | b / denom * grad_output 40 | + k * (utx - theta_frac) * x 41 | + k * (utx - c * x2 * theta_frac) * y 42 | ) 43 | 44 | c_grad = 1 / ( 45 | denom * ((2 * xy + y2) * utx - x2 * uty).clamp_min(1e-15) 46 | ) - denom_pow * 2 * (xy + c * x2 * y2) * (a * utx + b * uty) 47 | 48 | if x_grad.isinf().any() or y_grad.isinf().any() or c_grad.isinf().any(): 49 | raise ValueError("Exploded gradient encountered") 50 | 51 | if x_grad.isnan().any() or y_grad.isnan().any() or c_grad.isnan().any(): 52 | raise ValueError("Exploded gradient encountered") 53 | 54 | return ( 55 | x_grad, 56 | y_grad, 57 | c_grad, 58 | None, 59 | ) 60 | 61 | 62 | class ag_Project(torch.autograd.Function): 63 | """ 64 | Autograd implementation of Poincare project function. 65 | """ 66 | 67 | @staticmethod 68 | def forward(ctx, x, c, dim=-1): 69 | if x.dtype == torch.float32: 70 | eps = 4e-3 71 | else: 72 | eps = 1e-5 73 | maxnorm = (1 - eps) / ((c + 1e-15) ** 0.5) 74 | maxnorm = torch.where(c.gt(0), maxnorm, c.new_full((), 1e15)) 75 | norm = x.norm(dim=dim, keepdim=True, p=2).clamp_min(1e-15) 76 | cond = norm > maxnorm 77 | projected = x / norm * maxnorm 78 | 79 | ctx.save_for_backward(x, c, maxnorm, norm, cond) 80 | ctx.dim = dim 81 | 82 | return torch.where(cond, projected, x) 83 | 84 | @staticmethod 85 | def backward(ctx, grad_output): 86 | x, c, maxnorm, norm, cond = ctx.saved_tensors 87 | dim = ctx.dim 88 | 89 | utx = (grad_output * x).sum(dim=dim, keepdim=True) 90 | 91 | x_grad = ( 92 | torch.where( 93 | cond, maxnorm / norm, torch.as_tensor(1, dtype=x.dtype, device=x.device) 94 | ) 95 | * grad_output 96 | - cond * utx * maxnorm / (norm.pow(3)).clamp_min(1e-15) * x 97 | ) 98 | 99 | c_grad = -(cond * utx * maxnorm / (2 * (c + 1e-15) * norm)) 100 | 101 | if x_grad.isinf().any() or c_grad.isinf().any(): 102 | raise ValueError("Exploded gradient encountered") 103 | 104 | if x_grad.isnan().any() or c_grad.isnan().any(): 105 | raise ValueError("Exploded gradient encountered") 106 | 107 | if c_grad.abs().gt(1e5).any(): 108 | print(ctx.__class__.__name__, c_grad) 109 | 110 | return x_grad, c_grad, None 111 | 112 | 113 | class ag_ExpMap0(torch.autograd.Function): 114 | @staticmethod 115 | def forward(ctx, v, c, dim=-1): 116 | v_norm = v.norm(dim=dim, keepdim=True).clamp_min(1e-15) 117 | v_norm_c_sqrt = v_norm * c.sqrt() 118 | v_norm_c_sqrt_tanh = v_norm_c_sqrt.tanh() 119 | ctx.save_for_backward(v, c, v_norm, v_norm_c_sqrt, v_norm_c_sqrt_tanh) 120 | ctx.dim = dim 121 | return v_norm_c_sqrt_tanh * v / v_norm_c_sqrt 122 | 123 | @staticmethod 124 | def backward(ctx, grad_output): 125 | v, c, v_norm, v_norm_c_sqrt, v_norm_c_sqrt_tanh = ctx.saved_tensors 126 | dim = ctx.dim 127 | 128 | v_norm_c_sqrt_cosh2 = v_norm_c_sqrt.cosh().pow(2).clamp_min(1e-15) 129 | v_norm2 = v_norm.pow(2) 130 | 131 | utv = (grad_output * v).sum(dim=dim, keepdim=True) 132 | 133 | v_grad = ( 134 | 1 / (v_norm2 * v_norm_c_sqrt_cosh2).clamp_min(1e-15) 135 | - v_norm_c_sqrt_tanh / (v_norm_c_sqrt * v_norm2).clamp_min(1e-15) 136 | ) * utv * v + v_norm_c_sqrt_tanh / v_norm_c_sqrt * grad_output 137 | 138 | c_grad = ( 139 | utv 140 | / (2 * c) 141 | * (1 / v_norm_c_sqrt_cosh2 - v_norm_c_sqrt_tanh / v_norm_c_sqrt) 142 | ) 143 | 144 | if v_grad.isinf().any() or c_grad.isinf().any(): 145 | raise ValueError("Exploded gradient encountered") 146 | 147 | if v_grad.isnan().any() or c_grad.isnan().any(): 148 | raise ValueError("Exploded gradient encountered") 149 | 150 | if c_grad.abs().gt(1e5).any(): 151 | print(ctx.__class__.__name__, c_grad) 152 | 153 | return v_grad, c_grad, None 154 | 155 | 156 | class ag_LogMap0(torch.autograd.Function): 157 | @staticmethod 158 | def forward(ctx, y, c, dim=-1): 159 | y_norm = y.norm(dim=dim, keepdim=True) 160 | y_norm_c_sqrt = y_norm.clamp_min(1e-15) * c.sqrt() 161 | y_norm_c_sqrt_atanh = y_norm_c_sqrt.atanh() 162 | 163 | ctx.save_for_backward(y, c, y_norm, y_norm_c_sqrt, y_norm_c_sqrt_atanh) 164 | ctx.dim = dim 165 | 166 | return torch.atanh(y_norm_c_sqrt) * y / y_norm_c_sqrt 167 | 168 | @staticmethod 169 | def backward(ctx, grad_output): 170 | y, c, y_norm, y_norm_c_sqrt, y_norm_c_sqrt_atanh = ctx.saved_tensors 171 | dim = ctx.dim 172 | 173 | y_norm2 = y_norm.pow(2) 174 | 175 | uty = (grad_output * y).sum(dim=dim, keepdim=True) 176 | 177 | y_grad = ( 178 | 1 / (y_norm2 * (1 - c * y_norm2)).clamp_min(1e-15) 179 | - y_norm_c_sqrt_atanh / (y_norm_c_sqrt * y_norm2).clamp_min(1e-15) 180 | ) * uty * y + y_norm_c_sqrt_atanh / y_norm_c_sqrt * grad_output 181 | 182 | c_grad = ( 183 | uty 184 | / (2 * c) 185 | * ( 186 | 1 / (1 - c * y_norm2).clamp_min(1e-15) 187 | - y_norm_c_sqrt_atanh / y_norm_c_sqrt.clamp_min(1e-15) 188 | ) 189 | ) 190 | 191 | if y_grad.isinf().any() or c_grad.isinf().any(): 192 | raise ValueError("Exploded gradient encountered") 193 | 194 | if y_grad.isnan().any() or c_grad.isnan().any(): 195 | print(c, y_norm) 196 | raise ValueError("Exploded gradient encountered") 197 | 198 | if c_grad.abs().gt(1e5).any(): 199 | print(ctx.__class__.__name__, c_grad) 200 | 201 | return y_grad, c_grad, None 202 | 203 | 204 | class ag_ConfFactor(torch.autograd.Function): 205 | """ 206 | Autograd implementation of the conformal factor lambda. 207 | """ 208 | 209 | @staticmethod 210 | def forward(ctx, x, c, dim=-1): 211 | x2 = x.pow(2).sum(dim=dim, keepdim=True) 212 | ctx.save_for_backward(x, c, x2) 213 | return 2 / (1 - c * x2).clamp_min(1e-15) 214 | 215 | @staticmethod 216 | def backward(ctx, grad_output): 217 | x, c, x2 = ctx.saved_tensors 218 | 219 | cond = c * x2 < 1 220 | 221 | x_grad = grad_output * cond * (4 * c / (1 - c * x2).pow(2).clamp_min(1e-15) * x) 222 | c_grad = grad_output * cond * (2 * x2 / (1 - c * x2).pow(2).clamp_min(1e-15)) 223 | 224 | if x_grad.isinf().any() or c_grad.isinf().any(): 225 | raise ValueError("Exploded gradient encountered") 226 | 227 | if x_grad.isnan().any() or c_grad.isnan().any(): 228 | raise ValueError("Exploded gradient encountered") 229 | 230 | if c_grad.abs().gt(1e5).any(): 231 | print(ctx.__class__.__name__, c_grad) 232 | 233 | return x_grad, c_grad, None 234 | 235 | 236 | class ag_ExpSecondTerm(torch.autograd.Function): 237 | """ 238 | Autograd implementation of the second term (rhs of Mobius addition) of the Exponential map. 239 | """ 240 | 241 | @staticmethod 242 | def forward(ctx, x, v, c, dim=-1): 243 | lambda_denom = 1 - c * x.pow(2).sum(dim=dim, keepdim=True) 244 | v_norm = v.norm(dim=dim, keepdim=True).clamp_min(1e-15) 245 | c_sqrt_v_norm = v_norm * c.sqrt() 246 | prod_of_terms = c_sqrt_v_norm / lambda_denom 247 | tanh_term = (prod_of_terms).tanh() 248 | 249 | ctx.save_for_backward( 250 | x, v, c, lambda_denom, v_norm, c_sqrt_v_norm, prod_of_terms, tanh_term 251 | ) 252 | ctx.dim = dim 253 | 254 | return tanh_term / c_sqrt_v_norm * v 255 | 256 | @staticmethod 257 | def backward(ctx, grad_output): 258 | ( 259 | x, 260 | v, 261 | c, 262 | lambda_denom, 263 | v_norm, 264 | c_sqrt_v_norm, 265 | prod_of_terms, 266 | tanh_term, 267 | ) = ctx.saved_tensors 268 | dim = ctx.dim 269 | 270 | prod_of_terms_cosh2 = prod_of_terms.cosh().pow(2).clamp_min(1e-15) 271 | 272 | utv = (grad_output * v).sum(dim=dim, keepdim=True) 273 | 274 | x_grad = ( 275 | 2 276 | * utv 277 | * c 278 | / (prod_of_terms_cosh2 * lambda_denom.pow(2)).clamp_min(1e-15) 279 | * x 280 | ) 281 | 282 | v_grad = ( 283 | 1 / (prod_of_terms_cosh2 * v_norm.pow(2) * lambda_denom).clamp_min(1e-15) 284 | - tanh_term / (c_sqrt_v_norm * v_norm.pow(2)).clamp_min(1e-15) 285 | ) * utv * v + tanh_term / c_sqrt_v_norm * grad_output 286 | 287 | c_grad = ( 288 | utv 289 | / (2 * c) 290 | * ( 291 | 1 292 | / prod_of_terms_cosh2 293 | * (2 - lambda_denom) 294 | / lambda_denom.pow(2).clamp_min(1e-15) 295 | - tanh_term / c_sqrt_v_norm 296 | ) 297 | ) 298 | 299 | if x_grad.isinf().any() or v_grad.isinf().any() or c_grad.isinf().any(): 300 | raise ValueError("Exploded gradient encountered") 301 | 302 | if x_grad.isnan().any() or v_grad.isnan().any() or c_grad.isnan().any(): 303 | raise ValueError("Exploded gradient encountered") 304 | 305 | if c_grad.abs().gt(1e5).any(): 306 | print(ctx.__class__.__name__, c_grad) 307 | 308 | return ( 309 | x_grad, 310 | v_grad, 311 | c_grad, 312 | None, 313 | ) 314 | 315 | 316 | class ag_LogScaledTerm(torch.autograd.Function): 317 | """ 318 | The scaled version of the Mobius addition that forms the output of the Logarithmic map. 319 | z = MobiusAddition.apply(-x, y) 320 | """ 321 | 322 | @staticmethod 323 | def forward(ctx, x, z, c, dim=-1): 324 | z_norm = z.norm(dim=dim, keepdim=True).clamp_min(1e-15) 325 | lambda_x_denom = 1 - c * x.pow(2).sum(dim=dim, keepdim=True) 326 | c_sqrt_z_norm = z_norm * c.sqrt() 327 | frac_of_terms = lambda_x_denom / c_sqrt_z_norm 328 | atanh_term = c_sqrt_z_norm.atanh() 329 | 330 | ctx.save_for_backward( 331 | x, z, c, z_norm, lambda_x_denom, c_sqrt_z_norm, frac_of_terms, atanh_term 332 | ) 333 | ctx.dim = dim 334 | 335 | return frac_of_terms * atanh_term * z 336 | 337 | @staticmethod 338 | def backward(ctx, grad_output): 339 | ( 340 | x, 341 | z, 342 | c, 343 | z_norm, 344 | lambda_x_denom, 345 | c_sqrt_z_norm, 346 | frac_of_terms, 347 | atanh_term, 348 | ) = ctx.saved_tensors 349 | dim = ctx.dim 350 | 351 | z_norm2 = z_norm.pow(2) 352 | lambda_z_denom = 1 - c * z_norm2 353 | 354 | utz = (grad_output * z).sum(dim=dim, keepdim=True) 355 | 356 | x_grad = -2 * utz * c * atanh_term / c_sqrt_z_norm.clamp_min(1e-15) * x 357 | 358 | z_grad = ( 359 | utz 360 | * ( 361 | lambda_x_denom / (lambda_z_denom * z_norm2).clamp_min(1e-15) 362 | - atanh_term * frac_of_terms / z_norm2.clamp_min(1e-15) 363 | ) 364 | * z 365 | + frac_of_terms * atanh_term * grad_output 366 | ) 367 | 368 | c_grad = ( 369 | utz 370 | / (2 * c) 371 | * ( 372 | lambda_x_denom / lambda_z_denom.clamp_min(1e-15) 373 | - atanh_term * (2 - lambda_x_denom) / c_sqrt_z_norm.clamp_min(1e-15) 374 | ) 375 | ) 376 | 377 | if x_grad.isinf().any() or z_grad.isinf().any() or c_grad.isinf().any(): 378 | raise ValueError("Exploded gradient encountered") 379 | 380 | if x_grad.isnan().any() or z_grad.isnan().any() or c_grad.isnan().any(): 381 | raise ValueError("Exploded gradient encountered") 382 | 383 | if c_grad.abs().gt(1e5).any(): 384 | print(ctx.__class__.__name__, c_grad) 385 | 386 | return ( 387 | x_grad, 388 | z_grad, 389 | c_grad, 390 | None, 391 | ) 392 | 393 | 394 | def ag_expmap0(v, c, dim): 395 | x = ag_ExpMap0.apply(v, c, dim) 396 | return ag_Project.apply(x, c, dim) 397 | 398 | 399 | def ag_expmap(x, v, c, dim): 400 | y = ag_ExpSecondTerm.apply(x, v, c, dim) 401 | return ag_Project.apply(ag_MobiusAddition.apply(x, y, c, dim), c, dim) 402 | 403 | 404 | def ag_logmap(x, y, c, dim): 405 | z = ag_MobiusAddition.apply(-x, y, c, dim) 406 | return ag_LogScaledTerm.apply(x, z, c, dim) 407 | 408 | 409 | def ag_gyration(u, v, w, c, dim): 410 | uv = ag_MobiusAddition.apply(u, v, c, dim) 411 | vw = ag_MobiusAddition.apply(v, w, c, dim) 412 | uvw = ag_MobiusAddition.apply(u, vw, c, dim) 413 | 414 | return ag_MobiusAddition.apply(-uv, uvw, c, dim) 415 | 416 | 417 | def ag_transp(x, y, v, c, dim): 418 | lambda_x = ag_ConfFactor.apply(x, c, dim) 419 | lambda_y = ag_ConfFactor.apply(y, c, dim) 420 | return ag_gyration(y, -x, v, c, dim) * lambda_x / lambda_y 421 | 422 | 423 | def ag_dist( 424 | x: torch.Tensor, 425 | y: torch.Tensor, 426 | c: torch.Tensor, 427 | dim: int = -1, 428 | keepdim: bool = False, 429 | ) -> torch.Tensor: 430 | return ( 431 | 2 432 | / c.sqrt() 433 | * ( 434 | c.sqrt() 435 | * ag_MobiusAddition.apply(-x, y, c, dim).norm(dim=dim, keepdim=keepdim) 436 | ).atanh() 437 | ) 438 | --------------------------------------------------------------------------------