├── .coveragerc ├── .github ├── ISSUE_TEMPLATE │ ├── bug_report.md │ └── feature_request.md └── workflows │ └── python-publish.yml ├── .gitignore ├── .readthedocs.yml ├── .travis.yml ├── CODE_OF_CONDUCT.md ├── CONTRIBUTING.rst ├── LICENSE ├── README.rst ├── codecov.yml ├── conda.recipe ├── bld.bat ├── build.sh ├── conda_build_config.yaml └── meta.yaml ├── docs ├── Makefile ├── conf.py ├── core.rst ├── core_alg.rst ├── core_app.rst ├── core_linop.rst ├── core_prox.rst ├── figures │ ├── architecture.png │ ├── device.png │ ├── multiprocess_desired.png │ └── multiprocess_mpi.png ├── guide_basic.rst ├── guide_iter.rst ├── guide_multi_devices.rst ├── index.rst ├── make.bat ├── mri.rst ├── mri_app.rst ├── mri_linop.rst ├── mri_rf.rst ├── plot.rst └── requirements.txt ├── pyproject.toml ├── requirements.txt ├── run_tests.sh ├── setup.cfg ├── setup.py ├── sigpy ├── __init__.py ├── alg.py ├── app.py ├── backend.py ├── block.py ├── config.py ├── conv.py ├── fourier.py ├── interp.py ├── linop.py ├── mri │ ├── __init__.py │ ├── app.py │ ├── dcf.py │ ├── linop.py │ ├── precond.py │ ├── rf │ │ ├── __init__.py │ │ ├── adiabatic.py │ │ ├── b1sel.py │ │ ├── io.py │ │ ├── linop.py │ │ ├── multiband.py │ │ ├── optcont.py │ │ ├── ptx.py │ │ ├── shim.py │ │ ├── sim.py │ │ ├── slr.py │ │ ├── trajgrad.py │ │ └── util.py │ ├── samp.py │ ├── sim.py │ └── util.py ├── plot.py ├── prox.py ├── pytorch.py ├── sim.py ├── thresh.py ├── util.py ├── version.py └── wavelet.py └── tests ├── __init__.py ├── learn ├── test_app.py └── test_util.py ├── mri ├── __init__.py ├── rf │ ├── __init__.py │ ├── test_adiabatic.py │ ├── test_b1sel.py │ ├── test_linop.py │ ├── test_multiband.py │ ├── test_ptx.py │ ├── test_sim.py │ ├── test_slr.py │ └── test_trajgrad.py ├── test_app.py ├── test_dcf.py ├── test_linop.py ├── test_precond.py └── test_samp.py ├── test_alg.py ├── test_app.py ├── test_block.py ├── test_conv.py ├── test_fourier.py ├── test_interp.py ├── test_linop.py ├── test_prox.py ├── test_pytorch.py ├── test_thresh.py ├── test_util.py ├── test_version.py └── test_wavelet.py /.coveragerc: -------------------------------------------------------------------------------- 1 | [run] 2 | omit = 3 | sigpy/backend.py 4 | *test* -------------------------------------------------------------------------------- /.github/ISSUE_TEMPLATE/bug_report.md: -------------------------------------------------------------------------------- 1 | --- 2 | name: Bug report 3 | about: Create a report to help us improve 4 | title: '' 5 | labels: bug 6 | assignees: '' 7 | 8 | --- 9 | 10 | **Describe the bug** 11 | A clear and concise description of what the bug is. 12 | 13 | **To Reproduce** 14 | Steps to reproduce the behavior: 15 | 1. Go to '...' 16 | 2. Click on '....' 17 | 3. Scroll down to '....' 18 | 4. See error 19 | 20 | **Expected behavior** 21 | A clear and concise description of what you expected to happen. 22 | 23 | **Screenshots** 24 | If applicable, add screenshots to help explain your problem. 25 | 26 | **Desktop (please complete the following information):** 27 | - OS: [e.g. iOS] 28 | - Browser [e.g. chrome, safari] 29 | - Version [e.g. 22] 30 | 31 | **Smartphone (please complete the following information):** 32 | - Device: [e.g. iPhone6] 33 | - OS: [e.g. iOS8.1] 34 | - Browser [e.g. stock browser, safari] 35 | - Version [e.g. 22] 36 | 37 | **Additional context** 38 | Add any other context about the problem here. 39 | -------------------------------------------------------------------------------- /.github/ISSUE_TEMPLATE/feature_request.md: -------------------------------------------------------------------------------- 1 | --- 2 | name: Feature request 3 | about: Suggest an idea for this project 4 | title: '' 5 | labels: enhancement 6 | assignees: '' 7 | 8 | --- 9 | 10 | **Is your feature request related to a problem? Please describe.** 11 | A clear and concise description of what the problem is. Ex. I'm always frustrated when [...] 12 | 13 | **Describe the solution you'd like** 14 | A clear and concise description of what you want to happen. 15 | 16 | **Describe alternatives you've considered** 17 | A clear and concise description of any alternative solutions or features you've considered. 18 | 19 | **Additional context** 20 | Add any other context or screenshots about the feature request here. 21 | -------------------------------------------------------------------------------- /.github/workflows/python-publish.yml: -------------------------------------------------------------------------------- 1 | # This workflows will upload a Python Package using Twine when a release is created 2 | # For more information see: https://help.github.com/en/actions/language-and-framework-guides/using-python-with-github-actions#publishing-to-package-registries 3 | 4 | name: Upload Python Package 5 | 6 | on: 7 | release: 8 | types: [created] 9 | 10 | jobs: 11 | deploy: 12 | 13 | runs-on: ubuntu-latest 14 | 15 | steps: 16 | - uses: actions/checkout@v2 17 | - name: Set up Python 18 | uses: actions/setup-python@v2 19 | with: 20 | python-version: '3.x' 21 | - name: Install dependencies 22 | run: | 23 | python -m pip install --upgrade pip 24 | pip install setuptools wheel twine 25 | - name: Build and publish 26 | env: 27 | TWINE_USERNAME: ${{ secrets.PYPI_USERNAME }} 28 | TWINE_PASSWORD: ${{ secrets.PYPI_PASSWORD }} 29 | run: | 30 | python setup.py sdist bdist_wheel 31 | twine upload dist/* 32 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | build 2 | dist 3 | __pycache__ 4 | *.egg-info 5 | _build 6 | #* 7 | *# 8 | .DS_Store 9 | GPATH 10 | GRTAGS 11 | GTAGS 12 | .cache 13 | *.npy 14 | *.ra 15 | .idea* 16 | 17 | # sphinx build folder 18 | _build 19 | 20 | # Compiled source # 21 | ################### 22 | *.com 23 | *.class 24 | *.dll 25 | *.exe 26 | *.o 27 | *.so 28 | 29 | # Packages # 30 | ############ 31 | # it's better to unpack these files and commit the raw source 32 | # git has its own built in compression methods 33 | *.7z 34 | *.dmg 35 | *.gz 36 | *.iso 37 | *.jar 38 | *.rar 39 | *.tar 40 | *.zip 41 | 42 | # Logs and databases # 43 | ###################### 44 | *.log 45 | *.sql 46 | *.sqlite 47 | 48 | # OS generated files # 49 | ###################### 50 | .DS_Store? 51 | ehthumbs.db 52 | Icon? 53 | Thumbs.db 54 | 55 | # Editor backup files # 56 | ####################### 57 | *~ 58 | 59 | .pytest_cache 60 | *.tar.bz2 61 | 62 | .coverage 63 | docs/generated 64 | docs/_static 65 | docs/_autosummary -------------------------------------------------------------------------------- /.readthedocs.yml: -------------------------------------------------------------------------------- 1 | version: 2 2 | 3 | python: 4 | version: 3.7 5 | install: 6 | - requirements: docs/requirements.txt 7 | 8 | sphinx: 9 | builder: html 10 | fail_on_warning: true 11 | -------------------------------------------------------------------------------- /.travis.yml: -------------------------------------------------------------------------------- 1 | language: python 2 | python: 3 | - "3.6" 4 | - "3.7" 5 | - "3.8" 6 | 7 | install: 8 | - pip install --upgrade -r requirements.txt 9 | - pip install codecov flake8 sphinx sphinx_rtd_theme matplotlib 10 | 11 | script: 12 | - bash run_tests.sh 13 | 14 | after_success: 15 | - codecov 16 | -------------------------------------------------------------------------------- /CODE_OF_CONDUCT.md: -------------------------------------------------------------------------------- 1 | # Contributor Covenant Code of Conduct 2 | 3 | ## Our Pledge 4 | 5 | In the interest of fostering an open and welcoming environment, we as 6 | contributors and maintainers pledge to making participation in our project and 7 | our community a harassment-free experience for everyone, regardless of age, body 8 | size, disability, ethnicity, sex characteristics, gender identity and expression, 9 | level of experience, education, socio-economic status, nationality, personal 10 | appearance, race, religion, or sexual identity and orientation. 11 | 12 | ## Our Standards 13 | 14 | Examples of behavior that contributes to creating a positive environment 15 | include: 16 | 17 | * Using welcoming and inclusive language 18 | * Being respectful of differing viewpoints and experiences 19 | * Gracefully accepting constructive criticism 20 | * Focusing on what is best for the community 21 | * Showing empathy towards other community members 22 | 23 | Examples of unacceptable behavior by participants include: 24 | 25 | * The use of sexualized language or imagery and unwelcome sexual attention or 26 | advances 27 | * Trolling, insulting/derogatory comments, and personal or political attacks 28 | * Public or private harassment 29 | * Publishing others' private information, such as a physical or electronic 30 | address, without explicit permission 31 | * Other conduct which could reasonably be considered inappropriate in a 32 | professional setting 33 | 34 | ## Our Responsibilities 35 | 36 | Project maintainers are responsible for clarifying the standards of acceptable 37 | behavior and are expected to take appropriate and fair corrective action in 38 | response to any instances of unacceptable behavior. 39 | 40 | Project maintainers have the right and responsibility to remove, edit, or 41 | reject comments, commits, code, wiki edits, issues, and other contributions 42 | that are not aligned to this Code of Conduct, or to ban temporarily or 43 | permanently any contributor for other behaviors that they deem inappropriate, 44 | threatening, offensive, or harmful. 45 | 46 | ## Scope 47 | 48 | This Code of Conduct applies both within project spaces and in public spaces 49 | when an individual is representing the project or its community. Examples of 50 | representing a project or community include using an official project e-mail 51 | address, posting via an official social media account, or acting as an appointed 52 | representative at an online or offline event. Representation of a project may be 53 | further defined and clarified by project maintainers. 54 | 55 | ## Enforcement 56 | 57 | Instances of abusive, harassing, or otherwise unacceptable behavior may be 58 | reported by contacting the project team at frankongh@gmail.com. All 59 | complaints will be reviewed and investigated and will result in a response that 60 | is deemed necessary and appropriate to the circumstances. The project team is 61 | obligated to maintain confidentiality with regard to the reporter of an incident. 62 | Further details of specific enforcement policies may be posted separately. 63 | 64 | Project maintainers who do not follow or enforce the Code of Conduct in good 65 | faith may face temporary or permanent repercussions as determined by other 66 | members of the project's leadership. 67 | 68 | ## Attribution 69 | 70 | This Code of Conduct is adapted from the [Contributor Covenant][homepage], version 1.4, 71 | available at https://www.contributor-covenant.org/version/1/4/code-of-conduct.html 72 | 73 | [homepage]: https://www.contributor-covenant.org 74 | 75 | For answers to common questions about this code of conduct, see 76 | https://www.contributor-covenant.org/faq 77 | -------------------------------------------------------------------------------- /CONTRIBUTING.rst: -------------------------------------------------------------------------------- 1 | Contribution Guide 2 | ------------------ 3 | 4 | Thank you for considering contributing to SigPy. 5 | To contribute to the source code, you can submit a `pull request `_. 6 | 7 | To ensure your pull request can be merged quickly, you should check the following three items: 8 | 9 | - `Coding Style`_ 10 | - `Unit Testing`_ 11 | - `Documentation`_ 12 | 13 | A simple way to check is to run ``run_tests.sh``. You will need to install:: 14 | 15 | $ pip install codecov flake8 sphinx sphinx_rtd_theme matplotlib 16 | 17 | Any new features (new functions, Linops, Apps...), bug fixing, and improved documentation are welcome. 18 | We only ask you to avoid replicating existing features in NumPy, CuPy, and SigPy. 19 | A general rule is that if a feature can already be implemented in one line, 20 | then it is probably not worth defining as a new function. 21 | 22 | Coding Style 23 | ============ 24 | 25 | SigPy adopts the `Google code style `_. 26 | In particular, docstrings should use ``Args`` and ``Returns`` as described in the `Comments and Docstrings section `_. 27 | 28 | You can use ``autopep8`` and ``flake8`` to check your code. 29 | 30 | Unit Testing 31 | ============ 32 | 33 | You should write test cases for each function you commit. The unit tests are under the directory ``tests/``. 34 | The file hierarchy should follow the ``sigpy/`` directory, but each file should be prepended by ``test_``. 35 | 36 | SigPy use the ``unittest`` package for testing. You can run tests by doing:: 37 | 38 | $ python -m unittest 39 | 40 | Documentation 41 | ============= 42 | 43 | Each new feature should be documented in the documentation. The documention is stored under the directory``docs/``. 44 | 45 | You can build the docmentation in HTML format locally using Sphinx:: 46 | 47 | $ cd docs 48 | $ make html 49 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Copyright (c) 2016, Frank Ong. 2 | Copyright (c) 2016, The Regents of the University of California. 3 | All rights reserved. 4 | 5 | Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met: 6 | 7 | 1. Redistributions of source code must retain the above copyright notice, this list of conditions and the following disclaimer. 8 | 9 | 2. Redistributions in binary form must reproduce the above copyright notice, this list of conditions and the following disclaimer in the documentation and/or other materials provided with the distribution. 10 | 11 | 3. Neither the name of the copyright holder nor the names of its contributors may be used to endorse or promote products derived from this software without specific prior written permission. 12 | 13 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -------------------------------------------------------------------------------- /README.rst: -------------------------------------------------------------------------------- 1 | SigPy 2 | ===== 3 | 4 | .. image:: https://img.shields.io/badge/License-BSD%203--Clause-blue.svg 5 | :target: https://opensource.org/licenses/BSD-3-Clause 6 | 7 | .. image:: https://travis-ci.com/mikgroup/sigpy.svg?branch=master 8 | :target: https://travis-ci.com/mikgroup/sigpy 9 | 10 | .. image:: https://readthedocs.org/projects/sigpy/badge/?version=latest 11 | :target: https://sigpy.readthedocs.io/en/latest/?badge=latest 12 | :alt: Documentation Status 13 | 14 | .. image:: https://codecov.io/gh/mikgroup/sigpy/branch/master/graph/badge.svg 15 | :target: https://codecov.io/gh/mikgroup/sigpy 16 | 17 | .. image:: https://zenodo.org/badge/139635485.svg 18 | :target: https://zenodo.org/badge/latestdoi/139635485 19 | 20 | 21 | `Source Code `_ | `Documentation `_ | `MRI Recon Tutorial `_ | `MRI Pulse Design Tutorial `_ 22 | 23 | SigPy is a package for signal processing, with emphasis on iterative methods. It is built to operate directly on NumPy arrays on CPU and CuPy arrays on GPU. SigPy also provides several domain-specific submodules: ``sigpy.plot`` for multi-dimensional array plotting, ``sigpy.mri`` for MRI reconstruction, and ``sigpy.mri.rf`` for MRI pulse design. 24 | 25 | Installation 26 | ------------ 27 | 28 | SigPy requires Python version >= 3.5. The core module depends on ``numba``, ``numpy``, ``PyWavelets``, ``scipy``, and ``tqdm``. 29 | 30 | Additional features can be unlocked by installing the appropriate packages. To enable the plotting functions, you will need to install ``matplotlib``. To enable CUDA support, you will need to install ``cupy``. And to enable MPI support, you will need to install ``mpi4py``. 31 | 32 | Via ``conda`` 33 | ************* 34 | 35 | We recommend installing SigPy through ``conda``:: 36 | 37 | conda install -c frankong sigpy 38 | # (optional for plot support) conda install matplotlib 39 | # (optional for CUDA support) conda install cupy 40 | # (optional for MPI support) conda install mpi4py 41 | 42 | Via ``pip`` 43 | *********** 44 | 45 | SigPy can also be installed through ``pip``:: 46 | 47 | pip install sigpy 48 | # (optional for plot support) pip install matplotlib 49 | # (optional for CUDA support) pip install cupy 50 | # (optional for MPI support) pip install mpi4py 51 | 52 | Installation for Developers 53 | *************************** 54 | 55 | If you want to contribute to the SigPy source code, we recommend you install it with ``pip`` in editable mode:: 56 | 57 | cd /path/to/sigpy 58 | pip install -e . 59 | 60 | To run tests and contribute, we recommend installing the following packages:: 61 | 62 | pip install coverage ruff sphinx sphinx_rtd_theme black isort 63 | 64 | and run the script ``run_tests.sh``. 65 | 66 | Features 67 | -------- 68 | 69 | CPU/GPU Signal Processing Functions 70 | *********************************** 71 | SigPy provides signal processing functions with a unified CPU/GPU interface. For example, the same code can perform a CPU or GPU convolution on the input array device: 72 | 73 | .. code:: python 74 | 75 | # CPU convolve 76 | x = numpy.array([1, 2, 3, 4, 5]) 77 | y = numpy.array([1, 1, 1]) 78 | z = sigpy.convolve(x, y) 79 | 80 | # GPU convolve 81 | x = cupy.array([1, 2, 3, 4, 5]) 82 | y = cupy.array([1, 1, 1]) 83 | z = sigpy.convolve(x, y) 84 | 85 | Iterative Algorithms 86 | ******************** 87 | SigPy also provides convenient abstractions and classes for iterative algorithms. A compressed sensing experiment can be implemented in four lines using SigPy: 88 | 89 | .. code:: python 90 | 91 | # Given some observation vector y, and measurement matrix mat 92 | A = sigpy.linop.MatMul([n, 1], mat) # define forward linear operator 93 | proxg = sigpy.prox.L1Reg([n, 1], lamda=0.001) # define proximal operator 94 | x_hat = sigpy.app.LinearLeastSquares(A, y, proxg=proxg).run() # run iterative algorithm 95 | 96 | PyTorch Interoperability 97 | ************************ 98 | Want to do machine learning without giving up signal processing? SigPy has convenient functions to convert arrays and linear operators into PyTorch Tensors and Functions. For example, given a cupy array ``x``, and a ``Linop`` ``A``, we can convert them to Pytorch: 99 | 100 | .. code:: python 101 | 102 | x_torch = sigpy.to_pytorch(x) 103 | A_torch = sigpy.to_pytorch_function(A) 104 | 105 | -------------------------------------------------------------------------------- /codecov.yml: -------------------------------------------------------------------------------- 1 | comment: false 2 | 3 | ignore: 4 | - sigpy/backend.py 5 | 6 | coverage: 7 | status: 8 | # Enable coverage measurement for diff introduced in the pull-request, 9 | # but do not mark "X" on commit status for now. 10 | patch: 11 | default: 12 | target: '0%' 13 | -------------------------------------------------------------------------------- /conda.recipe/bld.bat: -------------------------------------------------------------------------------- 1 | "%PYTHON%" setup.py install --single-version-externally-managed --record=record.txt 2 | if errorlevel 1 exit 1 3 | -------------------------------------------------------------------------------- /conda.recipe/build.sh: -------------------------------------------------------------------------------- 1 | $PYTHON setup.py install --single-version-externally-managed --record=record.txt # Python command to install the script. 2 | -------------------------------------------------------------------------------- /conda.recipe/conda_build_config.yaml: -------------------------------------------------------------------------------- 1 | python: 2 | - 3.6 3 | - 3.7 4 | - 3.8 5 | - 3.9 6 | -------------------------------------------------------------------------------- /conda.recipe/meta.yaml: -------------------------------------------------------------------------------- 1 | {% set name = "sigpy" %} 2 | {% set version = "0.1.27" %} 3 | 4 | package: 5 | name: '{{ name|lower }}' 6 | version: '{{ version }}' 7 | 8 | source: 9 | path: .. 10 | 11 | requirements: 12 | host: 13 | - python {{ python }} 14 | - setuptools 15 | - numpy 16 | - pywavelets 17 | - numba 18 | - scipy 19 | - tqdm 20 | run: 21 | - python {{ python }} 22 | - numpy 23 | - pywavelets 24 | - numba 25 | - scipy 26 | - tqdm 27 | 28 | test: 29 | imports: 30 | - sigpy 31 | - sigpy.mri 32 | 33 | about: 34 | home: http://github.com/mikgroup/sigpy 35 | license: BSD 36 | license_family: BSD 37 | license_file: LICENSE 38 | summary: Python package for signal reconstruction. 39 | doc_url: http://sigpy.readthedocs.io 40 | -------------------------------------------------------------------------------- /docs/Makefile: -------------------------------------------------------------------------------- 1 | # Minimal makefile for Sphinx documentation 2 | # 3 | 4 | # You can set these variables from the command line. 5 | SPHINXOPTS = 6 | SPHINXBUILD = sphinx-build 7 | SPHINXPROJ = sigpy 8 | SOURCEDIR = . 9 | BUILDDIR = _build 10 | 11 | # Put it first so that "make" without argument is like "make help". 12 | help: 13 | @$(SPHINXBUILD) -M help "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) 14 | 15 | .PHONY: help Makefile 16 | 17 | # Catch-all target: route all unknown targets to Sphinx using the new 18 | # "make mode" option. $(O) is meant as a shortcut for $(SPHINXOPTS). 19 | %: Makefile 20 | @$(SPHINXBUILD) -M $@ "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) -------------------------------------------------------------------------------- /docs/conf.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # 3 | # Configuration file for the Sphinx documentation builder. 4 | # 5 | # This file does only contain a selection of the most common options. For a 6 | # full list see the documentation: 7 | # http://www.sphinx-doc.org/en/master/config 8 | 9 | # -- Path setup -------------------------------------------------------------- 10 | 11 | # If extensions (or modules to document with autodoc) are in another directory, 12 | # add these directories to sys.path here. If the directory is relative to the 13 | # documentation root, use os.path.abspath to make it absolute, like shown here. 14 | # 15 | import os 16 | import sys 17 | 18 | sys.path.insert(0, os.path.abspath("..")) 19 | 20 | 21 | # -- Project information ----------------------------------------------------- 22 | 23 | project = "sigpy" 24 | copyright = "2018-2019, Frank Ong" 25 | author = "Frank Ong" 26 | 27 | # The short X.Y version 28 | version = "" 29 | # The full version, including alpha/beta/rc tags 30 | release = "0.1.27" 31 | 32 | 33 | # -- General configuration --------------------------------------------------- 34 | 35 | # If your documentation needs a minimal Sphinx version, state it here. 36 | # 37 | # needs_sphinx = '1.0' 38 | 39 | # Add any Sphinx extension module names here, as strings. They can be 40 | # extensions coming with Sphinx (named 'sphinx.ext.*') or your custom 41 | # ones. 42 | extensions = [ 43 | "sphinx.ext.autodoc", 44 | "sphinx.ext.autosummary", 45 | "sphinx.ext.mathjax", 46 | "sphinx.ext.napoleon", 47 | "sphinx.ext.viewcode", 48 | ] 49 | autosummary_generate = True 50 | autosummary_imported_members = True 51 | 52 | # Add any paths that contain templates here, relative to this directory. 53 | templates_path = ["_templates"] 54 | 55 | # The suffix(es) of source filenames. 56 | # You can specify multiple suffix as a list of string: 57 | # 58 | # source_suffix = ['.rst', '.md'] 59 | source_suffix = ".rst" 60 | 61 | # The master toctree document. 62 | master_doc = "index" 63 | 64 | # The language for content autogenerated by Sphinx. Refer to documentation 65 | # for a list of supported languages. 66 | # 67 | # This is also used if you do content translation via gettext catalogs. 68 | # Usually you set "language" from the command line for these cases. 69 | language = "en" 70 | 71 | # List of patterns, relative to source directory, that match files and 72 | # directories to ignore when looking for source files. 73 | # This pattern also affects html_static_path and html_extra_path . 74 | exclude_patterns = ["_build", "Thumbs.db", ".DS_Store"] 75 | 76 | # The name of the Pygments (syntax highlighting) style to use. 77 | pygments_style = "sphinx" 78 | 79 | 80 | # -- Options for HTML output ------------------------------------------------- 81 | 82 | # The theme to use for HTML and HTML Help pages. See the documentation for 83 | # a list of builtin themes. 84 | # 85 | html_theme = "sphinx_rtd_theme" 86 | 87 | html_logo = "" 88 | 89 | # Theme options are theme-specific and customize the look and feel of a theme 90 | # further. For a list of options available for each theme, see the 91 | # documentation. 92 | # 93 | html_theme_options = {"navigation_depth": 1} 94 | 95 | # Add any paths that contain custom static files (such as style sheets) here, 96 | # relative to this directory. They are copied after the builtin static files, 97 | # so a file named "default.css" will overwrite the builtin "default.css". 98 | # html_static_path = ['_static'] 99 | 100 | # Custom sidebar templates, must be a dictionary that maps document names 101 | # to template names. 102 | # 103 | # The default sidebars (for documents that don't match any pattern) are 104 | # defined by theme itself. Builtin themes are using these templates by 105 | # default: ``['localtoc.html', 'relations.html', 'sourcelink.html', 106 | # 'searchbox.html']``. 107 | # 108 | # html_sidebars = {} 109 | pygments_style = "default" 110 | 111 | # -- Options for HTMLHelp output --------------------------------------------- 112 | 113 | # Output file base name for HTML help builder. 114 | htmlhelp_basename = "sigpydoc" 115 | 116 | 117 | # -- Options for LaTeX output ------------------------------------------------ 118 | 119 | latex_elements = { 120 | # The paper size ('letterpaper' or 'a4paper'). 121 | # 122 | # 'papersize': 'letterpaper', 123 | # The font size ('10pt', '11pt' or '12pt'). 124 | # 125 | # 'pointsize': '10pt', 126 | # Additional stuff for the LaTeX preamble. 127 | # 128 | # 'preamble': '', 129 | # Latex figure (float) alignment 130 | # 131 | # 'figure_align': 'htbp', 132 | } 133 | 134 | # Grouping the document tree into LaTeX files. List of tuples 135 | # (source start file, target name, title, 136 | # author, documentclass [howto, manual, or own class]). 137 | latex_documents = [ 138 | (master_doc, "sigpy.tex", "sigpy Documentation", "Frank Ong", "manual"), 139 | ] 140 | 141 | 142 | # -- Options for manual page output ------------------------------------------ 143 | 144 | # One entry per manual page. List of tuples 145 | # (source start file, name, description, authors, manual section). 146 | man_pages = [(master_doc, "sigpy", "sigpy Documentation", [author], 1)] 147 | 148 | 149 | # -- Options for Texinfo output ---------------------------------------------- 150 | 151 | # Grouping the document tree into Texinfo files. List of tuples 152 | # (source start file, target name, title, author, 153 | # dir menu entry, description, category) 154 | texinfo_documents = [ 155 | ( 156 | master_doc, 157 | "sigpy", 158 | "sigpy Documentation", 159 | author, 160 | "sigpy", 161 | "One line description of project.", 162 | "Miscellaneous", 163 | ), 164 | ] 165 | 166 | 167 | # -- Extension configuration ------------------------------------------------- 168 | -------------------------------------------------------------------------------- /docs/core.rst: -------------------------------------------------------------------------------- 1 | Functions (`sigpy`) 2 | =================== 3 | .. automodule:: 4 | sigpy 5 | 6 | Computing Backend Functions 7 | --------------------------- 8 | .. automodule:: 9 | sigpy.backend 10 | 11 | .. autosummary:: 12 | :toctree: generated 13 | :nosignatures: 14 | 15 | sigpy.Device 16 | sigpy.get_device 17 | sigpy.get_array_module 18 | sigpy.cpu_device 19 | sigpy.to_device 20 | sigpy.copyto 21 | sigpy.Communicator 22 | 23 | Block Reshape Functions 24 | ----------------------- 25 | .. automodule:: 26 | sigpy.block 27 | 28 | .. autosummary:: 29 | :toctree: generated 30 | :nosignatures: 31 | 32 | sigpy.array_to_blocks 33 | sigpy.blocks_to_array 34 | 35 | Convolution Functions 36 | --------------------- 37 | .. automodule:: 38 | sigpy.conv 39 | 40 | .. autosummary:: 41 | :toctree: generated 42 | :nosignatures: 43 | 44 | sigpy.convolve 45 | sigpy.convolve_data_adjoint 46 | sigpy.convolve_filter_adjoint 47 | 48 | Fourier Functions 49 | ----------------- 50 | .. automodule:: 51 | sigpy.fourier 52 | 53 | .. autosummary:: 54 | :toctree: generated 55 | :nosignatures: 56 | 57 | sigpy.fft 58 | sigpy.ifft 59 | sigpy.nufft 60 | sigpy.nufft_adjoint 61 | sigpy.estimate_shape 62 | 63 | Interpolation Functions 64 | ----------------------- 65 | .. automodule:: 66 | sigpy.interp 67 | 68 | .. autosummary:: 69 | :toctree: generated 70 | :nosignatures: 71 | 72 | sigpy.interpolate 73 | sigpy.gridding 74 | 75 | Pytorch Interop Functions 76 | ------------------------- 77 | .. automodule:: 78 | sigpy.pytorch 79 | 80 | .. autosummary:: 81 | :toctree: generated 82 | :nosignatures: 83 | 84 | sigpy.to_pytorch 85 | sigpy.from_pytorch 86 | sigpy.to_pytorch_function 87 | 88 | Simulation Functions 89 | -------------------- 90 | .. automodule:: 91 | sigpy.sim 92 | 93 | .. autosummary:: 94 | :toctree: generated 95 | :nosignatures: 96 | 97 | sigpy.shepp_logan 98 | 99 | Thresholding Functions 100 | ---------------------- 101 | .. automodule:: 102 | sigpy.thresh 103 | 104 | .. autosummary:: 105 | :toctree: generated 106 | :nosignatures: 107 | 108 | sigpy.soft_thresh 109 | sigpy.hard_thresh 110 | sigpy.l1_proj 111 | sigpy.l2_proj 112 | sigpy.linf_proj 113 | sigpy.psd_proj 114 | 115 | Utility Functions 116 | ----------------- 117 | .. automodule:: 118 | sigpy.util 119 | 120 | .. autosummary:: 121 | :toctree: generated 122 | :nosignatures: 123 | 124 | sigpy.resize 125 | sigpy.flip 126 | sigpy.circshift 127 | sigpy.downsample 128 | sigpy.upsample 129 | sigpy.dirac 130 | sigpy.randn 131 | sigpy.triang 132 | sigpy.hanning 133 | sigpy.monte_carlo_sure 134 | sigpy.axpy 135 | sigpy.xpay 136 | 137 | Wavelet Functions 138 | ----------------- 139 | .. automodule:: 140 | sigpy.wavelet 141 | 142 | .. autosummary:: 143 | :toctree: generated 144 | :nosignatures: 145 | 146 | sigpy.fwt 147 | sigpy.iwt 148 | -------------------------------------------------------------------------------- /docs/core_alg.rst: -------------------------------------------------------------------------------- 1 | Iterative Algorithms (`sigpy.alg`) 2 | ================================== 3 | 4 | .. automodule:: 5 | sigpy.alg 6 | 7 | 8 | The Algorithm Class 9 | ------------------- 10 | .. autosummary:: 11 | :toctree: generated 12 | :nosignatures: 13 | 14 | sigpy.alg.Alg 15 | 16 | First-order Gradient Methods 17 | ---------------------------- 18 | .. autosummary:: 19 | :toctree: generated 20 | :nosignatures: 21 | 22 | sigpy.alg.GradientMethod 23 | sigpy.alg.ConjugateGradient 24 | sigpy.alg.PrimalDualHybridGradient 25 | sigpy.alg.ADMM 26 | sigpy.alg.SDMM 27 | 28 | 29 | Other Methods 30 | ------------- 31 | .. autosummary:: 32 | :toctree: generated 33 | :nosignatures: 34 | 35 | sigpy.alg.NewtonsMethod 36 | sigpy.alg.PowerMethod 37 | sigpy.alg.AltMin 38 | sigpy.alg.AugmentedLagrangianMethod 39 | -------------------------------------------------------------------------------- /docs/core_app.rst: -------------------------------------------------------------------------------- 1 | Apps (`sigpy.app`) 2 | ================== 3 | 4 | .. automodule:: 5 | sigpy.app 6 | 7 | The App Class 8 | ------------- 9 | .. autosummary:: 10 | :toctree: generated 11 | :nosignatures: 12 | 13 | sigpy.app.App 14 | 15 | Apps 16 | ---- 17 | .. autosummary:: 18 | :toctree: generated 19 | :nosignatures: 20 | 21 | sigpy.app.MaxEig 22 | sigpy.app.LinearLeastSquares 23 | -------------------------------------------------------------------------------- /docs/core_linop.rst: -------------------------------------------------------------------------------- 1 | Linear Operators (`sigpy.linop`) 2 | ================================ 3 | 4 | .. automodule:: 5 | sigpy.linop 6 | 7 | The Linear Operator Class 8 | ------------------------- 9 | .. autosummary:: 10 | :toctree: generated 11 | :nosignatures: 12 | 13 | sigpy.linop.Linop 14 | 15 | Linop Manipulation 16 | ------------------ 17 | 18 | The following are classes that take in Linops and compose them to form a new Linop. 19 | 20 | .. autosummary:: 21 | :toctree: generated 22 | :nosignatures: 23 | 24 | sigpy.linop.Conj 25 | sigpy.linop.Add 26 | sigpy.linop.Compose 27 | sigpy.linop.Hstack 28 | sigpy.linop.Vstack 29 | sigpy.linop.Diag 30 | 31 | Basic Linops 32 | ------------ 33 | .. autosummary:: 34 | :toctree: generated 35 | :nosignatures: 36 | 37 | sigpy.linop.Embed 38 | sigpy.linop.Identity 39 | sigpy.linop.Reshape 40 | sigpy.linop.Slice 41 | sigpy.linop.Transpose 42 | 43 | Computing Related Linops 44 | ------------------------ 45 | .. autosummary:: 46 | :toctree: generated 47 | :nosignatures: 48 | 49 | sigpy.linop.ToDevice 50 | sigpy.linop.AllReduce 51 | sigpy.linop.AllReduceAdjoint 52 | 53 | Convolution Linops 54 | ------------------ 55 | .. autosummary:: 56 | :toctree: generated 57 | :nosignatures: 58 | 59 | sigpy.linop.ConvolveData 60 | sigpy.linop.ConvolveDataAdjoint 61 | sigpy.linop.ConvolveFilter 62 | sigpy.linop.ConvolveFilterAdjoint 63 | 64 | Fourier Linops 65 | -------------- 66 | .. autosummary:: 67 | :toctree: generated 68 | :nosignatures: 69 | 70 | sigpy.linop.FFT 71 | sigpy.linop.IFFT 72 | sigpy.linop.NUFFT 73 | sigpy.linop.NUFFTAdjoint 74 | 75 | Multiplication Linops 76 | --------------------- 77 | .. autosummary:: 78 | :toctree: generated 79 | :nosignatures: 80 | 81 | sigpy.linop.MatMul 82 | sigpy.linop.RightMatMul 83 | sigpy.linop.Multiply 84 | 85 | Interapolation Linops 86 | --------------------- 87 | .. autosummary:: 88 | :toctree: generated 89 | :nosignatures: 90 | 91 | sigpy.linop.Interpolate 92 | sigpy.linop.Gridding 93 | 94 | Array Manipulation Linops 95 | ------------------------- 96 | .. autosummary:: 97 | :toctree: generated 98 | :nosignatures: 99 | 100 | sigpy.linop.Resize 101 | sigpy.linop.Flip 102 | sigpy.linop.Downsample 103 | sigpy.linop.Upsample 104 | sigpy.linop.Circshift 105 | sigpy.linop.Sum 106 | sigpy.linop.Tile 107 | sigpy.linop.FiniteDifference 108 | 109 | Wavelet Transform Linops 110 | ------------------------ 111 | .. autosummary:: 112 | :toctree: generated 113 | :nosignatures: 114 | 115 | sigpy.linop.Wavelet 116 | sigpy.linop.InverseWavelet 117 | 118 | Block Reshape Linops 119 | -------------------- 120 | .. autosummary:: 121 | :toctree: generated 122 | :nosignatures: 123 | 124 | sigpy.linop.ArrayToBlocks 125 | sigpy.linop.BlocksToArray 126 | 127 | 128 | -------------------------------------------------------------------------------- /docs/core_prox.rst: -------------------------------------------------------------------------------- 1 | Proximal Operators (`sigpy.prox`) 2 | ================================= 3 | 4 | .. automodule:: 5 | sigpy.prox 6 | 7 | The Proximal Operator Class 8 | --------------------------- 9 | .. autosummary:: 10 | :toctree: generated 11 | :nosignatures: 12 | 13 | sigpy.prox.Prox 14 | 15 | Prox Manipulation 16 | ----------------- 17 | .. autosummary:: 18 | :toctree: generated 19 | :nosignatures: 20 | 21 | sigpy.prox.Conj 22 | sigpy.prox.Stack 23 | sigpy.prox.UnitaryTransform 24 | 25 | Basic Proxs 26 | ----------- 27 | .. autosummary:: 28 | :toctree: generated 29 | :nosignatures: 30 | 31 | sigpy.prox.NoOp 32 | sigpy.prox.BoxConstraint 33 | sigpy.prox.L1Proj 34 | sigpy.prox.L1Reg 35 | sigpy.prox.L2Proj 36 | sigpy.prox.L2Reg 37 | sigpy.prox.LInfProj 38 | sigpy.prox.PsdProj 39 | -------------------------------------------------------------------------------- /docs/figures/architecture.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mikgroup/sigpy/5da0e8605f166be41e520ef0ef913482487611d8/docs/figures/architecture.png -------------------------------------------------------------------------------- /docs/figures/device.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mikgroup/sigpy/5da0e8605f166be41e520ef0ef913482487611d8/docs/figures/device.png -------------------------------------------------------------------------------- /docs/figures/multiprocess_desired.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mikgroup/sigpy/5da0e8605f166be41e520ef0ef913482487611d8/docs/figures/multiprocess_desired.png -------------------------------------------------------------------------------- /docs/figures/multiprocess_mpi.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mikgroup/sigpy/5da0e8605f166be41e520ef0ef913482487611d8/docs/figures/multiprocess_mpi.png -------------------------------------------------------------------------------- /docs/guide_basic.rst: -------------------------------------------------------------------------------- 1 | **This guide is still under construction** 2 | 3 | Basic Usage 4 | ----------- 5 | 6 | SigPy is designed to have as little learning curve as possible. Since almost all Python users already use NumPy, SigPy operates on NumPy arrays directly on CPU, and avoids defining any redundant functions. When NumPy implementation is slow, SigPy uses Numba instead to translate Python functions to optimized machine code at runtime. For example, gridding functions in SigPy are implemented using Numba. For GPU, SigPy operates on CuPy arrays, which have the same interface as NumPy but are implemented in CUDA. 7 | 8 | SigPy does not bundle CuPy installation by default. 9 | To enable CUDA support, you must install CuPy as an additional step. 10 | 11 | While we try to make this documentation as self-contained as possible, 12 | we refer you to the `NumPy documentation `_, 13 | and `CuPy documentation `_ 14 | for general questions about NumPy/CuPy arrays and functions. 15 | 16 | In the following, we will use the following abbreviations: 17 | 18 | >>> import numpy as np 19 | >>> import cupy as cp 20 | >>> import sigpy as sp 21 | 22 | 23 | Choosing Computing Device 24 | ========================= 25 | 26 | SigPy provides a device class :class:`sigpy.Device` to allow you to specify the current computing device for functions and arrays. 27 | It extends the ``Device`` class from CuPy. 28 | Similar approach is also used by machine learning packages, such as TensorFlow, and PyTorch. 29 | 30 | For example to create an array on GPU 1, we can do: 31 | 32 | >>> with sp.Device(1): 33 | >>> x_on_gpu1 = cp.array([1, 2, 3, 4]) 34 | 35 | Note that this can also be done with ``cupy.cuda.Device``, and you can choose to use it as well. 36 | The main difference is that :class:`sigpy.Device` maps -1 to CPU, and makes it easier to develop CPU/GPU generic code. 37 | 38 | .. image:: figures/device.png 39 | :align: center 40 | 41 | To transfer an array between device, we can use :class:`sigpy.to_device`. For example, to transfer a numpy array to GPU 1, we can do: 42 | 43 | >>> x = np.array([1, 2, 3, 4]) 44 | >>> x_on_gpu1 = sp.to_device(x, 1) 45 | 46 | Finally, we can use :func:`sigpy.Device.xp` to choose NumPy or CuPy adaptively. 47 | For example, given a device id, 48 | the following code creates an array on the appropriate device using the appropriate module: 49 | 50 | >>> device = Device(id) 51 | >>> xp = device.xp # Returns NumPy if id == -1, otherwise returns CuPy 52 | >>> with device: 53 | >>> x = xp.array([1, 2, 3, 4]) 54 | -------------------------------------------------------------------------------- /docs/guide_iter.rst: -------------------------------------------------------------------------------- 1 | Building iterative methods 2 | ========================== 3 | 4 | SigPy provides four abstraction classes (Linop, Prox, Alg, and App) for optimization based iterative methods. Such abstraction is inspired by similar structure in BART. 5 | 6 | .. image:: figures/architecture.png 7 | :align: center 8 | 9 | The Linop class abstracts a linear operator, and supports adjoint, addition, composing, and stacking. Prepackaged Linops include FFT, NUFFT, and wavelet, and common array manipulation functions. In particular, given a Linop ``A``, the following operations can be performed: 10 | 11 | >>> A.H # adjoint 12 | >>> A.H * A # compose 13 | >>> A.H * A + lamda * I # addition and scalar multiplication 14 | >>> Hstack([A, B]) # horizontal stack 15 | >>> Vstack([A, B]) # vertical stack 16 | >>> Diag([A, B]) # diagonal stack 17 | 18 | The Prox class abstracts a proximal operator, and can do stacking and conjugation. Prepackaged Proxs include L1/L2 regularization and projection functions. In particular, given a proximal operator ``proxg``, the following operations can be performed: 19 | 20 | >>> Conj(proxg) # convex conjugate 21 | >>> UnitaryTransform(proxg, A) # A.H * proxg * A 22 | >>> Stack([proxg1, proxg2]) # diagonal stack 23 | 24 | The Alg class abstracts iterative algorithms. Prepackaged Algs include conjugate gradient, (accelerated/proximal) gradient method, and primal dual hybrid gradient. A typical usage is as follows: 25 | 26 | >>> while not alg.done(): 27 | >>> alg.update() 28 | 29 | Finally, the App class wraps the above three classes into a final deliverable application. Users can run an App without knowing the internal implementation. A typical usage of an App is as follows: 30 | 31 | >>> out = app.run() 32 | -------------------------------------------------------------------------------- /docs/guide_multi_devices.rst: -------------------------------------------------------------------------------- 1 | **This guide is still under construction** 2 | 3 | Using Multi-CPU/GPU 4 | ------------------- 5 | 6 | SigPy uses MPI and MPI4Py for multi-CPU/GPU programming. We note that this is still under heavy development. 7 | 8 | Although MPI may incur some overhead (for example redundant memory usage) for shared memory system, 9 | we find an MPI solution to be the simplest for multi-threading in Python. 10 | Another benefit is that an MPI parallelized code can run on both shared memory and distributed memory systems. 11 | 12 | For example, if we consider the following shared memory configuration (one multi-core CPU and two GPUs), 13 | and want to run the blue and red tasks concurrently: 14 | 15 | .. image:: figures/multiprocess_desired.png 16 | :align: center 17 | 18 | Then, using MPI, we can split the tasks to two MPI nodes as follows: 19 | 20 | .. image:: figures/multiprocess_mpi.png 21 | :align: center 22 | 23 | Note that tasks on each MPI node can run on any CPU/GPU device, and in our example, the blue task uses CPU and GPU 0, and 24 | the red task uses CPU and GPU 1. 25 | 26 | SigPy provides a communicator class :class:`sigpy.Communicator` that can be used to synchronize variables between ranks. 27 | It extends the ``Communicator`` class from ChainerMN. 28 | -------------------------------------------------------------------------------- /docs/index.rst: -------------------------------------------------------------------------------- 1 | .. include:: ../README.rst 2 | 3 | .. toctree:: 4 | :hidden: 5 | :caption: User Guide 6 | 7 | guide_basic 8 | guide_multi_devices 9 | guide_iter 10 | 11 | .. toctree:: 12 | :hidden: 13 | :caption: Core API Reference 14 | 15 | core 16 | core_linop 17 | core_prox 18 | core_alg 19 | core_app 20 | 21 | .. toctree:: 22 | :hidden: 23 | :caption: MRI API Reference 24 | 25 | mri 26 | mri_linop 27 | mri_app 28 | mri_rf 29 | 30 | .. toctree:: 31 | :hidden: 32 | :caption: Plot API Reference 33 | 34 | plot 35 | -------------------------------------------------------------------------------- /docs/make.bat: -------------------------------------------------------------------------------- 1 | @ECHO OFF 2 | 3 | pushd %~dp0 4 | 5 | REM Command file for Sphinx documentation 6 | 7 | if "%SPHINXBUILD%" == "" ( 8 | set SPHINXBUILD=sphinx-build 9 | ) 10 | set SOURCEDIR=. 11 | set BUILDDIR=_build 12 | set SPHINXPROJ=sigpy 13 | 14 | if "%1" == "" goto help 15 | 16 | %SPHINXBUILD% >NUL 2>NUL 17 | if errorlevel 9009 ( 18 | echo. 19 | echo.The 'sphinx-build' command was not found. Make sure you have Sphinx 20 | echo.installed, then set the SPHINXBUILD environment variable to point 21 | echo.to the full path of the 'sphinx-build' executable. Alternatively you 22 | echo.may add the Sphinx directory to PATH. 23 | echo. 24 | echo.If you don't have Sphinx installed, grab it from 25 | echo.http://sphinx-doc.org/ 26 | exit /b 1 27 | ) 28 | 29 | %SPHINXBUILD% -M %1 %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% 30 | goto end 31 | 32 | :help 33 | %SPHINXBUILD% -M help %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% 34 | 35 | :end 36 | popd 37 | -------------------------------------------------------------------------------- /docs/mri.rst: -------------------------------------------------------------------------------- 1 | MRI Functions (`sigpy.mri`) 2 | =========================== 3 | 4 | .. automodule:: 5 | sigpy.mri 6 | 7 | Preconditioner Functions 8 | ------------------------ 9 | .. automodule:: 10 | sigpy.mri.precond 11 | 12 | .. autosummary:: 13 | :toctree: generated 14 | :nosignatures: 15 | 16 | sigpy.mri.kspace_precond 17 | sigpy.mri.circulant_precond 18 | 19 | Sampling Functions 20 | ------------------ 21 | .. automodule:: 22 | sigpy.mri.samp 23 | 24 | .. autosummary:: 25 | :toctree: generated 26 | :nosignatures: 27 | 28 | sigpy.mri.poisson 29 | sigpy.mri.radial 30 | sigpy.mri.spiral 31 | 32 | Simulation Functions 33 | -------------------- 34 | .. automodule:: 35 | sigpy.mri.sim 36 | 37 | .. autosummary:: 38 | :toctree: generated 39 | :nosignatures: 40 | 41 | sigpy.mri.birdcage_maps 42 | 43 | Utility Functions 44 | ----------------- 45 | .. automodule:: 46 | sigpy.mri.util 47 | 48 | .. autosummary:: 49 | :toctree: generated 50 | :nosignatures: 51 | 52 | sigpy.mri.get_cov 53 | sigpy.mri.whiten 54 | -------------------------------------------------------------------------------- /docs/mri_app.rst: -------------------------------------------------------------------------------- 1 | MRI Apps (`sigpy.mri.app`) 2 | ========================== 3 | 4 | .. automodule:: 5 | sigpy.mri.app 6 | 7 | Regularized MRI Reconstruction 8 | ------------------------------ 9 | .. autosummary:: 10 | :toctree: generated 11 | :nosignatures: 12 | 13 | sigpy.mri.app.SenseRecon 14 | sigpy.mri.app.L1WaveletRecon 15 | sigpy.mri.app.TotalVariationRecon 16 | 17 | Sensitivity Map Estimation 18 | -------------------------- 19 | .. autosummary:: 20 | :toctree: generated 21 | :nosignatures: 22 | 23 | sigpy.mri.app.EspiritCalib 24 | sigpy.mri.app.JsenseRecon 25 | -------------------------------------------------------------------------------- /docs/mri_linop.rst: -------------------------------------------------------------------------------- 1 | MRI Linear Operators (`sigpy.mri.linop`) 2 | ======================================== 3 | 4 | .. automodule:: 5 | sigpy.mri.linop 6 | 7 | .. autosummary:: 8 | :toctree: generated 9 | :nosignatures: 10 | 11 | sigpy.mri.linop.Sense 12 | sigpy.mri.linop.ConvSense 13 | sigpy.mri.linop.ConvImage 14 | -------------------------------------------------------------------------------- /docs/mri_rf.rst: -------------------------------------------------------------------------------- 1 | MRI RF Design (`sigpy.mri.rf`) 2 | ============================== 3 | 4 | .. automodule:: 5 | sigpy.mri.rf 6 | 7 | Adiabatic Pulse Design Functions 8 | -------------------------------- 9 | .. autosummary:: 10 | :toctree: generated 11 | :nosignatures: 12 | 13 | sigpy.mri.rf.adiabatic.bir4 14 | sigpy.mri.rf.adiabatic.hypsec 15 | sigpy.mri.rf.adiabatic.wurst 16 | sigpy.mri.rf.adiabatic.goia_wurst 17 | sigpy.mri.rf.adiabatic.bloch_siegert_fm 18 | 19 | B1-Selective Pulse Design Functions 20 | ----------------------------------- 21 | .. autosummary:: 22 | :toctree: generated 23 | :nosignatures: 24 | 25 | sigpy.mri.rf.b1sel.dz_b1_rf 26 | sigpy.mri.rf.b1sel.dz_b1_gslider_rf 27 | sigpy.mri.rf.b1sel.dz_b1_hadamard_rf 28 | 29 | RF Linear Operators 30 | -------------------------- 31 | .. autosummary:: 32 | :toctree: generated 33 | :nosignatures: 34 | 35 | sigpy.mri.rf.linop.PtxSpatialExplicit 36 | 37 | Pulse Multibanding Functions 38 | ---------------------------- 39 | .. autosummary:: 40 | :toctree: generated 41 | :nosignatures: 42 | 43 | sigpy.mri.rf.multiband.mb_rf 44 | sigpy.mri.rf.multiband.dz_pins 45 | 46 | Optimal Control Design Functions 47 | -------------------------------- 48 | .. autosummary:: 49 | :toctree: generated 50 | :nosignatures: 51 | 52 | sigpy.mri.rf.optcont.blochsim 53 | sigpy.mri.rf.optcont.deriv 54 | 55 | Parallel Transmit Pulse Designers 56 | --------------------------------- 57 | .. autosummary:: 58 | :toctree: generated 59 | :nosignatures: 60 | 61 | sigpy.mri.rf.ptx.stspa 62 | sigpy.mri.rf.ptx.stspk 63 | 64 | RF Shimming Functions 65 | -------------------------- 66 | .. autosummary:: 67 | :toctree: generated 68 | :nosignatures: 69 | 70 | sigpy.mri.rf.shim.calc_shims 71 | sigpy.mri.rf.shim.init_optimal_spectral 72 | sigpy.mri.rf.shim.init_circ_polar 73 | 74 | RF Pulse Simulation 75 | -------------------------- 76 | .. autosummary:: 77 | :toctree: generated 78 | :nosignatures: 79 | 80 | sigpy.mri.rf.sim.abrm 81 | sigpy.mri.rf.sim.abrm_nd 82 | sigpy.mri.rf.sim.abrm_hp 83 | sigpy.mri.rf.sim.abrm_ptx 84 | 85 | SLR Pulse Design Functions 86 | -------------------------------- 87 | .. autosummary:: 88 | :toctree: generated 89 | :nosignatures: 90 | 91 | sigpy.mri.rf.slr.dzrf 92 | sigpy.mri.rf.slr.root_flip 93 | sigpy.mri.rf.slr.dz_gslider_rf 94 | sigpy.mri.rf.slr.dz_gslider_b 95 | sigpy.mri.rf.slr.dz_hadamard_b 96 | sigpy.mri.rf.slr.dz_recursive_rf 97 | 98 | Trajectory and Gradient Design Functions 99 | ---------------------------------------- 100 | .. autosummary:: 101 | :toctree: generated 102 | :nosignatures: 103 | 104 | sigpy.mri.rf.trajgrad.min_trap_grad 105 | sigpy.mri.rf.trajgrad.trap_grad 106 | sigpy.mri.rf.trajgrad.spiral_varden 107 | sigpy.mri.rf.trajgrad.spiral_arch 108 | sigpy.mri.rf.trajgrad.epi 109 | sigpy.mri.rf.trajgrad.rosette 110 | sigpy.mri.rf.trajgrad.stack_of 111 | sigpy.mri.rf.trajgrad.spokes_grad 112 | sigpy.mri.rf.traj_complex_to_array 113 | sigpy.mri.rf.traj_array_to_complex 114 | sigpy.mri.rf.min_time_gradient 115 | 116 | RF Utility 117 | -------------------------- 118 | .. autosummary:: 119 | :toctree: generated 120 | :nosignatures: 121 | 122 | sigpy.mri.rf.util.dinf 123 | 124 | I/O 125 | -------------------------- 126 | .. autosummary:: 127 | :toctree: generated 128 | :nosignatures: 129 | 130 | sigpy.mri.rf.io.siemens_rf 131 | sigpy.mri.rf.io.signa 132 | sigpy.mri.rf.io.ge_rf_params 133 | sigpy.mri.rf.io.philips_rf_params 134 | -------------------------------------------------------------------------------- /docs/plot.rst: -------------------------------------------------------------------------------- 1 | Plot Functions (`sigpy.plot`) 2 | ============================= 3 | 4 | .. automodule:: 5 | sigpy.plot 6 | 7 | .. autosummary:: 8 | :toctree: generated 9 | :nosignatures: 10 | 11 | sigpy.plot.ImagePlot 12 | sigpy.plot.LinePlot 13 | sigpy.plot.ScatterPlot 14 | -------------------------------------------------------------------------------- /docs/requirements.txt: -------------------------------------------------------------------------------- 1 | matplotlib 2 | numba 3 | numpy 4 | PyWavelets 5 | scipy 6 | tqdm 7 | 8 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [build-system] 2 | requires = ["setuptools", "setuptools-scm"] 3 | 4 | [project] 5 | name = "sigpy" 6 | description = "Python package for signal reconstruction." 7 | authors = [ 8 | {name = "Frank Ong", email = "frankong@berkeley.edu"}, 9 | ] 10 | license = {text = "BSD"} 11 | dependencies = [ 12 | "numpy", 13 | "pywavelets", 14 | "numba", 15 | "scipy", 16 | "tqdm" 17 | ] 18 | classifiers=[ 19 | "Programming Language :: Python :: 3", 20 | "License :: OSI Approved :: BSD License", 21 | "Operating System :: OS Independent", 22 | ] 23 | dynamic = [ 24 | "version", 25 | "readme" 26 | ] 27 | 28 | [project.optional-dependencies] 29 | test = [ 30 | "pytest < 5.0.0", 31 | "pytest-cov[all]", 32 | "coverage" 33 | ] 34 | lint = [ 35 | "ruff", 36 | "black", 37 | "isort", 38 | "sphinx", 39 | "sphinx_rtd_theme" 40 | ] 41 | 42 | 43 | [project.urls] 44 | homepage = "https://github.com/mikgroup/sigpy" 45 | documentation = "https://sigpy.readthedocs.io/en/latest/" 46 | repository = "https://github.com/mikgroup/sigpy" 47 | 48 | 49 | [tool.black] 50 | line-length = 79 51 | 52 | [tool.isort] 53 | profile = "black" 54 | 55 | [tool.setuptools.dynamic] 56 | version = {attr = "sigpy.version.__version__"} 57 | readme = {file = "README.rst" } 58 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | numba 2 | numpy 3 | PyWavelets 4 | scipy 5 | tqdm 6 | -------------------------------------------------------------------------------- /run_tests.sh: -------------------------------------------------------------------------------- 1 | set -e 2 | rm -rf docs/generated/ 3 | black . 4 | isort . 5 | ruff check . 6 | coverage run -m unittest 7 | sphinx-build -W docs docs/_build/html 8 | -------------------------------------------------------------------------------- /setup.cfg: -------------------------------------------------------------------------------- 1 | [bumpversion] 2 | current_version = 0.1.27 3 | commit = True 4 | tag = True 5 | 6 | [bumpversion:file:sigpy/version.py] 7 | 8 | [bumpversion:file:tests/test_version.py] 9 | 10 | [bumpversion:file:conda.recipe/meta.yaml] 11 | 12 | [bumpversion:file:docs/conf.py] 13 | 14 | [flake8] 15 | exclude = .eggs,*.egg,.asv,doc,.git 16 | ignore = E121,E123,E126,E129,E226,E24,E704,E741,W503,W504 17 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import setup 2 | 3 | setup() 4 | -------------------------------------------------------------------------------- /sigpy/__init__.py: -------------------------------------------------------------------------------- 1 | """The core module contains functions and classes for signal processing. 2 | 3 | SigPy provides simple interfaces to commonly used signal processing functions, 4 | including convolution, FFT, NUFFT, wavelet transform, and thresholdings. 5 | All functions, except wavelet transform, can run on both CPU and GPU. 6 | 7 | These functions are wrapped into higher level classes (Linop and Prox) 8 | that can be used in conjuction with Alg to form an App. 9 | 10 | """ 11 | from sigpy import ( 12 | alg, 13 | app, 14 | backend, 15 | block, 16 | config, 17 | conv, 18 | fourier, 19 | interp, 20 | linop, 21 | prox, 22 | pytorch, 23 | sim, 24 | thresh, 25 | util, 26 | wavelet, 27 | ) 28 | from sigpy.backend import * # noqa 29 | from sigpy.block import * # noqa 30 | from sigpy.conv import * # noqa 31 | from sigpy.fourier import * # noqa 32 | from sigpy.interp import * # noqa 33 | from sigpy.pytorch import * # noqa 34 | from sigpy.sim import * # noqa 35 | from sigpy.thresh import * # noqa 36 | from sigpy.util import * # noqa 37 | from sigpy.wavelet import * # noqa 38 | 39 | from .version import __version__ # noqa 40 | 41 | __all__ = ["alg", "app", "config", "linop", "prox"] 42 | __all__.extend(backend.__all__) 43 | __all__.extend(block.__all__) 44 | __all__.extend(conv.__all__) 45 | __all__.extend(interp.__all__) 46 | __all__.extend(fourier.__all__) 47 | __all__.extend(pytorch.__all__) 48 | __all__.extend(sim.__all__) 49 | __all__.extend(thresh.__all__) 50 | __all__.extend(util.__all__) 51 | __all__.extend(wavelet.__all__) 52 | -------------------------------------------------------------------------------- /sigpy/config.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """Configuration. 3 | 4 | This module contains flags to turn on and off optional modules. 5 | 6 | """ 7 | import warnings 8 | from importlib import util 9 | 10 | cupy_enabled = util.find_spec("cupy") is not None 11 | if cupy_enabled: 12 | try: 13 | import cupy # noqa 14 | except ImportError as e: 15 | warnings.warn( 16 | f"Importing cupy failed. " 17 | f"For more details, see the error stack below:\n{e}" 18 | ) 19 | cupy_enabled = False 20 | 21 | if cupy_enabled: # pragma: no cover 22 | try: 23 | cudnn_enabled = util.find_spec("cupy.cuda.cudnn") is not None 24 | if cudnn_enabled: 25 | from cupy import cudnn # noqa: F401 26 | except ImportError as e: 27 | warnings.warn( 28 | f"Importing cupy.cuda.cudnn failed. " 29 | f"For more details, see the error stack below:\n{e}" 30 | ) 31 | cudnn_enabled = False 32 | try: 33 | nccl_enabled = util.find_spec("cupy.cuda.nccl") is not None 34 | if nccl_enabled: 35 | from cupy.cuda import nccl # noqa: F401 36 | except ImportError as e: 37 | warnings.warn( 38 | f"Importing cupy.cuda.nccl failed. " 39 | f"For more details, see the error stack below:\n{e}" 40 | ) 41 | nccl_enabled = False 42 | else: 43 | cudnn_enabled = False 44 | nccl_enabled = False 45 | 46 | mpi4py_enabled = util.find_spec("mpi4py") is not None 47 | 48 | # This is to catch an import error when the cudnn in cupy (system) and pytorch 49 | # (built in) are in conflict. 50 | if util.find_spec("torch") is not None: 51 | try: 52 | import torch # noqa 53 | 54 | pytorch_enabled = True 55 | except ImportError: 56 | print("Warning : Pytorch installed but can import") 57 | pytorch_enabled = False 58 | else: 59 | pytorch_enabled = False 60 | -------------------------------------------------------------------------------- /sigpy/mri/__init__.py: -------------------------------------------------------------------------------- 1 | """The module contains functions and classes for MRI reconstruction. 2 | 3 | It provides convenient simulation and sampling functions, 4 | such as the poisson-disc sampling function. It also 5 | provides functions to compute preconditioners, 6 | and density compensation factors. 7 | 8 | """ 9 | from sigpy.mri import app, dcf, linop, precond, samp, sim, util 10 | from sigpy.mri.dcf import * # noqa 11 | from sigpy.mri.precond import * # noqa 12 | from sigpy.mri.samp import * # noqa 13 | from sigpy.mri.sim import * # noqa 14 | from sigpy.mri.util import * # noqa 15 | 16 | __all__ = ["app", "linop"] 17 | __all__.extend(dcf.__all__) 18 | __all__.extend(precond.__all__) 19 | __all__.extend(samp.__all__) 20 | __all__.extend(sim.__all__) 21 | __all__.extend(util.__all__) 22 | -------------------------------------------------------------------------------- /sigpy/mri/dcf.py: -------------------------------------------------------------------------------- 1 | """Density compensation functions. 2 | 3 | """ 4 | from tqdm.auto import tqdm 5 | 6 | import sigpy as sp 7 | 8 | __all__ = ["pipe_menon_dcf"] 9 | 10 | 11 | def pipe_menon_dcf( 12 | coord, 13 | img_shape=None, 14 | device=sp.cpu_device, 15 | max_iter=30, 16 | n=128, 17 | beta=8, 18 | width=4, 19 | show_pbar=True, 20 | ): 21 | r"""Compute Pipe Menon density compensation factor. 22 | 23 | Perform the following iteration: 24 | 25 | .. math:: 26 | 27 | w = \frac{w}{|G^H G w|} 28 | 29 | with :math:`G` as the gridding operator. 30 | 31 | Args: 32 | coord (array): k-space coordinates. 33 | img_shape (None or list): Image shape. 34 | device (Device): computing device. 35 | max_iter (int): number of iterations. 36 | n (int): Kaiser-Bessel sampling numbers for gridding operator. 37 | beta (float): Kaiser-Bessel kernel parameter. 38 | width (float): Kaiser-Bessel kernel width. 39 | show_pbar (bool): show progress bar. 40 | 41 | Returns: 42 | array: density compensation factor. 43 | 44 | References: 45 | Pipe, James G., and Padmanabhan Menon. 46 | Sampling Density Compensation in MRI: 47 | Rationale and an Iterative Numerical Solution. 48 | Magnetic Resonance in Medicine 41, no. 1 (1999): 179–86. 49 | 50 | 51 | """ 52 | device = sp.Device(device) 53 | xp = device.xp 54 | 55 | with device: 56 | w = xp.ones(coord.shape[:-1], dtype=coord.dtype) 57 | if img_shape is None: 58 | img_shape = sp.estimate_shape(coord) 59 | 60 | G = sp.linop.Gridding( 61 | img_shape, coord, param=beta, width=width, kernel="kaiser_bessel" 62 | ) 63 | with tqdm( 64 | total=max_iter, desc="PipeMenonDCF", disable=not show_pbar 65 | ) as pbar: 66 | for it in range(max_iter): 67 | GHGw = G.H * G * w 68 | w /= xp.abs(GHGw) 69 | resid = xp.abs(GHGw - 1).max().item() 70 | 71 | pbar.set_postfix(resid="{0:.2E}".format(resid)) 72 | pbar.update() 73 | 74 | return w 75 | -------------------------------------------------------------------------------- /sigpy/mri/linop.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """MRI linear operators. 3 | 4 | This module mainly contains the Sense linear operator, 5 | which integrates multi-channel coil sensitivity maps and 6 | discrete Fourier transform. 7 | 8 | """ 9 | import sigpy as sp 10 | 11 | 12 | def Sense( 13 | mps, 14 | coord=None, 15 | weights=None, 16 | tseg=None, 17 | ishape=None, 18 | coil_batch_size=None, 19 | comm=None, 20 | transp_nufft=False, 21 | ): 22 | """Sense linear operator. 23 | 24 | Args: 25 | mps (array): sensitivity maps of length = number of channels. 26 | coord (None or array): coordinates. 27 | weights (None or array): k-space weights. 28 | Useful for soft-gating or density compensation. 29 | tseg (None or Dictionary): parameters for time-segmented off-resonance 30 | correction. Parameters are 'b0' (array), 'dt' (float), 31 | 'lseg' (int), and 'n_bins' (int). Lseg is the number of 32 | time segments used, and n_bins is the number of histogram bins. 33 | ishape (None or tuple): image shape. 34 | coil_batch_size (None or int): batch size for processing multi-channel. 35 | When None, process all coils at the same time. 36 | Useful for saving memory. 37 | comm (None or `sigpy.Communicator`): communicator 38 | for distributed computing. 39 | 40 | """ 41 | # Get image shape and dimension. 42 | num_coils = len(mps) 43 | if ishape is None: 44 | ishape = mps.shape[1:] 45 | img_ndim = mps.ndim - 1 46 | else: 47 | img_ndim = len(ishape) 48 | 49 | # Serialize linop if coil_batch_size is smaller than num_coils. 50 | num_coils = len(mps) 51 | if coil_batch_size is None: 52 | coil_batch_size = num_coils 53 | 54 | if coil_batch_size < len(mps): 55 | num_coil_batches = (num_coils + coil_batch_size - 1) // coil_batch_size 56 | A = sp.linop.Vstack( 57 | [ 58 | Sense( 59 | mps[c * coil_batch_size : ((c + 1) * coil_batch_size)], 60 | coord=coord, 61 | weights=weights, 62 | ishape=ishape, 63 | ) 64 | for c in range(num_coil_batches) 65 | ], 66 | axis=0, 67 | ) 68 | 69 | if comm is not None: 70 | C = sp.linop.AllReduceAdjoint(ishape, comm, in_place=True) 71 | A = A * C 72 | 73 | return A 74 | 75 | # Create Sense linear operator 76 | S = sp.linop.Multiply(ishape, mps) 77 | if tseg is None: 78 | if coord is None: 79 | F = sp.linop.FFT(S.oshape, axes=range(-img_ndim, 0)) 80 | else: 81 | if transp_nufft is False: 82 | F = sp.linop.NUFFT(S.oshape, coord) 83 | else: 84 | F = sp.linop.NUFFT(S.oshape, -coord).H 85 | 86 | A = F * S 87 | 88 | # If B0 provided, perform time-segmented off-resonance compensation 89 | else: 90 | if transp_nufft is False: 91 | F = sp.linop.NUFFT(S.oshape, coord) 92 | else: 93 | F = sp.linop.NUFFT(S.oshape, -coord).H 94 | time = len(coord) * tseg["dt"] 95 | b, ct = sp.mri.util.tseg_off_res_b_ct( 96 | tseg["b0"], tseg["n_bins"], tseg["lseg"], tseg["dt"], time 97 | ) 98 | for ii in range(tseg["lseg"]): 99 | Bi = sp.linop.Multiply(F.oshape, b[:, ii]) 100 | Cti = sp.linop.Multiply(S.ishape, ct[:, ii].reshape(S.ishape)) 101 | 102 | # operation below is effectively A = A + Bi * F(Cti * S) 103 | if ii == 0: 104 | A = Bi * F * S * Cti 105 | else: 106 | A = A + Bi * F * S * Cti 107 | 108 | if weights is not None: 109 | with sp.get_device(weights): 110 | P = sp.linop.Multiply(F.oshape, weights**0.5) 111 | 112 | A = P * A 113 | 114 | if comm is not None: 115 | C = sp.linop.AllReduceAdjoint(ishape, comm, in_place=True) 116 | A = A * C 117 | 118 | A.repr_str = "Sense" 119 | return A 120 | 121 | 122 | def ConvSense( 123 | img_ker_shape, mps_ker, coord=None, weights=None, grd_shape=None, comm=None 124 | ): 125 | """Convolution linear operator with sensitivity maps kernel in k-space. 126 | 127 | Args: 128 | img_ker_shape (tuple of ints): image kernel shape. 129 | mps_ker (array): sensitivity maps kernel. 130 | coord (array): coordinates. 131 | grd_shape (None or list): Shape of grid. 132 | 133 | """ 134 | ndim = len(img_ker_shape) 135 | num_coils = mps_ker.shape[0] 136 | mps_ker = mps_ker.reshape((num_coils, 1) + mps_ker.shape[1:]) 137 | R = sp.linop.Reshape((1,) + tuple(img_ker_shape), img_ker_shape) 138 | C = sp.linop.ConvolveData( 139 | R.oshape, mps_ker, mode="valid", multi_channel=True 140 | ) 141 | A = C * R 142 | 143 | if coord is not None: 144 | if grd_shape is None: 145 | grd_shape = sp.estimate_shape(coord) 146 | else: 147 | grd_shape = list(grd_shape) 148 | 149 | grd_shape = [num_coils] + grd_shape 150 | iF = sp.linop.IFFT(grd_shape, axes=range(-ndim, 0)) 151 | N = sp.linop.NUFFT(grd_shape, coord) 152 | A = N * iF * A 153 | 154 | if weights is not None: 155 | with sp.get_device(weights): 156 | P = sp.linop.Multiply(A.oshape, weights**0.5) 157 | 158 | A = P * A 159 | 160 | if comm is not None: 161 | C = sp.linop.AllReduceAdjoint(img_ker_shape, comm, in_place=True) 162 | A = A * C 163 | 164 | return A 165 | 166 | 167 | def ConvImage( 168 | mps_ker_shape, img_ker, coord=None, weights=None, grd_shape=None 169 | ): 170 | """Convolution linear operator with image kernel in k-space. 171 | 172 | Args: 173 | mps_ker_shape (tuple of ints): sensitivity maps kernel shape. 174 | img_ker (array): image kernel. 175 | coord (array): coordinates. 176 | grd_shape (None or list): Shape of grid. 177 | 178 | """ 179 | ndim = img_ker.ndim 180 | num_coils = mps_ker_shape[0] 181 | img_ker = img_ker.reshape((1,) + img_ker.shape) 182 | R = sp.linop.Reshape( 183 | (num_coils, 1) + tuple(mps_ker_shape[1:]), mps_ker_shape 184 | ) 185 | C = sp.linop.ConvolveFilter( 186 | R.oshape, img_ker, mode="valid", multi_channel=True 187 | ) 188 | A = C * R 189 | 190 | if coord is not None: 191 | num_coils = mps_ker_shape[0] 192 | if grd_shape is None: 193 | grd_shape = sp.estimate_shape(coord) 194 | else: 195 | grd_shape = list(grd_shape) 196 | 197 | grd_shape = [num_coils] + grd_shape 198 | iF = sp.linop.IFFT(grd_shape, axes=range(-ndim, 0)) 199 | N = sp.linop.NUFFT(grd_shape, coord) 200 | A = N * iF * A 201 | 202 | if weights is not None: 203 | with sp.get_device(weights): 204 | P = sp.linop.Multiply(A.oshape, weights**0.5) 205 | A = P * A 206 | 207 | return A 208 | -------------------------------------------------------------------------------- /sigpy/mri/precond.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """MRI preconditioners. 3 | """ 4 | import sigpy as sp 5 | 6 | __all__ = ["kspace_precond", "circulant_precond"] 7 | 8 | 9 | def kspace_precond( 10 | mps, weights=None, coord=None, lamda=0, device=sp.cpu_device, oversamp=1.25 11 | ): 12 | r"""Compute a diagonal preconditioner in k-space. 13 | 14 | Considers the optimization problem: 15 | 16 | .. math:: 17 | \min_P \| P A A^H - I \|_F^2 18 | 19 | where A is the Sense operator. 20 | 21 | Args: 22 | mps (array): sensitivity maps of shape [num_coils] + image shape. 23 | weights (array): k-space weights. 24 | coord (array): k-space coordinates of shape [...] + [ndim]. 25 | lamda (float): regularization. 26 | 27 | Returns: 28 | array: k-space preconditioner of same shape as k-space. 29 | 30 | """ 31 | dtype = mps.dtype 32 | 33 | if weights is not None: 34 | weights = sp.to_device(weights, device) 35 | 36 | device = sp.Device(device) 37 | xp = device.xp 38 | 39 | mps_shape = list(mps.shape) 40 | img_shape = mps_shape[1:] 41 | img2_shape = [i * 2 for i in img_shape] 42 | ndim = len(img_shape) 43 | 44 | scale = sp.prod(img2_shape) ** 1.5 / sp.prod(img_shape) 45 | with device: 46 | if coord is None: 47 | idx = (slice(None, None, 2),) * ndim 48 | 49 | ones = xp.zeros(img2_shape, dtype=dtype) 50 | if weights is None: 51 | ones[idx] = 1 52 | else: 53 | ones[idx] = weights**0.5 54 | 55 | psf = sp.ifft(ones) 56 | else: 57 | coord2 = coord * 2 58 | ones = xp.ones(coord.shape[:-1], dtype=dtype) 59 | if weights is not None: 60 | ones *= weights**0.5 61 | 62 | psf = sp.nufft_adjoint(ones, coord2, img2_shape, oversamp=oversamp) 63 | 64 | p_inv = [] 65 | for mps_i in mps: 66 | mps_i = sp.to_device(mps_i, device) 67 | mps_i_norm2 = xp.linalg.norm(mps_i) ** 2 68 | xcorr_fourier = 0 69 | for mps_j in mps: 70 | mps_j = sp.to_device(mps_j, device) 71 | xcorr_fourier += ( 72 | xp.abs(sp.fft(mps_i * xp.conj(mps_j), img2_shape)) ** 2 73 | ) 74 | 75 | xcorr = sp.ifft(xcorr_fourier) 76 | xcorr *= psf 77 | if coord is None: 78 | p_inv_i = sp.fft(xcorr)[idx] 79 | else: 80 | p_inv_i = sp.nufft(xcorr, coord2, oversamp=oversamp) 81 | 82 | if weights is not None: 83 | p_inv_i *= weights**0.5 84 | 85 | p_inv.append(p_inv_i * scale / mps_i_norm2) 86 | 87 | p_inv = (xp.abs(xp.stack(p_inv, axis=0)) + lamda) / (1 + lamda) 88 | p_inv[p_inv == 0] = 1 89 | p = 1 / p_inv 90 | 91 | return p.astype(dtype) 92 | 93 | 94 | def circulant_precond( 95 | mps, weights=None, coord=None, lamda=0, device=sp.cpu_device 96 | ): 97 | r"""Compute circulant preconditioner. 98 | 99 | Considers the optimization problem: 100 | 101 | .. math:: 102 | \min_P \| A^H A - F P F^H \|_2^2 103 | 104 | where A is the Sense operator, 105 | and F is a unitary Fourier transform operator. 106 | 107 | Args: 108 | mps (array): sensitivity maps of shape [num_coils] + image shape. 109 | weights (array): k-space weights. 110 | coord (array): k-space coordinates of shape [...] + [ndim]. 111 | lamda (float): regularization. 112 | 113 | Returns: 114 | array: circulant preconditioner of image shape. 115 | 116 | """ 117 | if coord is not None: 118 | coord = sp.to_device(coord, device) 119 | 120 | if weights is not None: 121 | weights = sp.to_device(weights, device) 122 | 123 | dtype = mps.dtype 124 | device = sp.Device(device) 125 | xp = device.xp 126 | 127 | mps_shape = list(mps.shape) 128 | img_shape = mps_shape[1:] 129 | img2_shape = [i * 2 for i in img_shape] 130 | ndim = len(img_shape) 131 | 132 | scale = sp.prod(img2_shape) ** 1.5 / sp.prod(img_shape) ** 2 133 | with device: 134 | idx = (slice(None, None, 2),) * ndim 135 | if coord is None: 136 | ones = xp.zeros(img2_shape, dtype=dtype) 137 | if weights is None: 138 | ones[idx] = 1 139 | else: 140 | ones[idx] = weights**0.5 141 | 142 | psf = sp.ifft(ones) 143 | else: 144 | coord2 = coord * 2 145 | ones = xp.ones(coord.shape[:-1], dtype=dtype) 146 | if weights is not None: 147 | ones *= weights**0.5 148 | 149 | psf = sp.nufft_adjoint(ones, coord2, img2_shape) 150 | 151 | p_inv = 0 152 | for mps_i in mps: 153 | mps_i = sp.to_device(mps_i, device) 154 | xcorr_fourier = xp.abs(sp.fft(xp.conj(mps_i), img2_shape)) ** 2 155 | xcorr = sp.ifft(xcorr_fourier) 156 | xcorr *= psf 157 | p_inv_i = sp.fft(xcorr) 158 | p_inv_i = p_inv_i[idx] 159 | p_inv_i *= scale 160 | if weights is not None: 161 | p_inv_i *= weights**0.5 162 | 163 | p_inv += p_inv_i 164 | 165 | p_inv += lamda 166 | p_inv[p_inv == 0] = 1 167 | p = 1 / p_inv 168 | 169 | return p.astype(dtype) 170 | -------------------------------------------------------------------------------- /sigpy/mri/rf/__init__.py: -------------------------------------------------------------------------------- 1 | """This MRI submodule contains functions and classes for MRI pulse design. 2 | 3 | It contains functions to design a variety of RF pulses for MRI, such as SLR, 4 | adiabatic, parallel transmit, multibanded, and others. The submodule also 5 | includes other functions to assist with pulse design, such as I/O functions, 6 | trajectory/gradient designers, and Bloch simulators. 7 | 8 | Explore RF design tutorials at `sigpy-rf-tutorials`_. These are primarily 9 | Jupyter Notebooks, and provide more detailed instruction on pulse design 10 | workflow and function use. 11 | 12 | See in-progress features at `sigpy-rf`_. 13 | 14 | .. _sigpy-rf-tutorials: https://github.com/jonbmartin/sigpy-rf-tutorials 15 | .. _sigpy-rf: https://github.com/jonbmartin/sigpy-rf 16 | 17 | """ 18 | from sigpy.mri import linop 19 | from sigpy.mri.rf import ( 20 | adiabatic, 21 | b1sel, 22 | io, 23 | multiband, 24 | optcont, 25 | ptx, 26 | shim, 27 | sim, 28 | slr, 29 | trajgrad, 30 | util, 31 | ) 32 | from sigpy.mri.rf.adiabatic import * # noqa 33 | from sigpy.mri.rf.b1sel import * # noqa 34 | from sigpy.mri.rf.io import * # noqa 35 | from sigpy.mri.rf.linop import * # noqa 36 | from sigpy.mri.rf.multiband import * # noqa 37 | from sigpy.mri.rf.optcont import * # noqa 38 | from sigpy.mri.rf.ptx import * # noqa 39 | from sigpy.mri.rf.shim import * # noqa 40 | from sigpy.mri.rf.sim import * # noqa 41 | from sigpy.mri.rf.slr import * # noqa 42 | from sigpy.mri.rf.trajgrad import * # noqa 43 | from sigpy.mri.rf.util import * # noqa 44 | 45 | __all__ = ["linop"] 46 | __all__.extend(adiabatic.__all__) 47 | __all__.extend(b1sel.__all__) 48 | __all__.extend(io.__all__) 49 | __all__.extend(multiband.__all__) 50 | __all__.extend(optcont.__all__) 51 | __all__.extend(ptx.__all__) 52 | __all__.extend(sim.__all__) 53 | __all__.extend(shim.__all__) 54 | __all__.extend(slr.__all__) 55 | __all__.extend(trajgrad.__all__) 56 | __all__.extend(util.__all__) 57 | -------------------------------------------------------------------------------- /sigpy/mri/rf/adiabatic.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """Adiabatic Pulse Design functions. 3 | 4 | """ 5 | import numpy as np 6 | 7 | __all__ = ["bir4", "hypsec", "wurst", "goia_wurst", "bloch_siegert_fm"] 8 | 9 | 10 | def bir4(n, beta, kappa, theta, dw0): 11 | r"""Design a BIR-4 adiabatic pulse. 12 | 13 | BIR-4 is equivalent to two BIR-1 pulses back-to-back. 14 | 15 | Args: 16 | n (int): number of samples (should be a multiple of 4). 17 | beta (float): AM waveform parameter. 18 | kappa (float): FM waveform parameter. 19 | theta (float): flip angle in radians. 20 | dw0: FM waveform scaling (radians/s). 21 | 22 | Returns: 23 | 2-element tuple containing 24 | 25 | - **a** (*array*): AM waveform. 26 | - **om** (*array*): FM waveform (radians/s). 27 | 28 | References: 29 | Staewen, R.S. et al. (1990). '3-D FLASH Imaging using a single surface 30 | coil and a new adiabatic pulse, BIR-4'. 31 | Invest. Radiology, 25:559-567. 32 | """ 33 | 34 | dphi = np.pi + theta / 2 35 | 36 | t = np.arange(0, n) / n 37 | 38 | a1 = np.tanh(beta * (1 - 4 * t[: n // 4])) 39 | a2 = np.tanh(beta * (4 * t[n // 4 : n // 2] - 1)) 40 | a3 = np.tanh(beta * (3 - 4 * t[n // 2 : 3 * n // 4])) 41 | a4 = np.tanh(beta * (4 * t[3 * n // 4 :] - 3)) 42 | 43 | a = np.concatenate((a1, a2, a3, a4)).astype(np.complex64) 44 | a[n // 4 : 3 * n // 4] = a[n // 4 : 3 * n // 4] * np.exp(1j * dphi) 45 | 46 | om1 = dw0 * np.tan(kappa * 4 * t[: n // 4]) / np.tan(kappa) 47 | om2 = dw0 * np.tan(kappa * (4 * t[n // 4 : n // 2] - 2)) / np.tan(kappa) 48 | om3 = ( 49 | dw0 * np.tan(kappa * (4 * t[n // 2 : 3 * n // 4] - 2)) / np.tan(kappa) 50 | ) 51 | om4 = dw0 * np.tan(kappa * (4 * t[3 * n // 4 :] - 4)) / np.tan(kappa) 52 | 53 | om = np.concatenate((om1, om2, om3, om4)) 54 | 55 | return a, om 56 | 57 | 58 | def hypsec(n=512, beta=800, mu=4.9, dur=0.012): 59 | r"""Design a hyperbolic secant adiabatic pulse. 60 | 61 | mu * beta becomes the amplitude of the frequency sweep 62 | 63 | Args: 64 | n (int): number of samples (should be a multiple of 4). 65 | beta (float): AM waveform parameter. 66 | mu (float): a constant, determines amplitude of frequency sweep. 67 | dur (float): pulse time (s). 68 | 69 | Returns: 70 | 2-element tuple containing 71 | 72 | - **a** (*array*): AM waveform. 73 | - **om** (*array*): FM waveform (radians/s). 74 | 75 | References: 76 | Baum, J., Tycko, R. and Pines, A. (1985). 'Broadband and adiabatic 77 | inversion of a two-level system by phase-modulated pulses'. 78 | Phys. Rev. A., 32:3435-3447. 79 | """ 80 | 81 | t = np.arange(-n // 2, n // 2) / n * dur 82 | 83 | a = np.cosh(beta * t) ** -1 84 | om = -mu * beta * np.tanh(beta * t) 85 | 86 | return a, om 87 | 88 | 89 | def wurst(n=512, n_fac=40, bw=40e3, dur=2e-3): 90 | r"""Design a WURST (wideband, uniform rate, smooth truncation) adiabatic 91 | inversion pulse 92 | 93 | Args: 94 | n (int): number of samples (should be a multiple of 4). 95 | n_fac (int): power to exponentiate to within AM term. ~20 or greater is 96 | typical. 97 | bw (float): pulse bandwidth. 98 | dur (float): pulse time (s). 99 | 100 | 101 | Returns: 102 | 2-element tuple containing 103 | 104 | - **a** (*array*): AM waveform. 105 | - **om** (*array*): FM waveform (radians/s). 106 | 107 | References: 108 | Kupce, E. and Freeman, R. (1995). 'Stretched Adiabatic Pulses for 109 | Broadband Spin Inversion'. 110 | J. Magn. Reson. Ser. A., 117:246-256. 111 | """ 112 | 113 | t = np.arange(0, n) * dur / n 114 | 115 | a = 1 - np.power(np.abs(np.cos(np.pi * t / dur)), n_fac) 116 | om = np.linspace(-bw / 2, bw / 2, n) * 2 * np.pi 117 | 118 | return a, om 119 | 120 | 121 | def goia_wurst( 122 | n=512, dur=3.5e-3, f=0.9, n_b1=16, m_grad=4, b1_max=817, bw=20000 123 | ): 124 | r"""Design a GOIA (gradient offset independent adiabaticity) WURST 125 | inversion pulse 126 | 127 | Args: 128 | n (int): number of samples. 129 | dur (float): pulse duration (s). 130 | f (float): [0,1] gradient modulation factor 131 | n_b1 (int): order for B1 modulation 132 | m_grad (int): order for gradient modulation 133 | b1_max (float): maximum b1 (Hz) 134 | bw (float): pulse bandwidth (Hz) 135 | 136 | Returns: 137 | 3-element tuple containing: 138 | 139 | - **a** (*array*): AM waveform (Hz) 140 | - **om** (*array*): FM waveform (Hz) 141 | - **g** (*array*): normalized gradient waveform 142 | 143 | References: 144 | O. C. Andronesi, S. Ramadan, E.-M. Ratai, D. Jennings, C. E. Mountford, 145 | A. G. Sorenson. 146 | J Magn Reson, 203:283-293, 2010. 147 | 148 | """ 149 | 150 | t = np.arange(0, n) * dur / n 151 | 152 | a = b1_max * (1 - np.abs(np.sin(np.pi / 2 * (2 * t / dur - 1))) ** n_b1) 153 | g = (1 - f) + f * np.abs(np.sin(np.pi / 2 * (2 * t / dur - 1))) ** m_grad 154 | om = np.cumsum((a**2) / g) * dur / n 155 | om = om - om[n // 2 + 1] 156 | om = g * om 157 | om = om / np.max(np.abs(om)) * bw / 2 158 | 159 | return a, om, g 160 | 161 | 162 | def bloch_siegert_fm( 163 | n=512, dur=2e-3, b1p=20.0, k=42.0, gamma=2 * np.pi * 42.58 164 | ): 165 | r""" 166 | U-shaped FM waveform for adiabatic Bloch-Siegert :math:`B_1^{+}` mapping 167 | and spatial encoding. 168 | 169 | Args: 170 | n (int): number of time points 171 | dur (float): duration in seconds 172 | b1p (float): nominal amplitude of constant AM waveform 173 | k (float): design parameter that affects max in-band 174 | perturbation 175 | gamma (float): gyromagnetic ratio 176 | 177 | Returns: 178 | om (array): FM waveform (radians/s). 179 | 180 | References: 181 | M. M. Khalighi, B. K. Rutt, and A. B. Kerr. 182 | Adiabatic RF pulse design for Bloch-Siegert B1+ mapping. 183 | Magn Reson Med, 70(3):829–835, 2013. 184 | 185 | M. Jankiewicz, J. C. Gore, and W. A. Grissom. 186 | Improved encoding pulses for Bloch-Siegert B1+ mapping. 187 | J Magn Reson, 226:79–87, 2013. 188 | 189 | """ 190 | 191 | t = np.arange(1, n // 2) * dur / n 192 | 193 | om = gamma * b1p / np.sqrt((1 - gamma * b1p / k * t) ** -2 - 1) 194 | om = np.concatenate((om, om[::-1])) 195 | 196 | return om 197 | -------------------------------------------------------------------------------- /sigpy/mri/rf/io.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """MRI waveform import/export files. 3 | """ 4 | 5 | import struct 6 | 7 | import numpy as np 8 | 9 | __all__ = ["signa", "ge_rf_params", "philips_rf_params", "siemens_rf"] 10 | 11 | 12 | def siemens_rf( 13 | pulse, rfbw, rfdurms, pulsename, minslice=0.5, maxslice=320.0, comment=None 14 | ): 15 | """Write a .pta text file for Siemens PulseTool. 16 | 17 | Args: 18 | pulse (array): complex-valued RF pulse array with maximum of 4096 19 | points. 20 | rfbw (float): bandwidth of RF pulse in Hz 21 | rfdurms (float): duration of RF pulse in ms 22 | pulsename (string): '.', e.g. 'Sigpy.SincPulse' 23 | minslice (float): minimum slice thickness [mm] 24 | maxslice (float): maximum slice thickness [mm] 25 | comment (string): a comment that can be seen in Siemens PulseTool 26 | 27 | Note this has only been tested on MAGNETOM Verio running (VB17) 28 | 29 | Open pulsetool from the IDEA command line. Open the extrf.dat file and add 30 | this .pta file using the import function 31 | 32 | Recommended to make a copy and renaming extrf.dat prior to making changes. 33 | 34 | After saving a new pulse to _extrf.dat and copying it to 35 | the scanner, you will need to re-boot the host for it to load changes. 36 | 37 | """ 38 | 39 | # get the number of points in RF waveform 40 | npts = pulse.size 41 | assert npts <= 4096, ( 42 | "RF pulse must have less than 4096 points for" " Siemens VB17" 43 | ) 44 | 45 | if comment is None: 46 | comment = "" 47 | 48 | # Calculate reference gradient value. 49 | # This is necessary for proper calculation of slice-select gradient 50 | # amplitude using the .getGSAmplitude() method for the external RF class. 51 | # See the IDEA documentation for more details on this. 52 | refgrad = 1000.0 * rfbw * (rfdurms / 5.12) / (42.577e06 * (10.0 / 1000.0)) 53 | 54 | rffile = open(pulsename + ".pta", "w") 55 | rffile.write("PULSENAME: {}\n".format(pulsename)) 56 | rffile.write("COMMENT: {}\n".format(comment)) 57 | rffile.write("REFGRAD: {:6.5f}\n".format(refgrad)) 58 | rffile.write("MINSLICE: {:6.5f}\n".format(minslice)) 59 | rffile.write("MAXSLICE: {:6.5f}\n".format(maxslice)) 60 | 61 | # the following are related to SAR calcs and will be calculated by 62 | # PulseTool upon loading the pulse 63 | rffile.write("AMPINT: \n") 64 | rffile.write("POWERINT: \n") 65 | rffile.write("ABSINT: \n\n") 66 | 67 | # magnitude must be between 0 and 1 68 | mxmag = np.max(np.abs(pulse)) 69 | for n in range(npts): 70 | mag = np.abs(pulse[n]) / mxmag # magnitude at current point 71 | mag = np.squeeze(mag) 72 | pha = np.angle(pulse[n]) # phase at current point 73 | pha = np.squeeze(pha) 74 | rffile.write("{:10.9f}\t{:10.9f}\t; ({:d})\n".format(mag, pha, n)) 75 | rffile.close() 76 | 77 | 78 | def signa(wav, filename, scale=-1): 79 | """Write a binary waveform in the GE format. 80 | 81 | Args: 82 | wav (array): waveform (gradient or RF), may be complex-valued. 83 | filename (string): filename to write to. 84 | scale (float): scaling factor to apply (default = waveform's max) 85 | 86 | Adapted from John Pauly's RF Tools signa() MATLAB function 87 | 88 | """ 89 | 90 | wmax = int("7ffe", 16) 91 | 92 | if not np.iscomplexobj(wav): 93 | if scale == -1: 94 | scale = 1 / np.max(np.abs(wav)) 95 | 96 | # scale up to fit in a short integer 97 | wav = wav * scale * wmax 98 | 99 | # mask off low bit, since it would be an EOS otherwise 100 | wav = 2 * np.round(wav / 2) 101 | 102 | fid = open(filename, "wb") 103 | 104 | for x in np.nditer(wav): 105 | fid.write(struct.pack(">h", int(x.item()))) 106 | 107 | fid.close() 108 | 109 | else: 110 | if scale == -1: 111 | scale = 1 / np.max( 112 | (np.max(np.abs(np.real(wav))), np.max(np.abs(np.imag(wav)))) 113 | ) 114 | 115 | # scale up to fit in a short integer 116 | wav = wav * scale * wmax 117 | 118 | # mask off low bit, since it would be an EOS otherwise 119 | wav = 2 * np.round(wav / 2) 120 | 121 | fid = open(filename + ".r", "wb") 122 | 123 | for x in np.nditer(wav): 124 | fid.write(struct.pack(">h", int(np.real(x)))) 125 | 126 | fid.close() 127 | 128 | fid = open(filename + ".i", "wb") 129 | 130 | for x in np.nditer(wav): 131 | fid.write(struct.pack(">h", int(np.imag(x)))) 132 | 133 | fid.close() 134 | 135 | 136 | def ge_rf_params(rf, dt=4e-6): 137 | """Calculate RF pulse parameters for deployment 138 | on a GE scanner. 139 | 140 | Args: 141 | rf (array): RF pulse samples 142 | dt (scalar): RF dwell time (seconds) 143 | 144 | Adapted from Adam Kerr's rf_save() MATLAB function 145 | 146 | """ 147 | 148 | print("GE RF Pulse Parameters:") 149 | 150 | n = len(rf) 151 | rfn = rf / np.max(np.abs(rf)) 152 | 153 | abswidth = np.sum(np.abs(rfn)) / n 154 | print("abswidth = ", abswidth) 155 | 156 | effwidth = np.sum(np.abs(rfn) ** 2) / n 157 | print("effwidth = ", effwidth) 158 | 159 | print("area = ", abswidth) 160 | 161 | pon = np.abs(rfn) > 0 162 | temp_pw = 0 163 | max_pw = 0 164 | for i in range(0, len(rfn)): 165 | if pon[i] == 0 & temp_pw > 0: 166 | max_pw = np.max(max_pw, temp_pw) 167 | temp_pw = 0 168 | max_pw = max_pw / n 169 | 170 | dty_cyc = np.sum(np.abs(rfn) > 0.2236) / n 171 | if dty_cyc < max_pw: 172 | dty_cyc = max_pw 173 | print("dtycyc = ", dty_cyc) 174 | print("maxpw = ", max_pw) 175 | 176 | max_b1 = np.max(np.abs(rf)) 177 | print("max_b1 = ", max_b1) 178 | 179 | int_b1_sqr = np.sum(np.abs(rf) ** 2) * dt * 1e3 180 | print("int_b1_sqr = ", int_b1_sqr) 181 | 182 | rms_b1 = np.sqrt(np.sum(np.abs(rf) ** 2)) / n 183 | print("max_rms_b1 = ", rms_b1) 184 | 185 | 186 | def philips_rf_params(rf): 187 | """Calculate RF pulse parameters for deployment 188 | on a Philips scanner. 189 | 190 | Args: 191 | rf (array): RF pulse samples (assumed real-valued) 192 | 193 | """ 194 | 195 | print("Philips RF Pulse Parameters") 196 | 197 | n = len(rf) 198 | rfn = rf / np.max(np.abs(rf)) 199 | 200 | am_c_teff = np.sum(rfn * 32767) / (32767 * n) 201 | print("am_c_teff = ", am_c_teff) 202 | 203 | am_c_trms = np.sum((rfn * 32767) ** 2) / (32767**2 * n) 204 | print("am_c_trms = ", am_c_trms) 205 | 206 | am_c_tabs = np.sum(np.abs(rfn) * 32767) / (32767 * n) 207 | print("am_c_tabs = ", am_c_tabs) 208 | 209 | # assume that the isodelay point is at the peak 210 | am_c_sym = np.argmax(np.abs(rfn)) / n 211 | print("am_c_sym = ", am_c_sym) 212 | -------------------------------------------------------------------------------- /sigpy/mri/rf/linop.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """MRI pulse-design-specific linear operators. 3 | """ 4 | import sigpy as sp 5 | from sigpy import backend 6 | 7 | 8 | def PtxSpatialExplicit(sens, coord, dt, img_shape, b0=None, ret_array=False): 9 | """Explicit spatial-domain pulse design linear operator. 10 | Linear operator relates rf pulses to desired magnetization. 11 | Equivalent matrix has dimensions [Ns Nt]. 12 | 13 | Args: 14 | sens (array): sensitivity maps. [nc dim dim] 15 | coord (None or array): coordinates. [nt 2] 16 | dt (float): hardware sampling dt. 17 | img_shape (None or tuple): image shape. 18 | b0 (array): 2D array, B0 inhomogeneity map. 19 | ret_array (bool): if true, return explicit numpy array. 20 | Else return linop. 21 | 22 | Returns: 23 | SigPy linop with A.repr_string 'pTx spatial explicit', or numpy array 24 | if selected with 'ret_array' 25 | 26 | 27 | References: 28 | Grissom, W., Yip, C., Zhang, Z., Stenger, V. A., Fessler, J. A. 29 | & Noll, D. C.(2006). 30 | Spatial Domain Method for the Design of RF Pulses in Multicoil 31 | Parallel Excitation. Magnetic resonance in medicine, 56, 620-629. 32 | """ 33 | three_d = False 34 | if len(img_shape) >= 3: 35 | three_d = True 36 | 37 | device = backend.get_device(sens) 38 | xp = device.xp 39 | with device: 40 | nc = sens.shape[0] 41 | dur = dt * coord.shape[0] # duration of pulse, in s 42 | 43 | # create time vector 44 | t = xp.expand_dims(xp.linspace(0, dur, coord.shape[0]), axis=1) 45 | 46 | # row-major order 47 | # x L to R, y T to B 48 | x_ = xp.linspace( 49 | -img_shape[0] / 2, img_shape[0] - img_shape[0] / 2, img_shape[0] 50 | ) 51 | y_ = xp.linspace( 52 | img_shape[1] / 2, -(img_shape[1] - img_shape[1] / 2), img_shape[1] 53 | ) 54 | if three_d: 55 | z_ = xp.linspace( 56 | -img_shape[2] / 2, 57 | img_shape[2] - img_shape[2] / 2, 58 | img_shape[2], 59 | ) 60 | x, y, z = xp.meshgrid(x_, y_, z_, indexing="ij") 61 | else: 62 | x, y = xp.meshgrid(x_, y_, indexing="ij") 63 | 64 | # create explicit Ns * Nt system matrix, for 3d or 2d problem 65 | if three_d: 66 | if b0 is None: 67 | AExplicit = xp.exp( 68 | 1j 69 | * ( 70 | xp.outer(x.flatten(), coord[:, 0]) 71 | + xp.outer(y.flatten(), coord[:, 1]) 72 | + xp.outer(z.flatten(), coord[:, 2]) 73 | ) 74 | ) 75 | else: 76 | AExplicit = xp.exp( 77 | 1j * 2 * xp.pi * xp.transpose(b0.flatten() * (t - dur)) 78 | + 1j 79 | * ( 80 | xp.outer(x.flatten(), coord[:, 0]) 81 | + xp.outer(y.flatten(), coord[:, 1]) 82 | + xp.outer(z.flatten(), coord[:, 2]) 83 | ) 84 | ) 85 | else: 86 | if b0 is None: 87 | AExplicit = xp.exp( 88 | 1j 89 | * ( 90 | xp.outer(x.flatten(), coord[:, 0]) 91 | + xp.outer(y.flatten(), coord[:, 1]) 92 | ) 93 | ) 94 | else: 95 | AExplicit = xp.exp( 96 | 1j * 2 * xp.pi * xp.transpose(b0.flatten() * (t - dur)) 97 | + 1j 98 | * ( 99 | xp.outer(x.flatten(), coord[:, 0]) 100 | + xp.outer(y.flatten(), coord[:, 1]) 101 | ) 102 | ) 103 | 104 | # add sensitivities to system matrix 105 | AFullExplicit = xp.empty(AExplicit.shape) 106 | for ii in range(nc): 107 | if three_d: 108 | tmp = xp.squeeze(sens[ii, :, :, :]).flatten() 109 | else: 110 | tmp = sens[ii, :, :].flatten() 111 | D = xp.transpose(xp.tile(tmp, [coord.shape[0], 1])) 112 | AFullExplicit = xp.concatenate( 113 | (AFullExplicit, D * AExplicit), axis=1 114 | ) 115 | 116 | # remove 1st empty AExplicit entries 117 | AFullExplicit = AFullExplicit[:, coord.shape[0] :] 118 | A = sp.linop.MatMul((coord.shape[0] * nc, 1), AFullExplicit) 119 | 120 | # Finally, adjustment of input/output dimensions to be consistent with 121 | # the existing Sense linop operator. [nc x nt] in, [dim x dim] out 122 | Ro = sp.linop.Reshape(ishape=A.oshape, oshape=sens.shape[1:]) 123 | Ri = sp.linop.Reshape( 124 | ishape=(nc, coord.shape[0]), oshape=(coord.shape[0] * nc, 1) 125 | ) 126 | A = Ro * A * Ri 127 | 128 | A.repr_str = "pTx spatial explicit" 129 | 130 | # output a sigpy linop or a numpy array 131 | if ret_array: 132 | return A.linops[1].mat 133 | else: 134 | return A 135 | -------------------------------------------------------------------------------- /sigpy/mri/rf/optcont.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """Optimal Control Pulse Design functions. 3 | """ 4 | from sigpy import backend 5 | 6 | __all__ = ["blochsim", "deriv"] 7 | 8 | 9 | def blochsim(rf, x, g): 10 | r"""1D RF pulse simulation, with simultaneous RF + gradient rotations. 11 | Assume x has inverse spatial units of g, and g has gamma*dt applied and 12 | assume x = [...,Ndim], g = [Ndim,Nt]. 13 | 14 | Args: 15 | rf (array): rf waveform input. 16 | x (array): spatial locations. 17 | g (array): gradient waveform. 18 | 19 | Returns: 20 | array: SLR alpha parameter 21 | array: SLR beta parameter 22 | """ 23 | 24 | device = backend.get_device(rf) 25 | xp = device.xp 26 | with device: 27 | a = xp.ones(xp.shape(x)[0], dtype=xp.complex128) 28 | b = xp.zeros(xp.shape(x)[0], dtype=xp.complex128) 29 | for mm in range(0, xp.size(rf), 1): # loop over time 30 | # apply RF 31 | c = xp.cos(xp.abs(rf[mm]) / 2) 32 | s = 1j * xp.exp(1j * xp.angle(rf[mm])) * xp.sin(xp.abs(rf[mm]) / 2) 33 | at = a * c - b * xp.conj(s) 34 | bt = a * s + b * c 35 | a = at 36 | b = bt 37 | 38 | # apply gradient 39 | if g.ndim > 1: 40 | z = xp.exp(-1j * x @ g[mm, :]) 41 | else: 42 | z = xp.exp(-1j * x * g[mm]) 43 | b = b * z 44 | 45 | # apply total phase accrual 46 | if g.ndim > 1: 47 | z = xp.exp(1j / 2 * x @ xp.sum(g, 0)) 48 | else: 49 | z = xp.exp(1j / 2 * x * xp.sum(g)) 50 | a = a * z 51 | b = b * z 52 | 53 | return a, b 54 | 55 | 56 | def deriv(rf, x, g, auxa, auxb, af, bf): 57 | r"""1D RF pulse simulation, with simultaneous RF + gradient rotations. 58 | 59 | 'rf', 'g', and 'x' should have consistent units. 60 | 61 | Args: 62 | rf (array): rf waveform input. 63 | x (array): spatial locations. 64 | g (array): gradient waveform. 65 | auxa (None or array): auxa 66 | auxb (array): auxb 67 | af (array): forward sim a. 68 | bf( array): forward sim b. 69 | 70 | Returns: 71 | array: SLR alpha parameter 72 | array: SLR beta parameter 73 | """ 74 | 75 | device = backend.get_device(rf) 76 | xp = device.xp 77 | with device: 78 | drf = xp.zeros(xp.shape(rf), dtype=xp.complex128) 79 | ar = xp.ones(xp.shape(af), dtype=xp.complex128) 80 | br = xp.zeros(xp.shape(bf), dtype=xp.complex128) 81 | 82 | for mm in range(xp.size(rf) - 1, -1, -1): 83 | # calculate gradient blip phase 84 | if g.ndim > 1: 85 | z = xp.exp(1j / 2 * x @ g[mm, :]) 86 | else: 87 | z = xp.exp(1j / 2 * x * g[mm]) 88 | 89 | # strip off gradient blip from forward sim 90 | af = af * xp.conj(z) 91 | bf = bf * z 92 | 93 | # add gradient blip to backward sim 94 | ar = ar * z 95 | br = br * z 96 | 97 | # strip off the curent rf rotation from forward sim 98 | c = xp.cos(xp.abs(rf[mm]) / 2) 99 | s = 1j * xp.exp(1j * xp.angle(rf[mm])) * xp.sin(xp.abs(rf[mm]) / 2) 100 | at = af * c + bf * xp.conj(s) 101 | bt = -af * s + bf * c 102 | af = at 103 | bf = bt 104 | 105 | # calculate derivatives wrt rf[mm] 106 | db1 = xp.conj(1j / 2 * br * bf) * auxb 107 | db2 = xp.conj(1j / 2 * af) * ar * auxb 108 | drf[mm] = xp.sum(db2 + xp.conj(db1)) 109 | if auxa is not None: 110 | da1 = xp.conj(1j / 2 * bf * ar) * auxa 111 | da2 = 1j / 2 * xp.conj(af) * br * auxa 112 | drf[mm] += xp.sum(da2 + xp.conj(da1)) 113 | 114 | # add current rf rotation to backward sim 115 | art = ar * c - xp.conj(br) * s 116 | brt = br * c + xp.conj(ar) * s 117 | ar = art 118 | br = brt 119 | 120 | return drf 121 | -------------------------------------------------------------------------------- /sigpy/mri/rf/shim.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """MRI RF shimming. 3 | """ 4 | 5 | import numpy as np 6 | 7 | import sigpy as sp 8 | from sigpy import backend 9 | from sigpy.mri import rf as rf 10 | 11 | __all__ = ["calc_shims", "init_optimal_spectral", "init_circ_polar"] 12 | 13 | 14 | def calc_shims(shim_roi, sens, x0, dt, lamb=0, max_iter=50): 15 | """RF shim designer. Uses the Gerchberg Saxton algorithm. 16 | 17 | Args: 18 | shim_roi (array): region within volume to be shimmed. Mask of 1's and 19 | 0's. [dim_x dim_y dim_z] 20 | sens (array): sensitivity maps. [Nc dim_x dim_y dim_z] 21 | x0 (array) initial guess for shim values. [Nc 1] 22 | dt (float): hardware sampling dwell time. 23 | lamb (float): regularization term. 24 | max_iter (int): max number of iterations. 25 | 26 | Returns: 27 | Vector of complex shim weights. 28 | """ 29 | 30 | k1 = np.expand_dims(np.array((0, 0, 0)), 0) 31 | A = rf.PtxSpatialExplicit( 32 | sens, coord=k1, dt=dt, img_shape=shim_roi.shape, ret_array=False 33 | ) 34 | 35 | alg_method = sp.alg.GerchbergSaxton( 36 | A, shim_roi, x0, max_iter=max_iter, tol=10e-9, lamb=lamb 37 | ) 38 | while not alg_method.done(): 39 | alg_method.update() 40 | 41 | return alg_method.x 42 | 43 | 44 | def init_optimal_spectral(A, sens, preproc=False): 45 | """Function to return initial shim weights based on an optimal spectral 46 | method, an eigenvector-based method. 47 | 48 | Args: 49 | A (linop): sigpy Linear operator. 50 | sens (array): sensitivity maps. [Nc dim_x dim_y] 51 | preproc (bool): option to apply preprocessing function before \ 52 | finding eigenvectors 53 | 54 | Returns: 55 | Vector of complex shim weights. 56 | 57 | References: 58 | Chandra, R., Zhong, Z., Hontz, J., McCulloch, V., Studer, C., 59 | Goldstein, T. (2017) 'PhasePack: A Phase Retrieval Library.' 60 | arXiv:1711.10175. 61 | """ 62 | device = backend.get_device(sens) 63 | xp = device.xp 64 | with device: 65 | if hasattr(A, "repr_str") and A.repr_str == "pTx spatial explicit": 66 | Anum = A.linops[1].mat 67 | else: 68 | Anum = A 69 | 70 | sens = sens.flatten() 71 | n = Anum.shape[1] 72 | Anumt = xp.transpose(Anum) 73 | 74 | m = sens.size 75 | y = sens**2 76 | 77 | # normalize the measurements 78 | delta = m / n 79 | ymean = y / xp.mean(y) 80 | 81 | # apply pre-processing function 82 | yplus = xp.amax(y) 83 | Y = (1 / m) * Anumt @ Anum 84 | 85 | if preproc: 86 | T = (yplus - 1) / (yplus + xp.sqrt(delta) - 1) 87 | 88 | # unnormalize 89 | T *= ymean 90 | T = xp.transpose(xp.expand_dims(T, axis=1)) 91 | 92 | for mm in range(m): 93 | col = Anum[mm, :] 94 | aat = col * xp.transpose(col) 95 | Y = Y + (1 / m) * T[mm] * aat 96 | 97 | w, v = xp.linalg.eigh(Y) 98 | 99 | return xp.expand_dims(v[:, 0], 1) 100 | 101 | 102 | def init_circ_polar(sens): 103 | """Function to return circularly polarized initial shim weights. Provides 104 | shim weights that set the phase to be even in the middle of the sens 105 | profiles. 106 | 107 | Args: 108 | sens (array): sensitivity maps. [Nc dim_x dim_y] 109 | 110 | Returns: 111 | Vector of complex shim weights. 112 | """ 113 | dim = sens.shape[1] 114 | device = backend.get_device(sens) 115 | xp = device.xp 116 | with device: 117 | # As a rough approximation, assume that the center of sens profile is 118 | # also the center of the object within the profile to be imaged. 119 | phs = xp.angle(sens[:, xp.int(dim / 2), xp.int(dim / 2)]) 120 | phs_wt = xp.exp(-phs * 1j) 121 | 122 | return xp.expand_dims(phs_wt, 1) 123 | -------------------------------------------------------------------------------- /sigpy/mri/rf/sim.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """RF Pulse Simulation Functions. 3 | 4 | """ 5 | from sigpy import backend 6 | 7 | __all__ = ["abrm", "abrm_nd", "abrm_hp", "abrm_ptx"] 8 | 9 | 10 | def abrm(rf, x, balanced=False): 11 | r"""1D RF pulse simulation, with simultaneous RF + gradient rotations. 12 | 13 | Args: 14 | rf (array): rf waveform input. 15 | x (array): spatial locations. 16 | balanced (bool): toggles application of rewinder. 17 | 18 | Returns: 19 | 2-element tuple containing 20 | 21 | - **a** (*array*): SLR alpha parameter. 22 | - **b** (*array*): SLR beta parameter. 23 | 24 | References: 25 | Pauly, J., Le Roux, Patrick., Nishimura, D., and Macovski, A.(1991). 26 | 'Parameter Relations for the Shinnar-LeRoux Selective Excitation 27 | Pulse Design Algorithm'. 28 | IEEE Transactions on Medical Imaging, Vol 10, No 1, 53-65. 29 | """ 30 | 31 | device = backend.get_device(rf) 32 | xp = device.xp 33 | with device: 34 | eps = 1e-16 35 | 36 | g = xp.ones(xp.size(rf)) * 2 * xp.pi / xp.size(rf) 37 | 38 | a = xp.ones(xp.size(x), dtype=xp.complex128) 39 | b = xp.zeros(xp.size(x), dtype=xp.complex128) 40 | for mm in range(xp.size(rf)): 41 | om = x * g[mm] 42 | phi = xp.sqrt(xp.abs(rf[mm]) ** 2 + om**2) + eps 43 | n = xp.column_stack( 44 | (xp.real(rf[mm]) / phi, xp.imag(rf[mm]) / phi, om / phi) 45 | ) 46 | av = xp.cos(phi / 2) - 1j * n[:, 2] * xp.sin(phi / 2) 47 | bv = -1j * (n[:, 0] + 1j * n[:, 1]) * xp.sin(phi / 2) 48 | at = av * a - xp.conj(bv) * b 49 | bt = bv * a + xp.conj(av) * b 50 | a = at 51 | b = bt 52 | 53 | if balanced: # apply a rewinder 54 | g = -2 * xp.pi / 2 55 | om = x * g 56 | phi = xp.abs(om) + eps 57 | nz = om / phi 58 | av = xp.cos(phi / 2) - 1j * nz * xp.sin(phi / 2) 59 | a = av * a 60 | b = xp.conj(av) * b 61 | 62 | return a, b 63 | 64 | 65 | def abrm_nd(rf, x, g): 66 | r"""N-dim RF pulse simulation 67 | 68 | Assumes that x has inverse spatial units of g, and g has gamma*dt applied. 69 | 70 | Assumes dimensions x = [...,Ndim], g = [Ndim,Nt]. 71 | 72 | Args: 73 | rf (array): rf waveform input. 74 | x (array): spatial locations. 75 | g (array): gradient array. 76 | 77 | Returns: 78 | 2-element tuple containing 79 | 80 | - **a** (*array*): SLR alpha parameter. 81 | - **b** (*array*): SLR beta parameter. 82 | 83 | References: 84 | Pauly, J., Le Roux, Patrick., Nishimura, D., and Macovski, A.(1991). 85 | 'Parameter Relations for the Shinnar-LeRoux Selective Excitation 86 | Pulse Design Algorithm'. 87 | IEEE Transactions on Medical Imaging, Vol 10, No 1, 53-65. 88 | """ 89 | 90 | device = backend.get_device(rf) 91 | xp = device.xp 92 | with device: 93 | eps = 1e-16 94 | 95 | a = xp.ones(xp.shape(x)[0], dtype=xp.complex128) 96 | b = xp.zeros(xp.shape(x)[0], dtype=xp.complex128) 97 | for mm in range(xp.size(rf)): 98 | om = x @ g[mm, :] 99 | phi = xp.sqrt(xp.abs(rf[mm]) ** 2 + om**2) 100 | n = xp.column_stack( 101 | ( 102 | xp.real(rf[mm]) / (phi + eps), 103 | xp.imag(rf[mm]) / (phi + eps), 104 | om / (phi + eps), 105 | ) 106 | ) 107 | av = xp.cos(phi / 2) - 1j * n[:, 2] * xp.sin(phi / 2) 108 | bv = -1j * (n[:, 0] + 1j * n[:, 1]) * xp.sin(phi / 2) 109 | at = av * a - xp.conj(bv) * b 110 | bt = bv * a + xp.conj(av) * b 111 | a = at 112 | b = bt 113 | 114 | return a, b 115 | 116 | 117 | def abrm_hp(rf, gamgdt, xx, dom0dt=0): 118 | r"""1D RF pulse simulation, with non-simultaneous RF + gradient rotations. 119 | 120 | Args: 121 | rf (array): rf pulse samples in radians. 122 | gamdt (array): gradient samples in radians/(units of xx). 123 | xx (array): spatial locations. 124 | dom0dt (float): off-resonance phase in radians. 125 | 126 | Returns: 127 | 2-element tuple containing 128 | 129 | - **a** (*array*): SLR alpha parameter. 130 | - **b** (*array*): SLR beta parameter. 131 | 132 | References: 133 | Pauly, J., Le Roux, Patrick., Nishimura, D., and Macovski, A.(1991). 134 | 'Parameter Relations for the Shinnar-LeRoux Selective Excitation 135 | Pulse Design Algorithm'. 136 | IEEE Transactions on Medical Imaging, Vol 10, No 1, 53-65. 137 | """ 138 | 139 | device = backend.get_device(rf) 140 | xp = device.xp 141 | with device: 142 | Ns = xp.shape(xx) 143 | Ns = Ns[0] # Ns: # of spatial locs 144 | Nt = xp.shape(gamgdt) 145 | Nt = Nt[0] # Nt: # time points 146 | 147 | a = xp.ones((Ns,)) 148 | b = xp.zeros((Ns,)) 149 | 150 | for ii in xp.arange(Nt): 151 | # apply phase accural 152 | z = xp.exp(-1j * (xx * gamgdt[ii,] + dom0dt)) 153 | b = b * z 154 | 155 | # apply rf 156 | C = xp.cos(xp.abs(rf[ii]) / 2) 157 | S = 1j * xp.exp(1j * xp.angle(rf[ii])) * xp.sin(xp.abs(rf[ii]) / 2) 158 | at = a * C - b * xp.conj(S) 159 | bt = a * S + b * C 160 | 161 | a = at 162 | b = bt 163 | 164 | z = xp.exp(1j / 2 * (xx * xp.sum(gamgdt, axis=0) + Nt * dom0dt)) 165 | a = a * z 166 | b = b * z 167 | 168 | return a, b 169 | 170 | 171 | def abrm_ptx(b1, x, g, dt, fmap=None, sens=None): 172 | r"""N-dim RF pulse simulation 173 | 174 | Assumes that x has inverse spatial units of g, and g has gamma*dt applied. 175 | 176 | Assumes dimensions rf = [Nc, Nt], x = [...,Ndim], g = [Ndim,Nt], and 177 | sens = [Nc, dim, dim]. 178 | 179 | Args: 180 | b1 (array): rf waveform input samples in radians. 181 | x (array): spatial locations (m). 182 | g (array): gradient array (mT/m with gamma*dt applied). 183 | dt (float): hardware dwell time (s). 184 | fmap (array): off-resonance map (Hz). 185 | sens (array or None): B1+ sensitivity matrix. If None, creates matrix 186 | of 1's. Input size [Nc dim dim] 187 | 188 | 189 | Returns: 190 | 4-element tuple containing 191 | 192 | - **a** (*array*): SLR alpha parameter. 193 | - **b** (*array*): SLR beta parameter. 194 | - **m** (*array*): transverse magnetization. 195 | - **mz** (*array*): longitudinal magnetization. 196 | 197 | References: 198 | Pauly, J., Le Roux, Patrick., Nishimura, D., and Macovski, A.(1991). 199 | 'Parameter Relations for the Shinnar-LeRoux Selective Excitation 200 | Pulse Design Algorithm'. 201 | IEEE Transactions on Medical Imaging, Vol 10, No 1, 53-65. 202 | 203 | Grissom, W., Xu, D., Kerr, A., Fessler, J. and Noll, D. (2009). 'Fast 204 | large-tip-angle multidimensional and parallel RF pulse design in MRI' 205 | IEEE Trans Med Imaging, Vol 28, No 10, 1548-59. 206 | """ 207 | 208 | device = backend.get_device(b1) 209 | xp = device.xp 210 | with device: 211 | gam = 267.522 * 1e6 / 1000 # rad/s/mT 212 | 213 | dim = int(xp.sqrt(x.shape[0])) 214 | Ns = dim * dim 215 | Nc = b1.shape[0] 216 | Nt = b1.shape[1] 217 | dim = int(xp.sqrt(x.shape[0])) 218 | 219 | if sens is None: 220 | sens = xp.ones((dim * dim, Nc)) 221 | else: 222 | sens = xp.transpose(sens) 223 | sens = xp.reshape(sens, (dim * dim, Nc)) 224 | 225 | bxy = sens @ b1 226 | bz = x @ xp.transpose(g) 227 | 228 | if fmap is not None and xp.sum(xp.abs(fmap)) != 0: 229 | rep_b0 = xp.repeat(xp.expand_dims(fmap.flatten(), 0), Nt, axis=0) 230 | bz += xp.transpose(rep_b0 / gam * 2 * xp.pi) 231 | 232 | statea = xp.ones((Ns, 1)) 233 | stateb = xp.zeros((Ns, 1)) 234 | a = xp.ones(xp.shape(x)[0], dtype=xp.complex128) 235 | b = xp.zeros(xp.shape(x)[0], dtype=xp.complex128) 236 | for mm in range(Nt): 237 | phi = dt * gam * xp.sqrt(xp.abs(bxy[:, mm]) ** 2 + bz[:, mm] ** 2) 238 | with xp.errstate(divide="ignore"): 239 | normfact = dt * gam * (phi**-1) 240 | normfact[xp.isinf(normfact)] = 0 241 | nxy = normfact * bxy[:, mm] 242 | nxy[xp.isinf(nxy)] = 0 243 | nz = normfact * bz[:, mm] 244 | nz[xp.isinf(nz)] = 0 245 | cp = xp.cos(phi / 2) 246 | sp = xp.sin(phi / 2) 247 | alpha = xp.expand_dims(cp + 1j * nz * sp, 1) 248 | beta = xp.expand_dims(1j * xp.conj(nxy) * sp, 1) 249 | 250 | tmpa = xp.multiply(alpha, statea) + xp.multiply(beta, stateb) 251 | tmpb = -xp.conj(beta) * statea + xp.conj(alpha) * stateb 252 | 253 | statea, stateb = tmpa, tmpb 254 | 255 | # NOT returning all states: 256 | a = statea 257 | b = -xp.conj(stateb) 258 | 259 | mxy0 = 0 + 1j * 0 260 | mz0 = 1 261 | m = mz0 * xp.conj(statea) * stateb 262 | m += mxy0 * xp.conj(statea) ** 2 263 | m -= xp.conj(mxy0) * (stateb**2) 264 | mz = mz0 * (statea * xp.conj(statea) - stateb * xp.conj(stateb)) 265 | mz += 2 * xp.real( 266 | mxy0 * xp.conj(statea) * xp.negative(xp.conj(stateb)) 267 | ) 268 | 269 | return a, b, m, mz 270 | -------------------------------------------------------------------------------- /sigpy/mri/rf/util.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """MRI RF utilities. 3 | """ 4 | 5 | import numpy as np 6 | 7 | __all__ = ["dinf"] 8 | 9 | 10 | def dinf(d1=0.01, d2=0.01): 11 | """Calculate D infinity for a linear phase filter. 12 | 13 | Args: 14 | d1 (float): passband ripple level in M0**-1. 15 | d2 (float): stopband ripple level in M0**-1. 16 | 17 | Returns: 18 | float: D infinity. 19 | 20 | References: 21 | Pauly J, Le Roux P, Nishimra D, Macovski A. Parameter relations for the 22 | Shinnar-Le Roux selective excitation pulse design algorithm. 23 | IEEE Tr Medical Imaging 1991; 10(1):53-65. 24 | 25 | """ 26 | 27 | a1 = 5.309e-3 28 | a2 = 7.114e-2 29 | a3 = -4.761e-1 30 | a4 = -2.66e-3 31 | a5 = -5.941e-1 32 | a6 = -4.278e-1 33 | 34 | l10d1 = np.log10(d1) 35 | l10d2 = np.log10(d2) 36 | 37 | d = (a1 * l10d1 * l10d1 + a2 * l10d1 + a3) * l10d2 + ( 38 | a4 * l10d1 * l10d1 + a5 * l10d1 + a6 39 | ) 40 | 41 | return d 42 | -------------------------------------------------------------------------------- /sigpy/mri/sim.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """MRI simulation functions. 3 | """ 4 | import numpy as np 5 | 6 | __all__ = ["birdcage_maps"] 7 | 8 | 9 | def birdcage_maps(shape, r=1.5, nzz=8, dtype=np.complex128): 10 | """Simulates birdcage coil sensitivies. 11 | 12 | Args: 13 | shape (tuple of ints): sensitivity maps shape, 14 | can be of length 3, and 4. 15 | r (float): relative radius of birdcage. 16 | nzz (int): number of coils per ring. 17 | dtype (Dtype): data type. 18 | 19 | Returns: 20 | array. 21 | """ 22 | 23 | if len(shape) == 3: 24 | nc, ny, nx = shape 25 | c, y, x = np.mgrid[:nc, :ny, :nx] 26 | 27 | coilx = r * np.cos(c * (2 * np.pi / nc)) 28 | coily = r * np.sin(c * (2 * np.pi / nc)) 29 | coil_phs = -c * (2 * np.pi / nc) 30 | 31 | x_co = (x - nx / 2.0) / (nx / 2.0) - coilx 32 | y_co = (y - ny / 2.0) / (ny / 2.0) - coily 33 | rr = np.sqrt(x_co**2 + y_co**2) 34 | phi = np.arctan2(x_co, -y_co) + coil_phs 35 | out = (1.0 / rr) * np.exp(1j * phi) 36 | 37 | elif len(shape) == 4: 38 | nc, nz, ny, nx = shape 39 | c, z, y, x = np.mgrid[:nc, :nz, :ny, :nx] 40 | 41 | coilx = r * np.cos(c * (2 * np.pi / nzz)) 42 | coily = r * np.sin(c * (2 * np.pi / nzz)) 43 | coilz = np.floor(c / nzz) - 0.5 * (np.ceil(nc / nzz) - 1) 44 | coil_phs = -(c + np.floor(c / nzz)) * (2 * np.pi / nzz) 45 | 46 | x_co = (x - nx / 2.0) / (nx / 2.0) - coilx 47 | y_co = (y - ny / 2.0) / (ny / 2.0) - coily 48 | z_co = (z - nz / 2.0) / (nz / 2.0) - coilz 49 | rr = (x_co**2 + y_co**2 + z_co**2) ** 0.5 50 | phi = np.arctan2(x_co, -y_co) + coil_phs 51 | out = (1 / rr) * np.exp(1j * phi) 52 | else: 53 | raise ValueError("Can only generate shape with length 3 or 4") 54 | 55 | rss = sum(abs(out) ** 2, 0) ** 0.5 56 | out /= rss 57 | 58 | return out.astype(dtype) 59 | -------------------------------------------------------------------------------- /sigpy/mri/util.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """MRI utilities. 3 | """ 4 | import numpy as np 5 | 6 | import sigpy as sp 7 | 8 | __all__ = ["get_cov", "whiten", "tseg_off_res_b_ct", "apply_tseg"] 9 | 10 | 11 | def get_cov(noise): 12 | """Get covariance matrix from noise measurements. 13 | 14 | Args: 15 | noise (array): Noise measurements of shape [num_coils, ...] 16 | 17 | Returns: 18 | array: num_coils x num_coils covariance matrix. 19 | 20 | """ 21 | num_coils = noise.shape[0] 22 | X = noise.reshape([num_coils, -1]) 23 | X -= np.mean(X, axis=-1, keepdims=True) 24 | cov = np.matmul(X, X.T.conjugate()) 25 | 26 | return cov 27 | 28 | 29 | def whiten(ksp, cov): 30 | """Whitens k-space measurements. 31 | 32 | Args: 33 | ksp (array): k-space measurements of shape [num_coils, ...] 34 | cov (array): num_coils x num_coils covariance matrix. 35 | 36 | Returns: 37 | array: whitened k-space array. 38 | 39 | """ 40 | num_coils = ksp.shape[0] 41 | 42 | x = ksp.reshape([num_coils, -1]) 43 | 44 | L = np.linalg.cholesky(cov) 45 | x_w = np.linalg.solve(L, x) 46 | ksp_w = x_w.reshape(ksp.shape) 47 | 48 | return ksp_w 49 | 50 | 51 | def tseg_off_res_b_ct(b0, bins, lseg, dt, T): 52 | """Creates B and Ct matrices needed for time-segmented off-resonance 53 | compensation. 54 | 55 | Args: 56 | b0 (array): inhomogeneity matrix. 57 | bins (int): number of histogram bins to use. 58 | lseg (int): number of time segments. 59 | dt (float): hardware dwell time (ms). 60 | T (float): length of pulse (ms). 61 | 62 | Returns: 63 | 2-element tuple containing 64 | 65 | - **B** (*array*): temporal interpolator. 66 | - **Ct** (*array*): off-resonance phase at each time segment center. 67 | """ 68 | 69 | # create time vector 70 | t = np.linspace(0, T, int(T / dt)) 71 | hist_wt, bin_edges = np.histogram( 72 | np.imag(2j * np.pi * np.concatenate(b0)), bins 73 | ) 74 | 75 | # Build B and Ct 76 | bin_centers = bin_edges[1:] - bin_edges[1] / 2 77 | zk = 0 + 1j * bin_centers 78 | tl = np.linspace(0, lseg, lseg) / lseg * T / 1000 # time seg centers 79 | # calculate off-resonance phase @ each time seg, for hist bins 80 | ch = np.exp(-np.expand_dims(tl, axis=1) @ np.expand_dims(zk, axis=0)) 81 | w = np.diag(np.sqrt(hist_wt)) 82 | p = np.linalg.pinv(w @ np.transpose(ch)) @ w 83 | b = p @ np.exp( 84 | -np.expand_dims(zk, axis=1) @ np.expand_dims(t, axis=0) / 1000 85 | ) 86 | b = np.transpose(b) 87 | b0_v = np.expand_dims(2j * np.pi * np.concatenate(b0), axis=0) 88 | ct = np.transpose(np.exp(-np.expand_dims(tl, axis=1) @ b0_v)) 89 | 90 | return b, ct 91 | 92 | 93 | def apply_tseg(array_in, coord, b, ct, fwd=True): 94 | """Apply the temporal interpolator and phase shift maps calculated 95 | 96 | Args: 97 | array_in (array): array to apply correction to. 98 | coord (array): coordinates for noncartesian trajectories. [Nt 2]. 99 | b (array): temporal interpolator. 100 | ct (array): off-resonance phase at each time segment center. 101 | fwd (Boolean): indicates forward direction (img -> kspace) or 102 | backward (kspace->img) 103 | 104 | Returns: 105 | out (array): array with correction applied. 106 | """ 107 | 108 | # get number of time segments from B input. 109 | lseg = b.shape[1] 110 | dim = array_in.shape[0] 111 | 112 | out = 0 113 | if fwd: 114 | for ii in range(lseg): 115 | ctd = np.reshape(ct[:, ii] * array_in.flatten(), (dim, dim)) 116 | out = out + b[:, ii] * sp.fourier.nufft(ctd, coord * 20) 117 | 118 | else: 119 | for ii in range(lseg): 120 | ctd = np.reshape( 121 | np.conj(ct[:, ii]) * array_in.flatten(), (dim, dim) 122 | ) 123 | out = out + sp.fourier.nufft(ctd, coord * 20) * np.conj(b[:, ii]) 124 | 125 | return np.expand_dims(out, 1) 126 | -------------------------------------------------------------------------------- /sigpy/pytorch.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """Functions for interoperability between sigpy and pytorch. 3 | 4 | """ 5 | import numpy as np 6 | 7 | from sigpy import backend, config 8 | 9 | __all__ = ["to_pytorch", "from_pytorch", "to_pytorch_function"] 10 | 11 | 12 | def to_pytorch(array, requires_grad=True): # pragma: no cover 13 | """Zero-copy conversion from numpy/cupy array to pytorch tensor. 14 | 15 | For complex array input, returns a tensor with shape + [2], 16 | where tensor[..., 0] and tensor[..., 1] represent the real 17 | and imaginary. 18 | 19 | Args: 20 | array (numpy/cupy array): input. 21 | requires_grad(bool): Set .requires_grad output tensor 22 | Returns: 23 | PyTorch tensor. 24 | 25 | """ 26 | import torch 27 | from torch.utils.dlpack import from_dlpack 28 | 29 | device = backend.get_device(array) 30 | if not np.issubdtype(array.dtype, np.floating): 31 | with device: 32 | shape = array.shape 33 | array = array.view(dtype=array.real.dtype) 34 | array = array.reshape(shape + (2,)) 35 | 36 | if device == backend.cpu_device: 37 | tensor = torch.from_numpy(array) 38 | else: 39 | tensor = from_dlpack(array.toDlpack()) 40 | 41 | tensor.requires_grad = requires_grad 42 | return tensor.contiguous() 43 | 44 | 45 | def from_pytorch(tensor, iscomplex=False): # pragma: no cover 46 | """Zero-copy conversion from pytorch tensor to numpy/cupy array. 47 | 48 | If iscomplex, then tensor must have the last dimension as 2, 49 | and the output will be viewed as a complex valued array. 50 | 51 | Args: 52 | tensor (PyTorch tensor): input. 53 | iscomplex (bool): whether input represents complex valued tensor. 54 | 55 | Returns: 56 | Numpy/cupy array. 57 | 58 | """ 59 | from torch.utils.dlpack import to_dlpack 60 | 61 | device = tensor.device 62 | if device.type == "cpu": 63 | output = tensor.detach().contiguous().numpy() 64 | else: 65 | if config.cupy_enabled: 66 | import cupy as cp 67 | 68 | output = cp.fromDlpack(to_dlpack(tensor.contiguous())) 69 | else: 70 | raise TypeError( 71 | "CuPy not installed, " 72 | "but trying to convert GPU PyTorch Tensor." 73 | ) 74 | 75 | if iscomplex: 76 | if output.shape[-1] != 2: 77 | raise ValueError( 78 | "shape[-1] must be 2 when iscomplex is " 79 | "specified, but got {}".format(output.shape) 80 | ) 81 | 82 | with backend.get_device(output): 83 | if output.dtype == np.float32: 84 | output = output.view(np.complex64) 85 | elif output.dtype == np.float64: 86 | output = output.view(np.complex128) 87 | 88 | output = output.reshape(output.shape[:-1]) 89 | 90 | return output 91 | 92 | 93 | def to_pytorch_function( 94 | linop, input_iscomplex=False, output_iscomplex=False 95 | ): # pragma: no cover 96 | """Convert SigPy Linop to PyTorch Function. 97 | 98 | The returned function can be treated as a native 99 | pytorch function performing the linop operator. 100 | The function can be backpropagated, applied on GPU arrays, 101 | and has minimal overhead as the underlying arrays 102 | are shared without copying. 103 | For complex valued input/output, the appropriate options 104 | should be set when calling the function. 105 | 106 | Args: 107 | linop (Linop): linear operator to be converted. 108 | input_iscomplex (bool): whether the PyTorch input 109 | represents complex tensor. 110 | output_iscomplex (bool): whether the PyTorch output 111 | represents complex tensor. 112 | 113 | Returns: 114 | torch.autograd.Function: equivalent PyTorch Function. 115 | 116 | """ 117 | import torch 118 | 119 | class LinopFunction(torch.autograd.Function): 120 | @staticmethod 121 | def forward(ctx, input): 122 | return to_pytorch( 123 | linop(from_pytorch(input, iscomplex=input_iscomplex)) 124 | ) 125 | 126 | @staticmethod 127 | def backward(ctx, grad_output): 128 | return to_pytorch( 129 | linop.H(from_pytorch(grad_output, iscomplex=output_iscomplex)) 130 | ) 131 | 132 | return LinopFunction 133 | -------------------------------------------------------------------------------- /sigpy/sim.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """Functions for simulations. 3 | 4 | """ 5 | import numpy as np 6 | 7 | __all__ = ["shepp_logan"] 8 | 9 | 10 | def shepp_logan(shape, dtype=np.complex128): 11 | """Generates a Shepp Logan phantom with a given shape and dtype. 12 | 13 | Args: 14 | shape (tuple of ints): shape, can be of length 2 or 3. 15 | dtype (Dtype): data type. 16 | 17 | Returns: 18 | array. 19 | 20 | """ 21 | return phantom(shape, sl_amps, sl_scales, sl_offsets, sl_angles, dtype) 22 | 23 | 24 | sl_amps = [1, -0.8, -0.2, -0.2, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1] 25 | 26 | sl_scales = [ 27 | [0.6900, 0.920, 0.810], # white big 28 | [0.6624, 0.874, 0.780], # gray big 29 | [0.1100, 0.310, 0.220], # right black 30 | [0.1600, 0.410, 0.280], # left black 31 | [0.2100, 0.250, 0.410], # gray center blob 32 | [0.0460, 0.046, 0.050], 33 | [0.0460, 0.046, 0.050], 34 | [0.0460, 0.046, 0.050], # left small dot 35 | [0.0230, 0.023, 0.020], # mid small dot 36 | [0.0230, 0.023, 0.020], 37 | ] 38 | 39 | sl_offsets = [ 40 | [0.0, 0.0, 0], 41 | [0.0, -0.0184, 0], 42 | [0.22, 0.0, 0], 43 | [-0.22, 0.0, 0], 44 | [0.0, 0.35, -0.15], 45 | [0.0, 0.1, 0.25], 46 | [0.0, -0.1, 0.25], 47 | [-0.08, -0.605, 0], 48 | [0.0, -0.606, 0], 49 | [0.06, -0.605, 0], 50 | ] 51 | 52 | sl_angles = [ 53 | [0, 0, 0], 54 | [0, 0, 0], 55 | [-18, 0, 10], 56 | [18, 0, 10], 57 | [0, 0, 0], 58 | [0, 0, 0], 59 | [0, 0, 0], 60 | [0, 0, 0], 61 | [0, 0, 0], 62 | [0, 0, 0], 63 | ] 64 | 65 | 66 | def phantom(shape, amps, scales, offsets, angles, dtype): 67 | """ 68 | Generate a cube of given shape using a list of ellipsoid 69 | parameters. 70 | """ 71 | 72 | if len(shape) == 2: 73 | ndim = 2 74 | shape = (1, shape[-2], shape[-1]) 75 | 76 | elif len(shape) == 3: 77 | ndim = 3 78 | 79 | else: 80 | raise ValueError("Incorrect dimension") 81 | 82 | out = np.zeros(shape, dtype=dtype) 83 | 84 | z, y, x = np.mgrid[ 85 | -(shape[-3] // 2) : ((shape[-3] + 1) // 2), 86 | -(shape[-2] // 2) : ((shape[-2] + 1) // 2), 87 | -(shape[-1] // 2) : ((shape[-1] + 1) // 2), 88 | ] 89 | 90 | coords = np.stack( 91 | ( 92 | x.ravel() / shape[-1] * 2, 93 | y.ravel() / shape[-2] * 2, 94 | z.ravel() / shape[-3] * 2, 95 | ) 96 | ) 97 | 98 | for amp, scale, offset, angle in zip(amps, scales, offsets, angles): 99 | ellipsoid(amp, scale, offset, angle, coords, out) 100 | 101 | if ndim == 2: 102 | return out[0, :, :] 103 | 104 | else: 105 | return out 106 | 107 | 108 | def ellipsoid(amp, scale, offset, angle, coords, out): 109 | """ 110 | Generate a cube containing an ellipsoid defined by its parameters. 111 | If out is given, fills the given cube instead of creating a new 112 | one. 113 | """ 114 | R = rotation_matrix(angle) 115 | coords = (np.matmul(R, coords) - np.reshape(offset, (3, 1))) / np.reshape( 116 | scale, (3, 1) 117 | ) 118 | 119 | r2 = np.sum(coords**2, axis=0).reshape(out.shape) 120 | 121 | out[r2 <= 1] += amp 122 | 123 | 124 | def rotation_matrix(angle): 125 | cphi = np.cos(np.radians(angle[0])) 126 | sphi = np.sin(np.radians(angle[0])) 127 | ctheta = np.cos(np.radians(angle[1])) 128 | stheta = np.sin(np.radians(angle[1])) 129 | cpsi = np.cos(np.radians(angle[2])) 130 | spsi = np.sin(np.radians(angle[2])) 131 | alpha = [ 132 | [ 133 | cpsi * cphi - ctheta * sphi * spsi, 134 | cpsi * sphi + ctheta * cphi * spsi, 135 | spsi * stheta, 136 | ], 137 | [ 138 | -spsi * cphi - ctheta * sphi * cpsi, 139 | -spsi * sphi + ctheta * cphi * cpsi, 140 | cpsi * stheta, 141 | ], 142 | [stheta * sphi, -stheta * cphi, ctheta], 143 | ] 144 | return np.array(alpha) 145 | -------------------------------------------------------------------------------- /sigpy/thresh.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """Thresholding functions. 3 | """ 4 | import numba as nb 5 | import numpy as np 6 | 7 | from sigpy import backend, config, util 8 | 9 | __all__ = [ 10 | "soft_thresh", 11 | "hard_thresh", 12 | "l1_proj", 13 | "l2_proj", 14 | "linf_proj", 15 | "psd_proj", 16 | ] 17 | 18 | 19 | def soft_thresh(lamda, input): 20 | r"""Soft threshold. 21 | 22 | Performs: 23 | 24 | .. math:: 25 | (| x | - \lambda)_+ \text{sgn}(x) 26 | 27 | Args: 28 | lamda (float, or array): Threshold parameter. 29 | input (array) 30 | 31 | Returns: 32 | array: soft-thresholded result. 33 | 34 | """ 35 | device = backend.get_device(input) 36 | xp = device.xp 37 | if xp == np: 38 | return _soft_thresh(lamda, input) 39 | else: # pragma: no cover 40 | if np.isscalar(lamda): 41 | lamda = backend.to_device(lamda, device) 42 | 43 | return _soft_thresh_cuda(lamda, input) 44 | 45 | 46 | def hard_thresh(lamda, input): 47 | """Hard threshold. 48 | 49 | Args: 50 | lamda (float, or array): Threshold parameter. 51 | input (array) 52 | 53 | Returns: 54 | array: hard-thresholded result. 55 | 56 | """ 57 | device = backend.get_device(input) 58 | xp = device.xp 59 | if xp == np: 60 | return _hard_thresh(lamda, input) 61 | else: # pragma: no cover 62 | if np.isscalar(lamda): 63 | lamda = backend.to_device(lamda, device) 64 | 65 | return _hard_thresh_cuda(lamda, input) 66 | 67 | 68 | def l1_proj(eps, input): 69 | """Projection onto L1 ball. 70 | 71 | Args: 72 | eps (float, or array): L1 ball scaling. 73 | input (array) 74 | 75 | Returns: 76 | array: Result. 77 | 78 | References: 79 | J. Duchi, S. Shalev-Shwartz, and Y. Singer, "Efficient projections onto 80 | the l1-ball for learning in high dimensions" 2008. 81 | 82 | """ 83 | xp = backend.get_array_module(input) 84 | shape = input.shape 85 | input = input.ravel() 86 | 87 | if xp.linalg.norm(input, 1) < eps: 88 | return input 89 | else: 90 | size = len(input) 91 | s = xp.sort(xp.abs(input))[::-1] 92 | st = (xp.cumsum(s) - eps) / (xp.arange(size) + 1) 93 | idx = xp.flatnonzero((s - st) > 0).max() 94 | return soft_thresh(st[idx], input.reshape(shape)) 95 | 96 | 97 | def l2_proj(eps, input, axes=None): 98 | """Projection onto L2 ball. 99 | 100 | Args: 101 | eps (float, or array): L2 ball scaling. 102 | input (array) 103 | 104 | Returns: 105 | array: Result. 106 | 107 | """ 108 | axes = util._normalize_axes(axes, input.ndim) 109 | 110 | xp = backend.get_array_module(input) 111 | norm = xp.sum(xp.abs(input) ** 2, axis=axes, keepdims=True) ** 0.5 112 | mask = norm < eps 113 | output = mask * input + (1 - mask) * (eps * input / (norm + mask)) 114 | 115 | return output 116 | 117 | 118 | def linf_proj(eps, input, bias=None): 119 | """Projection onto L-infinity ball. 120 | 121 | Args: 122 | eps (float, or array): l-infinity ball scaling. 123 | input (array) 124 | 125 | Returns: 126 | array: Result. 127 | 128 | """ 129 | if bias is not None: 130 | input = input - bias 131 | 132 | output = input - soft_thresh(eps, input) 133 | 134 | if bias is not None: 135 | output += bias 136 | 137 | return output 138 | 139 | 140 | def psd_proj(input): 141 | """Projection onto postiive semi-definite matrices. 142 | 143 | Args: 144 | input (array): a two-dimensional matrix. 145 | 146 | Returns: 147 | array: Result. 148 | 149 | """ 150 | xp = backend.get_array_module(input) 151 | w, v = xp.linalg.eig((input + xp.conj(input).T) / 2) 152 | w[w < 0] = 0 153 | return (v * w) @ v.conjugate().T 154 | 155 | 156 | @nb.vectorize # pragma: no cover 157 | def _soft_thresh(lamda, input): 158 | abs_input = abs(input) 159 | if abs_input == 0: 160 | sign = 0 161 | else: 162 | sign = input / abs_input 163 | 164 | mag = abs_input - lamda 165 | mag = (abs(mag) + mag) / 2 166 | 167 | return mag * sign 168 | 169 | 170 | @nb.vectorize # pragma: no cover 171 | def _hard_thresh(lamda, input): 172 | abs_input = abs(input) 173 | if abs_input > lamda: 174 | return input 175 | else: 176 | return 0 177 | 178 | 179 | if config.cupy_enabled: # pragma: no cover 180 | import cupy as cp 181 | 182 | _soft_thresh_cuda = cp.ElementwiseKernel( 183 | "S lamda, T input", 184 | "T output", 185 | """ 186 | S abs_input = abs(input); 187 | T sign; 188 | if (abs_input == 0) 189 | sign = 0; 190 | else 191 | sign = input / (T) abs_input; 192 | S mag = abs_input - lamda; 193 | mag = (abs(mag) + mag) / 2.; 194 | 195 | output = (T) mag * sign; 196 | """, 197 | name="soft_thresh", 198 | ) 199 | 200 | _hard_thresh_cuda = cp.ElementwiseKernel( 201 | "S lamda, T input", 202 | "T output", 203 | """ 204 | S abs_input = abs(input); 205 | if (abs_input > lamda) 206 | output = input; 207 | else 208 | output = 0; 209 | """, 210 | name="hard_thresh", 211 | ) 212 | -------------------------------------------------------------------------------- /sigpy/version.py: -------------------------------------------------------------------------------- 1 | __version__ = "0.1.27" 2 | -------------------------------------------------------------------------------- /sigpy/wavelet.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """Wavelet transform functions. 3 | """ 4 | import numpy as np 5 | import pywt 6 | 7 | from sigpy import backend, util 8 | 9 | __all__ = ["fwt", "iwt"] 10 | 11 | 12 | def get_wavelet_shape(shape, wave_name="db4", axes=None, level=None): 13 | zshape = [((i + 1) // 2) * 2 for i in shape] 14 | 15 | tmp = pywt.wavedecn( 16 | np.zeros(zshape), wave_name, mode="zero", axes=axes, level=level 17 | ) 18 | tmp, coeff_slices = pywt.coeffs_to_array(tmp, axes=axes) 19 | oshape = tmp.shape 20 | 21 | return oshape, coeff_slices 22 | 23 | 24 | def fwt(input, wave_name="db4", axes=None, level=None): 25 | """Forward wavelet transform. 26 | 27 | Args: 28 | input (array): Input array. 29 | axes (None or tuple of int): Axes to perform wavelet transform. 30 | wave_name (str): Wavelet name. 31 | level (None or int): Number of wavelet levels. 32 | """ 33 | device = backend.get_device(input) 34 | input = backend.to_device(input, backend.cpu_device) 35 | 36 | zshape = [((i + 1) // 2) * 2 for i in input.shape] 37 | zinput = util.resize(input, zshape) 38 | 39 | coeffs = pywt.wavedecn( 40 | zinput, wave_name, mode="zero", axes=axes, level=level 41 | ) 42 | output, _ = pywt.coeffs_to_array(coeffs, axes=axes) 43 | 44 | output = backend.to_device(output, device) 45 | return output 46 | 47 | 48 | def iwt(input, oshape, coeff_slices, wave_name="db4", axes=None, level=None): 49 | """Inverse wavelet transform. 50 | 51 | Args: 52 | input (array): Input array. 53 | oshape (tuple of ints): Output shape. 54 | coeff_slices (list of slice): Slices to split coefficients. 55 | axes (None or tuple of int): Axes to perform wavelet transform. 56 | wave_name (str): Wavelet name. 57 | level (None or int): Number of wavelet levels. 58 | """ 59 | device = backend.get_device(input) 60 | input = backend.to_device(input, backend.cpu_device) 61 | 62 | input = pywt.array_to_coeffs(input, coeff_slices, output_format="wavedecn") 63 | output = pywt.waverecn(input, wave_name, mode="zero", axes=axes) 64 | output = util.resize(output, oshape) 65 | 66 | output = backend.to_device(output, device) 67 | return output 68 | -------------------------------------------------------------------------------- /tests/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mikgroup/sigpy/5da0e8605f166be41e520ef0ef913482487611d8/tests/__init__.py -------------------------------------------------------------------------------- /tests/learn/test_app.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | 3 | import numpy as np 4 | import numpy.testing as npt 5 | 6 | from sigpy.learn import app 7 | 8 | if __name__ == "__main__": 9 | unittest.main() 10 | 11 | 12 | class TestApp(unittest.TestCase): 13 | def test_ConvSparseDecom(self): 14 | lamda = 1e-9 15 | L = np.array([[1, 1], [1, -1]], dtype=np.float64) / 2**0.5 16 | y = np.array([[1, 1]], dtype=np.float64) / 2**0.5 17 | 18 | R = app.ConvSparseDecom(y, L, lamda=lamda).run() 19 | 20 | npt.assert_allclose(R, [[[1], [0]]]) 21 | 22 | def test_ConvSparseCoefficients(self): 23 | lamda = 1e-10 24 | L = np.array([[1, 1], [1, -1]], dtype=np.float64) / 2**0.5 25 | y = np.array([[1, 1]], dtype=np.float64) / 2**0.5 26 | 27 | R_j = app.ConvSparseCoefficients(y, L, lamda=lamda) 28 | 29 | npt.assert_allclose(R_j[:], [[[1], [0]]]) 30 | npt.assert_allclose(R_j[0, :], [[1], [0]]) 31 | npt.assert_allclose(R_j[:, 0], [[1]]) 32 | npt.assert_allclose(R_j[:, :, 0], [[1, 0]]) 33 | 34 | def test_ConvSparseCoding(self): 35 | num_atoms = 1 36 | filt_width = 2 37 | batch_size = 1 38 | y = np.array([[1, 1]], dtype=np.float64) / 2**0.5 39 | lamda = 1e-10 40 | alpha = 1 41 | 42 | L, _ = app.ConvSparseCoding( 43 | y, 44 | num_atoms, 45 | filt_width, 46 | batch_size, 47 | alpha=alpha, 48 | lamda=lamda, 49 | max_iter=100, 50 | ).run() 51 | 52 | npt.assert_allclose( 53 | np.abs(L), [[1 / 2**0.5, 1 / 2**0.5]], atol=0.1, rtol=0.1 54 | ) 55 | 56 | def test_LinearRegression(self): 57 | n = 2 58 | k = 5 59 | m = 4 60 | batch_size = n 61 | 62 | X = np.random.randn(n, k) 63 | y = np.random.randn(n, m) 64 | 65 | alpha = 1 / np.linalg.svd(X, compute_uv=False)[0] ** 2 66 | mat = app.LinearRegression(X, y, batch_size, alpha, max_iter=300).run() 67 | mat_lstsq = np.linalg.lstsq(X, y, rcond=-1)[0] 68 | 69 | npt.assert_allclose(mat, mat_lstsq, atol=1e-2, rtol=1e-2) 70 | -------------------------------------------------------------------------------- /tests/learn/test_util.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | 3 | import numpy as np 4 | import numpy.testing as npt 5 | 6 | from sigpy.learn import util 7 | 8 | if __name__ == "__main__": 9 | unittest.main() 10 | 11 | 12 | class TestUtil(unittest.TestCase): 13 | def test_labels_to_scores(self): 14 | labels = np.array([0, 1, 2]) 15 | 16 | scores = util.labels_to_scores(labels) 17 | 18 | npt.assert_allclose(scores, [[1, 0, 0], [0, 1, 0], [0, 0, 1]]) 19 | 20 | def test_scores_to_labels(self): 21 | scores = np.array([[1, 0, 0], [0, 1, 0], [0, 0, 1]]) 22 | 23 | labels = util.scores_to_labels(scores) 24 | 25 | npt.assert_allclose(labels, [0, 1, 2]) 26 | -------------------------------------------------------------------------------- /tests/mri/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mikgroup/sigpy/5da0e8605f166be41e520ef0ef913482487611d8/tests/mri/__init__.py -------------------------------------------------------------------------------- /tests/mri/rf/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mikgroup/sigpy/5da0e8605f166be41e520ef0ef913482487611d8/tests/mri/rf/__init__.py -------------------------------------------------------------------------------- /tests/mri/rf/test_adiabatic.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | 3 | import numpy as np 4 | import numpy.testing as npt 5 | 6 | from sigpy.mri import rf 7 | 8 | if __name__ == "__main__": 9 | unittest.main() 10 | 11 | 12 | class TestAdiabatic(unittest.TestCase): 13 | def test_bir4(self): 14 | # test an excitation bir4 pulse 15 | n = 1176 16 | dt = 4e-6 17 | dw0 = 100 * np.pi / dt / n 18 | beta = 10 19 | kappa = np.arctan(20) 20 | flip = np.pi / 4 21 | [am_bir, om_bir] = rf.adiabatic.bir4(n, beta, kappa, flip, dw0) 22 | 23 | # check relatively homogeneous over range of B1 values 24 | b1 = np.arange(0.2, 0.8, 0.1) 25 | b1 = np.reshape(b1, (np.size(b1), 1)) 26 | a = np.zeros(np.shape(b1), dtype=np.complex128) 27 | b = np.zeros(np.shape(b1), dtype=np.complex128) 28 | 29 | for ii in range(0, np.size(b1)): 30 | [a[ii], b[ii]] = rf.sim.abrm_nd( 31 | 2 * np.pi * dt * 4258 * b1[ii] * am_bir, 32 | np.ones(1), 33 | dt * np.reshape(om_bir, (np.size(om_bir), 1)), 34 | ) 35 | 36 | mxy = 2 * np.multiply(np.conj(a), b) 37 | 38 | test = np.ones(mxy.shape) * 0.7 # magnetization value we expect 39 | 40 | npt.assert_array_almost_equal(np.abs(mxy), test, 2) 41 | 42 | def test_hyp_ex(self): 43 | # test an inversion adiabatic hyp pulse 44 | n = 512 45 | beta = 800 46 | mu = 4.9 47 | dur = 0.012 48 | [am_sech, om_sech] = rf.adiabatic.hypsec(n, beta, mu, dur) 49 | 50 | # check relatively homogeneous over range of B1 values 51 | b1 = np.arange(0.2, 0.8, 0.1) 52 | b1 = np.reshape(b1, (np.size(b1), 1)) 53 | 54 | a = np.zeros(np.shape(b1), dtype=np.complex128) 55 | b = np.zeros(np.shape(b1), dtype=np.complex128) 56 | for ii in range(0, np.size(b1)): 57 | [a[ii], b[ii]] = rf.sim.abrm_nd( 58 | 2 * np.pi * (dur / n) * 4258 * b1[ii] * am_sech, 59 | np.ones(1), 60 | dur / n * np.reshape(om_sech, (np.size(om_sech), 1)), 61 | ) 62 | mz = 1 - 2 * np.abs(b) ** 2 63 | 64 | test = np.ones(mz.shape) * -1 # magnetization value we expect 65 | 66 | npt.assert_array_almost_equal(mz, test, 2) 67 | 68 | def test_goia_wurst(self): 69 | # test a goia-wurst adiabatic pulse 70 | n = 512 71 | dur = 3.5e-3 72 | f = 0.9 73 | n_b1 = 16 74 | m_grad = 4 75 | [_, om_goia, g_goia] = rf.adiabatic.goia_wurst(n, dur, f, n_b1, m_grad) 76 | 77 | # test midpoint of goia pulse. Expect 1-f g, 0.1 fm 78 | npt.assert_almost_equal(g_goia[int(len(g_goia) / 2)], 1 - f, 2) 79 | npt.assert_almost_equal(g_goia[int(len(om_goia) / 2)], 0.1, 2) 80 | -------------------------------------------------------------------------------- /tests/mri/rf/test_b1sel.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | 3 | import numpy as np 4 | import numpy.testing as npt 5 | 6 | import sigpy.mri.rf as rf 7 | 8 | if __name__ == "__main__": 9 | unittest.main() 10 | 11 | 12 | class TestB1sel(unittest.TestCase): 13 | def test_b1sel_generic(self): 14 | dt = 2e-6 # sampling period 15 | d1 = 0.01 # passband ripple 16 | d2 = 0.01 # stopband ripple 17 | tb = 4 # time-bandwidth product 18 | ptype = "ex" # 'st', 'ex', 'inv' or 'sat' 19 | pbw = 0.5 # gauss, passband width 20 | pbc = 5 # gauss, passband center 21 | flip = np.pi / 4 # radians, flip angle 22 | 23 | [rf_am, rf_fm] = rf.b1sel.dz_b1_rf( 24 | dt, tb, ptype, flip, pbw, pbc, d1, d2 25 | ) 26 | b1 = np.arange( 27 | 0, 2 * pbc, 2 * pbc / np.size(rf_am) * 4 28 | ) # b1 grid we simulate the pulse over 29 | b1 = np.reshape(b1, (np.size(b1), 1)) 30 | [a, b] = rf.sim.abrm_nd( 31 | 2 * np.pi * dt * rf_fm, 32 | b1, 33 | 2 * np.pi * 4258 * dt * np.reshape(rf_am, (np.size(rf_am), 1)), 34 | ) 35 | mxy = -2 * np.real(a * b) + 1j * np.imag(np.conj(a) ** 2 - b**2) 36 | 37 | pts = np.array([mxy[10], mxy[int(len(b1) / 2)], mxy[len(b1) - 10]]) 38 | npt.assert_almost_equal(abs(pts), np.array([0, 0.7, 0]), decimal=1) 39 | 40 | def test_b1sel_gslider(self): 41 | g = 5 42 | flip = np.pi / 2 43 | ptype = "ex" # 'ex' or 'st' 44 | tb = 12 45 | d1 = 0.01 46 | d2 = 0.01 47 | pbc = 1 # gauss, passband center 48 | pbw = 0.25 # passband width 49 | dt = 2e-6 # seconds, sampling rate 50 | [om1, dom] = rf.b1sel.dz_b1_gslider_rf( 51 | dt, g, tb, ptype, flip, pbw, pbc, d1, d2 52 | ) 53 | 54 | n = np.shape(om1)[0] 55 | b1 = np.arange( 56 | 0, 2 * pbc, 2 * pbc / n * 4 57 | ) # b1 grid we simulate the pulse over 58 | b1 = np.reshape(b1, (np.size(b1), 1)) 59 | [a, b] = rf.sim.abrm_nd( 60 | 2 * np.pi * dt * dom[:, 0], 61 | b1, 62 | 2 * np.pi * 4258 * dt * np.reshape(om1[:, 0], (n, 1)), 63 | ) 64 | mxy = -2 * np.real(a * b) + 1j * np.imag(np.conj(a) ** 2 - b**2) 65 | 66 | pts = np.array([mxy[10], mxy[int(len(b1) / 2)], mxy[len(b1) - 10]]) 67 | npt.assert_almost_equal(abs(pts), np.array([0, 1, 0]), decimal=2) 68 | -------------------------------------------------------------------------------- /tests/mri/rf/test_linop.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | 3 | import numpy as np 4 | 5 | import sigpy as sp 6 | import sigpy.mri.rf as rf 7 | from sigpy.mri.rf import linop 8 | 9 | if __name__ == "__main__": 10 | unittest.main() 11 | 12 | 13 | def check_linop_adjoint(A, dtype=np.float32, device=sp.cpu_device): 14 | device = sp.Device(device) 15 | x = sp.randn(A.ishape, dtype=dtype, device=device) 16 | y = sp.randn(A.oshape, dtype=dtype, device=device) 17 | 18 | xp = device.xp 19 | with device: 20 | lhs = xp.vdot(A * x, y) 21 | rhs = xp.vdot(x, A.H * y) 22 | 23 | xp.testing.assert_allclose(lhs, rhs, atol=1e-5, rtol=1e-5) 24 | 25 | 26 | class TestLinop(unittest.TestCase): 27 | def test_spatial_explicit_model(self): 28 | dim = 3 29 | img_shape = [dim, dim, dim] 30 | mps_shape = [8, dim, dim, dim] 31 | 32 | dt = 4e-6 33 | 34 | k = sp.mri.spiral( 35 | fov=dim / 2, 36 | N=dim, 37 | f_sampling=1, 38 | R=1, 39 | ninterleaves=1, 40 | alpha=1, 41 | gm=0.03, 42 | sm=200, 43 | ) 44 | k = rf.stack_of(k, 3, 0.1) 45 | 46 | mps = sp.randn(mps_shape, dtype=np.complex64) 47 | 48 | A = linop.PtxSpatialExplicit(mps, k, dt, img_shape) 49 | 50 | check_linop_adjoint(A, dtype=np.complex64) 51 | -------------------------------------------------------------------------------- /tests/mri/rf/test_multiband.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | 3 | import numpy as np 4 | import numpy.testing as npt 5 | 6 | import sigpy.mri.rf as rf 7 | 8 | if __name__ == "__main__": 9 | unittest.main() 10 | 11 | 12 | class TestMultiband(unittest.TestCase): 13 | def test_multiband(self): 14 | # slr pulse 15 | tb = 8 16 | N = 512 17 | d1 = 0.01 18 | d2 = 0.01 19 | p_type = "ex" 20 | f_type = "ls" 21 | pulse = rf.slr.dzrf(N, tb, p_type, f_type, d1, d2, True) 22 | 23 | # multiband it 24 | n_bands = 3 25 | phs_type = "phs_mod" # phsMod, ampMod, or quadMod 26 | band_sep = 5 * tb # separate by 5 slice widths 27 | mb_pulse = rf.multiband.mb_rf(pulse, n_bands, band_sep, phs_type) 28 | 29 | # simulate it 30 | [a, b] = rf.sim.abrm( 31 | mb_pulse, np.arange(-20 * tb, 20 * tb, 40 * tb / 2000), True 32 | ) 33 | mxy = 2 * np.multiply(np.conj(a), b) 34 | 35 | pts = np.array([mxy[750], mxy[850], mxy[1000], mxy[1150], mxy[1250]]) 36 | npt.assert_almost_equal(abs(pts), np.array([1, 0, 1, 0, 1]), decimal=2) 37 | 38 | def test_pins(self): 39 | # pins pulse specs 40 | tb = 8 41 | d1 = 0.01 42 | d2 = 0.01 43 | p_type = "ex" 44 | f_type = "ls" 45 | 46 | sl_sep = 3 # cm 47 | sl_thick = 0.3 # cm 48 | g_max = 4 # gauss/cm 49 | g_slew = 18000 # gauss/cm/s 50 | dt = 4e-6 # seconds, dwell time 51 | b1_max = 0.18 # gauss 52 | [rf_pins, g_pins] = rf.multiband.dz_pins( 53 | tb, 54 | sl_sep, 55 | sl_thick, 56 | g_max, 57 | g_slew, 58 | dt, 59 | b1_max, 60 | p_type, 61 | f_type, 62 | d1, 63 | d2, 64 | ) 65 | 66 | # simulate it 67 | x = np.reshape(np.arange(-1000, 1000), (2000, 1)) / 1000 * 12 # cm 68 | [a, b] = rf.sim.abrm_nd( 69 | 2 * np.pi * dt * 4258 * rf_pins, 70 | x, 71 | np.reshape(g_pins, (np.size(g_pins), 1)) * 4258 * dt * 2 * np.pi, 72 | ) 73 | mxy = 2 * np.conj(a) * b 74 | 75 | pts = np.array([mxy[100], mxy[1000], mxy[1900]]) 76 | npt.assert_almost_equal(abs(pts), np.array([0, 1, 0]), decimal=2) 77 | -------------------------------------------------------------------------------- /tests/mri/rf/test_ptx.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | 3 | import numpy as np 4 | import numpy.testing as npt 5 | from scipy.ndimage import gaussian_filter 6 | 7 | import sigpy as sp 8 | from sigpy.mri import linop, rf, sim 9 | 10 | if __name__ == "__main__": 11 | unittest.main() 12 | 13 | 14 | class TestPtx(unittest.TestCase): 15 | @staticmethod 16 | def problem_2d(dim): 17 | img_shape = [dim, dim] 18 | sens_shape = [8, dim, dim] 19 | 20 | # target - slightly blurred circle 21 | x, y = np.ogrid[ 22 | -img_shape[0] / 2 : img_shape[0] - img_shape[0] / 2, 23 | -img_shape[1] / 2 : img_shape[1] - img_shape[1] / 2, 24 | ] 25 | circle = x * x + y * y <= int(img_shape[0] / 6) ** 2 26 | target = np.zeros(img_shape) 27 | target[circle] = 1 28 | target = gaussian_filter(target, 1) 29 | target = target.astype(np.complex64) 30 | 31 | sens = sim.birdcage_maps(sens_shape) 32 | 33 | return target, sens 34 | 35 | @staticmethod 36 | def problem_3d(dim, Nz): 37 | Nc = 8 38 | img_shape = [dim, dim, Nz] 39 | sens_shape = [Nc, dim, dim, Nz] 40 | 41 | # target - slightly blurred circle 42 | x, y, z = np.ogrid[ 43 | -img_shape[0] / 2 : img_shape[0] - img_shape[0] / 2, 44 | -img_shape[1] / 2 : img_shape[1] - img_shape[1] / 2, 45 | -img_shape[2] / 2 : img_shape[2] - img_shape[2] / 2, 46 | ] 47 | circle = x * x + y * y + z * z <= int(img_shape[0] / 5) ** 2 48 | target = np.zeros(img_shape) 49 | target[circle] = 1 50 | target = gaussian_filter(target, 1) 51 | target = target.astype(np.complex64) 52 | sens = sp.mri.sim.birdcage_maps(sens_shape) 53 | 54 | return target, sens 55 | 56 | def test_stspa_radial(self): 57 | target, sens = self.problem_2d(8) 58 | 59 | # makes dim*dim*2 trajectory 60 | traj = sp.mri.radial( 61 | (sens.shape[1], sens.shape[1], 2), 62 | target.shape, 63 | golden=True, 64 | dtype=np.float32, 65 | ) 66 | # reshape to be Nt*2 trajectory 67 | traj = np.reshape(traj, [traj.shape[0] * traj.shape[1], 2]) 68 | 69 | A = linop.Sense(sens, coord=traj, weights=None, ishape=target.shape).H 70 | 71 | pulses = rf.stspa( 72 | target, 73 | sens, 74 | traj, 75 | dt=4e-6, 76 | alpha=1, 77 | b0=None, 78 | st=None, 79 | explicit=False, 80 | max_iter=100, 81 | tol=1e-4, 82 | ) 83 | 84 | npt.assert_array_almost_equal(A * pulses, target, 1e-3) 85 | 86 | def test_stspa_spiral(self): 87 | target, sens = self.problem_2d(8) 88 | 89 | fov = 0.55 90 | gts = 6.4e-6 91 | gslew = 190 92 | gamp = 40 93 | R = 1 94 | dx = 0.025 # in m 95 | # construct a trajectory 96 | g, k, t, s = rf.spiral_arch(fov / R, dx, gts, gslew, gamp) 97 | 98 | A = linop.Sense(sens, coord=k, ishape=target.shape).H 99 | 100 | pulses = rf.stspa( 101 | target, 102 | sens, 103 | k, 104 | dt=4e-6, 105 | alpha=1, 106 | b0=None, 107 | st=None, 108 | explicit=False, 109 | max_iter=100, 110 | tol=1e-4, 111 | ) 112 | 113 | npt.assert_array_almost_equal(A * pulses, target, 1e-3) 114 | 115 | def test_stspa_2d_explicit(self): 116 | target, sens = self.problem_2d(8) 117 | dim = target.shape[0] 118 | g, k1, t, s = rf.spiral_arch(0.24, dim, 4e-6, 200, 0.035) 119 | k1 = k1 / dim 120 | 121 | A = rf.PtxSpatialExplicit( 122 | sens, k1, dt=4e-6, img_shape=target.shape, b0=None 123 | ) 124 | pulses = sp.mri.rf.stspa( 125 | target, 126 | sens, 127 | st=None, 128 | coord=k1, 129 | dt=4e-6, 130 | max_iter=100, 131 | alpha=10, 132 | tol=1e-4, 133 | phase_update_interval=200, 134 | explicit=True, 135 | ) 136 | 137 | npt.assert_array_almost_equal(A * pulses, target, 1e-3) 138 | 139 | def test_stspa_3d_explicit(self): 140 | nz = 4 141 | target, sens = self.problem_3d(3, nz) 142 | dim = target.shape[0] 143 | 144 | g, k1, t, s = rf.spiral_arch(0.24, dim, 4e-6, 200, 0.035) 145 | k1 = k1 / dim 146 | 147 | k1 = rf.stack_of(k1, nz, 0.1) 148 | A = rf.linop.PtxSpatialExplicit( 149 | sens, k1, dt=4e-6, img_shape=target.shape, b0=None 150 | ) 151 | 152 | pulses = sp.mri.rf.stspa( 153 | target, 154 | sens, 155 | st=None, 156 | coord=k1, 157 | dt=4e-6, 158 | max_iter=30, 159 | alpha=10, 160 | tol=1e-3, 161 | phase_update_interval=200, 162 | explicit=True, 163 | ) 164 | 165 | npt.assert_array_almost_equal(A * pulses, target, 1e-3) 166 | 167 | def test_stspa_3d_nonexplicit(self): 168 | nz = 3 169 | target, sens = self.problem_3d(3, nz) 170 | dim = target.shape[0] 171 | 172 | g, k1, t, s = rf.spiral_arch(0.24, dim, 4e-6, 200, 0.035) 173 | k1 = k1 / dim 174 | 175 | k1 = rf.stack_of(k1, nz, 0.1) 176 | A = sp.mri.linop.Sense( 177 | sens, k1, weights=None, tseg=None, ishape=target.shape 178 | ).H 179 | 180 | pulses = sp.mri.rf.stspa( 181 | target, 182 | sens, 183 | st=None, 184 | coord=k1, 185 | dt=4e-6, 186 | max_iter=30, 187 | alpha=10, 188 | tol=1e-3, 189 | phase_update_interval=200, 190 | explicit=False, 191 | ) 192 | 193 | npt.assert_array_almost_equal(A * pulses, target, 1e-3) 194 | 195 | def test_spokes(self): 196 | # spokes problem definition: 197 | dim = 20 # size of the b1 matrix loaded 198 | n_spokes = 5 199 | fov = 20 # cm 200 | dx_max = 2 # cm 201 | gts = 4e-6 202 | sl_thick = 5 # slice thickness, mm 203 | tbw = 4 204 | dgdtmax = 18000 # g/cm/s 205 | gmax = 2 # g/cm 206 | 207 | _, sens = self.problem_2d(dim) 208 | roi = np.zeros((dim, dim)) 209 | radius = dim // 2 210 | cx, cy = dim // 2, dim // 2 211 | y, x = np.ogrid[-radius:radius, -radius:radius] 212 | index = x**2 + y**2 <= radius**2 213 | roi[cy - radius : cy + radius, cx - radius : cx + radius][index] = 1 214 | sens = sens * roi 215 | 216 | [pulses, g] = rf.stspk( 217 | roi, 218 | sens, 219 | n_spokes, 220 | fov, 221 | dx_max, 222 | gts, 223 | sl_thick, 224 | tbw, 225 | dgdtmax, 226 | gmax, 227 | alpha=1, 228 | ) 229 | 230 | # should give the number of pulses corresponding to number of TX ch 231 | npt.assert_equal(np.shape(pulses)[0], np.shape(sens)[0]) 232 | # should hit the max gradient constraint 233 | npt.assert_almost_equal(gmax, np.max(g), decimal=3) 234 | -------------------------------------------------------------------------------- /tests/mri/rf/test_sim.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | 3 | import numpy as np 4 | import numpy.testing as npt 5 | 6 | import sigpy.mri.rf as rf 7 | 8 | if __name__ == "__main__": 9 | unittest.main() 10 | 11 | 12 | class TestSim(unittest.TestCase): 13 | def test_abrm(self): 14 | # also provides testing of SLR excitation. Check ex profile sim. 15 | tb = 8 16 | N = 128 17 | d1 = 0.01 18 | d2 = 0.01 19 | ptype = "ex" 20 | ftype = "ls" 21 | 22 | pulse = rf.slr.dzrf(N, tb, ptype, ftype, d1, d2, False) 23 | [a, b] = rf.sim.abrm(pulse, np.arange(-2 * tb, 2 * tb, 0.01), True) 24 | Mxy = 2 * np.multiply(np.conj(a), b) 25 | 26 | pts = np.array( 27 | [ 28 | Mxy[int(len(Mxy) / 2 - len(Mxy) / 3)], 29 | Mxy[int(len(Mxy) / 2)], 30 | Mxy[int(len(Mxy) / 2 + len(Mxy) / 3)], 31 | ] 32 | ) 33 | 34 | npt.assert_almost_equal(abs(pts), np.array([0, 1, 0]), decimal=2) 35 | -------------------------------------------------------------------------------- /tests/mri/rf/test_slr.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | 3 | import numpy as np 4 | import numpy.testing as npt 5 | 6 | import sigpy as sp 7 | import sigpy.mri.rf as rf 8 | 9 | if __name__ == "__main__": 10 | unittest.main() 11 | 12 | 13 | class TestSlr(unittest.TestCase): 14 | def test_st(self): 15 | # check to make sure profile roughly matches anticipated within d1, d2 16 | N = 128 17 | tb = 16 18 | filts = ["ls", "ms", "pm", "min", "max"] 19 | for idx, filt in enumerate(filts): 20 | pulse = sp.mri.rf.dzrf( 21 | N, tb, ptype="st", ftype=filt, d1=0.01, d2=0.01 22 | ) 23 | 24 | m = np.abs(sp.fft(pulse, norm=None)) 25 | 26 | pts = np.array( 27 | [m[int(N / 2 - 10)], m[int(N / 2)], m[int(N / 2 + 10)]] 28 | ) 29 | npt.assert_almost_equal(pts, np.array([0, 1, 0]), decimal=2) 30 | 31 | def test_inv(self): 32 | # also provides testing of sim. Check inv profile. 33 | tb = 8 34 | N = 128 35 | d1 = 0.01 36 | d2 = 0.01 37 | ptype = "ex" 38 | filts = ["min", "max"] # filts produce inconsistent inversions 39 | 40 | for idx, filt in enumerate(filts): 41 | pulse = rf.slr.dzrf(N, tb, ptype, filt, d1, d2) 42 | 43 | [_, b] = rf.sim.abrm(pulse, np.arange(-2 * tb, 2 * tb, 0.01)) 44 | mz = 1 - 2 * np.abs(b) ** 2 45 | 46 | pts = np.array( 47 | [ 48 | mz[int(len(mz) / 2 - len(mz) / 3)], 49 | mz[int(len(mz) / 2)], 50 | mz[int(len(mz) / 2 + len(mz) / 3)], 51 | ] 52 | ) 53 | 54 | npt.assert_almost_equal(pts, np.array([1, -0.2, 1]), decimal=1) 55 | 56 | def test_root_flipped(self): 57 | tb = 12 58 | N = 128 59 | d1 = 0.01 60 | d2 = 0.001 61 | flip = np.pi / 2 62 | ptype = "ex" 63 | [bsf, d1, d2] = rf.slr.calc_ripples(ptype, d1, d2) 64 | b = bsf * rf.slr.dzmp(N, tb, d1, d2) 65 | b = b[::-1] 66 | [pulse, _] = rf.slr.root_flip(b, d1, flip, tb, verbose=False) 67 | 68 | [_, b] = rf.sim.abrm(pulse, np.arange(-2 * tb, 2 * tb, 0.01)) 69 | mz = 1 - 2 * np.abs(b) ** 2 70 | 71 | pts = np.array( 72 | [ 73 | mz[int(len(mz) / 2 - len(mz) / 3)], 74 | mz[int(len(mz) / 2)], 75 | mz[int(len(mz) / 2 + len(mz) / 3)], 76 | ] 77 | ) 78 | 79 | npt.assert_almost_equal(pts, np.array([1, 0.2, 1]), decimal=1) 80 | 81 | def test_recursive(self): 82 | # Design the pulses 83 | nseg = 3 # number of EPI segments/RF Pulses 84 | tb = 4 85 | n = 200 86 | se_seq = True 87 | tb_ref = 8 # time-bandwidth of ref pulse 88 | [pulses, _] = rf.slr.dz_recursive_rf(nseg, tb, n, se_seq, tb_ref) 89 | 90 | mz = np.ones(np.size(np.arange(-4 * tb, 4 * tb, 0.01))) 91 | for ii in range(0, nseg): 92 | [a, b] = rf.sim.abrm( 93 | pulses[:, ii], np.arange(-4 * tb, 4 * tb, 0.01), True 94 | ) 95 | mxy = 2 * mz * np.multiply(np.conj(a), b) 96 | mz = mz * (1 - 2 * np.abs(b) ** 2) 97 | 98 | pts = np.array( 99 | [ 100 | mxy[int(len(mxy) / 2 - len(mxy) / 3)], 101 | mxy[int(len(mxy) / 2)], 102 | mxy[int(len(mxy) / 2 + len(mxy) / 3)], 103 | ] 104 | ) 105 | 106 | npt.assert_almost_equal(abs(pts), np.array([0, 0.5, 0]), decimal=1) 107 | 108 | def test_gslider(self): 109 | n = 512 110 | g = 5 111 | ex_flip = 90 * np.pi / 180 112 | tb = 12 113 | d1 = 0.01 114 | d2 = 0.01 115 | phi = np.pi 116 | 117 | pulses = rf.slr.dz_gslider_rf( 118 | n, g, ex_flip, phi, tb, d1, d2, cancel_alpha_phs=True 119 | ) 120 | 121 | for gind in range(1, g + 1): 122 | [a, b] = rf.sim.abrm( 123 | pulses[:, gind - 1], np.arange(-2 * tb, 2 * tb, 0.01), True 124 | ) 125 | mxy = 2 * np.multiply(np.conj(a), b) 126 | 127 | pts = np.array( 128 | [ 129 | mxy[int(len(mxy) / 2 - len(mxy) / 3)], 130 | mxy[int(len(mxy) / 2)], 131 | mxy[int(len(mxy) / 2 + len(mxy) / 3)], 132 | ] 133 | ) 134 | 135 | npt.assert_almost_equal(abs(pts), np.array([0, 1, 0]), decimal=2) 136 | -------------------------------------------------------------------------------- /tests/mri/rf/test_trajgrad.py: -------------------------------------------------------------------------------- 1 | import math 2 | import unittest 3 | 4 | import numpy as np 5 | import numpy.testing as npt 6 | 7 | import sigpy.mri.rf as rf 8 | 9 | if __name__ == "__main__": 10 | unittest.main() 11 | 12 | 13 | class TestTrajGrad(unittest.TestCase): 14 | def test_min_gradient(self): 15 | t = np.linspace(0, 1, 1000) 16 | kx = np.sin(2.0 * math.pi * t) 17 | ky = np.cos(2.0 * math.pi * t) 18 | kz = t 19 | k = np.stack((kx, ky, kz), axis=-1) 20 | 21 | (g, k, s, t) = rf.min_time_gradient( 22 | k, 0.0, 0.0, gmax=4, smax=15, dt=4e-3, gamma=4.257 23 | ) 24 | 25 | npt.assert_almost_equal(np.max(t), 0.916, decimal=4) 26 | 27 | def test_trap_grad(self): 28 | dt = 4e-6 # s 29 | area = 200 * dt 30 | dgdt = 18000 # g/cm/s 31 | gmax = 2 # g/cm 32 | 33 | trap, _ = rf.trap_grad(area, gmax, dgdt, dt) 34 | 35 | npt.assert_almost_equal(area, np.sum(trap) * dt, decimal=3) 36 | npt.assert_almost_equal(gmax, np.max(trap), decimal=1) 37 | 38 | def test_min_trap_grad(self): 39 | dt = 4e-6 # s 40 | area = 200 * dt 41 | dgdt = 18000 # g/cm/s 42 | gmax = 2 # g/cm 43 | 44 | trap, _ = rf.min_trap_grad(area, gmax, dgdt, dt) 45 | 46 | npt.assert_almost_equal(area, np.sum(trap) * dt, decimal=3) 47 | npt.assert_almost_equal(gmax, np.max(trap), decimal=1) 48 | -------------------------------------------------------------------------------- /tests/mri/test_app.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | 3 | import numpy as np 4 | import numpy.testing as npt 5 | 6 | import sigpy as sp 7 | from sigpy.mri import app, sim 8 | 9 | if __name__ == "__main__": 10 | unittest.main() 11 | 12 | 13 | class TestApp(unittest.TestCase): 14 | def shepp_logan_setup(self): 15 | img_shape = [6, 6] 16 | mps_shape = [4, 6, 6] 17 | 18 | img = sp.shepp_logan(img_shape) 19 | mps = sim.birdcage_maps(mps_shape) 20 | 21 | mask = np.zeros(img_shape) 22 | mask[:, ::2] = 1 23 | 24 | ksp = mask * sp.fft(mps * img, axes=[-2, -1]) 25 | return img, mps, ksp 26 | 27 | def test_shepp_logan_SenseRecon(self): 28 | img, mps, ksp = self.shepp_logan_setup() 29 | lamda = 0 30 | 31 | for solver in [ 32 | "ConjugateGradient", 33 | "GradientMethod", 34 | "PrimalDualHybridGradient", 35 | "ADMM", 36 | ]: 37 | for coil_batch_size in [None, 1, 2, 3]: 38 | with self.subTest(solver=solver): 39 | img_rec = app.SenseRecon( 40 | ksp, 41 | mps, 42 | lamda, 43 | solver=solver, 44 | coil_batch_size=coil_batch_size, 45 | show_pbar=False, 46 | ).run() 47 | npt.assert_allclose(img, img_rec, atol=1e-2, rtol=1e-2) 48 | 49 | if sp.config.mpi4py_enabled: 50 | 51 | def test_shepp_logan_SenseRecon_with_comm(self): 52 | img, mps, ksp = self.shepp_logan_setup() 53 | lamda = 0 54 | comm = sp.Communicator() 55 | ksp = ksp[comm.rank :: comm.size] 56 | mps = mps[comm.rank :: comm.size] 57 | 58 | for solver in [ 59 | "ConjugateGradient", 60 | "GradientMethod", 61 | "PrimalDualHybridGradient", 62 | "ADMM", 63 | ]: 64 | with self.subTest(solver=solver): 65 | img_rec = app.SenseRecon( 66 | ksp, 67 | mps, 68 | lamda, 69 | comm=comm, 70 | solver=solver, 71 | show_pbar=False, 72 | ).run() 73 | npt.assert_allclose(img, img_rec, atol=1e-2, rtol=1e-2) 74 | 75 | def test_shepp_logan_L1WaveletRecon(self): 76 | img, mps, ksp = self.shepp_logan_setup() 77 | lamda = 0 78 | 79 | for solver in ["GradientMethod", "PrimalDualHybridGradient", "ADMM"]: 80 | with self.subTest(solver=solver): 81 | img_rec = app.L1WaveletRecon( 82 | ksp, mps, lamda, solver=solver, show_pbar=False 83 | ).run() 84 | npt.assert_allclose(img, img_rec, atol=1e-2, rtol=1e-2) 85 | 86 | def test_shepp_logan_TotalVariationRecon(self): 87 | img, mps, ksp = self.shepp_logan_setup() 88 | lamda = 0 89 | for solver in ["PrimalDualHybridGradient", "ADMM"]: 90 | with self.subTest(solver=solver): 91 | img_rec = app.TotalVariationRecon( 92 | ksp, 93 | mps, 94 | lamda, 95 | solver=solver, 96 | max_iter=1000, 97 | show_pbar=False, 98 | ).run() 99 | 100 | npt.assert_allclose(img, img_rec, atol=1e-2, rtol=1e-2) 101 | 102 | def test_ones_JsenseRecon(self): 103 | img_shape = [6, 6] 104 | mps_shape = [4, 6, 6] 105 | 106 | img = np.ones(img_shape, dtype=np.complex128) 107 | mps = sim.birdcage_maps(mps_shape) 108 | ksp = sp.fft(mps * img, axes=[-2, -1]) 109 | 110 | _app = app.JsenseRecon( 111 | ksp, mps_ker_width=6, ksp_calib_width=6, show_pbar=False 112 | ) 113 | mps_rec = _app.run() 114 | 115 | npt.assert_allclose(mps, mps_rec, atol=1e-2, rtol=1e-2) 116 | 117 | def test_espirit_maps(self): 118 | # 2D 119 | mps_shape = [8, 16, 16] 120 | mps = sim.birdcage_maps(mps_shape) 121 | ksp = sp.fft(mps, axes=[-1, -2]) 122 | mps_rec = app.EspiritCalib(ksp, show_pbar=False).run() 123 | 124 | np.testing.assert_allclose( 125 | np.abs(mps)[:, 4:-4, 4:-4], 126 | np.abs(mps_rec[:, 4:-4, 4:-4]), 127 | rtol=1e-2, 128 | atol=1e-2, 129 | ) 130 | 131 | # 3D 132 | mps_shape = [8, 16, 16, 16] 133 | mps = sim.birdcage_maps(mps_shape) 134 | ksp = sp.fft(mps, axes=[-1, -2, -3]) 135 | mps_rec = app.EspiritCalib(ksp, show_pbar=False).run() 136 | 137 | np.testing.assert_allclose( 138 | np.abs(mps)[:, 4:-4, 4:-4, 4:-4], 139 | np.abs(mps_rec[:, 4:-4, 4:-4, 4:-4]), 140 | rtol=1e-2, 141 | atol=1e-2, 142 | ) 143 | 144 | def test_espirit_maps_eig(self): 145 | # 2D 146 | mps_shape = [8, 16, 16] 147 | mps = sim.birdcage_maps(mps_shape) 148 | ksp = sp.fft(mps, axes=[-1, -2]) 149 | mps_rec, eig_val = app.EspiritCalib( 150 | ksp, output_eigenvalue=True, show_pbar=False 151 | ).run() 152 | 153 | np.testing.assert_allclose(eig_val, 1, rtol=0.01, atol=0.01) 154 | 155 | # 3D 156 | mps_shape = [8, 16, 16, 16] 157 | mps = sim.birdcage_maps(mps_shape) 158 | ksp = sp.fft(mps, axes=[-1, -2, -3]) 159 | mps_rec, eig_val = app.EspiritCalib( 160 | ksp, output_eigenvalue=True, show_pbar=False 161 | ).run() 162 | 163 | np.testing.assert_allclose(eig_val, 1, rtol=0.01, atol=0.01) 164 | -------------------------------------------------------------------------------- /tests/mri/test_dcf.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | 3 | import numpy as np 4 | import numpy.testing as npt 5 | 6 | import sigpy as sp 7 | from sigpy.mri import dcf, samp 8 | 9 | if __name__ == "__main__": 10 | unittest.main() 11 | 12 | 13 | class TestApp(unittest.TestCase): 14 | def shepp_logan_setup(self): 15 | img_shape = [16, 16] 16 | coord_shape = [int(16 * np.pi), 16, 2] 17 | 18 | img = sp.shepp_logan(img_shape) 19 | coord = samp.radial(coord_shape, img_shape) 20 | ksp = sp.nufft(img, coord) 21 | return img, coord, ksp 22 | 23 | def test_shepp_logan_dcf(self): 24 | img, coord, ksp = self.shepp_logan_setup() 25 | pm_dcf = dcf.pipe_menon_dcf(coord, show_pbar=False) 26 | img_dcf = sp.nufft_adjoint(ksp * pm_dcf, coord, oshape=img.shape) 27 | img_dcf /= np.abs(img_dcf).max() 28 | npt.assert_allclose(img, img_dcf, atol=1, rtol=1e-1) 29 | -------------------------------------------------------------------------------- /tests/mri/test_linop.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | 3 | import numpy as np 4 | import numpy.testing as npt 5 | 6 | import sigpy as sp 7 | from sigpy.mri import linop 8 | 9 | if __name__ == "__main__": 10 | unittest.main() 11 | 12 | 13 | def check_linop_adjoint(A, dtype=float, device=sp.cpu_device): 14 | device = sp.Device(device) 15 | x = sp.randn(A.ishape, dtype=dtype, device=device) 16 | y = sp.randn(A.oshape, dtype=dtype, device=device) 17 | 18 | xp = device.xp 19 | with device: 20 | lhs = xp.vdot(A * x, y) 21 | rhs = xp.vdot(x, A.H * y) 22 | 23 | xp.testing.assert_allclose(lhs, rhs, atol=1e-5, rtol=1e-5) 24 | 25 | 26 | class TestLinop(unittest.TestCase): 27 | def test_sense_model(self): 28 | img_shape = [16, 16] 29 | mps_shape = [8, 16, 16] 30 | 31 | img = sp.randn(img_shape, dtype=np.complex128) 32 | mps = sp.randn(mps_shape, dtype=np.complex128) 33 | 34 | A = linop.Sense(mps) 35 | 36 | check_linop_adjoint(A, dtype=np.complex128) 37 | 38 | npt.assert_allclose(sp.fft(img * mps, axes=[-1, -2]), A * img) 39 | 40 | def test_sense_model_batch(self): 41 | img_shape = [16, 16] 42 | mps_shape = [8, 16, 16] 43 | 44 | img = sp.randn(img_shape, dtype=np.complex128) 45 | mps = sp.randn(mps_shape, dtype=np.complex128) 46 | 47 | for coil_batch_size in [None, 1, 2, 3]: 48 | A = linop.Sense(mps, coil_batch_size=coil_batch_size) 49 | check_linop_adjoint(A, dtype=np.complex128) 50 | npt.assert_allclose(sp.fft(img * mps, axes=[-1, -2]), A * img) 51 | 52 | def test_noncart_sense_model(self): 53 | img_shape = [16, 16] 54 | mps_shape = [8, 16, 16] 55 | 56 | img = sp.randn(img_shape, dtype=np.complex128) 57 | mps = sp.randn(mps_shape, dtype=np.complex128) 58 | 59 | y, x = np.mgrid[:16, :16] 60 | coord = np.stack([np.ravel(y - 8), np.ravel(x - 8)], axis=1) 61 | coord = coord.astype(np.float64) 62 | 63 | A = linop.Sense(mps, coord=coord) 64 | check_linop_adjoint(A, dtype=np.complex128) 65 | npt.assert_allclose( 66 | sp.fft(img * mps, axes=[-1, -2]).ravel(), 67 | (A * img).ravel(), 68 | atol=0.1, 69 | rtol=0.1, 70 | ) 71 | 72 | def test_sense_tseg_off_res_model(self): 73 | img_shape = [16, 16] 74 | mps_shape = [8, 16, 16] 75 | 76 | img = sp.randn(img_shape, dtype=np.complex128) 77 | mps = sp.randn(mps_shape, dtype=np.complex128) 78 | 79 | y, x = np.mgrid[:16, :16] 80 | coord = np.stack([np.ravel(y - 8), np.ravel(x - 8)], axis=1) 81 | coord = coord.astype(np.float64) 82 | 83 | d = np.sqrt(x * x + y * y) 84 | sigma, mu, a = 2, 0.25, 400 85 | b0 = a * np.exp(-((d - mu) ** 2 / (2.0 * sigma**2))) 86 | tseg = {"b0": b0, "dt": 4e-6, "lseg": 1, "n_bins": 10} 87 | 88 | F = sp.linop.NUFFT(mps_shape, coord) 89 | b, ct = sp.mri.util.tseg_off_res_b_ct( 90 | b0=b0, bins=10, lseg=1, dt=4e-6, T=coord.shape[0] * 4e-6 91 | ) 92 | B1 = sp.linop.Multiply(F.oshape, b.T) 93 | Ct1 = sp.linop.Multiply(img_shape, ct.reshape(img_shape)) 94 | S = sp.linop.Multiply(img_shape, mps) 95 | 96 | A = linop.Sense(mps, coord=coord, tseg=tseg) 97 | 98 | check_linop_adjoint(A, dtype=np.complex128) 99 | npt.assert_allclose(B1 * F * S * Ct1 * img, A * img) 100 | 101 | def test_noncart_sense_model_batch(self): 102 | img_shape = [16, 16] 103 | mps_shape = [8, 16, 16] 104 | 105 | img = sp.randn(img_shape, dtype=np.complex128) 106 | mps = sp.randn(mps_shape, dtype=np.complex128) 107 | 108 | y, x = np.mgrid[:16, :16] 109 | coord = np.stack([np.ravel(y - 8), np.ravel(x - 8)], axis=1) 110 | coord = coord.astype(np.float64) 111 | 112 | for coil_batch_size in [None, 1, 2, 3]: 113 | A = linop.Sense(mps, coord=coord, coil_batch_size=coil_batch_size) 114 | check_linop_adjoint(A, dtype=np.complex128) 115 | npt.assert_allclose( 116 | sp.fft(img * mps, axes=[-1, -2]).ravel(), 117 | (A * img).ravel(), 118 | atol=0.1, 119 | rtol=0.1, 120 | ) 121 | 122 | if sp.config.mpi4py_enabled: 123 | 124 | def test_sense_model_with_comm(self): 125 | img_shape = [16, 16] 126 | mps_shape = [8, 16, 16] 127 | comm = sp.Communicator() 128 | 129 | img = sp.randn(img_shape, dtype=np.complex128) 130 | mps = sp.randn(mps_shape, dtype=np.complex128) 131 | comm.allreduce(img) 132 | comm.allreduce(mps) 133 | ksp = sp.fft(img * mps, axes=[-1, -2]) 134 | 135 | A = linop.Sense(mps[comm.rank :: comm.size], comm=comm) 136 | 137 | npt.assert_allclose( 138 | A.H(ksp[comm.rank :: comm.size]), 139 | np.sum(sp.ifft(ksp, axes=[-1, -2]) * mps.conjugate(), 0), 140 | ) 141 | -------------------------------------------------------------------------------- /tests/mri/test_precond.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | 3 | import numpy as np 4 | import numpy.testing as npt 5 | 6 | import sigpy as sp 7 | from sigpy.mri import linop, precond 8 | 9 | if __name__ == "__main__": 10 | unittest.main() 11 | 12 | 13 | class TestPrecond(unittest.TestCase): 14 | def test_kspace_precond_cart(self): 15 | nc = 4 16 | n = 10 17 | shape = (nc, n) 18 | mps = sp.randn(shape, dtype=np.complex128) 19 | mps /= np.linalg.norm(mps, axis=0, keepdims=True) 20 | weights = sp.randn([n]) >= 0 21 | 22 | A = sp.linop.Multiply(shape, weights**0.5) * linop.Sense(mps) 23 | 24 | AAH = np.zeros((nc, n, nc, n), dtype=np.complex128) 25 | for d in range(nc): 26 | for j in range(n): 27 | x = np.zeros((nc, n), dtype=np.complex128) 28 | x[d, j] = 1.0 29 | AAHx = A(A.H(x)) 30 | 31 | for c in range(nc): 32 | for i in range(n): 33 | AAH[c, i, d, j] = AAHx[c, i] 34 | 35 | p_expected = np.ones((nc, n), dtype=np.complex128) 36 | for c in range(nc): 37 | for i in range(n): 38 | if weights[i]: 39 | p_expected_inv_ic = 0 40 | for d in range(nc): 41 | for j in range(n): 42 | p_expected_inv_ic += abs( 43 | AAH[c, i, d, j] 44 | ) ** 2 / abs(AAH[c, i, c, i]) 45 | 46 | p_expected[c, i] = 1 / p_expected_inv_ic 47 | 48 | p = precond.kspace_precond(mps, weights=weights) 49 | npt.assert_allclose( 50 | p[:, weights == 1], 51 | p_expected[:, weights == 1], 52 | atol=1e-6, 53 | rtol=1e-6, 54 | ) 55 | 56 | def test_kspace_precond_noncart(self): 57 | n = 10 58 | nc = 3 59 | shape = [nc, n] 60 | mps = sp.randn(shape, dtype=np.complex128) 61 | mps /= np.linalg.norm(mps, axis=0, keepdims=True) 62 | coord = sp.randn([n, 1], dtype=np.float64) 63 | 64 | A = linop.Sense(mps, coord=coord) 65 | 66 | AAH = np.zeros((nc, n, nc, n), dtype=np.complex128) 67 | for d in range(nc): 68 | for j in range(n): 69 | x = np.zeros(shape, dtype=np.complex128) 70 | x[d, j] = 1.0 71 | AAHx = A(A.H(x)) 72 | for c in range(nc): 73 | for i in range(n): 74 | AAH[c, i, d, j] = AAHx[c, i] 75 | 76 | p_expected = np.zeros([nc, n], dtype=np.complex128) 77 | for c in range(nc): 78 | for i in range(n): 79 | p_expected_inv_ic = 0 80 | for d in range(nc): 81 | for j in range(n): 82 | p_expected_inv_ic += abs(AAH[c, i, d, j]) ** 2 / abs( 83 | AAH[c, i, c, i] 84 | ) 85 | 86 | p_expected[c, i] = 1 / p_expected_inv_ic 87 | 88 | p = precond.kspace_precond(mps, coord=coord) 89 | npt.assert_allclose(p, p_expected, atol=1e-2, rtol=1e-2) 90 | 91 | def test_kspace_precond_simple_cart(self): 92 | # Check identity 93 | mps_shape = [1, 1] 94 | mps = np.ones(mps_shape, dtype=np.complex128) 95 | p = precond.kspace_precond(mps) 96 | npt.assert_allclose(p, np.ones(mps_shape), atol=1e-6, rtol=1e-6) 97 | 98 | # Check scaling 99 | mps_shape = [1, 3] 100 | mps = np.ones(mps_shape, dtype=np.complex128) 101 | p = precond.kspace_precond(mps) 102 | npt.assert_allclose(p, np.ones(mps_shape), atol=1e-6, rtol=1e-6) 103 | 104 | # Check 2d 105 | mps_shape = [1, 3, 3] 106 | mps = np.ones(mps_shape, dtype=np.complex128) 107 | p = precond.kspace_precond(mps) 108 | npt.assert_allclose(p, np.ones(mps_shape), atol=1e-6, rtol=1e-6) 109 | 110 | # Check weights 111 | mps_shape = [1, 3] 112 | mps = np.ones(mps_shape, dtype=np.complex128) 113 | weights = np.array([1, 0, 1], dtype=np.complex128) 114 | p = precond.kspace_precond(mps, weights=weights) 115 | npt.assert_allclose(p, [[1, 1, 1]], atol=1e-6, rtol=1e-6) 116 | 117 | def test_kspace_precond_simple_noncart(self): 118 | # Check identity 119 | mps_shape = [1, 1] 120 | 121 | mps = np.ones(mps_shape, dtype=np.complex128) 122 | coord = np.array([[0.0]]) 123 | p = precond.kspace_precond(mps, coord=coord) 124 | npt.assert_allclose(p, [[1.0]], atol=1, rtol=1e-1) 125 | 126 | mps_shape = [1, 3] 127 | 128 | mps = np.ones(mps_shape, dtype=np.complex128) 129 | coord = np.array([[0.0], [-1], [1]]) 130 | p = precond.kspace_precond(mps, coord=coord) 131 | npt.assert_allclose(p, [[1.0, 1.0, 1.0]], atol=1, rtol=1e-1) 132 | 133 | def test_circulant_precond_cart(self): 134 | nc = 4 135 | n = 10 136 | shape = (nc, n) 137 | mps = sp.randn(shape, dtype=np.complex128) 138 | mps /= np.linalg.norm(mps, axis=0, keepdims=True) 139 | weights = sp.randn([n]) >= 0 140 | 141 | A = sp.linop.Multiply(shape, weights**0.5) * linop.Sense(mps) 142 | F = sp.linop.FFT([n]) 143 | 144 | p_expected = np.zeros(n, dtype=np.complex128) 145 | for i in range(n): 146 | if weights[i]: 147 | x = np.zeros(n, dtype=np.complex128) 148 | x[i] = 1.0 149 | p_expected[i] = 1 / F(A.H(A(F.H(x))))[i] 150 | 151 | p = precond.circulant_precond(mps, weights=weights) 152 | npt.assert_allclose( 153 | p[weights == 1], p_expected[weights == 1], atol=1e-6, rtol=1e-6 154 | ) 155 | 156 | def test_circulant_precond_noncart(self): 157 | nc = 4 158 | n = 10 159 | shape = [nc, n] 160 | mps = np.ones(shape, dtype=np.complex128) 161 | mps /= np.linalg.norm(mps, axis=0, keepdims=True) 162 | coord = sp.randn([n, 1], dtype=np.float64) 163 | 164 | A = linop.Sense(mps, coord=coord) 165 | F = sp.linop.FFT([n]) 166 | 167 | p_expected = np.zeros(n, dtype=np.complex128) 168 | for i in range(n): 169 | x = np.zeros(n, dtype=np.complex128) 170 | x[i] = 1.0 171 | p_expected[i] = 1 / F(A.H(A(F.H(x))))[i] 172 | 173 | p = precond.circulant_precond(mps, coord=coord) 174 | npt.assert_allclose(p, p_expected, atol=1e-1, rtol=1e-1) 175 | -------------------------------------------------------------------------------- /tests/mri/test_samp.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | 3 | import numpy as np 4 | import numpy.testing as npt 5 | 6 | from sigpy.mri import samp 7 | 8 | if __name__ == "__main__": 9 | unittest.main() 10 | 11 | 12 | class TestPoisson(unittest.TestCase): 13 | """Test poisson undersampling defined in `sigpy.mri.samp.poisson`.""" 14 | 15 | def test_numpy_random_state(self): 16 | """Verify that random state is unchanged when seed is specified.""" 17 | np.random.seed(0) 18 | expected_state = np.random.get_state() 19 | 20 | _ = samp.poisson((120, 120), accel=6, seed=80) 21 | 22 | state = np.random.get_state() 23 | assert (expected_state[1] == state[1]).all() 24 | 25 | def test_reproducibility(self): 26 | """Verify that poisson is reproducible.""" 27 | np.random.seed(45) 28 | mask1 = samp.poisson((120, 120), accel=6, seed=80) 29 | 30 | # Changing internal numpy state should not affect mask. 31 | np.random.seed(20) 32 | mask2 = samp.poisson((120, 120), accel=6, seed=80) 33 | 34 | npt.assert_allclose(mask2, mask1) 35 | 36 | def test_poisson_accel(self): 37 | """Verify that poisson generates the correct acceleration.""" 38 | for x in [60, 120]: 39 | for y in [60, 120]: 40 | for tol in [0.1, 0.2]: 41 | for accel in [4, 5, 6, 7, 8]: 42 | mask = samp.poisson( 43 | (x, y), accel=accel, seed=80, tol=tol 44 | ) 45 | assert abs(mask.size / np.sum(mask) - accel) < tol 46 | -------------------------------------------------------------------------------- /tests/test_alg.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | 3 | import numpy as np 4 | import numpy.testing as npt 5 | 6 | from sigpy import alg, linop 7 | 8 | if __name__ == "__main__": 9 | unittest.main() 10 | 11 | 12 | class TestAlg(unittest.TestCase): 13 | def Ax_setup(self, n): 14 | A = np.eye(n) + 0.1 * np.ones([n, n]) 15 | x = np.arange(n, dtype=np.float32) 16 | return A, x 17 | 18 | def Ax_y_setup(self, n, lamda): 19 | A, x = self.Ax_setup(n) 20 | y = A @ x 21 | x_numpy = np.linalg.solve(A.T @ A + lamda * np.eye(n), A.T @ y) 22 | 23 | return A, x_numpy, y 24 | 25 | def test_PowerMethod(self): 26 | n = 5 27 | A, x = self.Ax_setup(n) 28 | x_hat = np.random.random([n, 1]) 29 | alg_method = alg.PowerMethod(lambda x: A.T @ A @ x, x_hat) 30 | while not alg_method.done(): 31 | alg_method.update() 32 | 33 | s_numpy = np.linalg.svd(A, compute_uv=False)[0] 34 | s_sigpy = np.linalg.norm(A @ x_hat) 35 | npt.assert_allclose(s_numpy, s_sigpy, atol=1e-3) 36 | 37 | def test_GradientMethod(self): 38 | n = 5 39 | lamda = 0.1 40 | A, x_numpy, y = self.Ax_y_setup(n, lamda) 41 | 42 | # Compute step-size 43 | lipschitz = np.linalg.svd( 44 | A.T @ A + lamda * np.eye(n), compute_uv=False 45 | )[0] 46 | alpha = 1.0 / lipschitz 47 | 48 | for accelerate in [True, False]: 49 | for proxg in [None, lambda alpha, x: x / (1 + lamda * alpha)]: 50 | with self.subTest(accelerate=accelerate, proxg=proxg): 51 | x_sigpy = np.zeros([n]) 52 | 53 | def gradf(x): 54 | gradf_x = A.T @ (A @ x - y) 55 | if proxg is None: 56 | gradf_x += lamda * x 57 | 58 | return gradf_x 59 | 60 | alg_method = alg.GradientMethod( 61 | gradf, 62 | x_sigpy, 63 | alpha, 64 | accelerate=accelerate, 65 | proxg=proxg, 66 | max_iter=1000, 67 | ) 68 | 69 | while not alg_method.done(): 70 | alg_method.update() 71 | 72 | npt.assert_allclose(x_sigpy, x_numpy) 73 | 74 | def test_ConjugateGradient(self): 75 | n = 5 76 | lamda = 0.1 77 | A, x_numpy, y = self.Ax_y_setup(n, lamda) 78 | x = np.zeros([n]) 79 | alg_method = alg.ConjugateGradient( 80 | lambda x: A.T @ A @ x + lamda * x, A.T @ y, x, max_iter=1000 81 | ) 82 | while not alg_method.done(): 83 | alg_method.update() 84 | 85 | npt.assert_allclose(x, x_numpy) 86 | 87 | def test_PrimalDualHybridGradient(self): 88 | n = 5 89 | lamda = 0.1 90 | A, x_numpy, y = self.Ax_y_setup(n, lamda) 91 | 92 | # Compute step-size 93 | lipschitz = np.linalg.svd(np.matmul(A.T, A), compute_uv=False)[0] 94 | tau = 1.0 / lipschitz 95 | sigma = 1.0 96 | 97 | x = np.zeros([n]) 98 | u = np.zeros([n]) 99 | alg_method = alg.PrimalDualHybridGradient( 100 | lambda alpha, u: (u - alpha * y) / (1 + alpha), 101 | lambda alpha, x: x / (1 + lamda * alpha), 102 | lambda x: A @ x, 103 | lambda x: A.T @ x, 104 | x, 105 | u, 106 | tau, 107 | sigma, 108 | max_iter=1000, 109 | ) 110 | while not alg_method.done(): 111 | alg_method.update() 112 | 113 | npt.assert_allclose(x, x_numpy) 114 | 115 | def test_AugmentedLagrangianMethod(self): 116 | n = 5 117 | lamda = 0.1 118 | A, x_numpy, y = self.Ax_y_setup(n, lamda) 119 | 120 | # Solve 1 / 2 \| A x - y \|_2^2 + lamda * \| z \|_2^2 s.t. x = z 121 | mu = 1 122 | x_z = np.zeros([2 * n]) 123 | v = np.zeros([n]) 124 | 125 | def minL(): 126 | x = x_z[:n] 127 | z = x_z[n:] 128 | x[:] = np.linalg.solve( 129 | A.T @ A + mu * np.eye(n), A.T @ y - v + mu * z 130 | ) 131 | z[:] = (mu * x + v) / (mu + lamda) 132 | 133 | def h(x_z): 134 | x = x_z[:n] 135 | z = x_z[n:] 136 | return x - z 137 | 138 | alg_method = alg.AugmentedLagrangianMethod( 139 | minL, None, h, x_z, None, v, mu 140 | ) 141 | while not alg_method.done(): 142 | alg_method.update() 143 | 144 | x = x_z[:n] 145 | npt.assert_allclose(x, x_numpy) 146 | 147 | def test_NewtonsMethod(self): 148 | n = 5 149 | lamda = 0.1 150 | A, x_numpy, y = self.Ax_y_setup(n, lamda) 151 | 152 | def gradf(x): 153 | gradf_x = A.T @ (A @ x - y) 154 | gradf_x += lamda * x 155 | 156 | return gradf_x 157 | 158 | def inv_hessf(x): 159 | Id = np.eye(n) 160 | return lambda x: np.linalg.pinv(A.T @ A + lamda * Id) @ x 161 | 162 | for beta in [1, 0.5]: 163 | with self.subTest(beta=beta): 164 | if beta < 1: 165 | 166 | def f(x): 167 | f_x = 1 / 2 * np.linalg.norm(A @ x - y) ** 2 168 | f_x += lamda / 2 * np.linalg.norm(x) ** 2 169 | 170 | return f_x 171 | 172 | else: 173 | f = None 174 | 175 | x = np.zeros(n) 176 | alg_method = alg.NewtonsMethod( 177 | gradf, inv_hessf, x, beta=beta, f=f 178 | ) 179 | while not alg_method.done(): 180 | alg_method.update() 181 | 182 | npt.assert_allclose(x, x_numpy) 183 | 184 | def test_GerchbergSaxton(self): 185 | n = 10 186 | lamda = 0.1 187 | A, x_numpy, y = self.Ax_y_setup(n, lamda) 188 | y = np.expand_dims(np.csingle(y), 1) 189 | x_numpy = np.expand_dims(x_numpy, 1) 190 | A = np.csingle(A) 191 | A = linop.MatMul(y.shape, A) 192 | x0 = np.zeros(A.ishape, dtype=np.complex128) 193 | 194 | alg_method = alg.GerchbergSaxton( 195 | A, y, x0, max_iter=100, tol=10e-9, lamb=lamda 196 | ) 197 | 198 | while not alg_method.done(): 199 | alg_method.update() 200 | 201 | phs = np.conj(x_numpy * alg_method.x / abs(x_numpy * alg_method.x)) 202 | npt.assert_allclose(alg_method.x * phs, x_numpy, rtol=1e-6) 203 | 204 | def test_SDMM(self): 205 | n = 5 206 | lamda = 0.1 207 | A, x_numpy, y = self.Ax_y_setup(n, lamda) 208 | y = np.expand_dims(y, 1) 209 | A = linop.MatMul(np.expand_dims(x_numpy, 1).shape, A) 210 | 211 | c_norm = None 212 | c_max = None 213 | mu = 10**8 # big mu ok since no constraints used 214 | rho_norm = 1 215 | rho_max = 1 216 | lam = 0.1 217 | cg_iters = 5 218 | max_iter = 10 219 | 220 | L = [] 221 | c = [1] 222 | rho = [1] 223 | for ii in range(len(y) - 1): 224 | c.append(0.00012**2) 225 | rho.append(0.001) 226 | 227 | alg_method = alg.SDMM( 228 | A, 229 | y, 230 | lam, 231 | L=L, 232 | c=c, 233 | c_max=c_max, 234 | c_norm=c_norm, 235 | mu=mu, 236 | rho=rho, 237 | rho_max=rho_max, 238 | rho_norm=rho_norm, 239 | eps_pri=10**-5, 240 | eps_dual=10**-2, 241 | max_cg_iter=cg_iters, 242 | max_iter=max_iter, 243 | ) 244 | 245 | while not alg_method.done(): 246 | alg_method.update() 247 | 248 | npt.assert_allclose(np.squeeze(abs(alg_method.x)), x_numpy) 249 | -------------------------------------------------------------------------------- /tests/test_app.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | 3 | import numpy as np 4 | import numpy.testing as npt 5 | 6 | from sigpy import app, linop, prox, util 7 | 8 | if __name__ == "__main__": 9 | unittest.main() 10 | 11 | 12 | class TestApp(unittest.TestCase): 13 | def test_MaxEig(self): 14 | n = 5 15 | mat = util.randn([n, n]) 16 | A = linop.MatMul([n, 1], mat) 17 | s = np.linalg.svd(mat, compute_uv=False) 18 | 19 | npt.assert_allclose( 20 | app.MaxEig(A.H * A, max_iter=1000, show_pbar=False).run(), 21 | s[0] ** 2, 22 | atol=1e-3, 23 | ) 24 | 25 | def test_LinearLeastSquares(self): 26 | n = 5 27 | _A = np.eye(n) + 0.1 * np.ones([n, n]) 28 | A = linop.MatMul([n, 1], _A) 29 | x = np.arange(n).reshape([n, 1]) 30 | y = A(x) 31 | 32 | for z in [None, x.copy()]: 33 | for lamda in [0, 0.1]: 34 | for proxg in [None, prox.L2Reg([n, 1], lamda)]: 35 | for solver in [ 36 | "GradientMethod", 37 | "PrimalDualHybridGradient", 38 | "ConjugateGradient", 39 | "ADMM", 40 | ]: 41 | with self.subTest( 42 | proxg=proxg, solver=solver, lamda=lamda, z=z 43 | ): 44 | AHA = _A.T @ _A + lamda * np.eye(n) 45 | AHy = _A.T @ y 46 | if proxg is not None: 47 | AHA += lamda * np.eye(n) 48 | 49 | if z is not None: 50 | AHy = _A.T @ y + lamda * z 51 | 52 | x_numpy = np.linalg.solve(AHA, AHy) 53 | if ( 54 | solver == "ConjugateGradient" 55 | and proxg is not None 56 | ): 57 | with self.assertRaises(ValueError): 58 | app.LinearLeastSquares( 59 | A, 60 | y, 61 | solver=solver, 62 | lamda=lamda, 63 | proxg=proxg, 64 | z=z, 65 | show_pbar=False, 66 | ).run() 67 | else: 68 | x_rec = app.LinearLeastSquares( 69 | A, 70 | y, 71 | solver=solver, 72 | lamda=lamda, 73 | proxg=proxg, 74 | z=z, 75 | show_pbar=False, 76 | ).run() 77 | 78 | npt.assert_allclose(x_rec, x_numpy, atol=1e-3) 79 | 80 | def test_precond_LinearLeastSquares(self): 81 | n = 5 82 | _A = np.eye(n) + 0.01 * util.randn([n, n]) 83 | A = linop.MatMul([n, 1], _A) 84 | x = util.randn([n, 1]) 85 | y = A(x) 86 | x_lstsq = np.linalg.lstsq(_A, y, rcond=-1)[0] 87 | p = 1 / (np.sum(abs(_A) ** 2, axis=0).reshape([n, 1])) 88 | 89 | P = linop.Multiply([n, 1], p) 90 | x_rec = app.LinearLeastSquares(A, y, show_pbar=False).run() 91 | npt.assert_allclose(x_rec, x_lstsq, atol=1e-3) 92 | 93 | alpha = 1 / app.MaxEig(P * A.H * A, show_pbar=False).run() 94 | x_rec = app.LinearLeastSquares( 95 | A, 96 | y, 97 | solver="GradientMethod", 98 | alpha=alpha, 99 | max_power_iter=100, 100 | max_iter=1000, 101 | show_pbar=False, 102 | ).run() 103 | npt.assert_allclose(x_rec, x_lstsq, atol=1e-3) 104 | 105 | tau = p 106 | x_rec = app.LinearLeastSquares( 107 | A, 108 | y, 109 | solver="PrimalDualHybridGradient", 110 | max_iter=1000, 111 | tau=tau, 112 | show_pbar=False, 113 | ).run() 114 | npt.assert_allclose(x_rec, x_lstsq, atol=1e-3) 115 | 116 | def test_dual_precond_LinearLeastSquares(self): 117 | n = 5 118 | _A = np.eye(n) + 0.1 * util.randn([n, n]) 119 | A = linop.MatMul([n, 1], _A) 120 | x = util.randn([n, 1]) 121 | y = A(x) 122 | x_lstsq = np.linalg.lstsq(_A, y, rcond=-1)[0] 123 | 124 | d = 1 / np.sum(abs(_A) ** 2, axis=1, keepdims=True).reshape([n, 1]) 125 | x_rec = app.LinearLeastSquares( 126 | A, 127 | y, 128 | solver="PrimalDualHybridGradient", 129 | max_iter=1000, 130 | sigma=d, 131 | show_pbar=False, 132 | ).run() 133 | npt.assert_allclose(x_rec, x_lstsq, atol=1e-3) 134 | -------------------------------------------------------------------------------- /tests/test_block.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | 3 | import numpy as np 4 | 5 | from sigpy import block, config 6 | 7 | if config.cupy_enabled: 8 | import cupy as cp 9 | 10 | if __name__ == "__main__": 11 | unittest.main() 12 | 13 | 14 | class TestInterp(unittest.TestCase): 15 | def test_array_to_blocks(self): 16 | xps = [np] 17 | if config.cupy_enabled: 18 | xps.append(cp) 19 | 20 | for xp in xps: 21 | for dtype in [np.float32, np.complex64]: 22 | for ndim in [1, 2, 3]: 23 | with self.subTest(xp=xp, dtype=dtype, ndim=ndim): 24 | input = xp.array( 25 | [0, 1, 2, 3, 4, 5], dtype=dtype 26 | ).reshape([6] + [1] * (ndim - 1)) 27 | 28 | blk_shape = [1] + [1] * (ndim - 1) 29 | blk_strides = [1] + [1] * (ndim - 1) 30 | output = xp.array( 31 | [[0], [1], [2], [3], [4], [5]], dtype=dtype 32 | ).reshape( 33 | [6] + [1] * (ndim - 1) + [1] + [1] * (ndim - 1) 34 | ) 35 | xp.testing.assert_allclose( 36 | output, 37 | block.array_to_blocks( 38 | input, blk_shape, blk_strides 39 | ), 40 | ) 41 | 42 | blk_shape = [2] + [1] * (ndim - 1) 43 | blk_strides = [1] + [1] * (ndim - 1) 44 | output = xp.array( 45 | [[0, 1], [1, 2], [2, 3], [3, 4], [4, 5]], 46 | dtype=dtype, 47 | ).reshape( 48 | [5] + [1] * (ndim - 1) + [2] + [1] * (ndim - 1) 49 | ) 50 | xp.testing.assert_allclose( 51 | output, 52 | block.array_to_blocks( 53 | input, blk_shape, blk_strides 54 | ), 55 | ) 56 | 57 | blk_shape = [2] + [1] * (ndim - 1) 58 | blk_strides = [2] + [1] * (ndim - 1) 59 | output = xp.array( 60 | [[0, 1], [2, 3], [4, 5]], dtype=dtype 61 | ).reshape( 62 | [3] + [1] * (ndim - 1) + [2] + [1] * (ndim - 1) 63 | ) 64 | xp.testing.assert_allclose( 65 | output, 66 | block.array_to_blocks( 67 | input, blk_shape, blk_strides 68 | ), 69 | ) 70 | 71 | blk_shape = [3] + [1] * (ndim - 1) 72 | blk_strides = [2] + [1] * (ndim - 1) 73 | output = xp.array( 74 | [[0, 1, 2], [2, 3, 4]], dtype=dtype 75 | ).reshape( 76 | [2] + [1] * (ndim - 1) + [3] + [1] * (ndim - 1) 77 | ) 78 | xp.testing.assert_allclose( 79 | output, 80 | block.array_to_blocks( 81 | input, blk_shape, blk_strides 82 | ), 83 | ) 84 | 85 | def test_blocks_to_array(self): 86 | xps = [np] 87 | if config.cupy_enabled: 88 | xps.append(cp) 89 | 90 | for xp in xps: 91 | for dtype in [np.float32, np.complex64]: 92 | for ndim in [1, 2, 3]: 93 | with self.subTest(xp=xp, dtype=dtype, ndim=ndim): 94 | shape = [6] + [1] * (ndim - 1) 95 | 96 | blk_shape = [1] + [1] * (ndim - 1) 97 | blk_strides = [1] + [1] * (ndim - 1) 98 | input = xp.array( 99 | [[0], [1], [2], [3], [4], [5]], dtype=dtype 100 | ).reshape( 101 | [6] + [1] * (ndim - 1) + [1] + [1] * (ndim - 1) 102 | ) 103 | output = xp.array( 104 | [0, 1, 2, 3, 4, 5], dtype=dtype 105 | ).reshape([6] + [1] * (ndim - 1)) 106 | xp.testing.assert_allclose( 107 | output, 108 | block.blocks_to_array( 109 | input, shape, blk_shape, blk_strides 110 | ), 111 | ) 112 | 113 | blk_shape = [2] + [1] * (ndim - 1) 114 | blk_strides = [1] + [1] * (ndim - 1) 115 | input = xp.array( 116 | [[0, 1], [1, 2], [2, 3], [3, 4], [4, 5]], 117 | dtype=dtype, 118 | ).reshape( 119 | [5] + [1] * (ndim - 1) + [2] + [1] * (ndim - 1) 120 | ) 121 | output = xp.array( 122 | [0, 2, 4, 6, 8, 5], dtype=dtype 123 | ).reshape([6] + [1] * (ndim - 1)) 124 | xp.testing.assert_allclose( 125 | output, 126 | block.blocks_to_array( 127 | input, shape, blk_shape, blk_strides 128 | ), 129 | ) 130 | 131 | blk_shape = [2] + [1] * (ndim - 1) 132 | blk_strides = [2] + [1] * (ndim - 1) 133 | input = xp.array( 134 | [[0, 1], [2, 3], [4, 5]], dtype=dtype 135 | ).reshape( 136 | [3] + [1] * (ndim - 1) + [2] + [1] * (ndim - 1) 137 | ) 138 | output = xp.array( 139 | [0, 1, 2, 3, 4, 5], dtype=dtype 140 | ).reshape([6] + [1] * (ndim - 1)) 141 | xp.testing.assert_allclose( 142 | output, 143 | block.blocks_to_array( 144 | input, shape, blk_shape, blk_strides 145 | ), 146 | ) 147 | 148 | blk_shape = [3] + [1] * (ndim - 1) 149 | blk_strides = [2] + [1] * (ndim - 1) 150 | input = xp.array( 151 | [[0, 1, 2], [2, 3, 4]], dtype=dtype 152 | ).reshape( 153 | [2] + [1] * (ndim - 1) + [3] + [1] * (ndim - 1) 154 | ) 155 | output = xp.array( 156 | [0, 1, 4, 3, 4, 0], dtype=dtype 157 | ).reshape([6] + [1] * (ndim - 1)) 158 | xp.testing.assert_allclose( 159 | output, 160 | block.blocks_to_array( 161 | input, shape, blk_shape, blk_strides 162 | ), 163 | ) 164 | -------------------------------------------------------------------------------- /tests/test_interp.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | 3 | import numpy as np 4 | 5 | from sigpy import config, interp 6 | 7 | if config.cupy_enabled: 8 | import cupy as cp 9 | 10 | if __name__ == "__main__": 11 | unittest.main() 12 | 13 | 14 | class TestInterp(unittest.TestCase): 15 | def test_interpolate_spline(self): 16 | xps = [np] 17 | if config.cupy_enabled: 18 | xps.append(cp) 19 | 20 | batch = 2 21 | for xp in xps: 22 | for ndim in [1, 2, 3]: 23 | for dtype in [np.float32, np.complex64]: 24 | with self.subTest(ndim=ndim, xp=xp, dtype=dtype): 25 | shape = [3] + [1] * (ndim - 1) 26 | coord = xp.array( 27 | [ 28 | [0.1] + [0] * (ndim - 1), 29 | [1.1] + [0] * (ndim - 1), 30 | [2.1] + [0] * (ndim - 1), 31 | ] 32 | ) 33 | 34 | input = xp.array([[0, 1.0, 0]] * batch, dtype=dtype) 35 | input = input.reshape([batch] + shape) 36 | output = interp.interpolate(input, coord) 37 | output_expected = xp.array([[0.1, 0.9, 0]] * batch) 38 | xp.testing.assert_allclose( 39 | output, output_expected, atol=1e-7 40 | ) 41 | 42 | def test_gridding_spline(self): 43 | xps = [np] 44 | if config.cupy_enabled: 45 | xps.append(cp) 46 | 47 | batch = 2 48 | for xp in xps: 49 | for ndim in [1, 2, 3]: 50 | for dtype in [np.float32, np.complex64]: 51 | with self.subTest(ndim=ndim, xp=xp, dtype=dtype): 52 | shape = [3] + [1] * (ndim - 1) 53 | coord = xp.array( 54 | [ 55 | [0.1] + [0] * (ndim - 1), 56 | [1.1] + [0] * (ndim - 1), 57 | [2.1] + [0] * (ndim - 1), 58 | ] 59 | ) 60 | 61 | input = xp.array([[0, 1.0, 0]] * batch, dtype=dtype) 62 | output = interp.gridding(input, coord, [batch] + shape) 63 | output_expected = xp.array( 64 | [[0, 0.9, 0.1]] * batch 65 | ).reshape([batch] + shape) 66 | xp.testing.assert_allclose( 67 | output, output_expected, atol=1e-7 68 | ) 69 | 70 | if config.cupy_enabled: 71 | 72 | def test_interpolate_cpu_gpu(self): 73 | for ndim in [1, 2, 3]: 74 | for dtype in [np.float32, np.complex64]: 75 | with self.subTest(ndim=ndim, dtype=dtype): 76 | shape = [2, 20] + [1] * (ndim - 1) 77 | coord = np.random.random([10, ndim]) 78 | 79 | input = np.random.random(shape).astype(dtype=dtype) 80 | output_cpu = interp.interpolate( 81 | input, coord, kernel="kaiser_bessel" 82 | ) 83 | output_gpu = interp.interpolate( 84 | cp.array(input), 85 | cp.array(coord), 86 | kernel="kaiser_bessel", 87 | ).get() 88 | np.testing.assert_allclose( 89 | output_cpu, output_gpu, atol=1e-5 90 | ) 91 | 92 | def test_gridding_cpu_gpu(self): 93 | for ndim in [1, 2, 3]: 94 | for dtype in [np.float32, np.complex64]: 95 | with self.subTest(ndim=ndim, dtype=dtype): 96 | shape = [2, 20] + [1] * (ndim - 1) 97 | coord = np.random.random([10, ndim]) 98 | 99 | input = np.random.random( 100 | [2, 10] + [1] * (ndim - 1) 101 | ).astype(dtype=dtype) 102 | output_cpu = interp.gridding( 103 | input, coord, shape, kernel="kaiser_bessel" 104 | ) 105 | output_gpu = interp.gridding( 106 | cp.array(input), 107 | cp.array(coord), 108 | shape, 109 | kernel="kaiser_bessel", 110 | ).get() 111 | np.testing.assert_allclose( 112 | output_cpu, output_gpu, atol=1e-5 113 | ) 114 | -------------------------------------------------------------------------------- /tests/test_prox.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | 3 | import numpy as np 4 | import numpy.testing as npt 5 | 6 | from sigpy import linop, prox, util 7 | 8 | if __name__ == "__main__": 9 | unittest.main() 10 | 11 | 12 | class TestProx(unittest.TestCase): 13 | def test_L1Reg(self): 14 | shape = [6] 15 | lamda = 1.0 16 | P = prox.L1Reg(shape, lamda) 17 | phase = np.exp(1j * np.random.random(shape)) 18 | x = np.array([-3.0, -2.0, -1.0, 0, 1.0, 2.0]) * phase 19 | y = P(1.0, x) 20 | z = np.array([-2.0, -1.0, -0.0, 0, 0.0, 1.0]) * phase 21 | npt.assert_allclose(y, z) 22 | 23 | def test_L1Proj(self): 24 | shape = [6] 25 | epsilon = 1.0 26 | P = prox.L1Proj(shape, epsilon) 27 | x = util.randn(shape) 28 | y = P(1.0, x) 29 | z = 1.0 if np.linalg.norm(x, 1) > 1.0 else np.linalg.norm(x, 1) 30 | npt.assert_allclose(np.linalg.norm(y, 1), z) 31 | 32 | x = util.randn(shape) * 0.0001 33 | y = P(1.0, x) 34 | z = 1.0 if np.linalg.norm(x, 1) > 1.0 else np.linalg.norm(x, 1) 35 | npt.assert_allclose(np.linalg.norm(y, 1), z) 36 | 37 | def test_UnitaryTransform(self): 38 | shape = [6] 39 | lamda = 1.0 40 | A = linop.FFT(shape) 41 | P = prox.UnitaryTransform(prox.L2Reg(shape, lamda), A) 42 | x = util.randn(shape) 43 | y = P(0.1, x) 44 | npt.assert_allclose(y, x / (1 + lamda * 0.1), atol=1e-6, rtol=1e-6) 45 | 46 | def test_L2Reg(self): 47 | shape = [6] 48 | lamda = 1.0 49 | P = prox.L2Reg(shape, lamda) 50 | x = util.randn(shape) 51 | y = P(0.1, x) 52 | npt.assert_allclose(y, x / (1 + lamda * 0.1)) 53 | 54 | def test_L2Proj(self): 55 | shape = [6] 56 | epsilon = 1.0 57 | P = prox.L2Proj(shape, epsilon) 58 | x = util.randn(shape) * 10 59 | y = P(1.0, x) 60 | npt.assert_allclose(y, x / np.linalg.norm(x.ravel())) 61 | 62 | def test_LInfProj(self): 63 | shape = [5] 64 | epsilon = 0.6 65 | P = prox.LInfProj(shape, epsilon) 66 | x = np.array([-1, -0.5, 0, 0.5, 1]) 67 | y = P(1.0, x) 68 | npt.assert_allclose(y, [-0.6, -0.5, 0, 0.5, 0.6]) 69 | 70 | def test_PsdProj(self): 71 | shape = [3, 3] 72 | P = prox.PsdProj(shape) 73 | x = np.array([[1, 0, 0], [0, -1, 0], [0, 0, -2]]) 74 | y = P(None, x) 75 | npt.assert_allclose(y, np.array([[1, 0, 0], [0, 0, 0], [0, 0, 0]])) 76 | 77 | def test_BoxConstraint(self): 78 | shape = [5] 79 | P = prox.BoxConstraint(shape, -1, 1) 80 | x = np.array([-2, -1, 0, 1, 2]) 81 | y = P(None, x) 82 | npt.assert_allclose(y, [-1, -1, 0, 1, 1]) 83 | 84 | P = prox.BoxConstraint(shape, [-1, 0, -1, -1, -1], [1, 1, 1, 0, 1]) 85 | x = np.array([-2, -1, 0, 1, 2]) 86 | y = P(None, x) 87 | npt.assert_allclose(y, [-1, 0, 0, 0, 1]) 88 | -------------------------------------------------------------------------------- /tests/test_pytorch.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | 3 | import numpy as np 4 | import numpy.testing as npt 5 | 6 | from sigpy import backend, config, linop, pytorch 7 | 8 | if config.pytorch_enabled: 9 | import torch 10 | 11 | if __name__ == "__main__": 12 | unittest.main() 13 | 14 | devices = [backend.cpu_device] 15 | if config.cupy_enabled: 16 | devices.append(backend.Device(0)) 17 | 18 | class TestPytorch(unittest.TestCase): 19 | def test_to_pytorch(self): 20 | for dtype in [np.float32, np.float64]: 21 | for device in devices: 22 | with self.subTest(device=device, dtype=dtype): 23 | xp = device.xp 24 | array = xp.array([1, 2, 3], dtype=dtype) 25 | tensor = pytorch.to_pytorch(array) 26 | array[0] = 0 27 | torch.testing.assert_allclose( 28 | tensor, 29 | torch.tensor( 30 | [0, 2, 3], 31 | dtype=tensor.dtype, 32 | device=tensor.device, 33 | ), 34 | ) 35 | 36 | def test_to_pytorch_complex(self): 37 | for dtype in [np.complex64, np.complex128]: 38 | for device in devices: 39 | with self.subTest(device=device, dtype=dtype): 40 | xp = device.xp 41 | array = xp.array([1 + 1j, 2 + 2j, 3 + 3j], dtype=dtype) 42 | tensor = pytorch.to_pytorch(array) 43 | array[0] = 0 44 | torch.testing.assert_allclose( 45 | tensor, 46 | torch.tensor( 47 | [[0, 0], [2, 2], [3, 3]], 48 | dtype=tensor.dtype, 49 | device=tensor.device, 50 | ), 51 | ) 52 | 53 | def test_from_pytorch(self): 54 | for dtype in [torch.float32, torch.float64]: 55 | for device in devices: 56 | with self.subTest(device=device, dtype=dtype): 57 | if device == backend.cpu_device: 58 | torch_device = torch.device("cpu") 59 | else: 60 | torch_device = torch.device("cuda:0") 61 | 62 | tensor = torch.tensor( 63 | [1, 2, 3], dtype=dtype, device=torch_device 64 | ) 65 | array = pytorch.from_pytorch(tensor) 66 | array[0] = 0 67 | np.testing.assert_allclose( 68 | tensor.cpu().numpy(), [0, 2, 3] 69 | ) 70 | 71 | def test_from_pytorch_complex(self): 72 | for dtype in [torch.float32, torch.float64]: 73 | for device in devices: 74 | with self.subTest(device=device, dtype=dtype): 75 | if device == backend.cpu_device: 76 | torch_device = torch.device("cpu") 77 | else: 78 | torch_device = torch.device("cuda:0") 79 | 80 | tensor = torch.tensor( 81 | [[1, 1], [2, 2], [3, 3]], 82 | dtype=dtype, 83 | device=torch_device, 84 | ) 85 | array = pytorch.from_pytorch(tensor, iscomplex=True) 86 | xp = device.xp 87 | xp.testing.assert_array_equal( 88 | array, [1 + 1j, 2 + 2j, 3 + 3j] 89 | ) 90 | array[0] -= 1 91 | np.testing.assert_allclose( 92 | tensor.cpu().numpy(), [[0, 1], [2, 2], [3, 3]] 93 | ) 94 | 95 | def test_to_pytorch_function(self): 96 | A = linop.Resize([5], [3]) 97 | x = np.array([1, 2, 3], dtype=np.float64) 98 | y = np.ones([5]) 99 | 100 | with self.subTest("forward"): 101 | f = pytorch.to_pytorch_function(A).apply 102 | x_torch = pytorch.to_pytorch(x) 103 | npt.assert_allclose(f(x_torch).detach().numpy(), A(x)) 104 | 105 | with self.subTest("adjoint"): 106 | y_torch = pytorch.to_pytorch(y) 107 | loss = (f(x_torch) - y_torch).pow(2).sum() / 2 108 | loss.backward() 109 | npt.assert_allclose( 110 | x_torch.grad.detach().numpy(), A.H(A(x) - y) 111 | ) 112 | 113 | def test_to_pytorch_function_complex(self): 114 | A = linop.FFT([3]) 115 | x = np.array([1 + 1j, 2 + 2j, 3 + 3j], dtype=np.complex128) 116 | y = np.ones([3], dtype=np.complex128) 117 | 118 | with self.subTest("forward"): 119 | f = pytorch.to_pytorch_function( 120 | A, input_iscomplex=True, output_iscomplex=True 121 | ).apply 122 | x_torch = pytorch.to_pytorch(x) 123 | npt.assert_allclose( 124 | f(x_torch).detach().numpy().ravel(), A(x).view(np.float64) 125 | ) 126 | 127 | with self.subTest("adjoint"): 128 | y_torch = pytorch.to_pytorch(y) 129 | loss = (f(x_torch) - y_torch).pow(2).sum() / 2 130 | loss.backward() 131 | npt.assert_allclose( 132 | x_torch.grad.detach().numpy().ravel(), 133 | A.H(A(x) - y).view(np.float64), 134 | ) 135 | -------------------------------------------------------------------------------- /tests/test_thresh.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | 3 | import numpy as np 4 | import numpy.testing as npt 5 | 6 | from sigpy import config, thresh 7 | 8 | if config.cupy_enabled: 9 | import cupy as cp 10 | 11 | if __name__ == "__main__": 12 | unittest.main() 13 | 14 | 15 | class TestThresh(unittest.TestCase): 16 | def test_l2_proj(self): 17 | x = np.ones(5) 18 | y = np.full(5, 1 / 5**0.5) 19 | npt.assert_allclose(thresh.l2_proj(1, x), y) 20 | 21 | x = np.ones(5) 22 | y = np.ones(5) 23 | npt.assert_allclose(thresh.l2_proj(5**0.5, x), y) 24 | 25 | x = np.ones(5) 26 | y = np.ones(5) 27 | npt.assert_allclose(thresh.l2_proj(10, x), y) 28 | 29 | def test_linf_proj(self): 30 | x = np.ones(5) 31 | y = np.ones(5) 32 | npt.assert_allclose(thresh.linf_proj(1.1, x), y) 33 | 34 | x = np.ones(5) 35 | y = np.ones(5) * 0.1 36 | npt.assert_allclose(thresh.linf_proj(0.1, x), y) 37 | 38 | x = [-2, -1, 0, 1, 2] 39 | y = [-1, -1, 0, 1, 1] 40 | npt.assert_allclose(thresh.linf_proj(1, x), y) 41 | 42 | def test_psd_proj(self): 43 | x = np.array([[1, 0, 0], [0, -1, 0], [0, 0, -2]]) 44 | y = np.array([[1, 0, 0], [0, 0, 0], [0, 0, 0]]) 45 | npt.assert_allclose(thresh.psd_proj(x), y) 46 | 47 | def test_soft_thresh(self): 48 | x = np.array([-2, -1.5, -1, 0.5, 0, 0.5, 1, 1.5, 2]) 49 | y = np.array([-1, -0.5, 0, 0, 0, 0, 0, 0.5, 1]) 50 | 51 | npt.assert_allclose(thresh.soft_thresh(1, x), y) 52 | 53 | def test_hard_thresh(self): 54 | x = np.array([-2, -1.5, -1, 0.5, 0, 0.5, 1, 1.5, 2]) 55 | y = np.array([-2, -1.5, 0, 0, 0, 0, 0, 1.5, 2]) 56 | 57 | npt.assert_allclose(thresh.hard_thresh(1, x), y) 58 | 59 | if config.cupy_enabled: 60 | 61 | def test_soft_thresh_cuda(self): 62 | x = cp.array([-2, -1.5, -1, 0.5, 0, 0.5, 1, 1.5, 2]) 63 | y = cp.array([-1, -0.5, 0, 0, 0, 0, 0, 0.5, 1]) 64 | lamda = cp.array([1.0]) 65 | 66 | cp.testing.assert_allclose(thresh.soft_thresh(lamda, x), y) 67 | 68 | def test_hard_thresh_cuda(self): 69 | x = cp.array([-2, -1.5, -1, 0.5, 0, 0.5, 1, 1.5, 2]) 70 | y = cp.array([-2, -1.5, 0, 0, 0, 0, 0, 1.5, 2]) 71 | lamda = cp.array([1.0]) 72 | 73 | cp.testing.assert_allclose(thresh.hard_thresh(lamda, x), y) 74 | -------------------------------------------------------------------------------- /tests/test_util.py: -------------------------------------------------------------------------------- 1 | import pickle 2 | import unittest 3 | 4 | import numpy as np 5 | import numpy.testing as npt 6 | 7 | from sigpy import backend, util 8 | 9 | if __name__ == "__main__": 10 | unittest.main() 11 | 12 | 13 | class TestUtil(unittest.TestCase): 14 | def test_device(self): 15 | device = backend.Device(-1) 16 | pickle.dumps(device) 17 | 18 | def test_dirac(self): 19 | output = util.dirac([5]) 20 | truth = [0, 0, 1, 0, 0] 21 | npt.assert_allclose(output, truth) 22 | 23 | output = util.dirac([4]) 24 | truth = [0, 0, 1, 0] 25 | npt.assert_allclose(output, truth) 26 | 27 | def test_triang(self): 28 | npt.assert_allclose(util.triang([3]), [0.5, 1, 0.5]) 29 | npt.assert_allclose(util.triang([4]), [0.25, 0.75, 0.75, 0.25]) 30 | 31 | def test_hanning(self): 32 | npt.assert_allclose(util.hanning([4]), [0, 0.5, 1, 0.5]) 33 | npt.assert_allclose(util.hanning([5]), [0, 0.5, 1, 0.5, 0]) 34 | 35 | def test_resize(self): 36 | # Zero-pad 37 | x = np.array([1, 2, 3]) 38 | oshape = [5] 39 | y = util.resize(x, oshape) 40 | npt.assert_allclose(y, [0, 1, 2, 3, 0]) 41 | 42 | x = np.array([1, 2, 3]) 43 | oshape = [4] 44 | y = util.resize(x, oshape) 45 | npt.assert_allclose(y, [0, 1, 2, 3]) 46 | 47 | x = np.array([1, 2]) 48 | oshape = [5] 49 | y = util.resize(x, oshape) 50 | npt.assert_allclose(y, [0, 1, 2, 0, 0]) 51 | 52 | x = np.array([1, 2]) 53 | oshape = [4] 54 | y = util.resize(x, oshape) 55 | npt.assert_allclose(y, [0, 1, 2, 0]) 56 | 57 | # Zero-pad non centered 58 | x = np.array([1, 2, 3]) 59 | oshape = [5] 60 | y = util.resize(x, oshape, oshift=[0]) 61 | npt.assert_allclose(y, [1, 2, 3, 0, 0]) 62 | 63 | # Crop 64 | x = np.array([0, 1, 2, 3, 0]) 65 | oshape = [3] 66 | y = util.resize(x, oshape) 67 | npt.assert_allclose(y, [1, 2, 3]) 68 | 69 | x = np.array([0, 1, 2, 3]) 70 | oshape = [3] 71 | y = util.resize(x, oshape) 72 | npt.assert_allclose(y, [1, 2, 3]) 73 | 74 | x = np.array([0, 1, 2, 0, 0]) 75 | oshape = [2] 76 | y = util.resize(x, oshape) 77 | npt.assert_allclose(y, [1, 2]) 78 | 79 | x = np.array([0, 1, 2, 0]) 80 | oshape = [2] 81 | y = util.resize(x, oshape) 82 | npt.assert_allclose(y, [1, 2]) 83 | 84 | # Crop non centered 85 | x = np.array([1, 2, 3, 0, 0]) 86 | oshape = [3] 87 | y = util.resize(x, oshape, ishift=[0]) 88 | npt.assert_allclose(y, [1, 2, 3]) 89 | 90 | def test_downsample(self): 91 | x = np.array([1, 2, 3, 4, 5]) 92 | y = util.downsample(x, [2]) 93 | npt.assert_allclose(y, [1, 3, 5]) 94 | 95 | def test_upsample(self): 96 | x = np.array([1, 2, 3]) 97 | y = util.upsample(x, [5], [2]) 98 | npt.assert_allclose(y, [1, 0, 2, 0, 3]) 99 | 100 | def test_circshift(self): 101 | input = np.array([0, 1, 2, 3]) 102 | axes = [0] 103 | shift = [1] 104 | npt.assert_allclose(util.circshift(input, shift, axes), [3, 0, 1, 2]) 105 | 106 | input = np.array([[0, 1, 2], [3, 4, 5]]) 107 | axes = [-1] 108 | shift = [2] 109 | npt.assert_allclose( 110 | util.circshift(input, shift, axes), [[1, 2, 0], [4, 5, 3]] 111 | ) 112 | 113 | input = np.array([[0, 1, 2], [3, 4, 5]]) 114 | axes = [-2] 115 | shift = [1] 116 | npt.assert_allclose( 117 | util.circshift(input, shift, axes), [[3, 4, 5], [0, 1, 2]] 118 | ) 119 | 120 | def test_monte_carlo_sure(self): 121 | x = np.ones([100000], dtype=np.float64) 122 | sigma = 0.1 123 | noise = 0.1 * util.randn([100000], dtype=np.float64) 124 | y = x + noise 125 | 126 | def f(y): 127 | return y 128 | 129 | npt.assert_allclose( 130 | sigma**2, util.monte_carlo_sure(f, y, sigma), atol=1e-3 131 | ) 132 | -------------------------------------------------------------------------------- /tests/test_version.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | 3 | from sigpy import version 4 | 5 | if __name__ == "__main__": 6 | unittest.main() 7 | 8 | 9 | class TestVersion(unittest.TestCase): 10 | def test_version(self): 11 | assert version.__version__ == "0.1.27" 12 | -------------------------------------------------------------------------------- /tests/test_wavelet.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | 3 | import numpy as np 4 | import numpy.testing as npt 5 | 6 | from sigpy import wavelet 7 | 8 | if __name__ == "__main__": 9 | unittest.main() 10 | 11 | 12 | class TestWavelet(unittest.TestCase): 13 | def test_fwt(self): 14 | n = 8 15 | input = np.zeros(n, dtype=np.float32) 16 | input[0] = 1 17 | npt.assert_allclose( 18 | wavelet.fwt(input, level=1, wave_name="haar"), 19 | [1 / 2**0.5, 0, 0, 0, 1 / 2**0.5, 0, 0, 0], 20 | ) 21 | 22 | def test_fwt_iwt(self): 23 | for n in range(5, 11): 24 | input = np.zeros(n, dtype=np.float32) 25 | input[0] = 1 26 | _, coeff_slices = wavelet.get_wavelet_shape([n]) 27 | npt.assert_allclose( 28 | wavelet.iwt(wavelet.fwt(input), [n], coeff_slices), input 29 | ) 30 | --------------------------------------------------------------------------------