├── .gitignore ├── CITATION.cff ├── LICENSE ├── README.md ├── setup.py └── torchqmet ├── __init__.py ├── iqe.py ├── mrn.py ├── neural_norms.py ├── pqe ├── README.md ├── __init__.py ├── cdf_ops │ ├── .gitignore │ ├── README.md │ ├── __init__.py │ ├── cdf_ops.cpp │ ├── cdf_ops.h │ ├── cpu │ │ ├── cdflib │ │ │ ├── cdflib.cpp │ │ │ └── cdflib.hpp │ │ ├── cephes │ │ │ ├── chbevl.h │ │ │ ├── i0.h │ │ │ ├── i1.h │ │ │ ├── ndtr.h │ │ │ └── polevl.h │ │ └── kernels.h │ ├── cuda │ │ ├── cdflib │ │ │ ├── chndtr.cuh │ │ │ ├── cumchi.cuh │ │ │ ├── cumchn.cuh │ │ │ ├── cumgam.cuh │ │ │ ├── error_fc.cuh │ │ │ ├── exparg.cuh │ │ │ ├── fifidint.cuh │ │ │ ├── gam1.cuh │ │ │ ├── gamma_inc.cuh │ │ │ ├── rexp.cuh │ │ │ └── rlog.cuh │ │ ├── cephes │ │ │ ├── chbevl.cuh │ │ │ ├── i0.cuh │ │ │ ├── i1.cuh │ │ │ └── ndtr.cuh │ │ ├── kernels.cuh │ │ ├── kernels_bessel_i.cu │ │ ├── kernels_chndtr_scalar.cu │ │ ├── kernels_chndtr_scalar_double.cu │ │ ├── kernels_chndtr_scalar_scalar_t.cu │ │ ├── kernels_chndtr_tensor.cu │ │ ├── kernels_ndtr.cu │ │ ├── kernels_prob_two_poisson_backward_mu1.cu │ │ ├── kernels_prob_two_poisson_backward_mu2_gt.cu │ │ ├── kernels_prob_two_poisson_backward_mu2_le.cu │ │ ├── kernels_prob_two_poisson_forward_gt.cu │ │ ├── kernels_prob_two_poisson_forward_le.cu │ │ └── kernels_prob_two_poisson_templates.cuh │ ├── load_ext.py │ └── op_wrappers.py ├── measures.py └── shapes.py ├── reductions.py ├── transforms.py └── utils.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | *.egg-info/ 24 | .installed.cfg 25 | *.egg 26 | MANIFEST 27 | 28 | # PyInstaller 29 | # Usually these files are written by a python script from a template 30 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 31 | *.manifest 32 | *.spec 33 | 34 | # Installer logs 35 | pip-log.txt 36 | pip-delete-this-directory.txt 37 | 38 | # Unit test / coverage reports 39 | htmlcov/ 40 | .tox/ 41 | .coverage 42 | .coverage.* 43 | .cache 44 | nosetests.xml 45 | coverage.xml 46 | *.cover 47 | .hypothesis/ 48 | .pytest_cache/ 49 | 50 | # Translations 51 | *.mo 52 | *.pot 53 | 54 | # Django stuff: 55 | *.log 56 | local_settings.py 57 | db.sqlite3 58 | 59 | # Flask stuff: 60 | instance/ 61 | .webassets-cache 62 | 63 | # Scrapy stuff: 64 | .scrapy 65 | 66 | # Sphinx documentation 67 | docs/_build/ 68 | 69 | # PyBuilder 70 | target/ 71 | 72 | # Jupyter Notebook 73 | .ipynb_checkpoints 74 | 75 | # pyenv 76 | .python-version 77 | 78 | # celery beat schedule file 79 | celerybeat-schedule 80 | 81 | # SageMath parsed files 82 | *.sage.py 83 | 84 | # Environments 85 | .env 86 | .venv 87 | env/ 88 | venv/ 89 | ENV/ 90 | env.bak/ 91 | venv.bak/ 92 | 93 | # Spyder project settings 94 | .spyderproject 95 | .spyproject 96 | 97 | # Rope project settings 98 | .ropeproject 99 | 100 | # mkdocs documentation 101 | /site 102 | 103 | # mypy 104 | .mypy_cache/ 105 | .idea/ 106 | *.cfg 107 | !setup.cfg 108 | exman 109 | .tests 110 | cfg/ 111 | data/ 112 | notebooks/ 113 | testing-report.html 114 | test-output.xml 115 | *.el 116 | .vscode/ 117 | -------------------------------------------------------------------------------- /CITATION.cff: -------------------------------------------------------------------------------- 1 | cff-version: 1.2.0 2 | message: "If you use this software, please cite it as below." 3 | authors: 4 | - family-names: "Wang" 5 | given-names: "Tongzhou" 6 | title: "torchqmet: PyTorch Package for Quasimetric Learning" 7 | version: 0.1.0 8 | date-released: 2022-11-28 9 | url: "https://github.com/quasimetric-learning/torch-quasimetric" 10 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Copyright (c) 2022 Tongzhou Wang 2 | All rights reserved. 3 | 4 | Redistribution and use in source and binary forms, with or without 5 | modification, are permitted provided that the following conditions 6 | are met: 7 | 8 | 1. Redistributions of source code must retain the above copyright 9 | notice, this list of conditions and the following disclaimer. 10 | 11 | 2. Redistributions in binary form must reproduce the above 12 | copyright notice, this list of conditions and the following 13 | disclaimer in the documentation and/or other materials provided 14 | with the distribution. 15 | 16 | 3. Neither the name of the copyright holder nor the names of its 17 | contributors may be used to endorse or promote products derived 18 | from this software without specific prior written permission. 19 | 20 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS 21 | "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT 22 | LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR 23 | A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT 24 | OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, 25 | SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT 26 | LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, 27 | DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY 28 | THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT 29 | (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 30 | OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | import os 3 | import re 4 | from setuptools import setup, find_packages 5 | 6 | PROJECT_ROOT = os.path.dirname(os.path.realpath(__file__)) 7 | readme = open(os.path.join(PROJECT_ROOT, 'README.md')).read() 8 | 9 | def get_version(*path): 10 | version_file = os.path.join(*path) 11 | lines = open(version_file, "rt").readlines() 12 | version_regex = r"^__version__ = ['\"]([^'\"]*)['\"]" 13 | for line in lines: 14 | mo = re.search(version_regex, line, re.M) 15 | if mo: 16 | return mo.group(1) 17 | raise RuntimeError("Unable to find version in %s." % (version_file,)) 18 | 19 | 20 | 21 | setup( 22 | # Metadata 23 | name='torchqmet', 24 | author='Tongzhou Wang', 25 | author_email='tongzhou.wang.1994@gmail.com', 26 | url='https://github.com/quasimetric-learning/torch-quasimetric', 27 | install_requires=["torch>=1.11.0"], 28 | python_requires=">=3.7.0", 29 | description='PyTorch Package for Quasimetric Learning', 30 | long_description=readme, 31 | license='BSD', 32 | 33 | # Package info 34 | packages=find_packages(exclude=('test',)), 35 | version=get_version(PROJECT_ROOT, "torchqmet", "__init__.py"), 36 | 37 | zip_safe=True, 38 | ) 39 | -------------------------------------------------------------------------------- /torchqmet/__init__.py: -------------------------------------------------------------------------------- 1 | from typing import * 2 | 3 | import abc 4 | 5 | import torch 6 | import torch.nn as nn 7 | 8 | from .transforms import TransformBase, make_transform 9 | from .reductions import ReductionBase, make_reduction 10 | 11 | 12 | class QuasimetricBase(nn.Module, metaclass=abc.ABCMeta): 13 | input_size: int # dimensionality of input latent space 14 | num_components: int # number of components to be combined to form the latent quasimetric 15 | discount: Optional[float] # if set, output the discounted quasimetric, `discount ** d` 16 | guaranteed_quasimetric: bool # whether this is guaranteed to satisfy quasimetric constraints 17 | 18 | transforms: nn.Sequential # Sequential[TransformBase] 19 | reduction: ReductionBase 20 | 21 | def __init__(self, input_size: int, num_components: int, *, 22 | warn_if_not_quasimetric: bool = True, guaranteed_quasimetric: bool, 23 | transforms: Collection[str], reduction: str, discount: Optional[float] = None) -> None: 24 | super().__init__() 25 | self.input_size = input_size 26 | self.num_components = num_components 27 | self.guaranteed_quasimetric = guaranteed_quasimetric 28 | self.discount = discount 29 | 30 | _transforms: List[TransformBase] = [] 31 | for transform in transforms: 32 | _transforms.append(make_transform(transform, num_components)) 33 | num_components = _transforms[-1].output_num_components 34 | self.transforms = nn.Sequential(*_transforms) 35 | self.reduction = make_reduction(reduction, num_components, discount) 36 | 37 | @abc.abstractmethod 38 | def compute_components(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: 39 | r''' 40 | Inputs: 41 | x (torch.Tensor): Shape [..., input_size] 42 | y (torch.Tensor): Shape [..., input_size] 43 | 44 | Output: 45 | d (torch.Tensor): Shape [..., num_components] 46 | ''' 47 | pass 48 | 49 | def forward(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: 50 | assert x.shape[-1] == y.shape[-1] == self.input_size 51 | d = self.compute_components(x, y) 52 | d: torch.Tensor = self.transforms(d) 53 | return self.reduction(d) 54 | 55 | def __call__(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: 56 | # Manually define for typing 57 | # https://github.com/pytorch/pytorch/issues/45414 58 | return super().__call__(x, y) 59 | 60 | def extra_repr(self) -> str: 61 | return f"guaranteed_quasimetric={self.guaranteed_quasimetric}\ninput_size={self.input_size}, num_components={self.num_components}" + ( 62 | ", discount=None" if self.discount is None else f", discount={self.discount:g}" 63 | ) 64 | 65 | 66 | from .pqe import PQE, PQELH, PQEGG 67 | from .iqe import IQE 68 | from .mrn import MRN, MRNFixed 69 | from .neural_norms import DeepNorm, WideNorm 70 | 71 | __all__ = ['PQE', 'PQELH', 'PQEGG', 'IQE', 'MRN', 'MRNFixed', 'DeepNorm', 'WideNorm'] 72 | __version__ = "0.1.0" 73 | -------------------------------------------------------------------------------- /torchqmet/iqe.py: -------------------------------------------------------------------------------- 1 | r''' 2 | Inteval Quasimetric Embedding (IQE) 3 | https://arxiv.org/abs/2211.15120 4 | ''' 5 | 6 | from typing import * 7 | 8 | import torch 9 | 10 | from . import QuasimetricBase 11 | 12 | 13 | @torch.jit.script 14 | def iqe(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: 15 | D = x.shape[-1] # D: dim_per_component 16 | 17 | # ignore pairs that x >= y 18 | valid = x < y 19 | 20 | # sort to better count 21 | xy = torch.cat(torch.broadcast_tensors(x, y), dim=-1) 22 | sxy, ixy = xy.sort(dim=-1) 23 | 24 | # f(c) = indic( c > 0 ) 25 | # at each location `x` along the real line, get `c` the number of intervals covering `x`, and apply `f`: 26 | # \int f(c(x)) dx 27 | 28 | # neg_inc_copies: the **negated** increment of **input** of f at sorted locations, in terms of **#copies of delta** 29 | neg_inc_copies = torch.gather(valid, dim=-1, index=ixy % D) * torch.where(ixy < D, -1, 1) 30 | 31 | # neg_incf: the **negated** increment of **output** of f at sorted locations 32 | neg_inp_copies = torch.cumsum(neg_inc_copies, dim=-1) 33 | 34 | # delta = inf 35 | # f input: 0 -> 0, x -> -inf. 36 | # neg_inp = torch.where(neg_inp_copies == 0, 0., -delta) 37 | # f output: 0 -> 0, x -> 1. 38 | neg_f = (neg_inp_copies < 0) * (-1.) 39 | neg_incf = torch.cat([neg_f.narrow(-1, 0, 1), torch.diff(neg_f, dim=-1)], dim=-1) 40 | 41 | # reduction 42 | return (sxy * neg_incf).sum(-1) 43 | 44 | 45 | class IQE(QuasimetricBase): 46 | r''' 47 | Inteval Quasimetric Embedding (IQE): 48 | https://arxiv.org/abs/2211.15120 49 | 50 | One-line Usage: 51 | 52 | IQE(input_size: int, dim_per_component: int = 16, ...) 53 | 54 | 55 | Default arguments implements IQE-maxmean. Set `reduction="sum"` to create IQE-sum. 56 | 57 | IQE-Specific Args: 58 | input_size (int): Dimension of input latent vectors 59 | dim_per_component (int): IQE splits latent vectors into chunks, where ach chunk computes gives an IQE component. 60 | This is the number of latent dimensions assigned to each chunk. This number must 61 | perfectly divide ``input_size``. IQE paper recomments at least ``8``. 62 | Default: ``16``. 63 | 64 | Common Args (Exist for all quasimetrics, **Keyword-only**, Default values may be different for different quasimetrics): 65 | transforms (Collection[str]): A sequence of transforms to apply to the components, before reducing them to form 66 | the final latent quasimetric. 67 | Supported choices: 68 | + "concave_activation": Concave activation transform from Neural Norms paper. 69 | Default: ``()`` (no transforms). 70 | reduction (str): Reduction method to aggregate components into final quasimetric value. 71 | Supported choices: 72 | + "sum": Sum of components. 73 | + "max": Max of components. 74 | + "mean": Average of components. 75 | + "maxmean": Convex combination of max and mean. Used in original Deep Norm, Wide Norm, and IQE. 76 | + "deep_linear_net_weighted_sum": Weighted sum with weights given by a deep linear net. Used in 77 | original PQE, whose components have limited range [0, 1). 78 | Default: ``"maxmean"``. 79 | discounted (Optional[float]): If not ``None``, this module instead estimates discounted distances with the 80 | base as ``discounted``. 81 | Default ``None``. 82 | warn_if_not_quasimetric (bool): If ``True``, issue a warning if this module does not always obey quasimetric 83 | constraints. IQEs always obey quasimetric constraints. 84 | Default: ``True``. 85 | 86 | Shape: 87 | - Input: Two broadcastable tensors of shape ``(..., input_size)`` 88 | - Output: ``(...)`` 89 | 90 | Non-Module Attributes: 91 | input_size (int) 92 | num_components (int): Number of components to be combined to form the latent quasimetric. For IQEs, this is 93 | ``input_size // dim_per_component``. 94 | discount (Optional[float]) 95 | guaranteed_quasimetric (bool): Whether this is guaranteed to satisfy quasimetric constraints. 96 | 97 | Module Attributes: 98 | transforms (nn.Sequential[TransformBase]): Transforms to be applied on quasimetric components. 99 | reduction (ReductionBase): Reduction methods to aggregate components. 100 | 101 | Examples:: 102 | 103 | >>> iqe = IQE(128, dim_per_component=16) 104 | >>> print(iqe) 105 | IQE( 106 | guaranteed_quasimetric=True 107 | input_size=128, num_components=8, discount=None 108 | (transforms): Sequential() 109 | (reduction): MaxMean(input_num_components=8) 110 | ) 111 | >>> x = torch.randn(5, 128, requires_grad=True) 112 | >>> y = torch.randn(5, 128, requires_grad=True) 113 | >>> print(iqe(x, y)) 114 | tensor([3.3045, 3.8072, 3.9671, 3.3521, 3.7831],, grad_fn=) 115 | >>> print(iqe(y, x)) 116 | tensor([3.3850, 3.8457, 4.0870, 3.1757, 3.9459], grad_fn=) 117 | >>> print(iqe(x[:, None], x)) # pdist 118 | tensor([[0.0000, 3.8321, 3.7907, 3.5915, 3.3326], 119 | [3.9845, 0.0000, 4.0173, 3.8059, 3.7177], 120 | [3.7934, 4.3673, 0.0000, 4.0536, 3.6068], 121 | [3.1764, 3.4881, 3.5300, 0.0000, 2.9292], 122 | [3.7184, 3.8690, 3.8321, 3.5905, 0.0000]], grad_fn=) 123 | ''' 124 | 125 | def __init__(self, input_size: int, dim_per_component: int = 16, *, 126 | transforms: Collection[str] = (), reduction: str = 'maxmean', 127 | discount: Optional[float] = None, warn_if_not_quasimetric: bool = True): 128 | assert dim_per_component > 0, "dim_per_component must be positive" 129 | assert input_size % dim_per_component == 0, \ 130 | f"input_size={input_size} is not divisible by dim_per_component={dim_per_component}" 131 | num_components = input_size // dim_per_component 132 | super().__init__(input_size, num_components, guaranteed_quasimetric=True, warn_if_not_quasimetric=warn_if_not_quasimetric, 133 | transforms=transforms, reduction=reduction, discount=discount) 134 | self.latent_2d_shape = torch.Size([num_components, dim_per_component]) 135 | 136 | def compute_components(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: 137 | return iqe( 138 | x=x.unflatten(-1, self.latent_2d_shape), 139 | y=y.unflatten(-1, self.latent_2d_shape), 140 | ) 141 | -------------------------------------------------------------------------------- /torchqmet/mrn.py: -------------------------------------------------------------------------------- 1 | r""" 2 | Metric Residual Network (MRN) 3 | https://arxiv.org/abs/2208.08133 4 | """ 5 | 6 | from typing import * 7 | 8 | import warnings 9 | 10 | import torch 11 | import torch.nn as nn 12 | 13 | from . import QuasimetricBase 14 | 15 | 16 | class MRNProjector(nn.Sequential): 17 | output_size: int 18 | 19 | def __init__(self, input_size: int, *, output_size: int = 16, hidden_sizes: List[int] = [176]): 20 | modules = [] 21 | for hidden_size in hidden_sizes: 22 | modules.append(nn.Linear(input_size, hidden_size)) 23 | modules.append(nn.ReLU()) 24 | input_size = hidden_size 25 | modules.append(nn.Linear(input_size, output_size)) 26 | super().__init__(*modules) 27 | 28 | def __call__(self, z: torch.Tensor) -> torch.Tensor: 29 | return super().__call__(z) 30 | 31 | 32 | class MRN(QuasimetricBase): 33 | r""" 34 | Metric Residual Network (MRN): 35 | https://arxiv.org/abs/2208.08133 36 | 37 | One-line Usage: 38 | 39 | MRN(input_size: int, sym_p: float = 2, ...) 40 | 41 | Default arguments implement the MRN as described in the original MRN paper: 42 | 43 | d_z(x, y) = ( 1/d_sym * \sum_i (f_sym(x)[i] - f_sym(y))^2 )^(p/2) + \max_j ReLU( f_asym(x)[j] - f_asym(y)[j] ), 44 | 45 | where `f_sym` and `f_asym` are 2-layer MLPs, and `d_sym` is the output size of `f_sym`. 46 | 47 | + The first term is simply a (scaled) Euclidean distance raised to the `p`-th power, representing the symmetrical port. 48 | + The second term is the asymmetrical part. 49 | 50 | These two terms are used as two **components** of the quasimetric. With default arguments, a summation reduction 51 | combines them. 52 | 53 | NOTE:: 54 | Default arguments does not guarantee a true quasimetric, since one of the component is the **squared** Euclidean 55 | distance, rather than regular Euclidean distance. 56 | 57 | Following a fix proposed in the IQE paper (https://arxiv.org/abs/2211.15120), we allow setting 58 | `sym_p=1`, which uses the regular Euclidean distance instead, and guarantees a quasimetric. 59 | 60 | Alternatively, simply use subclass :class:`MRNFixed`, which changes the default of `sym_p` to `1`. 61 | 62 | MRN-Specific Args: 63 | input_size (int): Dimension of input latent vectors. 64 | sym_p (float): Exponent applied to the symmetrical term of Euclidean distance. 65 | Default: ``2``. 66 | proj_hidden_size (int): Hidden size of `f_sym` and `f_asym` MLPs. 67 | Default: ``176``. 68 | proj_output_size (int): Output size of `f_sym` and `f_asym` MLPs. 69 | Default: ``16``. 70 | 71 | Common Args (Exist for all quasimetrics, **Keyword-only**, Default values may be different for different quasimetrics): 72 | transforms (Collection[str]): A sequence of transforms to apply to the components, before reducing them to form 73 | the final latent quasimetric. 74 | Supported choices: 75 | + "concave_activation": Concave activation transform from Neural Norms paper. 76 | Default: ``()`` (no transforms). 77 | reduction (str): Reduction method to aggregate components into final quasimetric value. 78 | Supported choices: 79 | + "sum": Sum of components. 80 | + "max": Max of components. 81 | + "mean": Average of components. 82 | + "maxmean": Convex combination of max and mean. Used in original Deep Norm, Wide Norm, and IQE. 83 | + "deep_linear_net_weighted_sum": Weighted sum with weights given by a deep linear net. Used in 84 | original PQE, whose components have limited range [0, 1). 85 | Default: ``"sum"``. 86 | discounted (Optional[float]): If not ``None``, this module instead estimates discounted distances with the 87 | base as ``discounted``. 88 | Default ``None``, but recommended for PQEs (following original paper). 89 | warn_if_not_quasimetric (bool): If ``True``, issue a warning if this module does not always obey quasimetric 90 | constraints. MRNs always obey quasimetric constraints if `0 < sym_p <= 1`. 91 | Default: ``True``. 92 | 93 | Shape: 94 | - Input: Two broadcastable tensors of shape ``(..., input_size)`` 95 | - Output: ``(...)`` 96 | 97 | Non-Module Attributes: 98 | input_size (int) 99 | sym_p (float) 100 | num_components (int): Number of components to be combined to form the latent quasimetric. For MRN, this is always ``2``. 101 | discount (Optional[float]) 102 | guaranteed_quasimetric (bool): Whether this is guaranteed to satisfy quasimetric constraints. 103 | 104 | Module Attributes: 105 | transforms (nn.Sequential[TransformBase]): Transforms to be applied on quasimetric components. 106 | reduction (ReductionBase): Reduction methods to aggregate components. 107 | 108 | Examples:: 109 | 110 | >>> mrn = MRN(128) # default MRN 111 | .../torchqmet/mrn.py:61: UserWarning: MRN with `sym_p=2` may not be a quasimetric (see IQE paper Sec. C.2). Use 112 | `torchqmet.MRNFixed` with default `sym_p=1` to guarantee a quasimetric. 113 | >>> print(mrn) 114 | MRN( 115 | guaranteed_quasimetric=False 116 | input_size=128, num_components=2, discount=None 117 | sym_p=2 118 | (transforms): Sequential() 119 | (reduction): Sum(input_num_components=2) 120 | (sym_proj): MRNProjector( 121 | (0): Linear(in_features=128, out_features=176, bias=True) 122 | (1): ReLU() 123 | (2): Linear(in_features=176, out_features=16, bias=True) 124 | ) 125 | (asym_proj): MRNProjector( 126 | (0): Linear(in_features=128, out_features=176, bias=True) 127 | (1): ReLU() 128 | (2): Linear(in_features=176, out_features=16, bias=True) 129 | ) 130 | ) 131 | >>> x = torch.randn(5, 128, requires_grad=True) 132 | >>> y = torch.randn(5, 128, requires_grad=True) 133 | >>> print(mrn(x, y)) 134 | tensor([0.3584, 0.8246, 0.4646, 0.5300, 0.5409], grad_fn=) 135 | >>> print(mrn(y, x)) 136 | tensor([0.5899, 0.5375, 0.7205, 0.4931, 0.5727], grad_fn=) 137 | >>> print(mrn(x[:, None], x)) # pdist 138 | tensor([[0.0000, 0.3609, 0.5478, 0.6326, 0.4724], 139 | [0.5219, 0.0000, 0.5700, 0.7597, 0.5657], 140 | [0.4636, 0.5970, 0.0000, 0.4545, 0.5955], 141 | [0.8028, 0.8550, 1.1630, 0.0000, 0.7704], 142 | [0.6520, 0.5160, 0.8666, 0.4677, 0.0000]], grad_fn=) 143 | >>> 144 | >>> # MRN with fix to guarantee quasimetric constraints 145 | >>> mrn = MRNFixed(128) # or use MRN(..., sym_p=1) 146 | >>> print(mrn) 147 | MRNFixed( 148 | guaranteed_quasimetric=True 149 | input_size=128, num_components=2, discount=None 150 | sym_p=1 151 | (transforms): Sequential() 152 | (reduction): Sum(input_num_components=2) 153 | (sym_proj): MRNProjector( 154 | (0): Linear(in_features=128, out_features=176, bias=True) 155 | (1): ReLU() 156 | (2): Linear(in_features=176, out_features=16, bias=True) 157 | ) 158 | (asym_proj): MRNProjector( 159 | (0): Linear(in_features=128, out_features=176, bias=True) 160 | (1): ReLU() 161 | (2): Linear(in_features=176, out_features=16, bias=True) 162 | ) 163 | ) 164 | >>> print(mrn(x[:, None], x)) # pdist 165 | tensor([[0.0000, 0.7640, 0.7091, 0.5985, 0.7392], 166 | [0.7220, 0.0000, 0.8448, 0.9160, 0.8006], 167 | [0.8715, 0.7199, 0.0000, 0.9072, 0.8582], 168 | [0.7666, 0.8370, 0.7094, 0.0000, 0.9459], 169 | [0.7773, 0.6895, 0.7869, 0.8662, 0.0000]], grad_fn=) 170 | """ 171 | 172 | sym_p: float 173 | 174 | def __init__(self, input_size: int, sym_p: float = 2, proj_hidden_size: int = 176, proj_output_size: int = 16, *, 175 | transforms: Collection[str] = (), reduction: str = 'sum', 176 | discount: Optional[float] = None, warn_if_not_quasimetric: bool = True): 177 | if sym_p > 1: 178 | guaranteed_quasimetric = False 179 | if warn_if_not_quasimetric: 180 | warnings.warn( 181 | f'MRN with `sym_p={sym_p:g}` may not be a quasimetric (see IQE paper Sec. C.2). ' 182 | 'Use `torchqmet.MRNFixed` with default `sym_p=1` to guarantee a quasimetric.') 183 | elif sym_p <= 0: 184 | raise ValueError(f"Expect positive `sym_p`, but `sym_p={sym_p:g}`") 185 | else: 186 | guaranteed_quasimetric = True 187 | super().__init__(input_size, num_components=2, guaranteed_quasimetric=guaranteed_quasimetric, warn_if_not_quasimetric=warn_if_not_quasimetric, 188 | transforms=transforms, reduction=reduction, discount=discount) 189 | self.sym_p = sym_p 190 | self.sym_proj = MRNProjector(input_size, output_size=proj_output_size, hidden_sizes=[proj_hidden_size]) 191 | self.asym_proj = MRNProjector(input_size, output_size=proj_output_size, hidden_sizes=[proj_hidden_size]) 192 | 193 | def compute_components(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: 194 | xy = torch.stack(torch.broadcast_tensors(x, y), dim=0) 195 | sym_projx, sym_projy = self.sym_proj(xy).unbind(0) 196 | sym_dist = (sym_projx - sym_projy).square().mean(dim=-1).pow(self.sym_p / 2) 197 | asym_projx, asym_projy = self.asym_proj(xy).unbind(0) 198 | asym_dist = (asym_projx - asym_projy).max(dim=-1).values.relu() 199 | return torch.stack([sym_dist, asym_dist], dim=-1) 200 | 201 | def extra_repr(self) -> str: 202 | return super().extra_repr() + f'\nsym_p={self.sym_p:g}' 203 | 204 | 205 | class MRNFixed(MRN): 206 | r""" 207 | Metric Residual Network (MRN): 208 | https://arxiv.org/abs/2208.08133 209 | with fix proposed by the IQE paper (Sec. C.2): 210 | https://arxiv.org/abs/2211.15120 211 | 212 | One-line Usage: 213 | 214 | MRNFixed(input_size, sym_p=1, ...) 215 | 216 | Defaults to `sym_p=1`. This guarantees a quasimetric, unlike the original official MRN (where `sym_p=2`). 217 | 218 | See :class:`MRN` for details of other arguments. 219 | """ 220 | 221 | def __init__(self, input_size: int, sym_p: float = 1, proj_hidden_size: int = 176, proj_output_size: int = 16, *, 222 | transforms: Collection[str] = (), reduction: str = 'sum', 223 | discount: Optional[float] = None, warn_if_not_quasimetric: bool = True): 224 | super().__init__(input_size, sym_p, proj_hidden_size, proj_output_size, warn_if_not_quasimetric=warn_if_not_quasimetric, 225 | transforms=transforms, reduction=reduction, discount=discount) 226 | -------------------------------------------------------------------------------- /torchqmet/pqe/README.md: -------------------------------------------------------------------------------- 1 | Contents of this folder are modified from the [official PQE repository](https://github.com/ssnl/poisson_quasimetric_embedding). -------------------------------------------------------------------------------- /torchqmet/pqe/__init__.py: -------------------------------------------------------------------------------- 1 | r''' 2 | Poisson Quasimetric Embedding (PQE) 3 | https://arxiv.org/abs/2206.15478 4 | ''' 5 | 6 | from typing import * 7 | 8 | import torch 9 | 10 | from .measures import MeasureBase, LebesgueMeasure, GaussianBasedMeasure 11 | from .shapes import ShapeBase, HalfLineShape, GaussianShape 12 | from .. import QuasimetricBase 13 | 14 | 15 | class PQE(QuasimetricBase): 16 | r''' 17 | Poisson Quasimetric Embedding (PQE): 18 | https://arxiv.org/abs/2206.15478 19 | 20 | One-line Usage: 21 | 22 | PQE(input_size, dim_per_component=16, measure="lebesgue", shape="halfline", ...) 23 | 24 | 25 | PQE requires a specification of "shape" and "measure" for defining the Poisson process counts. We support 26 | + Measure: Lebesgue measure, a Gaussian-based measure. 27 | + Shape: Half-line, a Gaussian shape. 28 | These choices are sufficient to implement PQE-LH (Lebesgue + Half-line) and PQE-GG (Gaussian-based measure + Gaussian shape), 29 | the two PQE variants used in the original PQE paper. 30 | 31 | Default arguments implements PQE-LH, which has a simple form and generally works well according the PQE paper. 32 | To use PQE-GG, PQE paper's other proposed variant, set `shape="gaussian", measure="gaussian"`, or simply use subclass 33 | :class:`PQEGG`. Similarly, subclass :class:`PQELH` is gauranteed to PQE-LH. 34 | 35 | PQE-Specific Args: 36 | input_size (int): Dimension of input latent vectors 37 | dim_per_component (int): IQE splits latent vectors into chunks, where ach chunk computes gives an IQE component. 38 | This is the number of latent dimensions assigned to each chunk. This number must 39 | perfectly divide ``input_size``. 40 | Default: ``4``. 41 | measure (str): Measure used in the Poisson processes. Choices are ``"lebesgue"`` and ``"guassian"``. 42 | Default: ``"lebesgue"``. 43 | shape (str): Shape parametrizations used in the Poisson processes. Choices are ``"halfline"`` and ``"guassian"``. 44 | ``"guassian"`` can only be used with ``"guassian"`` measure. 45 | Default: ``"halfline"``. 46 | 47 | Common Args (Exist for all quasimetrics, **Keyword-only**, Default values may be different for different quasimetrics): 48 | transforms (Collection[str]): A sequence of transforms to apply to the components, before reducing them to form 49 | the final latent quasimetric. 50 | Supported choices: 51 | + "concave_activation": Concave activation transform from Neural Norms paper. 52 | Default: ``()`` (no transforms). 53 | reduction (str): Reduction method to aggregate components into final quasimetric value. 54 | Supported choices: 55 | + "sum": Sum of components. 56 | + "max": Max of components. 57 | + "mean": Average of components. 58 | + "maxmean": Convex combination of max and mean. Used in original Deep Norm, Wide Norm, and IQE. 59 | + "deep_linear_net_weighted_sum": Weighted sum with weights given by a deep linear net. Used in 60 | original PQE, whose components have limited range [0, 1). 61 | Default: ``"deep_linear_net_weighted_sum"``. 62 | discounted (Optional[float]): If not ``None``, this module instead estimates discounted distances with the 63 | base as ``discounted``. 64 | Default ``None``, but recommended for PQEs (following original paper). 65 | warn_if_not_quasimetric (bool): If ``True``, issue a warning if this module does not always obey quasimetric 66 | constraints. PQEs always obey quasimetric constraints. 67 | Default: ``True``. 68 | 69 | Shape: 70 | - Input: Two broadcastable tensors of shape ``(..., input_size)`` 71 | - Output: ``(...)`` 72 | 73 | Non-Module Attributes: 74 | input_size (int) 75 | num_components (int): Number of components to be combined to form the latent quasimetric. For PQEs, this is 76 | ``input_size // dim_per_component``. 77 | discount (Optional[float]) 78 | guaranteed_quasimetric (bool): Whether this is guaranteed to satisfy quasimetric constraints. 79 | 80 | Module Attributes: 81 | measure (MeasureBase): Poisson process measure used. 82 | shape (ShapeBase): Poisson process shape parametrization used. 83 | transforms (nn.Sequential[TransformBase]): Transforms to be applied on quasimetric components. 84 | reduction (ReductionBase): Reduction methods to aggregate components. 85 | 86 | Examples:: 87 | 88 | >>> pqe = PQE(128, dim_per_component=16) # default is PQE-LH, see `measure` and `shape` below 89 | >>> print(pqe) 90 | PQE( 91 | guaranteed_quasimetric=True 92 | input_size=128, num_components=8, discount=None 93 | (transforms): Sequential() 94 | (reduction): DeepLinearNetWeightedSum( 95 | input_num_components=8 96 | (alpha_net): DeepLinearNet( 97 | bias=True, non_negative=True 98 | (mats): ParameterList( 99 | (0): Parameter containing: [torch.float32 of size 1x64] 100 | (1): Parameter containing: [torch.float32 of size 64x64] 101 | (2): Parameter containing: [torch.float32 of size 64x64] 102 | (3): Parameter containing: [torch.float32 of size 64x8] 103 | ) 104 | ) 105 | ) 106 | (measure): LebesgueMeasure() 107 | (shape): HalfLineShape() 108 | ) 109 | >>> x = torch.randn(5, 128, requires_grad=True) 110 | >>> y = torch.randn(5, 128, requires_grad=True) 111 | >>> print(pqe(x, y)) 112 | tensor([0.5994, 0.7079, 0.6474, 0.7858, 0.6954], grad_fn=) 113 | >>> print(pqe(y, x)) 114 | tensor([0.5731, 0.7868, 0.9577, 0.5707, 0.7005], grad_fn=) 115 | >>> print(pqe(x[:, None], x)) # pdist 116 | tensor([[0.0000, 0.8147, 0.9515, 0.6505, 0.8131], 117 | [0.6491, 0.0000, 0.8892, 0.4910, 0.7271], 118 | [0.5663, 0.6442, 0.0000, 0.4402, 0.6461], 119 | [0.6756, 0.7252, 0.9157, 0.0000, 0.7032], 120 | [0.6689, 0.7006, 0.8784, 0.4509, 0.0000]], grad_fn=) 121 | >>> 122 | >>> # PQE-GG, modeling discounted distances 123 | >>> pqe = PQEGG(128, dim_per_component=16, discount=0.9) # or use PQE(..., shape="guassian", measure="gaussian") 124 | >>> # PQE-GG requires the `cdf_ops` extension. First usage of PQE-GG will trigger compile. 125 | >>> # See `PQE` docstring for details. 126 | >>> print(pqe(x, y)) # discounted distance 127 | tensor([0.9429, 0.9435, 0.9402, 0.9404, 0.9428], grad_fn=) 128 | >>> print(pqe(x[:, None], x)) # discounted pdist 129 | tensor([[1.0000, 0.9423, 0.9313, 0.9473, 0.9470], 130 | [0.9452, 1.0000, 0.9400, 0.9520, 0.9517], 131 | [0.9395, 0.9456, 1.0000, 0.9489, 0.9531], 132 | [0.9380, 0.9397, 0.9313, 1.0000, 0.9484], 133 | [0.9395, 0.9412, 0.9371, 0.9502, 1.0000]], grad_fn=) 134 | ''' 135 | 136 | measure: MeasureBase 137 | shape: ShapeBase 138 | 139 | def __init__(self, input_size: int, dim_per_component: int = 4, measure: str = 'lebesgue', shape: str = 'halfline', *, 140 | transforms: Collection[str] = (), reduction: str = 'deep_linear_net_weighted_sum', 141 | discount: Optional[float] = None, warn_if_not_quasimetric: bool = True): 142 | assert dim_per_component > 0, "dim_per_component must be positive" 143 | assert input_size % dim_per_component == 0, \ 144 | f"input_size={input_size} is not divisible by dim_per_component={dim_per_component}" 145 | num_components = input_size // dim_per_component 146 | super().__init__(input_size, num_components, guaranteed_quasimetric=True, warn_if_not_quasimetric=warn_if_not_quasimetric, 147 | transforms=transforms, reduction=reduction, discount=discount) 148 | # Will need to reshape the latents to be 2D so that 149 | # - the last dim represents Poisson processes that parametrize a distribution of quasipartitions 150 | # - the second last dim represents the number of mixtures of such quasipartition distributions 151 | self.latent_2d_shape = torch.Size([num_components, dim_per_component]) 152 | if measure == 'lebesgue': 153 | self.measure = LebesgueMeasure(shape=self.latent_2d_shape) 154 | elif measure == 'gaussian': 155 | self.measure = GaussianBasedMeasure(shape=self.latent_2d_shape) 156 | else: 157 | raise ValueError(f'Unsupported measure={repr(measure)}') 158 | if shape == 'halfline': 159 | self.shape = HalfLineShape() 160 | elif shape == 'gaussian': 161 | self.shape = GaussianShape() 162 | else: 163 | raise ValueError(f'Unsupported shape={repr(shape)}') 164 | 165 | def compute_components(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: 166 | return self.shape.expected_quasipartiton( 167 | x.unflatten(-1, self.latent_2d_shape), 168 | y.unflatten(-1, self.latent_2d_shape), 169 | measure=self.measure, 170 | ) 171 | 172 | 173 | class PQELH(PQE): 174 | r""" 175 | PQE-LH variant of Poisson Quasimetric Embedding (PQE), using Lebesgue measure and Half-line shape: 176 | https://arxiv.org/abs/2206.15478 177 | 178 | One-line Usage: 179 | 180 | PQELH(input_size, dim_per_component=16, ...) 181 | 182 | Unlike :class:`PQE`, arguments `measure="lebesgue"` and `shape="halfline"` are fixed and not configurable. 183 | 184 | See :class:`PQE` for details of other arguments. 185 | """ 186 | 187 | def __init__(self, input_size: int, dim_per_component: int = 4, *, 188 | transforms: Collection[str] = (), reduction: str = 'deep_linear_net_weighted_sum', 189 | discount: Optional[float] = None, warn_if_not_quasimetric: bool = True): 190 | super().__init__(input_size, dim_per_component, measure='lebesgue', shape='halfline', 191 | warn_if_not_quasimetric=warn_if_not_quasimetric, transforms=transforms, reduction=reduction, discount=discount) 192 | 193 | 194 | class PQEGG(PQE): 195 | r""" 196 | PQE-GG variant of Poisson Quasimetric Embedding (PQE), using Gaussian-based measure and Gaussian-based shape: 197 | https://arxiv.org/abs/2206.15478 198 | 199 | One-line Usage: 200 | 201 | PQEGG(input_size, dim_per_component=16, ...) 202 | 203 | Unlike :class:`PQE`, arguments `measure="gaussian"` and `shape="gaussian"` are fixed and not configurable. 204 | 205 | See :class:`PQE` for details of other arguments. 206 | """ 207 | 208 | def __init__(self, input_size: int, dim_per_component: int = 4, *, 209 | transforms: Collection[str] = (), reduction: str = 'deep_linear_net_weighted_sum', 210 | discount: Optional[float] = None, warn_if_not_quasimetric: bool = True): 211 | super().__init__(input_size, dim_per_component, measure='gaussian', shape='gaussian', 212 | warn_if_not_quasimetric=warn_if_not_quasimetric, transforms=transforms, reduction=reduction, discount=discount) 213 | 214 | -------------------------------------------------------------------------------- /torchqmet/pqe/cdf_ops/.gitignore: -------------------------------------------------------------------------------- 1 | __pycache__/ 2 | build/ 3 | dist/ 4 | *.egg-info/ 5 | -------------------------------------------------------------------------------- /torchqmet/pqe/cdf_ops/README.md: -------------------------------------------------------------------------------- 1 | # `cdf_ops` Extension 2 | 3 | This extension implements a couple functions for computing Bessel function, Poisson race probabilities, and Gaussian/non-central-chi-square CDFs. 4 | 5 | ## Documentations 6 | 7 | The provided functions are listed below. They work on both CPU and CUDA PyTorch Tensors. 8 | 9 | At first use of a function (if it is not found in the PyTorch installation), a compilation of the extension will trigger, which may take up to 10 minutes. Subsequent uses will use cached compilation results, as long as it is on the same GPU compute capabilities, CUDA and PyTorch versions. 10 | 11 | ```py 12 | def chndtr(x: torch.Tensor, df: Union[torch.Tensor, float], nc: torch.Tensor) -> torch.Tensor: 13 | r""" 14 | Computes the non-central Chi-square CDF. 15 | 16 | For a distribution with :attr:`df` degrees of freedom and :attr:`nc` non-centrality parameter, 17 | this evaluates the CDF at :attr:`x`. 18 | """ 19 | ... 20 | 21 | 22 | def i0(input: torch.Tensor) -> torch.Tensor: 23 | r""" 24 | Computes the zeroth order modified Bessel function of the first kind for each element of :attr:`input`. 25 | 26 | .. math:: 27 | \text{out}_{i} = I_0(\text{input}_{i}) = \sum_{k=0}^{\infty} \frac{(\text{input}_{i}^2/4)^k}{(k!)^2} 28 | """ 29 | ... 30 | 31 | 32 | def i0e(input: torch.Tensor) -> torch.Tensor: 33 | r""" 34 | Computes the exponentially scaled zeroth order modified Bessel function of the first kind (as defined below) 35 | for each element of :attr:`input`. 36 | 37 | .. math:: 38 | \text{out}_{i} = \exp(-|x|) * i0(x) = \exp(-|x|) * \sum_{k=0}^{\infty} \frac{(\text{input}_{i}^2/4)^k}{(k!)^2} 39 | """ 40 | ... 41 | 42 | 43 | def i1(input: torch.Tensor) -> torch.Tensor: 44 | r""" 45 | Computes the first order modified Bessel function of the first kind (as defined below) 46 | for each element of :attr:`input`. 47 | 48 | .. math:: 49 | \text{out}_{i} = \frac{(\text{input}_{i})}{2} * \sum_{k=0}^{\infty} \frac{(\text{input}_{i}^2/4)^k}{(k!) * (k+1)!} 50 | """ 51 | ... 52 | 53 | 54 | def i1e(input: torch.Tensor) -> torch.Tensor: 55 | r""" 56 | Computes the exponentially scaled first order modified Bessel function of the first kind (as defined below) 57 | for each element of :attr:`input`. 58 | 59 | .. math:: 60 | \text{out}_{i} = \exp(-|x|) * i1(x) = 61 | \exp(-|x|) * \frac{(\text{input}_{i})}{2} * \sum_{k=0}^{\infty} \frac{(\text{input}_{i}^2/4)^k}{(k!) * (k+1)!} 62 | """ 63 | ... 64 | 65 | 66 | def prob_two_poisson_gt(mu1: torch.Tensor, mu2: torch.Tensor) -> torch.Tensor: 67 | r""" 68 | Computes the elementwise ``Prob[ Poisson(mu1) > Poisson(mu2) ]``. 69 | """ 70 | ... 71 | 72 | 73 | def prob_two_poisson_le(mu1: torch.Tensor, mu2: torch.Tensor) -> torch.Tensor: 74 | r""" 75 | Computes the elementwise ``Prob[ Poisson(mu1) <= Poisson(mu2) ]``. 76 | """ 77 | ... 78 | 79 | 80 | def ndtr(x: torch.Tensor) -> torch.Tensor: 81 | r""" 82 | Computes the standard Gaussian CDF evaluated at :attr:`x`. 83 | """ 84 | ... 85 | 86 | 87 | def log_ndtr(x: torch.Tensor) -> torch.Tensor: 88 | r""" 89 | Computes the log of the standard Gaussian CDF evaluated at :attr:`x`. 90 | 91 | This is numerically more stable than calling ``ndtr(x).log()``, in both forward and backward. 92 | """ 93 | ... 94 | 95 | 96 | def prod_ndtr(x: torch.Tensor, *, dim: int = -1) -> torch.Tensor: 97 | r""" 98 | Computes ``ndtr(x).prod(dim=dim)``. 99 | 100 | This is numerically more stable than calling ``ndtr(x).prod(dim=dim)``, in both forward and backward. 101 | """ 102 | ... 103 | ``` 104 | 105 | ## FAQ 106 | 107 | **Q:** How to compile so that the compiled extension can be used for machines with GPUs of different compute capabilities (e.g., on a cluster with many types of GPUs)? 108 | 109 | **A:** Specify a environment flag like `TORCH_CUDA_ARCH_LIST='6.0;6.1;7.0;7.5+PTX'`. 110 | 111 | ## License 112 | 113 | Part of the code is modified from [`cephes`](https://www.netlib.org/cephes/) and [`CDFLIB`](https://people.sc.fsu.edu/~jburkardt/cpp_src/cdflib/cdflib.html). 114 | 115 | `cephes` is available from [`scipy`](https://github.com/scipy/scipy) under 3-clause BSD. All derived code from `cephes` are located under [`./cpu/cephes`](./cpu/cephes) and [`./cuda/cephes`](./cuda/cephes). 116 | 117 | For `CDFLIB`, while it is website release its under LGPL. We derive the code from [`scipy`](https://github.com/scipy/scipy), which is under 3-clause BSD. All derived code from `CDFLIB` are located under [`./cpu/cdflib`](./cpu/cdflib) and [`./cuda/cdflib`](./cuda/cdflib). 118 | -------------------------------------------------------------------------------- /torchqmet/pqe/cdf_ops/__init__.py: -------------------------------------------------------------------------------- 1 | from .op_wrappers import ( 2 | chndtr, i0, i0e, i1, i1e, prob_two_poisson_gt, prob_two_poisson_le, \ 3 | ndtr, log_ndtr, prod_ndtr 4 | ) 5 | 6 | __all__ = [ 7 | 'chndtr', 'i0', 'i0e', 'i1', 'i1e', 8 | 'prob_two_poisson_gt', 'prob_two_poisson_le', 9 | 'ndtr', 'log_ndtr', 'prod_ndtr', 10 | ] 11 | -------------------------------------------------------------------------------- /torchqmet/pqe/cdf_ops/cdf_ops.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | namespace cdf_ops { 4 | 5 | enum TwoPoissonComparisonProb { 6 | GT, 7 | LE 8 | }; 9 | 10 | } // namespace cdf_ops 11 | -------------------------------------------------------------------------------- /torchqmet/pqe/cdf_ops/cpu/cdflib/cdflib.hpp: -------------------------------------------------------------------------------- 1 | # include 2 | # include 3 | 4 | 5 | namespace cdflib { namespace cpu { 6 | 7 | double algdiv ( double *a, double *b ); 8 | double alnrel ( double *a ); 9 | double apser ( double *a, double *b, double *x, double *eps ); 10 | double bcorr ( double *a0, double *b0 ); 11 | double beta ( double a, double b ); 12 | double beta_asym ( double *a, double *b, double *lambda, double *eps ); 13 | double beta_frac ( double *a, double *b, double *x, double *y, double *lambda, 14 | double *eps ); 15 | void beta_grat ( double *a, double *b, double *x, double *y, double *w, 16 | double *eps,int *ierr ); 17 | void beta_inc ( double *a, double *b, double *x, double *y, double *w, 18 | double *w1, int *ierr ); 19 | void beta_inc_values ( int *n_data, double *a, double *b, double *x, double *fx ); 20 | double beta_log ( double *a0, double *b0 ); 21 | double beta_pser ( double *a, double *b, double *x, double *eps ); 22 | double beta_rcomp ( double *a, double *b, double *x, double *y ); 23 | double beta_rcomp1 ( int *mu, double *a, double *b, double *x, double *y ); 24 | double beta_up ( double *a, double *b, double *x, double *y, int *n, double *eps ); 25 | void binomial_cdf_values ( int *n_data, int *a, double *b, int *x, double *fx ); 26 | void cdfbet ( int *which, double *p, double *q, double *x, double *y, 27 | double *a, double *b, int *status, double *bound ); 28 | void cdfbin ( int *which, double *p, double *q, double *s, double *xn, 29 | double *pr, double *ompr, int *status, double *bound ); 30 | void cdfchi ( int *which, double *p, double *q, double *x, double *df, 31 | int *status, double *bound ); 32 | void cdfchn ( int *which, double *p, double *q, double *x, double *df, 33 | double *pnonc, int *status, double *bound ); 34 | void cdff ( int *which, double *p, double *q, double *f, double *dfn, 35 | double *dfd, int *status, double *bound ); 36 | void cdffnc ( int *which, double *p, double *q, double *f, double *dfn, 37 | double *dfd, double *phonc, int *status, double *bound ); 38 | void cdfgam ( int *which, double *p, double *q, double *x, double *shape, 39 | double *scale, int *status, double *bound ); 40 | void cdfnbn ( int *which, double *p, double *q, double *s, double *xn, 41 | double *pr, double *ompr, int *status, double *bound ); 42 | void cdfnor ( int *which, double *p, double *q, double *x, double *mean, 43 | double *sd, int *status, double *bound ); 44 | void cdfpoi ( int *which, double *p, double *q, double *s, double *xlam, 45 | int *status, double *bound ); 46 | void cdft ( int *which, double *p, double *q, double *t, double *df, 47 | int *status, double *bound ); 48 | void chi_noncentral_cdf_values ( int *n_data, double *x, double *lambda, 49 | int *df, double *cdf ); 50 | void chi_square_cdf_values ( int *n_data, int *a, double *x, double *fx ); 51 | void cumbet ( double *x, double *y, double *a, double *b, double *cum, 52 | double *ccum ); 53 | void cumbin ( double *s, double *xn, double *pr, double *ompr, 54 | double *cum, double *ccum ); 55 | void cumchi ( double *x, double *df, double *cum, double *ccum ); 56 | void cumchn ( double *x, double *df, double *pnonc, double *cum, 57 | double *ccum ); 58 | void cumf ( double *f, double *dfn, double *dfd, double *cum, double *ccum ); 59 | void cumfnc ( double *f, double *dfn, double *dfd, double *pnonc, 60 | double *cum, double *ccum ); 61 | void cumgam ( double *x, double *a, double *cum, double *ccum ); 62 | void cumnbn ( double *s, double *xn, double *pr, double *ompr, 63 | double *cum, double *ccum ); 64 | void cumnor ( double *arg, double *result, double *ccum ); 65 | void cumpoi ( double *s, double *xlam, double *cum, double *ccum ); 66 | void cumt ( double *t, double *df, double *cum, double *ccum ); 67 | double dbetrm ( double *a, double *b ); 68 | double dexpm1 ( double *x ); 69 | double dinvnr ( double *p, double *q ); 70 | void dinvr ( int *status, double *x, double *fx, 71 | unsigned long *qleft, unsigned long *qhi ); 72 | double dlanor ( double *x ); 73 | double dpmpar ( int *i ); 74 | void dstinv ( double *zsmall, double *zbig, double *zabsst, 75 | double *zrelst, double *zstpmu, double *zabsto, double *zrelto ); 76 | double dstrem ( double *z ); 77 | void dstzr ( double *zxlo, double *zxhi, double *zabstl, double *zreltl ); 78 | double dt1 ( double *p, double *q, double *df ); 79 | void dzror ( int *status, double *x, double *fx, double *xlo, 80 | double *xhi, unsigned long *qleft, unsigned long *qhi ); 81 | void erf_values ( int *n_data, double *x, double *fx ); 82 | double error_f ( double *x ); 83 | double error_fc ( int *ind, double *x ); 84 | double esum ( int *mu, double *x ); 85 | double eval_pol ( const double a[], const int *n, const double *x ); 86 | double exparg ( int *l ); 87 | void f_cdf_values ( int *n_data, int *a, int *b, double *x, double *fx ); 88 | void f_noncentral_cdf_values ( int *n_data, int *a, int *b, double *lambda, 89 | double *x, double *fx ); 90 | double fifdint ( const double a ); 91 | double fifdmax1 ( const double a, const double b ); 92 | double fifdmin1 ( const double a, const double b ); 93 | double fifdsign ( const double mag, const double sign ); 94 | inline long fifidint ( const double a ); 95 | long fifmod ( const long a, const long b ); 96 | double fpser ( double *a, double *b, double *x, double *eps ); 97 | void ftnstop ( std::string msg ); 98 | double gam1 ( double *a ); 99 | void gamma_inc ( double *a, double *x, double *ans, double *qans, int *ind ); 100 | void gamma_inc_inv ( double *a, double *x, double *x0, double *p, double *q, 101 | int *ierr ); 102 | void gamma_inc_values ( int *n_data, double *a, double *x, double *fx ); 103 | double gamma_ln1 ( double *a ); 104 | double gamma_log ( double *a ); 105 | void gamma_rat1 ( double *a, double *x, double *r, double *p, double *q, 106 | double *eps ); 107 | void gamma_values ( int *n_data, double *x, double *fx ); 108 | double gamma_x ( double *a ); 109 | double gsumln ( double *a, double *b ); 110 | int ipmpar ( int *i ); 111 | void negative_binomial_cdf_values ( int *n_data, int *f, int *s, double *p, 112 | double *cdf ); 113 | void normal_cdf_values ( int *n_data, double *x, double *fx ); 114 | void poisson_cdf_values ( int *n_data, double *a, int *x, double *fx ); 115 | double psi ( double *xx ); 116 | void psi_values ( int *n_data, double *x, double *fx ); 117 | double rcomp ( double *a, double *x ); 118 | double rexp ( double *x ); 119 | double rlog ( double *x ); 120 | double rlog1 ( double *x ); 121 | void student_cdf_values ( int *n_data, int *a, double *x, double *fx ); 122 | double stvaln ( double *p ); 123 | void timestamp ( void ); 124 | 125 | } // namespace cpu 126 | } // namespace cdflib 127 | -------------------------------------------------------------------------------- /torchqmet/pqe/cdf_ops/cpu/cephes/chbevl.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | /* 4 | * From 5 | * https://github.com/scipy/scipy/blob/5caee5d4ad564cfae4596f8dfa8b45997767035b/scipy/special/cephes/chbevl.c 6 | */ 7 | 8 | /* chbevl.c 9 | * 10 | * Evaluate Chebyshev series 11 | * 12 | * 13 | * 14 | * SYNOPSIS: 15 | * 16 | * int N; 17 | * double x, y, coef[N], chebevl(); 18 | * 19 | * y = chbevl( x, coef, N ); 20 | * 21 | * 22 | * 23 | * DESCRIPTION: 24 | * 25 | * Evaluates the series 26 | * 27 | * N-1 28 | * - ' 29 | * y = > coef[i] T (x/2) 30 | * - i 31 | * i=0 32 | * 33 | * of Chebyshev polynomials Ti at argument x/2. 34 | * 35 | * Coefficients are stored in reverse order, i.e. the zero 36 | * order term is last in the array. Note N is the number of 37 | * coefficients, not the order. 38 | * 39 | * If coefficients are for the interval a to b, x must 40 | * have been transformed to x -> 2(2x - b - a)/(b-a) before 41 | * entering the routine. This maps x from (a, b) to (-1, 1), 42 | * over which the Chebyshev polynomials are defined. 43 | * 44 | * If the coefficients are for the inverted interval, in 45 | * which (a, b) is mapped to (1/b, 1/a), the transformation 46 | * required is x -> 2(2ab/x - b - a)/(b-a). If b is infinity, 47 | * this becomes x -> 4a/x - 1. 48 | * 49 | * 50 | * 51 | * SPEED: 52 | * 53 | * Taking advantage of the recurrence properties of the 54 | * Chebyshev polynomials, the routine requires one more 55 | * addition per loop than evaluating a nested polynomial of 56 | * the same degree. 57 | * 58 | */ 59 | /* chbevl.c */ 60 | 61 | /* 62 | * Cephes Math Library Release 2.0: April, 1987 63 | * Copyright 1985, 1987 by Stephen L. Moshier 64 | * Direct inquiries to 30 Frost Street, Cambridge, MA 02140 65 | */ 66 | 67 | 68 | namespace cephes { namespace cpu { 69 | 70 | template 71 | static inline typename std::enable_if::value, scalar_t>::type 72 | chbevl(const scalar_t x, const scalar_t array[], const size_t len) 73 | { 74 | scalar_t b0, b1, b2; 75 | 76 | b0 = array[0]; 77 | b1 = static_cast(0.0); 78 | 79 | for (size_t i = 1; i < len; ++i) { 80 | b2 = b1; 81 | b1 = b0; 82 | b0 = x * b1 - b2 + array[i]; 83 | } 84 | 85 | return (0.5 * (b0 - b2)); 86 | } 87 | 88 | } // namespace cpu 89 | } // namespace cephes 90 | -------------------------------------------------------------------------------- /torchqmet/pqe/cdf_ops/cpu/cephes/i0.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | /* 4 | * From 5 | * https://github.com/scipy/scipy/blob/5caee5d4ad564cfae4596f8dfa8b45997767035b/scipy/special/cephes/i0.c 6 | */ 7 | 8 | /* i0.c 9 | * 10 | * Modified Bessel function of order zero 11 | * 12 | * 13 | * 14 | * SYNOPSIS: 15 | * 16 | * double x, y, i0(); 17 | * 18 | * y = i0( x ); 19 | * 20 | * 21 | * 22 | * DESCRIPTION: 23 | * 24 | * Returns modified Bessel function of order zero of the 25 | * argument. 26 | * 27 | * The function is defined as i0(x) = j0( ix ). 28 | * 29 | * The range is partitioned into the two intervals [0,8] and 30 | * (8, infinity). Chebyshev polynomial expansions are employed 31 | * in each interval. 32 | * 33 | * 34 | * 35 | * ACCURACY: 36 | * 37 | * Relative error: 38 | * arithmetic domain # trials peak rms 39 | * IEEE 0,30 30000 5.8e-16 1.4e-16 40 | * 41 | */ 42 | /* i0e.c 43 | * 44 | * Modified Bessel function of order zero, 45 | * exponentially scaled 46 | * 47 | * 48 | * 49 | * SYNOPSIS: 50 | * 51 | * double x, y, i0e(); 52 | * 53 | * y = i0e( x ); 54 | * 55 | * 56 | * 57 | * DESCRIPTION: 58 | * 59 | * Returns exponentially scaled modified Bessel function 60 | * of order zero of the argument. 61 | * 62 | * The function is defined as i0e(x) = exp(-|x|) j0( ix ). 63 | * 64 | * 65 | * 66 | * ACCURACY: 67 | * 68 | * Relative error: 69 | * arithmetic domain # trials peak rms 70 | * IEEE 0,30 30000 5.4e-16 1.2e-16 71 | * See i0(). 72 | * 73 | */ 74 | 75 | /* i0.c */ 76 | 77 | 78 | /* 79 | * Cephes Math Library Release 2.8: June, 2000 80 | * Copyright 1984, 1987, 2000 by Stephen L. Moshier 81 | */ 82 | 83 | #include "chbevl.h" 84 | #include 85 | 86 | 87 | namespace cephes { namespace cpu { 88 | 89 | namespace { 90 | 91 | using std::exp; 92 | using std::sqrt; 93 | 94 | template 95 | static inline scalar_t calc_i0(scalar_t x) { 96 | 97 | /* Chebyshev coefficients for exp(-x) I0(x) 98 | * in the interval [0,8]. 99 | * 100 | * lim(x->0){ exp(-x) I0(x) } = 1. 101 | */ 102 | static const scalar_t A[] = { 103 | -4.41534164647933937950E-18, 104 | 3.33079451882223809783E-17, 105 | -2.43127984654795469359E-16, 106 | 1.71539128555513303061E-15, 107 | -1.16853328779934516808E-14, 108 | 7.67618549860493561688E-14, 109 | -4.85644678311192946090E-13, 110 | 2.95505266312963983461E-12, 111 | -1.72682629144155570723E-11, 112 | 9.67580903537323691224E-11, 113 | -5.18979560163526290666E-10, 114 | 2.65982372468238665035E-9, 115 | -1.30002500998624804212E-8, 116 | 6.04699502254191894932E-8, 117 | -2.67079385394061173391E-7, 118 | 1.11738753912010371815E-6, 119 | -4.41673835845875056359E-6, 120 | 1.64484480707288970893E-5, 121 | -5.75419501008210370398E-5, 122 | 1.88502885095841655729E-4, 123 | -5.76375574538582365885E-4, 124 | 1.63947561694133579842E-3, 125 | -4.32430999505057594430E-3, 126 | 1.05464603945949983183E-2, 127 | -2.37374148058994688156E-2, 128 | 4.93052842396707084878E-2, 129 | -9.49010970480476444210E-2, 130 | 1.71620901522208775349E-1, 131 | -3.04682672343198398683E-1, 132 | 6.76795274409476084995E-1 133 | }; 134 | 135 | /* Chebyshev coefficients for exp(-x) sqrt(x) I0(x) 136 | * in the inverted interval [8,infinity]. 137 | * 138 | * lim(x->inf){ exp(-x) sqrt(x) I0(x) } = 1/sqrt(2pi). 139 | */ 140 | static const scalar_t B[] = { 141 | -7.23318048787475395456E-18, 142 | -4.83050448594418207126E-18, 143 | 4.46562142029675999901E-17, 144 | 3.46122286769746109310E-17, 145 | -2.82762398051658348494E-16, 146 | -3.42548561967721913462E-16, 147 | 1.77256013305652638360E-15, 148 | 3.81168066935262242075E-15, 149 | -9.55484669882830764870E-15, 150 | -4.15056934728722208663E-14, 151 | 1.54008621752140982691E-14, 152 | 3.85277838274214270114E-13, 153 | 7.18012445138366623367E-13, 154 | -1.79417853150680611778E-12, 155 | -1.32158118404477131188E-11, 156 | -3.14991652796324136454E-11, 157 | 1.18891471078464383424E-11, 158 | 4.94060238822496958910E-10, 159 | 3.39623202570838634515E-9, 160 | 2.26666899049817806459E-8, 161 | 2.04891858946906374183E-7, 162 | 2.89137052083475648297E-6, 163 | 6.88975834691682398426E-5, 164 | 3.36911647825569408990E-3, 165 | 8.04490411014108831608E-1 166 | }; 167 | 168 | 169 | scalar_t y; 170 | 171 | if (x < 0) 172 | x = -x; 173 | 174 | if (x <= 8.0) { 175 | y = (x / 2.0) - 2.0; 176 | if (e) { 177 | return chbevl(y, A, 30); 178 | } else { 179 | return (exp(x) * chbevl(y, A, 30)); 180 | } 181 | } 182 | 183 | if (e) { 184 | return (chbevl(32.0 / x - 2.0, B, 25) / sqrt(x)); 185 | } else { 186 | return (exp(x) * chbevl(32.0 / x - 2.0, B, 25) / sqrt(x)); 187 | } 188 | } 189 | 190 | } 191 | 192 | template 193 | static inline scalar_t i0(scalar_t x) { 194 | return calc_i0(x); 195 | } 196 | 197 | template 198 | static inline scalar_t i0e(scalar_t x) { 199 | return calc_i0(x); 200 | } 201 | 202 | } // namespace cpu 203 | } // namespace cephes 204 | -------------------------------------------------------------------------------- /torchqmet/pqe/cdf_ops/cpu/cephes/i1.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | /* 4 | * From 5 | * https://github.com/scipy/scipy/blob/5caee5d4ad564cfae4596f8dfa8b45997767035b/scipy/special/cephes/i1.c 6 | */ 7 | 8 | /* i1.c 9 | * 10 | * Modified Bessel function of order one 11 | * 12 | * 13 | * 14 | * SYNOPSIS: 15 | * 16 | * double x, y, i1(); 17 | * 18 | * y = i1( x ); 19 | * 20 | * 21 | * 22 | * DESCRIPTION: 23 | * 24 | * Returns modified Bessel function of order one of the 25 | * argument. 26 | * 27 | * The function is defined as i1(x) = -i j1( ix ). 28 | * 29 | * The range is partitioned into the two intervals [0,8] and 30 | * (8, infinity). Chebyshev polynomial expansions are employed 31 | * in each interval. 32 | * 33 | * 34 | * 35 | * ACCURACY: 36 | * 37 | * Relative error: 38 | * arithmetic domain # trials peak rms 39 | * IEEE 0, 30 30000 1.9e-15 2.1e-16 40 | * 41 | * 42 | */ 43 | /* i1e.c 44 | * 45 | * Modified Bessel function of order one, 46 | * exponentially scaled 47 | * 48 | * 49 | * 50 | * SYNOPSIS: 51 | * 52 | * double x, y, i1e(); 53 | * 54 | * y = i1e( x ); 55 | * 56 | * 57 | * 58 | * DESCRIPTION: 59 | * 60 | * Returns exponentially scaled modified Bessel function 61 | * of order one of the argument. 62 | * 63 | * The function is defined as i1(x) = -i exp(-|x|) j1( ix ). 64 | * 65 | * 66 | * 67 | * ACCURACY: 68 | * 69 | * Relative error: 70 | * arithmetic domain # trials peak rms 71 | * IEEE 0, 30 30000 2.0e-15 2.0e-16 72 | * See i1(). 73 | * 74 | */ 75 | 76 | /* i1.c 2 */ 77 | 78 | 79 | /* 80 | * Cephes Math Library Release 2.8: June, 2000 81 | * Copyright 1985, 1987, 2000 by Stephen L. Moshier 82 | */ 83 | 84 | #include "chbevl.h" 85 | #include 86 | 87 | 88 | namespace cephes { namespace cpu { 89 | 90 | namespace { 91 | 92 | using std::abs; 93 | using std::exp; 94 | using std::sqrt; 95 | 96 | 97 | template 98 | static inline scalar_t calc_i1(scalar_t x) { 99 | /* Chebyshev coefficients for exp(-x) I1(x) / x 100 | * in the interval [0,8]. 101 | * 102 | * lim(x->0){ exp(-x) I1(x) / x } = 1/2. 103 | */ 104 | 105 | static const scalar_t A[] = { 106 | 2.77791411276104639959E-18, 107 | -2.11142121435816608115E-17, 108 | 1.55363195773620046921E-16, 109 | -1.10559694773538630805E-15, 110 | 7.60068429473540693410E-15, 111 | -5.04218550472791168711E-14, 112 | 3.22379336594557470981E-13, 113 | -1.98397439776494371520E-12, 114 | 1.17361862988909016308E-11, 115 | -6.66348972350202774223E-11, 116 | 3.62559028155211703701E-10, 117 | -1.88724975172282928790E-9, 118 | 9.38153738649577178388E-9, 119 | -4.44505912879632808065E-8, 120 | 2.00329475355213526229E-7, 121 | -8.56872026469545474066E-7, 122 | 3.47025130813767847674E-6, 123 | -1.32731636560394358279E-5, 124 | 4.78156510755005422638E-5, 125 | -1.61760815825896745588E-4, 126 | 5.12285956168575772895E-4, 127 | -1.51357245063125314899E-3, 128 | 4.15642294431288815669E-3, 129 | -1.05640848946261981558E-2, 130 | 2.47264490306265168283E-2, 131 | -5.29459812080949914269E-2, 132 | 1.02643658689847095384E-1, 133 | -1.76416518357834055153E-1, 134 | 2.52587186443633654823E-1 135 | }; 136 | 137 | /* Chebyshev coefficients for exp(-x) sqrt(x) I1(x) 138 | * in the inverted interval [8,infinity]. 139 | * 140 | * lim(x->inf){ exp(-x) sqrt(x) I1(x) } = 1/sqrt(2pi). 141 | */ 142 | static const scalar_t B[] = { 143 | 7.51729631084210481353E-18, 144 | 4.41434832307170791151E-18, 145 | -4.65030536848935832153E-17, 146 | -3.20952592199342395980E-17, 147 | 2.96262899764595013876E-16, 148 | 3.30820231092092828324E-16, 149 | -1.88035477551078244854E-15, 150 | -3.81440307243700780478E-15, 151 | 1.04202769841288027642E-14, 152 | 4.27244001671195135429E-14, 153 | -2.10154184277266431302E-14, 154 | -4.08355111109219731823E-13, 155 | -7.19855177624590851209E-13, 156 | 2.03562854414708950722E-12, 157 | 1.41258074366137813316E-11, 158 | 3.25260358301548823856E-11, 159 | -1.89749581235054123450E-11, 160 | -5.58974346219658380687E-10, 161 | -3.83538038596423702205E-9, 162 | -2.63146884688951950684E-8, 163 | -2.51223623787020892529E-7, 164 | -3.88256480887769039346E-6, 165 | -1.10588938762623716291E-4, 166 | -9.76109749136146840777E-3, 167 | 7.78576235018280120474E-1 168 | }; 169 | 170 | 171 | scalar_t y, z; 172 | 173 | z = abs(x); 174 | if (z <= 8.0) { 175 | y = (z / 2.0) - 2.0; 176 | if (e) { 177 | z = chbevl(y, A, 29) * z; 178 | } else { 179 | z = chbevl(y, A, 29) * z * exp(z); 180 | } 181 | } 182 | else { 183 | if (e) { 184 | z = chbevl(32.0 / z - 2.0, B, 25) / sqrt(z); 185 | } else { 186 | z = exp(z) * chbevl(32.0 / z - 2.0, B, 25) / sqrt(z); 187 | } 188 | } 189 | if (x < 0.0) { 190 | z = -z; 191 | } 192 | return (z); 193 | 194 | } 195 | 196 | } 197 | 198 | template 199 | static inline scalar_t i1(scalar_t x) { 200 | return calc_i1(x); 201 | } 202 | 203 | template 204 | static inline scalar_t i1e(scalar_t x) { 205 | return calc_i1(x); 206 | } 207 | 208 | } // namespace cpu 209 | } // namespace cephes 210 | -------------------------------------------------------------------------------- /torchqmet/pqe/cdf_ops/cpu/cephes/ndtr.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | /* 4 | * From 5 | * https://github.com/scipy/scipy/blob/7c7a5f8393e7b16e5bc81c739c84fe2e639c367f/scipy/special/cephes/ndtr.c 6 | */ 7 | 8 | /* ndtr.c 9 | * 10 | * Normal distribution function 11 | * 12 | * 13 | * 14 | * SYNOPSIS: 15 | * 16 | * double x, y, ndtr(); 17 | * 18 | * y = ndtr( x ); 19 | * 20 | * 21 | * 22 | * DESCRIPTION: 23 | * 24 | * Returns the area under the Gaussian probability density 25 | * function, integrated from minus infinity to x: 26 | * 27 | * x 28 | * - 29 | * 1 | | 2 30 | * ndtr(x) = --------- | exp( - t /2 ) dt 31 | * sqrt(2pi) | | 32 | * - 33 | * -inf. 34 | * 35 | * = ( 1 + erf(z) ) / 2 36 | * = erfc(z) / 2 37 | * 38 | * where z = x/sqrt(2). Computation is via the functions 39 | * erf and erfc. 40 | * 41 | * 42 | * ACCURACY: 43 | * 44 | * Relative error: 45 | * arithmetic domain # trials peak rms 46 | * IEEE -13,0 30000 3.4e-14 6.7e-15 47 | * 48 | * 49 | * ERROR MESSAGES: 50 | * 51 | * message condition value returned 52 | * erfc underflow x > 37.519379347 0.0 53 | * 54 | */ 55 | /* erf.c 56 | * 57 | * Error function 58 | * 59 | * 60 | * 61 | * SYNOPSIS: 62 | * 63 | * double x, y, erf(); 64 | * 65 | * y = erf( x ); 66 | * 67 | * 68 | * 69 | * DESCRIPTION: 70 | * 71 | * The integral is 72 | * 73 | * x 74 | * - 75 | * 2 | | 2 76 | * erf(x) = -------- | exp( - t ) dt. 77 | * sqrt(pi) | | 78 | * - 79 | * 0 80 | * 81 | * For 0 <= |x| < 1, erf(x) = x * P4(x**2)/Q5(x**2); otherwise 82 | * erf(x) = 1 - erfc(x). 83 | * 84 | * 85 | * 86 | * ACCURACY: 87 | * 88 | * Relative error: 89 | * arithmetic domain # trials peak rms 90 | * IEEE 0,1 30000 3.7e-16 1.0e-16 91 | * 92 | */ 93 | /* erfc.c 94 | * 95 | * Complementary error function 96 | * 97 | * 98 | * 99 | * SYNOPSIS: 100 | * 101 | * double x, y, erfc(); 102 | * 103 | * y = erfc( x ); 104 | * 105 | * 106 | * 107 | * DESCRIPTION: 108 | * 109 | * 110 | * 1 - erf(x) = 111 | * 112 | * inf. 113 | * - 114 | * 2 | | 2 115 | * erfc(x) = -------- | exp( - t ) dt 116 | * sqrt(pi) | | 117 | * - 118 | * x 119 | * 120 | * 121 | * For small x, erfc(x) = 1 - erf(x); otherwise rational 122 | * approximations are computed. 123 | * 124 | * 125 | * 126 | * ACCURACY: 127 | * 128 | * Relative error: 129 | * arithmetic domain # trials peak rms 130 | * IEEE 0,26.6417 30000 5.7e-14 1.5e-14 131 | */ 132 | 133 | 134 | /* 135 | * Cephes Math Library Release 2.2: June, 1992 136 | * Copyright 1984, 1987, 1988, 1992 by Stephen L. Moshier 137 | * Direct inquiries to 30 Frost Street, Cambridge, MA 02140 138 | */ 139 | 140 | // #include "polevl.h" 141 | 142 | #ifndef _USE_MATH_DEFINES 143 | #define _USE_MATH_DEFINES 144 | #endif 145 | #include 146 | #include 147 | #include 148 | 149 | namespace cephes { namespace cpu { 150 | 151 | 152 | 153 | template 154 | static inline scalar_t ndtr(scalar_t a) 155 | { 156 | scalar_t x, y, z; 157 | 158 | if (std::isnan(a)) { 159 | // throw std::runtime_error("ndtr sees NaN") 160 | return NAN; 161 | } 162 | 163 | // std::sqrt(0.5); std::sqrt is not constexpr 164 | constexpr scalar_t SQRT1_2 = 0.707106781186547524400844362104849039284835937688474036588339868995366239231053519425193767163820786367507; 165 | 166 | x = a * SQRT1_2; 167 | z = std::fabs(x); 168 | 169 | if (z < SQRT1_2) { 170 | y = 0.5 + 0.5 * std::erf(x); 171 | 172 | } else { 173 | y = 0.5 * std::erfc(z); 174 | 175 | if (x > 0) { 176 | y = 1.0 - y; 177 | } 178 | } 179 | 180 | return (y); 181 | } 182 | 183 | 184 | // template 185 | // static inline scalar_t erfc(scalar_t a) 186 | // { 187 | // scalar_t p, q, x, y, z; 188 | 189 | // if (std::isnan(a)) { 190 | // // throw std::runtime_error("erfc sees NaN") 191 | // return NAN; 192 | // } 193 | 194 | // if (a < 0.0) 195 | // x = -a; 196 | // else 197 | // x = a; 198 | 199 | // if (x < 1.0) 200 | // return (1.0 - erf(a)); 201 | 202 | // z = -a * a; 203 | 204 | // if (z < -MAXLOG) { 205 | // under: 206 | // // sf_error("erfc", SF_ERROR_UNDERFLOW, NULL); 207 | // if (a < 0) 208 | // return (2.0); 209 | // else 210 | // return (0.0); 211 | // } 212 | 213 | // z = std::exp(z); 214 | 215 | // if (x < 8.0) { 216 | 217 | // static const scalar_t P[] = { 218 | // 2.46196981473530512524E-10, 219 | // 5.64189564831068821977E-1, 220 | // 7.46321056442269912687E0, 221 | // 4.86371970985681366614E1, 222 | // 1.96520832956077098242E2, 223 | // 5.26445194995477358631E2, 224 | // 9.34528527171957607540E2, 225 | // 1.02755188689515710272E3, 226 | // 5.57535335369399327526E2 227 | // }; 228 | 229 | // static const scalar_t Q[] = { 230 | // /* 1.00000000000000000000E0, */ 231 | // 1.32281951154744992508E1, 232 | // 8.67072140885989742329E1, 233 | // 3.54937778887819891062E2, 234 | // 9.75708501743205489753E2, 235 | // 1.82390916687909736289E3, 236 | // 2.24633760818710981792E3, 237 | // 1.65666309194161350182E3, 238 | // 5.57535340817727675546E2 239 | // }; 240 | 241 | // p = polevl(x, P, 8); 242 | // q = p1evl(x, Q, 8); 243 | // } 244 | // else { 245 | 246 | // static const scalar_t R[] = { 247 | // 5.64189583547755073984E-1, 248 | // 1.27536670759978104416E0, 249 | // 5.01905042251180477414E0, 250 | // 6.16021097993053585195E0, 251 | // 7.40974269950448939160E0, 252 | // 2.97886665372100240670E0 253 | // }; 254 | 255 | // static const scalar_t S[] = { 256 | // /* 1.00000000000000000000E0, */ 257 | // 2.26052863220117276590E0, 258 | // 9.39603524938001434673E0, 259 | // 1.20489539808096656605E1, 260 | // 1.70814450747565897222E1, 261 | // 9.60896809063285878198E0, 262 | // 3.36907645100081516050E0 263 | // }; 264 | 265 | // p = polevl(x, R, 5); 266 | // q = p1evl(x, S, 6); 267 | // } 268 | // y = (z * p) / q; 269 | 270 | // if (a < 0) 271 | // y = 2.0 - y; 272 | 273 | // if (y == 0.0) 274 | // goto under; 275 | 276 | // return (y); 277 | // } 278 | 279 | 280 | // template 281 | // static inline scalar_t erf(scalar_t x) 282 | // { 283 | // scalar_t y, z; 284 | 285 | // if (std::isnan(x)) { 286 | // // throw std::runtime_error("erf sees NaN") 287 | // return NAN; 288 | // } 289 | 290 | // if (x < 0.0) { 291 | // return -erf(-x); 292 | // } 293 | 294 | // if (std::fabs(x) > 1.0) 295 | // return (1.0 - erfc(x)); 296 | 297 | // z = x * x; 298 | 299 | // static const scalar_t T[] = { 300 | // 9.60497373987051638749E0, 301 | // 9.00260197203842689217E1, 302 | // 2.23200534594684319226E3, 303 | // 7.00332514112805075473E3, 304 | // 5.55923013010394962768E4 305 | // }; 306 | 307 | // static const scalar_t U[] = { 308 | // /* 1.00000000000000000000E0, */ 309 | // 3.35617141647503099647E1, 310 | // 5.21357949780152679795E2, 311 | // 4.59432382970980127987E3, 312 | // 2.26290000613890934246E4, 313 | // 4.92673942608635921086E4 314 | // }; 315 | 316 | // y = x * polevl(z, T, 4) / p1evl(z, U, 5); 317 | // return (y); 318 | 319 | // } 320 | 321 | /* 322 | * double log_ndtr(double a) 323 | * 324 | * For a > -20, use the existing ndtr technique and take a log. 325 | * for a <= -20, we use the Taylor series approximation of erf to compute 326 | * the log CDF directly. The Taylor series consists of two parts which we will name "left" 327 | * and "right" accordingly. The right part involves a summation which we compute until the 328 | * difference in terms falls below the machine-specific EPSILON. 329 | * 330 | * \Phi(z) &=& 331 | * \frac{e^{-z^2/2}}{-z\sqrt{2\pi}} * [1 + \sum_{n=1}^{N-1} (-1)^n \frac{(2n-1)!!}{(z^2)^n}] 332 | * + O(z^{-2N+2}) 333 | * = [\mbox{LHS}] * [\mbox{RHS}] + \mbox{error}. 334 | * 335 | */ 336 | 337 | template 338 | static inline scalar_t log_ndtr(scalar_t a) 339 | { 340 | 341 | if (a > 6) { 342 | return -ndtr(-a); /* log(1+x) \approx x */ 343 | } 344 | if (a > -20) { 345 | return std::log(ndtr(a)); 346 | } 347 | 348 | scalar_t log_LHS, /* we compute the left hand side of the approx (LHS) in one shot */ 349 | last_total = 0, /* variable used to check for convergence */ 350 | right_hand_side = 1, /* includes first term from the RHS summation */ 351 | numerator = 1, /* numerator for RHS summand */ 352 | denom_factor = 1, /* use reciprocal for denominator to avoid division */ 353 | denom_cons = 1.0 / (a * a); /* the precomputed division we use to adjust the denominator */ 354 | long sign = 1, i = 0; 355 | 356 | log_LHS = -0.5 * a * a - std::log(-a) - 0.5 * std::log(2 * M_PI); 357 | 358 | while (std::fabs(last_total - right_hand_side) > std::numeric_limits::epsilon()) { 359 | i += 1; 360 | last_total = right_hand_side; 361 | sign = -sign; 362 | denom_factor *= denom_cons; 363 | numerator *= 2 * i - 1; 364 | right_hand_side += sign * numerator * denom_factor; 365 | } 366 | return log_LHS + std::log(right_hand_side); 367 | } 368 | 369 | 370 | template 371 | static inline std::pair ndtr_log_ndtr(scalar_t a) 372 | { 373 | 374 | if (a > 6) { 375 | /* log(1+x) \approx x */ 376 | scalar_t x = -ndtr(-a); /* x = -ndtr_remain_val */ 377 | return std::make_pair(1 + x, x); 378 | } 379 | if (a > -20) { 380 | scalar_t ndtr_val = ndtr(a); 381 | return std::make_pair(ndtr_val, std::log(ndtr_val)); 382 | } 383 | 384 | scalar_t log_LHS, /* we compute the left hand side of the approx (LHS) in one shot */ 385 | last_total = 0, /* variable used to check for convergence */ 386 | right_hand_side = 1, /* includes first term from the RHS summation */ 387 | numerator = 1, /* numerator for RHS summand */ 388 | denom_factor = 1, /* use reciprocal for denominator to avoid division */ 389 | denom_cons = 1.0 / (a * a); /* the precomputed division we use to adjust the denominator */ 390 | long sign = 1, i = 0; 391 | 392 | log_LHS = -0.5 * a * a - std::log(-a) - 0.5 * std::log(2 * M_PI); 393 | 394 | while (std::fabs(last_total - right_hand_side) > std::numeric_limits::epsilon()) { 395 | i += 1; 396 | last_total = right_hand_side; 397 | sign = -sign; 398 | denom_factor *= denom_cons; 399 | numerator *= 2 * i - 1; 400 | right_hand_side += sign * numerator * denom_factor; 401 | } 402 | return std::make_pair(ndtr(a), log_LHS + std::log(right_hand_side)); 403 | } 404 | 405 | } // namespace cpu 406 | } // namespace cephes 407 | -------------------------------------------------------------------------------- /torchqmet/pqe/cdf_ops/cpu/cephes/polevl.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | /* 4 | * From 5 | * https://github.com/scipy/scipy/blob/5f4c4d802e5a56708d86909af6e5685cd95e6e66/scipy/special/cephes/polevl.h 6 | */ 7 | 8 | /* polevl.c 9 | * p1evl.c 10 | * 11 | * Evaluate polynomial 12 | * 13 | * 14 | * 15 | * SYNOPSIS: 16 | * 17 | * int N; 18 | * double x, y, coef[N+1], polevl[]; 19 | * 20 | * y = polevl( x, coef, N ); 21 | * 22 | * 23 | * 24 | * DESCRIPTION: 25 | * 26 | * Evaluates polynomial of degree N: 27 | * 28 | * 2 N 29 | * y = C + C x + C x +...+ C x 30 | * 0 1 2 N 31 | * 32 | * Coefficients are stored in reverse order: 33 | * 34 | * coef[0] = C , ..., coef[N] = C . 35 | * N 0 36 | * 37 | * The function p1evl() assumes that coef[N] = 1.0 and is 38 | * omitted from the array. Its calling arguments are 39 | * otherwise the same as polevl(). 40 | * 41 | * 42 | * SPEED: 43 | * 44 | * In the interest of speed, there are no checks for out 45 | * of bounds arithmetic. This routine is used by most of 46 | * the functions in the library. Depending on available 47 | * equipment features, the user may wish to rewrite the 48 | * program in microcode or assembly language. 49 | * 50 | */ 51 | 52 | 53 | /* 54 | * Cephes Math Library Release 2.1: December, 1988 55 | * Copyright 1984, 1987, 1988 by Stephen L. Moshier 56 | * Direct inquiries to 30 Frost Street, Cambridge, MA 02140 57 | */ 58 | 59 | /* Sources: 60 | * [1] Holin et. al., "Polynomial and Rational Function Evaluation", 61 | * https://www.boost.org/doc/libs/1_61_0/libs/math/doc/html/math_toolkit/roots/rational.html 62 | */ 63 | 64 | /* Scipy changes: 65 | * - 06-23-2016: add code for evaluating rational functions 66 | */ 67 | 68 | #include 69 | 70 | 71 | namespace cephes { namespace cpu { 72 | 73 | template 74 | static inline 75 | scalar_t polevl(const scalar_t x, const scalar_t coef[], const size_t N) 76 | { 77 | scalar_t ans; 78 | size_t i; 79 | const scalar_t *p; 80 | 81 | p = coef; 82 | ans = *p++; 83 | i = N; 84 | 85 | do 86 | ans = ans * x + *p++; 87 | while (--i); 88 | 89 | return (ans); 90 | } 91 | 92 | /* p1evl() */ 93 | /* N 94 | * Evaluate polynomial when coefficient of x is 1.0. 95 | * Otherwise same as polevl. 96 | */ 97 | 98 | template 99 | static inline 100 | scalar_t p1evl(scalar_t x, const scalar_t coef[], const size_t N) 101 | { 102 | scalar_t ans; 103 | const scalar_t *p; 104 | size_t i; 105 | 106 | p = coef; 107 | ans = x + *p++; 108 | i = N - 1; 109 | 110 | do 111 | ans = ans * x + *p++; 112 | while (--i); 113 | 114 | return (ans); 115 | } 116 | 117 | /* Evaluate a rational function. See [1]. */ 118 | 119 | template 120 | static inline 121 | scalar_t ratevl(scalar_t x, const scalar_t num[], const size_t M, 122 | const scalar_t denom[], const size_t N) 123 | { 124 | size_t i; 125 | int dir; 126 | scalar_t y, num_ans, denom_ans; 127 | scalar_t absx = std::fabs(x); 128 | const scalar_t *p; 129 | 130 | if (absx > 1) { 131 | /* Evaluate as a polynomial in 1/x. */ 132 | dir = -1; 133 | p = num + M; 134 | y = 1 / x; 135 | } else { 136 | dir = 1; 137 | p = num; 138 | y = x; 139 | } 140 | 141 | /* Evaluate the numerator */ 142 | num_ans = *p; 143 | p += dir; 144 | for (i = 1; i <= M; i++) { 145 | num_ans = num_ans * y + *p; 146 | p += dir; 147 | } 148 | 149 | /* Evaluate the denominator */ 150 | if (absx > 1) { 151 | p = denom + N; 152 | } else { 153 | p = denom; 154 | } 155 | 156 | denom_ans = *p; 157 | p += dir; 158 | for (i = 1; i <= N; i++) { 159 | denom_ans = denom_ans * y + *p; 160 | p += dir; 161 | } 162 | 163 | if (absx > 1) { 164 | i = N - M; 165 | return std::pow(x, i) * num_ans / denom_ans; 166 | } else { 167 | return num_ans / denom_ans; 168 | } 169 | } 170 | 171 | } // namespace cpu 172 | } // namespace cephes 173 | -------------------------------------------------------------------------------- /torchqmet/pqe/cdf_ops/cpu/kernels.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include "../cdf_ops.h" 4 | #include "cdflib/cdflib.hpp" 5 | #include "cephes/i0.h" 6 | #include "cephes/i1.h" 7 | #include "cephes/ndtr.h" 8 | 9 | #ifndef _USE_MATH_DEFINES 10 | #define _USE_MATH_DEFINES 11 | #endif 12 | #include 13 | #include 14 | #include 15 | #include 16 | 17 | #include 18 | #include 19 | #include 20 | #include 21 | 22 | 23 | namespace cdf_ops { 24 | namespace cpu { 25 | 26 | namespace { 27 | 28 | // https://github.com/scipy/scipy/blob/5f4c4d802e5a56708d86909af6e5685cd95e6e66/scipy/special/cdf_wrappers.c#L65-L98 29 | template 30 | static inline 31 | double cdflib_get_result(const char* name, int status, double bound, double result) { 32 | TORCH_CHECK(status >= 0, name, ": Invalid input parameter ", -status, " is out of range"); 33 | switch (status) { 34 | case 0: 35 | /* no error */ 36 | return result; 37 | case 1: 38 | if (return_bound) { 39 | TORCH_WARN(name, ": Answer appears to be lower than lowest search bound (", bound, ")"); 40 | return bound; 41 | } else { 42 | TORCH_CHECK(false, name, ": Answer appears to be lower than lowest search bound (", bound, ")"); 43 | } 44 | break; 45 | case 2: 46 | if (return_bound) { 47 | TORCH_WARN(name, ": Answer appears to be higher than highest search bound (", bound, ")"); 48 | return bound; 49 | } else { 50 | TORCH_CHECK(false, name, ": Answer appears to be higher than highest search bound (", bound, ")"); 51 | } 52 | break; 53 | case 3: 54 | case 4: 55 | TORCH_CHECK(false, name, ": Two parameters that should sum to 1.0 do not"); 56 | break; 57 | case 10: 58 | TORCH_CHECK(false, name, ": Computational error"); 59 | break; 60 | default: 61 | TORCH_CHECK(false, name, ": Unknown error"); 62 | } 63 | return std::numeric_limits::quiet_NaN(); 64 | } 65 | 66 | } 67 | 68 | template 69 | static inline 70 | scalar_t cdflib_chndtr(scalar_t _x, scalar_t _df, scalar_t _nc) { 71 | if (_x != _x || _df != _df || _nc != _nc) { 72 | // cdf_ops doesn't handle NaN well. 73 | return NAN; 74 | } 75 | 76 | // NOTE [ Non-central Chi-square CDF Bounds ] 77 | // 78 | // The CDFLIB algorithm is *really* slow when both `x` and `nc` are large 79 | // and can scale with `nc`. 80 | // 81 | // See documentation for CDFLIB CDFCHN, copied below: 82 | // 83 | // The computation time required for this routine is proportional 84 | // to the noncentrality parameter (PNONC). Very large values of 85 | // this parameter can consume immense computer resources. This is 86 | // why the search range is bounded by 1e9. 87 | // 88 | // Hence, we employ the CDF bounds from 89 | // http://proceedings.mlr.press/v22/kolar12/kolar12Supple.pdf 90 | // Lemma8. 91 | 92 | if (use_bounds) { 93 | constexpr scalar_t neglogeps = -std::log(std::numeric_limits::epsilon()); 94 | constexpr scalar_t sqrt_neglogeps = std::sqrt(neglogeps); 95 | const scalar_t mean = _df + _nc; 96 | const scalar_t term1 = 2 * std::sqrt(_df + 2 * _nc) * sqrt_neglogeps; 97 | if (_x >= mean + term1 + 2 * neglogeps) { 98 | return 1; 99 | } else if (_x <= mean - term1) { 100 | return 0; 101 | } 102 | } 103 | 104 | int which = 1; 105 | double q = 0, p = 0, bound = 0; 106 | int status = 10; 107 | 108 | double x = (double) _x; 109 | double df = (double) _df; 110 | double nc = (double) _nc; 111 | 112 | cdflib::cpu::cdfchn(&which, &p, &q, &x, &df, &nc, &status, &bound); 113 | scalar_t cdf = cdflib_get_result("cdfchn", status, bound, p); 114 | 115 | if (clamp_01) { 116 | if (cdf < 0) { 117 | return 0; 118 | } else if (cdf > 1) { 119 | return 1; 120 | } 121 | } 122 | return cdf; 123 | } 124 | 125 | static inline 126 | void chndtr_kernel_cpu(at::TensorIterator& iter) { 127 | AT_DISPATCH_FLOATING_TYPES( 128 | iter.dtype(), "chndtr_cpu", [&] { 129 | at::native::cpu_kernel(iter, &cdflib_chndtr); 130 | } 131 | ); 132 | } 133 | 134 | 135 | static inline 136 | void chndtr_scalar_kernel_cpu(at::TensorIterator& iter, double df) { 137 | AT_DISPATCH_FLOATING_TYPES( 138 | iter.dtype(), "chndtr_scalar_cpu", [&] { 139 | at::native::cpu_kernel(iter, [=](scalar_t x, scalar_t nc) -> scalar_t { 140 | return cdflib_chndtr(x, df, nc); // it is internally double anyways, so avoid precision lost on df... 141 | }); 142 | } 143 | ); 144 | } 145 | 146 | 147 | static inline 148 | void i0e_kernel_cpu(at::TensorIterator& iter) { 149 | AT_DISPATCH_FLOATING_TYPES( 150 | iter.dtype(), "i0e_cpu", [&] { 151 | at::native::cpu_kernel(iter, &cephes::cpu::i0e); 152 | } 153 | ); 154 | } 155 | 156 | 157 | static inline 158 | void i1_kernel_cpu(at::TensorIterator& iter) { 159 | AT_DISPATCH_FLOATING_TYPES( 160 | iter.dtype(), "i1_cpu", [&] { 161 | at::native::cpu_kernel(iter, &cephes::cpu::i1); 162 | } 163 | ); 164 | } 165 | 166 | 167 | static inline 168 | void i1e_kernel_cpu(at::TensorIterator& iter) { 169 | AT_DISPATCH_FLOATING_TYPES( 170 | iter.dtype(), "i1e_cpu", [&] { 171 | at::native::cpu_kernel(iter, &cephes::cpu::i1e); 172 | } 173 | ); 174 | } 175 | 176 | 177 | static inline 178 | void ndtr_kernel_cpu(at::TensorIterator& iter) { 179 | AT_DISPATCH_FLOATING_TYPES( 180 | iter.dtype(), "ndtr_cpu", [&] { 181 | at::native::cpu_kernel(iter, &cephes::cpu::ndtr); 182 | } 183 | ); 184 | } 185 | 186 | 187 | static inline 188 | void log_ndtr_kernel_cpu(at::TensorIterator& iter) { 189 | AT_DISPATCH_FLOATING_TYPES( 190 | iter.dtype(), "log_ndtr_cpu", [&] { 191 | at::native::cpu_kernel(iter, &cephes::cpu::log_ndtr); 192 | } 193 | ); 194 | } 195 | 196 | 197 | static inline 198 | void ndtr_log_ndtr_kernel_cpu(at::TensorIterator& iter) { 199 | AT_DISPATCH_FLOATING_TYPES( 200 | iter.input_dtype(), "ndtr_log_ndtr_cpu", [&] { 201 | at::native::cpu_kernel(iter, [](scalar_t x) -> c10::complex { 202 | auto result = cephes::cpu::ndtr_log_ndtr(x); 203 | return c10::complex(result.first, result.second); 204 | }); 205 | } 206 | ); 207 | } 208 | 209 | 210 | static inline 211 | void ndtr_backward_kernel_cpu(at::TensorIterator& iter) { 212 | AT_DISPATCH_FLOATING_TYPES( 213 | iter.dtype(), "ndtr_backward_cpu", [&] { 214 | at::native::cpu_kernel(iter, [](scalar_t x, scalar_t gout) -> scalar_t { 215 | // Forward ndtr(x) = 0.5 [ 1 + erf( x / sqrt(2) )] 216 | // 217 | // Erf backward: 218 | // - name: erf(Tensor self) -> Tensor 219 | // self: 2.0 / sqrt(M_PI) * exp(-(self.pow(2))) * grad 220 | 221 | // Backward grad * 0.5 * 2 / sqrt(pi) * exp( - (x / sqrt(2)).pow(2) ) / sqrt(2) 222 | // = grad / sqrt(2 pi) * exp( - x * x / 2) 223 | 224 | // std::sqrt(2 * M_PI); std::sqrt is not constexpr. 225 | constexpr scalar_t SQRT_2PI = 2.5066282746310005024157652848110452530069867406099383166299235763422936546078419749466; 226 | return std::exp(- x * x / 2) * gout / SQRT_2PI; 227 | }); 228 | } 229 | ); 230 | } 231 | 232 | 233 | static inline 234 | void log_ndtr_backward_kernel_cpu(at::TensorIterator& iter) { 235 | AT_DISPATCH_FLOATING_TYPES( 236 | iter.dtype(), "log_ndtr_backward_cpu", [&] { 237 | at::native::cpu_kernel(iter, [](scalar_t x, scalar_t gout) -> scalar_t { 238 | scalar_t ndtr_val = cephes::cpu::ndtr(x); 239 | // std::sqrt(2 * M_PI); std::sqrt is not constexpr. 240 | constexpr scalar_t SQRT_2PI = 2.5066282746310005024157652848110452530069867406099383166299235763422936546078419749466; 241 | return std::exp(- x * x / 2) * gout / ndtr_val / SQRT_2PI; 242 | }); 243 | } 244 | ); 245 | } 246 | 247 | 248 | static inline 249 | void log_ndtr_backward_with_ndtr_kernel_cpu(at::TensorIterator& iter) { 250 | AT_DISPATCH_FLOATING_TYPES( 251 | iter.dtype(), "log_ndtr_backward_with_ndtr_cpu", [&] { 252 | at::native::cpu_kernel(iter, [](scalar_t x, scalar_t ndtr_val, scalar_t gout) -> scalar_t { 253 | // std::sqrt(2 * M_PI); std::sqrt is not constexpr. 254 | constexpr scalar_t SQRT_2PI = 2.5066282746310005024157652848110452530069867406099383166299235763422936546078419749466; 255 | return std::exp(- x * x / 2) * gout / ndtr_val / SQRT_2PI; 256 | }); 257 | } 258 | ); 259 | } 260 | 261 | 262 | template 263 | static inline 264 | void prob_two_poisson_kernel_cpu(at::TensorIterator& iter) { 265 | // NOTE [ Skellam CDF Bounds at 0 ] 266 | // 267 | // In `chndtr`, we used bounds for general noncentral chi-sq CDFs. Here, 268 | // we can use a slightly different one based on Poisson race. 269 | // https://www.wikiwand.com/en/Poisson_distribution#/Poisson_races 270 | // 271 | // Poisson race says that for independent X ~ Poisson(mu1), Y ~ Poisson(mu2), 272 | // mu1 > mu2, 273 | // Pr[ X - Y >= 0 ] <= exp{ -(\sqrt{mu1} - \sqrt{mu2})^2 }, 274 | // given by standard Chernoff bound. 275 | // 276 | // This implies that 277 | // 278 | // 1. if mu1 > mu2, 279 | // 280 | // (\sqrt{mu1} - \sqrt{mu2})^2 >= t 281 | // => Pr[ Y <= X ] > 1 - exp(-t). 282 | // 283 | // 2. if mu1 < mu2, 284 | // 285 | // Pr[ Y <= X ] = Pr[X = Y] + Pr[ Y < X ] 286 | // = Pr[X = Y] + 1 - Pr[Y >= X] 287 | // >= 1 - Pr[Y >= X] 288 | // >= 1 - exp{ -(\sqrt{mu1} - \sqrt{mu2})^2 }. 289 | // 290 | // Side note: we have following fact, which maybe integrated to strengthen the bound: 291 | // Pr[ X = Y ] = Q_1 ( \sqrt{2 mu1}, \sqrt{2 mu2} ) 292 | // = exp{ -(\sqrt{mu1} - \sqrt{mu2})^2 } I0e( 2\sqrt{mu1 mu2} ). 293 | // 294 | // Hence, if (\sqrt{mu1} - \sqrt{mu2})^2 is large enough, we can just act like an indicator 295 | // function! 296 | // 297 | // But how different is this from the bounds used in `chndtr`. When setting 298 | // ` x = 2 mu1` 299 | // `df = 2` 300 | // `nc = 2 mu2` 301 | // 302 | // The `chndtr` bounds reverts to indicator when one of the following is true 303 | // 1. mu1 >= 1 + mu2 + \sqrt{ (2 + 4 mu2) t } + t 304 | // 2. mu1 <= 1 + mu2 - \sqrt{ (2 + 4 mu2) t }. 305 | // 306 | // The above bound reverts to indicator when one of the following is true 307 | // 1. mu1 >= mu2 + (\sqrt{mu1} + \sqrt{mu2}) \sqrt{t} 308 | // 2. mu1 <= mu2 - (\sqrt{mu1} + \sqrt{mu2}) \sqrt{t}. 309 | // 310 | // It is not really clear whether the above bound would lead to more improvement 311 | // frequently. So we do not implement it. 312 | 313 | AT_DISPATCH_FLOATING_TYPES( 314 | iter.dtype(), "prob_two_poisson_cpu", [&] { 315 | at::native::cpu_kernel(iter, [](scalar_t mu1, scalar_t mu2) -> scalar_t { 316 | // See NOTE [ Relation between Non-central Chi Square and Skellam ] 317 | // 318 | // Compute Prob[ NCX2( 2, 2*mu2 ) < 2*mu1 ] 319 | // 320 | // df = 2. 321 | // nc = 2 * mu2. 322 | // x = 2 * mu1 323 | 324 | double df = 2; 325 | double nc = 2 * mu2; 326 | double x = 2 * mu1; 327 | 328 | if (comp == TwoPoissonComparisonProb::GT) { 329 | return cdflib_chndtr(x, df, nc); 330 | } else if (comp == TwoPoissonComparisonProb::LE) { 331 | return 1 - cdflib_chndtr(x, df, nc); 332 | } else { 333 | __builtin_unreachable(); 334 | } 335 | }); 336 | } 337 | ); 338 | } 339 | 340 | template 341 | static inline 342 | void prob_two_poisson_grad_mu1_kernel_cpu(at::TensorIterator& iter) { 343 | AT_DISPATCH_FLOATING_TYPES( 344 | iter.dtype(), "prob_two_poisson_grad_mu1_cpu", [&] { 345 | at::native::cpu_kernel(iter, [](scalar_t mu1, scalar_t mu2, scalar_t gout) -> scalar_t { 346 | // See NOTE [ Relation between Non-central Chi Square and Skellam ] 347 | // 348 | // Compute 349 | // auto grad_mu1 = (-mu1 - mu2).exp() * torch::i0(2 * (mu1 * mu2).sqrt()) * gout; 350 | 351 | double g_gt; 352 | 353 | if (!use_besselIe) { 354 | g_gt = std::exp(-mu1 - mu2) * ::calc_i0(std::sqrt(mu1 * mu2) * 2) * gout; 355 | } else { 356 | scalar_t twice_sqrtmu12 = std::sqrt(mu1 * mu2) * 2; 357 | g_gt = std::exp(twice_sqrtmu12 - mu1 - mu2) * cephes::cpu::i0e(twice_sqrtmu12) * gout; 358 | } 359 | 360 | if (comp == TwoPoissonComparisonProb::GT) { 361 | return g_gt; 362 | } else if (comp == TwoPoissonComparisonProb::LE) { 363 | return -g_gt; 364 | } else { 365 | __builtin_unreachable(); 366 | } 367 | 368 | }); 369 | } 370 | ); 371 | } 372 | 373 | template 374 | static inline 375 | void prob_two_poisson_grad_mu2_kernel_cpu(at::TensorIterator& iter) { 376 | AT_DISPATCH_FLOATING_TYPES( 377 | iter.dtype(), "prob_two_poisson_grad_mu2_cpu", [&] { 378 | at::native::cpu_kernel(iter, [](scalar_t mu1, scalar_t mu2, scalar_t out, scalar_t gout) -> scalar_t { 379 | // See NOTE [ Relation between Non-central Chi Square and Skellam ] 380 | // 381 | // Compute 382 | // auto grad_mu2 = (chndtr_scalar(2 * mu1, 4, 2 * mu2) - out) * gout; 383 | if (mu1 == 0) { 384 | 385 | if (comp == TwoPoissonComparisonProb::GT) { 386 | return -out * gout; 387 | } else if (comp == TwoPoissonComparisonProb::LE) { 388 | return (1 - out) * gout; 389 | } else { 390 | __builtin_unreachable(); 391 | } 392 | 393 | } else if (!use_besselIe || mu2 == 0) { 394 | // nc = 2 * mu2. When mu2 == 0, the chndtr code computes cdf for Chi-Square isntead. 395 | // So let it handle. 396 | 397 | if (comp == TwoPoissonComparisonProb::GT) { 398 | return (cdflib_chndtr(2 * mu1, 4, 2 * mu2) - out) * gout; 399 | } else if (comp == TwoPoissonComparisonProb::LE) { 400 | return (1 - out - cdflib_chndtr(2 * mu1, 4, 2 * mu2)) * gout; 401 | } else { 402 | __builtin_unreachable(); 403 | } 404 | 405 | 406 | } else { 407 | scalar_t twice_sqrtmu12 = std::sqrt(mu1 * mu2) * 2; 408 | scalar_t log_mu1 = std::log(mu1); 409 | scalar_t log_mu2 = std::log(mu2); 410 | 411 | scalar_t g_le = std::exp( 412 | (log_mu1 - log_mu2) / 2 + twice_sqrtmu12 - mu1 - mu2 413 | ) * cephes::cpu::i1e(twice_sqrtmu12) * gout; 414 | 415 | if (comp == TwoPoissonComparisonProb::GT) { 416 | return -g_le; 417 | } else if (comp == TwoPoissonComparisonProb::LE) { 418 | return g_le; 419 | } else { 420 | __builtin_unreachable(); 421 | } 422 | 423 | 424 | } 425 | }); 426 | } 427 | ); 428 | } 429 | 430 | } // namespace cpu 431 | } // namespace cdf_ops 432 | -------------------------------------------------------------------------------- /torchqmet/pqe/cdf_ops/cuda/cdflib/chndtr.cuh: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include "cumchn.cuh" 4 | 5 | #include 6 | #include 7 | #include 8 | #include 9 | #include 10 | 11 | #include 12 | #include 13 | #include 14 | #include 15 | #include 16 | #include 17 | #include 18 | #include 19 | 20 | 21 | namespace cdflib { 22 | namespace cuda { 23 | 24 | template 25 | __host__ __device__ __forceinline__ 26 | scalar_t _chndtr(scalar_t x, scalar_t df, scalar_t pnonc) { 27 | 28 | static_assert(std::is_same::value || std::is_same::value ,"Unsupported scalar_t"); 29 | 30 | // constexpr scalar_t tent9 = 1.0e9; 31 | // constexpr double tol = 1.0e-8; 32 | // constexpr double atol = 1.0e-50; 33 | // constexpr double zero = 1.0e-300; 34 | // constexpr double one = 1.0e0 - 1.0e-16; 35 | // constexpr scalar_t inf = std::is_same::value ? 1.0e300 : 1.0e32; 36 | 37 | // if (x > inf) { 38 | // x = inf; 39 | // } 40 | // if (df > inf) { 41 | // df = inf; 42 | // } 43 | // if (pnonc > tent9){ 44 | // pnonc = tent9; 45 | // } 46 | 47 | // execute cumchn 48 | // cumchn(x,df,pnonc,p,q); 49 | // 50 | // Somehow the following commented code reproduces the CPU result (which always computes in float64) 51 | // fully with *float32*, yet the scalar_t version (uncommented code) does not with *float64*. 52 | // I'll just understand this as vectorization. 53 | // double p, q; 54 | // cumchn((double) x, (double) df, (double) pnonc, &p, &q); 55 | scalar_t p, q; 56 | cumchn(x, df, pnonc, &p, &q); 57 | 58 | if (clamp_01) { 59 | if (p > 1) { 60 | p = 1; 61 | } else if (p < 0) { 62 | p = 0; 63 | } 64 | } 65 | return p; 66 | }; 67 | 68 | 69 | constexpr double chndtr_double_thresh = 1.0e3; 70 | 71 | template 72 | __host__ __device__ __forceinline__ 73 | scalar_t chndtr(scalar_t x, scalar_t df, scalar_t pnonc) { 74 | 75 | if (x != x || df != df || pnonc != pnonc) { 76 | // cdf_ops doesn't handle NaN well. 77 | return NAN; 78 | } 79 | 80 | // bound checks 81 | CUDA_KERNEL_ASSERT(!(x < 0.0e0)); 82 | CUDA_KERNEL_ASSERT(!(df <= 0.0e0)); 83 | CUDA_KERNEL_ASSERT(!(pnonc < 0.0e0)); 84 | 85 | // NOTE [ Non-central Chi-square CDF Bounds ] 86 | // 87 | // The following algorithm is *really* slow when both `x` and `nc` are large 88 | // and can scale with `nc`. 89 | // 90 | // See documentation for CDFLIB CDFCHN, copied below: 91 | // 92 | // The computation time required for this routine is proportional 93 | // to the noncentrality parameter (PNONC). Very large values of 94 | // this parameter can consume immense computer resources. This is 95 | // why the search range is bounded by 1e9. 96 | // 97 | // Hence, we employ the CDF bounds from 98 | // http://proceedings.mlr.press/v22/kolar12/klar12Supple.pdf 99 | // Lemma8. 100 | 101 | if (use_bounds) { 102 | // CUDA doesn't like constexpr std::log and std::sqrt. Do it ourselves! 103 | // constexpr scalar_t neglogeps = -std::log(std::numeric_limits::epsilon()); 104 | // constexpr scalar_t sqrt_neglogeps = std::sqrt(neglogeps); 105 | static_assert(std::is_same::value || std::is_same::value ,"Unsupported scalar_t"); 106 | static_assert(std::numeric_limits::epsilon() < 1.193e-07 && 1.192e-07 < std::numeric_limits::epsilon()); 107 | static_assert(std::numeric_limits::epsilon() < 2.221e-16 && 1.220e-16 < std::numeric_limits::epsilon()); 108 | 109 | constexpr scalar_t neglogeps = std::is_same::value 110 | ? 36.0436533891171535515240975655615329742431640625 111 | : 15.9423847198486328125; 112 | constexpr scalar_t sqrt_neglogeps = std::is_same::value 113 | ? 6.00363668030612540604806781630031764507293701171875 114 | : 3.992791652679443359375; 115 | 116 | const scalar_t mean = df + pnonc; 117 | const scalar_t term1 = 2 * std::sqrt(df + 2 * pnonc) * sqrt_neglogeps; 118 | if (x >= mean + term1 + 2 * neglogeps) { 119 | return 1; 120 | } else if (x <= mean - term1) { 121 | return 0; 122 | } 123 | } 124 | 125 | if (!std::is_same::value && ( 126 | (check_x && x > chndtr_double_thresh) || 127 | (check_df && df > chndtr_double_thresh) || 128 | (check_pnonc && pnonc > chndtr_double_thresh))) { 129 | return static_cast(_chndtr(x, df, pnonc)); 130 | } else { 131 | return _chndtr(x, df, pnonc); 132 | } 133 | } 134 | 135 | } // namespace cuda 136 | } // namespace cdflib 137 | -------------------------------------------------------------------------------- /torchqmet/pqe/cdf_ops/cuda/cdflib/cumchi.cuh: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include "cumgam.cuh" 4 | 5 | #include 6 | #include 7 | #include 8 | #include 9 | #include 10 | 11 | #include 12 | #include 13 | #include 14 | #include 15 | #include 16 | #include 17 | #include 18 | #include 19 | 20 | 21 | namespace cdflib { 22 | namespace cuda { 23 | 24 | template 25 | __host__ __device__ __forceinline__ 26 | void cumchi ( scalar_t x, scalar_t df, scalar_t *cum, scalar_t *ccum ) 27 | 28 | //****************************************************************************80 29 | // 30 | // Purpose: 31 | // 32 | // CUMCHI evaluates the cumulative chi-square distribution. 33 | // 34 | // Parameters: 35 | // 36 | // Input, double *X, the upper limit of integration. 37 | // 38 | // Input, double *DF, the degrees of freedom of the 39 | // chi-square distribution. 40 | // 41 | // Output, double *CUM, the cumulative chi-square distribution. 42 | // 43 | // Output, double *CCUM, the complement of the cumulative 44 | // chi-square distribution. 45 | // 46 | { 47 | scalar_t a = df * 0.5; 48 | scalar_t xx = x * 0.5; 49 | cumgam ( xx, a, cum, ccum ); 50 | } 51 | 52 | } // namespace cuda 53 | } // namespace cdflib 54 | -------------------------------------------------------------------------------- /torchqmet/pqe/cdf_ops/cuda/cdflib/cumchn.cuh: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include "cumchi.cuh" 4 | #include "fifidint.cuh" 5 | 6 | #include 7 | #include 8 | #include 9 | #include 10 | #include 11 | 12 | #include 13 | #include 14 | #include 15 | #include 16 | #include 17 | #include 18 | #include 19 | #include 20 | 21 | 22 | namespace cdflib { 23 | namespace cuda { 24 | 25 | template 26 | __host__ __device__ __forceinline__ 27 | void cumchn ( scalar_t x, scalar_t df, scalar_t pnonc, scalar_t *cum, scalar_t *ccum ) 28 | 29 | //****************************************************************************80 30 | // 31 | // Purpose: 32 | // 33 | // CUMCHN evaluates the cumulative noncentral chi-square distribution. 34 | // 35 | // Discussion: 36 | // 37 | // Calculates the cumulative noncentral chi-square 38 | // distribution, i.e., the probability that a random variable 39 | // which follows the noncentral chi-square distribution, with 40 | // noncentrality parameter PNONC and continuous degrees of 41 | // freedom DF, is less than or equal to X. 42 | // 43 | // Reference: 44 | // 45 | // Milton Abramowitz and Irene Stegun, 46 | // Handbook of Mathematical Functions 47 | // 1966, Formula 26.4.25. 48 | // 49 | // Parameters: 50 | // 51 | // Input, double *X, the upper limit of integration. 52 | // 53 | // Input, double *DF, the number of degrees of freedom. 54 | // 55 | // Input, double *PNONC, the noncentrality parameter of 56 | // the noncentral chi-square distribution. 57 | // 58 | // Output, double *CUM, *CCUM, the CDF and complementary 59 | // CDF of the noncentral chi-square distribution. 60 | // 61 | // Local Parameters: 62 | // 63 | // Local, double EPS, the convergence criterion. The sum 64 | // stops when a term is less than EPS*SUM. 65 | // 66 | { 67 | # define dg(i) (df+2.0e0*(scalar_t)(i)) 68 | static_assert(std::is_same::value || std::is_same::value ,"Unsupported scalar_t"); 69 | constexpr scalar_t abstol = std::is_same::value ? 1.0e-300 : 1.0e-32; 70 | # define qsmall(xx) (int)(!(sum >= abstol && (xx) >= eps*sum)) 71 | int ntired = 1e9; 72 | # define qtired(i) (int)((i) > ntired) 73 | 74 | constexpr scalar_t eps = std::is_same::value ? 1.0e-15 : 1.0e-6; 75 | scalar_t adj,centaj,centwt,chid2,dfd2,lcntaj,lcntwt,lfact,pcent,pterm,sum, 76 | sumadj,term,wt,xnonc; 77 | int i, icent, iterb, iterf; 78 | 79 | if(!(x <= 0.0e0)) goto S10; 80 | *cum = 0.0e0; 81 | *ccum = 1.0e0; 82 | return; 83 | S10: 84 | if(!(pnonc <= 1.0e-10)) goto S20; 85 | // 86 | // When non-centrality parameter is (essentially) zero, 87 | // use cumulative chi-square distribution 88 | // 89 | cumchi(x, df, cum, ccum); 90 | return; 91 | S20: 92 | xnonc = pnonc/2.0e0; 93 | // 94 | // The following code calculates the weight, chi-square, and 95 | // adjustment term for the central term in the infinite series. 96 | // The central term is the one in which the poisson weight is 97 | // greatest. The adjustment term is the amount that must 98 | // be subtracted from the chi-square to move up two degrees 99 | // of freedom. 100 | // 101 | icent = static_cast(fifidint(xnonc)); 102 | if(icent == 0) icent = 1; 103 | chid2 = x/2.0e0; 104 | // 105 | // Calculate central weight term 106 | // 107 | lfact = std::lgamma ( static_cast(icent+1) ); 108 | lcntwt = -xnonc+static_cast(icent)*std::log(xnonc)-lfact; 109 | centwt = std::exp(lcntwt); 110 | // 111 | // Calculate central chi-square 112 | // 113 | scalar_t T2 = dg(icent); 114 | cumchi(x,T2,&pcent,ccum); 115 | // 116 | // Calculate central adjustment term 117 | // 118 | dfd2 = dg(icent)/2.0e0; 119 | lfact = std::lgamma ( 1.0e0+dfd2 ); 120 | lcntaj = dfd2*std::log(chid2)-chid2-lfact; 121 | centaj = std::exp(lcntaj); 122 | sum = centwt*pcent; 123 | // 124 | // Sum backwards from the central term towards zero. 125 | // Quit whenever either 126 | // (1) the zero term is reached, or 127 | // (2) the term gets small relative to the sum, or 128 | // (3) More than NTIRED terms are totaled. 129 | // 130 | iterb = 0; 131 | sumadj = 0.0e0; 132 | adj = centaj; 133 | wt = centwt; 134 | i = icent; 135 | goto S40; 136 | S30: 137 | if( qtired(iterb) || qsmall(term) || i == 0 ) goto S50; 138 | S40: 139 | dfd2 = dg(i)/2.0e0; 140 | // 141 | // Adjust chi-square for two fewer degrees of freedom. 142 | // The adjusted value ends up in PTERM. 143 | // 144 | adj = adj*dfd2/chid2; 145 | sumadj = sumadj + adj; 146 | pterm = pcent+sumadj; 147 | // 148 | // Adjust poisson weight for J decreased by one 149 | // 150 | wt *= static_cast(i) / xnonc; 151 | term = wt*pterm; 152 | sum = sum + term; 153 | i -= 1; 154 | iterb += 1; 155 | goto S30; 156 | S50: 157 | iterf = 0; 158 | // 159 | // Now sum forward from the central term towards infinity. 160 | // Quit when either 161 | // (1) the term gets small relative to the sum, or 162 | // (2) More than NTIRED terms are totaled. 163 | // 164 | sumadj = adj = centaj; 165 | wt = centwt; 166 | i = icent; 167 | goto S70; 168 | S60: 169 | if ( qtired(iterf) || qsmall(term) ) goto S80; 170 | S70: 171 | // 172 | // Update weights for next higher J 173 | // 174 | wt *= xnonc / static_cast(i+1); 175 | // 176 | // Calculate PTERM and add term to sum 177 | // 178 | pterm = pcent-sumadj; 179 | term = wt*pterm; 180 | sum = sum + term; 181 | // 182 | // Update adjustment term for DF for next iteration 183 | // 184 | i = i + 1; 185 | dfd2 = dg(i)/2.0e0; 186 | adj = adj*chid2/dfd2; 187 | sumadj = sumadj + adj; 188 | iterf = iterf + 1; 189 | goto S60; 190 | S80: 191 | *cum = sum; 192 | *ccum = 0.5e0+(0.5e0-*cum); 193 | return; 194 | # undef dg 195 | # undef qsmall 196 | # undef qtired 197 | } 198 | 199 | } // namespace cuda 200 | } // namespace cdflib 201 | -------------------------------------------------------------------------------- /torchqmet/pqe/cdf_ops/cuda/cdflib/cumgam.cuh: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include "gamma_inc.cuh" 4 | 5 | #include 6 | #include 7 | #include 8 | #include 9 | #include 10 | 11 | #include 12 | #include 13 | #include 14 | #include 15 | #include 16 | #include 17 | #include 18 | #include 19 | 20 | 21 | namespace cdflib { 22 | namespace cuda { 23 | 24 | template 25 | __host__ __device__ __forceinline__ 26 | void cumgam ( scalar_t x, scalar_t a, scalar_t *cum, scalar_t *ccum ) 27 | 28 | //****************************************************************************80 29 | // 30 | // Purpose: 31 | // 32 | // CUMGAM evaluates the cumulative incomplete gamma distribution. 33 | // 34 | // Discussion: 35 | // 36 | // This routine computes the cumulative distribution function of the 37 | // incomplete gamma distribution, i.e., the integral from 0 to X of 38 | // 39 | // (1/GAM(A))*EXP(-T)*T**(A-1) DT 40 | // 41 | // where GAM(A) is the complete gamma function of A, i.e., 42 | // 43 | // GAM(A) = integral from 0 to infinity of EXP(-T)*T**(A-1) DT 44 | // 45 | // Parameters: 46 | // 47 | // Input, double *X, the upper limit of integration. 48 | // 49 | // Input, double *A, the shape parameter of the incomplete 50 | // Gamma distribution. 51 | // 52 | // Output, double *CUM, *CCUM, the incomplete Gamma CDF and 53 | // complementary CDF. 54 | // 55 | { 56 | if(!(x <= 0.0e0)) goto S10; 57 | *cum = 0.0e0; 58 | *ccum = 1.0e0; 59 | return; 60 | S10: 61 | gamma_inc ( a, x, cum, ccum); 62 | // 63 | // Call gratio routine 64 | // 65 | return; 66 | } 67 | 68 | } // namespace cuda 69 | } // namespace cdflib 70 | -------------------------------------------------------------------------------- /torchqmet/pqe/cdf_ops/cuda/cdflib/error_fc.cuh: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include "exparg.cuh" 4 | 5 | #include 6 | #include 7 | #include 8 | #include 9 | #include 10 | 11 | #include 12 | #include 13 | #include 14 | #include 15 | #include 16 | #include 17 | #include 18 | #include 19 | 20 | 21 | namespace cdflib { 22 | namespace cuda { 23 | 24 | // This is used due to the availability of scale_with_expxx 25 | 26 | template 27 | __host__ __device__ __forceinline__ 28 | scalar_t error_fc ( scalar_t x ) 29 | 30 | //****************************************************************************80 31 | // 32 | // Purpose: 33 | // 34 | // ERROR_FC evaluates the complementary error function ERFC. 35 | // 36 | // Modified: 37 | // 38 | // 09 December 1999 39 | // 40 | // Parameters: 41 | // 42 | // Input, int *IND, chooses the scaling. 43 | // If IND is nonzero, then the value returned has been multiplied by 44 | // EXP(X*X). 45 | // 46 | // Input, double *X, the argument of the function. 47 | // 48 | // Output, double ERROR_FC, the value of the complementary 49 | // error function. 50 | // 51 | { 52 | constexpr scalar_t abs_negative_exp_arg_bound = -exparg(); 53 | static const scalar_t c = .564189583547756e0; 54 | static const scalar_t a[5] = { 55 | .771058495001320e-04,-.133733772997339e-02,.323076579225834e-01, 56 | .479137145607681e-01,.128379167095513e+00 57 | }; 58 | static const scalar_t b[3] = { 59 | .301048631703895e-02,.538971687740286e-01,.375795757275549e+00 60 | }; 61 | static const scalar_t p[8] = { 62 | -1.36864857382717e-07,5.64195517478974e-01,7.21175825088309e+00, 63 | 4.31622272220567e+01,1.52989285046940e+02,3.39320816734344e+02, 64 | 4.51918953711873e+02,3.00459261020162e+02 65 | }; 66 | static const scalar_t q[8] = { 67 | 1.00000000000000e+00,1.27827273196294e+01,7.70001529352295e+01, 68 | 2.77585444743988e+02,6.38980264465631e+02,9.31354094850610e+02, 69 | 7.90950925327898e+02,3.00459260956983e+02 70 | }; 71 | static const scalar_t r[5] = { 72 | 2.10144126479064e+00,2.62370141675169e+01,2.13688200555087e+01, 73 | 4.65807828718470e+00,2.82094791773523e-01 74 | }; 75 | static const scalar_t s[4] = { 76 | 9.41537750555460e+01,1.87114811799590e+02,9.90191814623914e+01, 77 | 1.80124575948747e+01 78 | }; 79 | scalar_t erfc1,ax,bot,e,t,top,w; 80 | 81 | // 82 | // ABS(X) .LE. 0.5 83 | // 84 | ax = std::abs(x); 85 | if(ax > 0.5e0) goto S10; 86 | t = x*x; 87 | top = (((a[0]*t+a[1])*t+a[2])*t+a[3])*t+a[4]+1.0e0; 88 | bot = ((b[0]*t+b[1])*t+b[2])*t+1.0e0; 89 | erfc1 = 0.5e0+(0.5e0-x*(top/bot)); 90 | if(scale_with_expxx) erfc1 = std::exp(t)*erfc1; 91 | return erfc1; 92 | S10: 93 | // 94 | // 0.5 .LT. ABS(X) .LE. 4 95 | // 96 | if(ax > 4.0e0) goto S20; 97 | top = ((((((p[0]*ax+p[1])*ax+p[2])*ax+p[3])*ax+p[4])*ax+p[5])*ax+p[6])*ax+p[ 98 | 7]; 99 | bot = ((((((q[0]*ax+q[1])*ax+q[2])*ax+q[3])*ax+q[4])*ax+q[5])*ax+q[6])*ax+q[ 100 | 7]; 101 | erfc1 = top/bot; 102 | goto S40; 103 | S20: 104 | // 105 | // ABS(X) .GT. 4 106 | // 107 | if(x <= -5.6e0) goto S60; 108 | if(scale_with_expxx) goto S30; 109 | if(x > 100.0e0) goto S70; 110 | if(x*x > abs_negative_exp_arg_bound) goto S70; 111 | S30: 112 | t = std::pow(1.0e0/ x,2.0); 113 | top = (((r[0]*t+r[1])*t+r[2])*t+r[3])*t+r[4]; 114 | bot = (((s[0]*t+s[1])*t+s[2])*t+s[3])*t+1.0e0; 115 | erfc1 = (c-t*top/bot)/ax; 116 | S40: 117 | // 118 | // FINAL ASSEMBLY 119 | // 120 | if(!scale_with_expxx) goto S50; 121 | if(x < 0.0e0) erfc1 = 2.0e0*std::exp(x*x)-erfc1; 122 | return erfc1; 123 | S50: 124 | w = x*x; 125 | t = w; 126 | e = w-t; 127 | erfc1 = (0.5e0+(0.5e0-e))*std::exp(-t)*erfc1; 128 | if(x < 0.0e0) erfc1 = 2.0e0-erfc1; 129 | return erfc1; 130 | S60: 131 | // 132 | // LIMIT VALUE FOR LARGE NEGATIVE X 133 | // 134 | erfc1 = 2.0e0; 135 | if(scale_with_expxx) erfc1 = 2.0e0*std::exp(x*x); 136 | return erfc1; 137 | S70: 138 | // 139 | // LIMIT VALUE FOR LARGE POSITIVE X 140 | // WHEN IND = 0 141 | // 142 | erfc1 = 0.0e0; 143 | return erfc1; 144 | } 145 | 146 | } // namespace cuda 147 | } // namespace cdflib 148 | -------------------------------------------------------------------------------- /torchqmet/pqe/cdf_ops/cuda/cdflib/exparg.cuh: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include 4 | #include 5 | #include 6 | #include 7 | #include 8 | 9 | #include 10 | #include 11 | #include 12 | #include 13 | #include 14 | #include 15 | #include 16 | #include 17 | 18 | 19 | namespace cdflib { 20 | namespace cuda { 21 | 22 | template 23 | static constexpr scalar_t exparg() 24 | 25 | //****************************************************************************80 26 | // 27 | // Purpose: 28 | // 29 | // EXPARG returns the largest or smallest legal argument for EXP. 30 | // 31 | // Discussion: 32 | // 33 | // Only an approximate limit for the argument of EXP is desired. 34 | // 35 | // Modified: 36 | // 37 | // 09 December 1999 38 | // 39 | // Parameters: 40 | // 41 | // Input, int *L, indicates which limit is desired. 42 | // If L = 0, then the largest positive argument for EXP is desired. 43 | // Otherwise, the largest negative argument for EXP for which the 44 | // result is nonzero is desired. 45 | // 46 | // Output, double EXPARG, the desired value. 47 | // 48 | { 49 | constexpr double lnb = .69314718055995e0; // ln of base = 2 50 | int m = 0; 51 | 52 | if (positive) { 53 | m = std::numeric_limits::max_exponent; // the largest exponent E for double precision. 54 | } else { 55 | m = std::numeric_limits::min_exponent - 1; // the smallest exponent E for double precision, then - 1. 56 | } 57 | return static_cast(0.99999e0*((double)m*lnb)); 58 | } 59 | 60 | } // namespace cuda 61 | } // namespace cdflib 62 | -------------------------------------------------------------------------------- /torchqmet/pqe/cdf_ops/cuda/cdflib/fifidint.cuh: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include 4 | #include 5 | #include 6 | #include 7 | #include 8 | 9 | #include 10 | #include 11 | #include 12 | #include 13 | #include 14 | #include 15 | #include 16 | #include 17 | 18 | 19 | namespace cdflib { 20 | namespace cuda { 21 | 22 | template 23 | __device__ __host__ __forceinline__ 24 | static long fifidint(scalar_t val) { 25 | return val < 1.0 ? 0 : static_cast(val); 26 | } 27 | 28 | } // namespace cuda 29 | } // namespace cdflib 30 | -------------------------------------------------------------------------------- /torchqmet/pqe/cdf_ops/cuda/cdflib/gam1.cuh: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include 4 | #include 5 | #include 6 | #include 7 | #include 8 | 9 | #include 10 | #include 11 | #include 12 | #include 13 | #include 14 | #include 15 | #include 16 | #include 17 | 18 | 19 | namespace cdflib { 20 | namespace cuda { 21 | 22 | template 23 | __host__ __device__ __forceinline__ 24 | scalar_t gam1 ( scalar_t a ) 25 | 26 | //****************************************************************************80 27 | // 28 | // Purpose: 29 | // 30 | // GAM1 computes 1 / GAMMA(A+1) - 1 for -0.5D+00 <= A <= 1.5 31 | // 32 | // Parameters: 33 | // 34 | // Input, double *A, forms the argument of the Gamma function. 35 | // 36 | // Output, double GAM1, the value of 1 / GAMMA ( A + 1 ) - 1. 37 | // 38 | { 39 | static const scalar_t s1 = .273076135303957e+00; 40 | static const scalar_t s2 = .559398236957378e-01; 41 | static const scalar_t p[7] = { 42 | .577215664901533e+00,-.409078193005776e+00,-.230975380857675e+00, 43 | .597275330452234e-01,.766968181649490e-02,-.514889771323592e-02, 44 | .589597428611429e-03 45 | }; 46 | static const scalar_t q[5] = { 47 | .100000000000000e+01,.427569613095214e+00,.158451672430138e+00, 48 | .261132021441447e-01,.423244297896961e-02 49 | }; 50 | static const scalar_t r[9] = { 51 | -.422784335098468e+00,-.771330383816272e+00,-.244757765222226e+00, 52 | .118378989872749e+00,.930357293360349e-03,-.118290993445146e-01, 53 | .223047661158249e-02,.266505979058923e-03,-.132674909766242e-03 54 | }; 55 | scalar_t gam1,bot,d,t,top,w,T1; 56 | 57 | t = a; 58 | d = a-0.5e0; 59 | if(d > 0.0e0) t = d-0.5e0; 60 | T1 = t; 61 | if(T1 < 0) goto S40; 62 | else if(T1 == 0) goto S10; 63 | else goto S20; 64 | S10: 65 | gam1 = 0.0e0; 66 | return gam1; 67 | S20: 68 | top = (((((p[6]*t+p[5])*t+p[4])*t+p[3])*t+p[2])*t+p[1])*t+p[0]; 69 | bot = (((q[4]*t+q[3])*t+q[2])*t+q[1])*t+1.0e0; 70 | w = top/bot; 71 | if(d > 0.0e0) goto S30; 72 | gam1 = a*w; 73 | return gam1; 74 | S30: 75 | gam1 = t/ a*(w-0.5e0-0.5e0); 76 | return gam1; 77 | S40: 78 | top = (((((((r[8]*t+r[7])*t+r[6])*t+r[5])*t+r[4])*t+r[3])*t+r[2])*t+r[1])*t+ 79 | r[0]; 80 | bot = (s2*t+s1)*t+1.0e0; 81 | w = top/bot; 82 | if(d > 0.0e0) goto S50; 83 | gam1 = a*(w+0.5e0+0.5e0); 84 | return gam1; 85 | S50: 86 | gam1 = t*w/ a; 87 | return gam1; 88 | } 89 | 90 | } // namespace cuda 91 | } // namespace cdflib 92 | -------------------------------------------------------------------------------- /torchqmet/pqe/cdf_ops/cuda/cdflib/gamma_inc.cuh: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include "error_fc.cuh" 4 | #include "fifidint.cuh" 5 | #include "gam1.cuh" 6 | #include "rlog.cuh" 7 | #include "rexp.cuh" 8 | 9 | #include 10 | #include 11 | #include 12 | #include 13 | #include 14 | 15 | #include 16 | #include 17 | #include 18 | #include 19 | #include 20 | #include 21 | #include 22 | #include 23 | 24 | 25 | namespace cdflib { 26 | namespace cuda { 27 | 28 | template 29 | __host__ __device__ __forceinline__ 30 | void gamma_inc ( scalar_t a, scalar_t x, scalar_t *ans, scalar_t *qans ) 31 | 32 | //****************************************************************************80 33 | // 34 | // Purpose: 35 | // 36 | // GAMMA_INC evaluates the incomplete gamma ratio functions P(A,X) and Q(A,X). 37 | // 38 | // Discussion: 39 | // 40 | // This is certified spaghetti code. 41 | // 42 | // Author: 43 | // 44 | // Alfred H Morris, Jr, 45 | // Naval Surface Weapons Center, 46 | // Dahlgren, Virginia. 47 | // 48 | // Parameters: 49 | // 50 | // Input, double *A, *X, the arguments of the incomplete 51 | // gamma ratio. A and X must be nonnegative. A and X cannot 52 | // both be zero. 53 | // 54 | // Output, double *ANS, *QANS. On normal output, 55 | // ANS = P(A,X) and QANS = Q(A,X). However, ANS is set to 2 if 56 | // A or X is negative, or both are 0, or when the answer is 57 | // computationally indeterminate because A is extremely large 58 | // and X is very close to A. 59 | // 60 | // Input, bool PREC_IND, indicates the accuracy request: 61 | // 0, as much accuracy as possible. 62 | // 1, to within 1 unit of the 6-th significant digit, 63 | // otherwise, to within 1 unit of the 3rd significant digit. 64 | // 65 | { 66 | static const scalar_t alog10 = 2.30258509299405e0; 67 | static const scalar_t d10 = -.185185185185185e-02; 68 | static const scalar_t d20 = .413359788359788e-02; 69 | static const scalar_t d30 = .649434156378601e-03; 70 | static const scalar_t d40 = -.861888290916712e-03; 71 | static const scalar_t d50 = -.336798553366358e-03; 72 | static const scalar_t d60 = .531307936463992e-03; 73 | static const scalar_t d70 = .344367606892378e-03; 74 | static const scalar_t rt2pin = .398942280401433e0; 75 | static const scalar_t rtpi = 1.77245385090552e0; 76 | static const scalar_t third = .333333333333333e0; 77 | static const scalar_t acc0[3] = { 78 | 5.e-15,5.e-7,5.e-4 79 | }; 80 | static const scalar_t big[3] = { 81 | 20.0e0,14.0e0,10.0e0 82 | }; 83 | static const scalar_t d0[13] = { 84 | .833333333333333e-01,-.148148148148148e-01,.115740740740741e-02, 85 | .352733686067019e-03,-.178755144032922e-03,.391926317852244e-04, 86 | -.218544851067999e-05,-.185406221071516e-05,.829671134095309e-06, 87 | -.176659527368261e-06,.670785354340150e-08,.102618097842403e-07, 88 | -.438203601845335e-08 89 | }; 90 | static const scalar_t d1[12] = { 91 | -.347222222222222e-02,.264550264550265e-02,-.990226337448560e-03, 92 | .205761316872428e-03,-.401877572016461e-06,-.180985503344900e-04, 93 | .764916091608111e-05,-.161209008945634e-05,.464712780280743e-08, 94 | .137863344691572e-06,-.575254560351770e-07,.119516285997781e-07 95 | }; 96 | static const scalar_t d2[10] = { 97 | -.268132716049383e-02,.771604938271605e-03,.200938786008230e-05, 98 | -.107366532263652e-03,.529234488291201e-04,-.127606351886187e-04, 99 | .342357873409614e-07,.137219573090629e-05,-.629899213838006e-06, 100 | .142806142060642e-06 101 | }; 102 | static const scalar_t d3[8] = { 103 | .229472093621399e-03,-.469189494395256e-03,.267720632062839e-03, 104 | -.756180167188398e-04,-.239650511386730e-06,.110826541153473e-04, 105 | -.567495282699160e-05,.142309007324359e-05 106 | }; 107 | static const scalar_t d4[6] = { 108 | .784039221720067e-03,-.299072480303190e-03,-.146384525788434e-05, 109 | .664149821546512e-04,-.396836504717943e-04,.113757269706784e-04 110 | }; 111 | static const scalar_t d5[4] = { 112 | -.697281375836586e-04,.277275324495939e-03,-.199325705161888e-03, 113 | .679778047793721e-04 114 | }; 115 | static const scalar_t d6[2] = { 116 | -.592166437353694e-03,.270878209671804e-03 117 | }; 118 | static const scalar_t e00[3] = { 119 | .25e-3,.25e-1,.14e0 120 | }; 121 | static const scalar_t x00[3] = { 122 | 31.0e0,17.0e0,9.7e0 123 | }; 124 | scalar_t a2n,a2nm1,acc,am0,amn,an,an0,apn,b2n,b2nm1,c,c0,c1,c2,c3,c4,c5,c6, 125 | cma,e0,g,h,j,l,r,rta,rtx,s,sum,t,t1,tol,twoa,u,w,x0,y,z; 126 | int i,iop,m,max,n; 127 | scalar_t wk[20],T3; 128 | int T4,T5; 129 | scalar_t T6,T7; 130 | 131 | // 132 | // E IS A MACHINE DEPENDENT CONSTANT. E IS THE SMALLEST 133 | // NUMBER FOR WHICH 1.0 + E .GT. 1.0 . 134 | // 135 | constexpr scalar_t e = std::numeric_limits::epsilon(); 136 | if(a < 0.0e0 || x < 0.0e0) goto S430; 137 | if(a == 0.0e0 && x == 0.0e0) goto S430; 138 | if(a*x == 0.0e0) goto S420; 139 | iop = prec_ind+1; 140 | if(iop != 1 && iop != 2) iop = 3; 141 | acc = std::max(acc0[iop-1],e); 142 | e0 = e00[iop-1]; 143 | x0 = x00[iop-1]; 144 | // 145 | // SELECT THE APPROPRIATE ALGORITHM 146 | // 147 | if(a >= 1.0e0) goto S10; 148 | if(a == 0.5e0) goto S390; 149 | if(x < 1.1e0) goto S160; 150 | t1 = a*std::log(x)-x; 151 | u = a*std::exp(t1); 152 | if(u == 0.0e0) goto S380; 153 | r = u*(1.0e0+gam1(a)); 154 | goto S250; 155 | S10: 156 | if(a >= big[iop-1]) goto S30; 157 | if(a > x || x >= x0) goto S20; 158 | twoa = a+a; 159 | m = static_cast(fifidint(twoa)); 160 | if(twoa != static_cast(m)) goto S20; 161 | i = m/2; 162 | if(a == static_cast(i)) goto S210; 163 | goto S220; 164 | S20: 165 | t1 = a*std::log(x)-x; 166 | r = std::exp(t1)/ std::tgamma(a); 167 | goto S40; 168 | S30: 169 | l = x/ a; 170 | if(l == 0.0e0) goto S370; 171 | s = 0.5e0+(0.5e0-l); 172 | z = rlog(l); 173 | if(z >= 700.0e0/ a) goto S410; 174 | y = a*z; 175 | rta = std::sqrt(a); 176 | if(std::abs(s) <= e0/rta) goto S330; 177 | if(std::abs(s) <= 0.4e0) goto S270; 178 | t = std::pow(1.0e0/ a,2.0); 179 | t1 = (((0.75e0*t-1.0e0)*t+3.5e0)*t-105.0e0)/(a*1260.0e0); 180 | t1 -= y; 181 | r = rt2pin*rta*std::exp(t1); 182 | S40: 183 | if(r == 0.0e0) goto S420; 184 | if(x <= std::max(a,alog10)) goto S50; 185 | if(x < x0) goto S250; 186 | goto S100; 187 | S50: 188 | // 189 | // TAYLOR SERIES FOR P/R 190 | // 191 | apn = a+1.0e0; 192 | t = x/apn; 193 | wk[0] = t; 194 | for ( n = 2; n <= 20; n++ ) 195 | { 196 | apn += 1.0e0; 197 | t *= (x/apn); 198 | if(t <= 1.e-3) goto S70; 199 | wk[n-1] = t; 200 | } 201 | n = 20; 202 | S70: 203 | sum = t; 204 | tol = 0.5e0*acc; 205 | S80: 206 | apn += 1.0e0; 207 | t *= (x/apn); 208 | sum += t; 209 | if(t > tol) goto S80; 210 | max = n-1; 211 | for ( m = 1; m <= max; m++ ) 212 | { 213 | n -= 1; 214 | sum += wk[n-1]; 215 | } 216 | *ans = r/ a*(1.0e0+sum); 217 | *qans = 0.5e0+(0.5e0-*ans); 218 | return; 219 | S100: 220 | // 221 | // ASYMPTOTIC EXPANSION 222 | // 223 | amn = a-1.0e0; 224 | t = amn/ x; 225 | wk[0] = t; 226 | for ( n = 2; n <= 20; n++ ) 227 | { 228 | amn -= 1.0e0; 229 | t *= (amn/ x); 230 | if(std::abs(t) <= 1.e-3) goto S120; 231 | wk[n-1] = t; 232 | } 233 | n = 20; 234 | S120: 235 | sum = t; 236 | S130: 237 | if(std::abs(t) <= acc) goto S140; 238 | amn -= 1.0e0; 239 | t *= (amn/ x); 240 | sum += t; 241 | goto S130; 242 | S140: 243 | max = n-1; 244 | for ( m = 1; m <= max; m++ ) 245 | { 246 | n -= 1; 247 | sum += wk[n-1]; 248 | } 249 | *qans = r/ x*(1.0e0+sum); 250 | *ans = 0.5e0+(0.5e0-*qans); 251 | return; 252 | S160: 253 | // 254 | // TAYLOR SERIES FOR P(A,X)/X**A 255 | // 256 | an = 3.0e0; 257 | c = x; 258 | sum = x/(a+3.0e0); 259 | tol = 3.0e0*acc/(a+1.0e0); 260 | S170: 261 | an += 1.0e0; 262 | c = -(c*(x/an)); 263 | t = c/(a+an); 264 | sum += t; 265 | if(std::abs(t) > tol) goto S170; 266 | j = a*x*((sum/6.0e0-0.5e0/(a+2.0e0))*x+1.0e0/(a+1.0e0)); 267 | z = a*std::log(x); 268 | h = gam1(a); 269 | g = 1.0e0+h; 270 | if(x < 0.25e0) goto S180; 271 | if(a < x/2.59e0) goto S200; 272 | goto S190; 273 | S180: 274 | if(z > -.13394e0) goto S200; 275 | S190: 276 | w = std::exp(z); 277 | *ans = w*g*(0.5e0+(0.5e0-j)); 278 | *qans = 0.5e0+(0.5e0-*ans); 279 | return; 280 | S200: 281 | l = rexp(z); 282 | w = 0.5e0+(0.5e0+l); 283 | *qans = (w*j-l)*g-h; 284 | if(*qans < 0.0e0) goto S380; 285 | *ans = 0.5e0+(0.5e0-*qans); 286 | return; 287 | S210: 288 | // 289 | // FINITE SUMS FOR Q WHEN A .GE. 1 AND 2*A IS AN INTEGER 290 | // 291 | sum = std::exp(-x); 292 | t = sum; 293 | n = 1; 294 | c = 0.0e0; 295 | goto S230; 296 | S220: 297 | rtx = std::sqrt(x); 298 | sum = std::erfc ( rtx ); 299 | t = std::exp(-x)/(rtpi*rtx); 300 | n = 0; 301 | c = -0.5e0; 302 | S230: 303 | if(n == i) goto S240; 304 | n += 1; 305 | c += 1.0e0; 306 | t = x*t/c; 307 | sum += t; 308 | goto S230; 309 | S240: 310 | *qans = sum; 311 | *ans = 0.5e0+(0.5e0-*qans); 312 | return; 313 | S250: 314 | // 315 | // CONTINUED FRACTION EXPANSION 316 | // 317 | tol = std::max( static_cast(5.0e0) * e, acc); 318 | a2nm1 = a2n = 1.0e0; 319 | b2nm1 = x; 320 | b2n = x+(1.0e0-a); 321 | c = 1.0e0; 322 | S260: 323 | a2nm1 = x*a2n+c*a2nm1; 324 | b2nm1 = x*b2n+c*b2nm1; 325 | am0 = a2nm1/b2nm1; 326 | c += 1.0e0; 327 | cma = c-a; 328 | a2n = a2nm1+cma*a2n; 329 | b2n = b2nm1+cma*b2n; 330 | an0 = a2n/b2n; 331 | if(std::abs(an0-am0) >= tol*an0) goto S260; 332 | *qans = r*an0; 333 | *ans = 0.5e0+(0.5e0-*qans); 334 | return; 335 | S270: 336 | // 337 | // GENERAL TEMME EXPANSION 338 | // 339 | if(std::abs(s) <= 2.0e0*e && a*e*e > 3.28e-3) goto S430; 340 | c = std::exp(-y); 341 | T3 = std::sqrt(y); 342 | w = 0.5e0 * error_fc ( T3 ); 343 | u = 1.0e0/ a; 344 | z = std::sqrt(z+z); 345 | if(l < 1.0e0) z = -z; 346 | T4 = iop-2; 347 | if(T4 < 0) goto S280; 348 | else if(T4 == 0) goto S290; 349 | else goto S300; 350 | S280: 351 | if(std::abs(s) <= 1.e-3) goto S340; 352 | c0 = ((((((((((((d0[12]*z+d0[11])*z+d0[10])*z+d0[9])*z+d0[8])*z+d0[7])*z+d0[ 353 | 6])*z+d0[5])*z+d0[4])*z+d0[3])*z+d0[2])*z+d0[1])*z+d0[0])*z-third; 354 | c1 = (((((((((((d1[11]*z+d1[10])*z+d1[9])*z+d1[8])*z+d1[7])*z+d1[6])*z+d1[5] 355 | )*z+d1[4])*z+d1[3])*z+d1[2])*z+d1[1])*z+d1[0])*z+d10; 356 | c2 = (((((((((d2[9]*z+d2[8])*z+d2[7])*z+d2[6])*z+d2[5])*z+d2[4])*z+d2[3])*z+ 357 | d2[2])*z+d2[1])*z+d2[0])*z+d20; 358 | c3 = (((((((d3[7]*z+d3[6])*z+d3[5])*z+d3[4])*z+d3[3])*z+d3[2])*z+d3[1])*z+ 359 | d3[0])*z+d30; 360 | c4 = (((((d4[5]*z+d4[4])*z+d4[3])*z+d4[2])*z+d4[1])*z+d4[0])*z+d40; 361 | c5 = (((d5[3]*z+d5[2])*z+d5[1])*z+d5[0])*z+d50; 362 | c6 = (d6[1]*z+d6[0])*z+d60; 363 | t = ((((((d70*u+c6)*u+c5)*u+c4)*u+c3)*u+c2)*u+c1)*u+c0; 364 | goto S310; 365 | S290: 366 | c0 = (((((d0[5]*z+d0[4])*z+d0[3])*z+d0[2])*z+d0[1])*z+d0[0])*z-third; 367 | c1 = (((d1[3]*z+d1[2])*z+d1[1])*z+d1[0])*z+d10; 368 | c2 = d2[0]*z+d20; 369 | t = (c2*u+c1)*u+c0; 370 | goto S310; 371 | S300: 372 | t = ((d0[2]*z+d0[1])*z+d0[0])*z-third; 373 | S310: 374 | if(l < 1.0e0) goto S320; 375 | *qans = c*(w+rt2pin*t/rta); 376 | *ans = 0.5e0+(0.5e0-*qans); 377 | return; 378 | S320: 379 | *ans = c*(w-rt2pin*t/rta); 380 | *qans = 0.5e0+(0.5e0-*ans); 381 | return; 382 | S330: 383 | // 384 | // TEMME EXPANSION FOR L = 1 385 | // 386 | if(a*e*e > 3.28e-3) goto S430; 387 | c = 0.5e0+(0.5e0-y); 388 | w = (0.5e0-std::sqrt(y)*(0.5e0+(0.5e0-y/3.0e0))/rtpi)/c; 389 | u = 1.0e0/ a; 390 | z = std::sqrt(z+z); 391 | if(l < 1.0e0) z = -z; 392 | T5 = iop-2; 393 | if(T5 < 0) goto S340; 394 | else if(T5 == 0) goto S350; 395 | else goto S360; 396 | S340: 397 | c0 = ((((((d0[6]*z+d0[5])*z+d0[4])*z+d0[3])*z+d0[2])*z+d0[1])*z+d0[0])*z- 398 | third; 399 | c1 = (((((d1[5]*z+d1[4])*z+d1[3])*z+d1[2])*z+d1[1])*z+d1[0])*z+d10; 400 | c2 = ((((d2[4]*z+d2[3])*z+d2[2])*z+d2[1])*z+d2[0])*z+d20; 401 | c3 = (((d3[3]*z+d3[2])*z+d3[1])*z+d3[0])*z+d30; 402 | c4 = (d4[1]*z+d4[0])*z+d40; 403 | c5 = (d5[1]*z+d5[0])*z+d50; 404 | c6 = d6[0]*z+d60; 405 | t = ((((((d70*u+c6)*u+c5)*u+c4)*u+c3)*u+c2)*u+c1)*u+c0; 406 | goto S310; 407 | S350: 408 | c0 = (d0[1]*z+d0[0])*z-third; 409 | c1 = d1[0]*z+d10; 410 | t = (d20*u+c1)*u+c0; 411 | goto S310; 412 | S360: 413 | t = d0[0]*z-third; 414 | goto S310; 415 | S370: 416 | // 417 | // SPECIAL CASES 418 | // 419 | *ans = 0.0e0; 420 | *qans = 1.0e0; 421 | return; 422 | S380: 423 | *ans = 1.0e0; 424 | *qans = 0.0e0; 425 | return; 426 | S390: 427 | if(x >= 0.25e0) goto S400; 428 | T6 = std::sqrt(x); 429 | *ans = std::erf ( T6 ); 430 | *qans = 0.5e0+(0.5e0-*ans); 431 | return; 432 | S400: 433 | T7 = std::sqrt(x); 434 | *qans = std::erfc ( T7 ); 435 | *ans = 0.5e0+(0.5e0-*qans); 436 | return; 437 | S410: 438 | if(std::abs(s) <= 2.0e0*e) goto S430; 439 | S420: 440 | if(x <= a) goto S370; 441 | goto S380; 442 | S430: 443 | // 444 | // ERROR RETURN 445 | // 446 | *ans = 2.0e0; 447 | return; 448 | } 449 | 450 | } // namespace cuda 451 | } // namespace cdflib 452 | -------------------------------------------------------------------------------- /torchqmet/pqe/cdf_ops/cuda/cdflib/rexp.cuh: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include 4 | #include 5 | #include 6 | #include 7 | #include 8 | 9 | #include 10 | #include 11 | #include 12 | #include 13 | #include 14 | #include 15 | #include 16 | #include 17 | 18 | 19 | namespace cdflib { 20 | namespace cuda { 21 | 22 | template 23 | __host__ __device__ __forceinline__ 24 | scalar_t rexp ( scalar_t x ) 25 | 26 | //****************************************************************************80 27 | // 28 | // Purpose: 29 | // 30 | // REXP evaluates the function EXP(X) - 1. 31 | // 32 | // Modified: 33 | // 34 | // 09 December 1999 35 | // 36 | // Parameters: 37 | // 38 | // Input, double *X, the argument of the function. 39 | // 40 | // Output, double REXP, the value of EXP(X)-1. 41 | // 42 | { 43 | static const scalar_t p1 = .914041914819518e-09; 44 | static const scalar_t p2 = .238082361044469e-01; 45 | static const scalar_t q1 = -.499999999085958e+00; 46 | static const scalar_t q2 = .107141568980644e+00; 47 | static const scalar_t q3 = -.119041179760821e-01; 48 | static const scalar_t q4 = .595130811860248e-03; 49 | scalar_t rexp,w; 50 | 51 | if(std::abs(x) > 0.15e0) goto S10; 52 | rexp = x*(((p2*x+p1)*x+1.0e0)/((((q4*x+q3)*x+q2)*x+q1)*x+1.0e0)); 53 | return rexp; 54 | S10: 55 | w = std::exp(x); 56 | if(x > 0.0e0) goto S20; 57 | rexp = w-0.5e0-0.5e0; 58 | return rexp; 59 | S20: 60 | rexp = w*(0.5e0+(0.5e0-1.0e0/w)); 61 | return rexp; 62 | } 63 | 64 | } // namespace cuda 65 | } // namespace cdflib 66 | -------------------------------------------------------------------------------- /torchqmet/pqe/cdf_ops/cuda/cdflib/rlog.cuh: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include 4 | #include 5 | #include 6 | #include 7 | #include 8 | 9 | #include 10 | #include 11 | #include 12 | #include 13 | #include 14 | #include 15 | #include 16 | #include 17 | 18 | 19 | namespace cdflib { 20 | namespace cuda { 21 | 22 | template 23 | __host__ __device__ __forceinline__ 24 | scalar_t rlog ( scalar_t x ) 25 | 26 | //****************************************************************************80 27 | // 28 | // Purpose: 29 | // 30 | // RLOG computes X - 1 - LN(X). 31 | // 32 | // Modified: 33 | // 34 | // 09 December 1999 35 | // 36 | // Parameters: 37 | // 38 | // Input, double *X, the argument of the function. 39 | // 40 | // Output, double RLOG, the value of the function. 41 | // 42 | { 43 | static const scalar_t a = .566749439387324e-01; 44 | static const scalar_t b = .456512608815524e-01; 45 | static const scalar_t p0 = .333333333333333e+00; 46 | static const scalar_t p1 = -.224696413112536e+00; 47 | static const scalar_t p2 = .620886815375787e-02; 48 | static const scalar_t q1 = -.127408923933623e+01; 49 | static const scalar_t q2 = .354508718369557e+00; 50 | scalar_t rlog,r,t,u,w,w1; 51 | 52 | if(x < 0.61e0 || x > 1.57e0) goto S40; 53 | if(x < 0.82e0) goto S10; 54 | if(x > 1.18e0) goto S20; 55 | // 56 | // ARGUMENT REDUCTION 57 | // 58 | u = x-0.5e0-0.5e0; 59 | w1 = 0.0e0; 60 | goto S30; 61 | S10: 62 | u = x-0.7e0; 63 | u /= 0.7e0; 64 | w1 = a-u*0.3e0; 65 | goto S30; 66 | S20: 67 | u = 0.75e0*x-1.e0; 68 | w1 = b+u/3.0e0; 69 | S30: 70 | // 71 | // SERIES EXPANSION 72 | // 73 | r = u/(u+2.0e0); 74 | t = r*r; 75 | w = ((p2*t+p1)*t+p0)/((q2*t+q1)*t+1.0e0); 76 | rlog = 2.0e0*t*(1.0e0/(1.0e0-r)-r*w)+w1; 77 | return rlog; 78 | S40: 79 | r = x-0.5e0-0.5e0; 80 | rlog = r-std::log(x); 81 | return rlog; 82 | } 83 | //****************************************************************************80 84 | 85 | } // namespace cuda 86 | } // namespace cdflib 87 | -------------------------------------------------------------------------------- /torchqmet/pqe/cdf_ops/cuda/cephes/chbevl.cuh: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | /* 4 | * From 5 | * https://github.com/scipy/scipy/blob/5caee5d4ad564cfae4596f8dfa8b45997767035b/scipy/special/cephes/chbevl.c 6 | */ 7 | 8 | /* chbevl.c 9 | * 10 | * Evaluate Chebyshev series 11 | * 12 | * 13 | * 14 | * SYNOPSIS: 15 | * 16 | * int N; 17 | * double x, y, coef[N], chebevl(); 18 | * 19 | * y = chbevl( x, coef, N ); 20 | * 21 | * 22 | * 23 | * DESCRIPTION: 24 | * 25 | * Evaluates the series 26 | * 27 | * N-1 28 | * - ' 29 | * y = > coef[i] T (x/2) 30 | * - i 31 | * i=0 32 | * 33 | * of Chebyshev polynomials Ti at argument x/2. 34 | * 35 | * Coefficients are stored in reverse order, i.e. the zero 36 | * order term is last in the array. Note N is the number of 37 | * coefficients, not the order. 38 | * 39 | * If coefficients are for the interval a to b, x must 40 | * have been transformed to x -> 2(2x - b - a)/(b-a) before 41 | * entering the routine. This maps x from (a, b) to (-1, 1), 42 | * over which the Chebyshev polynomials are defined. 43 | * 44 | * If the coefficients are for the inverted interval, in 45 | * which (a, b) is mapped to (1/b, 1/a), the transformation 46 | * required is x -> 2(2ab/x - b - a)/(b-a). If b is infinity, 47 | * this becomes x -> 4a/x - 1. 48 | * 49 | * 50 | * 51 | * SPEED: 52 | * 53 | * Taking advantage of the recurrence properties of the 54 | * Chebyshev polynomials, the routine requires one more 55 | * addition per loop than evaluating a nested polynomial of 56 | * the same degree. 57 | * 58 | */ 59 | /* chbevl.c */ 60 | 61 | /* 62 | * Cephes Math Library Release 2.0: April, 1987 63 | * Copyright 1985, 1987 by Stephen L. Moshier 64 | * Direct inquiries to 30 Frost Street, Cambridge, MA 02140 65 | */ 66 | 67 | 68 | #include 69 | 70 | namespace cephes { 71 | namespace cuda { 72 | 73 | template 74 | static __host__ __device__ __forceinline__ 75 | typename std::enable_if::value, scalar_t>::type 76 | chbevl(const scalar_t _x, const scalar_t array[], const size_t len) 77 | { 78 | using accscalar_t = at::acc_type; 79 | 80 | accscalar_t x = static_cast(_x); 81 | accscalar_t b0, b1, b2; 82 | 83 | b0 = static_cast(array[0]); 84 | b1 = 0.0; 85 | 86 | for (size_t i = 1; i < len; ++i) { 87 | b2 = b1; 88 | b1 = b0; 89 | b0 = x * b1 - b2 + static_cast(array[i]);; 90 | } 91 | 92 | return static_cast(0.5 * (b0 - b2)); 93 | } 94 | 95 | } // namespace cuda 96 | } // namespace cephes 97 | -------------------------------------------------------------------------------- /torchqmet/pqe/cdf_ops/cuda/cephes/i0.cuh: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | /* 4 | * From 5 | * https://github.com/scipy/scipy/blob/5caee5d4ad564cfae4596f8dfa8b45997767035b/scipy/special/cephes/i0.c 6 | */ 7 | 8 | /* i0.c 9 | * 10 | * Modified Bessel function of order zero 11 | * 12 | * 13 | * 14 | * SYNOPSIS: 15 | * 16 | * double x, y, i0(); 17 | * 18 | * y = i0( x ); 19 | * 20 | * 21 | * 22 | * DESCRIPTION: 23 | * 24 | * Returns modified Bessel function of order zero of the 25 | * argument. 26 | * 27 | * The function is defined as i0(x) = j0( ix ). 28 | * 29 | * The range is partitioned into the two intervals [0,8] and 30 | * (8, infinity). Chebyshev polynomial expansions are employed 31 | * in each interval. 32 | * 33 | * 34 | * 35 | * ACCURACY: 36 | * 37 | * Relative error: 38 | * arithmetic domain # trials peak rms 39 | * IEEE 0,30 30000 5.8e-16 1.4e-16 40 | * 41 | */ 42 | /* i0e.c 43 | * 44 | * Modified Bessel function of order zero, 45 | * exponentially scaled 46 | * 47 | * 48 | * 49 | * SYNOPSIS: 50 | * 51 | * double x, y, i0e(); 52 | * 53 | * y = i0e( x ); 54 | * 55 | * 56 | * 57 | * DESCRIPTION: 58 | * 59 | * Returns exponentially scaled modified Bessel function 60 | * of order zero of the argument. 61 | * 62 | * The function is defined as i0e(x) = exp(-|x|) j0( ix ). 63 | * 64 | * 65 | * 66 | * ACCURACY: 67 | * 68 | * Relative error: 69 | * arithmetic domain # trials peak rms 70 | * IEEE 0,30 30000 5.4e-16 1.2e-16 71 | * See i0(). 72 | * 73 | */ 74 | 75 | /* i0.c */ 76 | 77 | 78 | /* 79 | * Cephes Math Library Release 2.8: June, 2000 80 | * Copyright 1984, 1987, 2000 by Stephen L. Moshier 81 | */ 82 | 83 | #include "chbevl.cuh" 84 | 85 | #include 86 | #include 87 | 88 | namespace cephes { 89 | namespace cuda { 90 | 91 | namespace { 92 | 93 | using std::exp; 94 | using std::sqrt; 95 | 96 | template 97 | static __host__ __device__ __forceinline__ 98 | scalar_t calc_i0(scalar_t x) { 99 | 100 | /* Chebyshev coefficients for exp(-x) I0(x) 101 | * in the interval [0,8]. 102 | * 103 | * lim(x->0){ exp(-x) I0(x) } = 1. 104 | */ 105 | static const scalar_t A[] = { 106 | -4.41534164647933937950E-18, 107 | 3.33079451882223809783E-17, 108 | -2.43127984654795469359E-16, 109 | 1.71539128555513303061E-15, 110 | -1.16853328779934516808E-14, 111 | 7.67618549860493561688E-14, 112 | -4.85644678311192946090E-13, 113 | 2.95505266312963983461E-12, 114 | -1.72682629144155570723E-11, 115 | 9.67580903537323691224E-11, 116 | -5.18979560163526290666E-10, 117 | 2.65982372468238665035E-9, 118 | -1.30002500998624804212E-8, 119 | 6.04699502254191894932E-8, 120 | -2.67079385394061173391E-7, 121 | 1.11738753912010371815E-6, 122 | -4.41673835845875056359E-6, 123 | 1.64484480707288970893E-5, 124 | -5.75419501008210370398E-5, 125 | 1.88502885095841655729E-4, 126 | -5.76375574538582365885E-4, 127 | 1.63947561694133579842E-3, 128 | -4.32430999505057594430E-3, 129 | 1.05464603945949983183E-2, 130 | -2.37374148058994688156E-2, 131 | 4.93052842396707084878E-2, 132 | -9.49010970480476444210E-2, 133 | 1.71620901522208775349E-1, 134 | -3.04682672343198398683E-1, 135 | 6.76795274409476084995E-1 136 | }; 137 | 138 | /* Chebyshev coefficients for exp(-x) sqrt(x) I0(x) 139 | * in the inverted interval [8,infinity]. 140 | * 141 | * lim(x->inf){ exp(-x) sqrt(x) I0(x) } = 1/sqrt(2pi). 142 | */ 143 | static const scalar_t B[] = { 144 | -7.23318048787475395456E-18, 145 | -4.83050448594418207126E-18, 146 | 4.46562142029675999901E-17, 147 | 3.46122286769746109310E-17, 148 | -2.82762398051658348494E-16, 149 | -3.42548561967721913462E-16, 150 | 1.77256013305652638360E-15, 151 | 3.81168066935262242075E-15, 152 | -9.55484669882830764870E-15, 153 | -4.15056934728722208663E-14, 154 | 1.54008621752140982691E-14, 155 | 3.85277838274214270114E-13, 156 | 7.18012445138366623367E-13, 157 | -1.79417853150680611778E-12, 158 | -1.32158118404477131188E-11, 159 | -3.14991652796324136454E-11, 160 | 1.18891471078464383424E-11, 161 | 4.94060238822496958910E-10, 162 | 3.39623202570838634515E-9, 163 | 2.26666899049817806459E-8, 164 | 2.04891858946906374183E-7, 165 | 2.89137052083475648297E-6, 166 | 6.88975834691682398426E-5, 167 | 3.36911647825569408990E-3, 168 | 8.04490411014108831608E-1 169 | }; 170 | 171 | 172 | scalar_t y; 173 | 174 | if (x < 0) 175 | x = -x; 176 | 177 | if (x <= 8.0) { 178 | y = (x / 2.0) - 2.0; 179 | if (e) { 180 | return chbevl(y, A, 30); 181 | } else { 182 | return (exp(x) * chbevl(y, A, 30)); 183 | } 184 | } 185 | 186 | if (e) { 187 | return (chbevl(32.0 / x - 2.0, B, 25) / sqrt(x)); 188 | } else { 189 | return (exp(x) * chbevl(32.0 / x - 2.0, B, 25) / sqrt(x)); 190 | } 191 | } 192 | 193 | } 194 | 195 | template 196 | static __host__ __device__ __forceinline__ 197 | scalar_t i0(scalar_t x) { 198 | return calc_i0(x); 199 | } 200 | 201 | template 202 | static __host__ __device__ __forceinline__ 203 | scalar_t i0e(scalar_t x) { 204 | return calc_i0(x); 205 | } 206 | 207 | } // namespace cuda 208 | } // namespace cephes 209 | -------------------------------------------------------------------------------- /torchqmet/pqe/cdf_ops/cuda/cephes/i1.cuh: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | /* 4 | * From 5 | * https://github.com/scipy/scipy/blob/5caee5d4ad564cfae4596f8dfa8b45997767035b/scipy/special/cephes/i1.c 6 | */ 7 | 8 | /* i1.c 9 | * 10 | * Modified Bessel function of order one 11 | * 12 | * 13 | * 14 | * SYNOPSIS: 15 | * 16 | * double x, y, i1(); 17 | * 18 | * y = i1( x ); 19 | * 20 | * 21 | * 22 | * DESCRIPTION: 23 | * 24 | * Returns modified Bessel function of order one of the 25 | * argument. 26 | * 27 | * The function is defined as i1(x) = -i j1( ix ). 28 | * 29 | * The range is partitioned into the two intervals [0,8] and 30 | * (8, infinity). Chebyshev polynomial expansions are employed 31 | * in each interval. 32 | * 33 | * 34 | * 35 | * ACCURACY: 36 | * 37 | * Relative error: 38 | * arithmetic domain # trials peak rms 39 | * IEEE 0, 30 30000 1.9e-15 2.1e-16 40 | * 41 | * 42 | */ 43 | /* i1e.c 44 | * 45 | * Modified Bessel function of order one, 46 | * exponentially scaled 47 | * 48 | * 49 | * 50 | * SYNOPSIS: 51 | * 52 | * double x, y, i1e(); 53 | * 54 | * y = i1e( x ); 55 | * 56 | * 57 | * 58 | * DESCRIPTION: 59 | * 60 | * Returns exponentially scaled modified Bessel function 61 | * of order one of the argument. 62 | * 63 | * The function is defined as i1(x) = -i exp(-|x|) j1( ix ). 64 | * 65 | * 66 | * 67 | * ACCURACY: 68 | * 69 | * Relative error: 70 | * arithmetic domain # trials peak rms 71 | * IEEE 0, 30 30000 2.0e-15 2.0e-16 72 | * See i1(). 73 | * 74 | */ 75 | 76 | /* i1.c 2 */ 77 | 78 | 79 | /* 80 | * Cephes Math Library Release 2.8: June, 2000 81 | * Copyright 1985, 1987, 2000 by Stephen L. Moshier 82 | */ 83 | 84 | #include "chbevl.cuh" 85 | #include 86 | 87 | 88 | namespace cephes { 89 | namespace cuda { 90 | 91 | namespace { 92 | 93 | using std::abs; 94 | using std::exp; 95 | using std::sqrt; 96 | 97 | 98 | template 99 | static __host__ __device__ __forceinline__ 100 | scalar_t calc_i1(scalar_t x) { 101 | /* Chebyshev coefficients for exp(-x) I1(x) / x 102 | * in the interval [0,8]. 103 | * 104 | * lim(x->0){ exp(-x) I1(x) / x } = 1/2. 105 | */ 106 | 107 | static scalar_t A[] = { 108 | 2.77791411276104639959E-18, 109 | -2.11142121435816608115E-17, 110 | 1.55363195773620046921E-16, 111 | -1.10559694773538630805E-15, 112 | 7.60068429473540693410E-15, 113 | -5.04218550472791168711E-14, 114 | 3.22379336594557470981E-13, 115 | -1.98397439776494371520E-12, 116 | 1.17361862988909016308E-11, 117 | -6.66348972350202774223E-11, 118 | 3.62559028155211703701E-10, 119 | -1.88724975172282928790E-9, 120 | 9.38153738649577178388E-9, 121 | -4.44505912879632808065E-8, 122 | 2.00329475355213526229E-7, 123 | -8.56872026469545474066E-7, 124 | 3.47025130813767847674E-6, 125 | -1.32731636560394358279E-5, 126 | 4.78156510755005422638E-5, 127 | -1.61760815825896745588E-4, 128 | 5.12285956168575772895E-4, 129 | -1.51357245063125314899E-3, 130 | 4.15642294431288815669E-3, 131 | -1.05640848946261981558E-2, 132 | 2.47264490306265168283E-2, 133 | -5.29459812080949914269E-2, 134 | 1.02643658689847095384E-1, 135 | -1.76416518357834055153E-1, 136 | 2.52587186443633654823E-1 137 | }; 138 | 139 | /* Chebyshev coefficients for exp(-x) sqrt(x) I1(x) 140 | * in the inverted interval [8,infinity]. 141 | * 142 | * lim(x->inf){ exp(-x) sqrt(x) I1(x) } = 1/sqrt(2pi). 143 | */ 144 | static scalar_t B[] = { 145 | 7.51729631084210481353E-18, 146 | 4.41434832307170791151E-18, 147 | -4.65030536848935832153E-17, 148 | -3.20952592199342395980E-17, 149 | 2.96262899764595013876E-16, 150 | 3.30820231092092828324E-16, 151 | -1.88035477551078244854E-15, 152 | -3.81440307243700780478E-15, 153 | 1.04202769841288027642E-14, 154 | 4.27244001671195135429E-14, 155 | -2.10154184277266431302E-14, 156 | -4.08355111109219731823E-13, 157 | -7.19855177624590851209E-13, 158 | 2.03562854414708950722E-12, 159 | 1.41258074366137813316E-11, 160 | 3.25260358301548823856E-11, 161 | -1.89749581235054123450E-11, 162 | -5.58974346219658380687E-10, 163 | -3.83538038596423702205E-9, 164 | -2.63146884688951950684E-8, 165 | -2.51223623787020892529E-7, 166 | -3.88256480887769039346E-6, 167 | -1.10588938762623716291E-4, 168 | -9.76109749136146840777E-3, 169 | 7.78576235018280120474E-1 170 | }; 171 | 172 | 173 | scalar_t y, z; 174 | 175 | z = abs(x); 176 | if (z <= 8.0) { 177 | y = (z / 2.0) - 2.0; 178 | if (e) { 179 | z = chbevl(y, A, 29) * z; 180 | } else { 181 | z = chbevl(y, A, 29) * z * exp(z); 182 | } 183 | } 184 | else { 185 | if (e) { 186 | z = chbevl(32.0 / z - 2.0, B, 25) / sqrt(z); 187 | } else { 188 | z = exp(z) * chbevl(32.0 / z - 2.0, B, 25) / sqrt(z); 189 | } 190 | } 191 | if (x < 0.0) { 192 | z = -z; 193 | } 194 | return (z); 195 | 196 | } 197 | 198 | } 199 | 200 | 201 | template 202 | static __host__ __device__ __forceinline__ 203 | scalar_t i1(scalar_t x) { 204 | return calc_i1(x); 205 | } 206 | 207 | template 208 | static __host__ __device__ __forceinline__ 209 | scalar_t i1e(scalar_t x) { 210 | return calc_i1(x); 211 | } 212 | 213 | } // namespace cuda 214 | } // namespace cephes 215 | -------------------------------------------------------------------------------- /torchqmet/pqe/cdf_ops/cuda/cephes/ndtr.cuh: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | /* 4 | * From 5 | * https://github.com/scipy/scipy/blob/7c7a5f8393e7b16e5bc81c739c84fe2e639c367f/scipy/special/cephes/ndtr.c 6 | */ 7 | 8 | /* ndtr.c 9 | * 10 | * Normal distribution function 11 | * 12 | * 13 | * 14 | * SYNOPSIS: 15 | * 16 | * double x, y, ndtr(); 17 | * 18 | * y = ndtr( x ); 19 | * 20 | * 21 | * 22 | * DESCRIPTION: 23 | * 24 | * Returns the area under the Gaussian probability density 25 | * function, integrated from minus infinity to x: 26 | * 27 | * x 28 | * - 29 | * 1 | | 2 30 | * ndtr(x) = --------- | exp( - t /2 ) dt 31 | * sqrt(2pi) | | 32 | * - 33 | * -inf. 34 | * 35 | * = ( 1 + erf(z) ) / 2 36 | * = erfc(z) / 2 37 | * 38 | * where z = x/sqrt(2). Computation is via the functions 39 | * erf and erfc. 40 | * 41 | * 42 | * ACCURACY: 43 | * 44 | * Relative error: 45 | * arithmetic domain # trials peak rms 46 | * IEEE -13,0 30000 3.4e-14 6.7e-15 47 | * 48 | * 49 | * ERROR MESSAGES: 50 | * 51 | * message condition value returned 52 | * erfc underflow x > 37.519379347 0.0 53 | * 54 | */ 55 | /* erf.c 56 | * 57 | * Error function 58 | * 59 | * 60 | * 61 | * SYNOPSIS: 62 | * 63 | * double x, y, erf(); 64 | * 65 | * y = erf( x ); 66 | * 67 | * 68 | * 69 | * DESCRIPTION: 70 | * 71 | * The integral is 72 | * 73 | * x 74 | * - 75 | * 2 | | 2 76 | * erf(x) = -------- | exp( - t ) dt. 77 | * sqrt(pi) | | 78 | * - 79 | * 0 80 | * 81 | * For 0 <= |x| < 1, erf(x) = x * P4(x**2)/Q5(x**2); otherwise 82 | * erf(x) = 1 - erfc(x). 83 | * 84 | * 85 | * 86 | * ACCURACY: 87 | * 88 | * Relative error: 89 | * arithmetic domain # trials peak rms 90 | * IEEE 0,1 30000 3.7e-16 1.0e-16 91 | * 92 | */ 93 | /* erfc.c 94 | * 95 | * Complementary error function 96 | * 97 | * 98 | * 99 | * SYNOPSIS: 100 | * 101 | * double x, y, erfc(); 102 | * 103 | * y = erfc( x ); 104 | * 105 | * 106 | * 107 | * DESCRIPTION: 108 | * 109 | * 110 | * 1 - erf(x) = 111 | * 112 | * inf. 113 | * - 114 | * 2 | | 2 115 | * erfc(x) = -------- | exp( - t ) dt 116 | * sqrt(pi) | | 117 | * - 118 | * x 119 | * 120 | * 121 | * For small x, erfc(x) = 1 - erf(x); otherwise rational 122 | * approximations are computed. 123 | * 124 | * 125 | * 126 | * ACCURACY: 127 | * 128 | * Relative error: 129 | * arithmetic domain # trials peak rms 130 | * IEEE 0,26.6417 30000 5.7e-14 1.5e-14 131 | */ 132 | 133 | 134 | /* 135 | * Cephes Math Library Release 2.2: June, 1992 136 | * Copyright 1984, 1987, 1988, 1992 by Stephen L. Moshier 137 | * Direct inquiries to 30 Frost Street, Cambridge, MA 02140 138 | */ 139 | 140 | // #include "polevl.h" 141 | 142 | #ifndef _USE_MATH_DEFINES 143 | #define _USE_MATH_DEFINES 144 | #endif 145 | #include 146 | #include 147 | #include 148 | 149 | namespace cephes { namespace cuda { 150 | 151 | 152 | 153 | template 154 | static __host__ __device__ __forceinline__ 155 | scalar_t ndtr(scalar_t a) 156 | { 157 | scalar_t x, y, z; 158 | 159 | if (std::isnan(a)) { 160 | // throw std::runtime_error("ndtr sees NaN") 161 | return NAN; 162 | } 163 | 164 | // std::sqrt(0.5); std::sqrt is not constexpr 165 | constexpr scalar_t SQRT1_2 = 0.707106781186547524400844362104849039284835937688474036588339868995366239231053519425193767163820786367507; 166 | 167 | x = a * SQRT1_2; 168 | z = std::fabs(x); 169 | 170 | if (z < SQRT1_2) { 171 | y = 0.5 + 0.5 * std::erf(x); 172 | 173 | } else { 174 | y = 0.5 * std::erfc(z); 175 | 176 | if (x > 0) { 177 | y = 1.0 - y; 178 | } 179 | } 180 | 181 | return (y); 182 | } 183 | 184 | 185 | // template 186 | // static inline scalar_t erfc(scalar_t a) 187 | // { 188 | // scalar_t p, q, x, y, z; 189 | 190 | // if (std::isnan(a)) { 191 | // // throw std::runtime_error("erfc sees NaN") 192 | // return NAN; 193 | // } 194 | 195 | // if (a < 0.0) 196 | // x = -a; 197 | // else 198 | // x = a; 199 | 200 | // if (x < 1.0) 201 | // return (1.0 - erf(a)); 202 | 203 | // z = -a * a; 204 | 205 | // if (z < -MAXLOG) { 206 | // under: 207 | // // sf_error("erfc", SF_ERROR_UNDERFLOW, NULL); 208 | // if (a < 0) 209 | // return (2.0); 210 | // else 211 | // return (0.0); 212 | // } 213 | 214 | // z = std::exp(z); 215 | 216 | // if (x < 8.0) { 217 | 218 | // static const scalar_t P[] = { 219 | // 2.46196981473530512524E-10, 220 | // 5.64189564831068821977E-1, 221 | // 7.46321056442269912687E0, 222 | // 4.86371970985681366614E1, 223 | // 1.96520832956077098242E2, 224 | // 5.26445194995477358631E2, 225 | // 9.34528527171957607540E2, 226 | // 1.02755188689515710272E3, 227 | // 5.57535335369399327526E2 228 | // }; 229 | 230 | // static const scalar_t Q[] = { 231 | // /* 1.00000000000000000000E0, */ 232 | // 1.32281951154744992508E1, 233 | // 8.67072140885989742329E1, 234 | // 3.54937778887819891062E2, 235 | // 9.75708501743205489753E2, 236 | // 1.82390916687909736289E3, 237 | // 2.24633760818710981792E3, 238 | // 1.65666309194161350182E3, 239 | // 5.57535340817727675546E2 240 | // }; 241 | 242 | // p = polevl(x, P, 8); 243 | // q = p1evl(x, Q, 8); 244 | // } 245 | // else { 246 | 247 | // static const scalar_t R[] = { 248 | // 5.64189583547755073984E-1, 249 | // 1.27536670759978104416E0, 250 | // 5.01905042251180477414E0, 251 | // 6.16021097993053585195E0, 252 | // 7.40974269950448939160E0, 253 | // 2.97886665372100240670E0 254 | // }; 255 | 256 | // static const scalar_t S[] = { 257 | // /* 1.00000000000000000000E0, */ 258 | // 2.26052863220117276590E0, 259 | // 9.39603524938001434673E0, 260 | // 1.20489539808096656605E1, 261 | // 1.70814450747565897222E1, 262 | // 9.60896809063285878198E0, 263 | // 3.36907645100081516050E0 264 | // }; 265 | 266 | // p = polevl(x, R, 5); 267 | // q = p1evl(x, S, 6); 268 | // } 269 | // y = (z * p) / q; 270 | 271 | // if (a < 0) 272 | // y = 2.0 - y; 273 | 274 | // if (y == 0.0) 275 | // goto under; 276 | 277 | // return (y); 278 | // } 279 | 280 | 281 | // template 282 | // static inline scalar_t erf(scalar_t x) 283 | // { 284 | // scalar_t y, z; 285 | 286 | // if (std::isnan(x)) { 287 | // // throw std::runtime_error("erf sees NaN") 288 | // return NAN; 289 | // } 290 | 291 | // if (x < 0.0) { 292 | // return -erf(-x); 293 | // } 294 | 295 | // if (std::fabs(x) > 1.0) 296 | // return (1.0 - erfc(x)); 297 | 298 | // z = x * x; 299 | 300 | // static const scalar_t T[] = { 301 | // 9.60497373987051638749E0, 302 | // 9.00260197203842689217E1, 303 | // 2.23200534594684319226E3, 304 | // 7.00332514112805075473E3, 305 | // 5.55923013010394962768E4 306 | // }; 307 | 308 | // static const scalar_t U[] = { 309 | // /* 1.00000000000000000000E0, */ 310 | // 3.35617141647503099647E1, 311 | // 5.21357949780152679795E2, 312 | // 4.59432382970980127987E3, 313 | // 2.26290000613890934246E4, 314 | // 4.92673942608635921086E4 315 | // }; 316 | 317 | // y = x * polevl(z, T, 4) / p1evl(z, U, 5); 318 | // return (y); 319 | 320 | // } 321 | 322 | /* 323 | * double log_ndtr(double a) 324 | * 325 | * For a > -20, use the existing ndtr technique and take a log. 326 | * for a <= -20, we use the Taylor series approximation of erf to compute 327 | * the log CDF directly. The Taylor series consists of two parts which we will name "left" 328 | * and "right" accordingly. The right part involves a summation which we compute until the 329 | * difference in terms falls below the machine-specific EPSILON. 330 | * 331 | * \Phi(z) &=& 332 | * \frac{e^{-z^2/2}}{-z\sqrt{2\pi}} * [1 + \sum_{n=1}^{N-1} (-1)^n \frac{(2n-1)!!}{(z^2)^n}] 333 | * + O(z^{-2N+2}) 334 | * = [\mbox{LHS}] * [\mbox{RHS}] + \mbox{error}. 335 | * 336 | */ 337 | 338 | template 339 | static __host__ __device__ __forceinline__ 340 | scalar_t log_ndtr(scalar_t a) 341 | { 342 | 343 | if (a > 6) { 344 | return -ndtr(-a); /* log(1+x) \approx x */ 345 | } 346 | if (a > -20) { 347 | return std::log(ndtr(a)); 348 | } 349 | 350 | scalar_t log_LHS, /* we compute the left hand side of the approx (LHS) in one shot */ 351 | last_total = 0, /* variable used to check for convergence */ 352 | right_hand_side = 1, /* includes first term from the RHS summation */ 353 | numerator = 1, /* numerator for RHS summand */ 354 | denom_factor = 1, /* use reciprocal for denominator to avoid division */ 355 | denom_cons = 1.0 / (a * a); /* the precomputed division we use to adjust the denominator */ 356 | long sign = 1, i = 0; 357 | 358 | log_LHS = -0.5 * a * a - std::log(-a) - 0.5 * std::log(2 * M_PI); 359 | 360 | while (std::fabs(last_total - right_hand_side) > std::numeric_limits::epsilon()) { 361 | i += 1; 362 | last_total = right_hand_side; 363 | sign = -sign; 364 | denom_factor *= denom_cons; 365 | numerator *= 2 * i - 1; 366 | right_hand_side += sign * numerator * denom_factor; 367 | } 368 | return log_LHS + std::log(right_hand_side); 369 | } 370 | 371 | 372 | template 373 | static __host__ __device__ __forceinline__ 374 | std::pair ndtr_log_ndtr(scalar_t a) 375 | { 376 | 377 | if (a > 6) { 378 | /* log(1+x) \approx x */ 379 | scalar_t x = -ndtr(-a); /* x = -ndtr_remain_val */ 380 | return std::make_pair(1 + x, x); 381 | } 382 | if (a > -20) { 383 | scalar_t ndtr_val = ndtr(a); 384 | return std::make_pair(ndtr_val, std::log(ndtr_val)); 385 | } 386 | 387 | scalar_t log_LHS, /* we compute the left hand side of the approx (LHS) in one shot */ 388 | last_total = 0, /* variable used to check for convergence */ 389 | right_hand_side = 1, /* includes first term from the RHS summation */ 390 | numerator = 1, /* numerator for RHS summand */ 391 | denom_factor = 1, /* use reciprocal for denominator to avoid division */ 392 | denom_cons = 1.0 / (a * a); /* the precomputed division we use to adjust the denominator */ 393 | long sign = 1, i = 0; 394 | 395 | log_LHS = -0.5 * a * a - std::log(-a) - 0.5 * std::log(2 * M_PI); 396 | 397 | while (std::fabs(last_total - right_hand_side) > std::numeric_limits::epsilon()) { 398 | i += 1; 399 | last_total = right_hand_side; 400 | sign = -sign; 401 | denom_factor *= denom_cons; 402 | numerator *= 2 * i - 1; 403 | right_hand_side += sign * numerator * denom_factor; 404 | } 405 | return std::make_pair(ndtr(a), log_LHS + std::log(right_hand_side)); 406 | } 407 | 408 | } // namespace cuda 409 | } // namespace cephes 410 | -------------------------------------------------------------------------------- /torchqmet/pqe/cdf_ops/cuda/kernels.cuh: -------------------------------------------------------------------------------- 1 | #include "../cdf_ops.h" 2 | 3 | #include 4 | #include 5 | 6 | namespace cdf_ops { 7 | namespace cuda { 8 | 9 | void chndtr_kernel_cuda(at::TensorIterator& iter); 10 | 11 | void chndtr_scalar_double_kernel_cuda(at::TensorIterator& iter, double df); 12 | void chndtr_scalar_scalar_t_kernel_cuda(at::TensorIterator& iter, double df); 13 | void chndtr_scalar_kernel_cuda(at::TensorIterator& iter, double df); 14 | 15 | void i0e_kernel_cuda(at::TensorIterator& iter); 16 | void i1_kernel_cuda(at::TensorIterator& iter); 17 | void i1e_kernel_cuda(at::TensorIterator& iter); 18 | 19 | template 20 | void prob_two_poisson_kernel_cuda(at::TensorIterator& iter); 21 | template 22 | void prob_two_poisson_grad_mu1_kernel_cuda(at::TensorIterator& iter); 23 | template 24 | void prob_two_poisson_grad_mu2_kernel_cuda(at::TensorIterator& iter); 25 | 26 | void ndtr_kernel_cuda(at::TensorIterator& iter); 27 | void log_ndtr_kernel_cuda(at::TensorIterator& iter); 28 | void ndtr_log_ndtr_kernel_cuda(at::TensorIterator& iter); 29 | void ndtr_backward_kernel_cuda(at::TensorIterator& iter); 30 | void log_ndtr_backward_kernel_cuda(at::TensorIterator& iter); 31 | void log_ndtr_backward_with_ndtr_kernel_cuda(at::TensorIterator& iter); 32 | 33 | } // namespace cdf_ops 34 | } // namespace cuda 35 | -------------------------------------------------------------------------------- /torchqmet/pqe/cdf_ops/cuda/kernels_bessel_i.cu: -------------------------------------------------------------------------------- 1 | #include "cephes/i0.cuh" 2 | #include "cephes/i1.cuh" 3 | #include "kernels.cuh" 4 | 5 | #ifndef _USE_MATH_DEFINES 6 | #define _USE_MATH_DEFINES 7 | #endif 8 | #include 9 | #include 10 | 11 | #include 12 | #include 13 | #include 14 | #include 15 | #include 16 | 17 | 18 | namespace cdf_ops { 19 | namespace cuda { 20 | 21 | 22 | void i0e_kernel_cuda(at::TensorIterator& iter) { 23 | const at::cuda::OptionalCUDAGuard device_guard(iter.device()); 24 | AT_DISPATCH_FLOATING_TYPES(iter.dtype(), "i0e_cuda", [&]() { 25 | at::native::gpu_kernel(iter, []GPU_LAMBDA(scalar_t x) -> scalar_t { 26 | return cephes::cuda::i0e(x); 27 | }); 28 | }); 29 | } 30 | 31 | 32 | void i1_kernel_cuda(at::TensorIterator& iter) { 33 | const at::cuda::OptionalCUDAGuard device_guard(iter.device()); 34 | AT_DISPATCH_FLOATING_TYPES(iter.dtype(), "i1_cuda", [&]() { 35 | at::native::gpu_kernel(iter, []GPU_LAMBDA(scalar_t x) -> scalar_t { 36 | return cephes::cuda::i1(x); 37 | }); 38 | }); 39 | } 40 | 41 | 42 | void i1e_kernel_cuda(at::TensorIterator& iter) { 43 | const at::cuda::OptionalCUDAGuard device_guard(iter.device()); 44 | AT_DISPATCH_FLOATING_TYPES(iter.dtype(), "i1e_cuda", [&]() { 45 | at::native::gpu_kernel(iter, []GPU_LAMBDA(scalar_t x) -> scalar_t { 46 | return cephes::cuda::i1e(x); 47 | }); 48 | }); 49 | } 50 | 51 | } // namespace cuda 52 | } // namespace cdf_ops 53 | -------------------------------------------------------------------------------- /torchqmet/pqe/cdf_ops/cuda/kernels_chndtr_scalar.cu: -------------------------------------------------------------------------------- 1 | #include "cdflib/chndtr.cuh" 2 | #include "kernels.cuh" 3 | 4 | namespace cdf_ops { 5 | namespace cuda { 6 | 7 | 8 | void chndtr_scalar_kernel_cuda(at::TensorIterator& iter, double df) { 9 | // Do dispatch our selves :) 10 | if (df > cdflib::cuda::chndtr_double_thresh) { 11 | chndtr_scalar_double_kernel_cuda(iter, df); 12 | } else { 13 | chndtr_scalar_scalar_t_kernel_cuda(iter, df); 14 | } 15 | } 16 | 17 | } // namespace cuda 18 | } // namespace cdf_ops 19 | -------------------------------------------------------------------------------- /torchqmet/pqe/cdf_ops/cuda/kernels_chndtr_scalar_double.cu: -------------------------------------------------------------------------------- 1 | #include "cdflib/chndtr.cuh" 2 | #include "kernels.cuh" 3 | 4 | #ifndef _USE_MATH_DEFINES 5 | #define _USE_MATH_DEFINES 6 | #endif 7 | #include 8 | #include 9 | 10 | #include 11 | #include 12 | #include 13 | #include 14 | #include 15 | 16 | 17 | namespace cdf_ops { 18 | namespace cuda { 19 | 20 | void chndtr_scalar_double_kernel_cuda(at::TensorIterator& iter, double df) { 21 | const at::cuda::OptionalCUDAGuard device_guard(iter.device()); 22 | AT_ASSERT (df > cdflib::cuda::chndtr_double_thresh); 23 | // double 24 | AT_DISPATCH_FLOATING_TYPES(iter.dtype(), "chndtr_scalar_cuda", [&]() { 25 | at::native::gpu_kernel(iter, [=]GPU_LAMBDA(scalar_t x, scalar_t pnonc) -> scalar_t { 26 | return static_cast(cdflib::cuda::chndtr(x, df, pnonc)); 27 | }); 28 | }); 29 | } 30 | 31 | } // namespace cuda 32 | } // namespace cdf_ops 33 | -------------------------------------------------------------------------------- /torchqmet/pqe/cdf_ops/cuda/kernels_chndtr_scalar_scalar_t.cu: -------------------------------------------------------------------------------- 1 | #include "cdflib/chndtr.cuh" 2 | #include "kernels.cuh" 3 | 4 | #ifndef _USE_MATH_DEFINES 5 | #define _USE_MATH_DEFINES 6 | #endif 7 | #include 8 | #include 9 | 10 | #include 11 | #include 12 | #include 13 | #include 14 | #include 15 | 16 | 17 | namespace cdf_ops { 18 | namespace cuda { 19 | 20 | void chndtr_scalar_scalar_t_kernel_cuda(at::TensorIterator& iter, double df) { 21 | const at::cuda::OptionalCUDAGuard device_guard(iter.device()); 22 | AT_ASSERT (df <= cdflib::cuda::chndtr_double_thresh); 23 | AT_DISPATCH_FLOATING_TYPES(iter.dtype(), "chndtr_scalar_cuda", [&]() { 24 | at::native::gpu_kernel(iter, [=]GPU_LAMBDA(scalar_t x, scalar_t pnonc) -> scalar_t { 25 | return cdflib::cuda::chndtr(x, df, pnonc); 26 | }); 27 | }); 28 | } 29 | 30 | } // namespace cuda 31 | } // namespace cdf_ops 32 | -------------------------------------------------------------------------------- /torchqmet/pqe/cdf_ops/cuda/kernels_chndtr_tensor.cu: -------------------------------------------------------------------------------- 1 | #include "cdflib/chndtr.cuh" 2 | #include "kernels.cuh" 3 | 4 | #ifndef _USE_MATH_DEFINES 5 | #define _USE_MATH_DEFINES 6 | #endif 7 | #include 8 | #include 9 | 10 | #include 11 | #include 12 | #include 13 | #include 14 | #include 15 | 16 | 17 | namespace cdf_ops { 18 | namespace cuda { 19 | 20 | void chndtr_kernel_cuda(at::TensorIterator& iter) { 21 | const at::cuda::OptionalCUDAGuard device_guard(iter.device()); 22 | AT_DISPATCH_FLOATING_TYPES(iter.dtype(), "chndtr_cuda", [&]() { 23 | at::native::gpu_kernel(iter, []GPU_LAMBDA(scalar_t x, scalar_t df, scalar_t pnonc) -> scalar_t { 24 | return cdflib::cuda::chndtr(x, df, pnonc); 25 | }); 26 | }); 27 | } 28 | 29 | } // namespace cuda 30 | } // namespace cdf_ops 31 | -------------------------------------------------------------------------------- /torchqmet/pqe/cdf_ops/cuda/kernels_ndtr.cu: -------------------------------------------------------------------------------- 1 | #include "cephes/ndtr.cuh" 2 | #include "kernels.cuh" 3 | 4 | #ifndef _USE_MATH_DEFINES 5 | #define _USE_MATH_DEFINES 6 | #endif 7 | #include 8 | #include 9 | 10 | #include 11 | #include 12 | #include 13 | #include 14 | #include 15 | 16 | 17 | namespace cdf_ops { 18 | namespace cuda { 19 | 20 | 21 | void ndtr_kernel_cuda(at::TensorIterator& iter) { 22 | const at::cuda::OptionalCUDAGuard device_guard(iter.device()); 23 | AT_DISPATCH_FLOATING_TYPES(iter.dtype(), "ndtr_cuda", [&]() { 24 | at::native::gpu_kernel(iter, []GPU_LAMBDA(scalar_t x) -> scalar_t { 25 | return cephes::cuda::ndtr(x); 26 | }); 27 | }); 28 | } 29 | 30 | 31 | void log_ndtr_kernel_cuda(at::TensorIterator& iter) { 32 | const at::cuda::OptionalCUDAGuard device_guard(iter.device()); 33 | AT_DISPATCH_FLOATING_TYPES(iter.dtype(), "log_ndtr_cuda", [&]() { 34 | at::native::gpu_kernel(iter, []GPU_LAMBDA(scalar_t x) -> scalar_t { 35 | return cephes::cuda::log_ndtr(x); 36 | }); 37 | }); 38 | } 39 | 40 | 41 | void ndtr_log_ndtr_kernel_cuda(at::TensorIterator& iter) { 42 | const at::cuda::OptionalCUDAGuard device_guard(iter.device()); 43 | AT_DISPATCH_FLOATING_TYPES(iter.input_dtype(), "ndtr_log_ndtr_cuda", [&]() { 44 | at::native::gpu_kernel(iter, []GPU_LAMBDA(scalar_t x) -> c10::complex { 45 | auto result = cephes::cuda::ndtr_log_ndtr(x); 46 | return c10::complex(result.first, result.second); 47 | }); 48 | }); 49 | } 50 | 51 | 52 | void ndtr_backward_kernel_cuda(at::TensorIterator& iter) { 53 | const at::cuda::OptionalCUDAGuard device_guard(iter.device()); 54 | AT_DISPATCH_FLOATING_TYPES(iter.dtype(), "ndtr_backward_cuda", [&]() { 55 | at::native::gpu_kernel(iter, []GPU_LAMBDA(scalar_t x, scalar_t gout) -> scalar_t { 56 | // Forward ndtr(x) = 0.5 [ 1 + erf( x / sqrt(2) )] 57 | // 58 | // Erf backward: 59 | // - name: erf(Tensor self) -> Tensor 60 | // self: 2.0 / sqrt(M_PI) * exp(-(self.pow(2))) * grad 61 | 62 | // Backward grad * 0.5 * 2 / sqrt(pi) * exp( - (x / sqrt(2)).pow(2) ) / sqrt(2) 63 | // = grad / sqrt(2 pi) * exp( - x * x / 2) 64 | 65 | // std::sqrt(2 * M_PI); std::sqrt is not constexpr. 66 | constexpr scalar_t SQRT_2PI = 2.5066282746310005024157652848110452530069867406099383166299235763422936546078419749466; 67 | return std::exp(- x * x / 2) * gout / SQRT_2PI; 68 | }); 69 | }); 70 | } 71 | 72 | 73 | void log_ndtr_backward_kernel_cuda(at::TensorIterator& iter) { 74 | const at::cuda::OptionalCUDAGuard device_guard(iter.device()); 75 | AT_DISPATCH_FLOATING_TYPES(iter.dtype(), "log_ndtr_backward_cuda", [&]() { 76 | at::native::gpu_kernel(iter, []GPU_LAMBDA(scalar_t x, scalar_t gout) -> scalar_t { 77 | scalar_t ndtr_val = cephes::cuda::ndtr(x); 78 | // std::sqrt(2 * M_PI); std::sqrt is not constexpr. 79 | constexpr scalar_t SQRT_2PI = 2.5066282746310005024157652848110452530069867406099383166299235763422936546078419749466; 80 | return std::exp(- x * x / 2) * gout / ndtr_val / SQRT_2PI; 81 | }); 82 | }); 83 | } 84 | 85 | 86 | void log_ndtr_backward_with_ndtr_kernel_cuda(at::TensorIterator& iter) { 87 | const at::cuda::OptionalCUDAGuard device_guard(iter.device()); 88 | AT_DISPATCH_FLOATING_TYPES(iter.dtype(), "log_ndtr_backward_with_ndtr_cuda", [&]() { 89 | at::native::gpu_kernel(iter, []GPU_LAMBDA(scalar_t x, scalar_t ndtr_val, scalar_t gout) -> scalar_t { 90 | // std::sqrt(2 * M_PI); std::sqrt is not constexpr. 91 | constexpr scalar_t SQRT_2PI = 2.5066282746310005024157652848110452530069867406099383166299235763422936546078419749466; 92 | return std::exp(- x * x / 2) * gout / ndtr_val / SQRT_2PI; 93 | }); 94 | }); 95 | } 96 | 97 | } // namespace cuda 98 | } // namespace cdf_ops 99 | -------------------------------------------------------------------------------- /torchqmet/pqe/cdf_ops/cuda/kernels_prob_two_poisson_backward_mu1.cu: -------------------------------------------------------------------------------- 1 | #include "kernels_prob_two_poisson_templates.cuh" 2 | 3 | 4 | namespace cdf_ops { 5 | namespace cuda { 6 | 7 | // Explicit instantiations 8 | template void prob_two_poisson_grad_mu1_kernel_cuda(at::TensorIterator& iter); 9 | // template void prob_two_poisson_grad_mu1_kernel_cuda(at::TensorIterator& iter); 10 | template void prob_two_poisson_grad_mu1_kernel_cuda(at::TensorIterator& iter); 11 | // template void prob_two_poisson_grad_mu1_kernel_cuda(at::TensorIterator& iter); 12 | 13 | } // namespace cuda 14 | } // namespace cdf_ops 15 | -------------------------------------------------------------------------------- /torchqmet/pqe/cdf_ops/cuda/kernels_prob_two_poisson_backward_mu2_gt.cu: -------------------------------------------------------------------------------- 1 | #include "kernels_prob_two_poisson_templates.cuh" 2 | 3 | 4 | namespace cdf_ops { 5 | namespace cuda { 6 | 7 | // Explicit instantiations 8 | template void prob_two_poisson_grad_mu2_kernel_cuda(at::TensorIterator& iter); 9 | // template void prob_two_poisson_grad_mu2_kernel_cuda(at::TensorIterator& iter); 10 | // template void prob_two_poisson_grad_mu2_kernel_cuda(at::TensorIterator& iter); 11 | // template void prob_two_poisson_grad_mu2_kernel_cuda(at::TensorIterator& iter); 12 | 13 | } // namespace cuda 14 | } // namespace cdf_ops 15 | -------------------------------------------------------------------------------- /torchqmet/pqe/cdf_ops/cuda/kernels_prob_two_poisson_backward_mu2_le.cu: -------------------------------------------------------------------------------- 1 | #include "kernels_prob_two_poisson_templates.cuh" 2 | 3 | 4 | namespace cdf_ops { 5 | namespace cuda { 6 | 7 | // Explicit instantiations 8 | template void prob_two_poisson_grad_mu2_kernel_cuda(at::TensorIterator& iter); 9 | // template void prob_two_poisson_grad_mu2_kernel_cuda(at::TensorIterator& iter); 10 | // template void prob_two_poisson_grad_mu2_kernel_cuda(at::TensorIterator& iter); 11 | // template void prob_two_poisson_grad_mu2_kernel_cuda(at::TensorIterator& iter); 12 | 13 | } // namespace cuda 14 | } // namespace cdf_ops 15 | -------------------------------------------------------------------------------- /torchqmet/pqe/cdf_ops/cuda/kernels_prob_two_poisson_forward_gt.cu: -------------------------------------------------------------------------------- 1 | #include "kernels_prob_two_poisson_templates.cuh" 2 | 3 | 4 | namespace cdf_ops { 5 | namespace cuda { 6 | 7 | // Explicit instantiations 8 | template void prob_two_poisson_kernel_cuda(at::TensorIterator& iter); 9 | // template void prob_two_poisson_kernel_cuda(at::TensorIterator& iter); 10 | 11 | } // namespace cuda 12 | } // namespace cdf_ops 13 | -------------------------------------------------------------------------------- /torchqmet/pqe/cdf_ops/cuda/kernels_prob_two_poisson_forward_le.cu: -------------------------------------------------------------------------------- 1 | #include "kernels_prob_two_poisson_templates.cuh" 2 | 3 | 4 | namespace cdf_ops { 5 | namespace cuda { 6 | 7 | // Explicit instantiations 8 | template void prob_two_poisson_kernel_cuda(at::TensorIterator& iter); 9 | // template void prob_two_poisson_kernel_cuda(at::TensorIterator& iter); 10 | 11 | } // namespace cuda 12 | } // namespace cdf_ops 13 | -------------------------------------------------------------------------------- /torchqmet/pqe/cdf_ops/cuda/kernels_prob_two_poisson_templates.cuh: -------------------------------------------------------------------------------- 1 | #include "cdflib/chndtr.cuh" 2 | #include "cephes/i0.cuh" 3 | #include "cephes/i1.cuh" 4 | #include "kernels.cuh" 5 | 6 | #ifndef _USE_MATH_DEFINES 7 | #define _USE_MATH_DEFINES 8 | #endif 9 | #include 10 | #include 11 | 12 | #include 13 | #include 14 | #include 15 | #include 16 | #include 17 | 18 | 19 | namespace cdf_ops { 20 | namespace cuda { 21 | 22 | template 23 | void prob_two_poisson_kernel_cuda(at::TensorIterator& iter) { 24 | const at::cuda::OptionalCUDAGuard device_guard(iter.device()); 25 | AT_DISPATCH_FLOATING_TYPES(iter.dtype(), "prob_two_poisson_cuda", [&]() { 26 | at::native::gpu_kernel(iter, []GPU_LAMBDA(scalar_t mu1, scalar_t mu2) -> scalar_t { 27 | // See NOTE [ Relation between Non-central Chi Square and Skellam ] 28 | // 29 | // Compute Prob[ NCX2( 2, 2*mu2 ) < 2*mu1 ] 30 | // 31 | // df = 2. 32 | // nc = 2 * mu2. 33 | // x = 2 * mu1 34 | scalar_t df = 2; 35 | scalar_t nc = 2 * mu2; 36 | scalar_t x = 2 * mu1; 37 | 38 | if (comp == TwoPoissonComparisonProb::GT) { 39 | return cdflib::cuda::chndtr(x, df, nc); 40 | } else if (comp == TwoPoissonComparisonProb::LE) { 41 | return 1 - cdflib::cuda::chndtr(x, df, nc); 42 | } else { 43 | // __builtin_unreachable(); 44 | // Until CUDA 11.3, can't use above. 45 | // https://developer.nvidia.com/blog/boosting-productivity-and-performance-with-the-nvidia-cuda-11-2-c-compiler/ 46 | return NAN; 47 | } 48 | 49 | }); 50 | }); 51 | } 52 | 53 | template 54 | void prob_two_poisson_grad_mu1_kernel_cuda(at::TensorIterator& iter) { 55 | const at::cuda::OptionalCUDAGuard device_guard(iter.device()); 56 | AT_DISPATCH_FLOATING_TYPES(iter.dtype(), "prob_two_poisson_grad_mu1_cuda", [&]() { 57 | at::native::gpu_kernel(iter, []GPU_LAMBDA(scalar_t mu1, scalar_t mu2, scalar_t gout) -> scalar_t { 58 | // See NOTE [ Relation between Non-central Chi Square and Skellam ] 59 | // 60 | // Compute 61 | // auto grad_mu1 = (-mu1 - mu2).exp() * torch::i0(2 * (mu1 * mu2).sqrt()) * gout; 62 | 63 | scalar_t g_gt; 64 | 65 | if (!use_besselIe) { 66 | g_gt = std::exp(-mu1 - mu2) * cephes::cuda::i0(std::sqrt(mu1 * mu2) * 2) * gout; 67 | } else { 68 | scalar_t twice_sqrtmu12 = std::sqrt(mu1 * mu2) * 2; 69 | g_gt = std::exp(twice_sqrtmu12 - mu1 - mu2) * cephes::cuda::i0e(twice_sqrtmu12) * gout; 70 | } 71 | 72 | if (comp == TwoPoissonComparisonProb::GT) { 73 | return g_gt; 74 | } else if (comp == TwoPoissonComparisonProb::LE) { 75 | return -g_gt; 76 | } else { 77 | // __builtin_unreachable(); 78 | // Until CUDA 11.3, can't use above. 79 | // https://developer.nvidia.com/blog/boosting-productivity-and-performance-with-the-nvidia-cuda-11-2-c-compiler/ 80 | return NAN; 81 | } 82 | 83 | }); 84 | }); 85 | } 86 | 87 | template 88 | void prob_two_poisson_grad_mu2_kernel_cuda(at::TensorIterator& iter) { 89 | const at::cuda::OptionalCUDAGuard device_guard(iter.device()); 90 | AT_DISPATCH_FLOATING_TYPES(iter.dtype(), "prob_two_poisson_gt_grad_mu2_cuda", [&]() { 91 | at::native::gpu_kernel(iter, []GPU_LAMBDA(scalar_t mu1, scalar_t mu2, scalar_t out, scalar_t gout) -> scalar_t { 92 | // See NOTE [ Relation between Non-central Chi Square and Skellam ] 93 | // 94 | // Compute 95 | // auto grad_mu2 = (chndtr_scalar(2 * mu1, 4, 2 * mu2) - out) * gout; 96 | if (mu1 == 0) { 97 | 98 | if (comp == TwoPoissonComparisonProb::GT) { 99 | return -out * gout; 100 | } else if (comp == TwoPoissonComparisonProb::LE) { 101 | return (1 - out) * gout; 102 | } else { 103 | // __builtin_unreachable(); 104 | // Until CUDA 11.3, can't use above. 105 | // https://developer.nvidia.com/blog/boosting-productivity-and-performance-with-the-nvidia-cuda-11-2-c-compiler/ 106 | return NAN; 107 | } 108 | 109 | } else if (!use_besselIe || mu2 == 0) { 110 | // nc = 2 * mu2. When mu2 == 0, the chndtr code computes cdf for Chi-Square isntead. 111 | // So let it handle. 112 | 113 | if (comp == TwoPoissonComparisonProb::GT) { 114 | return (cdflib::cuda::chndtr(2 * mu1, 4, 2 * mu2) - out) * gout; 115 | } else if (comp == TwoPoissonComparisonProb::LE) { 116 | return (1 - out - cdflib::cuda::chndtr(2 * mu1, 4, 2 * mu2)) * gout; 117 | } else { 118 | // __builtin_unreachable(); 119 | // Until CUDA 11.3, can't use above. 120 | // https://developer.nvidia.com/blog/boosting-productivity-and-performance-with-the-nvidia-cuda-11-2-c-compiler/ 121 | return NAN; 122 | } 123 | 124 | } else { 125 | scalar_t twice_sqrtmu12 = std::sqrt(mu1 * mu2) * 2; 126 | scalar_t log_mu1 = std::log(mu1); 127 | scalar_t log_mu2 = std::log(mu2); 128 | 129 | scalar_t g_le = std::exp( 130 | (log_mu1 - log_mu2) / 2 + twice_sqrtmu12 - mu1 - mu2 131 | ) * cephes::cuda::i1e(twice_sqrtmu12) * gout; 132 | 133 | if (comp == TwoPoissonComparisonProb::GT) { 134 | return -g_le; 135 | } else if (comp == TwoPoissonComparisonProb::LE) { 136 | return g_le; 137 | } else { 138 | // __builtin_unreachable(); 139 | // Until CUDA 11.3, can't use above. 140 | // https://developer.nvidia.com/blog/boosting-productivity-and-performance-with-the-nvidia-cuda-11-2-c-compiler/ 141 | return NAN; 142 | } 143 | } 144 | }); 145 | }); 146 | } 147 | 148 | } // namespace cuda 149 | } // namespace cdf_ops 150 | -------------------------------------------------------------------------------- /torchqmet/pqe/cdf_ops/load_ext.py: -------------------------------------------------------------------------------- 1 | import os 2 | import glob 3 | import warnings 4 | import packaging.version 5 | 6 | import torch 7 | import torch.version 8 | 9 | 10 | def check_env_flag(name, default=''): 11 | return os.getenv(name, default).upper() in ['ON', '1', 'YES', 'TRUE', 'Y'] 12 | 13 | 14 | DEBUG_FLAG = check_env_flag('DEBUG', default='0') 15 | 16 | 17 | def get_source_files(): 18 | files = [ 19 | # Entry 20 | os.path.join(os.path.dirname(__file__), "cdf_ops.cpp"), 21 | # CPU 22 | os.path.join(os.path.dirname(__file__), "cpu", "cdflib", "cdflib.cpp"), 23 | ] 24 | # CUDA 25 | if torch.cuda.is_available(): 26 | files.extend(glob.glob(os.path.join( 27 | glob.escape(os.path.join(os.path.dirname(__file__), "cuda")), 28 | "kernels_*.cu", 29 | ))) 30 | return tuple(files) 31 | 32 | 33 | def get_extra_cflags(): 34 | if DEBUG_FLAG: 35 | return ['-O0', '-fopenmp', '-march=native', '-g'] 36 | else: 37 | return ['-O3', '-fopenmp', '-march=native', '-funroll-loops'] 38 | 39 | 40 | def get_extra_cuda_cflags(): 41 | if not torch.cuda.is_available(): 42 | return [] 43 | if DEBUG_FLAG: 44 | return ['--expt-relaxed-constexpr', '--expt-extended-lambda', '-O0', '-Xcicc', '-O0', '-Xptxas', '-O0', '-g'] 45 | else: 46 | return ['--expt-relaxed-constexpr', '--expt-extended-lambda', '-O3'] 47 | 48 | 49 | _extension_loaded: bool = False 50 | _warn_first_load: bool = True 51 | 52 | 53 | def disable_load_extension_warning(): 54 | global _warn_first_load 55 | _warn_first_load = False 56 | 57 | 58 | def load_extension_if_needed(): 59 | global _extension_loaded 60 | if _extension_loaded: 61 | return 62 | 63 | if _warn_first_load: 64 | warnings.warn( 65 | 'Loading `cdf_ops` extension. If this is the first compilation on this machine, ' 66 | 'up to 10 minutes is needed. Subsequent loading will use cached results. ' 67 | 'Use `pqe.cdf_ops.disable_load_extension_warning()` to suppress this warning.') 68 | 69 | 70 | if torch.cuda.is_available() and torch.version.cuda is not None: 71 | if packaging.version.parse(torch.version.cuda) == packaging.version.parse('11.3'): 72 | raise RuntimeError( 73 | 'cdf_ops: CUDA 11.3 has a compiler bug that causes compiling `cdf_ops` to hang. ' 74 | 'Please use anewer CUDA version.') 75 | 76 | # JIT load 77 | from torch.utils.cpp_extension import load 78 | load( 79 | name="cdf_ops", 80 | sources=get_source_files(), 81 | extra_cflags=get_extra_cflags(), 82 | extra_cuda_cflags=get_extra_cuda_cflags(), 83 | is_python_module=False, 84 | with_cuda=torch.cuda.is_available(), 85 | ) 86 | _extension_loaded = True 87 | -------------------------------------------------------------------------------- /torchqmet/pqe/cdf_ops/op_wrappers.py: -------------------------------------------------------------------------------- 1 | from typing import * 2 | 3 | import torch 4 | try: 5 | import torch.special as tsp 6 | except ImportError: 7 | tsp = None 8 | 9 | 10 | from .load_ext import load_extension_if_needed # load lazily 11 | 12 | 13 | def chndtr(x: torch.Tensor, df: Union[torch.Tensor, float], nc: torch.Tensor) -> torch.Tensor: 14 | r""" 15 | Computes the non-central Chi-square CDF. 16 | 17 | For a distribution with :attr:`df` degrees of freedom and :attr:`nc` non-centrality parameter, 18 | this evaluates the CDF at :attr:`x`. 19 | """ 20 | load_extension_if_needed() 21 | return torch.ops.cdf_ops.chndtr(x, df, nc) 22 | 23 | 24 | if not hasattr(torch, 'i0'): # i0 was first added as torch.i0 25 | def i0(input: torch.Tensor) -> torch.Tensor: 26 | r""" 27 | Computes the zeroth order modified Bessel function of the first kind for each element of :attr:`input`. 28 | 29 | .. math:: 30 | \text{out}_{i} = I_0(\text{input}_{i}) = \sum_{k=0}^{\infty} \frac{(\text{input}_{i}^2/4)^k}{(k!)^2} 31 | """ 32 | load_extension_if_needed() 33 | return torch.ops.cdf_ops.i0(input) 34 | else: 35 | def i0(input: torch.Tensor) -> torch.Tensor: 36 | r""" 37 | Computes the zeroth order modified Bessel function of the first kind for each element of :attr:`input`. 38 | 39 | .. math:: 40 | \text{out}_{i} = I_0(\text{input}_{i}) = \sum_{k=0}^{\infty} \frac{(\text{input}_{i}^2/4)^k}{(k!)^2} 41 | """ 42 | return torch.i0(input) 43 | 44 | 45 | if not hasattr(tsp, 'i0e'): 46 | def i0e(input: torch.Tensor) -> torch.Tensor: 47 | r""" 48 | Computes the exponentially scaled zeroth order modified Bessel function of the first kind (as defined below) 49 | for each element of :attr:`input`. 50 | 51 | .. math:: 52 | \text{out}_{i} = \exp(-|x|) * i0(x) = \exp(-|x|) * \sum_{k=0}^{\infty} \frac{(\text{input}_{i}^2/4)^k}{(k!)^2} 53 | """ 54 | load_extension_if_needed() 55 | return torch.ops.cdf_ops.i0e(input) 56 | else: 57 | def i0e(input: torch.Tensor) -> torch.Tensor: 58 | r""" 59 | Computes the exponentially scaled zeroth order modified Bessel function of the first kind (as defined below) 60 | for each element of :attr:`input`. 61 | 62 | .. math:: 63 | \text{out}_{i} = \exp(-|x|) * i0(x) = \exp(-|x|) * \sum_{k=0}^{\infty} \frac{(\text{input}_{i}^2/4)^k}{(k!)^2} 64 | """ 65 | return tsp.i0e(input) 66 | 67 | 68 | if not hasattr(tsp, 'i1'): 69 | def i1(input: torch.Tensor) -> torch.Tensor: 70 | r""" 71 | Computes the first order modified Bessel function of the first kind (as defined below) 72 | for each element of :attr:`input`. 73 | 74 | .. math:: 75 | \text{out}_{i} = \frac{(\text{input}_{i})}{2} * \sum_{k=0}^{\infty} \frac{(\text{input}_{i}^2/4)^k}{(k!) * (k+1)!} 76 | """ 77 | load_extension_if_needed() 78 | return torch.ops.cdf_ops.i1(input) 79 | else: 80 | def i1(input: torch.Tensor) -> torch.Tensor: 81 | r""" 82 | Computes the first order modified Bessel function of the first kind (as defined below) 83 | for each element of :attr:`input`. 84 | 85 | .. math:: 86 | \text{out}_{i} = \frac{(\text{input}_{i})}{2} * \sum_{k=0}^{\infty} \frac{(\text{input}_{i}^2/4)^k}{(k!) * (k+1)!} 87 | """ 88 | return tsp.i1(input) 89 | 90 | 91 | if not hasattr(tsp, 'i1e'): 92 | def i1e(input: torch.Tensor) -> torch.Tensor: 93 | r""" 94 | Computes the exponentially scaled first order modified Bessel function of the first kind (as defined below) 95 | for each element of :attr:`input`. 96 | 97 | .. math:: 98 | \text{out}_{i} = \exp(-|x|) * i1(x) = 99 | \exp(-|x|) * \frac{(\text{input}_{i})}{2} * \sum_{k=0}^{\infty} \frac{(\text{input}_{i}^2/4)^k}{(k!) * (k+1)!} 100 | """ 101 | load_extension_if_needed() 102 | return torch.ops.cdf_ops.i1e(input) 103 | else: 104 | def i1e(input: torch.Tensor) -> torch.Tensor: 105 | r""" 106 | Computes the exponentially scaled first order modified Bessel function of the first kind (as defined below) 107 | for each element of :attr:`input`. 108 | 109 | .. math:: 110 | \text{out}_{i} = \exp(-|x|) * i1(x) = 111 | \exp(-|x|) * \frac{(\text{input}_{i})}{2} * \sum_{k=0}^{\infty} \frac{(\text{input}_{i}^2/4)^k}{(k!) * (k+1)!} 112 | """ 113 | return tsp.i1e(input) 114 | 115 | 116 | def prob_two_poisson_gt(mu1: torch.Tensor, mu2: torch.Tensor) -> torch.Tensor: 117 | r""" 118 | Computes the elementwise ``Prob[ Poisson(mu1) > Poisson(mu2) ]``. 119 | """ 120 | load_extension_if_needed() 121 | return torch.ops.cdf_ops.prob_two_poisson_gt(mu1, mu2) 122 | 123 | 124 | def prob_two_poisson_le(mu1: torch.Tensor, mu2: torch.Tensor) -> torch.Tensor: 125 | r""" 126 | Computes the elementwise ``Prob[ Poisson(mu1) <= Poisson(mu2) ]``. 127 | """ 128 | load_extension_if_needed() 129 | return torch.ops.cdf_ops.prob_two_poisson_le(mu1, mu2) 130 | 131 | 132 | def ndtr(x: torch.Tensor) -> torch.Tensor: 133 | r""" 134 | Computes the standard Gaussian CDF evaluated at :attr:`x`. 135 | """ 136 | load_extension_if_needed() 137 | return torch.ops.cdf_ops.ndtr(x) 138 | 139 | 140 | def log_ndtr(x: torch.Tensor) -> torch.Tensor: 141 | r""" 142 | Computes the log of the standard Gaussian CDF evaluated at :attr:`x`. 143 | 144 | This is numerically more stable than calling ``ndtr(x).log()``, in both forward and backward. 145 | """ 146 | load_extension_if_needed() 147 | return torch.ops.cdf_ops.log_ndtr(x) 148 | 149 | 150 | def prod_ndtr(x: torch.Tensor, *, dim: int = -1) -> torch.Tensor: 151 | r""" 152 | Computes ``ndtr(x).prod(dim=dim)``. 153 | 154 | This is numerically more stable than calling ``ndtr(x).prod(dim=dim)``, in both forward and backward. 155 | """ 156 | load_extension_if_needed() 157 | return torch.ops.cdf_ops.log_ndtr(x, dim=dim) 158 | 159 | 160 | __all__ = [ 161 | 'chndtr', 'i0', 'i0e', 'i1', 'i1e', 162 | 'prob_two_poisson_gt', 'prob_two_poisson_le', 163 | 'ndtr', 'log_ndtr', 'prod_ndtr', 164 | ] 165 | -------------------------------------------------------------------------------- /torchqmet/pqe/measures.py: -------------------------------------------------------------------------------- 1 | from typing import * 2 | 3 | import torch 4 | import torch.nn as nn 5 | 6 | from ..utils import DeepLinearNet 7 | 8 | 9 | class MeasureBase(nn.Module): 10 | def __init__(self, shape: torch.Size): 11 | r""" 12 | `num_measures` defines how many measures (each of a different Poisson process 13 | space) this parameterizes. 14 | """ 15 | super().__init__() 16 | self.shape = shape 17 | assert len(shape) == 2, "shape should be [num_quasipartition_mixtures, num_process_per_mixture]" 18 | 19 | 20 | class LebesgueMeasure(MeasureBase): 21 | pass 22 | 23 | 24 | class GaussianBasedMeasure(MeasureBase): 25 | def __init__(self, shape: torch.Size, *, init_sigma2=1): 26 | super().__init__(shape) 27 | self.log_sigma2 = nn.Parameter(torch.empty(shape).fill_(init_sigma2).log_().requires_grad_()) 28 | self.scales_net = DeepLinearNet(shape.numel(), shape.numel(), non_negative=True, bias=True) # Sec. C.4.3 29 | 30 | def scale(self, x: torch.Tensor) -> torch.Tensor: 31 | return self.scale_multiple(x)[0] 32 | 33 | def scale_multiple(self, *xs: torch.Tensor) -> Tuple[torch.Tensor, ...]: 34 | assert all(x.shape[-2:] == self.shape for x in xs) 35 | return tuple(scaled_x.unflatten(-1, self.shape) for scaled_x in self.scales_net(*[x.flatten(-2) for x in xs])) 36 | 37 | @property 38 | def sigma2(self): 39 | return self.log_sigma2.exp() 40 | 41 | @property 42 | def sigma(self): 43 | return self.log_sigma2.div(2).exp() 44 | -------------------------------------------------------------------------------- /torchqmet/pqe/shapes.py: -------------------------------------------------------------------------------- 1 | import abc 2 | import math 3 | 4 | import torch 5 | import torch.nn as nn 6 | 7 | from . import cdf_ops 8 | 9 | from .measures import MeasureBase, LebesgueMeasure, GaussianBasedMeasure 10 | 11 | 12 | class ShapeBase(nn.Module, metaclass=abc.ABCMeta): 13 | @abc.abstractmethod 14 | def expected_quasipartiton(self: 'ShapeBase', u: torch.Tensor, v: torch.Tensor, *, measure: MeasureBase): 15 | r""" 16 | Computes the expected quasipartition as defined in Equation (13): 17 | 18 | For pi, a random quasipartition given by the Poisson processes defined with `measure` and shape parametrization 19 | `self`, 20 | 21 | E[ pi(u, v) ] = 1 - \prod_j Pr[ Count( Shape(u) ) <= Count( Shape(v) ) ] 22 | = 1 - \prod_j Pr[ Poisson( Measure(Shape(u) \ Shape(v)) ) <= Poisson( Measure(Shape(v) \ Shape(u)) ) ], 23 | 24 | where `Count(*)` is the Poisson process count. 25 | """ 26 | pass 27 | 28 | 29 | class HalfLineShape(ShapeBase): 30 | def expected_quasipartiton(self, u: torch.Tensor, v: torch.Tensor, *, measure: MeasureBase): 31 | # Shapes are (-infty, u) and (-\infty, v) 32 | if isinstance(measure, LebesgueMeasure): 33 | # PQE-LH, Eqn. (9) 34 | # 35 | # 1 - \prod_j exp( -max(u_j - v_j, 0) ) 36 | # = 1 - exp( \sum_j min(v_j - u_j, 0) ) 37 | # 38 | # Use mean instead of sum as a normalization for better stability at initialization (Sec. C.4.1). 39 | return -torch.expm1((v - u).clamp(max=0).mean(-1)) 40 | elif isinstance(measure, GaussianBasedMeasure): 41 | measure_exceed = measure.scale((cdf_ops.ndtr(v / measure.sigma) - cdf_ops.ndtr(u / measure.sigma)).clamp(max=0)) 42 | return -torch.expm1(measure_exceed.mean(-1)) 43 | else: 44 | raise NotImplementedError(f"measure={measure} is not supported") 45 | 46 | 47 | class GaussianShape(ShapeBase): 48 | sigma2: float = 1. 49 | sigma: float = 1. 50 | log_2pi: float = math.log(2 * math.pi) 51 | 52 | def expected_quasipartiton(self, u: torch.Tensor, v: torch.Tensor, *, measure: MeasureBase): 53 | # Shapes are areas under the unit Gaussian CDF centered at u or v 54 | if isinstance(measure, LebesgueMeasure): 55 | raise RuntimeError('Gaussian shape under lebesgue measure is always symmetrical') 56 | elif isinstance(measure, GaussianBasedMeasure): 57 | # See Sec. C.2 for details. 58 | 59 | # To obtain the rate of the Gaussian shape (centered at mu) over an interval, 60 | # we can view it as integrating density product of two independent Gaussian along a 61 | # line of the form Y = X + a. 62 | # 63 | # Following many algebraic manipulations, one can obtain, for the 1D case, 64 | # 65 | # 1 / \sqrt{2pi (shape_sig2 + measure_sig2)} * exp( - mu^2 / (2 (shape_sig2 + measure_sig2)) ) 66 | # * \Int GaussianDensity(mean=mu (measure_sig2 / (shape_sig2 + measure_sig2)), 67 | # sig2=shape_sig2 * measure_sig2 / (shape_sig2 + measure_sig2)) 68 | # 69 | # which can be simply viewed as scaled density of another "induced" Gaussian. 70 | # 71 | # Total rate is thus 72 | # lambda / \sqrt{2pi (shape_sig2 + measure_sig2)} * exp( - mu^2 / (2 (shape_sig2 + measure_sig2)) ) 73 | 74 | measure_sigma2 = measure.sigma2 75 | shape_sigma2 = self.sigma2 76 | sum_sigma2 = shape_sigma2 + measure_sigma2 77 | log_sum_sigma2 = torch.log(sum_sigma2) 78 | 79 | log_total_base = - (self.log_2pi + log_sum_sigma2) / 2 80 | log_total_u = log_total_base - 0.5 * (u ** 2) / sum_sigma2 81 | log_total_v = log_total_base - 0.5 * (v ** 2) / sum_sigma2 82 | total_u = log_total_u.exp() 83 | total_v = log_total_v.exp() 84 | 85 | mid = (u + v) / 2 86 | 87 | # `mid` in the new mu Gaussian (the "induced" Gaussian in formula above) 88 | # after normalization would be 89 | # mid - mu * measure_sig2 / (measure_sig2 + shape_sig2) 90 | # -------------------------------------------------------- 91 | # \sqrt{measure_sig2 shape_sig2 / (measure_sig2 + shape_sig2)} 92 | 93 | new_mu_mult = measure_sigma2 / sum_sigma2 94 | new_sig2 = shape_sigma2 * new_mu_mult 95 | new_sig = torch.sqrt(new_sig2) 96 | u2mid_frac = cdf_ops.ndtr( (mid - u * new_mu_mult) / new_sig ) # noqa: E201, E202 97 | v2mid_frac = cdf_ops.ndtr( (mid - v * new_mu_mult) / new_sig ) # noqa: E201, E202 98 | 99 | intersection = torch.where( 100 | u < v, 101 | v2mid_frac * total_v + (1 - u2mid_frac) * total_u, 102 | u2mid_frac * total_u + (1 - v2mid_frac) * total_v, 103 | ) 104 | 105 | u_only = total_u - intersection 106 | v_only = total_v - intersection 107 | 108 | # Mathematically, `intersection` is smaller than both. But numerically it sometimes is computed to be 109 | # slightly larger. So we fix here (w/o changing the computation graph for autograd). 110 | u_only.data.clamp_(min=0) 111 | v_only.data.clamp_(min=0) 112 | 113 | u_only, v_only = measure.scale_multiple(u_only, v_only) 114 | 115 | ple: torch.Tensor = cdf_ops.prob_two_poisson_le(u_only / u.shape[-1], v_only / v.shape[-1]) 116 | return 1 - ple.prod(dim=-1) 117 | else: 118 | raise NotImplementedError(f"measure={measure} is not supported") 119 | -------------------------------------------------------------------------------- /torchqmet/reductions.py: -------------------------------------------------------------------------------- 1 | from typing import * 2 | 3 | import abc 4 | import math 5 | 6 | import torch 7 | import torch.nn as nn 8 | 9 | from .utils import DeepLinearNet, multidot, sigmoid_pow 10 | 11 | 12 | class ReductionBase(nn.Module, metaclass=abc.ABCMeta): 13 | input_num_components: int 14 | discount: Optional[float] 15 | 16 | def __init__(self, input_num_components: int, discount: Optional[float] = None) -> None: 17 | super().__init__() 18 | self.input_num_components = input_num_components 19 | self.discount = discount 20 | 21 | def reduce_distance(self, d: torch.Tensor) -> torch.Tensor: 22 | if self.discount is None: 23 | raise RuntimeError(f"{self} does not support non-discounted distances") 24 | return self.reduce_discounted_distance(d).log() / math.log(self.discount) 25 | 26 | def reduce_discounted_distance(self, d: torch.Tensor) -> torch.Tensor: 27 | return self.discount ** self.reduce_distance(d) 28 | 29 | def forward(self, d: torch.Tensor) -> torch.Tensor: 30 | if self.discount is None: 31 | return self.reduce_distance(d) 32 | else: 33 | return self.reduce_discounted_distance(d) 34 | 35 | def __call__(self, d: torch.Tensor) -> torch.Tensor: 36 | # Manually define for typing 37 | # https://github.com/pytorch/pytorch/issues/45414 38 | return super().__call__(d) 39 | 40 | def extra_repr(self) -> str: 41 | if self.discount is None: 42 | return f"input_num_components={self.input_num_components}" 43 | else: 44 | return f"input_num_components={self.input_num_components}, discount={self.discount:g}" 45 | 46 | 47 | class Max(ReductionBase): 48 | def reduce_distance(self, d: torch.Tensor) -> torch.Tensor: 49 | return d.max(dim=-1).values 50 | 51 | 52 | class Sum(ReductionBase): 53 | def reduce_distance(self, d: torch.Tensor) -> torch.Tensor: 54 | return d.sum(dim=-1) 55 | 56 | 57 | class Mean(ReductionBase): 58 | def reduce_distance(self, d: torch.Tensor) -> torch.Tensor: 59 | return d.mean(dim=-1) 60 | 61 | 62 | class MaxMean(ReductionBase): 63 | r''' 64 | `maxmean` from Neural Norms paper: 65 | https://arxiv.org/abs/2002.05825 66 | 67 | Implementation follows the official implementation: 68 | https://github.com/spitis/deepnorms/blob/6c8db1b1178eb92df23149c6d6bfb10782daac86/metrics_tf1.py#L26 69 | ''' 70 | 71 | def __init__(self, input_num_components: int, discount: Optional[float] = None) -> None: 72 | super().__init__(input_num_components=input_num_components, discount=discount) 73 | self.raw_alpha = nn.Parameter(torch.ones(()).neg_().requires_grad_()) # pre sigmoid 74 | 75 | def reduce_distance(self, d: torch.Tensor) -> torch.Tensor: 76 | alpha: torch.Tensor = self.raw_alpha.sigmoid() 77 | return torch.lerp( 78 | d.mean(dim=-1), # * (1 - alpha) 79 | d.max(dim=-1).values, # * alpha 80 | alpha, 81 | ) 82 | 83 | 84 | class DeepLinearNetWeightedSum(ReductionBase): 85 | r''' 86 | PQE-style aggregation by weighted sum from deep linear networks: 87 | https://arxiv.org/abs/2206.15478 88 | 89 | When using `discount`, we follow the original paper Sec. C.4.2 (and official PQE repository), and use a deep linear 90 | network to parametrize the input to a `sigmoid` function, whose output is used in an exponentiation. 91 | I.e., \prod_i sigmoid( deep_lienar_net_output[i] ) ** components[i]. 92 | ''' 93 | 94 | def __init__(self, input_num_components: int, discount: Optional[float] = None) -> None: 95 | super().__init__(input_num_components=input_num_components, discount=discount) 96 | if self.discount is None: 97 | self.alpha_net = DeepLinearNet(input_dim=input_num_components, output_dim=1, non_negative=True) 98 | else: 99 | self.beta_net = DeepLinearNet(input_dim=1, output_dim=input_num_components, non_negative=False) 100 | 101 | # Initialize logits so initial output is between 0.5 and 0.75. (Sec. C.4.3) 102 | # 103 | # Note that this is important since we are multiplying a bunch of things < 1 together, 104 | # and thus must take care to not make the result close to 0. 105 | # 106 | # Say the quasipartitions are 0.5. For output = y, with k quasipartitions, 107 | # we want the base to be roughly 108 | # - log ( y^{-2/k} - 1). 109 | 110 | k = input_num_components 111 | low_out = 0.5 112 | high_out = 0.75 113 | low = -math.log(low_out ** (-2 / k) - 1) 114 | high = -math.log(high_out ** (-2 / k) - 1) 115 | # `DeepLinearNet` should initialize s.t. the collapsed vector roughly 116 | # has zero mean and 1 variance. This holds even for intermediate activations. 117 | # NB that we crucially used `in_dim=1` rather than `out_dim=1`, which will make 118 | # weights have variance O(1/n). 119 | 120 | ms: List[torch.Tensor] = list(self.beta_net.mats) 121 | 122 | with torch.no_grad(): 123 | # collapse all but last 124 | out_before_last: torch.Tensor = multidot(ms[1:]) 125 | norm_out_before_last: torch.Tensor = out_before_last.norm() 126 | unit_out_before_last: torch.Tensor = out_before_last/ out_before_last.norm() 127 | 128 | # now simply constrain the projection dimension 129 | ms[0].sub_((ms[0] @ unit_out_before_last) @ unit_out_before_last.T) \ 130 | .add_(torch.empty(k, 1).uniform_(low, high).div(norm_out_before_last) @ unit_out_before_last.T) # noqa: E501 131 | q = self.beta_net.collapse().squeeze(1).sigmoid().pow(0.5).prod().item() 132 | assert low_out <= q <= high_out, q 133 | 134 | 135 | def reduce_distance(self, d: torch.Tensor) -> torch.Tensor: 136 | return self.alpha_net(d).squeeze(-1) 137 | 138 | def reduce_discounted_distance(self, d: torch.Tensor) -> torch.Tensor: 139 | logits = self.beta_net.collapse().squeeze(1) 140 | return sigmoid_pow(logits, d).prod(-1) 141 | 142 | 143 | REDUCTIONS: Mapping[str, Type[ReductionBase]] = dict( 144 | sum=Sum, 145 | mean=Mean, 146 | maxmean=MaxMean, 147 | max=Max, 148 | deep_linear_net_weighted_sum=DeepLinearNetWeightedSum, 149 | ) 150 | 151 | 152 | def make_reduction(kind: str, input_num_components: int, discount: Optional[float] = None) -> ReductionBase: 153 | return REDUCTIONS[kind](input_num_components, discount) 154 | -------------------------------------------------------------------------------- /torchqmet/transforms.py: -------------------------------------------------------------------------------- 1 | from typing import * 2 | 3 | import abc 4 | 5 | import torch 6 | import torch.nn as nn 7 | import torch.nn.functional as F 8 | 9 | 10 | class TransformBase(nn.Module, metaclass=abc.ABCMeta): 11 | input_num_components: int 12 | output_num_components: int 13 | 14 | def __init__(self, input_num_components: int, output_num_components: int) -> None: 15 | super().__init__() 16 | self.input_num_components = input_num_components 17 | self.output_num_components = output_num_components 18 | 19 | @abc.abstractmethod 20 | def forward(self, d: torch.Tensor) -> torch.Tensor: 21 | pass 22 | 23 | def __call__(self, d: torch.Tensor) -> torch.Tensor: 24 | # Manually define for typing 25 | # https://github.com/pytorch/pytorch/issues/45414 26 | return super().__call__(d) 27 | 28 | def extra_repr(self) -> str: 29 | return f"input_num_components={self.input_num_components}, output_num_components={self.output_num_components}" 30 | 31 | 32 | @torch.jit.script 33 | def apply_concave_activations(x: torch.Tensor, bs_first_constant: torch.Tensor, raw_bs_after_first: torch.Tensor, 34 | raw_ms: torch.Tensor) -> torch.Tensor: 35 | bs = torch.cat([bs_first_constant, F.softplus(raw_bs_after_first)], dim=-1) 36 | ms = torch.sigmoid(raw_ms) * 2 37 | v = torch.addcmul(bs, x.unsqueeze(-1), ms) 38 | return v.min(-1).values 39 | 40 | 41 | class ConcaveActivation(TransformBase): 42 | r''' 43 | Learned concave activatations used in neural norms (Deep Norm and Wide Norm): 44 | https://arxiv.org/abs/2002.05825 45 | 46 | Follows their official implementation: 47 | https://github.com/spitis/deepnorms/blob/6c8db1b1178eb92df23149c6d6bfb10782daac86/metrics_tf1.py#L30 48 | ''' 49 | 50 | bs_first_constant: torch.Tensor 51 | 52 | def __init__(self, input_num_components: int, num_units_per_input: int = 5): 53 | super().__init__(input_num_components, input_num_components) 54 | self.num_units_per_input = num_units_per_input 55 | self.register_buffer('bs_first_constant', torch.zeros(input_num_components, 1)) 56 | self.raw_bs_after_first = nn.Parameter( 57 | torch.randn(input_num_components, num_units_per_input - 1).mul_(1e-3).sub_(1).requires_grad_()) 58 | self.raw_ms = nn.Parameter( 59 | torch.randn(input_num_components, num_units_per_input).mul_(1e-3).requires_grad_()) 60 | 61 | def forward(self, d: torch.Tensor) -> torch.Tensor: 62 | return apply_concave_activations(d, self.bs_first_constant, self.raw_bs_after_first, self.raw_ms) 63 | 64 | def extra_repr(self) -> str: 65 | return super().extra_repr() + f"\nnum_units_per_input={self.num_units_per_input}" 66 | 67 | 68 | TRANSFORMS: Mapping[str, Type[TransformBase]] = dict( 69 | concave_activation=ConcaveActivation, 70 | ) 71 | 72 | def make_transform(kind: str, input_num_components: int) -> TransformBase: 73 | return TRANSFORMS[kind](input_num_components) 74 | -------------------------------------------------------------------------------- /torchqmet/utils.py: -------------------------------------------------------------------------------- 1 | from typing import * 2 | 3 | import torch 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | 7 | 8 | class SigmoidPow(torch.autograd.Function): 9 | # Computes `sigmoid(x)^y` and avoids NaN gradients when x << 0 and sigmoid(x) = 0 numerically. 10 | 11 | @staticmethod 12 | def forward(ctx, x: torch.Tensor, y: torch.Tensor): 13 | # Compute sigmoid(x)^y = exp( logsigmoid(x) * y ), 14 | # where logsigmoid(x) = - softplus(-x). 15 | logsigmoid = F.softplus(-x).neg() 16 | out = logsigmoid.mul(y).exp() 17 | ctx.save_for_backward(out, logsigmoid, x, y) 18 | return out 19 | 20 | @staticmethod 21 | def backward(ctx, gout: torch.Tensor): 22 | # Formula is simple, obtained from mathematica here. 23 | out, logsigmoid, x, y = ctx.saved_tensors 24 | gx = ((y + 1) * logsigmoid - x).exp() * y * gout 25 | gy = logsigmoid * out * gout 26 | return gx, gy 27 | 28 | 29 | def sigmoid_pow(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: 30 | # Computes `sigmoid(x)^y` and avoids NaN gradients when x << 0 and sigmoid(x) = 0 numerically. 31 | return SigmoidPow.apply(x, y) 32 | 33 | 34 | # https://stackoverflow.com/a/14267825 35 | def ceilpow2(x: int): 36 | assert x > 0 37 | return 1 << (x - 1).bit_length() 38 | 39 | 40 | if hasattr(torch, 'linalg') and hasattr(torch.linalg, 'multi_dot'): 41 | multidot = torch.linalg.multi_dot 42 | else: 43 | def multidot(mats: List[torch.Tensor]): 44 | return torch.chain_matmul(*mats) 45 | 46 | 47 | class DeepLinearNet(nn.Module): 48 | r""" 49 | Parametrize a vector/matrix as the matrix multiplication of a couple matrices (Sec. C.4.3). 50 | 51 | By default, this uses 3 hidden layers with dimension max(64, 1 + OutDim', 1 + InDim'), 52 | 53 | where XDim' is the smallest power of 2 that >= XDim. 54 | """ 55 | 56 | bias: Optional[torch.Tensor] 57 | 58 | def __init__(self, input_dim, output_dim, *, hidden_dims=None, non_negative=False, bias=True): 59 | super().__init__() 60 | 61 | if hidden_dims is None: 62 | hidden_dim = max(ceilpow2(input_dim), 64) 63 | hidden_dim = max(ceilpow2(output_dim), hidden_dim) 64 | hidden_dims = [hidden_dim, hidden_dim, hidden_dim] 65 | elif not bias and len(hidden_dims) > 0: 66 | assert min(hidden_dim) >= min(input_dim, output_dim), "cannot lose rank" 67 | 68 | self.input_dim = input_dim 69 | self.output_dim = output_dim 70 | self.hidden_dims = hidden_dims 71 | self.non_negative = non_negative 72 | 73 | dims = [input_dim] + list(hidden_dims) + [output_dim] 74 | mats = [] 75 | layer_in_dim = dims[0] 76 | for nh in dims[1:]: 77 | mats.append(nn.Linear(layer_in_dim, nh, bias=False).weight) # [linout, linin] 78 | nn.init.kaiming_normal_(mats[-1], a=1) # a=1 for linear! 79 | layer_in_dim = nh 80 | self.mats = nn.ParameterList(mats[::-1]) 81 | 82 | if bias: 83 | self.register_parameter('bias', nn.Parameter(torch.zeros(self.output_dim, self.input_dim), requires_grad=True)) 84 | else: 85 | self.register_buffer('bias', None) 86 | 87 | def collapse(self) -> torch.Tensor: 88 | # Returns A of Ax 89 | m: torch.Tensor = multidot(list(self.mats)) 90 | if self.bias is not None: 91 | m = m + self.bias 92 | if self.non_negative: 93 | m = m.pow(2) 94 | return m 95 | 96 | def forward(self, *xs: torch.Tensor): 97 | m = self.collapse().T 98 | if len(xs) == 1: 99 | return xs[0] @ m 100 | else: 101 | return tuple(x @ m for x in xs) 102 | 103 | def __call__(self, *xs: torch.Tensor) -> torch.Tensor: 104 | return super().__call__(*xs) 105 | 106 | def extra_repr(self) -> str: 107 | return f"bias={self.bias is not None}, non_negative={self.non_negative}" 108 | 109 | def get_num_effective_nparameters(self): 110 | return self.input_dim * self.output_dim 111 | 112 | 113 | def get_num_effective_parameters(module: nn.Module) -> int: 114 | total = 0 115 | 116 | def add(m: nn.Module): 117 | nonlocal total 118 | if isinstance(m, DeepLinearNet): 119 | total += m.get_num_effective_nparameters() 120 | else: 121 | total += sum(p.numel() for p in m.parameters(recurse=False)) 122 | for c in m.children(): 123 | add(c) 124 | 125 | add(module) 126 | return total 127 | --------------------------------------------------------------------------------