├── .gitignore ├── neural_collapse ├── __init__.py ├── util.py ├── kernels.py ├── measure.py └── accumulate.py ├── setup.py ├── LICENSE ├── examples ├── mnist.py └── layerwise_mnist.py └── README.md /.gitignore: -------------------------------------------------------------------------------- 1 | __pycache__ 2 | data 3 | wandb 4 | *egg* -------------------------------------------------------------------------------- /neural_collapse/__init__.py: -------------------------------------------------------------------------------- 1 | from . import accumulate, kernels, measure, util 2 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | # setup.py 2 | 3 | from setuptools import find_packages, setup 4 | 5 | setup( 6 | name="neural_collapse", 7 | version="0.1", 8 | author="Robert Wu", 9 | author_email="rupert@cs.toronto.edu", 10 | description="A generic library for accumulating feature statistics and computing neural collapse (NC) metrics.", 11 | long_description=open("README.md").read(), 12 | long_description_content_type="text/markdown", 13 | url="https://github.com/rhubarbwu/neural-collapse", # Link to your repository 14 | packages=find_packages(), # Automatically find the 'matrix_operator' package 15 | install_requires=[ 16 | "numpy", 17 | "scipy", 18 | "torch", 19 | ], 20 | extras_require={ 21 | "faiss": ["faiss-gpu", "numpy<2"], 22 | }, 23 | classifiers=[ 24 | "Programming Language :: Python :: 3", 25 | "License :: OSI Approved :: MIT License", 26 | "Operating System :: OS Independent", 27 | ], 28 | python_requires=">=3.6", 29 | ) 30 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2024-2025 Robert Wu, Aditya Mehrotra 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /neural_collapse/util.py: -------------------------------------------------------------------------------- 1 | from hashlib import sha256 2 | 3 | import torch as pt 4 | import torch.linalg as la 5 | from torch import Tensor 6 | 7 | hashify = lambda O: sha256(O.cpu().numpy().tobytes()).hexdigest() 8 | resolve = lambda A, B: B if not A or A == B else None 9 | normalize = lambda x: x / (la.norm(x, dim=-1, keepdim=True) + pt.finfo(x.dtype).eps) 10 | 11 | 12 | def tiling(data: Tensor, kernel: callable, tile_size: int = None) -> Tensor: 13 | """Create a grid of kernel evaluations based on tiles of the input data. 14 | 15 | The function divides the input tensor `data` into overlapping tiles of 16 | size `tile_size` and applies the specified kernel function to each pair 17 | of tiles. The results are stored in a 2D tensor that represents the 18 | evaluations of the kernel on each pair of tiles. 19 | 20 | Args: 21 | data (Tensor): Input tensor to be tiled. 22 | kernel (callable): Function that takes two tiles as input and 23 | returns a tensor representing their kernel evaluation. 24 | tile_size (int, optional): Size of the tiles to be extracted from the 25 | data. Set tile_size << K to avoid OOM. Defaults to None. 26 | 27 | Returns: 28 | Tensor: Matrix where the element at position (i, j) is the result 29 | of applying the kernel to the i-th and j-th tiles of the input. 30 | """ 31 | N = len(data) 32 | outgrid = pt.zeros((N, N), device=data.device) 33 | if not tile_size: 34 | tile_size = N 35 | n_tiles = (N + tile_size - 1) // tile_size 36 | 37 | for i in range(n_tiles): 38 | i0, i1 = i * tile_size, min((i + 1) * tile_size, N) 39 | tile_i = data[i0:i1] 40 | for j in range(n_tiles): 41 | j0, j1 = j * tile_size, min((j + 1) * tile_size, N) 42 | tile_j = data[j0:j1] 43 | outgrid[i0:i1, j0:j1] = kernel(tile_i, tile_j) 44 | 45 | return outgrid 46 | 47 | 48 | def symm_reduce(data: Tensor, reduce: callable = pt.sum) -> Tensor: 49 | """Compute a symmetric reduction of the upper triangle of a square tensor. 50 | 51 | This function computes a specified reduction the upper triangle of a 52 | square tensor `data`, ignoring the diagonal. It also handles the case of 53 | an even-sized tensor by including the middle row in the reduction. 54 | 55 | Args: 56 | data (Tensor): Square matrix from which to compute the reduction. 57 | reduce (callable, optional): A callable function to apply for the 58 | reduction. Defaults to `pt.sum`. 59 | 60 | Returns: 61 | Tensor: Mean of the reduction applied to the upper triangle of the 62 | symmetric tensor. 63 | """ 64 | N = data.shape[0] 65 | total = 0 66 | 67 | assert N == data.shape[1] 68 | for i in range((N - 1) // 2): 69 | upper = data[i][i + 1 :] 70 | lower = data[N - i - 2][N - i - 1 :] 71 | folded = pt.cat((upper, lower)) 72 | total += reduce(folded) 73 | if N % 2 == 0: 74 | row = data[N // 2 - 1][N // 2 :] 75 | total += reduce(row) 76 | 77 | return total / (N * (N - 1) / 2) 78 | -------------------------------------------------------------------------------- /examples/mnist.py: -------------------------------------------------------------------------------- 1 | import torch as pt 2 | import torch.nn as nn 3 | import torch.optim as optim 4 | import torchvision.models as models 5 | from neural_collapse.accumulate import (CovarAccumulator, DecAccumulator, 6 | MeanAccumulator, VarNormAccumulator) 7 | from neural_collapse.kernels import kernel_stats, log_kernel 8 | from neural_collapse.measure import (clf_ncc_agreement, covariance_ratio, 9 | orthogonality_deviation, 10 | self_duality_error, similarities, 11 | simplex_etf_error, variability_cdnv) 12 | from torch.utils.data import DataLoader 13 | from torchvision.datasets import MNIST, FashionMNIST 14 | from torchvision.transforms import Compose, Normalize, ToTensor 15 | 16 | # Device configuration 17 | device = pt.device("cuda" if pt.cuda.is_available() else "cpu") 18 | 19 | # Hyperparameters 20 | n_epochs = 200 21 | batch_size = 128 22 | lr, epochs_lr_decay, lr_decay = 0.0679, [n_epochs // 3, n_epochs * 2 // 3], 0.1 23 | momentum = 0.9 24 | weight_decay = 5e-4 25 | 26 | # MNIST dataset 27 | transform = Compose([ToTensor(), Normalize((0.1307,), (0.3081,))]) 28 | train_dataset = MNIST("./data", True, transform, download=True) 29 | train_loader = DataLoader(dataset=train_dataset, batch_size=batch_size, shuffle=True) 30 | test_dataset = MNIST("./data", False, transform, download=True) 31 | test_loader = DataLoader(dataset=test_dataset, batch_size=batch_size, shuffle=True) 32 | # OoD dataset (Fashion MNIST), for NC5 33 | ood_dataset = FashionMNIST("./data", False, transform, download=True) 34 | ood_loader = DataLoader(dataset=ood_dataset, batch_size=batch_size, shuffle=True) 35 | 36 | 37 | # ResNet model 38 | model = models.resnet18(num_classes=10, weights=None).to(device) 39 | model.conv1 = nn.Conv2d(1, model.conv1.weight.shape[0], 3, 1, 1, bias=False) 40 | model.maxpool = nn.MaxPool2d(kernel_size=1, stride=1, padding=0) 41 | model.to(device) 42 | 43 | 44 | class Features: 45 | pass 46 | 47 | 48 | def hook(self, input, output): 49 | Features.value = input[0].clone() 50 | 51 | 52 | # register hook that saves last-layer input into features 53 | classifier = model.fc 54 | classifier.register_forward_hook(hook) 55 | 56 | # Loss and optimizer 57 | criterion = nn.CrossEntropyLoss() 58 | optimizer = optim.SGD(model.parameters(), lr, momentum, weight_decay=weight_decay) 59 | lr_scheduler = optim.lr_scheduler.MultiStepLR(optimizer, epochs_lr_decay, lr_decay) 60 | 61 | try: 62 | import wandb 63 | 64 | wandb.init(project="neural-collapse") 65 | WANDB = True 66 | except: 67 | WANDB = False 68 | 69 | 70 | # Train the model 71 | total_step = len(train_loader) 72 | log_line = lambda epoch, i: f"Epoch [{epoch+1}/{n_epochs}], Step [{i+1}/{total_step}]" 73 | for epoch in range(n_epochs): 74 | model.train() 75 | for i, (images, labels) in enumerate(train_loader): 76 | images, labels = images.to(device), labels.to(device) 77 | 78 | logits = model(images) 79 | loss = criterion(logits, labels) 80 | 81 | optimizer.zero_grad() 82 | loss.backward() 83 | optimizer.step() 84 | 85 | if (i + 1) % 100 == 0: 86 | print(f"{log_line(epoch, i)}, Loss: {loss.item():.4f}") 87 | lr_scheduler.step() 88 | 89 | with pt.no_grad(): 90 | model.eval() 91 | weights = model.fc.weight 92 | 93 | # NC collections 94 | mean_accum = MeanAccumulator(10, 512, "cuda") 95 | for i, (images, labels) in enumerate(train_loader): 96 | images, labels = images.to(device), labels.to(device) 97 | outputs = model(images) 98 | mean_accum.accumulate(Features.value, labels) 99 | means, mG = mean_accum.compute() 100 | 101 | var_norms_accum = VarNormAccumulator(10, 512, "cuda", M=means) 102 | covar_accum = CovarAccumulator(10, 512, "cuda", M=means) 103 | for i, (images, labels) in enumerate(train_loader): 104 | images, labels = images.to(device), labels.to(device) 105 | outputs = model(images) 106 | var_norms_accum.accumulate(Features.value, labels, means) 107 | covar_accum.accumulate(Features.value, labels, means) 108 | var_norms, _ = var_norms_accum.compute() 109 | covar_within = covar_accum.compute() 110 | 111 | dec_accum = DecAccumulator(10, 512, "cuda", M=means, W=weights) 112 | dec_accum.create_index(means) # optionally use FAISS index for NCC 113 | for i, (images, labels) in enumerate(test_loader): 114 | images, labels = images.to(device), labels.to(device) 115 | outputs = model(images) 116 | 117 | # mean embeddings (only) necessary again if not using FAISS index 118 | if dec_accum.index is None: 119 | dec_accum.accumulate(Features.value, labels, weights, means) 120 | else: 121 | dec_accum.accumulate(Features.value, labels, weights) 122 | 123 | ood_mean_accum = MeanAccumulator(10, 512, "cuda") 124 | for i, (images, labels) in enumerate(ood_loader): 125 | images, labels = images.to(device), labels.to(device) 126 | outputs = model(images) 127 | ood_mean_accum.accumulate(Features.value, labels) 128 | _, mG_ood = ood_mean_accum.compute() 129 | 130 | # NC measurements 131 | results = { 132 | "nc1_pinv": covariance_ratio(covar_within, means, mG), 133 | "nc1_svd": covariance_ratio(covar_within, means, mG, "svd"), 134 | "nc1_quot": covariance_ratio(covar_within, means, mG, "quotient"), 135 | "nc1_cdnv": variability_cdnv(var_norms, means, tile_size=64), 136 | "nc2_etf_err": simplex_etf_error(means, mG), 137 | "nc2g_dist": kernel_stats(means, mG, tile_size=64)[1], 138 | "nc2g_log": kernel_stats(means, mG, kernel=log_kernel, tile_size=64)[1], 139 | "nc3_dual_err": self_duality_error(weights, means, mG), 140 | "nc3u_uni_dual": similarities(weights, means, mG).var().item(), 141 | "nc4_agree": clf_ncc_agreement(dec_accum), 142 | "nc5_ood_dev": orthogonality_deviation(means, mG_ood), 143 | } 144 | 145 | if WANDB: 146 | wandb.log(results) 147 | else: 148 | print(results) 149 | -------------------------------------------------------------------------------- /neural_collapse/kernels.py: -------------------------------------------------------------------------------- 1 | from math import copysign 2 | from typing import Tuple 3 | 4 | import torch as pt 5 | from torch import Tensor 6 | 7 | from .util import normalize, symm_reduce, tiling 8 | 9 | 10 | def class_dist_norm_vars( 11 | V_norms: Tensor, 12 | M: Tensor, 13 | dist_exp: float = 1.0, 14 | tile_size: int = None, 15 | ) -> Tensor: 16 | """Compute the matrix grid of class-distance normalized variances (CDNV). 17 | This metric reflects pairwise variability adjusted for mean distances. 18 | Galanti et al. (2021): https://arxiv.org/abs/2112.15121 19 | 20 | Arguments: 21 | V_norms (Tensor): Matrix of within-class variance norms. 22 | M (Tensor): Matrix of feature (or class mean) embeddings. 23 | dist_exp (int): Power with which to exponentiate the distance 24 | normalizer. A greater power further diminishes the contribution of 25 | mutual variability between already-disparate classes. Defaults to 26 | 1, equivalent to the CDNV introduced by Galanti et al. (2021). 27 | tile_size (int, optional): Size of the tile for kernel computation. 28 | Set tile_size << K to avoid OOM. Defaults to None. 29 | 30 | Returns: 31 | Tensor: A tensor representing the matrix grid of pairwise CDNVs. 32 | """ 33 | V_norms = V_norms.view(-1, 1) 34 | bundled = pt.cat((M, V_norms), dim=1) 35 | 36 | def kernel(tile_i, tile_j): 37 | vars_i, vars_j = tile_i[:, -1], tile_j[:, -1] 38 | var_avgs = (vars_i.unsqueeze(1) + vars_j).squeeze() / 2 39 | 40 | M_i, M_j = tile_i[:, :-1], tile_j[:, :-1] 41 | 42 | M_diff = M_i.unsqueeze(1) - M_j 43 | M_diff_norm_sq = pt.sum(M_diff * M_diff, dim=-1) 44 | return var_avgs.squeeze(0) / (M_diff_norm_sq**dist_exp) 45 | 46 | return tiling(bundled, kernel, tile_size) 47 | 48 | 49 | def dist_kernel(data: Tensor, tile_size: int = None) -> Tensor: 50 | """Compute the grid of pairwise vector distances across a set of vectors. 51 | 52 | Arguments: 53 | data (Tensor): Input data tensor across which to apply the kernel. 54 | tile_size (int, optional): Size of the tile for kernel computation. 55 | Set tile_size << K to avoid OOM. Defaults to None. 56 | """ 57 | kernel = lambda tile_i, tile_j: (tile_i.unsqueeze(1) - tile_j).norm(dim=-1) 58 | return tiling(data, kernel, tile_size) 59 | 60 | 61 | def log_kernel(data: Tensor, exponent: int = -1, tile_size: int = None) -> Tensor: 62 | """Compute the grid of pairwise logarithmic distances across vectors. 63 | Liu et al. (2023): https://arxiv.org/abs/2303.06484 64 | 65 | Arguments: 66 | data (Tensor): Input data tensor across which to apply the kernel. 67 | exponent (int, optional): Power with which to exponentiate the 68 | distance norm before the logarithm. Defaults to -1 (inverse). 69 | tile_size (int, optional): Size of the tile for kernel computation. 70 | Set tile_size << K to avoid OOM. Defaults to None. 71 | """ 72 | 73 | def kernel(tile_i, tile_j): 74 | diff = tile_i.unsqueeze(1) - tile_j 75 | diff_norms = diff.norm(dim=-1) 76 | return (diff_norms ** (exponent)).log() 77 | 78 | return tiling(data, kernel, tile_size) 79 | 80 | 81 | def riesz_kernel(data: Tensor, tile_size: int = None) -> Tensor: 82 | """Compute the grid of Riesz distances across vectors. 83 | Liu et al. (2023): https://arxiv.org/abs/2303.06484 84 | 85 | Arguments: 86 | data (Tensor): Input data tensor across which to apply the kernel. 87 | tile_size (int, optional): Size of the tile for kernel computation. 88 | Set tile_size << K to avoid OOM. Defaults to None. 89 | """ 90 | S = data.shape[-1] - 2 91 | 92 | def kernel(tile_i, tile_j): 93 | diff = tile_i.unsqueeze(1) - tile_j 94 | diff_norms = diff.norm(dim=-1) 95 | return copysign(1, S) * diff_norms ** (-S) 96 | 97 | return tiling(data, kernel, tile_size) 98 | 99 | 100 | def kernel_grid( 101 | M: Tensor, 102 | m_G: Tensor = 0, 103 | kernel: callable = dist_kernel, 104 | tile_size: int = None, 105 | ) -> Tensor: 106 | """Compute the grid from the kernel function on pairwise interactions 107 | between embeddings. Self-interactions are excluded. 108 | 109 | Arguments: 110 | M (Tensor): Matrix of feature (e.g. class mean) embeddings. 111 | m_G (Tensor, optional): Bias (e.g. global mean) vector. Defaults to 0. 112 | kernel (callable, optional): The kernel with which to compute 113 | interactions. Defaults to the inner product. Other common 114 | functions include the logarithmic or Riesz distance kernels. 115 | tile_size (int, optional): Size of the tile for kernel computation. 116 | Set tile_size << K to avoid OOM. Defaults to None. 117 | 118 | Returns: 119 | float: Average of pairwise kernel interactions. 120 | float: Variance of pairwise kernel interactions. 121 | """ 122 | M_centred_normed = normalize(M - m_G) 123 | return kernel(M_centred_normed, tile_size=tile_size) 124 | 125 | 126 | def kernel_stats( 127 | M: Tensor, 128 | m_G: Tensor = 0, 129 | kernel: callable = dist_kernel, 130 | tile_size: int = None, 131 | ) -> Tuple[float, float]: 132 | """Compute the average and variance of a kernel function on pairwise 133 | interactions between embeddings. Self-interactions are excluded. 134 | Liu et al. (2023): https://arxiv.org/abs/2303.06484 135 | 136 | Arguments: 137 | M (Tensor): Matrix of feature (e.g. class mean) embeddings. 138 | m_G (Tensor, optional): Bias (e.g. global mean) vector. Defaults to 0. 139 | kernel (callable): Kernel function with which to compute 140 | interactions. Defaults to the inner product. Other common 141 | functions include the logarithmic or Riesz distance kernels. 142 | tile_size (int, optional): Size of the tile for kernel computation. 143 | Set tile_size << K to avoid OOM. Defaults to None. 144 | 145 | Returns: 146 | float: Average of pairwise kernel interactions. 147 | float: Variance of pairwise kernel interactions. 148 | """ 149 | grid: Tensor = kernel_grid(M, m_G, kernel, tile_size) 150 | avg = symm_reduce(grid) 151 | var = symm_reduce(grid, lambda row: pt.sum((row - avg) ** 2)) 152 | 153 | return avg.item(), var.item() 154 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Neural Collapse 2 | 3 | This package aims to be a reference implementation for the analysis of 4 | [Neural Collapse (NC) (Papyan et al., 2020)](https://www.pnas.org/doi/full/10.1073/pnas.2015509117). 5 | We provide, 6 | 7 | 1. Accumulators to collect embeddings from output representations from your 8 | pre-trained model. 9 | 2. Measurement (kernel) functions for several canonical and modern NC metrics. 10 | 3. Tiling support for memory-bound settings arising from large embeddings, 11 | many classes and/or limited parallel accelerator (e.g. GPU) memory. 12 | 13 | ## Installation 14 | 15 | ```sh 16 | # install from remote 17 | pip install git+https://github.com/rhubarbwu/neural-collapse.git 18 | 19 | # install with FAISS 20 | pip install git+https://github.com/rhubarbwu/neural-collapse.git#egg=neural_collapse[faiss] 21 | 22 | # install locally from a repository clone [with FAISS] 23 | git clone https://github.com/rhubarbwu/neural-collapse.git 24 | pip install -e neural-collapse[faiss] 25 | ``` 26 | 27 | ## Usage 28 | 29 | ```py 30 | import neural_collapse as nc 31 | ``` 32 | 33 | We assume that you, 34 | 35 | - Already pre-trained your model or are in the training process with a 36 | programmable loop, where the top-layer classifier weights are available. 37 | - Have your iterable dataloader(s) available. Make sure your training data is 38 | the same as that with which you trained your model. 39 | - Have model evaluation functions or results; technically optional but ideal. 40 | 41 | For use cases with large embeddings or many classes, we recommend using a 42 | hardware accelerator (e.g. `cuda`). 43 | 44 | ### Accumulators 45 | 46 | You'll need to collect (e.g. "accumulate") statistics from your learned 47 | representations. Here we outline a 48 | [basic example on the MNIST dataset](./examples/mnist.py) with `K=10` 49 | classes and embeddings of size `D=512`. 50 | 51 | ```py 52 | from neural_collapse.accumulate import (CovarAccumulator, DecAccumulator, 53 | MeanAccumulator, VarNormAccumulator) 54 | ``` 55 | 56 | #### Mean Embedding Accumulators (for NC\* in general) 57 | 58 | ```py 59 | mean_accum = MeanAccumulator(10, 512, "cuda") 60 | for i, (images, labels) in enumerate(train_loader): 61 | images, labels = images.to(device), labels.to(device) 62 | outputs = model(images) 63 | mean_accum.accumulate(Features.value, labels) 64 | means, mG = mean_accum.compute() 65 | ``` 66 | 67 | #### Variance Accumulators (for NC1) 68 | 69 | For measuring within-class variability collapse (NC1), you would typically 70 | collect within-class covariances (`covar_accum` below); note that this might 71 | be memory-intensive at order `K*D*D`. 72 | 73 | ```py 74 | covar_accum = CovarAccumulator(10, 512, "cuda", M=means) 75 | var_norms_accum = VarNormAccumulator(10, 512, "cuda", M=means) # for CDNV 76 | for i, (images, labels) in enumerate(train_loader): 77 | images, labels = images.to(device), labels.to(device) 78 | outputs = model(images) 79 | covar_accum.accumulate(Features.value, labels, means) 80 | var_norms_accum.accumulate(Features.value, labels, means) 81 | covar_within = covar_accum.compute() 82 | var_norms, _ = var_norms_accum.compute() # for CDNV 83 | ``` 84 | 85 | NC1 can also be empirically measured using the class-distance normalized 86 | variance [(CDNV) (Galanti et. al, 2021)](https://arxiv.org/abs/2112.15121), 87 | which only requires collecting within-class variance norms at order `K`. 88 | 89 | #### Decision Agreement Accumulators (for NC4) 90 | 91 | Measuring the convergence of the linear classifier's behaviour to that of the 92 | implicit near-class centre (NCC) classifier has since been extended to 93 | generalizing to unseen (e.g. validation or test) data. 94 | 95 | ```py 96 | dec_accum = DecAccumulator(10, 512, "cuda", M=means, W=weights) 97 | dec_accum.create_index(means) # optionally use FAISS index for NCC 98 | for i, (images, labels) in enumerate(test_loader): 99 | images, labels = images.to(device), labels.to(device) 100 | outputs = model(images) 101 | 102 | # mean embeddings (only) necessary again if not using FAISS index 103 | if dec_accum.index is None: 104 | dec_accum.accumulate(Features.value, labels, weights, means) 105 | else: 106 | dec_accum.accumulate(Features.value, labels, weights) 107 | ``` 108 | 109 | #### Out-of-Distribution (OoD) Means (for NC5) 110 | 111 | For OoD detection 112 | [(NC5) (Ammar et al., 2024)](https://arxiv.org/abs/2310.06823), collect 113 | class-mean embeddings from an out-of-distribution dataset for OoD detection. 114 | 115 | ```py 116 | ood_mean_accum = MeanAccumulator(10, 512, "cuda") 117 | for i, (images, labels) in enumerate(ood_loader): 118 | images, labels = images.to(device), labels.to(device) 119 | outputs = model(images) 120 | ood_mean_accum.accumulate(Features.value, labels) 121 | _, mG_ood = ood_mean_accum.compute() 122 | ``` 123 | 124 | ### Measurements 125 | 126 | Here's a snippet from our [example on the MNIST dataset](./examples/mnist.py). 127 | 128 | ```py 129 | from neural_collapse.measure import (clf_ncc_agreement, covariance_pinv, 130 | covariance_ratio, orthogonality_deviation, 131 | self_duality_error, simplex_etf_error, 132 | variability_cdnv) 133 | 134 | results = { 135 | "nc1_pinv": covariance_ratio(covar_within, means, mG), 136 | "nc1_svd": covariance_ratio(covar_within, means, mG, "svd"), 137 | "nc1_quot": covariance_ratio(covar_within, means, mG, "quotient"), 138 | "nc1_cdnv": variability_cdnv(var_norms, means), 139 | "nc2_etf_err": simplex_etf_error(means, mG), 140 | "nc2g_dist": kernel_stats(means, mG)[1], 141 | "nc2g_log": kernel_stats(means, mG, kernel=log_kernel)[1], 142 | "nc3_dual_err": self_duality_error(weights, means, mG), 143 | "nc3u_uni_dual": similarities(weights, means, mG).var().item(), 144 | "nc4_agree": clf_ncc_agreement(dec_accum), 145 | "nc5_ood_dev": orthogonality_deviation(means, mG_ood), 146 | } 147 | ``` 148 | 149 | #### Pre-Centring Means 150 | 151 | Where centring is required for `means`, you can include the global mean `mG` 152 | as a bias argument (as above), or pre-centre them (as below). 153 | 154 | ```py 155 | means_centred = means - mG 156 | results = { 157 | "nc1_pinv": covariance_ratio(covar_within, means_centred), 158 | "nc1_svd": covariance_ratio(covar_within, means_centred, metric="svd"), 159 | "nc1_quot": covariance_ratio(covar_within, means_centred, metric="quotient"), 160 | "nc1_cdnv": variability_cdnv(var_norms, means), 161 | # ... 162 | "nc5_ood_dev": orthogonality_deviation(means, mG_ood), 163 | } 164 | ``` 165 | 166 | Note that since the uncentred means are still needed for some measurements 167 | (such as CDNV) (and therefore cannot be discarded), storing pre-centred means 168 | may not be economical memory-wise if `K` and/or `D` are large. 169 | 170 | #### Tiling & Reductions 171 | 172 | For many of the NC measurement functions, we implement kernel tiling if large 173 | embeddings or many classes are straining your hardware memory. You may want to 174 | tune the tile square size to maximize accelerator throughput. 175 | 176 | ```py 177 | results = { 178 | # ... 179 | "nc1_cdnv": variability_cdnv(var_norms, means, tile_size=64), 180 | "nc2g_dist": kernel_stats(means, mG, tile_size=64)[1], # var 181 | "nc2g_log": kernel_stats(means, mG, kernel=log_kernel, tile_size=64)[1], # var 182 | # ... 183 | } 184 | ``` 185 | 186 | After `kernel_grid` produces a symmetric measurement matrix, `kernel_stats` 187 | computes the mean (`[0]`) and variance (`[1]`) using triangle row folding. 188 | 189 | ## Development 190 | 191 | This project is under active development. Feel free to open issues for bugs, 192 | features, optimizations, or papers you would like (us) to implement. 193 | 194 | ## Citation 195 | 196 | As most of the code is taken from the [linguistic-collapse](https://github.com/rhubarbwu/linguistic-collapse) repository, we ask that you cite that paper if you use this code. 197 | 198 | ```tex 199 | @inproceedings{NEURIPS2024_f88cc893, 200 | author = {Wu, Robert and Papyan, Vardan}, 201 | booktitle = {Advances in Neural Information Processing Systems}, 202 | editor = {A. Globerson and L. Mackey and D. Belgrave and A. Fan and U. Paquet and J. Tomczak and C. Zhang}, 203 | pages = {137432--137473}, 204 | publisher = {Curran Associates, Inc.}, 205 | title = {Linguistic Collapse: Neural Collapse in (Large) Language Models}, 206 | url = {https://proceedings.neurips.cc/paper_files/paper/2024/file/f88cc8930b47a45ec4733123bf3039b9-Paper-Conference.pdf}, 207 | volume = {37}, 208 | year = {2024} 209 | } 210 | ``` 211 | 212 | ## References 213 | 214 | - [Prevalence of neural collapse during the terminal phase of deep learning training](https://www.pnas.org/doi/full/10.1073/pnas.2015509117) 215 | - [On the Role of Neural Collapse in Transfer Learning](https://arxiv.org/abs/2112.15121) 216 | - [Neural Collapse: A Review on Modelling Principles and Generalization](https://arxiv.org/abs/2206.04041) 217 | - [Perturbation Analysis of Neural Collapse](https://proceedings.mlr.press/v202/tirer23a) 218 | - [Generalizing and Decoupling Neural Collapse via Hyperspherical Uniformity Gap](https://arxiv.org/abs/2303.06484) 219 | - [NECO: NEural Collapse Based Out-of-distribution detection](https://arxiv.org/abs/2310.06823) 220 | - [Linguistic Collapse: Neural Collapse in (Large) Language Models](https://arxiv.org/abs/2405.17767) 221 | -------------------------------------------------------------------------------- /examples/layerwise_mnist.py: -------------------------------------------------------------------------------- 1 | import copy 2 | import numpy as np 3 | import matplotlib.pyplot as plt 4 | import torch as pt 5 | import torch.nn as nn 6 | import torch.optim as optim 7 | import torchvision.models as models 8 | from torch.utils.data import DataLoader 9 | from torchvision.datasets import MNIST, FashionMNIST 10 | from torchvision.transforms import Compose, Normalize, ToTensor 11 | import wandb 12 | 13 | from neural_collapse.accumulate import ( 14 | CovarAccumulator, 15 | DecAccumulator, 16 | MeanAccumulator, 17 | VarNormAccumulator, 18 | ) 19 | from neural_collapse.measure import ( 20 | clf_ncc_agreement, 21 | covariance_ratio, 22 | orthogonality_deviation, 23 | self_duality_error, 24 | simplex_etf_error, 25 | variability_cdnv, 26 | ) 27 | 28 | 29 | def replace_layers(model, max_channels=16): 30 | for name, module in model.named_children(): 31 | # Replace Conv2d layers with in_channels > max_channels 32 | if isinstance(module, nn.Conv2d) and module.in_channels >= max_channels: 33 | new_conv = nn.Conv2d( 34 | in_channels=max_channels, 35 | out_channels=max_channels, 36 | kernel_size=module.kernel_size, 37 | stride=module.stride, 38 | padding=module.padding, 39 | dilation=module.dilation, 40 | groups=module.groups, 41 | bias=(module.bias is not None), 42 | padding_mode=module.padding_mode, 43 | ) 44 | setattr(model, name, new_conv) 45 | 46 | # Replace BatchNorm2d layers with num_features > max_channels 47 | elif isinstance(module, nn.BatchNorm2d) and module.num_features >= max_channels: 48 | new_bn = nn.BatchNorm2d( 49 | num_features=max_channels, 50 | eps=module.eps, 51 | momentum=module.momentum, 52 | affine=module.affine, 53 | track_running_stats=module.track_running_stats, 54 | ) 55 | setattr(model, name, new_bn) 56 | 57 | # Recursively apply the function to child modules 58 | replace_layers(module, max_channels) 59 | 60 | 61 | # Device configuration 62 | device = pt.device("cuda" if pt.cuda.is_available() else "cpu") 63 | 64 | # Hyperparameters 65 | n_epochs = 200 66 | batch_size = 128 67 | lr, epochs_lr_decay, lr_decay = 0.0679, [n_epochs // 3, n_epochs * 2 // 3], 0.1 68 | momentum = 0.9 69 | weight_decay = 5e-4 70 | 71 | # MNIST dataset 72 | transform = Compose([ToTensor(), Normalize((0.1307,), (0.3081,))]) 73 | train_dataset = MNIST("./data", True, transform, download=True) 74 | train_loader = DataLoader(dataset=train_dataset, batch_size=batch_size, shuffle=True) 75 | test_dataset = MNIST("./data", False, transform, download=True) 76 | test_loader = DataLoader(dataset=test_dataset, batch_size=batch_size, shuffle=True) 77 | # OoD dataset (Fashion MNIST), for NC5 78 | ood_dataset = FashionMNIST("./data", False, transform, download=True) 79 | ood_loader = DataLoader(dataset=ood_dataset, batch_size=batch_size, shuffle=True) 80 | 81 | 82 | # ResNet model 83 | model = models.resnet18(num_classes=10, weights=None).to(device) 84 | model.conv1 = nn.Conv2d(1, 16, 3, 1, 1, bias=False) 85 | model.maxpool = nn.MaxPool2d(kernel_size=1, stride=1, padding=0) 86 | model.fc = nn.Linear(in_features=16, out_features=10) 87 | replace_layers(model) 88 | 89 | print(model) 90 | 91 | 92 | class ModifiedResNet(nn.Module): 93 | def __init__(self, original_model): 94 | super(ModifiedResNet, self).__init__() 95 | self.features = nn.ModuleList(original_model.children()) 96 | 97 | def forward(self, x): 98 | outputs = [] 99 | for i, module in enumerate(self.features): 100 | if isinstance(module, nn.Sequential): 101 | x = module(x) 102 | outputs.append(x.flatten(start_dim=1)) 103 | elif isinstance(module, nn.Linear): 104 | x = x.flatten(start_dim=1) 105 | outputs.append(x) 106 | x = module(x) 107 | else: 108 | x = module(x) 109 | return x, outputs 110 | 111 | 112 | model = ModifiedResNet(model).to(device=device) 113 | 114 | # Loss and optimizer 115 | criterion = nn.CrossEntropyLoss() 116 | optimizer = optim.SGD(model.parameters(), lr, momentum, weight_decay=weight_decay) 117 | lr_scheduler = optim.lr_scheduler.MultiStepLR(optimizer, epochs_lr_decay, lr_decay) 118 | 119 | 120 | try: 121 | import wandb 122 | 123 | wandb.init(project="neural-collapse") 124 | WANDB = True 125 | except: 126 | WANDB = False 127 | 128 | with pt.no_grad(): 129 | latents = model(pt.ones(1, 1, 28, 28).cuda())[-1] 130 | 131 | # Train the model 132 | total_step = len(train_loader) 133 | log_line = lambda epoch, i: f"Epoch [{epoch+1}/{n_epochs}], Step [{i+1}/{total_step}]" 134 | for epoch in range(n_epochs): 135 | model.train() 136 | for i, (images, labels) in enumerate(train_loader): 137 | images, labels = images.to(device), labels.to(device) 138 | 139 | logits, _ = model(images) 140 | loss = criterion(logits, labels) 141 | 142 | optimizer.zero_grad() 143 | loss.backward() 144 | optimizer.step() 145 | 146 | if (i + 1) % 100 == 0: 147 | print(f"{log_line(epoch, i)}, Loss: {loss.item():.4f}") 148 | lr_scheduler.step() 149 | 150 | with pt.no_grad(): 151 | model.eval() 152 | 153 | # NC collections 154 | mean_accums = [MeanAccumulator(10, t.shape[-1], "cuda") for t in latents] 155 | for i, (images, labels) in enumerate(train_loader): 156 | images, labels = images.to(device), labels.to(device) 157 | outputs, hiddens = model(images) 158 | for mean_accum, hidden in zip(mean_accums, hiddens): 159 | mean_accum.accumulate(hidden, labels) 160 | mean_stats = [mean_accum.compute() for mean_accum in mean_accums] 161 | # means, mG = mean_accum.compute() 162 | 163 | var_norms_accums = [ 164 | VarNormAccumulator(10, t.shape[-1], "cuda", M=mean) 165 | for ((mean, _), t) in zip(mean_stats, latents) 166 | ] 167 | covar_accums = [ 168 | CovarAccumulator(10, t.shape[-1], "cuda", M=mean) 169 | for ((mean, _), t) in zip(mean_stats, latents) 170 | ] 171 | for i, (images, labels) in enumerate(train_loader): 172 | images, labels = images.to(device), labels.to(device) 173 | outputs, hiddens = model(images) 174 | 175 | for (mean, _), var_norms_accum, covar_accum, hidden in zip( 176 | mean_stats, var_norms_accums, covar_accums, hiddens 177 | ): 178 | var_norms_accum.accumulate(hidden, labels, mean) 179 | covar_accum.accumulate(hidden, labels, mean) 180 | vars_norms = [ 181 | var_norms_accum.compute()[0] for var_norms_accum in var_norms_accums 182 | ] 183 | covars_within = [covar_accum.compute() for covar_accum in covar_accums] 184 | 185 | dec_accum = DecAccumulator( 186 | 10, 187 | latents[-1].shape[-1], 188 | "cuda", 189 | M=mean_stats[-1][0], 190 | W=model.features[-1].weight, 191 | ) 192 | for i, (images, labels) in enumerate(test_loader): 193 | images, labels = images.to(device), labels.to(device) 194 | outputs, hiddens = model(images) 195 | dec_accum.accumulate( 196 | hiddens[-1], labels, model.features[-1].weight, mean_stats[-1][0] 197 | ) 198 | 199 | ood_mean_accums = [MeanAccumulator(10, t.shape[-1], "cuda") for t in latents] 200 | for i, (images, labels) in enumerate(ood_loader): 201 | images, labels = images.to(device), labels.to(device) 202 | outputs, hiddens = model(images) 203 | for ood_mean_accum, hidden in zip(ood_mean_accums, hiddens): 204 | ood_mean_accum.accumulate(hidden, labels) 205 | ood_mean_stats = [ 206 | ood_mean_accum.compute() for ood_mean_accum in ood_mean_accums 207 | ] 208 | 209 | # NC measurements 210 | 211 | layerwise_metrics = [] 212 | for i in range(len(latents)): 213 | means, mG = mean_stats[i] 214 | covar_within = covars_within[i] 215 | var_norms = vars_norms[i] 216 | _, mG_ood = ood_mean_stats[i] 217 | 218 | results = { 219 | "nc1_covariance_pinv": covariance_ratio( 220 | covar_within, means, mG, metric="svd" 221 | ), 222 | "nc1_variability_cdnv": variability_cdnv(var_norms, means), 223 | "nc2_simplex_etf_error": simplex_etf_error(means, mG), 224 | "nc5_ood_deviation": orthogonality_deviation(means, mG_ood), 225 | } 226 | 227 | layerwise_metrics.append(copy.deepcopy(results)) 228 | if i == (len(latents) - 1): 229 | results["nc4_decs_agreement"] = clf_ncc_agreement(dec_accum) 230 | results["nc3_self_duality"] = self_duality_error( 231 | means, model.features[-1].weight, mG 232 | ) 233 | 234 | if WANDB: 235 | wandb.log({f"layer {i + 1}/{k}": v for k, v in results.items()}) 236 | 237 | 238 | # Plot layerwise plots 239 | if WANDB: 240 | for k, _ in layerwise_metrics[0].items(): 241 | 242 | vals = [] 243 | for metrics in layerwise_metrics: 244 | vals.append(np.log(metrics[k])) 245 | 246 | data = [[x, y] for x, y in zip(list(range(1, len(layerwise_metrics) + 1)), vals)] 247 | 248 | f, ax = plt.subplots(1, 1) 249 | ax.set_xlabel("Layer", fontsize=14) 250 | ax.set_xticks(np.arange(1, len(layerwise_metrics) + 1, dtype=np.int8)) 251 | ax.plot(list(range(1, len(layerwise_metrics) + 1)), vals) 252 | 253 | wandb.log( 254 | { 255 | f"layerwise_plots/{k}": f, 256 | }, 257 | ) 258 | pt.save(model.state_dict(), "collapsed_network.pt") -------------------------------------------------------------------------------- /neural_collapse/measure.py: -------------------------------------------------------------------------------- 1 | from math import sqrt 2 | from typing import List 3 | 4 | import numpy as np 5 | import torch as pt 6 | import torch.linalg as la 7 | from scipy.sparse.linalg import svds 8 | from torch import Tensor 9 | from torch.nn.functional import cosine_similarity 10 | 11 | from .accumulate import DecAccumulator 12 | from .kernels import class_dist_norm_vars 13 | from .util import normalize, symm_reduce 14 | 15 | 16 | def covariance_ratio( 17 | V_intra: Tensor, M: Tensor, m_G: Tensor = 0, metric: str = "pinv" 18 | ) -> float: 19 | """Compute the ratio between the within-class (intra) and between-class 20 | (inter) covariances, using some empirical metric. 21 | Papyan et al. (2020): https://doi.org/10.1073/pnas.2015509117; 22 | Han et al. (2022): https://arxiv.org/abs/2106.02073; 23 | Tirer et al. (2023): https://proceedings.mlr.press/v202/tirer23a; 24 | 25 | Arguments: 26 | V_intra (Tensor): Matrix of within-class covariance. 27 | M (Tensor): Matrix of feature (e.g. class mean) embeddings. 28 | m_G (Tensor, optional): Bias (e.g. global mean) vector. Defaults to 0. 29 | metric (str, optional): Empirical metric for within-class variability. 30 | Defaults to "pinv" for trace of the MP pseudoinverse left-product; 31 | "svd" approximates the pinv left-product with SVD (top-K); 32 | "quotient" computes the quotient of traces of covariances; 33 | 34 | Returns: 35 | float: Ratio of between-class/within-class covariances. 36 | """ 37 | (K, _), M_centred = M.shape, M - m_G 38 | V_inter = M_centred.mT @ M_centred / K 39 | 40 | if metric == "pinv": # Papyan et al. (2020) 41 | V_intra = V_intra.to(V_inter.device) 42 | prod = la.pinv(V_inter) @ V_intra 43 | return pt.trace(prod).item() / K 44 | 45 | if metric == "svd": # Han et al. (2022) 46 | V_intra, V_inter = V_intra.cpu().numpy(), V_inter.cpu().numpy() 47 | eig_vecs, eig_vals, _ = svds(V_inter, k=K - 1) 48 | inv_Sb = eig_vecs @ np.diag(eig_vals ** (-1)) @ eig_vecs.T 49 | return float(np.trace(V_intra @ inv_Sb)) / K 50 | 51 | if metric == "quotient": # Tirer et al. (2023) 52 | return pt.trace(V_intra).item() / pt.trace(V_inter).item() 53 | 54 | raise NotImplementedError 55 | 56 | 57 | def variability_cdnv( 58 | V_norms: Tensor, M: Tensor, dist_exp: float = 1.0, tile_size: int = None 59 | ) -> float: 60 | """Compute the average class-distance normalized variances (CDNV). 61 | This metric reflects pairwise variability adjusted for mean distances. 62 | Galanti et al. (2021): https://arxiv.org/abs/2112.15121; 63 | 64 | Arguments: 65 | V_norms (Tensor): Matrix of within-class variance norms for the classes. 66 | M (Tensor): Matrix of feature (e.g. class mean) embeddings. 67 | dist_exp (int): The power with which to exponentiate the distance 68 | normalizer. A greater power further diminishes the contribution of 69 | mutual variability between already-disparate classes. Defaults to 70 | 1, equivalent to the CDNV introduced by Galanti et al. (2021). 71 | tile_size (int, optional): Size of the tile for kernel computation. 72 | Set tile_size << K to avoid OOM. Defaults to None. 73 | 74 | Returns: 75 | float: The average CDNVs across all class pairs. 76 | """ 77 | kernel_grid = class_dist_norm_vars(V_norms, M, dist_exp, tile_size) 78 | avg = symm_reduce(kernel_grid, pt.sum) 79 | return avg.item() 80 | 81 | 82 | def mean_norms(M: Tensor, m_G: Tensor = 0, post_funcs: List[callable] = []) -> Tensor: 83 | """Compute the norms of (mean) embeddings (centred). 84 | 85 | Arguments: 86 | M (Tensor): Matrix of feature (e.g. class mean) embeddings. 87 | m_G (Tensor, optional): Bias (e.g. global mean) vector. Defaults to 0. 88 | post_funcs (List[callable], optional): Functions (Tensor -> Tensor) 89 | applied to the computed norms. Defaults to []. 90 | 91 | Returns: 92 | Tensor: A vector containing the norms for each class. 93 | """ 94 | M_centred = M - m_G 95 | result = M_centred.norm(dim=-1) # (K) 96 | for post_func in post_funcs: 97 | result = post_func(result) 98 | return result 99 | 100 | 101 | def interference_grid(M: Tensor, m_G: Tensor = 0) -> Tensor: 102 | """Compute the pairwise interference grid between (mean) embeddings. 103 | 104 | Arguments: 105 | M (Tensor): The matrix of feature (or class mean) embeddings. 106 | m_G (Tensor, optional): Bias (e.g. global mean) vector. Defaults to 0. 107 | 108 | Returns: 109 | Tensor: A matrix representing pairwise interferences. 110 | """ 111 | M_centred = M - m_G 112 | return pt.inner(M_centred, M_centred) # (K,K) 113 | 114 | 115 | def similarities(W: Tensor, M: Tensor, m_G: Tensor = 0, cos: bool = False) -> Tensor: 116 | """Compute the (cosine or dot-product) similarities between a set of (mean) 117 | embeddings and classifiers vectors. 118 | 119 | Arguments: 120 | W (Tensor): Weight vectors of the classifiers. Computations will be 121 | performed on the device of W. 122 | M (Tensor): Matrix of feature (e.g. class mean) embeddings. 123 | m_G (Tensor, optional): Bias (e.g. global mean) vector. Defaults to 0. 124 | cos (bool, optional): Whether to use cosine similarity. Defaults to 125 | False, using dot-product similarity. 126 | 127 | Returns: 128 | Tensor: Per-class similarities between embeddings and classifiers. 129 | """ 130 | M_centred = (M - m_G).to(W.device) 131 | if cos: 132 | return cosine_similarity(W, M_centred.to(W.dtype)) 133 | return (W * M_centred).sum(dim=1) 134 | 135 | 136 | def distance_norms(W: Tensor, M: Tensor, m_G: Tensor = 0, norm: bool = True) -> Tensor: 137 | """Compute the distance between (mean) embeddings and classifier vectors. 138 | 139 | Arguments: 140 | M (Tensor): Feature (e.g. class mean) embeddings vectors. 141 | W (Tensor): Weight vectors of the classifiers. Computations will be 142 | performed on the device of W. 143 | m_G (Tensor, optional): Bias (e.g. global mean) vector. Defaults to 0. 144 | norm (bool, optional): Whether to normalize vectors before taking 145 | their distances. Defaults to True, allowing two dual spaces. 146 | 147 | Returns: 148 | Tensor: Per-class distances between embeddings and classifiers. 149 | """ 150 | M_centred = (M - m_G).to(W.device) 151 | if norm: 152 | W, M_centred = normalize(W), normalize(M_centred) 153 | return (W - M_centred).norm(dim=-1) 154 | 155 | 156 | def structure_error(A: Tensor, B: Tensor) -> float: 157 | """Compute the error between the cross-coherence structure formed 158 | by two sets of vectors and the ideal simplex equiangular tight frame 159 | (ETF), expressed as the matrix norm of their difference. 160 | Kothapalli (2023): https://arxiv.org/abs/2206.04041 161 | 162 | Arguments: 163 | A (Tensor): First tensor for comparison. 164 | B (Tensor): Second tensor for comparison. 165 | 166 | Returns: 167 | float: Scalar error of excess incoherence from simplex ETF. 168 | """ 169 | (K, _) = A.shape 170 | 171 | struct = B.to(A.device) @ A.mT # (K,K) 172 | struct /= la.matrix_norm(struct) 173 | 174 | struct += 1 / K / sqrt(K - 1) 175 | struct.diagonal().sub_(1 / sqrt(K - 1)) 176 | 177 | return la.matrix_norm(struct).item() 178 | 179 | 180 | def simplex_etf_error(M: Tensor, m_G: Tensor = 0) -> float: 181 | """Compute the excess cross-class incoherence within a set of (mean) 182 | embeddings, relative to the ideal simplex ETF. 183 | Kothapalli (2023): https://arxiv.org/abs/2206.04041 184 | 185 | Arguments: 186 | M (Tensor): Matrix of feature (e.g. class mean) embeddings. 187 | m_G (Tensor, optional): Bias (e.g. global mean) vector. Defaults to 0. 188 | 189 | Returns: 190 | float: Scalar error of excess incoherence from simplex ETF. 191 | """ 192 | M_centred = M - m_G 193 | return structure_error(M_centred, M_centred) 194 | 195 | 196 | def self_duality_error(W: Tensor, M: Tensor, m_G: Tensor = 0) -> float: 197 | """Compute the excess cross-class incoherence between a set of (mean) 198 | embeddings and classifiers, relative to the ideal simplex ETF. 199 | Kothapalli (2023): https://arxiv.org/abs/2206.04041 200 | 201 | Arguments: 202 | W (Tensor): Weight vectors of the classifiers. 203 | M (Tensor): Matrix of feature (e.g. class mean) embeddings. 204 | m_G (Tensor, optional): Bias (e.g. global mean) vector. Defaults to 0. 205 | 206 | Returns: 207 | float: Scalar error of excess incoherence from simplex ETF. 208 | """ 209 | M_centred = M - m_G 210 | return structure_error(M_centred, W) 211 | 212 | 213 | def clf_ncc_agreement( 214 | accum: DecAccumulator, indices: List[int] = None, weighted: bool = True 215 | ) -> float: 216 | """Compute the rate of agreement between the linear and the implicit 217 | nearest-class centre (NCC) classifiers: percentage of hits over Ns samples. 218 | 219 | Arguments: 220 | accum (DecAccumulator): Tracker of per-class sample counts and hits. 221 | indices (List[int], optional): Indices of specific classes to include 222 | in agreement analysis. Defaults to [], to include all classes. 223 | weighted (bool, optional): Whether to weigh class hit rates by numbers 224 | of samples. Defaults to True. 225 | 226 | Returns: 227 | float: The rate of agreement as a float, or None if an error occurs 228 | (e.g. neither hits nor misses given, or shape mismatch) 229 | """ 230 | _, global_agree_rates = accum.compute(indices, weighted) 231 | return global_agree_rates.item() 232 | 233 | 234 | def orthogonality_deviation(M: Tensor, m_G_ood: Tensor = 0) -> float: 235 | """Compute the average normalized deviation of means from the global mean 236 | embedding from out-of-distribution (OoD) data. 237 | Ammar et al. (2024): https://arxiv.org/abs/2310.06823; 238 | 239 | Arguments: 240 | M (Tensor): Matrix of feature (e.g. class mean) embeddings. 241 | m_G_ood (Tensor, optional): Out-of-Distribution bias (e.g. global 242 | mean) vector. Defaults to 0. 243 | Returns: 244 | float: Average normalized deviation of means from the OoD global mean. 245 | """ 246 | deviations = pt.abs(similarities(M, m_G_ood, cos=True)) # (K) 247 | return pt.mean(deviations).item() 248 | -------------------------------------------------------------------------------- /neural_collapse/accumulate.py: -------------------------------------------------------------------------------- 1 | from abc import ABCMeta, abstractmethod 2 | from typing import List, Tuple, Union 3 | 4 | import torch as pt 5 | from torch import Tensor 6 | 7 | from .util import hashify, resolve 8 | 9 | 10 | class Accumulator(metaclass=ABCMeta): 11 | """Base class for accumulators that track sample counts and totals. 12 | 13 | This abstract class provides the foundation for different types of 14 | accumulators that calculate statistics based on input tensors. It manages 15 | sample counts for each class and includes methods to filter indices based 16 | on sample counts, accumulate data, and compute averages. 17 | 18 | Attributes: 19 | n_classes (int): Number of classes. 20 | d_vectors (int): Dimensionality of the input vectors. 21 | ctype (torch.dtype): Data type for counts. 22 | dtype (torch.dtype): Data type for totals and computations. 23 | device (Union[str, torch.device]): Device on which tensors are stored. 24 | ns_samples (Tensor): Tensor to store per-class sample counts. 25 | 26 | Methods: 27 | filter_indices_by_n_samples(minimum=0, maximum=None): 28 | Filters class indices based on a minimum and maximum sample count. 29 | class_idxs(X, Y): 30 | Increment sample counts and return per-class sample indices. 31 | compute(idxs=None, weighted=False): 32 | Computes averages of accumulated samples (totals/counts). 33 | accumulate(*args): 34 | Abstract method to be implemented by subclasses to define specific 35 | accumulation behavior. 36 | """ 37 | 38 | def __init__( 39 | self, 40 | n_classes: int, 41 | d_vectors: int, 42 | device: Union[str, pt.device] = "cpu", 43 | dtype: pt.dtype = pt.float32, 44 | ctype: pt.dtype = pt.int32, 45 | ): 46 | self.n_classes, self.d_vectors = n_classes, d_vectors 47 | self.ctype, self.dtype, self.device = ctype, dtype, device 48 | self.ns_samples = pt.zeros(self.n_classes, dtype=ctype).to(device) # (K) 49 | 50 | def filter_indices_by_n_samples( 51 | self, minimum: int = 0, maximum: int = None 52 | ) -> Union[Tensor, float]: 53 | idxs = self.ns_samples.squeeze() >= minimum 54 | assert pt.all(minimum <= self.ns_samples[idxs]) 55 | if maximum: 56 | idxs &= self.ns_samples.squeeze() < maximum 57 | assert pt.all(self.ns_samples[idxs] < maximum) 58 | 59 | filtered = idxs.nonzero().squeeze() 60 | 61 | return filtered 62 | 63 | def class_idxs(self, X: Tensor, Y: Tensor) -> Union[Tensor, float]: 64 | Y = Y.squeeze() 65 | assert X.shape[0] == Y.shape[0] 66 | Y_range = pt.arange(self.n_classes, dtype=self.ctype) 67 | idxs = (Y[:, None] == Y_range.to(Y.device)).to(self.device) 68 | self.ns_samples += pt.sum(idxs, dim=0, dtype=self.ctype)[:, None].squeeze() 69 | return idxs 70 | 71 | def compute( 72 | self, idxs: List[int] = None, weighted: bool = False 73 | ) -> Tuple[Tensor, Tensor]: 74 | ns_samples, totals = self.ns_samples, self.totals # (K), (K,D) 75 | if idxs is not None: 76 | ns_samples, totals = ns_samples[idxs], totals[idxs] # (K'), (K',D) 77 | if len(self.totals.shape) > 1: 78 | ns_samples = ns_samples.unsqueeze(1) 79 | 80 | eps = pt.finfo(self.dtype).eps 81 | avg = totals / (ns_samples + eps).to(self.dtype) # (K, D) 82 | if weighted: 83 | avg_G = ns_samples.to(self.dtype) @ avg / (ns_samples.sum() + eps) # (D) 84 | else: 85 | avg_G = avg.mean(dim=0) # (D) 86 | 87 | return avg, avg_G 88 | 89 | @abstractmethod 90 | def accumulate(self, *args): 91 | pass 92 | 93 | 94 | class MeanAccumulator(Accumulator): 95 | """Accumulator that computes mean vectors for each class. 96 | 97 | Inherits from the Accumulator class: accumulates the totals for each 98 | class to compute the mean of the input vectors. 99 | 100 | Methods: 101 | accumulate(X, Y): 102 | Increment and return per-class mean totals and sample counts. 103 | """ 104 | 105 | def __init__(self, *args, **kwargs): 106 | super().__init__(*args, **kwargs) 107 | dtype, device = self.dtype, self.device 108 | self.totals = pt.zeros(self.n_classes, self.d_vectors, dtype=dtype).to(device) 109 | 110 | def accumulate(self, X: Tensor, Y: Tensor) -> Tuple[Tensor, Tensor]: 111 | idxs = self.class_idxs(X, Y).mT.to(self.dtype) # (K,B) 112 | self.totals += idxs @ X.to(device=self.device, dtype=self.dtype) # (K,D) 113 | return self.ns_samples, self.totals 114 | 115 | 116 | class CovarAccumulator(Accumulator): 117 | """Accumulator that computes covariance matrices for each class. 118 | 119 | Inherits from the Accumulator class: accumulates the differences between 120 | input vectors and class-specific mean vectors to compute covariance. 121 | 122 | Attributes: 123 | hash_M (hash): Hash of the mean vectors for ensuring consistency. 124 | totals (Tensor): Matrix to store the accumulated covariance totals. 125 | 126 | Methods: 127 | accumulate(X, Y, M): 128 | Increment covariance totals with squared differences between 129 | input vectors and their label-corresponding class means. 130 | compute(idxs=None): 131 | Computes the average covariance matrix from accumulated totals. 132 | """ 133 | 134 | def __init__(self, *args, M: Tensor = None, **kwargs): 135 | super().__init__(*args, **kwargs) 136 | D, dtype, device = self.d_vectors, self.dtype, self.device 137 | self.hash_M = None if M is None else hashify(M) 138 | self.totals = pt.zeros(D, D, dtype=dtype).to(device) 139 | 140 | def accumulate(self, X: Tensor, Y: Tensor, M: Tensor) -> Tuple[Tensor, Tensor]: 141 | self.hash_M = resolve(self.hash_M, hashify(M)) 142 | assert self.hash_M 143 | 144 | M = M.to(self.device, self.dtype) 145 | assert M.shape == (self.n_classes, self.d_vectors) 146 | 147 | self.class_idxs(X, Y) 148 | diff = X.to(self.device) - M[Y] # (B,D) 149 | self.totals += diff.mT @ diff # (D,D) 150 | 151 | return self.ns_samples, self.totals 152 | 153 | def compute(self, idxs: List[int] = None) -> Tuple[Tensor, Tensor]: 154 | ns_samples, totals = self.ns_samples, self.totals # (K), (K,D) 155 | if idxs is not None: 156 | ns_samples, totals = ns_samples[idxs], totals[idxs] 157 | 158 | eps = pt.finfo(self.dtype).eps 159 | return totals / (ns_samples.sum() + eps).to(self.dtype) 160 | 161 | 162 | class VarNormAccumulator(Accumulator): 163 | """Accumulator that computes the variance norms for each class. 164 | 165 | Inherits from the Accumulator class and calculates the variance of input 166 | vectors relative to class-specific mean vectors. 167 | 168 | Attributes: 169 | hash_M (hash): Hash of the mean vectors for ensuring consistency. 170 | totals (Tensor): A tensor to hold the accumulated variance totals. 171 | 172 | Methods: 173 | accumulate(X, Y, M): 174 | Increment per-class totals with norms of squared differences 175 | between input vectors and their label-corresponding class means. 176 | """ 177 | 178 | def __init__(self, *args, M: Tensor = None, **kwargs): 179 | super().__init__(*args, **kwargs) 180 | self.hash_M = None if M is None else hashify(M) 181 | self.totals = pt.zeros(self.n_classes, dtype=self.dtype).to(self.device) 182 | 183 | def accumulate(self, X: Tensor, Y: Tensor, M: Tensor) -> Tuple[Tensor, Tensor]: 184 | self.hash_M = resolve(self.hash_M, hashify(M)) 185 | assert self.hash_M 186 | 187 | M = M.to(self.device, self.dtype) 188 | assert M.shape == (self.n_classes, self.d_vectors) 189 | 190 | idxs = self.class_idxs(X, Y).mT.to(self.dtype) # (K,B) 191 | diffs_sq = (X.to(self.device) - M[Y]) ** 2 # (B,D) 192 | self.totals += (idxs @ diffs_sq).sum(dim=-1) # (K,D) 193 | 194 | return self.ns_samples, self.totals 195 | 196 | 197 | class DecAccumulator(Accumulator): 198 | """Accumulator that computes decision hits from multiple classifiers. 199 | 200 | Inherits from the Accumulator class and integrates results from 201 | nearest-class center and linear classifiers to track hits for each class. 202 | 203 | Attributes: 204 | hash_M (hash): Hash of the mean vectors for ensuring consistency. 205 | hash_W (hash): Hash of the linear classifier weights for consistency. 206 | index (IndexFlatL2): FAISS index for efficient nearest neighbors. 207 | totals (Tensor): Tensor of per-class hit counts. 208 | 209 | Methods: 210 | create_index(M): 211 | Initializes a FAISS (if installed) index with the provided mean 212 | vectors. If FAISS not found, do nothing. 213 | accumulate(X, Y, W, M=None): 214 | Updates the totals based on hits between the nearest-class center 215 | and linear classifiers. 216 | compute(idxs=None, weighted=True): 217 | Computes per-class and global classifier decision agreement rates. 218 | """ 219 | 220 | def __init__(self, *args, M: Tensor = None, W: Tensor = None, **kwargs): 221 | super().__init__(*args, **kwargs) 222 | self.hash_M = None if M is None else hashify(M) 223 | self.hash_W = None if W is None else hashify(W) 224 | self.index = None 225 | self.totals = pt.zeros(self.n_classes, dtype=self.ctype).to(self.device) 226 | 227 | def create_index(self, M: Tensor): 228 | self.hash_M = resolve(self.hash_M, hashify(M)) 229 | assert self.hash_M 230 | 231 | try: 232 | from faiss import IndexFlatL2 233 | 234 | self.index = IndexFlatL2(self.d_vectors) 235 | self.index.add(M.cpu().numpy()) 236 | except: 237 | self.index = None 238 | 239 | def accumulate( 240 | self, 241 | X: Tensor, 242 | Y: Tensor, 243 | W: Tensor, 244 | M: Tensor = None, 245 | ) -> Tuple[Tensor, Tensor]: 246 | 247 | self.hash_W = resolve(self.hash_W, hashify(W)) 248 | assert self.hash_W 249 | assert W.shape == (self.n_classes, self.d_vectors) 250 | X, W = X.to(self.device, self.dtype), W.to(self.device, self.dtype) 251 | 252 | # NCC classifier decisions 253 | if self.index: # using FAISS index 254 | _, I = self.index.search(X.cpu().numpy(), 1) 255 | Y_ncc = pt.tensor(I).to(self.device).squeeze() # (B) 256 | else: # manual near-class centre, using given means 257 | assert type(M) == Tensor 258 | self.hash_M = resolve(self.hash_M, hashify(M)) 259 | assert self.hash_M 260 | assert M.shape == (self.n_classes, self.d_vectors) 261 | M = M.to(self.device, self.dtype) 262 | 263 | dots = pt.inner(X, M) # (B,K) 264 | feats, centre = pt.norm(X, dim=-1) ** 2, pt.norm(M, dim=-1) ** 2 # (B), (K) 265 | dists = feats.unsqueeze(1) + centre.unsqueeze(0) - 2 * dots # (B,K) 266 | Y_ncc = dists.argmin(dim=-1) # (B) 267 | 268 | # linear classifier decisions 269 | Y_lin = (X @ W.mT).argmax(dim=-1) # (B) 270 | 271 | # count matches between classifiers 272 | matches = (Y_lin == Y_ncc).to(self.ctype) # (B) 273 | self.class_idxs(X, Y) 274 | self.totals.scatter_add_(0, Y.to(self.device, pt.int64), matches) 275 | 276 | return self.ns_samples, self.totals 277 | --------------------------------------------------------------------------------