├── LICENSE ├── Readme.md ├── docs ├── Makefile ├── make.bat └── source │ ├── conf.py │ ├── examples │ ├── backprojection.ipynb │ ├── index.rst │ └── pga.ipynb │ ├── index.rst │ ├── modules.rst │ └── torchbp.rst ├── examples ├── Readme.md ├── sar_polar_to_cart.py ├── sar_process_safetensor.py └── sar_process_safetensor_gpga.py ├── pyproject.toml ├── requirements.txt ├── setup.py ├── tests ├── benchmark_backprojection.py └── test_torchbp.py └── torchbp ├── __init__.py ├── autofocus.py ├── csrc ├── cuda │ ├── std_complex.h │ └── torchbp.cu └── torchbp.cpp ├── ops.py └── util.py /LICENSE: -------------------------------------------------------------------------------- 1 | Copyright 2025 Henrik Forstén 2 | 3 | Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the “Software”), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions: 4 | 5 | The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software. 6 | 7 | THE SOFTWARE IS PROVIDED “AS IS”, WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. 8 | -------------------------------------------------------------------------------- /Readme.md: -------------------------------------------------------------------------------- 1 | # Torchbp 2 | 3 | Fast C++ Pytorch extension for differentiable synthetic aperture radar image formation and autofocus library on CPU and GPU. 4 | 5 | Only Nvidia GPUs are supported. Currently, some operations are not supported on CPU. 6 | 7 | On RTX 3090 Ti backprojection on polar grid achieves 225 billion backprojections/s. 8 | 9 | ## Installation 10 | 11 | Tested with CUDA version 12.1, some newer versions might cause build issues. 12 | 13 | ### From source 14 | 15 | ```bash 16 | git clone https://github.com/Ttl/torchbp.git 17 | cd torchbp 18 | pip install . 19 | ``` 20 | 21 | ## Documentation 22 | 23 | API documentation and examples can be built with sphinx. 24 | 25 | ```bash 26 | pip install .[docs] 27 | cd docs 28 | make html 29 | ``` 30 | 31 | Open `docs/build/html/index.html`. 32 | -------------------------------------------------------------------------------- /docs/Makefile: -------------------------------------------------------------------------------- 1 | # Minimal makefile for Sphinx documentation 2 | # 3 | 4 | # You can set these variables from the command line, and also 5 | # from the environment for the first two. 6 | SPHINXOPTS ?= 7 | SPHINXBUILD ?= sphinx-build 8 | SOURCEDIR = source 9 | BUILDDIR = build 10 | 11 | # Put it first so that "make" without argument is like "make help". 12 | help: 13 | @$(SPHINXBUILD) -M help "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) 14 | 15 | .PHONY: help Makefile 16 | 17 | # Catch-all target: route all unknown targets to Sphinx using the new 18 | # "make mode" option. $(O) is meant as a shortcut for $(SPHINXOPTS). 19 | %: Makefile 20 | sphinx-apidoc -f -o "$(SOURCEDIR)" ../torchbp 21 | @$(SPHINXBUILD) -M $@ "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) 22 | -------------------------------------------------------------------------------- /docs/make.bat: -------------------------------------------------------------------------------- 1 | @ECHO OFF 2 | 3 | pushd %~dp0 4 | 5 | REM Command file for Sphinx documentation 6 | 7 | if "%SPHINXBUILD%" == "" ( 8 | set SPHINXBUILD=sphinx-build 9 | ) 10 | set SOURCEDIR=source 11 | set BUILDDIR=build 12 | 13 | %SPHINXBUILD% >NUL 2>NUL 14 | if errorlevel 9009 ( 15 | echo. 16 | echo.The 'sphinx-build' command was not found. Make sure you have Sphinx 17 | echo.installed, then set the SPHINXBUILD environment variable to point 18 | echo.to the full path of the 'sphinx-build' executable. Alternatively you 19 | echo.may add the Sphinx directory to PATH. 20 | echo. 21 | echo.If you don't have Sphinx installed, grab it from 22 | echo.https://www.sphinx-doc.org/ 23 | exit /b 1 24 | ) 25 | 26 | if "%1" == "" goto help 27 | 28 | %SPHINXBUILD% -M %1 %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% %O% 29 | goto end 30 | 31 | :help 32 | %SPHINXBUILD% -M help %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% %O% 33 | 34 | :end 35 | popd 36 | -------------------------------------------------------------------------------- /docs/source/conf.py: -------------------------------------------------------------------------------- 1 | # Configuration file for the Sphinx documentation builder. 2 | # 3 | # For the full list of built-in configuration values, see the documentation: 4 | # https://www.sphinx-doc.org/en/master/usage/configuration.html 5 | 6 | # -- Project information ----------------------------------------------------- 7 | # https://www.sphinx-doc.org/en/master/usage/configuration.html#project-information 8 | import sys 9 | import os 10 | project = 'torchbp' 11 | copyright = '2024, Henrik Forstén' 12 | author = 'Henrik Forstén' 13 | release = '0.0.1' 14 | 15 | sys.path.insert(0, os.path.abspath('../../')) 16 | 17 | import torchbp 18 | 19 | # -- General configuration --------------------------------------------------- 20 | # https://www.sphinx-doc.org/en/master/usage/configuration.html#general-configuration 21 | 22 | extensions = [ 23 | 'sphinx.ext.autodoc', 24 | 'sphinx.ext.autosectionlabel', 25 | 'sphinx.ext.autosummary', 26 | 'sphinx.ext.napoleon', 27 | 'sphinx.ext.mathjax', 28 | 'sphinx.ext.viewcode', 29 | 'sphinx.ext.intersphinx', 30 | 'sphinx_rtd_theme', 31 | 'nbsphinx', 32 | 'IPython.sphinxext.ipython_directive', 33 | 'IPython.sphinxext.ipython_console_highlighting', 34 | ] 35 | 36 | # Napoleon settings 37 | napoleon_google_docstring = False 38 | napoleon_numpy_docstring = True 39 | napoleon_include_init_with_doc = False 40 | napoleon_include_private_with_doc = False 41 | napoleon_include_special_with_doc = False 42 | napoleon_use_admonition_for_examples = False 43 | napoleon_use_admonition_for_notes = False 44 | napoleon_use_admonition_for_references = False 45 | napoleon_use_ivar = False 46 | napoleon_use_param = True 47 | napoleon_use_rtype = True 48 | 49 | # NBsphinx settings 50 | nbsphinx_execute = 'always' 51 | nbsphinx_allow_errors = True 52 | nbsphinx_kernel_name = 'python' 53 | numpydoc_show_class_members = False 54 | nbsphinx_timeout = 180 55 | 56 | # Autodoc settings 57 | autodoc_typehints = "description" 58 | 59 | # Add any paths that contain templates here, relative to this directory. 60 | templates_path = ['_templates'] 61 | 62 | # The suffix of source filenames. 63 | source_suffix = '.rst' 64 | 65 | # The master toctree document. 66 | master_doc = 'index' 67 | 68 | #templates_path = ['_templates'] 69 | exclude_patterns = [] 70 | 71 | autosummary_generate = True 72 | 73 | # -- Options for HTML output ------------------------------------------------- 74 | # https://www.sphinx-doc.org/en/master/usage/configuration.html#options-for-html-output 75 | 76 | html_theme = "sphinx_rtd_theme" 77 | html_static_path = ['_static'] 78 | -------------------------------------------------------------------------------- /docs/source/examples/backprojection.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "id": "897bb39b-547f-4993-a49c-2018cc9648e3", 6 | "metadata": {}, 7 | "source": [ 8 | "# Backprojection image formation" 9 | ] 10 | }, 11 | { 12 | "cell_type": "code", 13 | "execution_count": null, 14 | "id": "3a958b90-c742-4fa2-aef9-52202b7c8aee", 15 | "metadata": {}, 16 | "outputs": [], 17 | "source": [ 18 | "import torch\n", 19 | "import torchbp\n", 20 | "import matplotlib.pyplot as plt\n", 21 | "from numpy import hamming" 22 | ] 23 | }, 24 | { 25 | "cell_type": "markdown", 26 | "id": "d247ef25-faf9-469d-bf68-54c716f41a37", 27 | "metadata": {}, 28 | "source": [ 29 | "Use CUDA if it's available, CPU otherwise" 30 | ] 31 | }, 32 | { 33 | "cell_type": "code", 34 | "execution_count": null, 35 | "id": "4062655e-b0dd-4f95-ad06-493ffb13445a", 36 | "metadata": {}, 37 | "outputs": [], 38 | "source": [ 39 | "if torch.cuda.is_available():\n", 40 | " device = \"cuda\"\n", 41 | "else:\n", 42 | " device = \"cpu\"\n", 43 | "print(\"Device:\", device)" 44 | ] 45 | }, 46 | { 47 | "cell_type": "markdown", 48 | "id": "f0d43373-ef7f-4895-852b-d43941fe1201", 49 | "metadata": {}, 50 | "source": [ 51 | "Constant definitions" 52 | ] 53 | }, 54 | { 55 | "cell_type": "code", 56 | "execution_count": null, 57 | "id": "a31e5901-925c-49bb-839c-b1bbaabf9294", 58 | "metadata": {}, 59 | "outputs": [], 60 | "source": [ 61 | "nr = 128 # Range points\n", 62 | "ntheta = 128 # Azimuth points\n", 63 | "nsweeps = 128 # Number of measurements\n", 64 | "fc = 6e9 # RF center frequency\n", 65 | "bw = 100e6 # RF bandwidth\n", 66 | "tsweep = 100e-6 # Sweep length\n", 67 | "fs = 1e6 # Sampling frequency\n", 68 | "nsamples = int(fs * tsweep) # Time domain samples per sweep\n", 69 | "\n", 70 | "# Imaging grid definition. Azimuth angle \"theta\" is sine of radians. 0.2 = 11.5 degrees.\n", 71 | "grid_polar = {\"r\": (90, 110), \"theta\": (-0.2, 0.2), \"nr\": nr, \"ntheta\": ntheta}" 72 | ] 73 | }, 74 | { 75 | "cell_type": "markdown", 76 | "id": "bed5dbac-3f81-4a93-a47f-72b9b5fcaa39", 77 | "metadata": {}, 78 | "source": [ 79 | "Define target and radar positions. There is one point target at 100 m distance and zero azimuth angle.\n", 80 | "For polar image formation radar motion should be in direction of Y-axis.\n", 81 | "If this is not the case positions should be rotated." 82 | ] 83 | }, 84 | { 85 | "cell_type": "code", 86 | "execution_count": null, 87 | "id": "56b18c7f-303b-44e1-b4ec-e56eaea7fff7", 88 | "metadata": {}, 89 | "outputs": [], 90 | "source": [ 91 | "target_pos = torch.tensor([[100, 0, 0]], dtype=torch.float32, device=device)\n", 92 | "target_rcs = torch.tensor([[1]], dtype=torch.float32, device=device)\n", 93 | "pos = torch.zeros([nsweeps, 3], dtype=torch.float32, device=device)\n", 94 | "pos[:,1] = torch.linspace(-nsweeps/2, nsweeps/2, nsweeps) * 0.25 * 3e8 / fc\n", 95 | "pos[:,2] = 50 # Platform height" 96 | ] 97 | }, 98 | { 99 | "cell_type": "markdown", 100 | "id": "a4047b6f-d590-4ee9-a4fc-c3e13f7d770c", 101 | "metadata": {}, 102 | "source": [ 103 | "Generate synthetic radar data" 104 | ] 105 | }, 106 | { 107 | "cell_type": "code", 108 | "execution_count": null, 109 | "id": "73133328-4cdb-4cbb-b4cd-ac2ce81301ba", 110 | "metadata": {}, 111 | "outputs": [], 112 | "source": [ 113 | "# Oversampling input data decreases interpolation errors\n", 114 | "oversample = 3\n", 115 | "\n", 116 | "with torch.no_grad():\n", 117 | " data = torchbp.util.generate_fmcw_data(target_pos, target_rcs, pos, fc, bw, tsweep, fs)\n", 118 | " # Apply windowing function in range direction\n", 119 | " w = torch.tensor(hamming(data.shape[-1])[None,:], dtype=torch.float32, device=device)\n", 120 | " data = torch.fft.fft(data * w, dim=-1, n=nsamples * oversample)\n", 121 | "\n", 122 | "data_db = 20*torch.log10(torch.abs(data)).detach()\n", 123 | "m = torch.max(data_db)\n", 124 | "\n", 125 | "plt.figure()\n", 126 | "plt.imshow(data_db.cpu().numpy(), origin=\"lower\", vmin=m-30, vmax=m, aspect=\"auto\")\n", 127 | "plt.xlabel(\"Range samples\")\n", 128 | "plt.ylabel(\"Azimuth samples\");" 129 | ] 130 | }, 131 | { 132 | "cell_type": "markdown", 133 | "id": "963bda9b-1f72-48f6-9ee1-d0ea54ff8cb9", 134 | "metadata": {}, 135 | "source": [ 136 | "Image formation.\n", 137 | "Hamming window was applied in range direction so low sidelobes in range are expected.\n", 138 | "Azimuth direction has no windowing function and high sidelobes (Highest -13 dB) are expected.\n", 139 | "Azimuth sidelobes could be decreased by windowing the input data also in the other dimension." 140 | ] 141 | }, 142 | { 143 | "cell_type": "code", 144 | "execution_count": null, 145 | "id": "ddd6b005-5d5f-432f-ba40-b4db21699609", 146 | "metadata": {}, 147 | "outputs": [], 148 | "source": [ 149 | "r_res = 3e8 / (2 * bw * oversample) # Range bin size in input data\n", 150 | "\n", 151 | "img = torchbp.ops.backprojection_polar_2d(data, grid_polar, fc, r_res, pos, dealias=True)\n", 152 | "img = img.squeeze() # Removes singular batch dimension\n", 153 | "\n", 154 | "img_db = 20*torch.log10(torch.abs(img)).detach()\n", 155 | "\n", 156 | "m = torch.max(img_db)\n", 157 | "\n", 158 | "extent = [*grid_polar[\"r\"], *grid_polar[\"theta\"]]\n", 159 | "\n", 160 | "plt.figure()\n", 161 | "plt.imshow(img_db.cpu().numpy().T, origin=\"lower\", vmin=m-30, vmax=m, extent=extent, aspect=\"auto\")\n", 162 | "plt.xlabel(\"Range (m)\")\n", 163 | "plt.ylabel(\"Angle (sin radians)\");" 164 | ] 165 | }, 166 | { 167 | "cell_type": "markdown", 168 | "id": "5d3d87a9-7c0b-4f46-bf05-873db9d7a336", 169 | "metadata": {}, 170 | "source": [ 171 | "Image entropy. Can be used as a loss function for optimization." 172 | ] 173 | }, 174 | { 175 | "cell_type": "code", 176 | "execution_count": null, 177 | "id": "1b44fa95-3a4e-4b95-900b-3a89a8290d06", 178 | "metadata": {}, 179 | "outputs": [], 180 | "source": [ 181 | "entropy = torchbp.util.entropy(img)\n", 182 | "print(\"Entropy:\", entropy.item())" 183 | ] 184 | }, 185 | { 186 | "cell_type": "markdown", 187 | "id": "768dc700-737b-4787-924d-a734ff9e7832", 188 | "metadata": {}, 189 | "source": [ 190 | "Convert image to cartesian coordinates:" 191 | ] 192 | }, 193 | { 194 | "cell_type": "code", 195 | "execution_count": null, 196 | "id": "4d5ab162-2b7f-43d9-af03-1e23eb55f8db", 197 | "metadata": {}, 198 | "outputs": [], 199 | "source": [ 200 | "# Origin of the polar coordinates\n", 201 | "origin = torch.mean(pos, axis=0)\n", 202 | "# Cartesian grid definition\n", 203 | "grid_cart = {\"x\": (90, 110), \"y\": (-10, 10), \"nx\": 128, \"ny\": 128}\n", 204 | "\n", 205 | "img_cart = torchbp.ops.polar_to_cart_linear(img, origin, grid_polar, grid_cart, fc, rotation=0)\n", 206 | "\n", 207 | "img_db = 20*torch.log10(torch.abs(img_cart)).detach()\n", 208 | "\n", 209 | "m = torch.max(img_db)\n", 210 | "\n", 211 | "extent = [*grid_cart[\"x\"], *grid_cart[\"y\"]]\n", 212 | "\n", 213 | "plt.figure()\n", 214 | "plt.imshow(img_db.cpu().numpy().T, origin=\"lower\", vmin=m-30, vmax=m, extent=extent, aspect=\"equal\")\n", 215 | "plt.xlabel(\"Range (m)\")\n", 216 | "plt.ylabel(\"Cross-range (m)\");" 217 | ] 218 | }, 219 | { 220 | "cell_type": "markdown", 221 | "id": "d5dfb0a3-61fd-41e1-aa7c-ef1d2e3cee5c", 222 | "metadata": {}, 223 | "source": [ 224 | "Backprojection directly onto Cartesian grid" 225 | ] 226 | }, 227 | { 228 | "cell_type": "code", 229 | "execution_count": null, 230 | "id": "7d470981-89c2-4352-802e-e41cb4581ea4", 231 | "metadata": {}, 232 | "outputs": [], 233 | "source": [ 234 | "img_cart2 = torchbp.ops.backprojection_cart_2d(data, grid_cart, fc, r_res, pos)\n", 235 | "img_cart2 = img_cart2.squeeze() # Removes singular batch dimension\n", 236 | "\n", 237 | "img_db = 20*torch.log10(torch.abs(img_cart2)).detach()\n", 238 | "\n", 239 | "m = torch.max(img_db)\n", 240 | "\n", 241 | "extent = [*grid_cart[\"x\"], *grid_cart[\"y\"]]\n", 242 | "\n", 243 | "plt.figure()\n", 244 | "plt.imshow(img_db.cpu().numpy().T, origin=\"lower\", vmin=m-30, vmax=m, extent=extent, aspect=\"equal\")\n", 245 | "plt.xlabel(\"Range (m)\")\n", 246 | "plt.ylabel(\"Cross-range (m)\");" 247 | ] 248 | }, 249 | { 250 | "cell_type": "markdown", 251 | "id": "dcaf0207-4f62-4404-8aaf-3954dbfb31d2", 252 | "metadata": {}, 253 | "source": [ 254 | "Difference between the results should be very small" 255 | ] 256 | }, 257 | { 258 | "cell_type": "code", 259 | "execution_count": null, 260 | "id": "be8f4727-c5c8-451b-bbf4-f743c2f246c7", 261 | "metadata": {}, 262 | "outputs": [], 263 | "source": [ 264 | "plt.figure()\n", 265 | "plt.title(\"Phase difference\")\n", 266 | "plt.imshow(torch.angle(img_cart * torch.conj(img_cart2)).cpu().numpy().T, origin=\"lower\", extent=extent, aspect=\"equal\")" 267 | ] 268 | }, 269 | { 270 | "cell_type": "code", 271 | "execution_count": null, 272 | "id": "528054b9-3174-41a4-af41-4622439a59f4", 273 | "metadata": {}, 274 | "outputs": [], 275 | "source": [ 276 | "torch.mean(torch.abs(img_cart - img_cart2))" 277 | ] 278 | }, 279 | { 280 | "cell_type": "code", 281 | "execution_count": null, 282 | "id": "121c3eff-0dc1-47b1-9da6-af9245420117", 283 | "metadata": {}, 284 | "outputs": [], 285 | "source": [] 286 | } 287 | ], 288 | "metadata": { 289 | "kernelspec": { 290 | "display_name": "Python 3 (ipykernel)", 291 | "language": "python", 292 | "name": "python3" 293 | }, 294 | "language_info": { 295 | "codemirror_mode": { 296 | "name": "ipython", 297 | "version": 3 298 | }, 299 | "file_extension": ".py", 300 | "mimetype": "text/x-python", 301 | "name": "python", 302 | "nbconvert_exporter": "python", 303 | "pygments_lexer": "ipython3", 304 | "version": "3.10.12" 305 | } 306 | }, 307 | "nbformat": 4, 308 | "nbformat_minor": 5 309 | } 310 | -------------------------------------------------------------------------------- /docs/source/examples/index.rst: -------------------------------------------------------------------------------- 1 | .. _examples: 2 | 3 | Examples 4 | ========= 5 | 6 | .. toctree:: 7 | :maxdepth: 1 8 | :glob: 9 | 10 | ./* 11 | -------------------------------------------------------------------------------- /docs/source/examples/pga.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "id": "1192e1d8-9f6b-47f2-804c-820d3773dd6d", 6 | "metadata": {}, 7 | "source": [ 8 | "# Phase gradient autofocus" 9 | ] 10 | }, 11 | { 12 | "cell_type": "code", 13 | "execution_count": null, 14 | "id": "4b06324e-7f14-4371-8f67-18afa9397a52", 15 | "metadata": {}, 16 | "outputs": [], 17 | "source": [ 18 | "import torch\n", 19 | "import torchbp\n", 20 | "import matplotlib.pyplot as plt\n", 21 | "from scipy.signal import get_window\n", 22 | "import numpy as np\n", 23 | "\n", 24 | "if torch.cuda.is_available():\n", 25 | " device = \"cuda\"\n", 26 | "else:\n", 27 | " device = \"cpu\"\n", 28 | "print(\"Device:\", device)" 29 | ] 30 | }, 31 | { 32 | "cell_type": "markdown", 33 | "id": "6c6fe062-389d-4be3-b537-753b46d88f56", 34 | "metadata": {}, 35 | "source": [ 36 | "Generate synthetic radar data" 37 | ] 38 | }, 39 | { 40 | "cell_type": "code", 41 | "execution_count": null, 42 | "id": "4be6f29d-2c72-443a-99fd-46590d98c54f", 43 | "metadata": {}, 44 | "outputs": [], 45 | "source": [ 46 | "nr = 100 # Range points\n", 47 | "ntheta = 128 # Azimuth points\n", 48 | "nsweeps = 128 # Number of measurements\n", 49 | "fc = 6e9 # RF center frequency\n", 50 | "bw = 100e6 # RF bandwidth\n", 51 | "tsweep = 100e-6 # Sweep length\n", 52 | "fs = 1e6 # Sampling frequency\n", 53 | "nsamples = int(fs * tsweep) # Time domain samples per sweep\n", 54 | "\n", 55 | "# Imaging grid definition. Azimuth angle \"theta\" is sine of radians. 0.2 = 11.5 degrees.\n", 56 | "grid_polar = {\"r\": (90, 110), \"theta\": (-0.2, 0.2), \"nr\": nr, \"ntheta\": ntheta}" 57 | ] 58 | }, 59 | { 60 | "cell_type": "code", 61 | "execution_count": null, 62 | "id": "1d107200-f4bd-4649-9999-5f756fd92fd5", 63 | "metadata": {}, 64 | "outputs": [], 65 | "source": [ 66 | "target_pos = torch.tensor([[100, 0, 0], [105, 10, 0], [97, -5, 0], [102, -10, 0], [95, 5, 0]], dtype=torch.float32, device=device)\n", 67 | "target_rcs = torch.tensor([1,1,1,1,1], dtype=torch.float32, device=device)\n", 68 | "pos = torch.zeros([nsweeps, 3], dtype=torch.float32, device=device)\n", 69 | "pos[:,1] = torch.linspace(-nsweeps/2, nsweeps/2, nsweeps) * 0.25 * 3e8 / fc" 70 | ] 71 | }, 72 | { 73 | "cell_type": "code", 74 | "execution_count": null, 75 | "id": "de9bcbfd-ff44-4d79-8a2e-fc7c8d480ea9", 76 | "metadata": {}, 77 | "outputs": [], 78 | "source": [ 79 | "# Oversampling input data decreases interpolation errors\n", 80 | "oversample = 3\n", 81 | "\n", 82 | "with torch.no_grad():\n", 83 | " data = torchbp.util.generate_fmcw_data(target_pos, target_rcs, pos, fc, bw, tsweep, fs)\n", 84 | " # Apply windowing function in range direction\n", 85 | " wr = torch.tensor(get_window((\"taylor\", 3, 30), data.shape[-1])[None,:], dtype=torch.float32, device=device)\n", 86 | " wa = torch.tensor(get_window((\"taylor\", 3, 30), data.shape[0])[:,None], dtype=torch.float32, device=device)\n", 87 | " data = torch.fft.fft(data * wa * wr, dim=-1, n=nsamples * oversample)\n", 88 | "\n", 89 | "data_db = 20*torch.log10(torch.abs(data)).detach()\n", 90 | "m = torch.max(data_db)\n", 91 | "\n", 92 | "plt.figure()\n", 93 | "plt.imshow(data_db.cpu().numpy(), origin=\"lower\", vmin=m-30, vmax=m, aspect=\"auto\")\n", 94 | "plt.xlabel(\"Range samples\")\n", 95 | "plt.ylabel(\"Azimuth samples\");" 96 | ] 97 | }, 98 | { 99 | "cell_type": "markdown", 100 | "id": "ddc7fc94-caff-4e75-a276-358fe4880fb5", 101 | "metadata": {}, 102 | "source": [ 103 | "Focused image without motion error" 104 | ] 105 | }, 106 | { 107 | "cell_type": "code", 108 | "execution_count": null, 109 | "id": "05e995c4-b67d-4a7e-a2a5-346e295034c9", 110 | "metadata": {}, 111 | "outputs": [], 112 | "source": [ 113 | "r_res = 3e8 / (2 * bw * oversample) # Range bin size in input data\n", 114 | "\n", 115 | "# dealias=True removes range spectrum aliasing\n", 116 | "img = torchbp.ops.backprojection_polar_2d(data, grid_polar, fc, r_res, pos, dealias=True)\n", 117 | "img = img.squeeze(0) # Removes singular batch dimension\n", 118 | "# Backprojection image has spectrum with DC at zero index.\n", 119 | "# Shifting the spectrum shifts the DC to center bin.\n", 120 | "# This makes the solved phase to have same order as the position vector\n", 121 | "# Without shifting of the image, fftshift needs to be applied to\n", 122 | "# the solved phase for it to be in the same order as the position vector.\n", 123 | "# This doesn't affect the absolute value of the image.\n", 124 | "img = torchbp.util.shift_spectrum(img)\n", 125 | "\n", 126 | "img_db = 20*torch.log10(torch.abs(img)).detach()\n", 127 | "\n", 128 | "m = torch.max(img_db)\n", 129 | "\n", 130 | "extent = [*grid_polar[\"r\"], *grid_polar[\"theta\"]]\n", 131 | "\n", 132 | "plt.figure()\n", 133 | "plt.imshow(img_db.cpu().numpy().T, origin=\"lower\", vmin=m-40, vmax=m, extent=extent, aspect=\"auto\")\n", 134 | "plt.xlabel(\"Range (m)\")\n", 135 | "plt.ylabel(\"Angle (sin radians)\");" 136 | ] 137 | }, 138 | { 139 | "cell_type": "markdown", 140 | "id": "29782eea-9576-4351-a630-41bf52bebd8d", 141 | "metadata": {}, 142 | "source": [ 143 | "Create corrupted image" 144 | ] 145 | }, 146 | { 147 | "cell_type": "code", 148 | "execution_count": null, 149 | "id": "7f270d8e-634e-4fbf-9737-972367e1e608", 150 | "metadata": {}, 151 | "outputs": [], 152 | "source": [ 153 | "phase_error = torch.exp(1j*2*torch.pi*torch.linspace(-3, 3, ntheta, dtype=torch.float32, device=device)[None,:]**2)\n", 154 | "\n", 155 | "plt.figure()\n", 156 | "plt.plot(torch.angle(phase_error.squeeze()).cpu().numpy())\n", 157 | "plt.xlabel(\"Azimuth sample\")\n", 158 | "plt.ylabel(\"Phase error (radians)\")\n", 159 | "\n", 160 | "img_corrupted = torch.fft.ifft(torch.fft.fft(img, dim=-1) * phase_error, dim=-1)\n", 161 | "\n", 162 | "plt.figure()\n", 163 | "plt.imshow(20*torch.log10(torch.abs(img_corrupted)).cpu().numpy().T, origin=\"lower\", vmin=m-40, vmax=m, extent=extent, aspect=\"auto\")\n", 164 | "plt.xlabel(\"Range (m)\")\n", 165 | "plt.ylabel(\"Angle (sin radians)\");" 166 | ] 167 | }, 168 | { 169 | "cell_type": "markdown", 170 | "id": "b24ce554-3f32-4f89-bf82-f23fc561f85a", 171 | "metadata": {}, 172 | "source": [ 173 | "Phase gradient autofocus with phase difference estimator" 174 | ] 175 | }, 176 | { 177 | "cell_type": "code", 178 | "execution_count": null, 179 | "id": "a5d7ba36-09fe-4885-b194-a1a013beae97", 180 | "metadata": {}, 181 | "outputs": [], 182 | "source": [ 183 | "img_pga, phi = torchbp.autofocus.pga(img_corrupted, remove_trend=False, estimator=\"pd\")\n", 184 | "\n", 185 | "plt.figure()\n", 186 | "plt.imshow(20*torch.log10(torch.abs(img_pga)).cpu().numpy().T, origin=\"lower\", vmin=m-40, vmax=m, extent=extent, aspect=\"auto\")\n", 187 | "plt.xlabel(\"Range (m)\")\n", 188 | "plt.ylabel(\"Angle (sin radians)\");\n", 189 | "\n", 190 | "plt.figure()\n", 191 | "plt.plot(torch.angle(torch.exp(1j*phi)).cpu().numpy())\n", 192 | "plt.xlabel(\"Azimuth samples\")\n", 193 | "plt.ylabel(\"Phase error (radians)\");" 194 | ] 195 | }, 196 | { 197 | "cell_type": "markdown", 198 | "id": "e1036ccb-2383-4f41-be81-8ae96366a314", 199 | "metadata": {}, 200 | "source": [ 201 | "Apply maximum likelihood phase gradient autofocus" 202 | ] 203 | }, 204 | { 205 | "cell_type": "code", 206 | "execution_count": null, 207 | "id": "2170e6dc-c32e-4a7a-8ee5-0b9e3e1e6ee5", 208 | "metadata": {}, 209 | "outputs": [], 210 | "source": [ 211 | "img_pga, phi = torchbp.autofocus.pga(img_corrupted, remove_trend=False, estimator=\"ml\")\n", 212 | "\n", 213 | "plt.figure()\n", 214 | "plt.imshow(20*torch.log10(torch.abs(img_pga)).cpu().numpy().T, origin=\"lower\", vmin=m-40, vmax=m, extent=extent, aspect=\"auto\")\n", 215 | "plt.xlabel(\"Range (m)\")\n", 216 | "plt.ylabel(\"Angle (sin radians)\");\n", 217 | "\n", 218 | "plt.figure()\n", 219 | "plt.plot(torch.angle(torch.exp(1j*phi)).cpu().numpy())\n", 220 | "plt.xlabel(\"Azimuth samples\")\n", 221 | "plt.ylabel(\"Phase error (radians)\");" 222 | ] 223 | }, 224 | { 225 | "cell_type": "markdown", 226 | "id": "055e8f56-d7d2-4225-a4cd-f302de7ba3c2", 227 | "metadata": {}, 228 | "source": [ 229 | "Multiplying the solved phase with FFT of the corrupted image gives the focused image and taking inverse FFT gives the focused image. This should be identical to the image returned by `pga_ml`." 230 | ] 231 | }, 232 | { 233 | "cell_type": "code", 234 | "execution_count": null, 235 | "id": "45663e08-5bbf-42ca-9599-962df0dbc6d1", 236 | "metadata": {}, 237 | "outputs": [], 238 | "source": [ 239 | "img_focused = torch.fft.ifft(torch.fft.fft(img_corrupted, dim=-1) * torch.exp(-1j*phi), dim=-1)\n", 240 | "\n", 241 | "plt.figure()\n", 242 | "plt.imshow(20*torch.log10(torch.abs(img_focused)).cpu().numpy().T, origin=\"lower\", vmin=m-40, vmax=m, extent=extent, aspect=\"auto\")\n", 243 | "plt.xlabel(\"Range (m)\")\n", 244 | "plt.ylabel(\"Angle (sin radians)\");" 245 | ] 246 | }, 247 | { 248 | "cell_type": "code", 249 | "execution_count": null, 250 | "id": "ceddc632-0865-4369-9fdb-9abe2af8c5c6", 251 | "metadata": {}, 252 | "outputs": [], 253 | "source": [] 254 | } 255 | ], 256 | "metadata": { 257 | "kernelspec": { 258 | "display_name": "Python 3 (ipykernel)", 259 | "language": "python", 260 | "name": "python3" 261 | }, 262 | "language_info": { 263 | "codemirror_mode": { 264 | "name": "ipython", 265 | "version": 3 266 | }, 267 | "file_extension": ".py", 268 | "mimetype": "text/x-python", 269 | "name": "python", 270 | "nbconvert_exporter": "python", 271 | "pygments_lexer": "ipython3", 272 | "version": "3.10.12" 273 | } 274 | }, 275 | "nbformat": 4, 276 | "nbformat_minor": 5 277 | } 278 | -------------------------------------------------------------------------------- /docs/source/index.rst: -------------------------------------------------------------------------------- 1 | .. torchbp documentation master file, created by 2 | sphinx-quickstart on Wed Dec 25 08:36:09 2024. 3 | You can adapt this file completely to your liking, but it should at least 4 | contain the root `toctree` directive. 5 | 6 | torchbp documentation 7 | =================================== 8 | 9 | torchbp is Pytorch library for fast differentiable synthetic radar image formation 10 | and autofocus with GPU and CPU support. 11 | 12 | .. toctree:: 13 | :maxdepth: 2 14 | :caption: Contents: 15 | 16 | examples/index.rst 17 | torchbp.rst 18 | -------------------------------------------------------------------------------- /docs/source/modules.rst: -------------------------------------------------------------------------------- 1 | torchbp 2 | ======= 3 | 4 | .. toctree:: 5 | :maxdepth: 4 6 | 7 | torchbp 8 | -------------------------------------------------------------------------------- /docs/source/torchbp.rst: -------------------------------------------------------------------------------- 1 | torchbp package 2 | =============== 3 | 4 | Submodules 5 | ---------- 6 | 7 | torchbp.autofocus module 8 | ------------------------ 9 | 10 | .. automodule:: torchbp.autofocus 11 | :members: 12 | :undoc-members: 13 | :show-inheritance: 14 | 15 | torchbp.ops module 16 | ------------------ 17 | 18 | .. automodule:: torchbp.ops 19 | :members: 20 | :undoc-members: 21 | :show-inheritance: 22 | 23 | torchbp.util module 24 | ------------------- 25 | 26 | .. automodule:: torchbp.util 27 | :members: 28 | :undoc-members: 29 | :show-inheritance: 30 | 31 | Module contents 32 | --------------- 33 | 34 | .. automodule:: torchbp 35 | :members: 36 | :undoc-members: 37 | :show-inheritance: 38 | -------------------------------------------------------------------------------- /examples/Readme.md: -------------------------------------------------------------------------------- 1 | # Example SAR data processing script 2 | 3 | Instructions: 4 | 5 | 1. Download the sample data from: https://hforsten.com/sar.safetensors.zip 6 | 2. Unzip the file to this directory. 7 | 3. Run `sar_process_safetensor.py` (optimization based minimum entropy 8 | autofocus) or `sar_process_safetensor_gpga.py` (generalized phase gradient 9 | autofocus). It will process the file and display polar formatted image. 10 | Processed image is also saved to disk for next step. 11 | 4. Run `sar_polar_to_cart.py` to display the previously saved image in Cartesian grid. 12 | 13 | Some processing parameters can be modified in the `sar_process_safetensor.py` 14 | file. For example, for higher resolution image set `nsweeps = 51200` and 15 | increase `max_step_limit` to 5. 16 | -------------------------------------------------------------------------------- /examples/sar_polar_to_cart.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # Visualize pickled radar image 3 | import pickle 4 | import matplotlib.pyplot as plt 5 | import torchbp 6 | from torchbp.util import entropy 7 | import torch 8 | import sys 9 | from torchvision.transforms.functional import resize 10 | 11 | if __name__ == "__main__": 12 | filename = "sar_img.p" 13 | if len(sys.argv) > 1: 14 | filename = sys.argv[1] 15 | with open(filename, "rb") as f: 16 | sar_img, mission, grid, grid_polar, origin, origin_angle = pickle.load(f) 17 | 18 | dev = "cuda" 19 | sar_img = torch.from_numpy(sar_img).to(dtype=torch.complex64, device=dev) 20 | fc = mission["fc"] 21 | print("Entropy", entropy(sar_img).item()) 22 | 23 | # Increase Cartesian image size 24 | oversample = 1 25 | # Increases image size, but then resamples it down by the same amount 26 | # Can be used for multilook processing, when the input polar format data 27 | # resolution is higher than can fit into the Cartesian grid 28 | multilook = 2 29 | grid["nx"] = int(oversample * grid["nx"] * multilook) 30 | grid["ny"] = int(oversample * grid["ny"] * multilook) 31 | 32 | plt.figure() 33 | origin = torch.from_numpy(origin).to(dtype=torch.float32, device=dev) 34 | # Amplitude scaling in image 35 | m = 20 * torch.log10(torch.median(torch.abs(sar_img))) - 3 36 | m = m.cpu().numpy() 37 | m2 = m + 40 38 | 39 | sar_img_cart = torchbp.ops.polar_to_cart_bicubic( 40 | torch.abs(sar_img), 41 | origin, 42 | grid_polar, 43 | grid, 44 | fc, 45 | origin_angle 46 | ) 47 | extent = [grid["x"][0], grid["x"][1], grid["y"][0], grid["y"][1]] 48 | img_db = torch.abs(sar_img_cart) + 1e-10 49 | out_shape = [img_db.shape[-2] // multilook, img_db.shape[-1] // multilook] 50 | img_db = resize(img_db, out_shape).squeeze() 51 | img_db = 20 * torch.log10(img_db) 52 | img_db = img_db.cpu().numpy() 53 | 54 | plt.imshow(img_db.T, origin="lower", aspect="equal", vmin=m, vmax=m2, extent=extent) 55 | plt.grid(False) 56 | plt.savefig("sar_img_cart.png", dpi=700) 57 | plt.xlabel("X (m)") 58 | plt.ylabel("Y (m)") 59 | 60 | plt.show(block=True) 61 | -------------------------------------------------------------------------------- /examples/sar_process_safetensor.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # Example SAR data processing script. 3 | # Sample data can be downloaded from: https://hforsten.com/sar.safetensors.zip 4 | import sys 5 | import numpy as np 6 | import matplotlib.pyplot as plt 7 | import scipy.signal as signal 8 | import pickle 9 | import torch 10 | import torchbp 11 | from torchbp.util import make_polar_grid 12 | from safetensors.torch import safe_open 13 | 14 | plt.style.use("ggplot") 15 | 16 | 17 | def grid_extent(pos, att, min_range, max_range, bw=0, origin_angle=0): 18 | """ 19 | Return grid dimension that contain the radar data. 20 | 21 | Parameters 22 | ---------- 23 | pos : np.array 24 | Platform xyz-position vector. Shape: [N, 3]. 25 | att : np.array 26 | Antenna Euler angle vector. Shape: [N, 3]. 27 | min_range : float 28 | Minimum range from radar in m. 29 | max_range : float 30 | Maximum range from radar in m. 31 | bw : float 32 | Antenna beam width in radians. 33 | origin_angle : float 34 | Input position rotation angle. 35 | 36 | Returns 37 | ------------- 38 | x, y : tuple 39 | Minimum and maximum X and Y coordinates for image grid. 40 | """ 41 | x = None 42 | y = None 43 | for b in [-bw, 0, bw]: 44 | yaw = att[:, 2] + b + origin_angle 45 | pos = pos[:, :2] 46 | range_vector = np.array([np.cos(yaw), np.sin(yaw)]).T 47 | fc_range = pos + range_vector * max_range 48 | max_x = (np.min(fc_range[:, 0]), np.max(fc_range[:, 0])) 49 | max_y = (np.min(fc_range[:, 1]), np.max(fc_range[:, 1])) 50 | fc_range = pos + range_vector * min_range 51 | min_x = (np.min(fc_range[:, 0]), np.max(fc_range[:, 0])) 52 | min_y = (np.min(fc_range[:, 1]), np.max(fc_range[:, 1])) 53 | xn = (min(min_x[0], max_x[0]), max(min_x[1], max_x[1])) 54 | yn = (min(min_y[0], max_y[0]), max(min_y[1], max_y[1])) 55 | if x is None: 56 | x = xn 57 | y = yn 58 | else: 59 | x = (min(xn[0], x[0]), max(x[1], xn[1])) 60 | y = (min(yn[0], y[0]), max(y[1], yn[1])) 61 | return x, y 62 | 63 | 64 | def load_data(filename): 65 | tensors = {} 66 | with safe_open(filename, framework="pt", device="cpu") as f: 67 | for key in f.keys(): 68 | tensors[key] = f.get_tensor(key) 69 | mission = f.metadata() 70 | mission = {k: float(mission[k]) for k in mission.keys()} 71 | return mission, tensors 72 | 73 | 74 | if __name__ == "__main__": 75 | filename = "sar.safetensors" 76 | if len(sys.argv) > 1: 77 | filename = sys.argv[1] 78 | 79 | # Final image dimensions 80 | x0 = 1 81 | x1 = 2000 82 | # Image dimensions during autofocus, typically smaller than the final image 83 | autofocus_x0 = 400 84 | autofocus_x1 = 1200 85 | autofocus_theta_limit = 0.8 86 | # Azimuth range in polar image in sin of radians. 1 for full 180 degrees. 87 | theta_limit = 1 88 | # Decrease the number of sweeps to speed up the calculation 89 | nsweeps = 10000 # Max 51200 90 | sweep_start = 0 91 | # Maximum number of autofocus iterations 92 | max_steps = 15 93 | # Maximum autofocus position update in wavelengths 94 | # Optimal value depends on the maximum error in the image 95 | max_step_limit = 0.5 # Try 5 with 50k sweeps 96 | data_dtype = torch.complex64 # Can be `torch.complex32` to save VRAM 97 | 98 | # Windowing functions 99 | range_window = "hamming" 100 | angle_window = ("taylor", 4, 50) 101 | # FFT oversampling factor, decrease to 2 to save some VRAM 102 | fft_oversample = 3 103 | dev = torch.device("cuda") 104 | # Distance in radar data corresponding to zero actual distance 105 | # Slightly higher than zero due to antenna feedlines and other delays. 106 | d0 = 0.5 107 | 108 | # Calculate initial estimate using PGA 109 | initial_pga = False 110 | 111 | c0 = 299792458 112 | 113 | # Load the input data 114 | try: 115 | mission, tensors = load_data(filename) 116 | except FileNotFoundError: 117 | print("Input file {filename} not found.") 118 | 119 | sweeps = tensors["data"][sweep_start:nsweeps].to(dtype=torch.float32) 120 | pos = tensors["pos"][sweep_start:nsweeps].cpu().numpy() 121 | att = tensors["att"][sweep_start:nsweeps].cpu().numpy() 122 | counts = tensors["counts"][sweep_start:nsweeps] 123 | nsweeps = sweeps.shape[0] 124 | del tensors 125 | 126 | bw = mission["bw"] 127 | fc = mission["fc"] 128 | fs = mission["fsample"] 129 | origin_angle = mission["origin_angle"] 130 | tsweep = sweeps.shape[-1] / fs 131 | sweep_interval = mission["pri"] 132 | res = c0 / (2 * mission["bw"]) 133 | 134 | # Calculate Cartesian grid that fits the radar image 135 | antenna_bw = 50 * np.pi / 180 136 | x, y = grid_extent(pos, att, x0, x1, bw=antenna_bw, origin_angle=origin_angle) 137 | nx = int((x[1] - x[0]) / res) 138 | ny = int((y[1] - y[0]) / res) 139 | grid = {"x": x, "y": y, "nx": nx, "ny": ny} 140 | print("mission", mission) 141 | print("grid_cart", grid) 142 | 143 | # Calculate polar grid 144 | d = np.linalg.norm(pos[-1] - pos[0]) 145 | wl = c0 / fc 146 | spacing = d / wl / nsweeps 147 | # Critically spaced array would be 0.25 wavelengths apart 148 | ntheta = int(1 + nsweeps * spacing * theta_limit / 0.25) 149 | nr = int((x1 - x0) / res) 150 | az = att[:, 2] 151 | mean_az = np.angle(np.mean(np.exp(1j * az))) 152 | grid_polar = make_polar_grid( 153 | x0, 154 | x1, 155 | nr, 156 | ntheta, 157 | theta_limit=theta_limit, 158 | squint=mean_az if theta_limit < 1 else 0, 159 | ) 160 | 161 | nr = int((autofocus_x1 - autofocus_x0) / res) 162 | ntheta = int(1 + nsweeps * spacing * autofocus_theta_limit / 0.25) 163 | grid_polar_autofocus = make_polar_grid( 164 | autofocus_x0, 165 | autofocus_x1, 166 | nr, 167 | ntheta, 168 | theta_limit=autofocus_theta_limit, 169 | squint=mean_az, 170 | ) 171 | print("grid", grid_polar) 172 | print("grid autofocus", grid_polar_autofocus) 173 | 174 | pos = torch.from_numpy(pos).to(dtype=torch.float32, device=dev) 175 | 176 | # Generate window functions 177 | nsamples = sweeps.shape[-1] 178 | wr = signal.get_window(range_window, nsamples) 179 | wr /= np.mean(wr) 180 | wr = torch.tensor(wr).to(dtype=torch.float32, device=dev) 181 | wa = torch.tensor( 182 | signal.get_window(angle_window, sweeps.shape[0], fftbins=False) 183 | ).to(dtype=torch.float32, device=dev) 184 | wa /= torch.mean(wa) 185 | 186 | # Residual video phase compensation 187 | nsamples = sweeps.shape[-1] 188 | f = torch.fft.rfftfreq(int(nsamples * fft_oversample), d=1 / fs).to(dev) 189 | rvp = torch.exp(-1j * torch.pi * f**2 * tsweep / bw) 190 | r_res = c0 / (2 * bw * fft_oversample) 191 | del f 192 | 193 | # Timestamp of each sweep 194 | data_time = sweep_interval * counts 195 | 196 | v = torch.diff(pos, dim=0, prepend=pos[0].unsqueeze(0)) / sweep_interval 197 | pos_mean = torch.mean(pos, dim=0) 198 | v_orig = v.detach().clone() 199 | 200 | # Apply windowing 201 | sweeps *= wa[:, None, None].cpu() 202 | sweeps *= wr[None, None, :].cpu() 203 | 204 | # FFT radar data in blocks to decrease the maximum needed VRAM 205 | n = int(nsamples * fft_oversample) 206 | fsweeps = torch.zeros((sweeps.shape[0], n // 2 + 1), dtype=data_dtype, device=dev) 207 | blocks = 16 208 | block = (sweeps.shape[0] + blocks - 1) // blocks 209 | for b in range(blocks): 210 | s0 = b * block 211 | s1 = min((b + 1) * block, sweeps.shape[0]) 212 | fsw = torch.fft.rfft( 213 | sweeps[s0:s1, 0, :].to(device=dev), n=n, norm="forward", dim=-1 214 | ) 215 | fsw = torch.conj(fsw) 216 | fsw *= rvp[None, :] 217 | fsweeps[s0:s1] = fsw.to(dtype=data_dtype) 218 | del sweeps 219 | del fsw 220 | 221 | pos = pos.to(device=dev) 222 | data_time = data_time.to(device=dev) 223 | 224 | if max_steps > 1: 225 | if initial_pga: 226 | print("Calculating initial estimate with PGA") 227 | origin = torch.tensor([torch.mean(pos[:,0]), torch.mean(pos[:,1]), 0], 228 | device=dev, dtype=torch.float32)[None,:] 229 | pos_centered = pos - origin 230 | sar_img, phi = torchbp.autofocus.gpga_ml_bp_polar(None, fsweeps, 231 | pos_centered, fc, r_res, grid_polar_autofocus, 232 | window_width=nsweeps//8, d0=d0, target_threshold_db=20) 233 | 234 | d = torchbp.util.phase_to_distance(phi, fc) 235 | d -= torch.mean(d) 236 | pos[:,0] = pos[:,0] + d 237 | 238 | print("Calculating autofocus. This might take a while. Press Ctrl-C to interrupt.") 239 | sar_img, origin, pos, steps = torchbp.autofocus.bp_polar_grad_minimum_entropy( 240 | fsweeps, 241 | data_time, 242 | pos, 243 | fc, 244 | r_res, 245 | grid_polar_autofocus, 246 | wa, 247 | tx_norm=None, 248 | max_steps=max_steps, 249 | lr_max=10000, 250 | d0=d0, 251 | pos_reg=0.1, 252 | lr_reduce=0.8, 253 | verbose=True, 254 | convergence_limit=0.01, 255 | max_step_limit=max_step_limit, 256 | grad_limit_quantile=0.99, 257 | fixed_pos=0, 258 | ) 259 | v = torch.diff(pos, dim=0, prepend=pos[0].unsqueeze(0)) / sweep_interval 260 | 261 | plt.figure() 262 | plt.title("Original and optimized velocity") 263 | p = v.detach().cpu().numpy() 264 | plt.plot(p[:, 0], label="vx opt") 265 | plt.plot(p[:, 1], label="vy opt") 266 | plt.plot(p[:, 2], label="vz opt") 267 | po = v_orig.detach().cpu().numpy() 268 | plt.plot(po[:, 0], label="vx") 269 | plt.plot(po[:, 1], label="vy") 270 | plt.plot(po[:, 2], label="vz") 271 | plt.legend(loc="best") 272 | plt.xlabel("Sweep index") 273 | plt.ylabel("Velocity (m/s)") 274 | 275 | origin = torch.tensor( 276 | [torch.mean(pos[:, 0]), torch.mean(pos[:, 1]), 0], 277 | device=dev, 278 | dtype=torch.float32, 279 | )[None, :] 280 | pos_centered = pos - origin 281 | print("Focusing final image") 282 | sar_img = torchbp.ops.backprojection_polar_2d( fsweeps, grid_polar, fc, 283 | r_res, pos_centered, d0).squeeze() 284 | print("Entropy", torchbp.util.entropy(sar_img).item()) 285 | sar_img = sar_img.cpu().numpy() 286 | 287 | plt.figure() 288 | extent = [ 289 | grid_polar["r"][0], 290 | grid_polar["r"][1], 291 | grid_polar["theta"][0], 292 | grid_polar["theta"][1], 293 | ] 294 | abs_img = np.abs(sar_img) 295 | m = 20 * np.log10(np.median(abs_img)) - 13 296 | plt.imshow( 297 | 20 * np.log10(abs_img).T, aspect="auto", origin="lower", extent=extent, vmin=m 298 | ) 299 | plt.grid(False) 300 | plt.xlabel("Range (m)") 301 | plt.ylabel("Angle (sin(radians))") 302 | print("Exporting image") 303 | plt.savefig("sar_img.png", dpi=400) 304 | 305 | # Export image as pickle file 306 | with open("sar_img.p", "wb") as f: 307 | origin = origin.cpu().numpy().squeeze() 308 | pickle.dump((sar_img, mission, grid, grid_polar, origin, origin_angle), f) 309 | 310 | plt.show(block=True) 311 | -------------------------------------------------------------------------------- /examples/sar_process_safetensor_gpga.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # Example SAR data processing script. 3 | # Sample data can be downloaded from: https://hforsten.com/sar.safetensors.zip 4 | import sys 5 | import time 6 | import numpy as np 7 | import matplotlib.pyplot as plt 8 | import scipy.signal as signal 9 | import pickle 10 | import torch 11 | import torchbp 12 | from torchbp.util import make_polar_grid 13 | from safetensors.torch import safe_open 14 | from sar_process_safetensor import grid_extent, load_data 15 | plt.style.use("ggplot") 16 | 17 | if __name__ == "__main__": 18 | filename = "sar.safetensors" 19 | if len(sys.argv) > 1: 20 | filename = sys.argv[1] 21 | 22 | # Final image dimensions 23 | x0 = 1 24 | x1 = 2000 25 | # Image dimensions during autofocus, typically smaller than the final image 26 | autofocus_x0 = 400 27 | autofocus_x1 = 1200 28 | autofocus_theta_limit = 0.8 29 | # Azimuth range in polar image in sin of radians. 1 for full 180 degrees. 30 | theta_limit = 1 31 | # Decrease the number of sweeps to speed up the calculation 32 | nsweeps = 10000 # Max 51200 33 | sweep_start = 0 34 | 35 | # Windowing functions 36 | range_window = "hamming" 37 | angle_window = ("taylor", 4, 50) 38 | # FFT oversampling factor, decrease to 2 to save some VRAM 39 | fft_oversample = 3 40 | dev = torch.device("cuda") 41 | # Distance in radar data corresponding to zero actual distance 42 | # Slightly higher than zero due to antenna feedlines and other delays. 43 | d0 = 0.5 44 | data_dtype = torch.complex64 # Can be `torch.complex32` to save VRAM 45 | # Use fast factorized backprojection, slightly reduces the image quality 46 | # but is faster. 47 | ffbp = True 48 | 49 | c0 = 299792458 50 | 51 | # Load the input data 52 | try: 53 | mission, tensors = load_data(filename) 54 | except FileNotFoundError: 55 | print("Input file {filename} not found.") 56 | 57 | sweeps = tensors["data"][sweep_start:nsweeps].to(dtype=torch.float32) 58 | pos = tensors["pos"][sweep_start:nsweeps].cpu().numpy() 59 | att = tensors["att"][sweep_start:nsweeps].cpu().numpy() 60 | counts = tensors["counts"][sweep_start:nsweeps] 61 | nsweeps = sweeps.shape[0] 62 | del tensors 63 | 64 | bw = mission["bw"] 65 | fc = mission["fc"] 66 | fs = mission["fsample"] 67 | origin_angle = mission["origin_angle"] 68 | tsweep = sweeps.shape[-1] / fs 69 | sweep_interval = mission["pri"] 70 | res = c0 / (2 * mission["bw"]) 71 | 72 | # Calculate Cartesian grid that fits the radar image 73 | antenna_bw = 50 * np.pi / 180 74 | x, y = grid_extent(pos, att, x0, x1, bw=antenna_bw, origin_angle=origin_angle) 75 | nx = int((x[1] - x[0]) / res) 76 | ny = int((y[1] - y[0]) / res) 77 | grid = {"x": x, "y": y, "nx": nx, "ny": ny} 78 | print("mission", mission) 79 | print("grid_cart", grid) 80 | 81 | # Calculate polar grid 82 | d = np.linalg.norm(pos[-1] - pos[0]) 83 | wl = c0 / fc 84 | spacing = d / wl / nsweeps 85 | # Critically spaced array would be 0.25 wavelengths apart 86 | ntheta = int(1 + 2 * nsweeps * spacing * theta_limit / 0.25) 87 | nr = int(2 * (x1 - x0) / res) 88 | 89 | az = att[:, 2] 90 | mean_az = np.angle(np.mean(np.exp(1j * az))) 91 | grid_polar = make_polar_grid( 92 | x0, 93 | x1, 94 | nr, 95 | ntheta, 96 | theta_limit=theta_limit, 97 | squint=mean_az if theta_limit < 1 else 0, 98 | ) 99 | 100 | nr = int((autofocus_x1 - autofocus_x0) / res) 101 | ntheta = int(1 + nsweeps * spacing * autofocus_theta_limit / 0.25) 102 | grid_polar_autofocus = make_polar_grid( 103 | autofocus_x0, 104 | autofocus_x1, 105 | nr, 106 | ntheta, 107 | theta_limit=autofocus_theta_limit, 108 | squint=mean_az, 109 | ) 110 | print("grid", grid_polar) 111 | print("grid autofocus", grid_polar_autofocus) 112 | 113 | pos = torch.from_numpy(pos).to(dtype=torch.float32, device=dev) 114 | 115 | # Generate window functions 116 | nsamples = sweeps.shape[-1] 117 | wr = signal.get_window(range_window, nsamples) 118 | wr /= np.mean(wr) 119 | wr = torch.tensor(wr).to(dtype=torch.float32, device=dev) 120 | wa = torch.tensor( 121 | signal.get_window(angle_window, sweeps.shape[0], fftbins=False) 122 | ).to(dtype=torch.float32, device=dev) 123 | wa /= torch.mean(wa) 124 | 125 | # Residual video phase compensation 126 | nsamples = sweeps.shape[-1] 127 | f = torch.fft.rfftfreq(int(nsamples * fft_oversample), d=1 / fs).to(dev) 128 | rvp = torch.exp(-1j * torch.pi * f**2 * tsweep / bw) 129 | r_res = c0 / (2 * bw * fft_oversample) 130 | del f 131 | 132 | # Timestamp of each sweep 133 | data_time = sweep_interval * counts 134 | 135 | # Apply windowing 136 | sweeps *= wa[:, None, None].cpu() 137 | sweeps *= wr[None, None, :].cpu() 138 | 139 | # FFT radar data in blocks to decrease the maximum needed VRAM 140 | n = int(nsamples * fft_oversample) 141 | fsweeps = torch.zeros((sweeps.shape[0], n // 2 + 1), dtype=data_dtype, device=dev) 142 | blocks = 16 143 | block = (sweeps.shape[0] + blocks - 1) // blocks 144 | for b in range(blocks): 145 | s0 = b * block 146 | s1 = min((b + 1) * block, sweeps.shape[0]) 147 | fsw = torch.fft.rfft( 148 | sweeps[s0:s1, 0, :].to(device=dev), n=n, norm="forward", dim=-1 149 | ) 150 | fsw = torch.conj(fsw) 151 | fsw *= rvp[None, :] 152 | fsweeps[s0:s1] = fsw.to(dtype=data_dtype) 153 | del sweeps 154 | del fsw 155 | 156 | pos = pos.to(device=dev) 157 | data_time = data_time.to(device=dev) 158 | 159 | print("Calculating autofocus. This might take a while.") 160 | 161 | origin = torch.tensor([torch.mean(pos[:,0]), torch.mean(pos[:,1]), 0], 162 | device=dev, dtype=torch.float32)[None,:] 163 | pos_centered = pos - origin 164 | sar_img, phi = torchbp.autofocus.gpga_ml_bp_polar(None, fsweeps, 165 | pos_centered, fc, r_res, grid_polar_autofocus, 166 | window_width=nsweeps, d0=d0, target_threshold_db=15) 167 | 168 | d = torchbp.util.phase_to_distance(phi, fc) 169 | d -= torch.mean(d) 170 | 171 | plt.figure() 172 | plt.plot(d.cpu().numpy()) 173 | plt.xlabel("Sweep index") 174 | plt.ylabel("Solved position error (m)") 175 | 176 | pos_centered[:,0] = pos_centered[:,0] + d 177 | 178 | print("Focusing final image") 179 | torch.cuda.synchronize() 180 | tstart = time.time() 181 | if ffbp: 182 | sar_img = torchbp.ops.ffbp(fsweeps, grid_polar, fc, r_res, pos_centered, 183 | stages=5, divisions=2, d0=d0) 184 | else: 185 | sar_img = torchbp.ops.backprojection_polar_2d( 186 | fsweeps, grid_polar, fc, r_res, pos_centered, d0)[0] 187 | torch.cuda.synchronize() 188 | print(f"Final image created in {time.time() - tstart:.3g} s") 189 | print("Entropy", torchbp.util.entropy(sar_img).item()) 190 | sar_img = sar_img.cpu().numpy() 191 | 192 | plt.figure() 193 | extent = [ 194 | grid_polar["r"][0], 195 | grid_polar["r"][1], 196 | grid_polar["theta"][0], 197 | grid_polar["theta"][1], 198 | ] 199 | abs_img = np.abs(sar_img) 200 | m = 20 * np.log10(np.median(abs_img)) - 13 201 | plt.imshow( 202 | 20 * np.log10(abs_img).T, aspect="auto", origin="lower", extent=extent, vmin=m 203 | ) 204 | plt.grid(False) 205 | plt.xlabel("Range (m)") 206 | plt.ylabel("Angle (sin(radians))") 207 | print("Exporting image") 208 | plt.savefig("sar_img.png", dpi=400) 209 | 210 | # Export image as pickle file 211 | with open("sar_img.p", "wb") as f: 212 | origin = origin.cpu().numpy().squeeze() 213 | pickle.dump((sar_img, mission, grid, grid_polar, origin, origin_angle), f) 214 | 215 | plt.show(block=True) 216 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [build-system] 2 | requires = [ 3 | "setuptools", 4 | "torch", 5 | "numpy", 6 | "scipy", 7 | "safetensors", 8 | ] 9 | build-backend = "setuptools.build_meta" 10 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | setuptools 2 | torch 3 | numpy 4 | scipy 5 | safetensors 6 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import glob 4 | 5 | from setuptools import find_packages, setup 6 | 7 | from torch.utils.cpp_extension import ( 8 | CppExtension, 9 | CUDAExtension, 10 | BuildExtension, 11 | CUDA_HOME, 12 | ) 13 | 14 | library_name = "torchbp" 15 | 16 | def get_extensions(): 17 | debug_mode = os.getenv("DEBUG", "0") == "1" 18 | use_cuda = os.getenv("USE_CUDA", "1") == "1" 19 | if debug_mode: 20 | print("Compiling in debug mode") 21 | 22 | use_cuda = use_cuda and torch.cuda.is_available() and CUDA_HOME is not None 23 | extension = CUDAExtension if use_cuda else CppExtension 24 | 25 | if use_cuda: 26 | print("Compiling with cuda support") 27 | else: 28 | print("No cuda support") 29 | 30 | extra_link_args = ["-fopenmp"] 31 | extra_compile_args = { 32 | "cxx": [ 33 | "-O3" if not debug_mode else "-O0", 34 | "-fdiagnostics-color=always", 35 | "-fopenmp", 36 | ], 37 | "nvcc": [ 38 | "-O3" if not debug_mode else "-O0", 39 | "-DLIBCUDACXX_ENABLE_SIMPLIFIED_COMPLEX_OPERATIONS", 40 | "--use_fast_math", 41 | ], 42 | } 43 | if debug_mode: 44 | extra_compile_args["cxx"].append("-g") 45 | extra_compile_args["nvcc"].extend(["-g", "-G"]) 46 | extra_link_args.extend(["-O0", "-g"]) 47 | 48 | this_dir = os.path.dirname(os.path.curdir) 49 | extensions_dir = os.path.join(this_dir, library_name, "csrc") 50 | sources = list(glob.glob(os.path.join(extensions_dir, "*.cpp"))) 51 | 52 | extensions_cuda_dir = os.path.join(extensions_dir, "cuda") 53 | cuda_sources = list(glob.glob(os.path.join(extensions_cuda_dir, "*.cu"))) 54 | 55 | if use_cuda: 56 | sources += cuda_sources 57 | 58 | ext_modules = [ 59 | extension( 60 | f"{library_name}._C", 61 | sources, 62 | extra_compile_args=extra_compile_args, 63 | extra_link_args=extra_link_args, 64 | ) 65 | ] 66 | 67 | return ext_modules 68 | 69 | 70 | setup( 71 | name=library_name, 72 | version="0.0.1", 73 | packages=find_packages(), 74 | ext_modules=get_extensions(), 75 | install_requires=["torch", "numpy"], 76 | extras_require = { 77 | 'docs': [ 78 | "matplotlib >=3.5", 79 | "nbval >=0.9", 80 | "jupyter-client >=7.3.5", 81 | "sphinx-rtd-theme >=1.0", 82 | "sphinx >=4", 83 | "nbsphinx >= 0.8.9", 84 | "openpyxl >= 3", 85 | "lxml-html-clean >= 0.4.1"] 86 | }, 87 | description="Differentiable synthetic aperture radar library", 88 | long_description=open("Readme.md").read(), 89 | long_description_content_type="text/markdown", 90 | #url="", 91 | cmdclass={"build_ext": BuildExtension}, 92 | ) 93 | -------------------------------------------------------------------------------- /tests/benchmark_backprojection.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | import torch 3 | import torchbp 4 | import time 5 | import numpy as np 6 | import torch.utils.benchmark as benchmark 7 | 8 | device = "cuda" 9 | 10 | nbatch = 1 11 | nr = 1024 12 | ntheta = 1024 13 | nsweeps = 16 if device == "cpu" else 1024 14 | nsamples = 1024 15 | data_dtype = torch.complex64 16 | 17 | fc = 6e9 18 | r_res = 0.5 19 | 20 | grid_polar = {"r": (10, 500), "theta": (-1, 1), "nr": nr, "ntheta": ntheta} 21 | 22 | data = torch.randn((nbatch, nsweeps, nsamples), dtype=data_dtype, device=device) 23 | 24 | pos = torch.zeros((nbatch, nsweeps, 3), dtype=torch.float32, device=device) 25 | pos[:,:,1] = 0.25 * 3e8/fc * (torch.arange(nsweeps, dtype=torch.float32, device=device) - nsweeps/2) 26 | vel = torch.zeros((nbatch, nsweeps, 3), dtype=torch.float32, device=device) 27 | att = torch.zeros((nbatch, nsweeps, 3), dtype=torch.float32, device=device) 28 | 29 | pos.requires_grad = True 30 | 31 | backprojs = nbatch * nr * ntheta * nsweeps 32 | 33 | iterations = 10 34 | 35 | tf = benchmark.Timer( 36 | stmt='torchbp.ops.backprojection_polar_2d(data, grid_polar, fc, r_res, pos, vel, att)', 37 | setup='import torchbp', 38 | globals={'data': data, 'grid_polar': grid_polar, 'fc': fc, 'r_res': r_res, 'pos': pos, 'vel': vel, 'att': att}) 39 | 40 | tb = benchmark.Timer( 41 | stmt='torch.mean(torch.abs(torchbp.ops.backprojection_polar_2d(data, grid_polar, fc, r_res, pos, vel, att))).backward()', 42 | setup='import torchbp; ', 43 | globals={'data': data, 'grid_polar': grid_polar, 'fc': fc, 'r_res': r_res, 'pos': pos, 'vel': vel, 'att': att}) 44 | 45 | f = tf.timeit(iterations).median 46 | print(f"Device {device}, Forward: {backprojs / f:.3g} backprojections/s") 47 | b = tb.timeit(iterations).median 48 | print(f"Device {device}, Backward: {backprojs / (b - f):.3g} backprojections/s") 49 | -------------------------------------------------------------------------------- /tests/test_torchbp.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | import torch 3 | from torch.testing._internal.common_utils import TestCase 4 | from torch.testing._internal.optests import opcheck 5 | import unittest 6 | import torchbp 7 | from torch import Tensor 8 | from typing import Tuple 9 | import torch.nn.functional as F 10 | 11 | class TestEntropy(TestCase): 12 | def sample_inputs(self, device, *, requires_grad=False, dtype=torch.complex64): 13 | def make_tensor(size, dtype=dtype): 14 | x = torch.randn(size, device=device, requires_grad=requires_grad, dtype=dtype) 15 | return x 16 | 17 | def make_nondiff_tensor(size, dtype=dtype): 18 | return torch.randn(size, device=device, requires_grad=False, dtype=dtype) 19 | 20 | args = { 21 | 'img': make_tensor((3, 3), dtype=dtype) 22 | } 23 | return [args] 24 | 25 | @unittest.skipIf(not torch.cuda.is_available(), "requires cuda") 26 | def test_ref(self): 27 | samples = self.sample_inputs("cuda", requires_grad=True) 28 | for sample in samples: 29 | sample_cpu = {k:sample[k].detach().cpu() if isinstance(sample[k], torch.Tensor) else sample[k] for k in sample.keys()} 30 | for k in sample.keys(): 31 | if isinstance(sample[k], torch.Tensor) and sample[k].requires_grad: 32 | sample_cpu[k].requires_grad = True 33 | 34 | res_gpu = torchbp.ops.entropy(sample["img"]) 35 | res_gpu.backward() 36 | grads_gpu = [sample[k].cpu() for k in sample.keys() if isinstance(sample[k], torch.Tensor) and sample[k].requires_grad] 37 | 38 | res_cpu = torchbp.util.entropy(sample_cpu["img"]) 39 | res_cpu.backward() 40 | grads_cpu = [sample_cpu[k] for k in sample_cpu.keys() if isinstance(sample_cpu[k], torch.Tensor) and sample_cpu[k].requires_grad] 41 | torch.testing.assert_close(grads_cpu, grads_gpu) 42 | torch.testing.assert_close(res_gpu.cpu(), res_cpu) 43 | 44 | class TestPolarInterpLinear(TestCase): 45 | def sample_inputs(self, device, *, requires_grad=False, dtype=torch.float32): 46 | def make_tensor(size, dtype=dtype): 47 | x = torch.randn(size, device=device, requires_grad=requires_grad, dtype=dtype) 48 | return x 49 | 50 | def make_nondiff_tensor(size, dtype=dtype): 51 | return torch.randn(size, device=device, requires_grad=False, dtype=dtype) 52 | 53 | complex_dtype = torch.complex64 if dtype == torch.float32 else torch.complex128 54 | nbatch = 2 55 | grid_polar = {"r": (10, 20), "theta": (-1, 1), "nr": 2, "ntheta": 2} 56 | grid_polar_new = {"r": (12, 18), "theta": (-0.8, 0.8), "nr": 3, "ntheta": 3} 57 | dorigin = 0.1*make_tensor((nbatch, 3), dtype=dtype) 58 | args = { 59 | 'img': make_tensor((nbatch, grid_polar["nr"], grid_polar["ntheta"]), dtype=complex_dtype), 60 | 'dorigin': dorigin, 61 | 'grid_polar': grid_polar, 62 | 'fc': 6e9, 63 | 'rotation': 0.3, 64 | 'grid_polar_new': grid_polar_new, 65 | 'z0': 2 66 | } 67 | return [args] 68 | 69 | @unittest.skipIf(not torch.cuda.is_available(), "requires cuda") 70 | def test_cpu_and_gpu_grad(self): 71 | samples = self.sample_inputs("cuda", requires_grad=True) 72 | for sample in samples: 73 | sample_cpu = {k:sample[k].detach().cpu() if isinstance(sample[k], torch.Tensor) else sample[k] for k in sample.keys()} 74 | for k in sample.keys(): 75 | if isinstance(sample[k], torch.Tensor) and sample[k].requires_grad: 76 | sample_cpu[k].requires_grad = True 77 | 78 | res_gpu = torchbp.ops.polar_interp_linear(**sample) 79 | loss_gpu = torch.mean(torch.abs(res_gpu)) 80 | loss_gpu.backward() 81 | grads_gpu = [sample[k].cpu() for k in sample.keys() if isinstance(sample[k], torch.Tensor) and sample[k].requires_grad] 82 | 83 | res_cpu = torchbp.ops.polar_interp_linear(**sample_cpu) 84 | loss_cpu = torch.mean(torch.abs(res_cpu)) 85 | loss_cpu.backward() 86 | grads_cpu = [sample_cpu[k] for k in sample_cpu.keys() if isinstance(sample_cpu[k], torch.Tensor) and sample_cpu[k].requires_grad] 87 | torch.testing.assert_close(grads_cpu, grads_gpu, atol=1e-3, rtol=1e-2) 88 | 89 | @unittest.skipIf(not torch.cuda.is_available(), "requires cuda") 90 | def test_cpu_and_gpu(self): 91 | samples = self.sample_inputs("cuda") 92 | for sample in samples: 93 | res_gpu = torchbp.ops.polar_interp_linear(**sample).cpu() 94 | sample_cpu = {k:sample[k].cpu() if isinstance(sample[k], torch.Tensor) else sample[k] for k in sample.keys()} 95 | res_cpu = torchbp.ops.polar_interp_linear(**sample_cpu) 96 | torch.testing.assert_close(res_cpu, res_gpu, rtol=5e-4, atol=5e-4) 97 | 98 | def _test_gradients(self, device, dtype=torch.float32): 99 | samples = self.sample_inputs(device, requires_grad=True, dtype=dtype) 100 | eps = 1e-3 if dtype == torch.float32 else 1e-4 101 | rtol = 0.15 if dtype == torch.float32 else 0.05 102 | for args in samples: 103 | torch.autograd.gradcheck( 104 | torchbp.ops.polar_interp_linear, 105 | list(args.values()), 106 | eps=eps, # This test is very sensitive to eps 107 | rtol=rtol, # Also to rtol 108 | ) 109 | 110 | def test_gradients_cpu(self): 111 | self._test_gradients("cpu") 112 | self._test_gradients("cpu", dtype=torch.float64) 113 | 114 | @unittest.skipIf(not torch.cuda.is_available(), "requires cuda") 115 | def test_gradients_cuda(self): 116 | self._test_gradients("cuda") 117 | 118 | class TestPolarToCartLinear(TestCase): 119 | def sample_inputs(self, device, *, requires_grad=False, dtype=torch.float32): 120 | def make_tensor(size, dtype=dtype): 121 | x = torch.randn(size, device=device, requires_grad=requires_grad, dtype=dtype) 122 | return x 123 | 124 | def make_nondiff_tensor(size, dtype=dtype): 125 | return torch.randn(size, device=device, requires_grad=False, dtype=dtype) 126 | 127 | complex_dtype = torch.complex64 if dtype == torch.float32 else torch.complex128 128 | nbatch = 2 129 | grid_polar = {"r": (10, 20), "theta": (-1, 1), "nr": 2, "ntheta": 2} 130 | grid_cart = {"x": (12, 18), "y": (-5, 5), "nx": 3, "ny": 3} 131 | origin = 0.1 * make_tensor((nbatch, 3), dtype=dtype) 132 | origin[:,2] += 4 # Offset height 133 | args = { 134 | 'img': make_tensor((nbatch, grid_polar["nr"], grid_polar["ntheta"]), dtype=complex_dtype), 135 | 'origin': origin, 136 | 'grid_polar': grid_polar, 137 | 'grid_cart': grid_cart, 138 | 'fc': 6e9, 139 | 'rotation': 0.1, 140 | } 141 | return [args] 142 | 143 | #@unittest.skipIf(not torch.cuda.is_available(), "requires cuda") 144 | @unittest.skip 145 | def test_cpu_and_gpu_grad(self): 146 | samples = self.sample_inputs("cuda", requires_grad=True) 147 | for sample in samples: 148 | sample_cpu = {k:sample[k].detach().cpu() if isinstance(sample[k], torch.Tensor) else sample[k] for k in sample.keys()} 149 | for k in sample.keys(): 150 | if isinstance(sample[k], torch.Tensor) and sample[k].requires_grad: 151 | sample_cpu[k].requires_grad = True 152 | 153 | res_gpu = torchbp.ops.polar_to_cart_linear(**sample) 154 | loss_gpu = torch.mean(torch.abs(res_gpu)) 155 | loss_gpu.backward() 156 | grads_gpu = [sample[k].cpu() for k in sample.keys() if isinstance(sample[k], torch.Tensor) and sample[k].requires_grad] 157 | 158 | res_cpu = torchbp.ops.polar_to_cart_linear(**sample_cpu) 159 | loss_cpu = torch.mean(torch.abs(res_cpu)) 160 | loss_cpu.backward() 161 | grads_cpu = [sample_cpu[k] for k in sample_cpu.keys() if isinstance(sample_cpu[k], torch.Tensor) and sample_cpu[k].requires_grad] 162 | torch.testing.assert_close(grads_cpu, grads_gpu, atol=1e-3, rtol=1e-2) 163 | 164 | #@unittest.skipIf(not torch.cuda.is_available(), "requires cuda") 165 | @unittest.skip 166 | def test_cpu_and_gpu(self): 167 | samples = self.sample_inputs("cuda") 168 | for sample in samples: 169 | res_gpu = torchbp.ops.polar_to_cart_linear(**sample).cpu() 170 | sample_cpu = {k:sample[k].cpu() if isinstance(sample[k], torch.Tensor) else sample[k] for k in sample.keys()} 171 | res_cpu = torchbp.ops.polar_to_cart_linear(**sample_cpu) 172 | torch.testing.assert_close(res_cpu, res_gpu) 173 | 174 | def _test_gradients(self, device, dtype=torch.float32): 175 | samples = self.sample_inputs(device, requires_grad=True, dtype=dtype) 176 | eps = 5e-4 if dtype == torch.float32 else 1e-4 177 | rtol = 0.15 if dtype == torch.float32 else 0.05 178 | for args in samples: 179 | torch.autograd.gradcheck( 180 | torchbp.ops.polar_to_cart_linear, 181 | list(args.values()), 182 | eps=eps, # This test is very sensitive to eps 183 | rtol=rtol, # Also to rtol 184 | ) 185 | 186 | @unittest.skip 187 | def test_gradients_cpu(self): 188 | self._test_gradients("cpu") 189 | self._test_gradients("cpu", dtype=torch.float64) 190 | 191 | @unittest.skipIf(not torch.cuda.is_available(), "requires cuda") 192 | def test_gradients_cuda(self): 193 | self._test_gradients("cuda") 194 | 195 | class TestPolarToCartBicubic(TestCase): 196 | def sample_inputs(self, device, *, requires_grad=False, dtype=torch.float32): 197 | def make_tensor(size, dtype=dtype): 198 | x = torch.randn(size, device=device, requires_grad=requires_grad, dtype=dtype) 199 | return x 200 | 201 | def make_nondiff_tensor(size, dtype=dtype): 202 | return torch.randn(size, device=device, requires_grad=False, dtype=dtype) 203 | 204 | complex_dtype = torch.complex64 if dtype == torch.float32 else torch.complex128 205 | nbatch = 2 206 | grid_polar = {"r": (10, 20), "theta": (-1, 1), "nr": 3, "ntheta": 3} 207 | grid_cart = {"x": (12, 18), "y": (-5, 5), "nx": 3, "ny": 3} 208 | origin = 0.1*make_tensor((nbatch, 3), dtype=dtype) 209 | origin[:,2] += 4 # Offset height 210 | args = { 211 | 'img': make_tensor((nbatch, grid_polar["nr"], grid_polar["ntheta"]), dtype=complex_dtype), 212 | 'origin': origin, 213 | 'grid_polar': grid_polar, 214 | 'grid_cart': grid_cart, 215 | 'fc': 6e9, 216 | 'rotation': -0.1, 217 | } 218 | return [args] 219 | 220 | #@unittest.skipIf(not torch.cuda.is_available(), "requires cuda") 221 | @unittest.skip 222 | def test_cpu_and_gpu_grad(self): 223 | samples = self.sample_inputs("cuda", requires_grad=True) 224 | for sample in samples: 225 | sample_cpu = {k:sample[k].detach().cpu() if isinstance(sample[k], torch.Tensor) else sample[k] for k in sample.keys()} 226 | for k in sample.keys(): 227 | if isinstance(sample[k], torch.Tensor) and sample[k].requires_grad: 228 | sample_cpu[k].requires_grad = True 229 | 230 | res_gpu = torchbp.ops.polar_to_cart_bicubic(**sample) 231 | loss_gpu = torch.mean(torch.abs(res_gpu)) 232 | loss_gpu.backward() 233 | grads_gpu = [sample[k].cpu() for k in sample.keys() if isinstance(sample[k], torch.Tensor) and sample[k].requires_grad] 234 | 235 | res_cpu = torchbp.ops.polar_to_cart_bicubic(**sample_cpu) 236 | loss_cpu = torch.mean(torch.abs(res_cpu)) 237 | loss_cpu.backward() 238 | grads_cpu = [sample_cpu[k] for k in sample_cpu.keys() if isinstance(sample_cpu[k], torch.Tensor) and sample_cpu[k].requires_grad] 239 | torch.testing.assert_close(grads_cpu, grads_gpu, atol=1e-3, rtol=1e-2) 240 | 241 | #@unittest.skipIf(not torch.cuda.is_available(), "requires cuda") 242 | @unittest.skip 243 | def test_cpu_and_gpu(self): 244 | samples = self.sample_inputs("cuda") 245 | for sample in samples: 246 | res_gpu = torchbp.ops.polar_to_cart_bicubic(**sample).cpu() 247 | sample_cpu = {k:sample[k].cpu() if isinstance(sample[k], torch.Tensor) else sample[k] for k in sample.keys()} 248 | res_cpu = torchbp.ops.polar_to_cart_bicubic(**sample_cpu) 249 | torch.testing.assert_close(res_cpu, res_gpu) 250 | 251 | def _test_gradients(self, device, dtype=torch.float32): 252 | samples = self.sample_inputs(device, requires_grad=True, dtype=dtype) 253 | eps = 7e-4 if dtype == torch.float32 else 1e-4 254 | rtol = 0.2 if dtype == torch.float32 else 0.05 255 | for args in samples: 256 | torch.autograd.gradcheck( 257 | torchbp.ops.polar_to_cart_bicubic, 258 | list(args.values()), 259 | eps=eps, # This test is very sensitive to eps 260 | rtol=rtol, # Also to rtol 261 | ) 262 | 263 | @unittest.skip 264 | def test_gradients_cpu(self): 265 | self._test_gradients("cpu") 266 | self._test_gradients("cpu", dtype=torch.float64) 267 | 268 | @unittest.skipIf(not torch.cuda.is_available(), "requires cuda") 269 | def test_gradients_cuda(self): 270 | self._test_gradients("cuda") 271 | 272 | class TestBackprojectionPolar(TestCase): 273 | def sample_inputs(self, device, *, requires_grad=False): 274 | def make_tensor(size, dtype=torch.float32): 275 | x = torch.randn(size, device=device, requires_grad=requires_grad, dtype=dtype) 276 | return x 277 | 278 | # Make sure that scene is in view 279 | def make_pos_tensor(size, dtype=torch.float32): 280 | x = torch.randn(size, device=device, requires_grad=requires_grad, dtype=dtype) 281 | x = x - torch.max(x[:,0]) - 2 282 | return x 283 | 284 | def make_nondiff_tensor(size, dtype=torch.float32): 285 | return torch.randn(size, device=device, requires_grad=False, dtype=dtype) 286 | 287 | nbatch = 2 288 | sweeps = 2 289 | sweep_samples = 64 290 | grid = {"r": (1, 10), "theta": (-0.9, 0.9), "nr": 4, "ntheta": 4} 291 | args = { 292 | 'data': make_tensor((nbatch, sweeps, sweep_samples), dtype=torch.complex64), 293 | 'grid': grid, 294 | 'fc': 6e9, 295 | 'r_res': 0.15, 296 | 'pos': make_pos_tensor((nbatch, sweeps, 3), dtype=torch.float32), 297 | 'd0': 0.2, 298 | } 299 | return [args] 300 | 301 | @unittest.skipIf(not torch.cuda.is_available(), "requires cuda") 302 | def test_cpu_and_gpu(self): 303 | samples = self.sample_inputs("cuda") 304 | for sample in samples: 305 | res_gpu = torchbp.ops.backprojection_polar_2d(**sample).cpu() 306 | sample_cpu = {k:sample[k].cpu() if isinstance(sample[k], torch.Tensor) else sample[k] for k in sample.keys()} 307 | res_cpu = torchbp.ops.backprojection_polar_2d(**sample_cpu) 308 | torch.testing.assert_close(res_cpu, res_gpu, atol=1e-3, rtol=1e-2) 309 | 310 | @unittest.skipIf(not torch.cuda.is_available(), "requires cuda") 311 | def test_cpu_and_gpu_grad(self): 312 | samples = self.sample_inputs("cuda", requires_grad=True) 313 | for sample in samples: 314 | sample_cpu = {k:sample[k].detach().cpu() if isinstance(sample[k], torch.Tensor) else sample[k] for k in sample.keys()} 315 | for k in sample.keys(): 316 | if isinstance(sample[k], torch.Tensor) and sample[k].requires_grad: 317 | sample_cpu[k].requires_grad = True 318 | 319 | res_gpu = torchbp.ops.backprojection_polar_2d(**sample) 320 | loss_gpu = torch.mean(torch.abs(res_gpu)) 321 | loss_gpu.backward() 322 | grads_gpu = [sample[k].cpu() for k in sample.keys() if isinstance(sample[k], torch.Tensor) and sample[k].requires_grad] 323 | 324 | res_cpu = torchbp.ops.backprojection_polar_2d(**sample_cpu) 325 | loss_cpu = torch.mean(torch.abs(res_cpu)) 326 | loss_cpu.backward() 327 | grads_cpu = [sample_cpu[k] for k in sample_cpu.keys() if isinstance(sample_cpu[k], torch.Tensor) and sample_cpu[k].requires_grad] 328 | torch.testing.assert_close(grads_cpu, grads_gpu, atol=1e-3, rtol=1e-2) 329 | 330 | def _test_gradients(self, device): 331 | samples = self.sample_inputs(device, requires_grad=True) 332 | for args in samples: 333 | torch.autograd.gradcheck( 334 | torchbp.ops.backprojection_polar_2d, 335 | list(args.values()), 336 | eps=5e-4, # This test is very sensitive to eps 337 | rtol=0.2, # Also to rtol 338 | atol=0.05 339 | ) 340 | 341 | def test_gradients_cpu(self): 342 | self._test_gradients("cpu") 343 | 344 | @unittest.skipIf(not torch.cuda.is_available(), "requires cuda") 345 | def test_gradients_cuda(self): 346 | self._test_gradients("cuda") 347 | 348 | class TestBackprojectionCart(TestCase): 349 | 350 | def sample_inputs(self, device, *, requires_grad=False): 351 | def make_tensor(size, dtype=torch.float32): 352 | x = torch.randn(size, device=device, requires_grad=requires_grad, dtype=dtype) 353 | return x 354 | 355 | # Make sure that scene is in view 356 | def make_pos_tensor(size, dtype=torch.float32): 357 | x = torch.randn(size, device=device, requires_grad=requires_grad, dtype=dtype) 358 | x = x - torch.max(x[:,0]) - 2 359 | return x 360 | 361 | def make_nondiff_tensor(size, dtype=torch.float32): 362 | return torch.randn(size, device=device, requires_grad=False, dtype=dtype) 363 | 364 | nbatch = 2 365 | sweeps = 2 366 | sweep_samples = 128 367 | grid = {"x": (2, 10), "y": (-5, 5), "nx": 4, "ny": 4} 368 | args = { 369 | 'data': make_tensor((nbatch, sweeps, sweep_samples), dtype=torch.complex64), 370 | 'grid': grid, 371 | 'fc': 6e9, 372 | 'r_res': 0.15, 373 | 'pos': make_pos_tensor((nbatch, sweeps, 3), dtype=torch.float32), 374 | 'beamwidth': 3.14, 375 | 'd0': 0.2, 376 | } 377 | return [args] 378 | 379 | def _test_gradients(self, device): 380 | samples = self.sample_inputs(device, requires_grad=True) 381 | for args in samples: 382 | torch.autograd.gradcheck( 383 | torchbp.ops.backprojection_cart_2d, 384 | list(args.values()), 385 | eps=5e-4, # This test is very sensitive to eps 386 | rtol=0.2, # Also to rtol 387 | atol=0.05 388 | 389 | ) 390 | 391 | #def test_gradients_cpu(self): 392 | # self._test_gradients("cpu") 393 | 394 | @unittest.skipIf(not torch.cuda.is_available(), "requires cuda") 395 | def test_gradients_cuda(self): 396 | self._test_gradients("cuda") 397 | 398 | #def _opcheck(self, device): 399 | # # Use opcheck to check for incorrect usage of operator registration APIs 400 | # samples = self.sample_inputs(device, requires_grad=True) 401 | # samples.extend(self.sample_inputs(device, requires_grad=False)) 402 | # for args in samples: 403 | # opcheck(torch.ops.torchbp.backprojection_cart_2d, list(args.values())) 404 | 405 | #def test_opcheck_cpu(self): 406 | # self._opcheck("cpu") 407 | 408 | #@unittest.skipIf(not torch.cuda.is_available(), "requires cuda") 409 | #def test_opcheck_cuda(self): 410 | # self._opcheck("cuda") 411 | 412 | if __name__ == "__main__": 413 | unittest.main() 414 | -------------------------------------------------------------------------------- /torchbp/__init__.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from . import _C, ops, autofocus, util 3 | -------------------------------------------------------------------------------- /torchbp/autofocus.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | from torch import Tensor 4 | from .ops import ( 5 | backprojection_polar_2d, 6 | backprojection_cart_2d, 7 | gpga_backprojection_2d_core, 8 | ) 9 | from .ops import entropy 10 | from .util import unwrap, detrend, fft_lowpass_filter_window 11 | import inspect 12 | from scipy import signal 13 | 14 | 15 | def pga_estimator(g: Tensor, estimator: str = "wls", eps: float = 1e-3) -> Tensor: 16 | """ 17 | Estimate phase error from set of measurements. 18 | 19 | Parameters 20 | ---------- 21 | g : Tensor 22 | Demodulated phase from each target. Shape [Ntargets, Nazimuth]. 23 | estimator : str 24 | Estimator to use. 25 | "pd": Phase difference. [1]_ 26 | "ml": Maximum likelihood. [2]_ 27 | "wls": Weighted least squares using estimated signal-to-clutter weighting. [3]_ 28 | eps : float 29 | Minimum weight for weighted PGA. 30 | Normalized to the maximum weight. 31 | 32 | References 33 | ---------- 34 | .. [1] D. E. Wahl, P. H. Eichel, D. C. Ghiglia and C. V. Jakowatz, "Phase 35 | gradient autofocus - A robust tool for high resolution SAR phase 36 | correction," in IEEE Transactions on Aerospace and Electronic Systems, vol. 37 | 30, no. 3, pp. 827-835, July 1994. 38 | 39 | .. [2] Charles V. Jakowatz and Daniel E. Wahl, "Eigenvector method for 40 | maximum-likelihood estimation of phase errors in synthetic-aperture-radar 41 | imagery," J. Opt. Soc. Am. A 10, 2539-2546 (1993). 42 | 43 | .. [3] Wei Ye, Tat Soon Yeo and Zheng Bao, "Weighted least-squares 44 | estimation of phase errors for SAR/ISAR autofocus," in IEEE Transactions on 45 | Geoscience and Remote Sensing, vol. 37, no. 5, pp. 2487-2494, Sept. 1999. 46 | 47 | Returns 48 | ---------- 49 | phi : Tensor 50 | Solved phase error. 51 | """ 52 | if estimator == "ml": 53 | u, s, v = torch.linalg.svd(g) 54 | phi = torch.angle(v[0, :]) 55 | elif estimator == "wls": 56 | c = torch.mean(torch.abs(g), dim=1, keepdim=True) 57 | d = torch.mean(torch.abs(g) ** 2, dim=1, keepdim=True) 58 | w = torch.nan_to_num( 59 | d / (4 * (2 * c**2 - d) - 4 * c * torch.sqrt(4 * c**2 - 3 * d)) 60 | ) 61 | w = w / (torch.max(w) + 1e-6) + eps 62 | gshift = torch.nn.functional.pad(g[..., :-1], (1, 0)) 63 | phidot = torch.angle(torch.sum(w * (g * torch.conj(gshift)), axis=0)) 64 | phi = torch.cumsum(phidot, dim=0) 65 | elif estimator == "pd": 66 | z = torch.zeros((g.shape[0], 1), device=g.device, dtype=g.dtype) 67 | gdot = torch.diff(g, prepend=z, axis=-1) 68 | phidot = torch.sum((torch.conj(g) * gdot).imag, axis=0) / torch.sum( 69 | torch.abs(g) ** 2, axis=0 70 | ) 71 | phi = torch.cumsum(phidot, dim=0) 72 | else: 73 | raise ValueError(f"Unknown estimator {estimator}") 74 | return phi 75 | 76 | 77 | def pga( 78 | img: Tensor, 79 | window_width: int | None = None, 80 | max_iters: int = 10, 81 | window_exp: float = 0.5, 82 | min_window: int = 5, 83 | remove_trend: bool = True, 84 | offload: bool = False, 85 | estimator: str = "wls", 86 | eps=1e-2, 87 | ) -> (Tensor, Tensor): 88 | """ 89 | Phase gradient autofocus 90 | 91 | Parameters 92 | ---------- 93 | img : Tensor 94 | Complex input image. Shape should be: [Range, azimuth]. 95 | window_width : int 96 | Initial window width. Default is None which uses full image size. 97 | max_iter : int 98 | Maximum number of iterations. 99 | window_exp : float 100 | Exponent for decreasing the window size for each iteration. 101 | min_window : int 102 | Minimum window size. 103 | remove_trend : bool 104 | Remove linear trend that shifts the image. 105 | offload : bool 106 | Offload some variable to CPU to save VRAM on GPU at 107 | the expense of longer running time. 108 | estimator : str 109 | Estimator to use. 110 | See `pga_estimator` function for possible choices. 111 | eps : float 112 | Minimum weight for weighted PGA. 113 | 114 | Returns 115 | ---------- 116 | img : Tensor 117 | Focused image. 118 | phi : Tensor 119 | Solved phase error. 120 | """ 121 | if img.ndim != 2: 122 | raise ValueError("Input image should be 2D.") 123 | if window_exp > 1 or window_exp < 0: 124 | raise ValueError(f"Invalid window_exp {window_exp}") 125 | nr, ntheta = img.shape 126 | phi_sum = torch.zeros(ntheta, device=img.device) 127 | if window_width is None: 128 | window_width = ntheta 129 | if window_width > ntheta: 130 | window_width = ntheta 131 | x = np.arange(ntheta) 132 | dev = img.device 133 | for i in range(max_iters): 134 | window = int(window_width * window_exp**i) 135 | if window < min_window: 136 | break 137 | # Peak for each range bin 138 | g = img.clone() 139 | if offload: 140 | img = img.to(device="cpu") 141 | rpeaks = torch.argmax(torch.abs(g), axis=1) 142 | # Roll theta axis so that peak is at 0 bin 143 | for j in range(nr): 144 | g[j, :] = torch.roll(g[j, :], -rpeaks[j].item()) 145 | # Apply window 146 | g[:, 1 + window // 2 : 1 - window // 2] = 0 147 | # IFFT across theta 148 | g = torch.fft.fft(g, axis=-1) 149 | phi = pga_estimator(g, estimator, eps) 150 | del g 151 | if remove_trend: 152 | phi = detrend(unwrap(phi)) 153 | phi_sum += phi 154 | 155 | if offload: 156 | img = img.to(device=dev) 157 | img_ifft = torch.fft.fft(img, axis=-1) 158 | img_ifft *= torch.exp(-1j * phi[None, :]) 159 | img = torch.fft.ifft(img_ifft, axis=-1) 160 | 161 | return img, phi_sum 162 | 163 | 164 | def gpga_bp_polar( 165 | img: Tensor | None, 166 | data: Tensor, 167 | pos: Tensor, 168 | fc: float, 169 | r_res: float, 170 | grid_polar: dict, 171 | window_width: int | None = None, 172 | max_iters: int = 10, 173 | window_exp: float = 0.8, 174 | min_window: int = 5, 175 | d0: float = 0.0, 176 | target_threshold_db: float = 20, 177 | remove_trend: bool = True, 178 | estimator: str = "pd", 179 | lowpass_window: str = "boxcar", 180 | eps: float = 1e-3, 181 | interp_method: str="linear" 182 | ) -> (Tensor, Tensor): 183 | """ 184 | Generalized phase gradient autofocus using 2D polar coordinate 185 | backprojection image formation. 186 | 187 | Parameters 188 | ---------- 189 | img : Tensor or None 190 | Complex input image. Shape should be: [Range, azimuth]. 191 | If None image is generated from the data. 192 | data : Tensor 193 | Range compressed input data. Shape should be [nsweeps, samples]. 194 | pos : Tensor 195 | Position of the platform at each data point. Shape should be [nsweeps, 3]. 196 | fc : float 197 | RF center frequency in Hz. 198 | r_res : float 199 | Range bin resolution in data (meters). 200 | For FMCW radar: c/(2*bw*oversample), where c is speed of light, bw is sweep bandwidth, 201 | and oversample is FFT oversampling factor. 202 | grid_polar : dict 203 | Grid definition. Dictionary with keys "r", "theta", "nr", "ntheta". 204 | "r": (r0, r1), tuple of min and max range, 205 | "theta": (theta0, theta1), sin of min and max angle. (-1, 1) for 180 degree view. 206 | "nr": nr, number of range bins. 207 | "ntheta": number of angle bins. 208 | window_width : int or None 209 | Initial low-pass filter window width in samples. None for initial 210 | maximum size. 211 | max_iters : int 212 | Maximum number of iterations. 213 | window_exp : float 214 | Exponent on window_width decrease for each iteration. 215 | min_window : int 216 | Minimum window size. 217 | d0 : float 218 | Zero range correction. 219 | target_threshold_db : float 220 | Filter out targets that are this many dB below the maximum amplitude 221 | target. 222 | remove_trend : bool 223 | Remove linear trend in phase correction. 224 | estimator : str 225 | Estimator to use. 226 | See `pga_estimator` function for possible choices. 227 | lowpass_window : str 228 | FFT window to use for lowpass filtering. 229 | See `scipy.get_window` for syntax. 230 | eps : float 231 | Minimum weight for weighted PGA. 232 | interp_method : str 233 | Interpolation method 234 | "linear": linear interpolation. 235 | ("lanczos", N): Lanczos interpolation with order 2*N+1. 236 | 237 | References 238 | ---------- 239 | .. [#] A. Evers and J. A. Jackson, "A Generalized Phase Gradient Autofocus 240 | Algorithm," in IEEE Transactions on Computational Imaging, vol. 5, no. 4, 241 | pp. 606-619, Dec. 2019. 242 | 243 | Returns 244 | ---------- 245 | img : Tensor 246 | Focused SAR image. 247 | phi : Tensor 248 | Solved phase error. 249 | """ 250 | r0, r1 = grid_polar["r"] 251 | theta0, theta1 = grid_polar["theta"] 252 | ntheta = grid_polar["ntheta"] 253 | nr = grid_polar["nr"] 254 | dtheta = (theta1 - theta0) / ntheta 255 | dr = (r1 - r0) / nr 256 | 257 | phi_sum = torch.zeros(data.shape[0], dtype=torch.float32, device=data.device) 258 | 259 | r = r0 + dr * torch.arange(nr, device=data.device, dtype=torch.float32) 260 | theta = theta0 + dtheta * torch.arange( 261 | ntheta, device=data.device, dtype=torch.float32 262 | ) 263 | pos_new = pos.clone() 264 | 265 | if window_width is None: 266 | window_width = data.shape[0] 267 | 268 | if img is None: 269 | img = backprojection_polar_2d(data, grid_polar, fc, r_res, pos_new)[0] 270 | 271 | for i in range(max_iters): 272 | rpeaks = torch.argmax(torch.abs(img), axis=1) 273 | a = torch.abs(img[torch.arange(img.size(0)), rpeaks]) 274 | max_a = torch.max(a) 275 | 276 | target_idx = a > max_a * 10 ** (-target_threshold_db / 20) 277 | target_theta = theta0 + dtheta * rpeaks[target_idx].to(torch.float32) 278 | target_r = r[target_idx] 279 | 280 | x = target_r * torch.sqrt(1 - target_theta**2) 281 | y = target_r * target_theta 282 | z = torch.zeros_like(target_r) 283 | target_pos = torch.stack([x, y, z], dim=1) 284 | 285 | # Get range profile samples for each target 286 | target_data = gpga_backprojection_2d_core( 287 | target_pos, data, pos_new, fc, r_res, d0, interp_method=interp_method 288 | ) 289 | # Filter samples 290 | if window_width is not None and window_width < target_data.shape[1]: 291 | target_data = fft_lowpass_filter_window( 292 | target_data, window=lowpass_window, window_width=window_width 293 | ) 294 | phi = pga_estimator(target_data, estimator, eps) 295 | phi_sum = unwrap(phi_sum + phi) 296 | if remove_trend: 297 | phi_sum = detrend(phi_sum) 298 | # Phase to distance 299 | c0 = 299792458 300 | d = phi_sum * c0 / (4 * torch.pi * fc) 301 | d = d - torch.mean(d) 302 | pos_new[:, 0] = pos[:, 0] + d 303 | 304 | img = backprojection_polar_2d(data, grid_polar, fc, r_res, pos_new, d0=d0)[0] 305 | window_width = int(window_width**window_exp) 306 | if window_width < min_window: 307 | break 308 | return img, phi_sum 309 | 310 | 311 | def _get_kwargs() -> dict: 312 | frame = inspect.currentframe().f_back 313 | keys, _, _, values = inspect.getargvalues(frame) 314 | kwargs = {} 315 | for key in keys: 316 | if key != "self": 317 | kwargs[key] = values[key] 318 | return kwargs 319 | 320 | 321 | def minimum_entropy_grad_autofocus( 322 | f, 323 | data: Tensor, 324 | data_time: Tensor, 325 | pos: Tensor, 326 | fc: float, 327 | r_res: float, 328 | grid: dict, 329 | wa: Tensor, 330 | tx_norm: Tensor = None, 331 | max_steps: float = 100, 332 | lr_max: float = 10000, 333 | d0: float = 0, 334 | pos_reg: float = 1, 335 | lr_reduce: float = 0.8, 336 | verbose: bool = True, 337 | convergence_limit: float = 0.01, 338 | max_step_limit: float = 0.25, 339 | grad_limit_quantile: float = 0.9, 340 | fixed_pos: int = 0, 341 | minimize_only: bool = False, 342 | ) -> (Tensor, Tensor, Tensor, int): 343 | """ 344 | Minimum entropy autofocus using backpropagation optimization through image 345 | formation. 346 | 347 | Parameters 348 | ---------- 349 | f : function 350 | Radar image generation function. 351 | data : Tensor 352 | Radar data. 353 | data_time : Tensor 354 | Recording time of each data sample. 355 | pos : Tensor 356 | Position at each data sample. 357 | fc : float 358 | RF frequency in Hz. 359 | r_res : float 360 | Range bin resolution in data (meters). 361 | For FMCW radar: c/(2*bw*oversample), where c is speed of light, bw is sweep bandwidth, 362 | and oversample is FFT oversampling factor. 363 | grid : dict 364 | Grid definition. Correct definition depends on the radar image function. 365 | wa : Tensor 366 | Azimuth windowing function. 367 | Should be applied to data already, used for scaling gradient. 368 | tx_norm : Tensor 369 | Radar image is divided by this tensor before calculating entropy. 370 | If None no division is done. 371 | max_steps : int 372 | Maximum number of optimization steps. 373 | lr_max : float 374 | Maximum learning rate. 375 | Too large learning rate is scaled automatically. 376 | d0 : float 377 | Zero range correction. 378 | pos_reg : float 379 | Position regularization value. 380 | lr_reduce : float 381 | Learning rate is multiplied with this value if new entropy is larger than previously. 382 | verbose : bool 383 | Print progress during optimization. 384 | convergence_limit : float 385 | If maximum position change is below this value stop optimization. 386 | Units in wavelengths. 387 | max_step_limit : float 388 | Maximum step size in wavelengths. 389 | grad_limit_quantile : float 390 | Quantile used for maximum step size calculation. 391 | 0 to 1 range. 392 | fixed_pos : int 393 | First `fixed_pos` positions are kept fixed and are not optimized. 394 | minimize_only : bool 395 | Reject steps that would increase image entropy. 396 | 397 | Returns 398 | ---------- 399 | sar_img : Tensor 400 | Optimized radar image. 401 | origin : Tensor 402 | Mean of position tensor. 403 | pos : Tensor 404 | Platform position. 405 | step : int 406 | Number of steps. 407 | """ 408 | dev = data.device 409 | t = data_time.unsqueeze(1) 410 | dt = torch.diff(t, dim=0, prepend=t[0].unsqueeze(0)) 411 | dt[0] = dt[1] 412 | vopt = torch.diff(pos, dim=0, prepend=pos[0].unsqueeze(0)) / dt 413 | pos_mean = torch.mean(pos, dim=0) 414 | 415 | if fixed_pos > 0: 416 | v_fixed = vopt[:fixed_pos].detach().clone() 417 | 418 | pos_orig = pos.clone() 419 | vopt.requires_grad = True 420 | 421 | wl = 3e8 / fc 422 | lr = lr_max 423 | 424 | opt = torch.optim.SGD([vopt], momentum=0, lr=1) 425 | 426 | def lr_sch(epoch): 427 | p = int(0.75 * max_steps) 428 | if epoch > p: 429 | a = -lr / (max_steps + 1 - p) 430 | b = lr * max_steps / (max_steps + 1 - p) 431 | return a * epoch + b 432 | return lr 433 | 434 | scheduler = torch.optim.lr_scheduler.LambdaLR(opt, lr_lambda=lr_sch) 435 | 436 | last_entr = None 437 | v_prev = vopt.detach().clone() 438 | 439 | try: 440 | for step in range(max_steps): 441 | if fixed_pos > 0: 442 | v = torch.cat([v_fixed, vopt[fixed_pos:]], dim=0) 443 | else: 444 | v = vopt 445 | pos = torch.cumsum(v * dt, 0) 446 | pos = pos - torch.mean(pos, dim=0) + pos_mean 447 | # pos_d2 = torch.diff(pos, n=2, dim=0) / dt 448 | 449 | pos_loss = pos_reg * torch.mean(torch.square(pos - pos_orig)) 450 | # acc_loss = acc_reg * torch.mean(torch.square(pos_d2[1:])) 451 | 452 | origin = torch.tensor( 453 | [torch.mean(pos[:, 0]), torch.mean(pos[:, 1]), 0], 454 | device=dev, 455 | dtype=torch.float32, 456 | )[None, :] 457 | pos_centered = pos - origin 458 | 459 | sar_img = f(data, grid, fc, r_res, pos_centered, d0).squeeze() 460 | if tx_norm is not None: 461 | entr = entropy(sar_img / tx_norm) 462 | else: 463 | entr = entropy(sar_img) 464 | loss = entr + pos_loss # + acc_loss 465 | if last_entr is not None and entr > last_entr: 466 | lr *= lr_reduce 467 | if minimize_only: 468 | vopt.data = v_prev.data 469 | continue 470 | last_entr = entr 471 | v_prev = vopt.detach().clone() 472 | if step < max_steps - 1: 473 | loss.backward() 474 | l = scheduler.get_last_lr()[0] 475 | with torch.no_grad(): 476 | vopt.grad /= wa[:, None] 477 | g = vopt.grad.detach() 478 | gpos = torch.cumsum(l * g * dt, 0) 479 | dp = torch.abs(gpos[:, 0]) 480 | maxd = torch.quantile(dp, grad_limit_quantile) 481 | dp = torch.linalg.vector_norm(gpos, dim=-1) 482 | maxd2 = torch.quantile(dp, grad_limit_quantile) 483 | s = max_step_limit * wl / (1e-5 + maxd) 484 | if maxd < convergence_limit * wl: 485 | if verbose: 486 | print("Optimization converged") 487 | break 488 | if s < 1: 489 | vopt.grad *= s 490 | lr *= s.item() 491 | opt.step() 492 | opt.zero_grad() 493 | scheduler.step() 494 | if verbose: 495 | print( 496 | step, 497 | "Entropy", 498 | entr.detach().cpu().numpy(), 499 | "loss", 500 | loss.detach().cpu().numpy(), 501 | ) 502 | except KeyboardInterrupt: 503 | print("Interrupted") 504 | pass 505 | 506 | return sar_img.detach(), origin.detach(), pos.detach(), step 507 | 508 | 509 | def bp_polar_grad_minimum_entropy( 510 | data: Tensor, 511 | data_time: Tensor, 512 | pos: Tensor, 513 | fc: float, 514 | r_res: float, 515 | grid: dict, 516 | wa: Tensor, 517 | tx_norm: Tensor = None, 518 | max_steps: float = 100, 519 | lr_max: float = 10000, 520 | d0: float = 0, 521 | pos_reg: float = 1, 522 | lr_reduce: float = 0.8, 523 | verbose: bool = True, 524 | convergence_limit: float = 0.01, 525 | max_step_limit: float = 0.25, 526 | grad_limit_quantile: float = 0.9, 527 | fixed_pos: int = 0, 528 | ): 529 | """ 530 | Minimum entropy autofocus optimization autofocus. 531 | 532 | Wrapper around `minimum_entropy_autofocus`. 533 | 534 | Parameters 535 | ---------- 536 | data : Tensor 537 | Radar data. 538 | data_time : Tensor 539 | Recording time of each data sample. 540 | pos : Tensor 541 | Position at each data sample. 542 | fc : float 543 | RF frequency in Hz. 544 | r_res : float 545 | Range bin resolution in data (meters). 546 | For FMCW radar: c/(2*bw*oversample), where c is speed of light, bw is sweep bandwidth, 547 | and oversample is FFT oversampling factor. 548 | grid : dict 549 | Grid definition. Correct definition depends on the radar image function. 550 | wa : Tensor 551 | Azimuth windowing function. 552 | Should be applied to data already, used for scaling gradient. 553 | tx_norm : Tensor 554 | Radar image is divided by this tensor before calculating entropy. 555 | If None no division is done. 556 | max_steps : int 557 | Maximum number of optimization steps. 558 | lr_max : float 559 | Maximum learning rate. 560 | Too large learning rate is scaled automatically. 561 | d0 : float 562 | Zero range correction. 563 | pos_reg : float 564 | Position regularization value. 565 | lr_reduce : float 566 | Learning rate is multiplied with this value if new entropy is larger than previously. 567 | verbose : bool 568 | Print progress during optimization. 569 | convergence_limit : float 570 | If maximum position change is below this value stop optimization. 571 | Units in wavelengths. 572 | max_step_limit : float 573 | Maximum step size in wavelengths. 574 | grad_limit_quantile : float 575 | Quantile used for maximum step size calculation. 576 | 0 to 1 range. 577 | fixed_pos : int 578 | First `fixed_pos` positions are kept fixed and are not optimized. 579 | 580 | Returns 581 | ---------- 582 | sar_img : Tensor 583 | Optimized radar image. 584 | origin : Tensor 585 | Mean of position tensor. 586 | pos : Tensor 587 | Platform position. 588 | step : int 589 | Number of steps. 590 | """ 591 | kw = _get_kwargs() 592 | return minimum_entropy_grad_autofocus(backprojection_polar_2d, **kw) 593 | 594 | 595 | def bp_cart_grad_minimum_entropy( 596 | data: Tensor, 597 | data_time: Tensor, 598 | pos: Tensor, 599 | fc: float, 600 | r_res: float, 601 | grid: dict, 602 | wa: Tensor, 603 | tx_norm: Tensor = None, 604 | max_steps: float = 100, 605 | lr_max: float = 10000, 606 | d0: float = 0, 607 | pos_reg: float = 1, 608 | lr_reduce: float = 0.8, 609 | verbose: bool = True, 610 | convergence_limit: float = 0.01, 611 | max_step_limit: float = 0.25, 612 | grad_limit_quantile: float = 0.9, 613 | fixed_pos: int = 0, 614 | ): 615 | """ 616 | Minimum entropy autofocus optimization autofocus. 617 | 618 | Wrapper around `minimum_entropy_autofocus`. 619 | 620 | Parameters 621 | ---------- 622 | data : Tensor 623 | Radar data. 624 | data_time : Tensor 625 | Recording time of each data sample. 626 | pos : Tensor 627 | Position at each data sample. 628 | fc : float 629 | RF frequency in Hz. 630 | r_res : float 631 | Range bin resolution in data (meters). 632 | For FMCW radar: c/(2*bw*oversample), where c is speed of light, bw is sweep bandwidth, 633 | and oversample is FFT oversampling factor. 634 | grid : dict 635 | Grid definition. Correct definition depends on the radar image function. 636 | wa : Tensor 637 | Azimuth windowing function. 638 | Should be applied to data already, used for scaling gradient. 639 | tx_norm : Tensor 640 | Radar image is divided by this tensor before calculating entropy. 641 | If None no division is done. 642 | max_steps : int 643 | Maximum number of optimization steps. 644 | lr_max : float 645 | Maximum learning rate. 646 | Too large learning rate is scaled automatically. 647 | d0 : float 648 | Zero range correction. 649 | pos_reg : float 650 | Position regularization value. 651 | lr_reduce : float 652 | Learning rate is multiplied with this value if new entropy is larger than previously. 653 | verbose : bool 654 | Print progress during optimization. 655 | convergence_limit : float 656 | If maximum position change is below this value stop optimization. 657 | Units in wavelengths. 658 | max_step_limit : float 659 | Maximum step size in wavelengths. 660 | grad_limit_quantile : float 661 | Quantile used for maximum step size calculation. 662 | 0 to 1 range. 663 | fixed_pos : int 664 | First `fixed_pos` positions are kept fixed and are not optimized. 665 | 666 | Returns 667 | ---------- 668 | sar_img : Tensor 669 | Optimized radar image. 670 | origin : Tensor 671 | Mean of position tensor. 672 | pos : Tensor 673 | Platform position. 674 | step : int 675 | Number of steps. 676 | """ 677 | kw = _get_kwargs() 678 | return minimum_entropy_grad_autofocus(backprojection_cart_2d, **kw) 679 | -------------------------------------------------------------------------------- /torchbp/csrc/cuda/std_complex.h: -------------------------------------------------------------------------------- 1 | // -*- C++ -*- 2 | //===--------------------------- complex ----------------------------------===// 3 | // 4 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 5 | // See https://llvm.org/LICENSE.txt for license information. 6 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 7 | // 8 | //===----------------------------------------------------------------------===// 9 | 10 | #ifndef _LIBCUDACXX_COMPLEX 11 | #define _LIBCUDACXX_COMPLEX 12 | 13 | /* 14 | complex synopsis 15 | 16 | namespace std 17 | { 18 | 19 | template 20 | class complex 21 | { 22 | public: 23 | typedef T value_type; 24 | 25 | complex(const T& re = T(), const T& im = T()); // constexpr in C++14 26 | complex(const complex&); // constexpr in C++14 27 | template complex(const complex&); // constexpr in C++14 28 | 29 | T real() const; // constexpr in C++14 30 | T imag() const; // constexpr in C++14 31 | 32 | void real(T); 33 | void imag(T); 34 | 35 | complex& operator= (const T&); 36 | complex& operator+=(const T&); 37 | complex& operator-=(const T&); 38 | complex& operator*=(const T&); 39 | complex& operator/=(const T&); 40 | 41 | complex& operator=(const complex&); 42 | template complex& operator= (const complex&); 43 | template complex& operator+=(const complex&); 44 | template complex& operator-=(const complex&); 45 | template complex& operator*=(const complex&); 46 | template complex& operator/=(const complex&); 47 | }; 48 | 49 | template<> 50 | class complex 51 | { 52 | public: 53 | typedef float value_type; 54 | 55 | constexpr complex(float re = 0.0f, float im = 0.0f); 56 | explicit constexpr complex(const complex&); 57 | explicit constexpr complex(const complex&); 58 | 59 | constexpr float real() const; 60 | void real(float); 61 | constexpr float imag() const; 62 | void imag(float); 63 | 64 | complex& operator= (float); 65 | complex& operator+=(float); 66 | complex& operator-=(float); 67 | complex& operator*=(float); 68 | complex& operator/=(float); 69 | 70 | complex& operator=(const complex&); 71 | template complex& operator= (const complex&); 72 | template complex& operator+=(const complex&); 73 | template complex& operator-=(const complex&); 74 | template complex& operator*=(const complex&); 75 | template complex& operator/=(const complex&); 76 | }; 77 | 78 | template<> 79 | class complex 80 | { 81 | public: 82 | typedef double value_type; 83 | 84 | constexpr complex(double re = 0.0, double im = 0.0); 85 | constexpr complex(const complex&); 86 | explicit constexpr complex(const complex&); 87 | 88 | constexpr double real() const; 89 | void real(double); 90 | constexpr double imag() const; 91 | void imag(double); 92 | 93 | complex& operator= (double); 94 | complex& operator+=(double); 95 | complex& operator-=(double); 96 | complex& operator*=(double); 97 | complex& operator/=(double); 98 | complex& operator=(const complex&); 99 | 100 | template complex& operator= (const complex&); 101 | template complex& operator+=(const complex&); 102 | template complex& operator-=(const complex&); 103 | template complex& operator*=(const complex&); 104 | template complex& operator/=(const complex&); 105 | }; 106 | 107 | template<> 108 | class complex 109 | { 110 | public: 111 | typedef long double value_type; 112 | 113 | constexpr complex(long double re = 0.0L, long double im = 0.0L); 114 | constexpr complex(const complex&); 115 | constexpr complex(const complex&); 116 | 117 | constexpr long double real() const; 118 | void real(long double); 119 | constexpr long double imag() const; 120 | void imag(long double); 121 | 122 | complex& operator=(const complex&); 123 | complex& operator= (long double); 124 | complex& operator+=(long double); 125 | complex& operator-=(long double); 126 | complex& operator*=(long double); 127 | complex& operator/=(long double); 128 | 129 | template complex& operator= (const complex&); 130 | template complex& operator+=(const complex&); 131 | template complex& operator-=(const complex&); 132 | template complex& operator*=(const complex&); 133 | template complex& operator/=(const complex&); 134 | }; 135 | 136 | // 26.3.6 operators: 137 | template complex operator+(const complex&, const complex&); 138 | template complex operator+(const complex&, const T&); 139 | template complex operator+(const T&, const complex&); 140 | template complex operator-(const complex&, const complex&); 141 | template complex operator-(const complex&, const T&); 142 | template complex operator-(const T&, const complex&); 143 | template complex operator*(const complex&, const complex&); 144 | template complex operator*(const complex&, const T&); 145 | template complex operator*(const T&, const complex&); 146 | template complex operator/(const complex&, const complex&); 147 | template complex operator/(const complex&, const T&); 148 | template complex operator/(const T&, const complex&); 149 | template complex operator+(const complex&); 150 | template complex operator-(const complex&); 151 | template bool operator==(const complex&, const complex&); // constexpr in C++14 152 | template bool operator==(const complex&, const T&); // constexpr in C++14 153 | template bool operator==(const T&, const complex&); // constexpr in C++14 154 | template bool operator!=(const complex&, const complex&); // constexpr in C++14 155 | template bool operator!=(const complex&, const T&); // constexpr in C++14 156 | template bool operator!=(const T&, const complex&); // constexpr in C++14 157 | 158 | template 159 | basic_istream& 160 | operator>>(basic_istream&, complex&); 161 | template 162 | basic_ostream& 163 | operator<<(basic_ostream&, const complex&); 164 | 165 | // 26.3.7 values: 166 | 167 | template T real(const complex&); // constexpr in C++14 168 | long double real(long double); // constexpr in C++14 169 | double real(double); // constexpr in C++14 170 | template double real(T); // constexpr in C++14 171 | float real(float); // constexpr in C++14 172 | 173 | template T imag(const complex&); // constexpr in C++14 174 | long double imag(long double); // constexpr in C++14 175 | double imag(double); // constexpr in C++14 176 | template double imag(T); // constexpr in C++14 177 | float imag(float); // constexpr in C++14 178 | 179 | template T abs(const complex&); 180 | 181 | template T arg(const complex&); 182 | long double arg(long double); 183 | double arg(double); 184 | template double arg(T); 185 | float arg(float); 186 | 187 | template T norm(const complex&); 188 | long double norm(long double); 189 | double norm(double); 190 | template double norm(T); 191 | float norm(float); 192 | 193 | template complex conj(const complex&); 194 | complex conj(long double); 195 | complex conj(double); 196 | template complex conj(T); 197 | complex conj(float); 198 | 199 | template complex proj(const complex&); 200 | complex proj(long double); 201 | complex proj(double); 202 | template complex proj(T); 203 | complex proj(float); 204 | 205 | template complex polar(const T&, const T& = T()); 206 | 207 | // 26.3.8 transcendentals: 208 | template complex acos(const complex&); 209 | template complex asin(const complex&); 210 | template complex atan(const complex&); 211 | template complex acosh(const complex&); 212 | template complex asinh(const complex&); 213 | template complex atanh(const complex&); 214 | template complex cos (const complex&); 215 | template complex cosh (const complex&); 216 | template complex exp (const complex&); 217 | template complex log (const complex&); 218 | template complex log10(const complex&); 219 | 220 | template complex pow(const complex&, const T&); 221 | template complex pow(const complex&, const complex&); 222 | template complex pow(const T&, const complex&); 223 | 224 | template complex sin (const complex&); 225 | template complex sinh (const complex&); 226 | template complex sqrt (const complex&); 227 | template complex tan (const complex&); 228 | template complex tanh (const complex&); 229 | 230 | template 231 | basic_istream& 232 | operator>>(basic_istream& is, complex& x); 233 | 234 | template 235 | basic_ostream& 236 | operator<<(basic_ostream& o, const complex& x); 237 | 238 | } // std 239 | 240 | */ 241 | 242 | #ifndef __cuda_std__ 243 | #include <__config> 244 | #include 245 | #include 246 | #include 247 | #include 248 | #include 249 | #include <__pragma_push> 250 | #endif //__cuda_std__ 251 | 252 | #if defined(_LIBCUDACXX_USE_PRAGMA_GCC_SYSTEM_HEADER) 253 | #pragma GCC system_header 254 | #endif 255 | 256 | # if _LIBCUDACXX_CUDA_ABI_VERSION > 3 257 | # define _LIBCUDACXX_COMPLEX_ALIGNAS(V) _ALIGNAS(V) 258 | # else 259 | # define _LIBCUDACXX_COMPLEX_ALIGNAS(V) 260 | # endif 261 | 262 | _LIBCUDACXX_BEGIN_NAMESPACE_STD 263 | 264 | template class _LIBCUDACXX_TEMPLATE_VIS _LIBCUDACXX_COMPLEX_ALIGNAS(2*sizeof(_Tp)) complex; 265 | 266 | template _LIBCUDACXX_INLINE_VISIBILITY 267 | complex<_Tp> operator*(const complex<_Tp>& __z, const complex<_Tp>& __w); 268 | 269 | template _LIBCUDACXX_INLINE_VISIBILITY 270 | complex<_Tp> operator/(const complex<_Tp>& __x, const complex<_Tp>& __y); 271 | 272 | template 273 | class _LIBCUDACXX_TEMPLATE_VIS _LIBCUDACXX_COMPLEX_ALIGNAS(2*sizeof(_Tp)) complex 274 | { 275 | public: 276 | typedef _Tp value_type; 277 | private: 278 | value_type __re_; 279 | value_type __im_; 280 | public: 281 | _LIBCUDACXX_INLINE_VISIBILITY _LIBCUDACXX_CONSTEXPR_AFTER_CXX11 282 | complex(const value_type& __re = value_type(), const value_type& __im = value_type()) 283 | : __re_(__re), __im_(__im) {} 284 | template _LIBCUDACXX_INLINE_VISIBILITY _LIBCUDACXX_CONSTEXPR_AFTER_CXX11 285 | complex(const complex<_Xp>& __c) 286 | : __re_(__c.real()), __im_(__c.imag()) {} 287 | 288 | _LIBCUDACXX_INLINE_VISIBILITY _LIBCUDACXX_CONSTEXPR_AFTER_CXX11 value_type real() const {return __re_;} 289 | _LIBCUDACXX_INLINE_VISIBILITY _LIBCUDACXX_CONSTEXPR_AFTER_CXX11 value_type imag() const {return __im_;} 290 | 291 | _LIBCUDACXX_INLINE_VISIBILITY void real(value_type __re) {__re_ = __re;} 292 | _LIBCUDACXX_INLINE_VISIBILITY void imag(value_type __im) {__im_ = __im;} 293 | 294 | _LIBCUDACXX_INLINE_VISIBILITY complex& operator= (const value_type& __re) 295 | {__re_ = __re; __im_ = value_type(); return *this;} 296 | _LIBCUDACXX_INLINE_VISIBILITY complex& operator+=(const value_type& __re) {__re_ += __re; return *this;} 297 | _LIBCUDACXX_INLINE_VISIBILITY complex& operator-=(const value_type& __re) {__re_ -= __re; return *this;} 298 | _LIBCUDACXX_INLINE_VISIBILITY complex& operator*=(const value_type& __re) {__re_ *= __re; __im_ *= __re; return *this;} 299 | _LIBCUDACXX_INLINE_VISIBILITY complex& operator/=(const value_type& __re) {__re_ /= __re; __im_ /= __re; return *this;} 300 | 301 | template _LIBCUDACXX_INLINE_VISIBILITY complex& operator= (const complex<_Xp>& __c) 302 | { 303 | __re_ = __c.real(); 304 | __im_ = __c.imag(); 305 | return *this; 306 | } 307 | template _LIBCUDACXX_INLINE_VISIBILITY complex& operator+=(const complex<_Xp>& __c) 308 | { 309 | __re_ += __c.real(); 310 | __im_ += __c.imag(); 311 | return *this; 312 | } 313 | template _LIBCUDACXX_INLINE_VISIBILITY complex& operator-=(const complex<_Xp>& __c) 314 | { 315 | __re_ -= __c.real(); 316 | __im_ -= __c.imag(); 317 | return *this; 318 | } 319 | template _LIBCUDACXX_INLINE_VISIBILITY complex& operator*=(const complex<_Xp>& __c) 320 | { 321 | *this = *this * complex(__c.real(), __c.imag()); 322 | return *this; 323 | } 324 | template _LIBCUDACXX_INLINE_VISIBILITY complex& operator/=(const complex<_Xp>& __c) 325 | { 326 | *this = *this / complex(__c.real(), __c.imag()); 327 | return *this; 328 | } 329 | }; 330 | 331 | template<> class complex; 332 | #ifdef _LIBCUDACXX_HAS_COMPLEX_LONG_DOUBLE 333 | template<> class complex; 334 | #endif // _LIBCUDACXX_HAS_COMPLEX_LONG_DOUBLE 335 | 336 | template<> 337 | class _LIBCUDACXX_TEMPLATE_VIS _LIBCUDACXX_COMPLEX_ALIGNAS(2*sizeof(float)) complex 338 | { 339 | float __re_; 340 | float __im_; 341 | public: 342 | typedef float value_type; 343 | 344 | _LIBCUDACXX_INLINE_VISIBILITY _LIBCUDACXX_CONSTEXPR complex(float __re = 0.0f, float __im = 0.0f) 345 | : __re_(__re), __im_(__im) {} 346 | _LIBCUDACXX_INLINE_VISIBILITY 347 | explicit _LIBCUDACXX_CONSTEXPR complex(const complex& __c); 348 | _LIBCUDACXX_INLINE_VISIBILITY 349 | explicit _LIBCUDACXX_CONSTEXPR complex(const complex& __c); 350 | 351 | _LIBCUDACXX_INLINE_VISIBILITY _LIBCUDACXX_CONSTEXPR float real() const {return __re_;} 352 | _LIBCUDACXX_INLINE_VISIBILITY _LIBCUDACXX_CONSTEXPR float imag() const {return __im_;} 353 | 354 | _LIBCUDACXX_INLINE_VISIBILITY void real(value_type __re) {__re_ = __re;} 355 | _LIBCUDACXX_INLINE_VISIBILITY void imag(value_type __im) {__im_ = __im;} 356 | 357 | _LIBCUDACXX_INLINE_VISIBILITY complex& operator= (float __re) 358 | {__re_ = __re; __im_ = value_type(); return *this;} 359 | _LIBCUDACXX_INLINE_VISIBILITY complex& operator+=(float __re) {__re_ += __re; return *this;} 360 | _LIBCUDACXX_INLINE_VISIBILITY complex& operator-=(float __re) {__re_ -= __re; return *this;} 361 | _LIBCUDACXX_INLINE_VISIBILITY complex& operator*=(float __re) {__re_ *= __re; __im_ *= __re; return *this;} 362 | _LIBCUDACXX_INLINE_VISIBILITY complex& operator/=(float __re) {__re_ /= __re; __im_ /= __re; return *this;} 363 | 364 | template _LIBCUDACXX_INLINE_VISIBILITY complex& operator= (const complex<_Xp>& __c) 365 | { 366 | __re_ = __c.real(); 367 | __im_ = __c.imag(); 368 | return *this; 369 | } 370 | template _LIBCUDACXX_INLINE_VISIBILITY complex& operator+=(const complex<_Xp>& __c) 371 | { 372 | __re_ += __c.real(); 373 | __im_ += __c.imag(); 374 | return *this; 375 | } 376 | template _LIBCUDACXX_INLINE_VISIBILITY complex& operator-=(const complex<_Xp>& __c) 377 | { 378 | __re_ -= __c.real(); 379 | __im_ -= __c.imag(); 380 | return *this; 381 | } 382 | template _LIBCUDACXX_INLINE_VISIBILITY complex& operator*=(const complex<_Xp>& __c) 383 | { 384 | *this = *this * complex(__c.real(), __c.imag()); 385 | return *this; 386 | } 387 | template _LIBCUDACXX_INLINE_VISIBILITY complex& operator/=(const complex<_Xp>& __c) 388 | { 389 | *this = *this / complex(__c.real(), __c.imag()); 390 | return *this; 391 | } 392 | }; 393 | 394 | template<> 395 | class _LIBCUDACXX_TEMPLATE_VIS _LIBCUDACXX_COMPLEX_ALIGNAS(2*sizeof(double)) complex 396 | { 397 | double __re_; 398 | double __im_; 399 | public: 400 | typedef double value_type; 401 | 402 | _LIBCUDACXX_INLINE_VISIBILITY _LIBCUDACXX_CONSTEXPR complex(double __re = 0.0, double __im = 0.0) 403 | : __re_(__re), __im_(__im) {} 404 | _LIBCUDACXX_INLINE_VISIBILITY 405 | _LIBCUDACXX_CONSTEXPR complex(const complex& __c); 406 | _LIBCUDACXX_INLINE_VISIBILITY 407 | explicit _LIBCUDACXX_CONSTEXPR complex(const complex& __c); 408 | 409 | _LIBCUDACXX_INLINE_VISIBILITY _LIBCUDACXX_CONSTEXPR double real() const {return __re_;} 410 | _LIBCUDACXX_INLINE_VISIBILITY _LIBCUDACXX_CONSTEXPR double imag() const {return __im_;} 411 | 412 | _LIBCUDACXX_INLINE_VISIBILITY void real(value_type __re) {__re_ = __re;} 413 | _LIBCUDACXX_INLINE_VISIBILITY void imag(value_type __im) {__im_ = __im;} 414 | 415 | _LIBCUDACXX_INLINE_VISIBILITY complex& operator= (double __re) 416 | {__re_ = __re; __im_ = value_type(); return *this;} 417 | _LIBCUDACXX_INLINE_VISIBILITY complex& operator+=(double __re) {__re_ += __re; return *this;} 418 | _LIBCUDACXX_INLINE_VISIBILITY complex& operator-=(double __re) {__re_ -= __re; return *this;} 419 | _LIBCUDACXX_INLINE_VISIBILITY complex& operator*=(double __re) {__re_ *= __re; __im_ *= __re; return *this;} 420 | _LIBCUDACXX_INLINE_VISIBILITY complex& operator/=(double __re) {__re_ /= __re; __im_ /= __re; return *this;} 421 | 422 | template _LIBCUDACXX_INLINE_VISIBILITY complex& operator= (const complex<_Xp>& __c) 423 | { 424 | __re_ = __c.real(); 425 | __im_ = __c.imag(); 426 | return *this; 427 | } 428 | template _LIBCUDACXX_INLINE_VISIBILITY complex& operator+=(const complex<_Xp>& __c) 429 | { 430 | __re_ += __c.real(); 431 | __im_ += __c.imag(); 432 | return *this; 433 | } 434 | template _LIBCUDACXX_INLINE_VISIBILITY complex& operator-=(const complex<_Xp>& __c) 435 | { 436 | __re_ -= __c.real(); 437 | __im_ -= __c.imag(); 438 | return *this; 439 | } 440 | template _LIBCUDACXX_INLINE_VISIBILITY complex& operator*=(const complex<_Xp>& __c) 441 | { 442 | *this = *this * complex(__c.real(), __c.imag()); 443 | return *this; 444 | } 445 | template _LIBCUDACXX_INLINE_VISIBILITY complex& operator/=(const complex<_Xp>& __c) 446 | { 447 | *this = *this / complex(__c.real(), __c.imag()); 448 | return *this; 449 | } 450 | }; 451 | 452 | template<> 453 | class _LIBCUDACXX_TEMPLATE_VIS _LIBCUDACXX_COMPLEX_ALIGNAS(2*sizeof(long double)) complex 454 | { 455 | #ifndef _LIBCUDACXX_HAS_COMPLEX_LONG_DOUBLE 456 | public: 457 | template 458 | _LIBCUDACXX_INLINE_VISIBILITY _LIBCUDACXX_CONSTEXPR complex(long double __re = 0.0, long double __im = 0.0) 459 | {static_assert(is_same<_Dummy, void>::value, "complex is not supported");} 460 | 461 | template 462 | _LIBCUDACXX_INLINE_VISIBILITY _LIBCUDACXX_CONSTEXPR complex(const complex<_Tp> &__c) 463 | {static_assert(is_same<_Dummy, void>::value, "complex is not supported");} 464 | 465 | #else 466 | long double __re_; 467 | long double __im_; 468 | public: 469 | typedef long double value_type; 470 | 471 | _LIBCUDACXX_INLINE_VISIBILITY _LIBCUDACXX_CONSTEXPR complex(long double __re = 0.0L, long double __im = 0.0L) 472 | : __re_(__re), __im_(__im) {} 473 | _LIBCUDACXX_INLINE_VISIBILITY 474 | _LIBCUDACXX_CONSTEXPR complex(const complex& __c); 475 | _LIBCUDACXX_INLINE_VISIBILITY 476 | _LIBCUDACXX_CONSTEXPR complex(const complex& __c); 477 | 478 | _LIBCUDACXX_INLINE_VISIBILITY _LIBCUDACXX_CONSTEXPR long double real() const {return __re_;} 479 | _LIBCUDACXX_INLINE_VISIBILITY _LIBCUDACXX_CONSTEXPR long double imag() const {return __im_;} 480 | 481 | _LIBCUDACXX_INLINE_VISIBILITY void real(value_type __re) {__re_ = __re;} 482 | _LIBCUDACXX_INLINE_VISIBILITY void imag(value_type __im) {__im_ = __im;} 483 | 484 | _LIBCUDACXX_INLINE_VISIBILITY complex& operator= (long double __re) 485 | {__re_ = __re; __im_ = value_type(); return *this;} 486 | _LIBCUDACXX_INLINE_VISIBILITY complex& operator+=(long double __re) {__re_ += __re; return *this;} 487 | _LIBCUDACXX_INLINE_VISIBILITY complex& operator-=(long double __re) {__re_ -= __re; return *this;} 488 | _LIBCUDACXX_INLINE_VISIBILITY complex& operator*=(long double __re) {__re_ *= __re; __im_ *= __re; return *this;} 489 | _LIBCUDACXX_INLINE_VISIBILITY complex& operator/=(long double __re) {__re_ /= __re; __im_ /= __re; return *this;} 490 | template _LIBCUDACXX_INLINE_VISIBILITY complex& operator= (const complex<_Xp>& __c) 491 | { 492 | __re_ = __c.real(); 493 | __im_ = __c.imag(); 494 | return *this; 495 | } 496 | template _LIBCUDACXX_INLINE_VISIBILITY complex& operator+=(const complex<_Xp>& __c) 497 | { 498 | __re_ += __c.real(); 499 | __im_ += __c.imag(); 500 | return *this; 501 | } 502 | template _LIBCUDACXX_INLINE_VISIBILITY complex& operator-=(const complex<_Xp>& __c) 503 | { 504 | __re_ -= __c.real(); 505 | __im_ -= __c.imag(); 506 | return *this; 507 | } 508 | template _LIBCUDACXX_INLINE_VISIBILITY complex& operator*=(const complex<_Xp>& __c) 509 | { 510 | *this = *this * complex(__c.real(), __c.imag()); 511 | return *this; 512 | } 513 | template _LIBCUDACXX_INLINE_VISIBILITY complex& operator/=(const complex<_Xp>& __c) 514 | { 515 | *this = *this / complex(__c.real(), __c.imag()); 516 | return *this; 517 | } 518 | #endif // _LIBCUDACXX_HAS_COMPLEX_LONG_DOUBLE 519 | }; 520 | 521 | #if defined(_LIBCUDACXX_USE_PRAGMA_MSVC_WARNING) 522 | // MSVC complains about narrowing conversions on these copy constructors regardless if they are used 523 | #pragma warning(push) 524 | #pragma warning(disable : 4244) 525 | #endif 526 | 527 | inline 528 | _LIBCUDACXX_CONSTEXPR 529 | complex::complex(const complex& __c) 530 | : __re_(__c.real()), __im_(__c.imag()) {} 531 | 532 | inline 533 | _LIBCUDACXX_CONSTEXPR 534 | complex::complex(const complex& __c) 535 | : __re_(__c.real()), __im_(__c.imag()) {} 536 | 537 | #ifdef _LIBCUDACXX_HAS_COMPLEX_LONG_DOUBLE 538 | inline 539 | _LIBCUDACXX_CONSTEXPR 540 | complex::complex(const complex& __c) 541 | : __re_(__c.real()), __im_(__c.imag()) {} 542 | 543 | inline 544 | _LIBCUDACXX_CONSTEXPR 545 | complex::complex(const complex& __c) 546 | : __re_(__c.real()), __im_(__c.imag()) {} 547 | 548 | inline 549 | _LIBCUDACXX_CONSTEXPR 550 | complex::complex(const complex& __c) 551 | : __re_(__c.real()), __im_(__c.imag()) {} 552 | 553 | inline 554 | _LIBCUDACXX_CONSTEXPR 555 | complex::complex(const complex& __c) 556 | : __re_(__c.real()), __im_(__c.imag()) {} 557 | #endif // _LIBCUDACXX_HAS_COMPLEX_LONG_DOUBLE 558 | 559 | #if defined(_LIBCUDACXX_USE_PRAGMA_MSVC_WARNING) 560 | #pragma warning(pop) 561 | #endif 562 | 563 | // 26.3.6 operators: 564 | 565 | template 566 | inline _LIBCUDACXX_INLINE_VISIBILITY 567 | complex<_Tp> 568 | operator+(const complex<_Tp>& __x, const complex<_Tp>& __y) 569 | { 570 | complex<_Tp> __t(__x); 571 | __t += __y; 572 | return __t; 573 | } 574 | 575 | template 576 | inline _LIBCUDACXX_INLINE_VISIBILITY 577 | complex<_Tp> 578 | operator+(const complex<_Tp>& __x, const _Tp& __y) 579 | { 580 | complex<_Tp> __t(__x); 581 | __t += __y; 582 | return __t; 583 | } 584 | 585 | template 586 | inline _LIBCUDACXX_INLINE_VISIBILITY 587 | complex<_Tp> 588 | operator+(const _Tp& __x, const complex<_Tp>& __y) 589 | { 590 | complex<_Tp> __t(__y); 591 | __t += __x; 592 | return __t; 593 | } 594 | 595 | template 596 | inline _LIBCUDACXX_INLINE_VISIBILITY 597 | complex<_Tp> 598 | operator-(const complex<_Tp>& __x, const complex<_Tp>& __y) 599 | { 600 | complex<_Tp> __t(__x); 601 | __t -= __y; 602 | return __t; 603 | } 604 | 605 | template 606 | inline _LIBCUDACXX_INLINE_VISIBILITY 607 | complex<_Tp> 608 | operator-(const complex<_Tp>& __x, const _Tp& __y) 609 | { 610 | complex<_Tp> __t(__x); 611 | __t -= __y; 612 | return __t; 613 | } 614 | 615 | template 616 | inline _LIBCUDACXX_INLINE_VISIBILITY 617 | complex<_Tp> 618 | operator-(const _Tp& __x, const complex<_Tp>& __y) 619 | { 620 | complex<_Tp> __t(-__y); 621 | __t += __x; 622 | return __t; 623 | } 624 | 625 | template 626 | complex<_Tp> 627 | operator*(const complex<_Tp>& __z, const complex<_Tp>& __w) 628 | { 629 | _Tp __a = __z.real(); 630 | _Tp __b = __z.imag(); 631 | _Tp __c = __w.real(); 632 | _Tp __d = __w.imag(); 633 | _Tp __ac = __a * __c; 634 | _Tp __bd = __b * __d; 635 | _Tp __ad = __a * __d; 636 | _Tp __bc = __b * __c; 637 | _Tp __x = __ac - __bd; 638 | _Tp __y = __ad + __bc; 639 | /* 640 | if (__libcpp_isnan_or_builtin(__x) && __libcpp_isnan_or_builtin(__y)) 641 | { 642 | bool __recalc = false; 643 | if (__libcpp_isinf_or_builtin(__a) || __libcpp_isinf_or_builtin(__b)) 644 | { 645 | __a = copysign(__libcpp_isinf_or_builtin(__a) ? _Tp(1) : _Tp(0), __a); 646 | __b = copysign(__libcpp_isinf_or_builtin(__b) ? _Tp(1) : _Tp(0), __b); 647 | if (__libcpp_isnan_or_builtin(__c)) 648 | __c = copysign(_Tp(0), __c); 649 | if (__libcpp_isnan_or_builtin(__d)) 650 | __d = copysign(_Tp(0), __d); 651 | __recalc = true; 652 | } 653 | if (__libcpp_isinf_or_builtin(__c) || __libcpp_isinf_or_builtin(__d)) 654 | { 655 | __c = copysign(__libcpp_isinf_or_builtin(__c) ? _Tp(1) : _Tp(0), __c); 656 | __d = copysign(__libcpp_isinf_or_builtin(__d) ? _Tp(1) : _Tp(0), __d); 657 | if (__libcpp_isnan_or_builtin(__a)) 658 | __a = copysign(_Tp(0), __a); 659 | if (__libcpp_isnan_or_builtin(__b)) 660 | __b = copysign(_Tp(0), __b); 661 | __recalc = true; 662 | } 663 | if (!__recalc && (__libcpp_isinf_or_builtin(__ac) || __libcpp_isinf_or_builtin(__bd) || 664 | __libcpp_isinf_or_builtin(__ad) || __libcpp_isinf_or_builtin(__bc))) 665 | { 666 | if (__libcpp_isnan_or_builtin(__a)) 667 | __a = copysign(_Tp(0), __a); 668 | if (__libcpp_isnan_or_builtin(__b)) 669 | __b = copysign(_Tp(0), __b); 670 | if (__libcpp_isnan_or_builtin(__c)) 671 | __c = copysign(_Tp(0), __c); 672 | if (__libcpp_isnan_or_builtin(__d)) 673 | __d = copysign(_Tp(0), __d); 674 | __recalc = true; 675 | } 676 | if (__recalc) 677 | { 678 | __x = _Tp(INFINITY) * (__a * __c - __b * __d); 679 | __y = _Tp(INFINITY) * (__a * __d + __b * __c); 680 | } 681 | } 682 | */ 683 | return complex<_Tp>(__x, __y); 684 | } 685 | 686 | template 687 | inline _LIBCUDACXX_INLINE_VISIBILITY 688 | complex<_Tp> 689 | operator*(const complex<_Tp>& __x, const _Tp& __y) 690 | { 691 | complex<_Tp> __t(__x); 692 | __t *= __y; 693 | return __t; 694 | } 695 | 696 | template 697 | inline _LIBCUDACXX_INLINE_VISIBILITY 698 | complex<_Tp> 699 | operator*(const _Tp& __x, const complex<_Tp>& __y) 700 | { 701 | complex<_Tp> __t(__y); 702 | __t *= __x; 703 | return __t; 704 | } 705 | 706 | namespace detail { 707 | template 708 | inline _LIBCUDACXX_INLINE_VISIBILITY 709 | _Tp __scalbn(_Tp __x, int __i) { 710 | return static_cast<_Tp>(scalbn(static_cast(__x), __i)); 711 | } 712 | 713 | template <> 714 | inline _LIBCUDACXX_INLINE_VISIBILITY 715 | float __scalbn(float __x, int __i) { 716 | return scalbnf(__x, __i); 717 | } 718 | 719 | template <> 720 | inline _LIBCUDACXX_INLINE_VISIBILITY 721 | double __scalbn(double __x, int __i) { 722 | return scalbn(__x, __i); 723 | } 724 | 725 | #ifndef _LIBCUDACXX_COMPILER_NVRTC 726 | template <> 727 | inline _LIBCUDACXX_INLINE_VISIBILITY 728 | long double __scalbn(long double __x, int __i) { 729 | return scalbnl(__x, __i); 730 | } 731 | #endif 732 | } 733 | 734 | template 735 | complex<_Tp> 736 | operator/(const complex<_Tp>& __z, const complex<_Tp>& __w) 737 | { 738 | int __ilogbw = 0; 739 | _Tp __a = __z.real(); 740 | _Tp __b = __z.imag(); 741 | _Tp __c = __w.real(); 742 | _Tp __d = __w.imag(); 743 | _Tp __logbw = logb(fmax(fabs(__c), fabs(__d))); 744 | if (__libcpp_isfinite_or_builtin(__logbw)) 745 | { 746 | __ilogbw = static_cast(__logbw); 747 | __c = detail::__scalbn(__c, -__ilogbw); 748 | __d = detail::__scalbn(__d, -__ilogbw); 749 | } 750 | _Tp __denom = __c * __c + __d * __d; 751 | _Tp __x = detail::__scalbn((__a * __c + __b * __d) / __denom, -__ilogbw); 752 | _Tp __y = detail::__scalbn((__b * __c - __a * __d) / __denom, -__ilogbw); 753 | /* 754 | if (__libcpp_isnan_or_builtin(__x) && __libcpp_isnan_or_builtin(__y)) 755 | { 756 | if ((__denom == _Tp(0)) && (!__libcpp_isnan_or_builtin(__a) || !__libcpp_isnan_or_builtin(__b))) 757 | { 758 | __x = copysign(_Tp(INFINITY), __c) * __a; 759 | __y = copysign(_Tp(INFINITY), __c) * __b; 760 | } 761 | else if ((__libcpp_isinf_or_builtin(__a) || __libcpp_isinf_or_builtin(__b)) && __libcpp_isfinite_or_builtin(__c) && __libcpp_isfinite_or_builtin(__d)) 762 | { 763 | __a = copysign(__libcpp_isinf_or_builtin(__a) ? _Tp(1) : _Tp(0), __a); 764 | __b = copysign(__libcpp_isinf_or_builtin(__b) ? _Tp(1) : _Tp(0), __b); 765 | __x = _Tp(INFINITY) * (__a * __c + __b * __d); 766 | __y = _Tp(INFINITY) * (__b * __c - __a * __d); 767 | } 768 | else if (__libcpp_isinf_or_builtin(__logbw) && __logbw > _Tp(0) && __libcpp_isfinite_or_builtin(__a) && __libcpp_isfinite_or_builtin(__b)) 769 | { 770 | __c = copysign(__libcpp_isinf_or_builtin(__c) ? _Tp(1) : _Tp(0), __c); 771 | __d = copysign(__libcpp_isinf_or_builtin(__d) ? _Tp(1) : _Tp(0), __d); 772 | __x = _Tp(0) * (__a * __c + __b * __d); 773 | __y = _Tp(0) * (__b * __c - __a * __d); 774 | } 775 | } 776 | */ 777 | return complex<_Tp>(__x, __y); 778 | } 779 | 780 | template 781 | inline _LIBCUDACXX_INLINE_VISIBILITY 782 | complex<_Tp> 783 | operator/(const complex<_Tp>& __x, const _Tp& __y) 784 | { 785 | return complex<_Tp>(__x.real() / __y, __x.imag() / __y); 786 | } 787 | 788 | template 789 | inline _LIBCUDACXX_INLINE_VISIBILITY 790 | complex<_Tp> 791 | operator/(const _Tp& __x, const complex<_Tp>& __y) 792 | { 793 | complex<_Tp> __t(__x); 794 | __t /= __y; 795 | return __t; 796 | } 797 | 798 | template 799 | inline _LIBCUDACXX_INLINE_VISIBILITY 800 | complex<_Tp> 801 | operator+(const complex<_Tp>& __x) 802 | { 803 | return __x; 804 | } 805 | 806 | template 807 | inline _LIBCUDACXX_INLINE_VISIBILITY 808 | complex<_Tp> 809 | operator-(const complex<_Tp>& __x) 810 | { 811 | return complex<_Tp>(-__x.real(), -__x.imag()); 812 | } 813 | 814 | template 815 | inline _LIBCUDACXX_INLINE_VISIBILITY _LIBCUDACXX_CONSTEXPR_AFTER_CXX11 816 | bool 817 | operator==(const complex<_Tp>& __x, const complex<_Tp>& __y) 818 | { 819 | return __x.real() == __y.real() && __x.imag() == __y.imag(); 820 | } 821 | 822 | template 823 | inline _LIBCUDACXX_INLINE_VISIBILITY _LIBCUDACXX_CONSTEXPR_AFTER_CXX11 824 | bool 825 | operator==(const complex<_Tp>& __x, const _Tp& __y) 826 | { 827 | return __x.real() == __y && __x.imag() == 0; 828 | } 829 | 830 | template 831 | inline _LIBCUDACXX_INLINE_VISIBILITY _LIBCUDACXX_CONSTEXPR_AFTER_CXX11 832 | bool 833 | operator==(const _Tp& __x, const complex<_Tp>& __y) 834 | { 835 | return __x == __y.real() && 0 == __y.imag(); 836 | } 837 | 838 | template 839 | inline _LIBCUDACXX_INLINE_VISIBILITY _LIBCUDACXX_CONSTEXPR_AFTER_CXX11 840 | bool 841 | operator!=(const complex<_Tp>& __x, const complex<_Tp>& __y) 842 | { 843 | return !(__x == __y); 844 | } 845 | 846 | template 847 | inline _LIBCUDACXX_INLINE_VISIBILITY _LIBCUDACXX_CONSTEXPR_AFTER_CXX11 848 | bool 849 | operator!=(const complex<_Tp>& __x, const _Tp& __y) 850 | { 851 | return !(__x == __y); 852 | } 853 | 854 | template 855 | inline _LIBCUDACXX_INLINE_VISIBILITY _LIBCUDACXX_CONSTEXPR_AFTER_CXX11 856 | bool 857 | operator!=(const _Tp& __x, const complex<_Tp>& __y) 858 | { 859 | return !(__x == __y); 860 | } 861 | 862 | // 26.3.7 values: 863 | 864 | template ::value, 865 | bool = is_floating_point<_Tp>::value 866 | > 867 | struct __libcpp_complex_overload_traits {}; 868 | 869 | // Integral Types 870 | template 871 | struct __libcpp_complex_overload_traits<_Tp, true, false> 872 | { 873 | typedef double _ValueType; 874 | typedef complex _ComplexType; 875 | }; 876 | 877 | // Floating point types 878 | template 879 | struct __libcpp_complex_overload_traits<_Tp, false, true> 880 | { 881 | typedef _Tp _ValueType; 882 | typedef complex<_Tp> _ComplexType; 883 | }; 884 | 885 | // real 886 | 887 | template 888 | inline _LIBCUDACXX_INLINE_VISIBILITY _LIBCUDACXX_CONSTEXPR_AFTER_CXX11 889 | _Tp 890 | real(const complex<_Tp>& __c) 891 | { 892 | return __c.real(); 893 | } 894 | 895 | template 896 | inline _LIBCUDACXX_INLINE_VISIBILITY _LIBCUDACXX_CONSTEXPR_AFTER_CXX11 897 | typename __libcpp_complex_overload_traits<_Tp>::_ValueType 898 | real(_Tp __re) 899 | { 900 | return __re; 901 | } 902 | 903 | // imag 904 | 905 | template 906 | inline _LIBCUDACXX_INLINE_VISIBILITY _LIBCUDACXX_CONSTEXPR_AFTER_CXX11 907 | _Tp 908 | imag(const complex<_Tp>& __c) 909 | { 910 | return __c.imag(); 911 | } 912 | 913 | template 914 | inline _LIBCUDACXX_INLINE_VISIBILITY _LIBCUDACXX_CONSTEXPR_AFTER_CXX11 915 | typename __libcpp_complex_overload_traits<_Tp>::_ValueType 916 | imag(_Tp) 917 | { 918 | return 0; 919 | } 920 | 921 | // abs 922 | 923 | template 924 | inline _LIBCUDACXX_INLINE_VISIBILITY 925 | _Tp 926 | abs(const complex<_Tp>& __c) 927 | { 928 | return hypot(__c.real(), __c.imag()); 929 | } 930 | 931 | // arg 932 | 933 | template 934 | inline _LIBCUDACXX_INLINE_VISIBILITY 935 | _Tp 936 | arg(const complex<_Tp>& __c) 937 | { 938 | return atan2(__c.imag(), __c.real()); 939 | } 940 | 941 | template 942 | inline _LIBCUDACXX_INLINE_VISIBILITY 943 | typename enable_if< 944 | is_same<_Tp, long double>::value, 945 | long double 946 | >::type 947 | arg(_Tp __re) 948 | { 949 | return atan2l(0.L, __re); 950 | } 951 | 952 | template 953 | inline _LIBCUDACXX_INLINE_VISIBILITY 954 | typename enable_if 955 | < 956 | is_integral<_Tp>::value || is_same<_Tp, double>::value, 957 | double 958 | >::type 959 | arg(_Tp __re) 960 | { 961 | // integrals need to be promoted to double 962 | return atan2(0., static_cast(__re)); 963 | } 964 | 965 | template 966 | inline _LIBCUDACXX_INLINE_VISIBILITY 967 | typename enable_if< 968 | is_same<_Tp, float>::value, 969 | float 970 | >::type 971 | arg(_Tp __re) 972 | { 973 | return atan2f(0.F, __re); 974 | } 975 | 976 | // norm 977 | 978 | template 979 | inline _LIBCUDACXX_INLINE_VISIBILITY 980 | _Tp 981 | norm(const complex<_Tp>& __c) 982 | { 983 | if (__libcpp_isinf_or_builtin(__c.real())) 984 | return abs(__c.real()); 985 | if (__libcpp_isinf_or_builtin(__c.imag())) 986 | return abs(__c.imag()); 987 | return __c.real() * __c.real() + __c.imag() * __c.imag(); 988 | } 989 | 990 | template 991 | inline _LIBCUDACXX_INLINE_VISIBILITY 992 | typename __libcpp_complex_overload_traits<_Tp>::_ValueType 993 | norm(_Tp __re) 994 | { 995 | typedef typename __libcpp_complex_overload_traits<_Tp>::_ValueType _ValueType; 996 | return static_cast<_ValueType>(__re) * __re; 997 | } 998 | 999 | // conj 1000 | 1001 | template 1002 | inline _LIBCUDACXX_INLINE_VISIBILITY 1003 | complex<_Tp> 1004 | conj(const complex<_Tp>& __c) 1005 | { 1006 | return complex<_Tp>(__c.real(), -__c.imag()); 1007 | } 1008 | 1009 | template 1010 | inline _LIBCUDACXX_INLINE_VISIBILITY 1011 | typename __libcpp_complex_overload_traits<_Tp>::_ComplexType 1012 | conj(_Tp __re) 1013 | { 1014 | typedef typename __libcpp_complex_overload_traits<_Tp>::_ComplexType _ComplexType; 1015 | return _ComplexType(__re); 1016 | } 1017 | 1018 | 1019 | 1020 | // proj 1021 | 1022 | template 1023 | inline _LIBCUDACXX_INLINE_VISIBILITY 1024 | complex<_Tp> 1025 | proj(const complex<_Tp>& __c) 1026 | { 1027 | std::complex<_Tp> __r = __c; 1028 | if (__libcpp_isinf_or_builtin(__c.real()) || __libcpp_isinf_or_builtin(__c.imag())) 1029 | __r = complex<_Tp>(INFINITY, copysign(_Tp(0), __c.imag())); 1030 | return __r; 1031 | } 1032 | 1033 | template 1034 | inline _LIBCUDACXX_INLINE_VISIBILITY 1035 | typename enable_if 1036 | < 1037 | is_floating_point<_Tp>::value, 1038 | typename __libcpp_complex_overload_traits<_Tp>::_ComplexType 1039 | >::type 1040 | proj(_Tp __re) 1041 | { 1042 | if (__libcpp_isinf_or_builtin(__re)) 1043 | __re = abs(__re); 1044 | return complex<_Tp>(__re); 1045 | } 1046 | 1047 | template 1048 | inline _LIBCUDACXX_INLINE_VISIBILITY 1049 | typename enable_if 1050 | < 1051 | is_integral<_Tp>::value, 1052 | typename __libcpp_complex_overload_traits<_Tp>::_ComplexType 1053 | >::type 1054 | proj(_Tp __re) 1055 | { 1056 | typedef typename __libcpp_complex_overload_traits<_Tp>::_ComplexType _ComplexType; 1057 | return _ComplexType(__re); 1058 | } 1059 | 1060 | // polar 1061 | 1062 | template 1063 | _LIBCUDACXX_INLINE_VISIBILITY complex<_Tp> 1064 | polar(const _Tp& __rho, const _Tp& __theta = _Tp()) 1065 | { 1066 | if (__libcpp_isnan_or_builtin(__rho) || signbit(__rho)) 1067 | return complex<_Tp>(_Tp(NAN), _Tp(NAN)); 1068 | if (__libcpp_isnan_or_builtin(__theta)) 1069 | { 1070 | if (__libcpp_isinf_or_builtin(__rho)) 1071 | return complex<_Tp>(__rho, __theta); 1072 | return complex<_Tp>(__theta, __theta); 1073 | } 1074 | if (__libcpp_isinf_or_builtin(__theta)) 1075 | { 1076 | if (__libcpp_isinf_or_builtin(__rho)) 1077 | return complex<_Tp>(__rho, _Tp(NAN)); 1078 | return complex<_Tp>(_Tp(NAN), _Tp(NAN)); 1079 | } 1080 | _Tp __x = __rho * cos(__theta); 1081 | if (__libcpp_isnan_or_builtin(__x)) 1082 | __x = 0; 1083 | _Tp __y = __rho * sin(__theta); 1084 | if (__libcpp_isnan_or_builtin(__y)) 1085 | __y = 0; 1086 | return complex<_Tp>(__x, __y); 1087 | } 1088 | 1089 | // log 1090 | 1091 | template 1092 | inline _LIBCUDACXX_INLINE_VISIBILITY 1093 | complex<_Tp> 1094 | log(const complex<_Tp>& __x) 1095 | { 1096 | return complex<_Tp>(log(abs(__x)), arg(__x)); 1097 | } 1098 | 1099 | // log10 1100 | 1101 | template 1102 | inline _LIBCUDACXX_INLINE_VISIBILITY 1103 | complex<_Tp> 1104 | log10(const complex<_Tp>& __x) 1105 | { 1106 | return log(__x) / log(_Tp(10)); 1107 | } 1108 | 1109 | // sqrt 1110 | 1111 | template 1112 | _LIBCUDACXX_INLINE_VISIBILITY complex<_Tp> 1113 | sqrt(const complex<_Tp>& __x) 1114 | { 1115 | if (__libcpp_isinf_or_builtin(__x.imag())) 1116 | return complex<_Tp>(_Tp(INFINITY), __x.imag()); 1117 | if (__libcpp_isinf_or_builtin(__x.real())) 1118 | { 1119 | if (__x.real() > _Tp(0)) 1120 | return complex<_Tp>(__x.real(), __libcpp_isnan_or_builtin(__x.imag()) ? __x.imag() : copysign(_Tp(0), __x.imag())); 1121 | return complex<_Tp>(__libcpp_isnan_or_builtin(__x.imag()) ? __x.imag() : _Tp(0), copysign(__x.real(), __x.imag())); 1122 | } 1123 | return polar(sqrt(abs(__x)), arg(__x) / _Tp(2)); 1124 | } 1125 | 1126 | // exp 1127 | 1128 | template 1129 | _LIBCUDACXX_INLINE_VISIBILITY complex<_Tp> 1130 | exp(const complex<_Tp>& __x) 1131 | { 1132 | _Tp __i = __x.imag(); 1133 | if (__i == 0) { 1134 | return complex<_Tp>(exp(__x.real()), copysign(_Tp(0), __x.imag())); 1135 | } 1136 | if (__libcpp_isinf_or_builtin(__x.real())) 1137 | { 1138 | if (__x.real() < _Tp(0)) 1139 | { 1140 | if (!__libcpp_isfinite_or_builtin(__i)) 1141 | __i = _Tp(1); 1142 | } 1143 | else if (__i == 0 || !__libcpp_isfinite_or_builtin(__i)) 1144 | { 1145 | if (__libcpp_isinf_or_builtin(__i)) 1146 | __i = _Tp(NAN); 1147 | return complex<_Tp>(__x.real(), __i); 1148 | } 1149 | } 1150 | _Tp __e = exp(__x.real()); 1151 | return complex<_Tp>(__e * cos(__i), __e * sin(__i)); 1152 | } 1153 | 1154 | // pow 1155 | 1156 | template 1157 | inline _LIBCUDACXX_INLINE_VISIBILITY 1158 | complex<_Tp> 1159 | pow(const complex<_Tp>& __x, const complex<_Tp>& __y) 1160 | { 1161 | return exp(__y * log(__x)); 1162 | } 1163 | 1164 | template 1165 | inline _LIBCUDACXX_INLINE_VISIBILITY 1166 | complex::type> 1167 | pow(const complex<_Tp>& __x, const complex<_Up>& __y) 1168 | { 1169 | typedef complex::type> result_type; 1170 | return _CUDA_VSTD::pow(result_type(__x), result_type(__y)); 1171 | } 1172 | 1173 | template 1174 | inline _LIBCUDACXX_INLINE_VISIBILITY 1175 | typename enable_if 1176 | < 1177 | is_arithmetic<_Up>::value, 1178 | complex::type> 1179 | >::type 1180 | pow(const complex<_Tp>& __x, const _Up& __y) 1181 | { 1182 | typedef complex::type> result_type; 1183 | return _CUDA_VSTD::pow(result_type(__x), result_type(__y)); 1184 | } 1185 | 1186 | template 1187 | inline _LIBCUDACXX_INLINE_VISIBILITY 1188 | typename enable_if 1189 | < 1190 | is_arithmetic<_Tp>::value, 1191 | complex::type> 1192 | >::type 1193 | pow(const _Tp& __x, const complex<_Up>& __y) 1194 | { 1195 | typedef complex::type> result_type; 1196 | return _CUDA_VSTD::pow(result_type(__x), result_type(__y)); 1197 | } 1198 | 1199 | // __sqr, computes pow(x, 2) 1200 | 1201 | template 1202 | inline _LIBCUDACXX_INLINE_VISIBILITY 1203 | _LIBCUDACXX_INLINE_VISIBILITY complex<_Tp> 1204 | __sqr(const complex<_Tp>& __x) 1205 | { 1206 | return complex<_Tp>((__x.real() - __x.imag()) * (__x.real() + __x.imag()), 1207 | _Tp(2) * __x.real() * __x.imag()); 1208 | } 1209 | 1210 | // asinh 1211 | 1212 | template 1213 | _LIBCUDACXX_INLINE_VISIBILITY complex<_Tp> 1214 | asinh(const complex<_Tp>& __x) 1215 | { 1216 | const _Tp __pi(static_cast<_Tp>(atan2(+0., -0.))); 1217 | if (__libcpp_isinf_or_builtin(__x.real())) 1218 | { 1219 | if (__libcpp_isnan_or_builtin(__x.imag())) 1220 | return __x; 1221 | if (__libcpp_isinf_or_builtin(__x.imag())) 1222 | return complex<_Tp>(__x.real(), copysign(__pi * _Tp(0.25), __x.imag())); 1223 | return complex<_Tp>(__x.real(), copysign(_Tp(0), __x.imag())); 1224 | } 1225 | if (__libcpp_isnan_or_builtin(__x.real())) 1226 | { 1227 | if (__libcpp_isinf_or_builtin(__x.imag())) 1228 | return complex<_Tp>(__x.imag(), __x.real()); 1229 | if (__x.imag() == 0) 1230 | return __x; 1231 | return complex<_Tp>(__x.real(), __x.real()); 1232 | } 1233 | if (__libcpp_isinf_or_builtin(__x.imag())) 1234 | return complex<_Tp>(copysign(__x.imag(), __x.real()), copysign(__pi/_Tp(2), __x.imag())); 1235 | complex<_Tp> __z = log(__x + sqrt(__sqr(__x) + _Tp(1))); 1236 | return complex<_Tp>(copysign(__z.real(), __x.real()), copysign(__z.imag(), __x.imag())); 1237 | } 1238 | 1239 | // acosh 1240 | 1241 | template 1242 | _LIBCUDACXX_INLINE_VISIBILITY complex<_Tp> 1243 | acosh(const complex<_Tp>& __x) 1244 | { 1245 | const _Tp __pi(static_cast<_Tp>(atan2(+0., -0.))); 1246 | if (__libcpp_isinf_or_builtin(__x.real())) 1247 | { 1248 | if (__libcpp_isnan_or_builtin(__x.imag())) 1249 | return complex<_Tp>(abs(__x.real()), __x.imag()); 1250 | if (__libcpp_isinf_or_builtin(__x.imag())) 1251 | { 1252 | if (__x.real() > 0) 1253 | return complex<_Tp>(__x.real(), copysign(__pi * _Tp(0.25), __x.imag())); 1254 | else 1255 | return complex<_Tp>(-__x.real(), copysign(__pi * _Tp(0.75), __x.imag())); 1256 | } 1257 | if (__x.real() < 0) 1258 | return complex<_Tp>(-__x.real(), copysign(__pi, __x.imag())); 1259 | return complex<_Tp>(__x.real(), copysign(_Tp(0), __x.imag())); 1260 | } 1261 | if (__libcpp_isnan_or_builtin(__x.real())) 1262 | { 1263 | if (__libcpp_isinf_or_builtin(__x.imag())) 1264 | return complex<_Tp>(abs(__x.imag()), __x.real()); 1265 | return complex<_Tp>(__x.real(), __x.real()); 1266 | } 1267 | if (__libcpp_isinf_or_builtin(__x.imag())) 1268 | return complex<_Tp>(abs(__x.imag()), copysign(__pi/_Tp(2), __x.imag())); 1269 | complex<_Tp> __z = log(__x + sqrt(__sqr(__x) - _Tp(1))); 1270 | return complex<_Tp>(copysign(__z.real(), _Tp(0)), copysign(__z.imag(), __x.imag())); 1271 | } 1272 | 1273 | // atanh 1274 | 1275 | template 1276 | _LIBCUDACXX_INLINE_VISIBILITY complex<_Tp> 1277 | atanh(const complex<_Tp>& __x) 1278 | { 1279 | const _Tp __pi(static_cast<_Tp>(atan2(+0., -0.))); 1280 | if (__libcpp_isinf_or_builtin(__x.imag())) 1281 | { 1282 | return complex<_Tp>(copysign(_Tp(0), __x.real()), copysign(__pi/_Tp(2), __x.imag())); 1283 | } 1284 | if (__libcpp_isnan_or_builtin(__x.imag())) 1285 | { 1286 | if (__libcpp_isinf_or_builtin(__x.real()) || __x.real() == 0) 1287 | return complex<_Tp>(copysign(_Tp(0), __x.real()), __x.imag()); 1288 | return complex<_Tp>(__x.imag(), __x.imag()); 1289 | } 1290 | if (__libcpp_isnan_or_builtin(__x.real())) 1291 | { 1292 | return complex<_Tp>(__x.real(), __x.real()); 1293 | } 1294 | if (__libcpp_isinf_or_builtin(__x.real())) 1295 | { 1296 | return complex<_Tp>(copysign(_Tp(0), __x.real()), copysign(__pi/_Tp(2), __x.imag())); 1297 | } 1298 | if (abs(__x.real()) == _Tp(1) && __x.imag() == _Tp(0)) 1299 | { 1300 | return complex<_Tp>(copysign(_Tp(INFINITY), __x.real()), copysign(_Tp(0), __x.imag())); 1301 | } 1302 | complex<_Tp> __z = log((_Tp(1) + __x) / (_Tp(1) - __x)) / _Tp(2); 1303 | return complex<_Tp>(copysign(__z.real(), __x.real()), copysign(__z.imag(), __x.imag())); 1304 | } 1305 | 1306 | // sinh 1307 | 1308 | template 1309 | _LIBCUDACXX_INLINE_VISIBILITY complex<_Tp> 1310 | sinh(const complex<_Tp>& __x) 1311 | { 1312 | if (__libcpp_isinf_or_builtin(__x.real()) && !__libcpp_isfinite_or_builtin(__x.imag())) 1313 | return complex<_Tp>(__x.real(), _Tp(NAN)); 1314 | if (__x.real() == 0 && !__libcpp_isfinite_or_builtin(__x.imag())) 1315 | return complex<_Tp>(__x.real(), _Tp(NAN)); 1316 | if (__x.imag() == 0 && !__libcpp_isfinite_or_builtin(__x.real())) 1317 | return __x; 1318 | return complex<_Tp>(sinh(__x.real()) * cos(__x.imag()), cosh(__x.real()) * sin(__x.imag())); 1319 | } 1320 | 1321 | // cosh 1322 | 1323 | template 1324 | _LIBCUDACXX_INLINE_VISIBILITY complex<_Tp> 1325 | cosh(const complex<_Tp>& __x) 1326 | { 1327 | if (__libcpp_isinf_or_builtin(__x.real()) && !__libcpp_isfinite_or_builtin(__x.imag())) 1328 | return complex<_Tp>(abs(__x.real()), _Tp(NAN)); 1329 | if (__x.real() == 0 && !__libcpp_isfinite_or_builtin(__x.imag())) 1330 | return complex<_Tp>(_Tp(NAN), __x.real()); 1331 | if (__x.real() == 0 && __x.imag() == 0) 1332 | return complex<_Tp>(_Tp(1), __x.imag()); 1333 | if (__x.imag() == 0 && !__libcpp_isfinite_or_builtin(__x.real())) 1334 | return complex<_Tp>(abs(__x.real()), __x.imag()); 1335 | return complex<_Tp>(cosh(__x.real()) * cos(__x.imag()), sinh(__x.real()) * sin(__x.imag())); 1336 | } 1337 | 1338 | // tanh 1339 | 1340 | template 1341 | _LIBCUDACXX_INLINE_VISIBILITY complex<_Tp> 1342 | tanh(const complex<_Tp>& __x) 1343 | { 1344 | if (__libcpp_isinf_or_builtin(__x.real())) 1345 | { 1346 | if (!__libcpp_isfinite_or_builtin(__x.imag())) 1347 | return complex<_Tp>(copysign(_Tp(1), __x.real()), _Tp(0)); 1348 | return complex<_Tp>(copysign(_Tp(1), __x.real()), copysign(_Tp(0), sin(_Tp(2) * __x.imag()))); 1349 | } 1350 | if (__libcpp_isnan_or_builtin(__x.real()) && __x.imag() == 0) 1351 | return __x; 1352 | _Tp __2r(_Tp(2) * __x.real()); 1353 | _Tp __2i(_Tp(2) * __x.imag()); 1354 | _Tp __d(cosh(__2r) + cos(__2i)); 1355 | _Tp __2rsh(sinh(__2r)); 1356 | if (__libcpp_isinf_or_builtin(__2rsh) && __libcpp_isinf_or_builtin(__d)) 1357 | return complex<_Tp>(__2rsh > _Tp(0) ? _Tp(1) : _Tp(-1), 1358 | __2i > _Tp(0) ? _Tp(0) : _Tp(-0.)); 1359 | return complex<_Tp>(__2rsh/__d, sin(__2i)/__d); 1360 | } 1361 | 1362 | // asin 1363 | 1364 | template 1365 | _LIBCUDACXX_INLINE_VISIBILITY complex<_Tp> 1366 | asin(const complex<_Tp>& __x) 1367 | { 1368 | complex<_Tp> __z = asinh(complex<_Tp>(-__x.imag(), __x.real())); 1369 | return complex<_Tp>(__z.imag(), -__z.real()); 1370 | } 1371 | 1372 | // acos 1373 | 1374 | template 1375 | _LIBCUDACXX_INLINE_VISIBILITY complex<_Tp> 1376 | acos(const complex<_Tp>& __x) 1377 | { 1378 | const _Tp __pi(static_cast<_Tp>(atan2(+0., -0.))); 1379 | if (__libcpp_isinf_or_builtin(__x.real())) 1380 | { 1381 | if (__libcpp_isnan_or_builtin(__x.imag())) 1382 | return complex<_Tp>(__x.imag(), __x.real()); 1383 | if (__libcpp_isinf_or_builtin(__x.imag())) 1384 | { 1385 | if (__x.real() < _Tp(0)) 1386 | return complex<_Tp>(_Tp(0.75) * __pi, -__x.imag()); 1387 | return complex<_Tp>(_Tp(0.25) * __pi, -__x.imag()); 1388 | } 1389 | if (__x.real() < _Tp(0)) 1390 | return complex<_Tp>(__pi, signbit(__x.imag()) ? -__x.real() : __x.real()); 1391 | return complex<_Tp>(_Tp(0), signbit(__x.imag()) ? __x.real() : -__x.real()); 1392 | } 1393 | if (__libcpp_isnan_or_builtin(__x.real())) 1394 | { 1395 | if (__libcpp_isinf_or_builtin(__x.imag())) 1396 | return complex<_Tp>(__x.real(), -__x.imag()); 1397 | return complex<_Tp>(__x.real(), __x.real()); 1398 | } 1399 | if (__libcpp_isinf_or_builtin(__x.imag())) 1400 | return complex<_Tp>(__pi/_Tp(2), -__x.imag()); 1401 | if (__x.real() == 0 && (__x.imag() == 0 || isnan(__x.imag()))) 1402 | return complex<_Tp>(__pi/_Tp(2), -__x.imag()); 1403 | complex<_Tp> __z = log(__x + sqrt(__sqr(__x) - _Tp(1))); 1404 | if (signbit(__x.imag())) 1405 | return complex<_Tp>(abs(__z.imag()), abs(__z.real())); 1406 | return complex<_Tp>(abs(__z.imag()), -abs(__z.real())); 1407 | } 1408 | 1409 | // atan 1410 | 1411 | template 1412 | _LIBCUDACXX_INLINE_VISIBILITY complex<_Tp> 1413 | atan(const complex<_Tp>& __x) 1414 | { 1415 | complex<_Tp> __z = atanh(complex<_Tp>(-__x.imag(), __x.real())); 1416 | return complex<_Tp>(__z.imag(), -__z.real()); 1417 | } 1418 | 1419 | // sin 1420 | 1421 | template 1422 | _LIBCUDACXX_INLINE_VISIBILITY complex<_Tp> 1423 | sin(const complex<_Tp>& __x) 1424 | { 1425 | complex<_Tp> __z = sinh(complex<_Tp>(-__x.imag(), __x.real())); 1426 | return complex<_Tp>(__z.imag(), -__z.real()); 1427 | } 1428 | 1429 | // cos 1430 | 1431 | template 1432 | inline _LIBCUDACXX_INLINE_VISIBILITY 1433 | complex<_Tp> 1434 | cos(const complex<_Tp>& __x) 1435 | { 1436 | return cosh(complex<_Tp>(-__x.imag(), __x.real())); 1437 | } 1438 | 1439 | // tan 1440 | 1441 | template 1442 | _LIBCUDACXX_INLINE_VISIBILITY complex<_Tp> 1443 | tan(const complex<_Tp>& __x) 1444 | { 1445 | complex<_Tp> __z = tanh(complex<_Tp>(-__x.imag(), __x.real())); 1446 | return complex<_Tp>(__z.imag(), -__z.real()); 1447 | } 1448 | 1449 | #ifndef __cuda_std__ 1450 | 1451 | template 1452 | basic_istream<_CharT, _Traits>& 1453 | operator>>(basic_istream<_CharT, _Traits>& __is, complex<_Tp>& __x) 1454 | { 1455 | if (__is.good()) 1456 | { 1457 | ws(__is); 1458 | if (__is.peek() == _CharT('(')) 1459 | { 1460 | __is.get(); 1461 | _Tp __r; 1462 | __is >> __r; 1463 | if (!__is.fail()) 1464 | { 1465 | ws(__is); 1466 | _CharT __c = __is.peek(); 1467 | if (__c == _CharT(',')) 1468 | { 1469 | __is.get(); 1470 | _Tp __i; 1471 | __is >> __i; 1472 | if (!__is.fail()) 1473 | { 1474 | ws(__is); 1475 | __c = __is.peek(); 1476 | if (__c == _CharT(')')) 1477 | { 1478 | __is.get(); 1479 | __x = complex<_Tp>(__r, __i); 1480 | } 1481 | else 1482 | __is.setstate(ios_base::failbit); 1483 | } 1484 | else 1485 | __is.setstate(ios_base::failbit); 1486 | } 1487 | else if (__c == _CharT(')')) 1488 | { 1489 | __is.get(); 1490 | __x = complex<_Tp>(__r, _Tp(0)); 1491 | } 1492 | else 1493 | __is.setstate(ios_base::failbit); 1494 | } 1495 | else 1496 | __is.setstate(ios_base::failbit); 1497 | } 1498 | else 1499 | { 1500 | _Tp __r; 1501 | __is >> __r; 1502 | if (!__is.fail()) 1503 | __x = complex<_Tp>(__r, _Tp(0)); 1504 | else 1505 | __is.setstate(ios_base::failbit); 1506 | } 1507 | } 1508 | else 1509 | __is.setstate(ios_base::failbit); 1510 | return __is; 1511 | } 1512 | 1513 | template 1514 | basic_ostream<_CharT, _Traits>& 1515 | operator<<(basic_ostream<_CharT, _Traits>& __os, const complex<_Tp>& __x) 1516 | { 1517 | basic_ostringstream<_CharT, _Traits> __s; 1518 | __s.flags(__os.flags()); 1519 | __s.imbue(__os.getloc()); 1520 | __s.precision(__os.precision()); 1521 | __s << '(' << __x.real() << ',' << __x.imag() << ')'; 1522 | return __os << __s.str(); 1523 | } 1524 | 1525 | #endif // __cuda_std__ 1526 | 1527 | #if _LIBCUDACXX_STD_VER > 11 && defined(_LIBCUDACXX_HAS_STL_LITERALS) 1528 | // Literal suffix for complex number literals [complex.literals] 1529 | inline namespace literals 1530 | { 1531 | inline namespace complex_literals 1532 | { 1533 | #ifdef _LIBCUDACXX_HAS_COMPLEX_LONG_DOUBLE 1534 | _LIBCUDACXX_INLINE_VISIBILITY constexpr complex operator""il(long double __im) 1535 | { 1536 | return { 0.0l, __im }; 1537 | } 1538 | 1539 | _LIBCUDACXX_INLINE_VISIBILITY constexpr complex operator""il(unsigned long long __im) 1540 | { 1541 | return { 0.0l, static_cast(__im) }; 1542 | } 1543 | 1544 | _LIBCUDACXX_INLINE_VISIBILITY constexpr complex operator""i(long double __im) 1545 | { 1546 | return { 0.0, static_cast(__im) }; 1547 | } 1548 | 1549 | _LIBCUDACXX_INLINE_VISIBILITY constexpr complex operator""i(unsigned long long __im) 1550 | { 1551 | return { 0.0, static_cast(__im) }; 1552 | } 1553 | 1554 | _LIBCUDACXX_INLINE_VISIBILITY constexpr complex operator""if(long double __im) 1555 | { 1556 | return { 0.0f, static_cast(__im) }; 1557 | } 1558 | 1559 | _LIBCUDACXX_INLINE_VISIBILITY constexpr complex operator""if(unsigned long long __im) 1560 | { 1561 | return { 0.0f, static_cast(__im) }; 1562 | } 1563 | #else 1564 | _LIBCUDACXX_INLINE_VISIBILITY constexpr complex operator""i(double __im) 1565 | { 1566 | return { 0.0, static_cast(__im) }; 1567 | } 1568 | 1569 | _LIBCUDACXX_INLINE_VISIBILITY constexpr complex operator""i(unsigned long long __im) 1570 | { 1571 | return { 0.0, static_cast(__im) }; 1572 | } 1573 | 1574 | _LIBCUDACXX_INLINE_VISIBILITY constexpr complex operator""if(double __im) 1575 | { 1576 | return { 0.0f, static_cast(__im) }; 1577 | } 1578 | 1579 | _LIBCUDACXX_INLINE_VISIBILITY constexpr complex operator""if(unsigned long long __im) 1580 | { 1581 | return { 0.0f, static_cast(__im) }; 1582 | } 1583 | #endif 1584 | } 1585 | } 1586 | #endif 1587 | 1588 | _LIBCUDACXX_END_NAMESPACE_STD 1589 | 1590 | #ifndef __cuda_std__ 1591 | #include <__pragma_pop> 1592 | #endif //__cuda_std__ 1593 | 1594 | #endif // _LIBCUDACXX_COMPLEX 1595 | -------------------------------------------------------------------------------- /torchbp/csrc/torchbp.cpp: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | 4 | namespace torchbp { 5 | 6 | #define kPI 3.1415926535897932384626433f 7 | #define kC0 299792458.0f 8 | 9 | // Registers _C as a Python extension module. 10 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {} 11 | 12 | using complex64_t = c10::complex; 13 | 14 | c10::complex operator * (const float &a, const c10::complex &b){ 15 | return c10::complex(b.real() * (double)a, b.imag() * (double)a); 16 | } 17 | 18 | c10::complex operator * (const c10::complex &b, const float &a){ 19 | return c10::complex(b.real() * (double)a, b.imag() * (double)a); 20 | } 21 | 22 | template 23 | static T interp2d(const T *img, int nx, int ny, 24 | int x_int, float x_frac, int y_int, float y_frac) { 25 | return img[x_int*ny + y_int]*(1.0f-x_frac)*(1.0f-y_frac) + 26 | img[x_int*ny + y_int+1]*(1.0f-x_frac)*y_frac + 27 | img[(x_int+1)*ny + y_int]*x_frac*(1.0f-y_frac) + 28 | img[(x_int+1)*ny + y_int+1]*x_frac*y_frac; 29 | } 30 | 31 | template 32 | static T interp2d_gradx(const T *img, int nx, int ny, 33 | int x_int, float x_frac, int y_int, float y_frac) { 34 | return -img[x_int*ny + y_int]*(1.0f-y_frac) + 35 | -img[x_int*ny + y_int+1]*y_frac + 36 | img[(x_int+1)*ny + y_int]*(1.0f-y_frac) + 37 | img[(x_int+1)*ny + y_int+1]*y_frac; 38 | } 39 | 40 | template 41 | static T interp2d_grady(const T *img, int nx, int ny, 42 | int x_int, float x_frac, int y_int, float y_frac) { 43 | return -img[x_int*ny + y_int]*(1.0f-x_frac) + 44 | img[x_int*ny + y_int+1]*(1.0f-x_frac) + 45 | -img[(x_int+1)*ny + y_int]*x_frac + 46 | img[(x_int+1)*ny + y_int+1]*x_frac; 47 | } 48 | 49 | template 50 | static void sincospi(T x, T *sinx, T *cosx) { 51 | *sinx = sin(static_cast(kPI) * x); 52 | *cosx = cos(static_cast(kPI) * x); 53 | } 54 | 55 | template 56 | static void polar_interp_kernel_linear_cpu(const c10::complex *img, c10::complex *out, const T *dorigin, T rotation, 57 | T ref_phase, T r0, T dr, T theta0, T dtheta, int Nr, int Ntheta, 58 | T r1, T dr1, T theta1, T dtheta1, int Nr1, int Ntheta1, T z1, int idx, int idbatch) { 59 | const int idtheta = idx % Ntheta1; 60 | const int idr = idx / Ntheta1; 61 | 62 | if (idx >= Nr1 * Ntheta1) { 63 | return; 64 | } 65 | 66 | const T d = r1 + dr1 * idr; 67 | T t = theta1 + dtheta1 * idtheta; 68 | if (rotation != 0.0f) { 69 | t = sinf(asinf(t) - rotation); 70 | } 71 | if (t < -1.0f || t > 1.0f) { 72 | return; 73 | } 74 | const T dorig0 = dorigin[idbatch * 3 + 0]; 75 | const T dorig1 = dorigin[idbatch * 3 + 1]; 76 | const T sint = t; 77 | const T cost = sqrt(1.0f - t*t); 78 | const T rp = sqrt(d*d + dorig0*dorig0 + dorig1*dorig1 + 2*d*(dorig0*cost + dorig1*sint)); 79 | const T arg = (d*sint + dorig1) / (d*cost + dorig0); 80 | const T tp = arg / sqrt(1.0f + arg*arg); 81 | 82 | const T dri = (rp - r0) / dr; 83 | const T dti = (tp - theta0) / dtheta; 84 | 85 | const int dri_int = dri; 86 | const T dri_frac = dri - dri_int; 87 | const int dti_int = dti; 88 | const T dti_frac = dti - dti_int; 89 | 90 | if (dri_int >= 0 && dri_int < Nr-1 && dti_int >= 0 && dti_int < Ntheta-1) { 91 | c10::complex v = interp2d>(&img[idbatch * Nr * Ntheta], Nr, Ntheta, dri_int, dri_frac, dti_int, dti_frac); 92 | T ref_sin, ref_cos; 93 | const T z0 = z1 + dorigin[idbatch * 3 + 2]; 94 | const T dz = sqrt(z1*z1 + d*d); 95 | const T rpz = sqrt(z0*z0 + rp*rp); 96 | sincospi(ref_phase * (rpz - dz), &ref_sin, &ref_cos); 97 | c10::complex ref = {ref_cos, ref_sin}; 98 | out[idbatch * Nr1 * Ntheta1 + idr*Ntheta1 + idtheta] = v * ref; 99 | } else { 100 | out[idbatch * Nr1 * Ntheta1 + idr*Ntheta1 + idtheta] = {0.0f, 0.0f}; 101 | } 102 | } 103 | 104 | template 105 | static void polar_interp_kernel_linear_grad_cpu(const c10::complex *img, const T *dorigin, T rotation, 106 | T ref_phase, T r0, T dr, T theta0, T dtheta, int Nr, int Ntheta, 107 | T r1, T dr1, T theta1, T dtheta1, int Nr1, int Ntheta1, T z1, 108 | const c10::complex *grad, c10::complex *img_grad, T *dorigin_grad, 109 | int idx, int idbatch) { 110 | const int idtheta = idx % Ntheta1; 111 | const int idr = idx / Ntheta1; 112 | c10::complex I = {0.0f, 1.0f}; 113 | 114 | const T d = r1 + dr1 * idr; 115 | T t = theta1 + dtheta1 * idtheta; 116 | if (t > 1.0f) { 117 | t = 1.0f; 118 | } 119 | if (rotation != 0.0f) { 120 | t = sinf(asinf(t) - rotation); 121 | } 122 | 123 | if (idx >= Nr1 * Ntheta1) { 124 | return; 125 | } 126 | if (t < -1.0f || t > 1.0f) { 127 | return; 128 | } 129 | const T dorig0 = dorigin[idbatch * 3 + 0]; 130 | const T dorig1 = dorigin[idbatch * 3 + 1]; 131 | const T dorig2 = dorigin[idbatch * 3 + 2]; 132 | const T sint = t; 133 | const T cost = sqrt(1.0f - t*t); 134 | // TODO: Add dorig2 135 | const T rp = sqrt(d*d + dorig0*dorig0 + dorig1*dorig1 + 2*d*(dorig0*cost + dorig1*sint)); 136 | const T arg = (d*sint + dorig1) / (d*cost + dorig0); 137 | const T cosarg = sqrt(1.0f + arg*arg); 138 | const T tp = arg / cosarg; 139 | 140 | const T dri = (rp - r0) / dr; 141 | const T dti = (tp - theta0) / dtheta; 142 | 143 | const int dri_int = dri; 144 | const T dri_frac = dri - dri_int; 145 | 146 | const int dti_int = dti; 147 | const T dti_frac = dti - dti_int; 148 | 149 | c10::complex v = {0.0f, 0.0f}; 150 | c10::complex ref = {0.0f, 0.0f}; 151 | 152 | const T z0 = z1 + dorig2; 153 | const T rpz = sqrt(z0*z0 + rp*rp); 154 | if (dri_int >= 0 && dri_int < Nr-1 && dti_int >= 0 && dti_int < Ntheta-1) { 155 | v = interp2d>(&img[idbatch * Nr * Ntheta], Nr, Ntheta, dri_int, dri_frac, dti_int, dti_frac); 156 | T ref_sin, ref_cos; 157 | const T dz = sqrt(z1*z1 + d*d); 158 | sincospi(ref_phase * (rpz - dz), &ref_sin, &ref_cos); 159 | ref = {ref_cos, ref_sin}; 160 | } 161 | 162 | if (dorigin_grad != nullptr) { 163 | const c10::complex dref_drpz = I * kPI * ref_phase * ref; 164 | const c10::complex dv_drp = interp2d_gradx>( 165 | &img[idbatch * Nr * Ntheta], Nr, Ntheta, dri_int, dri_frac, 166 | dti_int, dti_frac) / dr; 167 | const c10::complex dv_dt = interp2d_grady>( 168 | &img[idbatch * Nr * Ntheta], Nr, Ntheta, dri_int, dri_frac, 169 | dti_int, dti_frac) / dtheta; 170 | const T drp_dorig0 = (cost*d + dorig0) / rp; 171 | const T drp_dorig1 = (sint*d + dorig1) / rp; 172 | const T drpz_dorig0 = (cost*d + dorig0) / rpz; 173 | const T drpz_dorig1 = (sint*d + dorig1) / rpz; 174 | const T drpz_dorig2 = (dorig2 + z1) / rpz; 175 | const T dt_darg = -arg*arg/(cosarg*cosarg*cosarg) + 1.0f / cosarg; 176 | const T darg_dorig0 = -(d*sint + dorig1) / ((dorig0 + d*cost)*(dorig0 + d*cost)); 177 | const T darg_dorig1 = 1.0f / (cost*d + dorig0); 178 | 179 | const c10::complex g = grad[idbatch * Nr1 * Ntheta1 + idr*Ntheta1 + idtheta]; 180 | const c10::complex dout_dorig0 = ref * (dv_drp * drp_dorig0 + dv_dt * dt_darg * darg_dorig0) + v * dref_drpz * drpz_dorig0; 181 | const c10::complex dout_dorig1 = ref * (dv_drp * drp_dorig1 + dv_dt * dt_darg * darg_dorig1) + v * dref_drpz * drpz_dorig1; 182 | const c10::complex dout_dorig2 = v * dref_drpz * drpz_dorig2; 183 | T g_dorig0 = std::real(g * std::conj(dout_dorig0)); 184 | T g_dorig1 = std::real(g * std::conj(dout_dorig1)); 185 | T g_dorig2 = std::real(g * std::conj(dout_dorig2)); 186 | 187 | #pragma omp atomic 188 | dorigin_grad[idbatch * 3 + 0] += g_dorig0; 189 | #pragma omp atomic 190 | dorigin_grad[idbatch * 3 + 1] += g_dorig1; 191 | #pragma omp atomic 192 | dorigin_grad[idbatch * 3 + 2] += g_dorig2; 193 | } 194 | 195 | if (img_grad != nullptr) { 196 | if (dri_int >= 0 && dri_int < Nr-1 && dti_int >= 0 && dti_int < Ntheta-1) { 197 | c10::complex g = grad[idbatch * Nr1 * Ntheta1 + idr*Ntheta1 + idtheta] * std::conj(ref); 198 | 199 | c10::complex g11 = g * (1.0f-dri_frac)*(1.0f-dti_frac); 200 | c10::complex g12 = g * (1.0f-dri_frac)*dti_frac; 201 | c10::complex g21 = g * dri_frac*(1.0f-dti_frac); 202 | c10::complex g22 = g * dri_frac*dti_frac; 203 | // Slow 204 | #pragma omp critical 205 | { 206 | img_grad[idbatch * Nr * Ntheta + dri_int*Ntheta + dti_int] += g11; 207 | img_grad[idbatch * Nr * Ntheta + dri_int*Ntheta + dti_int + 1] += g12; 208 | img_grad[idbatch * Nr * Ntheta + (dri_int+1)*Ntheta + dti_int] += g21; 209 | img_grad[idbatch * Nr * Ntheta + (dri_int+1)*Ntheta + dti_int + 1] += g22; 210 | } 211 | } 212 | } 213 | } 214 | 215 | at::Tensor polar_interp_linear_cpu( 216 | const at::Tensor &img, 217 | const at::Tensor &dorigin, 218 | int64_t nbatch, 219 | double rotation, 220 | double fc, 221 | double r0, 222 | double dr0, 223 | double theta0, 224 | double dtheta0, 225 | int64_t nr0, 226 | int64_t ntheta0, 227 | double r1, 228 | double dr1, 229 | double theta1, 230 | double dtheta1, 231 | int64_t nr1, 232 | int64_t ntheta1, 233 | double z1) { 234 | TORCH_CHECK(img.dtype() == at::kComplexFloat || img.dtype() == at::kComplexDouble); 235 | TORCH_CHECK(dorigin.dtype() == at::kFloat || dorigin.dtype() == at::kDouble); 236 | TORCH_INTERNAL_ASSERT(img.device().type() == at::DeviceType::CPU); 237 | TORCH_INTERNAL_ASSERT(dorigin.device().type() == at::DeviceType::CPU); 238 | at::Tensor img_contig = img.contiguous(); 239 | at::Tensor out = torch::empty({nbatch, nr1, ntheta1}, img_contig.options()); 240 | at::Tensor dorigin_contig = dorigin.contiguous(); 241 | 242 | if (img.dtype() == at::kComplexFloat) { 243 | TORCH_CHECK(dorigin.dtype() == at::kFloat); 244 | const float* dorigin_ptr = dorigin_contig.data_ptr(); 245 | c10::complex* img_ptr = img.data_ptr>(); 246 | c10::complex* out_ptr = out.data_ptr>(); 247 | const float ref_phase = 4.0f * fc / kC0; 248 | 249 | #pragma omp parallel for collapse(2) 250 | for(int idbatch = 0; idbatch < nbatch; idbatch++) { 251 | for(int idx = 0; idx < nr1 * ntheta1; idx++) { 252 | polar_interp_kernel_linear_cpu(img_ptr, out_ptr, dorigin_ptr, rotation, 253 | ref_phase, r0, dr0, theta0, dtheta0, nr0, ntheta0, 254 | r1, dr1, theta1, dtheta1, nr1, ntheta1, z1, idx, idbatch); 255 | } 256 | } 257 | } else { 258 | TORCH_CHECK(dorigin.dtype() == at::kDouble); 259 | const double* dorigin_ptr = dorigin_contig.data_ptr(); 260 | c10::complex* img_ptr = img.data_ptr>(); 261 | c10::complex* out_ptr = out.data_ptr>(); 262 | const double ref_phase = 4.0 * fc / kC0; 263 | 264 | #pragma omp parallel for collapse(2) 265 | for(int idbatch = 0; idbatch < nbatch; idbatch++) { 266 | for(int idx = 0; idx < nr1 * ntheta1; idx++) { 267 | polar_interp_kernel_linear_cpu(img_ptr, out_ptr, dorigin_ptr, rotation, 268 | ref_phase, r0, dr0, theta0, dtheta0, nr0, ntheta0, 269 | r1, dr1, theta1, dtheta1, nr1, ntheta1, z1, idx, idbatch); 270 | } 271 | } 272 | } 273 | return out; 274 | } 275 | 276 | std::vector polar_interp_linear_grad_cpu( 277 | const at::Tensor &grad, 278 | const at::Tensor &img, 279 | const at::Tensor &dorigin, 280 | int64_t nbatch, 281 | double rotation, 282 | double fc, 283 | double r0, 284 | double dr0, 285 | double theta0, 286 | double dtheta0, 287 | int64_t nr0, 288 | int64_t ntheta0, 289 | double r1, 290 | double dr1, 291 | double theta1, 292 | double dtheta1, 293 | int64_t nr1, 294 | int64_t ntheta1, 295 | double z1) { 296 | TORCH_CHECK(img.dtype() == at::kComplexFloat || img.dtype() == at::kComplexDouble); 297 | TORCH_CHECK(dorigin.dtype() == at::kFloat || dorigin.dtype() == at::kDouble); 298 | TORCH_CHECK(grad.dtype() == at::kComplexFloat || grad.dtype() == at::kComplexDouble); 299 | TORCH_INTERNAL_ASSERT(img.device().type() == at::DeviceType::CPU); 300 | TORCH_INTERNAL_ASSERT(dorigin.device().type() == at::DeviceType::CPU); 301 | TORCH_INTERNAL_ASSERT(grad.device().type() == at::DeviceType::CPU); 302 | at::Tensor dorigin_contig = dorigin.contiguous(); 303 | at::Tensor img_contig = img.contiguous(); 304 | at::Tensor grad_contig = grad.contiguous(); 305 | at::Tensor img_grad; 306 | at::Tensor dorigin_grad; 307 | 308 | if (img.dtype() == at::kComplexFloat) { 309 | TORCH_CHECK(dorigin.dtype() == at::kFloat); 310 | TORCH_CHECK(grad.dtype() == at::kComplexFloat); 311 | const float* dorigin_ptr = dorigin_contig.data_ptr(); 312 | c10::complex* img_ptr = img.data_ptr>(); 313 | c10::complex* grad_ptr = grad.data_ptr>(); 314 | c10::complex* img_grad_ptr = nullptr; 315 | if (img.requires_grad()) { 316 | img_grad = torch::zeros_like(img); 317 | img_grad_ptr = img_grad.data_ptr>(); 318 | } else { 319 | img_grad = torch::Tensor(); 320 | } 321 | 322 | float* dorigin_grad_ptr = nullptr; 323 | if (dorigin.requires_grad()) { 324 | dorigin_grad = torch::zeros_like(dorigin); 325 | dorigin_grad_ptr = dorigin_grad.data_ptr(); 326 | } else { 327 | dorigin_grad = torch::Tensor(); 328 | } 329 | 330 | const float ref_phase = 4.0f * fc / kC0; 331 | 332 | #pragma omp parallel for collapse(2) 333 | for(int idbatch = 0; idbatch < nbatch; idbatch++) { 334 | for(int idx = 0; idx < nr1 * ntheta1; idx++) { 335 | polar_interp_kernel_linear_grad_cpu(img_ptr, dorigin_ptr, rotation, 336 | ref_phase, r0, dr0, theta0, dtheta0, nr0, ntheta0, 337 | r1, dr1, theta1, dtheta1, nr1, ntheta1, z1, 338 | grad_ptr, img_grad_ptr, dorigin_grad_ptr, 339 | idx, idbatch); 340 | } 341 | } 342 | } else { 343 | TORCH_CHECK(dorigin.dtype() == at::kDouble); 344 | TORCH_CHECK(grad.dtype() == at::kComplexDouble); 345 | const double* dorigin_ptr = dorigin_contig.data_ptr(); 346 | c10::complex* img_ptr = img.data_ptr>(); 347 | c10::complex* grad_ptr = grad.data_ptr>(); 348 | c10::complex* img_grad_ptr = nullptr; 349 | if (img.requires_grad()) { 350 | img_grad = torch::zeros_like(img); 351 | img_grad_ptr = img_grad.data_ptr>(); 352 | } else { 353 | img_grad = torch::Tensor(); 354 | } 355 | 356 | double* dorigin_grad_ptr = nullptr; 357 | if (dorigin.requires_grad()) { 358 | dorigin_grad = torch::zeros_like(dorigin); 359 | dorigin_grad_ptr = dorigin_grad.data_ptr(); 360 | } else { 361 | dorigin_grad = torch::Tensor(); 362 | } 363 | 364 | const double ref_phase = 4.0 * fc / kC0; 365 | 366 | #pragma omp parallel for collapse(2) 367 | for(int idbatch = 0; idbatch < nbatch; idbatch++) { 368 | for(int idx = 0; idx < nr1 * ntheta1; idx++) { 369 | polar_interp_kernel_linear_grad_cpu(img_ptr, dorigin_ptr, rotation, 370 | ref_phase, r0, dr0, theta0, dtheta0, nr0, ntheta0, 371 | r1, dr1, theta1, dtheta1, nr1, ntheta1, z1, 372 | grad_ptr, img_grad_ptr, dorigin_grad_ptr, 373 | idx, idbatch); 374 | } 375 | } 376 | } 377 | 378 | std::vector ret; 379 | ret.push_back(img_grad); 380 | ret.push_back(dorigin_grad); 381 | return ret; 382 | } 383 | 384 | static void backprojection_polar_2d_kernel_cpu( 385 | const complex64_t* data, 386 | const float* pos, 387 | complex64_t* img, 388 | int sweep_samples, 389 | int nsweeps, 390 | float ref_phase, 391 | float delta_r, 392 | float r0, 393 | float dr, 394 | float theta0, 395 | float dtheta, 396 | int Nr, 397 | int Ntheta, 398 | float d0, 399 | bool dealias, float z0, 400 | int idx, 401 | int idbatch) { 402 | const int idtheta = idx % Ntheta; 403 | const int idr = idx / Ntheta; 404 | if (idr >= Nr || idtheta >= Ntheta) { 405 | return; 406 | } 407 | 408 | const float r = r0 + idr * dr; 409 | const float theta = theta0 + idtheta * dtheta; 410 | const float x = r * sqrtf(1.0f - theta*theta); 411 | const float y = r * theta; 412 | 413 | complex64_t pixel = {0, 0}; 414 | 415 | for(int i = 0; i < nsweeps; i++) { 416 | // Sweep reference position. 417 | float pos_x = pos[idbatch * nsweeps * 3 + i * 3 + 0]; 418 | float pos_y = pos[idbatch * nsweeps * 3 + i * 3 + 1]; 419 | float pos_z = pos[idbatch * nsweeps * 3 + i * 3 + 2]; 420 | float px = (x - pos_x); 421 | float py = (y - pos_y); 422 | float pz2 = pos_z * pos_z; 423 | 424 | // Calculate distance to the pixel. 425 | float d = sqrtf(px * px + py * py + pz2) + d0; 426 | 427 | float sx = delta_r * d; 428 | 429 | // Linear interpolation. 430 | int id0 = sx; 431 | int id1 = id0 + 1; 432 | if (id0 < 0 || id1 >= sweep_samples) { 433 | continue; 434 | } 435 | complex64_t s0 = data[idbatch * sweep_samples * nsweeps + i * sweep_samples + id0]; 436 | complex64_t s1 = data[idbatch * sweep_samples * nsweeps + i * sweep_samples + id1]; 437 | 438 | float interp_idx = sx - id0; 439 | complex64_t s = (1.0f - interp_idx) * s0 + interp_idx * s1; 440 | 441 | float ref_sin, ref_cos; 442 | sincospi(ref_phase * d, &ref_sin, &ref_cos); 443 | complex64_t ref = {ref_cos, ref_sin}; 444 | pixel += s * ref; 445 | } 446 | if (dealias) { 447 | const float d = sqrtf(x*x + y*y + z0*z0); 448 | float ref_sin, ref_cos; 449 | sincospi(-ref_phase * d, &ref_sin, &ref_cos); 450 | complex64_t ref = {ref_cos, ref_sin}; 451 | pixel *= ref; 452 | } 453 | img[idbatch * Nr * Ntheta + idr * Ntheta + idtheta] = pixel; 454 | } 455 | 456 | at::Tensor backprojection_polar_2d_cpu( 457 | const at::Tensor &data, 458 | const at::Tensor &pos, 459 | int64_t nbatch, 460 | int64_t sweep_samples, 461 | int64_t nsweeps, 462 | double fc, 463 | double r_res, 464 | double r0, 465 | double dr, 466 | double theta0, 467 | double dtheta, 468 | int64_t Nr, 469 | int64_t Ntheta, 470 | double d0, 471 | int64_t dealias, 472 | double z0) { 473 | TORCH_CHECK(pos.dtype() == at::kFloat); 474 | TORCH_CHECK(data.dtype() == at::kComplexFloat); 475 | TORCH_INTERNAL_ASSERT(pos.device().type() == at::DeviceType::CPU); 476 | TORCH_INTERNAL_ASSERT(data.device().type() == at::DeviceType::CPU); 477 | 478 | at::Tensor pos_contig = pos.contiguous(); 479 | at::Tensor data_contig = data.contiguous(); 480 | auto options = 481 | torch::TensorOptions() 482 | .dtype(torch::kComplexFloat) 483 | .layout(torch::kStrided) 484 | .device(torch::kCPU, 1); 485 | at::Tensor img = torch::zeros({nbatch, Nr, Ntheta}, options); 486 | const float* pos_ptr = pos_contig.data_ptr(); 487 | const c10::complex* data_ptr = data_contig.data_ptr>(); 488 | c10::complex* img_ptr = img.data_ptr>(); 489 | 490 | const float delta_r = 1.0f / r_res; 491 | const float ref_phase = 4.0f * fc / kC0; 492 | 493 | #pragma omp parallel for collapse(2) 494 | for(int idbatch = 0; idbatch < nbatch; idbatch++) { 495 | for(int idx = 0; idx < Nr * Ntheta; idx++) { 496 | backprojection_polar_2d_kernel_cpu( 497 | data_ptr, 498 | pos_ptr, 499 | img_ptr, 500 | sweep_samples, 501 | nsweeps, 502 | ref_phase, 503 | delta_r, 504 | r0, dr, 505 | theta0, dtheta, 506 | Nr, Ntheta, 507 | d0, 508 | dealias, z0, 509 | idx, idbatch); 510 | } 511 | } 512 | return img; 513 | } 514 | 515 | static void backprojection_polar_2d_grad_kernel_cpu( 516 | const complex64_t* data, 517 | const float* pos, 518 | int sweep_samples, 519 | int nsweeps, 520 | float ref_phase, 521 | float delta_r, 522 | float r0, 523 | float dr, 524 | float theta0, 525 | float dtheta, 526 | int Nr, 527 | int Ntheta, 528 | float d0, 529 | bool dealias, 530 | float z0, 531 | const complex64_t* grad, 532 | float* pos_grad, 533 | complex64_t *data_grad, 534 | int idx, 535 | int idbatch) { 536 | const int idtheta = idx % Ntheta; 537 | const int idr = idx / Ntheta; 538 | if (idx >= Nr * Ntheta) { 539 | return; 540 | } 541 | 542 | bool have_pos_grad = pos_grad != nullptr; 543 | bool have_data_grad = data_grad != nullptr; 544 | 545 | const float r = r0 + idr * dr; 546 | const float theta = theta0 + idtheta * dtheta; 547 | const float x = r * sqrtf(1.0f - theta*theta); 548 | const float y = r * theta; 549 | 550 | complex64_t g = grad[idbatch * Nr * Ntheta + idr * Ntheta + idtheta]; 551 | 552 | if (dealias) { 553 | const float d = sqrtf(x*x + y*y + z0*z0); 554 | float ref_sin, ref_cos; 555 | sincospi(-ref_phase * d, &ref_sin, &ref_cos); 556 | complex64_t ref = {ref_cos, ref_sin}; 557 | g *= ref; 558 | } 559 | 560 | complex64_t I = {0.0f, 1.0f}; 561 | 562 | for(int i = 0; i < nsweeps; i++) { 563 | // Sweep reference position. 564 | float pos_x = pos[idbatch * nsweeps * 3 + i * 3 + 0]; 565 | float pos_y = pos[idbatch * nsweeps * 3 + i * 3 + 1]; 566 | float pos_z = pos[idbatch * nsweeps * 3 + i * 3 + 2]; 567 | float px = (x - pos_x); 568 | float py = (y - pos_y); 569 | // Image plane is assumed to be at z=0 570 | float pz2 = pos_z * pos_z; 571 | 572 | // Calculate distance to the pixel. 573 | float d = sqrtf(px * px + py * py + pz2) + d0; 574 | 575 | float sx = delta_r * d; 576 | 577 | float dx = 0.0f; 578 | float dy = 0.0f; 579 | float dz = 0.0f; 580 | complex64_t ds0 = 0.0f; 581 | complex64_t ds1 = 0.0f; 582 | 583 | // Linear interpolation. 584 | int id0 = sx; 585 | int id1 = id0 + 1; 586 | if (id0 >= 0 && id1 < sweep_samples) { 587 | complex64_t s0 = data[idbatch * sweep_samples * nsweeps + i * sweep_samples + id0]; 588 | complex64_t s1 = data[idbatch * sweep_samples * nsweeps + i * sweep_samples + id1]; 589 | 590 | float interp_idx = sx - id0; 591 | complex64_t s = (1.0f - interp_idx) * s0 + interp_idx * s1; 592 | 593 | float ref_sin, ref_cos; 594 | sincospi(ref_phase * d, &ref_sin, &ref_cos); 595 | complex64_t ref = {ref_cos, ref_sin}; 596 | 597 | if (have_pos_grad) { 598 | complex64_t dout = ref * ((I * kPI * ref_phase) * s + (s1 - s0) * delta_r); 599 | complex64_t gdout = g * std::conj(dout); 600 | 601 | // Take real part 602 | float gd = std::real(gdout); 603 | 604 | dx = -px / (d - d0); 605 | dy = -py / (d - d0); 606 | // Different from x,y because pos_z is handled differently. 607 | dz = pos_z / (d - d0); 608 | dx *= gd; 609 | dy *= gd; 610 | dz *= gd; 611 | // Avoid issues with zero range 612 | if (!isfinite(dx)) dx = 0.0f; 613 | if (!isfinite(dy)) dy = 0.0f; 614 | if (!isfinite(dz)) dz = 0.0f; 615 | } 616 | 617 | if (have_data_grad) { 618 | ds0 = g * std::conj((1.0f - interp_idx) * ref); 619 | ds1 = g * std::conj(interp_idx * ref); 620 | } 621 | } 622 | 623 | if (have_pos_grad) { 624 | #pragma omp atomic 625 | pos_grad[idbatch * nsweeps * 3 + i * 3 + 0] += dx; 626 | #pragma omp atomic 627 | pos_grad[idbatch * nsweeps * 3 + i * 3 + 1] += dy; 628 | #pragma omp atomic 629 | pos_grad[idbatch * nsweeps * 3 + i * 3 + 2] += dz; 630 | } 631 | 632 | if (have_data_grad) { 633 | if (id0 >= 0 && id1 < sweep_samples) { 634 | // Slow 635 | #pragma omp critical 636 | { 637 | data_grad[idbatch * sweep_samples * nsweeps + i * sweep_samples + id0] += ds0; 638 | data_grad[idbatch * sweep_samples * nsweeps + i * sweep_samples + id1] += ds1; 639 | } 640 | } 641 | } 642 | } 643 | } 644 | 645 | std::vector backprojection_polar_2d_grad_cpu( 646 | const at::Tensor &grad, 647 | const at::Tensor &data, 648 | const at::Tensor &pos, 649 | int64_t nbatch, 650 | int64_t sweep_samples, 651 | int64_t nsweeps, 652 | double fc, 653 | double r_res, 654 | double r0, 655 | double dr, 656 | double theta0, 657 | double dtheta, 658 | int64_t Nr, 659 | int64_t Ntheta, 660 | double d0, 661 | int64_t dealias, 662 | double z0) { 663 | TORCH_CHECK(pos.dtype() == at::kFloat); 664 | TORCH_CHECK(data.dtype() == at::kComplexFloat); 665 | TORCH_CHECK(grad.dtype() == at::kComplexFloat); 666 | TORCH_INTERNAL_ASSERT(pos.device().type() == at::DeviceType::CPU); 667 | TORCH_INTERNAL_ASSERT(data.device().type() == at::DeviceType::CPU); 668 | TORCH_INTERNAL_ASSERT(grad.device().type() == at::DeviceType::CPU); 669 | at::Tensor pos_contig = pos.contiguous(); 670 | at::Tensor data_contig = data.contiguous(); 671 | at::Tensor grad_contig = grad.contiguous(); 672 | const float* pos_ptr = pos_contig.data_ptr(); 673 | const c10::complex* data_ptr = data_contig.data_ptr>(); 674 | const c10::complex* grad_ptr = grad_contig.data_ptr>(); 675 | 676 | at::Tensor pos_grad; 677 | float* pos_grad_ptr = nullptr; 678 | if (pos.requires_grad()) { 679 | pos_grad = torch::zeros_like(pos); 680 | pos_grad_ptr = pos_grad.data_ptr(); 681 | } else { 682 | pos_grad = torch::Tensor(); 683 | } 684 | 685 | at::Tensor data_grad; 686 | c10::complex* data_grad_ptr = nullptr; 687 | if (data.requires_grad()) { 688 | data_grad = torch::zeros_like(data); 689 | data_grad_ptr = data_grad.data_ptr>(); 690 | } else { 691 | data_grad = torch::Tensor(); 692 | } 693 | 694 | const float delta_r = 1.0f / r_res; 695 | const float ref_phase = 4.0f * fc / kC0; 696 | 697 | #pragma omp parallel for collapse(2) 698 | for(int idbatch = 0; idbatch < nbatch; idbatch++) { 699 | for(int idx = 0; idx < Nr * Ntheta; idx++) { 700 | backprojection_polar_2d_grad_kernel_cpu( 701 | data_ptr, 702 | pos_ptr, 703 | sweep_samples, 704 | nsweeps, 705 | ref_phase, 706 | delta_r, 707 | r0, dr, 708 | theta0, dtheta, 709 | Nr, Ntheta, 710 | d0, 711 | dealias, z0, 712 | grad_ptr, 713 | pos_grad_ptr, 714 | data_grad_ptr, 715 | idx, 716 | idbatch 717 | ); 718 | } 719 | } 720 | std::vector ret; 721 | ret.push_back(data_grad); 722 | ret.push_back(pos_grad); 723 | return ret; 724 | } 725 | 726 | template 727 | static void polar_to_cart_kernel_linear_cpu(const T *img, T 728 | *out, const float *origin, float rotation, float ref_phase, float r0, 729 | float dr, float theta0, float dtheta, int Nr, int Ntheta, float x0, 730 | float dx, float y0, float dy, int Nx, int Ny, 731 | int id1, int idbatch) { 732 | const int idy = id1 % Ny; 733 | const int idx = id1 / Ny; 734 | 735 | if (id1 >= Nx * Ny) { 736 | return; 737 | } 738 | 739 | const float orig0 = origin[idbatch * 3 + 0]; 740 | const float orig1 = origin[idbatch * 3 + 1]; 741 | const float orig2 = origin[idbatch * 3 + 2]; 742 | const float x = x0 + dx * idx; 743 | const float y = y0 + dy * idy; 744 | const float d = sqrtf((x-orig0)*(x-orig0) + (y-orig1)*(y-orig1)); 745 | const float dz = sqrtf(d*d + orig2*orig2); 746 | float t = (y - orig1) / d; // Sin of angle 747 | float tc = (x - orig0) / d; // Cos of angle 748 | float rs = sinf(rotation); 749 | float rc = cosf(rotation); 750 | float cosa = t*rs + tc*rc; 751 | if (rotation != 0.0f) { 752 | t = rc * t - rs * tc; 753 | } 754 | const float dri = (d - r0) / dr; 755 | const float dti = (t - theta0) / dtheta; 756 | 757 | const int dri_int = dri; 758 | const float dri_frac = dri - dri_int; 759 | const int dti_int = dti; 760 | const float dti_frac = dti - dti_int; 761 | 762 | if (cosa >= 0 && dri_int >= 0 && dri_int < Nr-1 && dti_int >= 0 && dti_int < Ntheta-1) { 763 | T v = interp2d(&img[idbatch * Nr * Ntheta], Nr, Ntheta, dri_int, dri_frac, dti_int, dti_frac); 764 | if constexpr (std::is_same_v) { 765 | float ref_sin, ref_cos; 766 | sincospi(ref_phase * dz, &ref_sin, &ref_cos); 767 | complex64_t ref = {ref_cos, ref_sin}; 768 | out[idbatch * Nx * Ny + idx*Ny + idy] = v * ref; 769 | } else { 770 | out[idbatch * Nx * Ny + idx*Ny + idy] = v; 771 | } 772 | } else { 773 | if constexpr (std::is_same_v) { 774 | out[idbatch * Nx * Ny + idx*Ny + idy] = {0.0f, 0.0f}; 775 | } else { 776 | out[idbatch * Nx * Ny + idx*Ny + idy] = 0.0f; 777 | } 778 | } 779 | } 780 | 781 | at::Tensor polar_to_cart_linear_cpu( 782 | const at::Tensor &img, 783 | const at::Tensor &origin, 784 | int64_t nbatch, 785 | double rotation, 786 | double fc, 787 | double r0, 788 | double dr, 789 | double theta0, 790 | double dtheta, 791 | int64_t Nr, 792 | int64_t Ntheta, 793 | double x0, 794 | double y0, 795 | double dx, 796 | double dy, 797 | int64_t Nx, 798 | int64_t Ny) { 799 | TORCH_CHECK(img.dtype() == at::kComplexFloat || img.dtype() == at::kFloat); 800 | TORCH_CHECK(origin.dtype() == at::kFloat); 801 | TORCH_INTERNAL_ASSERT(img.device().type() == at::DeviceType::CPU); 802 | TORCH_INTERNAL_ASSERT(origin.device().type() == at::DeviceType::CPU); 803 | at::Tensor origin_contig = origin.contiguous(); 804 | at::Tensor img_contig = img.contiguous(); 805 | at::Tensor out = torch::empty({nbatch, Nx, Ny}, img_contig.options()); 806 | const float* origin_ptr = origin_contig.data_ptr(); 807 | 808 | const float ref_phase = 4.0f * fc / kC0; 809 | 810 | #pragma omp parallel for collapse(2) 811 | for(int idbatch = 0; idbatch < nbatch; idbatch++) { 812 | for(int id1 = 0; id1 < Nx * Ny; id1++) { 813 | if (img.dtype() == at::kComplexFloat) { 814 | c10::complex* img_ptr = img_contig.data_ptr>(); 815 | c10::complex* out_ptr = out.data_ptr>(); 816 | polar_to_cart_kernel_linear_cpu( 817 | (const complex64_t*)img_ptr, 818 | (complex64_t*)out_ptr, 819 | origin_ptr, 820 | rotation, 821 | ref_phase, 822 | r0, 823 | dr, 824 | theta0, 825 | dtheta, 826 | Nr, 827 | Ntheta, 828 | x0, 829 | dx, 830 | y0, 831 | dy, 832 | Nx, 833 | Ny, 834 | id1, 835 | idbatch 836 | ); 837 | } else { 838 | float* img_ptr = img_contig.data_ptr(); 839 | float* out_ptr = out.data_ptr(); 840 | polar_to_cart_kernel_linear_cpu( 841 | img_ptr, 842 | out_ptr, 843 | origin_ptr, 844 | rotation, 845 | ref_phase, 846 | r0, 847 | dr, 848 | theta0, 849 | dtheta, 850 | Nr, 851 | Ntheta, 852 | x0, 853 | dx, 854 | y0, 855 | dy, 856 | Nx, 857 | Ny, 858 | id1, 859 | idbatch 860 | ); 861 | } 862 | } 863 | } 864 | return out; 865 | } 866 | 867 | // Defines the operators 868 | TORCH_LIBRARY(torchbp, m) { 869 | m.def("backprojection_polar_2d(Tensor data, Tensor pos, int nbatch, int sweep_samples, int nsweeps, float fc, float r_res, float r0, float dr, float theta0, float dtheta, int Nr, int Ntheta, float d0, int dealias, float z0) -> Tensor"); 870 | m.def("backprojection_polar_2d_grad(Tensor grad, Tensor data, Tensor pos, int nbatch, int sweep_samples, int nsweeps, float fc, float r_res, float r0, float dr, float theta0, float dtheta, int Nr, int Ntheta, float d0, int dealias, float z0) -> Tensor[]"); 871 | m.def("backprojection_polar_2d_lanczos(Tensor data, Tensor pos, int nbatch, int sweep_samples, int nsweeps, float fc, float r_res, float r0, float dr, float theta0, float dtheta, int Nr, int Ntheta, float d0, int dealias, float z0, int order) -> Tensor"); 872 | m.def("backprojection_cart_2d(Tensor data, Tensor pos, int nbatch, int sweep_samples, int nsweeps, float fc, float r_res, float x0, float dx, float y0, float dy, int Nx, int Ny, float beamwidth, float d0) -> Tensor"); 873 | m.def("backprojection_cart_2d_grad(Tensor grad, Tensor data, Tensor pos, int nbatch, int sweep_samples, int nsweeps, float fc, float r_res, float x0, float dx, float y0, float dy, int Nx, int Ny, float beamwidth, float d0) -> Tensor[]"); 874 | m.def("gpga_backprojection_2d(Tensor target_pos, Tensor data, Tensor pos, int sweep_samples, int nsweeps, float fc, float r_res, int Ntarget, float d0) -> Tensor"); 875 | m.def("gpga_backprojection_2d_lanczos(Tensor target_pos, Tensor data, Tensor pos, int sweep_samples, int nsweeps, float fc, float r_res, int Ntarget, float d0, int order) -> Tensor"); 876 | m.def("cfar_2d(Tensor img, int nbatch, int N0, int N1, int Navg0, int Navg1, int Nguard0, int Nguard1, float threshold, int peaks_only) -> Tensor"); 877 | m.def("polar_interp_linear(Tensor img, Tensor dorigin, int nbatch, float rotation, float fc, float r0, float dr0, float theta0, float dtheta0, int Nr0, int Ntheta0, float r1, float dr1, float theta1, float dtheta1, int Nr1, int Ntheta1, float z1) -> Tensor"); 878 | m.def("polar_interp_linear_grad(Tensor grad, Tensor img, Tensor dorigin, int nbatch, float rotation, float fc, float r0, float dr0, float theta0, float dtheta0, int Nr0, int Ntheta0, float r1, float dr1, float theta1, float dtheta1, int Nr1, int Ntheta1, float z1) -> Tensor[]"); 879 | m.def("polar_interp_bicubic(Tensor img, Tensor img_gx, Tensor img_gy, Tensor img_gxy, Tensor dorigin, int nbatch, float rotation, float fc, float r0, float dr0, float theta0, float dtheta0, int Nr0, int Ntheta0, float r1, float dr1, float theta1, float dtheta1, int Nr1, int Ntheta1, float z1) -> Tensor"); 880 | m.def("polar_interp_lanczos(Tensor img, Tensor dorigin, int nbatch, float rotation, float fc, float r0, float dr0, float theta0, float dtheta0, int Nr0, int Ntheta0, float r1, float dr1, float theta1, float dtheta1, int Nr1, int Ntheta1, float z1, int order) -> Tensor"); 881 | m.def("polar_to_cart_linear(Tensor img, Tensor origin, int nbatch, float rotation, float fc, float r0, float dr, float theta0, float dtheta, int Nr, int Ntheta, float x0, float y0, float dx, float dy, int Nx, int Ny) -> Tensor"); 882 | m.def("polar_to_cart_linear_grad(Tensor grad, Tensor img, Tensor origin, int nbatch, float rotation, float fc, float r0, float dr, float theta0, float dtheta, int Nr, int Ntheta, float x0, float y0, float dx, float dy, int Nx, int Ny) -> Tensor[]"); 883 | m.def("polar_to_cart_bicubic(Tensor img, Tensor img_gx, Tensor img_gy, Tensor img_gxy, Tensor origin, int nbatch, float rotation, float fc, float r0, float dr, float theta0, float dtheta, int Nr, int Ntheta, float x0, float y0, float dx, float dy, int Nx, int Ny) -> Tensor"); 884 | m.def("polar_to_cart_bicubic_grad(Tensor grad, Tensor img, Tensor img_gx, Tensor img_gy, Tensor img_gxy, Tensor origin, int nbatch, float rotation, float fc, float r0, float dr, float theta0, float dtheta, int Nr, int Ntheta, float x0, float y0, float dx, float dy, int Nx, int Ny) -> Tensor[]"); 885 | m.def("backprojection_polar_2d_tx_power(Tensor wa, Tensor pos, Tensor att, Tensor gtx, Tensor grx, int nbatch, float g_az0, float g_el0, float g_daz, float g_del, int g_naz, int g_nel, int nsweeps, float r_res, float r0, float dr, float theta0, float dtheta, int Nr, int Ntheta, int sin_look_angle) -> Tensor"); 886 | m.def("entropy(Tensor data, Tensor norm, int nbatch) -> Tensor"); 887 | m.def("entropy_grad(Tensor data, Tensor norm, Tensor grad, int nbatch) -> Tensor[]"); 888 | m.def("abs_sum(Tensor data, int nbatch) -> Tensor"); 889 | m.def("abs_sum_grad(Tensor data, Tensor grad, int nbatch) -> Tensor"); 890 | } 891 | 892 | TORCH_LIBRARY_IMPL(torchbp, CPU, m) { 893 | m.impl("backprojection_polar_2d", &backprojection_polar_2d_cpu); 894 | m.impl("backprojection_polar_2d_grad", &backprojection_polar_2d_grad_cpu); 895 | m.impl("polar_interp_linear", &polar_interp_linear_cpu); 896 | m.impl("polar_interp_linear_grad", &polar_interp_linear_grad_cpu); 897 | m.impl("polar_to_cart_linear", &polar_to_cart_linear_cpu); 898 | } 899 | 900 | } 901 | -------------------------------------------------------------------------------- /torchbp/util.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import Tensor 3 | from math import pi 4 | import numpy as np 5 | from scipy.signal import get_window 6 | 7 | 8 | def bp_polar_range_dealias( 9 | img: Tensor, origin: Tensor, fc: float, grid_polar: dict 10 | ) -> Tensor: 11 | """ 12 | De-alias range-axis spectrum of polar SAR image processed with backprojection. 13 | 14 | Parameters 15 | ---------- 16 | img : Tensor 17 | Complex input image. Shape should be: [Range, azimuth]. 18 | origin : Tensor 19 | Center of the platform position. 20 | fc : float 21 | RF center frequency. 22 | grid_polar : dict 23 | Polar grid definition 24 | 25 | References 26 | ---------- 27 | .. [#] T. Shi, X. Mao, A. Jakobsson and Y. Liu, "Extended PGA for Spotlight 28 | SAR-Filtered Backprojection Imagery," in IEEE Geoscience and Remote Sensing 29 | Letters, vol. 19, pp. 1-5, 2022, Art no. 4516005. 30 | 31 | Returns 32 | ---------- 33 | img : Tensor 34 | SAR image without range spectrum aliasing. 35 | """ 36 | r0, r1 = grid_polar["r"] 37 | theta0, theta1 = grid_polar["theta"] 38 | ntheta = grid_polar["ntheta"] 39 | nr = grid_polar["nr"] 40 | dtheta = (theta1 - theta0) / ntheta 41 | dr = (r1 - r0) / nr 42 | 43 | r = r0 + dr * torch.arange(nr, device=img.device) 44 | theta = theta0 + dtheta * torch.arange(ntheta, device=img.device) 45 | 46 | x = r[:, None] * torch.sqrt(1 - torch.square(theta))[None, :] 47 | y = r[:, None] * theta[None, :] 48 | 49 | if origin.dim() == 2: 50 | origin = origin[0] 51 | d = torch.sqrt((x - origin[0]) ** 2 + (y - origin[1]) ** 2 + origin[2] ** 2) 52 | c0 = 299792458 53 | phase = torch.exp(-1j * 4 * pi * fc * d / c0) 54 | if img.dim() == 3: 55 | phase = phase.unsqueeze(0) 56 | return phase * img 57 | 58 | 59 | def diff(x: Tensor, dim: int = -1, same_size: bool = False) -> Tensor: 60 | """ 61 | ``np.diff`` implemented in torch. 62 | 63 | Parameters 64 | ---------- 65 | x : Tensor 66 | Input tensor. 67 | dim : int 68 | Dimension. 69 | same_size : bool 70 | Pad output to same size as input. 71 | 72 | Returns 73 | ---------- 74 | d : Tensor 75 | Difference tensor. 76 | """ 77 | if dim != -1: 78 | raise NotImplementedError("Only dim=-1 is implemented") 79 | if same_size: 80 | return torch.nn.functional.pad(x[..., 1:] - x[..., :-1], (1, 0)) 81 | else: 82 | return x[..., 1:] - x[..., :-1] 83 | 84 | 85 | def unwrap(phi: Tensor, dim: int = -1) -> Tensor: 86 | """ 87 | ``np.unwrap`` implemented in torch. 88 | 89 | Parameters 90 | ---------- 91 | phi : Tensor 92 | Input tensor. 93 | dim : int 94 | Dimension. 95 | 96 | Returns 97 | ---------- 98 | phi : Tensor 99 | Unwrapped tensor. 100 | """ 101 | if dim != -1: 102 | raise NotImplementedError("Only dim=-1 is implemented") 103 | phi_wrap = ((phi + torch.pi) % (2 * torch.pi)) - torch.pi 104 | dphi = diff(phi_wrap, same_size=True) 105 | dphi_m = ((dphi + torch.pi) % (2 * torch.pi)) - torch.pi 106 | dphi_m[(dphi_m == -torch.pi) & (dphi > 0)] = torch.pi 107 | phi_adj = dphi_m - dphi 108 | phi_adj[dphi.abs() < torch.pi] = 0 109 | return phi_wrap + phi_adj.cumsum(dim) 110 | 111 | 112 | def quad_interp(a: Tensor, v: int) -> Tensor: 113 | """ 114 | Quadractic peak interpolation. 115 | Useful for FFT peak interpolation. 116 | 117 | Parameters 118 | ---------- 119 | a : Tensor 120 | Input tensor. 121 | v : int 122 | Peak index. 123 | 124 | Returns 125 | ---------- 126 | f : float 127 | Estimated fractional peak index. 128 | """ 129 | a1 = a[(v - 1) % len(a)] 130 | a2 = a[v % len(a)] 131 | a3 = a[(v + 1) % len(a)] 132 | return 0.5 * (a1 - a3) / (a1 - 2 * a2 + a3) 133 | 134 | 135 | def find_image_shift_1d(x: Tensor, y: Tensor, dim: int = -1) -> Tensor: 136 | """ 137 | Find shift between images that maximizes correlation. 138 | 139 | Parameters 140 | ---------- 141 | x : Tensor 142 | Input tensor. 143 | y : int 144 | Input tensor. Should have same shape as x. 145 | 146 | Returns 147 | ---------- 148 | c : int 149 | Estimated shift. 150 | """ 151 | if x.shape != y.shape: 152 | raise ValueError("Input shapes should be identical") 153 | if dim < 0: 154 | dim = x.dim() + dim 155 | fx = torch.fft.fft(x, dim=dim) 156 | fy = torch.fft.fft(y, dim=dim) 157 | c = (fx * fy.conj()) / (torch.abs(fx) * torch.abs(fy)) 158 | other_dims = [i for i in range(x.dim()) if i != dim] 159 | c = torch.abs(torch.fft.ifft(c, dim=dim)) 160 | if len(other_dims) > 0: 161 | c = torch.mean(c, dim=other_dims) 162 | return torch.argmax(c) 163 | 164 | 165 | def fft_peak_1d(x: Tensor, dim: int = -1, fractional: bool = True) -> Tensor: 166 | """ 167 | Find fractional peak of ``abs(fft(x))``. 168 | 169 | Parameters 170 | ---------- 171 | x : Tensor 172 | Input tensor. 173 | dim : int 174 | Dimension to calculate peak. 175 | fractional : bool 176 | Estimate peak location with fractional index accuracy. 177 | 178 | Returns 179 | ---------- 180 | a : int or float 181 | Estimated peak index. 182 | """ 183 | fx = torch.abs(torch.fft.fft(x, dim=dim)) 184 | a = torch.argmax(fx) 185 | if fractional: 186 | a = a + quad_interp(fx, a) 187 | l = x.shape[dim] 188 | if a > l // 2: 189 | a = l - a 190 | return a 191 | 192 | 193 | def detrend(x: Tensor) -> Tensor: 194 | """ 195 | Removes linear trend 196 | 197 | Parameters 198 | ---------- 199 | x : Tensor 200 | Input tensor. Should be 1 dimensional. 201 | Returns 202 | ---------- 203 | x : Tensor 204 | x with linear trend removed. 205 | """ 206 | n = x.shape[0] 207 | k = np.arange(n) / n 208 | a, b = np.polyfit(k, x.cpu().numpy(), 1) 209 | return x - (a * torch.arange(n, device=x.device, dtype=x.dtype) / n + b) 210 | 211 | 212 | def entropy(x: Tensor) -> Tensor: 213 | """ 214 | Calculates entropy: 215 | 216 | ``-sum(y*log(y))`` 217 | 218 | where ``y = abs(x) / sum(abs(x))``. 219 | 220 | Parameters 221 | ---------- 222 | x : Tensor 223 | Input tensor. 224 | 225 | Returns 226 | ---------- 227 | entropy : Tensor 228 | Calculated entropy of the input. 229 | """ 230 | ax = torch.abs(x) 231 | ax /= torch.sum(ax) 232 | return -torch.sum(torch.xlogy(ax, ax)) 233 | 234 | 235 | def contrast(x: Tensor, dim: int = -1) -> Tensor: 236 | """ 237 | Calculates negative contrast: 238 | 239 | ``-mean(std/mu)`` 240 | 241 | where ``mu`` is mean and ``std`` is standard deviation of ``abs(x)`` along 242 | dimension ``dim``. 243 | 244 | Parameters 245 | ---------- 246 | x : Tensor 247 | Input tensor. 248 | 249 | Returns 250 | ---------- 251 | contrast: Tensor 252 | Calculated negative contrast of the input. 253 | """ 254 | std, mu = torch.std_mean(torch.abs(x), dim=dim) 255 | contrast = torch.mean(std / mu) 256 | return -contrast 257 | 258 | 259 | def shift_spectrum(x: Tensor, dim: int = -1) -> Tensor: 260 | """ 261 | Equivalent to: ``fft(ifftshift(ifft(x, dim), dim), dim)``, 262 | but avoids calculating FFTs. 263 | 264 | Parameters 265 | ---------- 266 | x : Tensor 267 | Input tensor. 268 | 269 | Returns 270 | ---------- 271 | y : Tensor 272 | Shifted tensor. 273 | """ 274 | if dim != -1: 275 | raise NotImplementedError("dim should be -1") 276 | shape = [1] * len(x.shape) 277 | shape[dim] = x.shape[dim] 278 | c = torch.ones(shape, dtype=torch.float32, device=x.device) 279 | c[..., 1::2] = -1 280 | return x * c 281 | 282 | 283 | def generate_fmcw_data( 284 | target_pos: Tensor, 285 | target_rcs: Tensor, 286 | pos: Tensor, 287 | fc: float, 288 | bw: float, 289 | tsweep: float, 290 | fs: float, 291 | d0: float = 0, 292 | rvp: bool = True, 293 | ) -> Tensor: 294 | """ 295 | Generate FMCW radar time-domain IF signal. 296 | 297 | Parameters 298 | ---------- 299 | target_pos : Tensor 300 | [ntargets, 3] tensor of target XYZ positions. 301 | target_rcs : Tensor 302 | [ntargets, 1] tensor of target reflectivity. 303 | pos : Tensor 304 | [nsweeps, 3] tensor of platform positions. 305 | fc : float 306 | RF center frequency in Hz. 307 | bw : float 308 | RF bandwidth in Hz. 309 | tsweep : float 310 | Length of one sweep in seconds. 311 | fs : float 312 | Sampling frequency in Hz. 313 | d0 : float 314 | Zero range. 315 | rvp : bool 316 | True to include residual video phase term. 317 | 318 | Returns 319 | ---------- 320 | data : Tensor 321 | [nsweeps, nsamples] measurement data. 322 | """ 323 | if pos.dim() != 2: 324 | raise ValueError("pos tensor should have 2 dimensions") 325 | if pos.shape[1] != 3: 326 | raise ValueError("positions should be 3 dimensional") 327 | npos = pos.shape[0] 328 | nsamples = int(fs * tsweep) 329 | 330 | device = pos.device 331 | data = torch.zeros((npos, nsamples), dtype=torch.complex64, device=device) 332 | t = torch.arange(nsamples, dtype=torch.float32, device=device) / fs 333 | k = bw / tsweep 334 | 335 | c0 = 299792458 336 | 337 | use_rvp = 1 if rvp else 0 338 | 339 | t = t[None, :] 340 | for e, target in enumerate(target_pos): 341 | d = torch.linalg.vector_norm(pos - target[None, :], dim=-1)[:, None] + d0 342 | tau = 2 * d / c0 343 | data += (target_rcs[e] / d**4) * torch.exp( 344 | -1j * 2 * pi * (fc * tau - k * tau * t + use_rvp * 0.5 * k * tau**2) 345 | ) 346 | return data 347 | 348 | 349 | def make_polar_grid( 350 | r0: float, r1: float, nr: int, ntheta: int, theta_limit: int = 1, squint: float = 0 351 | ) -> dict: 352 | """ 353 | Generate polar grid dict in format understood by other polar functions. 354 | 355 | Parameters 356 | ---------- 357 | r0 : float 358 | Minimum range in m. 359 | r1 : float 360 | Maximum range in m. 361 | nr : float 362 | Number of range points. 363 | ntheta : float 364 | Number of azimuth points. 365 | theta_limit : float 366 | Theta axis limits, symmetrical around zero. 367 | Units are sin of angle (0 to 1 valid range). 368 | Default is 1. 369 | squint : float 370 | Grid azimuth mean angle, radians. 371 | 372 | Returns 373 | ---------- 374 | grid_polar : dict 375 | Polar grid dict. 376 | """ 377 | t0 = np.clip(np.sin(squint) - theta_limit, -1, 1) 378 | t1 = np.clip(np.sin(squint) + theta_limit, -1, 1) 379 | grid_polar = {"r": (r0, r1), "theta": (t0, t1), "nr": nr, "ntheta": ntheta} 380 | return grid_polar 381 | 382 | 383 | def phase_to_distance(p: Tensor, fc: float) -> Tensor: 384 | """ 385 | Convert radar reflection phase shift to distance. 386 | 387 | Parameters 388 | ---------- 389 | p : Tensor 390 | Phase shift tensor. 391 | fc : float 392 | RF center frequency. 393 | """ 394 | c0 = 299792458 395 | return c0 * p / (4 * torch.pi * fc) 396 | 397 | 398 | def fft_lowpass_filter_window( 399 | target_data: Tensor, window: str | tuple = "hamming", window_width: int = None 400 | ) -> Tensor: 401 | """ 402 | FFT low-pass filtering with a configurable window function. 403 | 404 | Parameters 405 | ---------- 406 | target_data : Tensor 407 | Input data. 408 | window_type : str 409 | Window to apply. See scipy.get_window for syntax. 410 | e.g., 'hann', 'hamming', 'blackman'. 411 | window_width : int 412 | Width of the window in samples. If None or larger than signal, returns 413 | the input unchanged. 414 | 415 | Returns 416 | ---------- 417 | Filtered tensor (same shape as input) 418 | """ 419 | fdata = torch.fft.fft(target_data, dim=-1) 420 | n = target_data.size(-1) 421 | 422 | # If window_width is None, do nothing 423 | if window_width is None or window_width > n: 424 | return target_data 425 | 426 | # Window needs to be centered at DC in FFT 427 | half_width = (window_width + 1) // 2 428 | half_window = get_window(window, 2 * half_width - 1, fftbins=True)[half_width - 1 :] 429 | w = np.zeros(n, dtype=np.float32) 430 | w[:half_width] = half_window 431 | w[-half_width + 1 :] = np.flip(half_window[1:]) 432 | 433 | w = torch.tensor(w).to(target_data.device) 434 | filtered_data = torch.fft.ifft(fdata * w, dim=-1) 435 | return filtered_data 436 | 437 | def center_pos(pos: Tensor): 438 | """ 439 | Center position to origin. Centers X and Y coordinates, but doesn't modify Z. 440 | Useful for preparing positions for polar backprojection 441 | 442 | Parameters 443 | ---------- 444 | pos : Tensor 445 | 3D positions. Shape should be [N, 3]. 446 | 447 | Returns 448 | ---------- 449 | pos_local : Tensor 450 | Centered positions. 451 | origin : Tensor 452 | Position subtracted from the pos. 453 | h : Tensor 454 | Mean height. 455 | """ 456 | origin = torch.tensor( 457 | [ 458 | torch.mean(pos[:, 0]), 459 | torch.mean(pos[:, 1]), 460 | 0 461 | ], 462 | device=pos.device, 463 | dtype=torch.float32, 464 | )[None, :] 465 | pos_local = pos - origin 466 | return pos_local, origin 467 | --------------------------------------------------------------------------------