├── .gitignore ├── LICENSE ├── README.md ├── deepem ├── data │ ├── augment │ │ ├── flip_rotate.py │ │ ├── flyem │ │ │ └── aug_mip1.py │ │ ├── grayscale_warping.py │ │ ├── kasthuri11 │ │ │ ├── aug_v2.py │ │ │ ├── aug_v2_valid.py │ │ │ └── aug_v2_valid_no-interp.py │ │ ├── no_aug.py │ │ └── pinky_basil │ │ │ ├── aug_mip0_v0.py │ │ │ ├── aug_mip0_valid0.py │ │ │ ├── aug_mip1_v0.py │ │ │ ├── aug_mip1_v1.py │ │ │ ├── aug_mip1_v2.py │ │ │ ├── aug_mip1_v2_2.py │ │ │ └── aug_mip1_v3.py │ ├── dataset │ │ ├── flyem │ │ │ ├── cremi_b.py │ │ │ ├── cremi_b_mip1.py │ │ │ ├── cremi_b_segd.py │ │ │ ├── cremi_dodam_mip1.py │ │ │ ├── focused_annotation.py │ │ │ ├── focused_annotation_mito.py │ │ │ ├── focused_annotation_v1.py │ │ │ ├── focused_v0.py │ │ │ ├── focused_v1.py │ │ │ ├── sparse_annotation.py │ │ │ └── sparse_v0.py │ │ ├── kasthuri11 │ │ │ └── train216_val40_test100 │ │ │ │ ├── mip1 │ │ │ │ ├── seg-m0.py │ │ │ │ └── seg-m0b.py │ │ │ │ ├── no_pad.py │ │ │ │ ├── seg-m0.py │ │ │ │ └── seg-m0b.py │ │ ├── minnie │ │ │ └── pinky_basil_minnie.py │ │ ├── pinky │ │ │ ├── base.py │ │ │ └── mip0_padded_x512_y512_z32.py │ │ └── pinky_basil │ │ │ ├── mip0_padded_x512_y512_z32.py │ │ │ └── mip1_padded_x512_y512_z32.py │ ├── modifier │ │ └── crop2x.py │ └── sampler │ │ ├── aff.py │ │ ├── aff_dynamic_bdr.py │ │ ├── aff_glia.py │ │ ├── aff_mit.py │ │ ├── aff_mye.py │ │ ├── aff_mye_blv1.py │ │ ├── aff_mye_blv1_fld0.py │ │ ├── aff_mye_blv2.py │ │ ├── aff_psd_mye.py │ │ ├── aff_psd_mye_blv.py │ │ ├── aff_syn_mye.py │ │ ├── aff_syn_mye_blv.py │ │ ├── mit.py │ │ └── psd.py ├── loss │ ├── __init__.py │ ├── affinity.py │ └── loss.py ├── models │ ├── layers.py │ ├── rsunet.py │ ├── rsunet_act.py │ ├── rsunet_deprecated.py │ ├── updown.py │ ├── updown_act.py │ └── updown_deprecated.py ├── test │ ├── cv_utils.py │ ├── forward.py │ ├── fwd_utils.py │ ├── mask.py │ ├── model.py │ ├── option.py │ ├── run.py │ └── utils.py ├── train │ ├── data.py │ ├── logger.py │ ├── model.py │ ├── option.py │ ├── run.py │ └── utils.py └── utils │ ├── py_utils.py │ └── torch_utils.py └── requirements.txt /.gitignore: -------------------------------------------------------------------------------- 1 | datasets/ 2 | experiments/ 3 | */**/__pycache__ 4 | */**/*.pyc 5 | */**/*~ 6 | *~ 7 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2019 Kisuk Lee, Nicholas Turner, Kyle Luther 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # DeepEM 2 | Deep Learning for EM Connectomics 3 | 4 | ## Citation 5 | [Lee et al. 2017](https://arxiv.org/abs/1706.00120) 6 | ``` 7 | @article{lee2017superhuman, 8 | author = {Kisuk Lee and 9 | Jonathan Zung and 10 | Peter Li and 11 | Viren Jain and 12 | H. Sebastian Seung}, 13 | title = {Superhuman Accuracy on the {SNEMI3D} Connectomics Challenge}, 14 | journal = {arXiv preprint arXiv:1706.00120}, 15 | year = {2017}, 16 | } 17 | ``` 18 | [Dorkenwald et al. 2019](https://www.biorxiv.org/content/10.1101/2019.12.29.890319v1) 19 | ``` 20 | @article {Dorkenwald2019.12.29.890319, 21 | author = {Dorkenwald, Sven and Turner, Nicholas L. and Macrina, Thomas and Lee, Kisuk and Lu, Ran and Wu, Jingpeng and Bodor, Agnes L. and Bleckert, Adam A. and Brittain, Derrick and Kemnitz, Nico and Silversmith, William M. and Ih, Dodam and Zung, Jonathan and Zlateski, Aleksandar and Tartavull, Ignacio and Yu, Szi-Chieh and Popovych, Sergiy and Wong, William and Castro, Manuel and Jordan, Chris S. and Wilson, Alyssa M. and Froudarakis, Emmanouil and Buchanan, JoAnn and Takeno, Marc and Torres, Russel and Mahalingam, Gayathri and Collman, Forrest and Schneider-Mizell, Casey and Bumbarger, Daniel J. and Li, Yang and Becker, Lynne and Suckow, Shelby and Reimer, Jacob and Tolias, Andreas S. and da Costa, Nuno Ma{\c c}arico and Reid, R. Clay and Seung, H. Sebastian}, 22 | title = {Binary and analog variation of synapses between cortical pyramidal neurons}, 23 | elocation-id = {2019.12.29.890319}, 24 | year = {2019}, 25 | doi = {10.1101/2019.12.29.890319}, 26 | publisher = {Cold Spring Harbor Laboratory}, 27 | URL = {https://www.biorxiv.org/content/early/2019/12/31/2019.12.29.890319}, 28 | eprint = {https://www.biorxiv.org/content/early/2019/12/31/2019.12.29.890319.full.pdf}, 29 | journal = {bioRxiv} 30 | } 31 | ``` 32 | -------------------------------------------------------------------------------- /deepem/data/augment/flip_rotate.py: -------------------------------------------------------------------------------- 1 | from augmentor import * 2 | 3 | 4 | def get_augmentation(is_train, recompute=False, **kwargs): 5 | augs = list() 6 | 7 | # Recompute connected components 8 | if recompute: 9 | augs.append(Label()) 10 | 11 | # Flip & rotate 12 | augs.append(FlipRotate()) 13 | 14 | return Compose(augs) 15 | -------------------------------------------------------------------------------- /deepem/data/augment/flyem/aug_mip1.py: -------------------------------------------------------------------------------- 1 | from augmentor import * 2 | 3 | 4 | def get_augmentation(is_train, box=None, missing=7, blur=7, lost=True, 5 | random=False, **kwargs): 6 | augs = list() 7 | 8 | # Brightness & contrast purterbation 9 | augs.append( 10 | MixedGrayscale2D( 11 | contrast_factor=0.5, 12 | brightness_factor=0.5, 13 | prob=1, skip=0.3)) 14 | 15 | # Mutually exclusive augmentations 16 | mutex = list() 17 | 18 | # (1) Misalingment 19 | trans = Compose([Misalign((0, 5), margin=1), 20 | Misalign((0,15), margin=1), 21 | Misalign((0,25), margin=1)]) 22 | slip = Compose([SlipMisalign((0, 5), interp=True, margin=1), 23 | SlipMisalign((0,15), interp=True, margin=1), 24 | SlipMisalign((0,25), interp=True, margin=1)]) 25 | mutex.append(Blend([trans,slip], props=[0.7,0.3])) 26 | 27 | # (2) Misalignment + missing section 28 | if is_train: 29 | mutex.append(Blend([ 30 | MisalignPlusMissing((3,15), value=0, random=random), 31 | MisalignPlusMissing((3,15), value=0, random=False) 32 | ])) 33 | else: 34 | mutex.append(MisalignPlusMissing((3,15), value=0, random=False)) 35 | 36 | # (3) Missing section 37 | if missing > 0: 38 | if is_train: 39 | mutex.append(Blend([ 40 | MixedMissingSection(maxsec=missing, individual=True, value=0, random=False), 41 | MixedMissingSection(maxsec=missing, individual=True, value=0, random=random), 42 | MissingSection(maxsec=missing, individual=False, value=0, random=random), 43 | ])) 44 | else: 45 | mutex.append( 46 | MixedMissingSection(maxsec=missing, individual=True, value=0, random=False) 47 | ) 48 | 49 | # (4) Lost section 50 | if lost: 51 | if is_train: 52 | mutex.append(Blend([ 53 | LostSection(1), 54 | LostPlusMissing(value=0, random=random), 55 | LostPlusMissing(value=0, random=False) 56 | ])) 57 | 58 | # Mutually exclusive augmentations 59 | augs.append(Blend(mutex)) 60 | 61 | # Box 62 | if is_train: 63 | if box == 'noise': 64 | augs.append( 65 | NoiseBox(sigma=(1,3), dims=(5,25), margin=(1,5,5), 66 | density=0.3, skip=0.1) 67 | ) 68 | elif box == 'fill': 69 | augs.append( 70 | FillBox(dims=(5,25), margin=(1,5,5), 71 | density=0.3, skip=0.1) 72 | ) 73 | 74 | # Out-of-focus section 75 | if blur > 0: 76 | augs.append(MixedBlurrySection(maxsec=blur)) 77 | 78 | # Warping 79 | if is_train: 80 | augs.append(Warp(skip=0.3, do_twist=False, rot_max=45.0, scale_max=1.1)) 81 | 82 | # Flip & rotate 83 | augs.append(FlipRotate()) 84 | 85 | return Compose(augs) 86 | -------------------------------------------------------------------------------- /deepem/data/augment/grayscale_warping.py: -------------------------------------------------------------------------------- 1 | from augmentor import * 2 | 3 | 4 | def get_augmentation(is_train, recompute=False, grayscale=False, warping=False, 5 | **kwargs): 6 | augs = list() 7 | 8 | # Recompute connected components 9 | if recompute: 10 | augs.append(Label()) 11 | 12 | # Brightness & contrast purterbation 13 | if is_train and grayscale: 14 | augs.append( 15 | MixedGrayscale2D( 16 | contrast_factor=0.5, 17 | brightness_factor=0.5, 18 | prob=1, skip=0.3)) 19 | 20 | # Warping 21 | if is_train and warping: 22 | augs.append(Warp(skip=0.3, do_twist=False, rot_max=45.0)) 23 | 24 | # Flip & rotate 25 | augs.append(FlipRotate()) 26 | 27 | return Compose(augs) 28 | -------------------------------------------------------------------------------- /deepem/data/augment/kasthuri11/aug_v2.py: -------------------------------------------------------------------------------- 1 | from augmentor import * 2 | 3 | 4 | def get_augmentation(is_train, recompute=False, grayscale=False, missing=0, 5 | blur=0, warping=False, misalign=0, box=None, mip=0, 6 | random=False, **kwargs): 7 | augs = list() 8 | 9 | # MIP factor 10 | mip_f = pow(2,mip) 11 | 12 | if is_train: 13 | # Brightness & contrast purterbation 14 | if grayscale: 15 | augs.append( 16 | MixedGrayscale2D( 17 | contrast_factor=0.5, 18 | brightness_factor=0.5, 19 | prob=1, skip=0.3)) 20 | 21 | # Mutually exclusive augmentations 22 | mutex = list() 23 | 24 | # Misalignment 25 | if misalign > 0: 26 | mutex.append(Blend([ 27 | Misalign((0,misalign)), 28 | SlipMisalign((0,misalign), interp=True), 29 | None], 30 | props=[0.5,0.2,0.3] 31 | )) 32 | 33 | # Missing section 34 | if missing > 0: 35 | mutex.append( 36 | MixedMissingSection(maxsec=missing, individual=True, random=random, skip=0.1) 37 | ) 38 | 39 | if misalign > 0 or missing > 0: 40 | augs.append(Blend(mutex)) 41 | 42 | # Box occlusion 43 | if box == 'fill': 44 | dims = (6//mip_f, 30//mip_f) 45 | margin = (1, 6//mip_f, 6//mip_f) 46 | aniso = 30/(6*mip_f) 47 | augs.append( 48 | FillBox(dims=dims, margin=margin, density=0.3, individual=True, 49 | aniso=aniso, skip=0.1) 50 | ) 51 | 52 | # Out-of-focus section 53 | if blur > 0: 54 | augs.append(MixedBlurrySection(maxsec=blur)) 55 | 56 | # Warping 57 | if warping: 58 | augs.append(Warp(skip=0.3, do_twist=False, rot_max=45.0)) 59 | 60 | # Flip & rotate 61 | augs.append(FlipRotate()) 62 | 63 | # Recompute connected components 64 | if recompute: 65 | augs.append(Label()) 66 | 67 | return Compose(augs) 68 | -------------------------------------------------------------------------------- /deepem/data/augment/kasthuri11/aug_v2_valid.py: -------------------------------------------------------------------------------- 1 | from augmentor import * 2 | 3 | 4 | def get_augmentation(is_train, recompute=False, grayscale=False, missing=0, 5 | blur=0, warping=False, misalign=0, box=None, mip=0, 6 | random=False, **kwargs): 7 | augs = list() 8 | 9 | # MIP factor 10 | mip_f = pow(2,mip) 11 | 12 | 13 | # Brightness & contrast purterbation 14 | if grayscale: 15 | augs.append( 16 | MixedGrayscale2D( 17 | contrast_factor=0.5, 18 | brightness_factor=0.5, 19 | prob=1, skip=0.3)) 20 | 21 | # Mutually exclusive augmentations 22 | mutex = list() 23 | 24 | # Misalignment 25 | if misalign > 0: 26 | mutex.append(Blend([ 27 | Misalign((0,misalign)), 28 | SlipMisalign((0,misalign), interp=True), 29 | None], 30 | props=[0.5,0.2,0.3] 31 | )) 32 | 33 | # Missing section 34 | if missing > 0: 35 | mutex.append( 36 | MixedMissingSection(maxsec=missing, individual=True, random=random, skip=0.1) 37 | ) 38 | 39 | if misalign > 0 or missing > 0: 40 | augs.append(Blend(mutex)) 41 | 42 | # Box occlusion 43 | if is_train: 44 | if box == 'fill': 45 | dims = (6//mip_f, 30//mip_f) 46 | margin = (1, 6//mip_f, 6//mip_f) 47 | aniso = 30/(6*mip_f) 48 | augs.append( 49 | FillBox(dims=dims, margin=margin, density=0.3, individual=True, 50 | aniso=aniso, skip=0.1) 51 | ) 52 | 53 | # Out-of-focus section 54 | if blur > 0: 55 | augs.append(MixedBlurrySection(maxsec=blur)) 56 | 57 | # Warping 58 | if is_train: 59 | if warping: 60 | augs.append(Warp(skip=0.3, do_twist=False, rot_max=45.0)) 61 | 62 | # Flip & rotate 63 | augs.append(FlipRotate()) 64 | 65 | # Recompute connected components 66 | if recompute: 67 | augs.append(Label()) 68 | 69 | return Compose(augs) 70 | -------------------------------------------------------------------------------- /deepem/data/augment/kasthuri11/aug_v2_valid_no-interp.py: -------------------------------------------------------------------------------- 1 | from augmentor import * 2 | 3 | 4 | def get_augmentation(is_train, recompute=False, grayscale=False, missing=0, 5 | blur=0, warping=False, misalign=0, box=None, mip=0, 6 | random=False, **kwargs): 7 | augs = list() 8 | 9 | # MIP factor 10 | mip_f = pow(2,mip) 11 | 12 | 13 | # Brightness & contrast purterbation 14 | if grayscale: 15 | augs.append( 16 | MixedGrayscale2D( 17 | contrast_factor=0.5, 18 | brightness_factor=0.5, 19 | prob=1, skip=0.3)) 20 | 21 | # Mutually exclusive augmentations 22 | mutex = list() 23 | 24 | # Misalignment 25 | if misalign > 0: 26 | mutex.append(Blend([ 27 | Misalign((0,misalign)), 28 | SlipMisalign((0,misalign), interp=False), 29 | None], 30 | props=[0.5,0.2,0.3] 31 | )) 32 | 33 | # Missing section 34 | if missing > 0: 35 | mutex.append( 36 | MixedMissingSection(maxsec=missing, individual=True, random=random, skip=0.1) 37 | ) 38 | 39 | if misalign > 0 or missing > 0: 40 | augs.append(Blend(mutex)) 41 | 42 | # Box occlusion 43 | if is_train: 44 | if box == 'fill': 45 | dims = (6//mip_f, 30//mip_f) 46 | margin = (1, 6//mip_f, 6//mip_f) 47 | aniso = 30/(6*mip_f) 48 | augs.append( 49 | FillBox(dims=dims, margin=margin, density=0.3, individual=True, 50 | aniso=aniso, skip=0.1) 51 | ) 52 | 53 | # Out-of-focus section 54 | if blur > 0: 55 | augs.append(MixedBlurrySection(maxsec=blur)) 56 | 57 | # Warping 58 | if is_train: 59 | if warping: 60 | augs.append(Warp(skip=0.3, do_twist=False, rot_max=45.0)) 61 | 62 | # Flip & rotate 63 | augs.append(FlipRotate()) 64 | 65 | # Recompute connected components 66 | if recompute: 67 | augs.append(Label()) 68 | 69 | return Compose(augs) 70 | -------------------------------------------------------------------------------- /deepem/data/augment/no_aug.py: -------------------------------------------------------------------------------- 1 | from augmentor import * 2 | 3 | 4 | def get_augmentation(is_train, recompute=False, **kwargs): 5 | augs = list() 6 | 7 | # Recompute connected components 8 | if recompute: 9 | augs.append(Label()) 10 | 11 | # Flip & rotate 12 | if not is_train: 13 | augs.append(FlipRotate()) 14 | 15 | return Compose(augs) 16 | -------------------------------------------------------------------------------- /deepem/data/augment/pinky_basil/aug_mip0_v0.py: -------------------------------------------------------------------------------- 1 | from augmentor import * 2 | 3 | 4 | def get_augmentation(is_train, box=None, interp=False, missing=7, blur=7, 5 | lost=True, random=False, **kwargs): 6 | # Mild misalignment 7 | m1 = Blend( 8 | [Misalign((0,10), margin=1), SlipMisalign((0,10), interp=interp, margin=1)], 9 | props=[0.7,0.3] 10 | ) 11 | 12 | # Medium misalignment 13 | m2 = Blend( 14 | [Misalign((0,30), margin=1), SlipMisalign((0,30), interp=interp, margin=1)], 15 | props=[0.7,0.3] 16 | ) 17 | 18 | # Large misalignment 19 | m3 = Blend( 20 | [Misalign((0,50), margin=1), SlipMisalign((0,50), interp=interp, margin=1)], 21 | props=[0.7,0.3] 22 | ) 23 | 24 | augs = list() 25 | 26 | # Box 27 | if is_train: 28 | if box == 'noise': 29 | augs.append( 30 | NoiseBox(sigma=(1,3), dims=(10,50), margin=(1,10,10), 31 | density=0.3, skip=0.1) 32 | ) 33 | elif box == 'fill': 34 | augs.append( 35 | FillBox(dims=(10,50), margin=(1,10,10), 36 | density=0.3, skip=0.1) 37 | ) 38 | 39 | # Brightness & contrast purterbation 40 | augs.append( 41 | MixedGrayscale2D( 42 | contrast_factor=0.5, 43 | brightness_factor=0.5, 44 | prob=1, skip=0.3)) 45 | 46 | # Missing section & misalignment 47 | to_blend = list() 48 | to_blend.append(Compose([m1,m2,m3])) 49 | if is_train: 50 | to_blend.append(Blend([ 51 | MisalignPlusMissing((5,30), value=1, random=random), 52 | MisalignPlusMissing((5,30), value=1, random=False) 53 | ])) 54 | else: 55 | to_blend.append(MisalignPlusMissing((5,30), value=1, random=False)) 56 | if missing > 0: 57 | if is_train: 58 | to_blend.append(Blend([ 59 | MixedMissingSection(maxsec=missing, individual=True, value=1, random=False), 60 | MixedMissingSection(maxsec=missing, individual=True, value=1, random=random), 61 | MissingSection(maxsec=missing, individual=False, value=1, random=random) 62 | ])) 63 | else: 64 | to_blend.append( 65 | MixedMissingSection(maxsec=missing, individual=True, value=1, random=False) 66 | ) 67 | if lost: 68 | if is_train: 69 | to_blend.append(Blend([ 70 | LostSection(1), 71 | LostPlusMissing(value=1, random=random), 72 | LostPlusMissing(value=1, random=False) 73 | ])) 74 | augs.append(Blend(to_blend)) 75 | 76 | # Out-of-focus 77 | if blur > 0: 78 | augs.append(MixedBlurrySection(maxsec=blur)) 79 | 80 | # Warping 81 | if is_train: 82 | augs.append(Warp(skip=0.3, do_twist=False, rot_max=45.0, scale_max=1.1)) 83 | 84 | # Flip & rotate 85 | augs.append(FlipRotate()) 86 | 87 | return Compose(augs) 88 | -------------------------------------------------------------------------------- /deepem/data/augment/pinky_basil/aug_mip0_valid0.py: -------------------------------------------------------------------------------- 1 | from augmentor import * 2 | 3 | 4 | def get_augmentation(is_train, box=None, interp=False, missing=7, blur=7, 5 | lost=True, random=False, **kwargs): 6 | # Mild misalignment 7 | m1 = Blend( 8 | [Misalign((0,10), margin=1), SlipMisalign((0,10), interp=interp, margin=1)], 9 | props=[0.7,0.3] 10 | ) 11 | 12 | # Medium misalignment 13 | m2 = Blend( 14 | [Misalign((0,30), margin=1), SlipMisalign((0,30), interp=interp, margin=1)], 15 | props=[0.7,0.3] 16 | ) 17 | 18 | # Large misalignment 19 | m3 = Blend( 20 | [Misalign((0,50), margin=1), SlipMisalign((0,50), interp=interp, margin=1)], 21 | props=[0.7,0.3] 22 | ) 23 | 24 | augs = list() 25 | 26 | if is_train: 27 | # Box 28 | if box == 'noise': 29 | augs.append( 30 | NoiseBox(sigma=(1,3), dims=(10,50), margin=(1,10,10), 31 | density=0.3, skip=0.1) 32 | ) 33 | elif box == 'fill': 34 | augs.append( 35 | FillBox(dims=(10,50), margin=(1,10,10), 36 | density=0.3, skip=0.1) 37 | ) 38 | 39 | # Brightness & contrast purterbation 40 | augs.append( 41 | MixedGrayscale2D( 42 | contrast_factor=0.5, 43 | brightness_factor=0.5, 44 | prob=1, skip=0.3)) 45 | 46 | # Missing section & misalignment 47 | to_blend = list() 48 | to_blend.append(Compose([m1,m2,m3])) 49 | to_blend.append(Blend([ 50 | MisalignPlusMissing((5,30), value=1, random=random), 51 | MisalignPlusMissing((5,30), value=1, random=False) 52 | ])) 53 | if missing > 0: 54 | to_blend.append(Blend([ 55 | MixedMissingSection(maxsec=missing, individual=True, value=1, random=random), 56 | MixedMissingSection(maxsec=missing, individual=False, value=1, random=random), 57 | MixedMissingSection(maxsec=missing, individual=False, value=1, random=False) 58 | ])) 59 | if lost: 60 | to_blend.append(Blend([ 61 | LostSection(1), 62 | LostPlusMissing(value=1, random=random), 63 | LostPlusMissing(value=1, random=False) 64 | ])) 65 | augs.append(Blend(to_blend)) 66 | 67 | # Out-of-focus 68 | if blur > 0: 69 | augs.append(MixedBlurrySection(maxsec=blur)) 70 | 71 | # Warping 72 | augs.append(Warp(skip=0.3, do_twist=False, rot_max=45.0, scale_max=1.1)) 73 | 74 | # Flip & rotate 75 | augs.append(FlipRotate()) 76 | 77 | return Compose(augs) 78 | -------------------------------------------------------------------------------- /deepem/data/augment/pinky_basil/aug_mip1_v0.py: -------------------------------------------------------------------------------- 1 | from augmentor import * 2 | 3 | 4 | def get_augmentation(is_train, box=None, interp=False, missing=7, blur=7, 5 | lost=True, **kwargs): 6 | # Mild misalignment 7 | m1 = Blend( 8 | [Misalign((0,5), margin=1), SlipMisalign((0,5), interp=interp, margin=1)], 9 | props=[0.7,0.3] 10 | ) 11 | 12 | # Medium misalignment 13 | m2 = Blend( 14 | [Misalign((0,15), margin=1), SlipMisalign((0,15), interp=interp, margin=1)], 15 | props=[0.7,0.3] 16 | ) 17 | 18 | # Large misalignment 19 | m3 = Blend( 20 | [Misalign((0,25), margin=1), SlipMisalign((0,25), interp=interp, margin=1)], 21 | props=[0.7,0.3] 22 | ) 23 | 24 | augs = list() 25 | 26 | # Box 27 | if is_train: 28 | if box == 'noise': 29 | augs.append( 30 | NoiseBox(sigma=(1,3), dims=(5,25), margin=(1,5,5), 31 | density=0.3, skip=0.1) 32 | ) 33 | elif box == 'fill': 34 | augs.append( 35 | FillBox(dims=(5,25), margin=(1,5,5), 36 | density=0.3, skip=0.1) 37 | ) 38 | 39 | # Brightness & contrast purterbation 40 | augs.append( 41 | MixedGrayscale2D( 42 | contrast_factor=0.5, 43 | brightness_factor=0.5, 44 | prob=1, skip=0.3)) 45 | 46 | # Missing section & misalignment 47 | to_blend = list() 48 | to_blend.append(Compose([m1,m2,m3])) 49 | to_blend.append(MisalignPlusMissing((3,15), random=is_train)) 50 | if missing > 0: 51 | to_blend.append(MixedMissingSection( 52 | maxsec=missing, individual=False, random=is_train)) 53 | if lost: 54 | to_blend.append(Blend([ 55 | Compose([LostSection(1), LostSection(1)]), 56 | LostPlusMissing(random=is_train) 57 | ])) 58 | augs.append(Blend(to_blend)) 59 | 60 | # Out-of-focus 61 | if blur > 0: 62 | augs.append(MixedBlurrySection(maxsec=blur)) 63 | 64 | # Warping 65 | if is_train: 66 | augs.append(Warp(skip=0.3, do_twist=False, rot_max=45.0, scale_max=1.1)) 67 | 68 | # Flip & rotate 69 | augs.append(FlipRotate()) 70 | 71 | return Compose(augs) 72 | -------------------------------------------------------------------------------- /deepem/data/augment/pinky_basil/aug_mip1_v1.py: -------------------------------------------------------------------------------- 1 | from augmentor import * 2 | 3 | 4 | def get_augmentation(is_train, box=None, interp=False, missing=7, blur=7, 5 | lost=True, **kwargs): 6 | # Mild misalignment 7 | m1 = Blend( 8 | [Misalign((0,5), margin=1), SlipMisalign((0,5), interp=interp, margin=1)], 9 | props=[0.7,0.3] 10 | ) 11 | 12 | # Medium misalignment 13 | m2 = Blend( 14 | [Misalign((0,15), margin=1), SlipMisalign((0,15), interp=interp, margin=1)], 15 | props=[0.7,0.3] 16 | ) 17 | 18 | # Large misalignment 19 | m3 = Blend( 20 | [Misalign((0,25), margin=1), SlipMisalign((0,25), interp=interp, margin=1)], 21 | props=[0.7,0.3] 22 | ) 23 | 24 | augs = list() 25 | 26 | # Box 27 | if is_train: 28 | if box == 'noise': 29 | augs.append( 30 | NoiseBox(sigma=(1,3), dims=(5,25), margin=(1,5,5), 31 | density=0.3, skip=0.1) 32 | ) 33 | elif box == 'fill': 34 | augs.append( 35 | FillBox(dims=(5,25), margin=(1,5,5), 36 | density=0.3, skip=0.1) 37 | ) 38 | 39 | # Brightness & contrast purterbation 40 | augs.append( 41 | MixedGrayscale2D( 42 | contrast_factor=0.5, 43 | brightness_factor=0.5, 44 | prob=1, skip=0.3)) 45 | 46 | # Missing section & misalignment 47 | to_blend = list() 48 | to_blend.append(Compose([m1,m2,m3])) 49 | to_blend.append(MisalignPlusMissing((3,15), random=False)) 50 | if missing > 0: 51 | if is_train: 52 | to_blend.append(Blend([ 53 | MixedMissingSection(maxsec=missing, individual=True, random=True), 54 | MixedMissingSection(maxsec=missing, individual=False, random=is_train), 55 | MixedMissingSection(maxsec=missing, individual=False, random=False) 56 | ])) 57 | else: 58 | to_blend.append( 59 | MixedMissingSection(maxsec=missing, individual=False, random=False) 60 | ) 61 | if lost: 62 | to_blend.append(Blend([ 63 | Compose([LostSection(1), LostSection(1)]), 64 | LostPlusMissing(random=False) 65 | ])) 66 | augs.append(Blend(to_blend)) 67 | 68 | # Out-of-focus 69 | if blur > 0: 70 | augs.append(MixedBlurrySection(maxsec=blur)) 71 | 72 | # Warping 73 | if is_train: 74 | augs.append(Warp(skip=0.3, do_twist=False, rot_max=45.0, scale_max=1.1)) 75 | 76 | # Flip & rotate 77 | augs.append(FlipRotate()) 78 | 79 | return Compose(augs) 80 | -------------------------------------------------------------------------------- /deepem/data/augment/pinky_basil/aug_mip1_v2.py: -------------------------------------------------------------------------------- 1 | from augmentor import * 2 | 3 | 4 | def get_augmentation(is_train, box=None, interp=False, missing=7, blur=7, 5 | lost=True, random=False, **kwargs): 6 | # Mild misalignment 7 | m1 = Blend( 8 | [Misalign((0,5), margin=1), SlipMisalign((0,5), interp=interp, margin=1)], 9 | props=[0.7,0.3] 10 | ) 11 | 12 | # Medium misalignment 13 | m2 = Blend( 14 | [Misalign((0,15), margin=1), SlipMisalign((0,15), interp=interp, margin=1)], 15 | props=[0.7,0.3] 16 | ) 17 | 18 | # Large misalignment 19 | m3 = Blend( 20 | [Misalign((0,25), margin=1), SlipMisalign((0,25), interp=interp, margin=1)], 21 | props=[0.7,0.3] 22 | ) 23 | 24 | augs = list() 25 | 26 | # Box 27 | if is_train: 28 | if box == 'noise': 29 | augs.append( 30 | NoiseBox(sigma=(1,3), dims=(5,25), margin=(1,5,5), 31 | density=0.3, skip=0.1) 32 | ) 33 | elif box == 'fill': 34 | augs.append( 35 | FillBox(dims=(5,25), margin=(1,5,5), 36 | density=0.3, skip=0.1) 37 | ) 38 | 39 | # Brightness & contrast purterbation 40 | augs.append( 41 | MixedGrayscale2D( 42 | contrast_factor=0.5, 43 | brightness_factor=0.5, 44 | prob=1, skip=0.3)) 45 | 46 | # Missing section & misalignment 47 | to_blend = list() 48 | to_blend.append(Compose([m1,m2,m3])) 49 | if is_train: 50 | to_blend.append(Blend([ 51 | MisalignPlusMissing((3,15), value=1, random=random), 52 | MisalignPlusMissing((3,15), value=1, random=False) 53 | ])) 54 | else: 55 | to_blend.append(MisalignPlusMissing((3,15), value=1, random=False)) 56 | if missing > 0: 57 | if is_train: 58 | to_blend.append(Blend([ 59 | MixedMissingSection(maxsec=missing, individual=True, value=1, random=False), 60 | MixedMissingSection(maxsec=missing, individual=True, value=1, random=random), 61 | MissingSection(maxsec=missing, individual=False, value=1, random=random) 62 | ])) 63 | else: 64 | to_blend.append( 65 | MixedMissingSection(maxsec=missing, individual=True, value=1, random=False) 66 | ) 67 | if lost: 68 | if is_train: 69 | to_blend.append(Blend([ 70 | LostSection(1), 71 | LostPlusMissing(value=1, random=random), 72 | LostPlusMissing(value=1, random=False) 73 | ])) 74 | augs.append(Blend(to_blend)) 75 | 76 | # Out-of-focus 77 | if blur > 0: 78 | augs.append(MixedBlurrySection(maxsec=blur)) 79 | 80 | # Warping 81 | if is_train: 82 | augs.append(Warp(skip=0.3, do_twist=False, rot_max=45.0, scale_max=1.1)) 83 | 84 | # Flip & rotate 85 | augs.append(FlipRotate()) 86 | 87 | return Compose(augs) 88 | -------------------------------------------------------------------------------- /deepem/data/augment/pinky_basil/aug_mip1_v2_2.py: -------------------------------------------------------------------------------- 1 | from augmentor import * 2 | 3 | 4 | def get_augmentation(is_train, box=None, interp=False, missing=7, blur=7, 5 | lost=True, random=False, **kwargs): 6 | # Mild misalignment 7 | m1 = Blend( 8 | [Misalign((0,5), margin=1), SlipMisalign((0,5), interp=interp, margin=1)], 9 | props=[0.7,0.3] 10 | ) 11 | 12 | # Medium misalignment 13 | m2 = Blend( 14 | [Misalign((0,15), margin=1), SlipMisalign((0,15), interp=interp, margin=1)], 15 | props=[0.7,0.3] 16 | ) 17 | 18 | # Large misalignment 19 | m3 = Blend( 20 | [Misalign((0,25), margin=1), SlipMisalign((0,25), interp=interp, margin=1)], 21 | props=[0.7,0.3] 22 | ) 23 | 24 | augs = list() 25 | 26 | # Box 27 | if is_train: 28 | if box == 'noise': 29 | augs.append( 30 | NoiseBox(sigma=(1,3), dims=(5,25), margin=(1,5,5), 31 | density=0.3, skip=0.1) 32 | ) 33 | elif box == 'fill': 34 | augs.append( 35 | FillBox(dims=(5,25), margin=(1,5,5), 36 | density=0.3, skip=0.1) 37 | ) 38 | 39 | # Brightness & contrast purterbation 40 | augs.append( 41 | MixedGrayscale2D( 42 | contrast_factor=0.5, 43 | brightness_factor=0.5, 44 | prob=1, skip=0.3)) 45 | 46 | # Missing section & misalignment 47 | to_blend = list() 48 | to_blend.append(Compose([m1,m2,m3])) 49 | if is_train: 50 | to_blend.append(Blend([ 51 | MisalignPlusMissing((3,15), value=0, random=random), 52 | MisalignPlusMissing((3,15), value=0, random=False) 53 | ])) 54 | else: 55 | to_blend.append(MisalignPlusMissing((3,15), value=0, random=False)) 56 | if missing > 0: 57 | if is_train: 58 | to_blend.append(Blend([ 59 | MixedMissingSection(maxsec=missing, individual=True, value=0, random=False), 60 | MixedMissingSection(maxsec=missing, individual=True, value=0, random=random), 61 | MissingSection(maxsec=missing, individual=False, value=0, random=random) 62 | ])) 63 | else: 64 | to_blend.append( 65 | MixedMissingSection(maxsec=missing, individual=True, value=0, random=False) 66 | ) 67 | if lost: 68 | if is_train: 69 | to_blend.append(Blend([ 70 | LostSection(1), 71 | LostPlusMissing(value=0, random=random), 72 | LostPlusMissing(value=0, random=False) 73 | ])) 74 | augs.append(Blend(to_blend)) 75 | 76 | # Out-of-focus 77 | if blur > 0: 78 | augs.append(MixedBlurrySection(maxsec=blur)) 79 | 80 | # Warping 81 | if is_train: 82 | augs.append(Warp(skip=0.3, do_twist=False, rot_max=45.0, scale_max=1.1)) 83 | 84 | # Flip & rotate 85 | augs.append(FlipRotate()) 86 | 87 | return Compose(augs) 88 | -------------------------------------------------------------------------------- /deepem/data/augment/pinky_basil/aug_mip1_v3.py: -------------------------------------------------------------------------------- 1 | from augmentor import * 2 | 3 | 4 | def get_augmentation(is_train, box=None, missing=7, blur=7, lost=True, 5 | random=False, **kwargs): 6 | augs = list() 7 | 8 | # Box 9 | if is_train: 10 | if box == 'noise': 11 | augs.append( 12 | NoiseBox(sigma=(1,3), dims=(5,25), margin=(1,5,5), 13 | density=0.3, skip=0.1) 14 | ) 15 | elif box == 'fill': 16 | augs.append( 17 | FillBox(dims=(5,25), margin=(1,5,5), 18 | density=0.3, skip=0.1) 19 | ) 20 | 21 | # Brightness & contrast purterbation 22 | augs.append( 23 | MixedGrayscale2D( 24 | contrast_factor=0.5, 25 | brightness_factor=0.5, 26 | prob=1, skip=0.3)) 27 | 28 | # Missing section & misalignment 29 | to_blend = list() 30 | # Misalingments 31 | trans = Compose([Misalign((0, 5), margin=1), 32 | Misalign((0,15), margin=1), 33 | Misalign((0,25), margin=1)]) 34 | 35 | # Out-of-alignments 36 | slip = Compose([SlipMisalign((0, 5), interp=True, margin=1), 37 | SlipMisalign((0,15), interp=True, margin=1), 38 | SlipMisalign((0,25), interp=True, margin=1)]) 39 | to_blend.append(Blend([trans,slip], props=[0.7,0.3])) 40 | if is_train: 41 | to_blend.append(Blend([ 42 | MisalignPlusMissing((3,15), value=0, random=random), 43 | MisalignPlusMissing((3,15), value=0, random=False) 44 | ])) 45 | else: 46 | to_blend.append(MisalignPlusMissing((3,15), value=0, random=False)) 47 | if missing > 0: 48 | if is_train: 49 | to_blend.append(Blend([ 50 | MixedMissingSection(maxsec=missing, individual=True, value=0, random=False), 51 | MixedMissingSection(maxsec=missing, individual=True, value=0, random=random), 52 | MissingSection(maxsec=missing, individual=False, value=0, random=random), 53 | ])) 54 | else: 55 | to_blend.append( 56 | MixedMissingSection(maxsec=missing, individual=True, value=0, random=False) 57 | ) 58 | if lost: 59 | if is_train: 60 | to_blend.append(Blend([ 61 | LostSection(1), 62 | LostPlusMissing(value=0, random=random), 63 | LostPlusMissing(value=0, random=False) 64 | ])) 65 | augs.append(Blend(to_blend)) 66 | 67 | # Out-of-focus 68 | if blur > 0: 69 | augs.append(MixedBlurrySection(maxsec=blur)) 70 | 71 | # Warping 72 | if is_train: 73 | augs.append(Warp(skip=0.3, do_twist=False, rot_max=45.0, scale_max=1.1)) 74 | 75 | # Flip & rotate 76 | augs.append(FlipRotate()) 77 | 78 | return Compose(augs) 79 | -------------------------------------------------------------------------------- /deepem/data/dataset/flyem/cremi_b.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import dataprovider3.emio as emio 4 | 5 | 6 | # New CREMI 7 | data_dir = 'flyem/ground_truth/cremi_b/mip1/padded_x256_y256_z16' 8 | data_info = { 9 | 'cremi_b':{ 10 | 'img': 'img.h5', 11 | 'seg': 'seg.h5', 12 | 'seg_d3_b0': 'seg_d3_b0.h5', 13 | 'glia': 'glia.h5', 14 | 'msk': 'msk.h5', 15 | 'dir': '', 16 | 'loc': True, 17 | }, 18 | } 19 | 20 | 21 | def load_data(base_dir, data_ids=None, **kwargs): 22 | if data_ids is None: 23 | data_ids = data_info.keys() 24 | data = dict() 25 | base = os.path.expanduser(base_dir) 26 | dpath = os.path.join(base, data_dir) 27 | for data_id in data_ids: 28 | info = data_info[data_id] 29 | data[data_id] = load_dataset(dpath, data_id, info, **kwargs) 30 | return data 31 | 32 | 33 | def load_dataset(dpath, tag, info, class_keys=[], **kwargs): 34 | assert len(class_keys) > 0 35 | dset = dict() 36 | 37 | # Image 38 | fpath = os.path.join(dpath, info['dir'], info['img']) 39 | print(fpath) 40 | dset['img'] = emio.imread(fpath).astype('float32') 41 | dset['img'] /= 255.0 42 | 43 | # Mask 44 | fpath = os.path.join(dpath, info['dir'], 'msk_train.h5') 45 | print(fpath) 46 | dset['msk_train'] = emio.imread(fpath).astype('uint8') 47 | fpath = os.path.join(dpath, info['dir'], 'msk_val.h5') 48 | print(fpath) 49 | dset['msk_val'] = emio.imread(fpath).astype('uint8') 50 | 51 | # Segmentation 52 | if 'aff' in class_keys: 53 | fpath = os.path.join(dpath, info['dir'], info['seg']) 54 | print(fpath) 55 | dset['seg'] = emio.imread(fpath).astype('uint32') 56 | 57 | # Glia 58 | if 'glia' in class_keys: 59 | fpath = os.path.join(dpath, info['dir'], info['glia']) 60 | print(fpath) 61 | dset['glia'] = emio.imread(fpath).astype('uint8') 62 | 63 | # Additoinal info 64 | dset['loc'] = info['loc'] 65 | 66 | return dset 67 | -------------------------------------------------------------------------------- /deepem/data/dataset/flyem/cremi_b_mip1.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | 4 | import dataprovider3.emio as emio 5 | 6 | 7 | # New CREMI 8 | data_dir = 'flyem/ground_truth/cremi_b/mip1/padded_x256_y256_z16' 9 | data_info = { 10 | 'cremi_b':{ 11 | 'img': 'img.h5', 12 | 'msk': 'msk.h5', 13 | 'bdr': 'seg_d3_b0.h5', 14 | 'seg': 'seg.h5', 15 | 'seg_d3_b0': 'seg_d3_b0.h5', 16 | 'glia': 'glia.h5', 17 | 'dir': '', 18 | 'loc': True, 19 | 'glia_ids': [150,151,177,272,1032,1414,1466,1813,1914,1920,1937,2294,2299,2541], 20 | }, 21 | } 22 | 23 | 24 | def load_data(base_dir, data_ids=None, **kwargs): 25 | if data_ids is None: 26 | data_ids = data_info.keys() 27 | data = dict() 28 | base = os.path.expanduser(base_dir) 29 | dpath = os.path.join(base, data_dir) 30 | for data_id in data_ids: 31 | if data_id in data_info: 32 | info = data_info[data_id] 33 | data[data_id] = load_dataset(dpath, data_id, info, **kwargs) 34 | return data 35 | 36 | 37 | def load_dataset(dpath, tag, info, class_keys=[], glia_mask=False, **kwargs): 38 | assert len(class_keys) > 0 39 | dset = dict() 40 | 41 | # Image 42 | fpath = os.path.join(dpath, info['dir'], info['img']) 43 | print(fpath) 44 | dset['img'] = emio.imread(fpath).astype('float32') 45 | dset['img'] /= 255.0 46 | 47 | # Mask 48 | fpath = os.path.join(dpath, info['dir'], info['msk']) 49 | print(fpath) 50 | dset['msk'] = emio.imread(fpath).astype('uint8') 51 | 52 | # Segmentation 53 | fpath = os.path.join(dpath, info['dir'], info['seg_d3_b0']) 54 | print(fpath) 55 | dset['seg'] = emio.imread(fpath).astype('uint32') 56 | 57 | # Boundary 58 | if 'bdr' in class_keys: 59 | fpath = os.path.join(dpath, info['dir'], info['bdr']) 60 | print(fpath) 61 | bdr = emio.imread(fpath).astype('uint32') 62 | dset['bdr'] = bdr 63 | 64 | # Glia 65 | if 'glia' in class_keys: 66 | fpath = os.path.join(dpath, info['dir'], info['glia']) 67 | print(fpath) 68 | dset['glia'] = emio.imread(fpath).astype('uint8') 69 | 70 | # Glia mask 71 | if glia_mask: 72 | # Use original mask for glia detection. 73 | assert 'msk' in dset 74 | dset['gmsk'] = np.copy(dset['msk']) 75 | 76 | # Mask out glia. 77 | assert 'seg' in dset 78 | gmsk = ~np.isin(dset['seg'], info['glia_ids']) 79 | dset['msk'] &= gmsk 80 | 81 | # Additoinal info 82 | dset['loc'] = info['loc'] 83 | 84 | return dset 85 | -------------------------------------------------------------------------------- /deepem/data/dataset/flyem/cremi_b_segd.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import os 3 | 4 | import dataprovider3.emio as emio 5 | 6 | 7 | # New CREMI 8 | data_dir = 'flyem/ground_truth/cremi_b/mip1/padded_x256_y256_z16' 9 | data_info = { 10 | 'cremi_b':{ 11 | 'img': 'img.h5', 12 | 'seg': 'seg.h5', 13 | 'seg_d3_b0': 'seg_d3_b0.h5', 14 | 'glia': 'glia.h5', 15 | 'msk': 'msk.h5', 16 | 'dir': '', 17 | 'loc': True, 18 | 'glia_ids': [151,177,705,1414], 19 | }, 20 | } 21 | 22 | 23 | def load_data(base_dir, data_ids=None, **kwargs): 24 | if data_ids is None: 25 | data_ids = data_info.keys() 26 | data = dict() 27 | base = os.path.expanduser(base_dir) 28 | dpath = os.path.join(base, data_dir) 29 | for data_id in data_ids: 30 | info = data_info[data_id] 31 | data[data_id] = load_dataset(dpath, data_id, info, **kwargs) 32 | return data 33 | 34 | 35 | def load_dataset(dpath, tag, info, class_keys=[], glia_mask=False, **kwargs): 36 | assert len(class_keys) > 0 37 | dset = dict() 38 | 39 | # Image 40 | fpath = os.path.join(dpath, info['dir'], info['img']) 41 | print(fpath) 42 | dset['img'] = emio.imread(fpath).astype(np.float32) 43 | dset['img'] /= 255.0 44 | 45 | # Mask 46 | fpath = os.path.join(dpath, info['dir'], 'msk.h5') 47 | print(fpath) 48 | dset['msk'] = emio.imread(fpath).astype(np.uint8) 49 | 50 | # Segmentation 51 | if 'aff' in class_keys: 52 | fpath = os.path.join(dpath, info['dir'], info['seg_d3_b0']) 53 | print(fpath) 54 | dset['seg'] = emio.imread(fpath).astype(np.uint32) 55 | 56 | # Glia 57 | if 'glia' in class_keys: 58 | fpath = os.path.join(dpath, info['dir'], info['glia']) 59 | print(fpath) 60 | dset['glia'] = emio.imread(fpath).astype(np.uint8) 61 | 62 | # Glia mask 63 | if glia_mask: 64 | # Use original mask for glia detection. 65 | assert 'msk' in dset 66 | dset['gmsk'] = np.copy(dset['msk']) 67 | 68 | # Mask out glia. 69 | assert 'seg' in dset 70 | gmsk = ~np.isin(dset['seg'], info['glia_ids']) 71 | dset['msk'] &= gmsk 72 | 73 | # Additoinal info 74 | dset['loc'] = info['loc'] 75 | 76 | return dset 77 | -------------------------------------------------------------------------------- /deepem/data/dataset/flyem/cremi_dodam_mip1.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import dataprovider3.emio as emio 4 | 5 | 6 | # CREMI-Dodam dataset 7 | data_dir = 'flyem/ground_truth' 8 | data_info = { 9 | 'cremi_dodam_a':{ 10 | 'img': 'img.h5', 11 | 'seg': 'seg_b0.h5', 12 | 'msk': 'msk', 13 | 'dir': 'cremi_dodam_a/mip1/padded_x0_y0_z0', 14 | 'loc': True, 15 | }, 16 | 'cremi_dodam_b':{ 17 | 'img': 'img.h5', 18 | 'seg': 'seg_b0.h5', 19 | 'msk': 'msk', 20 | 'dir': 'cremi_dodam_b/mip1/padded_x0_y0_z0', 21 | 'loc': True, 22 | }, 23 | 'cremi_dodam_c':{ 24 | 'img': 'img.h5', 25 | 'seg': 'seg_b0.h5', 26 | 'msk': 'msk', 27 | 'dir': 'cremi_dodam_c/mip1/padded_x0_y0_z0', 28 | 'loc': True, 29 | }, 30 | } 31 | 32 | 33 | def load_data(base_dir, data_ids=None, **kwargs): 34 | if data_ids is None: 35 | data_ids = data_info.keys() 36 | data = dict() 37 | base = os.path.expanduser(base_dir) 38 | dpath = os.path.join(base, data_dir) 39 | for data_id in data_ids: 40 | if data_id in data_info: 41 | info = data_info[data_id] 42 | data[data_id] = load_dataset(dpath, data_id, info, **kwargs) 43 | return data 44 | 45 | 46 | def load_dataset(dpath, tag, info, class_keys=[], **kwargs): 47 | assert len(class_keys) > 0 48 | dset = dict() 49 | 50 | # Image 51 | fpath = os.path.join(dpath, info['dir'], info['img']) 52 | print(fpath) 53 | dset['img'] = emio.imread(fpath).astype('float32') 54 | dset['img'] /= 255.0 55 | 56 | # Segmentation 57 | fpath = os.path.join(dpath, info['dir'], info['seg']) 58 | print(fpath) 59 | dset['seg'] = emio.imread(fpath).astype('uint32') 60 | 61 | # Mask 62 | fpath = os.path.join(dpath, info['dir'], info['msk'] + '_train.h5') 63 | print(fpath) 64 | dset['msk_train'] = emio.imread(fpath).astype('uint8') 65 | fpath = os.path.join(dpath, info['dir'], info['msk'] + '_val.h5') 66 | print(fpath) 67 | dset['msk_val'] = emio.imread(fpath).astype('uint8') 68 | 69 | # Additoinal info 70 | dset['loc'] = info['loc'] 71 | 72 | return dset 73 | -------------------------------------------------------------------------------- /deepem/data/dataset/flyem/focused_annotation.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | 4 | import dataprovider3.emio as emio 5 | 6 | 7 | # Focused annotation 8 | data_dir = 'flyem/ground_truth/focused_annotation' 9 | data_info = { 10 | 'vol001':{ 11 | 'img': 'img.h5', 12 | 'msk': 'msk.h5', 13 | 'seg': 'seg.h5', 14 | 'glia': 'glia.h5', 15 | 'seg_d3_b0': 'seg_d3_b0.h5', 16 | 'dir': 'vol001/mip1/padded_x512_y512_z20', 17 | 'lamellae': [287], 18 | 'loc': True, 19 | }, 20 | 'vol002':{ 21 | 'img': 'img.h5', 22 | 'msk': 'msk.h5', 23 | 'seg': 'seg.h5', 24 | 'seg_d3_b0': 'seg_d3_b0.h5', 25 | 'dir': 'vol002/mip1/padded_x512_y512_z20', 26 | 'loc': True, 27 | }, 28 | 'vol003':{ 29 | 'img': 'img.h5', 30 | 'msk': 'msk.h5', 31 | 'seg': 'seg.h5', 32 | 'glia': 'glia.h5', 33 | 'seg_d3_b0': 'seg_d3_b0.h5', 34 | 'dir': 'vol003/mip1/padded_x512_y512_z20', 35 | 'loc': True, 36 | }, 37 | 'vol004':{ 38 | 'img': 'img.h5', 39 | 'msk': 'msk.h5', 40 | 'seg': 'seg.h5', 41 | 'glia': 'glia.h5', 42 | 'seg_d3_b0': 'seg_d3_b0.h5', 43 | 'dir': 'vol004/mip1/padded_x512_y512_z20', 44 | 'rosetta': [1], 45 | 'loc': True, 46 | }, 47 | 'vol005':{ 48 | 'img': 'img.h5', 49 | 'msk': 'msk.h5', 50 | 'seg': 'seg.h5', 51 | 'glia': 'glia.h5', 52 | 'seg_d3_b0': 'seg_d3_b0.h5', 53 | 'dir': 'vol005/mip1/padded_x512_y512_z20', 54 | 'esophagus': [1], 55 | 'loc': True, 56 | }, 57 | 'vol006':{ 58 | 'img': 'img.h5', 59 | 'msk': 'msk.h5', 60 | 'seg': 'seg.h5', 61 | 'seg_d3_b0': 'seg_d3_b0.h5', 62 | 'dir': 'vol006/mip1/padded_x512_y512_z20', 63 | 'loc': True, 64 | }, 65 | 'vol007':{ 66 | 'img': 'img.h5', 67 | 'msk': 'msk.h5', 68 | 'seg': 'seg.h5', 69 | 'glia': 'glia.h5', 70 | 'seg_d3_b0': 'seg_d3_b0.h5', 71 | 'dir': 'vol007/mip1/padded_x512_y512_z20', 72 | 'loc': True, 73 | }, 74 | 'vol008':{ 75 | 'img': 'img.h5', 76 | 'msk': 'msk.h5', 77 | 'seg': 'seg.h5', 78 | 'glia': 'glia.h5', 79 | 'seg_d3_b0': 'seg_d3_b0.h5', 80 | 'dir': 'vol008/mip1/padded_x512_y512_z20', 81 | 'loc': True, 82 | }, 83 | 'vol009':{ 84 | 'img': 'img.h5', 85 | 'msk': 'msk.h5', 86 | 'seg': 'seg.h5', 87 | 'glia': 'glia.h5', 88 | 'seg_d3_b0': 'seg_d3_b0.h5', 89 | 'dir': 'vol009/mip1/padded_x512_y512_z20', 90 | 'glia_msk': [52], 91 | 'loc': True, 92 | }, 93 | 'vol010':{ 94 | 'img': 'img.h5', 95 | 'msk': 'msk.h5', 96 | 'seg': 'seg.h5', 97 | 'glia': 'glia.h5', 98 | 'seg_d3_b0': 'seg_d3_b0.h5', 99 | 'dir': 'vol010/mip1/padded_x512_y512_z20', 100 | 'loc': True, 101 | }, 102 | 'vol011':{ 103 | 'img': 'img.h5', 104 | 'msk': 'msk.h5', 105 | 'seg': 'seg.h5', 106 | 'glia': 'glia.h5', 107 | 'seg_d3_b0': 'seg_d3_b0.h5', 108 | 'dir': 'vol011/mip1/padded_x512_y512_z20', 109 | 'loc': True, 110 | 'lamellae': [52], 111 | }, 112 | } 113 | 114 | 115 | def load_data(base_dir, data_ids=None, **kwargs): 116 | if data_ids is None: 117 | data_ids = data_info.keys() 118 | data = dict() 119 | base = os.path.expanduser(base_dir) 120 | dpath = os.path.join(base, data_dir) 121 | for data_id in data_ids: 122 | if data_id in data_info: 123 | info = data_info[data_id] 124 | data[data_id] = load_dataset(dpath, data_id, info, **kwargs) 125 | return data 126 | 127 | 128 | def load_dataset(dpath, tag, info, class_keys=[], **kwargs): 129 | assert len(class_keys) > 0 130 | dset = dict() 131 | 132 | # Image 133 | fpath = os.path.join(dpath, info['dir'], info['img']) 134 | print(fpath) 135 | dset['img'] = emio.imread(fpath).astype(np.float32) 136 | dset['img'] /= 255.0 137 | 138 | # Mask 139 | fpath = os.path.join(dpath, info['dir'], info['msk']) 140 | print(fpath) 141 | dset['msk'] = emio.imread(fpath).astype(np.uint8) 142 | 143 | # Segmentation 144 | fpath = os.path.join(dpath, info['dir'], info['seg_d3_b0']) 145 | print(fpath) 146 | dset['seg'] = emio.imread(fpath).astype(np.uint32) 147 | 148 | # Special case 149 | if 'lamellae' in info: 150 | idx = np.isin(dset['seg'], info['lamellae']) 151 | dset['seg'][idx] = 0 152 | 153 | # Glia 154 | if 'glia' in class_keys: 155 | if 'glia' in info: 156 | fpath = os.path.join(dpath, info['dir'], info['glia']) 157 | print(fpath) 158 | dset['glia'] = emio.imread(fpath).astype(np.uint8) 159 | else: 160 | dset['glia'] = np.zeros_like(dset['msk']) 161 | 162 | # Mask out 163 | if 'rosetta' in info: 164 | idx = np.isin(dset['seg'], info['rosetta']) 165 | dset['msk'][idx] = 0 166 | 167 | if 'esophagus' in info: 168 | idx = np.isin(dset['seg'], info['esophagus']) 169 | dset['msk'][idx] = 0 170 | 171 | if 'glia_msk' in info: 172 | idx = np.isin(dset['seg'], info['glia_msk']) 173 | dset['msk'][idx] = 0 174 | 175 | # Additoinal info 176 | dset['loc'] = info['loc'] 177 | 178 | return dset 179 | -------------------------------------------------------------------------------- /deepem/data/dataset/flyem/focused_annotation_mito.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | 4 | import dataprovider3.emio as emio 5 | 6 | 7 | # Focused annotation 8 | data_dir = 'flyem/ground_truth/focused_annotation' 9 | data_info = { 10 | 'vol001':{ 11 | 'img': 'img.h5', 12 | 'msk': 'msk.h5', 13 | 'seg': 'seg.h5', 14 | 'glia': 'glia.h5', 15 | 'seg_d3_b0': 'seg_d3_b0.h5', 16 | 'dir': 'vol001/mip1/padded_x512_y512_z20', 17 | 'lamellae': [287], 18 | 'loc': True, 19 | }, 20 | 'vol002':{ 21 | 'img': 'img.h5', 22 | 'msk': 'msk.h5', 23 | 'seg': 'seg.h5', 24 | 'seg_d3_b0': 'seg_d3_b0.h5', 25 | 'mit': 'mit.h5', 26 | 'dir': 'vol002/mip1/padded_x512_y512_z20', 27 | 'loc': True, 28 | }, 29 | 'vol003':{ 30 | 'img': 'img.h5', 31 | 'msk': 'msk.h5', 32 | 'seg': 'seg.h5', 33 | 'glia': 'glia.h5', 34 | 'seg_d3_b0': 'seg_d3_b0.h5', 35 | 'mit': 'mit.h5', 36 | 'dir': 'vol003/mip1/padded_x512_y512_z20', 37 | 'trachea': [16], 38 | 'glia_msk': [78,79,80,84,85,87,89,93,97,101], 39 | 'loc': True, 40 | }, 41 | 'vol004':{ 42 | 'img': 'img.h5', 43 | 'msk': 'msk.h5', 44 | 'seg': 'seg.h5', 45 | 'glia': 'glia.h5', 46 | 'seg_d3_b0': 'seg_d3_b0.h5', 47 | 'dir': 'vol004/mip1/padded_x512_y512_z20', 48 | 'rosetta': [1], 49 | 'trachea': [26], 50 | 'loc': True, 51 | }, 52 | 'vol005':{ 53 | 'img': 'img.h5', 54 | 'msk': 'msk.h5', 55 | 'seg': 'seg.h5', 56 | 'glia': 'glia.h5', 57 | 'seg_d3_b0': 'seg_d3_b0.h5', 58 | 'mit': 'mit.h5', 59 | 'dir': 'vol005/mip1/padded_x512_y512_z20', 60 | 'esophagus': [1], 61 | 'trachea': [10], 62 | 'loc': True, 63 | }, 64 | 'vol006':{ 65 | 'img': 'img.h5', 66 | 'msk': 'msk.h5', 67 | 'seg': 'seg.h5', 68 | 'seg_d3_b0': 'seg_d3_b0.h5', 69 | 'mit': 'mit.h5', 70 | 'dir': 'vol006/mip1/padded_x512_y512_z20', 71 | 'glia_msk': [108,122], 72 | 'loc': True, 73 | }, 74 | 'vol007':{ 75 | 'img': 'img.h5', 76 | 'msk': 'msk.h5', 77 | 'seg': 'seg.h5', 78 | 'glia': 'glia.h5', 79 | 'mit': 'mit.h5', 80 | 'seg_d3_b0': 'seg_d3_b0.h5', 81 | 'dir': 'vol007/mip1/padded_x512_y512_z20', 82 | 'loc': True, 83 | }, 84 | 'vol008':{ 85 | 'img': 'img.h5', 86 | 'msk': 'msk.h5', 87 | 'seg': 'seg.h5', 88 | 'glia': 'glia.h5', 89 | 'seg_d3_b0': 'seg_d3_b0.h5', 90 | 'mit': 'mit.h5', 91 | 'dir': 'vol008/mip1/padded_x512_y512_z20', 92 | 'trachea': [67], 93 | 'glia_msk': [62,74,75,76,77,79,80], 94 | 'loc': True, 95 | }, 96 | 'vol009':{ 97 | 'img': 'img.h5', 98 | 'msk': 'msk.h5', 99 | 'seg': 'seg.h5', 100 | 'glia': 'glia.h5', 101 | 'seg_d3_b0': 'seg_d3_b0.h5', 102 | 'mit': 'mit.h5', 103 | 'dir': 'vol009/mip1/padded_x512_y512_z20', 104 | 'glia_msk': [52], 105 | 'loc': True, 106 | }, 107 | 'vol010':{ 108 | 'img': 'img.h5', 109 | 'msk': 'msk.h5', 110 | 'seg': 'seg.h5', 111 | 'glia': 'glia.h5', 112 | 'seg_d3_b0': 'seg_d3_b0.h5', 113 | 'mit': 'mit.h5', 114 | 'dir': 'vol010/mip1/padded_x512_y512_z20', 115 | 'glia_msk': [5,18,20], 116 | 'dark_cell': [1], 117 | 'loc': True, 118 | }, 119 | 'vol011':{ 120 | 'img': 'img.h5', 121 | 'msk': 'msk.h5', 122 | 'seg': 'seg.h5', 123 | 'glia': 'glia.h5', 124 | 'seg_d3_b0': 'seg_d3_b0.h5', 125 | 'mit': 'mit.h5', 126 | 'dir': 'vol011/mip1/padded_x512_y512_z20', 127 | 'lamellae': [52], 128 | 'trachea': [9], 129 | 'glia_msk': [37,38,40,42,43], 130 | 'loc': True, 131 | }, 132 | } 133 | 134 | 135 | def load_data(base_dir, data_ids=None, **kwargs): 136 | if data_ids is None: 137 | data_ids = data_info.keys() 138 | data = dict() 139 | base = os.path.expanduser(base_dir) 140 | dpath = os.path.join(base, data_dir) 141 | for data_id in data_ids: 142 | if data_id in data_info: 143 | info = data_info[data_id] 144 | data[data_id] = load_dataset(dpath, data_id, info, **kwargs) 145 | return data 146 | 147 | 148 | def load_dataset(dpath, tag, info, class_keys=[], **kwargs): 149 | assert len(class_keys) > 0 150 | dset = dict() 151 | 152 | # Image 153 | fpath = os.path.join(dpath, info['dir'], info['img']) 154 | print(fpath) 155 | dset['img'] = emio.imread(fpath).astype(np.float32) 156 | dset['img'] /= 255.0 157 | 158 | # Mask 159 | fpath = os.path.join(dpath, info['dir'], info['msk']) 160 | print(fpath) 161 | dset['msk'] = emio.imread(fpath).astype(np.uint8) 162 | 163 | # Mitochondria 164 | if 'mit' in info: 165 | fpath = os.path.join(dpath, info['dir'], info['mit']) 166 | print(fpath) 167 | dset['mit'] = (emio.imread(fpath) > 0).astype(np.uint8) 168 | else: 169 | dset['mit'] = np.zeros_like(dset['msk']) 170 | 171 | # Additoinal info 172 | dset['loc'] = info['loc'] 173 | 174 | return dset 175 | -------------------------------------------------------------------------------- /deepem/data/dataset/flyem/focused_annotation_v1.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | 4 | import dataprovider3.emio as emio 5 | 6 | 7 | # Focused annotation 8 | data_dir = 'flyem/ground_truth/focused_annotation' 9 | data_info = { 10 | 'vol001':{ 11 | 'img': 'img.h5', 12 | 'msk': 'msk.h5', 13 | 'seg': 'seg.h5', 14 | 'glia': 'glia.h5', 15 | 'seg_d3_b0': 'seg_d3_b0.h5', 16 | 'dir': 'vol001/mip1/padded_x512_y512_z20', 17 | 'lamellae': [287], 18 | 'loc': True, 19 | }, 20 | 'vol002':{ 21 | 'img': 'img.h5', 22 | 'msk': 'msk.h5', 23 | 'seg': 'seg.h5', 24 | 'seg_d3_b0': 'seg_d3_b0.h5', 25 | 'mit': 'mit.h5', 26 | 'dir': 'vol002/mip1/padded_x512_y512_z20', 27 | 'loc': True, 28 | }, 29 | 'vol003':{ 30 | 'img': 'img.h5', 31 | 'msk': 'msk.h5', 32 | 'seg': 'seg.h5', 33 | 'glia': 'glia.h5', 34 | 'seg_d3_b0': 'seg_d3_b0.h5', 35 | 'mit': 'mit.h5', 36 | 'dir': 'vol003/mip1/padded_x512_y512_z20', 37 | 'trachea': [16], 38 | 'glia_msk': [78,79,80,84,85,87,89,93,97,101], 39 | 'loc': True, 40 | }, 41 | 'vol004':{ 42 | 'img': 'img.h5', 43 | 'msk': 'msk.h5', 44 | 'seg': 'seg.h5', 45 | 'glia': 'glia.h5', 46 | 'seg_d3_b0': 'seg_d3_b0.h5', 47 | 'dir': 'vol004/mip1/padded_x512_y512_z20', 48 | 'rosetta': [1], 49 | 'trachea': [26], 50 | 'loc': True, 51 | }, 52 | 'vol005':{ 53 | 'img': 'img.h5', 54 | 'msk': 'msk.h5', 55 | 'seg': 'seg.h5', 56 | 'glia': 'glia.h5', 57 | 'seg_d3_b0': 'seg_d3_b0.h5', 58 | 'mit': 'mit.h5', 59 | 'dir': 'vol005/mip1/padded_x512_y512_z20', 60 | 'esophagus': [1], 61 | 'trachea': [10], 62 | 'loc': True, 63 | }, 64 | 'vol006':{ 65 | 'img': 'img.h5', 66 | 'msk': 'msk.h5', 67 | 'seg': 'seg.h5', 68 | 'seg_d3_b0': 'seg_d3_b0.h5', 69 | 'mit': 'mit.h5', 70 | 'dir': 'vol006/mip1/padded_x512_y512_z20', 71 | 'glia_msk': [108,122], 72 | 'loc': True, 73 | }, 74 | 'vol007':{ 75 | 'img': 'img.h5', 76 | 'msk': 'msk.h5', 77 | 'seg': 'seg.h5', 78 | 'glia': 'glia.h5', 79 | 'mit': 'mit.h5', 80 | 'seg_d3_b0': 'seg_d3_b0.h5', 81 | 'dir': 'vol007/mip1/padded_x512_y512_z20', 82 | 'loc': True, 83 | }, 84 | 'vol008':{ 85 | 'img': 'img.h5', 86 | 'msk': 'msk.h5', 87 | 'seg': 'seg.h5', 88 | 'glia': 'glia.h5', 89 | 'seg_d3_b0': 'seg_d3_b0.h5', 90 | 'mit': 'mit.h5', 91 | 'dir': 'vol008/mip1/padded_x512_y512_z20', 92 | 'trachea': [67], 93 | 'glia_msk': [62,74,75,76,77,79,80], 94 | 'loc': True, 95 | }, 96 | 'vol009':{ 97 | 'img': 'img.h5', 98 | 'msk': 'msk.h5', 99 | 'seg': 'seg.h5', 100 | 'glia': 'glia.h5', 101 | 'seg_d3_b0': 'seg_d3_b0.h5', 102 | 'mit': 'mit.h5', 103 | 'dir': 'vol009/mip1/padded_x512_y512_z20', 104 | 'glia_msk': [52], 105 | 'loc': True, 106 | }, 107 | 'vol010':{ 108 | 'img': 'img.h5', 109 | 'msk': 'msk.h5', 110 | 'seg': 'seg.h5', 111 | 'glia': 'glia.h5', 112 | 'seg_d3_b0': 'seg_d3_b0.h5', 113 | 'mit': 'mit.h5', 114 | 'dir': 'vol010/mip1/padded_x512_y512_z20', 115 | 'glia_msk': [5,18,20], 116 | 'dark_cell': [1], 117 | 'loc': True, 118 | }, 119 | 'vol011':{ 120 | 'img': 'img.h5', 121 | 'msk': 'msk.h5', 122 | 'seg': 'seg.h5', 123 | 'glia': 'glia.h5', 124 | 'seg_d3_b0': 'seg_d3_b0.h5', 125 | 'mit': 'mit.h5', 126 | 'dir': 'vol011/mip1/padded_x512_y512_z20', 127 | 'lamellae': [52], 128 | 'trachea': [9], 129 | 'glia_msk': [37,38,40,42,43], 130 | 'loc': True, 131 | }, 132 | } 133 | 134 | 135 | def load_data(base_dir, data_ids=None, **kwargs): 136 | if data_ids is None: 137 | data_ids = data_info.keys() 138 | data = dict() 139 | base = os.path.expanduser(base_dir) 140 | dpath = os.path.join(base, data_dir) 141 | for data_id in data_ids: 142 | if data_id in data_info: 143 | info = data_info[data_id] 144 | data[data_id] = load_dataset(dpath, data_id, info, **kwargs) 145 | return data 146 | 147 | 148 | def load_dataset(dpath, tag, info, class_keys=[], **kwargs): 149 | assert len(class_keys) > 0 150 | dset = dict() 151 | 152 | # Image 153 | fpath = os.path.join(dpath, info['dir'], info['img']) 154 | print(fpath) 155 | dset['img'] = emio.imread(fpath).astype(np.float32) 156 | dset['img'] /= 255.0 157 | 158 | # Mask 159 | fpath = os.path.join(dpath, info['dir'], info['msk']) 160 | print(fpath) 161 | dset['msk'] = emio.imread(fpath).astype(np.uint8) 162 | 163 | # Segmentation 164 | fpath = os.path.join(dpath, info['dir'], info['seg_d3_b0']) 165 | print(fpath) 166 | dset['seg'] = emio.imread(fpath).astype(np.uint32) 167 | 168 | # Special case 169 | if 'lamellae' in info: 170 | idx = np.isin(dset['seg'], info['lamellae']) 171 | dset['seg'][idx] = 0 172 | 173 | # Glia 174 | if 'glia' in class_keys: 175 | if 'glia' in info: 176 | fpath = os.path.join(dpath, info['dir'], info['glia']) 177 | print(fpath) 178 | dset['glia'] = emio.imread(fpath).astype(np.uint8) 179 | else: 180 | dset['glia'] = np.zeros_like(dset['msk']) 181 | 182 | # Mask out 183 | if 'rosetta' in info: 184 | idx = np.isin(dset['seg'], info['rosetta']) 185 | dset['msk'][idx] = 0 186 | 187 | if 'esophagus' in info: 188 | idx = np.isin(dset['seg'], info['esophagus']) 189 | dset['msk'][idx] = 0 190 | 191 | if 'glia_msk' in info: 192 | idx = np.isin(dset['seg'], info['glia_msk']) 193 | dset['msk'][idx] = 0 194 | 195 | if 'dark_cell' in info: 196 | idx = np.isin(dset['seg'], info['dark_cell']) 197 | dset['msk'][idx] = 0 198 | 199 | # Additoinal info 200 | dset['loc'] = info['loc'] 201 | 202 | return dset 203 | -------------------------------------------------------------------------------- /deepem/data/dataset/flyem/focused_v0.py: -------------------------------------------------------------------------------- 1 | from deepem.data.dataset.flyem import cremi_b_mip1 as cremi_new 2 | from deepem.data.dataset.flyem import cremi_dodam_mip1 as cremi_dodam 3 | from deepem.data.dataset.flyem import focused_annotation as focused 4 | 5 | 6 | def load_data(*args, **kwargs): 7 | d1 = cremi_new.load_data(*args, **kwargs) 8 | d2 = cremi_dodam.load_data(*args, **kwargs) 9 | d3 = focused.load_data(*args, **kwargs) 10 | 11 | data = dict() 12 | data.update(d1) 13 | data.update(d2) 14 | data.update(d3) 15 | return data 16 | -------------------------------------------------------------------------------- /deepem/data/dataset/flyem/focused_v1.py: -------------------------------------------------------------------------------- 1 | from deepem.data.dataset.flyem import cremi_b_mip1 as cremi_new 2 | from deepem.data.dataset.flyem import focused_annotation_v1 as focused 3 | 4 | 5 | def load_data(*args, **kwargs): 6 | d1 = cremi_new.load_data(*args, **kwargs) 7 | d2 = focused.load_data(*args, **kwargs) 8 | 9 | data = dict() 10 | data.update(d1) 11 | data.update(d2) 12 | return data 13 | -------------------------------------------------------------------------------- /deepem/data/dataset/flyem/sparse_annotation.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | 4 | import dataprovider3.emio as emio 5 | 6 | 7 | # Sparse annotation 8 | data_dir = 'flyem/ground_truth/sparse_annotation' 9 | data_info = { 10 | 'img': 'img.h5', 11 | 'msk': 'msk.h5', 12 | 'seg': 'seg.h5', 13 | 'seg_d3_b0': 'seg_d3_b0.h5', 14 | 'dir': 'mip1/padded_x384_y384_z20', 15 | 'loc': True, 16 | } 17 | data_keys = ['svol0{:0>2d}'.format(i+1) for i in range(44)] 18 | 19 | 20 | def load_data(base_dir, data_ids=None, **kwargs): 21 | if data_ids is None: 22 | data_ids = ['sparse_superset'] 23 | data = dict() 24 | base_dir = os.path.expanduser(base_dir) 25 | base_dir = os.path.join(base_dir, data_dir) 26 | if 'sparse_superset' in data_ids: 27 | for data_id in data_keys: 28 | dpath = os.path.join(base_dir, data_id) 29 | if os.path.exists(dpath): 30 | data[data_id] = load_dataset(dpath, **kwargs) 31 | return {'sparse_superset': data} 32 | 33 | 34 | def load_dataset(dpath, **kwargs): 35 | dset = dict() 36 | 37 | # Image 38 | fpath = os.path.join(dpath, data_info['dir'], data_info['img']) 39 | print(fpath) 40 | dset['img'] = emio.imread(fpath).astype(np.float32) 41 | dset['img'] /= 255.0 42 | 43 | # Mask 44 | fpath = os.path.join(dpath, data_info['dir'], data_info['msk']) 45 | print(fpath) 46 | dset['msk'] = emio.imread(fpath).astype(np.uint8) 47 | 48 | # Segmentation 49 | fpath = os.path.join(dpath, data_info['dir'], data_info['seg_d3_b0']) 50 | print(fpath) 51 | dset['seg'] = emio.imread(fpath).astype(np.uint32) 52 | 53 | # Background mask 54 | idx = dset['seg'] == 1 55 | dset['msk'][idx] = 0 56 | 57 | # Membrane swirl 58 | idx = dset['seg'] == 2 59 | dset['seg'][idx] = 0 60 | 61 | # Large lamellar structure 62 | idx = dset['seg'] == 3 63 | dset['msk'][idx] = 0 64 | 65 | # Additoinal info 66 | dset['loc'] = data_info['loc'] 67 | 68 | return dset 69 | -------------------------------------------------------------------------------- /deepem/data/dataset/flyem/sparse_v0.py: -------------------------------------------------------------------------------- 1 | from deepem.data.dataset.flyem import cremi_b_mip1 as cremi_new 2 | from deepem.data.dataset.flyem import focused_annotation as focused 3 | from deepem.data.dataset.flyem import sparse_annotation as sparse 4 | 5 | 6 | def load_data(*args, **kwargs): 7 | d1 = cremi_new.load_data(*args, **kwargs) 8 | d2 = focused.load_data(*args, **kwargs) 9 | d3 = sparse.load_data(*args, **kwargs) 10 | 11 | data = dict() 12 | data.update(d1) 13 | data.update(d2) 14 | data.update(d3) 15 | return data 16 | -------------------------------------------------------------------------------- /deepem/data/dataset/kasthuri11/train216_val40_test100/mip1/seg-m0.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import os 3 | 4 | import dataprovider3.emio as emio 5 | 6 | 7 | # Kasthuri11 dataset 8 | data_info = { 9 | 'train_AC4':{ 10 | 'img': 'img.h5', 11 | 'seg': 'segm0.h5', 12 | 'bdr': 'segm0.h5', 13 | 'msk': 'msk.h5', 14 | 'mye': 'mye.h5', 15 | 'dir': 'train216_val40_test100/mip1/padded_x0_y0_z0/train/AC4', 16 | 'loc': False, 17 | }, 18 | 'train_AC3':{ 19 | 'img': 'img.h5', 20 | 'seg': 'segm0.h5', 21 | 'bdr': 'segm0.h5', 22 | 'msk': 'msk.h5', 23 | 'mye': 'mye.h5', 24 | 'dir': 'train216_val40_test100/mip1/padded_x0_y0_z0/train/AC3', 25 | 'loc': False, 26 | }, 27 | 'val':{ 28 | 'img': 'img.h5', 29 | 'seg': 'segm0.h5', 30 | 'bdr': 'segm0.h5', 31 | 'msk': 'msk.h5', 32 | 'mye': 'mye.h5', 33 | 'dir': 'train216_val40_test100/mip1/padded_x0_y0_z0/val', 34 | 'loc': False, 35 | }, 36 | 'test':{ 37 | 'img': 'img.h5', 38 | 'seg': 'segm0.h5', 39 | 'bdr': 'segm0.h5', 40 | 'msk': 'msk.h5', 41 | 'mye': 'mye.h5', 42 | 'dir': 'train216_val40_test100/mip1/padded_x0_y0_z0/test', 43 | 'loc': False, 44 | }, 45 | } 46 | 47 | 48 | def load_data(data_dir, data_ids=None, bdr=False, mye=False, **kwargs): 49 | if data_ids is None: 50 | data_ids = data_info.keys() 51 | data = dict() 52 | dpath = os.path.expanduser(data_dir) 53 | for data_id in data_ids: 54 | info = data_info[data_id] 55 | data[data_id] = load_dataset(dpath, data_id, info, bdr=bdr, mye=mye) 56 | return data 57 | 58 | 59 | def load_dataset(dpath, tag, info, bdr=False, mye=False): 60 | dset = dict() 61 | 62 | # Image 63 | fpath = os.path.join(dpath, info['dir'], info['img']) 64 | print(fpath) 65 | img = emio.imread(fpath).astype('float32') / 255.0 66 | dset['img'] = img 67 | 68 | # Segmentation 69 | fpath = os.path.join(dpath, info['dir'], info['seg']) 70 | print(fpath) 71 | seg = emio.imread(fpath).astype('uint32') 72 | dset['seg'] = seg 73 | 74 | # Boundary 75 | if bdr: 76 | fpath = os.path.join(dpath, info['dir'], info['bdr']) 77 | print(fpath) 78 | bdr = emio.imread(fpath).astype('uint32') 79 | dset['bdr'] = bdr 80 | 81 | # Myelin 82 | if mye: 83 | fpath = os.path.join(dpath, info['dir'], info['mye']) 84 | print(fpath) 85 | mye = emio.imread(fpath).astype('uint8') 86 | dset['mye'] = mye 87 | 88 | # Train mask 89 | fpath = os.path.join(dpath, info['dir'], info['msk']) 90 | print(fpath) 91 | dset['msk'] = emio.imread(fpath).astype('uint8') 92 | 93 | # Additoinal info 94 | dset['loc'] = info['loc'] 95 | 96 | return dset 97 | -------------------------------------------------------------------------------- /deepem/data/dataset/kasthuri11/train216_val40_test100/mip1/seg-m0b.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import os 3 | 4 | import dataprovider3.emio as emio 5 | 6 | 7 | # Kasthuri11 dataset 8 | data_info = { 9 | 'train_AC4':{ 10 | 'img': 'img.h5', 11 | 'seg': 'segm0b.h5', 12 | 'bdr': 'segm0.h5', 13 | 'msk': 'msk.h5', 14 | 'mye': 'mye.h5', 15 | 'dir': 'train216_val40_test100/mip1/padded_x0_y0_z0/train/AC4', 16 | 'loc': False, 17 | }, 18 | 'train_AC3':{ 19 | 'img': 'img.h5', 20 | 'seg': 'segm0b.h5', 21 | 'bdr': 'segm0.h5', 22 | 'msk': 'msk.h5', 23 | 'mye': 'mye.h5', 24 | 'dir': 'train216_val40_test100/mip1/padded_x0_y0_z0/train/AC3', 25 | 'loc': False, 26 | }, 27 | 'val':{ 28 | 'img': 'img.h5', 29 | 'seg': 'segm0b.h5', 30 | 'bdr': 'segm0.h5', 31 | 'msk': 'msk.h5', 32 | 'mye': 'mye.h5', 33 | 'dir': 'train216_val40_test100/mip1/padded_x0_y0_z0/val', 34 | 'loc': False, 35 | }, 36 | 'test':{ 37 | 'img': 'img.h5', 38 | 'seg': 'segm0b.h5', 39 | 'bdr': 'segm0.h5', 40 | 'msk': 'msk.h5', 41 | 'mye': 'mye.h5', 42 | 'dir': 'train216_val40_test100/mip1/padded_x0_y0_z0/test', 43 | 'loc': False, 44 | }, 45 | } 46 | 47 | 48 | def load_data(data_dir, data_ids=None, bdr=False, mye=False, **kwargs): 49 | if data_ids is None: 50 | data_ids = data_info.keys() 51 | data = dict() 52 | dpath = os.path.expanduser(data_dir) 53 | for data_id in data_ids: 54 | info = data_info[data_id] 55 | data[data_id] = load_dataset(dpath, data_id, info, bdr=bdr, mye=mye) 56 | return data 57 | 58 | 59 | def load_dataset(dpath, tag, info, bdr=False, mye=False): 60 | dset = dict() 61 | 62 | # Image 63 | fpath = os.path.join(dpath, info['dir'], info['img']) 64 | print(fpath) 65 | img = emio.imread(fpath).astype('float32') / 255.0 66 | dset['img'] = img 67 | 68 | # Segmentation 69 | fpath = os.path.join(dpath, info['dir'], info['seg']) 70 | print(fpath) 71 | seg = emio.imread(fpath).astype('uint32') 72 | dset['seg'] = seg 73 | 74 | # Boundary 75 | if bdr: 76 | fpath = os.path.join(dpath, info['dir'], info['bdr']) 77 | print(fpath) 78 | bdr = emio.imread(fpath).astype('uint32') 79 | dset['bdr'] = bdr 80 | 81 | # Myelin 82 | if mye: 83 | fpath = os.path.join(dpath, info['dir'], info['mye']) 84 | print(fpath) 85 | mye = emio.imread(fpath).astype('uint8') 86 | dset['mye'] = mye 87 | 88 | # Train mask 89 | fpath = os.path.join(dpath, info['dir'], info['msk']) 90 | print(fpath) 91 | dset['msk'] = emio.imread(fpath).astype('uint8') 92 | 93 | # Additoinal info 94 | dset['loc'] = info['loc'] 95 | 96 | return dset 97 | -------------------------------------------------------------------------------- /deepem/data/dataset/kasthuri11/train216_val40_test100/no_pad.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import dataprovider3.emio as emio 4 | 5 | 6 | data_info = { 7 | 'train_AC4':{ 8 | 'img': 'img.h5', 9 | 'seg': 'segm0b.h5', 10 | 'msk': 'msk.h5', 11 | 'mye': 'mye.h5', 12 | 'dir': 'train216_val40_test100/mip0/padded_x0_y0_z0/train/AC4', 13 | 'loc': False, 14 | }, 15 | 'train_AC3':{ 16 | 'img': 'img.h5', 17 | 'seg': 'segm0b.h5', 18 | 'msk': 'msk.h5', 19 | 'mye': 'mye.h5', 20 | 'dir': 'train216_val40_test100/mip0/padded_x0_y0_z0/train/AC3', 21 | 'loc': False, 22 | }, 23 | 'val':{ 24 | 'img': 'img.h5', 25 | 'seg': 'segm0b.h5', 26 | 'msk': 'msk.h5', 27 | 'mye': 'mye.h5', 28 | 'dir': 'train216_val40_test100/mip0/padded_x0_y0_z0/val', 29 | 'loc': False, 30 | }, 31 | 'test':{ 32 | 'img': 'img.h5', 33 | 'seg': 'segm0b.h5', 34 | 'msk': 'msk.h5', 35 | 'mye': 'mye.h5', 36 | 'dir': 'train216_val40_test100/mip0/padded_x0_y0_z0/test', 37 | 'loc': False, 38 | }, 39 | } 40 | 41 | 42 | def load_data(data_dir, data_ids=None, **kwargs): 43 | if data_ids is None: 44 | data_ids = data_info.keys() 45 | data = dict() 46 | dpath = os.path.expanduser(data_dir) 47 | for data_id in data_ids: 48 | info = data_info[data_id] 49 | data[data_id] = load_dataset(dpath, data_id, info, **kwargs) 50 | return data 51 | 52 | 53 | def load_dataset(dpath, tag, info, class_keys=[], **kwargs): 54 | assert len(class_keys) > 0 55 | dset = dict() 56 | 57 | # Image 58 | fpath = os.path.join(dpath, info['dir'], info['img']) 59 | print(fpath) 60 | dset['img'] = emio.imread(fpath).astype('float32') / 255.0 61 | 62 | # Mask 63 | fpath = os.path.join(dpath, info['dir'], info['msk']) 64 | print(fpath) 65 | dset['msk'] = emio.imread(fpath).astype('uint8') 66 | 67 | # Segmentation 68 | if 'aff' in class_keys or 'long' in class_keys: 69 | fpath = os.path.join(dpath, info['dir'], info['seg']) 70 | print(fpath) 71 | dset['seg'] = emio.imread(fpath).astype('uint32') 72 | 73 | # Myelin 74 | if 'mye' in class_keys: 75 | fpath = os.path.join(dpath, info['dir'], info['mye']) 76 | print(fpath) 77 | mye = emio.imread(fpath).astype('uint8') 78 | dset['mye'] = mye 79 | 80 | # Additoinal info 81 | dset['loc'] = info['loc'] 82 | 83 | return dset 84 | -------------------------------------------------------------------------------- /deepem/data/dataset/kasthuri11/train216_val40_test100/seg-m0.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import os 3 | 4 | import dataprovider3.emio as emio 5 | 6 | 7 | # Kasthuri11 dataset 8 | data_info = { 9 | 'train_AC4':{ 10 | 'img': 'img.h5', 11 | 'seg': 'segm0.h5', 12 | 'bdr': 'segm0.h5', 13 | 'msk': 'msk.h5', 14 | 'mye': 'mye.h5', 15 | 'dir': 'train216_val40_test100/mip0/padded_x0_y0_z0/train/AC4', 16 | 'loc': False, 17 | }, 18 | 'train_AC3':{ 19 | 'img': 'img.h5', 20 | 'seg': 'segm0.h5', 21 | 'bdr': 'segm0.h5', 22 | 'msk': 'msk.h5', 23 | 'mye': 'mye.h5', 24 | 'dir': 'train216_val40_test100/mip0/padded_x0_y0_z0/train/AC3', 25 | 'loc': False, 26 | }, 27 | 'val':{ 28 | 'img': 'img.h5', 29 | 'seg': 'segm0.h5', 30 | 'bdr': 'segm0.h5', 31 | 'msk': 'msk.h5', 32 | 'mye': 'mye.h5', 33 | 'dir': 'train216_val40_test100/mip0/padded_x0_y0_z0/val', 34 | 'loc': False, 35 | }, 36 | 'test':{ 37 | 'img': 'img.h5', 38 | 'seg': 'segm0.h5', 39 | 'bdr': 'segm0.h5', 40 | 'msk': 'msk.h5', 41 | 'mye': 'mye.h5', 42 | 'dir': 'train216_val40_test100/mip0/padded_x0_y0_z0/test', 43 | 'loc': False, 44 | }, 45 | } 46 | 47 | 48 | def load_data(data_dir, data_ids=None, bdr=False, mye=False, **kwargs): 49 | if data_ids is None: 50 | data_ids = data_info.keys() 51 | data = dict() 52 | dpath = os.path.expanduser(data_dir) 53 | for data_id in data_ids: 54 | info = data_info[data_id] 55 | data[data_id] = load_dataset(dpath, data_id, info, bdr=bdr, mye=mye) 56 | return data 57 | 58 | 59 | def load_dataset(dpath, tag, info, bdr=False, mye=False): 60 | dset = dict() 61 | 62 | # Image 63 | fpath = os.path.join(dpath, info['dir'], info['img']) 64 | print(fpath) 65 | img = emio.imread(fpath).astype('float32') / 255.0 66 | dset['img'] = img 67 | 68 | # Segmentation 69 | fpath = os.path.join(dpath, info['dir'], info['seg']) 70 | print(fpath) 71 | seg = emio.imread(fpath).astype('uint32') 72 | dset['seg'] = seg 73 | 74 | # Boundary (or affinity) 75 | if bdr: 76 | fpath = os.path.join(dpath, info['dir'], info['bdr']) 77 | print(fpath) 78 | seg = emio.imread(fpath).astype('uint32') 79 | dset['aff'] = seg 80 | dset['bdr'] = (seg == 0).astype('uint8') 81 | 82 | # Myelin 83 | if mye: 84 | fpath = os.path.join(dpath, info['dir'], info['mye']) 85 | print(fpath) 86 | mye = emio.imread(fpath).astype('uint8') 87 | dset['mye'] = mye 88 | 89 | # Train mask 90 | fpath = os.path.join(dpath, info['dir'], info['msk']) 91 | print(fpath) 92 | dset['msk'] = emio.imread(fpath).astype('uint8') 93 | 94 | # Additoinal info 95 | dset['loc'] = info['loc'] 96 | 97 | return dset 98 | -------------------------------------------------------------------------------- /deepem/data/dataset/kasthuri11/train216_val40_test100/seg-m0b.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import os 3 | 4 | import dataprovider3.emio as emio 5 | 6 | 7 | # Kasthuri11 dataset 8 | data_info = { 9 | 'train_AC4':{ 10 | 'img': 'img.h5', 11 | 'seg': 'segm0b.h5', 12 | 'bdr': 'segm0b.h5', 13 | 'msk': 'msk.h5', 14 | 'mye': 'mye.h5', 15 | 'dir': 'train216_val40_test100/mip0/padded_x0_y0_z0/train/AC4', 16 | 'loc': False, 17 | }, 18 | 'train_AC3':{ 19 | 'img': 'img.h5', 20 | 'seg': 'segm0b.h5', 21 | 'bdr': 'segm0b.h5', 22 | 'msk': 'msk.h5', 23 | 'mye': 'mye.h5', 24 | 'dir': 'train216_val40_test100/mip0/padded_x0_y0_z0/train/AC3', 25 | 'loc': False, 26 | }, 27 | 'val':{ 28 | 'img': 'img.h5', 29 | 'seg': 'segm0b.h5', 30 | 'bdr': 'segm0b.h5', 31 | 'msk': 'msk.h5', 32 | 'mye': 'mye.h5', 33 | 'dir': 'train216_val40_test100/mip0/padded_x0_y0_z0/val', 34 | 'loc': False, 35 | }, 36 | 'test':{ 37 | 'img': 'img.h5', 38 | 'seg': 'segm0b.h5', 39 | 'bdr': 'segm0b.h5', 40 | 'msk': 'msk.h5', 41 | 'mye': 'mye.h5', 42 | 'dir': 'train216_val40_test100/mip0/padded_x0_y0_z0/test', 43 | 'loc': False, 44 | }, 45 | } 46 | 47 | 48 | def load_data(data_dir, data_ids=None, bdr=False, mye=False, **kwargs): 49 | if data_ids is None: 50 | data_ids = data_info.keys() 51 | data = dict() 52 | dpath = os.path.expanduser(data_dir) 53 | for data_id in data_ids: 54 | info = data_info[data_id] 55 | data[data_id] = load_dataset(dpath, data_id, info, bdr=bdr, mye=mye) 56 | return data 57 | 58 | 59 | def load_dataset(dpath, tag, info, bdr=False, mye=False): 60 | dset = dict() 61 | 62 | # Image 63 | fpath = os.path.join(dpath, info['dir'], info['img']) 64 | print(fpath) 65 | img = emio.imread(fpath).astype('float32') / 255.0 66 | dset['img'] = img 67 | 68 | # Segmentation 69 | fpath = os.path.join(dpath, info['dir'], info['seg']) 70 | print(fpath) 71 | seg = emio.imread(fpath).astype('uint32') 72 | dset['seg'] = seg 73 | 74 | # Boundary (or affinity) 75 | if bdr: 76 | fpath = os.path.join(dpath, info['dir'], info['bdr']) 77 | print(fpath) 78 | seg = emio.imread(fpath).astype('uint32') 79 | dset['aff'] = seg 80 | dset['bdr'] = (seg == 0).astype('uint8') 81 | 82 | # Myelin 83 | if mye: 84 | fpath = os.path.join(dpath, info['dir'], info['mye']) 85 | print(fpath) 86 | mye = emio.imread(fpath).astype('uint8') 87 | dset['mye'] = mye 88 | 89 | # Train mask 90 | fpath = os.path.join(dpath, info['dir'], info['msk']) 91 | print(fpath) 92 | dset['msk'] = emio.imread(fpath).astype('uint8') 93 | 94 | # Additoinal info 95 | dset['loc'] = info['loc'] 96 | 97 | return dset 98 | -------------------------------------------------------------------------------- /deepem/data/dataset/pinky/base.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import dataprovider3.emio as emio 4 | 5 | 6 | pinky_dir = 'pinky/ground_truth' 7 | pinky_info = { 8 | 'stitched_vol19-vol34':{ 9 | 'img': 'img.h5', 10 | 'seg': 'seg.h5', 11 | 'msk': 'msk.h5', 12 | 'dir': 'stitched_vol19-vol34/padded_z32_y512_x512', 13 | 'loc': True, 14 | }, 15 | 'stitched_vol40-vol41':{ 16 | 'img': 'img.h5', 17 | 'seg': 'seg.h5', 18 | 'msk': 'msk.h5', 19 | 'dir': 'stitched_vol40-vol41/padded_z32_y512_x512', 20 | 'loc': True, 21 | }, 22 | } 23 | 24 | 25 | def load_data(data_dir, data_ids=None, **kwargs): 26 | if data_ids is None: 27 | data_ids = pinky_info.keys() 28 | data = dict() 29 | base = os.path.expanduser(data_dir) 30 | dpath = os.path.join(base, pinky_dir) 31 | for data_id in data_ids: 32 | info = pinky_info[data_id] 33 | data[data_id] = load_dataset(dpath, data_id, info, **kwargs) 34 | return data 35 | 36 | 37 | def load_dataset(dpath, tag, info, class_keys=[], **kwargs): 38 | assert len(class_keys) > 0 39 | dset = dict() 40 | 41 | # Image 42 | fpath = os.path.join(dpath, info['dir'], info['img']) 43 | print(fpath) 44 | dset['img'] = emio.imread(fpath).astype('float32') / 255.0 45 | 46 | # Mask 47 | if tag == 'stitched_vol19-vol34': 48 | # Train 49 | fpath = os.path.join(dpath, info['dir'], 'msk_train.h5') 50 | print(fpath) 51 | dset['msk_train'] = emio.imread(fpath).astype('uint8') 52 | # Validation 53 | fpath = os.path.join(dpath, info['dir'], 'msk_val.h5') 54 | print(fpath) 55 | dset['msk_val'] = emio.imread(fpath).astype('uint8') 56 | else: 57 | fpath = os.path.join(dpath, info['dir'], info['msk']) 58 | print(fpath) 59 | dset['msk'] = emio.imread(fpath).astype('uint8') 60 | 61 | # Segmentation 62 | if 'aff' in class_keys or 'long' in class_keys: 63 | fpath = os.path.join(dpath, info['dir'], info['seg']) 64 | print(fpath) 65 | dset['seg'] = emio.imread(fpath).astype('uint32') 66 | 67 | # Additoinal info 68 | dset['loc'] = info['loc'] 69 | 70 | return dset 71 | -------------------------------------------------------------------------------- /deepem/data/dataset/pinky/mip0_padded_x512_y512_z32.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import os 3 | 4 | import dataprovider3.emio as emio 5 | 6 | 7 | # Pinky dataset 8 | pinky_dir = 'pinky/ground_truth/mip0/padded_x512_y512_z32' 9 | pinky_info = { 10 | 'stitched_vol19-vol34':{ 11 | 'img': 'img.h5', 12 | 'seg': 'seg.h5', 13 | 'psd': 'psd.h5', 14 | 'msk': 'msk.h5', 15 | 'loc': True, 16 | }, 17 | 'stitched_vol40-vol41':{ 18 | 'img': 'img.h5', 19 | 'seg': 'seg.h5', 20 | 'psd': 'psd.h5', 21 | 'msk': 'msk.h5', 22 | 'loc': True, 23 | }, 24 | 'vol101':{ 25 | 'img': 'img.h5', 26 | 'seg': 'seg.h5', 27 | 'psd': 'psd.h5', 28 | 'msk': 'msk.h5', 29 | 'blv': 'blv.h5', 30 | 'loc': True, 31 | }, 32 | 'vol102':{ 33 | 'img': 'img.h5', 34 | 'seg': 'seg.h5', 35 | 'psd': 'psd.h5', 36 | 'msk': 'msk.h5', 37 | 'loc': True, 38 | }, 39 | 'vol103':{ 40 | 'img': 'img.h5', 41 | 'seg': 'seg.h5', 42 | 'psd': 'psd.h5', 43 | 'msk': 'msk.h5', 44 | 'loc': True, 45 | }, 46 | 'vol104':{ 47 | 'img': 'img.h5', 48 | 'seg': 'seg.h5', 49 | 'psd': 'psd.h5', 50 | 'msk': 'msk.h5', 51 | 'loc': True, 52 | }, 53 | 'vol401':{ 54 | 'img': 'img.h5', 55 | 'seg': 'seg.h5', 56 | 'psd': 'psd.h5', 57 | 'msk': 'msk.h5', 58 | 'mye': 'mye.h5', 59 | 'blv': 'blv.h5', 60 | 'loc': True, 61 | }, 62 | 'vol501':{ 63 | 'img': 'img.h5', 64 | 'seg': 'seg.h5', 65 | 'msk': 'msk.d128.h5', 66 | 'loc': True, 67 | }, 68 | 'vol501a':{ 69 | 'img': 'img.h5', 70 | 'seg': 'seg.h5', 71 | 'msk': 'msk.h5', 72 | 'loc': True, 73 | }, 74 | 'vol502':{ 75 | 'img': 'img.h5', 76 | 'seg': 'seg.h5', 77 | 'psd': 'psd.h5', 78 | 'msk': 'msk.h5', 79 | 'mye': 'mye.h5', 80 | 'loc': True, 81 | }, 82 | 'vol503':{ 83 | 'img': 'img.h5', 84 | 'seg': 'seg.h5', 85 | 'psd': 'psd.h5', 86 | 'msk': 'msk.h5', 87 | 'blv': 'blv.h5', 88 | 'loc': True, 89 | }, 90 | 'vol201':{ 91 | 'img': 'img.h5', 92 | 'seg': 'seg.h5', 93 | 'msk': 'msk.d128.h5', 94 | 'blv': 'blv.h5', 95 | 'loc': True, 96 | }, 97 | 'vol201a':{ 98 | 'img': 'img.h5', 99 | 'seg': 'seg.h5', 100 | 'msk': 'msk.h5', 101 | 'blv': 'blv.h5', 102 | 'loc': True, 103 | }, 104 | } 105 | 106 | 107 | def load_data(data_dir, data_ids=None, **kwargs): 108 | if data_ids is None: 109 | data_ids = pinky_info.keys() 110 | 111 | data = dict() 112 | base = os.path.expanduser(data_dir) 113 | 114 | for data_id in data_ids: 115 | # Pinky 116 | if data_id in pinky_info: 117 | dpath = os.path.join(base, pinky_dir) 118 | info = pinky_info[data_id] 119 | data[data_id] = load_dataset(dpath, data_id, info, **kwargs) 120 | 121 | return data 122 | 123 | 124 | def load_dataset(dpath, tag, info, class_keys=[], **kwargs): 125 | assert len(class_keys) > 0 126 | dset = dict() 127 | 128 | # Image 129 | dname = tag[:-1] if tag[-1] == 'a' else tag 130 | fpath = os.path.join(dpath, dname, info['img']) 131 | print(fpath) 132 | dset['img'] = emio.imread(fpath).astype('float32') 133 | dset['img'] /= 255.0 134 | 135 | # Mask 136 | fpath = os.path.join(dpath, dname, info['msk']) 137 | print(fpath) 138 | dset['msk'] = emio.imread(fpath).astype('uint8') 139 | 140 | # Special volume 141 | if dname == 'stitched_vol19-vol34': 142 | fpath = os.path.join(dpath, dname, 'msk_train.h5') 143 | print(fpath) 144 | dset['msk_train'] = emio.imread(fpath).astype('uint8') 145 | fpath = os.path.join(dpath, dname, 'msk_val.h5') 146 | print(fpath) 147 | dset['msk_val'] = emio.imread(fpath).astype('uint8') 148 | 149 | # Segmentation 150 | if 'aff' in class_keys or 'long' in class_keys: 151 | fpath = os.path.join(dpath, dname, info['seg']) 152 | print(fpath) 153 | dset['seg'] = emio.imread(fpath).astype('uint32') 154 | 155 | # Synapse 156 | if 'psd' in class_keys: 157 | if 'psd' in info: 158 | fpath = os.path.join(dpath, dname, info['psd']) 159 | print(fpath) 160 | psd = (emio.imread(fpath) > 0).astype('uint8') 161 | else: 162 | psd = np.zeros(dset['img'].shape, dtype='uint8') 163 | dset['psd'] = psd 164 | 165 | # Special volumes 166 | special = ['stitched_vol40-vol41','vol101','vol102','vol103','vol104'] 167 | if dname in special: 168 | fpath = os.path.join(dpath, dname, 'psd_msk.h5') 169 | print(fpath) 170 | psd_msk = emio.imread(fpath).astype('uint8') 171 | else: 172 | psd_msk = dset['msk'] 173 | dset['psd_msk'] = psd_msk 174 | 175 | # Myelin 176 | if 'mye' in class_keys: 177 | if 'mye' in info: 178 | fpath = os.path.join(dpath, dname, info['mye']) 179 | print(fpath) 180 | mye = emio.imread(fpath).astype('uint8') 181 | else: 182 | mye = np.zeros(dset['img'].shape, dtype='uint8') 183 | dset['mye'] = mye 184 | 185 | # Blood vessel 186 | if 'blv' in class_keys: 187 | if 'blv' in info: 188 | fpath = os.path.join(dpath, dname, info['blv']) 189 | print(fpath) 190 | blv = emio.imread(fpath).astype('uint8') 191 | else: 192 | blv = np.zeros(dset['img'].shape, dtype='uint8') 193 | dset['blv'] = blv 194 | 195 | # Additoinal info 196 | dset['loc'] = info['loc'] 197 | 198 | return dset 199 | -------------------------------------------------------------------------------- /deepem/data/dataset/pinky_basil/mip0_padded_x512_y512_z32.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import os 3 | 4 | import dataprovider3.emio as emio 5 | 6 | 7 | # Basil dataset 8 | basil_dir = 'basil/ground_truth/mip0/padded_x512_y512_z32' 9 | basil_info = { 10 | 'vol001':{ 11 | 'img': 'img.h5', 12 | 'seg': 'seg.h5', 13 | 'psd': 'psd.h5', 14 | 'msk': 'msk.d128.h5', 15 | 'blv': 'blv.h5', 16 | 'loc': True, 17 | }, 18 | 'vol001a':{ 19 | 'img': 'img.h5', 20 | 'seg': 'seg.h5', 21 | 'psd': 'psd.h5', 22 | 'msk': 'msk.h5', 23 | 'blv': 'blv.h5', 24 | 'loc': True, 25 | }, 26 | 'vol002':{ 27 | 'img': 'img.h5', 28 | 'seg': 'seg.h5', 29 | 'psd': 'psd.h5', 30 | 'msk': 'msk.d128.h5', 31 | 'loc': True, 32 | }, 33 | 'vol002a':{ 34 | 'img': 'img.h5', 35 | 'seg': 'seg.h5', 36 | 'psd': 'psd.h5', 37 | 'msk': 'msk.h5', 38 | 'loc': True, 39 | }, 40 | 'vol003':{ 41 | 'img': 'img.h5', 42 | 'seg': 'seg.h5', 43 | 'psd': 'psd.h5', 44 | 'msk': 'msk.h5', 45 | 'mye': 'mye.h5', 46 | 'loc': True, 47 | }, 48 | 'vol004':{ 49 | 'img': 'img.h5', 50 | 'seg': 'seg.h5', 51 | 'psd': 'psd.h5', 52 | 'msk': 'msk.h5', 53 | 'loc': True, 54 | }, 55 | 'vol005':{ 56 | 'img': 'img.h5', 57 | 'seg': 'seg.h5', 58 | 'psd': 'psd.h5', 59 | 'msk': 'msk.h5', 60 | 'mye': 'mye.h5', 61 | 'loc': True, 62 | }, 63 | 'vol006':{ 64 | 'img': 'img.h5', 65 | 'seg': 'seg.h5', 66 | 'psd': 'psd.h5', 67 | 'msk': 'msk.h5', 68 | 'loc': True, 69 | }, 70 | 'vol008':{ 71 | 'img': 'img.h5', 72 | 'seg': 'seg.h5', 73 | 'psd': 'psd.h5', 74 | 'msk': 'msk.h5', 75 | 'loc': True, 76 | }, 77 | 'vol011':{ 78 | 'img': 'img.h5', 79 | 'seg': 'seg.h5', 80 | 'psd': 'psd.h5', 81 | 'msk': 'msk.h5', 82 | 'loc': True, 83 | }, 84 | } 85 | 86 | 87 | # Pinky dataset 88 | pinky_dir = 'pinky/ground_truth/mip0/padded_x512_y512_z32' 89 | pinky_info = { 90 | 'stitched_vol19-vol34':{ 91 | 'img': 'img.h5', 92 | 'seg': 'seg.h5', 93 | 'psd': 'psd.h5', 94 | 'msk': 'msk.h5', 95 | 'loc': True, 96 | }, 97 | 'stitched_vol40-vol41':{ 98 | 'img': 'img.h5', 99 | 'seg': 'seg.h5', 100 | 'psd': 'psd.h5', 101 | 'msk': 'msk.h5', 102 | 'loc': True, 103 | }, 104 | 'vol101':{ 105 | 'img': 'img.h5', 106 | 'seg': 'seg.h5', 107 | 'psd': 'psd.h5', 108 | 'msk': 'msk.h5', 109 | 'blv': 'blv.h5', 110 | 'loc': True, 111 | }, 112 | 'vol102':{ 113 | 'img': 'img.h5', 114 | 'seg': 'seg.h5', 115 | 'psd': 'psd.h5', 116 | 'msk': 'msk.h5', 117 | 'loc': True, 118 | }, 119 | 'vol103':{ 120 | 'img': 'img.h5', 121 | 'seg': 'seg.h5', 122 | 'psd': 'psd.h5', 123 | 'msk': 'msk.h5', 124 | 'loc': True, 125 | }, 126 | 'vol104':{ 127 | 'img': 'img.h5', 128 | 'seg': 'seg.h5', 129 | 'psd': 'psd.h5', 130 | 'msk': 'msk.h5', 131 | 'loc': True, 132 | }, 133 | 'vol401':{ 134 | 'img': 'img.h5', 135 | 'seg': 'seg.h5', 136 | 'psd': 'psd.h5', 137 | 'msk': 'msk.h5', 138 | 'mye': 'mye.h5', 139 | 'blv': 'blv.h5', 140 | 'loc': True, 141 | }, 142 | 'vol501':{ 143 | 'img': 'img.h5', 144 | 'seg': 'seg.h5', 145 | 'psd': 'psd.h5', 146 | 'msk': 'msk.d128.h5', 147 | 'loc': True, 148 | }, 149 | 'vol501a':{ 150 | 'img': 'img.h5', 151 | 'seg': 'seg.h5', 152 | 'psd': 'psd.h5', 153 | 'msk': 'msk.h5', 154 | 'loc': True, 155 | }, 156 | 'vol502':{ 157 | 'img': 'img.h5', 158 | 'seg': 'seg.h5', 159 | 'psd': 'psd.h5', 160 | 'msk': 'msk.h5', 161 | 'mye': 'mye.h5', 162 | 'loc': True, 163 | }, 164 | 'vol503':{ 165 | 'img': 'img.h5', 166 | 'seg': 'seg.h5', 167 | 'psd': 'psd.h5', 168 | 'msk': 'msk.h5', 169 | 'blv': 'blv.h5', 170 | 'loc': True, 171 | }, 172 | 'vol201':{ 173 | 'img': 'img.h5', 174 | 'seg': 'seg.h5', 175 | 'msk': 'msk.d128.h5', 176 | 'blv': 'blv.h5', 177 | 'loc': True, 178 | }, 179 | 'vol201a':{ 180 | 'img': 'img.h5', 181 | 'seg': 'seg.h5', 182 | 'msk': 'msk.h5', 183 | 'blv': 'blv.h5', 184 | 'loc': True, 185 | }, 186 | } 187 | 188 | 189 | def load_data(data_dir, data_ids=None, **kwargs): 190 | if data_ids is None: 191 | data_ids = basil_info.keys() + pinky_info.keys() 192 | 193 | data = dict() 194 | base = os.path.expanduser(data_dir) 195 | 196 | for data_id in data_ids: 197 | # Basil 198 | if data_id in basil_info: 199 | dpath = os.path.join(base, basil_dir) 200 | info = basil_info[data_id] 201 | data[data_id] = load_dataset(dpath, data_id, info, **kwargs) 202 | # Pinky 203 | if data_id in pinky_info: 204 | dpath = os.path.join(base, pinky_dir) 205 | info = pinky_info[data_id] 206 | data[data_id] = load_dataset(dpath, data_id, info, **kwargs) 207 | 208 | return data 209 | 210 | 211 | def load_dataset(dpath, tag, info, class_keys=[], **kwargs): 212 | assert len(class_keys) > 0 213 | dset = dict() 214 | 215 | # Image 216 | dname = tag[:-1] if tag[-1] == 'a' else tag 217 | fpath = os.path.join(dpath, dname, info['img']) 218 | print(fpath) 219 | dset['img'] = emio.imread(fpath).astype('float32') 220 | dset['img'] /= 255.0 221 | 222 | # Mask 223 | if dname == 'stitched_vol19-vol34': 224 | fpath = os.path.join(dpath, dname, 'msk_train.h5') 225 | print(fpath) 226 | dset['msk_train'] = emio.imread(fpath).astype('uint8') 227 | fpath = os.path.join(dpath, dname, 'msk_val.h5') 228 | print(fpath) 229 | dset['msk_val'] = emio.imread(fpath).astype('uint8') 230 | else: 231 | fpath = os.path.join(dpath, dname, info['msk']) 232 | print(fpath) 233 | dset['msk'] = emio.imread(fpath).astype('uint8') 234 | 235 | # Segmentation 236 | if 'aff' in class_keys or 'long' in class_keys: 237 | fpath = os.path.join(dpath, dname, info['seg']) 238 | print(fpath) 239 | dset['seg'] = emio.imread(fpath).astype('uint32') 240 | 241 | # Synapse 242 | if 'psd' in class_keys: 243 | if 'psd' in info: 244 | fpath = os.path.join(dpath, dname, info['psd']) 245 | print(fpath) 246 | psd = (emio.imread(fpath) > 0).astype('uint8') 247 | else: 248 | psd = np.zeros(dset['img'].shape, dtype='uint8') 249 | dset['psd'] = psd 250 | 251 | # Special volumes 252 | special = ['stitched_vol40-vol41','vol101','vol102','vol103','vol104'] 253 | if dname in special: 254 | fpath = os.path.join(dpath, dname, 'psd_msk.h5') 255 | print(fpath) 256 | psd_msk = emio.imread(fpath).astype('uint8') 257 | else: 258 | psd_msk = dset['msk'] 259 | dset['psd_msk'] = psd_msk 260 | 261 | # Myelin 262 | if 'mye' in class_keys: 263 | if 'mye' in info: 264 | fpath = os.path.join(dpath, dname, info['mye']) 265 | print(fpath) 266 | mye = emio.imread(fpath).astype('uint8') 267 | else: 268 | mye = np.zeros(dset['img'].shape, dtype='uint8') 269 | dset['mye'] = mye 270 | 271 | # Blood vessel 272 | if 'blv' in class_keys: 273 | if 'blv' in info: 274 | fpath = os.path.join(dpath, dname, info['blv']) 275 | print(fpath) 276 | blv = emio.imread(fpath).astype('uint8') 277 | else: 278 | blv = np.zeros(dset['img'].shape, dtype='uint8') 279 | dset['blv'] = blv 280 | 281 | # Additoinal info 282 | dset['loc'] = info['loc'] 283 | 284 | return dset 285 | -------------------------------------------------------------------------------- /deepem/data/dataset/pinky_basil/mip1_padded_x512_y512_z32.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import os 3 | 4 | import dataprovider3.emio as emio 5 | 6 | 7 | # Basil dataset 8 | basil_dir = 'basil/ground_truth/mip1/padded_x512_y512_z32' 9 | basil_info = { 10 | 'vol001':{ 11 | 'img': 'img.h5', 12 | 'seg': 'seg.h5', 13 | 'syn': 'syn.h5', 14 | 'msk': 'msk.d128.h5', 15 | 'blv': 'blv.h5', 16 | 'loc': True, 17 | }, 18 | 'vol001a':{ 19 | 'img': 'img.h5', 20 | 'seg': 'seg.h5', 21 | 'syn': 'syn.h5', 22 | 'msk': 'msk.h5', 23 | 'blv': 'blv.h5', 24 | 'loc': True, 25 | }, 26 | 'vol002':{ 27 | 'img': 'img.h5', 28 | 'seg': 'seg.h5', 29 | 'syn': 'syn.h5', 30 | 'msk': 'msk.d128.h5', 31 | 'loc': True, 32 | }, 33 | 'vol002a':{ 34 | 'img': 'img.h5', 35 | 'seg': 'seg.h5', 36 | 'syn': 'syn.h5', 37 | 'msk': 'msk.h5', 38 | 'loc': True, 39 | }, 40 | 'vol003':{ 41 | 'img': 'img.h5', 42 | 'seg': 'seg.h5', 43 | 'syn': 'syn.h5', 44 | 'msk': 'msk.h5', 45 | 'mye': 'mye.h5', 46 | 'loc': True, 47 | }, 48 | 'vol004':{ 49 | 'img': 'img.h5', 50 | 'seg': 'seg.h5', 51 | 'syn': 'syn.h5', 52 | 'msk': 'msk.h5', 53 | 'loc': True, 54 | }, 55 | 'vol005':{ 56 | 'img': 'img.h5', 57 | 'seg': 'seg.h5', 58 | 'syn': 'syn.h5', 59 | 'msk': 'msk.h5', 60 | 'mye': 'mye.h5', 61 | 'loc': True, 62 | }, 63 | 'vol006':{ 64 | 'img': 'img.h5', 65 | 'seg': 'seg.h5', 66 | 'syn': 'syn.h5', 67 | 'msk': 'msk.h5', 68 | 'loc': True, 69 | }, 70 | 'vol008':{ 71 | 'img': 'img.h5', 72 | 'seg': 'seg.h5', 73 | 'syn': 'syn.h5', 74 | 'msk': 'msk.h5', 75 | 'loc': True, 76 | }, 77 | 'vol011':{ 78 | 'img': 'img.h5', 79 | 'seg': 'seg.h5', 80 | 'syn': 'syn.h5', 81 | 'msk': 'msk.h5', 82 | 'loc': True, 83 | }, 84 | } 85 | 86 | 87 | # Pinky dataset 88 | pinky_dir = 'pinky/ground_truth/mip1/padded_x512_y512_z32' 89 | pinky_info = { 90 | 'stitched_vol19-vol34':{ 91 | 'img': 'img.h5', 92 | 'seg': 'seg.h5', 93 | 'syn': 'syn.h5', 94 | 'msk': 'msk.h5', 95 | 'loc': True, 96 | }, 97 | 'stitched_vol40-vol41':{ 98 | 'img': 'img.h5', 99 | 'seg': 'seg.h5', 100 | 'syn': 'syn.h5', 101 | 'msk': 'msk.h5', 102 | 'loc': True, 103 | }, 104 | 'vol101':{ 105 | 'img': 'img.h5', 106 | 'seg': 'seg.h5', 107 | 'syn': 'syn.h5', 108 | 'msk': 'msk.h5', 109 | 'blv': 'blv.h5', 110 | 'loc': True, 111 | }, 112 | 'vol102':{ 113 | 'img': 'img.h5', 114 | 'seg': 'seg.h5', 115 | 'syn': 'syn.h5', 116 | 'msk': 'msk.h5', 117 | 'loc': True, 118 | }, 119 | 'vol103':{ 120 | 'img': 'img.h5', 121 | 'seg': 'seg.h5', 122 | 'syn': 'syn.h5', 123 | 'msk': 'msk.h5', 124 | 'loc': True, 125 | }, 126 | 'vol104':{ 127 | 'img': 'img.h5', 128 | 'seg': 'seg.h5', 129 | 'syn': 'syn.h5', 130 | 'msk': 'msk.h5', 131 | 'loc': True, 132 | }, 133 | 'vol401':{ 134 | 'img': 'img.h5', 135 | 'seg': 'seg.h5', 136 | 'syn': 'syn.h5', 137 | 'msk': 'msk.h5', 138 | 'mye': 'mye.h5', 139 | 'blv': 'blv.h5', 140 | 'loc': True, 141 | }, 142 | 'vol501':{ 143 | 'img': 'img.h5', 144 | 'seg': 'seg.h5', 145 | 'syn': 'syn.h5', 146 | 'msk': 'msk.d128.h5', 147 | 'loc': True, 148 | }, 149 | 'vol501a':{ 150 | 'img': 'img.h5', 151 | 'seg': 'seg.h5', 152 | 'syn': 'syn.h5', 153 | 'msk': 'msk.h5', 154 | 'loc': True, 155 | }, 156 | 'vol502':{ 157 | 'img': 'img.h5', 158 | 'seg': 'seg.h5', 159 | 'syn': 'syn.h5', 160 | 'msk': 'msk.h5', 161 | 'mye': 'mye.h5', 162 | 'loc': True, 163 | }, 164 | 'vol503':{ 165 | 'img': 'img.h5', 166 | 'seg': 'seg.h5', 167 | 'syn': 'syn.h5', 168 | 'msk': 'msk.h5', 169 | 'blv': 'blv.h5', 170 | 'loc': True, 171 | }, 172 | 'vol201':{ 173 | 'img': 'img.h5', 174 | 'seg': 'seg.h5', 175 | 'msk': 'msk.d128.h5', 176 | 'blv': 'blv.h5', 177 | 'loc': True, 178 | }, 179 | 'vol201a':{ 180 | 'img': 'img.h5', 181 | 'seg': 'seg.h5', 182 | 'msk': 'msk.h5', 183 | 'blv': 'blv.h5', 184 | 'loc': True, 185 | }, 186 | } 187 | 188 | 189 | def load_data(data_dir, data_ids=None, **kwargs): 190 | if data_ids is None: 191 | data_ids = basil_info.keys() + pinky_info.keys() 192 | 193 | data = dict() 194 | base = os.path.expanduser(data_dir) 195 | 196 | for data_id in data_ids: 197 | # Basil 198 | if data_id in basil_info: 199 | dpath = os.path.join(base, basil_dir) 200 | info = basil_info[data_id] 201 | data[data_id] = load_dataset(dpath, data_id, info, **kwargs) 202 | # Pinky 203 | if data_id in pinky_info: 204 | dpath = os.path.join(base, pinky_dir) 205 | info = pinky_info[data_id] 206 | data[data_id] = load_dataset(dpath, data_id, info, **kwargs) 207 | 208 | return data 209 | 210 | 211 | def load_dataset(dpath, tag, info, class_keys=[], **kwargs): 212 | assert len(class_keys) > 0 213 | dset = dict() 214 | 215 | # Image 216 | dname = tag[:-1] if tag[-1] == 'a' else tag 217 | fpath = os.path.join(dpath, dname, info['img']) 218 | print(fpath) 219 | dset['img'] = emio.imread(fpath).astype('float32') 220 | dset['img'] /= 255.0 221 | 222 | # Mask 223 | if dname == 'stitched_vol19-vol34': 224 | fpath = os.path.join(dpath, dname, 'msk_train.h5') 225 | print(fpath) 226 | dset['msk_train'] = emio.imread(fpath).astype('uint8') 227 | fpath = os.path.join(dpath, dname, 'msk_val.h5') 228 | print(fpath) 229 | dset['msk_val'] = emio.imread(fpath).astype('uint8') 230 | else: 231 | fpath = os.path.join(dpath, dname, info['msk']) 232 | print(fpath) 233 | dset['msk'] = emio.imread(fpath).astype('uint8') 234 | 235 | # Segmentation 236 | if 'aff' in class_keys or 'long' in class_keys: 237 | fpath = os.path.join(dpath, dname, info['seg']) 238 | print(fpath) 239 | dset['seg'] = emio.imread(fpath).astype('uint32') 240 | 241 | # Synapse (distillation) 242 | if 'syn' in class_keys: 243 | if 'syn' in info: 244 | fpath = os.path.join(dpath, dname, info['syn']) 245 | print(fpath) 246 | syn = emio.imread(fpath).astype('float32') 247 | else: 248 | syn = np.zeros(dset['img'].shape, dtype='float32') 249 | dset['syn'] = syn 250 | 251 | # Myelin 252 | if 'mye' in class_keys: 253 | if 'mye' in info: 254 | fpath = os.path.join(dpath, dname, info['mye']) 255 | print(fpath) 256 | mye = emio.imread(fpath).astype('uint8') 257 | else: 258 | mye = np.zeros(dset['img'].shape, dtype='uint8') 259 | dset['mye'] = mye 260 | 261 | # Blood vessel 262 | if 'blv' in class_keys: 263 | if 'blv' in info: 264 | fpath = os.path.join(dpath, dname, info['blv']) 265 | print(fpath) 266 | blv = emio.imread(fpath).astype('uint8') 267 | else: 268 | blv = np.zeros(dset['img'].shape, dtype='uint8') 269 | dset['blv'] = blv 270 | 271 | # Additoinal info 272 | dset['loc'] = info['loc'] 273 | 274 | return dset 275 | -------------------------------------------------------------------------------- /deepem/data/modifier/crop2x.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | from deepem.utils.torch_utils import crop_center 4 | 5 | 6 | class Modifier(object): 7 | def __call__(self, sample): 8 | if np.random.rand() < 0.5: 9 | return sample 10 | for k, v in sample.items(): 11 | cropsz = (v.shape[-3], v.shape[-2]//2, v.shape[-1]//2) 12 | sample[k] = crop_center(v, cropsz).contiguous() 13 | return sample 14 | -------------------------------------------------------------------------------- /deepem/data/sampler/aff.py: -------------------------------------------------------------------------------- 1 | from augmentor import Augment 2 | from dataprovider3 import DataProvider, Dataset, DataSuperset 3 | 4 | 5 | def get_spec(in_spec, out_spec): 6 | spec = dict() 7 | # Input spec 8 | for k, v in in_spec.items(): 9 | spec[k] = v[-3:] 10 | # Output spec 11 | for k, v in out_spec.items(): 12 | dim = tuple(v[-3:]) 13 | spec[k] = dim 14 | spec[k+'_mask'] = dim 15 | return spec 16 | 17 | 18 | class Sampler(object): 19 | def __init__(self, data, spec, is_train, aug=None, prob=None): 20 | self.is_train = is_train 21 | if 'long_range' in spec: 22 | self.long_range = True 23 | del spec['long_range'] 24 | del spec['long_range_mask'] 25 | else: 26 | self.long_range = False 27 | self.build(data, spec, aug, prob) 28 | 29 | def __call__(self): 30 | sample = self.dataprovider() 31 | return self.postprocess(sample) 32 | 33 | def postprocess(self, sample): 34 | assert 'affinity' in sample 35 | 36 | # TODO: Copy or Ref? 37 | if self.long_range: 38 | sample['long_range'] = sample['affinity'] 39 | sample['long_range_mask'] = sample['affinity_mask'] 40 | 41 | sample = Augment.to_tensor(sample) 42 | return self.to_float32(sample) 43 | 44 | def to_float32(self, sample): 45 | for k, v in sample.items(): 46 | sample[k] = v.astype('float32') 47 | return sample 48 | 49 | def build(self, data, spec, aug, prob): 50 | dp = DataProvider(spec) 51 | keys = data.keys() 52 | for k in keys: 53 | if 'superset' in k: 54 | dp.add_dataset(self.build_datasuperset(k, data[k])) 55 | else: 56 | dp.add_dataset(self.build_dataset(k, data[k])) 57 | dp.set_augment(aug) 58 | dp.set_imgs(['input']) 59 | dp.set_segs(['affinity']) 60 | if prob: 61 | dp.set_sampling_weights(p=[prob[k] for k in keys]) 62 | else: 63 | dp.set_sampling_weights(p=None) 64 | self.dataprovider = dp 65 | print(dp) 66 | 67 | def build_datasuperset(self, tag, data): 68 | dset = DataSuperset(tag=tag) 69 | for k in data.keys(): 70 | dset.add_dataset(self.build_dataset(k, data[k])) 71 | return dset 72 | 73 | def build_dataset(self, tag, data): 74 | img = data['img'] 75 | seg = data['seg'] 76 | loc = data['loc'] 77 | msk = self.get_mask(data) 78 | 79 | # Create Dataset. 80 | dset = Dataset(tag=tag) 81 | dset.add_data(key='input', data=img) 82 | dset.add_data(key='affinity', data=seg) 83 | dset.add_mask(key='affinity_mask', data=msk, loc=loc) 84 | 85 | return dset 86 | 87 | def get_mask(self, data): 88 | key = 'msk_train' if self.is_train else 'msk_val' 89 | if key in data: 90 | return data[key] 91 | assert 'msk' in data 92 | return data['msk'] 93 | -------------------------------------------------------------------------------- /deepem/data/sampler/aff_dynamic_bdr.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | from augmentor import Augment 4 | from dataprovider3 import DataProvider, Dataset 5 | from datatools import create_border 6 | 7 | 8 | def get_spec(in_spec, out_spec): 9 | spec = dict() 10 | # Input spec 11 | for k, v in in_spec.items(): 12 | spec[k] = v[-3:] 13 | # Output spec 14 | for k, v in out_spec.items(): 15 | dim = tuple(v[-3:]) 16 | spec[k] = dim 17 | spec[k+'_mask'] = dim 18 | return spec 19 | 20 | 21 | class Sampler(object): 22 | def __init__(self, data, spec, is_train, aug=None, prob=None): 23 | self.is_train = is_train 24 | if 'long_range' in spec: 25 | self.long_range = True 26 | del spec['long_range'] 27 | del spec['long_range_mask'] 28 | else: 29 | self.long_range = False 30 | self.build(data, spec, aug, prob) 31 | 32 | def __call__(self): 33 | sample = self.dataprovider() 34 | return self.postprocess(sample) 35 | 36 | def postprocess(self, sample): 37 | assert 'affinity' in sample 38 | 39 | # Create border 40 | seg = np.squeeze(sample['affinity']) 41 | sample['affinity'] = create_border(seg.astype('uint32')) 42 | 43 | # TODO: Copy or Ref? 44 | if self.long_range: 45 | sample['long_range'] = sample['affinity'] 46 | sample['long_range_mask'] = sample['affinity_mask'] 47 | 48 | sample = Augment.to_tensor(sample) 49 | return self.to_float32(sample) 50 | 51 | def to_float32(self, sample): 52 | for k, v in sample.items(): 53 | sample[k] = v.astype('float32') 54 | return sample 55 | 56 | def build(self, data, spec, aug, prob): 57 | dp = DataProvider(spec) 58 | keys = data.keys() 59 | for k in keys: 60 | dp.add_dataset(self.build_dataset(k, data[k])) 61 | dp.set_augment(aug) 62 | dp.set_imgs(['input']) 63 | dp.set_segs(['affinity']) 64 | if prob: 65 | dp.set_sampling_weights(p=[prob[k] for k in keys]) 66 | else: 67 | dp.set_sampling_weights(p=None) 68 | self.dataprovider = dp 69 | print(dp) 70 | 71 | def build_dataset(self, tag, data): 72 | img = data['img'] 73 | seg = data['seg'] 74 | loc = data['loc'] 75 | msk = self.get_mask(data) 76 | 77 | # Create Dataset. 78 | dset = Dataset(tag=tag) 79 | dset.add_data(key='input', data=img) 80 | dset.add_data(key='affinity', data=seg) 81 | dset.add_mask(key='affinity_mask', data=msk, loc=loc) 82 | 83 | return dset 84 | 85 | def get_mask(self, data): 86 | key = 'msk_train' if self.is_train else 'msk_val' 87 | if key in data: 88 | return data[key] 89 | assert 'msk' in data 90 | return data['msk'] 91 | -------------------------------------------------------------------------------- /deepem/data/sampler/aff_glia.py: -------------------------------------------------------------------------------- 1 | from augmentor import Augment 2 | from dataprovider3 import DataProvider, Dataset 3 | 4 | 5 | def get_spec(in_spec, out_spec): 6 | spec = dict() 7 | # Input spec 8 | for k, v in in_spec.items(): 9 | spec[k] = v[-3:] 10 | # Output spec 11 | for k, v in out_spec.items(): 12 | dim = tuple(v[-3:]) 13 | spec[k] = dim 14 | spec[k+'_mask'] = dim 15 | return spec 16 | 17 | 18 | class Sampler(object): 19 | def __init__(self, data, spec, is_train, aug=None, prob=None): 20 | self.is_train = is_train 21 | if 'long_range' in spec: 22 | self.long_range = True 23 | del spec['long_range'] 24 | del spec['long_range_mask'] 25 | else: 26 | self.long_range = False 27 | self.build(data, spec, aug, prob) 28 | 29 | def __call__(self): 30 | sample = self.dataprovider() 31 | return self.postprocess(sample) 32 | 33 | def postprocess(self, sample): 34 | assert 'affinity' in sample 35 | assert 'glia' in sample 36 | 37 | # TODO: Copy or Ref? 38 | if self.long_range: 39 | sample['long_range'] = sample['affinity'] 40 | sample['long_range_mask'] = sample['affinity_mask'] 41 | 42 | sample = Augment.to_tensor(sample) 43 | return self.to_float32(sample) 44 | 45 | def to_float32(self, sample): 46 | for k, v in sample.items(): 47 | sample[k] = v.astype('float32') 48 | return sample 49 | 50 | def build(self, data, spec, aug, prob): 51 | dp = DataProvider(spec) 52 | keys = data.keys() 53 | for k in keys: 54 | dp.add_dataset(self.build_dataset(k, data[k])) 55 | dp.set_augment(aug) 56 | dp.set_imgs(['input']) 57 | dp.set_segs(['affinity']) 58 | if prob: 59 | dp.set_sampling_weights(p=[prob[k] for k in keys]) 60 | else: 61 | dp.set_sampling_weights(p=None) 62 | self.dataprovider = dp 63 | print(dp) 64 | 65 | def build_dataset(self, tag, data): 66 | img = data['img'] 67 | seg = data['seg'] 68 | loc = data['loc'] 69 | msk = self.get_mask(data) 70 | glia = data['glia'] 71 | gmsk = data['gmsk'] if 'gmsk' in data else msk 72 | 73 | # Create Dataset. 74 | dset = Dataset(tag=tag) 75 | dset.add_data(key='input', data=img) 76 | dset.add_data(key='affinity', data=seg) 77 | dset.add_mask(key='affinity_mask', data=msk, loc=loc) 78 | dset.add_data(key='glia', data=glia) 79 | dset.add_mask(key='glia_mask', data=gmsk) 80 | 81 | return dset 82 | 83 | def get_mask(self, data): 84 | key = 'msk_train' if self.is_train else 'msk_val' 85 | if key in data: 86 | return data[key] 87 | assert 'msk' in data 88 | return data['msk'] 89 | -------------------------------------------------------------------------------- /deepem/data/sampler/aff_mit.py: -------------------------------------------------------------------------------- 1 | from augmentor import Augment 2 | from dataprovider3 import DataProvider, Dataset 3 | 4 | 5 | def get_spec(in_spec, out_spec): 6 | spec = dict() 7 | # Input spec 8 | for k, v in in_spec.items(): 9 | spec[k] = v[-3:] 10 | # Output spec 11 | for k, v in out_spec.items(): 12 | dim = tuple(v[-3:]) 13 | spec[k] = dim 14 | spec[k+'_mask'] = dim 15 | return spec 16 | 17 | 18 | class Sampler(object): 19 | def __init__(self, data, spec, is_train, aug=None, prob=None): 20 | self.is_train = is_train 21 | if 'long_range' in spec: 22 | self.long_range = True 23 | del spec['long_range'] 24 | del spec['long_range_mask'] 25 | else: 26 | self.long_range = False 27 | self.build(data, spec, aug, prob) 28 | 29 | def __call__(self): 30 | sample = self.dataprovider() 31 | return self.postprocess(sample) 32 | 33 | def postprocess(self, sample): 34 | assert 'affinity' in sample 35 | assert 'mitochondria' in sample 36 | 37 | # TODO: Copy or Ref? 38 | if self.long_range: 39 | sample['long_range'] = sample['affinity'] 40 | sample['long_range_mask'] = sample['affinity_mask'] 41 | 42 | sample = Augment.to_tensor(sample) 43 | return self.to_float32(sample) 44 | 45 | def to_float32(self, sample): 46 | for k, v in sample.items(): 47 | sample[k] = v.astype('float32') 48 | return sample 49 | 50 | def build(self, data, spec, aug, prob): 51 | dp = DataProvider(spec) 52 | keys = data.keys() 53 | for k in keys: 54 | dp.add_dataset(self.build_dataset(k, data[k])) 55 | dp.set_augment(aug) 56 | dp.set_imgs(['input']) 57 | dp.set_segs(['affinity']) 58 | if prob: 59 | dp.set_sampling_weights(p=[prob[k] for k in keys]) 60 | else: 61 | dp.set_sampling_weights(p=None) 62 | self.dataprovider = dp 63 | print(dp) 64 | 65 | def build_dataset(self, tag, data): 66 | img = data['img'] 67 | seg = data['seg'] 68 | mit = data['mit'] 69 | loc = data['loc'] 70 | msk = self.get_mask(data) 71 | 72 | # Create Dataset. 73 | dset = Dataset(tag=tag) 74 | dset.add_data(key='input', data=img) 75 | dset.add_data(key='affinity', data=seg) 76 | dset.add_mask(key='affinity_mask', data=msk, loc=loc) 77 | dset.add_data(key='mitochondria', data=mit) 78 | dset.add_mask(key='mitochondria_mask', data=msk) 79 | 80 | return dset 81 | 82 | def get_mask(self, data): 83 | key = 'msk_train' if self.is_train else 'msk_val' 84 | if key in data: 85 | return data[key] 86 | assert 'msk' in data 87 | return data['msk'] 88 | -------------------------------------------------------------------------------- /deepem/data/sampler/aff_mye.py: -------------------------------------------------------------------------------- 1 | from augmentor import Augment 2 | from dataprovider3 import DataProvider, Dataset 3 | 4 | 5 | def get_spec(in_spec, out_spec): 6 | spec = dict() 7 | # Input spec 8 | for k, v in in_spec.items(): 9 | spec[k] = v[-3:] 10 | # Output spec 11 | for k, v in out_spec.items(): 12 | dim = tuple(v[-3:]) 13 | spec[k] = dim 14 | spec[k+'_mask'] = dim 15 | return spec 16 | 17 | 18 | class Sampler(object): 19 | def __init__(self, data, spec, is_train, aug=None, prob=None): 20 | self.is_train = is_train 21 | if 'long_range' in spec: 22 | self.long_range = True 23 | del spec['long_range'] 24 | del spec['long_range_mask'] 25 | else: 26 | self.long_range = False 27 | self.build(data, spec, aug, prob) 28 | 29 | def __call__(self): 30 | sample = self.dataprovider() 31 | return self.postprocess(sample) 32 | 33 | def postprocess(self, sample): 34 | assert 'affinity' in sample 35 | assert 'myelin' in sample 36 | 37 | # TODO: Copy or Ref? 38 | if self.long_range: 39 | sample['long_range'] = sample['affinity'] 40 | sample['long_range_mask'] = sample['affinity_mask'] 41 | 42 | sample = Augment.to_tensor(sample) 43 | return self.to_float32(sample) 44 | 45 | def to_float32(self, sample): 46 | for k, v in sample.items(): 47 | sample[k] = v.astype('float32') 48 | return sample 49 | 50 | def build(self, data, spec, aug, prob): 51 | dp = DataProvider(spec) 52 | keys = data.keys() 53 | for k in keys: 54 | dp.add_dataset(self.build_dataset(k, data[k])) 55 | dp.set_augment(aug) 56 | dp.set_imgs(['input']) 57 | dp.set_segs(['affinity']) 58 | if prob: 59 | dp.set_sampling_weights(p=[prob[k] for k in keys]) 60 | else: 61 | dp.set_sampling_weights(p=None) 62 | self.dataprovider = dp 63 | print(dp) 64 | 65 | def build_dataset(self, tag, data): 66 | img = data['img'] 67 | seg = data['seg'] 68 | mye = data['mye'] 69 | loc = data['loc'] 70 | msk = self.get_mask(data) 71 | 72 | # Create Dataset. 73 | dset = Dataset(tag=tag) 74 | dset.add_data(key='input', data=img) 75 | dset.add_data(key='affinity', data=seg) 76 | dset.add_mask(key='affinity_mask', data=msk, loc=loc) 77 | dset.add_data(key='myelin', data=mye) 78 | dset.add_mask(key='myelin_mask', data=msk) 79 | 80 | return dset 81 | 82 | def get_mask(self, data): 83 | key = 'msk_train' if self.is_train else 'msk_val' 84 | if key in data: 85 | return data[key] 86 | assert 'msk' in data 87 | return data['msk'] 88 | -------------------------------------------------------------------------------- /deepem/data/sampler/aff_mye_blv1.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | from augmentor import Augment 4 | from dataprovider3 import DataProvider, Dataset 5 | 6 | 7 | def get_spec(in_spec, out_spec): 8 | spec = dict() 9 | # Input spec 10 | for k, v in in_spec.items(): 11 | spec[k] = v[-3:] 12 | # Output spec 13 | for k, v in out_spec.items(): 14 | dim = tuple(v[-3:]) 15 | spec[k] = dim 16 | spec[k+'_mask'] = dim 17 | return spec 18 | 19 | 20 | class Sampler(object): 21 | def __init__(self, data, spec, is_train, aug=None, prob=None): 22 | self.is_train = is_train 23 | if 'long_range' in spec: 24 | self.long_range = True 25 | del spec['long_range'] 26 | del spec['long_range_mask'] 27 | else: 28 | self.long_range = False 29 | self.build(data, spec, aug, prob) 30 | 31 | def __call__(self): 32 | sample = self.dataprovider() 33 | return self.postprocess(sample) 34 | 35 | def postprocess(self, sample): 36 | assert 'affinity' in sample 37 | assert 'myelin' in sample 38 | assert 'blood_vessel' in sample 39 | 40 | # TODO: Copy or Ref? 41 | if self.long_range: 42 | sample['long_range'] = sample['affinity'] 43 | sample['long_range_mask'] = sample['affinity_mask'] 44 | 45 | # Blood vessel 46 | blv = sample['blood_vessel'] 47 | sample['blood_vessel'] = (blv == 2) # Endothelia 48 | 49 | sample = Augment.to_tensor(sample) 50 | return self.to_float32(sample) 51 | 52 | def to_float32(self, sample): 53 | for k, v in sample.items(): 54 | sample[k] = v.astype('float32') 55 | return sample 56 | 57 | def build(self, data, spec, aug, prob): 58 | dp = DataProvider(spec) 59 | keys = data.keys() 60 | for k in keys: 61 | dp.add_dataset(self.build_dataset(k, data[k])) 62 | dp.set_augment(aug) 63 | dp.set_imgs(['input']) 64 | dp.set_segs(['affinity']) 65 | if prob: 66 | dp.set_sampling_weights(p=[prob[k] for k in keys]) 67 | else: 68 | dp.set_sampling_weights(p=None) 69 | self.dataprovider = dp 70 | print(dp) 71 | 72 | def build_dataset(self, tag, data): 73 | img = data['img'] 74 | seg = data['seg'] 75 | mye = data['mye'] 76 | blv = data['blv'] 77 | loc = data['loc'] 78 | msk = self.get_mask(data) 79 | 80 | # Create Dataset. 81 | dset = Dataset(tag=tag) 82 | dset.add_data(key='input', data=img) 83 | dset.add_data(key='affinity', data=seg) 84 | dset.add_mask(key='affinity_mask', data=msk, loc=loc) 85 | dset.add_data(key='myelin', data=mye) 86 | dset.add_mask(key='myelin_mask', data=msk) 87 | dset.add_data(key='blood_vessel', data=blv) 88 | dset.add_mask(key='blood_vessel_mask', data=msk) 89 | 90 | return dset 91 | 92 | def get_mask(self, data): 93 | key = 'msk_train' if self.is_train else 'msk_val' 94 | if key in data: 95 | return data[key] 96 | assert 'msk' in data 97 | return data['msk'] 98 | -------------------------------------------------------------------------------- /deepem/data/sampler/aff_mye_blv1_fld0.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | from augmentor import Augment 4 | from dataprovider3 import DataProvider, Dataset 5 | 6 | 7 | def get_spec(in_spec, out_spec): 8 | spec = dict() 9 | # Input spec 10 | for k, v in in_spec.items(): 11 | spec[k] = v[-3:] 12 | # Output spec 13 | for k, v in out_spec.items(): 14 | dim = tuple(v[-3:]) 15 | spec[k] = dim 16 | spec[k+'_mask'] = dim 17 | return spec 18 | 19 | 20 | class Sampler(object): 21 | def __init__(self, data, spec, is_train, aug=None, prob=None): 22 | self.is_train = is_train 23 | if 'long_range' in spec: 24 | self.long_range = True 25 | del spec['long_range'] 26 | del spec['long_range_mask'] 27 | else: 28 | self.long_range = False 29 | self.build(data, spec, aug, prob) 30 | 31 | def __call__(self): 32 | sample = self.dataprovider() 33 | return self.postprocess(sample) 34 | 35 | def postprocess(self, sample): 36 | assert 'affinity' in sample 37 | assert 'myelin' in sample 38 | assert 'blood_vessel' in sample 39 | 40 | # TODO: Copy or Ref? 41 | if self.long_range: 42 | sample['long_range'] = sample['affinity'] 43 | sample['long_range_mask'] = sample['affinity_mask'] 44 | 45 | # Blood vessel 46 | blv = sample['blood_vessel'] 47 | sample['blood_vessel'] = (blv == 2) # Endothelia 48 | 49 | sample = Augment.to_tensor(sample) 50 | return self.to_float32(sample) 51 | 52 | def to_float32(self, sample): 53 | for k, v in sample.items(): 54 | sample[k] = v.astype('float32') 55 | return sample 56 | 57 | def build(self, data, spec, aug, prob): 58 | dp = DataProvider(spec) 59 | keys = data.keys() 60 | for k in keys: 61 | dp.add_dataset(self.build_dataset(k, data[k])) 62 | dp.set_augment(aug) 63 | dp.set_imgs(['input']) 64 | dp.set_segs(['affinity']) 65 | if prob: 66 | dp.set_sampling_weights(p=[prob[k] for k in keys]) 67 | else: 68 | dp.set_sampling_weights(p=None) 69 | self.dataprovider = dp 70 | print(dp) 71 | 72 | def build_dataset(self, tag, data): 73 | img = data['img'] 74 | seg = data['seg'] 75 | mye = data['mye'] 76 | blv = data['blv'] 77 | loc = data['loc'] 78 | msk = self.get_mask(data) 79 | 80 | # Create Dataset. 81 | dset = Dataset(tag=tag) 82 | dset.add_data(key='input', data=img) 83 | dset.add_data(key='affinity', data=seg) 84 | dset.add_mask(key='affinity_mask', data=msk, loc=loc) 85 | dset.add_data(key='myelin', data=mye) 86 | dset.add_mask(key='myelin_mask', data=msk) 87 | dset.add_data(key='blood_vessel', data=blv) 88 | dset.add_mask(key='blood_vessel_mask', data=msk) 89 | 90 | return dset 91 | 92 | def get_mask(self, data): 93 | key = 'msk_train' if self.is_train else 'msk_val' 94 | if key in data: 95 | return data[key] 96 | assert 'msk' in data 97 | return data['msk'] 98 | -------------------------------------------------------------------------------- /deepem/data/sampler/aff_mye_blv2.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | from augmentor import Augment 4 | from dataprovider3 import DataProvider, Dataset 5 | 6 | 7 | def get_spec(in_spec, out_spec): 8 | spec = dict() 9 | # Input spec 10 | for k, v in in_spec.items(): 11 | spec[k] = v[-3:] 12 | # Output spec 13 | for k, v in out_spec.items(): 14 | dim = tuple(v[-3:]) 15 | spec[k] = dim 16 | spec[k+'_mask'] = dim 17 | return spec 18 | 19 | 20 | class Sampler(object): 21 | def __init__(self, data, spec, is_train, aug=None, prob=None): 22 | self.is_train = is_train 23 | if 'long_range' in spec: 24 | self.long_range = True 25 | del spec['long_range'] 26 | del spec['long_range_mask'] 27 | else: 28 | self.long_range = False 29 | self.build(data, spec, aug, prob) 30 | 31 | def __call__(self): 32 | sample = self.dataprovider() 33 | return self.postprocess(sample) 34 | 35 | def postprocess(self, sample): 36 | assert 'affinity' in sample 37 | assert 'myelin' in sample 38 | assert 'blood_vessel' in sample 39 | 40 | # TODO: Copy or Ref? 41 | if self.long_range: 42 | sample['long_range'] = sample['affinity'] 43 | sample['long_range_mask'] = sample['affinity_mask'] 44 | 45 | # Blood vessel 46 | blv = sample['blood_vessel'] 47 | sample['blood_vessel'] = (blv > 0) # Endothelia + lumen 48 | 49 | sample = Augment.to_tensor(sample) 50 | return self.to_float32(sample) 51 | 52 | def to_float32(self, sample): 53 | for k, v in sample.items(): 54 | sample[k] = v.astype('float32') 55 | return sample 56 | 57 | def build(self, data, spec, aug, prob): 58 | dp = DataProvider(spec) 59 | keys = data.keys() 60 | for k in keys: 61 | dp.add_dataset(self.build_dataset(k, data[k])) 62 | dp.set_augment(aug) 63 | dp.set_imgs(['input']) 64 | dp.set_segs(['affinity']) 65 | if prob: 66 | dp.set_sampling_weights(p=[prob[k] for k in keys]) 67 | else: 68 | dp.set_sampling_weights(p=None) 69 | self.dataprovider = dp 70 | print(dp) 71 | 72 | def build_dataset(self, tag, data): 73 | img = data['img'] 74 | seg = data['seg'] 75 | mye = data['mye'] 76 | blv = data['blv'] 77 | loc = data['loc'] 78 | msk = self.get_mask(data) 79 | 80 | # Create Dataset. 81 | dset = Dataset(tag=tag) 82 | dset.add_data(key='input', data=img) 83 | dset.add_data(key='affinity', data=seg) 84 | dset.add_mask(key='affinity_mask', data=msk, loc=loc) 85 | dset.add_data(key='myelin', data=mye) 86 | dset.add_mask(key='myelin_mask', data=msk) 87 | dset.add_data(key='blood_vessel', data=blv) 88 | dset.add_mask(key='blood_vessel_mask', data=msk) 89 | 90 | return dset 91 | 92 | def get_mask(self, data): 93 | key = 'msk_train' if self.is_train else 'msk_val' 94 | if key in data: 95 | return data[key] 96 | assert 'msk' in data 97 | return data['msk'] 98 | -------------------------------------------------------------------------------- /deepem/data/sampler/aff_psd_mye.py: -------------------------------------------------------------------------------- 1 | from augmentor import Augment 2 | from dataprovider3 import DataProvider, Dataset 3 | 4 | 5 | def get_spec(in_spec, out_spec): 6 | spec = dict() 7 | # Input spec 8 | for k, v in in_spec.items(): 9 | spec[k] = v[-3:] 10 | # Output spec 11 | for k, v in out_spec.items(): 12 | dim = tuple(v[-3:]) 13 | spec[k] = dim 14 | spec[k+'_mask'] = dim 15 | return spec 16 | 17 | 18 | class Sampler(object): 19 | def __init__(self, data, spec, is_train, aug=None, prob=None): 20 | self.is_train = is_train 21 | if 'long_range' in spec: 22 | self.long_range = True 23 | del spec['long_range'] 24 | del spec['long_range_mask'] 25 | else: 26 | self.long_range = False 27 | self.build(data, spec, aug, prob) 28 | 29 | def __call__(self): 30 | sample = self.dataprovider() 31 | return self.postprocess(sample) 32 | 33 | def postprocess(self, sample): 34 | assert 'affinity' in sample 35 | assert 'synapse' in sample 36 | assert 'myelin' in sample 37 | 38 | # TODO: Copy or Ref? 39 | if self.long_range: 40 | sample['long_range'] = sample['affinity'] 41 | sample['long_range_mask'] = sample['affinity_mask'] 42 | 43 | sample = Augment.to_tensor(sample) 44 | return self.to_float32(sample) 45 | 46 | def to_float32(self, sample): 47 | for k, v in sample.items(): 48 | sample[k] = v.astype('float32') 49 | return sample 50 | 51 | def build(self, data, spec, aug, prob): 52 | dp = DataProvider(spec) 53 | keys = data.keys() 54 | for k in keys: 55 | dp.add_dataset(self.build_dataset(k, data[k])) 56 | dp.set_augment(aug) 57 | dp.set_imgs(['input']) 58 | dp.set_segs(['affinity']) 59 | if prob: 60 | dp.set_sampling_weights(p=[prob[k] for k in keys]) 61 | else: 62 | dp.set_sampling_weights(p=None) 63 | self.dataprovider = dp 64 | print(dp) 65 | 66 | def build_dataset(self, tag, data): 67 | img = data['img'] 68 | seg = data['seg'] 69 | psd = data['psd'] 70 | psd_msk = data['psd_msk'] 71 | mye = data['mye'] 72 | loc = data['loc'] 73 | msk = self.get_mask(data) 74 | 75 | # Create Dataset. 76 | dset = Dataset(tag=tag) 77 | dset.add_data(key='input', data=img) 78 | dset.add_data(key='affinity', data=seg) 79 | dset.add_mask(key='affinity_mask', data=msk, loc=loc) 80 | dset.add_data(key='synapse', data=psd) 81 | dset.add_mask(key='synapse_mask', data=psd_msk) 82 | dset.add_data(key='myelin', data=mye) 83 | dset.add_mask(key='myelin_mask', data=msk) 84 | 85 | return dset 86 | 87 | def get_mask(self, data): 88 | key = 'msk_train' if self.is_train else 'msk_val' 89 | if key in data: 90 | return data[key] 91 | assert 'msk' in data 92 | return data['msk'] 93 | -------------------------------------------------------------------------------- /deepem/data/sampler/aff_psd_mye_blv.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | from augmentor import Augment 4 | from dataprovider3 import DataProvider, Dataset 5 | 6 | 7 | def get_spec(in_spec, out_spec): 8 | spec = dict() 9 | # Input spec 10 | for k, v in in_spec.items(): 11 | spec[k] = v[-3:] 12 | # Output spec 13 | for k, v in out_spec.items(): 14 | dim = tuple(v[-3:]) 15 | spec[k] = dim 16 | spec[k+'_mask'] = dim 17 | return spec 18 | 19 | 20 | class Sampler(object): 21 | def __init__(self, data, spec, is_train, aug=None, prob=None): 22 | self.is_train = is_train 23 | if 'long_range' in spec: 24 | self.long_range = True 25 | del spec['long_range'] 26 | del spec['long_range_mask'] 27 | else: 28 | self.long_range = False 29 | self.build(data, spec, aug, prob) 30 | 31 | def __call__(self): 32 | sample = self.dataprovider() 33 | return self.postprocess(sample) 34 | 35 | def postprocess(self, sample): 36 | assert 'affinity' in sample 37 | assert 'synapse' in sample 38 | assert 'myelin' in sample 39 | assert 'blood_vessel' in sample 40 | 41 | # TODO: Copy or Ref? 42 | if self.long_range: 43 | sample['long_range'] = sample['affinity'] 44 | sample['long_range_mask'] = sample['affinity_mask'] 45 | 46 | # Blood vessel 47 | blv = sample['blood_vessel'] 48 | blv = np.concatenate((blv==1, blv==2), axis=-4) 49 | blv_msk = sample['blood_vessel_mask'] 50 | blv_msk = np.concatenate((blv_msk, blv_msk), axis=-4) 51 | sample['blood_vessel'] = blv 52 | sample['blood_vessel_mask'] = blv_msk 53 | 54 | sample = Augment.to_tensor(sample) 55 | return self.to_float32(sample) 56 | 57 | def to_float32(self, sample): 58 | for k, v in sample.items(): 59 | sample[k] = v.astype('float32') 60 | return sample 61 | 62 | def build(self, data, spec, aug, prob): 63 | dp = DataProvider(spec) 64 | keys = data.keys() 65 | for k in keys: 66 | dp.add_dataset(self.build_dataset(k, data[k])) 67 | dp.set_augment(aug) 68 | dp.set_imgs(['input']) 69 | dp.set_segs(['affinity']) 70 | if prob: 71 | dp.set_sampling_weights(p=[prob[k] for k in keys]) 72 | else: 73 | dp.set_sampling_weights(p=None) 74 | self.dataprovider = dp 75 | print(dp) 76 | 77 | def build_dataset(self, tag, data): 78 | img = data['img'] 79 | seg = data['seg'] 80 | psd = data['psd'] 81 | psd_msk = data['psd_msk'] 82 | mye = data['mye'] 83 | blv = data['blv'] 84 | loc = data['loc'] 85 | msk = self.get_mask(data) 86 | 87 | # Create Dataset. 88 | dset = Dataset(tag=tag) 89 | dset.add_data(key='input', data=img) 90 | dset.add_data(key='affinity', data=seg) 91 | dset.add_mask(key='affinity_mask', data=msk, loc=loc) 92 | dset.add_data(key='synapse', data=psd) 93 | dset.add_mask(key='synapse_mask', data=psd_msk) 94 | dset.add_data(key='myelin', data=mye) 95 | dset.add_mask(key='myelin_mask', data=msk) 96 | dset.add_data(key='blood_vessel', data=blv) 97 | dset.add_mask(key='blood_vessel_mask', data=msk) 98 | 99 | return dset 100 | 101 | def get_mask(self, data): 102 | key = 'msk_train' if self.is_train else 'msk_val' 103 | if key in data: 104 | return data[key] 105 | assert 'msk' in data 106 | return data['msk'] 107 | -------------------------------------------------------------------------------- /deepem/data/sampler/aff_syn_mye.py: -------------------------------------------------------------------------------- 1 | from augmentor import Augment 2 | from dataprovider3 import DataProvider, Dataset 3 | 4 | 5 | def get_spec(in_spec, out_spec): 6 | spec = dict() 7 | # Input spec 8 | for k, v in in_spec.items(): 9 | spec[k] = v[-3:] 10 | # Output spec 11 | for k, v in out_spec.items(): 12 | dim = tuple(v[-3:]) 13 | spec[k] = dim 14 | spec[k+'_mask'] = dim 15 | return spec 16 | 17 | 18 | class Sampler(object): 19 | def __init__(self, data, spec, is_train, aug=None, prob=None): 20 | self.is_train = is_train 21 | if 'long_range' in spec: 22 | self.long_range = True 23 | del spec['long_range'] 24 | del spec['long_range_mask'] 25 | else: 26 | self.long_range = False 27 | self.build(data, spec, aug, prob) 28 | 29 | def __call__(self): 30 | sample = self.dataprovider() 31 | return self.postprocess(sample) 32 | 33 | def postprocess(self, sample): 34 | assert 'affinity' in sample 35 | assert 'synapse' in sample 36 | assert 'myelin' in sample 37 | 38 | # TODO: Copy or Ref? 39 | if self.long_range: 40 | sample['long_range'] = sample['affinity'] 41 | sample['long_range_mask'] = sample['affinity_mask'] 42 | 43 | sample = Augment.to_tensor(sample) 44 | return self.to_float32(sample) 45 | 46 | def to_float32(self, sample): 47 | for k, v in sample.items(): 48 | sample[k] = v.astype('float32') 49 | return sample 50 | 51 | def build(self, data, spec, aug, prob): 52 | dp = DataProvider(spec) 53 | keys = data.keys() 54 | for k in keys: 55 | dp.add_dataset(self.build_dataset(k, data[k])) 56 | dp.set_augment(aug) 57 | dp.set_imgs(['input']) 58 | dp.set_segs(['affinity']) 59 | if prob: 60 | dp.set_sampling_weights(p=[prob[k] for k in keys]) 61 | else: 62 | dp.set_sampling_weights(p=None) 63 | self.dataprovider = dp 64 | print(dp) 65 | 66 | def build_dataset(self, tag, data): 67 | img = data['img'] 68 | seg = data['seg'] 69 | syn = data['syn'] 70 | mye = data['mye'] 71 | loc = data['loc'] 72 | msk = self.get_mask(data) 73 | 74 | # Create Dataset. 75 | dset = Dataset(tag=tag) 76 | dset.add_data(key='input', data=img) 77 | dset.add_data(key='affinity', data=seg) 78 | dset.add_mask(key='affinity_mask', data=msk, loc=loc) 79 | dset.add_data(key='synapse', data=syn) 80 | dset.add_mask(key='synapse_mask', data=msk) 81 | dset.add_data(key='myelin', data=mye) 82 | dset.add_mask(key='myelin_mask', data=msk) 83 | 84 | return dset 85 | 86 | def get_mask(self, data): 87 | key = 'msk_train' if self.is_train else 'msk_val' 88 | if key in data: 89 | return data[key] 90 | assert 'msk' in data 91 | return data['msk'] 92 | -------------------------------------------------------------------------------- /deepem/data/sampler/aff_syn_mye_blv.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | from augmentor import Augment 4 | from dataprovider3 import DataProvider, Dataset 5 | 6 | 7 | def get_spec(in_spec, out_spec): 8 | spec = dict() 9 | # Input spec 10 | for k, v in in_spec.items(): 11 | spec[k] = v[-3:] 12 | # Output spec 13 | for k, v in out_spec.items(): 14 | dim = tuple(v[-3:]) 15 | spec[k] = dim 16 | spec[k+'_mask'] = dim 17 | return spec 18 | 19 | 20 | class Sampler(object): 21 | def __init__(self, data, spec, is_train, aug=None, prob=None): 22 | self.is_train = is_train 23 | if 'long_range' in spec: 24 | self.long_range = True 25 | del spec['long_range'] 26 | del spec['long_range_mask'] 27 | else: 28 | self.long_range = False 29 | self.build(data, spec, aug, prob) 30 | 31 | def __call__(self): 32 | sample = self.dataprovider() 33 | return self.postprocess(sample) 34 | 35 | def postprocess(self, sample): 36 | assert 'affinity' in sample 37 | assert 'synapse' in sample 38 | assert 'myelin' in sample 39 | assert 'blood_vessel' in sample 40 | 41 | # TODO: Copy or Ref? 42 | if self.long_range: 43 | sample['long_range'] = sample['affinity'] 44 | sample['long_range_mask'] = sample['affinity_mask'] 45 | 46 | # Blood vessel 47 | blv = sample['blood_vessel'] 48 | blv = np.concatenate((blv==1, blv==2), axis=-4) 49 | blv_msk = sample['blood_vessel_mask'] 50 | blv_msk = np.concatenate((blv_msk, blv_msk), axis=-4) 51 | sample['blood_vessel'] = blv 52 | sample['blood_vessel_mask'] = blv_msk 53 | 54 | sample = Augment.to_tensor(sample) 55 | return self.to_float32(sample) 56 | 57 | def to_float32(self, sample): 58 | for k, v in sample.items(): 59 | sample[k] = v.astype('float32') 60 | return sample 61 | 62 | def build(self, data, spec, aug, prob): 63 | dp = DataProvider(spec) 64 | keys = data.keys() 65 | for k in keys: 66 | dp.add_dataset(self.build_dataset(k, data[k])) 67 | dp.set_augment(aug) 68 | dp.set_imgs(['input']) 69 | dp.set_segs(['affinity']) 70 | if prob: 71 | dp.set_sampling_weights(p=[prob[k] for k in keys]) 72 | else: 73 | dp.set_sampling_weights(p=None) 74 | self.dataprovider = dp 75 | print(dp) 76 | 77 | def build_dataset(self, tag, data): 78 | img = data['img'] 79 | seg = data['seg'] 80 | syn = data['syn'] 81 | mye = data['mye'] 82 | blv = data['blv'] 83 | loc = data['loc'] 84 | msk = self.get_mask(data) 85 | 86 | # Create Dataset. 87 | dset = Dataset(tag=tag) 88 | dset.add_data(key='input', data=img) 89 | dset.add_data(key='affinity', data=seg) 90 | dset.add_mask(key='affinity_mask', data=msk, loc=loc) 91 | dset.add_data(key='synapse', data=syn) 92 | dset.add_mask(key='synapse_mask', data=msk) 93 | dset.add_data(key='myelin', data=mye) 94 | dset.add_mask(key='myelin_mask', data=msk) 95 | dset.add_data(key='blood_vessel', data=blv) 96 | dset.add_mask(key='blood_vessel_mask', data=msk) 97 | 98 | return dset 99 | 100 | def get_mask(self, data): 101 | key = 'msk_train' if self.is_train else 'msk_val' 102 | if key in data: 103 | return data[key] 104 | assert 'msk' in data 105 | return data['msk'] 106 | -------------------------------------------------------------------------------- /deepem/data/sampler/mit.py: -------------------------------------------------------------------------------- 1 | from augmentor import Augment 2 | from dataprovider3 import DataProvider, Dataset 3 | 4 | 5 | def get_spec(in_spec, out_spec): 6 | spec = dict() 7 | # Input spec 8 | for k, v in in_spec.items(): 9 | spec[k] = v[-3:] 10 | # Output spec 11 | for k, v in out_spec.items(): 12 | dim = tuple(v[-3:]) 13 | spec[k] = dim 14 | spec[k+'_mask'] = dim 15 | return spec 16 | 17 | 18 | class Sampler(object): 19 | def __init__(self, data, spec, is_train, aug=None, prob=None): 20 | self.is_train = is_train 21 | self.build(data, spec, aug, prob) 22 | 23 | def __call__(self): 24 | sample = self.dataprovider() 25 | return self.postprocess(sample) 26 | 27 | def postprocess(self, sample): 28 | assert 'mitochondria' in sample 29 | sample = Augment.to_tensor(sample) 30 | return self.to_float32(sample) 31 | 32 | def to_float32(self, sample): 33 | for k, v in sample.items(): 34 | sample[k] = v.astype('float32') 35 | return sample 36 | 37 | def build(self, data, spec, aug, prob): 38 | dp = DataProvider(spec) 39 | keys = data.keys() 40 | for k in keys: 41 | dp.add_dataset(self.build_dataset(k, data[k])) 42 | dp.set_augment(aug) 43 | dp.set_imgs(['input']) 44 | dp.set_segs([]) 45 | if prob: 46 | dp.set_sampling_weights(p=[prob[k] for k in keys]) 47 | else: 48 | dp.set_sampling_weights(p=None) 49 | self.dataprovider = dp 50 | print(dp) 51 | 52 | def build_dataset(self, tag, data): 53 | img = data['img'] 54 | mit = data['mit'] 55 | loc = data['loc'] 56 | msk = self.get_mask(data) 57 | 58 | # Create Dataset. 59 | dset = Dataset(tag=tag) 60 | dset.add_data(key='input', data=img) 61 | dset.add_data(key='mitochondria', data=mit) 62 | dset.add_mask(key='mitochondria_mask', data=msk, loc=loc) 63 | 64 | return dset 65 | 66 | def get_mask(self, data): 67 | key = 'msk_train' if self.is_train else 'msk_val' 68 | if key in data: 69 | return data[key] 70 | assert 'msk' in data 71 | return data['msk'] 72 | -------------------------------------------------------------------------------- /deepem/data/sampler/psd.py: -------------------------------------------------------------------------------- 1 | from augmentor import Augment 2 | from dataprovider3 import DataProvider, Dataset 3 | 4 | 5 | def get_spec(in_spec, out_spec): 6 | spec = dict() 7 | # Input spec 8 | for k, v in in_spec.items(): 9 | spec[k] = v[-3:] 10 | # Output spec 11 | for k, v in out_spec.items(): 12 | dim = tuple(v[-3:]) 13 | spec[k] = dim 14 | spec[k+'_mask'] = dim 15 | return spec 16 | 17 | 18 | class Sampler(object): 19 | def __init__(self, data, spec, is_train, aug=None, prob=None): 20 | self.is_train = is_train 21 | self.build(data, spec, aug, prob) 22 | 23 | def __call__(self): 24 | sample = self.dataprovider() 25 | return self.postprocess(sample) 26 | 27 | def postprocess(self, sample): 28 | assert 'synapse' in sample 29 | sample = Augment.to_tensor(sample) 30 | return self.to_float32(sample) 31 | 32 | def to_float32(self, sample): 33 | for k, v in sample.items(): 34 | sample[k] = v.astype('float32') 35 | return sample 36 | 37 | def build(self, data, spec, aug, prob): 38 | dp = DataProvider(spec) 39 | keys = data.keys() 40 | for k in keys: 41 | dp.add_dataset(self.build_dataset(k, data[k])) 42 | dp.set_augment(aug) 43 | dp.set_imgs(['input']) 44 | dp.set_segs([]) 45 | if prob: 46 | dp.set_sampling_weights(p=[prob[k] for k in keys]) 47 | else: 48 | dp.set_sampling_weights(p=None) 49 | self.dataprovider = dp 50 | print(dp) 51 | 52 | def build_dataset(self, tag, data): 53 | img = data['img'] 54 | psd = data['psd'] 55 | psd_msk = data['psd_msk'] 56 | loc = data['loc'] 57 | msk = self.get_mask(data) 58 | 59 | # Create Dataset. 60 | dset = Dataset(tag=tag) 61 | dset.add_data(key='input', data=img) 62 | dset.add_data(key='synapse', data=psd) 63 | dset.add_mask(key='synapse_mask', data=psd_msk, loc=loc) 64 | 65 | return dset 66 | 67 | def get_mask(self, data): 68 | key = 'msk_train' if self.is_train else 'msk_val' 69 | if key in data: 70 | return data[key] 71 | assert 'msk' in data 72 | return data['msk'] 73 | -------------------------------------------------------------------------------- /deepem/loss/__init__.py: -------------------------------------------------------------------------------- 1 | from .affinity import AffinityLoss 2 | from .loss import * 3 | -------------------------------------------------------------------------------- /deepem/loss/affinity.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | import torch 4 | import torch.nn as nn 5 | 6 | from deepem.utils import torch_utils 7 | 8 | 9 | class EdgeSampler(object): 10 | def __init__(self, edges): 11 | self.edges = list(edges) 12 | 13 | def generate_edges(self): 14 | return list(self.edges) 15 | 16 | def generate_true_aff(self, obj, edge): 17 | o1, o2 = torch_utils.get_pair(obj, edge) 18 | ret = (((o1 == o2) + (o1 != 0) + (o2 != 0)) == 3) 19 | return ret.type(obj.type()) 20 | 21 | def generate_mask_aff(self, mask, edge): 22 | m1, m2 = torch_utils.get_pair(mask, edge) 23 | return (m1 * m2).type(mask.type()) 24 | 25 | 26 | class EdgeCRF(nn.Module): 27 | def __init__(self, criterion, size_average=False, class_balancing=False): 28 | super(EdgeCRF, self).__init__() 29 | self.criterion = criterion 30 | self.size_average = size_average 31 | self.balancing = class_balancing 32 | 33 | def forward(self, preds, targets, masks): 34 | assert len(preds) == len(targets) == len(masks) 35 | loss, nmsk = 0, 0 36 | for pred, target, mask in zip(preds, targets, masks): 37 | mask = self.class_balancing(target, mask) 38 | l, n = self.criterion(pred, target, mask) 39 | loss += l 40 | nmsk += n 41 | assert nmsk.item() >= 0 42 | 43 | if nmsk.item() == 0: 44 | loss = torch.tensor(0).type(torch.cuda.FloatTensor) 45 | return loss, nmsk 46 | 47 | if self.size_average: 48 | assert nmsk.item() > 0 49 | try: 50 | loss = loss / nmsk.item() 51 | nmsk = torch.tensor(1, dtype=nmsk.dtype, device=nmsk.device) 52 | except: 53 | import pdb; pdb.set_trace() 54 | raise 55 | 56 | return loss, nmsk 57 | 58 | def class_balancing(self, target, mask): 59 | if not self.balancing: 60 | return mask 61 | dtype = mask.type() 62 | m_int = mask * torch.eq(target, 1).type(dtype) 63 | m_ext = mask * torch.eq(target, 0).type(dtype) 64 | n_int = m_int.sum().item() 65 | n_ext = m_ext.sum().item() 66 | if n_int > 0 and n_ext > 0: 67 | m_int *= n_ext/(n_int + n_ext) 68 | m_ext *= n_int/(n_int + n_ext) 69 | return (m_int + m_ext).type(dtype) 70 | 71 | 72 | class AffinityLoss(nn.Module): 73 | def __init__(self, edges, criterion, size_average=False, 74 | class_balancing=False): 75 | super(AffinityLoss, self).__init__() 76 | self.sampler = EdgeSampler(edges) 77 | self.decoder = AffinityLoss.Decoder(edges) 78 | self.criterion = EdgeCRF( 79 | criterion, 80 | size_average=size_average, 81 | class_balancing=class_balancing 82 | ) 83 | 84 | def forward(self, preds, label, mask): 85 | pred_affs = list() 86 | true_affs = list() 87 | mask_affs = list() 88 | edges = self.sampler.generate_edges() 89 | for i, edge in enumerate(edges): 90 | try: 91 | pred_affs.append(self.decoder(preds, i)) 92 | true_affs.append(self.sampler.generate_true_aff(label, edge)) 93 | mask_affs.append(self.sampler.generate_mask_aff(mask, edge)) 94 | except: 95 | raise 96 | return self.criterion(pred_affs, true_affs, mask_affs) 97 | 98 | class Decoder(nn.Module): 99 | def __init__(self, edges): 100 | super(AffinityLoss.Decoder, self).__init__() 101 | assert len(edges) > 0 102 | self.edges = list(edges) 103 | 104 | def forward(self, x, i): 105 | num_channels = x.size(-4) 106 | assert num_channels == len(self.edges) 107 | assert i < num_channels and i >= 0 108 | edge = self.edges[i] 109 | return torch_utils.get_pair_first(x[...,[i],:,:,:], edge) 110 | -------------------------------------------------------------------------------- /deepem/loss/loss.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | import torch 4 | from torch import nn 5 | from torch.nn import functional as F 6 | 7 | 8 | class BCELoss(nn.Module): 9 | """ 10 | Binary cross entropy loss with logits. 11 | """ 12 | def __init__(self, size_average=True, margin0=0, margin1=0, inverse=True, 13 | **kwargs): 14 | super(BCELoss, self).__init__() 15 | self.bce = F.binary_cross_entropy_with_logits 16 | self.size_average = size_average 17 | self.margin0 = float(np.clip(margin0, 0, 1)) 18 | self.margin1 = float(np.clip(margin1, 0, 1)) 19 | self.inverse = inverse 20 | 21 | def forward(self, input, target, mask): 22 | # Number of valid voxels 23 | nmsk = (mask > 0).type(mask.dtype).sum() 24 | assert nmsk.item() >= 0 25 | if nmsk.item() == 0: 26 | loss = torch.tensor(0).type(torch.cuda.FloatTensor) 27 | return loss, nmsk 28 | 29 | # Margin 30 | m0, m1 = self.margin0, self.margin1 31 | if m0 > 0 or m1 > 0: 32 | if self.inverse: 33 | target[torch.eq(target, 1)] = 1 - m1 34 | target[torch.eq(target, 0)] = m0 35 | else: 36 | activ = torch.sigmoid(input) 37 | m_int = torch.ge(activ, 1 - m1) * torch.eq(target, 1) 38 | m_ext = torch.le(activ, m0) * torch.eq(target, 0) 39 | mask *= 1 - (m_int + m_ext).type(mask.dtype) 40 | 41 | loss = self.bce(input, target, weight=mask, size_average=False) 42 | 43 | if self.size_average: 44 | loss = loss / nmsk.item() 45 | nmsk = torch.tensor(1, dtype=nmsk.dtype, device=nmsk.device) 46 | 47 | return loss, nmsk 48 | 49 | 50 | class MSELoss(nn.Module): 51 | """ 52 | Mean squared error loss with (or without) logits. 53 | """ 54 | def __init__(self, size_average=True, margin0=0, margin1=0, logits=True, 55 | **kwargs): 56 | super(MSELoss, self).__init__() 57 | self.mse = F.mse_loss 58 | self.size_average = size_average 59 | self.margin0 = float(np.clip(margin0, 0, 1)) 60 | self.margin1 = float(np.clip(margin1, 0, 1)) 61 | self.logits = logits 62 | 63 | def forward(self, input, target, mask): 64 | # Number of valid voxels 65 | nmsk = (mask > 0).type(mask.type()).sum() 66 | assert nmsk.item() >= 0 67 | if nmsk.item() == 0: 68 | loss = torch.tensor(0).type(torch.cuda.FloatTensor) 69 | return loss, nmsk 70 | 71 | activ = torch.sigmoid(input) if self.logits else input 72 | 73 | # Margin 74 | m0, m1 = self.margin0, self.margin1 75 | if m0 > 0 or m1 > 0: 76 | m_int = torch.ge(activ, 1 - m1) * torch.eq(target, 1) 77 | m_ext = torch.le(activ, m0) * torch.eq(target, 0) 78 | mask *= 1 - (m_int + m_ext).type(mask.dtype) 79 | 80 | loss = self.mse(activ, target, reduce=False) 81 | loss = (loss * mask).sum() 82 | 83 | if self.size_average: 84 | loss = loss / nmsk.item() 85 | nmsk = torch.tensor(1, dtype=nmsk.dtype, device=nmsk.device) 86 | 87 | return loss, nmsk 88 | -------------------------------------------------------------------------------- /deepem/models/layers.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | import emvision 5 | from emvision.models.utils import pad_size 6 | 7 | from deepem.utils import torch_utils 8 | 9 | 10 | class Conv(nn.Module): 11 | def __init__(self, in_channels, out_channels, kernel_size=3, stride=1, 12 | bias=False): 13 | super(Conv, self).__init__() 14 | padding = pad_size(kernel_size, 'same') 15 | self.conv = nn.Conv3d(in_channels, out_channels, 16 | kernel_size=kernel_size, stride=stride, padding=padding, bias=bias) 17 | nn.init.kaiming_normal_(self.conv.weight, nonlinearity='relu') 18 | if bias: 19 | nn.init.constant_(self.conv.bias, 0) 20 | 21 | def forward(self, x): 22 | return self.conv(x) 23 | 24 | 25 | class Scale(nn.Module): 26 | def __init__(self, init_value=1.0): 27 | super(Scale, self).__init__() 28 | self.scale = nn.Parameter(torch.FloatTensor([init_value])) 29 | 30 | def forward(self, x): 31 | return x * self.scale 32 | 33 | 34 | class Crop(nn.Module): 35 | def __init__(self, cropsz): 36 | super(Crop, self).__init__() 37 | self.cropsz = tuple(cropsz) 38 | 39 | def forward(self, x): 40 | if self.cropsz is not None: 41 | for k, v in x.items(): 42 | cropsz = [int(v.shape[i]*self.cropsz[i]) for i in [-3,-2,-1]] 43 | x[k] = torch_utils.crop_center(v, cropsz) 44 | return x 45 | -------------------------------------------------------------------------------- /deepem/models/rsunet.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | import emvision 5 | 6 | from deepem.models.layers import Conv 7 | 8 | 9 | def create_model(opt): 10 | if opt.width: 11 | width = opt.width 12 | depth = len(width) 13 | else: 14 | width = [16,32,64,128,256,512] 15 | depth = opt.depth 16 | if opt.group > 0: 17 | # Group normalization 18 | core = emvision.models.rsunet_gn(width=width[:depth], group=opt.group) 19 | else: 20 | # Batch (instance) normalization 21 | core = emvision.models.RSUNet(width=width[:depth]) 22 | return Model(core, opt.in_spec, opt.out_spec, width[0]) 23 | 24 | 25 | class InputBlock(nn.Sequential): 26 | def __init__(self, in_channels, out_channels, kernel_size): 27 | super(InputBlock, self).__init__() 28 | self.add_module('conv', Conv(in_channels, out_channels, kernel_size)) 29 | 30 | 31 | class OutputBlock(nn.Module): 32 | def __init__(self, in_channels, out_spec, kernel_size): 33 | super(OutputBlock, self).__init__() 34 | for k, v in out_spec.items(): 35 | out_channels = v[-4] 36 | self.add_module(k, 37 | Conv(in_channels, out_channels, kernel_size, bias=True)) 38 | 39 | def forward(self, x): 40 | return {k: m(x) for k, m in self.named_children()} 41 | 42 | 43 | class Model(nn.Sequential): 44 | """ 45 | Residual Symmetric U-Net. 46 | """ 47 | def __init__(self, core, in_spec, out_spec, out_channels, io_kernel=(1,5,5)): 48 | super(Model, self).__init__() 49 | 50 | assert len(in_spec)==1, "model takes a single input" 51 | in_channels = 1 52 | 53 | self.add_module('in', InputBlock(in_channels, out_channels, io_kernel)) 54 | self.add_module('core', core) 55 | self.add_module('out', OutputBlock(out_channels, out_spec, io_kernel)) 56 | -------------------------------------------------------------------------------- /deepem/models/rsunet_act.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | import emvision 5 | from emvision.models import rsunet_act, rsunet_act_gn 6 | 7 | from deepem.models.layers import Conv, Crop 8 | 9 | 10 | def create_model(opt): 11 | if opt.width: 12 | width = opt.width 13 | depth = len(width) 14 | else: 15 | width = [16,32,64,128,256,512] 16 | depth = opt.depth 17 | if opt.group > 0: 18 | # Group normalization 19 | core = rsunet_act_gn(width=width[:depth], group=opt.group, act=opt.act) 20 | else: 21 | # Batch normalization 22 | core = rsunet_act(width=width[:depth], act=opt.act) 23 | return Model(core, opt.in_spec, opt.out_spec, width[0], cropsz=opt.cropsz) 24 | 25 | 26 | class InputBlock(nn.Sequential): 27 | def __init__(self, in_channels, out_channels, kernel_size): 28 | super(InputBlock, self).__init__() 29 | self.add_module('conv', Conv(in_channels, out_channels, kernel_size)) 30 | 31 | 32 | class OutputBlock(nn.Module): 33 | def __init__(self, in_channels, out_spec, kernel_size): 34 | super(OutputBlock, self).__init__() 35 | for k, v in out_spec.items(): 36 | out_channels = v[-4] 37 | self.add_module(k, 38 | Conv(in_channels, out_channels, kernel_size, bias=True)) 39 | 40 | def forward(self, x): 41 | return {k: m(x) for k, m in self.named_children()} 42 | 43 | 44 | class Model(nn.Sequential): 45 | """ 46 | Residual Symmetric U-Net. 47 | """ 48 | def __init__(self, core, in_spec, out_spec, out_channels, io_kernel=(1,5,5), 49 | cropsz=None): 50 | super(Model, self).__init__() 51 | 52 | assert len(in_spec)==1, "model takes a single input" 53 | in_channels = 1 54 | 55 | self.add_module('in', InputBlock(in_channels, out_channels, io_kernel)) 56 | self.add_module('core', core) 57 | self.add_module('out', OutputBlock(out_channels, out_spec, io_kernel)) 58 | if cropsz is not None: 59 | self.add_module('crop', Crop(cropsz)) 60 | -------------------------------------------------------------------------------- /deepem/models/rsunet_deprecated.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | import emvision 5 | 6 | from deepem.models.layers import Conv, Crop 7 | 8 | 9 | def create_model(opt): 10 | if opt.width: 11 | width = opt.width 12 | depth = len(width) 13 | else: 14 | width = [16,32,64,128,256,512] 15 | depth = opt.depth 16 | if opt.group > 0: 17 | # Group normalization 18 | core = emvision.models.rsunet_gn(width=width[:depth], group=opt.group) 19 | else: 20 | # Batch (instance) normalization 21 | core = emvision.models.RSUNet(width=width[:depth]) 22 | return Model(core, opt.in_spec, opt.out_spec, width[0], cropsz=opt.cropsz) 23 | 24 | 25 | class InputBlock(nn.Sequential): 26 | def __init__(self, in_channels, out_channels, kernel_size): 27 | super(InputBlock, self).__init__() 28 | self.add_module('conv', Conv(in_channels, out_channels, kernel_size)) 29 | 30 | 31 | class OutputBlock(nn.Module): 32 | def __init__(self, in_channels, out_spec, kernel_size): 33 | super(OutputBlock, self).__init__() 34 | for k, v in out_spec.items(): 35 | out_channels = v[-4] 36 | self.add_module(k, nn.Sequential( 37 | Conv(in_channels, out_channels, kernel_size, bias=True) 38 | )) 39 | 40 | def forward(self, x): 41 | return {k: m(x) for k, m in self.named_children()} 42 | 43 | 44 | class Model(nn.Sequential): 45 | """ 46 | Residual Symmetric U-Net. 47 | """ 48 | def __init__(self, core, in_spec, out_spec, out_channels, cropsz=None): 49 | super(Model, self).__init__() 50 | 51 | assert len(in_spec)==1, "model takes a single input" 52 | in_channels = 1 53 | io_kernel = (1,5,5) 54 | 55 | self.add_module('in', InputBlock(in_channels, out_channels, io_kernel)) 56 | self.add_module('core', core) 57 | self.add_module('out', OutputBlock(out_channels, out_spec, io_kernel)) 58 | if cropsz is not None: 59 | self.add_module('crop', Crop(cropsz)) 60 | -------------------------------------------------------------------------------- /deepem/models/updown.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | import emvision 5 | 6 | from deepem.models.layers import Conv 7 | 8 | 9 | def create_model(opt): 10 | if opt.width: 11 | width = opt.width 12 | depth = len(width) 13 | else: 14 | width = [16,32,64,128,256,512] 15 | depth = opt.depth 16 | if opt.group > 0: 17 | # Group normalization 18 | core = emvision.models.rsunet_gn(width=width[:depth], group=opt.group) 19 | else: 20 | # Batch (instance) normalization 21 | core = emvision.models.RSUNet(width=width[:depth]) 22 | return Model(core, opt.in_spec, opt.out_spec, width[0]) 23 | 24 | 25 | class InputBlock(nn.Sequential): 26 | def __init__(self, in_channels, out_channels, kernel_size): 27 | super(InputBlock, self).__init__() 28 | self.add_module('conv', Conv(in_channels, out_channels, kernel_size)) 29 | 30 | 31 | class OutputBlock(nn.Module): 32 | def __init__(self, in_channels, out_spec, kernel_size): 33 | super(OutputBlock, self).__init__() 34 | for k, v in out_spec.items(): 35 | out_channels = v[-4] 36 | self.add_module(k, 37 | Conv(in_channels, out_channels, kernel_size, bias=True)) 38 | 39 | def forward(self, x): 40 | return {k: m(x) for k, m in self.named_children()} 41 | 42 | 43 | class DownBlock(nn.Sequential): 44 | def __init__(self, scale_factor=(1,2,2)): 45 | super(DownBlock, self).__init__() 46 | self.add_module('down', nn.AvgPool3d(scale_factor)) 47 | 48 | 49 | class UpBlock(nn.Module): 50 | def __init__(self, out_spec, scale_factor=(1,2,2)): 51 | super(UpBlock, self).__init__() 52 | for k, v in out_spec.items(): 53 | self.add_module(k, 54 | nn.Upsample(scale_factor=scale_factor, mode='trilinear')) 55 | 56 | def forward(self, x): 57 | return {k: m(x[k]) for k, m in self.named_children()} 58 | 59 | 60 | class Model(nn.Sequential): 61 | """ 62 | Residual Symmetric U-Net with down/upsampling in/output. 63 | """ 64 | def __init__(self, core, in_spec, out_spec, out_channels, io_kernel=(1,5,5), 65 | scale_factor=(1,2,2)): 66 | super(Model, self).__init__() 67 | 68 | assert len(in_spec)==1, "model takes a single input" 69 | in_channels = 1 70 | 71 | self.add_module('down', DownBlock(scale_factor=scale_factor)) 72 | self.add_module('in', InputBlock(in_channels, out_channels, io_kernel)) 73 | self.add_module('core', core) 74 | self.add_module('out', OutputBlock(out_channels, out_spec, io_kernel)) 75 | self.add_module('up', UpBlock(out_spec, scale_factor=scale_factor)) 76 | -------------------------------------------------------------------------------- /deepem/models/updown_act.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | import emvision 5 | from emvision.models import rsunet_act, rsunet_act_gn 6 | 7 | from deepem.models.layers import Conv, Crop 8 | 9 | 10 | def create_model(opt): 11 | if opt.width: 12 | width = opt.width 13 | depth = len(width) 14 | else: 15 | width = [16,32,64,128,256,512] 16 | depth = opt.depth 17 | if opt.group > 0: 18 | # Group normalization 19 | core = rsunet_act_gn(width=width[:depth], group=opt.group, act=opt.act) 20 | else: 21 | # Batch normalization 22 | core = rsunet_act(width=width[:depth], act=opt.act) 23 | return Model(core, opt.in_spec, opt.out_spec, width[0], cropsz=opt.cropsz) 24 | 25 | 26 | class InputBlock(nn.Sequential): 27 | def __init__(self, in_channels, out_channels, kernel_size): 28 | super(InputBlock, self).__init__() 29 | self.add_module('conv', Conv(in_channels, out_channels, kernel_size)) 30 | 31 | 32 | class OutputBlock(nn.Module): 33 | def __init__(self, in_channels, out_spec, kernel_size): 34 | super(OutputBlock, self).__init__() 35 | for k, v in out_spec.items(): 36 | out_channels = v[-4] 37 | self.add_module(k, 38 | Conv(in_channels, out_channels, kernel_size, bias=True)) 39 | 40 | def forward(self, x): 41 | return {k: m(x) for k, m in self.named_children()} 42 | 43 | 44 | class DownBlock(nn.Sequential): 45 | def __init__(self, scale_factor=(1,2,2)): 46 | super(DownBlock, self).__init__() 47 | self.add_module('down', nn.AvgPool3d(scale_factor)) 48 | 49 | 50 | class UpBlock(nn.Module): 51 | def __init__(self, out_spec, scale_factor=(1,2,2)): 52 | super(UpBlock, self).__init__() 53 | for k, v in out_spec.items(): 54 | self.add_module(k, 55 | nn.Upsample(scale_factor=scale_factor, mode='trilinear')) 56 | 57 | def forward(self, x): 58 | return {k: m(x[k]) for k, m in self.named_children()} 59 | 60 | 61 | class Model(nn.Sequential): 62 | """ 63 | Residual Symmetric U-Net with down/upsampling in/output. 64 | """ 65 | def __init__(self, core, in_spec, out_spec, out_channels, io_kernel=(1,5,5), 66 | scale_factor=(1,2,2), cropsz=None): 67 | super(Model, self).__init__() 68 | 69 | assert len(in_spec)==1, "model takes a single input" 70 | in_channels = 1 71 | 72 | self.add_module('down', DownBlock(scale_factor=scale_factor)) 73 | self.add_module('in', InputBlock(in_channels, out_channels, io_kernel)) 74 | self.add_module('core', core) 75 | self.add_module('out', OutputBlock(out_channels, out_spec, io_kernel)) 76 | self.add_module('up', UpBlock(out_spec, scale_factor=scale_factor)) 77 | if cropsz is not None: 78 | self.add_module('crop', Crop(cropsz)) 79 | -------------------------------------------------------------------------------- /deepem/models/updown_deprecated.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | import emvision 5 | from emvision.models.layers import BilinearUp 6 | from deepem.models.layers import Conv, Crop 7 | 8 | 9 | def create_model(opt): 10 | if opt.width: 11 | width = opt.width 12 | depth = len(width) 13 | else: 14 | width = [16,32,64,128,256,512] 15 | depth = opt.depth 16 | if opt.group > 0: 17 | # Group normalization 18 | core = emvision.models.rsunet_gn(width=width[:depth], group=opt.group) 19 | else: 20 | # Batch (instance) normalization 21 | core = emvision.models.RSUNet(width=width[:depth]) 22 | return Model(core, opt.in_spec, opt.out_spec, width[0], cropsz=opt.cropsz) 23 | 24 | 25 | class InputBlock(nn.Sequential): 26 | def __init__(self, in_channels, out_channels, kernel_size): 27 | super(InputBlock, self).__init__() 28 | self.add_module('down', nn.AvgPool3d((1,2,2))) 29 | self.add_module('conv', Conv(in_channels, out_channels, kernel_size)) 30 | 31 | 32 | class OutputBlock(nn.Module): 33 | def __init__(self, in_channels, out_spec, kernel_size): 34 | super(OutputBlock, self).__init__() 35 | for k, v in out_spec.items(): 36 | out_channels = v[-4] 37 | self.add_module(k, nn.Sequential( 38 | Conv(in_channels, out_channels, kernel_size, bias=True), 39 | BilinearUp(out_channels, out_channels) 40 | )) 41 | 42 | def forward(self, x): 43 | return {k: m(x) for k, m in self.named_children()} 44 | 45 | 46 | class Model(nn.Sequential): 47 | """ 48 | Residual Symmetric U-Net with down/upsampling in/output. 49 | """ 50 | def __init__(self, core, in_spec, out_spec, out_channels, cropsz=None): 51 | super(Model, self).__init__() 52 | 53 | assert len(in_spec)==1, "model takes a single input" 54 | in_channels = 1 55 | io_kernel = (1,5,5) 56 | 57 | self.add_module('in', InputBlock(in_channels, out_channels, io_kernel)) 58 | self.add_module('core', core) 59 | self.add_module('out', OutputBlock(out_channels, out_spec, io_kernel)) 60 | if cropsz is not None: 61 | self.add_module('crop', Crop(cropsz)) 62 | -------------------------------------------------------------------------------- /deepem/test/cv_utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | import cloudvolume as cv 4 | from cloudvolume.lib import Vec, Bbox 5 | from taskqueue import LocalTaskQueue 6 | 7 | from deepem.utils import py_utils 8 | 9 | 10 | def make_info(num_channels, layer_type, dtype, shape, resolution, 11 | offset=(0,0,0), chunk_size=(64,64,64)): 12 | return cv.CloudVolume.create_new_info( 13 | num_channels, layer_type, dtype, 'raw', resolution, offset, shape, 14 | chunk_size=chunk_size) 15 | 16 | 17 | def cutout(opt, gs_path, dtype='uint8'): 18 | if '{}' in gs_path: 19 | gs_path = gs_path.format(*opt.keywords) 20 | print(gs_path) 21 | 22 | # CloudVolume. 23 | cvol = cv.CloudVolume(gs_path, mip=opt.in_mip, cache=opt.cache, 24 | fill_missing=True, parallel=opt.parallel) 25 | 26 | # Cutout 27 | offset0 = cvol.mip_voxel_offset(0) 28 | if opt.center is not None: 29 | assert opt.size is not None 30 | opt.begin = tuple(x - (y//2) for x, y in zip(opt.center, opt.size)) 31 | opt.end = tuple(x + y for x, y in zip(opt.begin, opt.size)) 32 | else: 33 | if not opt.begin: 34 | opt.begin = offset0 35 | if not opt.end: 36 | if not opt.size: 37 | opt.end = offset0 + cvol.mip_volume_size(0) 38 | else: 39 | opt.end = tuple(x + y for x, y in zip(opt.begin, opt.size)) 40 | sl = [slice(x,y) for x, y in zip(opt.begin, opt.end)] 41 | print('begin = {}'.format(opt.begin)) 42 | print('end = {}'.format(opt.end)) 43 | 44 | # Coordinates 45 | print('mip 0 = {}'.format(sl)) 46 | sl = cvol.slices_from_global_coords(sl) 47 | print('mip {} = {}'.format(opt.in_mip, sl)) 48 | cutout = cvol[sl] 49 | 50 | # Transpose & squeeze 51 | cutout = cutout.transpose([3,2,1,0]) 52 | cutout = np.squeeze(cutout).astype(dtype) 53 | return cutout 54 | 55 | 56 | def ingest(data, opt, tag=None): 57 | # Neuroglancer format 58 | data = py_utils.to_tensor(data) 59 | data = data.transpose((3,2,1,0)) 60 | num_channels = data.shape[-1] 61 | shape = data.shape[:-1] 62 | 63 | # Offset 64 | if opt.offset is None: 65 | opt.offset = opt.begin 66 | 67 | # MIP level correction 68 | if opt.gs_input and opt.in_mip > 0: 69 | o = opt.offset 70 | p = pow(2,opt.in_mip) 71 | offset = (o[0]//p, o[1]//p, o[2]) 72 | else: 73 | offset = opt.offset 74 | 75 | # Patch offset correction (when output patch is smaller than input patch) 76 | patch_offset = (np.array(opt.inputsz) - np.array(opt.outputsz)) // 2 77 | offset = tuple(np.array(offset) + np.flip(patch_offset, 0)) 78 | 79 | # Create info 80 | info = make_info(num_channels, 'image', str(data.dtype), shape, 81 | opt.resolution, offset=offset, chunk_size=opt.chunk_size) 82 | print(info) 83 | gs_path = opt.gs_output 84 | if '{}' in opt.gs_output: 85 | if opt.keywords: 86 | gs_path = gs_path.format(*opt.keywords) 87 | else: 88 | if opt.center is not None: 89 | coord = "x{}_y{}_z{}".format(*opt.center) 90 | coord += "_s{}-{}-{}".format(*opt.size) 91 | else: 92 | coord = '_'.join(['{}-{}'.format(b,e) for b,e in zip(opt.begin,opt.end)]) 93 | gs_path = gs_path.format(coord) 94 | 95 | # Tagging 96 | if tag is not None: 97 | if gs_path[-1] == '/': 98 | gs_path += tag 99 | else: 100 | gs_path += ('/' + tag) 101 | 102 | print("gs_output:\n{}".format(gs_path)) 103 | cvol = cv.CloudVolume(gs_path, mip=0, info=info, 104 | parallel=opt.parallel) 105 | cvol[:,:,:,:] = data 106 | cvol.commit_info() 107 | 108 | # Downsample 109 | if opt.downsample: 110 | import igneous 111 | from igneous.task_creation import create_downsampling_tasks 112 | 113 | with LocalTaskQueue(parallel=opt.parallel) as tq: 114 | # create_downsampling_tasks(tq, gs_path, mip=0, fill_missing=True) 115 | tasks = create_downsampling_tasks(gs_path, mip=0, fill_missing=True) 116 | tq.insert_all(tasks) 117 | -------------------------------------------------------------------------------- /deepem/test/forward.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import time 3 | 4 | import torch 5 | 6 | from dataprovider3 import Dataset, ForwardScanner 7 | 8 | from deepem.test import fwd_utils 9 | 10 | 11 | class Forward(object): 12 | """ 13 | Forward scanning. 14 | """ 15 | def __init__(self, opt): 16 | self.device = opt.device 17 | self.in_spec = dict(opt.in_spec) 18 | self.out_spec = dict(opt.out_spec) 19 | self.scan_spec = dict(opt.scan_spec) 20 | self.scan_params = dict(opt.scan_params) 21 | self.test_aug = opt.test_aug 22 | self.variance = opt.variance 23 | self.precomputed = (opt.blend == 'precomputed') 24 | 25 | def __call__(self, model, scanner): 26 | dataset = scanner.dataset 27 | 28 | # Test-time augmentation 29 | if self.test_aug: 30 | 31 | # For variance computation 32 | if self.variance: 33 | aug_out = dict() 34 | for k, v in scanner.outputs.data.items(): 35 | aug_out[k] = list() 36 | else: 37 | aug_out = None 38 | 39 | count = 0.0 40 | for aug in self.test_aug: 41 | # dec2bin 42 | rule = np.array([int(x) for x in bin(aug)[2:].zfill(4)]) 43 | print("Test-time augmentation {}".format(rule)) 44 | 45 | # Augment dataset. 46 | aug_dset = Dataset(spec=self.in_spec) 47 | for k, v in dataset.data.items(): 48 | aug_dset.add_data(k, fwd_utils.flip(v._data, rule=rule)) 49 | 50 | # Forward scan 51 | aug_scanner = self.make_forward_scanner(aug_dset) 52 | outputs = self.forward(model, aug_scanner) 53 | 54 | # Accumulate. 55 | for k, v in scanner.outputs.data.items(): 56 | print("Accumulate to {}...".format(k)) 57 | output = outputs.get_data(k) 58 | 59 | # Revert output. 60 | dst = (1,1,1) if k == 'affinity' else None 61 | reverted = fwd_utils.revert_flip(output, rule=rule, dst=dst) 62 | v._data += reverted 63 | 64 | # For variance computation 65 | if self.variance: 66 | aug_out[k].append(reverted) 67 | 68 | count += 1 69 | 70 | # Normalize. 71 | for k, v in scanner.outputs.data.items(): 72 | print("Normalize {}...".format(k)) 73 | if self.precomputed: 74 | v._data[...] /= count 75 | else: 76 | v._norm._data[...] = count 77 | 78 | return (scanner.outputs, aug_out) 79 | 80 | return (self.forward(model, scanner), None) 81 | 82 | #################################################################### 83 | ## Non-interface functions 84 | #################################################################### 85 | 86 | def forward(self, model, scanner): 87 | elapsed = list() 88 | t0 = time.time() 89 | with torch.no_grad(): 90 | inputs = scanner.pull() 91 | while inputs: 92 | inputs = self.to_torch(inputs) 93 | 94 | # Forward pass 95 | outputs = model(inputs) 96 | scanner.push(self.from_torch(outputs)) 97 | 98 | # Elapsed time 99 | elapsed.append(time.time() - t0) 100 | print("Elapsed: %.3f s" % elapsed[-1]) 101 | t0 = time.time() 102 | 103 | # Fetch next inputs 104 | inputs = scanner.pull() 105 | 106 | print("Elapsed: %.3f s/patch" % (sum(elapsed)/len(elapsed))) 107 | print("Throughput: %d voxel/s" % round(scanner.voxels()/sum(elapsed))) 108 | return scanner.outputs 109 | 110 | def to_torch(self, sample): 111 | inputs = dict() 112 | for k in sorted(self.in_spec): 113 | data = np.expand_dims(sample[k], axis=0) 114 | tensor = torch.from_numpy(data) 115 | inputs[k] = tensor.to(self.device) 116 | return inputs 117 | 118 | def from_torch(self, outputs): 119 | ret = dict() 120 | for k in sorted(self.out_spec): 121 | if k in self.scan_spec: 122 | scan_channels = self.scan_spec[k][-4] 123 | narrowed = outputs[k].narrow(1, 0, scan_channels) 124 | ret[k] = np.squeeze(narrowed.cpu().numpy(), axis=(0,)) 125 | return ret 126 | 127 | def make_forward_scanner(self, dataset): 128 | return ForwardScanner(dataset, self.scan_spec, **self.scan_params) 129 | -------------------------------------------------------------------------------- /deepem/test/fwd_utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | from deepem.utils import py_utils 4 | 5 | 6 | class Flip(object): 7 | def __call__(self, data, rule): 8 | """Flip data according to a specified rule. 9 | 10 | Args: 11 | data: 4D numpy array to be transformed. 12 | rule: Transform rule, specified as a Boolean array. 13 | [z-flip, y-flip, x-flip, xy-transpose] 14 | 15 | Returns: 16 | Transformed data. 17 | """ 18 | data = py_utils.to_tensor(data) 19 | assert np.size(rule)==4 20 | 21 | # z-flip 22 | if rule[0]: 23 | data = np.flip(data, axis=-3) 24 | # y-flip 25 | if rule[1]: 26 | data = np.flip(data, axis=-2) 27 | # x-flip 28 | if rule[2]: 29 | data = np.flip(data, axis=-1) 30 | # xy-transpose 31 | if rule[3]: 32 | data = data.transpose(0,1,3,2) 33 | 34 | # Prevent potential negative stride issues by copying. 35 | return np.copy(data) 36 | 37 | flip = Flip() 38 | 39 | 40 | def revert_flip(data, rule, dst=None): 41 | data = py_utils.to_tensor(data) 42 | assert np.size(rule)==4 43 | 44 | # Special treat for affinity. 45 | is_affinity = dst is not None 46 | if is_affinity: 47 | (dz,dy,dx) = dst 48 | assert data.shape[-4] >= 3 49 | assert dx and abs(dx) < data.shape[-1] 50 | assert dy and abs(dy) < data.shape[-2] 51 | assert dz and abs(dz) < data.shape[-3] 52 | 53 | # xy-transpose 54 | if rule[3]: 55 | data = data.transpose(0,1,3,2) 56 | # Swap x/y-affinity maps. 57 | if is_affinity: 58 | data[[0,1],...] = data[[1,0],...] 59 | 60 | # x-flip 61 | if rule[2]: 62 | data = np.flip(data, axis=-1) 63 | # Special treatment for x-affinity. 64 | if is_affinity: 65 | if dx > 0: 66 | data[0,:,:,dx:] = data[0,:,:,:-dx] 67 | data[0,:,:,:dx].fill(0) 68 | else: 69 | dx = abs(dx) 70 | data[0,:,:,:-dx] = data[0,:,:,dx:] 71 | data[0,:,:,-dx:].fill(0) 72 | 73 | # y-flip 74 | if rule[1]: 75 | data = np.flip(data, axis=-2) 76 | # Special treatment for y-affinity. 77 | if is_affinity: 78 | if dy > 0: 79 | data[1,:,dy:,:] = data[1,:,:-dy,:] 80 | data[1,:,:dy,:].fill(0) 81 | else: 82 | dy = abs(dy) 83 | data[1,:,:-dy,:] = data[1,:,dy:,:] 84 | data[1,:,-dy:,:].fill(0) 85 | 86 | # z-flip 87 | if rule[0]: 88 | data = np.flip(data, axis=-3) 89 | # Special treatment for z-affinity. 90 | if is_affinity: 91 | if dz > 0: 92 | data[2,dz:,:,:] = data[2,:-dz,:,:] 93 | data[2,:dz,:,:].fill(0) 94 | else: 95 | dz = abs(dz) 96 | data[2,:-dz,:,:] = data[2,dz:,:,:] 97 | data[2,-dz:,:,:].fill(0) 98 | 99 | return data 100 | -------------------------------------------------------------------------------- /deepem/test/mask.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | """ 4 | Adapted from https://github.com/seung-lab/chunkflow/. 5 | """ 6 | 7 | class PatchMask(np.ndarray): 8 | def __new__(cls, patch_size, overlap): 9 | assert len(patch_size) == 3 10 | assert len(overlap) == 3 11 | 12 | mask = make_mask(patch_size, overlap) 13 | return np.asarray(mask).view(cls) 14 | 15 | 16 | class AffinityMask(np.ndarray): 17 | def __new__(cls, patch_size, overlap, edges, bump): 18 | assert len(patch_size) == 3 19 | assert len(overlap) == 3 20 | assert len(edges) > 0 21 | assert bump in ['zung','wu', 'wu_no_crust'] 22 | 23 | masks = list() 24 | for edge in edges: 25 | mask = make_mask(patch_size, overlap, edge=edge, bump=bump) 26 | masks.append(mask) 27 | mask = np.stack(masks) 28 | return np.asarray(mask).view(cls) 29 | 30 | 31 | def make_mask(patch_size, overlap, edge=None, bump='zung'): 32 | # Stride 33 | stride = tuple(p - o for p, o in zip(patch_size, overlap)) 34 | 35 | # Offsets of the 3x3x3 grid 36 | offsets = list() 37 | for z in range(3): 38 | for y in range(3): 39 | for x in range(3): 40 | offsets.append((z,y,x)) 41 | 42 | # Slices 43 | slices = list() 44 | for offset in offsets: 45 | s = tuple(slice(o*s,o*s+p) for o,s,p in zip(offset,stride,patch_size)) 46 | slices.append(s) 47 | 48 | # Shape of the 3x3x3 overlapping grid 49 | shape = tuple(f + 2*s for f, s in zip(patch_size, stride)) 50 | base_mask = np.zeros(shape, dtype=np.float64) 51 | 52 | if bump == 'zung': 53 | 54 | # Max logit 55 | max_logit = np.full(shape, -np.inf, dtype=np.float64) 56 | logit = bump_logit_map(patch_size) 57 | for s in slices: 58 | max_logit[s] = np.maximum(max_logit[s], logit) 59 | 60 | # Mask 61 | for s in slices: 62 | base_mask[s] += bump_map(logit, max_logit[s], edge=edge) 63 | 64 | # Normalized weight 65 | s = tuple(slice(s,s+p) for s,p in zip(stride,patch_size)) 66 | weight = bump_map(logit, max_logit[s], edge=edge) / base_mask[s] 67 | 68 | elif bump == 'wu': 69 | 70 | # Mask 71 | bmap = bump_map_wu(patch_size, edge=edge) 72 | for s in slices: 73 | base_mask[s] += bmap 74 | 75 | # Normalized weight 76 | s = tuple(slice(s,s+p) for s,p in zip(stride,patch_size)) 77 | weight = bmap / base_mask[s] 78 | 79 | elif bump == 'wu_no_crust': 80 | 81 | # Mask 82 | bmap = bump_map_wu_no_crust(patch_size, edge=edge) 83 | for s in slices: 84 | base_mask[s] += bmap 85 | 86 | # Normalized weight 87 | s = tuple(slice(s,s+p) for s,p in zip(stride,patch_size)) 88 | weight = bmap / base_mask[s] 89 | 90 | else: 91 | assert False 92 | 93 | return np.asarray(weight, dtype=np.float32) 94 | 95 | 96 | def bump_logit(z, y, x, t=1.5): 97 | return -(x*(1-x))**(-t)-(y*(1-y))**(-t)-(z*(1-z))**(-t) 98 | 99 | 100 | def bump_logit_map(patch_size): 101 | x = range(patch_size[-1]) 102 | y = range(patch_size[-2]) 103 | z = range(patch_size[-3]) 104 | zv, yv, xv = np.meshgrid(z, y, x, indexing='ij') 105 | xv = (xv + 1.0)/(patch_size[-1] + 1.0) 106 | yv = (yv + 1.0)/(patch_size[-2] + 1.0) 107 | zv = (zv + 1.0)/(patch_size[-3] + 1.0) 108 | return np.asarray(bump_logit(zv, yv, xv), dtype=np.float64) 109 | 110 | 111 | def mask_edge(weight, edge=None): 112 | if edge is not None: 113 | assert len(edge) == 3 114 | z, y, x = edge 115 | assert abs(x) < weight.shape[-1] 116 | if x > 0: 117 | weight[:,:,:x] = 0 118 | elif x < 0: 119 | weight[:,:,x:] = 0 120 | assert abs(y) < weight.shape[-2] 121 | if y > 0: 122 | weight[:,:y,:] = 0 123 | elif y < 0: 124 | weight[:,y:,:] = 0 125 | assert abs(z) < weight.shape[-3] 126 | if z > 0: 127 | weight[:z,:,:] = 0 128 | elif z < 0: 129 | weight[z:,:,:] = 0 130 | return weight 131 | 132 | 133 | def bump_map(logit, max_logit, edge=None): 134 | weight = np.exp(logit - max_logit) 135 | return mask_edge(weight, edge=edge) 136 | 137 | 138 | def bump_map_wu(patch_size, edge=None): 139 | """Wu blending""" 140 | x = range(patch_size[-1]) 141 | y = range(patch_size[-2]) 142 | z = range(patch_size[-3]) 143 | zv, yv, xv = np.meshgrid(z, y, x, indexing='ij') 144 | xv = (xv + 1.0)/(patch_size[-1] + 1.0) * 2.0 - 1.0 145 | yv = (yv + 1.0)/(patch_size[-2] + 1.0) * 2.0 - 1.0 146 | zv = (zv + 1.0)/(patch_size[-3] + 1.0) * 2.0 - 1.0 147 | weight = np.exp(-1.0/(1.0 - xv*xv) + 148 | -1.0/(1.0 - yv*yv) + 149 | -1.0/(1.0 - zv*zv)) 150 | weight = mask_edge(weight, edge=edge) 151 | return np.asarray(weight, dtype=np.float64) 152 | 153 | 154 | def bump_map_wu_no_crust(patch_size, edge=None): 155 | """Wu blending with crust suppressed""" 156 | weight = bump_map_wu(patch_size, edge=edge) 157 | 158 | # Ignore the "crust" 159 | weight[[0,-1],:,:] = 0 160 | weight[:,[0,-1],:] = 0 161 | weight[:,:,[0,-1]] = 0 162 | 163 | return weight 164 | -------------------------------------------------------------------------------- /deepem/test/model.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | import torch 4 | from torch import nn 5 | from torch.nn import functional as F 6 | 7 | from deepem.utils import torch_utils 8 | from deepem.test.mask import PatchMask, AffinityMask 9 | 10 | 11 | class Model(nn.Module): 12 | """ 13 | Model wrapper for inference. 14 | """ 15 | def __init__(self, model, opt): 16 | super(Model, self).__init__() 17 | self.device = opt.device 18 | self.model = model 19 | self.in_spec = dict(opt.in_spec) 20 | self.scan_spec = dict(opt.scan_spec) 21 | self.pretrain = opt.pretrain 22 | self.force_crop = opt.force_crop 23 | 24 | # Softer softmax 25 | if opt.temperature is None: 26 | self.temperature = None 27 | else: 28 | self.temperature = max(opt.temperature, 1.0) 29 | 30 | # Precomputed mask 31 | self.mask = dict() 32 | if opt.blend == 'precomputed': 33 | for k, v in opt.scan_spec.items(): 34 | patch_sz = v[-3:] 35 | if k == 'affinity': 36 | edges = opt.mask_edges 37 | mask = AffinityMask(patch_sz, opt.overlap, edges, opt.bump) 38 | else: 39 | mask = PatchMask(patch_sz, opt.overlap) 40 | mask = np.expand_dims(mask, axis=0) 41 | mask = np.expand_dims(mask, axis=0) 42 | self.mask[k] = torch.from_numpy(mask).to(opt.device) 43 | 44 | def forward(self, sample): 45 | inputs = [sample[k] for k in sorted(self.in_spec)] 46 | preds = self.model(*inputs) 47 | outputs = dict() 48 | for k, x in preds.items(): 49 | if self.temperature is None: 50 | outputs[k] = torch.sigmoid(x) 51 | else: 52 | outputs[k] = torch.sigmoid(x/self.temperature) 53 | 54 | # Narrowing 55 | output_channels = outputs[k].shape[-4] 56 | scan_channels = self.scan_spec[k][-4] 57 | assert output_channels >= scan_channels 58 | if output_channels > scan_channels: 59 | outputs[k] = outputs[k].narrow(-4, 0, scan_channels) 60 | 61 | # Precomputed mask 62 | if k in self.mask: 63 | outputs[k] *= self.mask[k] 64 | 65 | # Crop outputs. 66 | if self.force_crop is not None: 67 | outputs[k] = torch_utils.crop_border(outputs[k], self.force_crop) 68 | 69 | return outputs 70 | 71 | 72 | def load(self, fpath): 73 | map_location = 'cpu' if self.device == 'cpu' else None 74 | chkpt = torch.load(fpath, map_location=map_location) 75 | # Backward compatibility 76 | state_dict = chkpt['state_dict'] if 'state_dict' in chkpt else chkpt 77 | if self.pretrain: 78 | model_dict = self.model.state_dict() 79 | state_dict = {k:v for k, v in state_dict.items() if k in model_dict} 80 | model_dict.update(state_dict) 81 | self.model.load_state_dict(model_dict) 82 | else: 83 | self.model.load_state_dict(state_dict) 84 | -------------------------------------------------------------------------------- /deepem/test/run.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import torch 4 | 5 | from deepem.test.forward import Forward 6 | from deepem.test.option import Options 7 | from deepem.test.utils import * 8 | 9 | 10 | def test(opt): 11 | # Model 12 | model = load_model(opt) 13 | 14 | # Forward scan 15 | forward = Forward(opt) 16 | 17 | if opt.gs_input: 18 | scanner = make_forward_scanner(opt) 19 | output, aug_out = forward(model, scanner) 20 | save_output(output, opt, aug_out=aug_out) 21 | else: 22 | for dname in opt.data_names: 23 | scanner = make_forward_scanner(opt, data_name=dname) 24 | output, _ = forward(model, scanner) 25 | save_output(output, opt, data_name=dname) 26 | 27 | 28 | if __name__ == "__main__": 29 | # Options 30 | opt = Options().parse() 31 | 32 | # GPU 33 | if not opt.cpu: 34 | os.environ['CUDA_VISIBLE_DEVICES'] = opt.gpu_id 35 | torch.backends.cudnn.benchmark = not opt.no_autotune 36 | 37 | # Make directories. 38 | if not os.path.isdir(opt.exp_dir): 39 | os.makedirs(opt.exp_dir) 40 | if not os.path.isdir(opt.model_dir): 41 | os.makedirs(opt.model_dir) 42 | if not os.path.isdir(opt.fwd_dir): 43 | os.makedirs(opt.fwd_dir) 44 | 45 | # Run inference. 46 | print("Running inference: {}".format(opt.exp_name)) 47 | test(opt) 48 | -------------------------------------------------------------------------------- /deepem/test/utils.py: -------------------------------------------------------------------------------- 1 | import imp 2 | import numpy as np 3 | import os 4 | from types import SimpleNamespace 5 | 6 | from dataprovider3 import Dataset, ForwardScanner, emio 7 | 8 | from deepem.test.model import Model 9 | from deepem.utils import py_utils 10 | 11 | 12 | def load_model(opt): 13 | # Create a model. 14 | mod = imp.load_source('model', opt.model) 15 | model = Model(mod.create_model(opt), opt) 16 | 17 | # Load from a checkpoint, if any. 18 | if opt.chkpt_num > 0: 19 | model = load_chkpt(model, opt.model_dir, opt.chkpt_num) 20 | 21 | model = model.train() if opt.no_eval else model.eval() 22 | return model.to(opt.device) 23 | 24 | 25 | def load_chkpt(model, fpath, chkpt_num): 26 | print("LOAD CHECKPOINT: {} iters.".format(chkpt_num)) 27 | fname = os.path.join(fpath, "model{}.chkpt".format(chkpt_num)) 28 | model.load(fname) 29 | return model 30 | 31 | 32 | def make_forward_scanner(opt, data_name=None): 33 | # Cloud-volume 34 | if opt.gs_input: 35 | try: 36 | from deepem.test import cv_utils 37 | img = cv_utils.cutout(opt, opt.gs_input, dtype='uint8') 38 | 39 | # Optional input histogram normalization 40 | if opt.gs_input_norm: 41 | assert len(opt.gs_input_norm) == 2 42 | low, high = opt.gs_input_norm 43 | img = normalize_per_slice(img, lowerfract=low, upperfract=high) 44 | 45 | # [0, 255] -> [0.0, 1.0] 46 | img = (img/255.).astype('float32') 47 | 48 | # Optional input mask 49 | if opt.gs_input_mask: 50 | try: 51 | msk = cv_utils.cutout(opt, opt.gs_input_mask, dtype='uint8') 52 | img[msk > 0] = 0 53 | except: 54 | raise 55 | 56 | except ImportError: 57 | raise 58 | else: 59 | assert data_name is not None 60 | print(data_name) 61 | # Read an EM image. 62 | if opt.dummy: 63 | img = np.random.rand(*opt.dummy_inputsz[-3:]).astype('float32') 64 | else: 65 | fpath = os.path.join(opt.data_dir, data_name, opt.input_name) 66 | img = emio.imread(fpath) 67 | img = (img/255.).astype('float32') 68 | 69 | # Border mirroring 70 | if opt.mirror: 71 | pad_width = [(x//2,x//2) for x in opt.mirror] 72 | img = np.pad(img, pad_width, 'reflect') 73 | 74 | # ForwardScanner 75 | dataset = Dataset(spec=opt.in_spec) 76 | dataset.add_data('input', img) 77 | return ForwardScanner(dataset, opt.scan_spec, **opt.scan_params) 78 | 79 | 80 | def save_output(output, opt, data_name=None, aug_out=None): 81 | for k in output.data: 82 | data = output.get_data(k) 83 | 84 | # Crop 85 | if opt.crop_border: 86 | data = py_utils.crop_border(data, opt.crop_border) 87 | if opt.crop_center: 88 | data = py_utils.crop_center(data, opt.crop_center) 89 | 90 | # Cloud-volume 91 | if opt.gs_output: 92 | try: 93 | tag = k 94 | if opt.tags is not None: 95 | if tag in opt.tags: 96 | tag = opt.tags[tag] 97 | 98 | from deepem.test import cv_utils 99 | cv_utils.ingest(data, opt, tag=tag) 100 | 101 | # Optional variance 102 | if aug_out is not None: 103 | variance = np.var(np.stack(aug_out[k]), axis=0) 104 | cv_utils.ingest(variance, opt, tag=(tag + '_var')) 105 | 106 | except ImportError: 107 | raise 108 | else: 109 | dname = data_name.replace('/', '_') 110 | fname = "{}_{}_{}".format(dname, k, opt.chkpt_num) 111 | if opt.out_prefix: 112 | fname = opt.out_prefix + '_' + fname 113 | if opt.out_tag: 114 | fname = fname + '_' + opt.out_tag 115 | fpath = os.path.join(opt.fwd_dir, fname + ".h5") 116 | emio.imsave(data, fpath) 117 | 118 | 119 | def histogram_per_slice(img): 120 | z = img.shape[-3] 121 | xy = img.shape[-2] * img.shape[-1] 122 | return np.apply_along_axis(np.bincount, axis=1, arr=img.reshape((z,xy)), 123 | minlength=255) 124 | 125 | 126 | def find_section_clamping_values(zlevel, lowerfract, upperfract): 127 | """Find int8 values that correspond to lowerfract & upperfract of zlevel histogram 128 | 129 | From igneous (https://github.com/seung-lab/igneous/blob/master/igneous/tasks/tasks.py#L547) 130 | """ 131 | filtered = np.copy(zlevel) 132 | 133 | # remove pure black from frequency counts as 134 | # it has no information in our images 135 | filtered[0] = 0 136 | 137 | cdf = np.zeros(shape=(len(filtered),), dtype=np.uint64) 138 | cdf[0] = filtered[0] 139 | for i in range(1, len(filtered)): 140 | cdf[i] = cdf[i - 1] + filtered[i] 141 | 142 | total = cdf[-1] 143 | 144 | if total == 0: 145 | return (0, 0) 146 | 147 | lower = 0 148 | for i, val in enumerate(cdf): 149 | if float(val) / float(total) > lowerfract: 150 | break 151 | lower = i 152 | 153 | upper = 0 154 | for i, val in enumerate(cdf): 155 | if float(val) / float(total) > upperfract: 156 | break 157 | upper = i 158 | 159 | return (lower, upper) 160 | 161 | 162 | def normalize_per_slice(img, lowerfract=0.01, upperfract=0.01): 163 | maxval = 255. 164 | hist = histogram_per_slice(img) 165 | img = img.astype(np.float32) 166 | for z in range(img.shape[-3]): 167 | lower, upper = find_section_clamping_values(hist[z], 168 | lowerfract=lowerfract, 169 | upperfract=1-upperfract) 170 | if lower == upper: 171 | continue 172 | 173 | im = img[z,:,:] 174 | im = (im - float(lower)) * (maxval / (float(upper) - float(lower))) 175 | img[z,:,:] = im 176 | 177 | img = np.round(img) 178 | return np.clip(img, 0., maxval).astype(np.uint8) -------------------------------------------------------------------------------- /deepem/train/data.py: -------------------------------------------------------------------------------- 1 | import imp 2 | import numpy as np 3 | 4 | import torch 5 | import torch.utils.data 6 | from torch.utils.data import DataLoader 7 | 8 | 9 | def worker_init_fn(worker_id): 10 | # Each worker already has its own random torch state. 11 | seed = torch.IntTensor(1).random_()[0] 12 | # print("worker ID = {}, seed = {}".format(worker_id, seed)) 13 | np.random.seed(seed) 14 | 15 | 16 | class Dataset(torch.utils.data.Dataset): 17 | def __init__(self, sampler, size): 18 | super(Dataset, self).__init__() 19 | self.sampler = sampler 20 | self.size = size 21 | 22 | def __len__(self): 23 | return self.size 24 | 25 | def __getitem__(self, idx): 26 | return self.sampler() 27 | 28 | 29 | class Data(object): 30 | def __init__(self, opt, data, is_train=True, prob=None): 31 | self.build(opt, data, is_train, prob) 32 | 33 | def __call__(self): 34 | sample = next(self.dataiter) 35 | if self.is_train: 36 | sample = self.modifier(sample) 37 | for k in sample: 38 | is_input = k in self.inputs 39 | sample[k].requires_grad_(is_input) 40 | sample[k] = sample[k].cuda(non_blocking=(not is_input)) 41 | return sample 42 | 43 | def requires_grad(self, key): 44 | return self.is_train and (key in self.inputs) 45 | 46 | def build(self, opt, data, is_train, prob): 47 | # Data augmentation 48 | if opt.augment: 49 | mod = imp.load_source('augment', opt.augment) 50 | aug = mod.get_augmentation(is_train, **opt.aug_params) 51 | else: 52 | aug = None 53 | 54 | # Data sampler 55 | mod = imp.load_source('sampler', opt.sampler) 56 | spec = mod.get_spec(opt.in_spec, opt.out_spec) 57 | sampler = mod.Sampler(data, spec, is_train, aug, prob=prob) 58 | 59 | # Sample modifier 60 | self.modifier = lambda x: x 61 | if opt.modifier is not None: 62 | mod = imp.load_source('modifier', opt.modifier) 63 | self.modifier = mod.Modifier() 64 | 65 | # Data loader 66 | size = (opt.max_iter - opt.chkpt_num) * opt.batch_size 67 | dataset = Dataset(sampler, size) 68 | dataloader = DataLoader(dataset, 69 | batch_size=opt.batch_size, 70 | num_workers=opt.num_workers, 71 | pin_memory=True, 72 | worker_init_fn=worker_init_fn) 73 | 74 | # Attributes 75 | self.dataiter = iter(dataloader) 76 | self.inputs = opt.in_spec.keys() 77 | self.is_train = is_train 78 | -------------------------------------------------------------------------------- /deepem/train/logger.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import datetime 4 | from collections import OrderedDict 5 | import numpy as np 6 | 7 | import torch 8 | from torchvision.utils import make_grid 9 | from tensorboardX import SummaryWriter 10 | 11 | from deepem.utils import torch_utils, py_utils 12 | 13 | 14 | class Logger(object): 15 | def __init__(self, opt): 16 | self.monitor = {'train': Logger.Monitor(), 'test': Logger.Monitor()} 17 | self.log_dir = opt.log_dir 18 | self.writer = SummaryWriter(opt.log_dir) 19 | self.in_spec = dict(opt.in_spec) 20 | self.out_spec = dict(opt.out_spec) 21 | self.outputsz = opt.outputsz 22 | self.lr = opt.lr 23 | 24 | # Basic logging 25 | self.timestamp = datetime.datetime.now().strftime("%y%m%d_%H%M%S") 26 | self.log_params(vars(opt)) 27 | self.log_command() 28 | self.log_command_args() 29 | 30 | # Blood vessel 31 | self.blv_num_channels = opt.blv_num_channels 32 | 33 | def __enter__(self): 34 | return self 35 | 36 | def __exit__(self, type, value, traceback): 37 | if self.writer: 38 | self.writer.close() 39 | 40 | def record(self, phase, loss, nmsk, **kwargs): 41 | monitor = self.monitor[phase] 42 | 43 | # Reduce to scalar values. 44 | to_scalar = lambda x: x.item() if torch.is_tensor(x) else x 45 | for k in sorted(loss): 46 | monitor.add_to('vals', k, to_scalar(loss[k])) 47 | monitor.add_to('norm', k, to_scalar(nmsk[k])) 48 | 49 | for k, v in kwargs.items(): 50 | monitor.add_to('vals', k, v) 51 | monitor.add_to('norm', k, 1) 52 | 53 | def check(self, phase, iter_num): 54 | stats = self.monitor[phase].flush() 55 | self.log(phase, iter_num, stats) 56 | self.display(phase, iter_num, stats) 57 | 58 | def log(self, phase, iter_num, stats): 59 | for k, v in stats.items(): 60 | self.writer.add_scalar('{}/{}'.format(phase, k), v, iter_num) 61 | 62 | def display(self, phase, iter_num, stats): 63 | disp = "[%s] Iter: %8d, " % (phase, iter_num) 64 | for k, v in stats.items(): 65 | disp += "%s = %.3f, " % (k, v) 66 | disp += "(lr = %.6f). " % self.lr 67 | print(disp) 68 | 69 | class Monitor(object): 70 | def __init__(self): 71 | self.vals = OrderedDict() 72 | self.norm = OrderedDict() 73 | 74 | def add_to(self, name, k, v): 75 | assert(name in ['vals','norm']) 76 | d = getattr(self, name) 77 | if k in d: 78 | d[k] += v 79 | else: 80 | d[k] = v 81 | 82 | def flush(self): 83 | ret = OrderedDict() 84 | for k in self.vals: 85 | ret[k] = self.vals[k]/self.norm[k] 86 | self.vals = OrderedDict() 87 | self.norm = OrderedDict() 88 | return ret 89 | 90 | def log_images(self, phase, iter_num, preds, sample): 91 | # Peep output size 92 | key = sorted(self.out_spec)[0] 93 | cropsz = sample[key].shape[-3:] 94 | for k in sorted(self.out_spec): 95 | outsz = sample[k].shape[-3:] 96 | assert np.array_equal(outsz, cropsz) 97 | 98 | # Inputs 99 | for k in sorted(self.in_spec): 100 | tag = '{}/images/{}'.format(phase, k) 101 | tensor = sample[k][0,...].cpu() 102 | tensor = torch_utils.crop_center(tensor, cropsz) 103 | self.log_image(tag, tensor, iter_num) 104 | 105 | # Outputs 106 | for k in sorted(self.out_spec): 107 | if k == 'affinity': 108 | # Prediction 109 | tag = '{}/images/{}'.format(phase, k) 110 | tensor = torch.sigmoid(preds[k][0,0:3,...]).cpu() 111 | self.log_image(tag, tensor, iter_num) 112 | 113 | # Mask 114 | tag = '{}/masks/{}'.format(phase, k) 115 | msk = sample[k + '_mask'][0,...].cpu() 116 | self.log_image(tag, msk, iter_num) 117 | 118 | # Target 119 | tag = '{}/labels/{}'.format(phase, k) 120 | seg = sample[k][0,0,...].cpu().numpy().astype('uint32') 121 | rgb = torch.from_numpy(py_utils.seg2rgb(seg)) 122 | self.log_image(tag, rgb, iter_num) 123 | 124 | elif k == 'myelin': 125 | # Prediction 126 | tag = '{}/images/{}'.format(phase, k) 127 | pred = torch.sigmoid(preds[k][0,...]).cpu() 128 | self.log_image(tag, pred, iter_num) 129 | 130 | elif k == 'mitochondria': 131 | # Prediction 132 | tag = '{}/images/{}'.format(phase, k) 133 | pred = torch.sigmoid(preds[k][0,...]).cpu() 134 | self.log_image(tag, pred, iter_num) 135 | 136 | # Target 137 | tag = '{}/labels/{}'.format(phase, k) 138 | target = sample[k][0,...].cpu() 139 | self.log_image(tag, target, iter_num) 140 | 141 | elif k == 'synapse': 142 | # Prediction 143 | tag = '{}/images/{}'.format(phase, k) 144 | pred = torch.sigmoid(preds[k][0,...]).cpu() 145 | self.log_image(tag, pred, iter_num) 146 | 147 | # Target 148 | tag = '{}/labels/{}'.format(phase, k) 149 | target = sample[k][0,...].cpu() 150 | self.log_image(tag, target, iter_num) 151 | 152 | elif k == 'blood_vessel': 153 | # Prediction 154 | tag = '{}/images/{}'.format(phase, k) 155 | pred = torch.sigmoid(preds[k][0,...]).cpu() 156 | if self.blv_num_channels == 2: 157 | zero = torch.zeros_like(pred[[0],...]) 158 | pred = torch.cat((pred,zero), dim=-4) 159 | self.log_image(tag, pred, iter_num) 160 | 161 | elif k == 'glia': 162 | # Prediction 163 | tag = '{}/images/{}'.format(phase, k) 164 | pred = torch.sigmoid(preds[k][0,...]).cpu() 165 | self.log_image(tag, pred, iter_num) 166 | 167 | # Target 168 | tag = '{}/labels/{}'.format(phase, k) 169 | target = sample[k][0,...].cpu() 170 | self.log_image(tag, target, iter_num) 171 | 172 | def log_image(self, tag, tensor, iter_num): 173 | assert(torch.is_tensor(tensor)) 174 | depth = tensor.shape[-3] 175 | imgs = [tensor[:,z,:,:] for z in range(depth)] 176 | img = make_grid(imgs, nrow=depth, padding=0) 177 | self.writer.add_image(tag, img, iter_num) 178 | 179 | def log_params(self, params): 180 | fname = os.path.join(self.log_dir, "{}_params.csv".format(self.timestamp)) 181 | with open(fname, "w+") as f: 182 | for k, v in params.items(): 183 | f.write("{k}: {v}\n".format(k=k, v=v)) 184 | 185 | def log_command(self): 186 | fname = os.path.join(self.log_dir, "{}_command".format(self.timestamp)) 187 | command = " ".join(sys.argv) 188 | with open(fname, "w+") as f: 189 | f.write(command) 190 | 191 | def log_command_args(self): 192 | fname = os.path.join(self.log_dir, "{}_args.txt".format(self.timestamp)) 193 | with open(fname, "w+") as f: 194 | for arg in sys.argv[1:]: 195 | f.write(arg + "\n") 196 | -------------------------------------------------------------------------------- /deepem/train/model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | 5 | class Model(nn.Module): 6 | """ 7 | Model wrapper for training. 8 | """ 9 | 10 | def __init__(self, model, criteria, opt): 11 | super(Model, self).__init__() 12 | self.model = model 13 | self.criteria = criteria 14 | self.in_spec = dict(opt.in_spec) 15 | self.out_spec = dict(opt.out_spec) 16 | self.pretrain = opt.pretrain is not None 17 | 18 | def forward(self, sample): 19 | # Forward pass 20 | inputs = [sample[k] for k in sorted(self.in_spec)] 21 | preds = self.model(*inputs) 22 | 23 | # Loss evaluation 24 | try: 25 | losses, nmasks = self.eval_loss(preds, sample) 26 | except: 27 | import pdb; pdb.set_trace() 28 | raise 29 | return losses, nmasks, preds 30 | 31 | def eval_loss(self, preds, sample): 32 | losses, nmasks = dict(), dict() 33 | for k in self.out_spec: 34 | target = sample[k] 35 | mask = sample[k + '_mask'] 36 | loss, nmsk = self.criteria[k](preds[k], target, mask) 37 | # PyTorch 0.4.0-specific workaround 38 | losses[k] = loss.unsqueeze(0) 39 | nmasks[k] = nmsk.unsqueeze(0) 40 | return losses, nmasks 41 | 42 | def state_dict(self): 43 | return self.model.state_dict() 44 | 45 | def save(self, fpath): 46 | torch.save(self.model.state_dict(), fpath) 47 | 48 | def load(self, fpath): 49 | chkpt = torch.load(fpath) 50 | # Backward compatibility 51 | state_dict = chkpt['state_dict'] if 'state_dict' in chkpt else chkpt 52 | if self.pretrain: 53 | model_dict = self.model.state_dict() 54 | state_dict = {k:v for k, v in state_dict.items() if k in model_dict} 55 | model_dict.update(state_dict) 56 | self.model.load_state_dict(model_dict) 57 | else: 58 | self.model.load_state_dict(state_dict) 59 | -------------------------------------------------------------------------------- /deepem/train/run.py: -------------------------------------------------------------------------------- 1 | import imp 2 | import os 3 | import time 4 | 5 | import torch 6 | 7 | from deepem.train.logger import Logger 8 | from deepem.train.option import Options 9 | from deepem.train.utils import * 10 | 11 | 12 | def train(opt): 13 | # Model 14 | model = load_model(opt) 15 | 16 | # Optimizer 17 | trainable = filter(lambda p: p.requires_grad, model.parameters()) 18 | optimizer = load_optimizer(opt, trainable) 19 | 20 | # Data loaders 21 | train_loader, val_loader = load_data(opt) 22 | 23 | # Initial checkpoint 24 | save_chkpt(model, opt.model_dir, opt.chkpt_num, optimizer) 25 | 26 | # Training loop 27 | print("========== BEGIN TRAINING LOOP ==========") 28 | with Logger(opt) as logger: 29 | 30 | # Timer 31 | t0 = time.time() 32 | 33 | for i in range(opt.chkpt_num, opt.max_iter): 34 | 35 | # Load training samples. 36 | sample = train_loader() 37 | 38 | # Optimizer step 39 | optimizer.zero_grad() 40 | losses, nmasks, preds = forward(model, sample, opt) 41 | total_loss = sum([w*losses[k] for k, w in opt.loss_weight.items()]) 42 | total_loss.backward() 43 | optimizer.step() 44 | 45 | # Elapsed time 46 | elapsed = time.time() - t0 47 | 48 | # Record keeping 49 | logger.record('train', losses, nmasks, elapsed=elapsed) 50 | 51 | # Log & display averaged stats. 52 | if (i+1) % opt.avgs_intv == 0 or i < opt.warm_up: 53 | logger.check('train', i+1) 54 | 55 | # Image logging 56 | if (i+1) % opt.imgs_intv == 0: 57 | logger.log_images('train', i+1, preds, sample) 58 | 59 | # Evaluation loop 60 | if (i+1) % opt.eval_intv == 0: 61 | eval_loop(i+1, model, val_loader, opt, logger) 62 | 63 | # Model checkpoint 64 | if (i+1) % opt.chkpt_intv == 0: 65 | save_chkpt(model, opt.model_dir, i+1, optimizer) 66 | 67 | # Reset timer. 68 | t0 = time.time() 69 | 70 | 71 | def eval_loop(iter_num, model, data_loader, opt, logger): 72 | if not opt.no_eval: 73 | model.eval() 74 | 75 | # Evaluation loop 76 | print("---------- BEGIN EVALUATION LOOP ----------") 77 | with torch.no_grad(): 78 | t0 = time.time() 79 | for i in range(opt.eval_iter): 80 | sample = data_loader() 81 | losses, nmasks, preds = forward(model, sample, opt) 82 | elapsed = time.time() - t0 83 | 84 | # Record keeping 85 | logger.record('test', losses, nmasks, elapsed=elapsed) 86 | 87 | # Restart timer. 88 | t0 = time.time() 89 | 90 | # Log & display averaged stats. 91 | logger.check('test', iter_num) 92 | print("-------------------------------------------") 93 | 94 | model.train() 95 | 96 | 97 | if __name__ == "__main__": 98 | 99 | # Options 100 | opt = Options().parse() 101 | 102 | # GPUs 103 | os.environ['CUDA_VISIBLE_DEVICES'] = ','.join(opt.gpu_ids) 104 | 105 | # Make directories. 106 | if not os.path.isdir(opt.exp_dir): 107 | os.makedirs(opt.exp_dir) 108 | if not os.path.isdir(opt.log_dir): 109 | os.makedirs(opt.log_dir) 110 | if not os.path.isdir(opt.model_dir): 111 | os.makedirs(opt.model_dir) 112 | 113 | # cuDNN auto-tuning 114 | torch.backends.cudnn.benchmark = not opt.no_autotune 115 | 116 | # Run experiment. 117 | print("Running experiment: {}".format(opt.exp_name)) 118 | train(opt) 119 | -------------------------------------------------------------------------------- /deepem/train/utils.py: -------------------------------------------------------------------------------- 1 | import imp 2 | import os 3 | 4 | import torch 5 | from torch.nn.parallel import data_parallel 6 | 7 | import deepem.loss as loss 8 | from deepem.train.data import Data 9 | from deepem.train.model import Model 10 | 11 | 12 | def get_criteria(opt): 13 | criteria = dict() 14 | for k in opt.out_spec: 15 | if k == 'affinity' or k == 'long_range': 16 | if k == 'affinity': 17 | edges = [(0,0,1),(0,1,0),(1,0,0)] 18 | else: 19 | edges = list(opt.edges) 20 | assert len(edges) > 0 21 | params = dict(opt.loss_params) 22 | params['size_average'] = False 23 | criteria[k] = loss.AffinityLoss(edges, 24 | criterion=getattr(loss, opt.loss)(**params), 25 | size_average=opt.size_average, 26 | class_balancing=opt.class_balancing 27 | ) 28 | else: 29 | params = dict(opt.loss_params) 30 | if opt.default_aux: 31 | params['margin0'] = 0 32 | params['margin1'] = 0 33 | params['inverse'] = False 34 | criteria[k] = getattr(loss, 'BCELoss')(**params) 35 | return criteria 36 | 37 | 38 | def load_model(opt): 39 | # Create a model. 40 | mod = imp.load_source('model', opt.model) 41 | model = Model(mod.create_model(opt), get_criteria(opt), opt) 42 | 43 | if opt.pretrain: 44 | model.load(opt.pretrain) 45 | if opt.chkpt_num > 0: 46 | model = load_chkpt(model, opt.model_dir, opt.chkpt_num) 47 | 48 | return model.train().cuda() 49 | 50 | 51 | def load_optimizer(opt, trainable): 52 | # Create an optimizer. 53 | optimizer = getattr(torch.optim, opt.optim)(trainable, **opt.optim_params) 54 | 55 | if not opt.pretrain and opt.chkpt_num > 0: 56 | n = opt.chkpt_num 57 | fname = os.path.join(opt.model_dir, "model{}.chkpt".format(n)) 58 | chkpt = torch.load(fname) 59 | if 'optimizer' in chkpt: 60 | print("LOAD OPTIM STATE: {} iters.".format(n)) 61 | optimizer.load_state_dict(chkpt['optimizer']) 62 | for state in optimizer.state.values(): 63 | for k, v in state.items(): 64 | if isinstance(v, torch.Tensor): 65 | state[k] = v.cuda() 66 | 67 | print(optimizer) 68 | return optimizer 69 | 70 | 71 | def load_chkpt(model, fpath, chkpt_num): 72 | print("LOAD CHECKPOINT: {} iters.".format(chkpt_num)) 73 | fname = os.path.join(fpath, "model{}.chkpt".format(chkpt_num)) 74 | model.load(fname) 75 | return model 76 | 77 | 78 | def save_chkpt(model, fpath, chkpt_num, optimizer): 79 | print("SAVE CHECKPOINT: {} iters.".format(chkpt_num)) 80 | fname = os.path.join(fpath, "model{}.chkpt".format(chkpt_num)) 81 | state = {'iter': chkpt_num, 82 | 'state_dict': model.state_dict(), 83 | 'optimizer': optimizer.state_dict()} 84 | torch.save(state, fname) 85 | 86 | 87 | def load_data(opt): 88 | mod = imp.load_source('data', opt.data) 89 | data_ids = list(set().union(opt.train_ids, opt.val_ids)) 90 | data = mod.load_data(opt.data_dir, data_ids=data_ids, **opt.data_params) 91 | 92 | # Train 93 | train_data = {k: data[k] for k in opt.train_ids} 94 | if opt.train_prob: 95 | prob = dict(zip(opt.train_ids, opt.train_prob)) 96 | else: 97 | prob = None 98 | train_loader = Data(opt, train_data, is_train=True, prob=prob) 99 | 100 | # Validation 101 | val_data = {k: data[k] for k in opt.val_ids} 102 | if opt.val_prob: 103 | prob = dict(zip(opt.val_ids, opt.val_prob)) 104 | else: 105 | prob = None 106 | val_loader = Data(opt, val_data, is_train=False, prob=prob) 107 | 108 | return train_loader, val_loader 109 | 110 | 111 | def forward(model, sample, opt): 112 | # Forward pass 113 | if len(opt.gpu_ids) > 1: 114 | losses, nmasks, preds = data_parallel(model, sample) 115 | else: 116 | losses, nmasks, preds = model(sample) 117 | 118 | # Average over minibatch 119 | losses = {k: v.mean() for k, v in losses.items()} 120 | nmasks = {k: v.mean() for k, v in nmasks.items()} 121 | 122 | return losses, nmasks, preds 123 | -------------------------------------------------------------------------------- /deepem/utils/py_utils.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import numpy as np 3 | from collections import namedtuple 4 | 5 | 6 | def dict2tuple(d): 7 | return namedtuple('GenericDict', d.keys())(**d) 8 | 9 | 10 | def crop_border(img, size): 11 | assert(all([a > b for a, b in zip(img.shape[-3:], size[-3:])])) 12 | sz, sy, sx = [s // 2 for s in size[-3:]] 13 | return img[..., sz:-sz, sy:-sy, sx:-sx] 14 | 15 | 16 | def crop_center(img, size): 17 | assert(all([a >= b for a, b in zip(img.shape[-3:], size[-3:])])) 18 | z, y, x = size[-3:] 19 | sx = (img.shape[-1] - x) // 2 20 | sy = (img.shape[-2] - y) // 2 21 | sz = (img.shape[-3] - z) // 2 22 | return img[..., sz:sz+z, sy:sy+y, sx:sx+x] 23 | 24 | 25 | def vec3(s): 26 | try: 27 | z, y, x = map(int, s.split(',')) 28 | return (z,y,x) 29 | except: 30 | raise argparse.ArgumentTypeError("Vec3 must be z,y,x") 31 | 32 | 33 | def vec3f(s): 34 | try: 35 | z, y, x = map(float, s.split(',')) 36 | return (z,y,x) 37 | except: 38 | raise argparse.ArgumentTypeError("Vec3f must be z,y,x") 39 | 40 | 41 | def to_volume(data): 42 | """Ensure that data is a numpy 3D array.""" 43 | assert isinstance(data, np.ndarray) 44 | if data.ndim == 2: 45 | data = data[np.newaxis,...] 46 | elif data.ndim == 3: 47 | pass 48 | elif data.ndim == 4: 49 | assert data.shape[0]==1 50 | data = np.squeeze(data, axis=0) 51 | else: 52 | raise RuntimeError("data must be a numpy 3D array") 53 | assert data.ndim == 3 54 | return data 55 | 56 | 57 | def to_tensor(data): 58 | """Ensure that data is a numpy 4D array.""" 59 | assert isinstance(data, np.ndarray) 60 | if data.ndim == 2: 61 | data = data[np.newaxis,np.newaxis,...] 62 | elif data.ndim == 3: 63 | data = data[np.newaxis,...] 64 | elif data.ndim == 4: 65 | pass 66 | else: 67 | raise RuntimeError("data must be a numpy 4D array") 68 | assert data.ndim == 4 69 | return data 70 | 71 | 72 | def seg2rgb(seg, border=True): 73 | unq, unq_inv = np.unique(seg, return_inverse=True) 74 | 75 | # Random colormap 76 | N = len(unq) 77 | R = np.random.rand(N) 78 | G = np.random.rand(N) 79 | B = np.random.rand(N) 80 | if border: 81 | z = (unq == 0) 82 | R[z] = G[z] = B[z] = 0 83 | 84 | R = R[unq_inv].reshape(seg.shape) 85 | G = G[unq_inv].reshape(seg.shape) 86 | B = B[unq_inv].reshape(seg.shape) 87 | return np.stack((R,G,B), axis=0).astype(np.float32) 88 | -------------------------------------------------------------------------------- /deepem/utils/torch_utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | import torch 4 | from torch import nn 5 | from torch.nn import functional as F 6 | 7 | from deepem.utils import py_utils 8 | 9 | 10 | def get_pair_first(arr, edge): 11 | shape = arr.size()[-3:] 12 | edge = np.array(edge) 13 | os1 = np.maximum(edge, 0) 14 | os2 = np.maximum(-edge, 0) 15 | ret = arr[..., os1[0]:shape[0]-os2[0], 16 | os1[1]:shape[1]-os2[1], 17 | os1[2]:shape[2]-os2[2]] 18 | return ret 19 | 20 | 21 | def get_pair(arr, edge): 22 | shape = arr.size()[-3:] 23 | edge = np.array(edge) 24 | os1 = np.maximum(edge, 0) 25 | os2 = np.maximum(-edge, 0) 26 | arr1 = arr[..., os1[0]:shape[0]-os2[0], 27 | os1[1]:shape[1]-os2[1], 28 | os1[2]:shape[2]-os2[2]] 29 | arr2 = arr[..., os2[0]:shape[0]-os1[0], 30 | os2[1]:shape[1]-os1[1], 31 | os2[2]:shape[2]-os1[2]] 32 | return arr1, arr2 33 | 34 | 35 | def crop_border(v, size): 36 | assert all([a > b for a, b in zip(v.shape[-3:], size[-3:])]) 37 | sz, sy, sx = [s // 2 for s in size[-3:]] 38 | return v[..., sz:-sz, sy:-sy, sx:-sx] 39 | 40 | 41 | def crop_center(v, size): 42 | assert all([a >= b for a, b in zip(v.shape[-3:], size[-3:])]) 43 | z, y, x = size[-3:] 44 | sx = (v.shape[-1] - x) // 2 45 | sy = (v.shape[-2] - y) // 2 46 | sz = (v.shape[-3] - z) // 2 47 | return v[..., sz:sz+z, sy:sy+y, sx:sx+x] 48 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | cloud-volume 2 | scikit-image 3 | scikit-learn 4 | task-queue 5 | tensorflow 6 | tensorboard 7 | tensorboardX 8 | --------------------------------------------------------------------------------