├── README.md ├── deepseg1.sh ├── example_brain_t1.nii.gz ├── hippodeep.py ├── pyproject.toml └── torchparams ├── hippodeep.pt ├── params_head_00075_00000.pt └── paramsaffineta_00079_00000.pt /README.md: -------------------------------------------------------------------------------- 1 | # hippodeep 2 | Brain Hippocampus Segmentation 3 | 4 | This program segments the Hippocampus of raw brain T1 images in a few seconds. 5 | 6 | ![blink_rotated](https://user-images.githubusercontent.com/590921/75311442-1a705a00-589a-11ea-9cb6-d889fb226516.gif) 7 | 8 | It relies on a Convolutional Neural Network pre-trained on thousands of images from multiple large cohorts, and is therefore quite robust to subject- and MR-contrast variation. 9 | For more details on its creation, refer the corresponding manuscript at http://dx.doi.org/10.1016/j.media.2017.11.004 10 | 11 | This official hippodeep version is a modern PyTorch port of the original Theano version that is technically obsolete. While the hippocampal segmentation model is exactly the same as described in the paper, the pre- and post-processing steps had been improved, and thus, results may differ sligthly. The deprecated theano repo is still available at https://github.com/bthyreau/hippodeep 12 | 13 | 14 | ## Requirement 15 | 16 | This program requires Python 3 with the PyTorch library. 17 | 18 | No GPU is required. 19 | 20 | It should work on most distro and platform that supports pytorch, though mainly tested on Linux CentOS 6.x ~ 8.x, Ubuntu 18.04 ~ 22.04 and MacOS 10.13 ~ 12, using PyTorch versions from 1.0.0 to 2.6.0. 21 | 22 | ## Installation and Usage 23 | 24 | Just clone or download this repository. 25 | 26 | If you have the uv packaging tool ( https://docs.astral.sh/uv/ ), you can do 27 | 28 | `uv run hippodeep.py example_brain_t1.nii.gz` 29 | 30 | which should take care of downloading the dependencies in the first run. 31 | 32 | Otherwise, you need to configure a python 3 environment on your machine: In addition to PyTorch, the code requires scipy and nibabel. A possible way to install python from scratch is to use Anaconda (anaconda.com) to create an environment, then 33 | - install scipy (`conda install scipy` or `pip install scipy`) and nibabel (`pip install nibabel`) 34 | - get pytorch for python from `https://pytorch.org/get-started/locally/`. CUDA is not necessary. 35 | - Then, to use the program, call it with `python hippodeep.py example_brain_t1.nii.gz` , possibly changing 'python' to 'python3' depending on your exact setup. 36 | 37 | ## Results 38 | 39 | To process multiple subjects, pass them as multiple arguments. e.g: 40 | 41 | `python hippodeep.py subject_*.nii.gz`. (or equivalent with 'uv run') 42 | 43 | The resulting segmentations should be stored as `example_brain_t1_mask_L.nii.gz` (or R for right) and `example_brain_t1_brain_mask.nii.gz`. The mask volumes (in mm^3) are stored in a csv file named `example_brain_t1_hippoLR_volumes.csv`. If more than one input was specified, a summary table named `all_subjects_hippo_report.csv` is created. 44 | 45 | ## License 46 | MIT License 47 | -------------------------------------------------------------------------------- /deepseg1.sh: -------------------------------------------------------------------------------- 1 | python $(dirname $0)/hippodeep.py $@ 2 | -------------------------------------------------------------------------------- /example_brain_t1.nii.gz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bthyreau/hippodeep_pytorch/b90dde48db6bcc673cfa0a74bc5a6f7093d0b11c/example_brain_t1.nii.gz -------------------------------------------------------------------------------- /hippodeep.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import nibabel 3 | import numpy as np 4 | import os, sys, time 5 | import scipy.ndimage 6 | import torch.nn as nn 7 | import torch.nn.functional as F 8 | from numpy.linalg import inv 9 | try: 10 | import resource 11 | except: 12 | pass 13 | 14 | # monkey-patch for back-compatibility with older (~1.0.0) torch 15 | try: 16 | import inspect 17 | if not "align_corners" in inspect.signature(F.grid_sample).parameters: 18 | old_grid_sample = torch.nn.functional.grid_sample 19 | F.grid_sample = lambda *x, **k : old_grid_sample(*x) 20 | except: 21 | pass 22 | 23 | if len(sys.argv[1:]) == 0: 24 | print("Need to pass one or more T1 image filename as argument") 25 | sys.exit(1) 26 | 27 | print("Using all available CPU threads") 28 | if 0: # otherwise, set a limit (useful for running multiple instances) 29 | torch.set_num_threads(4) 30 | 31 | 32 | class HeadModel(nn.Module): 33 | def __init__(self): 34 | super(HeadModel, self).__init__() 35 | self.conv0a = nn.Conv3d(1, 8, 3, padding=1) 36 | self.conv0b = nn.Conv3d(8, 8, 3, padding=1) 37 | self.bn0a = nn.BatchNorm3d(8) 38 | 39 | self.ma1 = nn.MaxPool3d(2) 40 | self.conv1a = nn.Conv3d(8, 16, 3, padding=1) 41 | self.conv1b = nn.Conv3d(16, 24, 3, padding=1) 42 | self.bn1a = nn.BatchNorm3d(24) 43 | 44 | self.ma2 = nn.MaxPool3d(2) 45 | self.conv2a = nn.Conv3d(24, 24, 3, padding=1) 46 | self.conv2b = nn.Conv3d(24, 32, 3, padding=1) 47 | self.bn2a = nn.BatchNorm3d(32) 48 | 49 | self.ma3 = nn.MaxPool3d(2) 50 | self.conv3a = nn.Conv3d(32, 48, 3, padding=1) 51 | self.conv3b = nn.Conv3d(48, 48, 3, padding=1) 52 | self.bn3a = nn.BatchNorm3d(48) 53 | 54 | 55 | self.conv2u = nn.Conv3d(48, 24, 3, padding=1) 56 | self.conv2v = nn.Conv3d(24+32, 24, 3, padding=1) 57 | self.bn2u = nn.BatchNorm3d(24) 58 | 59 | 60 | self.conv1u = nn.Conv3d(24, 24, 3, padding=1) 61 | self.conv1v = nn.Conv3d(24+24, 24, 3, padding=1) 62 | self.bn1u = nn.BatchNorm3d(24) 63 | 64 | 65 | self.conv0u = nn.Conv3d(24, 16, 3, padding=1) 66 | self.conv0v = nn.Conv3d(16+8, 8, 3, padding=1) 67 | self.bn0u = nn.BatchNorm3d(8) 68 | 69 | self.conv1x = nn.Conv3d(8, 4, 1, padding=0) 70 | 71 | def forward(self, x): 72 | x = F.elu(self.conv0a(x)) 73 | self.li0 = x = F.elu(self.bn0a(self.conv0b(x))) 74 | 75 | x = self.ma1(x) 76 | x = F.elu(self.conv1a(x)) 77 | self.li1 = x = F.elu(self.bn1a(self.conv1b(x))) 78 | 79 | x = self.ma2(x) 80 | x = F.elu(self.conv2a(x)) 81 | self.li2 = x = F.elu(self.bn2a(self.conv2b(x))) 82 | 83 | x = self.ma3(x) 84 | x = F.elu(self.conv3a(x)) 85 | self.li3 = x = F.elu(self.bn3a(self.conv3b(x))) 86 | 87 | x = F.interpolate(x, scale_factor=2, mode="nearest") 88 | 89 | x = F.elu(self.conv2u(x)) 90 | x = torch.cat([x, self.li2], 1) 91 | x = F.elu(self.bn2u(self.conv2v(x))) 92 | 93 | self.lo1 = x 94 | x = F.interpolate(x, scale_factor=2, mode="nearest") 95 | 96 | x = F.elu(self.conv1u(x)) 97 | x = torch.cat([x, self.li1], 1) 98 | x = F.elu(self.bn1u(self.conv1v(x))) 99 | 100 | x = F.interpolate(x, scale_factor=2, mode="nearest") 101 | self.la1 = x 102 | 103 | x = F.elu(self.conv0u(x)) 104 | x = torch.cat([x, self.li0], 1) 105 | x = F.elu(self.bn0u(self.conv0v(x))) 106 | 107 | self.out = x = self.conv1x(x) 108 | x = torch.sigmoid(x) 109 | return x 110 | 111 | 112 | 113 | 114 | class ModelAff(nn.Module): 115 | def __init__(self): 116 | super(ModelAff, self).__init__() 117 | self.convaff1 = nn.Conv3d(2, 16, 3, padding=1) 118 | self.maaff1 = nn.MaxPool3d(2) 119 | self.convaff2 = nn.Conv3d(16, 16, 3, padding=1) 120 | self.bnaff2 = nn.LayerNorm([32, 32, 32]) 121 | 122 | self.maaff2 = nn.MaxPool3d(2) 123 | self.convaff3 = nn.Conv3d(16, 32, 3, padding=1) 124 | self.bnaff3 = nn.LayerNorm([16, 16, 16]) 125 | 126 | self.maaff3 = nn.MaxPool3d(2) 127 | self.convaff4 = nn.Conv3d(32, 64, 3, padding=1) 128 | self.maaff4 = nn.MaxPool3d(2) 129 | self.bnaff4 = nn.LayerNorm([8, 8, 8]) 130 | self.convaff5 = nn.Conv3d(64, 128, 1, padding=0) 131 | self.convaff6 = nn.Conv3d(128, 12, 4, padding=0) 132 | 133 | gsx, gsy, gsz = 64, 64, 64 134 | gx, gy, gz = np.linspace(-1, 1, gsx), np.linspace(-1, 1, gsy), np.linspace(-1,1, gsz) 135 | grid = np.meshgrid(gx, gy, gz) # Y, X, Z 136 | grid = np.stack([grid[2], grid[1], grid[0], np.ones_like(grid[0])], axis=3) 137 | netgrid = np.swapaxes(grid, 0, 1)[...,[2,1,0,3]] 138 | 139 | self.register_buffer('grid', torch.tensor(netgrid.astype("float32"), requires_grad = False)) 140 | self.register_buffer('diagA', torch.eye(4, dtype=torch.float32)) 141 | 142 | def forward(self, outc1): 143 | x = outc1 144 | x = F.relu(self.convaff1(x)) 145 | x = self.maaff1(x) 146 | x = F.relu(self.bnaff2(self.convaff2(x))) 147 | x = self.maaff2(x) 148 | x = F.relu(self.bnaff3(self.convaff3(x))) 149 | x = self.maaff3(x) 150 | x = F.relu(self.bnaff4(self.convaff4(x))) 151 | x = self.maaff4(x) 152 | x = F.relu(self.convaff5(x)) 153 | x = self.convaff6(x) 154 | 155 | x = x.view(-1, 3, 4) 156 | x = torch.cat([x, x[:,0:1] * 0], dim=1) 157 | self.tA = torch.transpose(x + self.diagA, 1, 2) 158 | 159 | wgrid = self.grid @ self.tA[:,None,None] 160 | gout = F.grid_sample(outc1, wgrid[...,[2,1,0]], align_corners=True) 161 | return gout, self.tA 162 | 163 | def resample_other(self, other): 164 | with torch.no_grad(): 165 | wgrid = self.grid @ self.tA[:,None,None] 166 | gout = F.grid_sample(other, wgrid[...,[2,1,0]], align_corners=True) 167 | return gout 168 | 169 | 170 | 171 | def bbox_world(affine, shape): 172 | s = shape[0]-1, shape[1]-1, shape[2]-1 173 | bbox = [[0,0,0], [s[0],0,0], [0,s[1],0], [0,0,s[2]], [s[0],s[1],0], [s[0],0,s[2]], [0,s[1],s[2]], [s[0],s[1],s[2]]] 174 | w = affine @ np.column_stack([bbox, [1]*8]).T 175 | return w.T 176 | 177 | bbox_one = np.array([[-1,-1,-1,1], [1, -1, -1, 1], [-1, 1, -1, 1], [-1, -1, 1, 1], [1, 1, -1, 1], [1, -1, 1, 1], [-1, 1, 1, 1], [1,1,1,1]]) 178 | 179 | affine64_mni = \ 180 | np.array([[ -2.85714293, -0. , 0. , 90. ], 181 | [ -0. , 3.42857146, -0. , -126. ], 182 | [ 0. , 0. , 2.85714293, -72. ], 183 | [ 0. , 0. , 0. , 1. ]]) 184 | 185 | 186 | scriptpath = os.path.dirname(os.path.realpath(__file__)) 187 | 188 | device = torch.device("cpu") 189 | net = HeadModel() 190 | net.to(device) 191 | net.load_state_dict(torch.load(scriptpath + "/torchparams/params_head_00075_00000.pt", map_location=device)) 192 | net.eval() 193 | 194 | netAff = ModelAff() 195 | netAff.load_state_dict(torch.load(scriptpath + "/torchparams/paramsaffineta_00079_00000.pt", map_location=device), strict=False) 196 | netAff.to(device) 197 | netAff.eval() 198 | 199 | 200 | 201 | class HippoModel(nn.Module): 202 | def __init__(self): 203 | super(HippoModel, self).__init__() 204 | self.conv0a_0 = l = nn.Conv3d(1, 16, (1,1,3), padding=0) 205 | self.conv0a_1 = l = nn.Conv3d(16, 16, (1,3,1), padding=0) 206 | self.conv0a = nn.Conv3d(16, 16, (3,1,1), padding=0) 207 | 208 | self.convf1 = nn.Conv3d(16, 48, (3,3,3), padding=0) 209 | 210 | self.maxpool1 = nn.MaxPool3d(2) 211 | 212 | self.bn1 = nn.BatchNorm3d(48, momentum=1) 213 | self.bn1.training = False 214 | self.convout0 = nn.Conv3d(48, 48, (3,3,3), padding=1) 215 | self.convout1 = nn.Conv3d(48, 48, (3,3,3), padding=1) 216 | 217 | self.maxpool2 = nn.MaxPool3d(2) 218 | 219 | self.bn2 = nn.BatchNorm3d(48, momentum=1) 220 | self.bn2.training = False 221 | 222 | self.convout2p = nn.Conv3d(48, 48, (3,3,3), padding=1) 223 | self.convout2 = nn.Conv3d(48, 48, (3,3,3), padding=1) 224 | 225 | self.convlx3 = nn.Conv3d(48, 48, (3,3,3), padding=1) 226 | 227 | self.convlx5 = nn.Conv3d(48, 48, (3,3,3), padding=1) 228 | 229 | self.convlx7 = nn.Conv3d(48, 16, (3,3,3), padding=1) 230 | 231 | self.convlx8 = nn.Conv3d(16, 1, 1, padding=0) 232 | 233 | self.blur = nn.Conv3d(1, 1, 7, padding=3) 234 | 235 | self.conv_extract = nn.Conv3d(48, 47, 3, padding=1) 236 | self.convmix = nn.Conv3d(48, 16, 3, padding=1) 237 | self.convout1x = nn.Conv3d(16, 1, 1, padding=0) 238 | 239 | def forward(self, x): 240 | x = F.relu(self.conv0a_0(x)) 241 | x = F.relu(self.conv0a_1(x)) 242 | x = F.relu(self.conv0a(x)) 243 | self.out_conv_f1 = x = F.relu(self.convf1(x)) 244 | 245 | self.out_maxpool1 = x = self.maxpool1(x) 246 | x = self.bn1(x) 247 | x = F.relu(self.convout0(x)) 248 | x = self.convout1(x) 249 | x = x + self.out_maxpool1 250 | x = F.relu(x) 251 | 252 | self.out_maxpool2 = x = self.maxpool2(x) 253 | x = self.bn2(x) 254 | x = F.relu(self.convout2p(x)) 255 | x = self.convout2(x) 256 | x = x + self.out_maxpool2 257 | x = F.relu(x) 258 | 259 | self.lx2 = F.interpolate(x, scale_factor=2, mode="nearest") 260 | 261 | x = F.relu(self.convlx3(x)) 262 | x = F.interpolate(x, scale_factor=2, mode="nearest") 263 | x = F.relu(self.convlx5(x)) 264 | x = F.interpolate(x, scale_factor=2, mode="nearest") 265 | x = F.relu(self.convlx7(x)) 266 | self.out_output1 = x = torch.sigmoid(self.convlx8(x)) 267 | 268 | x = torch.sigmoid(self.blur(x)) 269 | x = x * self.out_conv_f1 270 | x = F.leaky_relu(self.conv_extract(x)) 271 | x = torch.cat([self.out_output1, x], dim=1) 272 | 273 | x = F.relu(self.convmix(x)) 274 | self.out_output2 = x = torch.sigmoid(self.convout1x(x)) 275 | #x = torch.cat([self.out_output2, self.out_output1], dim=1) 276 | 277 | return x 278 | 279 | hipponet = HippoModel() 280 | hipponet.load_state_dict(torch.load(scriptpath + "/torchparams/hippodeep.pt")) 281 | 282 | 283 | OUTPUT_RES64 = False 284 | OUTPUT_NATIVE = True 285 | OUTPUT_DEBUG = False 286 | 287 | allsubjects_scalar_report = [] 288 | 289 | mul_homo = lambda g, Mt : g @ Mt[:3,:3].astype(np.float32) + Mt[3,:3].astype(np.float32) 290 | 291 | def indices_unitary(dimensions, dtype): 292 | dimensions = tuple(dimensions) 293 | N = len(dimensions) 294 | shape = (1,)*N 295 | res = np.empty((N,)+dimensions, dtype=dtype) 296 | for i, dim in enumerate(dimensions): 297 | res[i] = np.linspace(-1, 1, dim, dtype=dtype).reshape( shape[:i] + (dim,) + shape[i+1:] ) 298 | return res 299 | 300 | def main(): 301 | for fname in sys.argv[1:]: 302 | if "_mask" in fname: 303 | print("Skipping %s because the filename contains _mask in it" % fname) 304 | continue 305 | Ti = time.time() 306 | try: 307 | print("Loading image " + fname) 308 | outfilename = fname.replace(".mnc", ".nii").replace(".mgz", ".nii").replace(".nii.gz", ".nii").replace(".nii", "_tiv.nii.gz") 309 | img = nibabel.load(fname) 310 | 311 | if type(img) is nibabel.nifti1.Nifti1Image: 312 | img._affine = img.get_qform() # for ANTs compatibility 313 | 314 | if type(img) is nibabel.Nifti1Image: 315 | if img.header["qform_code"] == 0: 316 | if img.header["sform_code"] == 0: 317 | print(" *** Error: the header of this nifti file has no qform_code defined.") 318 | print(" Fix the header manually or reconvert from the original DICOM.") 319 | if not OUTPUT_DEBUG: 320 | continue 321 | 322 | if not np.allclose(img.get_sform(), img.get_qform()): 323 | img._affine = img.get_qform() # simplify later ANTs compatibility 324 | print("This image has an sform defined, ignoring it - work in scanner space using the qform") 325 | 326 | except: 327 | open(fname + ".warning.txt", "a").write("can't open the file\n") 328 | print(" *** Error: can't open file. Skip") 329 | continue 330 | 331 | d = img.get_fdata(caching="unchanged", dtype=np.float32) 332 | while len(d.shape) > 3: 333 | print("Warning: this looks like a timeserie. Averaging it") 334 | open(fname + ".warning.txt", "a").write("dim not 3. Averaging last dimension\n") 335 | d = d.mean(-1) 336 | 337 | d = (d - d.mean()) / d.std() 338 | 339 | o1 = nibabel.orientations.io_orientation(img.affine) 340 | o2 = np.array([[ 0., -1.], [ 1., 1.], [ 2., 1.]]) # We work in LAS space (same as the mni_icbm152 template) 341 | trn = nibabel.orientations.ornt_transform(o1, o2) # o1 to o2 (apply to o2 to obtain o1) 342 | trn_back = nibabel.orientations.ornt_transform(o2, o1) 343 | 344 | revaff1 = nibabel.orientations.inv_ornt_aff(trn, (1,1,1)) # mult on o1 to obtain o2 345 | revaff1i = nibabel.orientations.inv_ornt_aff(trn_back, (1,1,1)) # mult on o2 to obtain o1 346 | 347 | aff_orig64 = np.linalg.lstsq(bbox_world(np.identity(4), (64,64,64)), bbox_world(img.affine, img.shape[:3]), rcond=None)[0].T 348 | voxscale_native64 = np.abs(np.linalg.det(aff_orig64)) 349 | revaff64i = nibabel.orientations.inv_ornt_aff(trn_back, (64,64,64)) 350 | aff_reor64 = np.linalg.lstsq(bbox_world(revaff64i, (64,64,64)), bbox_world(img.affine, img.shape[:3]), rcond=None)[0].T 351 | 352 | wgridt = (netAff.grid @ torch.tensor(revaff1i, device=device, dtype=torch.float32))[None,...,[2,1,0]] 353 | d_orr = F.grid_sample(torch.as_tensor(d, dtype=torch.float32, device=device)[None,None], wgridt, align_corners=True) 354 | 355 | if OUTPUT_DEBUG: 356 | nibabel.Nifti1Image(np.asarray(d_orr[0,0].cpu()), aff_reor64).to_filename(outfilename.replace("_tiv", "_orig_b64")) 357 | 358 | ## Head priors 359 | T = time.time() 360 | with torch.no_grad(): 361 | out1t = net(d_orr) 362 | out1 = np.asarray(out1t.cpu()) 363 | #print("Head Inference in ", time.time() - T) 364 | 365 | ## Output head priors 366 | scalar_output = [] 367 | scalar_output_report = [] 368 | 369 | 370 | # brain mask 371 | output = out1[0,0].astype("float32") 372 | 373 | out_cc, lab = scipy.ndimage.label(output > .01) 374 | #output *= (out_cc == np.bincount(out_cc.flat)[1:].argmax()+1) 375 | brainmask_cc = torch.tensor(output) 376 | 377 | vol = (output[output > .5]).sum() * voxscale_native64 378 | if OUTPUT_DEBUG: 379 | print(" Estimated intra-cranial volume (mm^3): %d" % vol) 380 | if 0: 381 | open(outfilename.replace("_tiv.nii.gz", "_eTIV.txt"), "w").write("%d\n" % vol) 382 | scalar_output.append(vol) 383 | scalar_output_report.append(vol) 384 | 385 | if OUTPUT_RES64: 386 | out = (output.clip(0, 1) * 255).astype("uint8") 387 | nibabel.Nifti1Image(out, aff_reor64, img.header).to_filename(outfilename.replace("_tiv", "_tissues%d_b64" % 0)) 388 | 389 | if OUTPUT_NATIVE: 390 | # wgridt for native space 391 | gsx, gsy, gsz = img.shape[:3] 392 | # this is a big array, so use float16 393 | sgrid = np.rollaxis(indices_unitary((gsx,gsy,gsz), dtype=np.float16),0,4) 394 | wgridt = torch.as_tensor(mul_homo(sgrid, inv(revaff1i))[None,...,[2,1,0]], device=device, dtype=torch.float32) 395 | del sgrid 396 | 397 | dnat = np.asarray(F.grid_sample(torch.as_tensor(output, dtype=torch.float32, device=device)[None,None], wgridt, align_corners=True).cpu())[0,0] 398 | #nibabel.Nifti1Image(dnat, img.affine).to_filename(outfilename.replace("_tiv", "_tissues%d" % 0)) 399 | nibabel.Nifti1Image((dnat > .5).astype("uint8"), img.affine).to_filename(outfilename.replace("_tiv", "_brain_mask")) 400 | vol = (dnat > .5).sum() * np.abs(np.linalg.det(img.affine)) 401 | print(" Estimated intra-cranial volume (mm^3) (native space): %d" % vol) 402 | scalar_output.append(vol) 403 | scalar_output_report[-1] = vol # authoritative, so overwrite previous 404 | del dnat 405 | 406 | if 1: 407 | # cerebrum mask 408 | output = out1[0,2].astype("float32") 409 | 410 | out_cc, lab = scipy.ndimage.label(output > .01) 411 | output *= (out_cc == np.bincount(out_cc.flat)[1:].argmax()+1) 412 | 413 | vol = (output[output > .5]).sum() * voxscale_native64 414 | if OUTPUT_DEBUG: 415 | print(" Estimated cerebrum volume (mm^3): %d" % vol) 416 | if 0: 417 | open(outfilename.replace("_tiv.nii.gz", "_eTIV_nocerebellum.txt"), "w").write("%d\n" % vol) 418 | scalar_output.append(vol) 419 | 420 | if OUTPUT_RES64: 421 | out = (output.clip(0, 1) * 255).astype("uint8") 422 | nibabel.Nifti1Image(out, aff_reor64, img.header).to_filename(outfilename.replace("_tiv", "_tissues%d_b64" % 2)) 423 | if OUTPUT_NATIVE: 424 | dnat = np.asarray(F.grid_sample(torch.as_tensor(output, dtype=torch.float32, device=device)[None,None], wgridt, align_corners=True).cpu()[0,0]) 425 | #nibabel.Nifti1Image(dnat, img.affine).to_filename(outfilename.replace("_tiv", "_tissues%d" % 2)) 426 | nibabel.Nifti1Image((dnat > .5).astype("uint8"), img.affine).to_filename(outfilename.replace("_tiv", "_cerebrum_mask")) 427 | vol = (dnat > .5).sum() * np.abs(np.linalg.det(img.affine)) 428 | print(" Estimated cerebrum volume (mm^3) (native space): %d" % vol) 429 | scalar_output.append(vol) 430 | del dnat 431 | 432 | # cortex 433 | output = out1[0,1].astype("float32") 434 | output[output < .01] = 0 435 | if OUTPUT_RES64: 436 | out = (output.clip(0, 1) * 255).astype("uint8") 437 | nibabel.Nifti1Image(out, aff_reor64, img.header).to_filename(outfilename.replace("_tiv", "_tissues%d_b64" % 1)) 438 | if OUTPUT_NATIVE and OUTPUT_DEBUG: 439 | dnat = np.asarray(F.grid_sample(torch.as_tensor(output, dtype=torch.float32, device=device)[None,None], wgridt, align_corners=True).cpu()[0,0]) 440 | nibabel.Nifti1Image(dnat, img.affine).to_filename(outfilename.replace("_tiv", "_tissues%d" % 1)) 441 | del dnat 442 | 443 | 444 | ## MNI affine 445 | T = time.time() 446 | with torch.no_grad(): 447 | wc1, tA = netAff(out1t[:,[1,3]] * brainmask_cc) 448 | 449 | wnat = np.linalg.lstsq(bbox_world(img.affine, img.shape[:3]), bbox_one @ revaff1, rcond=None)[0] 450 | wmni = np.linalg.lstsq(bbox_world(affine64_mni, (64,64,64)), bbox_one, rcond=None)[0] 451 | M = (wnat @ inv(np.asarray(tA[0].cpu())) @ inv(wmni)).T 452 | # [native world coord] @ M.T -> [mni world coord] , in LAS space 453 | 454 | if OUTPUT_DEBUG: 455 | # Output MNI, mostly for debug, save in box64, uint8 456 | out2 = np.asarray(wc1.to("cpu")) 457 | out2 = np.clip((out2 * 255), 0, 255).astype("uint8") 458 | nibabel.Nifti1Image(out2[0,0], affine64_mni).to_filename(outfilename.replace("_tiv", "_mniwrapc1")) 459 | del out2 460 | if 0: 461 | out2r = np.asarray(netAff.resample_other(d_orr).cpu()) 462 | out2r = (out2r - out2r.min()) * 255 / out2r.ptp() 463 | nibabel.Nifti1Image(out2r[0,0].astype("uint8"), affine64_mni).to_filename(outfilename.replace("_tiv", "_mniwrap")) 464 | del out2r 465 | 466 | 467 | # output an ANTs-compatible matrix (AntsApplyTransforms -t) 468 | f3 = np.array([[1, 1, -1, -1],[1, 1, -1, -1], [-1, -1, 1, 1], [1, 1, 1, 1]]) # ANTs LPS 469 | MI = inv(M) * f3 470 | txt = """#Insight Transform File V1.0\nTransform: AffineTransform_float_3_3\nFixedParameters: 0 0 0\nParameters: """ 471 | txt += " ".join(["%4.6f %4.6f %4.6f" % tuple(x) for x in MI[:3,:3].tolist()]) + " %4.6f %4.6f %4.6f\n" % (MI[0,3], MI[1,3], MI[2,3]) 472 | if 0: 473 | open(outfilename.replace("_tiv.nii.gz", "_mni0Affine.txt"), "w").write(txt) 474 | 475 | u, s, vt = np.linalg.svd(MI[:3,:3]) 476 | MI3rigid = u @ vt 477 | txt = """#Insight Transform File V1.0\nTransform: AffineTransform_float_3_3\nFixedParameters: 0 0 0\nParameters: """ 478 | txt += " ".join(["%4.6f %4.6f %4.6f" % tuple(x) for x in MI3rigid.tolist()]) + " %4.6f %4.6f %4.6f\n" % (MI[0,3], MI[1,3], MI[2,3]) 479 | if 0: 480 | open(outfilename.replace("_tiv.nii.gz", "_mni0Rigid.txt"), "w").write(txt) 481 | 482 | ## Hippodeep 483 | T = time.time() 484 | 485 | imgcroproi_affine = np.array([[ -1., -0., 0., 54.], [ -0., 1., -0., -59.], [0., 0., 1., -45.], [0., 0., 0., 1.]]) 486 | imgcroproi_shape = (107, 72, 68) 487 | # coord in mm bbox 488 | gsx, gsy, gsz = 107, 72, 68 489 | sgrid = np.rollaxis(indices_unitary((gsx,gsy,gsz), dtype=np.float32),0,4) 490 | 491 | bboxnat = bbox_world(imgcroproi_affine, imgcroproi_shape) @ inv(M.T) @ wnat 492 | matzoom = np.linalg.lstsq(bbox_one, bboxnat, rcond=None)[0] # in -1..1 space 493 | # wgridt for hippo box 494 | wgridt = torch.tensor(mul_homo( sgrid, (matzoom @ revaff1i) )[None,...,[2,1,0]], device=device, dtype=torch.float32) 495 | del sgrid 496 | dout = F.grid_sample(torch.as_tensor(d, dtype=torch.float32, device=device)[None,None], wgridt, align_corners=True) 497 | # note: d was normalized from full-image 498 | d_in = np.asarray(dout[0,0].cpu()) # back to numpy since torch does not support negative step/strides 499 | 500 | if OUTPUT_RES64: 501 | d_in_u8 = (((d_in - d_in.min()) / d_in.ptp()) * 255).astype("uint8") 502 | nibabel.Nifti1Image(d_in_u8, imgcroproi_affine).to_filename(outfilename.replace("_tiv", "_affcrop")) 503 | 504 | d_in -= d_in.mean() 505 | d_in /= d_in.std() 506 | # split Left and Right (flipping Right) 507 | with torch.no_grad(): 508 | hippoR = hipponet(torch.as_tensor(d_in[None, None, 6: 54:+1,: ,2:-2 ].copy())) 509 | hippoL = hipponet(torch.as_tensor(d_in[None, None,-7:-55:-1,: ,2:-2 ].copy())) 510 | 511 | hippoRL = np.vstack([np.asarray(hippoR.cpu()), np.asarray(hippoL.cpu())]) 512 | #print("Hippo Inferrence in " + str(time.time() - T)) 513 | 514 | # smoothly rescale (.5 ~ .75) to (.5 ~ 1.) 515 | hippoRL = np.clip(((hippoRL - .5) * 2 + .5), 0, 1) * (hippoRL > .5) 516 | # lots numpy/torch copy below, because torch raises errors on negative strides 517 | output = np.zeros((2, 107, 72, 68), np.float32) 518 | output[0, -7:-55:-1,: ,2:-2][2:-2,2:-2,2:-2] = np.clip(hippoRL[1] * 255, 0, 255)#* maskL 519 | output[1, 6: 54:+1,: ,2:-2][2:-2,2:-2,2:-2] = np.clip(hippoRL[0] * 255, 0, 255) # * maskR 520 | 521 | if OUTPUT_DEBUG: 522 | #outputfn = outfilename.replace(".nii.gz", "_outseg_L.nii.gz") 523 | #nibabel.Nifti1Image(output[0], imgcroproi_affine).to_filename(outputfn) 524 | #outputfn = outfilename.replace(".nii.gz", "_outseg_R.nii.gz") 525 | #nibabel.Nifti1Image(output[1], imgcroproi_affine).to_filename(outputfn) 526 | outputfn = outfilename.replace("_tiv", "_affcrop_outseg_mask") 527 | nibabel.Nifti1Image(output.sum(0), imgcroproi_affine).to_filename(outputfn) 528 | 529 | boxvols = hippoRL[[1,0]].reshape(2, -1).sum(1) * np.abs(np.linalg.det(imgcroproi_affine @ inv(M))) 530 | scalar_output.append(boxvols) 531 | 532 | if 1: 533 | 534 | def bbox_xyz(shape, affine): 535 | " returns the worldspace of the edge of the image " 536 | s = shape[0]-1, shape[1]-1, shape[2]-1 537 | bbox = [[0,0,0], [s[0],0,0], [0,s[1],0], [0,0,s[2]], [s[0],s[1],0], [s[0],0,s[2]], [0,s[1],s[2]], [s[0],s[1],s[2]]] 538 | return mul_homo(bbox, affine.T) 539 | 540 | def indices_xyz(shape, affine, offset_vox= np.array([0,0,0])): 541 | assert (len(shape) == 3) 542 | ind = np.indices(shape).astype(np.float32) + offset_vox.reshape(3, 1,1,1).astype(np.float32) 543 | return mul_homo(np.rollaxis(ind, 0, 4), affine.T) 544 | 545 | def xyz_to_DHW3(xyz, iaffine, srcshape): 546 | affine = np.linalg.inv(iaffine) 547 | ijk3 = mul_homo(xyz, affine.T) 548 | ijk3[...,0] /= srcshape[0] -1 549 | ijk3[...,1] /= srcshape[1] -1 550 | ijk3[...,2] /= srcshape[2] -1 551 | ijk3 = ijk3 * 2 - 1 552 | DHW3 = np.swapaxes(ijk3, 0, 2) 553 | return DHW3 554 | 555 | pts = bbox_xyz(imgcroproi_shape, imgcroproi_affine) 556 | pts = mul_homo(pts, np.linalg.inv(M).T) 557 | pts_ijk = mul_homo(pts, np.linalg.inv(img.affine).T) 558 | for i in range(3): 559 | np.clip(pts_ijk[:,i], 0, img.shape[i], out = pts_ijk[:,i]) 560 | pmin = np.floor(np.min(pts_ijk, 0)).astype(int) 561 | pwidth = np.ceil(np.max(pts_ijk, 0)).astype(int) - pmin 562 | 563 | widx = indices_xyz(pwidth, img.affine, offset_vox=pmin) 564 | 565 | widx = mul_homo(widx, M.T) 566 | 567 | DHW3 = xyz_to_DHW3(widx, imgcroproi_affine, imgcroproi_shape) 568 | 569 | wdata = np.zeros(img.shape[:3], np.uint8) 570 | 571 | 572 | d = torch.tensor(output[0].T, dtype=torch.float32) 573 | outDHW = F.grid_sample(d[None,None], torch.tensor(DHW3[None]), align_corners=True) 574 | dnat = np.asarray(outDHW[0,0].permute(2,1,0)) 575 | dnat[dnat < 32] = 0 # remove noise 576 | volsAA_L = dnat.sum() / 255. * np.abs(np.linalg.det(img.affine)) 577 | wdata[pmin[0]:pmin[0]+pwidth[0], pmin[1]:pmin[1]+pwidth[1], pmin[2]:pmin[2]+pwidth[2]] = dnat.astype(np.uint8) 578 | nibabel.Nifti1Image(wdata.astype("uint8"), img.affine).to_filename(outfilename.replace("_tiv", "_mask_L")) 579 | 580 | d = torch.tensor(output[1].T, dtype=torch.float32) 581 | outDHW = F.grid_sample(d[None,None], torch.tensor(DHW3[None]), align_corners=True) 582 | dnat = np.asarray(outDHW[0,0].permute(2,1,0)) 583 | dnat[dnat < 32] = 0 # remove noise 584 | volsAA_R = dnat.sum() / 255. * np.abs(np.linalg.det(img.affine)) 585 | wdata[pmin[0]:pmin[0]+pwidth[0], pmin[1]:pmin[1]+pwidth[1], pmin[2]:pmin[2]+pwidth[2]] = dnat.astype(np.uint8) 586 | nibabel.Nifti1Image(wdata.astype("uint8"), img.affine).to_filename(outfilename.replace("_tiv", "_mask_R")) 587 | 588 | print(" Hippocampal volumes (L,R)", volsAA_L, volsAA_R) 589 | scalar_output.append([volsAA_L, volsAA_R]) 590 | scalar_output_report.append([volsAA_L, volsAA_R]) 591 | 592 | 593 | if OUTPUT_DEBUG: 594 | txt = "eTIV_mni,eTIV,cerebrum_mni,cerebrum,mni_hippoL,mni_hippoR,hippoL,hippoR\n" 595 | txt += "%4f,%4f,%4f,%4f,%4.4f,%4.4f,%4.4f,%4.4f\n" % (tuple(scalar_output[:4]) + tuple(scalar_output[4])+ tuple(scalar_output[5])) 596 | open(outfilename.replace("_tiv.nii.gz", "_scalars_hippo.csv"), "w").write(txt) 597 | 598 | if 1: 599 | txt = "eTIV,hippoL,hippoR\n" 600 | txt += "%4f,%4f,%4f\n" % (scalar_output_report[0], scalar_output_report[1][0], scalar_output_report[1][1]) 601 | open(outfilename.replace("_tiv.nii.gz", "_hippoLR_volumes.csv"), "w").write(txt) 602 | 603 | if OUTPUT_RES64: 604 | print("fslview %s %s -t .5 &" % (outfilename.replace("_tiv", "_affcrop"), outfilename.replace("_tiv", "_affcrop_outseg_mask"))) 605 | 606 | print(" Elapsed time for subject %4.2fs " % (time.time() - Ti)) 607 | print(" To display using fsleyes or fslview, try:") 608 | print(" fsleyes %s %s -a 75 -cm Red-Yellow %s -a 75 -cm Blue-Lightblue &" % (fname, outfilename.replace("_tiv", "_mask_L"), outfilename.replace("_tiv", "_mask_R"))) 609 | print(" fslview %s %s -t .5 %s -t .5 &" % (fname, outfilename.replace("_tiv", "_mask_L"), outfilename.replace("_tiv", "_mask_R"))) 610 | 611 | 612 | allsubjects_scalar_report.append( (fname, scalar_output_report[0], scalar_output_report[1][0], scalar_output_report[1][1]) ) 613 | 614 | try: 615 | print("Peak memory used (Gb) " + str(resource.getrusage(resource.RUSAGE_SELF)[2] / (1024.*1024))) 616 | except: 617 | pass 618 | 619 | print("Done") 620 | 621 | if len(sys.argv[1:]) > 1: 622 | outfilename = (os.path.dirname(fname) or ".") + "/all_subjects_hippo_report.csv" 623 | txt_entries = ["%s,%4f,%4f,%4f\n" % s for s in allsubjects_scalar_report] 624 | open(outfilename, "w").writelines( [ "filename,eTIV,hippoL,hippoR\n" ] + txt_entries) 625 | print("Volumes of every subjects saved as " + outfilename) 626 | 627 | if __name__ == "__main__": 628 | main() 629 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [project] 2 | name = "hippodeep" 3 | version="1.0" 4 | requires-python=">=3.8" 5 | description="quick segmentation of the hippocampus from T1 MRI images" 6 | dependencies = [ 7 | "torch", 8 | "nibabel", 9 | "scipy" 10 | ] 11 | 12 | [python.pip] 13 | extra-index-urls = ["https://download.pytorch.org/whl/cpu"] 14 | 15 | [project.scripts] 16 | hippodeep = "hippodeep:main" 17 | 18 | [[tool.uv.index]] 19 | name = "pytorch-cpu" 20 | url = "https://download.pytorch.org/whl/cpu" 21 | explicit = true 22 | 23 | [tool.uv.sources] 24 | torch = [ 25 | { index = "pytorch-cpu" }, 26 | ] 27 | -------------------------------------------------------------------------------- /torchparams/hippodeep.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bthyreau/hippodeep_pytorch/b90dde48db6bcc673cfa0a74bc5a6f7093d0b11c/torchparams/hippodeep.pt -------------------------------------------------------------------------------- /torchparams/params_head_00075_00000.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bthyreau/hippodeep_pytorch/b90dde48db6bcc673cfa0a74bc5a6f7093d0b11c/torchparams/params_head_00075_00000.pt -------------------------------------------------------------------------------- /torchparams/paramsaffineta_00079_00000.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bthyreau/hippodeep_pytorch/b90dde48db6bcc673cfa0a74bc5a6f7093d0b11c/torchparams/paramsaffineta_00079_00000.pt --------------------------------------------------------------------------------