├── tfkbnufft ├── mri │ ├── __init__.py │ └── dcomp_calc.py ├── nufft │ ├── __init__.py │ ├── utils.py │ ├── fft_functions.py │ └── interp_functions.py ├── tests │ ├── __init__.py │ ├── nufft │ │ ├── __init__.py │ │ ├── fft_functions_test.py │ │ └── interp_functions_test.py │ ├── utils │ │ └── itertools.py │ ├── utils.py │ ├── kbnufft_test.py │ ├── mri │ │ └── dcomp_calc_test.py │ └── ndft_test.py ├── utils │ ├── __init__.py │ └── itertools.py ├── __init__.py ├── kbmodule.py └── kbnufft.py ├── .gitattributes ├── requirements.txt ├── .gitignore ├── run_tests.sh ├── .github └── workflows │ ├── test.yml │ └── publish.yml ├── LICENSE ├── setup.py ├── README.md └── profile_tfkbnufft.py /tfkbnufft/mri/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /tfkbnufft/nufft/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /tfkbnufft/tests/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /tfkbnufft/utils/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /tfkbnufft/tests/nufft/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /.gitattributes: -------------------------------------------------------------------------------- 1 | *.ipynb linguist-documentation -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | tensorflow>=2.1.0 2 | numpy 3 | scipy 4 | -------------------------------------------------------------------------------- /tfkbnufft/tests/utils/itertools.py: -------------------------------------------------------------------------------- 1 | # TODO: write tests for product 2 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | .pytest_cache/ 2 | __pycache__/ 3 | .vscode/ 4 | dist/ 5 | build/ 6 | torchkbnufft.egg-info/ 7 | *checkpoint.ipynb 8 | speedtests/ 9 | _build/ 10 | gram/ 11 | tfkbnufft.egg-info/ 12 | -------------------------------------------------------------------------------- /run_tests.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | set -e 3 | pip install torch==1.7 torchkbnufft==0.3.4 scikit-image pytest 4 | # We test ndft_test.py separately as it causes some issues with tracing resulting in hangs. 5 | python -m pytest tfkbnufft --ignore=tfkbnufft/tests/ndft_test.py 6 | python -m pytest tfkbnufft/tests/ndft_test.py 7 | -------------------------------------------------------------------------------- /tfkbnufft/utils/itertools.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | 4 | def cartesian_product(arrays): 5 | # from https://stackoverflow.com/a/11146645/4332585 6 | la = len(arrays) 7 | dtype = np.result_type(*arrays) 8 | arr = np.empty([len(a) for a in arrays] + [la], dtype=dtype) 9 | for i, a in enumerate(np.ix_(*arrays)): 10 | arr[...,i] = a 11 | return arr.reshape(-1, la) 12 | -------------------------------------------------------------------------------- /.github/workflows/test.yml: -------------------------------------------------------------------------------- 1 | # This workflow will install Python dependencies, run tests with a single version of Python 2 | # For more information see: https://help.github.com/actions/language-and-framework-guides/using-python-with-github-actions 3 | name: Continuous testing 4 | 5 | on: 6 | push: 7 | branches: 8 | - 'master' 9 | 10 | pull_request: 11 | branches: 12 | - 'master' 13 | 14 | jobs: 15 | test: 16 | name: Test Code 17 | runs-on: ubuntu-latest 18 | steps: 19 | - uses: actions/checkout@v2 20 | - name: Set up Python 3.8 21 | uses: actions/setup-python@v2 22 | with: 23 | python-version: 3.8 24 | - name: Install dependencies 25 | run: | 26 | python -m pip install --upgrade pip 27 | pip install . 28 | - name: Test with pytest 29 | run: bash run_tests.sh 30 | -------------------------------------------------------------------------------- /tfkbnufft/tests/utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import tensorflow as tf 3 | import torch 4 | 5 | 6 | def to_torch_arg(arg): 7 | if isinstance(arg, list): 8 | return [to_torch_arg(a) for a in arg] 9 | elif isinstance(arg, dict): 10 | return {k: to_torch_arg(v) for k, v in arg.items()} 11 | else: 12 | if np.iscomplex(arg).any(): 13 | torch_x = np.stack((np.real(arg), np.imag(arg))) 14 | return torch.tensor(torch_x) 15 | else: 16 | return torch.tensor(arg) 17 | 18 | def to_tf_arg(arg): 19 | if isinstance(arg, list): 20 | return [to_tf_arg(a) for a in arg] 21 | elif isinstance(arg, dict): 22 | return {k: to_tf_arg(v) for k, v in arg.items()} 23 | else: 24 | return tf.convert_to_tensor(arg) 25 | 26 | def torch_to_numpy(array, complex_dim=None): 27 | if complex_dim is not None: 28 | assert array.shape[complex_dim] == 2 29 | return array.select(complex_dim, 0).numpy() + 1j * array.select(complex_dim, 1).numpy() 30 | else: 31 | return array.numpy() 32 | -------------------------------------------------------------------------------- /tfkbnufft/__init__.py: -------------------------------------------------------------------------------- 1 | """Package info""" 2 | 3 | __version__ = "develop" 4 | __author__ = 'Zaccharie Ramzi' 5 | __author_email__ = 'zaccharie.ramzi@inria.fr' 6 | __license__ = 'MIT' 7 | __homepage__ = 'https://github.com/zaccharieramzi/tfkbnufft' 8 | __docs__ = 'A robust, easy-to-deploy non-uniform Fast Fourier Transform in TensorFlow.' 9 | 10 | try: 11 | # This variable is injected in the __builtins__ by the build 12 | # process. 13 | __TFKBNUFFT_SETUP__ 14 | except NameError: 15 | __TFKBNUFFT_SETUP__ = False 16 | 17 | if __TFKBNUFFT_SETUP__: 18 | import sys 19 | sys.stderr.write('Partial import of during the build process.\n') 20 | else: 21 | # from .kbinterp import KbInterpBack, KbInterpForw 22 | from .kbnufft import kbnufft_forward, kbnufft_adjoint 23 | # from .mrisensenufft import MriSenseNufft, AdjMriSenseNufft, ToepSenseNufft 24 | from .nufft import utils as nufft_utils 25 | 26 | __all__ = [ 27 | # 'KbInterpForw', 28 | # 'KbInterpBack', 29 | 'kbnufft_forward', 30 | 'kbnufft_adjoint', 31 | # 'MriSenseNufft', 32 | # 'AdjMriSenseNufft' 33 | ] 34 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2019 Zaccharie Ramzi 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /tfkbnufft/kbmodule.py: -------------------------------------------------------------------------------- 1 | # I don't necessarily want this to be a tf Layer. It has no weights that can be trained 2 | class KbModule: 3 | """Parent class for tfkbnufft modules. 4 | 5 | This class inherits from nn.Module. It is mostly used to have a central 6 | location for all __repr__ calls. 7 | """ 8 | 9 | def __repr__(self): 10 | filter_list = ['interpob', 'buffer', 'parameters', 'hook', 'module'] 11 | tablecheck = False 12 | out = '\n{}\n'.format(self.__class__.__name__) 13 | out = out + '----------------------------------------\n' 14 | for attr, value in self.__dict__.items(): 15 | if 'table' in attr: 16 | if not tablecheck: 17 | out = out + ' table: {} arrays, lengths: {}\n'.format( 18 | len(self.table), self.table_oversamp) 19 | tablecheck = True 20 | elif ('traj' in attr and attr != 'grad_traj') or 'scaling_coef' in attr: 21 | out = out + ' {}: {} {} array\n'.format( 22 | attr, value.shape, value.dtype) 23 | elif not any([item in attr for item in filter_list]): 24 | out = out + ' {}: {}\n'.format(attr, value) 25 | return out 26 | -------------------------------------------------------------------------------- /.github/workflows/publish.yml: -------------------------------------------------------------------------------- 1 | # This workflows will upload a Python Package using Twine when a release is created 2 | # For more information see: https://help.github.com/en/actions/language-and-framework-guides/using-python-with-github-actions#publishing-to-package-registries 3 | # from https://github.com/grst/python-ci-versioneer/blob/master/.github/workflows/python-publish.yml 4 | name: Upload Python Package 5 | 6 | on: 7 | release: 8 | types: [published] 9 | jobs: 10 | deploy: 11 | runs-on: ubuntu-latest 12 | 13 | steps: 14 | - uses: actions/checkout@v2 15 | - name: Set up Python 16 | uses: actions/setup-python@v2 17 | with: 18 | python-version: "3.x" 19 | - name: Autobump version 20 | run: | 21 | # from refs/tags/v1.2.3 get 1.2.3 22 | VERSION=$(echo $GITHUB_REF | sed 's#.*/v##') 23 | PLACEHOLDER='__version__ = "develop"' 24 | VERSION_FILE='tfkbnufft/__init__.py' 25 | 26 | # ensure the placeholder is there. If grep doesn't find the placeholder 27 | # it exits with exit code 1 and github actions aborts the build. 28 | grep "$PLACEHOLDER" "$VERSION_FILE" 29 | sed -i "s/$PLACEHOLDER/__version__ = \"${VERSION}\"/g" "$VERSION_FILE" 30 | shell: bash 31 | - name: Build and publish to testpypi 32 | env: 33 | TWINE_USERNAME: __token__ 34 | TWINE_PASSWORD: ${{ secrets.TEST_PYPI_API_TOKEN }} 35 | run: | 36 | sh build_package.sh 37 | - name: Publish to pypi 38 | env: 39 | TWINE_USERNAME: __token__ 40 | TWINE_PASSWORD: ${{ secrets.PYPI_API_TOKEN }} 41 | run: | 42 | twine upload dist/* -------------------------------------------------------------------------------- /tfkbnufft/tests/kbnufft_test.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import pytest 3 | import tensorflow as tf 4 | 5 | from tfkbnufft import kbnufft_forward, kbnufft_adjoint 6 | from tfkbnufft.kbnufft import KbNufftModule 7 | 8 | 9 | image_shape = (640, 400) 10 | nspokes = 15 11 | spokelength = image_shape[-1] * 2 12 | kspace_shape = spokelength * nspokes 13 | 14 | def ktraj_function(): 15 | # radial trajectory creation 16 | ga = np.deg2rad(180 / ((1 + np.sqrt(5)) / 2)) 17 | kx = np.zeros(shape=(spokelength, nspokes)) 18 | ky = np.zeros(shape=(spokelength, nspokes)) 19 | ky[:, 0] = np.linspace(-np.pi, np.pi, spokelength) 20 | for i in range(1, nspokes): 21 | kx[:, i] = np.cos(ga) * kx[:, i - 1] - np.sin(ga) * ky[:, i - 1] 22 | ky[:, i] = np.sin(ga) * kx[:, i - 1] + np.cos(ga) * ky[:, i - 1] 23 | 24 | ky = np.transpose(ky) 25 | kx = np.transpose(kx) 26 | 27 | traj = np.stack((ky.flatten(), kx.flatten()), axis=0) 28 | traj = tf.convert_to_tensor(traj)[None, ...] 29 | return traj 30 | 31 | @pytest.mark.parametrize('multiprocessing', [True, False]) 32 | def test_adjoint_gradient(multiprocessing): 33 | traj = ktraj_function() 34 | kspace = tf.zeros([1, 1, kspace_shape], dtype=tf.complex64) 35 | nufft_ob = KbNufftModule( 36 | im_size=(640, 400), 37 | grid_size=None, 38 | norm='ortho', 39 | ) 40 | backward_op = kbnufft_adjoint(nufft_ob._extract_nufft_interpob(), multiprocessing) 41 | with tf.GradientTape() as tape: 42 | tape.watch(kspace) 43 | res = backward_op(kspace, traj) 44 | grad = tape.gradient(res, kspace) 45 | tf_test = tf.test.TestCase() 46 | tf_test.assertEqual(grad.shape, kspace.shape) 47 | 48 | @pytest.mark.parametrize('multiprocessing', [True, False]) 49 | def test_forward_gradient(multiprocessing): 50 | traj = ktraj_function() 51 | image = tf.zeros([1, 1, *image_shape], dtype=tf.complex64) 52 | nufft_ob = KbNufftModule( 53 | im_size=(640, 400), 54 | grid_size=None, 55 | norm='ortho', 56 | ) 57 | forward_op = kbnufft_forward(nufft_ob._extract_nufft_interpob(), multiprocessing) 58 | with tf.GradientTape() as tape: 59 | tape.watch(image) 60 | res = forward_op(image, traj) 61 | grad = tape.gradient(res, image) 62 | tf_test = tf.test.TestCase() 63 | tf_test.assertEqual(grad.shape, image.shape) 64 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | import pathlib 2 | 3 | from setuptools import setup, find_packages 4 | 5 | # https://github.com/williamFalcon/pytorch-lightning/blob/master/setup.py 6 | 7 | try: 8 | import builtins 9 | except ImportError: 10 | import __builtin__ as builtins 11 | 12 | # The directory containing this file 13 | HERE = pathlib.Path(__file__).parent 14 | 15 | # The text of the README file 16 | README = (HERE / "README.md").read_text() 17 | 18 | builtins.__TFKBNUFFT_SETUP__ = True 19 | 20 | import tfkbnufft # noqa: E402 21 | 22 | 23 | def load_requirements(path_dir=HERE, comment_char='#'): 24 | with open(path_dir / 'requirements.txt', 'r') as file: 25 | lines = [ln.strip() for ln in file.readlines()] 26 | reqs = [] 27 | for ln in lines: 28 | # filer all comments 29 | if comment_char in ln: 30 | ln = ln[:ln.index(comment_char)] 31 | if ln: # if requirement is not empty 32 | reqs.append(ln) 33 | return reqs 34 | 35 | 36 | # https://packaging.python.org/discussions/install-requires-vs-requirements 37 | setup( 38 | name='tfkbnufft', 39 | version=tfkbnufft.__version__, 40 | description=tfkbnufft.__docs__, 41 | author=tfkbnufft.__author__, 42 | author_email=tfkbnufft.__author_email__, 43 | url=tfkbnufft.__homepage__, 44 | download_url='https://github.com/zaccharieramzi/tfkbnufft', 45 | license=tfkbnufft.__license__, 46 | packages=find_packages(), 47 | 48 | long_description=open('README.md', encoding='utf-8').read(), 49 | long_description_content_type='text/markdown', 50 | include_package_data=True, 51 | zip_safe=False, 52 | 53 | keywords=['MRI', 'tensorflow'], 54 | python_requires='>=3.5', 55 | setup_requires=[], 56 | install_requires=load_requirements(HERE), 57 | 58 | classifiers=[ 59 | 'Environment :: Console', 60 | 'Natural Language :: English', 61 | # How mature is this project? Common values are 62 | # 3 - Alpha, 4 - Beta, 5 - Production/Stable 63 | 'Development Status :: 4 - Beta', 64 | # Indicate who your project is intended for 65 | 'Topic :: Scientific/Engineering :: Artificial Intelligence', 66 | # Pick your license as you wish 67 | "License :: OSI Approved :: MIT License", 68 | # Specify the Python versions you support here. In particular, ensure 69 | # that you indicate whether you support Python 2, Python 3 or both. 70 | 'Programming Language :: Python :: 3', 71 | 'Programming Language :: Python :: 3.7', 72 | ], 73 | ) 74 | -------------------------------------------------------------------------------- /tfkbnufft/tests/mri/dcomp_calc_test.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import tensorflow as tf 3 | import torch 4 | 5 | from tfkbnufft import kbnufft_forward, kbnufft_adjoint 6 | from tfkbnufft.kbnufft import KbNufftModule 7 | from tfkbnufft.mri.dcomp_calc import calculate_radial_dcomp_tf, \ 8 | calculate_density_compensator 9 | from torchkbnufft import KbNufft, AdjKbNufft 10 | from torchkbnufft.mri.dcomp_calc import calculate_radial_dcomp_pytorch 11 | 12 | 13 | def setup(): 14 | spokelength = 400 15 | grid_size = (spokelength, spokelength) 16 | nspokes = 10 17 | 18 | ga = np.deg2rad(180 / ((1 + np.sqrt(5)) / 2)) 19 | kx = np.zeros(shape=(spokelength, nspokes)) 20 | ky = np.zeros(shape=(spokelength, nspokes)) 21 | ky[:, 0] = np.linspace(-np.pi, np.pi, spokelength) 22 | for i in range(1, nspokes): 23 | kx[:, i] = np.cos(ga) * kx[:, i - 1] - np.sin(ga) * ky[:, i - 1] 24 | ky[:, i] = np.sin(ga) * kx[:, i - 1] + np.cos(ga) * ky[:, i - 1] 25 | 26 | ky = np.transpose(ky) 27 | kx = np.transpose(kx) 28 | 29 | ktraj = np.stack((ky.flatten(), kx.flatten()), axis=0) 30 | im_size = (200, 200) 31 | nufft_ob = KbNufftModule(im_size=im_size, grid_size=grid_size, norm='ortho') 32 | torch_forward = KbNufft(im_size=im_size, grid_size=grid_size, norm='ortho') 33 | torch_backward = AdjKbNufft(im_size=im_size, grid_size=grid_size, norm='ortho') 34 | return ktraj, nufft_ob, torch_forward, torch_backward 35 | 36 | def test_calculate_radial_dcomp_tf(): 37 | ktraj, nufft_ob, torch_forward, torch_backward = setup() 38 | interpob = nufft_ob._extract_nufft_interpob() 39 | tf_nufftob_forw = kbnufft_forward(interpob) 40 | tf_nufftob_back = kbnufft_adjoint(interpob) 41 | tf_ktraj = tf.convert_to_tensor(ktraj) 42 | torch_ktraj = torch.tensor(ktraj).unsqueeze(0) 43 | tf_dcomp = calculate_radial_dcomp_tf(interpob, tf_nufftob_forw, tf_nufftob_back, tf_ktraj) 44 | torch_dcomp = calculate_radial_dcomp_pytorch(torch_forward, torch_backward, torch_ktraj) 45 | np.testing.assert_allclose( 46 | tf_dcomp.numpy(), 47 | torch_dcomp[0].numpy(), 48 | rtol=1e-5, 49 | atol=1e-5, 50 | ) 51 | 52 | def test_density_compensators_tf(): 53 | # This is a simple test to ensure that the code works only! 54 | # We still dont have a method to test if the results are correct 55 | ktraj, nufft_ob, torch_forward, torch_backward = setup() 56 | interpob = nufft_ob._extract_nufft_interpob() 57 | tf_ktraj = tf.convert_to_tensor(ktraj) 58 | nufftob_back = kbnufft_adjoint(interpob) 59 | nufftob_forw = kbnufft_forward(interpob) 60 | tf_dcomp = calculate_density_compensator(interpob, nufftob_forw, nufftob_back, tf_ktraj, zero_grad=False) 61 | tf_dcomp_no_grad = calculate_density_compensator(interpob, nufftob_forw, nufftob_back, tf_ktraj, zero_grad=True) 62 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # TF KB-NUFFT 2 | 3 | [GitHub](https://github.com/zaccharieramzi/tfkbnufft) | [![Build Status](https://travis-ci.com/zaccharieramzi/tfkbnufft.svg?branch=master)](https://travis-ci.com/zaccharieramzi/tfkbnufft) 4 | 5 | 6 | Simple installation from pypi: 7 | ``` 8 | pip install tfkbnufft 9 | ``` 10 | 11 | ## About 12 | 13 | This package is a verly early-stage and modest adaptation to TensorFlow of the [torchkbnufft](https://github.com/mmuckley/torchkbnufft) package written by Matthew Muckley for PyTorch. 14 | Please cite his work appropriately if you use this package. 15 | 16 | ## Computation speed 17 | 18 | The computation speeds are given in seconds, for a 256x256 image with a spokelength of 512 and 405 spokes. 19 | These numbers are not to be directly compared to those of [torchkbnufft](https://github.com/mmuckley/torchkbnufft#computation-speed), since the computation is not the same. 20 | They are just to give a sense of the time required for computation. 21 | 22 | | Operation | CPU | GPU | 23 | |---------------|--------|--------| 24 | | Forward NUFFT | 0.1676 | 0.0626 | 25 | | Adjoint NUFFT | 0.7005 | 0.0635 | 26 | 27 | To obtain these numbers for your machine, run the following commands, after installing this package: 28 | ``` 29 | pip install scikit-image Pillow 30 | python profile_tfkbnufft.py 31 | ``` 32 | 33 | These numbers were obtained with a Quadro P5000. 34 | 35 | 36 | ## Gradients 37 | 38 | ### w.r.t trajectory 39 | 40 | This is experimental currently and is WIP. Please be cautious. 41 | Currently this is tested in CI against results from NDFT, but clear mathematical backing to some 42 | aspects are still being understood for applying the chain rule. 43 | 44 | 45 | ## References 46 | 47 | 1. Fessler, J. A., & Sutton, B. P. (2003). Nonuniform fast Fourier transforms using min-max interpolation. *IEEE transactions on signal processing*, 51(2), 560-574. 48 | 49 | 2. Beatty, P. J., Nishimura, D. G., & Pauly, J. M. (2005). Rapid gridding reconstruction with a minimal oversampling ratio. *IEEE transactions on medical imaging*, 24(6), 799-808. 50 | 51 | 3. Feichtinger, H. G., Gr, K., & Strohmer, T. (1995). Efficient numerical methods in non-uniform sampling theory. Numerische Mathematik, 69(4), 423-440. 52 | 53 | ## Citation 54 | 55 | If you want to cite the package, you can use any of the following: 56 | 57 | ```bibtex 58 | @conference{muckley:20:tah, 59 | author = {M. J. Muckley and R. Stern and T. Murrell and F. Knoll}, 60 | title = {{TorchKbNufft}: A High-Level, Hardware-Agnostic Non-Uniform Fast Fourier Transform}, 61 | booktitle = {ISMRM Workshop on Data Sampling \& Image Reconstruction}, 62 | year = 2020 63 | } 64 | 65 | @misc{Muckley2019, 66 | author = {Muckley, M.J. et al.}, 67 | title = {Torch KB-NUFFT}, 68 | year = {2019}, 69 | publisher = {GitHub}, 70 | journal = {GitHub repository}, 71 | howpublished = {\url{https://github.com/mmuckley/torchkbnufft}} 72 | } 73 | ``` 74 | -------------------------------------------------------------------------------- /profile_tfkbnufft.py: -------------------------------------------------------------------------------- 1 | import time 2 | 3 | import numpy as np 4 | from PIL import Image 5 | from skimage.data import camera 6 | import tensorflow as tf 7 | 8 | from tfkbnufft import kbnufft_forward, kbnufft_adjoint 9 | from tfkbnufft.kbnufft import KbNufftModule 10 | 11 | 12 | def profile_tfkbnufft( 13 | image, 14 | ktraj, 15 | im_size, 16 | device, 17 | ): 18 | if device == 'CPU': 19 | num_nuffts = 20 20 | else: 21 | num_nuffts = 50 22 | print(f'Using {device}') 23 | device_name = f'/{device}:0' 24 | with tf.device(device_name): 25 | image = tf.constant(image) 26 | if device == 'GPU': 27 | image = tf.cast(image, tf.complex64) 28 | ktraj = tf.constant(ktraj) 29 | nufft_ob = KbNufftModule(im_size=im_size, grid_size=None, norm='ortho') 30 | forward_op = kbnufft_forward(nufft_ob._extract_nufft_interpob()) 31 | adjoint_op = kbnufft_adjoint(nufft_ob._extract_nufft_interpob()) 32 | 33 | # warm-up computation 34 | for _ in range(2): 35 | y = forward_op(image, ktraj) 36 | 37 | start_time = time.perf_counter() 38 | for _ in range(num_nuffts): 39 | y = forward_op(image, ktraj) 40 | end_time = time.perf_counter() 41 | avg_time = (end_time-start_time) / num_nuffts 42 | print('forward average time: {}'.format(avg_time)) 43 | 44 | # warm-up computation 45 | for _ in range(2): 46 | x = adjoint_op(y, ktraj) 47 | 48 | # run the adjoint speed tests 49 | start_time = time.perf_counter() 50 | for _ in range(num_nuffts): 51 | x = adjoint_op(y, ktraj) 52 | end_time = time.perf_counter() 53 | avg_time = (end_time-start_time) / num_nuffts 54 | print('backward average time: {}'.format(avg_time)) 55 | 56 | 57 | def run_all_profiles(): 58 | print('running profiler...') 59 | spokelength = 512 60 | nspokes = 405 61 | 62 | print('problem size (radial trajectory, 2-factor oversampling):') 63 | print('number of spokes: {}'.format(nspokes)) 64 | print('spokelength: {}'.format(spokelength)) 65 | 66 | # create an example to run on 67 | image = np.array(Image.fromarray(camera()).resize((256, 256))) 68 | image = image.astype(np.complex) 69 | im_size = image.shape 70 | 71 | image = image[None, None, ...] 72 | 73 | # create k-space trajectory 74 | ga = np.deg2rad(180 / ((1 + np.sqrt(5)) / 2)) 75 | kx = np.zeros(shape=(spokelength, nspokes)) 76 | ky = np.zeros(shape=(spokelength, nspokes)) 77 | ky[:, 0] = np.linspace(-np.pi, np.pi, spokelength) 78 | for i in range(1, nspokes): 79 | kx[:, i] = np.cos(ga) * kx[:, i - 1] - np.sin(ga) * ky[:, i - 1] 80 | ky[:, i] = np.sin(ga) * kx[:, i - 1] + np.cos(ga) * ky[:, i - 1] 81 | 82 | ky = np.transpose(ky) 83 | kx = np.transpose(kx) 84 | 85 | ktraj = np.stack((ky.flatten(), kx.flatten()), axis=0) 86 | 87 | ktraj = ktraj[None, ...] 88 | 89 | profile_tfkbnufft(image, ktraj, im_size, device='CPU') 90 | profile_tfkbnufft(image, ktraj, im_size, device='GPU') 91 | 92 | 93 | if __name__ == '__main__': 94 | run_all_profiles() 95 | -------------------------------------------------------------------------------- /tfkbnufft/tests/ndft_test.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | import tensorflow as tf 3 | import numpy as np 4 | from tfkbnufft import kbnufft_forward, kbnufft_adjoint 5 | from tfkbnufft.kbnufft import KbNufftModule 6 | 7 | 8 | def get_fourier_matrix(ktraj, im_size, im_rank, do_ifft=False): 9 | r = [tf.linspace(-im_size[i]/2, im_size[i]/2-1, im_size[i]) for i in range(im_rank)] 10 | grid_r =tf.cast(tf.reshape(tf.meshgrid(*r ,indexing='ij'), (im_rank, tf.reduce_prod(im_size))), tf.float32) 11 | traj_grid = tf.cast(tf.matmul(tf.transpose(ktraj, [0, 2, 1]), tf.repeat(grid_r[None], ktraj.shape[0], axis=0)), tf.complex64) 12 | if do_ifft: 13 | A = tf.exp(1j * traj_grid) 14 | else: 15 | A = tf.exp(-1j * traj_grid) 16 | A = A / (np.sqrt(tf.reduce_prod(im_size)) * np.power(np.sqrt(2), im_rank)) 17 | return A 18 | 19 | 20 | @pytest.mark.parametrize('im_size', [(10, 10)]) 21 | @pytest.mark.parametrize('batch_size', [1, 2]) 22 | def test_adjoint_and_gradients(im_size, batch_size): 23 | tf.random.set_seed(0) 24 | grid_size = tuple(np.array(im_size)*2) 25 | im_rank = len(im_size) 26 | M = im_size[0] * 2**im_rank 27 | nufft_ob = KbNufftModule(im_size=im_size, grid_size=grid_size, norm='ortho', grad_traj=True) 28 | # Generate Trajectory 29 | ktraj_ori = tf.Variable(tf.random.uniform((batch_size, im_rank, M), minval=-1/2, maxval=1/2)*2*np.pi) 30 | # Have a random signal 31 | signal = tf.Variable(tf.cast(tf.random.uniform((batch_size, 1, *im_size)), tf.complex64)) 32 | kdata = tf.Variable(kbnufft_forward(nufft_ob._extract_nufft_interpob())(signal, ktraj_ori)) 33 | Idata = tf.Variable(kbnufft_adjoint(nufft_ob._extract_nufft_interpob())(kdata, ktraj_ori)) 34 | ktraj_noise = np.copy(ktraj_ori) 35 | ktraj_noise += 0.01 * tf.Variable(tf.random.uniform((batch_size, im_rank, M), minval=-1/2, maxval=1/2)*2*np.pi) 36 | ktraj = tf.Variable(ktraj_noise) 37 | with tf.GradientTape(persistent=True) as g: 38 | I_nufft = kbnufft_adjoint(nufft_ob._extract_nufft_interpob())(kdata, ktraj) 39 | A = get_fourier_matrix(ktraj, im_size, im_rank, do_ifft=True) 40 | I_ndft = tf.reshape(tf.transpose(tf.matmul(kdata, A), [0, 1, 2]), (batch_size, 1, *im_size)) 41 | loss_nufft = tf.math.reduce_mean(tf.abs(Idata - I_nufft)**2) 42 | loss_ndft = tf.math.reduce_mean(tf.abs(Idata - I_ndft)**2) 43 | 44 | tf_test = tf.test.TestCase() 45 | # Test if the NUFFT and NDFT operation is same 46 | tf_test.assertAllClose(I_nufft, I_ndft, atol=2e-3) 47 | 48 | # Test gradients with respect to kdata 49 | gradient_ndft_kdata = g.gradient(I_ndft, kdata)[0] 50 | gradient_nufft_kdata = g.gradient(I_nufft, kdata)[0] 51 | tf_test.assertAllClose(gradient_ndft_kdata, gradient_nufft_kdata, atol=6e-3) 52 | 53 | # Test gradients with respect to trajectory location 54 | gradient_ndft_traj = g.gradient(I_ndft, ktraj)[0] 55 | gradient_nufft_traj = g.gradient(I_nufft, ktraj)[0] 56 | tf_test.assertAllClose(gradient_ndft_traj, gradient_nufft_traj, atol=6e-3) 57 | 58 | # Test gradients in chain rule with respect to ktraj 59 | gradient_ndft_loss = g.gradient(loss_ndft, ktraj)[0] 60 | gradient_nufft_loss = g.gradient(loss_nufft, ktraj)[0] 61 | tf_test.assertAllClose(gradient_ndft_loss, gradient_nufft_loss, atol=5e-4) 62 | 63 | # This is gradient of NDFT from matrix, will help in debug 64 | # gradient_from_matrix = 2*np.pi*1j*tf.matmul(tf.cast(r, tf.complex64), tf.transpose(A))*kdata[0][0] 65 | 66 | 67 | @pytest.mark.parametrize('im_size', [(10, 10)]) 68 | @pytest.mark.parametrize('batch_size', [1, 2]) 69 | def test_forward_and_gradients(im_size, batch_size): 70 | tf.random.set_seed(0) 71 | grid_size = tuple(np.array(im_size)*2) 72 | im_rank = len(im_size) 73 | M = im_size[0] * 2**im_rank 74 | nufft_ob = KbNufftModule(im_size=im_size, grid_size=grid_size, norm='ortho', grad_traj=True) 75 | # Generate Trajectory 76 | ktraj_ori = tf.Variable(tf.random.uniform((batch_size, im_rank, M), minval=-1/2, maxval=1/2)*2*np.pi) 77 | # Have a random signal 78 | signal = tf.Variable(tf.cast(tf.random.uniform((batch_size, 1, *im_size)), tf.complex64)) 79 | kdata = kbnufft_forward(nufft_ob._extract_nufft_interpob())(signal, ktraj_ori) 80 | ktraj_noise = np.copy(ktraj_ori) 81 | ktraj_noise += 0.01 * tf.Variable(tf.random.uniform((batch_size, im_rank, M), minval=-1/2, maxval=1/2)*2*np.pi) 82 | ktraj = tf.Variable(ktraj_noise) 83 | with tf.GradientTape(persistent=True) as g: 84 | kdata_nufft = kbnufft_forward(nufft_ob._extract_nufft_interpob())(signal, ktraj) 85 | A = get_fourier_matrix(ktraj, im_size, im_rank, do_ifft=False) 86 | kdata_ndft = tf.matmul(tf.reshape(signal, (batch_size, 1, tf.reduce_prod(im_size))), tf.transpose(A, [0, 2, 1])) 87 | loss_nufft = tf.math.reduce_mean(tf.abs(kdata - kdata_nufft)**2) 88 | loss_ndft = tf.math.reduce_mean(tf.abs(kdata - kdata_ndft)**2) 89 | 90 | tf_test = tf.test.TestCase() 91 | # Test if the NUFFT and NDFT operation is same 92 | tf_test.assertAllClose(kdata_nufft, kdata_ndft, atol=2e-3) 93 | 94 | # Test gradients with respect to kdata 95 | gradient_ndft_kdata = g.gradient(kdata_ndft, signal)[0] 96 | gradient_nufft_kdata = g.gradient(kdata_nufft, signal)[0] 97 | tf_test.assertAllClose(gradient_ndft_kdata, gradient_nufft_kdata, atol=6e-3) 98 | 99 | # Test gradients with respect to trajectory location 100 | gradient_ndft_traj = g.gradient(kdata_ndft, ktraj)[0] 101 | gradient_nufft_traj = g.gradient(kdata_nufft, ktraj)[0] 102 | tf_test.assertAllClose(gradient_ndft_traj, gradient_nufft_traj, atol=6e-3) 103 | 104 | # Test gradients in chain rule with respect to ktraj 105 | gradient_ndft_loss = g.gradient(loss_ndft, ktraj)[0] 106 | gradient_nufft_loss = g.gradient(loss_nufft, ktraj)[0] 107 | tf_test.assertAllClose(gradient_ndft_loss, gradient_nufft_loss, atol=5e-4) 108 | # This is gradient of NDFT from matrix, will help in debug 109 | # gradient_ndft_matrix = -1j * tf.transpose(tf.matmul(A, tf.transpose(tf.cast(grid_r, tf.complex64) * signal[0][0]))) 110 | -------------------------------------------------------------------------------- /tfkbnufft/mri/dcomp_calc.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import tensorflow as tf 3 | from ..nufft.interp_functions import kbinterp, adjkbinterp 4 | 5 | 6 | def calculate_radial_dcomp_tf(interpob, nufftob_forw, nufftob_back, ktraj, stacks=False): 7 | """Numerical density compensation estimation for a radial trajectory. 8 | 9 | Estimates the density compensation function numerically using a NUFFT 10 | operator (nufftob_forw and nufftob_back) and a k-space trajectory (ktraj). 11 | The function applies A'A1 (where A is the nufftob and 1 is a ones vector) 12 | and estimates the signal accumulation at the origin of k-space. It then 13 | returns a vector of density compensation values that are computed based on 14 | the distance from the k-space center and thresholded above the center 15 | density estimate. Then, a density-compensated image can be calculated by 16 | applying A'Wy, where W is a diagonal matrix with the density compensation 17 | values. 18 | 19 | This function uses a nufft hyper parameter dictionary, the associated nufft 20 | operators and k-space trajectory. 21 | 22 | Args: 23 | interpob (dict): the output of `KbNufftModule._extract_nufft_interpob` 24 | containing all the hyper-parameters for the nufft computation. 25 | nufftob_forw (fun) 26 | nufftob_back (fun) 27 | ktraj (tensor): The k-space trajectory in radians/voxel dimension (d, m). 28 | d is the number of spatial dimensions, and m is the length of the 29 | trajectory. 30 | stacks (bool): whether the trajectory is actually a stacks of radial 31 | for 3D imaging rather than a pure radial trajectory. Not tested. 32 | Defaults to False. 33 | 34 | Returns: 35 | tensor: The density compensation coefficients for ktraj of size (m). 36 | """ 37 | # remove sensitivities if dealing with MriSenseNufft 38 | if not interpob['norm'] == 'ortho': 39 | norm_factor = tf.reduce_prod(interpob['grid_size']) 40 | else: 41 | norm_factor = 1 42 | 43 | # append 0s for batch, first coil 44 | im_size = interpob['im_size'] 45 | if len(im_size) != 3 and stacks: 46 | raise ValueError('`stacks` argument can only be used for 3d data') 47 | image_loc = tf.concat([ 48 | (0, 0,), 49 | im_size // 2, 50 | ], axis=0) 51 | 52 | 53 | # get the size of the test signal (add batch, coil) 54 | test_size = tf.concat([(1, 1,), im_size], axis=0) 55 | 56 | test_sig = tf.ones(test_size, dtype=tf.complex64) 57 | 58 | # get one dcomp for each batch 59 | # extract the signal amplitude increase from center of image 60 | query_point = tf.gather_nd( 61 | nufftob_back( 62 | nufftob_forw( 63 | test_sig, 64 | ktraj[None, :] 65 | ), 66 | ktraj[None, :] 67 | ), 68 | [image_loc], 69 | ) / norm_factor 70 | 71 | # use query point to get ramp intercept 72 | threshold_level = tf.cast(1 / query_point, ktraj.dtype) 73 | 74 | # compute the new dcomp for the batch in batch_ind 75 | pi = tf.constant(np.pi, dtype=ktraj.dtype) 76 | if stacks: 77 | ktraj_thresh = ktraj[0:2] 78 | else: 79 | ktraj_thresh = ktraj 80 | dcomp = tf.maximum( 81 | tf.sqrt(tf.reduce_sum(ktraj_thresh ** 2, axis=0)) / pi, 82 | threshold_level, 83 | ) 84 | 85 | return dcomp 86 | 87 | 88 | def calculate_density_compensator(interpob, nufftob_forw, nufftob_back, ktraj, num_iterations=10, zero_grad=True): 89 | """Numerical density compensation estimation for a any trajectory. 90 | 91 | Estimates the density compensation function numerically using a NUFFT 92 | interpolator operator and a k-space trajectory (ktraj). 93 | This function implements Pipe et al 94 | 95 | This function uses a nufft hyper parameter dictionary, the associated nufft 96 | operators and k-space trajectory. 97 | 98 | Args: 99 | interpob (dict): the output of `KbNufftModule._extract_nufft_interpob` 100 | containing all the hyper-parameters for the nufft computation. 101 | nufftob_forw (fun) 102 | nufftob_back (fun) 103 | ktraj (tensor): The k-space trajectory in radians/voxel dimension (d, m). 104 | d is the number of spatial dimensions, and m is the length of the 105 | trajectory. 106 | num_iterations (int): default 10 107 | number of iterations 108 | zero_grad (bool): default True 109 | when true, assumes that the density compensator is a constant and 110 | returns zero gradients 111 | 112 | Returns: 113 | tensor: The density compensation coefficients for ktraj of size (m). 114 | """ 115 | def _calculate_density_compensator(ktraj): 116 | test_sig = tf.ones([1, 1, ktraj.shape[1]], dtype=tf.float32) 117 | for i in range(num_iterations): 118 | test_sig = test_sig / tf.math.abs(kbinterp( 119 | adjkbinterp(tf.cast(test_sig, tf.complex64), ktraj[None, :], interpob), 120 | ktraj[None, :], 121 | interpob 122 | )) 123 | im_size = interpob['im_size'] 124 | test_sig = tf.cast(test_sig, tf.complex64) 125 | test_size = tf.concat([(1, 1,), im_size], axis=0) 126 | test_im = tf.ones(test_size, dtype=tf.complex64) 127 | test_im_recon = nufftob_back( 128 | test_sig * nufftob_forw( 129 | test_im, 130 | ktraj[None, :] 131 | ), 132 | ktraj[None, :] 133 | ) 134 | ratio = tf.reduce_mean(tf.math.abs(test_im_recon)) 135 | test_sig = test_sig / tf.cast(ratio, test_sig.dtype) 136 | test_sig = test_sig[0, 0] 137 | return test_sig 138 | 139 | @tf.custom_gradient 140 | def _calculate_density_compensator_no_grad(ktraj): 141 | """Internal function that returns density compensators, but also returns 142 | no gradients""" 143 | dc_weights = _calculate_density_compensator(ktraj) 144 | def grad(dy): 145 | return None 146 | return dc_weights, grad 147 | 148 | if zero_grad: 149 | return _calculate_density_compensator_no_grad(ktraj) 150 | else: 151 | return _calculate_density_compensator(ktraj) 152 | -------------------------------------------------------------------------------- /tfkbnufft/tests/nufft/fft_functions_test.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import pytest 3 | from skimage.data import shepp_logan_phantom 4 | import tensorflow as tf 5 | import torch 6 | 7 | from tfkbnufft.nufft import fft_functions as tf_fft_functions 8 | from torchkbnufft.nufft import fft_functions as torch_fft_functions 9 | 10 | def _crop_center(img, cropx, cropy): 11 | y, x = img.shape 12 | startx = x//2-(cropx//2) 13 | starty = y//2-(cropy//2) 14 | return img[starty:starty+cropy, startx:startx+cropx] 15 | 16 | @pytest.mark.parametrize('norm', ['ortho', None]) 17 | @pytest.mark.parametrize('multiprocessing', [True, False]) 18 | def test_scale_and_fft_on_image_volume(norm, multiprocessing): 19 | # problem definition 20 | x = shepp_logan_phantom().astype(np.complex64) 21 | im_size = x.shape 22 | scaling_coeffs = np.random.randn(*im_size) + 1j * np.random.randn(*im_size) 23 | scaling_coeffs = scaling_coeffs.astype(np.complex64) 24 | grid_size = [2*im_dim for im_dim in im_size] 25 | # torch computations 26 | torch_x = np.stack((np.real(x), np.imag(x))) 27 | torch_x = torch.tensor(torch_x).unsqueeze(0).unsqueeze(0) 28 | torch_scaling_coeffs = torch.tensor( 29 | np.stack((np.real(scaling_coeffs), np.imag(scaling_coeffs))) 30 | ) 31 | 32 | res_torch = torch_fft_functions.scale_and_fft_on_image_volume( 33 | torch_x, 34 | torch_scaling_coeffs, 35 | torch.tensor(grid_size).float(), 36 | torch.tensor(im_size), 37 | norm, 38 | 39 | ).numpy() 40 | res_torch = res_torch[:, :, 0] + 1j *res_torch[:, :, 1] 41 | # tf computations 42 | res_tf = tf_fft_functions.scale_and_fft_on_image_volume( 43 | tf.convert_to_tensor(x)[None, None, ...], 44 | tf.convert_to_tensor(scaling_coeffs), 45 | tf.convert_to_tensor(grid_size), 46 | tf.convert_to_tensor(im_size), 47 | norm, 48 | multiprocessing=multiprocessing, 49 | ).numpy() 50 | np.testing.assert_allclose(res_torch, res_tf, rtol=1e-4, atol=2*1e-2) 51 | 52 | @pytest.mark.parametrize('norm', ['ortho', None]) 53 | @pytest.mark.parametrize('multiprocessing', [True, False]) 54 | def test_ifft_and_scale_on_gridded_data(norm, multiprocessing): 55 | # problem definition 56 | x = shepp_logan_phantom().astype(np.complex64) 57 | grid_size = x.shape 58 | im_size = [im_dim//2 for im_dim in grid_size] 59 | scaling_coeffs = np.random.randn(*im_size) + 1j * np.random.randn(*im_size) 60 | scaling_coeffs = scaling_coeffs.astype(np.complex64) 61 | # torch computations 62 | torch_x = np.stack((np.real(x), np.imag(x))) 63 | torch_x = torch.tensor(torch_x).unsqueeze(0).unsqueeze(0) 64 | torch_scaling_coeffs = torch.tensor( 65 | np.stack((np.real(scaling_coeffs), np.imag(scaling_coeffs))) 66 | ) 67 | res_torch = torch_fft_functions.ifft_and_scale_on_gridded_data( 68 | torch_x, 69 | torch_scaling_coeffs, 70 | torch.tensor(grid_size).float(), 71 | torch.tensor(im_size), 72 | norm, 73 | ).numpy() 74 | res_torch = res_torch[:, :, 0] + 1j *res_torch[:, :, 1] 75 | # tf computations 76 | res_tf = tf_fft_functions.ifft_and_scale_on_gridded_data( 77 | tf.convert_to_tensor(x)[None, None, ...], 78 | tf.convert_to_tensor(scaling_coeffs), 79 | tf.convert_to_tensor(grid_size), 80 | tf.convert_to_tensor(im_size), 81 | norm, 82 | multiprocessing=multiprocessing, 83 | ).numpy() 84 | np.testing.assert_allclose(res_torch, res_tf, rtol=1e-4, atol=2) 85 | 86 | @pytest.mark.parametrize('norm', ['ortho', None]) 87 | @pytest.mark.parametrize('multiprocessing', [True, False]) 88 | def test_scale_and_fft_on_image_volume_3d(norm, multiprocessing): 89 | # problem definition 90 | x = shepp_logan_phantom().astype(np.complex64) 91 | x = _crop_center(x, 128, 128) 92 | x = x[None, ...] 93 | x = np.tile(x, [128, 1, 1]) 94 | im_size = x.shape 95 | scaling_coeffs = np.random.randn(*im_size) + 1j * np.random.randn(*im_size) 96 | scaling_coeffs = scaling_coeffs.astype(np.complex64) 97 | grid_size = [2*im_dim for im_dim in im_size] 98 | # torch computations 99 | torch_x = np.stack((np.real(x), np.imag(x))) 100 | torch_x = torch.tensor(torch_x).unsqueeze(0).unsqueeze(0) 101 | torch_scaling_coeffs = torch.tensor( 102 | np.stack((np.real(scaling_coeffs), np.imag(scaling_coeffs))) 103 | ) 104 | 105 | res_torch = torch_fft_functions.scale_and_fft_on_image_volume( 106 | torch_x, 107 | torch_scaling_coeffs, 108 | torch.tensor(grid_size).float(), 109 | torch.tensor(im_size), 110 | norm, 111 | 112 | ).numpy() 113 | res_torch = res_torch[:, :, 0] + 1j *res_torch[:, :, 1] 114 | # tf computations 115 | res_tf = tf_fft_functions.scale_and_fft_on_image_volume( 116 | tf.convert_to_tensor(x)[None, None, ...], 117 | tf.convert_to_tensor(scaling_coeffs), 118 | tf.convert_to_tensor(grid_size), 119 | tf.convert_to_tensor(im_size), 120 | norm, 121 | im_rank=3, 122 | multiprocessing=multiprocessing, 123 | ).numpy() 124 | np.testing.assert_allclose(res_torch, res_tf, rtol=1e-4, atol=2*1e-2) 125 | 126 | @pytest.mark.parametrize('norm', ['ortho', None]) 127 | @pytest.mark.parametrize('multiprocessing', [True, False]) 128 | def test_ifft_and_scale_on_gridded_data_3d(norm, multiprocessing): 129 | # problem definition 130 | x = shepp_logan_phantom().astype(np.complex64) 131 | x = _crop_center(x, 128, 128) 132 | x = x[None, ...] 133 | x = np.tile(x, [128, 1, 1]) 134 | grid_size = x.shape 135 | im_size = [im_dim//2 for im_dim in grid_size] 136 | scaling_coeffs = np.random.randn(*im_size) + 1j * np.random.randn(*im_size) 137 | scaling_coeffs = scaling_coeffs.astype(np.complex64) 138 | # torch computations 139 | torch_x = np.stack((np.real(x), np.imag(x))) 140 | torch_x = torch.tensor(torch_x).unsqueeze(0).unsqueeze(0) 141 | torch_scaling_coeffs = torch.tensor( 142 | np.stack((np.real(scaling_coeffs), np.imag(scaling_coeffs))) 143 | ) 144 | res_torch = torch_fft_functions.ifft_and_scale_on_gridded_data( 145 | torch_x, 146 | torch_scaling_coeffs, 147 | torch.tensor(grid_size).float(), 148 | torch.tensor(im_size), 149 | norm, 150 | ).numpy() 151 | res_torch = res_torch[:, :, 0] + 1j *res_torch[:, :, 1] 152 | # tf computations 153 | res_tf = tf_fft_functions.ifft_and_scale_on_gridded_data( 154 | tf.convert_to_tensor(x)[None, None, ...], 155 | tf.convert_to_tensor(scaling_coeffs), 156 | tf.convert_to_tensor(grid_size), 157 | tf.convert_to_tensor(im_size), 158 | norm, 159 | im_rank=3, 160 | multiprocessing=multiprocessing, 161 | ).numpy() 162 | np.testing.assert_allclose(res_torch, res_tf, rtol=1e-4, atol=2) 163 | -------------------------------------------------------------------------------- /tfkbnufft/nufft/utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from scipy import special 3 | from scipy.sparse import coo_matrix 4 | 5 | 6 | def build_spmatrix(om, numpoints, im_size, grid_size, n_shift, order, alpha): 7 | """Builds a sparse matrix with the interpolation coefficients. 8 | 9 | Args: 10 | om (ndarray): An array of coordinates to interpolate to. 11 | im_size (tuple): Size of base image. 12 | grid_size (tuple): Size of the grid to interpolate from. 13 | n_shift (tuple): Number of points to shift for fftshifts. 14 | order (tuple): Order of Kaiser-Bessel kernel. 15 | alpha (tuple): KB parameter. 16 | 17 | Returns: 18 | coo_matrix: A scipy sparse interpolation matrix. 19 | """ 20 | spmat = -1 21 | 22 | ndims = om.shape[0] 23 | klength = om.shape[1] 24 | 25 | # calculate interpolation coefficients using kb kernel 26 | def interp_coeff(om, npts, grdsz, alpha, order): 27 | gam = 2 * np.pi / grdsz 28 | interp_dist = om / gam - np.floor(om / gam - npts / 2) 29 | Jvec = np.reshape(np.array(range(1, npts + 1)), (1, npts)) 30 | kern_in = -1 * Jvec + np.expand_dims(interp_dist, 1) 31 | 32 | cur_coeff = np.zeros(shape=kern_in.shape, dtype=np.complex64) 33 | indices = abs(kern_in) < npts / 2 34 | bess_arg = np.sqrt(1 - (kern_in[indices] / (npts / 2))**2) 35 | denom = special.iv(order, alpha) 36 | cur_coeff[indices] = special.iv(order, alpha * bess_arg) / denom 37 | cur_coeff = np.real(cur_coeff) 38 | 39 | return cur_coeff, kern_in 40 | 41 | full_coef = [] 42 | kd = [] 43 | for i in range(ndims): 44 | N = im_size[i] 45 | J = numpoints[i] 46 | K = grid_size[i] 47 | 48 | # get the interpolation coefficients 49 | coef, kern_in = interp_coeff(om[i, :], J, K, alpha[i], order[i]) 50 | 51 | gam = 2 * np.pi / K 52 | phase_scale = 1j * gam * (N - 1) / 2 53 | 54 | phase = np.exp(phase_scale * kern_in) 55 | full_coef.append(phase * coef) 56 | 57 | # nufft_offset 58 | koff = np.expand_dims(np.floor(om[i, :] / gam - J / 2), 1) 59 | Jvec = np.reshape(np.array(range(1, J + 1)), (1, J)) 60 | kd.append(np.mod(Jvec + koff, K) + 1) 61 | 62 | for i in range(len(kd)): 63 | kd[i] = (kd[i] - 1) * np.prod(grid_size[i + 1:]) 64 | 65 | # build the sparse matrix 66 | kk = kd[0] 67 | spmat_coef = full_coef[0] 68 | for i in range(1, ndims): 69 | Jprod = np.prod(numpoints[:i + 1]) 70 | # block outer sum 71 | kk = np.reshape( 72 | np.expand_dims(kk, 1) + np.expand_dims(kd[i], 2), 73 | (klength, Jprod) 74 | ) 75 | # block outer prod 76 | spmat_coef = np.reshape( 77 | np.expand_dims(spmat_coef, 1) * 78 | np.expand_dims(full_coef[i], 2), 79 | (klength, Jprod) 80 | ) 81 | 82 | # build in fftshift 83 | phase = np.exp(1j * np.dot(np.transpose(om), 84 | np.expand_dims(n_shift, 1))) 85 | spmat_coef = np.conj(spmat_coef) * phase 86 | 87 | # get coordinates in sparse matrix 88 | trajind = np.expand_dims(np.array(range(klength)), 1) 89 | trajind = np.repeat(trajind, np.prod(numpoints), axis=1) 90 | 91 | # build the sparse matrix 92 | spmat = coo_matrix(( 93 | spmat_coef.flatten(), 94 | (trajind.flatten(), kk.flatten())), 95 | shape=(klength, np.prod(grid_size)) 96 | ) 97 | 98 | return spmat 99 | 100 | 101 | def build_table(numpoints, table_oversamp, grid_size, im_size, ndims, order, alpha): 102 | """Builds an interpolation table. 103 | 104 | Args: 105 | numpoints (tuple): Number of points to use for interpolation in each 106 | dimension. Default is six points in each direction. 107 | table_oversamp (tuple): Table oversampling factor. 108 | grid_size (tuple): Size of the grid to interpolate from. 109 | im_size (tuple): Size of base image. 110 | ndims (int): Number of image dimensions. 111 | order (tuple): Order of Kaiser-Bessel kernel. 112 | alpha (tuple): KB parameter. 113 | 114 | Returns: 115 | list: A list of tables for each dimension. 116 | """ 117 | table = [] 118 | 119 | # build one table for each dimension 120 | for i in range(ndims): 121 | J = numpoints[i] 122 | L = table_oversamp[i] 123 | K = grid_size[i] 124 | N = im_size[i] 125 | 126 | # The following is a trick of Fessler. 127 | # It uses broadcasting semantics to quickly build the table. 128 | t1 = J / 2 - 1 + np.array(range(L)) / L # [L] 129 | om1 = t1 * 2 * np.pi / K # gam 130 | s1 = build_spmatrix( 131 | np.expand_dims(om1, 0), 132 | numpoints=(J,), 133 | im_size=(N,), 134 | grid_size=(K,), 135 | n_shift=(0,), 136 | order=(order[i],), 137 | alpha=(alpha[i],) 138 | ) 139 | h = np.array(s1.getcol(J - 1).todense()) 140 | for col in range(J - 2, -1, -1): 141 | h = np.concatenate( 142 | (h, np.array(s1.getcol(col).todense())), axis=0) 143 | h = np.concatenate((h.flatten(), np.array([0]))) 144 | 145 | table.append(h) 146 | 147 | return table 148 | 149 | 150 | def kaiser_bessel_ft(om, npts, alpha, order, d): 151 | """Computes FT of KB function for scaling in image domain. 152 | 153 | Args: 154 | om (ndarray): An array of coordinates to interpolate to. 155 | npts (int): Number of points to use for interpolation in each 156 | dimension. 157 | order (int): Order of Kaiser-Bessel kernel. 158 | alpha (double or array of doubles): KB parameter. 159 | d (int): ## TODO: find what d is 160 | 161 | Returns: 162 | ndarray: The scaling coefficients. 163 | """ 164 | z = np.sqrt((2 * np.pi * (npts / 2) * om)**2 - alpha**2 + 0j) 165 | nu = d / 2 + order 166 | scaling_coef = (2 * np.pi)**(d / 2) * ((npts / 2)**d) * (alpha**order) / \ 167 | special.iv(order, alpha) * special.jv(nu, z) / (z**nu) 168 | scaling_coef = np.real(scaling_coef) 169 | 170 | return scaling_coef 171 | 172 | 173 | def compute_scaling_coefs(im_size, grid_size, numpoints, alpha, order): 174 | """Computes scaling coefficients for NUFFT operation. 175 | 176 | Args: 177 | im_size (tuple): Size of base image. 178 | grid_size (tuple): Size of the grid to interpolate from. 179 | numpoints (tuple): Number of points to use for interpolation in each 180 | dimension. Default is six points in each direction. 181 | alpha (tuple): KB parameter. 182 | order (tuple): Order of Kaiser-Bessel kernel. 183 | 184 | Returns: 185 | ndarray: The scaling coefficients. 186 | """ 187 | num_coefs = np.array(range(im_size[0])) - (im_size[0] - 1) / 2 188 | scaling_coef = 1 / kaiser_bessel_ft( 189 | num_coefs / grid_size[0], 190 | numpoints[0], 191 | alpha[0], 192 | order[0], 193 | 1 194 | ) 195 | if numpoints[0] == 1: 196 | scaling_coef = np.ones(scaling_coef.shape) 197 | for i in range(1, len(im_size)): 198 | indlist = np.array(range(im_size[i])) - (im_size[i] - 1) / 2 199 | scaling_coef = np.expand_dims(scaling_coef, axis=-1) 200 | tmp = 1 / kaiser_bessel_ft( 201 | indlist / grid_size[i], 202 | numpoints[i], 203 | alpha[i], 204 | order[i], 205 | 1 206 | ) 207 | 208 | for _ in range(i): 209 | tmp = tmp[np.newaxis] 210 | 211 | if numpoints[i] == 1: 212 | tmp = np.ones(tmp.shape) 213 | 214 | scaling_coef = scaling_coef * tmp 215 | 216 | return scaling_coef 217 | -------------------------------------------------------------------------------- /tfkbnufft/tests/nufft/interp_functions_test.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import pytest 3 | from torchkbnufft.nufft import interp_functions as torch_interp_functions 4 | from torchkbnufft.nufft.utils import build_table 5 | 6 | from tfkbnufft.nufft import interp_functions as tf_interp_functions 7 | from tfkbnufft.utils.itertools import cartesian_product 8 | from ..utils import to_torch_arg, torch_to_numpy, to_tf_arg 9 | 10 | def setup(): 11 | grid_size = np.array([800, 800]) 12 | im_size = grid_size / 2 13 | n_samples = 324000 14 | # normalized frequency locations (i.e. between -grid_size/2 and grid_size/2) 15 | tm = np.random.uniform(-grid_size/2, grid_size/2, size=(n_samples, len(im_size))).T 16 | numpoints = np.array([6,] * len(im_size)) 17 | Jgen = np.array([np.random.uniform(numpoints).astype(int)]).T 18 | table_oversamp = 2**10 19 | L = np.array([table_oversamp for i in im_size]) 20 | table = build_table( 21 | numpoints=numpoints, 22 | table_oversamp=L, 23 | grid_size=grid_size, 24 | im_size=im_size, 25 | ndims=len(im_size), 26 | order=(0,) * len(im_size), 27 | alpha=tuple(np.array(2.34) * np.array(numpoints)), 28 | ) 29 | numpoints = numpoints.astype('float') 30 | L = L.astype('float') 31 | return tm, Jgen, table, numpoints, L, grid_size 32 | 33 | 34 | @pytest.mark.parametrize('conjcoef', [True, False]) 35 | def test_calc_coef_and_indices(conjcoef): 36 | tm, Jgen, table, numpoints, L, grid_size = setup() 37 | kofflist = 1 + np.floor(tm - numpoints[:, None] / 2).astype(int) 38 | Jval = Jgen[0] 39 | centers = np.floor(numpoints * L / 2).astype(int) 40 | args = [tm, kofflist, Jval, table, centers, L, grid_size.astype('int')] 41 | torch_args = [to_torch_arg(arg) for arg in args] + [conjcoef] 42 | res_torch_coefs, res_torch_ind = torch_interp_functions.calc_coef_and_indices(*torch_args) 43 | res_torch_coefs = torch_to_numpy(res_torch_coefs, complex_dim=0) 44 | tf_args = [to_tf_arg(arg) for arg in args] + [conjcoef] 45 | res_tf_coefs, res_tf_ind = tf_interp_functions.calc_coef_and_indices(*tf_args) 46 | np.testing.assert_equal(res_torch_ind.numpy(), res_tf_ind.numpy()) 47 | np.testing.assert_allclose(res_torch_coefs, res_tf_coefs.numpy()) 48 | 49 | @pytest.mark.parametrize('n_coil', [1, 2, 5, 16]) 50 | @pytest.mark.parametrize('conjcoef', [True, False]) 51 | def test_run_interp(n_coil, conjcoef): 52 | tm, Jgen, table, numpoints, L, grid_size = setup() 53 | grid_size = grid_size.astype('int') 54 | griddat = np.stack([ 55 | np.reshape( 56 | np.random.randn(*grid_size) + 1j * np.random.randn(*grid_size), 57 | [-1], 58 | ) for i in range(n_coil) 59 | ]) 60 | params = { 61 | 'dims': grid_size, 62 | 'table': table, 63 | 'numpoints': numpoints, 64 | 'Jlist': Jgen, 65 | 'table_oversamp': L, 66 | 'conjcoef': conjcoef, 67 | } 68 | args = [griddat, tm, params] 69 | if not conjcoef: 70 | torch_args = [to_torch_arg(arg) for arg in args] 71 | # I need this because griddat is first n_coil then real/imag 72 | torch_args[0] = torch_args[0].permute(1, 0, 2) 73 | res_torch = torch_interp_functions.run_interp(*torch_args) 74 | # I need this because I create Jlist in a neater way for tensorflow 75 | params['Jlist'] = Jgen.T 76 | tf_args = [to_tf_arg(arg) for arg in args] 77 | res_tf = tf_interp_functions.run_interp(*tf_args) 78 | if not conjcoef: 79 | # Compare results with torch 80 | np.testing.assert_allclose(torch_to_numpy(res_torch, complex_dim=1), res_tf.numpy()) 81 | 82 | @pytest.mark.parametrize('n_coil', [1, 2, 5, 16]) 83 | def test_run_interp_back(n_coil): 84 | tm, Jgen, table, numpoints, L, grid_size = setup() 85 | num_samples = tm.shape[1] 86 | grid_size = grid_size.astype('int') 87 | kdat = np.stack([ 88 | np.random.randn(num_samples) + 1j * np.random.randn(num_samples) 89 | for i in range(n_coil) 90 | ]) 91 | params = { 92 | 'dims': grid_size, 93 | 'table': table, 94 | 'numpoints': numpoints, 95 | 'Jlist': Jgen, 96 | 'table_oversamp': L, 97 | } 98 | args = [kdat, tm, params] 99 | torch_args = [to_torch_arg(arg) for arg in args] 100 | # I need this because griddat is first n_coil then real/imag 101 | torch_args[0] = torch_args[0].permute(1, 0, 2) 102 | res_torch = torch_interp_functions.run_interp_back(*torch_args) 103 | # I need this because I create Jlist in a neater way for tensorflow 104 | params['Jlist'] = Jgen.T 105 | tf_args = [to_tf_arg(arg) for arg in args] 106 | res_tf = tf_interp_functions.run_interp_back(*tf_args) 107 | np.testing.assert_allclose(torch_to_numpy(res_torch, complex_dim=1), res_tf.numpy()) 108 | 109 | @pytest.mark.parametrize('n_coil', [1, 2, 5, 16]) 110 | def test_kbinterp(n_coil): 111 | tm, _, table, numpoints, L, grid_size = setup() 112 | grid_size = grid_size.astype('int') 113 | x = np.stack([ 114 | np.random.randn(*grid_size) + 1j * np.random.randn(*grid_size) 115 | for i in range(n_coil) 116 | ])[None, ...] # adding batch dimension 117 | tm = tm[None, ...] # adding batch dimension 118 | n_shift = np.array((grid_size//2) // 2).astype('float') 119 | Jgen = [] 120 | for i in range(2): 121 | # number of points to use for interpolation is numpoints 122 | Jgen.append(np.arange(numpoints[i])) 123 | Jgen = cartesian_product(Jgen) 124 | interpob = { 125 | 'grid_size': grid_size.astype('float'), 126 | 'table': table, 127 | 'numpoints': numpoints, 128 | 'Jlist': Jgen.astype('int64'), 129 | 'table_oversamp': L, 130 | 'n_shift': n_shift, 131 | } 132 | om = np.zeros_like(tm) 133 | for i in range(tm.shape[1]): 134 | gam = (2 * np.pi / grid_size[i]) 135 | om[:, i, :] = tm[:, i, :] * gam 136 | args = [x, om, interpob] 137 | torch_args = [to_torch_arg(arg) for arg in args] 138 | # I need this because griddat is first nbatch, n_coil then real/imag 139 | torch_args[0] = torch_args[0].permute(1, 2, 0, 3, 4) 140 | res_torch = torch_interp_functions.kbinterp(*torch_args) 141 | tf_args = [to_tf_arg(arg) for arg in args] 142 | res_tf = tf_interp_functions.kbinterp(*tf_args) 143 | # those tols seem like a lot, but for now it'll do 144 | np.testing.assert_allclose(torch_to_numpy(res_torch, complex_dim=2), res_tf.numpy(), rtol=1e-1, atol=2*1e-2) 145 | 146 | @pytest.mark.parametrize('n_coil', [1, 2, 5, 16]) 147 | def test_adjkbinterp(n_coil): 148 | tm, _, table, numpoints, L, grid_size = setup() 149 | num_samples = tm.shape[1] 150 | grid_size = grid_size.astype('int') 151 | y = np.stack([ 152 | np.random.randn(num_samples) + 1j * np.random.randn(num_samples) 153 | for i in range(n_coil) 154 | ])[None, ...] # adding batch dimension 155 | tm = tm[None, ...] # adding batch dimension 156 | n_shift = np.array((grid_size//2) // 2).astype('float') 157 | Jgen = [] 158 | for i in range(2): 159 | # number of points to use for interpolation is numpoints 160 | Jgen.append(np.arange(numpoints[i])) 161 | Jgen = cartesian_product(Jgen) 162 | interpob = { 163 | 'grid_size': grid_size.astype('float'), 164 | 'table': table, 165 | 'numpoints': numpoints, 166 | 'Jlist': Jgen.astype('int64'), 167 | 'table_oversamp': L, 168 | 'n_shift': n_shift, 169 | } 170 | om = np.zeros_like(tm) 171 | for i in range(tm.shape[1]): 172 | gam = (2 * np.pi / grid_size[i]) 173 | om[:, i, :] = tm[:, i, :] * gam 174 | args = [y, om, interpob] 175 | torch_args = [to_torch_arg(arg) for arg in args] 176 | # I need this because griddat is first nbatch, n_coil then real/imag 177 | torch_args[0] = torch_args[0].permute(1, 2, 0, 3) 178 | res_torch = torch_interp_functions.adjkbinterp(*torch_args) 179 | tf_args = [to_tf_arg(arg) for arg in args] 180 | res_tf = tf_interp_functions.adjkbinterp(*tf_args) 181 | # those tols seem like a lot, but for now it'll do 182 | np.testing.assert_allclose(torch_to_numpy(res_torch, complex_dim=2), res_tf.numpy(), rtol=1e-1, atol=2*1e-2) 183 | -------------------------------------------------------------------------------- /tfkbnufft/nufft/fft_functions.py: -------------------------------------------------------------------------------- 1 | import multiprocessing 2 | 3 | import tensorflow as tf 4 | from tensorflow.python.ops.signal.fft_ops import ifft2d, fft2d, fft, ifft 5 | 6 | def tf_mp_ifft(kspace): 7 | k_shape_x = tf.shape(kspace)[-1] 8 | batched_kspace = tf.reshape(kspace, (-1, k_shape_x)) 9 | batched_image = tf.map_fn( 10 | ifft, 11 | batched_kspace, 12 | parallel_iterations=multiprocessing.cpu_count(), 13 | ) 14 | image = tf.reshape(batched_image, tf.shape(kspace)) 15 | return image 16 | 17 | def tf_mp_fft(kspace): 18 | k_shape_x = tf.shape(kspace)[-1] 19 | batched_kspace = tf.reshape(kspace, (-1, k_shape_x)) 20 | batched_image = tf.map_fn( 21 | fft, 22 | batched_kspace, 23 | parallel_iterations=multiprocessing.cpu_count(), 24 | ) 25 | image = tf.reshape(batched_image, tf.shape(kspace)) 26 | return image 27 | 28 | def tf_mp_ifft2d(kspace): 29 | k_shape_x = tf.shape(kspace)[-2] 30 | k_shape_y = tf.shape(kspace)[-1] 31 | batched_kspace = tf.reshape(kspace, (-1, k_shape_x, k_shape_y)) 32 | batched_image = tf.map_fn( 33 | ifft2d, 34 | batched_kspace, 35 | parallel_iterations=multiprocessing.cpu_count(), 36 | ) 37 | image = tf.reshape(batched_image, tf.shape(kspace)) 38 | return image 39 | 40 | def tf_mp_fft2d(image): 41 | shape_x = tf.shape(image)[-2] 42 | shape_y = tf.shape(image)[-1] 43 | batched_image = tf.reshape(image, (-1, shape_x, shape_y)) 44 | batched_kspace = tf.map_fn( 45 | fft2d, 46 | batched_image, 47 | parallel_iterations=multiprocessing.cpu_count(), 48 | ) 49 | kspace = tf.reshape(batched_kspace, tf.shape(image)) 50 | return kspace 51 | 52 | def tf_mp_ifft3d(kspace): 53 | image = tf_mp_fourier3d(kspace, trans_type='inv') 54 | return image 55 | 56 | def tf_mp_fft3d(image): 57 | kspace = tf_mp_fourier3d(image, trans_type='forw') 58 | return kspace 59 | 60 | def tf_mp_fourier3d(x, trans_type='inv'): 61 | fn_2d, fn_1d = (ifft2d, ifft) if trans_type == 'inv' else (fft2d, fft) 62 | n_slices = tf.shape(x)[0] 63 | n_coils = tf.shape(x)[1] 64 | shape_z = tf.shape(x)[-3] 65 | shape_x = tf.shape(x)[-2] 66 | shape_y = tf.shape(x)[-1] 67 | reshaped_x = tf.reshape(x, (-1, shape_x, shape_y)) 68 | batched_incomplete_y = tf.map_fn( 69 | fn_2d, 70 | reshaped_x, 71 | parallel_iterations=multiprocessing.cpu_count(), 72 | ) 73 | incomplete_y = tf.reshape(batched_incomplete_y, tf.shape(x)) 74 | incomplete_y_reshaped = tf.transpose(incomplete_y, [0, 1, 3, 4, 2]) 75 | batched_incomplete_y_reshaped = tf.reshape(incomplete_y_reshaped, (-1, shape_z)) 76 | batched_y = tf.map_fn( 77 | fn_1d, 78 | batched_incomplete_y_reshaped, 79 | parallel_iterations=multiprocessing.cpu_count(), 80 | ) 81 | y_reshaped = tf.reshape(batched_y, [n_slices, n_coils, shape_x, shape_y, shape_z]) 82 | y = tf.transpose(y_reshaped, [0, 1, 4, 2, 3]) 83 | return y 84 | 85 | # Generate a fourier dictionary to simplify its use below. 86 | # In the end we have the following list: 87 | # fourier_dict[do_ifft][multiprocessing][rank of image - 1] 88 | fourier_list = [ 89 | [ 90 | [ 91 | tf.signal.fft, 92 | tf.signal.fft2d, 93 | tf.signal.fft3d, 94 | ], 95 | [ 96 | tf_mp_fft, 97 | tf_mp_fft2d, 98 | tf_mp_fft3d, 99 | ] 100 | ], 101 | [ 102 | [ 103 | tf.signal.ifft, 104 | tf.signal.ifft2d, 105 | tf.signal.ifft3d, 106 | ], 107 | [ 108 | tf_mp_ifft, 109 | tf_mp_ifft2d, 110 | tf_mp_ifft3d, 111 | ] 112 | ] 113 | ] 114 | 115 | 116 | def scale_and_fft_on_image_volume(x, scaling_coef, grid_size, im_size, norm, im_rank=2, multiprocessing=False, 117 | do_ifft=False): 118 | """Applies the FFT and any relevant scaling factors to x. 119 | 120 | Args: 121 | x (tensor): The image to be FFT'd. 122 | scaling_coef (tensor): The NUFFT scaling coefficients to be multiplied 123 | prior to FFT. 124 | grid_size (tensor): The oversampled grid size. 125 | im_size (tensor): The image dimensions for x. 126 | norm (str): Type of normalization factor to use. If 'ortho', uses 127 | orthogonal FFT, otherwise, no normalization is applied. 128 | do_ifft (bool, optional, default False): When true, the IFFT is 129 | carried out on signal rather than FFT. This is needed for gradient. 130 | 131 | Returns: 132 | tensor: The oversampled FFT of x. 133 | """ 134 | # zero pad for oversampled nufft 135 | # we don't need permutations since the fft in fourier is done on the 136 | # innermost dimensions and we are handling complex tensors 137 | pad_sizes = [ 138 | (0, 0), # batch dimension 139 | (0, 0), # coil dimension 140 | ] + [ 141 | (0, grid_size[0] - im_size[0]), # nx 142 | ] 143 | if im_rank >= 2: 144 | pad_sizes += [(0, grid_size[1] - im_size[1])] 145 | if im_rank == 3: 146 | pad_sizes += [(0, grid_size[2] - im_size[2])] # nz 147 | scaling_coef = tf.cast(scaling_coef, x.dtype) 148 | scaling_coef = scaling_coef[None, None, ...] 149 | # multiply by scaling coefs 150 | if do_ifft: 151 | x = x * tf.math.conj(scaling_coef) 152 | else: 153 | x = x * scaling_coef 154 | 155 | # zero pad and fft 156 | x = tf.pad(x, pad_sizes) 157 | x = fourier_list[do_ifft][multiprocessing][im_rank - 1](x) 158 | if norm == 'ortho': 159 | scaling_factor = tf.cast(tf.reduce_prod(grid_size), x.dtype) 160 | if do_ifft: 161 | x = x * tf.sqrt(scaling_factor) 162 | else: 163 | x = x / tf.sqrt(scaling_factor) 164 | 165 | return x 166 | 167 | 168 | def ifft_and_scale_on_gridded_data(x, scaling_coef, grid_size, im_size, norm, im_rank=2, multiprocessing=False): 169 | """Applies the iFFT and any relevant scaling factors to x. 170 | 171 | Args: 172 | x (tensor): The gridded data to be iFFT'd. 173 | scaling_coef (tensor): The NUFFT scaling coefficients to be multiplied 174 | after iFFT. 175 | grid_size (tensor): The oversampled grid size. 176 | im_size (tensor): The image dimensions for x. 177 | norm (str): Type of normalization factor to use. If 'ortho', uses 178 | orthogonal iFFT, otherwise, no normalization is applied. 179 | 180 | Returns: 181 | tensor: The iFFT of x. 182 | """ 183 | # we don't need permutations since the fft in fourier is done on the 184 | # innermost dimensions and we are handling complex tensors 185 | # do the inverse fft 186 | x = fourier_list[True][multiprocessing][im_rank - 1](x) 187 | im_size = tf.cast(im_size, tf.int32) 188 | # crop to output size 189 | x = x[:, :, :im_size[0]] 190 | if im_rank >=2: 191 | if im_rank == 3: 192 | x = x[..., :im_size[1], :im_size[2]] 193 | else: 194 | x = x[..., :im_size[1]] 195 | 196 | # scaling 197 | scaling_factor = tf.cast(tf.reduce_prod(grid_size), x.dtype) 198 | if norm == 'ortho': 199 | x = x * tf.sqrt(scaling_factor) 200 | else: 201 | x = x * scaling_factor 202 | 203 | # scaling coefficient multiply 204 | scaling_coef = tf.cast(scaling_coef, x.dtype) 205 | scaling_coef = scaling_coef[None, None, ...] 206 | 207 | x = x * tf.math.conj(scaling_coef) 208 | # this might be nice to try at some point more like an option rather 209 | # than a try except. 210 | # # try to broadcast multiply - batch over coil if not enough memory 211 | # raise_error = False 212 | # try: 213 | # x = x * tf.math.conj(scaling_coef) 214 | # except RuntimeError as e: 215 | # if 'out of memory' in str(e) and not raise_error: 216 | # torch.cuda.empty_cache() 217 | # for coilind in range(x.shape[1]): 218 | # x[:, coilind, ...] = conj_complex_mult( 219 | # x[:, coilind:coilind + 1, ...], scaling_coef, dim=2) 220 | # raise_error = True 221 | # else: 222 | # raise e 223 | # except BaseException: 224 | # raise e 225 | # 226 | return x 227 | 228 | # used for toep thing 229 | # def fft_filter(x, kern, norm=None): 230 | # """FFT-based filtering on a 2-size oversampled grid. 231 | # """ 232 | # x = x.clone() 233 | # 234 | # im_size = torch.tensor(x.shape).to(torch.long)[3:] 235 | # grid_size = im_size * 2 236 | # 237 | # # set up n-dimensional zero pad 238 | # pad_sizes = [] 239 | # permute_dims = [0, 1] 240 | # inv_permute_dims = [0, 1, 2 + grid_size.shape[0]] 241 | # for i in range(grid_size.shape[0]): 242 | # pad_sizes.append(0) 243 | # pad_sizes.append(int(grid_size[-1 - i] - im_size[-1 - i])) 244 | # permute_dims.append(3 + i) 245 | # inv_permute_dims.append(2 + i) 246 | # permute_dims.append(2) 247 | # pad_sizes = tuple(pad_sizes) 248 | # permute_dims = tuple(permute_dims) 249 | # inv_permute_dims = tuple(inv_permute_dims) 250 | # 251 | # # zero pad and fft 252 | # x = F.pad(x, pad_sizes) 253 | # x = x.permute(permute_dims) 254 | # x = torch.fft(x, grid_size.numel()) 255 | # if norm == 'ortho': 256 | # x = x / torch.sqrt(torch.prod(grid_size.to(torch.double))) 257 | # x = x.permute(inv_permute_dims) 258 | # 259 | # # apply the filter 260 | # x = complex_mult(x, kern, dim=2) 261 | # 262 | # # inverse fft 263 | # x = x.permute(permute_dims) 264 | # x = torch.ifft(x, grid_size.numel()) 265 | # x = x.permute(inv_permute_dims) 266 | # 267 | # # crop to input size 268 | # crop_starts = tuple(np.array(x.shape).astype(np.int) * 0) 269 | # crop_ends = [x.shape[0], x.shape[1], x.shape[2]] 270 | # for dim in im_size: 271 | # crop_ends.append(int(dim)) 272 | # x = x[tuple(map(slice, crop_starts, crop_ends))] 273 | # 274 | # # scaling, assume user handled adjoint scaling with their kernel 275 | # if norm == 'ortho': 276 | # x = x / torch.sqrt(torch.prod(grid_size.to(torch.double))) 277 | # 278 | # return x 279 | -------------------------------------------------------------------------------- /tfkbnufft/nufft/interp_functions.py: -------------------------------------------------------------------------------- 1 | import math as m 2 | 3 | import numpy as np 4 | import tensorflow as tf 5 | 6 | 7 | def calc_coef_and_indices(tm, kofflist, Jval, table, centers, L, dims, conjcoef=False): 8 | """Calculates interpolation coefficients and on-grid indices. 9 | 10 | Args: 11 | tm (tensor): normalized frequency locations. 12 | kofflist (tensor): A tensor with offset locations to first elements in 13 | list of nearest neighbords. 14 | Jval (tensor): A tuple-like tensor for how much to increment offsets. 15 | table (list): A list of tensors tabulating a Kaiser-Bessel 16 | interpolation kernel. 17 | centers (tensor): A tensor with the center locations of the table for 18 | each dimension. 19 | L (tensor): A tensor with the table size in each dimension. 20 | dims (tensor): A tensor with image dimensions. 21 | conjcoef (boolean, default=False): A boolean for whether to compute 22 | normal or complex conjugate interpolation coefficients 23 | (conjugate needed for adjoint). 24 | 25 | Returns: 26 | tuple: A tuple with interpolation coefficients and indices. 27 | """ 28 | # type values 29 | dtype = tm.dtype 30 | int_type = tf.int64 31 | 32 | # array shapes 33 | M = tf.shape(tm)[1] 34 | ndims = tm.shape[0] 35 | 36 | # indexing locations 37 | gridind = tf.cast(kofflist + Jval[:, None], dtype) 38 | distind = tf.cast(tf.round((tm - gridind) * L[:, None]), int_type) 39 | gridind = tf.cast(gridind, int_type) 40 | 41 | arr_ind = tf.zeros((M,), dtype=int_type) 42 | coef = tf.ones(M, dtype=table[0].dtype) 43 | 44 | for d in range(ndims): # spatial dimension 45 | sliced_table = tf.gather_nd(table[d], (distind[d, :] + centers[d])[:, None]) 46 | if conjcoef: 47 | coef = coef * tf.math.conj(sliced_table) 48 | else: 49 | coef = coef * sliced_table 50 | 51 | floormod = tf.where( 52 | tf.less(gridind[d, :], 0), 53 | gridind[d, :] + dims[d], 54 | gridind[d, :], 55 | ) 56 | arr_ind = arr_ind + floormod * tf.reduce_prod(dims[d + 1:]) 57 | 58 | return coef, arr_ind 59 | 60 | @tf.function(experimental_relax_shapes=True) 61 | def run_interp(griddat, tm, params): 62 | """Interpolates griddat to off-grid coordinates with input sparse matrices. 63 | 64 | Args: 65 | griddat (tensor): The on-grid frequency data. 66 | tm (tensor): Normalized frequency coordinates. 67 | params (dict): Dictionary with elements 'dims', 'table', 'numpoints', 68 | 'Jlist', and 'table_oversamp'. 69 | 70 | Returns: 71 | tensor: griddat interpolated to off-grid locations. 72 | """ 73 | # extract parameters 74 | dims = params['dims'] 75 | table = params['table'] 76 | numpoints = params['numpoints'] 77 | Jlist = params['Jlist'] 78 | L = params['table_oversamp'] 79 | L = tf.cast(L, tm.dtype) 80 | numpoints = tf.cast(numpoints, tm.dtype) 81 | 82 | # extract data types 83 | int_type = tf.int64 84 | 85 | # center of tables 86 | centers = tf.cast(tf.floor(numpoints * L / 2), int_type) 87 | 88 | # offset from k-space to first coef loc 89 | kofflist = 1 + \ 90 | tf.cast(tf.floor(tm - numpoints[:, None] / 2.0), int_type) 91 | 92 | # initialize output array 93 | kdat = tf.zeros( 94 | shape=(tf.shape(griddat)[0], tf.shape(tm)[-1]), 95 | dtype=griddat.dtype, 96 | ) 97 | 98 | # loop over offsets and take advantage of broadcasting 99 | for J in Jlist: 100 | coef, arr_ind = calc_coef_and_indices( 101 | tm, kofflist, J, table, centers, L, dims, conjcoef=params['conjcoef']) 102 | coef = tf.cast(coef, griddat.dtype) 103 | # I don't need to expand on coil dimension since I use tf gather and not 104 | # gather_nd 105 | # gather and multiply coefficients 106 | kdat += coef[None, ...] * tf.gather(griddat, arr_ind, axis=1) 107 | 108 | return kdat 109 | 110 | @tf.function(experimental_relax_shapes=True) 111 | def run_interp_back(kdat, tm, params): 112 | """Interpolates kdat to on-grid coordinates. 113 | 114 | Args: 115 | kdat (tensor): The off-grid frequency data. 116 | tm (tensor): Normalized frequency coordinates. 117 | params (dict): Dictionary with elements 'dims', 'table', 'numpoints', 118 | 'Jlist', and 'table_oversamp'. 119 | 120 | Returns: 121 | tensor: kdat interpolated to on-grid locations. 122 | """ 123 | # extract parameters 124 | dims = params['dims'] 125 | table = params['table'] 126 | numpoints = params['numpoints'] 127 | Jlist = params['Jlist'] 128 | L = params['table_oversamp'] 129 | L = tf.cast(L, tm.dtype) 130 | numpoints = tf.cast(numpoints, tm.dtype) 131 | 132 | # extract data types 133 | int_type = tf.int64 134 | 135 | # center of tables 136 | centers = tf.cast(tf.floor(numpoints * L / 2), int_type) 137 | 138 | # offset from k-space to first coef loc 139 | kofflist = 1 + \ 140 | tf.cast(tf.floor(tm - numpoints[:, None] / 2.0), int_type) 141 | 142 | # initialize output array 143 | griddat = tf.zeros( 144 | shape=(tf.cast(tf.reduce_prod(dims), tf.int32), tf.shape(kdat)[0]), 145 | dtype=kdat.dtype, 146 | ) 147 | griddat_real = tf.math.real(griddat) 148 | griddat_imag = tf.math.imag(griddat) 149 | 150 | # loop over offsets and take advantage of numpy broadcasting 151 | for J in Jlist: 152 | coef, arr_ind = calc_coef_and_indices( 153 | tm, kofflist, J, table, centers, L, dims, conjcoef=True) 154 | coef = tf.cast(coef, kdat.dtype) 155 | updates = tf.transpose(coef[None, ...] * kdat) 156 | # TODO: change because the array of indexes was only in one dimension 157 | arr_ind = arr_ind[:, None] 158 | # a hack related to https://github.com/tensorflow/tensorflow/issues/40672 159 | # is to deal with real and imag parts separately 160 | griddat_real = tf.tensor_scatter_nd_add(griddat_real, arr_ind, tf.math.real(updates)) 161 | griddat_imag = tf.tensor_scatter_nd_add(griddat_imag, arr_ind, tf.math.imag(updates)) 162 | 163 | griddat = tf.transpose(tf.complex(griddat_real, griddat_imag)) 164 | return griddat 165 | 166 | @tf.function(experimental_relax_shapes=True) 167 | def kbinterp(x, om, interpob, conj=False): 168 | """Apply table interpolation. 169 | 170 | Inputs are assumed to be batch/chans x coil x image dims. 171 | Om should be nbatch x ndims x klength. 172 | 173 | Args: 174 | x (tensor): The oversampled DFT of the signal. 175 | om (tensor, optional): A custom set of k-space points to 176 | interpolate to in radians/voxel. 177 | interpob (dict): An interpolation object with 'table', 'n_shift', 178 | 'grid_size', 'numpoints', and 'table_oversamp' keys. 179 | conj (bool, optional, default False): Boolean value to check if 180 | conjugate value of interpolator coefficient must be used. 181 | This is need for gradients calculation 182 | 183 | Returns: 184 | tensor: The signal interpolated to off-grid locations. 185 | """ 186 | # extract interpolation params 187 | n_shift = interpob['n_shift'] 188 | n_shift = tf.cast(n_shift, om.dtype) 189 | # TODO: refactor all of this with adjkbinterp 190 | grid_size = interpob['grid_size'] 191 | grid_size = tf.cast(grid_size, om.dtype) 192 | numpoints = interpob['numpoints'] 193 | 194 | # convert to normalized freq locs 195 | # the frequencies are originally in [-pi; pi] 196 | # we put them in [-grid_size/2; grid_size/2] 197 | pi = tf.constant(m.pi) 198 | tm = om * grid_size[None, :, None] / tf.cast(2 * pi, om.dtype) 199 | # build an iterator for going over all J values 200 | # set up params if not using sparse mats 201 | params = { 202 | 'dims': None, 203 | 'table': interpob['table'], 204 | 'numpoints': numpoints, 205 | 'Jlist': interpob['Jlist'], 206 | 'table_oversamp': interpob['table_oversamp'], 207 | } 208 | # run the table interpolator for each batch element 209 | # TODO: look into how to use tf.while_loop 210 | params['dims'] = tf.cast(tf.shape(x[0])[1:], 'int64') 211 | params['conjcoef'] = conj 212 | def _map_body(inputs): 213 | _x, _tm, _om = inputs 214 | y_not_shifted = run_interp(tf.reshape(_x, (tf.shape(_x)[0], -1)), _tm, params) 215 | shift = tf.exp(1j * tf.cast(tf.linalg.matvec(tf.transpose(_om), n_shift), y_not_shifted.dtype))[None, ...] 216 | if conj: 217 | y = y_not_shifted * tf.math.conj(shift) 218 | else: 219 | y = y_not_shifted * shift 220 | return y 221 | 222 | y = tf.map_fn(_map_body, [x, tm, om], dtype=x.dtype) 223 | 224 | return y 225 | 226 | @tf.function(experimental_relax_shapes=True) 227 | def adjkbinterp(y, om, interpob): 228 | """Apply table interpolation adjoint. 229 | 230 | Inputs are assumed to be batch/chans x coil x x kspace length. 231 | Om should be nbatch x ndims x klength. 232 | 233 | Args: 234 | y (tensor): The off-grid DFT of the signal. 235 | om (tensor, optional): A set of k-space points to 236 | interpolate from in radians/voxel. 237 | interpob (dict): An interpolation object with 'table', 'n_shift', 238 | 'grid_size', 'numpoints', and 'table_oversamp' keys. 239 | 240 | Returns: 241 | tensor: The signal interpolated to on-grid locations. 242 | """ 243 | n_shift = interpob['n_shift'] 244 | n_shift = tf.cast(n_shift, om.dtype) 245 | 246 | # TODO: refactor with kbinterp 247 | grid_size = interpob['grid_size'] 248 | grid_size = tf.cast(grid_size, om.dtype) 249 | numpoints = interpob['numpoints'] 250 | 251 | # convert to normalized freq locs 252 | # the frequencies are originally in [-pi; pi] 253 | # we put them in [-grid_size/2; grid_size/2] 254 | pi = tf.constant(m.pi) 255 | tm = om * grid_size[None, :, None] / tf.cast(2 * pi, om.dtype) 256 | # set up params if not using sparse mats 257 | params = { 258 | 'dims': None, 259 | 'table': interpob['table'], 260 | 'numpoints': numpoints, 261 | 'Jlist': interpob['Jlist'], 262 | 'table_oversamp': interpob['table_oversamp'], 263 | } 264 | 265 | # run the table interpolator for each batch element 266 | # TODO: look into how to use tf.while_loop 267 | params['dims'] = tf.cast(grid_size, 'int64') 268 | 269 | def _map_body(inputs): 270 | _y, _om, _tm = inputs 271 | y_shifted = _y * tf.math.conj(tf.exp(1j * tf.cast(tf.linalg.matvec(tf.transpose(_om), n_shift), _y.dtype))[None, ...]) 272 | x = run_interp_back(y_shifted, _tm, params) 273 | return x 274 | 275 | x = tf.map_fn(_map_body, [y, om, tm], dtype=y.dtype) 276 | 277 | bsize = tf.shape(y)[0] 278 | ncoil = tf.shape(y)[1] 279 | out_size = tf.concat([[bsize, ncoil], tf.cast(grid_size, 'int64')], 0) 280 | 281 | x = tf.reshape(x, out_size) 282 | 283 | return x 284 | -------------------------------------------------------------------------------- /tfkbnufft/kbnufft.py: -------------------------------------------------------------------------------- 1 | import warnings 2 | 3 | import numpy as np 4 | import tensorflow as tf 5 | 6 | # from .functional.kbnufft import AdjKbNufftFunction, KbNufftFunction 7 | # ToepNufftFunction) 8 | from .kbmodule import KbModule 9 | from .nufft.fft_functions import scale_and_fft_on_image_volume, ifft_and_scale_on_gridded_data 10 | from .nufft.interp_functions import kbinterp, adjkbinterp 11 | from .nufft.utils import build_spmatrix, build_table, compute_scaling_coefs 12 | from .utils.itertools import cartesian_product 13 | 14 | 15 | 16 | class KbNufftModule(KbModule): 17 | """Parent class for KbNufft classes. 18 | 19 | This implementation collects all init functions into one place. 20 | 21 | Args: 22 | im_size (int or tuple of ints): Size of base image. 23 | grid_size (int or tuple of ints, default=2*im_size): Size of the grid 24 | to interpolate from. 25 | numpoints (int or tuple of ints, default=6): Number of points to use 26 | for interpolation in each dimension. Default is six points in each 27 | direction. 28 | n_shift (int or tuple of ints, default=im_size//2): Number of points to 29 | shift for fftshifts. 30 | table_oversamp (int, default=2^10): Table oversampling factor. 31 | kbwidth (double, default=2.34): Kaiser-Bessel width parameter. 32 | order (double, default=0): Order of Kaiser-Bessel kernel. 33 | norm (str, default='None'): Normalization for FFT. Default uses no 34 | normalization. Use 'ortho' to use orthogonal FFTs and preserve 35 | energy. 36 | """ 37 | 38 | def __init__(self, im_size, grid_size=None, numpoints=6, n_shift=None, 39 | table_oversamp=2**10, kbwidth=2.34, order=0, norm='None', 40 | coil_broadcast=False, matadj=False, grad_traj=False): 41 | super(KbNufftModule, self).__init__() 42 | 43 | self.im_size = im_size 44 | self.im_rank = len(im_size) 45 | self.grad_traj = grad_traj 46 | if self.grad_traj: 47 | warnings.warn('The gradient w.r.t trajectory is Experimental and WIP. ' 48 | 'Please use with caution') 49 | if grid_size is None: 50 | self.grid_size = tuple(np.array(self.im_size) * 2) 51 | else: 52 | self.grid_size = grid_size 53 | if n_shift is None: 54 | self.n_shift = tuple(np.array(self.im_size) // 2) 55 | else: 56 | self.n_shift = n_shift 57 | if isinstance(numpoints, int): 58 | self.numpoints = (numpoints,) * len(self.grid_size) 59 | else: 60 | self.numpoints = numpoints 61 | self.alpha = tuple(np.array(kbwidth) * np.array(self.numpoints)) 62 | if isinstance(order, int) or isinstance(order, float): 63 | self.order = (order,) * len(self.grid_size) 64 | else: 65 | self.order = order 66 | if isinstance(table_oversamp, float) or isinstance(table_oversamp, int): 67 | self.table_oversamp = (table_oversamp,) * len(self.grid_size) 68 | else: 69 | self.table_oversamp = table_oversamp 70 | 71 | # dimension checking 72 | assert len(self.grid_size) == len(self.im_size) 73 | assert len(self.n_shift) == len(self.im_size) 74 | assert len(self.numpoints) == len(self.im_size) 75 | assert len(self.alpha) == len(self.im_size) 76 | assert len(self.order) == len(self.im_size) 77 | assert len(self.table_oversamp) == len(self.im_size) 78 | 79 | table = build_table( 80 | numpoints=self.numpoints, 81 | table_oversamp=self.table_oversamp, 82 | grid_size=self.grid_size, 83 | im_size=self.im_size, 84 | ndims=len(self.im_size), 85 | order=self.order, 86 | alpha=self.alpha 87 | ) 88 | self.table = table 89 | assert len(self.table) == len(self.im_size) 90 | 91 | scaling_coef = compute_scaling_coefs( 92 | im_size=self.im_size, 93 | grid_size=self.grid_size, 94 | numpoints=self.numpoints, 95 | alpha=self.alpha, 96 | order=self.order 97 | ) 98 | self.scaling_coef = scaling_coef 99 | self.norm = norm 100 | self.coil_broadcast = coil_broadcast 101 | self.matadj = matadj 102 | 103 | if coil_broadcast == True: 104 | warnings.warn( 105 | 'coil_broadcast will be deprecated in a future release', 106 | DeprecationWarning) 107 | if matadj == True: 108 | warnings.warn( 109 | 'matadj will be deprecated in a future release', 110 | DeprecationWarning) 111 | 112 | self.scaling_coef_tensor = tf.convert_to_tensor(self.scaling_coef) 113 | self.table_tensors = [] 114 | for item in self.table: 115 | self.table_tensors.append(tf.convert_to_tensor(item)) 116 | # register buffer is not necessary in tf, you just have the variable in 117 | # your class, point. 118 | self.n_shift_tensor = tf.convert_to_tensor(np.array(self.n_shift, dtype=np.int64)) 119 | self.grid_size_tensor = tf.convert_to_tensor(np.array(self.grid_size, dtype=np.int64)) 120 | self.im_size_tensor = tf.convert_to_tensor(np.array(self.im_size, dtype=np.int64)) 121 | self.numpoints_tensor = tf.convert_to_tensor(np.array(self.numpoints, dtype=np.double)) 122 | self.table_oversamp_tensor = tf.convert_to_tensor(np.array(self.table_oversamp, dtype=np.double)) 123 | 124 | def _extract_nufft_interpob(self): 125 | """Extracts interpolation object from self. 126 | 127 | Returns: 128 | dict: An interpolation object for the NUFFT operation. 129 | """ 130 | interpob = dict() 131 | interpob['scaling_coef'] = self.scaling_coef_tensor 132 | interpob['table'] = self.table_tensors 133 | interpob['n_shift'] = self.n_shift_tensor 134 | interpob['grid_size'] = self.grid_size_tensor 135 | interpob['im_size'] = self.im_size_tensor 136 | interpob['im_rank'] = self.im_rank 137 | interpob['numpoints'] = self.numpoints_tensor 138 | interpob['table_oversamp'] = self.table_oversamp_tensor 139 | interpob['norm'] = self.norm 140 | interpob['coil_broadcast'] = self.coil_broadcast 141 | interpob['matadj'] = self.matadj 142 | interpob['grad_traj'] = self.grad_traj 143 | Jgen = [] 144 | for i in range(self.im_rank): 145 | # number of points to use for interpolation is numpoints 146 | Jgen.append(np.arange(self.numpoints[i])) 147 | Jgen = cartesian_product(Jgen) 148 | interpob['Jlist'] = Jgen.astype('int64') 149 | 150 | return interpob 151 | 152 | def kbnufft_forward(interpob, multiprocessing=False): 153 | @tf.function(experimental_relax_shapes=True) 154 | @tf.custom_gradient 155 | def kbnufft_forward_for_interpob(x, om): 156 | """Apply FFT and interpolate from gridded data to scattered data. 157 | 158 | Inputs are assumed to be batch/chans x coil x image dims. 159 | Om should be nbatch x ndims x klength. 160 | 161 | Args: 162 | x (tensor): The original imagel. 163 | om (tensor, optional): A new set of omega coordinates at which to 164 | calculate the signal in radians/voxel. 165 | 166 | Returns: 167 | tensor: x computed at off-grid locations in om. 168 | """ 169 | # this is with registered gradient, I would like to try without 170 | # y = KbNufftFunction.apply(x, om, interpob, interp_mats) 171 | # extract interpolation params 172 | scaling_coef = interpob['scaling_coef'] 173 | grid_size = interpob['grid_size'] 174 | im_size = interpob['im_size'] 175 | norm = interpob['norm'] 176 | grad_traj = interpob['grad_traj'] 177 | im_rank = interpob.get('im_rank', 2) 178 | 179 | fft_x = scale_and_fft_on_image_volume( 180 | x, scaling_coef, grid_size, im_size, norm, im_rank=im_rank, multiprocessing=multiprocessing) 181 | 182 | y = kbinterp(fft_x, om, interpob) 183 | 184 | def grad(dy): 185 | # Gradients with respect to image 186 | grid_dy = adjkbinterp(dy, om, interpob) 187 | ifft_dy = ifft_and_scale_on_gridded_data( 188 | grid_dy, scaling_coef, grid_size, im_size, norm, im_rank=im_rank) 189 | if grad_traj: 190 | # Gradients with respect to trajectory locations 191 | r = [tf.linspace(-im_size[i]/2, im_size[i]/2-1, im_size[i]) for i in range(im_rank)] 192 | grid_r = tf.cast(tf.meshgrid(*r, indexing='ij'), x.dtype)[None, ...] 193 | fft_dx_dom = scale_and_fft_on_image_volume( 194 | x * grid_r, scaling_coef, grid_size, im_size, norm, im_rank=im_rank) 195 | # Do this when handling batches 196 | fft_dx_dom = tf.reshape(fft_dx_dom, shape=(-1, 1, *fft_dx_dom.shape[2:])) 197 | nufft_dx_dom = kbinterp(fft_dx_dom, tf.repeat(om, im_rank, axis=0), interpob) 198 | # Unbatch back the data 199 | nufft_dx_dom = tf.reshape(nufft_dx_dom, shape=(-1, im_rank, *nufft_dx_dom.shape[2:])) 200 | dy_dom = tf.cast(-1j * tf.math.conj(dy) * nufft_dx_dom, om.dtype) 201 | # dy_dom = tf.math.reduce_sum(dy_dom, axis=1)[None, :] 202 | else: 203 | dy_dom = None 204 | return ifft_dy, dy_dom 205 | 206 | return y, grad 207 | return kbnufft_forward_for_interpob 208 | 209 | def kbnufft_adjoint(interpob, multiprocessing=False): 210 | @tf.function(experimental_relax_shapes=True) 211 | @tf.custom_gradient 212 | def kbnufft_adjoint_for_interpob(y, om): 213 | """Interpolate from scattered data to gridded data and then iFFT. 214 | 215 | Inputs are assumed to be batch/chans x coil x kspace 216 | length. Om should be nbatch x ndims x klength. 217 | 218 | Args: 219 | y (tensor): The off-grid signal. 220 | om (tensor, optional): The off-grid coordinates in radians/voxel. 221 | 222 | Returns: 223 | tensor: The image after adjoint NUFFT. 224 | """ 225 | grid_y = adjkbinterp(y, om, interpob) 226 | scaling_coef = interpob['scaling_coef'] 227 | grid_size = interpob['grid_size'] 228 | im_size = interpob['im_size'] 229 | norm = interpob['norm'] 230 | grad_traj = interpob['grad_traj'] 231 | im_rank = interpob.get('im_rank', 2) 232 | ifft_y = ifft_and_scale_on_gridded_data( 233 | grid_y, scaling_coef, grid_size, im_size, norm, im_rank=im_rank, multiprocessing=multiprocessing) 234 | 235 | def grad(dx): 236 | # Gradients with respect to off grid signal 237 | fft_dx = scale_and_fft_on_image_volume( 238 | dx, scaling_coef, grid_size, im_size, norm, im_rank=im_rank) 239 | dx_dy = kbinterp(fft_dx, om, interpob) 240 | if grad_traj: 241 | # Gradients with respect to trajectory locations 242 | r = [tf.linspace(-im_size[i]/2, im_size[i]/2-1, im_size[i]) for i in range(im_rank)] 243 | # This wont work for multicoil case as the dimension for dx is `batch_size x coil x Nx x Ny` 244 | grid_r = tf.cast(tf.meshgrid(*r, indexing='ij'), dx.dtype)[None, ...] 245 | ifft_dxr = scale_and_fft_on_image_volume( 246 | tf.math.conj(dx) * grid_r, scaling_coef, grid_size, im_size, norm, im_rank=im_rank, do_ifft=True) 247 | # Do this when handling batches 248 | ifft_dxr = tf.reshape(ifft_dxr, shape=(-1, 1, *ifft_dxr.shape[2:])) 249 | inufft_dxr = kbinterp(ifft_dxr, tf.repeat(om, im_rank, axis=0), interpob, conj=True) 250 | # Unbatch back the data 251 | inufft_dxr = tf.reshape(inufft_dxr, shape=(-1, im_rank, *inufft_dxr.shape[2:])) 252 | dx_dom = tf.cast(1j * y * inufft_dxr, om.dtype) 253 | # dx_dom = tf.math.reduce_sum(dx_dom, axis=1)[None, :] 254 | else: 255 | dx_dom = None 256 | return dx_dy, dx_dom 257 | return ifft_y, grad 258 | return kbnufft_adjoint_for_interpob 259 | 260 | 261 | # class ToepNufft(KbModule): 262 | # """Forward/backward NUFFT with Toeplitz embedding. 263 | # 264 | # This module applies Tx, where T is a matrix such that T = A'A, where A is 265 | # a NUFFT matrix. Using Toeplitz embedding, this module computes the A'A 266 | # operation without interpolations, which is extremely fast. 267 | # 268 | # The module is intended to be used in combination with an fft kernel 269 | # computed to be the frequency response of an embedded Toeplitz matrix. The 270 | # kernel is calculated offline via 271 | # 272 | # torchkbnufft.nufft.toep_functions.calc_toep_kernel 273 | # 274 | # The corresponding kernel is then passed to this module in its forward 275 | # forward operation, which applies a (zero-padded) fft filter using the 276 | # kernel. 277 | # """ 278 | # 279 | # def __init__(self): 280 | # super(ToepNufft, self).__init__() 281 | # 282 | # def forward(self, x, kern, norm=None): 283 | # """Toeplitz NUFFT forward function. 284 | # 285 | # Args: 286 | # x (tensor): The image (or images) to apply the forward/backward 287 | # Toeplitz-embedded NUFFT to. 288 | # kern (tensor): The filter response taking into account Toeplitz 289 | # embedding. 290 | # norm (str, default=None): Use 'ortho' if kern was designed to use 291 | # orthogonal FFTs. 292 | # 293 | # Returns: 294 | # tensor: x after applying the Toeplitz NUFFT. 295 | # """ 296 | # x = ToepNufftFunction.apply(x, kern, norm) 297 | # 298 | # return x 299 | --------------------------------------------------------------------------------