├── .github ├── ISSUE_TEMPLATE │ └── bug_report.md └── workflows │ └── python-publish.yml ├── .gitignore ├── .vscode └── settings.json ├── LICENSE ├── README.md ├── brainextractor ├── __init__.py ├── helpers.py ├── main.py └── scripts │ ├── __init__.py │ ├── brainextractor.py │ └── brainextractor_render.py ├── pyproject.toml ├── setup.cfg └── setup.py /.github/ISSUE_TEMPLATE/bug_report.md: -------------------------------------------------------------------------------- 1 | --- 2 | name: Bug report 3 | about: Create a report to help us improve 4 | title: '' 5 | labels: '' 6 | assignees: '' 7 | 8 | --- 9 | 10 | **Describe the bug** 11 | A clear and concise description of what the bug is. 12 | 13 | **To Reproduce** 14 | Steps to reproduce the behavior: 15 | 1. Go to '...' 16 | 2. Click on '....' 17 | 3. Scroll down to '....' 18 | 4. See error 19 | 20 | **Expected behavior** 21 | A clear and concise description of what you expected to happen. 22 | 23 | **Screenshots** 24 | If applicable, add screenshots to help explain your problem. 25 | 26 | **Desktop (please complete the following information):** 27 | - OS: [e.g. iOS] 28 | - Version [e.g. 22] 29 | - Python Version 30 | 31 | **Additional context** 32 | Add any other context about the problem here. 33 | -------------------------------------------------------------------------------- /.github/workflows/python-publish.yml: -------------------------------------------------------------------------------- 1 | # This workflow will upload a Python Package using Twine when a release is created 2 | # For more information see: https://help.github.com/en/actions/language-and-framework-guides/using-python-with-github-actions#publishing-to-package-registries 3 | 4 | # This workflow uses actions that are not certified by GitHub. 5 | # They are provided by a third-party and are governed by 6 | # separate terms of service, privacy policy, and support 7 | # documentation. 8 | 9 | name: Upload Python Package 10 | 11 | on: 12 | release: 13 | types: [published] 14 | 15 | jobs: 16 | deploy: 17 | 18 | runs-on: ubuntu-latest 19 | 20 | steps: 21 | - uses: actions/checkout@v2 22 | - name: Set up Python 23 | uses: actions/setup-python@v2 24 | with: 25 | python-version: '3.x' 26 | - name: Install dependencies 27 | run: | 28 | python -m pip install --upgrade pip 29 | pip install build 30 | - name: Build package 31 | run: python -m build 32 | - name: Publish package 33 | uses: pypa/gh-action-pypi-publish@27b31702a0e7fc50959f5ad993c78deac1bdfc29 34 | with: 35 | user: __token__ 36 | password: ${{ secrets.PYPI_API_TOKEN }} 37 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | pip-wheel-metadata/ 24 | share/python-wheels/ 25 | *.egg-info/ 26 | .installed.cfg 27 | *.egg 28 | MANIFEST 29 | 30 | # PyInstaller 31 | # Usually these files are written by a python script from a template 32 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 33 | *.manifest 34 | *.spec 35 | 36 | # Installer logs 37 | pip-log.txt 38 | pip-delete-this-directory.txt 39 | 40 | # Unit test / coverage reports 41 | htmlcov/ 42 | .tox/ 43 | .nox/ 44 | .coverage 45 | .coverage.* 46 | .cache 47 | nosetests.xml 48 | coverage.xml 49 | *.cover 50 | *.py,cover 51 | .hypothesis/ 52 | .pytest_cache/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | target/ 76 | 77 | # Jupyter Notebook 78 | .ipynb_checkpoints 79 | 80 | # IPython 81 | profile_default/ 82 | ipython_config.py 83 | 84 | # pyenv 85 | .python-version 86 | 87 | # pipenv 88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 91 | # install all needed dependencies. 92 | #Pipfile.lock 93 | 94 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 95 | __pypackages__/ 96 | 97 | # Celery stuff 98 | celerybeat-schedule 99 | celerybeat.pid 100 | 101 | # SageMath parsed files 102 | *.sage.py 103 | 104 | # Environments 105 | .env 106 | .venv 107 | env/ 108 | venv/ 109 | ENV/ 110 | env.bak/ 111 | venv.bak/ 112 | 113 | # Spyder project settings 114 | .spyderproject 115 | .spyproject 116 | 117 | # Rope project settings 118 | .ropeproject 119 | 120 | # mkdocs documentation 121 | /site 122 | 123 | # mypy 124 | .mypy_cache/ 125 | .dmypy.json 126 | dmypy.json 127 | 128 | # Pyre type checker 129 | .pyre/ 130 | 131 | # ignore data 132 | data/ -------------------------------------------------------------------------------- /.vscode/settings.json: -------------------------------------------------------------------------------- 1 | { 2 | "python.formatting.provider": "black" 3 | } -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2020 Andrew Van 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # brainextractor 2 | A re-implementation of FSL's Brain Extraction Tool in Python. 3 | 4 | Follows the algorithm as described in: 5 | 6 | ``` 7 | Smith SM. Fast robust automated brain extraction. Hum Brain Mapp. 8 | 2002 Nov;17(3):143-55. doi: 10.1002/hbm.10062. PMID: 12391568; PMCID: PMC6871816. 9 | ``` 10 | 11 | This code was originally made for a [course project](https://www.cse.wustl.edu/~taoju/cse554/). 12 | 13 | https://user-images.githubusercontent.com/3641187/190677589-be019bc6-60e4-4e96-8c71-266285ab0755.mp4 14 | 15 | ## Install 16 | 17 | To install, use `pip` to install this repo: 18 | 19 | ``` 20 | # install from pypi 21 | pip install brainextractor 22 | 23 | # install repo with pip 24 | pip install git+https://github.com/vanandrew/brainextractor@main 25 | 26 | # install from local copy 27 | pip install /path/to/local/repo 28 | ``` 29 | 30 | > **__NOTE:__** It is recommended to use `brainextractor` on **Python 3.7** and above. 31 | 32 | ## Usage 33 | 34 | To extract a brain mask from an image, you can call: 35 | 36 | ``` 37 | # basic usage 38 | brainextractor [input_image] [output_image] 39 | 40 | # example 41 | brainextractor /path/to/test_image.nii.gz /path/to/some_output_image.nii.gz 42 | ``` 43 | 44 | You can adjust the fractional intensity with the `-f` flag: 45 | 46 | ``` 47 | # with custom set threshold 48 | brainextractor [input_image] [output_image] -f [threshold] 49 | 50 | # example 51 | brainextractor /path/to/test_image.nii.gz /path/to/some_output_image.nii.gz -f 0.4 52 | ``` 53 | 54 | To view the deformation process (as in the video above), you can use the `-w` flag to 55 | write the surfaces to a file. Then use `brainextractor_render` to view them: 56 | 57 | ``` 58 | # writes surfaces to file 59 | brainextractor [input_image] [output_image] -w [surfaces_file] 60 | 61 | # load surfaces and render 62 | brainextractor_render [surfaces_file] 63 | 64 | # example 65 | brainextractor /path/to/test_image.nii.gz /path/to/some_output_image.nii.gz -w /path/to/surface_file.surfaces 66 | 67 | brainextractor_render /path/to/surface_file.surfaces 68 | ``` 69 | 70 | If you need an explanation of the options at any time, simply run the `--help` flag: 71 | 72 | ``` 73 | brainextractor --help 74 | ``` 75 | 76 | If you need to call Brainextractor directly from python: 77 | ```python 78 | # import the nibabel library so we can read in a nifti image 79 | import nibabel as nib 80 | # import the BrainExtractor class 81 | from brainextractor import BrainExtractor 82 | 83 | # read in the image file first 84 | input_img = nib.load("/content/MNI.nii.gz") 85 | 86 | # create a BrainExtractor object using the input_img as input 87 | # we just use the default arguments here, but look at the 88 | # BrainExtractor class in the code for the full argument list 89 | bet = BrainExtractor(img=input_img) 90 | 91 | # run the brain extraction 92 | # this will by default run for 1000 iterations 93 | # I recommend looking at the run method to see how it works 94 | bet.run() 95 | 96 | # save the computed mask out to file 97 | bet.save_mask("/content/MNI_mask.nii.gz") 98 | ``` 99 | -------------------------------------------------------------------------------- /brainextractor/__init__.py: -------------------------------------------------------------------------------- 1 | from .main import BrainExtractor 2 | -------------------------------------------------------------------------------- /brainextractor/helpers.py: -------------------------------------------------------------------------------- 1 | """ 2 | Helper functions 3 | """ 4 | import numpy as np 5 | import trimesh 6 | from numba import jit 7 | 8 | 9 | def sphere(shape: list, radius: float, position: list): 10 | """ 11 | Creates a binary sphere 12 | """ 13 | # assume shape and position are both a 3-tuple of int or float 14 | # the units are pixels / voxels (px for short) 15 | # radius is a int or float in px 16 | semisizes = (radius,) * 3 17 | 18 | # genereate the grid for the support points 19 | # centered at the position indicated by position 20 | grid = [slice(-x0, dim - x0) for x0, dim in zip(position, shape)] 21 | position = np.ogrid[grid] 22 | # calculate the distance of all points from `position` center 23 | # scaled by the radius 24 | arr = np.zeros(shape, dtype=float) 25 | for x_i, semisize in zip(position, semisizes): 26 | # this can be generalized for exponent != 2 27 | # in which case `(x_i / semisize)` 28 | # would become `np.abs(x_i / semisize)` 29 | arr += (x_i / semisize) ** 2 30 | 31 | # the inner part of the sphere will have distance below 1 32 | return arr <= 1.0 33 | 34 | 35 | @jit(nopython=True, cache=True) 36 | def closest_integer_point(vertex: np.ndarray): 37 | """ 38 | Gives the closest integer point based on euclidean distance 39 | """ 40 | # get neighboring grid points to search 41 | x = vertex[0] 42 | y = vertex[1] 43 | z = vertex[2] 44 | x0 = np.floor(x) 45 | y0 = np.floor(y) 46 | z0 = np.floor(z) 47 | x1 = x0 + 1 48 | y1 = y0 + 1 49 | z1 = z0 + 1 50 | 51 | # initialize min euclidean distance 52 | min_euclid = 99 53 | 54 | # loop through each neighbor point 55 | for i in [x0, x1]: 56 | for j in [y0, y1]: 57 | for k in [z0, z1]: 58 | # compare coordinate and store if min euclid distance 59 | coords = np.array([i, j, k]) 60 | dist = l2norm(vertex - coords) 61 | if dist < min_euclid: 62 | min_euclid = dist 63 | final_coords = coords 64 | 65 | # return the final coords 66 | return final_coords.astype(np.int64) 67 | 68 | 69 | @jit(nopython=True, cache=True) 70 | def bresenham3d(v0: np.ndarray, v1: np.ndarray): 71 | """ 72 | Bresenham's algorithm for a 3-D line 73 | 74 | https://www.geeksforgeeks.org/bresenhams-algorithm-for-3-d-line-drawing/ 75 | """ 76 | # initialize axis differences 77 | 78 | dx = np.abs(v1[0] - v0[0]) 79 | dy = np.abs(v1[1] - v0[1]) 80 | dz = np.abs(v1[2] - v0[2]) 81 | xs = 1 if (v1[0] > v0[0]) else -1 82 | ys = 1 if (v1[1] > v0[1]) else -1 83 | zs = 1 if (v1[2] > v0[2]) else -1 84 | 85 | # determine the driving axis 86 | if dx >= dy and dx >= dz: 87 | d0 = dx 88 | d1 = dy 89 | d2 = dz 90 | s0 = xs 91 | s1 = ys 92 | s2 = zs 93 | a0 = 0 94 | a1 = 1 95 | a2 = 2 96 | elif dy >= dx and dy >= dz: 97 | d0 = dy 98 | d1 = dx 99 | d2 = dz 100 | s0 = ys 101 | s1 = xs 102 | s2 = zs 103 | a0 = 1 104 | a1 = 0 105 | a2 = 2 106 | elif dz >= dx and dz >= dy: 107 | d0 = dz 108 | d1 = dx 109 | d2 = dy 110 | s0 = zs 111 | s1 = xs 112 | s2 = ys 113 | a0 = 2 114 | a1 = 0 115 | a2 = 1 116 | 117 | # create line array 118 | line = np.zeros((d0 + 1, 3), dtype=np.int64) 119 | line[0] = v0 120 | 121 | # get points 122 | p1 = 2 * d1 - d0 123 | p2 = 2 * d2 - d0 124 | for i in range(d0): 125 | c = line[i].copy() 126 | c[a0] += s0 127 | if p1 >= 0: 128 | c[a1] += s1 129 | p1 -= 2 * d0 130 | if p2 >= 0: 131 | c[a2] += s2 132 | p2 -= 2 * d0 133 | p1 += 2 * d1 134 | p2 += 2 * d2 135 | line[i + 1] = c 136 | 137 | # return list 138 | return line 139 | 140 | 141 | @jit(nopython=True, cache=True) 142 | def l2norm(vec: np.ndarray): 143 | """ 144 | Computes the l2 norm for 3d vector 145 | """ 146 | return np.sqrt(vec[0] ** 2 + vec[1] ** 2 + vec[2] ** 2) 147 | 148 | 149 | @jit(nopython=True, cache=True) 150 | def l2normarray(array: np.ndarray): 151 | """ 152 | Computes the l2 norm for several 3d vectors 153 | """ 154 | return np.sqrt(array[:, 0] ** 2 + array[:, 1] ** 2 + array[:, 2] ** 2) 155 | 156 | 157 | def diagonal_dot(a: np.ndarray, b: np.ndarray): 158 | """ 159 | Dot product by row of a and b. 160 | There are a lot of ways to do this though 161 | performance varies very widely. This method 162 | uses a dot product to sum the row and avoids 163 | function calls if at all possible. 164 | """ 165 | a = np.asanyarray(a) 166 | return np.dot(a * b, [1.0] * a.shape[1]) 167 | -------------------------------------------------------------------------------- /brainextractor/main.py: -------------------------------------------------------------------------------- 1 | """ 2 | Main BrainExtractor class 3 | """ 4 | import os 5 | import warnings 6 | import numpy as np 7 | import nibabel as nib 8 | import trimesh 9 | from numba import jit 10 | from numba.typed import List 11 | from .helpers import sphere, closest_integer_point, bresenham3d, l2norm, l2normarray, diagonal_dot 12 | 13 | 14 | class BrainExtractor: 15 | """ 16 | Implemenation of the FSL Brain Extraction Tool 17 | 18 | This class takes in a Nifti1Image class and generates 19 | the brain surface and mask. 20 | """ 21 | 22 | def __init__( 23 | self, 24 | img: nib.Nifti1Image, 25 | t02t: float = 0.02, 26 | t98t: float = 0.98, 27 | bt: float = 0.5, 28 | d1: float = 20.0, # mm 29 | d2: float = 10.0, # mm 30 | rmin: float = 3.33, # mm 31 | rmax: float = 10.0, # mm 32 | ): 33 | """ 34 | Initialization of Brain Extractor 35 | 36 | Computes image range/thresholds and 37 | estimates the brain radius 38 | """ 39 | print("Initializing...") 40 | 41 | # get image resolution 42 | res = img.header["pixdim"][1] 43 | if not np.allclose(res, img.header["pixdim"][1:4], rtol=1e-3): 44 | warnings.warn( 45 | "The voxels in this image are non-isotropic! \ 46 | Brain extraction settings may not be valid!" 47 | ) 48 | 49 | # store brain extraction parameters 50 | print("Parameters: bt=%f, d1=%f, d2=%f, rmin=%f, rmax=%f" % (bt, d1, d2, rmin, rmax)) 51 | self.bt = bt 52 | self.d1 = d1 / res 53 | self.d2 = d2 / res 54 | self.rmin = rmin / res 55 | self.rmax = rmax / res 56 | 57 | # compute E, F constants 58 | self.E = (1.0 / rmin + 1.0 / rmax) / 2.0 59 | self.F = 6.0 / (1.0 / rmin - 1.0 / rmax) 60 | 61 | # store the image 62 | self.img = img 63 | 64 | # store conveinent references 65 | self.data = img.get_fdata() # 3D data 66 | self.rdata = img.get_fdata().ravel() # flattened data 67 | self.shape = img.shape # 3D shape 68 | self.rshape = np.multiply.reduce(img.shape) # flattened shape 69 | 70 | # get thresholds from histogram 71 | sorted_data = np.sort(self.rdata) 72 | self.tmin = np.min(sorted_data) 73 | self.t2 = sorted_data[np.ceil(t02t * self.rshape).astype(np.int64) + 1] 74 | self.t98 = sorted_data[np.ceil(t98t * self.rshape).astype(np.int64) + 1] 75 | self.tmax = np.max(sorted_data) 76 | self.t = (self.t98 - self.t2) * 0.1 + self.t2 77 | print("tmin: %f, t2: %f, t: %f, t98: %f, tmax: %f" % (self.tmin, self.t2, self.t, self.t98, self.tmax)) 78 | 79 | # find the center of mass of image 80 | ic, jc, kc = np.meshgrid( 81 | np.arange(self.shape[0]), np.arange(self.shape[1]), np.arange(self.shape[2]), indexing="ij", copy=False 82 | ) 83 | cdata = np.clip(self.rdata, self.t2, self.t98) * (self.rdata > self.t) 84 | ci = np.average(ic.ravel(), weights=cdata) 85 | cj = np.average(jc.ravel(), weights=cdata) 86 | ck = np.average(kc.ravel(), weights=cdata) 87 | self.c = np.array([ci, cj, ck]) 88 | print("Center-of-Mass: {}".format(self.c)) 89 | 90 | # compute 1/2 head radius with spherical formula 91 | self.r = 0.5 * np.cbrt(3 * np.sum(self.rdata > self.t) / (4 * np.pi)) 92 | print("Head Radius: %f" % (2 * self.r)) 93 | 94 | # get median value within estimated head sphere 95 | self.tm = np.median(self.data[sphere(self.shape, 2 * self.r, self.c)]) 96 | print("Median within Head Radius: %f" % self.tm) 97 | 98 | # generate initial surface 99 | print("Initializing surface...") 100 | self.surface = trimesh.creation.icosphere(subdivisions=4, radius=self.r) 101 | self.surface = self.surface.apply_transform([[1, 0, 0, ci], [0, 1, 0, cj], [0, 0, 1, ck], [0, 0, 0, 1]]) 102 | 103 | # update the surface attributes 104 | self.num_vertices = self.surface.vertices.shape[0] 105 | self.num_faces = self.surface.faces.shape[0] 106 | self.vertices = np.array(self.surface.vertices) 107 | self.faces = np.array(self.surface.faces) 108 | self.vertex_neighbors_idx = List([np.array(i) for i in self.surface.vertex_neighbors]) 109 | # compute location of vertices in face array 110 | self.face_vertex_idxs = np.zeros((self.num_vertices, 6, 2), dtype=np.int64) 111 | for v in range(self.num_vertices): 112 | f, i = np.asarray(self.faces == v).nonzero() 113 | self.face_vertex_idxs[v, : i.shape[0], 0] = f 114 | self.face_vertex_idxs[v, : i.shape[0], 1] = i 115 | if i.shape[0] == 5: 116 | self.face_vertex_idxs[v, 5, 0] = -1 117 | self.face_vertex_idxs[v, 5, 1] = -1 118 | self.update_surface_attributes() 119 | print("Brain extractor initialization complete!") 120 | 121 | @staticmethod 122 | @jit(nopython=True, cache=True) 123 | def compute_face_normals(num_faces, faces, vertices): 124 | """ 125 | Compute face normals 126 | """ 127 | face_normals = np.zeros((num_faces, 3)) 128 | for i, f in enumerate(faces): 129 | local_v = vertices[f] 130 | a = local_v[1] - local_v[0] 131 | b = local_v[2] - local_v[0] 132 | face_normals[i] = np.array( 133 | (a[1] * b[2] - a[2] * b[1], a[2] * b[0] - a[0] * b[2], a[0] * b[1] - a[1] * b[0]) 134 | ) 135 | face_normals[i] /= l2norm(face_normals[i]) 136 | return face_normals 137 | 138 | @staticmethod 139 | def compute_face_angles(triangles: np.ndarray): 140 | """ 141 | Compute angles in triangles of each face 142 | """ 143 | # don't copy triangles 144 | triangles = np.asanyarray(triangles, dtype=np.float64) 145 | 146 | # get a unit vector for each edge of the triangle 147 | u = triangles[:, 1] - triangles[:, 0] 148 | u /= l2normarray(u)[:, np.newaxis] 149 | v = triangles[:, 2] - triangles[:, 0] 150 | v /= l2normarray(v)[:, np.newaxis] 151 | w = triangles[:, 2] - triangles[:, 1] 152 | w /= l2normarray(w)[:, np.newaxis] 153 | 154 | # run the cosine and per-row dot product 155 | result = np.zeros((len(triangles), 3), dtype=np.float64) 156 | # clip to make sure we don't float error past 1.0 157 | result[:, 0] = np.arccos(np.clip(diagonal_dot(u, v), -1, 1)) 158 | result[:, 1] = np.arccos(np.clip(diagonal_dot(-u, w), -1, 1)) 159 | # the third angle is just the remaining 160 | result[:, 2] = np.pi - result[:, 0] - result[:, 1] 161 | 162 | # a triangle with any zero angles is degenerate 163 | # so set all of the angles to zero in that case 164 | result[(result < 1e-8).any(axis=1), :] = 0.0 165 | return result 166 | 167 | @staticmethod 168 | @jit(nopython=True, cache=True) 169 | def compute_vertex_normals( 170 | num_vertices: int, 171 | faces: np.ndarray, 172 | face_normals: np.ndarray, 173 | face_angles: np.ndarray, 174 | face_vertex_idxs: np.ndarray, 175 | ): 176 | """ 177 | Computes vertex normals 178 | 179 | Sums face normals connected to vertex, weighting 180 | by the angle the vertex makes with the face 181 | """ 182 | vertex_normals = np.zeros((num_vertices, 3)) 183 | for vertex_idx in range(num_vertices): 184 | face_idxs = np.asarray([f for f in face_vertex_idxs[vertex_idx, :, 0] if f != -1]) 185 | inface_idxs = np.asarray([f for f in face_vertex_idxs[vertex_idx, :, 1] if f != -1]) 186 | surrounding_angles = face_angles.ravel()[face_idxs * 3 + inface_idxs] 187 | vertex_normals[vertex_idx] = np.dot(surrounding_angles / surrounding_angles.sum(), face_normals[face_idxs]) 188 | vertex_normals[vertex_idx] /= l2norm(vertex_normals[vertex_idx]) 189 | return vertex_normals 190 | 191 | def rebuild_surface(self): 192 | """ 193 | Rebuilds the surface mesh for given updated vertices 194 | """ 195 | self.update_surface_attributes() 196 | self.surface = trimesh.Trimesh(vertices=self.vertices, faces=self.faces) 197 | 198 | @staticmethod 199 | @jit(nopython=True, cache=True) 200 | def update_surf_attr(vertices: np.ndarray, neighbors_idx: list): 201 | # the neighbors array is tricky because it doesn't 202 | # have the structure of a nice rectangular array 203 | # we initialize it to be the largest size (6) then we 204 | # can make a check for valid vertices later with neighbors size 205 | neighbors = np.zeros((vertices.shape[0], 6, 3)) 206 | neighbors_size = np.zeros(vertices.shape[0], dtype=np.int8) 207 | for i, ni in enumerate(neighbors_idx): 208 | for j, vi in enumerate(ni): 209 | neighbors[i, j, :] = vertices[vi] 210 | neighbors_size[i] = j + 1 211 | 212 | # compute centroids 213 | centroids = np.zeros((vertices.shape[0], 3)) 214 | for i, (n, s) in enumerate(zip(neighbors, neighbors_size)): 215 | centroids[i, 0] = np.mean(n[:s, 0]) 216 | centroids[i, 1] = np.mean(n[:s, 1]) 217 | centroids[i, 2] = np.mean(n[:s, 2]) 218 | 219 | # return optimized surface attributes 220 | return neighbors, neighbors_size, centroids 221 | 222 | def update_surface_attributes(self): 223 | """ 224 | Updates attributes related to the surface 225 | """ 226 | self.triangles = self.vertices[self.faces] 227 | self.face_normals = self.compute_face_normals(self.num_faces, self.faces, self.vertices) 228 | self.face_angles = self.compute_face_angles(self.triangles) 229 | self.vertex_normals = self.compute_vertex_normals( 230 | self.num_vertices, self.faces, self.face_normals, self.face_angles, self.face_vertex_idxs 231 | ) 232 | self.vertex_neighbors, self.vertex_neighbors_size, self.vertex_neighbors_centroids = self.update_surf_attr( 233 | self.vertices, self.vertex_neighbors_idx 234 | ) 235 | self.l = self.get_mean_intervertex_distance(self.vertices, self.vertex_neighbors, self.vertex_neighbors_size) 236 | 237 | @staticmethod 238 | @jit(nopython=True, cache=True) 239 | def get_mean_intervertex_distance(vertices: np.ndarray, neighbors: np.ndarray, sizes: np.ndarray): 240 | """ 241 | Computes the mean intervertex distance across the entire surface 242 | """ 243 | mivd = np.zeros(vertices.shape[0]) 244 | for v in range(vertices.shape[0]): 245 | vecs = vertices[v] - neighbors[v, : sizes[v]] 246 | vd = np.zeros(vecs.shape[0]) 247 | for i in range(vecs.shape[0]): 248 | vd[i] = l2norm(vecs[i]) 249 | mivd[v] = np.mean(vd) 250 | return np.mean(mivd) 251 | 252 | def run(self, iterations: int = 1000, deformation_path: str = None): 253 | """ 254 | Runs the extraction step. 255 | 256 | This deforms the surface based on the method outlined in" 257 | 258 | Smith SM. Fast robust automated brain extraction. Hum Brain Mapp. 259 | 2002 Nov;17(3):143-55. doi: 10.1002/hbm.10062. PMID: 12391568; 260 | PMCID: PMC6871816. 261 | 262 | """ 263 | print("Running surface deformation...") 264 | # initialize s_vectors 265 | s_vectors = np.zeros(self.vertices.shape) 266 | 267 | # initialize s_vector normal/tangent 268 | s_n = np.zeros(self.vertices.shape) 269 | s_t = np.zeros(self.vertices.shape) 270 | 271 | # initialize u components 272 | u1 = np.zeros(self.vertices.shape) 273 | u2 = np.zeros(self.vertices.shape) 274 | u3 = np.zeros(self.vertices.shape) 275 | u = np.zeros(self.vertices.shape) 276 | 277 | # if deformation path defined 278 | if deformation_path: 279 | import zipfile 280 | 281 | zip_file = zipfile.ZipFile(deformation_path, "w") 282 | 283 | # surface deformation loop 284 | for i in range(iterations): 285 | print("Iteration: %d" % i, end="\r") 286 | # run one step of deformation 287 | self.step_of_deformation( 288 | self.data, 289 | self.vertices, 290 | self.vertex_normals, 291 | self.vertex_neighbors_centroids, 292 | self.l, 293 | self.t2, 294 | self.t, 295 | self.tm, 296 | self.t98, 297 | self.E, 298 | self.F, 299 | self.bt, 300 | self.d1, 301 | self.d2, 302 | s_vectors, 303 | s_n, 304 | s_t, 305 | u1, 306 | u2, 307 | u3, 308 | u, 309 | ) 310 | # update vertices 311 | self.vertices += u 312 | if deformation_path: # write to stl if enabled 313 | surface_file = "surface{:0>5d}.stl".format(i) 314 | dirpath = os.path.dirname(deformation_path) 315 | self.rebuild_surface() 316 | self.save_surface(os.path.join(dirpath, surface_file)) 317 | zip_file.write(os.path.join(dirpath, surface_file), surface_file) 318 | os.remove(os.path.join(dirpath, surface_file)) 319 | else: # just update the surface attributes 320 | self.update_surface_attributes() 321 | 322 | # close zip file 323 | if deformation_path: 324 | zip_file.close() 325 | 326 | # update the surface 327 | self.rebuild_surface() 328 | print("") 329 | print("Complete.") 330 | 331 | @staticmethod 332 | @jit(nopython=True, cache=True) 333 | def step_of_deformation( 334 | data: np.ndarray, 335 | vertices: np.ndarray, 336 | normals: np.ndarray, 337 | neighbors_centroids: np.ndarray, 338 | l: float, 339 | t2: float, 340 | t: float, 341 | tm: float, 342 | t98: float, 343 | E: float, 344 | F: float, 345 | bt: float, 346 | d1: float, 347 | d2: float, 348 | s_vectors: np.ndarray, 349 | s_n: np.ndarray, 350 | s_t: np.ndarray, 351 | u1: np.ndarray, 352 | u2: np.ndarray, 353 | u3: np.ndarray, 354 | u: np.ndarray, 355 | ): 356 | """ 357 | Finds a single displacement step for the surface 358 | """ 359 | # loop over vertices 360 | for i, vertex in enumerate(vertices): 361 | # compute s vector 362 | s_vectors[i] = neighbors_centroids[i] - vertex 363 | 364 | # split s vector into normal and tangent components 365 | s_n[i] = np.dot(s_vectors[i], normals[i]) * normals[i] 366 | s_t[i] = s_vectors[i] - s_n[i] 367 | 368 | # set component u1 369 | u1[i] = 0.5 * s_t[i] 370 | 371 | # compute local radius of curvature 372 | r = (l ** 2) / (2 * l2norm(s_n[i])) 373 | 374 | # compute f2 375 | f2 = (1 + np.tanh(F * (1 / r - E))) / 2 376 | 377 | # set component u2 378 | u2[i] = f2 * s_n[i] 379 | 380 | # get endpoints directed interior (distance set by d1 and d2) 381 | e1 = closest_integer_point(vertex - d1 * normals[i]) 382 | e2 = closest_integer_point(vertex - d2 * normals[i]) 383 | 384 | # get lines created by e1/e2 385 | c = closest_integer_point(vertex) 386 | i1 = bresenham3d(c, e1) 387 | i2 = bresenham3d(c, e2) 388 | 389 | # get Imin/Imax 390 | linedata1 = [data[d[0], d[1], d[2]] for d in i1] 391 | linedata1.append(tm) 392 | linedata1 = np.asarray(linedata1) 393 | Imin = np.max(np.asarray([t2, np.min(linedata1)])) 394 | linedata2 = [data[d[0], d[1], d[2]] for d in i2] 395 | linedata2.append(t) 396 | linedata2 = np.asarray(linedata2) 397 | Imax = np.min(np.asarray([tm, np.max(linedata2)])) 398 | 399 | # get tl 400 | tl = (Imax - t2) * bt + t2 401 | 402 | # compute f3 403 | f3 = 0.05 * 2 * (Imin - tl) / (Imax - t2) * l 404 | 405 | # get component u3 406 | u3[i] = f3 * normals[i] 407 | 408 | # get displacement vector 409 | u[:, :] = u1 + u2 + u3 410 | 411 | @staticmethod 412 | def check_bound(img_min: int, img_max: int, img_start: int, img_end: int, vol_start: int, vol_end: int): 413 | if img_min < img_start: 414 | vol_start = vol_start + (img_start - img_min) 415 | img_min = 0 416 | if img_max > img_end: 417 | vol_end = vol_end - (img_max - img_end) 418 | img_max = img_end 419 | return img_min, img_max, img_start, img_end, vol_start, vol_end 420 | 421 | def compute_mask(self): 422 | """ 423 | Convert surface mesh to volume 424 | """ 425 | vol = self.surface.voxelized(1) 426 | vol = vol.fill() 427 | self.mask = np.zeros(self.shape) 428 | bounds = vol.bounds 429 | 430 | # adjust bounds to handle data outside the field of view 431 | 432 | # get the bounds of the volumized surface mesh 433 | x_min = int(vol.bounds[0, 0]) if vol.bounds[0, 0] > 0 else int(vol.bounds[0, 0]) - 1 434 | x_max = int(vol.bounds[1, 0]) if vol.bounds[1, 0] > 0 else int(vol.bounds[1, 0]) - 1 435 | y_min = int(vol.bounds[0, 1]) if vol.bounds[0, 1] > 0 else int(vol.bounds[0, 1]) - 1 436 | y_max = int(vol.bounds[1, 1]) if vol.bounds[1, 1] > 0 else int(vol.bounds[1, 1]) - 1 437 | z_min = int(vol.bounds[0, 2]) if vol.bounds[0, 2] > 0 else int(vol.bounds[0, 2]) - 1 438 | z_max = int(vol.bounds[1, 2]) if vol.bounds[1, 2] > 0 else int(vol.bounds[1, 2]) - 1 439 | 440 | # get the extents of the original image 441 | x_start = 0 442 | y_start = 0 443 | z_start = 0 444 | x_end = int(self.shape[0]) 445 | y_end = int(self.shape[1]) 446 | z_end = int(self.shape[2]) 447 | 448 | # get the extents of the volumized surface 449 | x_vol_start = 0 450 | y_vol_start = 0 451 | z_vol_start = 0 452 | x_vol_end = int(vol.matrix.shape[0]) 453 | y_vol_end = int(vol.matrix.shape[1]) 454 | z_vol_end = int(vol.matrix.shape[2]) 455 | 456 | # if the volumized surface mesh is outside the extents of the original image 457 | # we need to crop this volume to fit the image 458 | x_min, x_max, x_start, x_end, x_vol_start, x_vol_end = self.check_bound( 459 | x_min, x_max, x_start, x_end, x_vol_start, x_vol_end 460 | ) 461 | y_min, y_max, y_start, y_end, y_vol_start, y_vol_end = self.check_bound( 462 | y_min, y_max, y_start, y_end, y_vol_start, y_vol_end 463 | ) 464 | z_min, z_max, z_start, z_end, z_vol_start, z_vol_end = self.check_bound( 465 | z_min, z_max, z_start, z_end, z_vol_start, z_vol_end 466 | ) 467 | self.mask[x_min:x_max, y_min:y_max, z_min:z_max] = vol.matrix[ 468 | x_vol_start:x_vol_end, y_vol_start:y_vol_end, z_vol_start:z_vol_end 469 | ] 470 | return self.mask 471 | 472 | def save_mask(self, filename: str): 473 | """ 474 | Saves brain extraction to nifti file 475 | """ 476 | mask = self.compute_mask() 477 | nib.Nifti1Image(mask, self.img.affine).to_filename(filename) 478 | 479 | def save_surface(self, filename: str): 480 | """ 481 | Save surface in .stl 482 | """ 483 | self.surface.export(filename) 484 | -------------------------------------------------------------------------------- /brainextractor/scripts/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/vanandrew/brainextractor/4fb768703561d02842f0afa005f3e34f347da754/brainextractor/scripts/__init__.py -------------------------------------------------------------------------------- /brainextractor/scripts/brainextractor.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | import os 4 | import argparse 5 | import nibabel as nib 6 | from brainextractor.main import BrainExtractor 7 | 8 | 9 | def main(): 10 | # create command line parser 11 | parser = argparse.ArgumentParser( 12 | description="A Reimplementation of FSL's Brain Extraction Tool", 13 | epilog="Author: Andrew Van, vanandrew@wustl.edu, 12/15/2020", 14 | ) 15 | parser.add_argument("input_img", help="Input image to brain extract") 16 | parser.add_argument("output_img", help="Output image to write out") 17 | parser.add_argument("-w", "--write_surface_deform", help="Path to write out surface files at each deformation step") 18 | parser.add_argument( 19 | "-f", 20 | "--fractional_threshold", 21 | type=float, 22 | default=0.5, 23 | help="Main threshold parameter for controlling brain/background (Default: 0.5)", 24 | ) 25 | parser.add_argument( 26 | "-n", "--iterations", type=int, default=1000, help="Number of iterations to run (Default: 1000)" 27 | ) 28 | parser.add_argument( 29 | "-t", 30 | "--histogram_threshold", 31 | nargs=2, 32 | type=float, 33 | default=[0.02, 0.98], 34 | help="Sets min/max of histogram (Default: 0.02, 0.98)", 35 | ) 36 | parser.add_argument( 37 | "-d", 38 | "--search_distance", 39 | nargs=2, 40 | type=float, 41 | default=[20.0, 10.0], 42 | help="Sets search distance for max/min of image along vertex normals (Default: 20.0, 10.0)", 43 | ) 44 | parser.add_argument( 45 | "-r", 46 | "--radius_of_curvatures", 47 | nargs=2, 48 | type=float, 49 | default=[3.33, 10.0], 50 | help="Sets min/max radius of curvature for surface (Default: 3.33, 10.0)", 51 | ) 52 | 53 | # parse arguments 54 | args = parser.parse_args() 55 | 56 | # load input image 57 | input_img = os.path.abspath(args.input_img) 58 | img = nib.load(input_img) 59 | 60 | # create brain extractor 61 | bet = BrainExtractor( 62 | img=img, 63 | t02t=args.histogram_threshold[0], 64 | t98t=args.histogram_threshold[1], 65 | bt=args.fractional_threshold, 66 | d1=args.search_distance[0], 67 | d2=args.search_distance[1], 68 | rmin=args.radius_of_curvatures[0], 69 | rmax=args.radius_of_curvatures[1], 70 | ) 71 | 72 | # create output path for surface files if defined 73 | if args.write_surface_deform: 74 | deformation_path = os.path.abspath(args.write_surface_deform) 75 | try: 76 | os.remove(deformation_path) 77 | except FileNotFoundError: 78 | pass 79 | os.makedirs(os.path.dirname(deformation_path), exist_ok=True) 80 | 81 | # run brain extractor 82 | bet.run(iterations=args.iterations, deformation_path=deformation_path if args.write_surface_deform else None) 83 | 84 | # make dirs to output directory as needed 85 | output_img = os.path.abspath(args.output_img) 86 | os.makedirs(os.path.dirname(output_img), exist_ok=True) 87 | 88 | # write mask to file 89 | print("Saving mask...") 90 | bet.save_mask(output_img) 91 | print("Mask saved.") 92 | -------------------------------------------------------------------------------- /brainextractor/scripts/brainextractor_render.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | import os 4 | import argparse 5 | import trimesh 6 | import pyrender 7 | import time 8 | import zipfile 9 | 10 | 11 | def render(surface_path: str, video_path: str = None, loop: bool = False): 12 | """ 13 | Create rendering of surface deformation 14 | """ 15 | # open surfaces file 16 | surface_dir = os.path.dirname(surface_path) 17 | surfaces = zipfile.ZipFile(surface_path, "r") 18 | 19 | # get surface list 20 | surface_list = surfaces.namelist() 21 | iterations = len(surface_list) 22 | 23 | # get center of mesh (use first surface file) 24 | surfaces.extract(surface_list[0], path=surface_dir) 25 | mesh = pyrender.Mesh.from_trimesh(trimesh.load(os.path.join(surface_dir, surface_list[0]))) 26 | center = mesh.centroid 27 | os.remove(os.path.join(surface_dir, surface_list[0])) 28 | 29 | # read in surfaces 30 | nodes = list() 31 | for mesh_path in surface_list: 32 | surfaces.extract(mesh_path, path=surface_dir) 33 | print("Reading in surface data: %s" % mesh_path, end="\r") 34 | mesh = trimesh.load(os.path.join(surface_dir, mesh_path)) 35 | os.remove(os.path.join(surface_dir, mesh_path)) 36 | mesh.apply_transform([[1, 0, 0, -center[0]], [0, 1, 0, -center[1]], [0, 0, 1, -center[2]], [0, 0, 0, 1]]) 37 | dm = pyrender.Mesh.from_trimesh(mesh) 38 | node = pyrender.Node(scale=[0.01, 0.01, 0.01], mesh=dm) 39 | nodes.append(node) 40 | print("") 41 | 42 | # create scene 43 | scene = pyrender.Scene(bg_color=[0, 0, 0, 0]) 44 | v = pyrender.Viewer( 45 | scene, 46 | viewport_size=(1280, 720), 47 | use_direct_lighting=True, 48 | all_wireframe=True, 49 | run_in_thread=True, 50 | caption=[ 51 | { 52 | "location": 3, 53 | "text": "", 54 | "font_name": "OpenSans-Regular", 55 | "font_pt": 40, 56 | "color": [200, 200, 200, 255], 57 | "scale": 1.0, 58 | "align": 0, 59 | } 60 | ], 61 | record=bool(video_path), 62 | rotate=True, 63 | rotate_rate=0.25, 64 | rotate_axis=[0, 1, 0], 65 | ) 66 | 67 | # display surfaces frame by frame 68 | iteration_limit = iterations 69 | if loop: 70 | iterations *= int(3600 / (iterations * 0.033333333333333)) 71 | try: 72 | for i in range(iterations): 73 | c = i - (i // iteration_limit) * iteration_limit 74 | it = "Iteration %d" % c 75 | print(it, end="\r") 76 | v.render_lock.acquire() 77 | v.viewer_flags["caption"][0]["text"] = it 78 | if c > 0: 79 | scene.remove_node(nodes[c - 1]) 80 | elif i > 0 and i % iteration_limit == 0: 81 | scene.remove_node(nodes[iteration_limit - 1]) 82 | scene.add_node(nodes[c]) 83 | v.render_lock.release() 84 | time.sleep(0.033333333333333) 85 | v.close_external() 86 | except KeyboardInterrupt: 87 | pass 88 | print("") 89 | 90 | # save video 91 | if video_path: 92 | dirpath = os.path.dirname(video_path) 93 | os.makedirs(dirpath, exist_ok=True) 94 | print("Saving video to file...") 95 | v.save_gif(os.path.join(dirpath, "temp.gif")) 96 | os.system("ffmpeg -i {} {}".format(os.path.join(dirpath, "temp.gif"), video_path)) 97 | os.remove(os.path.join(dirpath, "temp.gif")) 98 | print("{} successfully saved.".format(os.path.basename(video_path))) 99 | 100 | 101 | def main(): 102 | # create command line parser 103 | parser = argparse.ArgumentParser( 104 | description="Renders surface deformation evolution", 105 | epilog="Author: Andrew Van, vanandrew@wustl.edu, 12/15/2020", 106 | ) 107 | parser.add_argument("surfaces", help="Surfaces to render") 108 | parser.add_argument("-s", "--save_mp4", help="Saves an mp4 output") 109 | parser.add_argument("-l", "--loop", action="store_true", help="Loop the render (1 hour)") 110 | 111 | # parse arguments 112 | args = parser.parse_args() 113 | 114 | # call render function 115 | render( 116 | surface_path=os.path.abspath(args.surfaces), 117 | video_path=os.path.abspath(args.save_mp4) if args.save_mp4 else None, 118 | loop=args.loop, 119 | ) 120 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [project] 2 | name = "brainextractor" 3 | description = "brain extraction in python" 4 | readme = "README.md" 5 | requires-python = ">=3.7" 6 | license = "MIT" 7 | authors = [{ name = "Andrew Van", email = "vanandrew@wustl.edu" }] 8 | keywords = [ 9 | "python", 10 | "image-processing", 11 | "neuroscience", 12 | "neuroimaging", 13 | "segmentation", 14 | "fsl", 15 | "brain-extraction", 16 | "skull-stripping", 17 | ] 18 | classifiers = [ 19 | "Programming Language :: Python :: 3", 20 | "Programming Language :: C++", 21 | "Topic :: Scientific/Engineering :: Image Processing", 22 | ] 23 | urls = { github = "https://github.com/vanandrew/brainextractor" } 24 | version = "0.2.2" 25 | dynamic = ["scripts"] 26 | dependencies = [ 27 | "numba >= 0.51.2", 28 | "nibabel >= 3.2.1", 29 | "trimesh >= 3.8.15", 30 | "numpy >= 1.19.4", 31 | "scipy >= 1.5.4", 32 | "pyrender >= 0.1.43", 33 | ] 34 | 35 | [project.optional-dependencies] 36 | dev = [ 37 | "black >= 22.0", 38 | ] 39 | 40 | [build-system] 41 | requires = ["setuptools", "wheel"] 42 | build-backend = "setuptools.build_meta" 43 | 44 | [tool.setuptools] 45 | zip-safe = true 46 | 47 | [tool.black] 48 | line-length = 120 49 | target-version = ["py37", "py38", "py39", "py310"] 50 | -------------------------------------------------------------------------------- /setup.cfg: -------------------------------------------------------------------------------- 1 | [pycodestyle] 2 | ignore = E203, W503, E741 3 | max-line-length = 120 4 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import site 3 | from pathlib import Path 4 | from setuptools import setup 5 | 6 | # This line enables user based installation when using pip in editable mode with the latest 7 | # pyproject.toml config 8 | site.ENABLE_USER_SITE = "--user" in sys.argv[1:] 9 | THISDIR = Path(__file__).parent 10 | 11 | # get scripts path 12 | scripts_path = THISDIR / "brainextractor" / "scripts" 13 | 14 | setup( 15 | entry_points={ 16 | "console_scripts": [ 17 | f"{f.stem}=brainextractor.scripts.{f.stem}:main" 18 | for f in scripts_path.glob("*.py") 19 | if f.name not in "__init__.py" 20 | ] 21 | } 22 | ) 23 | --------------------------------------------------------------------------------