├── datashifts.egg-info ├── dependency_links.txt ├── top_level.txt ├── .ipynb_checkpoints │ ├── dependency_links-checkpoint.txt │ ├── top_level-checkpoint.txt │ ├── requires-checkpoint.txt │ ├── SOURCES-checkpoint.txt │ └── PKG-INFO-checkpoint ├── requires.txt ├── SOURCES.txt └── PKG-INFO ├── datashifts ├── __init__.py ├── .ipynb_checkpoints │ ├── __init__-checkpoint.py │ └── core-checkpoint.py ├── __pycache__ │ ├── core.cpython-38.pyc │ ├── core.cpython-39.pyc │ ├── __init__.cpython-38.pyc │ └── __init__.cpython-39.pyc └── core.py ├── dist ├── datashifts-0.8.4.tar.gz └── datashifts-0.8.4-py3-none-any.whl ├── pyproject.toml ├── LICENSE ├── Demo.ipynb ├── logo ├── datashifts.svg └── .ipynb_checkpoints │ └── datashifts-checkpoint.svg └── README.md /datashifts.egg-info/dependency_links.txt: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /datashifts.egg-info/top_level.txt: -------------------------------------------------------------------------------- 1 | datashifts 2 | -------------------------------------------------------------------------------- /datashifts.egg-info/.ipynb_checkpoints/dependency_links-checkpoint.txt: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /datashifts.egg-info/.ipynb_checkpoints/top_level-checkpoint.txt: -------------------------------------------------------------------------------- 1 | datashifts 2 | -------------------------------------------------------------------------------- /datashifts.egg-info/requires.txt: -------------------------------------------------------------------------------- 1 | numpy>=1.16 2 | pykeops>=2.2 3 | torch>=1.8 4 | geomloss>=0.2.6 5 | -------------------------------------------------------------------------------- /datashifts/__init__.py: -------------------------------------------------------------------------------- 1 | from .core import DataShifts 2 | __all__=["datashifts"] 3 | __version__ = "0.8.4" -------------------------------------------------------------------------------- /dist/datashifts-0.8.4.tar.gz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DataShifts/datashifts/HEAD/dist/datashifts-0.8.4.tar.gz -------------------------------------------------------------------------------- /datashifts.egg-info/.ipynb_checkpoints/requires-checkpoint.txt: -------------------------------------------------------------------------------- 1 | numpy>=1.16 2 | pykeops>=2.2 3 | torch>=1.8 4 | geomloss>=0.2.6 5 | -------------------------------------------------------------------------------- /datashifts/.ipynb_checkpoints/__init__-checkpoint.py: -------------------------------------------------------------------------------- 1 | from .core import DataShifts 2 | __all__=["datashifts"] 3 | __version__ = "0.8.4" -------------------------------------------------------------------------------- /dist/datashifts-0.8.4-py3-none-any.whl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DataShifts/datashifts/HEAD/dist/datashifts-0.8.4-py3-none-any.whl -------------------------------------------------------------------------------- /datashifts/__pycache__/core.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DataShifts/datashifts/HEAD/datashifts/__pycache__/core.cpython-38.pyc -------------------------------------------------------------------------------- /datashifts/__pycache__/core.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DataShifts/datashifts/HEAD/datashifts/__pycache__/core.cpython-39.pyc -------------------------------------------------------------------------------- /datashifts/__pycache__/__init__.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DataShifts/datashifts/HEAD/datashifts/__pycache__/__init__.cpython-38.pyc -------------------------------------------------------------------------------- /datashifts/__pycache__/__init__.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DataShifts/datashifts/HEAD/datashifts/__pycache__/__init__.cpython-39.pyc -------------------------------------------------------------------------------- /datashifts.egg-info/.ipynb_checkpoints/SOURCES-checkpoint.txt: -------------------------------------------------------------------------------- 1 | LICENSE.txt 2 | README.md 3 | pyproject.toml 4 | datashifts/__init__.py 5 | datashifts/core.py 6 | datashifts.egg-info/PKG-INFO 7 | datashifts.egg-info/SOURCES.txt 8 | datashifts.egg-info/dependency_links.txt 9 | datashifts.egg-info/requires.txt 10 | datashifts.egg-info/top_level.txt -------------------------------------------------------------------------------- /datashifts.egg-info/SOURCES.txt: -------------------------------------------------------------------------------- 1 | LICENSE 2 | README.md 3 | pyproject.toml 4 | datashifts/__init__.py 5 | datashifts/core.py 6 | datashifts.egg-info/PKG-INFO 7 | datashifts.egg-info/SOURCES.txt 8 | datashifts.egg-info/dependency_links.txt 9 | datashifts.egg-info/requires.txt 10 | datashifts.egg-info/top_level.txt 11 | datashifts.egg-info/.ipynb_checkpoints/PKG-INFO-checkpoint 12 | datashifts.egg-info/.ipynb_checkpoints/SOURCES-checkpoint.txt 13 | datashifts.egg-info/.ipynb_checkpoints/dependency_links-checkpoint.txt 14 | datashifts.egg-info/.ipynb_checkpoints/requires-checkpoint.txt 15 | datashifts.egg-info/.ipynb_checkpoints/top_level-checkpoint.txt -------------------------------------------------------------------------------- /datashifts.egg-info/.ipynb_checkpoints/PKG-INFO-checkpoint: -------------------------------------------------------------------------------- 1 | Metadata-Version: 2.1 2 | Name: datashifts 3 | Version: 0.0.1 4 | Summary: Quantify and analyze distribution shifts from samples. 5 | Author-email: Hongbo Chen 6 | License: MIT 7 | Project-URL: Homepage, https://github.com/DataShifts/datashifts 8 | Project-URL: Bug Tracker, https://github.com/DataShifts/datashifts/issues 9 | Requires-Python: >=3.8 10 | Description-Content-Type: text/markdown 11 | License-File: LICENSE.txt 12 | Requires-Dist: numpy>=1.16 13 | Requires-Dist: pykeops>=2.2 14 | Requires-Dist: torch>=1.8 15 | Requires-Dist: geomloss>=0.2.6 16 | 17 | # datashifts 18 | Quantify and analyze distribution shifts from samples. 19 | - Official Documentation Coming Soon! 20 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [build-system] 2 | requires = ["setuptools>=68", "wheel"] 3 | build-backend = "setuptools.build_meta" 4 | 5 | [project] 6 | name = "datashifts" 7 | version = "0.8.4" 8 | description = "Quantify and analyze distribution shifts from samples." 9 | readme = "README.md" 10 | requires-python = ">=3.8" 11 | license = {text = "MIT"} 12 | authors = [ 13 | {name = "Hongbo Chen", email = "hongboc616@gmail.com"} 14 | ] 15 | dependencies = [ 16 | "numpy>=1.16", 17 | "pykeops>=2.2", 18 | "torch>=1.8", 19 | "geomloss>=0.2.6" 20 | ] 21 | 22 | [tool.setuptools] 23 | packages = ["datashifts"] 24 | 25 | [project.urls] 26 | "Homepage" = "https://github.com/DataShifts/datashifts" 27 | "Bug Tracker" = "https://github.com/DataShifts/datashifts/issues" 28 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2025 DataShifts 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /Demo.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 1, 6 | "id": "e9d97dd6-ec22-434e-b55d-b9cb719f6833", 7 | "metadata": {}, 8 | "outputs": [], 9 | "source": [ 10 | "import torch\n", 11 | "from datashifts import DataShifts\n", 12 | "import time" 13 | ] 14 | }, 15 | { 16 | "cell_type": "markdown", 17 | "id": "cfe7b883-ee18-4898-963a-c86b0f6a3662", 18 | "metadata": {}, 19 | "source": [ 20 | "# Generate data from two different distributions\n", 21 | "Take labels originating from pure noise as an example." 22 | ] 23 | }, 24 | { 25 | "cell_type": "code", 26 | "execution_count": 2, 27 | "id": "9f8675f4-50a2-4afb-8668-0d8a8f8023fd", 28 | "metadata": {}, 29 | "outputs": [], 30 | "source": [ 31 | "N=10000 #Number of samples\n", 32 | "x_dim=200 #Feature dimensions\n", 33 | "y_dim=10 #Label dimensions\n", 34 | "x_shift=10.0 #True covariate shift\n", 35 | "device=\"cuda\" #Device\n", 36 | "\n", 37 | "random_directions=torch.randn(1, x_dim, device=device)\n", 38 | "x_shift_vector=random_directions/((random_directions**2).sum()**(1/2))*x_shift\n", 39 | "# First distribution\n", 40 | "x1 = torch.randn(N, x_dim, device=device)\n", 41 | "y1= torch.rand(N, y_dim, device=device)\n", 42 | "# Second distribution\n", 43 | "x2 = torch.randn(N, x_dim, device=device)+x_shift_vector\n", 44 | "y2= torch.rand(N, y_dim, device=device)" 45 | ] 46 | }, 47 | { 48 | "cell_type": "markdown", 49 | "id": "481aaf39-4d95-47a8-a594-01a1b7f4a4ca", 50 | "metadata": {}, 51 | "source": [ 52 | "# Using datashifts to quantify covariate and concept shifts" 53 | ] 54 | }, 55 | { 56 | "cell_type": "code", 57 | "execution_count": 4, 58 | "id": "6f0da62c-bf29-4a1e-b0f7-adc1bae8f471", 59 | "metadata": {}, 60 | "outputs": [ 61 | { 62 | "name": "stdout", 63 | "output_type": "stream", 64 | "text": [ 65 | "The sample size of (x1,y1,w1) is larger than parameter 'N_max'=5000, sampling strategy is used.\n", 66 | "The sample size of (x2,y2,w2) is larger than parameter 'N_max'=5000, sampling strategy is used.\n", 67 | "Time-consuming: 3.565563201904297\n", 68 | "Covariate shift: tensor(10.0383, device='cuda:0')\n", 69 | "Concept shift: tensor(1.2612, device='cuda:0')\n" 70 | ] 71 | } 72 | ], 73 | "source": [ 74 | "time0=time.time()\n", 75 | "covariate_shift, concept_shift=DataShifts(x1, x2, y1, y2)\n", 76 | "print(\"Time-consuming: \",time.time()-time0)\n", 77 | "print(\"Covariate shift: \", covariate_shift)\n", 78 | "print(\"Concept shift: \", concept_shift )" 79 | ] 80 | }, 81 | { 82 | "cell_type": "code", 83 | "execution_count": null, 84 | "id": "189a1964-cfc2-4dd4-aec1-747b5c47fd08", 85 | "metadata": {}, 86 | "outputs": [], 87 | "source": [] 88 | } 89 | ], 90 | "metadata": { 91 | "kernelspec": { 92 | "display_name": "myconda", 93 | "language": "python", 94 | "name": "myconda" 95 | }, 96 | "language_info": { 97 | "codemirror_mode": { 98 | "name": "ipython", 99 | "version": 3 100 | }, 101 | "file_extension": ".py", 102 | "mimetype": "text/x-python", 103 | "name": "python", 104 | "nbconvert_exporter": "python", 105 | "pygments_lexer": "ipython3", 106 | "version": "3.8.12" 107 | } 108 | }, 109 | "nbformat": 4, 110 | "nbformat_minor": 5 111 | } 112 | -------------------------------------------------------------------------------- /logo/datashifts.svg: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /logo/.ipynb_checkpoints/datashifts-checkpoint.svg: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 |
2 | 3 |

