├── LICENSE ├── README.md ├── requirements.txt └── sqrtm.py /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2022 Steven Cheng-Xian Li 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Matrix square root for PyTorch 2 | 3 | A PyTorch function to compute the square root of a matrix with gradient support. 4 | The input matrix is assumed to be positive definite as matrix square root 5 | is not differentiable for matrices with zero eigenvalues. 6 | 7 | 8 | ## Dependency 9 | 10 | * [PyTorch](http://pytorch.org/) >= 1.0 11 | * [NumPy](http://www.numpy.org/) 12 | * [SciPy](https://www.scipy.org/) 13 | 14 | ## Example 15 | 16 | ```python 17 | import torch 18 | from sqrtm import sqrtm 19 | 20 | k = torch.randn(20, 10) 21 | # Create a (hopefully) positive definite matrix 22 | pd_mat = (k.t().matmul(k)).requires_grad_() 23 | sqrt_mat = sqrtm(pd_mat) 24 | sqrt_mat.sum().backward() 25 | ``` 26 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | numpy>=1.14.5 2 | scipy>=1.1.0 3 | torch>=1.0.0 4 | -------------------------------------------------------------------------------- /sqrtm.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.autograd import Function 3 | import numpy as np 4 | import scipy.linalg 5 | 6 | 7 | class MatrixSquareRoot(Function): 8 | """Square root of a positive definite matrix. 9 | 10 | NOTE: matrix square root is not differentiable for matrices with 11 | zero eigenvalues. 12 | """ 13 | @staticmethod 14 | def forward(ctx, input): 15 | m = input.detach().cpu().numpy().astype(np.float_) 16 | sqrtm = torch.from_numpy(scipy.linalg.sqrtm(m).real).to(input) 17 | ctx.save_for_backward(sqrtm) 18 | return sqrtm 19 | 20 | @staticmethod 21 | def backward(ctx, grad_output): 22 | grad_input = None 23 | if ctx.needs_input_grad[0]: 24 | sqrtm, = ctx.saved_tensors 25 | sqrtm = sqrtm.data.cpu().numpy().astype(np.float_) 26 | gm = grad_output.data.cpu().numpy().astype(np.float_) 27 | 28 | # Given a positive semi-definite matrix X, 29 | # since X = X^{1/2}X^{1/2}, we can compute the gradient of the 30 | # matrix square root dX^{1/2} by solving the Sylvester equation: 31 | # dX = (d(X^{1/2})X^{1/2} + X^{1/2}(dX^{1/2}). 32 | grad_sqrtm = scipy.linalg.solve_sylvester(sqrtm, sqrtm, gm) 33 | 34 | grad_input = torch.from_numpy(grad_sqrtm).to(grad_output) 35 | return grad_input 36 | 37 | 38 | sqrtm = MatrixSquareRoot.apply 39 | 40 | 41 | def main(): 42 | from torch.autograd import gradcheck 43 | k = torch.randn(20, 10).double() 44 | # Create a positive definite matrix 45 | pd_mat = (k.t().matmul(k)).requires_grad_() 46 | test = gradcheck(sqrtm, (pd_mat,)) 47 | print(test) 48 | 49 | 50 | if __name__ == '__main__': 51 | main() 52 | --------------------------------------------------------------------------------