├── requirements.txt ├── .gitignore ├── tests ├── requirements_tests.txt └── test.py ├── setup.py ├── CONTRIBUTING ├── README.md ├── LICENSE └── TruncatedNormal.py /requirements.txt: -------------------------------------------------------------------------------- 1 | torch>=1.5.0 2 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | .idea 2 | __pycache__ 3 | deploy* 4 | -------------------------------------------------------------------------------- /tests/requirements_tests.txt: -------------------------------------------------------------------------------- 1 | torch>=1.5.0 2 | scipy -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | import setuptools 2 | 3 | with open("README.md", "r") as fh: 4 | long_description = fh.read() 5 | 6 | setuptools.setup( 7 | name='torch_truncnorm', 8 | version='0.0.2', 9 | long_description=long_description, 10 | long_description_content_type="text/markdown", 11 | description='Truncated Normal distribution in PyTorch', 12 | python_requires='>=3.6', 13 | packages=setuptools.find_packages(), 14 | author='Anton Obukhov', 15 | license='BSD', 16 | url='https://www.github.com/toshas/torch_truncnorm', 17 | ) 18 | -------------------------------------------------------------------------------- /CONTRIBUTING: -------------------------------------------------------------------------------- 1 | When contributing to this repository, please first discuss the change you wish to make via the GitHub issue with the 2 | owners of this repository before making a change. Unsolicited pull requests may be rejected without consideration. By 3 | submitting a pull request into this repository, the contributor accepts the terms of the LICENSE agreement(s) and 4 | acknowledges that the proposed change is legally compatible with the repository in every possible way and does not 5 | impose any additional constraints on the repository after inclusion. In case the submitted change contains portions of 6 | another open-source project, the pull request description must contain references to the original source code, and a 7 | proper acknowledgement must be added to the README file. 8 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # torch_truncnorm 2 | Truncated Normal distribution in PyTorch. The module provides: 3 | - `TruncatedStandardNormal` class - zero mean unit variance of the parent Normal distribution, parameterized by the 4 | cut-off range `[a, b]` (similar to `scipy.stats.truncnorm`); 5 | - `TruncatedNormal` class - a wrapper with extra `loc` and `scale` parameters of the parent Normal distribution; 6 | - Differentiability wrt parameters of the distribution; 7 | - Batching support. 8 | 9 | # Why 10 | I just needed differentiation with respect to parameters of the distribution and found out that truncated normal 11 | distribution is not bundled in `torch.distributions` as of 1.6.0. 12 | 13 | # Known issues 14 | `icdf` is numerically unstable; as a consequence, so is `rsample`. This issue is also seen in 15 | `torch.distributions.normal.Normal`, so it is sort of *normal* (ba-dum-tss). 16 | 17 | # Tests 18 | ```shell script 19 | CUDA_VISIBLE_DEVICES=0 python -m tests.test 20 | ``` 21 | 22 | # Links 23 | https://people.sc.fsu.edu/~jburkardt/presentations/truncated_normal.pdf 24 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | BSD 3-Clause License 2 | 3 | Copyright (c) 2020, Anton Obukhov 4 | All rights reserved. 5 | 6 | Redistribution and use in source and binary forms, with or without 7 | modification, are permitted provided that the following conditions are met: 8 | 9 | 1. Redistributions of source code must retain the above copyright notice, this 10 | list of conditions and the following disclaimer. 11 | 12 | 2. Redistributions in binary form must reproduce the above copyright notice, 13 | this list of conditions and the following disclaimer in the documentation 14 | and/or other materials provided with the distribution. 15 | 16 | 3. Neither the name of the copyright holder nor the names of its 17 | contributors may be used to endorse or promote products derived from 18 | this software without specific prior written permission. 19 | 20 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" 21 | AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 22 | IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 23 | DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE 24 | FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL 25 | DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR 26 | SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER 27 | CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, 28 | OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 29 | OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 30 | -------------------------------------------------------------------------------- /tests/test.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | from unittest.util import safe_repr 3 | from warnings import warn 4 | 5 | import torch 6 | from scipy.stats import truncnorm 7 | from TruncatedNormal import TruncatedNormal as TruncatedNormalPT 8 | 9 | 10 | class TruncatedNormalSC: 11 | def __init__(self, loc, scale, a, b): 12 | self.loc = loc 13 | self.scale = scale 14 | self.alpha = (a - loc) / scale 15 | self.beta = (b - loc) / scale 16 | 17 | @property 18 | def mean(self): 19 | return truncnorm.stats(self.alpha, self.beta, loc=self.loc, scale=self.scale, moments='m') 20 | 21 | @property 22 | def variance(self): 23 | return truncnorm.stats(self.alpha, self.beta, loc=self.loc, scale=self.scale, moments='v') 24 | 25 | def cdf(self, value): 26 | return truncnorm.cdf(value, self.alpha, self.beta, loc=self.loc, scale=self.scale) 27 | 28 | def icdf(self, value): 29 | return truncnorm.ppf(value, self.alpha, self.beta, loc=self.loc, scale=self.scale) 30 | 31 | def log_prob(self, value): 32 | return truncnorm.logpdf(value, self.alpha, self.beta, loc=self.loc, scale=self.scale) 33 | 34 | @property 35 | def entropy(self): 36 | return truncnorm.entropy(self.alpha, self.beta, loc=self.loc, scale=self.scale) 37 | 38 | 39 | class Tests(unittest.TestCase): 40 | 41 | def assertRelativelyEqual(self, first, second, tol=1e-6, error=1e-5, msg=None): 42 | if first == second: 43 | return 44 | diff = abs(first - second) 45 | rel = diff / max(abs(first), abs(second)) 46 | if rel <= tol or diff <= error: 47 | return 48 | standardMsg = '%s != %s within tol=%s abs=%s (rel=%s diff=%s)' % (safe_repr(first), safe_repr(second), 49 | safe_repr(tol), safe_repr(error), safe_repr(rel), safe_repr(diff)) 50 | msg = self._formatMessage(msg, standardMsg) 51 | raise self.failureException(msg) 52 | 53 | def _test_numerical(self, loc, scale, a, b, do_icdf=True): 54 | pt = TruncatedNormalPT(loc, scale, a, b, validate_args=None) 55 | sc = TruncatedNormalSC(loc, scale, a, b) 56 | 57 | mean_sc = sc.mean 58 | mean_pt = pt.mean.numpy() 59 | self.assertRelativelyEqual(mean_sc, mean_pt) 60 | 61 | var_sc = sc.variance 62 | var_pt = pt.variance.numpy() 63 | self.assertRelativelyEqual(var_sc, var_pt) 64 | 65 | entropy_sc = sc.entropy 66 | entropy_pt = pt.entropy.numpy() 67 | self.assertRelativelyEqual(entropy_sc, entropy_pt) 68 | 69 | N = 10 70 | for i in range(N): 71 | p = i / (N - 1) 72 | x = a + (b - a) * p 73 | 74 | cdf_sc = sc.cdf(x) 75 | cdf_pt = float(pt.cdf(torch.tensor(x))) 76 | self.assertRelativelyEqual(cdf_sc, cdf_pt) 77 | 78 | log_prob_sc = sc.log_prob(x) 79 | log_prob_pt = float(pt.log_prob(torch.tensor(x))) 80 | self.assertRelativelyEqual(log_prob_sc, log_prob_pt) 81 | 82 | if do_icdf: 83 | icdf_sc = sc.icdf(p) 84 | icdf_pt = float(pt.icdf(torch.tensor(p))) 85 | self.assertRelativelyEqual(icdf_sc, icdf_pt, tol=1e-4, error=1e-3) 86 | 87 | def test_simple(self): 88 | self._test_numerical(0., 1., -2., 0.) 89 | self._test_numerical(0., 1., -2., 1.) 90 | self._test_numerical(0., 1., -2., 2.) 91 | self._test_numerical(0., 1., -1., 0.) 92 | self._test_numerical(0., 1., -1., 1.) 93 | self._test_numerical(0., 1., -1., 2.) 94 | self._test_numerical(0., 1., 0., 1.) 95 | self._test_numerical(0., 1., 0., 2.) 96 | self._test_numerical(1., 2., 1., 2.) 97 | self._test_numerical(1., 2., 2., 4.) 98 | 99 | def test_precision(self): 100 | self._test_numerical(0., 1., 2., 3.) 101 | self._test_numerical(0., 1., 2., 4.) 102 | # self._test_numerical(0., 1., 2., 8.) # fails due to .icdf returning inf 103 | self._test_numerical(0., 1., 2., 8., do_icdf=False) 104 | self._test_numerical(0., 1., 2., 16., do_icdf=False) 105 | self._test_numerical(0., 1., 2., 32., do_icdf=False) 106 | self._test_numerical(0., 1., 2., 64., do_icdf=False) 107 | self._test_numerical(0., 1., 2., 128., do_icdf=False) 108 | self._test_numerical(0., 1., 2., 256., do_icdf=False) 109 | self._test_numerical(0., 1., 2., 512., do_icdf=False) 110 | 111 | def test_support(self): 112 | pt = TruncatedNormalPT(0., 1., -1., 2., validate_args=None) 113 | with self.assertRaises(ValueError) as e: 114 | pt.log_prob(torch.tensor(-10)) 115 | self.assertEqual(str(e.exception), 'The value argument must be within the support') 116 | 117 | def test_cuda(self): 118 | if not torch.cuda.is_available(): 119 | warn('Skipping CUDA tests') 120 | return 121 | loc = torch.tensor([0., 1.]).cuda() 122 | scale = torch.tensor([1., 2.]).cuda() 123 | a = torch.tensor([-1., -10.]).cuda() 124 | b = torch.tensor([0., 100.]).cuda() 125 | pt = TruncatedNormalPT(loc, scale, a, b, validate_args=None) 126 | s = pt.rsample() 127 | self.assertTrue(s.is_cuda) 128 | 129 | 130 | if __name__ == '__main__': 131 | unittest.main() 132 | -------------------------------------------------------------------------------- /TruncatedNormal.py: -------------------------------------------------------------------------------- 1 | import math 2 | from numbers import Number 3 | 4 | import torch 5 | from torch.distributions import Distribution, constraints 6 | from torch.distributions.utils import broadcast_all 7 | 8 | CONST_SQRT_2 = math.sqrt(2) 9 | CONST_INV_SQRT_2PI = 1 / math.sqrt(2 * math.pi) 10 | CONST_INV_SQRT_2 = 1 / math.sqrt(2) 11 | CONST_LOG_INV_SQRT_2PI = math.log(CONST_INV_SQRT_2PI) 12 | CONST_LOG_SQRT_2PI_E = 0.5 * math.log(2 * math.pi * math.e) 13 | 14 | 15 | class TruncatedStandardNormal(Distribution): 16 | """ 17 | Truncated Standard Normal distribution 18 | https://people.sc.fsu.edu/~jburkardt/presentations/truncated_normal.pdf 19 | """ 20 | 21 | arg_constraints = { 22 | 'a': constraints.real, 23 | 'b': constraints.real, 24 | } 25 | has_rsample = True 26 | 27 | def __init__(self, a, b, validate_args=None): 28 | self.a, self.b = broadcast_all(a, b) 29 | if isinstance(a, Number) and isinstance(b, Number): 30 | batch_shape = torch.Size() 31 | else: 32 | batch_shape = self.a.size() 33 | super(TruncatedStandardNormal, self).__init__(batch_shape, validate_args=validate_args) 34 | if self.a.dtype != self.b.dtype: 35 | raise ValueError('Truncation bounds types are different') 36 | if any((self.a >= self.b).view(-1,).tolist()): 37 | raise ValueError('Incorrect truncation range') 38 | eps = torch.finfo(self.a.dtype).eps 39 | self._dtype_min_gt_0 = eps 40 | self._dtype_max_lt_1 = 1 - eps 41 | self._little_phi_a = self._little_phi(self.a) 42 | self._little_phi_b = self._little_phi(self.b) 43 | self._big_phi_a = self._big_phi(self.a) 44 | self._big_phi_b = self._big_phi(self.b) 45 | self._Z = (self._big_phi_b - self._big_phi_a).clamp_min(eps) 46 | self._log_Z = self._Z.log() 47 | little_phi_coeff_a = torch.nan_to_num(self.a, nan=math.nan) 48 | little_phi_coeff_b = torch.nan_to_num(self.b, nan=math.nan) 49 | self._lpbb_m_lpaa_d_Z = (self._little_phi_b * little_phi_coeff_b - self._little_phi_a * little_phi_coeff_a) / self._Z 50 | self._mean = -(self._little_phi_b - self._little_phi_a) / self._Z 51 | self._variance = 1 - self._lpbb_m_lpaa_d_Z - ((self._little_phi_b - self._little_phi_a) / self._Z) ** 2 52 | self._entropy = CONST_LOG_SQRT_2PI_E + self._log_Z - 0.5 * self._lpbb_m_lpaa_d_Z 53 | 54 | @constraints.dependent_property 55 | def support(self): 56 | return constraints.interval(self.a, self.b) 57 | 58 | @property 59 | def mean(self): 60 | return self._mean 61 | 62 | @property 63 | def variance(self): 64 | return self._variance 65 | 66 | @property 67 | def entropy(self): 68 | return self._entropy 69 | 70 | @property 71 | def auc(self): 72 | return self._Z 73 | 74 | @staticmethod 75 | def _little_phi(x): 76 | return (-(x ** 2) * 0.5).exp() * CONST_INV_SQRT_2PI 77 | 78 | @staticmethod 79 | def _big_phi(x): 80 | return 0.5 * (1 + (x * CONST_INV_SQRT_2).erf()) 81 | 82 | @staticmethod 83 | def _inv_big_phi(x): 84 | return CONST_SQRT_2 * (2 * x - 1).erfinv() 85 | 86 | def cdf(self, value): 87 | if self._validate_args: 88 | self._validate_sample(value) 89 | return ((self._big_phi(value) - self._big_phi_a) / self._Z).clamp(0, 1) 90 | 91 | def icdf(self, value): 92 | return self._inv_big_phi(self._big_phi_a + value * self._Z) 93 | 94 | def log_prob(self, value): 95 | if self._validate_args: 96 | self._validate_sample(value) 97 | return CONST_LOG_INV_SQRT_2PI - self._log_Z - (value ** 2) * 0.5 98 | 99 | def rsample(self, sample_shape=torch.Size()): 100 | shape = self._extended_shape(sample_shape) 101 | p = torch.empty(shape, device=self.a.device).uniform_(self._dtype_min_gt_0, self._dtype_max_lt_1) 102 | return self.icdf(p) 103 | 104 | 105 | class TruncatedNormal(TruncatedStandardNormal): 106 | """ 107 | Truncated Normal distribution 108 | https://people.sc.fsu.edu/~jburkardt/presentations/truncated_normal.pdf 109 | """ 110 | 111 | has_rsample = True 112 | 113 | def __init__(self, loc, scale, a, b, validate_args=None): 114 | self.loc, self.scale, a, b = broadcast_all(loc, scale, a, b) 115 | a = (a - self.loc) / self.scale 116 | b = (b - self.loc) / self.scale 117 | super(TruncatedNormal, self).__init__(a, b, validate_args=validate_args) 118 | self._log_scale = self.scale.log() 119 | self._mean = self._mean * self.scale + self.loc 120 | self._variance = self._variance * self.scale ** 2 121 | self._entropy += self._log_scale 122 | 123 | def _to_std_rv(self, value): 124 | return (value - self.loc) / self.scale 125 | 126 | def _from_std_rv(self, value): 127 | return value * self.scale + self.loc 128 | 129 | def cdf(self, value): 130 | return super(TruncatedNormal, self).cdf(self._to_std_rv(value)) 131 | 132 | def icdf(self, value): 133 | return self._from_std_rv(super(TruncatedNormal, self).icdf(value)) 134 | 135 | def log_prob(self, value): 136 | return super(TruncatedNormal, self).log_prob(self._to_std_rv(value)) - self._log_scale 137 | --------------------------------------------------------------------------------