4 | DataShifts Logo 5 |

6 | 7 |
8 | 9 | -------------------------------------------------------------------------------- 10 | 11 | # DataShifts — A Toolkit for Quantifying Distribution Shifts 12 | 13 | [![PyPI version](https://img.shields.io/pypi/v/datashifts?color=blue)](https://pypi.org/project/datashifts/) [![PyPI downloads](https://pepy.tech/badge/datashifts?color=green)](https://pypi.org/project/datashifts/) [![License](https://img.shields.io/pypi/l/datashifts.svg)](https://github.com/DataShifts/datashifts/blob/main/LICENSE) 14 | 15 | DataShifts is a Python package that makes it simple to **measure and analyze the distribution shifts from labeled samples**. It can be used with tensor computation frameworks such as [PyTorch](https://github.com/pytorch/pytorch), [NumPy](https://github.com/numpy/numpy) and [KeOps](https://github.com/getkeops/keops). It is designed for data science practitioners who need a principled way to answer questions such as: 16 | 17 | * *How far has my production data shifted from the training set?* 18 | 19 | * *How do the model’s representations shift in a new domain, and are they robust to distribution shifts?* 20 | 21 | * *Are the distribution shifts mainly in the inputs (covariate shift) or in the labels (concept shift)?* 22 | 23 | * *How do these distribution shifts affect model performance?* 24 | 25 | In analysis, distribution shift is often decomposed into **covariate shift ($X$ shift)** and **concept shift ($Y|X$ shift)**. The general theory below shows that the error bound **scales linearly with** these two shifts. With a single call, **DataShifts** estimates these two shifts from labeled samples, providing a rigorous and general tool for quantifying and analyzing distribution shift. 26 | 27 | --- 28 | 29 | ## Core Theory — General Learning Bound under Distribution Shifts 30 | 31 | Let the covariate and label spaces be metric spaces $(\mathcal{X} ,\rho _{\mathcal{X}}),(\mathcal{Y} ,\rho _{\mathcal{Y}})$, and $\mathcal{D} _{XY}^{A}, \mathcal{D} _{XY}^{B}$ are two joint distributions of covariates and labels on $\mathcal{X}\times\mathcal{Y}$. If the hypothesis $h:\mathcal{X} \rightarrow \mathcal{Y}'$ is $L _h$-Lipschitz continuous, loss $\ell :\mathcal{Y} \times \mathcal{Y} '\rightarrow \mathbb{R}$ is separately $(L _{\ell},L _{\ell}')$-Lipschitz continuous, then: 32 | 33 | $$ 34 | \LARGE 35 | \epsilon _B(h)\le \epsilon _A(h)+L _hL _{\ell}'S _{Cov}+L _{\ell}S _{Cpt}^{\gamma ^*} 36 | $$ 37 | 38 | where $\epsilon _A(h), \epsilon _B(h)$ are the errors of hypothesis $h$ under the distributions $\mathcal{D} _{XY}^{A}, \mathcal{D} _{XY}^{B}$, respectively. $S _{Cov}, S _{Cpt}^{\gamma ^*}$ are **covariate shift** (= $X$ shift, distribution shift of covariates) and **concept shift** (= $Y|X$ shift, distribution shift of labels conditioned on covariates) between $\mathcal{D} _{XY}^{A}, \mathcal{D} _{XY}^{B}$. Both shifts are defined in closed form via **entropic optimal transport**. 39 | 40 | This elegant theory shows how distribution shifts affect the error, and has the following advantages: 41 | 42 | * **General**: Because the theory assumes no particular loss or space, it applies broadly to losses and tasks—including regression, classification, and multi-label problems, as long as the covariate and label space of the problem can define metrics. Moreover, depending on whether the covariate space is the raw feature space or the model’s representation space, the theory can measure shifts in either the original data or the learned representations. 43 | 44 | * **Estimable**: Both covariate shift $S _{Cov}$ and concept shift $S _{Cpt}^{\gamma ^*}$ in the theory can be rigorously estimated from finite samples drawn from the two distributions—**which is the core capability of this package**. 45 | 46 | For further theoretical details, please see our [original paper](https://arxiv.org/abs/2506.12829). 47 | 48 | --- 49 | ## Installation 50 | 51 | Just use the following command to install DataShifts package: 52 | ```shell 53 | pip install datashifts 54 | ``` 55 | 56 | 57 | --- 58 | 59 | ## Quick Example 60 | 61 | ```python 62 | import torch 63 | from datashifts import DataShifts 64 | 65 | # Generate data from two different distributions (take labels originating from pure noise as an example) 66 | N=10000 #Number of samples 67 | x_dim=200 #Feature dimensions 68 | y_dim=10 #Label dimensions 69 | x_shift=10.0 #True covariate shift 70 | device="cuda" #Device 71 | 72 | random_directions=torch.randn(1, x_dim, device=device) 73 | x_shift_vector=random_directions/((random_directions**2).sum()**(1/2))*x_shift 74 | # First distribution 75 | x1 = torch.randn(N, x_dim, device=device) 76 | y1= torch.rand(N, y_dim, device=device) 77 | # Second distribution 78 | x2 = torch.randn(N, x_dim, device=device)+x_shift_vector 79 | y2= torch.rand(N, y_dim, device=device) 80 | 81 | # Using DataShifts to quantify covariate and concept shifts 82 | covariate_shift, concept_shift=DataShifts(x1, x2, y1, y2) 83 | print("Covariate shift: ", covariate_shift) 84 | print("Concept shift: ", concept_shift ) 85 | ``` 86 | 87 | Typical output 88 | 89 | ``` 90 | The sample size of (x1,y1,w1) is larger than parameter 'N_max'=5000, sampling strategy is used. 91 | The sample size of (x2,y2,w2) is larger than parameter 'N_max'=5000, sampling strategy is used. 92 | Covariate shift: tensor(9.9608, device='cuda:0') 93 | Concept shift: tensor(1.2627, device='cuda:0') 94 | ``` 95 | 96 | --- 97 | 98 | ## `datashifts.DataShifts` —  Measure Covariate  &  Concept Shift between Distributions from Samples 99 | 100 | `datashifts.DataShifts` is the core method of the DataShifts package, which estimates covariate shift and concept shift from finite labeled samples `(x1,y1), (x2,y2)` drawn from two distributions, with automatic sub‑sampling for scalability and GPU acceleration. 101 | 102 | ```python 103 | covariate_shift, concept_shift = DataShifts( 104 | x1, x2, y1, y2, # required 105 | weights1=None, weights2=None, # optional importance weights 106 | eps=0.01, # entropic regularisation 107 | N_max=5000, # max points kept per distribution 108 | device=None, # "cpu", "cuda" or None (auto) 109 | seed=None, # random seed for reproducibility 110 | verbose=True # print progress messages 111 | ) 112 | ``` 113 | 114 | *Note (temporary): For now, Euclidean distance is the only built-in metric. Custom metrics are planned.* 115 | 116 | ### Parameters 117 | 118 | | name | type | default | description | 119 | | ---------------------- | ------------------------------------- | ------- | ------------------------------------------------------------ | 120 | | `x1`, `x2` | `torch.Tensor` **or** `numpy.ndarray` | — | Covariates of the samples drawn from two distributions.
Shapes accepted:`(Batch_size, Num_samples, Dim_x)` or `(Num_samples, Dim_x)` | 121 | | `y1`, `y2` | `torch.Tensor` **or** `numpy.ndarray` | — | Corresponding labels.
Shapes accepted:`(Batch_size, Num_samples, Dim_y)` or `(Num_samples, Dim_y)`. Must match `x*` in `Batch_size` and `Num_samples` dimensions. | 122 | | `weights1`, `weights2` | `torch.Tensor` **or** `numpy.ndarray` | `None` | Sample weights.
Shapes accepted:`(Batch_size, Num_samples)` or `(Num_samples)`. Must match `x*` in `Batch_size` and `Num_samples` dimensions. | 123 | | `eps` | `float` | `0.01` | Entropic regularisation for optimal transport. Smaller => more precise but slower. | 124 | | `N_max` | `int` | `5000` | Upper bound on samples per distribution kept for optimal transport. If `N>N_max`, the function resamples without replacement to speed up the solution (weighted if `weights*` provided). Larger => more precise but slower. | 125 | | `device` | `str` | `None` | Running device.
`"cpu"`, `"cuda"`/`"gpu"`, or `None`(= automatically use GPU if available). | 126 | | `seed` | `int` | `None` | Random seed for shuffling and sampling. | 127 | | `verbose` | `bool` | `True` | Whether to print progress messages (sampling or automatic device choice). | 128 | 129 | 130 | ### Returns 131 | 132 | ```python 133 | covariate_shift : torch.Tensor 134 | concept_shift : torch.Tensor 135 | ``` 136 | 137 | Returned objects are PyTorch tensors placed on the chosen `device`. 138 | 139 | --- 140 | ## Licensing, Citation, Academic Use 141 | 142 | This package is released under the [MIT License](https://en.wikipedia.org/wiki/MIT_License). See the [LICENSE](https://github.com/DataShifts/datashifts/blob/main/LICENSE) file for full details. 143 | 144 | If you use this package in a research paper, **please cite** our [original paper](https://arxiv.org/abs/2506.12829): 145 | 146 | ```latex 147 | @article{chen2025general, 148 | title={General and Estimable Learning Bound Unifying Covariate and Concept Shifts}, 149 | author={Chen, Hongbo and Xia, Li Charlie}, 150 | journal={arXiv preprint arXiv:2506.12829}, 151 | year={2025} 152 | } 153 | ``` 154 | 155 | --- 156 | 157 | > **Contributions & issues** welcome at https://github.com/DataShifts/datashifts/issues 158 | -------------------------------------------------------------------------------- /datashifts.egg-info/PKG-INFO: -------------------------------------------------------------------------------- 1 | Metadata-Version: 2.1 2 | Name: datashifts 3 | Version: 0.8.4 4 | Summary: Quantify and analyze distribution shifts from samples. 5 | Author-email: Hongbo Chen 6 | License: MIT 7 | Project-URL: Homepage, https://github.com/DataShifts/datashifts 8 | Project-URL: Bug Tracker, https://github.com/DataShifts/datashifts/issues 9 | Requires-Python: >=3.8 10 | Description-Content-Type: text/markdown 11 | License-File: LICENSE 12 | Requires-Dist: numpy>=1.16 13 | Requires-Dist: pykeops>=2.2 14 | Requires-Dist: torch>=1.8 15 | Requires-Dist: geomloss>=0.2.6 16 | 17 |
18 | 19 |

