├── .gitignore ├── LICENSE ├── README.md ├── deep_evidential_regression_loss_pytorch ├── __init__.py ├── loss.py └── paper_loss.py ├── examples ├── README.md ├── example_dataset_experiment_1.ipynb └── plot_loss_function.ipynb ├── setup.py └── tests └── README.md /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | pip-wheel-metadata/ 24 | share/python-wheels/ 25 | *.egg-info/ 26 | .installed.cfg 27 | *.egg 28 | MANIFEST 29 | 30 | # PyInstaller 31 | # Usually these files are written by a python script from a template 32 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 33 | *.manifest 34 | *.spec 35 | 36 | # Installer logs 37 | pip-log.txt 38 | pip-delete-this-directory.txt 39 | 40 | # Unit test / coverage reports 41 | htmlcov/ 42 | .tox/ 43 | .nox/ 44 | .coverage 45 | .coverage.* 46 | .cache 47 | nosetests.xml 48 | coverage.xml 49 | *.cover 50 | *.py,cover 51 | .hypothesis/ 52 | .pytest_cache/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | target/ 76 | 77 | # Jupyter Notebook 78 | .ipynb_checkpoints 79 | 80 | # IPython 81 | profile_default/ 82 | ipython_config.py 83 | 84 | # pyenv 85 | .python-version 86 | 87 | # pipenv 88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 91 | # install all needed dependencies. 92 | #Pipfile.lock 93 | 94 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 95 | __pypackages__/ 96 | 97 | # Celery stuff 98 | celerybeat-schedule 99 | celerybeat.pid 100 | 101 | # SageMath parsed files 102 | *.sage.py 103 | 104 | # Environments 105 | .env 106 | .venv 107 | env/ 108 | venv/ 109 | ENV/ 110 | env.bak/ 111 | venv.bak/ 112 | 113 | # Spyder project settings 114 | .spyderproject 115 | .spyproject 116 | 117 | # Rope project settings 118 | .ropeproject 119 | 120 | # mkdocs documentation 121 | /site 122 | 123 | # mypy 124 | .mypy_cache/ 125 | .dmypy.json 126 | dmypy.json 127 | 128 | # Pyre type checker 129 | .pyre/ 130 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright [yyyy] [name of copyright owner] 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | --- 2 | 3 |
4 | 5 | # Deep Evidential Regression Loss Function 6 | 7 | [![Paper](http://img.shields.io/badge/paper-arxiv.1001.2234-B31B1B.svg)](https://arxiv.org/abs/1910.02600) 8 | [![Conference](http://img.shields.io/badge/ICLR-2020-4b44ce.svg)](https://openreview.net/forum?id=S1eSoeSYwr) 9 | 10 | 13 |
14 | 15 | ## Description 16 | The paper "Deep Evidential Uncertainty/Regression" was submitted to ICLR where 17 | it was rejected[1]. The idea is inline with light of Sensoy et al.[2] and Malinin & Gales[3]. 18 | It was rejected because of lack of experiments and similar ideas with Malinin thesis. 19 | The goal is to implement the loss function and validate the results. 20 | 21 | 22 | ## Installation 23 | 24 | ### Typical Install 25 | ``` 26 | pip install git+https://github.com/deebuls/deep_evidential_regression_loss_pytorch 27 | ``` 28 | 29 | ### Development 30 | ``` 31 | git clone https://github.com/deebuls/deep_evidential_regression_loss_pytorch 32 | cd deep_evidential_regression_loss_pytorch 33 | pip install -e .[dev] 34 | ``` 35 | 36 | Tests can then be run from the root of the project using: 37 | ``` 38 | nosetests 39 | ``` 40 | 41 | ## Usage 42 | 43 | To use this code `EvidentialLossSumOfSquares` and create loss function. `loss.py` 44 | implements the evidential loss function. 45 | 46 | Check examples for detailed usage example 47 | 48 | ## ToDo 49 | 50 | 1. Different variation of the loss (NLL, with log(alpha, beta, lambda)) 51 | 2. When output is image as case of VAE 52 | 3. Examples 53 | 4. Test cases 54 | 55 | 56 | ## Abstract 57 | Deterministic neural networks (NNs) are increasingly being deployed in safety 58 | critical domains, where calibrated, robust and efficient measures of 59 | uncertainty are crucial. While it is possible to train regression networks to 60 | output the parameters of a probability distribution by maximizing a Gaussian 61 | likelihood function, the resulting model remains oblivious to the underlying 62 | confidence of its predictions. In this paper, we propose a novel method for 63 | training deterministic NNs to not only estimate the desired target but also the 64 | associated evidence in support of that target. We accomplish this by placing 65 | evidential priors over our original Gaussian likelihood function and training 66 | our NN to infer the hyperparameters of our evidential distribution. We impose 67 | priors during training such that the model is penalized when its predicted 68 | evidence is not aligned with the correct output. Thus the model estimates not 69 | only the probabilistic mean and variance of our target but also the underlying 70 | uncertainty associated with each of those parameters. We observe that our 71 | evidential regression method learns well-calibrated measures of uncertainty on 72 | various benchmarks, scales to complex computer vision tasks, and is robust to 73 | adversarial input perturbations. 74 | 75 | ## References 76 | * [1] https://openreview.net/forum?id=S1eSoeSYwr¬eId=78WcDK50Bi 77 | * [2] M. Sensoy, et al. "Evidential deep learning to quantify classification uncertainty." NeurIPS. 2018. 78 | * [3] A. Malinin, et al. Predictive uncertainty estimation via prior networks. NeurIPS 2018. 79 | -------------------------------------------------------------------------------- /deep_evidential_regression_loss_pytorch/__init__.py: -------------------------------------------------------------------------------- 1 | from .loss import EvidentialLossSumOfSquares 2 | from .paper_loss import PaperEvidentialLossSumOfSquares 3 | -------------------------------------------------------------------------------- /deep_evidential_regression_loss_pytorch/loss.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | """Implements the evidential loss using Normal Inverse Gamma Distribution 15 | Use this function when you want to model your regression output as a 16 | normal inverse gamma distribution. 17 | """ 18 | 19 | from __future__ import absolute_import 20 | from __future__ import division 21 | from __future__ import print_function 22 | 23 | import numpy as np 24 | import torch 25 | import torch.nn as nn 26 | 27 | class EvidentialLossSumOfSquares(nn.Module): 28 | """The evidential loss function on a matrix. 29 | 30 | This class is implemented with slight modifications from the paper. The major 31 | change is in the regularizer parameter mentioned in the paper. The regularizer 32 | mentioned in the paper didnot give the required results, so we modified it 33 | with the KL divergence regularizer from the paper. In orderto overcome the problem 34 | that KL divergence are missing near zero so we add the minimum values to alpha, 35 | beta and lambda and compare distance with NIG(alpha=1.0, beta=0.1, lambda=1.0) 36 | 37 | This class only allows for rank-4 inputs for the output `targets`, and expectes 38 | `inputs` be of the form [mu, alpha, beta, lambda] 39 | 40 | alpha, beta and lambda needs to be positive values. 41 | """ 42 | 43 | def __init__(self, debug=False, return_all=False): 44 | """Sets up loss function. 45 | 46 | Args: 47 | debug: When set to 'true' prints all the intermittent values 48 | return_all: When set to 'true' returns all loss values without taking average 49 | 50 | """ 51 | super(EvidentialLossSumOfSquares, self).__init__() 52 | 53 | self.debug = debug 54 | self.return_all_values = return_all 55 | self.MAX_CLAMP_VALUE = 5.0 # Max you can go is 85 because exp(86) is nan Now exp(5.0) is 143 which is max of a,b and l 56 | 57 | def kl_divergence_nig(self, mu1, mu2, alpha_1, beta_1, lambda_1): 58 | alpha_2 = torch.ones_like(mu1)*1.0 59 | beta_2 = torch.ones_like(mu1)*0.1 60 | lambda_2 = torch.ones_like(mu1)*1.0 61 | 62 | t1 = 0.5 * (alpha_1/beta_1) * ((mu1 - mu2)**2) * lambda_2 63 | #t1 = 0.5 * (alpha_1/beta_1) * (torch.abs(mu1 - mu2)) * lambda_2 64 | t2 = 0.5*lambda_2/lambda_1 65 | t3 = alpha_2*torch.log(beta_1/beta_2) 66 | t4 = -torch.lgamma(alpha_1) + torch.lgamma(alpha_2) 67 | t5 = (alpha_1-alpha_2)*torch.digamma(alpha_1) 68 | t6 = -(beta_1 - beta_2)*(alpha_1/beta_1) 69 | return (t1+t2-0.5+t3+t4+t5+t6) 70 | 71 | def forward(self, inputs, targets): 72 | """ Implements the loss function 73 | 74 | Args: 75 | inputs: The output of the neural network. inputs has 4 dimension 76 | in the format [mu, alpha, beta, lambda]. Must be a tensor of 77 | floats 78 | targets: The expected output 79 | 80 | Returns: 81 | Based on the `return_all` it will return mean loss of batch or individual loss 82 | 83 | """ 84 | assert torch.is_tensor(inputs) 85 | assert torch.is_tensor(targets) 86 | assert (inputs[:,1] > 0).all() 87 | assert (inputs[:,2] > 0).all() 88 | assert (inputs[:,3] > 0).all() 89 | 90 | targets = targets.view(-1) 91 | y = inputs[:,0].view(-1) #first column is mu,delta, predicted value 92 | a = inputs[:,1].view(-1) + 1.0 #alpha 93 | b = inputs[:,2].view(-1) + 0.1 #beta to avoid zero 94 | l = inputs[:,3].view(-1) + 1.0 #lamda 95 | 96 | if self.debug: 97 | print("a :", a) 98 | print("b :", b) 99 | print("l :", l) 100 | 101 | J1 = torch.lgamma(a - 0.5) 102 | J2 = -torch.log(torch.tensor([4.0])) 103 | J3 = -torch.lgamma(a) 104 | J4 = -torch.log(l) 105 | J5 = -0.5*torch.log(b) 106 | J6 = torch.log(2*b*(1 + l) + (2*a - 1)*l*(y-targets)**2) 107 | 108 | if self.debug: 109 | print("lgama(a - 0.5) :", J1) 110 | print("log(4):", J2) 111 | print("lgama(a) :", J3) 112 | print("log(l) :", J4) 113 | print("log( ---- ) :", J6) 114 | 115 | J = J1 + J2 + J3 + J4 + J5 + J6 116 | #Kl_divergence = torch.abs(y - targets) * (2*a + l)/b ######## ????? 117 | #Kl_divergence = ((y - targets)**2) * (2*a + l) 118 | #Kl_divergence = torch.abs(y - targets) * (2*a + l) 119 | #Kl_divergence = 0.0 120 | #Kl_divergence = (torch.abs(y - targets) * (a-1) * l)/b 121 | Kl_divergence = self.kl_divergence_nig(y, targets, a, b, l) 122 | 123 | if self.debug: 124 | print ("KL ",Kl_divergence.data.numpy()) 125 | loss = torch.exp(J) + Kl_divergence 126 | 127 | if self.debug: 128 | print ("loss :", loss.mean()) 129 | 130 | 131 | if self.return_all_values: 132 | ret_loss = loss 133 | else: 134 | ret_loss = loss.mean() 135 | #if torch.isnan(ret_loss): 136 | # ret_loss.item() = self.prev_loss + 10 137 | #else: 138 | # self.prev_loss = ret_loss.item() 139 | 140 | return ret_loss 141 | -------------------------------------------------------------------------------- /deep_evidential_regression_loss_pytorch/paper_loss.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | """Implements the evidential loss using Normal Inverse Gamma Distribution 15 | Use this function when you want to model your regression output as a 16 | normal inverse gamma distribution. 17 | """ 18 | 19 | from __future__ import absolute_import 20 | from __future__ import division 21 | from __future__ import print_function 22 | 23 | import numpy as np 24 | import torch 25 | import torch.nn as nn 26 | 27 | class PaperEvidentialLossSumOfSquares(nn.Module): 28 | """The evidential loss function on a matrix. 29 | 30 | This class is implemented same as the paper. 31 | 32 | This class only allows for rank-4 inputs for the output `targets`, and expectes 33 | `inputs` be of the form [mu, alpha, beta, lambda] 34 | 35 | alpha, beta and lambda needs to be positive values. 36 | """ 37 | 38 | def __init__(self, debug=False, return_all=False): 39 | """Sets up loss function. 40 | 41 | Args: 42 | debug: When set to 'true' prints all the intermittent values 43 | return_all: When set to 'true' returns all loss values without taking average 44 | 45 | """ 46 | super(PaperEvidentialLossSumOfSquares, self).__init__() 47 | 48 | self.debug = debug 49 | self.return_all_values = return_all 50 | self.MAX_CLAMP_VALUE = 5.0 # Max you can go is 85 because exp(86) is nan Now exp(5.0) is 143 which is max of a,b and l 51 | 52 | 53 | def forward(self, inputs, targets): 54 | """ Implements the loss function 55 | 56 | Args: 57 | inputs: The output of the neural network. inputs has 4 dimension 58 | in the format [mu, alpha, beta, lambda]. Must be a tensor of 59 | floats 60 | targets: The expected output 61 | 62 | Returns: 63 | Based on the `return_all` it will return mean loss of batch or individual loss 64 | 65 | """ 66 | assert torch.is_tensor(inputs) 67 | assert torch.is_tensor(targets) 68 | assert (inputs[:,1] > 0).all() 69 | assert (inputs[:,2] > 0).all() 70 | assert (inputs[:,3] > 0).all() 71 | 72 | targets = targets.view(-1) 73 | y = inputs[:,0].view(-1) #first column is mu,delta, predicted value 74 | a = inputs[:,1].view(-1) #alpha 75 | b = inputs[:,2].view(-1) #beta to avoid zero 76 | l = inputs[:,3].view(-1) #lamda 77 | 78 | if self.debug: 79 | print("a :", a) 80 | print("b :", b) 81 | print("l :", l) 82 | 83 | #machine epsilon for safe cases 84 | machine_epsilon = torch.tensor(np.finfo(np.float32).eps) 85 | safe_a = torch.max(machine_epsilon, a) 86 | safe_b = torch.max(machine_epsilon, b) 87 | safe_l = torch.max(machine_epsilon, l) 88 | J1 = torch.lgamma(torch.max(a - 0.5)) 89 | J2 = -torch.log(torch.tensor([4.0])) 90 | J3 = -torch.lgamma(safe_a) 91 | J4 = -torch.log(safe_l) 92 | J5 = -0.5*torch.log(safe_b) 93 | J6 = torch.log(torch.max(machine_epsilon, 2*b*(1 + l) + (2*a - 1)*l*(y-targets)**2)) 94 | 95 | if self.debug: 96 | print("lgama(a - 0.5) :", J1) 97 | print("log(4):", J2) 98 | print("lgama(a) :", J3) 99 | print("log(l) :", J4) 100 | print("log( ---- ) :", J6) 101 | 102 | J = J1 + J2 + J3 + J4 + J5 + J6 103 | #Kl_divergence = torch.abs(y - targets) * (2*a + l)/b ######## ????? 104 | #Kl_divergence = ((y - targets)**2) * (2*safe_a + safe_l) 105 | #Kl_divergence = torch.abs(y - targets) * (2*a + l) 106 | #Kl_divergence = 0.0 107 | #Kl_divergence = (torch.abs(y - targets) * (a-1) * l)/b 108 | Kl_divergence = torch.norm(y - targets)*(2*safe_a + safe_l) 109 | 110 | if self.debug: 111 | print ("KL ",Kl_divergence.data.numpy()) 112 | loss = torch.exp(J) + Kl_divergence 113 | 114 | if self.debug: 115 | print ("loss :", loss.mean()) 116 | 117 | 118 | if self.return_all_values: 119 | ret_loss = loss 120 | else: 121 | ret_loss = loss.mean() 122 | #if torch.isnan(ret_loss): 123 | # ret_loss.item() = self.prev_loss + 10 124 | #else: 125 | # self.prev_loss = ret_loss.item() 126 | 127 | return ret_loss 128 | -------------------------------------------------------------------------------- /examples/README.md: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/deebuls/deep_evidential_regression_loss_pytorch/373dbd492f47fb063b0aaaa1a62e1fe7328a0812/examples/README.md -------------------------------------------------------------------------------- /examples/plot_loss_function.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "nbformat": 4, 3 | "nbformat_minor": 0, 4 | "metadata": { 5 | "colab": { 6 | "name": "plot_loss_function.ipynb", 7 | "provenance": [], 8 | "authorship_tag": "ABX9TyPZlZ02ifV0YVljczxIFYpH", 9 | "include_colab_link": true 10 | }, 11 | "kernelspec": { 12 | "name": "python3", 13 | "display_name": "Python 3" 14 | } 15 | }, 16 | "cells": [ 17 | { 18 | "cell_type": "markdown", 19 | "metadata": { 20 | "id": "view-in-github", 21 | "colab_type": "text" 22 | }, 23 | "source": [ 24 | "\"Open" 25 | ] 26 | }, 27 | { 28 | "cell_type": "code", 29 | "metadata": { 30 | "id": "9soCgj4iCiGF", 31 | "outputId": "ad06ba66-21ed-4071-f1cb-c273a545ec08", 32 | "colab": { 33 | "base_uri": "https://localhost:8080/" 34 | } 35 | }, 36 | "source": [ 37 | "!pip install git+https://github.com/deebuls/deep_evidential_regression_loss_pytorch" 38 | ], 39 | "execution_count": 1, 40 | "outputs": [ 41 | { 42 | "output_type": "stream", 43 | "text": [ 44 | "Collecting git+https://github.com/deebuls/deep_evidential_regression_loss_pytorch\n", 45 | " Cloning https://github.com/deebuls/deep_evidential_regression_loss_pytorch to /tmp/pip-req-build-_6zgu7no\n", 46 | " Running command git clone -q https://github.com/deebuls/deep_evidential_regression_loss_pytorch /tmp/pip-req-build-_6zgu7no\n", 47 | "Requirement already satisfied: torch>=1.3.1 in /usr/local/lib/python3.6/dist-packages (from deep-evidential-regression-loss-pytorch==0.0.1) (1.7.0+cu101)\n", 48 | "Collecting torch-dct\n", 49 | " Downloading https://files.pythonhosted.org/packages/8f/20/6f6280ed77a0382ae6226c5250c02f64924b8fc73d9aa7d73b9c6b3ee6a5/torch_dct-0.1.5-py3-none-any.whl\n", 50 | "Requirement already satisfied: numpy>=1.15.4 in /usr/local/lib/python3.6/dist-packages (from deep-evidential-regression-loss-pytorch==0.0.1) (1.18.5)\n", 51 | "Requirement already satisfied: scipy>=1.1.0 in /usr/local/lib/python3.6/dist-packages (from deep-evidential-regression-loss-pytorch==0.0.1) (1.4.1)\n", 52 | "Requirement already satisfied: absl-py>=0.1.9 in /usr/local/lib/python3.6/dist-packages (from deep-evidential-regression-loss-pytorch==0.0.1) (0.10.0)\n", 53 | "Requirement already satisfied: mpmath>=1.1.0 in /usr/local/lib/python3.6/dist-packages (from deep-evidential-regression-loss-pytorch==0.0.1) (1.1.0)\n", 54 | "Requirement already satisfied: typing-extensions in /usr/local/lib/python3.6/dist-packages (from torch>=1.3.1->deep-evidential-regression-loss-pytorch==0.0.1) (3.7.4.3)\n", 55 | "Requirement already satisfied: dataclasses in /usr/local/lib/python3.6/dist-packages (from torch>=1.3.1->deep-evidential-regression-loss-pytorch==0.0.1) (0.7)\n", 56 | "Requirement already satisfied: future in /usr/local/lib/python3.6/dist-packages (from torch>=1.3.1->deep-evidential-regression-loss-pytorch==0.0.1) (0.16.0)\n", 57 | "Requirement already satisfied: six in /usr/local/lib/python3.6/dist-packages (from absl-py>=0.1.9->deep-evidential-regression-loss-pytorch==0.0.1) (1.15.0)\n", 58 | "Building wheels for collected packages: deep-evidential-regression-loss-pytorch\n", 59 | " Building wheel for deep-evidential-regression-loss-pytorch (setup.py) ... \u001b[?25l\u001b[?25hdone\n", 60 | " Created wheel for deep-evidential-regression-loss-pytorch: filename=deep_evidential_regression_loss_pytorch-0.0.1-cp36-none-any.whl size=9547 sha256=a103d7cfdf0ead83de6b64105baefc37afe37c2286ad61fe1dca2d3f3695ca44\n", 61 | " Stored in directory: /tmp/pip-ephem-wheel-cache-vm98g9aj/wheels/55/4b/1c/769c7b66ab5dd04cef9dcb1e43a4797d9d467680ace7b95222\n", 62 | "Successfully built deep-evidential-regression-loss-pytorch\n", 63 | "Installing collected packages: torch-dct, deep-evidential-regression-loss-pytorch\n", 64 | "Successfully installed deep-evidential-regression-loss-pytorch-0.0.1 torch-dct-0.1.5\n" 65 | ], 66 | "name": "stdout" 67 | } 68 | ] 69 | }, 70 | { 71 | "cell_type": "code", 72 | "metadata": { 73 | "id": "rHUekTvqC8kf" 74 | }, 75 | "source": [ 76 | "from deep_evidential_regression_loss_pytorch import EvidentialLossSumOfSquares" 77 | ], 78 | "execution_count": 2, 79 | "outputs": [] 80 | }, 81 | { 82 | "cell_type": "code", 83 | "metadata": { 84 | "id": "xD3M8EzgwfaN" 85 | }, 86 | "source": [ 87 | "import numpy as np\n", 88 | "import torch\n", 89 | "import pandas as pd\n", 90 | "import seaborn as sns" 91 | ], 92 | "execution_count": 5, 93 | "outputs": [] 94 | }, 95 | { 96 | "cell_type": "code", 97 | "metadata": { 98 | "id": "UUIWqzG60yIa", 99 | "outputId": "f4ccdf62-7ca2-4017-f17b-8a6e7388e084", 100 | "colab": { 101 | "base_uri": "https://localhost:8080/" 102 | } 103 | }, 104 | "source": [ 105 | "criterion = EvidentialLossSumOfSquares(return_all=True)\n", 106 | "alphas = [1.0, 1.5, 2.0]\n", 107 | "betas = [0.1, 0.5, 1.0]\n", 108 | "lambdas = [1.0, 1.5, 2.0]\n", 109 | "\n", 110 | "diff = np.linspace(-1, 1, 100)\n", 111 | "val = np.zeros(100)\n", 112 | "val = val + diff\n", 113 | "\n", 114 | "all_data_loss = []\n", 115 | "\n", 116 | "for aa in alphas:\n", 117 | " for bb in betas:\n", 118 | " for ll in lambdas:\n", 119 | " temp = np.vstack((val, np.ones(100)*aa, np.ones(100)*bb, np.ones(100)*ll)).T\n", 120 | " inputs = torch.tensor(temp)\n", 121 | " targets = torch.zeros((100,1))\n", 122 | " loss_temp = criterion(inputs,targets).data.numpy()\n", 123 | " one_iter = np.hstack((temp, loss_temp.reshape(100,1)))\n", 124 | " all_data_loss.extend(one_iter)\n", 125 | "\n", 126 | "all_data_loss = pd.DataFrame(all_data_loss, columns=['residual', 'alpha', 'beta', 'lambda', 'loss'])\n", 127 | "all_data_loss.shape" 128 | ], 129 | "execution_count": 12, 130 | "outputs": [ 131 | { 132 | "output_type": "execute_result", 133 | "data": { 134 | "text/plain": [ 135 | "(2700, 5)" 136 | ] 137 | }, 138 | "metadata": { 139 | "tags": [] 140 | }, 141 | "execution_count": 12 142 | } 143 | ] 144 | }, 145 | { 146 | "cell_type": "code", 147 | "metadata": { 148 | "id": "JAOU6Oqy-MbG", 149 | "outputId": "f85cc4ee-c57b-43cf-86c0-f6ab69645344", 150 | "colab": { 151 | "base_uri": "https://localhost:8080/", 152 | "height": 455 153 | } 154 | }, 155 | "source": [ 156 | "g = sns.relplot(\n", 157 | " data=all_data_loss,\n", 158 | " x=\"residual\", y=\"loss\", col=\"alpha\", row=\"lambda\",\n", 159 | " hue=\"beta\",\n", 160 | " kind=\"line\", palette=\"crest\", linewidth=4, zorder=5,\n", 161 | " height=2, aspect=1.5, legend=True\n", 162 | " )" 163 | ], 164 | "execution_count": 14, 165 | "outputs": [ 166 | { 167 | "output_type": "display_data", 168 | "data": { 169 | "image/png": "\n", 170 | "text/plain": [ 171 | "
" 172 | ] 173 | }, 174 | "metadata": { 175 | "tags": [], 176 | "needs_background": "light" 177 | } 178 | } 179 | ] 180 | }, 181 | { 182 | "cell_type": "code", 183 | "metadata": { 184 | "id": "W32GRLv08nwq" 185 | }, 186 | "source": [ 187 | "" 188 | ], 189 | "execution_count": null, 190 | "outputs": [] 191 | } 192 | ] 193 | } -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | import io 2 | import os 3 | import re 4 | 5 | from setuptools import find_packages 6 | from setuptools import setup 7 | 8 | 9 | def read(filename): 10 | filename = os.path.join(os.path.dirname(__file__), filename) 11 | text_type = type(u"") 12 | with io.open(filename, mode="r", encoding='utf-8') as fd: 13 | return re.sub(text_type(r':[a-z]+:`~?(.*?)`'), text_type(r'``\1``'), fd.read()) 14 | 15 | 16 | requirements = [ 17 | 'torch>=1.3.1', 18 | 'torch-dct', 19 | 'numpy>=1.15.4', 20 | 'scipy>=1.1.0', 21 | 'absl-py>=0.1.9', 22 | 'mpmath>=1.1.0', 23 | ] 24 | 25 | requirements_dev = [ 26 | 'Pillow', 27 | 'nose' 28 | ] 29 | 30 | 31 | setup( 32 | name="deep_evidential_regression_loss_pytorch", 33 | version="0.0.1", 34 | url="https://github.com/deebuls/deep_evidential_regression_loss_pytorch", 35 | license='Apache 2.0', 36 | author="Deebul S. Nair", 37 | author_email="deebul.nair@h-brs.de", 38 | description="A Loss function which predicts posterior distribution for regression problems", 39 | long_description=read("README.md"), 40 | #package_dir={'':'deep_evidential_regression_loss_pytorch'}, # Optional 41 | #packages=find_packages(where=('deep_evidential_regression_loss_pytorch'),exclude=('tests',)), 42 | packages=find_packages(exclude=('tests',)), 43 | install_requires=requirements, 44 | extras_require={ 45 | 'dev': requirements_dev 46 | }, 47 | ) 48 | 49 | -------------------------------------------------------------------------------- /tests/README.md: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/deebuls/deep_evidential_regression_loss_pytorch/373dbd492f47fb063b0aaaa1a62e1fe7328a0812/tests/README.md --------------------------------------------------------------------------------