20 | DataShifts Logo 21 |

22 | 23 |
24 | 25 | -------------------------------------------------------------------------------- 26 | 27 | # DataShifts — A Toolkit for Quantifying Distribution Shifts 28 | 29 | [![PyPI version](https://img.shields.io/pypi/v/datashifts?color=blue)](https://pypi.org/project/datashifts/) [![PyPI downloads](https://pepy.tech/badge/datashifts?color=green)](https://pypi.org/project/datashifts/) [![License](https://img.shields.io/pypi/l/datashifts.svg)](https://github.com/DataShifts/datashifts/blob/main/LICENSE) 30 | 31 | DataShifts is a Python package that makes it simple to **measure and analyze the distribution shifts from labeled samples**. It can be used with tensor computation frameworks such as [PyTorch](https://github.com/pytorch/pytorch), [NumPy](https://github.com/numpy/numpy) and [KeOps](https://github.com/getkeops/keops). It is designed for data science practitioners who need a principled way to answer questions such as: 32 | 33 | * *How far has my production data shifted from the training set?* 34 | 35 | * *How do the model’s representations shift in a new domain, and are they robust to distribution shifts?* 36 | 37 | * *Are the distribution shifts mainly in the inputs (covariate shift) or in the labels (concept shift)?* 38 | 39 | * *How do these distribution shifts affect model performance?* 40 | 41 | In analysis, distribution shift is often decomposed into **covariate shift ($X$ shift)** and **concept shift ($Y|X$ shift)**. The general theory below shows that the error bound **scales linearly with** these two shifts. With a single call, **DataShifts** estimates these two shifts from labeled samples, providing a rigorous and general tool for quantifying and analyzing distribution shift. 42 | 43 | --- 44 | 45 | ## Core Theory — General Learning Bound under Distribution Shifts 46 | 47 | Let the covariate and label spaces be metric spaces $(\mathcal{X} ,\rho _{\mathcal{X}}),(\mathcal{Y} ,\rho _{\mathcal{Y}})$, and $\mathcal{D} _{XY}^{A}, \mathcal{D} _{XY}^{B}$ are two joint distributions of covariates and labels on $\mathcal{X}\times\mathcal{Y}$. If the hypothesis $h:\mathcal{X} \rightarrow \mathcal{Y}'$ is $L _h$-Lipschitz continuous, loss $\ell :\mathcal{Y} \times \mathcal{Y} '\rightarrow \mathbb{R}$ is separately $(L _{\ell},L _{\ell}')$-Lipschitz continuous, then: 48 | 49 | $$ 50 | \LARGE 51 | \epsilon _B(h)\le \epsilon _A(h)+L _hL _{\ell}'S _{Cov}+L _{\ell}S _{Cpt}^{\gamma ^*} 52 | $$ 53 | 54 | where $\epsilon _A(h), \epsilon _B(h)$ are the errors of hypothesis $h$ under the distributions $\mathcal{D} _{XY}^{A}, \mathcal{D} _{XY}^{B}$, respectively. $S _{Cov}, S _{Cpt}^{\gamma ^*}$ are **covariate shift** (= $X$ shift, distribution shift of covariates) and **concept shift** (= $Y|X$ shift, distribution shift of labels conditioned on covariates) between $\mathcal{D} _{XY}^{A}, \mathcal{D} _{XY}^{B}$. Both shifts are defined in closed form via **entropic optimal transport**. 55 | 56 | This elegant theory shows how distribution shifts affect the error, and has the following advantages: 57 | 58 | * **General**: Because the theory assumes no particular loss or space, it applies broadly to losses and tasks—including regression, classification, and multi-label problems, as long as the covariate and label space of the problem can define metrics. Moreover, depending on whether the covariate space is the raw feature space or the model’s representation space, the theory can measure shifts in either the original data or the learned representations. 59 | 60 | * **Estimable**: Both covariate shift $S _{Cov}$ and concept shift $S _{Cpt}^{\gamma ^*}$ in the theory can be rigorously estimated from finite samples drawn from the two distributions—**which is the core capability of this package**. 61 | 62 | For further theoretical details, please see our [original paper](https://arxiv.org/abs/2506.12829). 63 | 64 | --- 65 | ## Installation 66 | 67 | Just use the following command to install DataShifts package: 68 | ```shell 69 | pip install datashifts 70 | ``` 71 | 72 | 73 | --- 74 | 75 | ## Quick Example 76 | 77 | ```python 78 | import torch 79 | from datashifts import DataShifts 80 | 81 | # Generate data from two different distributions (take labels originating from pure noise as an example) 82 | N=10000 #Number of samples 83 | x_dim=200 #Feature dimensions 84 | y_dim=10 #Label dimensions 85 | x_shift=10.0 #True covariate shift 86 | device="cuda" #Device 87 | 88 | random_directions=torch.randn(1, x_dim, device=device) 89 | x_shift_vector=random_directions/((random_directions**2).sum()**(1/2))*x_shift 90 | # First distribution 91 | x1 = torch.randn(N, x_dim, device=device) 92 | y1= torch.rand(N, y_dim, device=device) 93 | # Second distribution 94 | x2 = torch.randn(N, x_dim, device=device)+x_shift_vector 95 | y2= torch.rand(N, y_dim, device=device) 96 | 97 | # Using DataShifts to quantify covariate and concept shifts 98 | covariate_shift, concept_shift=DataShifts(x1, x2, y1, y2) 99 | print("Covariate shift: ", covariate_shift) 100 | print("Concept shift: ", concept_shift ) 101 | ``` 102 | 103 | Typical output 104 | 105 | ``` 106 | The sample size of (x1,y1,w1) is larger than parameter 'N_max'=5000, sampling strategy is used. 107 | The sample size of (x2,y2,w2) is larger than parameter 'N_max'=5000, sampling strategy is used. 108 | Covariate shift: tensor(9.9608, device='cuda:0') 109 | Concept shift: tensor(1.2627, device='cuda:0') 110 | ``` 111 | 112 | --- 113 | 114 | ## `datashifts.DataShifts` —  Measure Covariate  &  Concept Shift between Distributions from Samples 115 | 116 | `datashifts.DataShifts` is the core method of the DataShifts package, which estimates covariate shift and concept shift from finite labeled samples `(x1,y1), (x2,y2)` drawn from two distributions, with automatic sub‑sampling for scalability and GPU acceleration. 117 | 118 | ```python 119 | covariate_shift, concept_shift = DataShifts( 120 | x1, x2, y1, y2, # required 121 | weights1=None, weights2=None, # optional importance weights 122 | eps=0.01, # entropic regularisation 123 | N_max=5000, # max points kept per distribution 124 | device=None, # "cpu", "cuda" or None (auto) 125 | seed=None, # random seed for reproducibility 126 | verbose=True # print progress messages 127 | ) 128 | ``` 129 | 130 | *Note (temporary): For now, Euclidean distance is the only built-in metric. Custom metrics are planned.* 131 | 132 | ### Parameters 133 | 134 | | name | type | default | description | 135 | | ---------------------- | ------------------------------------- | ------- | ------------------------------------------------------------ | 136 | | `x1`, `x2` | `torch.Tensor` **or** `numpy.ndarray` | — | Covariates of the samples drawn from two distributions.
Shapes accepted:`(Batch_size, Num_samples, Dim_x)` or `(Num_samples, Dim_x)` | 137 | | `y1`, `y2` | `torch.Tensor` **or** `numpy.ndarray` | — | Corresponding labels.
Shapes accepted:`(Batch_size, Num_samples, Dim_y)` or `(Num_samples, Dim_y)`. Must match `x*` in `Batch_size` and `Num_samples` dimensions. | 138 | | `weights1`, `weights2` | `torch.Tensor` **or** `numpy.ndarray` | `None` | Sample weights.
Shapes accepted:`(Batch_size, Num_samples)` or `(Num_samples)`. Must match `x*` in `Batch_size` and `Num_samples` dimensions. | 139 | | `eps` | `float` | `0.01` | Entropic regularisation for optimal transport. Smaller => more precise but slower. | 140 | | `N_max` | `int` | `5000` | Upper bound on samples per distribution kept for optimal transport. If `N>N_max`, the function resamples without replacement to speed up the solution (weighted if `weights*` provided). Larger => more precise but slower. | 141 | | `device` | `str` | `None` | Running device.
`"cpu"`, `"cuda"`/`"gpu"`, or `None`(= automatically use GPU if available). | 142 | | `seed` | `int` | `None` | Random seed for shuffling and sampling. | 143 | | `verbose` | `bool` | `True` | Whether to print progress messages (sampling or automatic device choice). | 144 | 145 | 146 | ### Returns 147 | 148 | ```python 149 | covariate_shift : torch.Tensor 150 | concept_shift : torch.Tensor 151 | ``` 152 | 153 | Returned objects are PyTorch tensors placed on the chosen `device`. 154 | 155 | --- 156 | ## Licensing, Citation, Academic Use 157 | 158 | This package is released under the [MIT License](https://en.wikipedia.org/wiki/MIT_License). See the [LICENSE](https://github.com/DataShifts/datashifts/blob/main/LICENSE) file for full details. 159 | 160 | If you use this package in a research paper, **please cite** our [original paper](https://arxiv.org/abs/2506.12829): 161 | 162 | ```latex 163 | @article{chen2025general, 164 | title={General and Estimable Learning Bound Unifying Covariate and Concept Shifts}, 165 | author={Chen, Hongbo and Xia, Li Charlie}, 166 | journal={arXiv preprint arXiv:2506.12829}, 167 | year={2025} 168 | } 169 | ``` 170 | 171 | --- 172 | 173 | > **Contributions & issues** welcome at https://github.com/DataShifts/datashifts/issues 174 | -------------------------------------------------------------------------------- /datashifts/core.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | from geomloss import SamplesLoss 4 | from pykeops.torch import LazyTensor 5 | from torch import Tensor 6 | from numpy import ndarray 7 | from typing import Callable, Union, Optional 8 | 9 | #Global Parameter 10 | N_min=3 11 | 12 | def default_distance_expansion(A, B, KeOps): 13 | #Expanding the dimensions of two tensors for calculating the matrix of default distance 14 | dim=A.dim() 15 | if KeOps: 16 | packer=LazyTensor 17 | else: 18 | packer=lambda x:x 19 | if dim == 1: 20 | A_expand = packer(A[:, None, None]) # (N,1,1) 21 | B_expand = packer(B[None, :, None]) # (1,M,1) 22 | elif dim >= 2: 23 | A_expand = packer(A.unsqueeze(dim-1)) # (..B,N,1,D) 24 | B_expand = packer(B.unsqueeze(dim-2)) # (..B,1,M,D) 25 | return A_expand, B_expand 26 | 27 | def Euclidean_distance(A, B, KeOps=True, p=1): 28 | A_expand, B_expand=default_distance_expansion(A, B, KeOps) 29 | if p==1: 30 | return ((A_expand - B_expand) ** 2).sum(-1) ** (1 / 2) 31 | elif p==2: 32 | return ((A_expand - B_expand) ** 2).sum(-1)/2 33 | else: 34 | raise ValueError("The value of 'p' can only be 1 or 2.") 35 | 36 | def Manhattan_distance(A, B, KeOps=True): 37 | A_expand, B_expand=default_distance_expansion(A, B, KeOps) 38 | return ((A_expand - B_expand).abs()).sum(-1) 39 | 40 | def Chebyshev_distance(A, B, KeOps=True): 41 | A_expand, B_expand=default_distance_expansion(A, B, KeOps) 42 | return ((A_expand - B_expand).abs()).max(-1) 43 | 44 | def W1_deb(x1, x2, w1, w2, eps=0.01): 45 | loss = SamplesLoss(loss="sinkhorn", p=1, blur=eps, debias=True, scaling=0.9) 46 | index1=int(x1.shape[-2]/2) 47 | index2=int(x2.shape[-2]/2) 48 | dim_batch=len(w1.shape)-1 49 | idx11=(slice(None),)*dim_batch+(slice(None,index1),) 50 | idx12=(slice(None),)*dim_batch+(slice(index1,None),) 51 | idx21=(slice(None),)*dim_batch+(slice(None,index2),) 52 | idx22=(slice(None),)*dim_batch+(slice(index2,None),) 53 | x11,x12=x1[idx11],x1[idx12] 54 | x21,x22=x2[idx21],x2[idx22] 55 | w11,w12=w1[idx11],w1[idx12] 56 | w21,w22=w2[idx21],w2[idx22] 57 | w11=w11/(w11.sum(axis=-1).unsqueeze(-1)) 58 | w12=w12/(w12.sum(axis=-1).unsqueeze(-1)) 59 | w21=w21/(w21.sum(axis=-1).unsqueeze(-1)) 60 | w22=w22/(w22.sum(axis=-1).unsqueeze(-1)) 61 | W_x12_1=loss(w11, x11, w21, x21) 62 | W_x12_2=loss(w12, x12, w22, x22) 63 | W_x11=loss(w11, x11, w12, x12) 64 | W_x22=loss(w21, x21, w22, x22) 65 | W1_deb=abs(W_x12_1**2/2+W_x12_2**2/2-W_x11**2/2-W_x22**2/2)**(1/2) 66 | return W1_deb 67 | 68 | def W2_deb(x1, x2, w1, w2, eps=0.01): 69 | loss = SamplesLoss(loss="sinkhorn", p=2, blur=eps**(1/2), debias=True, scaling=0.9**(1/2)) 70 | index1=int(x1.shape[-2]/2) 71 | index2=int(x2.shape[-2]/2) 72 | dim_batch=len(w1.shape)-1 73 | idx11=(slice(None),)*dim_batch+(slice(None,index1),) 74 | idx12=(slice(None),)*dim_batch+(slice(index1,None),) 75 | idx21=(slice(None),)*dim_batch+(slice(None,index2),) 76 | idx22=(slice(None),)*dim_batch+(slice(index2,None),) 77 | x11,x12=x1[idx11],x1[idx12] 78 | x21,x22=x2[idx21],x2[idx22] 79 | w11,w12=w1[idx11],w1[idx12] 80 | w21,w22=w2[idx21],w2[idx22] 81 | w11=w11/(w11.sum(axis=-1).unsqueeze(-1)) 82 | w12=w12/(w12.sum(axis=-1).unsqueeze(-1)) 83 | w21=w21/(w21.sum(axis=-1).unsqueeze(-1)) 84 | w22=w22/(w22.sum(axis=-1).unsqueeze(-1)) 85 | W_x12_1=loss(w11, x11, w21, x21) 86 | W_x12_2=loss(w12, x12, w22, x22) 87 | W_x11=loss(w11, x11, w12, x12) 88 | W_x22=loss(w21, x21, w22, x22) 89 | W2_deb=abs(W_x12_1+W_x12_2-W_x11-W_x22)**(1/2) 90 | return W2_deb 91 | 92 | def directional_derivative(x1, x2, x1_grad, x2_grad): 93 | x1_,x2_=default_distance_expansion(x1, x2, KeOps=True) 94 | diff=x1_-x2_ 95 | dis=(diff**2).sum(-1)**(1/2)+1e-9 #For now, only use European distances. 96 | direction_vector=diff/dis 97 | if x1_grad is not None and x2_grad is not None: 98 | x1_grad_,x2_grad_=default_distance_expansion(x1_grad, x2_grad, KeOps=True) 99 | x1_dir_d=(x1_grad_*direction_vector).sum(-1).abs() 100 | x2_dir_d=(x2_grad_*direction_vector).sum(-1).abs() 101 | dir_d=(x1_dir_d.concat(x2_dir_d)).max(-1) 102 | else: 103 | if x1_grad is not None: 104 | dim=x1_grad.dim() 105 | grad_=LazyTensor(x1_grad.unsqueeze(dim-1)) if dim >= 2 else LazyTensor(x1_grad[:, None, None]) 106 | elif x2_grad is not None: 107 | dim=x2_grad.dim() 108 | grad_=LazyTensor(x2_grad.unsqueeze(dim-2)) if dim >= 2 else LazyTensor(x2_grad[None, :, None]) 109 | else: 110 | raise ValueError("The input values 'x1_grad' and 'x2_grad' cannot both be None.") 111 | dir_d=(grad_*direction_vector).sum(-1).abs() 112 | return dir_d 113 | 114 | def ensure_no_grad(x, dataname): 115 | if isinstance(x, Tensor): 116 | if x.requires_grad: 117 | return x.detach(),True 118 | else: 119 | return x,False 120 | elif isinstance(x, ndarray): 121 | return x,False 122 | else: 123 | raise TypeError("'%s' must be an instance of either torch.Tensor or numpy.ndarray."%dataname) 124 | 125 | def check_class(x1, x2, y1, y2, weights1, weights2, grad1, grad2): 126 | requires_grad={} 127 | x1_,requires_grad["x1"]=ensure_no_grad(x1, "x1") 128 | x2_,requires_grad["x2"]=ensure_no_grad(x2, "x2") 129 | y1_,requires_grad["y1"]=ensure_no_grad(y1, "y1") 130 | y2_,requires_grad["y2"]=ensure_no_grad(y2, "y2") 131 | if weights1 is None: 132 | weights1_,requires_grad["weights1"]=None,None 133 | else: 134 | weights1_,requires_grad["weights1"]=ensure_no_grad(weights1, "weights1") 135 | if weights2 is None: 136 | weights2_,requires_grad["weights2"]=None,None 137 | else: 138 | weights2_,requires_grad["weights2"]=ensure_no_grad(weights2, "weights2") 139 | if grad1 is None: 140 | grad1_,requires_grad["grad1"]=None,None 141 | else: 142 | grad1_,requires_grad["grad1"]=ensure_no_grad(grad1, "grad1") 143 | if grad2 is None: 144 | grad2_,requires_grad["grad2"]=None,None 145 | else: 146 | grad2_,requires_grad["grad2"]=ensure_no_grad(grad2, "grad2") 147 | return x1_, x2_, y1_, y2_, weights1_, weights2_, grad1_, grad2_, requires_grad 148 | 149 | def check_coupling_format(data1, data2, Dis, axis, N_max, dataname): 150 | if Dis!="L2": 151 | raise ValueError("Currently the 'Dis_%s' value is only supported for 'L2', other distances will be supported in subsequent releases."%dataname) 152 | 153 | if Dis in ["L1","L2","inf"]: 154 | original_dis=True 155 | elif isinstance(Dis, Callable): 156 | original_dis=False 157 | elif isinstance(Dis, str): 158 | raise ValueError("Parameter 'Dis_%s'=\'%s\': distance metric is unknown. Built‑in options are 'L1' (Manhattan), 'L2' (Euclidean), and 'inf' (Chebyshev). You can also supply a function handle that takes '%s1' and '%s2' as inputs and returns their distance matrix, allowing a custom metric."%(dataname,Dis,dataname,dataname)) 159 | else: 160 | raise TypeError("Parameter 'Dis_%s' is an instance of %s: distance metric is unknown. Built‑in options are 'L1' (Manhattan), 'L2' (Euclidean), and 'inf' (Chebyshev). You can also supply a function handle that takes '%s1' and '%s2' as inputs and returns their distance matrix, allowing a custom metric."%(dataname,str(type(Dis)),dataname,dataname)) 161 | 162 | shape_data1=tuple(data1.shape) 163 | shape_data2=tuple(data2.shape) 164 | shape_batch=None 165 | N_data1=None 166 | N_data2=None 167 | if original_dis: 168 | if len(shape_data1)>=2 and len(shape_data2)>=2 and shape_data1[:-2]==shape_data2[:-2] and shape_data1[-1]==shape_data2[-1]: 169 | shape_batch=shape_data1[:-2] 170 | N_data1=shape_data1[-2] 171 | N_data2=shape_data2[-2] 172 | if N_data10: 179 | batchs=(torch.cumprod(torch.tensor(shape_batch),dim=0)).item() 180 | else: 181 | batchs=1 182 | estimated_space=4*min(N_data1,N_max)*min(N_data2,N_max)*dim*batchs 183 | if torch.cuda.is_available(): 184 | allowed_space=1024**3 185 | if estimated_space<=allowed_space: 186 | KeOps=False 187 | else: 188 | KeOps=True 189 | else: 190 | KeOps=True 191 | else: 192 | if axis is None: 193 | raise ValueError("Parameter 'axis_%s' is unknown: when providing a custom 'Dis_%s', '%s1' and '%s2' must each have shape (Batch_size, Num_samples, Dim1, Dim2..) or (Num_samples, Dim1, Dim2..). Supply 'axis_%s' to specify which axis in the '%s1' and '%s2' corresponds to the 'Num_samples' dimension, so that sampling and indexing work correctly."%(dataname,dataname,dataname,dataname,dataname,dataname,dataname)) 194 | elif not isinstance(axis, int): 195 | raise TypeError("Parameter 'axis_%s' must be an int."%dataname) 196 | 197 | if shape_data1[:axis]==shape_data2[:axis] and shape_data1[axis+1:]==shape_data2[axis+1:]: 198 | shape_batch=shape_data1[:axis] 199 | N_data1=shape_data1[axis] 200 | N_data2=shape_data2[axis] 201 | if N_data1 N_max; otherwise, use torch.randperm. 229 | #By default, all index tensors are first placed on the CPU. 230 | if isinstance(weights, ndarray): 231 | #If weights is a numpy.ndarray, convert it to a torch.tensor so that sampling can be handled uniformly with torch.multinomial. 232 | weights_=torch.tensor(weights,dtype=torch.float32,device="cpu") 233 | else: 234 | weights_=weights 235 | if N>N_max: 236 | sampling=True 237 | if verbose: 238 | print("The sample size of (x%s,y%s,w%s) is larger than parameter 'N_max'=%d, sampling strategy is used."%(dataname,dataname,dataname,N_max)) 239 | else: 240 | sampling=False 241 | if weights_ is None: 242 | Index=torch.randperm(N, generator=generator,device="cpu") 243 | if sampling: 244 | Index=Index[:N_max] 245 | else: 246 | if sampling: 247 | if len(shape_batch)!=0: 248 | flat_batch=int(torch.prod(torch.tensor(shape_batch))) 249 | weights_flat = weights_.reshape(flat_batch, C) 250 | Index=torch.multinomial(weights_flat, N_max, replacement=False, generator=generator) 251 | Index=Index.reshape(*shape_batch, N_max) 252 | else: 253 | Index=torch.multinomial(weights_, N_max, replacement=False, generator=generator) 254 | ReIndex=torch.randperm(Index.shape[-1], generator=generator,device=Index.device) 255 | Index=Index[(slice(None),)*len(shape_batch)+(ReIndex,)] 256 | else: 257 | #When weights is not None and N ≤ N_max, additionally verify that the weights are all non‑negative and not all zeros. 258 | if (weights_.min(axis=-1)[0]>=0).all() and (weights_.max(axis=-1)[0]>0).any(): 259 | pass 260 | else: 261 | raise ValueError("The input 'weights%s' must have all elements non‑negative and include at least one positive value."%dataname) 262 | Index=torch.randperm(N, generator=generator,device=weights_.device) 263 | #Return a 1‑D tensor or an N‑dimensional tensor. 264 | return Index 265 | 266 | def one_dimension_indexing(data, Index, d): 267 | if isinstance(data, Tensor): 268 | if len(Index.shape)==1: 269 | idx=(slice(None),)*d+(Index.to(data.device),) 270 | samples=data[idx] 271 | else: 272 | samples=torch.gather(data,d,Index.to(data.device)) 273 | else: 274 | Index_=Index.cpu().numpy() 275 | if len(Index.shape)==1: 276 | idx=(slice(None),)*d+(Index_,) 277 | samples=data[idx] 278 | else: 279 | samples=np.take_along_axis(data,Index_,d) 280 | return samples 281 | 282 | def tensorized(data, Cuda): 283 | if isinstance(data, ndarray): 284 | return torch.tensor(data, dtype=torch.float32, device="cuda") if Cuda else torch.tensor(data, dtype=torch.float32, device="cpu") 285 | else: 286 | return data.to("cuda") if Cuda else data.to("cpu") 287 | 288 | def DataShifts( 289 | x1:Union[Tensor,ndarray], 290 | x2:Union[Tensor,ndarray], 291 | y1:Union[Tensor,ndarray], 292 | y2:Union[Tensor,ndarray], 293 | weights1:Union[Tensor,ndarray]=None, 294 | weights2:Union[Tensor,ndarray]=None, 295 | grad1:Union[Tensor,ndarray]=None, 296 | grad2:Union[Tensor,ndarray]=None, 297 | P:int=1, 298 | eps:float=0.01, 299 | N_max:int=5000, 300 | device:Optional[str]=None, 301 | seed:Optional[int]=None, 302 | verbose:bool=True, 303 | ): 304 | r""" 305 | Compute covariate shift and concept shift between two labeled sample sets. 306 | 307 | This routine estimates, from finite samples, (i) the **covariate shift** in the 308 | X-space and (ii) the **concept shift** in the Y|X-space between two 309 | distributions. Covariate shift is computed as the **entropic optimal transport** 310 | in the feature space; concept shift is computed as the **expected label-space 311 | distance under the entropic optimal transport coupling** inferred from dual 312 | Sinkhorn potentials. The function supports batching, importance weights, 313 | automatic sub-sampling for scalability, and transparent GPU execution. 314 | 315 | Parameters 316 | ---------- 317 | x1, x2 : torch.Tensor or numpy.ndarray 318 | Covariate samples from the two domains. Shapes accepted: 319 | ``(Batch, Num, Dim_x)`` or ``(Num, Dim_x)``. Batch dimensions of `x1` and 320 | `x2` must match, and their last (feature) dimensions must be equal. 321 | y1, y2 : torch.Tensor or numpy.ndarray 322 | Corresponding label samples. Shapes accepted: 323 | ``(Batch, Num, Dim_y)`` or ``(Num, Dim_y)``. Must match `x*` in both 324 | ``Batch`` and ``Num``. If the label space is effectively 1-D, a singleton 325 | dimension is added internally for consistency. 326 | weights1, weights2 : torch.Tensor or numpy.ndarray, optional 327 | Optional per-sample weights with shapes ``(Batch, Num)`` or ``(Num,)``. 328 | If provided, they are validated to be non-negative and are **normalized 329 | per batch** internally to sum to 1. If omitted, uniform weights are used. 330 | grad1, grad2 : torch.Tensor or numpy.ndarray, optional 331 | The gradient of `x*` with respect to the error, used to compute the factor 332 | of the covariate shift's effect on the error, and returned as `covariate_factor`. 333 | The shape must correspond to `x*`. 334 | P : int, 1 or 2, default 1 335 | The order of entropic optimal transport. 336 | eps : float, default 0.01 337 | Entropic regularization for optimal transport; smaller is more faithful but 338 | slower/noisier, larger is smoother/faster. 339 | N_max : int, default 5000 340 | Upper bound on the number of samples retained per domain. If 341 | ``Num > N_max``, the data are **shuffled** and (weighted) **sub-sampled 342 | without replacement**. Shuffling is applied even without sub-sampling to 343 | avoid group-specific bias. 344 | device : {"cpu","cuda","gpu"} or None, default None 345 | Target device. If ``None``, the routine uses CUDA automatically when 346 | available (and prints a note if all inputs were on CPU). 347 | seed : int or None, default None 348 | Random seed for shuffling/sampling. Two independent RNGs are used 349 | (one per domain) for reproducible yet uncorrelated draws. 350 | verbose : bool, default True 351 | Whether to print informative messages (sampling strategy, auto-device). 352 | 353 | Returns 354 | ------- 355 | covariate_shift : torch.Tensor 356 | Debiased entropic optimal transport ``W_1^deb(x1,x2)`` or ``W_2^deb(x1,x2)`` in X-space. 357 | The tensor has shape ``Batch`` (or is a scalar 0-D tensor if there is no batch). 358 | concept_shift : torch.Tensor 359 | Expected label-space distance under the optimal coupling. Same shape semantics as above. 360 | covariate_factor : torch.Tensor 361 | If `grad*` is provided, the estimated factor of the covariate shift's effect on the 362 | error is returned. 363 | 364 | Notes 365 | ----- 366 | * **Distance metric:** currently **Euclidean ("L2")** is the only built-in 367 | metric for both X and Y; hooks for user-defined metrics are in place and will 368 | be enabled in a future release. 369 | * **Covariate shift** uses a debiased entropic optimal transport (P=1 or 2), implemented by 370 | combining OT costs on random splits to remove bias. 371 | * **Concept shift** first fits dual potentials (with entropic OT, P=1 or 2), turns 372 | them into a soft coupling ``π* = w1 · exp((g1 − C_x + g2)/eps) · w2``, then 373 | averages distances in Y-space under ``π*``. 374 | * **Shapes:** both outputs follow the leading batch dimensions of the inputs. 375 | * The routine heuristically selects a **KeOps LazyTensor** backend for large 376 | problems to control memory, otherwise uses dense tensors on GPU/CPU. 377 | 378 | Raises 379 | ------ 380 | TypeError 381 | If `eps`/`N_max` are non-numeric; if any input is neither a Tensor nor a 382 | NumPy array; if `verbose` is not bool; or for invalid custom-metric handles. 383 | ValueError 384 | If sample counts are below the global minimum; if shapes are inconsistent 385 | between domains or between X and Y; if weights have negatives or are all 0; 386 | or if an unsupported distance is requested. 387 | RuntimeError 388 | If a user-provided distance function (future pathway) raises during checks. 389 | 390 | Examples 391 | -------- 392 | >>> covariate_shift, concept_shift = DataShifts(x1, x2, y1, y2, N_max=2048, eps=0.01) 393 | >>> covariate_shift, concept_shift 394 | (tensor(..., device='cuda:0'), tensor(..., device='cuda:0')) 395 | """ 396 | 397 | #Dis_x:Union[str,Callable]="L2", 398 | #Dis_y:Union[str,Callable]="L2", 399 | #axis_x:Optional[int]=None, 400 | #axis_y:Optional[int]=None, 401 | #KeOps:Optional[bool]=None, 402 | 403 | Dis_x="L2" 404 | Dis_y="L2" 405 | axis_x=None 406 | axis_y=None 407 | KeOps=None 408 | 409 | #Perform class validation and gradient detachment for all tensor inputs. 410 | x1_, x2_, y1_, y2_, weights1_, weights2_, grad1_, grad2_, requires_grad=check_class(x1, x2, y1, y2, weights1, weights2, grad1, grad2,) 411 | #Verify that N_max is numeric and that N_max ≥ N_min. 412 | if isinstance(N_max, float) or isinstance(N_max, int): 413 | N_max=int(N_max) 414 | if N_max= 2: 23 | A_expand = packer(A.unsqueeze(dim-1)) # (..B,N,1,D) 24 | B_expand = packer(B.unsqueeze(dim-2)) # (..B,1,M,D) 25 | return A_expand, B_expand 26 | 27 | def Euclidean_distance(A, B, KeOps=True, p=1): 28 | A_expand, B_expand=default_distance_expansion(A, B, KeOps) 29 | if p==1: 30 | return ((A_expand - B_expand) ** 2).sum(-1) ** (1 / 2) 31 | elif p==2: 32 | return ((A_expand - B_expand) ** 2).sum(-1)/2 33 | else: 34 | raise ValueError("The value of 'p' can only be 1 or 2.") 35 | 36 | def Manhattan_distance(A, B, KeOps=True): 37 | A_expand, B_expand=default_distance_expansion(A, B, KeOps) 38 | return ((A_expand - B_expand).abs()).sum(-1) 39 | 40 | def Chebyshev_distance(A, B, KeOps=True): 41 | A_expand, B_expand=default_distance_expansion(A, B, KeOps) 42 | return ((A_expand - B_expand).abs()).max(-1) 43 | 44 | def W1_deb(x1, x2, w1, w2, eps=0.01): 45 | loss = SamplesLoss(loss="sinkhorn", p=1, blur=eps, debias=True, scaling=0.9) 46 | index1=int(x1.shape[-2]/2) 47 | index2=int(x2.shape[-2]/2) 48 | dim_batch=len(w1.shape)-1 49 | idx11=(slice(None),)*dim_batch+(slice(None,index1),) 50 | idx12=(slice(None),)*dim_batch+(slice(index1,None),) 51 | idx21=(slice(None),)*dim_batch+(slice(None,index2),) 52 | idx22=(slice(None),)*dim_batch+(slice(index2,None),) 53 | x11,x12=x1[idx11],x1[idx12] 54 | x21,x22=x2[idx21],x2[idx22] 55 | w11,w12=w1[idx11],w1[idx12] 56 | w21,w22=w2[idx21],w2[idx22] 57 | w11=w11/(w11.sum(axis=-1).unsqueeze(-1)) 58 | w12=w12/(w12.sum(axis=-1).unsqueeze(-1)) 59 | w21=w21/(w21.sum(axis=-1).unsqueeze(-1)) 60 | w22=w22/(w22.sum(axis=-1).unsqueeze(-1)) 61 | W_x12_1=loss(w11, x11, w21, x21) 62 | W_x12_2=loss(w12, x12, w22, x22) 63 | W_x11=loss(w11, x11, w12, x12) 64 | W_x22=loss(w21, x21, w22, x22) 65 | W1_deb=abs(W_x12_1**2/2+W_x12_2**2/2-W_x11**2/2-W_x22**2/2)**(1/2) 66 | return W1_deb 67 | 68 | def W2_deb(x1, x2, w1, w2, eps=0.01): 69 | loss = SamplesLoss(loss="sinkhorn", p=2, blur=eps**(1/2), debias=True, scaling=0.9**(1/2)) 70 | index1=int(x1.shape[-2]/2) 71 | index2=int(x2.shape[-2]/2) 72 | dim_batch=len(w1.shape)-1 73 | idx11=(slice(None),)*dim_batch+(slice(None,index1),) 74 | idx12=(slice(None),)*dim_batch+(slice(index1,None),) 75 | idx21=(slice(None),)*dim_batch+(slice(None,index2),) 76 | idx22=(slice(None),)*dim_batch+(slice(index2,None),) 77 | x11,x12=x1[idx11],x1[idx12] 78 | x21,x22=x2[idx21],x2[idx22] 79 | w11,w12=w1[idx11],w1[idx12] 80 | w21,w22=w2[idx21],w2[idx22] 81 | w11=w11/(w11.sum(axis=-1).unsqueeze(-1)) 82 | w12=w12/(w12.sum(axis=-1).unsqueeze(-1)) 83 | w21=w21/(w21.sum(axis=-1).unsqueeze(-1)) 84 | w22=w22/(w22.sum(axis=-1).unsqueeze(-1)) 85 | W_x12_1=loss(w11, x11, w21, x21) 86 | W_x12_2=loss(w12, x12, w22, x22) 87 | W_x11=loss(w11, x11, w12, x12) 88 | W_x22=loss(w21, x21, w22, x22) 89 | W2_deb=abs(W_x12_1+W_x12_2-W_x11-W_x22)**(1/2) 90 | return W2_deb 91 | 92 | def directional_derivative(x1, x2, x1_grad, x2_grad): 93 | x1_,x2_=default_distance_expansion(x1, x2, KeOps=True) 94 | diff=x1_-x2_ 95 | dis=(diff**2).sum(-1)**(1/2)+1e-9 #For now, only use European distances. 96 | direction_vector=diff/dis 97 | if x1_grad is not None and x2_grad is not None: 98 | x1_grad_,x2_grad_=default_distance_expansion(x1_grad, x2_grad, KeOps=True) 99 | x1_dir_d=(x1_grad_*direction_vector).sum(-1).abs() 100 | x2_dir_d=(x2_grad_*direction_vector).sum(-1).abs() 101 | dir_d=(x1_dir_d.concat(x2_dir_d)).max(-1) 102 | else: 103 | if x1_grad is not None: 104 | dim=x1_grad.dim() 105 | grad_=LazyTensor(x1_grad.unsqueeze(dim-1)) if dim >= 2 else LazyTensor(x1_grad[:, None, None]) 106 | elif x2_grad is not None: 107 | dim=x2_grad.dim() 108 | grad_=LazyTensor(x2_grad.unsqueeze(dim-2)) if dim >= 2 else LazyTensor(x2_grad[None, :, None]) 109 | else: 110 | raise ValueError("The input values 'x1_grad' and 'x2_grad' cannot both be None.") 111 | dir_d=(grad_*direction_vector).sum(-1).abs() 112 | return dir_d 113 | 114 | def ensure_no_grad(x, dataname): 115 | if isinstance(x, Tensor): 116 | if x.requires_grad: 117 | return x.detach(),True 118 | else: 119 | return x,False 120 | elif isinstance(x, ndarray): 121 | return x,False 122 | else: 123 | raise TypeError("'%s' must be an instance of either torch.Tensor or numpy.ndarray."%dataname) 124 | 125 | def check_class(x1, x2, y1, y2, weights1, weights2, grad1, grad2): 126 | requires_grad={} 127 | x1_,requires_grad["x1"]=ensure_no_grad(x1, "x1") 128 | x2_,requires_grad["x2"]=ensure_no_grad(x2, "x2") 129 | y1_,requires_grad["y1"]=ensure_no_grad(y1, "y1") 130 | y2_,requires_grad["y2"]=ensure_no_grad(y2, "y2") 131 | if weights1 is None: 132 | weights1_,requires_grad["weights1"]=None,None 133 | else: 134 | weights1_,requires_grad["weights1"]=ensure_no_grad(weights1, "weights1") 135 | if weights2 is None: 136 | weights2_,requires_grad["weights2"]=None,None 137 | else: 138 | weights2_,requires_grad["weights2"]=ensure_no_grad(weights2, "weights2") 139 | if grad1 is None: 140 | grad1_,requires_grad["grad1"]=None,None 141 | else: 142 | grad1_,requires_grad["grad1"]=ensure_no_grad(grad1, "grad1") 143 | if grad2 is None: 144 | grad2_,requires_grad["grad2"]=None,None 145 | else: 146 | grad2_,requires_grad["grad2"]=ensure_no_grad(grad2, "grad2") 147 | return x1_, x2_, y1_, y2_, weights1_, weights2_, grad1_, grad2_, requires_grad 148 | 149 | def check_coupling_format(data1, data2, Dis, axis, N_max, dataname): 150 | if Dis!="L2": 151 | raise ValueError("Currently the 'Dis_%s' value is only supported for 'L2', other distances will be supported in subsequent releases."%dataname) 152 | 153 | if Dis in ["L1","L2","inf"]: 154 | original_dis=True 155 | elif isinstance(Dis, Callable): 156 | original_dis=False 157 | elif isinstance(Dis, str): 158 | raise ValueError("Parameter 'Dis_%s'=\'%s\': distance metric is unknown. Built‑in options are 'L1' (Manhattan), 'L2' (Euclidean), and 'inf' (Chebyshev). You can also supply a function handle that takes '%s1' and '%s2' as inputs and returns their distance matrix, allowing a custom metric."%(dataname,Dis,dataname,dataname)) 159 | else: 160 | raise TypeError("Parameter 'Dis_%s' is an instance of %s: distance metric is unknown. Built‑in options are 'L1' (Manhattan), 'L2' (Euclidean), and 'inf' (Chebyshev). You can also supply a function handle that takes '%s1' and '%s2' as inputs and returns their distance matrix, allowing a custom metric."%(dataname,str(type(Dis)),dataname,dataname)) 161 | 162 | shape_data1=tuple(data1.shape) 163 | shape_data2=tuple(data2.shape) 164 | shape_batch=None 165 | N_data1=None 166 | N_data2=None 167 | if original_dis: 168 | if len(shape_data1)>=2 and len(shape_data2)>=2 and shape_data1[:-2]==shape_data2[:-2] and shape_data1[-1]==shape_data2[-1]: 169 | shape_batch=shape_data1[:-2] 170 | N_data1=shape_data1[-2] 171 | N_data2=shape_data2[-2] 172 | if N_data10: 179 | batchs=(torch.cumprod(torch.tensor(shape_batch),dim=0)).item() 180 | else: 181 | batchs=1 182 | estimated_space=4*min(N_data1,N_max)*min(N_data2,N_max)*dim*batchs 183 | if torch.cuda.is_available(): 184 | allowed_space=1024**3 185 | if estimated_space<=allowed_space: 186 | KeOps=False 187 | else: 188 | KeOps=True 189 | else: 190 | KeOps=True 191 | else: 192 | if axis is None: 193 | raise ValueError("Parameter 'axis_%s' is unknown: when providing a custom 'Dis_%s', '%s1' and '%s2' must each have shape (Batch_size, Num_samples, Dim1, Dim2..) or (Num_samples, Dim1, Dim2..). Supply 'axis_%s' to specify which axis in the '%s1' and '%s2' corresponds to the 'Num_samples' dimension, so that sampling and indexing work correctly."%(dataname,dataname,dataname,dataname,dataname,dataname,dataname)) 194 | elif not isinstance(axis, int): 195 | raise TypeError("Parameter 'axis_%s' must be an int."%dataname) 196 | 197 | if shape_data1[:axis]==shape_data2[:axis] and shape_data1[axis+1:]==shape_data2[axis+1:]: 198 | shape_batch=shape_data1[:axis] 199 | N_data1=shape_data1[axis] 200 | N_data2=shape_data2[axis] 201 | if N_data1 N_max; otherwise, use torch.randperm. 229 | #By default, all index tensors are first placed on the CPU. 230 | if isinstance(weights, ndarray): 231 | #If weights is a numpy.ndarray, convert it to a torch.tensor so that sampling can be handled uniformly with torch.multinomial. 232 | weights_=torch.tensor(weights,dtype=torch.float32,device="cpu") 233 | else: 234 | weights_=weights 235 | if N>N_max: 236 | sampling=True 237 | if verbose: 238 | print("The sample size of (x%s,y%s,w%s) is larger than parameter 'N_max'=%d, sampling strategy is used."%(dataname,dataname,dataname,N_max)) 239 | else: 240 | sampling=False 241 | if weights_ is None: 242 | Index=torch.randperm(N, generator=generator,device="cpu") 243 | if sampling: 244 | Index=Index[:N_max] 245 | else: 246 | if sampling: 247 | if len(shape_batch)!=0: 248 | flat_batch=int(torch.prod(torch.tensor(shape_batch))) 249 | weights_flat = weights_.reshape(flat_batch, C) 250 | Index=torch.multinomial(weights_flat, N_max, replacement=False, generator=generator) 251 | Index=Index.reshape(*shape_batch, N_max) 252 | else: 253 | Index=torch.multinomial(weights_, N_max, replacement=False, generator=generator) 254 | ReIndex=torch.randperm(Index.shape[-1], generator=generator,device=Index.device) 255 | Index=Index[(slice(None),)*len(shape_batch)+(ReIndex,)] 256 | else: 257 | #When weights is not None and N ≤ N_max, additionally verify that the weights are all non‑negative and not all zeros. 258 | if (weights_.min(axis=-1)[0]>=0).all() and (weights_.max(axis=-1)[0]>0).any(): 259 | pass 260 | else: 261 | raise ValueError("The input 'weights%s' must have all elements non‑negative and include at least one positive value."%dataname) 262 | Index=torch.randperm(N, generator=generator,device=weights_.device) 263 | #Return a 1‑D tensor or an N‑dimensional tensor. 264 | return Index 265 | 266 | def one_dimension_indexing(data, Index, d): 267 | if isinstance(data, Tensor): 268 | if len(Index.shape)==1: 269 | idx=(slice(None),)*d+(Index.to(data.device),) 270 | samples=data[idx] 271 | else: 272 | samples=torch.gather(data,d,Index.to(data.device)) 273 | else: 274 | Index_=Index.cpu().numpy() 275 | if len(Index.shape)==1: 276 | idx=(slice(None),)*d+(Index_,) 277 | samples=data[idx] 278 | else: 279 | samples=np.take_along_axis(data,Index_,d) 280 | return samples 281 | 282 | def tensorized(data, Cuda): 283 | if isinstance(data, ndarray): 284 | return torch.tensor(data, dtype=torch.float32, device="cuda") if Cuda else torch.tensor(data, dtype=torch.float32, device="cpu") 285 | else: 286 | return data.to("cuda") if Cuda else data.to("cpu") 287 | 288 | def DataShifts( 289 | x1:Union[Tensor,ndarray], 290 | x2:Union[Tensor,ndarray], 291 | y1:Union[Tensor,ndarray], 292 | y2:Union[Tensor,ndarray], 293 | weights1:Union[Tensor,ndarray]=None, 294 | weights2:Union[Tensor,ndarray]=None, 295 | grad1:Union[Tensor,ndarray]=None, 296 | grad2:Union[Tensor,ndarray]=None, 297 | P:int=1, 298 | eps:float=0.01, 299 | N_max:int=5000, 300 | device:Optional[str]=None, 301 | seed:Optional[int]=None, 302 | verbose:bool=True, 303 | ): 304 | r""" 305 | Compute covariate shift and concept shift between two labeled sample sets. 306 | 307 | This routine estimates, from finite samples, (i) the **covariate shift** in the 308 | X-space and (ii) the **concept shift** in the Y|X-space between two 309 | distributions. Covariate shift is computed as the **entropic optimal transport** 310 | in the feature space; concept shift is computed as the **expected label-space 311 | distance under the entropic optimal transport coupling** inferred from dual 312 | Sinkhorn potentials. The function supports batching, importance weights, 313 | automatic sub-sampling for scalability, and transparent GPU execution. 314 | 315 | Parameters 316 | ---------- 317 | x1, x2 : torch.Tensor or numpy.ndarray 318 | Covariate samples from the two domains. Shapes accepted: 319 | ``(Batch, Num, Dim_x)`` or ``(Num, Dim_x)``. Batch dimensions of `x1` and 320 | `x2` must match, and their last (feature) dimensions must be equal. 321 | y1, y2 : torch.Tensor or numpy.ndarray 322 | Corresponding label samples. Shapes accepted: 323 | ``(Batch, Num, Dim_y)`` or ``(Num, Dim_y)``. Must match `x*` in both 324 | ``Batch`` and ``Num``. If the label space is effectively 1-D, a singleton 325 | dimension is added internally for consistency. 326 | weights1, weights2 : torch.Tensor or numpy.ndarray, optional 327 | Optional per-sample weights with shapes ``(Batch, Num)`` or ``(Num,)``. 328 | If provided, they are validated to be non-negative and are **normalized 329 | per batch** internally to sum to 1. If omitted, uniform weights are used. 330 | grad1, grad2 : torch.Tensor or numpy.ndarray, optional 331 | The gradient of `x*` with respect to the error, used to compute the factor 332 | of the covariate shift's effect on the error, and returned as `covariate_factor`. 333 | The shape must correspond to `x*`. 334 | P : int, 1 or 2, default 1 335 | The order of entropic optimal transport. 336 | eps : float, default 0.01 337 | Entropic regularization for optimal transport; smaller is more faithful but 338 | slower/noisier, larger is smoother/faster. 339 | N_max : int, default 5000 340 | Upper bound on the number of samples retained per domain. If 341 | ``Num > N_max``, the data are **shuffled** and (weighted) **sub-sampled 342 | without replacement**. Shuffling is applied even without sub-sampling to 343 | avoid group-specific bias. 344 | device : {"cpu","cuda","gpu"} or None, default None 345 | Target device. If ``None``, the routine uses CUDA automatically when 346 | available (and prints a note if all inputs were on CPU). 347 | seed : int or None, default None 348 | Random seed for shuffling/sampling. Two independent RNGs are used 349 | (one per domain) for reproducible yet uncorrelated draws. 350 | verbose : bool, default True 351 | Whether to print informative messages (sampling strategy, auto-device). 352 | 353 | Returns 354 | ------- 355 | covariate_shift : torch.Tensor 356 | Debiased entropic optimal transport ``W_1^deb(x1,x2)`` or ``W_2^deb(x1,x2)`` in X-space. 357 | The tensor has shape ``Batch`` (or is a scalar 0-D tensor if there is no batch). 358 | concept_shift : torch.Tensor 359 | Expected label-space distance under the optimal coupling. Same shape semantics as above. 360 | covariate_factor : torch.Tensor 361 | If `grad*` is provided, the estimated factor of the covariate shift's effect on the 362 | error is returned. 363 | 364 | Notes 365 | ----- 366 | * **Distance metric:** currently **Euclidean ("L2")** is the only built-in 367 | metric for both X and Y; hooks for user-defined metrics are in place and will 368 | be enabled in a future release. 369 | * **Covariate shift** uses a debiased entropic optimal transport (P=1 or 2), implemented by 370 | combining OT costs on random splits to remove bias. 371 | * **Concept shift** first fits dual potentials (with entropic OT, P=1 or 2), turns 372 | them into a soft coupling ``π* = w1 · exp((g1 − C_x + g2)/eps) · w2``, then 373 | averages distances in Y-space under ``π*``. 374 | * **Shapes:** both outputs follow the leading batch dimensions of the inputs. 375 | * The routine heuristically selects a **KeOps LazyTensor** backend for large 376 | problems to control memory, otherwise uses dense tensors on GPU/CPU. 377 | 378 | Raises 379 | ------ 380 | TypeError 381 | If `eps`/`N_max` are non-numeric; if any input is neither a Tensor nor a 382 | NumPy array; if `verbose` is not bool; or for invalid custom-metric handles. 383 | ValueError 384 | If sample counts are below the global minimum; if shapes are inconsistent 385 | between domains or between X and Y; if weights have negatives or are all 0; 386 | or if an unsupported distance is requested. 387 | RuntimeError 388 | If a user-provided distance function (future pathway) raises during checks. 389 | 390 | Examples 391 | -------- 392 | >>> covariate_shift, concept_shift = DataShifts(x1, x2, y1, y2, N_max=2048, eps=0.01) 393 | >>> covariate_shift, concept_shift 394 | (tensor(..., device='cuda:0'), tensor(..., device='cuda:0')) 395 | """ 396 | 397 | #Dis_x:Union[str,Callable]="L2", 398 | #Dis_y:Union[str,Callable]="L2", 399 | #axis_x:Optional[int]=None, 400 | #axis_y:Optional[int]=None, 401 | #KeOps:Optional[bool]=None, 402 | 403 | Dis_x="L2" 404 | Dis_y="L2" 405 | axis_x=None 406 | axis_y=None 407 | KeOps=None 408 | 409 | #Perform class validation and gradient detachment for all tensor inputs. 410 | x1_, x2_, y1_, y2_, weights1_, weights2_, grad1_, grad2_, requires_grad=check_class(x1, x2, y1, y2, weights1, weights2, grad1, grad2,) 411 | #Verify that N_max is numeric and that N_max ≥ N_min. 412 | if isinstance(N_max, float) or isinstance(N_max, int): 413 | N_max=int(N_max) 414 | if N_max